From 7a82f1faa0f6157e3e3104d04531f62a8b1db90c Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Wed, 17 Sep 2025 12:46:18 -0400 Subject: [PATCH] ZJIT: Const-fold IsMethodCfunc --- test/ruby/test_zjit.rb | 12 ++++++++ vm_insnhelper.c | 12 ++++++++ zjit.c | 5 +++ zjit/bindgen/src/main.rs | 2 ++ zjit/src/codegen.rs | 2 +- zjit/src/cruby_bindings.inc.rs | 9 ++++++ zjit/src/hir.rs | 56 ++++++++++++++++------------------ 7 files changed, 68 insertions(+), 30 deletions(-) diff --git a/test/ruby/test_zjit.rb b/test/ruby/test_zjit.rb index 42d10490c5..e530fb797f 100644 --- a/test/ruby/test_zjit.rb +++ b/test/ruby/test_zjit.rb @@ -851,6 +851,18 @@ class TestZJIT < Test::Unit::TestCase }, insns: [:opt_new] end + def test_opt_new_invalidate_new + assert_compiles '["Foo", "foo"]', %q{ + class Foo; end + def test = Foo.new + test; test + result = [test.class.name] + def Foo.new = "foo" + result << test + result + }, insns: [:opt_new], call_threshold: 2 + end + def test_new_hash_empty assert_compiles '{}', %q{ def test = {} diff --git a/vm_insnhelper.c b/vm_insnhelper.c index 362af31188..8022a29a6e 100644 --- a/vm_insnhelper.c +++ b/vm_insnhelper.c @@ -2353,6 +2353,12 @@ vm_search_method(VALUE cd_owner, struct rb_call_data *cd, VALUE recv) return vm_cc_cme(cc); } +const struct rb_callable_method_entry_struct * +rb_zjit_vm_search_method(VALUE cd_owner, struct rb_call_data *cd, VALUE recv) +{ + return vm_search_method(cd_owner, cd, recv); +} + #if __has_attribute(transparent_union) typedef union { VALUE (*anyargs)(ANYARGS); @@ -2417,6 +2423,12 @@ vm_method_cfunc_is(const rb_iseq_t *iseq, CALL_DATA cd, VALUE recv, cfunc_type f return check_cfunc(cme, func); } +bool +rb_zjit_cme_is_cfunc(const rb_callable_method_entry_t *me, const cfunc_type func) +{ + return check_cfunc(me, func); +} + int rb_vm_method_cfunc_is(const rb_iseq_t *iseq, CALL_DATA cd, VALUE recv, cfunc_type func) { diff --git a/zjit.c b/zjit.c index 6bbe508f24..4b29578b4a 100644 --- a/zjit.c +++ b/zjit.c @@ -170,6 +170,11 @@ rb_zjit_local_id(const rb_iseq_t *iseq, unsigned idx) return ISEQ_BODY(iseq)->local_table[idx]; } +bool rb_zjit_cme_is_cfunc(const rb_callable_method_entry_t *me, const void *func); + +const struct rb_callable_method_entry_struct * +rb_zjit_vm_search_method(VALUE cd_owner, struct rb_call_data *cd, VALUE recv); + // Primitives used by zjit.rb. Don't put other functions below, which wouldn't use them. VALUE rb_zjit_assert_compiles(rb_execution_context_t *ec, VALUE self); VALUE rb_zjit_stats(rb_execution_context_t *ec, VALUE self, VALUE target_key); diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index c6f02be415..6e9a5a529f 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -342,6 +342,8 @@ fn main() { .allowlist_function("rb_get_cfp_ep_level") .allowlist_function("rb_get_cme_def_type") .allowlist_function("rb_zjit_constcache_shareable") + .allowlist_function("rb_zjit_vm_search_method") + .allowlist_function("rb_zjit_cme_is_cfunc") .allowlist_function("rb_get_cme_def_body_attr_id") .allowlist_function("rb_get_symbol_id") .allowlist_function("rb_get_cme_def_body_optimized_type") diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index 0f167ceec3..52d1dd315b 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -389,7 +389,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::FixnumAnd { left, right } => gen_fixnum_and(asm, opnd!(left), opnd!(right)), Insn::FixnumOr { left, right } => gen_fixnum_or(asm, opnd!(left), opnd!(right)), Insn::IsNil { val } => gen_isnil(asm, opnd!(val)), - &Insn::IsMethodCfunc { val, cd, cfunc } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc), + &Insn::IsMethodCfunc { val, cd, cfunc, state: _ } => gen_is_method_cfunc(jit, asm, opnd!(val), cd, cfunc), Insn::Test { val } => gen_test(asm, opnd!(val)), Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index 4bb1c3dffd..dfa1be9b8f 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -937,6 +937,15 @@ unsafe extern "C" { pub fn rb_zjit_defined_ivar(obj: VALUE, id: ID, pushval: VALUE) -> VALUE; pub fn rb_zjit_insn_leaf(insn: ::std::os::raw::c_int, opes: *const VALUE) -> bool; pub fn rb_zjit_local_id(iseq: *const rb_iseq_t, idx: ::std::os::raw::c_uint) -> ID; + pub fn rb_zjit_cme_is_cfunc( + me: *const rb_callable_method_entry_t, + func: *const ::std::os::raw::c_void, + ) -> bool; + pub fn rb_zjit_vm_search_method( + cd_owner: VALUE, + cd: *mut rb_call_data, + recv: VALUE, + ) -> *const rb_callable_method_entry_struct; pub fn rb_iseq_encoded_size(iseq: *const rb_iseq_t) -> ::std::os::raw::c_uint; pub fn rb_iseq_pc_at_idx(iseq: *const rb_iseq_t, insn_idx: u32) -> *mut VALUE; pub fn rb_iseq_opcode_at_pc(iseq: *const rb_iseq_t, pc: *const VALUE) -> ::std::os::raw::c_int; diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 114cfda549..7dcf1c6ba8 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -576,7 +576,7 @@ pub enum Insn { /// Return C `true` if `val` is `Qnil`, else `false`. IsNil { val: InsnId }, /// Return C `true` if `val`'s method on cd resolves to the cfunc. - IsMethodCfunc { val: InsnId, cd: *const rb_call_data, cfunc: *const u8 }, + IsMethodCfunc { val: InsnId, cd: *const rb_call_data, cfunc: *const u8, state: InsnId }, Defined { op_type: usize, obj: VALUE, pushval: VALUE, v: InsnId, state: InsnId }, GetConstantPath { ic: *const iseq_inline_constant_cache, state: InsnId }, @@ -1350,7 +1350,7 @@ impl Function { &ToRegexp { opt, ref values, state } => ToRegexp { opt, values: find_vec!(values), state }, &Test { val } => Test { val: find!(val) }, &IsNil { val } => IsNil { val: find!(val) }, - &IsMethodCfunc { val, cd, cfunc } => IsMethodCfunc { val: find!(val), cd, cfunc }, + &IsMethodCfunc { val, cd, cfunc, state } => IsMethodCfunc { val: find!(val), cd, cfunc, state }, Jump(target) => Jump(find_branch_edge!(target)), &IfTrue { val, ref target } => IfTrue { val: find!(val), target: find_branch_edge!(target) }, &IfFalse { val, ref target } => IfFalse { val: find!(val), target: find_branch_edge!(target) }, @@ -1906,6 +1906,16 @@ impl Function { self.push_insn_id(block, insn_id); } } + Insn::IsMethodCfunc { val, cd, cfunc, state } if self.type_of(val).ruby_object_known() => { + let class = self.type_of(val).ruby_object().unwrap(); + let cme = unsafe { rb_zjit_vm_search_method(self.iseq.into(), cd as *mut rb_call_data, class) }; + let is_expected_cfunc = unsafe { rb_zjit_cme_is_cfunc(cme, cfunc as *const c_void) }; + let method = unsafe { rb_vm_ci_mid((*cd).ci) }; + self.push_insn(block, Insn::PatchPoint { invariant: Invariant::MethodRedefined { klass: class, method, cme }, state }); + let replacement = self.push_insn(block, Insn::Const { val: Const::CBool(is_expected_cfunc) }); + self.insn_types[replacement.0] = self.infer_type(replacement); + self.make_equal_to(insn_id, replacement); + } Insn::ObjectAlloc { val, state } => { let val_type = self.type_of(val); if val_type.is_subtype(types::Class) && val_type.ruby_object_known() { @@ -2295,8 +2305,7 @@ impl Function { | &Insn::Return { val } | &Insn::Test { val } | &Insn::SetLocal { val, .. } - | &Insn::IsNil { val } - | &Insn::IsMethodCfunc { val, .. } => + | &Insn::IsNil { val } => worklist.push_back(val), &Insn::SetGlobal { val, state, .. } | &Insn::Defined { v: val, state, .. } @@ -2308,6 +2317,7 @@ impl Function { | &Insn::GuardBitEquals { val, state, .. } | &Insn::GuardShape { val, state, .. } | &Insn::ToArray { val, state } + | &Insn::IsMethodCfunc { val, state, .. } | &Insn::ToNewArray { val, state } => { worklist.push_back(val); worklist.push_back(state); @@ -3450,7 +3460,8 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { // TODO: Guard on a profiled class and add a patch point for #new redefinition let argc = unsafe { vm_ci_argc((*cd).ci) } as usize; let val = state.stack_topn(argc)?; - let test_id = fun.push_insn(block, Insn::IsMethodCfunc { val, cd, cfunc: rb_class_new_instance_pass_kw as *const u8 }); + let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); + let test_id = fun.push_insn(block, Insn::IsMethodCfunc { val, cd, cfunc: rb_class_new_instance_pass_kw as *const u8, state: exit_id }); // Jump to the fallback block if it's not the expected function. // Skip CheckInterrupts since the #new call will do it very soon anyway. @@ -3463,7 +3474,6 @@ pub fn iseq_to_hir(iseq: *const rb_iseq_t) -> Result { queue.push_back((state.clone(), target, target_idx, local_inval)); // Move on to the fast path - let exit_id = fun.push_insn(block, Insn::Snapshot { state: exit_state }); let insn_id = fun.push_insn(block, Insn::ObjectAlloc { val, state: exit_id }); state.stack_setn(argc, insn_id); state.stack_setn(argc + 1, insn_id); @@ -5400,8 +5410,8 @@ mod tests { bb0(v0:BasicObject): v5:BasicObject = GetConstantPath 0x1000 v6:NilClass = Const Value(nil) - v7:CBool = IsMethodCFunc v5, :new - IfFalse v7, bb1(v0, v6, v5) + v8:CBool = IsMethodCFunc v5, :new + IfFalse v8, bb1(v0, v6, v5) v10:HeapObject = ObjectAlloc v5 v12:BasicObject = SendWithoutBlock v10, :initialize CheckInterrupts @@ -8092,18 +8102,12 @@ mod opt_tests { PatchPoint StableConstantNames(0x1000, C) v34:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v6:NilClass = Const Value(nil) - v7:CBool = IsMethodCFunc v34, :new - IfFalse v7, bb1(v0, v6, v34) - v35:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008) - v12:BasicObject = SendWithoutBlock v35, :initialize + PatchPoint MethodRedefined(C@0x1008, new@0x1010, cme:0x1018) + v37:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008) + v12:BasicObject = SendWithoutBlock v37, :initialize CheckInterrupts - Jump bb2(v0, v35, v12) - bb1(v16:BasicObject, v17:NilClass, v18:Class[VALUE(0x1008)]): - v21:BasicObject = SendWithoutBlock v18, :new - Jump bb2(v16, v21, v17) - bb2(v23:BasicObject, v24:BasicObject, v25:BasicObject): CheckInterrupts - Return v24 + Return v37 "); } @@ -8126,19 +8130,13 @@ mod opt_tests { v36:Class[VALUE(0x1008)] = Const Value(VALUE(0x1008)) v6:NilClass = Const Value(nil) v7:Fixnum[1] = Const Value(1) - v8:CBool = IsMethodCFunc v36, :new - IfFalse v8, bb1(v0, v6, v36, v7) - v37:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008) - PatchPoint MethodRedefined(C@0x1008, initialize@0x1010, cme:0x1018) - v39:BasicObject = SendWithoutBlockDirect v37, :initialize (0x1040), v7 + PatchPoint MethodRedefined(C@0x1008, new@0x1010, cme:0x1018) + v39:HeapObject[class_exact:C] = ObjectAllocClass VALUE(0x1008) + PatchPoint MethodRedefined(C@0x1008, initialize@0x1040, cme:0x1048) + v41:BasicObject = SendWithoutBlockDirect v39, :initialize (0x1070), v7 CheckInterrupts - Jump bb2(v0, v37, v39) - bb1(v17:BasicObject, v18:NilClass, v19:Class[VALUE(0x1008)], v20:Fixnum[1]): - v23:BasicObject = SendWithoutBlock v19, :new, v20 - Jump bb2(v17, v23, v18) - bb2(v25:BasicObject, v26:BasicObject, v27:BasicObject): CheckInterrupts - Return v26 + Return v39 "); }