diff --git a/src/ir/lower.zig b/src/ir/lower.zig index a01fe7c..d87c88a 100644 --- a/src/ir/lower.zig +++ b/src/ir/lower.zig @@ -47,6 +47,7 @@ const lower_call = @import("lower/call.zig"); const lower_pack = @import("lower/pack.zig"); const lower_generic = @import("lower/generic.zig"); const lower_expr = @import("lower/expr.zig"); +const lower_closure = @import("lower/closure.zig"); const TypeId = types.TypeId; const StringId = types.StringId; @@ -501,690 +502,12 @@ pub const Lowering = struct { // ── Public entry point ────────────────────────────────────────── - pub fn lowerLambda(self: *Lowering, lam: *const ast.Lambda) Ref { - // Lower the lambda body as a new anonymous function - var buf: [64]u8 = undefined; - const name = std.fmt.bufPrint(&buf, "__lambda_{d}", .{self.block_counter}) catch "__lambda"; - self.block_counter += 1; - - // Collect lambda param names for exclusion from captures - var param_names = std.StringHashMap(void).init(self.alloc); - defer param_names.deinit(); - for (lam.params) |p| { - param_names.put(p.name, {}) catch {}; - } - - // Pre-scan lambda body AST for free variables (captures) - var captures = std.ArrayList(CaptureInfo).empty; - defer captures.deinit(self.alloc); - self.collectCaptures(lam.body, ¶m_names, &captures); - - // Deduplicate captures - var seen = std.StringHashMap(void).init(self.alloc); - defer seen.deinit(); - var deduped = std.ArrayList(CaptureInfo).empty; - defer deduped.deinit(self.alloc); - for (captures.items) |cap| { - if (!seen.contains(cap.name)) { - seen.put(cap.name, {}) catch {}; - deduped.append(self.alloc, cap) catch {}; - } - } - const capture_list = deduped.items; - - // Build env struct type if there are captures - var env_struct_ty: TypeId = .void; - if (capture_list.len > 0) { - const env_field_data = self.alloc.alloc(types.TypeInfo.StructInfo.Field, capture_list.len) catch unreachable; - for (capture_list, 0..) |cap, i| { - var nbuf: [32]u8 = undefined; - const fname = std.fmt.bufPrint(&nbuf, "cap_{d}", .{i}) catch "cap"; - env_field_data[i] = .{ - .name = self.module.types.internString(fname), - .ty = cap.ty, - }; - } - const env_name = std.fmt.bufPrint(&buf, "__env_{d}", .{self.block_counter}) catch "__env"; - const env_name_id = self.module.types.internString(env_name); - env_struct_ty = self.module.types.intern(.{ .@"struct" = .{ - .name = env_name_id, - .fields = env_field_data, - } }); - } - - // Save current builder state - const saved_func = self.builder.func; - const saved_block = self.builder.current_block; - const saved_counter = self.builder.inst_counter; - const saved_scope = self.scope; - - // Build param list. Convention when implicit_ctx is enabled: - // slot 0 = __sx_ctx: *void - // slot 1 = env: *void - // slot 2+ = user params - // Without implicit_ctx, env is slot 0 and user params follow. - var params = std.ArrayList(Function.Param).empty; - const env_ptr_ty = self.module.types.ptrTo(.void); - const lambda_wants_ctx = self.implicit_ctx_enabled and lam.call_conv != .c; - if (lambda_wants_ctx) { - params.append(self.alloc, .{ - .name = self.module.types.internString("__sx_ctx"), - .ty = env_ptr_ty, - }) catch unreachable; - } - params.append(self.alloc, .{ - .name = self.module.types.internString("env"), - .ty = env_ptr_ty, - }) catch unreachable; - // Get target closure param types for inference (from Closure(T1, T2) -> R annotations) - const target_closure_params: ?[]const TypeId = if (self.target_type) |tt| blk: { - if (!tt.isBuiltin()) { - const tti = self.module.types.get(tt); - if (tti == .closure) break :blk tti.closure.params; - // Unwrap ?Closure(...) → Closure(...) - if (tti == .optional) { - const inner = tti.optional.child; - if (!inner.isBuiltin()) { - const inner_info = self.module.types.get(inner); - if (inner_info == .closure) break :blk inner_info.closure.params; - } - } - } - break :blk null; - } else null; - // User params follow the ctx (optional) + env slots in `params`. - const user_param_base: usize = (if (lambda_wants_ctx) @as(usize, 1) else 0) + 1; - for (lam.params, 0..) |p, pi| { - const pty: TypeId = blk: { - // Unannotated lambda params take their type positionally from - // the target `Closure(T0, …)` signature. Resolve them here so - // `resolveParamType` (which would diagnose a missing annotation) - // is only called for params that carry one. - if (p.type_expr.data == .inferred_type) { - if (target_closure_params != null and pi < target_closure_params.?.len) { - break :blk target_closure_params.?[pi]; - } - if (self.diagnostics) |d| { - d.addFmt(.err, p.type_expr.span, "cannot infer type of lambda parameter '{s}'; annotate it or use the lambda where a closure type is expected", .{p.name}); - } - break :blk .unresolved; - } - break :blk self.resolveParamType(&p); - }; - params.append(self.alloc, .{ - .name = self.module.types.internString(p.name), - .ty = pty, - }) catch unreachable; - } - - const ret_ty = blk: { - if (lam.return_type) |rt| { - break :blk type_bridge.resolveAstType(rt, &self.module.types, &self.program_index.type_alias_map, &self.program_index.module_const_map); - } - // Use target closure return type if available — but only when it's - // a resolved type. An `.unresolved` ret comes from an unbound - // generic (`Closure(..) -> $R`); fall through to infer it from the - // body so the concrete return drives `$R` inference at the call site. - if (self.target_type) |tt| { - if (!tt.isBuiltin()) { - const tti = self.module.types.get(tt); - if (tti == .closure and tti.closure.ret != .unresolved) break :blk tti.closure.ret; - // Unwrap ?Closure(...) → Closure(...) - if (tti == .optional) { - const inner = tti.optional.child; - if (!inner.isBuiltin()) { - const inner_info = self.module.types.get(inner); - if (inner_info == .closure and inner_info.closure.ret != .unresolved) break :blk inner_info.closure.ret; - } - } - } - } - // Arrow lambda without explicit return type — infer from body expression - // Temporarily bind params in scope so inferExprType can resolve param types - var temp_scope = Scope.init(self.alloc, self.scope); - const saved = self.scope; - self.scope = &temp_scope; - for (lam.params, 0..) |p, i| { - const pty = params.items[user_param_base + i].ty; - temp_scope.put(p.name, .{ .ref = @enumFromInt(0), .ty = pty, .is_alloca = false }); - } - const inferred = self.inferExprType(lam.body); - self.scope = saved; - temp_scope.deinit(); - break :blk inferred; - }; - const name_id = self.module.types.internString(name); - const func_id = self.builder.beginFunction(name_id, params.items, ret_ty); - if (lam.call_conv == .c) { - self.module.getFunctionMut(func_id).call_conv = .c; - } - self.builder.currentFunc().has_implicit_ctx = lambda_wants_ctx; - - // Param-slot layout: ctx at 0 (if present), env at ctx_slots, - // user args at ctx_slots+1. - const lambda_ctx_slots: u32 = if (lambda_wants_ctx) 1 else 0; - const env_param_idx: u32 = lambda_ctx_slots; - const user_param_base_lam: u32 = lambda_ctx_slots + 1; - - // Save + rebind current_ctx_ref so the body's sx-to-sx calls - // forward the trampoline's own ctx (slot 0). - const saved_ctx_ref_lam = self.current_ctx_ref; - defer self.current_ctx_ref = saved_ctx_ref_lam; - if (lambda_wants_ctx) self.current_ctx_ref = Ref.fromIndex(0); - - // A lambda is its own function: its `return` must drain only ITS OWN - // `defer`s, not the enclosing function's. Open a fresh defer window - // (like `lowerFunction`/`monomorphizeFunction`) and restore on exit — - // otherwise lowering a closure literal inside a `defer` body re-enters - // the enclosing function's defer drain (infinite recursion — issue 0073). - const saved_func_defer_base = self.func_defer_base; - const saved_defer_len = self.defer_stack.items.len; - defer { - self.func_defer_base = saved_func_defer_base; - self.defer_stack.shrinkRetainingCapacity(saved_defer_len); - } - self.func_defer_base = saved_defer_len; - - // Create entry block - const entry_name = self.module.types.internString("entry"); - const entry = self.builder.appendBlock(entry_name, &.{}); - self.builder.switchToBlock(entry); - - // Create scope WITHOUT parent — captures are bound from env, not parent scope - var lambda_scope = Scope.init(self.alloc, null); - self.scope = &lambda_scope; - - // Bind captures from env struct (at env_param_idx) - if (capture_list.len > 0) { - const env_param_ref = Ref.fromIndex(env_param_idx); - // Alloca env struct locally so struct_gep can resolve the type - const env_local = self.builder.alloca(env_struct_ty); - // Compute env size - const env_byte_size_inner = self.computeEnvSize(capture_list); - const env_size_val = self.builder.constInt(@intCast(env_byte_size_inner), .s64); - // memcpy(local_alloca, env_param, size) - _ = self.callForeign("memcpy", &.{ env_local, env_param_ref, env_size_val }, self.module.types.ptrTo(.void)); - - for (capture_list, 0..) |cap, i| { - // GEP into env struct to get field pointer - const field_ptr = self.builder.structGepTyped(env_local, @intCast(i), self.module.types.ptrTo(cap.ty), env_struct_ty); - // Load the captured value into a local alloca - const loaded = self.builder.load(field_ptr, cap.ty); - const slot = self.builder.alloca(cap.ty); - self.builder.store(slot, loaded); - lambda_scope.put(cap.name, .{ .ref = slot, .ty = cap.ty, .is_alloca = true }); - } - } - - // Also need parent scope for function lookups (but not variable lookups) - // Set up fn_names from parent scope chain - { - var s: ?*Scope = saved_scope; - while (s) |scope| { - var it = scope.fn_names.iterator(); - while (it.next()) |e| { - if (!lambda_scope.fn_names.contains(e.key_ptr.*)) { - lambda_scope.fn_names.put(e.key_ptr.*, e.value_ptr.*) catch {}; - } - } - s = scope.parent; - } - } - - // Bind params (user args start at user_param_base_lam, shifted past ctx + env). - // Use the signature types computed above (`params`), which already - // applied contextual typing from the target closure to untyped params — - // `resolveParamType` alone would drop it and default each to s64. - for (lam.params, 0..) |p, i| { - const pty = params.items[user_param_base + i].ty; - const slot = self.builder.alloca(pty); - const param_ref = Ref.fromIndex(user_param_base_lam + @as(u32, @intCast(i))); - self.builder.store(slot, param_ref); - lambda_scope.put(p.name, .{ .ref = slot, .ty = pty, .is_alloca = true }); - } - - // Lower body — capture last expression as return value. The - // `in_lambda_body` flag scopes the lambda-specific `raise`-not-failable - // hint; save/restore so a lambda nested inside a regular function (or a - // lambda inside a lambda) restores the enclosing context. - const saved_in_lambda = self.in_lambda_body; - self.in_lambda_body = true; - if (ret_ty != .void) { - if (self.lowerBlockValue(lam.body)) |val| { - if (!self.currentBlockHasTerminator()) { - const val_ty = self.builder.getRefType(val); - // A value-carrying failable arrow lambda (`-> (T, !) => expr`) - // yields the bare success value; the compiler appends the - // no-error slot (0) — same as a `return v` in a block body. - if (!ret_ty.isBuiltin() and self.module.types.get(ret_ty) == .tuple and self.errorChannelOf(ret_ty) != null) { - self.lowerFailableSuccessReturn(val, ret_ty, lam.body.span); - } else { - const coerced = if (val_ty != .void) self.coerceToType(val, val_ty, ret_ty) else val; - self.builder.ret(coerced, ret_ty); - } - } - } - } else { - self.lowerBlock(lam.body); - } - self.in_lambda_body = saved_in_lambda; - self.ensureTerminator(ret_ty); - self.builder.finalize(); - - // Restore builder state - self.scope = saved_scope; - lambda_scope.deinit(); - self.builder.func = saved_func; - self.builder.current_block = saved_block; - self.builder.inst_counter = saved_counter; - // Restore the caller's `current_ctx_ref` BEFORE we emit the env - // alloc/memcpy below — those run in the caller's scope, and - // `allocViaContext` reads `current_ctx_ref` to find the - // installed allocator. Without this, the env_heap dispatch - // would still see `Ref.fromIndex(0)` (the lambda's own ctx - // param), which doesn't exist in the caller's frame and - // silently routes through the default context instead of any - // surrounding `push Context.{ allocator = ... }`. - self.current_ctx_ref = saved_ctx_ref_lam; - - // Closure flowing into a BARE function-pointer slot (`(T) -> U`, no env): - // the slot is called without the closure env arg, so the closure fn can't - // be passed directly. For a capture-free closure whose return type matches - // the slot, emit an adapter with the bare ABI. Reject the cases the bare - // ABI can't represent: a capturing closure (env has nowhere to live), and - // a failable closure into a non-failable slot (foreign code can't observe - // the error channel — ERR E5.1 FFI-boundary rule). - if (self.target_type) |tt| { - if (!tt.isBuiltin() and self.module.types.get(tt) == .function) { - const slot_ret = self.module.types.get(tt).function.ret; - const widen_ok = self.errorChannelOf(slot_ret) != null and self.errorChannelOf(ret_ty) == null and self.failableSuccessType(slot_ret) == ret_ty; - if (capture_list.len > 0) { - if (self.diagnostics) |d| d.addFmt(.err, lam.body.span, "a capturing closure cannot be passed as a bare function pointer; declare the parameter type as `Closure(...)` so its environment is carried", .{}); - } else if (ret_ty == slot_ret or widen_ok) { - // Matching ABI, or a non-failable closure widening into a - // failable slot (∅ ⊆ slot set) — the adapter wraps {value, 0}. - const adapter = self.createClosureToBareFnAdapter(func_id, self.module.types.get(tt).function, ret_ty, lam.body.span); - return self.builder.emit(.{ .func_ref = adapter }, tt); - } else if (self.errorChannelOf(ret_ty) != null and self.errorChannelOf(slot_ret) == null) { - if (self.diagnostics) |d| d.addFmt(.err, lam.body.span, "failable closure cannot be assigned to a non-failable function-type slot; foreign code can't observe the error channel — handle the error in a wrapper closure that absorbs it", .{}); - } else if (self.diagnostics) |d| { - d.addFmt(.err, lam.body.span, "closure return type does not match the function-type slot", .{}); - } - } - } - - // Create proper closure type (user-visible params only — skip ctx + env). - const skip_count: usize = if (lambda_wants_ctx) 2 else 1; - var param_types_list = std.ArrayList(TypeId).empty; - for (params.items[skip_count..]) |p| { - param_types_list.append(self.alloc, p.ty) catch unreachable; - } - const closure_ty = self.module.types.closureType(param_types_list.items, ret_ty); - - // Build env and closure in the caller's scope - if (capture_list.len > 0) { - // Alloca env struct on stack (so struct_gep can resolve the type) - const env_local = self.builder.alloca(env_struct_ty); - - // Store captured values into env struct fields - for (capture_list, 0..) |cap, i| { - const gep = self.builder.structGepTyped(env_local, @intCast(i), self.module.types.ptrTo(cap.ty), env_struct_ty); - const val = if (cap.is_alloca) - self.builder.load(cap.ref, cap.ty) - else - cap.ref; - self.builder.store(gep, val); - } - - // Copy env to heap (so it outlives the stack frame). - // Route through `context.allocator.alloc` rather than calling - // libc malloc directly so closures respect a surrounding - // `push Context.{ allocator = ... }` and a tracker / arena - // counts the env allocation alongside everything else. - const env_byte_size = self.computeEnvSize(capture_list); - const env_size = self.builder.constInt(@intCast(env_byte_size), .s64); - const ptr_void = self.module.types.ptrTo(.void); - const env_heap = self.allocViaContext(env_size, ptr_void); - // memcpy(heap, stack_alloca, size) - _ = self.callForeign("memcpy", &.{ env_heap, env_local, env_size }, ptr_void); - - return self.builder.closureCreate(func_id, env_heap, closure_ty); - } else { - return self.builder.closureCreate(func_id, Ref.none, closure_ty); - } - } - - /// Create a trampoline function that wraps a bare function for closure auto-promotion. - /// The trampoline has signature `(env: *void, args...) -> ret` and simply calls the - /// bare function with `(args...)`, ignoring the env parameter. - pub fn createBareFnTrampoline(self: *Lowering, bare_func_id: FuncId, closure_info: types.TypeInfo.ClosureInfo) FuncId { - // Build trampoline params: [__sx_ctx]? + env + closure params. - // When the program uses Context, every sx-side trampoline carries - // the implicit ctx at slot 0 and forwards it to the wrapped - // function (which is also sx-side and expects it at slot 0). - var params = std.ArrayList(inst_mod.Function.Param).empty; - defer params.deinit(self.alloc); - const void_ptr_ty = self.module.types.ptrTo(.void); - const wants_ctx = self.implicit_ctx_enabled; - if (wants_ctx) { - params.append(self.alloc, .{ .name = self.module.types.internString("__sx_ctx"), .ty = void_ptr_ty }) catch unreachable; - } - const env_name = self.module.types.internString("env"); - params.append(self.alloc, .{ .name = env_name, .ty = void_ptr_ty }) catch unreachable; - for (closure_info.params, 0..) |pty, i| { - var buf: [32]u8 = undefined; - const pname = std.fmt.bufPrint(&buf, "a{d}", .{i}) catch "arg"; - params.append(self.alloc, .{ .name = self.module.types.internString(pname), .ty = pty }) catch unreachable; - } - - // Generate unique trampoline name - const bare_func = self.module.functions.items[bare_func_id.index()]; - const bare_name = self.module.types.getString(bare_func.name); - var name_buf: [128]u8 = undefined; - const tramp_name = std.fmt.bufPrint(&name_buf, "__tramp_{s}", .{bare_name}) catch "__tramp"; - const tramp_name_id = self.module.types.internString(tramp_name); - - // Save builder state - const saved_func = self.builder.func; - const saved_block = self.builder.current_block; - const saved_counter = self.builder.inst_counter; - - // Create function - const owned_params = self.alloc.dupe(inst_mod.Function.Param, params.items) catch unreachable; - var func = inst_mod.Function.init(tramp_name_id, owned_params, closure_info.ret); - func.has_implicit_ctx = wants_ctx; - const func_id = self.module.addFunction(func); - self.builder.func = func_id; - self.builder.inst_counter = @intCast(owned_params.len); // params occupy refs 0..N-1 - const entry_name = self.module.types.internString("entry"); - const entry_block = self.builder.appendBlock(entry_name, &.{}); - self.builder.switchToBlock(entry_block); - - // Build call args: forward [__sx_ctx]? + user_params (skip env). - // Trampoline slots: 0=ctx (if present), {0|1}=env, then user args. - const ctx_slots: usize = if (wants_ctx) 1 else 0; - const user_arg_start: u32 = @intCast(ctx_slots + 1); // skip ctx + env - var call_args = std.ArrayList(Ref).empty; - defer call_args.deinit(self.alloc); - if (wants_ctx and bare_func.has_implicit_ctx) { - call_args.append(self.alloc, Ref.fromIndex(0)) catch unreachable; // forward our ctx - } - for (closure_info.params, 0..) |_, i| { - call_args.append(self.alloc, Ref.fromIndex(user_arg_start + @as(u32, @intCast(i)))) catch unreachable; - } - const owned_args = self.alloc.dupe(Ref, call_args.items) catch unreachable; - const result = self.builder.emit(.{ .call = .{ .callee = bare_func_id, .args = owned_args } }, closure_info.ret); - - // Return result (or void) - if (closure_info.ret != .void) { - self.builder.ret(result, closure_info.ret); - } else { - self.builder.retVoid(); - } - self.builder.finalize(); - - // Restore builder state - self.builder.func = saved_func; - self.builder.current_block = saved_block; - self.builder.inst_counter = saved_counter; - - return func_id; - } - - /// Adapter for coercing a closure into a BARE function-pointer slot - /// (`(T) -> U`, no env). The closure's underlying function has signature - /// `[ctx?] + env + user-params`, but a bare fn-ptr slot is *called* without - /// the env arg — so the closure fn can't be used directly (the env slot - /// would swallow the first user arg). This adapter carries the bare ABI - /// (`[ctx?] + user-params`) and forwards to the closure fn with a null env. - /// Only sound for capture-free closures (a null env is correct iff the body - /// reads no captures); the caller rejects capturing closures. - /// - /// When `closure_ret` differs from `fn_info.ret`, this is the ∅-widening - /// case (a non-failable closure into a failable slot): the closure returns - /// the success value and the adapter wraps it into the slot's `{value, 0}` - /// failable tuple (ERR E5.1 non-failable→failable widening). - fn createClosureToBareFnAdapter(self: *Lowering, closure_func_id: FuncId, fn_info: types.TypeInfo.FunctionInfo, closure_ret: TypeId, span: ast.Span) FuncId { - var params = std.ArrayList(inst_mod.Function.Param).empty; - defer params.deinit(self.alloc); - const void_ptr_ty = self.module.types.ptrTo(.void); - const wants_ctx = self.implicit_ctx_enabled; - if (wants_ctx) { - params.append(self.alloc, .{ .name = self.module.types.internString("__sx_ctx"), .ty = void_ptr_ty }) catch unreachable; - } - for (fn_info.params, 0..) |pty, i| { - var buf: [32]u8 = undefined; - const pname = std.fmt.bufPrint(&buf, "a{d}", .{i}) catch "arg"; - params.append(self.alloc, .{ .name = self.module.types.internString(pname), .ty = pty }) catch unreachable; - } - - const closure_func = self.module.functions.items[closure_func_id.index()]; - const closure_name = self.module.types.getString(closure_func.name); - var name_buf: [128]u8 = undefined; - const adapter_name = std.fmt.bufPrint(&name_buf, "__cl2fn_{s}", .{closure_name}) catch "__cl2fn"; - const adapter_name_id = self.module.types.internString(adapter_name); - - const saved_func = self.builder.func; - const saved_block = self.builder.current_block; - const saved_counter = self.builder.inst_counter; - - const owned_params = self.alloc.dupe(inst_mod.Function.Param, params.items) catch unreachable; - var func = inst_mod.Function.init(adapter_name_id, owned_params, fn_info.ret); - func.has_implicit_ctx = wants_ctx; - const func_id = self.module.addFunction(func); - self.builder.func = func_id; - self.builder.inst_counter = @intCast(owned_params.len); - const entry_name = self.module.types.internString("entry"); - const entry_block = self.builder.appendBlock(entry_name, &.{}); - self.builder.switchToBlock(entry_block); - - // Forward [ctx?] + null env + user params to the closure fn. - const ctx_slots: usize = if (wants_ctx) 1 else 0; - var call_args = std.ArrayList(Ref).empty; - defer call_args.deinit(self.alloc); - if (wants_ctx) call_args.append(self.alloc, Ref.fromIndex(0)) catch unreachable; - call_args.append(self.alloc, self.builder.constNull(void_ptr_ty)) catch unreachable; - for (fn_info.params, 0..) |_, i| { - call_args.append(self.alloc, Ref.fromIndex(@intCast(ctx_slots + i))) catch unreachable; - } - const owned_args = self.alloc.dupe(Ref, call_args.items) catch unreachable; - const result = self.builder.emit(.{ .call = .{ .callee = closure_func_id, .args = owned_args } }, closure_ret); - if (closure_ret == fn_info.ret) { - if (fn_info.ret != .void) { - self.builder.ret(result, fn_info.ret); - } else { - self.builder.retVoid(); - } - } else { - // ∅-widening: closure returns the success value; wrap `{value, 0}` - // into the slot's failable tuple. - self.lowerFailableSuccessReturn(result, fn_info.ret, span); - } - self.builder.finalize(); - - self.builder.func = saved_func; - self.builder.current_block = saved_block; - self.builder.inst_counter = saved_counter; - return func_id; - } - - /// Walk an AST node and collect free variable references (identifiers that are - /// in the current scope but not in lambda params). - fn collectCaptures(self: *Lowering, node: *const Node, param_names: *std.StringHashMap(void), captures: *std.ArrayList(CaptureInfo)) void { - switch (node.data) { - .identifier => |id| { - // Skip lambda params - if (param_names.contains(id.name)) return; - // Skip function names - if (self.program_index.fn_ast_map.contains(id.name)) return; - // Skip type names - if (self.program_index.struct_template_map.contains(id.name)) return; - // Check if it's a variable in the parent scope - if (self.scope) |scope| { - if (scope.lookup(id.name)) |binding| { - captures.append(self.alloc, .{ - .name = id.name, - .ty = binding.ty, - .ref = binding.ref, - .is_alloca = binding.is_alloca, - }) catch {}; - } - } - }, - .binary_op => |bo| { - self.collectCaptures(bo.lhs, param_names, captures); - self.collectCaptures(bo.rhs, param_names, captures); - }, - .unary_op => |uo| { - self.collectCaptures(uo.operand, param_names, captures); - }, - .call => |cl| { - self.collectCaptures(cl.callee, param_names, captures); - for (cl.args) |arg| { - self.collectCaptures(arg, param_names, captures); - } - }, - .block => |blk| { - for (blk.stmts) |stmt| { - self.collectCaptures(stmt, param_names, captures); - } - }, - .if_expr => |ie| { - self.collectCaptures(ie.condition, param_names, captures); - self.collectCaptures(ie.then_branch, param_names, captures); - if (ie.else_branch) |eb| self.collectCaptures(eb, param_names, captures); - }, - .while_expr => |we| { - self.collectCaptures(we.condition, param_names, captures); - self.collectCaptures(we.body, param_names, captures); - }, - .return_stmt => |rs| { - if (rs.value) |v| self.collectCaptures(v, param_names, captures); - }, - .var_decl => |vd| { - if (vd.value) |v| self.collectCaptures(v, param_names, captures); - // Register the local var name so it's not captured - param_names.put(vd.name, {}) catch {}; - }, - .const_decl => |cd| { - self.collectCaptures(cd.value, param_names, captures); - param_names.put(cd.name, {}) catch {}; - }, - .assignment => |a| { - self.collectCaptures(a.target, param_names, captures); - self.collectCaptures(a.value, param_names, captures); - }, - .destructure_decl => |dd| { - self.collectCaptures(dd.value, param_names, captures); - for (dd.names) |name| { - param_names.put(name, {}) catch {}; - } - }, - .field_access => |fa| { - self.collectCaptures(fa.object, param_names, captures); - }, - .index_expr => |ie| { - self.collectCaptures(ie.object, param_names, captures); - self.collectCaptures(ie.index, param_names, captures); - }, - .struct_literal => |sl| { - for (sl.field_inits) |fi| { - self.collectCaptures(fi.value, param_names, captures); - } - }, - .array_literal => |al| { - for (al.elements) |elem| { - self.collectCaptures(elem, param_names, captures); - } - }, - .lambda => |inner_lam| { - // For nested lambdas, the inner lambda captures from our scope too - // But its own params should be excluded - var inner_params = std.StringHashMap(void).init(self.alloc); - defer inner_params.deinit(); - // Copy current param_names - var it = param_names.iterator(); - while (it.next()) |e| { - inner_params.put(e.key_ptr.*, {}) catch {}; - } - for (inner_lam.params) |p| { - inner_params.put(p.name, {}) catch {}; - } - self.collectCaptures(inner_lam.body, &inner_params, captures); - }, - .match_expr => |me| { - self.collectCaptures(me.subject, param_names, captures); - for (me.arms) |arm| { - self.collectCaptures(arm.body, param_names, captures); - } - }, - .null_coalesce => |nc| { - self.collectCaptures(nc.lhs, param_names, captures); - self.collectCaptures(nc.rhs, param_names, captures); - }, - .deref_expr => |de| { - self.collectCaptures(de.operand, param_names, captures); - }, - .for_expr => |fe| { - self.collectCaptures(fe.iterable, param_names, captures); - // Register capture name as local so it's not captured - param_names.put(fe.capture_name, {}) catch {}; - self.collectCaptures(fe.body, param_names, captures); - }, - .slice_expr => |se| { - self.collectCaptures(se.object, param_names, captures); - if (se.start) |s| self.collectCaptures(s, param_names, captures); - if (se.end) |e| self.collectCaptures(e, param_names, captures); - }, - .tuple_literal => |tl| { - for (tl.elements) |elem| { - self.collectCaptures(elem.value, param_names, captures); - } - }, - .force_unwrap => |fu| { - self.collectCaptures(fu.operand, param_names, captures); - }, - .chained_comparison => |cc| { - for (cc.operands) |op| { - self.collectCaptures(op, param_names, captures); - } - }, - .defer_stmt => |ds| { - self.collectCaptures(ds.expr, param_names, captures); - }, - .ffi_intrinsic_call => |fic| { - self.collectCaptures(fic.return_type, param_names, captures); - for (fic.args) |arg| { - self.collectCaptures(arg, param_names, captures); - } - }, - else => {}, - } - } - - /// Compute the byte size of the env struct based on captured value types. - fn computeEnvSize(self: *Lowering, capture_list: []const CaptureInfo) usize { - // Must match LLVM's struct layout: fields are aligned to their natural alignment - var offset: usize = 0; - var max_align: usize = 1; - for (capture_list) |cap| { - const field_size = self.typeSizeBytes(cap.ty); - const field_align = self.typeAlignBytes(cap.ty); - if (field_align > max_align) max_align = field_align; - // Align offset to field alignment - offset = (offset + field_align - 1) & ~(field_align - 1); - offset += field_size; - } - // Align total to max field alignment (matches LLVM's struct alignment) - return (offset + max_align - 1) & ~(max_align - 1); - } - /// Byte size of an IR type matching LLVM's type layout. pub fn typeSizeBytes(self: *Lowering, ty: TypeId) usize { return self.module.types.typeSizeBytes(ty); } - fn typeAlignBytes(self: *Lowering, ty: TypeId) usize { + pub fn typeAlignBytes(self: *Lowering, ty: TypeId) usize { return self.module.types.typeAlignBytes(ty); } @@ -2209,7 +1532,7 @@ pub const Lowering = struct { pub const lookupGlobalIdByName = lower_objc_class.lookupGlobalIdByName; // --- moved to lower/call.zig (lower_call) --- - pub const CaptureInfo = lower_call.CaptureInfo; + pub const CaptureInfo = lower_closure.CaptureInfo; pub const lowerCall = lower_call.lowerCall; pub const diagnoseMissingContext = lower_call.diagnoseMissingContext; pub const allocViaContext = lower_call.allocViaContext; @@ -2332,4 +1655,11 @@ pub const Lowering = struct { pub const lowerTupleMembership = lower_expr.lowerTupleMembership; pub const lowerChainedComparison = lower_expr.lowerChainedComparison; pub const emitCmp = lower_expr.emitCmp; + + // --- moved to lower/closure.zig (lower_closure) --- + pub const lowerLambda = lower_closure.lowerLambda; + pub const createBareFnTrampoline = lower_closure.createBareFnTrampoline; + pub const createClosureToBareFnAdapter = lower_closure.createClosureToBareFnAdapter; + pub const collectCaptures = lower_closure.collectCaptures; + pub const computeEnvSize = lower_closure.computeEnvSize; }; diff --git a/src/ir/lower/call.zig b/src/ir/lower/call.zig index 2ed9ca6..052e6d5 100644 --- a/src/ir/lower/call.zig +++ b/src/ir/lower/call.zig @@ -43,7 +43,6 @@ const Function = inst_mod.Function; const Module = mod_mod.Module; const Builder = mod_mod.Builder; - const lower = @import("../lower.zig"); const Lowering = lower.Lowering; const Scope = lower.Scope; @@ -1206,13 +1205,6 @@ pub fn resolveBuiltin(name: []const u8) ?inst_mod.BuiltinId { // ── Lambda/closure ──────────────────────────────────────────── -pub const CaptureInfo = struct { - name: []const u8, - ty: TypeId, - ref: Ref, // alloca or value ref in the parent scope - is_alloca: bool, -}; - /// Build `tp.name -> TypeId` bindings for a generic call. /// `args_ast` must be parallel to `fd.params`; for dot-calls the caller /// prepends the receiver's AST node so positions align with `fd.params[0] = self`. diff --git a/src/ir/lower/closure.zig b/src/ir/lower/closure.zig new file mode 100644 index 0000000..8bc40ed --- /dev/null +++ b/src/ir/lower/closure.zig @@ -0,0 +1,733 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ast = @import("../../ast.zig"); +const Node = ast.Node; +const types = @import("../types.zig"); +const inst_mod = @import("../inst.zig"); +const mod_mod = @import("../module.zig"); +const type_bridge = @import("../type_bridge.zig"); +const unescape = @import("../../unescape.zig"); +const parser_mod = @import("../../parser.zig"); +const interp_mod = @import("../interp.zig"); +const errors = @import("../../errors.zig"); +const jni_descriptor = @import("../jni_descriptor.zig"); +const program_index_mod = @import("../program_index.zig"); +const resolver_mod = @import("../resolver.zig"); +const imports_mod = @import("../../imports.zig"); +const ProgramIndex = program_index_mod.ProgramIndex; +const GlobalInfo = program_index_mod.GlobalInfo; +const StructTemplate = program_index_mod.StructTemplate; +const TemplateParam = program_index_mod.TemplateParam; +const ProtocolDeclInfo = program_index_mod.ProtocolDeclInfo; +const ProtocolMethodInfo = program_index_mod.ProtocolMethodInfo; +const ModuleConstInfo = program_index_mod.ModuleConstInfo; +const TypeResolver = @import("../type_resolver.zig").TypeResolver; +const ResolveEnv = @import("../type_resolver.zig").ResolveEnv; +const PackResolver = @import("../packs.zig").PackResolver; +const ExprTyper = @import("../expr_typer.zig").ExprTyper; +const CallResolver = @import("../calls.zig").CallResolver; +const GenericResolver = @import("../generics.zig").GenericResolver; +const ProtocolResolver = @import("../protocols.zig").ProtocolResolver; +const CoercionResolver = @import("../conversions.zig").CoercionResolver; +const ErrorAnalysis = @import("../error_analysis.zig").ErrorAnalysis; +const ErrorFlow = @import("../error_flow.zig").ErrorFlow; +const ObjcLowering = @import("../ffi_objc.zig").ObjcLowering; +const semantic_diagnostics = @import("../semantic_diagnostics.zig"); + +const TypeId = types.TypeId; +const StringId = types.StringId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Function = inst_mod.Function; +const Module = mod_mod.Module; +const Builder = mod_mod.Builder; + +const lower = @import("../lower.zig"); +const Lowering = lower.Lowering; +const Scope = lower.Scope; + +pub fn lowerLambda(self: *Lowering, lam: *const ast.Lambda) Ref { + // Lower the lambda body as a new anonymous function + var buf: [64]u8 = undefined; + const name = std.fmt.bufPrint(&buf, "__lambda_{d}", .{self.block_counter}) catch "__lambda"; + self.block_counter += 1; + + // Collect lambda param names for exclusion from captures + var param_names = std.StringHashMap(void).init(self.alloc); + defer param_names.deinit(); + for (lam.params) |p| { + param_names.put(p.name, {}) catch {}; + } + + // Pre-scan lambda body AST for free variables (captures) + var captures = std.ArrayList(CaptureInfo).empty; + defer captures.deinit(self.alloc); + self.collectCaptures(lam.body, ¶m_names, &captures); + + // Deduplicate captures + var seen = std.StringHashMap(void).init(self.alloc); + defer seen.deinit(); + var deduped = std.ArrayList(CaptureInfo).empty; + defer deduped.deinit(self.alloc); + for (captures.items) |cap| { + if (!seen.contains(cap.name)) { + seen.put(cap.name, {}) catch {}; + deduped.append(self.alloc, cap) catch {}; + } + } + const capture_list = deduped.items; + + // Build env struct type if there are captures + var env_struct_ty: TypeId = .void; + if (capture_list.len > 0) { + const env_field_data = self.alloc.alloc(types.TypeInfo.StructInfo.Field, capture_list.len) catch unreachable; + for (capture_list, 0..) |cap, i| { + var nbuf: [32]u8 = undefined; + const fname = std.fmt.bufPrint(&nbuf, "cap_{d}", .{i}) catch "cap"; + env_field_data[i] = .{ + .name = self.module.types.internString(fname), + .ty = cap.ty, + }; + } + const env_name = std.fmt.bufPrint(&buf, "__env_{d}", .{self.block_counter}) catch "__env"; + const env_name_id = self.module.types.internString(env_name); + env_struct_ty = self.module.types.intern(.{ .@"struct" = .{ + .name = env_name_id, + .fields = env_field_data, + } }); + } + + // Save current builder state + const saved_func = self.builder.func; + const saved_block = self.builder.current_block; + const saved_counter = self.builder.inst_counter; + const saved_scope = self.scope; + + // Build param list. Convention when implicit_ctx is enabled: + // slot 0 = __sx_ctx: *void + // slot 1 = env: *void + // slot 2+ = user params + // Without implicit_ctx, env is slot 0 and user params follow. + var params = std.ArrayList(Function.Param).empty; + const env_ptr_ty = self.module.types.ptrTo(.void); + const lambda_wants_ctx = self.implicit_ctx_enabled and lam.call_conv != .c; + if (lambda_wants_ctx) { + params.append(self.alloc, .{ + .name = self.module.types.internString("__sx_ctx"), + .ty = env_ptr_ty, + }) catch unreachable; + } + params.append(self.alloc, .{ + .name = self.module.types.internString("env"), + .ty = env_ptr_ty, + }) catch unreachable; + // Get target closure param types for inference (from Closure(T1, T2) -> R annotations) + const target_closure_params: ?[]const TypeId = if (self.target_type) |tt| blk: { + if (!tt.isBuiltin()) { + const tti = self.module.types.get(tt); + if (tti == .closure) break :blk tti.closure.params; + // Unwrap ?Closure(...) → Closure(...) + if (tti == .optional) { + const inner = tti.optional.child; + if (!inner.isBuiltin()) { + const inner_info = self.module.types.get(inner); + if (inner_info == .closure) break :blk inner_info.closure.params; + } + } + } + break :blk null; + } else null; + // User params follow the ctx (optional) + env slots in `params`. + const user_param_base: usize = (if (lambda_wants_ctx) @as(usize, 1) else 0) + 1; + for (lam.params, 0..) |p, pi| { + const pty: TypeId = blk: { + // Unannotated lambda params take their type positionally from + // the target `Closure(T0, …)` signature. Resolve them here so + // `resolveParamType` (which would diagnose a missing annotation) + // is only called for params that carry one. + if (p.type_expr.data == .inferred_type) { + if (target_closure_params != null and pi < target_closure_params.?.len) { + break :blk target_closure_params.?[pi]; + } + if (self.diagnostics) |d| { + d.addFmt(.err, p.type_expr.span, "cannot infer type of lambda parameter '{s}'; annotate it or use the lambda where a closure type is expected", .{p.name}); + } + break :blk .unresolved; + } + break :blk self.resolveParamType(&p); + }; + params.append(self.alloc, .{ + .name = self.module.types.internString(p.name), + .ty = pty, + }) catch unreachable; + } + + const ret_ty = blk: { + if (lam.return_type) |rt| { + break :blk type_bridge.resolveAstType(rt, &self.module.types, &self.program_index.type_alias_map, &self.program_index.module_const_map); + } + // Use target closure return type if available — but only when it's + // a resolved type. An `.unresolved` ret comes from an unbound + // generic (`Closure(..) -> $R`); fall through to infer it from the + // body so the concrete return drives `$R` inference at the call site. + if (self.target_type) |tt| { + if (!tt.isBuiltin()) { + const tti = self.module.types.get(tt); + if (tti == .closure and tti.closure.ret != .unresolved) break :blk tti.closure.ret; + // Unwrap ?Closure(...) → Closure(...) + if (tti == .optional) { + const inner = tti.optional.child; + if (!inner.isBuiltin()) { + const inner_info = self.module.types.get(inner); + if (inner_info == .closure and inner_info.closure.ret != .unresolved) break :blk inner_info.closure.ret; + } + } + } + } + // Arrow lambda without explicit return type — infer from body expression + // Temporarily bind params in scope so inferExprType can resolve param types + var temp_scope = Scope.init(self.alloc, self.scope); + const saved = self.scope; + self.scope = &temp_scope; + for (lam.params, 0..) |p, i| { + const pty = params.items[user_param_base + i].ty; + temp_scope.put(p.name, .{ .ref = @enumFromInt(0), .ty = pty, .is_alloca = false }); + } + const inferred = self.inferExprType(lam.body); + self.scope = saved; + temp_scope.deinit(); + break :blk inferred; + }; + const name_id = self.module.types.internString(name); + const func_id = self.builder.beginFunction(name_id, params.items, ret_ty); + if (lam.call_conv == .c) { + self.module.getFunctionMut(func_id).call_conv = .c; + } + self.builder.currentFunc().has_implicit_ctx = lambda_wants_ctx; + + // Param-slot layout: ctx at 0 (if present), env at ctx_slots, + // user args at ctx_slots+1. + const lambda_ctx_slots: u32 = if (lambda_wants_ctx) 1 else 0; + const env_param_idx: u32 = lambda_ctx_slots; + const user_param_base_lam: u32 = lambda_ctx_slots + 1; + + // Save + rebind current_ctx_ref so the body's sx-to-sx calls + // forward the trampoline's own ctx (slot 0). + const saved_ctx_ref_lam = self.current_ctx_ref; + defer self.current_ctx_ref = saved_ctx_ref_lam; + if (lambda_wants_ctx) self.current_ctx_ref = Ref.fromIndex(0); + + // A lambda is its own function: its `return` must drain only ITS OWN + // `defer`s, not the enclosing function's. Open a fresh defer window + // (like `lowerFunction`/`monomorphizeFunction`) and restore on exit — + // otherwise lowering a closure literal inside a `defer` body re-enters + // the enclosing function's defer drain (infinite recursion — issue 0073). + const saved_func_defer_base = self.func_defer_base; + const saved_defer_len = self.defer_stack.items.len; + defer { + self.func_defer_base = saved_func_defer_base; + self.defer_stack.shrinkRetainingCapacity(saved_defer_len); + } + self.func_defer_base = saved_defer_len; + + // Create entry block + const entry_name = self.module.types.internString("entry"); + const entry = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry); + + // Create scope WITHOUT parent — captures are bound from env, not parent scope + var lambda_scope = Scope.init(self.alloc, null); + self.scope = &lambda_scope; + + // Bind captures from env struct (at env_param_idx) + if (capture_list.len > 0) { + const env_param_ref = Ref.fromIndex(env_param_idx); + // Alloca env struct locally so struct_gep can resolve the type + const env_local = self.builder.alloca(env_struct_ty); + // Compute env size + const env_byte_size_inner = self.computeEnvSize(capture_list); + const env_size_val = self.builder.constInt(@intCast(env_byte_size_inner), .s64); + // memcpy(local_alloca, env_param, size) + _ = self.callForeign("memcpy", &.{ env_local, env_param_ref, env_size_val }, self.module.types.ptrTo(.void)); + + for (capture_list, 0..) |cap, i| { + // GEP into env struct to get field pointer + const field_ptr = self.builder.structGepTyped(env_local, @intCast(i), self.module.types.ptrTo(cap.ty), env_struct_ty); + // Load the captured value into a local alloca + const loaded = self.builder.load(field_ptr, cap.ty); + const slot = self.builder.alloca(cap.ty); + self.builder.store(slot, loaded); + lambda_scope.put(cap.name, .{ .ref = slot, .ty = cap.ty, .is_alloca = true }); + } + } + + // Also need parent scope for function lookups (but not variable lookups) + // Set up fn_names from parent scope chain + { + var s: ?*Scope = saved_scope; + while (s) |scope| { + var it = scope.fn_names.iterator(); + while (it.next()) |e| { + if (!lambda_scope.fn_names.contains(e.key_ptr.*)) { + lambda_scope.fn_names.put(e.key_ptr.*, e.value_ptr.*) catch {}; + } + } + s = scope.parent; + } + } + + // Bind params (user args start at user_param_base_lam, shifted past ctx + env). + // Use the signature types computed above (`params`), which already + // applied contextual typing from the target closure to untyped params — + // `resolveParamType` alone would drop it and default each to s64. + for (lam.params, 0..) |p, i| { + const pty = params.items[user_param_base + i].ty; + const slot = self.builder.alloca(pty); + const param_ref = Ref.fromIndex(user_param_base_lam + @as(u32, @intCast(i))); + self.builder.store(slot, param_ref); + lambda_scope.put(p.name, .{ .ref = slot, .ty = pty, .is_alloca = true }); + } + + // Lower body — capture last expression as return value. The + // `in_lambda_body` flag scopes the lambda-specific `raise`-not-failable + // hint; save/restore so a lambda nested inside a regular function (or a + // lambda inside a lambda) restores the enclosing context. + const saved_in_lambda = self.in_lambda_body; + self.in_lambda_body = true; + if (ret_ty != .void) { + if (self.lowerBlockValue(lam.body)) |val| { + if (!self.currentBlockHasTerminator()) { + const val_ty = self.builder.getRefType(val); + // A value-carrying failable arrow lambda (`-> (T, !) => expr`) + // yields the bare success value; the compiler appends the + // no-error slot (0) — same as a `return v` in a block body. + if (!ret_ty.isBuiltin() and self.module.types.get(ret_ty) == .tuple and self.errorChannelOf(ret_ty) != null) { + self.lowerFailableSuccessReturn(val, ret_ty, lam.body.span); + } else { + const coerced = if (val_ty != .void) self.coerceToType(val, val_ty, ret_ty) else val; + self.builder.ret(coerced, ret_ty); + } + } + } + } else { + self.lowerBlock(lam.body); + } + self.in_lambda_body = saved_in_lambda; + self.ensureTerminator(ret_ty); + self.builder.finalize(); + + // Restore builder state + self.scope = saved_scope; + lambda_scope.deinit(); + self.builder.func = saved_func; + self.builder.current_block = saved_block; + self.builder.inst_counter = saved_counter; + // Restore the caller's `current_ctx_ref` BEFORE we emit the env + // alloc/memcpy below — those run in the caller's scope, and + // `allocViaContext` reads `current_ctx_ref` to find the + // installed allocator. Without this, the env_heap dispatch + // would still see `Ref.fromIndex(0)` (the lambda's own ctx + // param), which doesn't exist in the caller's frame and + // silently routes through the default context instead of any + // surrounding `push Context.{ allocator = ... }`. + self.current_ctx_ref = saved_ctx_ref_lam; + + // Closure flowing into a BARE function-pointer slot (`(T) -> U`, no env): + // the slot is called without the closure env arg, so the closure fn can't + // be passed directly. For a capture-free closure whose return type matches + // the slot, emit an adapter with the bare ABI. Reject the cases the bare + // ABI can't represent: a capturing closure (env has nowhere to live), and + // a failable closure into a non-failable slot (foreign code can't observe + // the error channel — ERR E5.1 FFI-boundary rule). + if (self.target_type) |tt| { + if (!tt.isBuiltin() and self.module.types.get(tt) == .function) { + const slot_ret = self.module.types.get(tt).function.ret; + const widen_ok = self.errorChannelOf(slot_ret) != null and self.errorChannelOf(ret_ty) == null and self.failableSuccessType(slot_ret) == ret_ty; + if (capture_list.len > 0) { + if (self.diagnostics) |d| d.addFmt(.err, lam.body.span, "a capturing closure cannot be passed as a bare function pointer; declare the parameter type as `Closure(...)` so its environment is carried", .{}); + } else if (ret_ty == slot_ret or widen_ok) { + // Matching ABI, or a non-failable closure widening into a + // failable slot (∅ ⊆ slot set) — the adapter wraps {value, 0}. + const adapter = self.createClosureToBareFnAdapter(func_id, self.module.types.get(tt).function, ret_ty, lam.body.span); + return self.builder.emit(.{ .func_ref = adapter }, tt); + } else if (self.errorChannelOf(ret_ty) != null and self.errorChannelOf(slot_ret) == null) { + if (self.diagnostics) |d| d.addFmt(.err, lam.body.span, "failable closure cannot be assigned to a non-failable function-type slot; foreign code can't observe the error channel — handle the error in a wrapper closure that absorbs it", .{}); + } else if (self.diagnostics) |d| { + d.addFmt(.err, lam.body.span, "closure return type does not match the function-type slot", .{}); + } + } + } + + // Create proper closure type (user-visible params only — skip ctx + env). + const skip_count: usize = if (lambda_wants_ctx) 2 else 1; + var param_types_list = std.ArrayList(TypeId).empty; + for (params.items[skip_count..]) |p| { + param_types_list.append(self.alloc, p.ty) catch unreachable; + } + const closure_ty = self.module.types.closureType(param_types_list.items, ret_ty); + + // Build env and closure in the caller's scope + if (capture_list.len > 0) { + // Alloca env struct on stack (so struct_gep can resolve the type) + const env_local = self.builder.alloca(env_struct_ty); + + // Store captured values into env struct fields + for (capture_list, 0..) |cap, i| { + const gep = self.builder.structGepTyped(env_local, @intCast(i), self.module.types.ptrTo(cap.ty), env_struct_ty); + const val = if (cap.is_alloca) + self.builder.load(cap.ref, cap.ty) + else + cap.ref; + self.builder.store(gep, val); + } + + // Copy env to heap (so it outlives the stack frame). + // Route through `context.allocator.alloc` rather than calling + // libc malloc directly so closures respect a surrounding + // `push Context.{ allocator = ... }` and a tracker / arena + // counts the env allocation alongside everything else. + const env_byte_size = self.computeEnvSize(capture_list); + const env_size = self.builder.constInt(@intCast(env_byte_size), .s64); + const ptr_void = self.module.types.ptrTo(.void); + const env_heap = self.allocViaContext(env_size, ptr_void); + // memcpy(heap, stack_alloca, size) + _ = self.callForeign("memcpy", &.{ env_heap, env_local, env_size }, ptr_void); + + return self.builder.closureCreate(func_id, env_heap, closure_ty); + } else { + return self.builder.closureCreate(func_id, Ref.none, closure_ty); + } +} + +/// Create a trampoline function that wraps a bare function for closure auto-promotion. +/// The trampoline has signature `(env: *void, args...) -> ret` and simply calls the +/// bare function with `(args...)`, ignoring the env parameter. +pub fn createBareFnTrampoline(self: *Lowering, bare_func_id: FuncId, closure_info: types.TypeInfo.ClosureInfo) FuncId { + // Build trampoline params: [__sx_ctx]? + env + closure params. + // When the program uses Context, every sx-side trampoline carries + // the implicit ctx at slot 0 and forwards it to the wrapped + // function (which is also sx-side and expects it at slot 0). + var params = std.ArrayList(inst_mod.Function.Param).empty; + defer params.deinit(self.alloc); + const void_ptr_ty = self.module.types.ptrTo(.void); + const wants_ctx = self.implicit_ctx_enabled; + if (wants_ctx) { + params.append(self.alloc, .{ .name = self.module.types.internString("__sx_ctx"), .ty = void_ptr_ty }) catch unreachable; + } + const env_name = self.module.types.internString("env"); + params.append(self.alloc, .{ .name = env_name, .ty = void_ptr_ty }) catch unreachable; + for (closure_info.params, 0..) |pty, i| { + var buf: [32]u8 = undefined; + const pname = std.fmt.bufPrint(&buf, "a{d}", .{i}) catch "arg"; + params.append(self.alloc, .{ .name = self.module.types.internString(pname), .ty = pty }) catch unreachable; + } + + // Generate unique trampoline name + const bare_func = self.module.functions.items[bare_func_id.index()]; + const bare_name = self.module.types.getString(bare_func.name); + var name_buf: [128]u8 = undefined; + const tramp_name = std.fmt.bufPrint(&name_buf, "__tramp_{s}", .{bare_name}) catch "__tramp"; + const tramp_name_id = self.module.types.internString(tramp_name); + + // Save builder state + const saved_func = self.builder.func; + const saved_block = self.builder.current_block; + const saved_counter = self.builder.inst_counter; + + // Create function + const owned_params = self.alloc.dupe(inst_mod.Function.Param, params.items) catch unreachable; + var func = inst_mod.Function.init(tramp_name_id, owned_params, closure_info.ret); + func.has_implicit_ctx = wants_ctx; + const func_id = self.module.addFunction(func); + self.builder.func = func_id; + self.builder.inst_counter = @intCast(owned_params.len); // params occupy refs 0..N-1 + const entry_name = self.module.types.internString("entry"); + const entry_block = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry_block); + + // Build call args: forward [__sx_ctx]? + user_params (skip env). + // Trampoline slots: 0=ctx (if present), {0|1}=env, then user args. + const ctx_slots: usize = if (wants_ctx) 1 else 0; + const user_arg_start: u32 = @intCast(ctx_slots + 1); // skip ctx + env + var call_args = std.ArrayList(Ref).empty; + defer call_args.deinit(self.alloc); + if (wants_ctx and bare_func.has_implicit_ctx) { + call_args.append(self.alloc, Ref.fromIndex(0)) catch unreachable; // forward our ctx + } + for (closure_info.params, 0..) |_, i| { + call_args.append(self.alloc, Ref.fromIndex(user_arg_start + @as(u32, @intCast(i)))) catch unreachable; + } + const owned_args = self.alloc.dupe(Ref, call_args.items) catch unreachable; + const result = self.builder.emit(.{ .call = .{ .callee = bare_func_id, .args = owned_args } }, closure_info.ret); + + // Return result (or void) + if (closure_info.ret != .void) { + self.builder.ret(result, closure_info.ret); + } else { + self.builder.retVoid(); + } + self.builder.finalize(); + + // Restore builder state + self.builder.func = saved_func; + self.builder.current_block = saved_block; + self.builder.inst_counter = saved_counter; + + return func_id; +} + +/// Adapter for coercing a closure into a BARE function-pointer slot +/// (`(T) -> U`, no env). The closure's underlying function has signature +/// `[ctx?] + env + user-params`, but a bare fn-ptr slot is *called* without +/// the env arg — so the closure fn can't be used directly (the env slot +/// would swallow the first user arg). This adapter carries the bare ABI +/// (`[ctx?] + user-params`) and forwards to the closure fn with a null env. +/// Only sound for capture-free closures (a null env is correct iff the body +/// reads no captures); the caller rejects capturing closures. +/// +/// When `closure_ret` differs from `fn_info.ret`, this is the ∅-widening +/// case (a non-failable closure into a failable slot): the closure returns +/// the success value and the adapter wraps it into the slot's `{value, 0}` +/// failable tuple (ERR E5.1 non-failable→failable widening). +pub fn createClosureToBareFnAdapter(self: *Lowering, closure_func_id: FuncId, fn_info: types.TypeInfo.FunctionInfo, closure_ret: TypeId, span: ast.Span) FuncId { + var params = std.ArrayList(inst_mod.Function.Param).empty; + defer params.deinit(self.alloc); + const void_ptr_ty = self.module.types.ptrTo(.void); + const wants_ctx = self.implicit_ctx_enabled; + if (wants_ctx) { + params.append(self.alloc, .{ .name = self.module.types.internString("__sx_ctx"), .ty = void_ptr_ty }) catch unreachable; + } + for (fn_info.params, 0..) |pty, i| { + var buf: [32]u8 = undefined; + const pname = std.fmt.bufPrint(&buf, "a{d}", .{i}) catch "arg"; + params.append(self.alloc, .{ .name = self.module.types.internString(pname), .ty = pty }) catch unreachable; + } + + const closure_func = self.module.functions.items[closure_func_id.index()]; + const closure_name = self.module.types.getString(closure_func.name); + var name_buf: [128]u8 = undefined; + const adapter_name = std.fmt.bufPrint(&name_buf, "__cl2fn_{s}", .{closure_name}) catch "__cl2fn"; + const adapter_name_id = self.module.types.internString(adapter_name); + + const saved_func = self.builder.func; + const saved_block = self.builder.current_block; + const saved_counter = self.builder.inst_counter; + + const owned_params = self.alloc.dupe(inst_mod.Function.Param, params.items) catch unreachable; + var func = inst_mod.Function.init(adapter_name_id, owned_params, fn_info.ret); + func.has_implicit_ctx = wants_ctx; + const func_id = self.module.addFunction(func); + self.builder.func = func_id; + self.builder.inst_counter = @intCast(owned_params.len); + const entry_name = self.module.types.internString("entry"); + const entry_block = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry_block); + + // Forward [ctx?] + null env + user params to the closure fn. + const ctx_slots: usize = if (wants_ctx) 1 else 0; + var call_args = std.ArrayList(Ref).empty; + defer call_args.deinit(self.alloc); + if (wants_ctx) call_args.append(self.alloc, Ref.fromIndex(0)) catch unreachable; + call_args.append(self.alloc, self.builder.constNull(void_ptr_ty)) catch unreachable; + for (fn_info.params, 0..) |_, i| { + call_args.append(self.alloc, Ref.fromIndex(@intCast(ctx_slots + i))) catch unreachable; + } + const owned_args = self.alloc.dupe(Ref, call_args.items) catch unreachable; + const result = self.builder.emit(.{ .call = .{ .callee = closure_func_id, .args = owned_args } }, closure_ret); + if (closure_ret == fn_info.ret) { + if (fn_info.ret != .void) { + self.builder.ret(result, fn_info.ret); + } else { + self.builder.retVoid(); + } + } else { + // ∅-widening: closure returns the success value; wrap `{value, 0}` + // into the slot's failable tuple. + self.lowerFailableSuccessReturn(result, fn_info.ret, span); + } + self.builder.finalize(); + + self.builder.func = saved_func; + self.builder.current_block = saved_block; + self.builder.inst_counter = saved_counter; + return func_id; +} + +/// Walk an AST node and collect free variable references (identifiers that are +/// in the current scope but not in lambda params). +pub fn collectCaptures(self: *Lowering, node: *const Node, param_names: *std.StringHashMap(void), captures: *std.ArrayList(CaptureInfo)) void { + switch (node.data) { + .identifier => |id| { + // Skip lambda params + if (param_names.contains(id.name)) return; + // Skip function names + if (self.program_index.fn_ast_map.contains(id.name)) return; + // Skip type names + if (self.program_index.struct_template_map.contains(id.name)) return; + // Check if it's a variable in the parent scope + if (self.scope) |scope| { + if (scope.lookup(id.name)) |binding| { + captures.append(self.alloc, .{ + .name = id.name, + .ty = binding.ty, + .ref = binding.ref, + .is_alloca = binding.is_alloca, + }) catch {}; + } + } + }, + .binary_op => |bo| { + self.collectCaptures(bo.lhs, param_names, captures); + self.collectCaptures(bo.rhs, param_names, captures); + }, + .unary_op => |uo| { + self.collectCaptures(uo.operand, param_names, captures); + }, + .call => |cl| { + self.collectCaptures(cl.callee, param_names, captures); + for (cl.args) |arg| { + self.collectCaptures(arg, param_names, captures); + } + }, + .block => |blk| { + for (blk.stmts) |stmt| { + self.collectCaptures(stmt, param_names, captures); + } + }, + .if_expr => |ie| { + self.collectCaptures(ie.condition, param_names, captures); + self.collectCaptures(ie.then_branch, param_names, captures); + if (ie.else_branch) |eb| self.collectCaptures(eb, param_names, captures); + }, + .while_expr => |we| { + self.collectCaptures(we.condition, param_names, captures); + self.collectCaptures(we.body, param_names, captures); + }, + .return_stmt => |rs| { + if (rs.value) |v| self.collectCaptures(v, param_names, captures); + }, + .var_decl => |vd| { + if (vd.value) |v| self.collectCaptures(v, param_names, captures); + // Register the local var name so it's not captured + param_names.put(vd.name, {}) catch {}; + }, + .const_decl => |cd| { + self.collectCaptures(cd.value, param_names, captures); + param_names.put(cd.name, {}) catch {}; + }, + .assignment => |a| { + self.collectCaptures(a.target, param_names, captures); + self.collectCaptures(a.value, param_names, captures); + }, + .destructure_decl => |dd| { + self.collectCaptures(dd.value, param_names, captures); + for (dd.names) |name| { + param_names.put(name, {}) catch {}; + } + }, + .field_access => |fa| { + self.collectCaptures(fa.object, param_names, captures); + }, + .index_expr => |ie| { + self.collectCaptures(ie.object, param_names, captures); + self.collectCaptures(ie.index, param_names, captures); + }, + .struct_literal => |sl| { + for (sl.field_inits) |fi| { + self.collectCaptures(fi.value, param_names, captures); + } + }, + .array_literal => |al| { + for (al.elements) |elem| { + self.collectCaptures(elem, param_names, captures); + } + }, + .lambda => |inner_lam| { + // For nested lambdas, the inner lambda captures from our scope too + // But its own params should be excluded + var inner_params = std.StringHashMap(void).init(self.alloc); + defer inner_params.deinit(); + // Copy current param_names + var it = param_names.iterator(); + while (it.next()) |e| { + inner_params.put(e.key_ptr.*, {}) catch {}; + } + for (inner_lam.params) |p| { + inner_params.put(p.name, {}) catch {}; + } + self.collectCaptures(inner_lam.body, &inner_params, captures); + }, + .match_expr => |me| { + self.collectCaptures(me.subject, param_names, captures); + for (me.arms) |arm| { + self.collectCaptures(arm.body, param_names, captures); + } + }, + .null_coalesce => |nc| { + self.collectCaptures(nc.lhs, param_names, captures); + self.collectCaptures(nc.rhs, param_names, captures); + }, + .deref_expr => |de| { + self.collectCaptures(de.operand, param_names, captures); + }, + .for_expr => |fe| { + self.collectCaptures(fe.iterable, param_names, captures); + // Register capture name as local so it's not captured + param_names.put(fe.capture_name, {}) catch {}; + self.collectCaptures(fe.body, param_names, captures); + }, + .slice_expr => |se| { + self.collectCaptures(se.object, param_names, captures); + if (se.start) |s| self.collectCaptures(s, param_names, captures); + if (se.end) |e| self.collectCaptures(e, param_names, captures); + }, + .tuple_literal => |tl| { + for (tl.elements) |elem| { + self.collectCaptures(elem.value, param_names, captures); + } + }, + .force_unwrap => |fu| { + self.collectCaptures(fu.operand, param_names, captures); + }, + .chained_comparison => |cc| { + for (cc.operands) |op| { + self.collectCaptures(op, param_names, captures); + } + }, + .defer_stmt => |ds| { + self.collectCaptures(ds.expr, param_names, captures); + }, + .ffi_intrinsic_call => |fic| { + self.collectCaptures(fic.return_type, param_names, captures); + for (fic.args) |arg| { + self.collectCaptures(arg, param_names, captures); + } + }, + else => {}, + } +} + +/// Compute the byte size of the env struct based on captured value types. +pub fn computeEnvSize(self: *Lowering, capture_list: []const CaptureInfo) usize { + // Must match LLVM's struct layout: fields are aligned to their natural alignment + var offset: usize = 0; + var max_align: usize = 1; + for (capture_list) |cap| { + const field_size = self.typeSizeBytes(cap.ty); + const field_align = self.typeAlignBytes(cap.ty); + if (field_align > max_align) max_align = field_align; + // Align offset to field alignment + offset = (offset + field_align - 1) & ~(field_align - 1); + offset += field_size; + } + // Align total to max field alignment (matches LLVM's struct alignment) + return (offset + max_align - 1) & ~(max_align - 1); +} + +pub const CaptureInfo = struct { + name: []const u8, + ty: TypeId, + ref: Ref, // alloca or value ref in the parent scope + is_alloca: bool, +};