From 3ca68189c0b99c7b76d8f182ff08a85690ea7dc2 Mon Sep 17 00:00:00 2001 From: agra Date: Tue, 2 Jun 2026 21:28:31 +0300 Subject: [PATCH] refactor(ir): extract GenericResolver (generics.zig) for substitution + mono keys (A4.1 step 2) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Generic substitution and monomorphization-key construction now live in one module, src/ir/generics.zig, behind a *Lowering facade (GenericResolver), mirroring CallResolver / ExprTyper. Moved verbatim: - mangleTypeName + mangleParamList (the mono-key fragment builder), - mangleGenericName (generic mono key), appendComptimeValueMangle (comptime-value fragment), - buildTypeBindings (call-site type-param inference), inferGenericReturnType (generic return resolution). inferGenericReturnType now uses a scoped TypeBindingScope (enter/exit with defer) instead of a manual type_bindings save/restore — the PLAN-ARCH A4.1 "scoped substitution env" shape; a generics.test.zig assertion confirms the prior bindings are restored (the issue-0048/0050 leak class, for this field). Lowering keeps a thin pub mangleTypeName wrapper delegating to genericResolver().mangleTypeName, because ~30 cross-cutting callers (impl-map keys, conversion keys, shape keys) reach it well beyond generics. mangleParamList (sole caller was mangleTypeName) moved fully. The other 4 originals are deleted (no fallback); their 6 call sites now go through self.genericResolver() (calls.zig via self.l.genericResolver()). matchTypeParam / extractTypeParam / isTypeParamDecl widened to pub (the moved substitution logic calls them); genericResolver() accessor added. The 2 mangleTypeName / inferGenericReturnType unit tests moved from lower.test.zig to generics.test.zig (driving GenericResolver directly) and wired into the barrel. monomorphizeFunction / monomorphizePackFn intentionally stay in lower.zig (they save/restore three fields across nested mono and call emission helpers) — a heavier scoped-env adoption deferred to an optional sub-step 3. zig build, zig build test, and tests/run_examples.sh (357/0) all green — no .ir snapshot churn, confirming the move preserved mono-key/substitution output. --- src/ir/calls.zig | 2 +- src/ir/generics.test.zig | 111 +++++++++++++ src/ir/generics.zig | 331 +++++++++++++++++++++++++++++++++++++++ src/ir/ir.zig | 3 + src/ir/lower.test.zig | 89 ----------- src/ir/lower.zig | 296 ++-------------------------------- 6 files changed, 462 insertions(+), 370 deletions(-) create mode 100644 src/ir/generics.test.zig create mode 100644 src/ir/generics.zig diff --git a/src/ir/calls.zig b/src/ir/calls.zig index 5214d5c6..25302b08 100644 --- a/src/ir/calls.zig +++ b/src/ir/calls.zig @@ -153,7 +153,7 @@ pub const CallResolver = struct { if (fd.type_params.len > 0) { return .{ .kind = .generic_fn, - .return_type = self.l.inferGenericReturnType(fd, c), + .return_type = self.l.genericResolver().inferGenericReturnType(fd, c), .target = .{ .named = name }, .expands_defaults = defaultsFor(fd, c.args.len), }; diff --git a/src/ir/generics.test.zig b/src/ir/generics.test.zig new file mode 100644 index 00000000..5548b38e --- /dev/null +++ b/src/ir/generics.test.zig @@ -0,0 +1,111 @@ +// Tests for generics.zig — the generic substitution + mono-key owner +// (`GenericResolver`). Reached via `ir.GenericResolver{ .l = &lowering }`, +// mirroring how calls.test.zig drives `CallResolver`. Moved here from +// lower.test.zig when the helpers moved out of `Lowering` (A4.1 sub-step 2). + +const std = @import("std"); +const ast = @import("../ast.zig"); +const Node = ast.Node; + +const ir_mod = @import("ir.zig"); +const TypeId = ir_mod.TypeId; +const Lowering = ir_mod.Lowering; +const GenericResolver = ir_mod.GenericResolver; + +fn typeKeyword(alloc: std.mem.Allocator, name: []const u8) *Node { + const n = alloc.create(Node) catch unreachable; + n.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = name, .is_generic = false } } }; + return n; +} + +test "generics: mangleTypeName encodes the mono-key fragment per type shape" { + // Arena: the compound arms allocate fragment strings via the module + // allocator (`allocPrint` / ArrayList) and never free them — the real + // compiler runs in the compile arena, so an arena keeps the leak checker + // clean without changing the encoding under test. + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + const alloc = arena.allocator(); + var module = ir_mod.Module.init(alloc); + defer module.deinit(); + var l = Lowering.init(&module); + const gr = GenericResolver{ .l = &l }; + const tt = &module.types; + + // Builtins — the leaf fragments `mangleGenericName` concatenates per + // bound type param (`base__...`). + try std.testing.expectEqualStrings("s64", gr.mangleTypeName(.s64)); + try std.testing.expectEqualStrings("u8", gr.mangleTypeName(.u8)); + try std.testing.expectEqualStrings("f32", gr.mangleTypeName(.f32)); + try std.testing.expectEqualStrings("bool", gr.mangleTypeName(.bool)); + try std.testing.expectEqualStrings("Any", gr.mangleTypeName(.any)); + try std.testing.expectEqualStrings("string", gr.mangleTypeName(.string)); + + // Compound shapes — prefix + recursive inner fragment. + try std.testing.expectEqualStrings("ptr_s64", gr.mangleTypeName(tt.ptrTo(.s64))); + try std.testing.expectEqualStrings("opt_s64", gr.mangleTypeName(tt.optionalOf(.s64))); + try std.testing.expectEqualStrings("ptr_opt_u8", gr.mangleTypeName(tt.ptrTo(tt.optionalOf(.u8)))); + try std.testing.expectEqualStrings("SL_f64", gr.mangleTypeName(tt.intern(.{ .slice = .{ .element = .f64 } }))); + try std.testing.expectEqualStrings("mptr_u8", gr.mangleTypeName(tt.intern(.{ .many_pointer = .{ .element = .u8 } }))); + try std.testing.expectEqualStrings("AR_4_s32", gr.mangleTypeName(tt.intern(.{ .array = .{ .element = .s32, .length = 4 } }))); + try std.testing.expectEqualStrings("vec_3_f32", gr.mangleTypeName(tt.intern(.{ .vector = .{ .element = .f32, .length = 3 } }))); + + // Named aggregate → its declared name. + const pt = tt.intern(.{ .@"struct" = .{ .name = tt.internString("Point"), .fields = &.{} } }); + try std.testing.expectEqualStrings("Point", gr.mangleTypeName(pt)); + + // Tuple: "tu" + "_" per field. + const tup = tt.intern(.{ .tuple = .{ .fields = &[_]TypeId{ .s64, .bool }, .names = null } }); + try std.testing.expectEqualStrings("tu_s64_bool", gr.mangleTypeName(tup)); + + // The `Lowering` wrapper delegates here — same result. + try std.testing.expectEqualStrings("ptr_s64", l.mangleTypeName(tt.ptrTo(.s64))); +} + +test "generics: inferGenericReturnType binds explicit type args, resolves return, restores bindings" { + var arena = std.heap.ArenaAllocator.init(std.testing.allocator); + defer arena.deinit(); + const alloc = arena.allocator(); + var module = ir_mod.Module.init(alloc); + defer module.deinit(); + var l = Lowering.init(&module); + const gr = GenericResolver{ .l = &l }; + + // pair :: ($T: Type, a: T, b: T) -> T — the return type IS the bound `T`. + const tps = [_]ast.StructTypeParam{.{ .name = "T", .constraint = typeKeyword(alloc, "Type") }}; + const params = [_]ast.Param{ + .{ .name = "T", .name_span = .{ .start = 0, .end = 0 }, .type_expr = typeKeyword(alloc, "Type") }, + .{ .name = "a", .name_span = .{ .start = 0, .end = 0 }, .type_expr = typeKeyword(alloc, "T") }, + .{ .name = "b", .name_span = .{ .start = 0, .end = 0 }, .type_expr = typeKeyword(alloc, "T") }, + }; + const body = alloc.create(Node) catch unreachable; + body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{} } } }; + const fd = ast.FnDecl{ .name = "pair", .params = ¶ms, .return_type = typeKeyword(alloc, "T"), .body = body, .type_params = &tps }; + + // Explicit type arg in position 0 binds `T`; the inferred return follows it. + const mkCall = struct { + fn f(a: std.mem.Allocator, type_name: []const u8) ast.Call { + const targ = typeKeyword(a, type_name); + const x = a.create(Node) catch unreachable; + x.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .int_literal = .{ .value = 1 } } }; + const y = a.create(Node) catch unreachable; + y.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .int_literal = .{ .value = 2 } } }; + const callee = a.create(Node) catch unreachable; + callee.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .identifier = .{ .name = "pair" } } }; + const args = a.alloc(*Node, 3) catch unreachable; + args[0] = targ; + args[1] = x; + args[2] = y; + return .{ .callee = callee, .args = args }; + } + }.f; + + const c_s64 = mkCall(alloc, "s64"); + try std.testing.expectEqual(TypeId.s64, gr.inferGenericReturnType(&fd, &c_s64)); + const c_f64 = mkCall(alloc, "f64"); + try std.testing.expectEqual(TypeId.f64, gr.inferGenericReturnType(&fd, &c_f64)); + + // The scoped binding env restores the prior `type_bindings` (null here) — + // it must NOT leak the call's temporary bindings (the issue-0048/0050 class). + try std.testing.expect(l.type_bindings == null); +} diff --git a/src/ir/generics.zig b/src/ir/generics.zig new file mode 100644 index 00000000..e8a3d4e1 --- /dev/null +++ b/src/ir/generics.zig @@ -0,0 +1,331 @@ +const std = @import("std"); +const ast = @import("../ast.zig"); +const types = @import("types.zig"); +const type_bridge = @import("type_bridge.zig"); +const lower = @import("lower.zig"); + +const Node = ast.Node; +const TypeId = types.TypeId; +const Lowering = lower.Lowering; + +/// Generic substitution + monomorphization-key construction (architecture +/// phase A4.1), extracted from `Lowering`. Owns: +/// - the type-name mangler (`mangleTypeName` / `mangleParamList`) — the leaf +/// fragment every mono key is built from, +/// - the generic mono key (`mangleGenericName`) and the comptime-value mono +/// fragment (`appendComptimeValueMangle`), +/// - type-parameter substitution: `buildTypeBindings` (call-site inference) +/// and `inferGenericReturnType` (generic return resolution). +/// +/// A `*Lowering` facade (Principle 5, like `CallResolver` / `ExprTyper`): +/// substitution reads live type-binding / scope state and the type resolver +/// helpers, so it borrows `*Lowering` rather than re-threading every field. +/// `Lowering` keeps a thin `mangleTypeName` wrapper (it has ~30 cross-cutting +/// callers — impl-map keys, conversion keys, shape keys — well beyond +/// generics); the rest call through `Lowering.genericResolver()`. +pub const GenericResolver = struct { + l: *Lowering, + + // ── Mono-key construction ─────────────────────────────────────────── + + /// Mangle a TypeId into its mono-key fragment ("s64", "ptr_T", "SL_T", + /// "AR_n_T", struct name, "tu_X_Y", …). Recursive for compound shapes. + pub fn mangleTypeName(self: GenericResolver, ty: TypeId) []const u8 { + // Builtin types + if (ty == .s8) return "s8"; + if (ty == .s16) return "s16"; + if (ty == .s32) return "s32"; + if (ty == .s64) return "s64"; + if (ty == .u8) return "u8"; + if (ty == .u16) return "u16"; + if (ty == .u32) return "u32"; + if (ty == .u64) return "u64"; + if (ty == .f32) return "f32"; + if (ty == .f64) return "f64"; + if (ty == .bool) return "bool"; + if (ty == .void) return "void"; + if (ty == .string) return "string"; + if (ty == .any) return "Any"; + if (ty == .usize) return "usize"; + if (ty == .isize) return "isize"; + + const info = self.l.module.types.get(ty); + return switch (info) { + .@"struct" => |s| self.l.module.types.getString(s.name), + .@"union" => |u| self.l.module.types.getString(u.name), + .tagged_union => |u| self.l.module.types.getString(u.name), + .@"enum" => |e| self.l.module.types.getString(e.name), + .pointer => |p| blk: { + const inner = self.mangleTypeName(p.pointee); + break :blk std.fmt.allocPrint(self.l.alloc, "ptr_{s}", .{inner}) catch "pointer"; + }, + .many_pointer => |p| blk: { + const inner = self.mangleTypeName(p.element); + break :blk std.fmt.allocPrint(self.l.alloc, "mptr_{s}", .{inner}) catch "many_pointer"; + }, + .slice => |s| blk: { + const inner = self.mangleTypeName(s.element); + break :blk std.fmt.allocPrint(self.l.alloc, "SL_{s}", .{inner}) catch "slice"; + }, + .array => |a| blk: { + const inner = self.mangleTypeName(a.element); + break :blk std.fmt.allocPrint(self.l.alloc, "AR_{d}_{s}", .{ a.length, inner }) catch "array"; + }, + .signed => |w| std.fmt.allocPrint(self.l.alloc, "s{d}", .{w}) catch "signed", + .unsigned => |w| std.fmt.allocPrint(self.l.alloc, "u{d}", .{w}) catch "unsigned", + .optional => |o| blk: { + const inner = self.mangleTypeName(o.child); + break :blk std.fmt.allocPrint(self.l.alloc, "opt_{s}", .{inner}) catch "optional"; + }, + .vector => |v| blk: { + const inner = self.mangleTypeName(v.element); + break :blk std.fmt.allocPrint(self.l.alloc, "vec_{d}_{s}", .{ v.length, inner }) catch "vector"; + }, + .closure => |c| self.mangleParamList("cl", c.params, c.ret), + .function => |f| self.mangleParamList("fn", f.params, f.ret), + .tuple => |t| blk: { + var buf = std.ArrayList(u8).empty; + buf.appendSlice(self.l.alloc, "tu") catch break :blk "tuple"; + for (t.fields) |fid| { + buf.append(self.l.alloc, '_') catch break :blk "tuple"; + buf.appendSlice(self.l.alloc, self.mangleTypeName(fid)) catch break :blk "tuple"; + } + break :blk buf.items; + }, + else => @tagName(info), + }; + } + + fn mangleParamList(self: GenericResolver, prefix: []const u8, params: []const TypeId, ret: TypeId) []const u8 { + var buf = std.ArrayList(u8).empty; + buf.appendSlice(self.l.alloc, prefix) catch return prefix; + for (params) |p| { + buf.append(self.l.alloc, '_') catch return prefix; + buf.appendSlice(self.l.alloc, self.mangleTypeName(p)) catch return prefix; + } + buf.appendSlice(self.l.alloc, "__") catch return prefix; + buf.appendSlice(self.l.alloc, self.mangleTypeName(ret)) catch return prefix; + return buf.items; + } + + /// Mangle a generic call site into "base__Type1_Type2". + /// Returns a heap-allocated string owned by the lowering allocator. + pub fn mangleGenericName( + self: GenericResolver, + base_name: []const u8, + fd: *const ast.FnDecl, + bindings: *const std.StringHashMap(TypeId), + ) []const u8 { + var mangled_buf: [256]u8 = undefined; + var mangled_len: usize = 0; + for (base_name) |ch| { + if (mangled_len < mangled_buf.len) { + mangled_buf[mangled_len] = ch; + mangled_len += 1; + } + } + for (fd.type_params) |tp| { + for ("__") |ch| { + if (mangled_len < mangled_buf.len) { + mangled_buf[mangled_len] = ch; + mangled_len += 1; + } + } + const ty = bindings.get(tp.name) orelse .unresolved; + const type_name_str = self.mangleTypeName(ty); + for (type_name_str) |ch| { + if (mangled_len < mangled_buf.len) { + mangled_buf[mangled_len] = ch; + mangled_len += 1; + } + } + } + return self.l.alloc.dupe(u8, mangled_buf[0..mangled_len]) catch base_name; + } + + /// Append a comptime parameter VALUE's mono fragment to `buf` (int/bool + /// verbatim, float with `.`/`-` escaped, string hashed) so distinct + /// comptime-value call sites get distinct monos. + pub fn appendComptimeValueMangle(self: GenericResolver, buf: *std.ArrayList(u8), node: *const Node) void { + switch (node.data) { + .int_literal => |lit| { + var tmp: [32]u8 = undefined; + const written = std.fmt.bufPrint(&tmp, "{d}", .{lit.value}) catch return; + buf.appendSlice(self.l.alloc, written) catch return; + }, + .bool_literal => |lit| { + buf.appendSlice(self.l.alloc, if (lit.value) "true" else "false") catch return; + }, + .float_literal => |lit| { + var tmp: [64]u8 = undefined; + const written = std.fmt.bufPrint(&tmp, "{d}", .{lit.value}) catch return; + for (written) |c| { + buf.append(self.l.alloc, if (c == '.') '_' else if (c == '-') 'n' else c) catch return; + } + }, + .string_literal => |lit| { + // Hash the string to a fixed-length tag — keeps the + // mangle short and stable for arbitrary content. + var h = std.hash.Wyhash.init(0); + h.update(lit.raw); + var tmp: [32]u8 = undefined; + const written = std.fmt.bufPrint(&tmp, "s{x}", .{h.final()}) catch return; + buf.appendSlice(self.l.alloc, written) catch return; + }, + else => buf.append(self.l.alloc, '?') catch return, + } + } + + // ── Type-parameter substitution ───────────────────────────────────── + + /// Build the `$T → concrete TypeId` bindings for a generic call site. + /// Strategy 1: explicit type args (the param named `$T` IS a type + /// expression). Strategy 2: infer from value params that use `T` + /// (`a: $T`, `items: []$T`), picking the widest match. + pub fn buildTypeBindings( + self: GenericResolver, + fd: *const ast.FnDecl, + args_ast: []const *const Node, + ) std.StringHashMap(TypeId) { + var bindings = std.StringHashMap(TypeId).init(self.l.alloc); + const types_passed_explicitly = args_ast.len == fd.params.len; + for (fd.type_params) |tp| { + var found = false; + // Strategy 1: explicit — the param whose name matches `tp.name` IS + // the `$T: Type` declaration; the arg at that position is a type expression. + if (types_passed_explicitly) { + for (fd.params, 0..) |param, pi| { + if (std.mem.eql(u8, param.name, tp.name)) { + if (pi < args_ast.len and type_bridge.isTypeShapedAstNode(args_ast[pi], &self.l.module.types)) { + const ty = self.l.resolveTypeArg(args_ast[pi]); + bindings.put(tp.name, ty) catch {}; + found = true; + } + break; + } + } + } + if (found) continue; + // Strategy 2: infer from value params that USE the type param + // (e.g. a: $T, b: T, items: []$T). Pick widest type across matches. + var inferred_ty: ?TypeId = null; + var s2_arg_idx: usize = 0; + for (fd.params) |param| { + const is_type_decl = Lowering.isTypeParamDecl(¶m, fd.type_params); + defer if (!is_type_decl) { + s2_arg_idx += 1; + }; + if (is_type_decl) { + if (types_passed_explicitly) s2_arg_idx += 1; + continue; + } + const matched = self.l.matchTypeParam(param.type_expr, tp.name); + if (matched) { + if (s2_arg_idx < args_ast.len) { + const arg_ty = self.l.inferExprType(args_ast[s2_arg_idx]); + const extracted = self.l.extractTypeParam(param.type_expr, arg_ty, tp.name); + if (extracted) |ety| { + if (inferred_ty) |prev| { + if (ety == .f64 and prev != .f64) { + inferred_ty = ety; + } else if (ety == .f32 and prev != .f64 and prev != .f32) { + inferred_ty = ety; + } + } else { + inferred_ty = ety; + } + } + } + } + } + if (inferred_ty) |ty| { + bindings.put(tp.name, ty) catch {}; + } + } + return bindings; + } + + /// Infer the return type of a generic function call by resolving type bindings. + pub fn inferGenericReturnType(self: GenericResolver, fd: *const ast.FnDecl, c: *const ast.Call) TypeId { + if (fd.return_type == null) return .void; + + // Build ALL type bindings from call args before resolving return type + var tmp_bindings = std.StringHashMap(TypeId).init(self.l.alloc); + defer tmp_bindings.deinit(); + + for (fd.type_params) |tp| { + // Strategy 1: direct type param decl ($T: Type) — param.name == tp.name. + // Only fires when the caller actually supplied a type expression at + // that position; otherwise fall through to value-based inference. + var found = false; + for (fd.params, 0..) |param, pi| { + if (std.mem.eql(u8, param.name, tp.name)) { + if (pi < c.args.len and type_bridge.isTypeShapedAstNode(c.args[pi], &self.l.module.types)) { + const ty = self.l.resolveTypeArg(c.args[pi]); + tmp_bindings.put(tp.name, ty) catch {}; + found = true; + } + break; + } + } + if (found) continue; + + // Strategy 2: inferred from usage (a: $T, b: T) — check ALL matching params, pick widest + var inferred_ty: ?TypeId = null; + for (fd.params, 0..) |param, pi| { + if (param.type_expr.data == .type_expr) { + const te = param.type_expr.data.type_expr; + if (std.mem.eql(u8, te.name, tp.name)) { + if (pi < c.args.len) { + const arg_ty = self.l.inferExprType(c.args[pi]); + if (inferred_ty) |prev| { + if (arg_ty == .f64 and prev != .f64) { + inferred_ty = arg_ty; + } else if (arg_ty == .f32 and prev != .f64 and prev != .f32) { + inferred_ty = arg_ty; + } + } else { + inferred_ty = arg_ty; + } + } + } + } + } + if (inferred_ty) |ty| { + tmp_bindings.put(tp.name, ty) catch {}; + } + } + + // Resolve return type with whatever bindings we built. Even an + // empty `tmp_bindings` is a valid input — non-generic literal + // return types (e.g. `walk(..$args) -> string`) still need to + // resolve through `resolveTypeWithBindings`, not fall through + // to the historical `.s64` default. The default silently + // misclassified pack-fn calls whose return type was a fixed + // literal — every consumer (e.g. print's pack-shape mangling) + // inferred `s64` and routed the value through the wrong Any + // tag. + var scope = TypeBindingScope.enter(self.l, tmp_bindings); + defer scope.exit(); + return self.l.resolveTypeWithBindings(fd.return_type.?); + } +}; + +/// Scoped override of `Lowering.type_bindings`: install a binding set for the +/// duration of a substitution, restoring the prior set on `exit`. Replaces the +/// manual save/restore the generic-return resolution used (PLAN-ARCH A4.1 +/// "scoped substitution envs"). +const TypeBindingScope = struct { + l: *Lowering, + saved: ?std.StringHashMap(TypeId), + + fn enter(l: *Lowering, bindings: std.StringHashMap(TypeId)) TypeBindingScope { + const saved = l.type_bindings; + l.type_bindings = bindings; + return .{ .l = l, .saved = saved }; + } + + fn exit(self: *TypeBindingScope) void { + self.l.type_bindings = self.saved; + } +}; diff --git a/src/ir/ir.zig b/src/ir/ir.zig index a92eb38e..d906dbff 100644 --- a/src/ir/ir.zig +++ b/src/ir/ir.zig @@ -9,6 +9,7 @@ pub const type_resolver = @import("type_resolver.zig"); pub const packs = @import("packs.zig"); pub const expr_typer = @import("expr_typer.zig"); pub const calls = @import("calls.zig"); +pub const generics = @import("generics.zig"); pub const semantic_diagnostics = @import("semantic_diagnostics.zig"); pub const TypeId = types.TypeId; @@ -43,6 +44,7 @@ pub const PackResolver = packs.PackResolver; pub const ExprTyper = expr_typer.ExprTyper; pub const CallResolver = calls.CallResolver; pub const CallPlan = calls.CallPlan; +pub const GenericResolver = generics.GenericResolver; pub const compiler_hooks = @import("compiler_hooks.zig"); pub const emit_llvm = @import("emit_llvm.zig"); @@ -66,6 +68,7 @@ pub const type_resolver_tests = @import("type_resolver.test.zig"); pub const packs_tests = @import("packs.test.zig"); pub const expr_typer_tests = @import("expr_typer.test.zig"); pub const calls_tests = @import("calls.test.zig"); +pub const generics_tests = @import("generics.test.zig"); pub const type_bridge_tests = @import("type_bridge.test.zig"); pub const emit_llvm_tests = @import("emit_llvm.test.zig"); pub const jni_descriptor_tests = @import("jni_descriptor.test.zig"); diff --git a/src/ir/lower.test.zig b/src/ir/lower.test.zig index 10d4df0f..c3a4eccc 100644 --- a/src/ir/lower.test.zig +++ b/src/ir/lower.test.zig @@ -828,92 +828,3 @@ test "E1.4c noreturn typing: divergence shapes + if-else unification + block pro defer alloc.destroy(both_div); try std.testing.expectEqual(TypeId.noreturn, lowering.inferExprType(both_div)); } - -// ── A4.1 test-first scaffolding: generic substitution + mono keys ──── -// Lock the CURRENT behavior of the generic mono-key building blocks -// (`mangleTypeName`) and generic-return substitution (`inferGenericReturnType`) -// before they move to `src/ir/generics.zig`. Reached through the existing -// public surface — no new exposure (mirrors the A3.2 sub-step-1 cadence). - -test "generics: mangleTypeName encodes the mono-key fragment per type shape" { - // Arena: the compound arms allocate fragment strings via the module - // allocator (`allocPrint` / ArrayList) and never free them — the real - // compiler runs in the compile arena, so an arena keeps the leak checker - // clean without changing the encoding under test. - var arena = std.heap.ArenaAllocator.init(std.testing.allocator); - defer arena.deinit(); - const alloc = arena.allocator(); - var module = ir_mod.Module.init(alloc); - defer module.deinit(); - var l = Lowering.init(&module); - const tt = &module.types; - - // Builtins — the leaf fragments `mangleGenericName` concatenates per - // bound type param (`base__...`). - try std.testing.expectEqualStrings("s64", l.mangleTypeName(.s64)); - try std.testing.expectEqualStrings("u8", l.mangleTypeName(.u8)); - try std.testing.expectEqualStrings("f32", l.mangleTypeName(.f32)); - try std.testing.expectEqualStrings("bool", l.mangleTypeName(.bool)); - try std.testing.expectEqualStrings("Any", l.mangleTypeName(.any)); - try std.testing.expectEqualStrings("string", l.mangleTypeName(.string)); - - // Compound shapes — prefix + recursive inner fragment. - try std.testing.expectEqualStrings("ptr_s64", l.mangleTypeName(tt.ptrTo(.s64))); - try std.testing.expectEqualStrings("opt_s64", l.mangleTypeName(tt.optionalOf(.s64))); - try std.testing.expectEqualStrings("ptr_opt_u8", l.mangleTypeName(tt.ptrTo(tt.optionalOf(.u8)))); - try std.testing.expectEqualStrings("SL_f64", l.mangleTypeName(tt.intern(.{ .slice = .{ .element = .f64 } }))); - try std.testing.expectEqualStrings("mptr_u8", l.mangleTypeName(tt.intern(.{ .many_pointer = .{ .element = .u8 } }))); - try std.testing.expectEqualStrings("AR_4_s32", l.mangleTypeName(tt.intern(.{ .array = .{ .element = .s32, .length = 4 } }))); - try std.testing.expectEqualStrings("vec_3_f32", l.mangleTypeName(tt.intern(.{ .vector = .{ .element = .f32, .length = 3 } }))); - - // Named aggregate → its declared name. - const pt = tt.intern(.{ .@"struct" = .{ .name = tt.internString("Point"), .fields = &.{} } }); - try std.testing.expectEqualStrings("Point", l.mangleTypeName(pt)); - - // Tuple: "tu" + "_" per field. - const tup = tt.intern(.{ .tuple = .{ .fields = &[_]TypeId{ .s64, .bool }, .names = null } }); - try std.testing.expectEqualStrings("tu_s64_bool", l.mangleTypeName(tup)); -} - -test "generics: inferGenericReturnType binds explicit type args, resolves return" { - var arena = std.heap.ArenaAllocator.init(std.testing.allocator); - defer arena.deinit(); - const alloc = arena.allocator(); - var module = ir_mod.Module.init(alloc); - defer module.deinit(); - var l = Lowering.init(&module); - - // pair :: ($T: Type, a: T, b: T) -> T — the return type IS the bound `T`. - const tps = [_]ast.StructTypeParam{.{ .name = "T", .constraint = typeKeyword(alloc, "Type") }}; - const params = [_]ast.Param{ - .{ .name = "T", .name_span = .{ .start = 0, .end = 0 }, .type_expr = typeKeyword(alloc, "Type") }, - .{ .name = "a", .name_span = .{ .start = 0, .end = 0 }, .type_expr = typeKeyword(alloc, "T") }, - .{ .name = "b", .name_span = .{ .start = 0, .end = 0 }, .type_expr = typeKeyword(alloc, "T") }, - }; - const body = alloc.create(Node) catch unreachable; - body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{} } } }; - const fd = ast.FnDecl{ .name = "pair", .params = ¶ms, .return_type = typeKeyword(alloc, "T"), .body = body, .type_params = &tps }; - - // Explicit type arg in position 0 binds `T`; the inferred return follows it. - const mkCall = struct { - fn f(a: std.mem.Allocator, type_name: []const u8) ast.Call { - const targ = typeKeyword(a, type_name); - const x = a.create(Node) catch unreachable; - x.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .int_literal = .{ .value = 1 } } }; - const y = a.create(Node) catch unreachable; - y.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .int_literal = .{ .value = 2 } } }; - const callee = a.create(Node) catch unreachable; - callee.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .identifier = .{ .name = "pair" } } }; - const args = a.alloc(*Node, 3) catch unreachable; - args[0] = targ; - args[1] = x; - args[2] = y; - return .{ .callee = callee, .args = args }; - } - }.f; - - const c_s64 = mkCall(alloc, "s64"); - try std.testing.expectEqual(TypeId.s64, l.inferGenericReturnType(&fd, &c_s64)); - const c_f64 = mkCall(alloc, "f64"); - try std.testing.expectEqual(TypeId.f64, l.inferGenericReturnType(&fd, &c_f64)); -} diff --git a/src/ir/lower.zig b/src/ir/lower.zig index d795af52..9c419b3c 100644 --- a/src/ir/lower.zig +++ b/src/ir/lower.zig @@ -24,6 +24,7 @@ const ResolveEnv = @import("type_resolver.zig").ResolveEnv; const PackResolver = @import("packs.zig").PackResolver; const ExprTyper = @import("expr_typer.zig").ExprTyper; const CallResolver = @import("calls.zig").CallResolver; +const GenericResolver = @import("generics.zig").GenericResolver; const semantic_diagnostics = @import("semantic_diagnostics.zig"); const TypeId = types.TypeId; @@ -7908,10 +7909,10 @@ pub const Lowering = struct { eff_args.append(self.alloc, effective_obj_node) catch unreachable; for (c.args) |a| eff_args.append(self.alloc, a) catch unreachable; - var gbindings = self.buildTypeBindings(gen_fd, eff_args.items); + var gbindings = self.genericResolver().buildTypeBindings(gen_fd, eff_args.items); defer gbindings.deinit(); - const gmangled = self.mangleGenericName(qualified, gen_fd, &gbindings); + const gmangled = self.genericResolver().mangleGenericName(qualified, gen_fd, &gbindings); if (!self.lowered_functions.contains(gmangled)) { self.monomorphizeFunction(gen_fd, gmangled, &gbindings); } @@ -9790,111 +9791,13 @@ pub const Lowering = struct { /// `args_ast` must be parallel to `fd.params`; for dot-calls the caller /// prepends the receiver's AST node so positions align with `fd.params[0] = self`. /// Caller owns the returned map and must call `.deinit()`. - fn buildTypeBindings( - self: *Lowering, - fd: *const ast.FnDecl, - args_ast: []const *const Node, - ) std.StringHashMap(TypeId) { - var bindings = std.StringHashMap(TypeId).init(self.alloc); - const types_passed_explicitly = args_ast.len == fd.params.len; - for (fd.type_params) |tp| { - var found = false; - // Strategy 1: explicit — the param whose name matches `tp.name` IS - // the `$T: Type` declaration; the arg at that position is a type expression. - if (types_passed_explicitly) { - for (fd.params, 0..) |param, pi| { - if (std.mem.eql(u8, param.name, tp.name)) { - if (pi < args_ast.len and type_bridge.isTypeShapedAstNode(args_ast[pi], &self.module.types)) { - const ty = self.resolveTypeArg(args_ast[pi]); - bindings.put(tp.name, ty) catch {}; - found = true; - } - break; - } - } - } - if (found) continue; - // Strategy 2: infer from value params that USE the type param - // (e.g. a: $T, b: T, items: []$T). Pick widest type across matches. - var inferred_ty: ?TypeId = null; - var s2_arg_idx: usize = 0; - for (fd.params) |param| { - const is_type_decl = isTypeParamDecl(¶m, fd.type_params); - defer if (!is_type_decl) { - s2_arg_idx += 1; - }; - if (is_type_decl) { - if (types_passed_explicitly) s2_arg_idx += 1; - continue; - } - const matched = self.matchTypeParam(param.type_expr, tp.name); - if (matched) { - if (s2_arg_idx < args_ast.len) { - const arg_ty = self.inferExprType(args_ast[s2_arg_idx]); - const extracted = self.extractTypeParam(param.type_expr, arg_ty, tp.name); - if (extracted) |ety| { - if (inferred_ty) |prev| { - if (ety == .f64 and prev != .f64) { - inferred_ty = ety; - } else if (ety == .f32 and prev != .f64 and prev != .f32) { - inferred_ty = ety; - } - } else { - inferred_ty = ety; - } - } - } - } - } - if (inferred_ty) |ty| { - bindings.put(tp.name, ty) catch {}; - } - } - return bindings; - } - - /// Mangle a generic call site into "base__Type1_Type2". - /// Returns a heap-allocated string owned by self.alloc. - fn mangleGenericName( - self: *Lowering, - base_name: []const u8, - fd: *const ast.FnDecl, - bindings: *const std.StringHashMap(TypeId), - ) []const u8 { - var mangled_buf: [256]u8 = undefined; - var mangled_len: usize = 0; - for (base_name) |ch| { - if (mangled_len < mangled_buf.len) { - mangled_buf[mangled_len] = ch; - mangled_len += 1; - } - } - for (fd.type_params) |tp| { - for ("__") |ch| { - if (mangled_len < mangled_buf.len) { - mangled_buf[mangled_len] = ch; - mangled_len += 1; - } - } - const ty = bindings.get(tp.name) orelse .unresolved; - const type_name_str = self.mangleTypeName(ty); - for (type_name_str) |ch| { - if (mangled_len < mangled_buf.len) { - mangled_buf[mangled_len] = ch; - mangled_len += 1; - } - } - } - return self.alloc.dupe(u8, mangled_buf[0..mangled_len]) catch base_name; - } - /// Lower a call to a generic function by monomorphizing it with inferred type arguments. fn lowerGenericCall(self: *Lowering, fd: *const ast.FnDecl, base_name: []const u8, call_node: *const ast.Call, lowered_args: []Ref) Ref { - var bindings = self.buildTypeBindings(fd, call_node.args); + var bindings = self.genericResolver().buildTypeBindings(fd, call_node.args); defer bindings.deinit(); const types_passed_explicitly = call_node.args.len == fd.params.len; - const mangled_name = self.mangleGenericName(base_name, fd, &bindings); + const mangled_name = self.genericResolver().mangleGenericName(base_name, fd, &bindings); if (!self.lowered_functions.contains(mangled_name)) { self.monomorphizeFunction(fd, mangled_name, &bindings); @@ -10498,7 +10401,7 @@ pub const Lowering = struct { if (ct_fi >= call_node.args.len) break; if (p.is_comptime) { name_buf.appendSlice(self.alloc, "__ct_") catch return self.builder.constInt(0, .void); - self.appendComptimeValueMangle(&name_buf, call_node.args[ct_fi]); + self.genericResolver().appendComptimeValueMangle(&name_buf, call_node.args[ct_fi]); } ct_fi += 1; } @@ -10529,40 +10432,6 @@ pub const Lowering = struct { return self.builder.call(fid, final_args, ret_ty); } - /// Append a stable mangle segment for a comptime call-arg literal. - /// Supports int / bool / float / string literals; non-literals - /// degrade to "?" (the mono is still cached but two different - /// non-literal expressions sharing one call site would collide, - /// which is acceptable since they'd lower the same body anyway). - fn appendComptimeValueMangle(self: *Lowering, buf: *std.ArrayList(u8), node: *const Node) void { - switch (node.data) { - .int_literal => |lit| { - var tmp: [32]u8 = undefined; - const written = std.fmt.bufPrint(&tmp, "{d}", .{lit.value}) catch return; - buf.appendSlice(self.alloc, written) catch return; - }, - .bool_literal => |lit| { - buf.appendSlice(self.alloc, if (lit.value) "true" else "false") catch return; - }, - .float_literal => |lit| { - var tmp: [64]u8 = undefined; - const written = std.fmt.bufPrint(&tmp, "{d}", .{lit.value}) catch return; - for (written) |c| { - buf.append(self.alloc, if (c == '.') '_' else if (c == '-') 'n' else c) catch return; - } - }, - .string_literal => |lit| { - // Hash the string to a fixed-length tag — keeps the - // mangle short and stable for arbitrary content. - var h = std.hash.Wyhash.init(0); - h.update(lit.raw); - var tmp: [32]u8 = undefined; - const written = std.fmt.bufPrint(&tmp, "s{x}", .{h.final()}) catch return; - buf.appendSlice(self.alloc, written) catch return; - }, - else => buf.append(self.alloc, '?') catch return, - } - } /// Build a single mono fn for the given pack-fn + concrete arg types. /// The mono carries N positional pack-params (synthesised names @@ -11485,7 +11354,7 @@ pub const Lowering = struct { /// Format a type name for function name mangling (identifier-safe). /// E.g. *Point → "ptr_Point", []s32 → "slice_s32", [3]f64 → "array_3_f64". /// Check if a param type expression references a type param name (possibly nested). - fn matchTypeParam(_: *Lowering, type_node: *const Node, tp_name: []const u8) bool { + pub fn matchTypeParam(_: *Lowering, type_node: *const Node, tp_name: []const u8) bool { return switch (type_node.data) { .type_expr => |te| std.mem.eql(u8, te.name, tp_name), .identifier => |id| std.mem.eql(u8, id.name, tp_name), @@ -11523,7 +11392,7 @@ pub const Lowering = struct { /// Extract the concrete type that corresponds to a type param from an arg type. /// E.g., param type []$T with arg type []s64 → T = s64. - fn extractTypeParam(self: *Lowering, type_node: *const Node, arg_ty: TypeId, tp_name: []const u8) ?TypeId { + pub fn extractTypeParam(self: *Lowering, type_node: *const Node, arg_ty: TypeId, tp_name: []const u8) ?TypeId { return switch (type_node.data) { .type_expr => |te| if (std.mem.eql(u8, te.name, tp_name)) arg_ty else null, .identifier => |id| if (std.mem.eql(u8, id.name, tp_name)) arg_ty else null, @@ -11590,70 +11459,12 @@ pub const Lowering = struct { }; } + /// Mangle a TypeId into its mono-key fragment. Thin delegation to the + /// canonical owner (`GenericResolver`, `generics.zig`); kept on `Lowering` + /// because ~30 cross-cutting callers (impl-map keys, conversion keys, shape + /// keys) reach it here, well beyond generic monomorphization. pub fn mangleTypeName(self: *Lowering, ty: TypeId) []const u8 { - // Builtin types - if (ty == .s8) return "s8"; - if (ty == .s16) return "s16"; - if (ty == .s32) return "s32"; - if (ty == .s64) return "s64"; - if (ty == .u8) return "u8"; - if (ty == .u16) return "u16"; - if (ty == .u32) return "u32"; - if (ty == .u64) return "u64"; - if (ty == .f32) return "f32"; - if (ty == .f64) return "f64"; - if (ty == .bool) return "bool"; - if (ty == .void) return "void"; - if (ty == .string) return "string"; - if (ty == .any) return "Any"; - if (ty == .usize) return "usize"; - if (ty == .isize) return "isize"; - - const info = self.module.types.get(ty); - return switch (info) { - .@"struct" => |s| self.module.types.getString(s.name), - .@"union" => |u| self.module.types.getString(u.name), - .tagged_union => |u| self.module.types.getString(u.name), - .@"enum" => |e| self.module.types.getString(e.name), - .pointer => |p| blk: { - const inner = self.mangleTypeName(p.pointee); - break :blk std.fmt.allocPrint(self.alloc, "ptr_{s}", .{inner}) catch "pointer"; - }, - .many_pointer => |p| blk: { - const inner = self.mangleTypeName(p.element); - break :blk std.fmt.allocPrint(self.alloc, "mptr_{s}", .{inner}) catch "many_pointer"; - }, - .slice => |s| blk: { - const inner = self.mangleTypeName(s.element); - break :blk std.fmt.allocPrint(self.alloc, "SL_{s}", .{inner}) catch "slice"; - }, - .array => |a| blk: { - const inner = self.mangleTypeName(a.element); - break :blk std.fmt.allocPrint(self.alloc, "AR_{d}_{s}", .{ a.length, inner }) catch "array"; - }, - .signed => |w| std.fmt.allocPrint(self.alloc, "s{d}", .{w}) catch "signed", - .unsigned => |w| std.fmt.allocPrint(self.alloc, "u{d}", .{w}) catch "unsigned", - .optional => |o| blk: { - const inner = self.mangleTypeName(o.child); - break :blk std.fmt.allocPrint(self.alloc, "opt_{s}", .{inner}) catch "optional"; - }, - .vector => |v| blk: { - const inner = self.mangleTypeName(v.element); - break :blk std.fmt.allocPrint(self.alloc, "vec_{d}_{s}", .{ v.length, inner }) catch "vector"; - }, - .closure => |c| self.mangleParamList("cl", c.params, c.ret), - .function => |f| self.mangleParamList("fn", f.params, f.ret), - .tuple => |t| blk: { - var buf = std.ArrayList(u8).empty; - buf.appendSlice(self.alloc, "tu") catch break :blk "tuple"; - for (t.fields) |fid| { - buf.append(self.alloc, '_') catch break :blk "tuple"; - buf.appendSlice(self.alloc, self.mangleTypeName(fid)) catch break :blk "tuple"; - } - break :blk buf.items; - }, - else => @tagName(info), - }; + return self.genericResolver().mangleTypeName(ty); } /// Collect impl entries visible from `current_source_file` — defined in @@ -11697,18 +11508,6 @@ pub const Lowering = struct { } } - fn mangleParamList(self: *Lowering, prefix: []const u8, params: []const TypeId, ret: TypeId) []const u8 { - var buf = std.ArrayList(u8).empty; - buf.appendSlice(self.alloc, prefix) catch return prefix; - for (params) |p| { - buf.append(self.alloc, '_') catch return prefix; - buf.appendSlice(self.alloc, self.mangleTypeName(p)) catch return prefix; - } - buf.appendSlice(self.alloc, "__") catch return prefix; - buf.appendSlice(self.alloc, self.mangleTypeName(ret)) catch return prefix; - return buf.items; - } - /// Resolve type category names (like "int", "struct", "float") to matching TypeId tag values. /// Returns a list of TypeId index values that match the category. fn resolveTypeCategoryTags(self: *Lowering, name: []const u8) []const u64 { @@ -12167,7 +11966,7 @@ pub const Lowering = struct { /// Check if a param is a type param declaration ($T: Type). /// A type param declaration has param.name == one of the type_params names. - fn isTypeParamDecl(param: *const ast.Param, type_params: []const ast.StructTypeParam) bool { + pub fn isTypeParamDecl(param: *const ast.Param, type_params: []const ast.StructTypeParam) bool { for (type_params) |tp| { if (std.mem.eql(u8, param.name, tp.name)) return true; } @@ -14454,71 +14253,8 @@ pub const Lowering = struct { return .{ .l = self }; } - /// Infer the return type of a generic function call by resolving type bindings. - pub fn inferGenericReturnType(self: *Lowering, fd: *const ast.FnDecl, c: *const ast.Call) TypeId { - if (fd.return_type == null) return .void; - - // Build ALL type bindings from call args before resolving return type - var tmp_bindings = std.StringHashMap(TypeId).init(self.alloc); - defer tmp_bindings.deinit(); - - for (fd.type_params) |tp| { - // Strategy 1: direct type param decl ($T: Type) — param.name == tp.name. - // Only fires when the caller actually supplied a type expression at - // that position; otherwise fall through to value-based inference. - var found = false; - for (fd.params, 0..) |param, pi| { - if (std.mem.eql(u8, param.name, tp.name)) { - if (pi < c.args.len and type_bridge.isTypeShapedAstNode(c.args[pi], &self.module.types)) { - const ty = self.resolveTypeArg(c.args[pi]); - tmp_bindings.put(tp.name, ty) catch {}; - found = true; - } - break; - } - } - if (found) continue; - - // Strategy 2: inferred from usage (a: $T, b: T) — check ALL matching params, pick widest - var inferred_ty: ?TypeId = null; - for (fd.params, 0..) |param, pi| { - if (param.type_expr.data == .type_expr) { - const te = param.type_expr.data.type_expr; - if (std.mem.eql(u8, te.name, tp.name)) { - if (pi < c.args.len) { - const arg_ty = self.inferExprType(c.args[pi]); - if (inferred_ty) |prev| { - if (arg_ty == .f64 and prev != .f64) { - inferred_ty = arg_ty; - } else if (arg_ty == .f32 and prev != .f64 and prev != .f32) { - inferred_ty = arg_ty; - } - } else { - inferred_ty = arg_ty; - } - } - } - } - } - if (inferred_ty) |ty| { - tmp_bindings.put(tp.name, ty) catch {}; - } - } - - // Resolve return type with whatever bindings we built. Even an - // empty `tmp_bindings` is a valid input — non-generic literal - // return types (e.g. `walk(..$args) -> string`) still need to - // resolve through `resolveTypeWithBindings`, not fall through - // to the historical `.s64` default. The default silently - // misclassified pack-fn calls whose return type was a fixed - // literal — every consumer (e.g. print's pack-shape mangling) - // inferred `s64` and routed the value through the wrong Any - // tag. - const saved = self.type_bindings; - self.type_bindings = tmp_bindings; - const ret = self.resolveTypeWithBindings(fd.return_type.?); - self.type_bindings = saved; - return ret; + pub fn genericResolver(self: *Lowering) GenericResolver { + return .{ .l = self }; } /// Lower the `xx` operator (type coercion).