diff --git a/examples/158-pack-mono-dedup.sx b/examples/158-pack-mono-dedup.sx new file mode 100644 index 0000000..82cdf2e --- /dev/null +++ b/examples/158-pack-mono-dedup.sx @@ -0,0 +1,30 @@ +// Variadic heterogeneous type packs — step 2b: per-call-shape +// monomorphisation. Each unique call signature gets ONE mono fn; +// repeat calls with the same signature share it. The runtime output +// confirms correct semantics; the IR (visible via `sx ir`) shows +// the distinct mono symbols: +// +// call @count__pack(ctx) +// call @count__pack_s64(ctx, 1) +// call @count__pack_s64(ctx, 2) ← shares with the 1-arg s64 call +// call @count__pack_s64_s64_s64(ctx, 1, 2, 3) +// call @count__pack_string_bool(ctx, ..) +// +// Before step 2b, each call inlined a fresh copy of the body into +// main's basic block — no shared symbols, IR size grew linearly in +// call sites. After 2b, distinct shapes get distinct functions, +// repeats share, IR scales with unique shapes. + +#import "modules/std.sx"; + +count :: (..$args) -> s64 => args.len; + +main :: () -> s32 { + a := count(); + b := count(1); + c := count(2); + d := count(1, 2, 3); + e := count("x", true); + print("{} {} {} {} {}\n", a, b, c, d, e); + return 0; +} diff --git a/src/ir/lower.zig b/src/ir/lower.zig index 482ae8a..780ba4d 100644 --- a/src/ir/lower.zig +++ b/src/ir/lower.zig @@ -164,6 +164,12 @@ pub const Lowering = struct { /// the call arg's real type instead of `Any`. The `[]Any` slice path /// remains the runtime-indexed fallback for non-literal indices. pack_arg_nodes: ?std.StringHashMap([]const *const Node) = null, + /// Active pack-arity bindings during a pack-fn mono's body lowering. + /// Maps the pack-param name (e.g. `args`) to N. `lowerFieldAccess` + /// uses this to resolve `args.len` to a compile-time constant Ref + /// when no `args` slice is in scope (the mono path doesn't + /// materialise the slice). + pack_param_count: ?std.StringHashMap(u32) = null, struct_const_map: std.StringHashMap(StructConstInfo) = std.StringHashMap(StructConstInfo).init(std.heap.page_allocator), // "Struct.CONST" → value info module_const_map: std.StringHashMap(ModuleConstInfo) = std.StringHashMap(ModuleConstInfo).init(std.heap.page_allocator), // module-level value constants (e.g. AF_INET :s32: 2) foreign_name_map: std.StringHashMap([]const u8) = std.StringHashMap([]const u8).init(std.heap.page_allocator), // sx name → C name for #foreign renames @@ -3675,6 +3681,18 @@ pub const Lowering = struct { } fn lowerFieldAccess(self: *Lowering, fa: *const ast.FieldAccess, span: ast.Span) Ref { + // Pack-arity intercept: `.len` in a pack-fn mono's + // body resolves to the comptime-known N. The mono doesn't + // materialise the `[]Any` slice that the inline path used, so + // `args` isn't in scope as a value. + if (self.pack_param_count) |ppc| { + if (fa.object.data == .identifier and std.mem.eql(u8, fa.field, "len")) { + if (ppc.get(fa.object.data.identifier.name)) |n| { + return self.builder.constInt(@as(i64, @intCast(n)), .s64); + } + } + } + // Check for struct constant access: Struct.CONST if (fa.object.data == .identifier) { const qualified = std.fmt.allocPrint(self.alloc, "{s}.{s}", .{ fa.object.data.identifier.name, fa.field }) catch fa.field; @@ -5423,6 +5441,9 @@ pub const Lowering = struct { }; if (self.fn_ast_map.get(early_name)) |fd| { if (hasComptimeParams(fd)) { + if (isPackFn(fd)) { + return self.lowerPackFnCall(fd, c); + } return self.lowerComptimeCall(fd, c); } // Early detection of generic function calls — skip arg lowering for type params @@ -8045,6 +8066,217 @@ pub const Lowering = struct { return self.builder.constInt(0, .void); } + /// Per-call-shape monomorphisation entry for pack-fns + /// (`isPackFn(fd) == true`). Computes a mangled name from the + /// call-site arg types, builds the mono if it's not cached, and + /// emits a direct call. Pack params expand into N positional IR + /// params with concrete types; the body's `args[]` and + /// `args.len` resolve to those params via the pack bindings. + fn lowerPackFnCall(self: *Lowering, fd: *const ast.FnDecl, call_node: *const ast.Call) Ref { + var arg_types_list = std.ArrayList(TypeId).empty; + defer arg_types_list.deinit(self.alloc); + for (call_node.args) |a| { + arg_types_list.append(self.alloc, self.inferExprType(a)) catch return self.builder.constInt(0, .void); + } + const arg_types = arg_types_list.items; + + // Mangle: `__pack__`. Distinct call shapes + // get distinct symbols; the same shape called repeatedly + // shares one mono. + var name_buf = std.ArrayList(u8).empty; + defer name_buf.deinit(self.alloc); + name_buf.appendSlice(self.alloc, fd.name) catch return self.builder.constInt(0, .void); + name_buf.appendSlice(self.alloc, "__pack") catch return self.builder.constInt(0, .void); + for (arg_types) |t| { + name_buf.append(self.alloc, '_') catch return self.builder.constInt(0, .void); + name_buf.appendSlice(self.alloc, self.mangleTypeName(t)) catch return self.builder.constInt(0, .void); + } + const mangled = name_buf.items; + + if (!self.lowered_functions.contains(mangled)) { + self.monomorphizePackFn(fd, mangled, arg_types); + } + + // Lower args BEFORE re-fetching the func pointer — lowering + // call-site args can trigger more module functions to be + // appended, which reallocates `module.functions.items` and + // invalidates any `&self.module.functions.items[i]` pointer. + var args = std.ArrayList(Ref).empty; + defer args.deinit(self.alloc); + for (call_node.args) |a| { + args.append(self.alloc, self.lowerExpr(a)) catch return self.builder.constInt(0, .void); + } + + const fid = self.resolveFuncByName(mangled) orelse return self.builder.constInt(0, .void); + const func = &self.module.functions.items[@intFromEnum(fid)]; + const ret_ty = func.ret; + const params = func.params; + const final_args = self.prependCtxIfNeeded(func, args.items); + self.coerceCallArgs(final_args, params); + return self.builder.call(fid, final_args, ret_ty); + } + + /// Build a single mono fn for the given pack-fn + concrete arg types. + /// The mono carries N positional pack-params (synthesised names + /// `__pack__`) plus any fixed-prefix non-pack params from + /// the original declaration. The body lowers normally — real + /// `return X;` emits real `ret X`; `args[]` substitutes via + /// `pack_arg_nodes`; `args.len` resolves via `pack_param_count`. + fn monomorphizePackFn(self: *Lowering, fd: *const ast.FnDecl, mangled_name: []const u8, arg_types: []const TypeId) void { + const owned_name = self.alloc.dupe(u8, mangled_name) catch return; + self.lowered_functions.put(owned_name, {}) catch {}; + + // Find the pack param's name and position in fd.params. + var pack_name: []const u8 = ""; + var pack_param_idx: usize = std.math.maxInt(usize); + for (fd.params, 0..) |p, i| { + if (p.is_variadic and p.is_comptime) { + pack_name = p.name; + pack_param_idx = i; + break; + } + } + if (pack_param_idx == std.math.maxInt(usize)) return; + + // Save state — mirrors monomorphizeFunction but also captures + // pack/inline-return state since the mono body must NOT route + // returns through any caller's inline slot. + const saved_func = self.builder.func; + const saved_block = self.builder.current_block; + const saved_counter = self.builder.inst_counter; + const saved_scope = self.scope; + const saved_defer_base = self.func_defer_base; + const saved_block_terminated = self.block_terminated; + const saved_target = self.target_type; + const saved_pan = self.pack_arg_nodes; + const saved_ppc = self.pack_param_count; + const saved_iri = self.inline_return_target; + const saved_ctx_ref = self.current_ctx_ref; + self.func_defer_base = self.defer_stack.items.len; + self.block_terminated = false; + self.inline_return_target = null; + defer { + self.scope = saved_scope; + self.func_defer_base = saved_defer_base; + self.block_terminated = saved_block_terminated; + self.target_type = saved_target; + self.pack_arg_nodes = saved_pan; + self.pack_param_count = saved_ppc; + self.inline_return_target = saved_iri; + self.current_ctx_ref = saved_ctx_ref; + self.builder.func = saved_func; + self.builder.current_block = saved_block; + self.builder.inst_counter = saved_counter; + } + + const ret_ty = self.resolveReturnType(fd); + self.target_type = ret_ty; + + const wants_ctx = self.funcWantsImplicitCtx(fd); + + // Synthesise pack-param names + AST ident nodes used to bind + // `args[]` substitutions during body lowering. + var pack_synth_names = std.ArrayList([]const u8).empty; + defer pack_synth_names.deinit(self.alloc); + var pack_arg_idents = std.ArrayList(*const Node).empty; + defer pack_arg_idents.deinit(self.alloc); + for (arg_types, 0..) |_, i| { + const synth_name = std.fmt.allocPrint(self.alloc, "__pack_{s}_{d}", .{ pack_name, i }) catch return; + pack_synth_names.append(self.alloc, synth_name) catch return; + const ident_node = self.alloc.create(Node) catch return; + ident_node.* = .{ + .span = fd.body.span, + .data = .{ .identifier = .{ .name = synth_name } }, + }; + pack_arg_idents.append(self.alloc, ident_node) catch return; + } + + // Param list: ctx (if needed) + fixed prefix + N pack params. + // NOT deinit'd — `params.items` is stored by reference in + // `Function.init` and read back later via `func.params`. + // Freeing here would leave the function holding a freed slice. + // (Matches the leak convention in `monomorphizeFunction`.) + var params = std.ArrayList(Function.Param).empty; + if (wants_ctx) { + params.append(self.alloc, .{ + .name = self.module.types.internString("__sx_ctx"), + .ty = self.module.types.ptrTo(.void), + }) catch return; + } + for (fd.params, 0..) |p, i| { + if (i == pack_param_idx) continue; + const pty = self.resolveParamType(&p); + params.append(self.alloc, .{ + .name = self.module.types.internString(p.name), + .ty = pty, + }) catch return; + } + for (arg_types, 0..) |ty, i| { + params.append(self.alloc, .{ + .name = self.module.types.internString(pack_synth_names.items[i]), + .ty = ty, + }) catch return; + } + + const name_id = self.module.types.internString(owned_name); + _ = self.builder.beginFunction(name_id, params.items, ret_ty); + self.builder.currentFunc().has_implicit_ctx = wants_ctx; + + const entry_name = self.module.types.internString("entry"); + const entry = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry); + if (wants_ctx) self.current_ctx_ref = Ref.fromIndex(0); + + var scope = Scope.init(self.alloc, null); + defer scope.deinit(); + self.scope = &scope; + + var param_idx: u32 = if (wants_ctx) 1 else 0; + for (fd.params, 0..) |p, i| { + if (i == pack_param_idx) continue; + const pty = self.resolveParamType(&p); + const slot = self.builder.alloca(pty); + self.builder.store(slot, Ref.fromIndex(param_idx)); + scope.put(p.name, .{ .ref = slot, .ty = pty, .is_alloca = true }); + param_idx += 1; + } + for (arg_types, 0..) |ty, i| { + const synth_name = pack_synth_names.items[i]; + const slot = self.builder.alloca(ty); + self.builder.store(slot, Ref.fromIndex(param_idx)); + scope.put(synth_name, .{ .ref = slot, .ty = ty, .is_alloca = true }); + param_idx += 1; + } + + // Install pack bindings for the body lowering. + var pan_map = std.StringHashMap([]const *const Node).init(self.alloc); + defer pan_map.deinit(); + pan_map.put(pack_name, pack_arg_idents.items) catch return; + self.pack_arg_nodes = pan_map; + + var ppc_map = std.StringHashMap(u32).init(self.alloc); + defer ppc_map.deinit(); + ppc_map.put(pack_name, @intCast(arg_types.len)) catch return; + self.pack_param_count = ppc_map; + + if (ret_ty != .void) { + const body_val = self.lowerBlockValue(fd.body); + if (!self.currentBlockHasTerminator()) { + if (body_val) |val| { + const val_ty = self.builder.getRefType(val); + const coerced = if (val_ty != .void) self.coerceToType(val, val_ty, ret_ty) else val; + self.builder.ret(coerced, ret_ty); + } else { + self.ensureTerminator(ret_ty); + } + } + } else { + self.lowerBlock(fd.body); + self.ensureTerminator(ret_ty); + } + self.builder.finalize(); + } + fn monomorphizeFunction(self: *Lowering, fd: *const ast.FnDecl, mangled_name: []const u8, bindings: *std.StringHashMap(TypeId)) void { // Mark as lowered before lowering (prevents infinite recursion) // Need to dupe the name since mangled_name may be stack-allocated @@ -9036,6 +9268,24 @@ pub const Lowering = struct { return false; } + /// Pure pack-fn: the ONLY comptime param is a trailing heterogeneous + /// pack (`is_variadic AND is_comptime`). Detected at call sites that + /// today route to `lowerComptimeCall`; siphoned off to + /// `lowerPackFnCall` for per-call-shape monomorphisation. Mixed + /// `($fmt, ..$args)` stays on the inline path for now — different + /// substitution mechanism for the comptime non-pack param. + fn isPackFn(fd: *const ast.FnDecl) bool { + var seen_pack = false; + for (fd.params) |p| { + if (p.is_comptime and p.is_variadic) { + seen_pack = true; + } else if (p.is_comptime) { + return false; // mixed — defer to lowerComptimeCall + } + } + return seen_pack; + } + /// Creates a temporary function marked `is_comptime = true` that wraps /// the given expression as its return value. Returns the FuncId. pub fn createComptimeFunction(self: *Lowering, prefix: []const u8, expr: *const Node, ret_ty: TypeId) FuncId { @@ -11038,6 +11288,14 @@ pub const Lowering = struct { return .s64; }, .field_access => |fa| { + // Pack-arity intercept: `.len` is s64. Mirrors + // the lowerFieldAccess intercept so AST-level type + // inference picks the same shape. + if (self.pack_param_count) |ppc| { + if (fa.object.data == .identifier and std.mem.eql(u8, fa.field, "len")) { + if (ppc.contains(fa.object.data.identifier.name)) return .s64; + } + } // M1.3 — `obj.class` on an Obj-C-class pointer returns Class (*void). if (std.mem.eql(u8, fa.field, "class")) { if (self.isObjcClassPointer(self.inferExprType(fa.object))) { diff --git a/tests/expected/158-pack-mono-dedup.exit b/tests/expected/158-pack-mono-dedup.exit new file mode 100644 index 0000000..573541a --- /dev/null +++ b/tests/expected/158-pack-mono-dedup.exit @@ -0,0 +1 @@ +0 diff --git a/tests/expected/158-pack-mono-dedup.txt b/tests/expected/158-pack-mono-dedup.txt new file mode 100644 index 0000000..afac048 --- /dev/null +++ b/tests/expected/158-pack-mono-dedup.txt @@ -0,0 +1 @@ +0 1 1 3 2