ZJIT: Support inference of ModuleExact type

This commit is contained in:
Stan Lo 2025-07-04 21:41:32 +01:00 committed by Max Bernstein
parent f5acefca44
commit 6c20082852
5 changed files with 73 additions and 31 deletions

View File

@ -158,6 +158,7 @@ unsafe extern "C" {
pub fn rb_vm_ic_hit_p(ic: IC, reg_ep: *const VALUE) -> bool;
pub fn rb_vm_stack_canary() -> VALUE;
pub fn rb_vm_push_cfunc_frame(cme: *const rb_callable_method_entry_t, recv_idx: c_int);
pub fn rb_obj_class(klass: VALUE) -> VALUE;
}
// Renames

View File

@ -5921,7 +5921,7 @@ mod opt_tests {
}
#[test]
fn module_instances_not_class_exact() {
fn module_instances_are_module_exact() {
eval("
def test = [Enumerable, Kernel]
test # Warm the constant cache
@ -5931,15 +5931,33 @@ mod opt_tests {
bb0(v0:BasicObject):
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, Enumerable)
v11:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008))
v11:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1010, Kernel)
v14:BasicObject[VALUE(0x1018)] = Const Value(VALUE(0x1018))
v14:ModuleExact[VALUE(0x1018)] = Const Value(VALUE(0x1018))
v7:ArrayExact = NewArray v11, v14
Return v7
"#]]);
}
#[test]
fn module_subclasses_are_not_module_exact() {
eval("
class ModuleSubclass < Module; end
MY_MODULE = ModuleSubclass.new
def test = MY_MODULE
test # Warm the constant cache
");
assert_optimized_method_hir("test", expect![[r#"
fn test:
bb0(v0:BasicObject):
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, MY_MODULE)
v7:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008))
Return v7
"#]]);
}
#[test]
fn eliminate_array_size() {
eval("
@ -6067,7 +6085,7 @@ mod opt_tests {
bb0(v0:BasicObject):
PatchPoint SingleRactorMode
PatchPoint StableConstantNames(0x1000, Kernel)
v7:BasicObject[VALUE(0x1008)] = Const Value(VALUE(0x1008))
v7:ModuleExact[VALUE(0x1008)] = Const Value(VALUE(0x1008))
Return v7
"#]]);
}

View File

@ -75,6 +75,7 @@ base_type "Range"
base_type "Set"
base_type "Regexp"
base_type "Class"
base_type "Module"
(integer, integer_exact) = base_type "Integer"
# CRuby partitions Integer into immediate and non-immediate variants.

View File

@ -9,7 +9,7 @@ mod bits {
pub const BasicObjectSubclass: u64 = 1u64 << 3;
pub const Bignum: u64 = 1u64 << 4;
pub const BoolExact: u64 = FalseClassExact | TrueClassExact;
pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | ClassExact | FalseClassExact | FloatExact | HashExact | IntegerExact | NilClassExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | SymbolExact | TrueClassExact;
pub const BuiltinExact: u64 = ArrayExact | BasicObjectExact | ClassExact | FalseClassExact | FloatExact | HashExact | IntegerExact | ModuleExact | NilClassExact | ObjectExact | RangeExact | RegexpExact | SetExact | StringExact | SymbolExact | TrueClassExact;
pub const CBool: u64 = 1u64 << 5;
pub const CDouble: u64 = 1u64 << 6;
pub const CInt: u64 = CSigned | CUnsigned;
@ -48,35 +48,38 @@ mod bits {
pub const Integer: u64 = IntegerExact | IntegerSubclass;
pub const IntegerExact: u64 = Bignum | Fixnum;
pub const IntegerSubclass: u64 = 1u64 << 29;
pub const Module: u64 = ModuleExact | ModuleSubclass;
pub const ModuleExact: u64 = 1u64 << 30;
pub const ModuleSubclass: u64 = 1u64 << 31;
pub const NilClass: u64 = NilClassExact | NilClassSubclass;
pub const NilClassExact: u64 = 1u64 << 30;
pub const NilClassSubclass: u64 = 1u64 << 31;
pub const Object: u64 = Array | Class | FalseClass | Float | Hash | Integer | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass;
pub const ObjectExact: u64 = 1u64 << 32;
pub const ObjectSubclass: u64 = 1u64 << 33;
pub const NilClassExact: u64 = 1u64 << 32;
pub const NilClassSubclass: u64 = 1u64 << 33;
pub const Object: u64 = Array | Class | FalseClass | Float | Hash | Integer | Module | NilClass | ObjectExact | ObjectSubclass | Range | Regexp | Set | String | Symbol | TrueClass;
pub const ObjectExact: u64 = 1u64 << 34;
pub const ObjectSubclass: u64 = 1u64 << 35;
pub const Range: u64 = RangeExact | RangeSubclass;
pub const RangeExact: u64 = 1u64 << 34;
pub const RangeSubclass: u64 = 1u64 << 35;
pub const RangeExact: u64 = 1u64 << 36;
pub const RangeSubclass: u64 = 1u64 << 37;
pub const Regexp: u64 = RegexpExact | RegexpSubclass;
pub const RegexpExact: u64 = 1u64 << 36;
pub const RegexpSubclass: u64 = 1u64 << 37;
pub const RegexpExact: u64 = 1u64 << 38;
pub const RegexpSubclass: u64 = 1u64 << 39;
pub const RubyValue: u64 = BasicObject | CallableMethodEntry | Undef;
pub const Set: u64 = SetExact | SetSubclass;
pub const SetExact: u64 = 1u64 << 38;
pub const SetSubclass: u64 = 1u64 << 39;
pub const StaticSymbol: u64 = 1u64 << 40;
pub const SetExact: u64 = 1u64 << 40;
pub const SetSubclass: u64 = 1u64 << 41;
pub const StaticSymbol: u64 = 1u64 << 42;
pub const String: u64 = StringExact | StringSubclass;
pub const StringExact: u64 = 1u64 << 41;
pub const StringSubclass: u64 = 1u64 << 42;
pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | ClassSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass;
pub const StringExact: u64 = 1u64 << 43;
pub const StringSubclass: u64 = 1u64 << 44;
pub const Subclass: u64 = ArraySubclass | BasicObjectSubclass | ClassSubclass | FalseClassSubclass | FloatSubclass | HashSubclass | IntegerSubclass | ModuleSubclass | NilClassSubclass | ObjectSubclass | RangeSubclass | RegexpSubclass | SetSubclass | StringSubclass | SymbolSubclass | TrueClassSubclass;
pub const Symbol: u64 = SymbolExact | SymbolSubclass;
pub const SymbolExact: u64 = DynamicSymbol | StaticSymbol;
pub const SymbolSubclass: u64 = 1u64 << 43;
pub const SymbolSubclass: u64 = 1u64 << 45;
pub const TrueClass: u64 = TrueClassExact | TrueClassSubclass;
pub const TrueClassExact: u64 = 1u64 << 44;
pub const TrueClassSubclass: u64 = 1u64 << 45;
pub const Undef: u64 = 1u64 << 46;
pub const AllBitPatterns: [(&'static str, u64); 76] = [
pub const TrueClassExact: u64 = 1u64 << 46;
pub const TrueClassSubclass: u64 = 1u64 << 47;
pub const Undef: u64 = 1u64 << 48;
pub const AllBitPatterns: [(&'static str, u64); 79] = [
("Any", Any),
("RubyValue", RubyValue),
("Immediate", Immediate),
@ -110,6 +113,9 @@ mod bits {
("NilClass", NilClass),
("NilClassSubclass", NilClassSubclass),
("NilClassExact", NilClassExact),
("Module", Module),
("ModuleSubclass", ModuleSubclass),
("ModuleExact", ModuleExact),
("Integer", Integer),
("IntegerSubclass", IntegerSubclass),
("Float", Float),
@ -154,7 +160,7 @@ mod bits {
("ArrayExact", ArrayExact),
("Empty", Empty),
];
pub const NumTypeBits: u64 = 47;
pub const NumTypeBits: u64 = 49;
}
pub mod types {
use super::*;
@ -206,6 +212,9 @@ pub mod types {
pub const Integer: Type = Type::from_bits(bits::Integer);
pub const IntegerExact: Type = Type::from_bits(bits::IntegerExact);
pub const IntegerSubclass: Type = Type::from_bits(bits::IntegerSubclass);
pub const Module: Type = Type::from_bits(bits::Module);
pub const ModuleExact: Type = Type::from_bits(bits::ModuleExact);
pub const ModuleSubclass: Type = Type::from_bits(bits::ModuleSubclass);
pub const NilClass: Type = Type::from_bits(bits::NilClass);
pub const NilClassExact: Type = Type::from_bits(bits::NilClassExact);
pub const NilClassSubclass: Type = Type::from_bits(bits::NilClassSubclass);

View File

@ -1,10 +1,11 @@
#![allow(non_upper_case_globals)]
use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH, RUBY_T_CLASS};
use crate::cruby::{Qfalse, Qnil, Qtrue, VALUE, RUBY_T_ARRAY, RUBY_T_STRING, RUBY_T_HASH, RUBY_T_CLASS, RUBY_T_MODULE};
use crate::cruby::{rb_cInteger, rb_cFloat, rb_cArray, rb_cHash, rb_cString, rb_cSymbol, rb_cObject, rb_cTrueClass, rb_cFalseClass, rb_cNilClass, rb_cRange, rb_cSet, rb_cRegexp, rb_cClass, rb_cModule};
use crate::cruby::ClassRelationship;
use crate::cruby::get_class_name;
use crate::cruby::ruby_sym_to_rust_string;
use crate::cruby::rb_mRubyVMFrozenCore;
use crate::cruby::rb_obj_class;
use crate::hir::PtrPrintMap;
#[derive(Copy, Clone, Debug, PartialEq)]
@ -145,9 +146,16 @@ fn is_range_exact(val: VALUE) -> bool {
val.class_of() == unsafe { rb_cRange }
}
fn is_class_exact(val: VALUE) -> bool {
// Objects with RUBY_T_CLASS type and not instances of Module
val.builtin_type() == RUBY_T_CLASS && val.class_of() != unsafe { rb_cModule }
fn is_module_exact(val: VALUE) -> bool {
if val.builtin_type() != RUBY_T_MODULE {
return false;
}
// For Class and Module instances, `class_of` will return the singleton class of the object.
// Using `rb_obj_class` will give us the actual class of the module so we can check if the
// object is an instance of Module, or an instance of Module subclass.
let klass = unsafe { rb_obj_class(val) };
klass == unsafe { rb_cModule }
}
impl Type {
@ -202,7 +210,10 @@ impl Type {
else if is_string_exact(val) {
Type { bits: bits::StringExact, spec: Specialization::Object(val) }
}
else if is_class_exact(val) {
else if is_module_exact(val) {
Type { bits: bits::ModuleExact, spec: Specialization::Object(val) }
}
else if val.builtin_type() == RUBY_T_CLASS {
Type { bits: bits::ClassExact, spec: Specialization::Object(val) }
}
else if val.class_of() == unsafe { rb_cRegexp } {
@ -301,6 +312,7 @@ impl Type {
if class == unsafe { rb_cFloat } { return true; }
if class == unsafe { rb_cHash } { return true; }
if class == unsafe { rb_cInteger } { return true; }
if class == unsafe { rb_cModule } { return true; }
if class == unsafe { rb_cNilClass } { return true; }
if class == unsafe { rb_cObject } { return true; }
if class == unsafe { rb_cRange } { return true; }
@ -410,6 +422,7 @@ impl Type {
if self.is_subtype(types::FloatExact) { return Some(unsafe { rb_cFloat }); }
if self.is_subtype(types::HashExact) { return Some(unsafe { rb_cHash }); }
if self.is_subtype(types::IntegerExact) { return Some(unsafe { rb_cInteger }); }
if self.is_subtype(types::ModuleExact) { return Some(unsafe { rb_cModule }); }
if self.is_subtype(types::NilClassExact) { return Some(unsafe { rb_cNilClass }); }
if self.is_subtype(types::ObjectExact) { return Some(unsafe { rb_cObject }); }
if self.is_subtype(types::RangeExact) { return Some(unsafe { rb_cRange }); }