diff --git a/src/ir/emit_llvm.zig b/src/ir/emit_llvm.zig index fae297f..23f9375 100644 --- a/src/ir/emit_llvm.zig +++ b/src/ir/emit_llvm.zig @@ -25,6 +25,10 @@ const interp_mod = @import("interp.zig"); const Interpreter = interp_mod.Interpreter; const Value = interp_mod.Value; +fn isIdentByte(b: u8) bool { + return (b >= 'a' and b <= 'z') or (b >= 'A' and b <= 'Z') or (b >= '0' and b <= '9') or b == '_'; +} + // ── LLVMEmitter ───────────────────────────────────────────────────────── // Emits LLVM IR from an IR Module. This is the Phase 3 replacement for // the AST-based codegen. @@ -84,6 +88,11 @@ pub const LLVMEmitter = struct { // dispatch through it with their own LLVMBuildCall2 function type // (opaque pointers — the function value is just a `ptr`). objc_msg_send_value: ?c.LLVMValueRef, + // `(name, sig)` → `{cls_slot, mid_slot}` cache for `#jni_call` + // interning (step 1.17). Two call sites with the same literal + // name + signature share one pair of static slots, populated + // lazily on the first call. + jni_slots: std.StringHashMap(JniSlotPair), // Cached field name arrays for reflection (TypeId → LLVM global) field_name_arrays: std.AutoHashMap(u32, c.LLVMValueRef), @@ -100,6 +109,11 @@ pub const LLVMEmitter = struct { param_index: u32, }; + const JniSlotPair = struct { + cls_slot: c.LLVMValueRef, // @SX_JNI_CLS_: ptr (GlobalRef to jclass) + mid_slot: c.LLVMValueRef, // @SX_JNI_MID_: ptr (jmethodID) + }; + pub fn init(alloc: Allocator, ir_mod: *const Module, module_name: [*:0]const u8, target_config: TargetConfig) LLVMEmitter { // Initialize LLVM targets if (target_config.triple == null) { @@ -165,6 +179,7 @@ pub const LLVMEmitter = struct { .any_struct_type = null, .closure_struct_type = null, .objc_msg_send_value = null, + .jni_slots = std.StringHashMap(JniSlotPair).init(alloc), .field_name_arrays = std.AutoHashMap(u32, c.LLVMValueRef).init(alloc), .target_config = target_config, .build_config = .{}, @@ -176,6 +191,9 @@ pub const LLVMEmitter = struct { self.ref_map.deinit(); self.func_map.deinit(); self.field_name_arrays.deinit(); + var jni_it = self.jni_slots.keyIterator(); + while (jni_it.next()) |k| self.alloc.free(k.*); + self.jni_slots.deinit(); self.global_map.deinit(); self.block_map.deinit(); if (self.target_machine) |tm| c.LLVMDisposeTargetMachine(tm); @@ -322,6 +340,50 @@ pub const LLVMEmitter = struct { } } + /// Return `{cls_slot, mid_slot}` global pair for the + /// `(name, sig)` literal — created on first lookup, shared across + /// later `#jni_call` sites with the same literal pair. Both + /// slots are zero-initialized `ptr`; the call-site lowering does + /// lazy population on first dispatch. + fn getOrCreateJniSlots(self: *LLVMEmitter, name: []const u8, sig: []const u8) JniSlotPair { + // Compose the key from name + a separator + sig. The separator + // is a byte that can't appear in a JNI method name or signature + // (NUL), so the same key never collides across distinct pairs. + const key = std.fmt.allocPrint(self.alloc, "{s}\x00{s}", .{ name, sig }) catch unreachable; + if (self.jni_slots.get(key)) |existing| { + self.alloc.free(key); + return existing; + } + const mangled = self.mangleJniKey(name, sig); + defer self.alloc.free(mangled); + const cls_name = std.fmt.allocPrintSentinel(self.alloc, "SX_JNI_CLS_{s}", .{mangled}, 0) catch unreachable; + defer self.alloc.free(cls_name); + const mid_name = std.fmt.allocPrintSentinel(self.alloc, "SX_JNI_MID_{s}", .{mangled}, 0) catch unreachable; + defer self.alloc.free(mid_name); + const cls_slot = c.LLVMAddGlobal(self.llvm_module, self.cached_ptr, cls_name.ptr); + c.LLVMSetLinkage(cls_slot, c.LLVMInternalLinkage); + c.LLVMSetInitializer(cls_slot, c.LLVMConstNull(self.cached_ptr)); + const mid_slot = c.LLVMAddGlobal(self.llvm_module, self.cached_ptr, mid_name.ptr); + c.LLVMSetLinkage(mid_slot, c.LLVMInternalLinkage); + c.LLVMSetInitializer(mid_slot, c.LLVMConstNull(self.cached_ptr)); + const pair = JniSlotPair{ .cls_slot = cls_slot, .mid_slot = mid_slot }; + self.jni_slots.put(key, pair) catch unreachable; + return pair; + } + + /// Build an LLVM-friendly identifier suffix from a JNI + /// `(method_name, signature)` pair. Non-identifier characters are + /// rewritten to `_`; the resulting string is unique per pair (the + /// caller guarantees uniqueness on `(name, sig)`, which we + /// preserve through the separator between mangled name and sig). + fn mangleJniKey(self: *LLVMEmitter, name: []const u8, sig: []const u8) []u8 { + var buf = std.ArrayList(u8).empty; + for (name) |b| buf.append(self.alloc, if (isIdentByte(b)) b else '_') catch unreachable; + buf.appendSlice(self.alloc, "__") catch unreachable; + for (sig) |b| buf.append(self.alloc, if (isIdentByte(b)) b else '_') catch unreachable; + return buf.toOwnedSlice(self.alloc) catch unreachable; + } + /// If `val` is a `{ptr, i64}` slice struct, extract field 0 /// (the ptr); otherwise return it unchanged. Used by JNI dispatch /// to feed string-literal method names + signatures to @@ -1174,19 +1236,66 @@ pub const LLVMEmitter = struct { const ifs = c.LLVMBuildLoad2(self.builder, self.cached_ptr, env, "jni.ifs"); - // GetObjectClass: (JNIEnv*, jobject) -> jclass - const get_obj_cls = self.loadJniFn(ifs, 31, "jni.GetObjectClass"); - var gocls_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr }; - const gocls_ty = c.LLVMFunctionType(self.cached_ptr, &gocls_params, 2, 0); - var gocls_args = [_]c.LLVMValueRef{ env, target }; - const cls = c.LLVMBuildCall2(self.builder, gocls_ty, get_obj_cls, &gocls_args, 2, "jni.cls"); + // Method-ID resolution. When `name` and `sig` are both + // string literals the call site participates in + // `(name, sig)` slot interning (step 1.17): a shared + // pair of static globals holds the `jclass` GlobalRef + // and the `jmethodID`, populated lazily on the first + // call to any matching site. Non-literal sites fall + // back to the per-call `GetObjectClass + GetMethodID` + // sequence (1.15 shape). + const mid = if (msg.cache_key) |ck| blk: { + const pair = self.getOrCreateJniSlots(ck.name_str, ck.sig_str); + const cached_mid = c.LLVMBuildLoad2(self.builder, self.cached_ptr, pair.mid_slot, "jni.cached.mid"); + const is_cached = c.LLVMBuildICmp(self.builder, c.LLVMIntNE, cached_mid, c.LLVMConstNull(self.cached_ptr), "jni.is.cached"); - // GetMethodID: (JNIEnv*, jclass, const char*, const char*) -> jmethodID - const get_mid = self.loadJniFn(ifs, 33, "jni.GetMethodID"); - var gmid_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr, self.cached_ptr, self.cached_ptr }; - const gmid_ty = c.LLVMFunctionType(self.cached_ptr, &gmid_params, 4, 0); - var gmid_args = [_]c.LLVMValueRef{ env, cls, name_ptr, sig_ptr }; - const mid = c.LLVMBuildCall2(self.builder, gmid_ty, get_mid, &gmid_args, 4, "jni.mid"); + const cur_fn = c.LLVMGetBasicBlockParent(c.LLVMGetInsertBlock(self.builder)); + const miss_bb = c.LLVMAppendBasicBlockInContext(self.context, cur_fn, "jni.miss"); + const cont_bb = c.LLVMAppendBasicBlockInContext(self.context, cur_fn, "jni.cont"); + const before_bb = c.LLVMGetInsertBlock(self.builder); + _ = c.LLVMBuildCondBr(self.builder, is_cached, cont_bb, miss_bb); + + // Miss path: GetObjectClass → NewGlobalRef → GetMethodID, then store both. + c.LLVMPositionBuilderAtEnd(self.builder, miss_bb); + const get_obj_cls = self.loadJniFn(ifs, 31, "jni.GetObjectClass"); + var gocls_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr }; + const gocls_ty = c.LLVMFunctionType(self.cached_ptr, &gocls_params, 2, 0); + var gocls_args = [_]c.LLVMValueRef{ env, target }; + const local_cls = c.LLVMBuildCall2(self.builder, gocls_ty, get_obj_cls, &gocls_args, 2, "jni.cls"); + const new_global_ref = self.loadJniFn(ifs, 21, "jni.NewGlobalRef"); + var ngref_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr }; + const ngref_ty = c.LLVMFunctionType(self.cached_ptr, &ngref_params, 2, 0); + var ngref_args = [_]c.LLVMValueRef{ env, local_cls }; + const global_cls = c.LLVMBuildCall2(self.builder, ngref_ty, new_global_ref, &ngref_args, 2, "jni.global.cls"); + _ = c.LLVMBuildStore(self.builder, global_cls, pair.cls_slot); + const get_mid = self.loadJniFn(ifs, 33, "jni.GetMethodID"); + var gmid_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr, self.cached_ptr, self.cached_ptr }; + const gmid_ty = c.LLVMFunctionType(self.cached_ptr, &gmid_params, 4, 0); + var gmid_args = [_]c.LLVMValueRef{ env, global_cls, name_ptr, sig_ptr }; + const fresh_mid = c.LLVMBuildCall2(self.builder, gmid_ty, get_mid, &gmid_args, 4, "jni.fresh.mid"); + _ = c.LLVMBuildStore(self.builder, fresh_mid, pair.mid_slot); + const miss_end_bb = c.LLVMGetInsertBlock(self.builder); + _ = c.LLVMBuildBr(self.builder, cont_bb); + + // Cont: phi the cached vs fresh mid. + c.LLVMPositionBuilderAtEnd(self.builder, cont_bb); + const phi = c.LLVMBuildPhi(self.builder, self.cached_ptr, "jni.mid"); + var phi_vals = [_]c.LLVMValueRef{ cached_mid, fresh_mid }; + var phi_blocks = [_]c.LLVMBasicBlockRef{ before_bb, miss_end_bb }; + c.LLVMAddIncoming(phi, &phi_vals, &phi_blocks, 2); + break :blk phi; + } else blk: { + const get_obj_cls = self.loadJniFn(ifs, 31, "jni.GetObjectClass"); + var gocls_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr }; + const gocls_ty = c.LLVMFunctionType(self.cached_ptr, &gocls_params, 2, 0); + var gocls_args = [_]c.LLVMValueRef{ env, target }; + const cls = c.LLVMBuildCall2(self.builder, gocls_ty, get_obj_cls, &gocls_args, 2, "jni.cls"); + const get_mid = self.loadJniFn(ifs, 33, "jni.GetMethodID"); + var gmid_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr, self.cached_ptr, self.cached_ptr }; + const gmid_ty = c.LLVMFunctionType(self.cached_ptr, &gmid_params, 4, 0); + var gmid_args = [_]c.LLVMValueRef{ env, cls, name_ptr, sig_ptr }; + break :blk c.LLVMBuildCall2(self.builder, gmid_ty, get_mid, &gmid_args, 4, "jni.mid"); + }; // CallMethod: (JNIEnv*, jobject, jmethodID, args...) -> RetTy const call_fn = self.loadJniFn(ifs, call_method_offset, "jni.callfn"); diff --git a/src/ir/inst.zig b/src/ir/inst.zig index e8fcb8a..bfb6fe1 100644 --- a/src/ir/inst.zig +++ b/src/ir/inst.zig @@ -315,8 +315,12 @@ pub const ObjcMsgSend = struct { /// JNI dispatch payload. `env` is `JNIEnv*` (typed as ptr); `target` /// is a `jobject` for instance calls and a `jclass` for static calls. /// `name` and `sig` are pointers to NUL-terminated bytes (typically -/// `[*]u8` from a string-literal `.ptr`). The dispatch sequence is -/// expanded in emit_llvm.zig — see `Inst.jni_msg_send`. +/// `[*]u8` from a string-literal `.ptr`). When the source-level +/// `name` and `sig` are string literals, `cache_key` carries their +/// content so emit_llvm.zig can intern a shared `jclass GlobalRef` + +/// `jmethodID` slot keyed on `(name, sig)`; otherwise the lookup +/// stays uncached. The dispatch sequence is expanded in +/// emit_llvm.zig — see `Inst.jni_msg_send`. pub const JniMsgSend = struct { env: Ref, target: Ref, @@ -324,6 +328,12 @@ pub const JniMsgSend = struct { sig: Ref, args: []const Ref, is_static: bool, + cache_key: ?CacheKey = null, +}; + +pub const CacheKey = struct { + name_str: []const u8, + sig_str: []const u8, }; pub const BuiltinCall = struct { diff --git a/src/ir/lower.zig b/src/ir/lower.zig index 567ec67..34e2921 100644 --- a/src/ir/lower.zig +++ b/src/ir/lower.zig @@ -3863,8 +3863,21 @@ pub const Lowering = struct { const ret_ty = self.resolveType(fic.return_type); const env_ref = self.lowerExpr(fic.args[0]); const target_ref = self.lowerExpr(fic.args[1]); - const name_ref = self.lowerExpr(fic.args[2]); - const sig_ref = self.lowerExpr(fic.args[3]); + const name_node = fic.args[2]; + const sig_node = fic.args[3]; + const name_ref = self.lowerExpr(name_node); + const sig_ref = self.lowerExpr(sig_node); + + // Capture the (name, sig) literal content when both args are + // string literals — emit_llvm uses this as the intern key for + // the shared `jclass`/`jmethodID` slot pair (step 1.17). + const cache_key: ?inst_mod.CacheKey = if (name_node.data == .string_literal and sig_node.data == .string_literal) + inst_mod.CacheKey{ + .name_str = name_node.data.string_literal.raw, + .sig_str = sig_node.data.string_literal.raw, + } + else + null; var extra = std.ArrayList(Ref).empty; var ai: usize = 4; @@ -3880,6 +3893,7 @@ pub const Lowering = struct { .sig = sig_ref, .args = extra_owned, .is_static = fic.kind == .jni_static_call, + .cache_key = cache_key, } }, ret_ty); } diff --git a/tests/expected/ffi-jni-call-03-methodid-sharing.ir b/tests/expected/ffi-jni-call-03-methodid-sharing.ir index 272cbcb..9d815cf 100644 --- a/tests/expected/ffi-jni-call-03-methodid-sharing.ir +++ b/tests/expected/ffi-jni-call-03-methodid-sharing.ir @@ -3,6 +3,8 @@ @g_should_call = internal global i1 false @str = private unnamed_addr constant [5 x i8] c"noop\00", align 1 @str.1 = private unnamed_addr constant [4 x i8] c"()V\00", align 1 +@SX_JNI_CLS_noop____V = internal global ptr null +@SX_JNI_MID_noop____V = internal global ptr null @str.2 = private unnamed_addr constant [5 x i8] c"noop\00", align 1 @str.3 = private unnamed_addr constant [4 x i8] c"()V\00", align 1 @str.4 = private unnamed_addr constant [4 x i8] c"ok\0A\00", align 1 @@ -208,27 +210,55 @@ entry: %load = load ptr, ptr %alloca, align 8 %loadN = load ptr, ptr %allocaN, align 8 %jni.ifs = load ptr, ptr %load, align 8 + %jni.cached.mid = load ptr, ptr @SX_JNI_MID_noop____V, align 8 + %jni.is.cached = icmp ne ptr %jni.cached.mid, null + br i1 %jni.is.cached, label %jni.cont, label %jni.miss + +jni.miss: ; preds = %entry %2 = getelementptr inbounds ptr, ptr %jni.ifs, i32 31 %jni.GetObjectClass = load ptr, ptr %2, align 8 %jni.cls = call ptr %jni.GetObjectClass(ptr %load, ptr %loadN) - %3 = getelementptr inbounds ptr, ptr %jni.ifs, i32 33 - %jni.GetMethodID = load ptr, ptr %3, align 8 - %jni.mid = call ptr %jni.GetMethodID(ptr %load, ptr %jni.cls, ptr @str, ptr @str.1) - %4 = getelementptr inbounds ptr, ptr %jni.ifs, i32 61 - %jni.callfn = load ptr, ptr %4, align 8 + %3 = getelementptr inbounds ptr, ptr %jni.ifs, i32 21 + %jni.NewGlobalRef = load ptr, ptr %3, align 8 + %jni.global.cls = call ptr %jni.NewGlobalRef(ptr %load, ptr %jni.cls) + store ptr %jni.global.cls, ptr @SX_JNI_CLS_noop____V, align 8 + %4 = getelementptr inbounds ptr, ptr %jni.ifs, i32 33 + %jni.GetMethodID = load ptr, ptr %4, align 8 + %jni.fresh.mid = call ptr %jni.GetMethodID(ptr %load, ptr %jni.global.cls, ptr @str, ptr @str.1) + store ptr %jni.fresh.mid, ptr @SX_JNI_MID_noop____V, align 8 + br label %jni.cont + +jni.cont: ; preds = %jni.miss, %entry + %jni.mid = phi ptr [ %jni.cached.mid, %entry ], [ %jni.fresh.mid, %jni.miss ] + %5 = getelementptr inbounds ptr, ptr %jni.ifs, i32 61 + %jni.callfn = load ptr, ptr %5, align 8 call void %jni.callfn(ptr %load, ptr %loadN, ptr %jni.mid) %loadN = load ptr, ptr %alloca, align 8 %loadN = load ptr, ptr %allocaN, align 8 %jni.ifs5 = load ptr, ptr %loadN, align 8 - %5 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 31 - %jni.GetObjectClass6 = load ptr, ptr %5, align 8 - %jni.cls7 = call ptr %jni.GetObjectClass6(ptr %loadN, ptr %loadN) - %6 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 33 - %jni.GetMethodID8 = load ptr, ptr %6, align 8 - %jni.mid9 = call ptr %jni.GetMethodID8(ptr %loadN, ptr %jni.cls7, ptr @str.2, ptr @str.3) - %7 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 61 - %jni.callfn10 = load ptr, ptr %7, align 8 - call void %jni.callfn10(ptr %loadN, ptr %loadN, ptr %jni.mid9) + %jni.cached.mid6 = load ptr, ptr @SX_JNI_MID_noop____V, align 8 + %jni.is.cached7 = icmp ne ptr %jni.cached.mid6, null + br i1 %jni.is.cached7, label %jni.cont9, label %jni.miss8 + +jni.miss8: ; preds = %jni.cont + %6 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 31 + %jni.GetObjectClass10 = load ptr, ptr %6, align 8 + %jni.cls11 = call ptr %jni.GetObjectClass10(ptr %loadN, ptr %loadN) + %7 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 21 + %jni.NewGlobalRef12 = load ptr, ptr %7, align 8 + %jni.global.cls13 = call ptr %jni.NewGlobalRef12(ptr %loadN, ptr %jni.cls11) + store ptr %jni.global.cls13, ptr @SX_JNI_CLS_noop____V, align 8 + %8 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 33 + %jni.GetMethodID14 = load ptr, ptr %8, align 8 + %jni.fresh.mid15 = call ptr %jni.GetMethodID14(ptr %loadN, ptr %jni.global.cls13, ptr @str.2, ptr @str.3) + store ptr %jni.fresh.mid15, ptr @SX_JNI_MID_noop____V, align 8 + br label %jni.cont9 + +jni.cont9: ; preds = %jni.miss8, %jni.cont + %jni.mid16 = phi ptr [ %jni.cached.mid6, %jni.cont ], [ %jni.fresh.mid15, %jni.miss8 ] + %9 = getelementptr inbounds ptr, ptr %jni.ifs5, i32 61 + %jni.callfn17 = load ptr, ptr %9, align 8 + call void %jni.callfn17(ptr %loadN, ptr %loadN, ptr %jni.mid16) ret void }