diff --git a/src/ir/emit_llvm.zig b/src/ir/emit_llvm.zig index ad7cc10..fae297f 100644 --- a/src/ir/emit_llvm.zig +++ b/src/ir/emit_llvm.zig @@ -322,6 +322,29 @@ pub const LLVMEmitter = struct { } } + /// If `val` is a `{ptr, i64}` slice struct, extract field 0 + /// (the ptr); otherwise return it unchanged. Used by JNI dispatch + /// to feed string-literal method names + signatures to + /// `GetMethodID`, which expects raw C strings. + fn extractSlicePtr(self: *LLVMEmitter, val: c.LLVMValueRef) c.LLVMValueRef { + const val_ty = c.LLVMTypeOf(val); + if (c.LLVMGetTypeKind(val_ty) != c.LLVMStructTypeKind) return val; + if (c.LLVMCountStructElementTypes(val_ty) != 2) return val; + const f0 = c.LLVMStructGetTypeAtIndex(val_ty, 0); + if (c.LLVMGetTypeKind(f0) != c.LLVMPointerTypeKind) return val; + return c.LLVMBuildExtractValue(self.builder, val, 0, "jni.str.ptr"); + } + + /// Load a JNI vtable function pointer at the given offset. `ifs` + /// is the `JNINativeInterface*` loaded from `JNIEnv*`. Treats the + /// vtable as an array of opaque `ptr`s and indexes into it. + fn loadJniFn(self: *LLVMEmitter, ifs: c.LLVMValueRef, offset: u32, name: [*:0]const u8) c.LLVMValueRef { + const offset_val = c.LLVMConstInt(self.cached_i32, offset, 0); + var idx = [_]c.LLVMValueRef{offset_val}; + const slot = c.LLVMBuildInBoundsGEP2(self.builder, self.cached_ptr, ifs, &idx, 1, ""); + return c.LLVMBuildLoad2(self.builder, self.cached_ptr, slot, name); + } + /// Lazily look up / declare the shared `@objc_msgSend` function. /// Cached on the emitter; all `objc_msg_send` instructions hand /// LLVMBuildCall2 their own per-call-site function type — the @@ -1120,6 +1143,77 @@ pub const LLVMEmitter = struct { // ref lookup in this function. self.mapRef(result); }, + .jni_msg_send => |msg| { + // JNI vtable indirection: + // ifs = *env // JNINativeInterface* + // cls = ifs[31](env, target) // GetObjectClass + // mid = ifs[33](env, cls, name, sig) // GetMethodID + // ifs[61](env, target, mid, args...) // CallVoidMethod + // Static dispatch (1.23) and non-void returns (1.18+) widen + // the switch below. + if (msg.is_static) { + self.mapRef(c.LLVMGetUndef(self.toLLVMType(instruction.ty))); + return; + } + const ret_ty_id = instruction.ty; + const call_method_offset: u32 = switch (ret_ty_id) { + .void => 61, // CallVoidMethod + else => { + self.mapRef(c.LLVMGetUndef(self.toLLVMType(instruction.ty))); + return; + }, + }; + + const env = self.resolveRef(msg.env); + const target = self.resolveRef(msg.target); + // String literals lower as `{ptr, i64}` slices in sx IR; + // JNI's `GetMethodID` expects raw C strings, so extract + // field 0 when the source is a slice. + const name_ptr = self.extractSlicePtr(self.resolveRef(msg.name)); + const sig_ptr = self.extractSlicePtr(self.resolveRef(msg.sig)); + + const ifs = c.LLVMBuildLoad2(self.builder, self.cached_ptr, env, "jni.ifs"); + + // GetObjectClass: (JNIEnv*, jobject) -> jclass + const get_obj_cls = self.loadJniFn(ifs, 31, "jni.GetObjectClass"); + var gocls_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr }; + const gocls_ty = c.LLVMFunctionType(self.cached_ptr, &gocls_params, 2, 0); + var gocls_args = [_]c.LLVMValueRef{ env, target }; + const cls = c.LLVMBuildCall2(self.builder, gocls_ty, get_obj_cls, &gocls_args, 2, "jni.cls"); + + // GetMethodID: (JNIEnv*, jclass, const char*, const char*) -> jmethodID + const get_mid = self.loadJniFn(ifs, 33, "jni.GetMethodID"); + var gmid_params = [_]c.LLVMTypeRef{ self.cached_ptr, self.cached_ptr, self.cached_ptr, self.cached_ptr }; + const gmid_ty = c.LLVMFunctionType(self.cached_ptr, &gmid_params, 4, 0); + var gmid_args = [_]c.LLVMValueRef{ env, cls, name_ptr, sig_ptr }; + const mid = c.LLVMBuildCall2(self.builder, gmid_ty, get_mid, &gmid_args, 4, "jni.mid"); + + // CallMethod: (JNIEnv*, jobject, jmethodID, args...) -> RetTy + const call_fn = self.loadJniFn(ifs, call_method_offset, "jni.callfn"); + const raw_ret = self.toLLVMType(ret_ty_id); + const total_call_params: usize = 3 + msg.args.len; + const call_param_types = self.alloc.alloc(c.LLVMTypeRef, total_call_params) catch unreachable; + defer self.alloc.free(call_param_types); + const call_args = self.alloc.alloc(c.LLVMValueRef, total_call_params) catch unreachable; + defer self.alloc.free(call_args); + call_param_types[0] = self.cached_ptr; + call_param_types[1] = self.cached_ptr; + call_param_types[2] = self.cached_ptr; + call_args[0] = env; + call_args[1] = target; + call_args[2] = mid; + for (msg.args, 0..) |arg_ref, i| { + const raw_ty = self.getRefIRType(arg_ref) orelse .void; + const raw_llvm = self.toLLVMType(raw_ty); + const coerced_ty = self.abiCoerceParamType(raw_ty, raw_llvm); + call_param_types[i + 3] = coerced_ty; + call_args[i + 3] = self.coerceArg(self.resolveRef(arg_ref), coerced_ty); + } + const call_fn_ty = c.LLVMFunctionType(raw_ret, call_param_types.ptr, @intCast(total_call_params), 0); + const label: [*:0]const u8 = if (ret_ty_id == .void) "" else "jni.ret"; + const result = c.LLVMBuildCall2(self.builder, call_fn_ty, call_fn, call_args.ptr, @intCast(total_call_params), label); + self.mapRef(result); + }, .call => |call_op| { // Evaluate comptime functions at compile time const callee_func = &self.ir_mod.functions.items[call_op.callee.index()]; diff --git a/src/ir/inst.zig b/src/ir/inst.zig index aacbfa5..e8fcb8a 100644 --- a/src/ir/inst.zig +++ b/src/ir/inst.zig @@ -188,6 +188,14 @@ pub const Op = union(enum) { /// per signature shape. objc_msg_send: ObjcMsgSend, + /// `#jni_call(ReturnT)(env, target, name, sig, args...)` and + /// `#jni_static_call(ReturnT)(env, class, name, sig, args...)`. + /// emit_llvm.zig expands this into the JNI vtable indirection: + /// `(*env)->GetObjectClass` (instance only) → `GetMethodID` / + /// `GetStaticMethodID` → `CallMethod` / `CallStaticMethod`. + /// Method-ID caching across call sites is added in step 1.17. + jni_msg_send: JniMsgSend, + // ── Protocol dispatch ─────────────────────────────────────────── protocol_call_dynamic: ProtocolCall, // vtable/inline dispatch protocol_erase: ProtocolErase, // concrete → protocol value (xx) @@ -304,6 +312,20 @@ pub const ObjcMsgSend = struct { args: []const Ref, // additional args after recv + sel }; +/// JNI dispatch payload. `env` is `JNIEnv*` (typed as ptr); `target` +/// is a `jobject` for instance calls and a `jclass` for static calls. +/// `name` and `sig` are pointers to NUL-terminated bytes (typically +/// `[*]u8` from a string-literal `.ptr`). The dispatch sequence is +/// expanded in emit_llvm.zig — see `Inst.jni_msg_send`. +pub const JniMsgSend = struct { + env: Ref, + target: Ref, + name: Ref, + sig: Ref, + args: []const Ref, + is_static: bool, +}; + pub const BuiltinCall = struct { builtin: BuiltinId, args: []const Ref, diff --git a/src/ir/interp.zig b/src/ir/interp.zig index b326a9a..2e89eb5 100644 --- a/src/ir/interp.zig +++ b/src/ir/interp.zig @@ -534,6 +534,8 @@ pub const Interpreter = struct { // `#objc_call` reached during `#run` execution can't // resolve. Fail fast so callers see a useful diagnostic. .objc_msg_send => return error.CannotEvalComptime, + // Same story for JNI — no JVM at compile time. + .jni_msg_send => return error.CannotEvalComptime, // ── Block params ──────────────────────────────────── .block_param => { diff --git a/src/ir/lower.zig b/src/ir/lower.zig index 89efb5b..567ec67 100644 --- a/src/ir/lower.zig +++ b/src/ir/lower.zig @@ -3790,11 +3790,8 @@ pub const Lowering = struct { /// fully wired. Extra arities + non-void returns will land in /// subsequent phase-1 steps. fn lowerFfiIntrinsicCall(self: *Lowering, fic: *const ast.FfiIntrinsicCall) Ref { - if (fic.kind != .objc_call) { - if (self.diagnostics) |d| { - d.add(.err, "#jni_call / #jni_static_call lowering not implemented yet (Phase 1.15+)", null); - } - return Ref.none; + if (fic.kind == .jni_call or fic.kind == .jni_static_call) { + return self.lowerJniCall(fic); } if (fic.args.len < 2) { @@ -3855,6 +3852,37 @@ pub const Lowering = struct { } }, ret_ty); } + fn lowerJniCall(self: *Lowering, fic: *const ast.FfiIntrinsicCall) Ref { + if (fic.args.len < 4) { + if (self.diagnostics) |d| { + d.add(.err, "#jni_call requires env, target, method name, and signature", null); + } + return Ref.none; + } + + const ret_ty = self.resolveType(fic.return_type); + const env_ref = self.lowerExpr(fic.args[0]); + const target_ref = self.lowerExpr(fic.args[1]); + const name_ref = self.lowerExpr(fic.args[2]); + const sig_ref = self.lowerExpr(fic.args[3]); + + var extra = std.ArrayList(Ref).empty; + var ai: usize = 4; + while (ai < fic.args.len) : (ai += 1) { + extra.append(self.alloc, self.lowerExpr(fic.args[ai])) catch unreachable; + } + const extra_owned = extra.toOwnedSlice(self.alloc) catch unreachable; + + return self.builder.emit(.{ .jni_msg_send = .{ + .env = env_ref, + .target = target_ref, + .name = name_ref, + .sig = sig_ref, + .args = extra_owned, + .is_static = fic.kind == .jni_static_call, + } }, ret_ty); + } + // ── Calls ─────────────────────────────────────────────────────── fn lowerCall(self: *Lowering, c: *const ast.Call) Ref { diff --git a/src/ir/print.zig b/src/ir/print.zig index 3039f18..a63969c 100644 --- a/src/ir/print.zig +++ b/src/ir/print.zig @@ -321,6 +321,14 @@ fn printInst(instruction: *const Inst, ref_idx: u32, tt: *const TypeTable, write try writeArgs(c.args, writer); try writer.writeAll(") : "); }, + .jni_msg_send => |c| { + const kind: []const u8 = if (c.is_static) "static" else "instance"; + try writer.print("jni_msg_send {s} env=%{d} target=%{d} name=%{d} sig=%{d}(", .{ + kind, c.env.index(), c.target.index(), c.name.index(), c.sig.index(), + }); + try writeArgs(c.args, writer); + try writer.writeAll(") : "); + }, .compiler_call => |cc| { const name = tt.getString(@enumFromInt(cc.name)); try writer.print("compiler_call \"{s}\"(", .{name});