From 9ad902e55c610e66114f528f77f7895295a242de Mon Sep 17 00:00:00 2001 From: Max Bernstein Date: Mon, 20 Oct 2025 11:45:59 -0400 Subject: [PATCH] ZJIT: Inline String#==, String#=== --- depend | 1 + jit.c | 9 ++ yjit.c | 8 -- yjit/src/cruby_bindings.inc.rs | 2 +- zjit/src/cruby.rs | 1 + zjit/src/cruby_bindings.inc.rs | 1 + zjit/src/cruby_methods.rs | 21 ++++ zjit/src/hir.rs | 191 ++++++++++++++++++++++++++++++++- 8 files changed, 221 insertions(+), 13 deletions(-) diff --git a/depend b/depend index fa61de77a0..5ed27d04e0 100644 --- a/depend +++ b/depend @@ -7393,6 +7393,7 @@ jit.$(OBJEXT): $(top_srcdir)/internal/sanitizers.h jit.$(OBJEXT): $(top_srcdir)/internal/serial.h jit.$(OBJEXT): $(top_srcdir)/internal/set_table.h jit.$(OBJEXT): $(top_srcdir)/internal/static_assert.h +jit.$(OBJEXT): $(top_srcdir)/internal/string.h jit.$(OBJEXT): $(top_srcdir)/internal/variable.h jit.$(OBJEXT): $(top_srcdir)/internal/vm.h jit.$(OBJEXT): $(top_srcdir)/internal/warnings.h diff --git a/jit.c b/jit.c index 2ff38c28e2..3111dcc3e3 100644 --- a/jit.c +++ b/jit.c @@ -15,6 +15,7 @@ #include "internal/gc.h" #include "vm_sync.h" #include "internal/fixnum.h" +#include "internal/string.h" enum jit_bindgen_constants { // Field offsets for the RObject struct @@ -750,3 +751,11 @@ rb_jit_fix_mod_fix(VALUE recv, VALUE obj) { return rb_fix_mod_fix(recv, obj); } + +// YJIT/ZJIT need this function to never allocate and never raise +VALUE +rb_yarv_str_eql_internal(VALUE str1, VALUE str2) +{ + // We wrap this since it's static inline + return rb_str_eql_internal(str1, str2); +} diff --git a/yjit.c b/yjit.c index 807aec9e39..3793b0f1ac 100644 --- a/yjit.c +++ b/yjit.c @@ -283,14 +283,6 @@ rb_yjit_str_simple_append(VALUE str1, VALUE str2) extern VALUE *rb_vm_base_ptr(struct rb_control_frame_struct *cfp); -// YJIT needs this function to never allocate and never raise -VALUE -rb_yarv_str_eql_internal(VALUE str1, VALUE str2) -{ - // We wrap this since it's static inline - return rb_str_eql_internal(str1, str2); -} - VALUE rb_str_neq_internal(VALUE str1, VALUE str2) { diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 272586a79f..6542e5ef09 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -1134,7 +1134,6 @@ extern "C" { pub fn rb_yjit_builtin_function(iseq: *const rb_iseq_t) -> *const rb_builtin_function; pub fn rb_yjit_str_simple_append(str1: VALUE, str2: VALUE) -> VALUE; pub fn rb_vm_base_ptr(cfp: *mut rb_control_frame_struct) -> *mut VALUE; - pub fn rb_yarv_str_eql_internal(str1: VALUE, str2: VALUE) -> VALUE; pub fn rb_str_neq_internal(str1: VALUE, str2: VALUE) -> VALUE; pub fn rb_ary_unshift_m(argc: ::std::os::raw::c_int, argv: *mut VALUE, ary: VALUE) -> VALUE; pub fn rb_yjit_rb_ary_subseq_length(ary: VALUE, beg: ::std::os::raw::c_long) -> VALUE; @@ -1274,4 +1273,5 @@ extern "C" { end: *mut ::std::os::raw::c_void, ); pub fn rb_jit_fix_mod_fix(recv: VALUE, obj: VALUE) -> VALUE; + pub fn rb_yarv_str_eql_internal(str1: VALUE, str2: VALUE) -> VALUE; } diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 645891496e..41e0e847aa 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -1342,6 +1342,7 @@ pub(crate) mod ids { name: NULL content: b"" name: respond_to_missing content: b"respond_to_missing?" name: eq content: b"==" + name: string_eq content: b"String#==" name: include_p content: b"include?" name: to_ary name: to_s diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index c9e5bc8fd1..d9bd2d33c0 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -1437,4 +1437,5 @@ unsafe extern "C" { start: *mut ::std::os::raw::c_void, end: *mut ::std::os::raw::c_void, ); + pub fn rb_yarv_str_eql_internal(str1: VALUE, str2: VALUE) -> VALUE; } diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index 9d3f5a756b..eabddce739 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -193,6 +193,7 @@ pub fn init() -> Annotations { annotate!(rb_cString, "getbyte", inline_string_getbyte); annotate!(rb_cString, "empty?", types::BoolExact, no_gc, leaf, elidable); annotate!(rb_cString, "<<", inline_string_append); + annotate!(rb_cString, "==", inline_string_eq); annotate!(rb_cModule, "name", types::StringExact.union(types::NilClass), no_gc, leaf, elidable); annotate!(rb_cModule, "===", types::BoolExact, no_gc, leaf); annotate!(rb_cArray, "length", types::Fixnum, no_gc, leaf, elidable); @@ -291,6 +292,26 @@ fn inline_string_append(fun: &mut hir::Function, block: hir::BlockId, recv: hir: } } +fn inline_string_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option { + let &[other] = args else { return None; }; + if fun.likely_a(recv, types::String, state) && fun.likely_a(other, types::String, state) { + let recv = fun.coerce_to(block, recv, types::String, state); + let other = fun.coerce_to(block, other, types::String, state); + let return_type = types::BoolExact; + let elidable = true; + // TODO(max): Make StringEqual its own opcode so that we can later constant-fold StringEqual(a, a) => true + let result = fun.push_insn(block, hir::Insn::CCall { + cfunc: rb_yarv_str_eql_internal as *const u8, + args: vec![recv, other], + name: ID!(string_eq), + return_type, + elidable, + }); + return Some(result); + } + None +} + fn inline_integer_succ(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option { if !args.is_empty() { return None; } if fun.likely_a(recv, types::Fixnum, state) { diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 6b873401ca..91484ca970 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -14090,7 +14090,6 @@ mod opt_tests { "); } - // TODO: Should be optimized, but is waiting on String#== inlining #[test] fn test_optimize_string_append_string_subclass() { eval(r#" @@ -14114,9 +14113,11 @@ mod opt_tests { PatchPoint MethodRedefined(String@0x1000, <<@0x1008, cme:0x1010) PatchPoint NoSingletonClass(String@0x1000) v28:StringExact = GuardType v11, StringExact - v29:BasicObject = CCallWithFrame <<@0x1038, v28, v12 + v29:String = GuardType v12, String + v30:StringExact = StringAppend v28, v29 + IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - Return v29 + Return v28 "); } @@ -14142,7 +14143,7 @@ mod opt_tests { bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): PatchPoint MethodRedefined(MyString@0x1000, <<@0x1008, cme:0x1010) PatchPoint NoSingletonClass(MyString@0x1000) - v28:HeapObject[class_exact:MyString] = GuardType v11, HeapObject[class_exact:MyString] + v28:StringSubclass[class_exact:MyString] = GuardType v11, StringSubclass[class_exact:MyString] v29:BasicObject = CCallWithFrame <<@0x1038, v28, v12 CheckInterrupts Return v29 @@ -15015,4 +15016,186 @@ mod opt_tests { Return v21 "); } + + #[test] + fn test_optimize_stringexact_eq_stringexact() { + eval(r#" + def test(l, r) = l == r + test("a", "b") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(String@0x1000, ==@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(String@0x1000) + v28:StringExact = GuardType v11, StringExact + v29:String = GuardType v12, String + v30:BoolExact = CCall String#==@0x1038, v28, v29 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v30 + "); + } + + #[test] + fn test_optimize_string_eq_string() { + eval(r#" + class C < String + end + def test(l, r) = l == r + test(C.new("a"), C.new("b")) + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(C@0x1000, ==@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(C@0x1000) + v28:StringSubclass[class_exact:C] = GuardType v11, StringSubclass[class_exact:C] + v29:String = GuardType v12, String + v30:BoolExact = CCall String#==@0x1038, v28, v29 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v30 + "); + } + + #[test] + fn test_optimize_stringexact_eq_string() { + eval(r#" + class C < String + end + def test(l, r) = l == r + test("a", C.new("b")) + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(String@0x1000, ==@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(String@0x1000) + v28:StringExact = GuardType v11, StringExact + v29:String = GuardType v12, String + v30:BoolExact = CCall String#==@0x1038, v28, v29 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v30 + "); + } + + #[test] + fn test_optimize_stringexact_eqq_stringexact() { + eval(r#" + def test(l, r) = l === r + test("a", "b") + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(String@0x1000, ===@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(String@0x1000) + v26:StringExact = GuardType v11, StringExact + v27:String = GuardType v12, String + v28:BoolExact = CCall String#==@0x1038, v26, v27 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v28 + "); + } + + #[test] + fn test_optimize_string_eqq_string() { + eval(r#" + class C < String + end + def test(l, r) = l === r + test(C.new("a"), C.new("b")) + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(C@0x1000, ===@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(C@0x1000) + v26:StringSubclass[class_exact:C] = GuardType v11, StringSubclass[class_exact:C] + v27:String = GuardType v12, String + v28:BoolExact = CCall String#==@0x1038, v26, v27 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v28 + "); + } + + #[test] + fn test_optimize_stringexact_eqq_string() { + eval(r#" + class C < String + end + def test(l, r) = l === r + test("a", C.new("b")) + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:4: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + v2:BasicObject = GetLocal l0, SP@5 + v3:BasicObject = GetLocal l0, SP@4 + Jump bb2(v1, v2, v3) + bb1(v6:BasicObject, v7:BasicObject, v8:BasicObject): + EntryPoint JIT(0) + Jump bb2(v6, v7, v8) + bb2(v10:BasicObject, v11:BasicObject, v12:BasicObject): + PatchPoint MethodRedefined(String@0x1000, ===@0x1008, cme:0x1010) + PatchPoint NoSingletonClass(String@0x1000) + v26:StringExact = GuardType v11, StringExact + v27:String = GuardType v12, String + v28:BoolExact = CCall String#==@0x1038, v26, v27 + IncrCounter inline_cfunc_optimized_send_count + CheckInterrupts + Return v28 + "); + } }