ZJIT: Use a shared trampoline across all ISEQs (#15042)

This commit is contained in:
Takashi Kokubun 2025-11-04 16:09:13 -08:00 committed by GitHub
parent be905b2e58
commit bd3b44cb0a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
Notes: git 2025-11-05 00:09:39 +00:00
Merged-By: k0kubun <takashikkbn@gmail.com>
10 changed files with 132 additions and 70 deletions

69
vm.c
View File

@ -503,7 +503,7 @@ rb_yjit_threshold_hit(const rb_iseq_t *iseq, uint64_t entry_calls)
#define rb_yjit_threshold_hit(iseq, entry_calls) false
#endif
#if USE_YJIT || USE_ZJIT
#if USE_YJIT
// Generate JIT code that supports the following kinds of ISEQ entries:
// * The first ISEQ on vm_exec (e.g. <main>, or Ruby methods/blocks
// called by a C method). The current frame has VM_FRAME_FLAG_FINISH.
@ -513,13 +513,32 @@ rb_yjit_threshold_hit(const rb_iseq_t *iseq, uint64_t entry_calls)
// The current frame doesn't have VM_FRAME_FLAG_FINISH. The current
// vm_exec does NOT stop whether JIT code returns Qundef or not.
static inline rb_jit_func_t
jit_compile(rb_execution_context_t *ec)
yjit_compile(rb_execution_context_t *ec)
{
const rb_iseq_t *iseq = ec->cfp->iseq;
struct rb_iseq_constant_body *body = ISEQ_BODY(iseq);
// Increment the ISEQ's call counter and trigger JIT compilation if not compiled
if (body->jit_entry == NULL) {
body->jit_entry_calls++;
if (rb_yjit_threshold_hit(iseq, body->jit_entry_calls)) {
rb_yjit_compile_iseq(iseq, ec, false);
}
}
return body->jit_entry;
}
#else
# define yjit_compile(ec) ((rb_jit_func_t)0)
#endif
#if USE_ZJIT
if (body->jit_entry == NULL && rb_zjit_enabled_p) {
static inline rb_jit_func_t
zjit_compile(rb_execution_context_t *ec)
{
const rb_iseq_t *iseq = ec->cfp->iseq;
struct rb_iseq_constant_body *body = ISEQ_BODY(iseq);
if (body->jit_entry == NULL) {
body->jit_entry_calls++;
// At profile-threshold, rewrite some of the YARV instructions
@ -533,38 +552,38 @@ jit_compile(rb_execution_context_t *ec)
rb_zjit_compile_iseq(iseq, false);
}
}
#endif
#if USE_YJIT
// Increment the ISEQ's call counter and trigger JIT compilation if not compiled
if (body->jit_entry == NULL && rb_yjit_enabled_p) {
body->jit_entry_calls++;
if (rb_yjit_threshold_hit(iseq, body->jit_entry_calls)) {
rb_yjit_compile_iseq(iseq, ec, false);
}
}
#endif
return body->jit_entry;
}
#else
# define zjit_compile(ec) ((rb_jit_func_t)0)
#endif
// Execute JIT code compiled by jit_compile()
// Execute JIT code compiled by yjit_compile() or zjit_compile()
static inline VALUE
jit_exec(rb_execution_context_t *ec)
{
rb_jit_func_t func = jit_compile(ec);
if (func) {
// Call the JIT code
return func(ec, ec->cfp);
}
else {
#if USE_YJIT
if (rb_yjit_enabled_p) {
rb_jit_func_t func = yjit_compile(ec);
if (func) {
return func(ec, ec->cfp);
}
return Qundef;
}
}
#else
# define jit_compile(ec) ((rb_jit_func_t)0)
# define jit_exec(ec) Qundef
#endif
#if USE_ZJIT
void *zjit_entry = rb_zjit_entry;
if (zjit_entry) {
rb_jit_func_t func = zjit_compile(ec);
if (func) {
return ((rb_zjit_func_t)zjit_entry)(ec, ec->cfp, func);
}
}
#endif
return Qundef;
}
#if USE_YJIT
// Generate JIT code that supports the following kind of ISEQ entry:
// * The first ISEQ pushed by vm_exec_handle_exception. The frame would

View File

@ -398,6 +398,7 @@ enum rb_builtin_attr {
};
typedef VALUE (*rb_jit_func_t)(struct rb_execution_context_struct *, struct rb_control_frame_struct *);
typedef VALUE (*rb_zjit_func_t)(struct rb_execution_context_struct *, struct rb_control_frame_struct *, rb_jit_func_t);
struct rb_iseq_constant_body {
enum rb_iseq_type type;

View File

@ -175,11 +175,22 @@ default: \
// Run the JIT from the interpreter
#define JIT_EXEC(ec, val) do { \
rb_jit_func_t func; \
/* don't run tailcalls since that breaks FINISH */ \
if (UNDEF_P(val) && GET_CFP() != ec->cfp && (func = jit_compile(ec))) { \
val = func(ec, ec->cfp); \
if (ec->tag->state) THROW_EXCEPTION(val); \
if (UNDEF_P(val) && GET_CFP() != ec->cfp) { \
rb_zjit_func_t zjit_entry; \
if (rb_yjit_enabled_p) { \
rb_jit_func_t func = yjit_compile(ec); \
if (func) { \
val = func(ec, ec->cfp); \
if (ec->tag->state) THROW_EXCEPTION(val); \
} \
} \
else if ((zjit_entry = rb_zjit_entry)) { \
rb_jit_func_t func = zjit_compile(ec); \
if (func) { \
val = zjit_entry(ec, ec->cfp, func); \
} \
} \
} \
} while (0)

6
zjit.h
View File

@ -10,7 +10,7 @@
#endif
#if USE_ZJIT
extern bool rb_zjit_enabled_p;
extern void *rb_zjit_entry;
extern uint64_t rb_zjit_call_threshold;
extern uint64_t rb_zjit_profile_threshold;
void rb_zjit_compile_iseq(const rb_iseq_t *iseq, bool jit_exception);
@ -29,7 +29,7 @@ void rb_zjit_before_ractor_spawn(void);
void rb_zjit_tracing_invalidate_all(void);
void rb_zjit_invalidate_no_singleton_class(VALUE klass);
#else
#define rb_zjit_enabled_p false
#define rb_zjit_entry 0
static inline void rb_zjit_compile_iseq(const rb_iseq_t *iseq, bool jit_exception) {}
static inline void rb_zjit_profile_insn(uint32_t insn, rb_execution_context_t *ec) {}
static inline void rb_zjit_profile_enable(const rb_iseq_t *iseq) {}
@ -42,4 +42,6 @@ static inline void rb_zjit_tracing_invalidate_all(void) {}
static inline void rb_zjit_invalidate_no_singleton_class(VALUE klass) {}
#endif // #if USE_ZJIT
#define rb_zjit_enabled_p (rb_zjit_entry != 0)
#endif // #ifndef ZJIT_H

View File

@ -1428,17 +1428,25 @@ impl Assembler {
}
},
Insn::CCall { fptr, .. } => {
// The offset to the call target in bytes
let src_addr = cb.get_write_ptr().raw_ptr(cb) as i64;
let dst_addr = *fptr as i64;
match fptr {
Opnd::UImm(fptr) => {
// The offset to the call target in bytes
let src_addr = cb.get_write_ptr().raw_ptr(cb) as i64;
let dst_addr = *fptr as i64;
// Use BL if the offset is short enough to encode as an immediate.
// Otherwise, use BLR with a register.
if b_offset_fits_bits((dst_addr - src_addr) / 4) {
bl(cb, InstructionOffset::from_bytes((dst_addr - src_addr) as i32));
} else {
emit_load_value(cb, Self::EMIT_OPND, dst_addr as u64);
blr(cb, Self::EMIT_OPND);
// Use BL if the offset is short enough to encode as an immediate.
// Otherwise, use BLR with a register.
if b_offset_fits_bits((dst_addr - src_addr) / 4) {
bl(cb, InstructionOffset::from_bytes((dst_addr - src_addr) as i32));
} else {
emit_load_value(cb, Self::EMIT_OPND, dst_addr as u64);
blr(cb, Self::EMIT_OPND);
}
}
Opnd::Reg(_) => {
blr(cb, fptr.into());
}
_ => unreachable!("unsupported ccall fptr: {fptr:?}")
}
},
Insn::CRet { .. } => {

View File

@ -386,7 +386,9 @@ pub enum Insn {
// C function call with N arguments (variadic)
CCall {
opnds: Vec<Opnd>,
fptr: *const u8,
/// The function pointer to be called. This should be Opnd::const_ptr
/// (Opnd::UImm) in most cases. gen_entry_trampoline() uses Opnd::Reg.
fptr: Opnd,
/// Optional PosMarker to remember the start address of the C call.
/// It's embedded here to insert the PosMarker after push instructions
/// that are split from this CCall on alloc_regs().
@ -1989,11 +1991,20 @@ impl Assembler {
pub fn ccall(&mut self, fptr: *const u8, opnds: Vec<Opnd>) -> Opnd {
let canary_opnd = self.set_stack_canary();
let out = self.new_vreg(Opnd::match_num_bits(&opnds));
let fptr = Opnd::const_ptr(fptr);
self.push_insn(Insn::CCall { fptr, opnds, start_marker: None, end_marker: None, out });
self.clear_stack_canary(canary_opnd);
out
}
/// Call a C function stored in a register
pub fn ccall_reg(&mut self, fptr: Opnd, num_bits: u8) -> Opnd {
assert!(matches!(fptr, Opnd::Reg(_)), "ccall_reg must be called with Opnd::Reg: {fptr:?}");
let out = self.new_vreg(num_bits);
self.push_insn(Insn::CCall { fptr, opnds: vec![], start_marker: None, end_marker: None, out });
out
}
/// Call a C function with PosMarkers. This is used for recording the start and end
/// addresses of the C call and rewriting it with a different function address later.
pub fn ccall_with_pos_markers(
@ -2005,7 +2016,7 @@ impl Assembler {
) -> Opnd {
let out = self.new_vreg(Opnd::match_num_bits(&opnds));
self.push_insn(Insn::CCall {
fptr,
fptr: Opnd::const_ptr(fptr),
opnds,
start_marker: Some(Rc::new(start_marker)),
end_marker: Some(Rc::new(end_marker)),

View File

@ -863,7 +863,15 @@ impl Assembler {
// C function call
Insn::CCall { fptr, .. } => {
call_ptr(cb, RAX, *fptr);
match fptr {
Opnd::UImm(fptr) => {
call_ptr(cb, RAX, *fptr as *const u8);
}
Opnd::Reg(_) => {
call(cb, fptr.into());
}
_ => unreachable!("unsupported ccall fptr: {fptr:?}")
}
},
Insn::CRet(opnd) => {

View File

@ -106,8 +106,7 @@ pub extern "C" fn rb_zjit_iseq_gen_entry_point(iseq: IseqPtr, jit_exception: boo
}
// Always mark the code region executable if asm.compile() has been used.
// We need to do this even if code_ptr is None because, whether gen_entry()
// fails or not, gen_iseq() may have already used asm.compile().
// We need to do this even if code_ptr is None because gen_iseq() may have already used asm.compile().
cb.mark_all_executable();
code_ptr.map_or(std::ptr::null(), |ptr| ptr.raw_ptr(cb))
@ -131,10 +130,7 @@ fn gen_iseq_entry_point(cb: &mut CodeBlock, iseq: IseqPtr, jit_exception: bool)
debug!("{err:?}: gen_iseq failed: {}", iseq_get_location(iseq, 0));
})?;
// Compile an entry point to the JIT code
gen_entry(cb, iseq, start_ptr).inspect_err(|err| {
debug!("{err:?}: gen_entry failed: {}", iseq_get_location(iseq, 0));
})
Ok(start_ptr)
}
/// Stub a branch for a JIT-to-JIT call
@ -170,14 +166,16 @@ fn register_with_perf(iseq_name: String, start_ptr: usize, code_size: usize) {
};
}
/// Compile a JIT entry
fn gen_entry(cb: &mut CodeBlock, iseq: IseqPtr, function_ptr: CodePtr) -> Result<CodePtr, CompileError> {
/// Compile a shared JIT entry trampoline
pub fn gen_entry_trampoline(cb: &mut CodeBlock) -> Result<CodePtr, CompileError> {
// Set up registers for CFP, EC, SP, and basic block arguments
let mut asm = Assembler::new();
gen_entry_prologue(&mut asm, iseq);
gen_entry_prologue(&mut asm);
// Jump to the first block using a call instruction
asm.ccall(function_ptr.raw_ptr(cb), vec![]);
// Jump to the first block using a call instruction. This trampoline is used
// as rb_zjit_func_t in jit_exec(), which takes (EC, CFP, rb_jit_func_t).
// So C_ARG_OPNDS[2] is rb_jit_func_t, which is (EC, CFP) -> VALUE.
asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS);
// Restore registers for CFP, EC, and SP after use
asm_comment!(asm, "return to the interpreter");
@ -190,8 +188,7 @@ fn gen_entry(cb: &mut CodeBlock, iseq: IseqPtr, function_ptr: CodePtr) -> Result
let start_ptr = code_ptr.raw_addr(cb);
let end_ptr = cb.get_write_ptr().raw_addr(cb);
let code_size = end_ptr - start_ptr;
let iseq_name = iseq_get_location(iseq, 0);
register_with_perf(format!("entry for {iseq_name}"), start_ptr, code_size);
register_with_perf("ZJIT entry trampoline".into(), start_ptr, code_size);
}
Ok(code_ptr)
}
@ -990,8 +987,8 @@ fn gen_load_field(asm: &mut Assembler, recv: Opnd, id: ID, offset: i32) -> Opnd
}
/// Compile an interpreter entry block to be inserted into an ISEQ
fn gen_entry_prologue(asm: &mut Assembler, iseq: IseqPtr) {
asm_comment!(asm, "ZJIT entry point: {}", iseq_get_location(iseq, 0));
fn gen_entry_prologue(asm: &mut Assembler) {
asm_comment!(asm, "ZJIT entry trampoline");
// Save the registers we'll use for CFP, EP, SP
asm.frame_setup(lir::JIT_PRESERVED_REGS);

View File

@ -1071,7 +1071,7 @@ pub use manual_defs::*;
pub mod test_utils {
use std::{ptr::null, sync::Once};
use crate::{options::{rb_zjit_call_threshold, rb_zjit_prepare_options, set_call_threshold, DEFAULT_CALL_THRESHOLD}, state::{rb_zjit_enabled_p, ZJITState}};
use crate::{options::{rb_zjit_call_threshold, rb_zjit_prepare_options, set_call_threshold, DEFAULT_CALL_THRESHOLD}, state::{rb_zjit_entry, ZJITState}};
use super::*;
@ -1114,10 +1114,10 @@ pub mod test_utils {
}
// Set up globals for convenience
ZJITState::init();
let zjit_entry = ZJITState::init();
// Enable zjit_* instructions
unsafe { rb_zjit_enabled_p = true; }
unsafe { rb_zjit_entry = zjit_entry; }
}
/// Make sure the Ruby VM is set up and run a given callback with rb_protect()

View File

@ -1,6 +1,6 @@
//! Runtime state of ZJIT.
use crate::codegen::{gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline};
use crate::codegen::{gen_entry_trampoline, gen_exit_trampoline, gen_exit_trampoline_with_counter, gen_function_stub_hit_trampoline};
use crate::cruby::{self, rb_bug_panic_hook, rb_vm_insn_count, EcPtr, Qnil, rb_vm_insn_addr2opcode, rb_profile_frames, VALUE, VM_INSTRUCTION_SIZE, size_t, rb_gc_mark};
use crate::cruby_methods;
use crate::invariants::Invariants;
@ -9,14 +9,16 @@ use crate::options::get_option;
use crate::stats::{Counters, InsnCounters, SideExitLocations};
use crate::virtualmem::CodePtr;
use std::collections::HashMap;
use std::ptr::null;
/// Shared trampoline to enter ZJIT. Not null when ZJIT is enabled.
#[allow(non_upper_case_globals)]
#[unsafe(no_mangle)]
pub static mut rb_zjit_enabled_p: bool = false;
pub static mut rb_zjit_entry: *const u8 = null();
/// Like rb_zjit_enabled_p, but for Rust code.
pub fn zjit_enabled_p() -> bool {
unsafe { rb_zjit_enabled_p }
unsafe { rb_zjit_entry != null() }
}
/// Global state needed for code generation
@ -65,8 +67,8 @@ pub struct ZJITState {
static mut ZJIT_STATE: Option<ZJITState> = None;
impl ZJITState {
/// Initialize the ZJIT globals
pub fn init() {
/// Initialize the ZJIT globals. Return the address of the JIT entry trampoline.
pub fn init() -> *const u8 {
let mut cb = {
use crate::options::*;
use crate::virtualmem::*;
@ -79,6 +81,7 @@ impl ZJITState {
CodeBlock::new(mem_block.clone(), get_option!(dump_disasm))
};
let entry_trampoline = gen_entry_trampoline(&mut cb).unwrap().raw_ptr(&cb);
let exit_trampoline = gen_exit_trampoline(&mut cb).unwrap();
let function_stub_hit_trampoline = gen_function_stub_hit_trampoline(&mut cb).unwrap();
@ -114,6 +117,8 @@ impl ZJITState {
let code_ptr = gen_exit_trampoline_with_counter(cb, exit_trampoline).unwrap();
ZJITState::get_instance().exit_trampoline_with_counter = code_ptr;
}
entry_trampoline
}
/// Return true if zjit_state has been initialized
@ -252,7 +257,7 @@ pub extern "C" fn rb_zjit_init() {
let result = std::panic::catch_unwind(|| {
// Initialize ZJIT states
cruby::ids::init();
ZJITState::init();
let zjit_entry = ZJITState::init();
// Install a panic hook for ZJIT
rb_bug_panic_hook();
@ -261,8 +266,8 @@ pub extern "C" fn rb_zjit_init() {
unsafe { rb_vm_insn_count = 0; }
// ZJIT enabled and initialized successfully
assert!(unsafe{ !rb_zjit_enabled_p });
unsafe { rb_zjit_enabled_p = true; }
assert!(unsafe{ rb_zjit_entry == null() });
unsafe { rb_zjit_entry = zjit_entry; }
});
if result.is_err() {