ffi M5.A.next.2b: per-call-shape monomorphisation for pack-fns

Pack-fns (`isPackFn(fd) == true` — last param `is_variadic AND
is_comptime`, no other comptime params) now emit ONE
monomorphised function per unique call-site signature. Repeat
calls with the same arg-type tuple share the mono; distinct
shapes get distinct symbols. Pre-2b each call inlined a fresh
body copy into the caller's basic block; IR size grew linearly
in call sites.

Plumbing in `src/ir/lower.zig`:

- `isPackFn(fd)` — true when the only comptime param is a
  trailing pack. Mixed `($fmt, ..$args)` shapes stay on the
  inline `lowerComptimeCall` path (different substitution
  mechanism for the comptime non-pack param; deferred).
- `lowerPackFnCall(fd, call_node)`:
  - Builds a mangled name `<fn_name>__pack__<arg_types>` from
    call-site `inferExprType` results. Distinct shapes get
    distinct symbols.
  - Cache-checks `lowered_functions`; calls
    `monomorphizePackFn` on miss.
  - Lowers call args, then re-fetches the func pointer (the
    fetch BEFORE arg lowering would invalidate after any
    transitively-triggered module.functions.items realloc),
    prepends ctx if needed, coerces, emits direct call.
- `monomorphizePackFn(fd, mangled, arg_types)`:
  - Mirrors `monomorphizeFunction` for the standard fn build:
    save state, build param list (ctx + fixed prefix + N pack
    params with synthesised names `__pack_<name>_<i>`),
    `beginFunction`, entry block, bind params to scope.
  - Installs `pack_arg_nodes[<name>]` with synthesised AST
    identifier nodes pointing at the pack-param slots so the
    body's `args[<int_literal>]` substitutes through the
    existing 2a.B mechanism — substitution resolves to the
    mono's own param slot loads.
  - Installs `pack_param_count[<name>] = N` so the body's
    `args.len` resolves to a compile-time constant via a new
    intercept in `lowerFieldAccess` (and the parallel arm in
    `inferExprType`).
  - Lowers the body with `inline_return_target = null` so
    `return X;` emits a real `ret X` instead of the inline-slot
    routing — the mono is a real fn now.
- Routed at three call sites: each `if (hasComptimeParams(fd))
  { return self.lowerComptimeCall(...); }` now first checks
  `isPackFn(fd)` and routes to `lowerPackFnCall` when true.

Lifetime gotcha caught and fixed: `params.items` is stored by
reference in `Function.init` (no copy), so the local
`ArrayList(Function.Param)` must NOT be deinit'd in
`monomorphizePackFn` — matches the leak convention already used
by `monomorphizeFunction`.

`examples/158-pack-mono-dedup.sx` confirms the dedup
end-to-end: `count(), count(1), count(2), count(1,2,3),
count("x", true)` produces `0 1 1 3 2` at runtime AND emits
exactly 4 monos in IR (`count__pack`, `count__pack_s64`,
`count__pack_s64_s64_s64`, `count__pack_string_bool`) — the
two s64 calls share. `args.len` resolves to the comptime
constant N inside each mono.

`examples/156-pack-typed-index.sx` and
`examples/157-pack-if-return.sx` continue to pass unchanged.

Out of scope:
- Mixed `$fmt + ..$args` shapes (stays on inline path).
- Generic `$R` return types (concrete returns only).
- Bare `args` reference (passing the slice as a whole).
- `args[<runtime_int>]` (non-literal index).

197/197 example tests + `zig build test` green.
This commit is contained in:
agra
2026-05-27 15:44:05 +03:00
parent 39a804f25e
commit 79896188eb
4 changed files with 290 additions and 0 deletions

View File

@@ -0,0 +1,30 @@
// Variadic heterogeneous type packs — step 2b: per-call-shape
// monomorphisation. Each unique call signature gets ONE mono fn;
// repeat calls with the same signature share it. The runtime output
// confirms correct semantics; the IR (visible via `sx ir`) shows
// the distinct mono symbols:
//
// call @count__pack(ctx)
// call @count__pack_s64(ctx, 1)
// call @count__pack_s64(ctx, 2) ← shares with the 1-arg s64 call
// call @count__pack_s64_s64_s64(ctx, 1, 2, 3)
// call @count__pack_string_bool(ctx, ..)
//
// Before step 2b, each call inlined a fresh copy of the body into
// main's basic block — no shared symbols, IR size grew linearly in
// call sites. After 2b, distinct shapes get distinct functions,
// repeats share, IR scales with unique shapes.
#import "modules/std.sx";
count :: (..$args) -> s64 => args.len;
main :: () -> s32 {
a := count();
b := count(1);
c := count(2);
d := count(1, 2, 3);
e := count("x", true);
print("{} {} {} {} {}\n", a, b, c, d, e);
return 0;
}

View File

@@ -164,6 +164,12 @@ pub const Lowering = struct {
/// the call arg's real type instead of `Any`. The `[]Any` slice path
/// remains the runtime-indexed fallback for non-literal indices.
pack_arg_nodes: ?std.StringHashMap([]const *const Node) = null,
/// Active pack-arity bindings during a pack-fn mono's body lowering.
/// Maps the pack-param name (e.g. `args`) to N. `lowerFieldAccess`
/// uses this to resolve `args.len` to a compile-time constant Ref
/// when no `args` slice is in scope (the mono path doesn't
/// materialise the slice).
pack_param_count: ?std.StringHashMap(u32) = null,
struct_const_map: std.StringHashMap(StructConstInfo) = std.StringHashMap(StructConstInfo).init(std.heap.page_allocator), // "Struct.CONST" → value info
module_const_map: std.StringHashMap(ModuleConstInfo) = std.StringHashMap(ModuleConstInfo).init(std.heap.page_allocator), // module-level value constants (e.g. AF_INET :s32: 2)
foreign_name_map: std.StringHashMap([]const u8) = std.StringHashMap([]const u8).init(std.heap.page_allocator), // sx name → C name for #foreign renames
@@ -3675,6 +3681,18 @@ pub const Lowering = struct {
}
fn lowerFieldAccess(self: *Lowering, fa: *const ast.FieldAccess, span: ast.Span) Ref {
// Pack-arity intercept: `<pack_name>.len` in a pack-fn mono's
// body resolves to the comptime-known N. The mono doesn't
// materialise the `[]Any` slice that the inline path used, so
// `args` isn't in scope as a value.
if (self.pack_param_count) |ppc| {
if (fa.object.data == .identifier and std.mem.eql(u8, fa.field, "len")) {
if (ppc.get(fa.object.data.identifier.name)) |n| {
return self.builder.constInt(@as(i64, @intCast(n)), .s64);
}
}
}
// Check for struct constant access: Struct.CONST
if (fa.object.data == .identifier) {
const qualified = std.fmt.allocPrint(self.alloc, "{s}.{s}", .{ fa.object.data.identifier.name, fa.field }) catch fa.field;
@@ -5423,6 +5441,9 @@ pub const Lowering = struct {
};
if (self.fn_ast_map.get(early_name)) |fd| {
if (hasComptimeParams(fd)) {
if (isPackFn(fd)) {
return self.lowerPackFnCall(fd, c);
}
return self.lowerComptimeCall(fd, c);
}
// Early detection of generic function calls — skip arg lowering for type params
@@ -8045,6 +8066,217 @@ pub const Lowering = struct {
return self.builder.constInt(0, .void);
}
/// Per-call-shape monomorphisation entry for pack-fns
/// (`isPackFn(fd) == true`). Computes a mangled name from the
/// call-site arg types, builds the mono if it's not cached, and
/// emits a direct call. Pack params expand into N positional IR
/// params with concrete types; the body's `args[<lit>]` and
/// `args.len` resolve to those params via the pack bindings.
fn lowerPackFnCall(self: *Lowering, fd: *const ast.FnDecl, call_node: *const ast.Call) Ref {
var arg_types_list = std.ArrayList(TypeId).empty;
defer arg_types_list.deinit(self.alloc);
for (call_node.args) |a| {
arg_types_list.append(self.alloc, self.inferExprType(a)) catch return self.builder.constInt(0, .void);
}
const arg_types = arg_types_list.items;
// Mangle: `<fn_name>__pack__<arg_types>`. Distinct call shapes
// get distinct symbols; the same shape called repeatedly
// shares one mono.
var name_buf = std.ArrayList(u8).empty;
defer name_buf.deinit(self.alloc);
name_buf.appendSlice(self.alloc, fd.name) catch return self.builder.constInt(0, .void);
name_buf.appendSlice(self.alloc, "__pack") catch return self.builder.constInt(0, .void);
for (arg_types) |t| {
name_buf.append(self.alloc, '_') catch return self.builder.constInt(0, .void);
name_buf.appendSlice(self.alloc, self.mangleTypeName(t)) catch return self.builder.constInt(0, .void);
}
const mangled = name_buf.items;
if (!self.lowered_functions.contains(mangled)) {
self.monomorphizePackFn(fd, mangled, arg_types);
}
// Lower args BEFORE re-fetching the func pointer — lowering
// call-site args can trigger more module functions to be
// appended, which reallocates `module.functions.items` and
// invalidates any `&self.module.functions.items[i]` pointer.
var args = std.ArrayList(Ref).empty;
defer args.deinit(self.alloc);
for (call_node.args) |a| {
args.append(self.alloc, self.lowerExpr(a)) catch return self.builder.constInt(0, .void);
}
const fid = self.resolveFuncByName(mangled) orelse return self.builder.constInt(0, .void);
const func = &self.module.functions.items[@intFromEnum(fid)];
const ret_ty = func.ret;
const params = func.params;
const final_args = self.prependCtxIfNeeded(func, args.items);
self.coerceCallArgs(final_args, params);
return self.builder.call(fid, final_args, ret_ty);
}
/// Build a single mono fn for the given pack-fn + concrete arg types.
/// The mono carries N positional pack-params (synthesised names
/// `__pack_<name>_<i>`) plus any fixed-prefix non-pack params from
/// the original declaration. The body lowers normally — real
/// `return X;` emits real `ret X`; `args[<lit>]` substitutes via
/// `pack_arg_nodes`; `args.len` resolves via `pack_param_count`.
fn monomorphizePackFn(self: *Lowering, fd: *const ast.FnDecl, mangled_name: []const u8, arg_types: []const TypeId) void {
const owned_name = self.alloc.dupe(u8, mangled_name) catch return;
self.lowered_functions.put(owned_name, {}) catch {};
// Find the pack param's name and position in fd.params.
var pack_name: []const u8 = "";
var pack_param_idx: usize = std.math.maxInt(usize);
for (fd.params, 0..) |p, i| {
if (p.is_variadic and p.is_comptime) {
pack_name = p.name;
pack_param_idx = i;
break;
}
}
if (pack_param_idx == std.math.maxInt(usize)) return;
// Save state — mirrors monomorphizeFunction but also captures
// pack/inline-return state since the mono body must NOT route
// returns through any caller's inline slot.
const saved_func = self.builder.func;
const saved_block = self.builder.current_block;
const saved_counter = self.builder.inst_counter;
const saved_scope = self.scope;
const saved_defer_base = self.func_defer_base;
const saved_block_terminated = self.block_terminated;
const saved_target = self.target_type;
const saved_pan = self.pack_arg_nodes;
const saved_ppc = self.pack_param_count;
const saved_iri = self.inline_return_target;
const saved_ctx_ref = self.current_ctx_ref;
self.func_defer_base = self.defer_stack.items.len;
self.block_terminated = false;
self.inline_return_target = null;
defer {
self.scope = saved_scope;
self.func_defer_base = saved_defer_base;
self.block_terminated = saved_block_terminated;
self.target_type = saved_target;
self.pack_arg_nodes = saved_pan;
self.pack_param_count = saved_ppc;
self.inline_return_target = saved_iri;
self.current_ctx_ref = saved_ctx_ref;
self.builder.func = saved_func;
self.builder.current_block = saved_block;
self.builder.inst_counter = saved_counter;
}
const ret_ty = self.resolveReturnType(fd);
self.target_type = ret_ty;
const wants_ctx = self.funcWantsImplicitCtx(fd);
// Synthesise pack-param names + AST ident nodes used to bind
// `args[<lit>]` substitutions during body lowering.
var pack_synth_names = std.ArrayList([]const u8).empty;
defer pack_synth_names.deinit(self.alloc);
var pack_arg_idents = std.ArrayList(*const Node).empty;
defer pack_arg_idents.deinit(self.alloc);
for (arg_types, 0..) |_, i| {
const synth_name = std.fmt.allocPrint(self.alloc, "__pack_{s}_{d}", .{ pack_name, i }) catch return;
pack_synth_names.append(self.alloc, synth_name) catch return;
const ident_node = self.alloc.create(Node) catch return;
ident_node.* = .{
.span = fd.body.span,
.data = .{ .identifier = .{ .name = synth_name } },
};
pack_arg_idents.append(self.alloc, ident_node) catch return;
}
// Param list: ctx (if needed) + fixed prefix + N pack params.
// NOT deinit'd — `params.items` is stored by reference in
// `Function.init` and read back later via `func.params`.
// Freeing here would leave the function holding a freed slice.
// (Matches the leak convention in `monomorphizeFunction`.)
var params = std.ArrayList(Function.Param).empty;
if (wants_ctx) {
params.append(self.alloc, .{
.name = self.module.types.internString("__sx_ctx"),
.ty = self.module.types.ptrTo(.void),
}) catch return;
}
for (fd.params, 0..) |p, i| {
if (i == pack_param_idx) continue;
const pty = self.resolveParamType(&p);
params.append(self.alloc, .{
.name = self.module.types.internString(p.name),
.ty = pty,
}) catch return;
}
for (arg_types, 0..) |ty, i| {
params.append(self.alloc, .{
.name = self.module.types.internString(pack_synth_names.items[i]),
.ty = ty,
}) catch return;
}
const name_id = self.module.types.internString(owned_name);
_ = self.builder.beginFunction(name_id, params.items, ret_ty);
self.builder.currentFunc().has_implicit_ctx = wants_ctx;
const entry_name = self.module.types.internString("entry");
const entry = self.builder.appendBlock(entry_name, &.{});
self.builder.switchToBlock(entry);
if (wants_ctx) self.current_ctx_ref = Ref.fromIndex(0);
var scope = Scope.init(self.alloc, null);
defer scope.deinit();
self.scope = &scope;
var param_idx: u32 = if (wants_ctx) 1 else 0;
for (fd.params, 0..) |p, i| {
if (i == pack_param_idx) continue;
const pty = self.resolveParamType(&p);
const slot = self.builder.alloca(pty);
self.builder.store(slot, Ref.fromIndex(param_idx));
scope.put(p.name, .{ .ref = slot, .ty = pty, .is_alloca = true });
param_idx += 1;
}
for (arg_types, 0..) |ty, i| {
const synth_name = pack_synth_names.items[i];
const slot = self.builder.alloca(ty);
self.builder.store(slot, Ref.fromIndex(param_idx));
scope.put(synth_name, .{ .ref = slot, .ty = ty, .is_alloca = true });
param_idx += 1;
}
// Install pack bindings for the body lowering.
var pan_map = std.StringHashMap([]const *const Node).init(self.alloc);
defer pan_map.deinit();
pan_map.put(pack_name, pack_arg_idents.items) catch return;
self.pack_arg_nodes = pan_map;
var ppc_map = std.StringHashMap(u32).init(self.alloc);
defer ppc_map.deinit();
ppc_map.put(pack_name, @intCast(arg_types.len)) catch return;
self.pack_param_count = ppc_map;
if (ret_ty != .void) {
const body_val = self.lowerBlockValue(fd.body);
if (!self.currentBlockHasTerminator()) {
if (body_val) |val| {
const val_ty = self.builder.getRefType(val);
const coerced = if (val_ty != .void) self.coerceToType(val, val_ty, ret_ty) else val;
self.builder.ret(coerced, ret_ty);
} else {
self.ensureTerminator(ret_ty);
}
}
} else {
self.lowerBlock(fd.body);
self.ensureTerminator(ret_ty);
}
self.builder.finalize();
}
fn monomorphizeFunction(self: *Lowering, fd: *const ast.FnDecl, mangled_name: []const u8, bindings: *std.StringHashMap(TypeId)) void {
// Mark as lowered before lowering (prevents infinite recursion)
// Need to dupe the name since mangled_name may be stack-allocated
@@ -9036,6 +9268,24 @@ pub const Lowering = struct {
return false;
}
/// Pure pack-fn: the ONLY comptime param is a trailing heterogeneous
/// pack (`is_variadic AND is_comptime`). Detected at call sites that
/// today route to `lowerComptimeCall`; siphoned off to
/// `lowerPackFnCall` for per-call-shape monomorphisation. Mixed
/// `($fmt, ..$args)` stays on the inline path for now — different
/// substitution mechanism for the comptime non-pack param.
fn isPackFn(fd: *const ast.FnDecl) bool {
var seen_pack = false;
for (fd.params) |p| {
if (p.is_comptime and p.is_variadic) {
seen_pack = true;
} else if (p.is_comptime) {
return false; // mixed — defer to lowerComptimeCall
}
}
return seen_pack;
}
/// Creates a temporary function marked `is_comptime = true` that wraps
/// the given expression as its return value. Returns the FuncId.
pub fn createComptimeFunction(self: *Lowering, prefix: []const u8, expr: *const Node, ret_ty: TypeId) FuncId {
@@ -11038,6 +11288,14 @@ pub const Lowering = struct {
return .s64;
},
.field_access => |fa| {
// Pack-arity intercept: `<pack_name>.len` is s64. Mirrors
// the lowerFieldAccess intercept so AST-level type
// inference picks the same shape.
if (self.pack_param_count) |ppc| {
if (fa.object.data == .identifier and std.mem.eql(u8, fa.field, "len")) {
if (ppc.contains(fa.object.data.identifier.name)) return .s64;
}
}
// M1.3 — `obj.class` on an Obj-C-class pointer returns Class (*void).
if (std.mem.eql(u8, fa.field, "class")) {
if (self.isObjcClassPointer(self.inferExprType(fa.object))) {

View File

@@ -0,0 +1 @@
0

View File

@@ -0,0 +1 @@
0 1 1 3 2