diff --git a/jit.c b/jit.c index 43c932e5a0..dd28bd6ad0 100644 --- a/jit.c +++ b/jit.c @@ -765,3 +765,5 @@ rb_yarv_str_eql_internal(VALUE str1, VALUE str2) // We wrap this since it's static inline return rb_str_eql_internal(str1, str2); } + +void rb_jit_str_concat_codepoint(VALUE str, VALUE codepoint); diff --git a/string.c b/string.c index 2dec3a11e6..c794b36748 100644 --- a/string.c +++ b/string.c @@ -12580,9 +12580,9 @@ rb_enc_interned_str_cstr(const char *ptr, rb_encoding *enc) return rb_enc_interned_str(ptr, strlen(ptr), enc); } -#if USE_YJIT +#if USE_YJIT || USE_ZJIT void -rb_yjit_str_concat_codepoint(VALUE str, VALUE codepoint) +rb_jit_str_concat_codepoint(VALUE str, VALUE codepoint) { if (RB_LIKELY(ENCODING_GET_INLINED(str) == rb_ascii8bit_encindex())) { ssize_t code = RB_NUM2SSIZE(codepoint); diff --git a/yjit/bindgen/src/main.rs b/yjit/bindgen/src/main.rs index 67a461cd16..9c28177a60 100644 --- a/yjit/bindgen/src/main.rs +++ b/yjit/bindgen/src/main.rs @@ -273,7 +273,7 @@ fn main() { .allowlist_function("rb_yjit_sendish_sp_pops") .allowlist_function("rb_yjit_invokeblock_sp_pops") .allowlist_function("rb_yjit_set_exception_return") - .allowlist_function("rb_yjit_str_concat_codepoint") + .allowlist_function("rb_jit_str_concat_codepoint") .allowlist_type("rstring_offsets") .allowlist_function("rb_assert_holding_vm_lock") .allowlist_function("rb_jit_shape_too_complex_p") diff --git a/yjit/src/codegen.rs b/yjit/src/codegen.rs index a426ad0773..50762c64d3 100644 --- a/yjit/src/codegen.rs +++ b/yjit/src/codegen.rs @@ -6244,7 +6244,7 @@ fn jit_rb_str_concat_codepoint( guard_object_is_fixnum(jit, asm, codepoint, StackOpnd(0)); - asm.ccall(rb_yjit_str_concat_codepoint as *const u8, vec![recv, codepoint]); + asm.ccall(rb_jit_str_concat_codepoint as *const u8, vec![recv, codepoint]); // The receiver is the return value, so we only need to pop the codepoint argument off the stack. // We can reuse the receiver slot in the stack as the return value. diff --git a/yjit/src/cruby.rs b/yjit/src/cruby.rs index cfaf48c3f0..5562f73be2 100644 --- a/yjit/src/cruby.rs +++ b/yjit/src/cruby.rs @@ -123,7 +123,6 @@ extern "C" { pub fn rb_float_new(d: f64) -> VALUE; pub fn rb_hash_empty_p(hash: VALUE) -> VALUE; - pub fn rb_yjit_str_concat_codepoint(str: VALUE, codepoint: VALUE); pub fn rb_str_setbyte(str: VALUE, index: VALUE, value: VALUE) -> VALUE; pub fn rb_vm_splat_array(flag: VALUE, ary: VALUE) -> VALUE; pub fn rb_vm_concat_array(ary1: VALUE, ary2st: VALUE) -> VALUE; diff --git a/yjit/src/cruby_bindings.inc.rs b/yjit/src/cruby_bindings.inc.rs index 66d4e5111d..04ca3494ac 100644 --- a/yjit/src/cruby_bindings.inc.rs +++ b/yjit/src/cruby_bindings.inc.rs @@ -1277,4 +1277,5 @@ extern "C" { ); pub fn rb_jit_fix_mod_fix(recv: VALUE, obj: VALUE) -> VALUE; pub fn rb_yarv_str_eql_internal(str1: VALUE, str2: VALUE) -> VALUE; + pub fn rb_jit_str_concat_codepoint(str_: VALUE, codepoint: VALUE); } diff --git a/zjit/bindgen/src/main.rs b/zjit/bindgen/src/main.rs index fe082504b8..6659049242 100644 --- a/zjit/bindgen/src/main.rs +++ b/zjit/bindgen/src/main.rs @@ -278,6 +278,7 @@ fn main() { .allowlist_function("rb_jit_get_page_size") .allowlist_function("rb_jit_array_len") .allowlist_function("rb_jit_iseq_builtin_attrs") + .allowlist_function("rb_jit_str_concat_codepoint") .allowlist_function("rb_zjit_iseq_inspect") .allowlist_function("rb_zjit_iseq_insn_set") .allowlist_function("rb_zjit_local_id") diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index f77b8cc4bf..6fc8566469 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -368,6 +368,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio &Insn::StringGetbyteFixnum { string, index } => gen_string_getbyte_fixnum(asm, opnd!(string), opnd!(index)), Insn::StringSetbyteFixnum { string, index, value } => gen_string_setbyte_fixnum(asm, opnd!(string), opnd!(index), opnd!(value)), Insn::StringAppend { recv, other, state } => gen_string_append(jit, asm, opnd!(recv), opnd!(other), &function.frame_state(*state)), + Insn::StringAppendCodepoint { recv, other, state } => gen_string_append_codepoint(jit, asm, opnd!(recv), opnd!(other), &function.frame_state(*state)), Insn::StringIntern { val, state } => gen_intern(asm, opnd!(val), &function.frame_state(*state)), Insn::ToRegexp { opt, values, state } => gen_toregexp(jit, asm, *opt, opnds!(values), &function.frame_state(*state)), Insn::Param => unreachable!("block.insns should not have Insn::Param"), @@ -2495,6 +2496,11 @@ fn gen_string_append(jit: &mut JITState, asm: &mut Assembler, string: Opnd, val: asm_ccall!(asm, rb_str_buf_append, string, val) } +fn gen_string_append_codepoint(jit: &mut JITState, asm: &mut Assembler, string: Opnd, val: Opnd, state: &FrameState) -> Opnd { + gen_prepare_non_leaf_call(jit, asm, state); + asm_ccall!(asm, rb_jit_str_concat_codepoint, string, val) +} + /// Generate a JIT entry that just increments exit_compilation_failure and exits fn gen_compile_error_counter(cb: &mut CodeBlock, compile_error: &CompileError) -> Result { let mut asm = Assembler::new(); diff --git a/zjit/src/cruby.rs b/zjit/src/cruby.rs index 9070cb38be..72c44ccc6e 100644 --- a/zjit/src/cruby.rs +++ b/zjit/src/cruby.rs @@ -130,7 +130,6 @@ unsafe extern "C" { pub fn rb_float_new(d: f64) -> VALUE; pub fn rb_hash_empty_p(hash: VALUE) -> VALUE; - pub fn rb_yjit_str_concat_codepoint(str: VALUE, codepoint: VALUE); pub fn rb_str_setbyte(str: VALUE, index: VALUE, value: VALUE) -> VALUE; pub fn rb_str_getbyte(str: VALUE, index: VALUE) -> VALUE; pub fn rb_vm_splat_array(flag: VALUE, ary: VALUE) -> VALUE; diff --git a/zjit/src/cruby_bindings.inc.rs b/zjit/src/cruby_bindings.inc.rs index aaecfa2f89..2256b7e32d 100644 --- a/zjit/src/cruby_bindings.inc.rs +++ b/zjit/src/cruby_bindings.inc.rs @@ -2224,4 +2224,5 @@ unsafe extern "C" { end: *mut ::std::os::raw::c_void, ); pub fn rb_yarv_str_eql_internal(str1: VALUE, str2: VALUE) -> VALUE; + pub fn rb_jit_str_concat_codepoint(str_: VALUE, codepoint: VALUE); } diff --git a/zjit/src/cruby_methods.rs b/zjit/src/cruby_methods.rs index f7bf92b31a..f86d383876 100644 --- a/zjit/src/cruby_methods.rs +++ b/zjit/src/cruby_methods.rs @@ -425,10 +425,15 @@ fn inline_string_append(fun: &mut hir::Function, block: hir::BlockId, recv: hir: let recv = fun.coerce_to(block, recv, types::StringExact, state); let other = fun.coerce_to(block, other, types::String, state); let _ = fun.push_insn(block, hir::Insn::StringAppend { recv, other, state }); - Some(recv) - } else { - None + return Some(recv); } + if fun.likely_a(recv, types::StringExact, state) && fun.likely_a(other, types::Fixnum, state) { + let recv = fun.coerce_to(block, recv, types::StringExact, state); + let other = fun.coerce_to(block, other, types::Fixnum, state); + let _ = fun.push_insn(block, hir::Insn::StringAppendCodepoint { recv, other, state }); + return Some(recv); + } + None } fn inline_string_eq(fun: &mut hir::Function, block: hir::BlockId, recv: hir::InsnId, args: &[hir::InsnId], state: hir::InsnId) -> Option { diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 4bbbbd150e..6604c52a82 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -652,6 +652,7 @@ pub enum Insn { StringGetbyteFixnum { string: InsnId, index: InsnId }, StringSetbyteFixnum { string: InsnId, index: InsnId, value: InsnId }, StringAppend { recv: InsnId, other: InsnId, state: InsnId }, + StringAppendCodepoint { recv: InsnId, other: InsnId, state: InsnId }, /// Combine count stack values into a regexp ToRegexp { opt: usize, values: Vec, state: InsnId }, @@ -1124,6 +1125,9 @@ impl<'a> std::fmt::Display for InsnPrinter<'a> { Insn::StringAppend { recv, other, .. } => { write!(f, "StringAppend {recv}, {other}") } + Insn::StringAppendCodepoint { recv, other, .. } => { + write!(f, "StringAppendCodepoint {recv}, {other}") + } Insn::ToRegexp { values, opt, .. } => { write!(f, "ToRegexp")?; let mut prefix = " "; @@ -1814,6 +1818,7 @@ impl Function { &StringGetbyteFixnum { string, index } => StringGetbyteFixnum { string: find!(string), index: find!(index) }, &StringSetbyteFixnum { string, index, value } => StringSetbyteFixnum { string: find!(string), index: find!(index), value: find!(value) }, &StringAppend { recv, other, state } => StringAppend { recv: find!(recv), other: find!(other), state: find!(state) }, + &StringAppendCodepoint { recv, other, state } => StringAppendCodepoint { recv: find!(recv), other: find!(other), state: find!(state) }, &ToRegexp { opt, ref values, state } => ToRegexp { opt, values: find_vec!(values), state }, &Test { val } => Test { val: find!(val) }, &IsNil { val } => IsNil { val: find!(val) }, @@ -2032,6 +2037,7 @@ impl Function { Insn::StringGetbyteFixnum { .. } => types::Fixnum.union(types::NilClass), Insn::StringSetbyteFixnum { .. } => types::Fixnum, Insn::StringAppend { .. } => types::StringExact, + Insn::StringAppendCodepoint { .. } => types::StringExact, Insn::ToRegexp { .. } => types::RegexpExact, Insn::NewArray { .. } => types::ArrayExact, Insn::ArrayDup { .. } => types::ArrayExact, @@ -3564,7 +3570,9 @@ impl Function { worklist.push_back(index); worklist.push_back(value); } - &Insn::StringAppend { recv, other, state } => { + &Insn::StringAppend { recv, other, state } + | &Insn::StringAppendCodepoint { recv, other, state } + => { worklist.push_back(recv); worklist.push_back(other); worklist.push_back(state); @@ -4328,6 +4336,10 @@ impl Function { self.assert_subtype(insn_id, recv, types::StringExact)?; self.assert_subtype(insn_id, other, types::String) } + Insn::StringAppendCodepoint { recv, other, .. } => { + self.assert_subtype(insn_id, recv, types::StringExact)?; + self.assert_subtype(insn_id, other, types::Fixnum) + } // Instructions with Array operands Insn::ArrayDup { val, .. } => self.assert_subtype(insn_id, val, types::ArrayExact), Insn::ArrayExtend { left, right, .. } => { diff --git a/zjit/src/hir/opt_tests.rs b/zjit/src/hir/opt_tests.rs index 90675965c2..60f5814973 100644 --- a/zjit/src/hir/opt_tests.rs +++ b/zjit/src/hir/opt_tests.rs @@ -5969,7 +5969,7 @@ mod hir_opt_tests { v13:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) v26:Fixnum = GuardType v9, Fixnum PatchPoint MethodRedefined(Integer@0x1008, to_s@0x1010, cme:0x1018) - v30:StringExact = CCallVariadic Integer#to_s@0x1040, v26 + v30:StringExact = CCallVariadic v26, :Integer#to_s@0x1040 v21:StringExact = StringConcat v13, v30 CheckInterrupts Return v21 @@ -6854,9 +6854,8 @@ mod hir_opt_tests { "); } - // TODO: This should be inlined just as in the interpreter #[test] - fn test_optimize_string_append_non_string() { + fn test_optimize_string_append_codepoint() { eval(r#" def test(x, y) = x << y test("iron", 4) @@ -6876,9 +6875,11 @@ mod hir_opt_tests { PatchPoint MethodRedefined(String@0x1000, <<@0x1008, cme:0x1010) PatchPoint NoSingletonClass(String@0x1000) v27:StringExact = GuardType v11, StringExact - v28:BasicObject = CCallWithFrame v27, :String#<<@0x1038, v12 + v28:Fixnum = GuardType v12, Fixnum + v29:StringExact = StringAppendCodepoint v27, v28 + IncrCounter inline_cfunc_optimized_send_count CheckInterrupts - Return v28 + Return v27 "); } @@ -6942,6 +6943,32 @@ mod hir_opt_tests { "); } + #[test] + fn test_dont_optimize_string_append_non_string() { + eval(r#" + def test = "iron" << :a + "#); + assert_snapshot!(hir_string("test"), @r" + fn test@:2: + bb0(): + EntryPoint interpreter + v1:BasicObject = LoadSelf + Jump bb2(v1) + bb1(v4:BasicObject): + EntryPoint JIT(0) + Jump bb2(v4) + bb2(v6:BasicObject): + v10:StringExact[VALUE(0x1000)] = Const Value(VALUE(0x1000)) + v11:StringExact = StringCopy v10 + v13:StaticSymbol[:a] = Const Value(VALUE(0x1008)) + PatchPoint MethodRedefined(String@0x1010, <<@0x1018, cme:0x1020) + PatchPoint NoSingletonClass(String@0x1010) + v24:BasicObject = CCallWithFrame v11, :String#<<@0x1048, v13 + CheckInterrupts + Return v24 + "); + } + #[test] fn test_dont_optimize_when_passing_too_many_args() { eval(r#"