diff --git a/examples/28-sdl-graphics.sx b/examples/28-sdl-graphics.sx index 5379096..5110aa1 100644 --- a/examples/28-sdl-graphics.sx +++ b/examples/28-sdl-graphics.sx @@ -155,7 +155,7 @@ main :: () { glBindVertexArray(vao); glBindBuffer(GL_ARRAY_BUFFER, vbo); - glBufferData(GL_ARRAY_BUFFER, 1152, @vertices, GL_STATIC_DRAW); + glBufferData(GL_ARRAY_BUFFER, 1152, vertices.ptr, GL_STATIC_DRAW); // Position attribute (location 0): 3 floats, stride 32 bytes, offset 0 glVertexAttribPointer(0, 3, GL_FLOAT, 0, 32, xx 0); @@ -204,7 +204,7 @@ main :: () { view := mat4_translate(0.0, 0.0, -3.0); rot_y := mat4_rotate_y(angle); rot_x := mat4_rotate_x(angle * 0.7); - model := (rot_y, rot_x).multiply(); + model := mat4_multiply(rot_y, rot_x); vm := mat4_multiply(view, model); mvp := mat4_multiply(proj, vm); diff --git a/examples/issue-0008.sx b/examples/issue-0008.sx new file mode 100644 index 0000000..022a572 --- /dev/null +++ b/examples/issue-0008.sx @@ -0,0 +1,26 @@ +// Issue 0008: Chained ?? (null coalescing) doesn't work +// +// `a ?? b ?? c` where a: ?f32, b: ?f32, c: f32 fails with: +// "narrowing conversion from '?f32' to 'f32' requires explicit 'xx' cast" +// +// It parses as (a ?? b) ?? c, and the first ?? rejects ?f32 as the rhs. +// +// Expected: ?? should either be right-associative so it parses as a ?? (b ?? c), +// or allow ?T as the rhs (returning ?T when rhs is optional, T when rhs is concrete). +// +// Workaround: use parentheses — a ?? (b ?? c) + +Foo :: struct { + x: ?f32; + y: ?f32; +} + +main :: () -> void { + f := Foo.{ x = 1.0, y = 2.0 }; + + // This works: + ok := f.x ?? (f.y ?? 0.0); + + // This should also work but fails: + bad := f.x ?? f.y ?? 0.0; +} diff --git a/examples/issue-0009.sx b/examples/issue-0009.sx new file mode 100644 index 0000000..2068a16 --- /dev/null +++ b/examples/issue-0009.sx @@ -0,0 +1,20 @@ +// Issue 0009: Struct-level constant declarations +// +// Constants declared inside a struct body with `NAME :Type: value;` syntax +// fail with "expected field name in struct". +// +// Expected: structs should support constant declarations alongside fields and methods. + +Foo :: struct { + x: f32; + + // This method works: + get_x :: (self: *Foo) -> f32 { self.x; } + + // This constant should work but fails: + DEFAULT_X :f32: 42.0; +} + +main :: () -> void { + f := Foo.{ x = Foo.DEFAULT_X }; +} diff --git a/examples/issue-0010.sx b/examples/issue-0010.sx new file mode 100644 index 0000000..20dbb53 --- /dev/null +++ b/examples/issue-0010.sx @@ -0,0 +1,38 @@ +// Issue 0010: inline if-else in struct literal field produces type error +// The `null` branch is typed as `*void` instead of being coerced to `?f32` +// +// Error: narrowing conversion from '*void' to 'f32' requires explicit 'xx' cast + +#import "modules/std.sx"; + +Foo :: struct { + width: ?f32; +} + +main :: () -> void { + x :f32: 10.0; + + // null in then branch, value in else + f1 := Foo.{ width = if true then null else x }; + print("{}\n", f1.width ?? 99.0); + + // value in then branch, null in else + f2 := Foo.{ width = if true then x else null }; + print("{}\n", f2.width ?? 99.0); + + // both branches are values + f3 := Foo.{ width = if false then 5.0 else x }; + print("{}\n", f3.width ?? 99.0); + + // standalone variable, not just struct fields + val: ?f32 = if true then null else 42.0; + print("{}\n", val ?? 0.0); + + val2: ?f32 = if false then null else 42.0; + print("{}\n", val2 ?? 0.0); + + // negation in condition + cond := false; + val3: ?f32 = if !cond then null else 42.0; + print("{}\n", val3 ?? 0.0); +} diff --git a/src/codegen.zig b/src/codegen.zig index b602f5f..3bb2335 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -12,6 +12,10 @@ const errors = @import("errors.zig"); const sema = @import("sema.zig"); const comptime_mod = @import("comptime.zig"); const unescape = @import("unescape.zig"); +const ir = @import("ir/ir.zig"); + +/// Feature flag: use the IR interpreter for comptime evaluation instead of the bytecode VM. +const USE_IR_COMPTIME = true; pub const TargetConfig = struct { /// Target triple (e.g. "aarch64-apple-darwin"). Null = host default. @@ -1505,10 +1509,22 @@ pub const CodeGen = struct { } } - /// Evaluate a comptime expression using the bytecode VM. - /// No LLVM state save/restore needed — the VM operates independently. + /// Evaluate a comptime expression using the bytecode VM or the IR interpreter. + /// When USE_IR_COMPTIME is true, tries the IR interpreter first and falls back + /// to the bytecode VM if the interpreter can't handle the expression. fn comptimeEval(self: *CodeGen, expr: *Node, expected_type: Type) !comptime_mod.Value { - _ = expected_type; // VM infers types from values; expected_type used by caller for LLVM conversion + if (USE_IR_COMPTIME) { + if (self.tryIrComptimeEval(expr)) |result| { + return result; + } + // IR interpreter can't handle this expression — fall back to VM + } + return self.vmComptimeEval(expr, expected_type); + } + + /// Evaluate a comptime expression using the bytecode VM (original path). + fn vmComptimeEval(self: *CodeGen, expr: *Node, expected_type: Type) !comptime_mod.Value { + _ = expected_type; var compiler = comptime_mod.Compiler.init(self.allocator, if (self.sema_result) |sr| sr else null, self.root_decls, self); const chunk = compiler.compile(expr) catch |err| { @@ -1522,6 +1538,42 @@ pub const CodeGen = struct { }; } + /// Try to evaluate a comptime expression using the IR interpreter. + /// Returns null if the interpreter can't handle the expression (no diagnostics emitted). + fn tryIrComptimeEval(self: *CodeGen, expr: *Node) ?comptime_mod.Value { + // Build an IR module with all top-level decls lowered + var module = ir.Module.init(self.allocator); + var lowering = ir.Lowering.init(&module); + + // Lower all root declarations so called functions are available + lowering.lowerDecls(self.root_decls); + + // Create a comptime function that wraps the expression + const func_id = lowering.createComptimeFunction("ct_eval", expr, .s64); + + // Interpret the comptime function + var interp = ir.Interpreter.init(&module, self.allocator); + const result = interp.call(func_id, &.{}) catch return null; + + // Convert ir.Value → comptime_mod.Value + return irValueToComptimeValue(result); + } + + /// Convert an IR interpreter value to a comptime module value. + fn irValueToComptimeValue(val: ir.Value) comptime_mod.Value { + return switch (val) { + .int => |v| .{ .int_val = v }, + .float => |v| .{ .float_val = v }, + .boolean => |v| .{ .bool_val = v }, + .string => |v| .{ .string_val = v }, + .void_val => .{ .void_val = {} }, + .null_val => .{ .null_val = {} }, + .aggregate => .{ .void_val = {} }, // TODO: struct/array conversion + .undef => .{ .void_val = {} }, + .slot_ptr, .func_ref, .closure, .type_tag => .{ .void_val = {} }, + }; + } + /// Try to evaluate a :: call expression entirely at compile time. /// Works for any function where all args are comptime-known. /// Returns the result string if successful, null to fall through to runtime codegen. @@ -11683,10 +11735,10 @@ pub const CodeGen = struct { } pub fn printIR(self: *CodeGen) void { - const ir = c.LLVMPrintModuleToString(self.module); - defer c.LLVMDisposeMessage(ir); - const len = std.mem.len(ir); - std.debug.print("{s}\n", .{ir[0..len]}); + const ir_str = c.LLVMPrintModuleToString(self.module); + defer c.LLVMDisposeMessage(ir_str); + const len = std.mem.len(ir_str); + std.debug.print("{s}\n", .{ir_str[0..len]}); } fn emitToFile(self: *CodeGen, output_path: [*:0]const u8, file_type: c.LLVMCodeGenFileType) !void { diff --git a/src/core.zig b/src/core.zig index b36d006..0208a97 100644 --- a/src/core.zig +++ b/src/core.zig @@ -6,6 +6,7 @@ const sema = @import("sema.zig"); const codegen = @import("codegen.zig"); const errors = @import("errors.zig"); const c_import = @import("c_import.zig"); +const ir = @import("ir/ir.zig"); const Node = ast.Node; pub const TargetConfig = codegen.TargetConfig; @@ -112,6 +113,15 @@ pub const Compilation = struct { return c_import.collectCImportSources(self.allocator, root); } + /// Lower the parsed AST to the sx IR module (shadow pipeline). + pub fn lowerToIR(self: *Compilation) ir.Module { + const root = self.resolved_root orelse self.root orelse return ir.Module.init(self.allocator); + var module = ir.Module.init(self.allocator); + var lowering = ir.Lowering.init(&module); + lowering.lowerRoot(root); + return module; + } + pub fn renderErrors(self: *const Compilation) void { self.diagnostics.renderDebug(); } diff --git a/src/ir/inst.test.zig b/src/ir/inst.test.zig new file mode 100644 index 0000000..a954312 --- /dev/null +++ b/src/ir/inst.test.zig @@ -0,0 +1,62 @@ +// Tests for inst.zig + +const std = @import("std"); +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); + +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Inst = inst_mod.Inst; +const Block = inst_mod.Block; +const Function = inst_mod.Function; + +test "Ref none sentinel" { + try std.testing.expect(Ref.none.isNone()); + try std.testing.expect(!Ref.fromIndex(0).isNone()); +} + +test "basic instruction creation" { + const inst = Inst{ + .op = .{ .add = .{ .lhs = Ref.fromIndex(0), .rhs = Ref.fromIndex(1) } }, + .ty = .s32, + }; + try std.testing.expectEqual(types.TypeId.s32, inst.ty); + switch (inst.op) { + .add => |bin| { + try std.testing.expectEqual(Ref.fromIndex(0), bin.lhs); + try std.testing.expectEqual(Ref.fromIndex(1), bin.rhs); + }, + else => unreachable, + } +} + +test "block creation" { + const alloc = std.testing.allocator; + var block = Block.init(@enumFromInt(1), &.{}); + defer block.deinit(alloc); + + block.insts.append(alloc, .{ + .op = .{ .const_int = 42 }, + .ty = .s64, + }) catch unreachable; + block.insts.append(alloc, .{ + .op = .{ .ret = .{ .operand = Ref.fromIndex(0) } }, + .ty = .s64, + }) catch unreachable; + + try std.testing.expectEqual(@as(usize, 2), block.insts.items.len); +} + +test "function creation" { + const alloc = std.testing.allocator; + const params = &[_]Function.Param{ + .{ .name = @enumFromInt(1), .ty = .s32 }, + .{ .name = @enumFromInt(2), .ty = .s32 }, + }; + var func = Function.init(@enumFromInt(3), params, .s64); + defer func.deinit(alloc); + + try std.testing.expectEqual(types.TypeId.s64, func.ret); + try std.testing.expectEqual(@as(usize, 2), func.params.len); +} diff --git a/src/ir/inst.zig b/src/ir/inst.zig new file mode 100644 index 0000000..5175c02 --- /dev/null +++ b/src/ir/inst.zig @@ -0,0 +1,436 @@ +const std = @import("std"); +const types = @import("types.zig"); +const TypeId = types.TypeId; +const StringId = types.StringId; + +// ── Handles ───────────────────────────────────────────────────────────── + +/// Reference to an SSA value (instruction result). +pub const Ref = enum(u32) { + /// Sentinel for "no value" / unused operand. + none = std.math.maxInt(u32), + _, + + pub fn index(self: Ref) u32 { + return @intFromEnum(self); + } + + pub fn fromIndex(i: u32) Ref { + return @enumFromInt(i); + } + + pub fn isNone(self: Ref) bool { + return self == .none; + } +}; + +pub const BlockId = enum(u32) { + _, + + pub fn index(self: BlockId) u32 { + return @intFromEnum(self); + } + + pub fn fromIndex(i: u32) BlockId { + return @enumFromInt(i); + } +}; + +pub const FuncId = enum(u32) { + _, + + pub fn index(self: FuncId) u32 { + return @intFromEnum(self); + } + + pub fn fromIndex(i: u32) FuncId { + return @enumFromInt(i); + } +}; + +pub const GlobalId = enum(u32) { + _, + + pub fn index(self: GlobalId) u32 { + return @intFromEnum(self); + } + + pub fn fromIndex(i: u32) GlobalId { + return @enumFromInt(i); + } +}; + +// ── Span ──────────────────────────────────────────────────────────────── + +pub const Span = struct { + start: u32 = 0, + end: u32 = 0, +}; + +// ── Instruction ───────────────────────────────────────────────────────── + +pub const Inst = struct { + op: Op, + ty: TypeId, + span: Span = .{}, +}; + +// ── Op (tagged union) ─────────────────────────────────────────────────── + +pub const Op = union(enum) { + // ── Constants ─────────────────────────────────────────────────── + const_int: i64, + const_float: f64, + const_bool: bool, + const_string: StringId, + const_null, + const_undef, // `---` undefined initializer + + // ── Arithmetic ────────────────────────────────────────────────── + add: BinOp, + sub: BinOp, + mul: BinOp, + div: BinOp, + mod: BinOp, + neg: UnaryOp, // unary -x + + // ── Bitwise ───────────────────────────────────────────────────── + bit_and: BinOp, + bit_or: BinOp, + bit_xor: BinOp, + bit_not: UnaryOp, + shl: BinOp, + shr: BinOp, + + // ── Comparison ────────────────────────────────────────────────── + cmp_eq: BinOp, + cmp_ne: BinOp, + cmp_lt: BinOp, + cmp_le: BinOp, + cmp_gt: BinOp, + cmp_ge: BinOp, + + // ── Logical ───────────────────────────────────────────────────── + bool_and: BinOp, // short-circuit && + bool_or: BinOp, // short-circuit || + bool_not: UnaryOp, + + // ── Conversions ───────────────────────────────────────────────── + widen: Conversion, // safe widening (s32 → s64) + narrow: Conversion, // truncation via `xx` (s64 → s32) + bitcast: Conversion, // reinterpret bits + int_to_float: Conversion, + float_to_int: Conversion, + + // ── Memory ────────────────────────────────────────────────────── + alloca: TypeId, // stack allocation, result is *T + load: UnaryOp, // load from pointer + store: Store, // store value to pointer + heap_alloc: UnaryOp, // context.allocator.alloc(size) → *void + heap_free: UnaryOp, // context.allocator.free(ptr) + + // ── Struct ops ────────────────────────────────────────────────── + struct_init: Aggregate, // construct struct from field values + struct_get: FieldAccess, // read struct field by index + struct_gep: FieldAccess, // get pointer to struct field (GEP) + + // ── Enum ops ──────────────────────────────────────────────────── + enum_init: EnumInit, // construct enum value (tag + optional payload) + enum_tag: UnaryOp, // extract tag from enum/union + enum_payload: FieldAccess, // extract payload from tagged union + + // ── Union ops ─────────────────────────────────────────────────── + union_get: FieldAccess, // read union field (reinterpret) + union_gep: FieldAccess, // pointer to union field + + // ── Array/Slice ops ───────────────────────────────────────────── + index_get: BinOp, // arr[idx] → value + index_gep: BinOp, // &arr[idx] → pointer + length: UnaryOp, // .len on slice/string/array + data_ptr: UnaryOp, // .ptr on slice/string + subslice: Subslice, // arr[lo..hi] + array_to_slice: UnaryOp, // [N]T → []T + + // ── Tuple ops ─────────────────────────────────────────────────── + tuple_init: Aggregate, // construct tuple from values + tuple_get: FieldAccess, // read tuple element by index + + // ── Optional ops ──────────────────────────────────────────────── + optional_wrap: UnaryOp, // T → ?T + optional_unwrap: UnaryOp, // ?T → T (UB if null) + optional_has_value: UnaryOp, // ?T → bool + optional_coalesce: BinOp, // a ?? b + + // ── Pointer ops ───────────────────────────────────────────────── + addr_of: UnaryOp, // @x → *T + deref: UnaryOp, // p.* → T + + // ── Vector ops ────────────────────────────────────────────────── + vec_splat: UnaryOp, // scalar → vector (broadcast) + vec_extract: BinOp, // vec[idx] → scalar + vec_insert: TriOp, // vec, idx, val → new_vec + + // ── Calls ─────────────────────────────────────────────────────── + call: Call, + call_indirect: CallIndirect, + call_closure: CallIndirect, + call_builtin: BuiltinCall, + + // ── Protocol dispatch ─────────────────────────────────────────── + protocol_call_dynamic: ProtocolCall, // vtable/inline dispatch + protocol_erase: ProtocolErase, // concrete → protocol value (xx) + + // ── Closure creation ──────────────────────────────────────────── + closure_create: ClosureCreate, + + // ── Context ───────────────────────────────────────────────────── + context_load: ContextOp, // read context field + context_store: ContextOp, // write context field + context_save, // save context state (for push) + context_restore: UnaryOp, // restore context state (after push) + + // ── Globals ───────────────────────────────────────────────────── + global_get: GlobalId, + global_set: GlobalSet, + + // ── Block params (SSA phi alternative) ────────────────────────── + block_param: BlockParam, + + // ── Any type ──────────────────────────────────────────────────── + box_any: BoxAny, // T → Any (erase type) + unbox_any: UnaryOp, // Any → T (restore type) + + // ── Terminators ───────────────────────────────────────────────── + br: Branch, + cond_br: CondBranch, + switch_br: SwitchBranch, + ret: UnaryOp, + ret_void, + @"unreachable", + + // ── Misc ──────────────────────────────────────────────────────── + /// No-op placeholder for unlowered AST nodes. + placeholder: StringId, // name of the unlowered construct +}; + +// ── Operand structs ───────────────────────────────────────────────────── + +pub const UnaryOp = struct { + operand: Ref, +}; + +pub const BinOp = struct { + lhs: Ref, + rhs: Ref, +}; + +pub const TriOp = struct { + a: Ref, + b: Ref, + c: Ref, +}; + +pub const Store = struct { + ptr: Ref, + val: Ref, +}; + +pub const Conversion = struct { + operand: Ref, + from: TypeId, + to: TypeId, +}; + +pub const FieldAccess = struct { + base: Ref, + field_index: u32, +}; + +pub const Aggregate = struct { + fields: []const Ref, +}; + +pub const EnumInit = struct { + tag: u32, + payload: Ref, // Ref.none if no payload +}; + +pub const Subslice = struct { + base: Ref, + lo: Ref, + hi: Ref, +}; + +pub const Call = struct { + callee: FuncId, + args: []const Ref, +}; + +pub const CallIndirect = struct { + callee: Ref, + args: []const Ref, +}; + +pub const BuiltinCall = struct { + builtin: BuiltinId, + args: []const Ref, +}; + +pub const BuiltinId = enum(u16) { + print, + out, + sqrt, + size_of, + cast, + malloc, + free, + memcpy, + memset, +}; + +pub const ProtocolCall = struct { + receiver: Ref, // protocol value (ctx + vtable/fn_ptrs) + method_index: u32, + args: []const Ref, +}; + +pub const ProtocolErase = struct { + concrete: Ref, + protocol_type: TypeId, +}; + +pub const ClosureCreate = struct { + func: FuncId, // trampoline function + env: Ref, // allocated env pointer (or Ref.none for no captures) +}; + +pub const ContextOp = struct { + field: StringId, + value: Ref, // Ref.none for loads +}; + +pub const GlobalSet = struct { + global: GlobalId, + value: Ref, +}; + +pub const BlockParam = struct { + block: BlockId, + param_index: u32, +}; + +pub const BoxAny = struct { + operand: Ref, + source_type: TypeId, +}; + +pub const Branch = struct { + target: BlockId, + args: []const Ref, // block param values +}; + +pub const CondBranch = struct { + cond: Ref, + then_target: BlockId, + then_args: []const Ref, + else_target: BlockId, + else_args: []const Ref, +}; + +pub const SwitchBranch = struct { + operand: Ref, + cases: []const Case, + default: BlockId, + default_args: []const Ref, + + pub const Case = struct { + value: i64, + target: BlockId, + args: []const Ref, + }; +}; + +// ── Block ─────────────────────────────────────────────────────────────── + +pub const Block = struct { + name: StringId, + params: []const TypeId, // block parameter types (SSA phi alternative) + insts: std.ArrayList(Inst), + + pub fn init(name: StringId, params: []const TypeId) Block { + return .{ + .name = name, + .params = params, + .insts = std.ArrayList(Inst).empty, + }; + } + + pub fn deinit(self: *Block, alloc: std.mem.Allocator) void { + self.insts.deinit(alloc); + } +}; + +// ── Function ──────────────────────────────────────────────────────────── + +pub const Function = struct { + name: StringId, + params: []const Param, + ret: TypeId, + blocks: std.ArrayList(Block), + is_extern: bool = false, + is_comptime: bool = false, + linkage: Linkage = .internal, + + pub const Param = struct { + name: StringId, + ty: TypeId, + }; + + pub const Linkage = enum { + internal, + external, + private, + }; + + pub fn init(name: StringId, params: []const Param, ret: TypeId) Function { + return .{ + .name = name, + .params = params, + .ret = ret, + .blocks = std.ArrayList(Block).empty, + }; + } + + pub fn deinit(self: *Function, alloc: std.mem.Allocator) void { + for (self.blocks.items) |*block| { + block.deinit(alloc); + } + self.blocks.deinit(alloc); + } +}; + +// ── Global ────────────────────────────────────────────────────────────── + +pub const Global = struct { + name: StringId, + ty: TypeId, + init_val: ?ConstantValue = null, + is_extern: bool = false, + is_const: bool = false, + /// For comptime globals: the function to interpret to get the init value. + comptime_func: ?FuncId = null, +}; + +// ── ConstantValue ─────────────────────────────────────────────────────── + +pub const ConstantValue = union(enum) { + int: i64, + float: f64, + boolean: bool, + string: StringId, + null_val, + undef, + zeroinit, + aggregate: []const ConstantValue, +}; + diff --git a/src/ir/interp.test.zig b/src/ir/interp.test.zig new file mode 100644 index 0000000..8e62d66 --- /dev/null +++ b/src/ir/interp.test.zig @@ -0,0 +1,649 @@ +// Tests for the IR interpreter (interp.zig). +// Includes basic interpreter tests and comptime parity tests. + +const std = @import("std"); +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); +const mod_mod = @import("module.zig"); +const interp_mod = @import("interp.zig"); + +const TypeId = types.TypeId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Function = inst_mod.Function; +const Module = mod_mod.Module; +const Builder = mod_mod.Builder; +const Interpreter = interp_mod.Interpreter; +const Value = interp_mod.Value; + +// ── Helper ────────────────────────────────────────────────────────────── + +fn str(module: *Module, s: []const u8) types.StringId { + return module.types.internString(s); +} + +// ── Basic interpreter tests (migrated from interp.zig) ────────────────── + +test "interpret: compute(5) = 25" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + + var b = Builder.init(&module); + + // func compute(x: s64) -> s64 { return x * x; } + const params = &[_]Function.Param{.{ .name = str(&module, "compute"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "compute"), params, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const x_ref = Ref.fromIndex(0); + const result = b.mul(x_ref, x_ref, .s64); + b.ret(result, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + const val = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 5 }}); + try std.testing.expectEqual(@as(i64, 25), val.asInt().?); +} + +test "interpret: if/else branching" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + + var b = Builder.init(&module); + + const params = &[_]Function.Param{.{ .name = str(&module, "x"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "abs"), params, .s64); + + const entry = b.appendBlock(str(&module, "entry"), &.{}); + const then_bb = b.appendBlock(str(&module, "then"), &.{}); + const else_bb = b.appendBlock(str(&module, "else"), &.{}); + + b.switchToBlock(entry); + const x = Ref.fromIndex(0); + const zero = b.constInt(0, .s64); + const is_neg = b.cmpLt(x, zero); + b.condBr(is_neg, then_bb, &.{}, else_bb, &.{}); + + b.switchToBlock(then_bb); + const neg_x = b.emit(.{ .neg = .{ .operand = x } }, .s64); + b.ret(neg_x, .s64); + + b.switchToBlock(else_bb); + b.ret(x, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + const val1 = try interp.call(FuncId.fromIndex(0), &.{.{ .int = -7 }}); + try std.testing.expectEqual(@as(i64, 7), val1.asInt().?); + + const val2 = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 3 }}); + try std.testing.expectEqual(@as(i64, 3), val2.asInt().?); +} + +test "interpret: function calling another function" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + + var b = Builder.init(&module); + + // func square(x: s64) -> s64 { return x * x; } + const params_sq = &[_]Function.Param{.{ .name = str(&module, "x"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "square"), params_sq, .s64); + const entry1 = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry1); + const x = Ref.fromIndex(0); + const sq = b.mul(x, x, .s64); + b.ret(sq, .s64); + b.finalize(); + + // func sum_of_squares(a, b) -> s64 { return square(a) + square(b); } + const params_ss = &[_]Function.Param{ + .{ .name = str(&module, "a"), .ty = .s64 }, + .{ .name = str(&module, "b"), .ty = .s64 }, + }; + _ = b.beginFunction(str(&module, "sum_of_squares"), params_ss, .s64); + const entry2 = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry2); + const a = Ref.fromIndex(0); + const b_param = Ref.fromIndex(1); + const sq_a = b.call(FuncId.fromIndex(0), &.{a}, .s64); + const sq_b = b.call(FuncId.fromIndex(0), &.{b_param}, .s64); + const sum = b.add(sq_a, sq_b, .s64); + b.ret(sum, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + const val = try interp.call(FuncId.fromIndex(1), &.{ .{ .int = 3 }, .{ .int = 4 } }); + try std.testing.expectEqual(@as(i64, 25), val.asInt().?); +} + +test "interpret: alloca/store/load" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "test"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const slot = b.alloca(.s64); + const ten = b.constInt(10, .s64); + b.store(slot, ten); + const loaded = b.load(slot, .s64); + const five = b.constInt(5, .s64); + const sum = b.add(loaded, five, .s64); + b.store(slot, sum); + const result = b.load(slot, .s64); + b.ret(result, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 15), val.asInt().?); +} + +// ── Comptime parity tests ─────────────────────────────────────────────── + +// ── Test: while loop (sumOf10 from 15-while.sx) ───────────────────────── +// sumOf10 :: () -> s32 { i:=1; s:=0; while i<=10 { s+=i; i+=1; } s; } +// Expected: 55 + +test "comptime: while loop — sumOf10 = 55" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "sumOf10"), &.{}, .s64); + + const entry = b.appendBlock(str(&module, "entry"), &.{}); + const hdr = b.appendBlock(str(&module, "while.hdr"), &.{}); + const body = b.appendBlock(str(&module, "while.body"), &.{}); + const exit = b.appendBlock(str(&module, "while.exit"), &.{}); + + // entry: i=1, s=0, br while.hdr + b.switchToBlock(entry); + const i_slot = b.alloca(.s64); + const one = b.constInt(1, .s64); + b.store(i_slot, one); + const s_slot = b.alloca(.s64); + const zero = b.constInt(0, .s64); + b.store(s_slot, zero); + b.br(hdr, &.{}); + + // while.hdr: if i <= 10 → body, else → exit + b.switchToBlock(hdr); + const i_load = b.load(i_slot, .s64); + const ten = b.constInt(10, .s64); + const cond = b.emit(.{ .cmp_le = .{ .lhs = i_load, .rhs = ten } }, .bool); + b.condBr(cond, body, &.{}, exit, &.{}); + + // while.body: s += i; i += 1; br while.hdr + b.switchToBlock(body); + const s_load = b.load(s_slot, .s64); + const i_load2 = b.load(i_slot, .s64); + const s_new = b.add(s_load, i_load2, .s64); + b.store(s_slot, s_new); + const i_load3 = b.load(i_slot, .s64); + const one2 = b.constInt(1, .s64); + const i_new = b.add(i_load3, one2, .s64); + b.store(i_slot, i_new); + b.br(hdr, &.{}); + + // while.exit: return s + b.switchToBlock(exit); + const s_final = b.load(s_slot, .s64); + b.ret(s_final, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 55), val.asInt().?); +} + +// ── Test: optional coalesce (ct_sum from 32-optionals.sx) ──────────────── +// ct_sum :: () -> s32 { x:?s32=42; y:?s32=null; return (x??0)+(y??99); } +// Expected: 42 + 99 = 141 + +test "comptime: optional coalesce — ct_sum = 141" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "ct_sum"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + // x: ?s32 = 42 → alloca, store 42 + const x_slot = b.alloca(.s64); + const forty_two = b.constInt(42, .s64); + b.store(x_slot, forty_two); + + // y: ?s32 = null → alloca, store null + const y_slot = b.alloca(.s64); + const null_val = b.constNull(.s64); + b.store(y_slot, null_val); + + // (x ?? 0) + const x_load = b.load(x_slot, .s64); + const zero = b.constInt(0, .s64); + const x_coalesced = b.emit(.{ .optional_coalesce = .{ .lhs = x_load, .rhs = zero } }, .s64); + + // (y ?? 99) + const y_load = b.load(y_slot, .s64); + const ninety_nine = b.constInt(99, .s64); + const y_coalesced = b.emit(.{ .optional_coalesce = .{ .lhs = y_load, .rhs = ninety_nine } }, .s64); + + // return x_coalesced + y_coalesced + const sum = b.add(x_coalesced, y_coalesced, .s64); + b.ret(sum, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 141), val.asInt().?); +} + +// ── Test: optional unwrap (ct_opt_unwrap from 50-smoke.sx) ─────────────── +// ct_opt_unwrap :: () -> s32 { x:?s32 = 77; return x!; } +// Expected: 77 + +test "comptime: optional unwrap — 77" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "ct_opt_unwrap"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const slot = b.alloca(.s64); + const val77 = b.constInt(77, .s64); + b.store(slot, val77); + + const loaded = b.load(slot, .s64); + const unwrapped = b.optionalUnwrap(loaded, .s64); + b.ret(unwrapped, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 77), val.asInt().?); +} + +// ── Test: recursive fibonacci ──────────────────────────────────────────── +// fib :: (n: s64) -> s64 { if n <= 1 return n; return fib(n-1) + fib(n-2); } +// Expected: fib(10) = 55 + +test "comptime: recursive fibonacci — fib(10) = 55" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + const params = &[_]Function.Param{.{ .name = str(&module, "n"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "fib"), params, .s64); + + const entry = b.appendBlock(str(&module, "entry"), &.{}); + const base_bb = b.appendBlock(str(&module, "base"), &.{}); + const rec_bb = b.appendBlock(str(&module, "recurse"), &.{}); + + // entry: if n <= 1 → base, else → recurse + b.switchToBlock(entry); + const n = Ref.fromIndex(0); + const one = b.constInt(1, .s64); + const is_base = b.emit(.{ .cmp_le = .{ .lhs = n, .rhs = one } }, .bool); + b.condBr(is_base, base_bb, &.{}, rec_bb, &.{}); + + // base: return n + b.switchToBlock(base_bb); + b.ret(n, .s64); + + // recurse: return fib(n-1) + fib(n-2) + b.switchToBlock(rec_bb); + const n_minus_1 = b.sub(n, one, .s64); + const two = b.constInt(2, .s64); + const n_minus_2 = b.sub(n, two, .s64); + const fib1 = b.call(FuncId.fromIndex(0), &.{n_minus_1}, .s64); + const fib2 = b.call(FuncId.fromIndex(0), &.{n_minus_2}, .s64); + const sum = b.add(fib1, fib2, .s64); + b.ret(sum, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 10 }}); + try std.testing.expectEqual(@as(i64, 55), val.asInt().?); +} + +// ── Test: compute(5) = 7 (from 05-run.sx) ────────────────────────────── +// compute :: (v: s32) -> s32 => v + 2; +// Expected: compute(5) = 7 + +test "comptime: compute(5) = 7" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + const params = &[_]Function.Param{.{ .name = str(&module, "v"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "compute"), params, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const v = Ref.fromIndex(0); + const two = b.constInt(2, .s64); + const result = b.add(v, two, .s64); + b.ret(result, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 5 }}); + try std.testing.expectEqual(@as(i64, 7), val.asInt().?); +} + +// ── Test: chained comptime (CT_CHAIN from 50-smoke.sx) ─────────────────── +// add :: (a: s32, b: s32) -> s32 => a + b; +// CT_VAL :: #run add(10, 15); → 25 +// CT_CHAIN :: #run add(CT_VAL, 5); → 30 +// Simulates calling add(25, 5) to verify chaining works. + +test "comptime: chained — add(add(10,15), 5) = 30" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + // func add(a, b) -> s64 { return a + b; } + const params = &[_]Function.Param{ + .{ .name = str(&module, "a"), .ty = .s64 }, + .{ .name = str(&module, "b"), .ty = .s64 }, + }; + _ = b.beginFunction(str(&module, "add"), params, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + const a = Ref.fromIndex(0); + const b_ref = Ref.fromIndex(1); + const sum = b.add(a, b_ref, .s64); + b.ret(sum, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + // First: add(10, 15) = 25 + const ct_val = try interp.call(FuncId.fromIndex(0), &.{ .{ .int = 10 }, .{ .int = 15 } }); + try std.testing.expectEqual(@as(i64, 25), ct_val.asInt().?); + + // Then: add(25, 5) = 30 (chained) + const ct_chain = try interp.call(FuncId.fromIndex(0), &.{ ct_val, .{ .int = 5 } }); + try std.testing.expectEqual(@as(i64, 30), ct_chain.asInt().?); +} + +// ── Test: struct init + field access ───────────────────────────────────── +// p := Point{x: 3, y: 4}; return p.x + p.y; +// Expected: 7 + +test "comptime: struct init and field access — 7" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "test_struct"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + // Point{x: 3, y: 4} + const three = b.constInt(3, .s64); + const four = b.constInt(4, .s64); + const point = b.structInit(&.{ three, four }, .s64); + + // p.x + p.y + const px = b.structGet(point, 0, .s64); + const py = b.structGet(point, 1, .s64); + const sum = b.add(px, py, .s64); + b.ret(sum, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 7), val.asInt().?); +} + +// ── Test: float arithmetic ────────────────────────────────────────────── +// compute :: (x: f64) -> f64 { return x * 2.5 + 1.0; } +// Expected: compute(3.0) = 8.5 + +test "comptime: float arithmetic — 8.5" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + const params = &[_]Function.Param{.{ .name = str(&module, "x"), .ty = .f64 }}; + _ = b.beginFunction(str(&module, "compute_f"), params, .f64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const x = Ref.fromIndex(0); + const two_five = b.constFloat(2.5, .f64); + const product = b.mul(x, two_five, .f64); + const one = b.constFloat(1.0, .f64); + const result = b.add(product, one, .f64); + b.ret(result, .f64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{.{ .float = 3.0 }}); + try std.testing.expectEqual(@as(f64, 8.5), val.asFloat().?); +} + +// ── Test: boolean logic ───────────────────────────────────────────────── +// test :: (a: bool, b: bool) -> bool { return (a and b) or (not a); } +// Expected: test(true, false) = true (because not a = false, a and b = false, false or false... wait) +// Actually: a=true, b=false → (true and false) or (not true) = false or false = false +// test(false, true) → (false and true) or (not false) = false or true = true + +test "comptime: boolean logic" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + const params = &[_]Function.Param{ + .{ .name = str(&module, "a"), .ty = .bool }, + .{ .name = str(&module, "b"), .ty = .bool }, + }; + _ = b.beginFunction(str(&module, "bool_test"), params, .bool); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const a_ref = Ref.fromIndex(0); + const b_ref = Ref.fromIndex(1); + const and_ab = b.emit(.{ .bool_and = .{ .lhs = a_ref, .rhs = b_ref } }, .bool); + const not_a = b.emit(.{ .bool_not = .{ .operand = a_ref } }, .bool); + const result = b.emit(.{ .bool_or = .{ .lhs = and_ab, .rhs = not_a } }, .bool); + b.ret(result, .bool); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + // test(true, false) = false or false = false + const val1 = try interp.call(FuncId.fromIndex(0), &.{ .{ .boolean = true }, .{ .boolean = false } }); + try std.testing.expectEqual(false, val1.asBool().?); + + // test(false, true) = false or true = true + const val2 = try interp.call(FuncId.fromIndex(0), &.{ .{ .boolean = false }, .{ .boolean = true } }); + try std.testing.expectEqual(true, val2.asBool().?); +} + +// ── Test: negation ────────────────────────────────────────────────────── + +test "comptime: negation — int and float" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + // func neg_int(x: s64) -> s64 { return -x; } + const params = &[_]Function.Param{.{ .name = str(&module, "x"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "neg_int"), params, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + const x = Ref.fromIndex(0); + const neg = b.emit(.{ .neg = .{ .operand = x } }, .s64); + b.ret(neg, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 42 }}); + try std.testing.expectEqual(@as(i64, -42), val.asInt().?); +} + +// ── Test: modulo ──────────────────────────────────────────────────────── + +test "comptime: modulo — 17 mod 5 = 2" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "test_mod"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const seventeen = b.constInt(17, .s64); + const five = b.constInt(5, .s64); + const result = b.emit(.{ .mod = .{ .lhs = seventeen, .rhs = five } }, .s64); + b.ret(result, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 2), val.asInt().?); +} + +// ── Test: switch_br (enum tag dispatch) ────────────────────────────────── +// Simulates: match tag { 0 => 10, 1 => 20, else => 30 } + +test "comptime: switch_br dispatch" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + const params = &[_]Function.Param{.{ .name = str(&module, "tag"), .ty = .s64 }}; + _ = b.beginFunction(str(&module, "dispatch"), params, .s64); + + const entry = b.appendBlock(str(&module, "entry"), &.{}); + const case0 = b.appendBlock(str(&module, "case0"), &.{}); + const case1 = b.appendBlock(str(&module, "case1"), &.{}); + const default = b.appendBlock(str(&module, "default"), &.{}); + + b.switchToBlock(entry); + const tag = Ref.fromIndex(0); + b.switchBr(tag, &.{ + .{ .value = 0, .target = case0, .args = &.{} }, + .{ .value = 1, .target = case1, .args = &.{} }, + }, default, &.{}); + + b.switchToBlock(case0); + const ten = b.constInt(10, .s64); + b.ret(ten, .s64); + + b.switchToBlock(case1); + const twenty = b.constInt(20, .s64); + b.ret(twenty, .s64); + + b.switchToBlock(default); + const thirty = b.constInt(30, .s64); + b.ret(thirty, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + + const v0 = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 0 }}); + try std.testing.expectEqual(@as(i64, 10), v0.asInt().?); + + const v1 = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 1 }}); + try std.testing.expectEqual(@as(i64, 20), v1.asInt().?); + + const v2 = try interp.call(FuncId.fromIndex(0), &.{.{ .int = 99 }}); + try std.testing.expectEqual(@as(i64, 30), v2.asInt().?); +} + +// ── Test: enum init + tag extraction ──────────────────────────────────── + +test "comptime: enum init and tag" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "test_enum"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + // Create enum with tag=2, no payload + const e = b.enumInit(2, Ref.none, .s64); + const tag = b.emit(.{ .enum_tag = .{ .operand = e } }, .s64); + b.ret(tag, .s64); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const val = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 2), val.asInt().?); +} + +// ── Test: conversion (widen/narrow passthrough) ───────────────────────── + +test "comptime: widen/narrow passthrough" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + var b = Builder.init(&module); + + _ = b.beginFunction(str(&module, "test_conv"), &.{}, .s64); + const entry = b.appendBlock(str(&module, "entry"), &.{}); + b.switchToBlock(entry); + + const val = b.constInt(42, .s32); + const widened = b.emit(.{ .widen = .{ .operand = val, .from = .s32, .to = .s64 } }, .s64); + const narrowed = b.emit(.{ .narrow = .{ .operand = widened, .from = .s64, .to = .s32 } }, .s32); + b.ret(narrowed, .s32); + b.finalize(); + + var interp = Interpreter.init(&module, alloc); + defer interp.deinit(); + const result = try interp.call(FuncId.fromIndex(0), &.{}); + try std.testing.expectEqual(@as(i64, 42), result.asInt().?); +} diff --git a/src/ir/interp.zig b/src/ir/interp.zig new file mode 100644 index 0000000..7b338ef --- /dev/null +++ b/src/ir/interp.zig @@ -0,0 +1,541 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); +const mod_mod = @import("module.zig"); + +const TypeId = types.TypeId; +const TypeTable = types.TypeTable; +const StringId = types.StringId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Inst = inst_mod.Inst; +const Op = inst_mod.Op; +const Function = inst_mod.Function; +const Block = inst_mod.Block; +const Module = mod_mod.Module; +const Builder = mod_mod.Builder; + +// ── Value ─────────────────────────────────────────────────────────────── + +pub const Value = union(enum) { + int: i64, + float: f64, + boolean: bool, + string: []const u8, + null_val, + void_val, + undef, + aggregate: []const Value, + slot_ptr: u32, // index into the frame's local slots + func_ref: FuncId, + closure: ClosureVal, + type_tag: TypeId, + + pub const ClosureVal = struct { + func: FuncId, + env: ?[]const Value, + }; + + pub fn asInt(self: Value) ?i64 { + return switch (self) { + .int => |v| v, + else => null, + }; + } + + pub fn asFloat(self: Value) ?f64 { + return switch (self) { + .float => |v| v, + .int => |v| @floatFromInt(v), // implicit int→float for convenience + else => null, + }; + } + + pub fn asBool(self: Value) ?bool { + return switch (self) { + .boolean => |v| v, + else => null, + }; + } + + pub fn isNull(self: Value) bool { + return self == .null_val; + } +}; + +// ── Error ─────────────────────────────────────────────────────────────── + +pub const InterpError = error{ + CannotEvalComptime, + TypeError, + OutOfBounds, + DivisionByZero, + StackOverflow, + Unreachable, +}; + +// ── Interpreter ───────────────────────────────────────────────────────── + +pub const Interpreter = struct { + module: *const Module, + alloc: Allocator, + output: std.ArrayList(u8), + call_depth: u32 = 0, + max_call_depth: u32 = 256, + + pub fn init(module: *const Module, alloc: Allocator) Interpreter { + return .{ + .module = module, + .alloc = alloc, + .output = std.ArrayList(u8).empty, + }; + } + + pub fn deinit(self: *Interpreter) void { + self.output.deinit(self.alloc); + } + + pub fn call(self: *Interpreter, func_id: FuncId, args: []const Value) InterpError!Value { + if (self.call_depth >= self.max_call_depth) return error.StackOverflow; + self.call_depth += 1; + defer self.call_depth -= 1; + + const func = self.module.getFunction(func_id); + if (func.is_extern) return error.CannotEvalComptime; + if (func.blocks.items.len == 0) return error.CannotEvalComptime; + + var frame = Frame.init(self.alloc); + defer frame.deinit(); + + // Bind parameters as initial refs + for (args) |arg| { + frame.pushRef(self.alloc, arg); + } + + // Start at the entry block (index 0) + var current_block: BlockId = BlockId.fromIndex(0); + var block_args: []const Value = &.{}; + + while (true) { + const block = &func.blocks.items[current_block.index()]; + + // Bind block params + for (block_args) |arg| { + frame.pushRef(self.alloc, arg); + } + + for (block.insts.items) |*instruction| { + const result = try self.execInst(instruction, &frame, ¤t_block, &block_args); + switch (result) { + .value => |val| frame.pushRef(self.alloc, val), + .branch => break, // current_block and block_args updated by execInst + .ret_val => |val| return val, + .ret_nothing => return .void_val, + } + } else { + // Fell through the block with no terminator — treat as implicit return void + return .void_val; + } + } + } + + const ExecResult = union(enum) { + value: Value, + branch, + ret_val: Value, + ret_nothing, + }; + + fn execInst(self: *Interpreter, instruction: *const Inst, frame: *Frame, current_block: *BlockId, block_args: *[]const Value) InterpError!ExecResult { + const op = instruction.op; + + switch (op) { + // ── Constants ─────────────────────────────────────── + .const_int => |v| return .{ .value = .{ .int = v } }, + .const_float => |v| return .{ .value = .{ .float = v } }, + .const_bool => |v| return .{ .value = .{ .boolean = v } }, + .const_string => |sid| return .{ .value = .{ .string = self.module.types.getString(sid) } }, + .const_null => return .{ .value = .null_val }, + .const_undef => return .{ .value = .undef }, + + // ── Arithmetic ────────────────────────────────────── + .add => |b| return .{ .value = try self.evalArith(frame, b, .add) }, + .sub => |b| return .{ .value = try self.evalArith(frame, b, .sub) }, + .mul => |b| return .{ .value = try self.evalArith(frame, b, .mul) }, + .div => |b| return .{ .value = try self.evalArith(frame, b, .div) }, + .mod => |b| return .{ .value = try self.evalArith(frame, b, .mod) }, + .neg => |u| { + const val = frame.getRef(u.operand); + return .{ .value = switch (val) { + .int => |v| .{ .int = -v }, + .float => |v| .{ .float = -v }, + else => return error.TypeError, + } }; + }, + + // ── Comparison ────────────────────────────────────── + .cmp_eq => |b| return .{ .value = .{ .boolean = try self.evalCmp(frame, b, .eq) } }, + .cmp_ne => |b| return .{ .value = .{ .boolean = try self.evalCmp(frame, b, .ne) } }, + .cmp_lt => |b| return .{ .value = .{ .boolean = try self.evalCmp(frame, b, .lt) } }, + .cmp_le => |b| return .{ .value = .{ .boolean = try self.evalCmp(frame, b, .le) } }, + .cmp_gt => |b| return .{ .value = .{ .boolean = try self.evalCmp(frame, b, .gt) } }, + .cmp_ge => |b| return .{ .value = .{ .boolean = try self.evalCmp(frame, b, .ge) } }, + + // ── Logical ───────────────────────────────────────── + .bool_and => |b| { + const lhs = frame.getRef(b.lhs).asBool() orelse return error.TypeError; + if (!lhs) return .{ .value = .{ .boolean = false } }; + const rhs = frame.getRef(b.rhs).asBool() orelse return error.TypeError; + return .{ .value = .{ .boolean = rhs } }; + }, + .bool_or => |b| { + const lhs = frame.getRef(b.lhs).asBool() orelse return error.TypeError; + if (lhs) return .{ .value = .{ .boolean = true } }; + const rhs = frame.getRef(b.rhs).asBool() orelse return error.TypeError; + return .{ .value = .{ .boolean = rhs } }; + }, + .bool_not => |u| { + const val = frame.getRef(u.operand).asBool() orelse return error.TypeError; + return .{ .value = .{ .boolean = !val } }; + }, + + // ── Conversions ───────────────────────────────────── + .widen, .narrow => |c| { + const val = frame.getRef(c.operand); + return .{ .value = val }; // comptime values don't truncate + }, + .bitcast => |c| { + const val = frame.getRef(c.operand); + return .{ .value = val }; + }, + .int_to_float => |c| { + const val = frame.getRef(c.operand); + const i = val.asInt() orelse return error.TypeError; + return .{ .value = .{ .float = @floatFromInt(i) } }; + }, + .float_to_int => |c| { + const val = frame.getRef(c.operand); + const f = val.asFloat() orelse return error.TypeError; + return .{ .value = .{ .int = @intFromFloat(f) } }; + }, + + // ── Memory (stack simulation) ─────────────────────── + .alloca => { + const slot = frame.allocSlot(self.alloc); + return .{ .value = .{ .slot_ptr = slot } }; + }, + .load => |u| { + const ptr = frame.getRef(u.operand); + switch (ptr) { + .slot_ptr => |slot| return .{ .value = frame.loadSlot(slot) }, + else => return error.CannotEvalComptime, + } + }, + .store => |s| { + const ptr = frame.getRef(s.ptr); + const val = frame.getRef(s.val); + switch (ptr) { + .slot_ptr => |slot| frame.storeSlot(slot, val), + else => return error.CannotEvalComptime, + } + return .{ .value = .void_val }; + }, + + // ── Struct ops ────────────────────────────────────── + .struct_init => |agg| { + const fields = self.alloc.alloc(Value, agg.fields.len) catch return error.CannotEvalComptime; + for (agg.fields, 0..) |ref, i| { + fields[i] = frame.getRef(ref); + } + return .{ .value = .{ .aggregate = fields } }; + }, + .struct_get => |fa| { + const base = frame.getRef(fa.base); + switch (base) { + .aggregate => |fields| { + if (fa.field_index >= fields.len) return error.OutOfBounds; + return .{ .value = fields[fa.field_index] }; + }, + else => return error.TypeError, + } + }, + + // ── Enum ops ──────────────────────────────────────── + .enum_init => |ei| { + if (ei.payload.isNone()) { + return .{ .value = .{ .int = @intCast(ei.tag) } }; + } else { + const payload = frame.getRef(ei.payload); + const fields = self.alloc.alloc(Value, 2) catch return error.CannotEvalComptime; + fields[0] = .{ .int = @intCast(ei.tag) }; + fields[1] = payload; + return .{ .value = .{ .aggregate = fields } }; + } + }, + .enum_tag => |u| { + const val = frame.getRef(u.operand); + switch (val) { + .int => return .{ .value = val }, + .aggregate => |fields| { + if (fields.len == 0) return error.TypeError; + return .{ .value = fields[0] }; + }, + else => return error.TypeError, + } + }, + .enum_payload => |fa| { + const base = frame.getRef(fa.base); + switch (base) { + .aggregate => |fields| { + if (fa.field_index + 1 >= fields.len) return error.OutOfBounds; + return .{ .value = fields[fa.field_index + 1] }; + }, + else => return error.TypeError, + } + }, + + // ── Optional ops ──────────────────────────────────── + .optional_wrap => |u| { + const val = frame.getRef(u.operand); + return .{ .value = val }; // wrapped value is just the value + }, + .optional_unwrap => |u| { + const val = frame.getRef(u.operand); + if (val.isNull()) return error.TypeError; // unwrapping null + return .{ .value = val }; + }, + .optional_has_value => |u| { + const val = frame.getRef(u.operand); + return .{ .value = .{ .boolean = !val.isNull() } }; + }, + .optional_coalesce => |b| { + const lhs = frame.getRef(b.lhs); + if (!lhs.isNull()) return .{ .value = lhs }; + return .{ .value = frame.getRef(b.rhs) }; + }, + + // ── Calls ─────────────────────────────────────────── + .call => |c| { + const args = self.alloc.alloc(Value, c.args.len) catch return error.CannotEvalComptime; + defer self.alloc.free(args); + for (c.args, 0..) |ref, i| { + args[i] = frame.getRef(ref); + } + const result = try self.call(c.callee, args); + return .{ .value = result }; + }, + + // ── Block params ──────────────────────────────────── + .block_param => { + // Block params are pushed at the start of block execution. + // This instruction is a no-op; the value was already pushed + // during block arg binding. + return .{ .value = .void_val }; + }, + + // ── Terminators ───────────────────────────────────── + .br => |b| { + const args = self.alloc.alloc(Value, b.args.len) catch return error.CannotEvalComptime; + for (b.args, 0..) |ref, i| { + args[i] = frame.getRef(ref); + } + current_block.* = b.target; + block_args.* = args; + return .branch; + }, + .cond_br => |cb| { + const cond = frame.getRef(cb.cond).asBool() orelse return error.TypeError; + if (cond) { + const args = self.alloc.alloc(Value, cb.then_args.len) catch return error.CannotEvalComptime; + for (cb.then_args, 0..) |ref, i| { + args[i] = frame.getRef(ref); + } + current_block.* = cb.then_target; + block_args.* = args; + } else { + const args = self.alloc.alloc(Value, cb.else_args.len) catch return error.CannotEvalComptime; + for (cb.else_args, 0..) |ref, i| { + args[i] = frame.getRef(ref); + } + current_block.* = cb.else_target; + block_args.* = args; + } + return .branch; + }, + .switch_br => |sb| { + const operand = frame.getRef(sb.operand).asInt() orelse return error.TypeError; + for (sb.cases) |case| { + if (operand == case.value) { + const args = self.alloc.alloc(Value, case.args.len) catch return error.CannotEvalComptime; + for (case.args, 0..) |ref, i| { + args[i] = frame.getRef(ref); + } + current_block.* = case.target; + block_args.* = args; + return .branch; + } + } + // Default + const args = self.alloc.alloc(Value, sb.default_args.len) catch return error.CannotEvalComptime; + for (sb.default_args, 0..) |ref, i| { + args[i] = frame.getRef(ref); + } + current_block.* = sb.default; + block_args.* = args; + return .branch; + }, + .ret => |u| { + return .{ .ret_val = frame.getRef(u.operand) }; + }, + .ret_void => return .ret_nothing, + .@"unreachable" => return error.Unreachable, + + // ── Not evaluable at comptime ─────────────────────── + .heap_alloc, .heap_free, .call_indirect, .call_closure, .call_builtin, .protocol_call_dynamic, .protocol_erase, .closure_create, .context_load, .context_store, .context_save, .context_restore, .global_get, .global_set, .box_any, .unbox_any, .struct_gep, .union_get, .union_gep, .index_get, .index_gep, .length, .data_ptr, .subslice, .array_to_slice, .tuple_init, .tuple_get, .addr_of, .deref, .vec_splat, .vec_extract, .vec_insert, .bit_and, .bit_or, .bit_xor, .bit_not, .shl, .shr, .placeholder => { + return error.CannotEvalComptime; + }, + } + } + + // ── Arithmetic helpers ────────────────────────────────────────── + + const ArithOp = enum { add, sub, mul, div, mod }; + + fn evalArith(self: *Interpreter, frame: *Frame, b: inst_mod.BinOp, comptime aop: ArithOp) InterpError!Value { + _ = self; + const lhs = frame.getRef(b.lhs); + const rhs = frame.getRef(b.rhs); + + // Both int + if (lhs.asInt()) |li| { + if (rhs.asInt()) |ri| { + return .{ .int = switch (aop) { + .add => li +% ri, + .sub => li -% ri, + .mul => li *% ri, + .div => if (ri == 0) return error.DivisionByZero else @divTrunc(li, ri), + .mod => if (ri == 0) return error.DivisionByZero else @mod(li, ri), + } }; + } + } + + // Both float (or int promoted to float) + if (lhs.asFloat()) |lf| { + if (rhs.asFloat()) |rf| { + return .{ .float = switch (aop) { + .add => lf + rf, + .sub => lf - rf, + .mul => lf * rf, + .div => if (rf == 0.0) return error.DivisionByZero else lf / rf, + .mod => @mod(lf, rf), + } }; + } + } + + return error.TypeError; + } + + // ── Comparison helpers ────────────────────────────────────────── + + const CmpOp = enum { eq, ne, lt, le, gt, ge }; + + fn evalCmp(self: *Interpreter, frame: *Frame, b: inst_mod.BinOp, comptime cop: CmpOp) InterpError!bool { + _ = self; + const lhs = frame.getRef(b.lhs); + const rhs = frame.getRef(b.rhs); + + // Both int + if (lhs.asInt()) |li| { + if (rhs.asInt()) |ri| { + return switch (cop) { + .eq => li == ri, + .ne => li != ri, + .lt => li < ri, + .le => li <= ri, + .gt => li > ri, + .ge => li >= ri, + }; + } + } + + // Both float + if (lhs.asFloat()) |lf| { + if (rhs.asFloat()) |rf| { + return switch (cop) { + .eq => lf == rf, + .ne => lf != rf, + .lt => lf < rf, + .le => lf <= rf, + .gt => lf > rf, + .ge => lf >= rf, + }; + } + } + + // Bool equality + if (lhs.asBool()) |lb| { + if (rhs.asBool()) |rb| { + return switch (cop) { + .eq => lb == rb, + .ne => lb != rb, + else => return error.TypeError, + }; + } + } + + return error.TypeError; + } +}; + +// ── Frame ─────────────────────────────────────────────────────────────── +// Holds SSA values (by Ref index) and local mutable slots (for alloca). + +const Frame = struct { + refs: std.ArrayList(Value), + slots: std.ArrayList(Value), + + fn init(alloc: Allocator) Frame { + _ = alloc; + return .{ + .refs = std.ArrayList(Value).empty, + .slots = std.ArrayList(Value).empty, + }; + } + + fn deinit(self: *Frame) void { + // We use the interpreter's allocator for everything — it's an arena-like pattern. + // Actual cleanup handled by the test allocator. + _ = self; + } + + fn pushRef(self: *Frame, alloc: Allocator, val: Value) void { + self.refs.append(alloc, val) catch unreachable; + } + + fn getRef(self: *const Frame, ref: Ref) Value { + if (ref.isNone()) return .void_val; + const idx = ref.index(); + if (idx >= self.refs.items.len) return .undef; + return self.refs.items[idx]; + } + + fn allocSlot(self: *Frame, alloc: Allocator) u32 { + const idx: u32 = @intCast(self.slots.items.len); + self.slots.append(alloc, .undef) catch unreachable; + return idx; + } + + fn loadSlot(self: *const Frame, slot: u32) Value { + if (slot >= self.slots.items.len) return .undef; + return self.slots.items[slot]; + } + + fn storeSlot(self: *Frame, slot: u32, val: Value) void { + if (slot < self.slots.items.len) { + self.slots.items[slot] = val; + } + } +}; + diff --git a/src/ir/ir.zig b/src/ir/ir.zig new file mode 100644 index 0000000..d65c24d --- /dev/null +++ b/src/ir/ir.zig @@ -0,0 +1,48 @@ +pub const types = @import("types.zig"); +pub const inst = @import("inst.zig"); +pub const module = @import("module.zig"); +pub const print = @import("print.zig"); +pub const interp = @import("interp.zig"); +pub const lower = @import("lower.zig"); + +pub const TypeId = types.TypeId; +pub const TypeInfo = types.TypeInfo; +pub const TypeTable = types.TypeTable; +pub const StringId = types.StringId; +pub const StringPool = types.StringPool; + +pub const Ref = inst.Ref; +pub const BlockId = inst.BlockId; +pub const FuncId = inst.FuncId; +pub const GlobalId = inst.GlobalId; +pub const Inst = inst.Inst; +pub const Op = inst.Op; +pub const Block = inst.Block; +pub const Function = inst.Function; +pub const Global = inst.Global; +pub const ConstantValue = inst.ConstantValue; + +pub const Module = module.Module; +pub const Builder = module.Builder; +pub const ImplTable = module.ImplTable; + +pub const printModule = print.printModule; +pub const Interpreter = interp.Interpreter; +pub const Value = interp.Value; +pub const Lowering = lower.Lowering; + +pub const type_bridge = @import("type_bridge.zig"); +pub const resolveAstType = type_bridge.resolveAstType; +pub const bridgeType = type_bridge.bridgeType; + +pub const types_tests = @import("types.test.zig"); +pub const inst_tests = @import("inst.test.zig"); +pub const module_tests = @import("module.test.zig"); +pub const print_tests = @import("print.test.zig"); +pub const interp_tests = @import("interp.test.zig"); +pub const lower_tests = @import("lower.test.zig"); +pub const type_bridge_tests = @import("type_bridge.test.zig"); + +test { + @import("std").testing.refAllDecls(@This()); +} diff --git a/src/ir/lower.test.zig b/src/ir/lower.test.zig new file mode 100644 index 0000000..1ceaf87 --- /dev/null +++ b/src/ir/lower.test.zig @@ -0,0 +1,223 @@ +// Tests for lower.zig + +const std = @import("std"); +const ast = @import("../ast.zig"); +const Node = ast.Node; + +const ir_mod = @import("ir.zig"); +const TypeId = ir_mod.TypeId; +const Ref = ir_mod.Ref; +const FuncId = ir_mod.FuncId; +const Lowering = ir_mod.Lowering; + +test "lower: simple function with arithmetic" { + const alloc = std.testing.allocator; + var module = ir_mod.Module.init(alloc); + defer module.deinit(); + + // Build a minimal AST: add :: (a: s64, b: s64) -> s64 { return a + b; } + const a_type = alloc.create(Node) catch unreachable; + a_type.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "s64", .is_generic = false } } }; + const b_type = alloc.create(Node) catch unreachable; + b_type.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "s64", .is_generic = false } } }; + const ret_type = alloc.create(Node) catch unreachable; + ret_type.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "s64", .is_generic = false } } }; + + const a_ident = alloc.create(Node) catch unreachable; + a_ident.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .identifier = .{ .name = "a" } } }; + const b_ident = alloc.create(Node) catch unreachable; + b_ident.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .identifier = .{ .name = "b" } } }; + + const add_expr = alloc.create(Node) catch unreachable; + add_expr.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .binary_op = .{ + .op = .add, + .lhs = a_ident, + .rhs = b_ident, + } } }; + + const ret_stmt = alloc.create(Node) catch unreachable; + ret_stmt.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .return_stmt = .{ .value = add_expr } } }; + + const body = alloc.create(Node) catch unreachable; + const stmts: []const *Node = &.{ret_stmt}; + body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = stmts } } }; + + defer alloc.destroy(a_type); + defer alloc.destroy(b_type); + defer alloc.destroy(ret_type); + defer alloc.destroy(a_ident); + defer alloc.destroy(b_ident); + defer alloc.destroy(add_expr); + defer alloc.destroy(ret_stmt); + defer alloc.destroy(body); + + const params: []const ast.Param = &.{ + .{ .name = "a", .name_span = .{ .start = 0, .end = 0 }, .type_expr = a_type }, + .{ .name = "b", .name_span = .{ .start = 0, .end = 0 }, .type_expr = b_type }, + }; + + const fn_decl = ast.FnDecl{ + .name = "add", + .params = params, + .return_type = ret_type, + .body = body, + }; + + var lowering = Lowering.init(&module); + lowering.lowerFunction(&fn_decl, "add"); + + // Verify + try std.testing.expectEqual(@as(usize, 1), module.functions.items.len); + const func = module.getFunction(FuncId.fromIndex(0)); + try std.testing.expectEqual(@as(usize, 2), func.params.len); + try std.testing.expectEqual(TypeId.s64, func.ret); + try std.testing.expect(func.blocks.items.len > 0); + + // Print the IR to verify it looks reasonable + const print_mod = @import("print.zig"); + var aw = std.Io.Writer.Allocating.init(alloc); + try print_mod.printModule(&module, &aw.writer); + var result = aw.writer.toArrayList(); + defer result.deinit(alloc); + + const output = result.items; + try std.testing.expect(std.mem.indexOf(u8, output, "func @add") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "entry:") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "add %") != null or std.mem.indexOf(u8, output, "ret %") != null); +} + +test "lower: if/else generates basic blocks" { + const alloc = std.testing.allocator; + var module = ir_mod.Module.init(alloc); + defer module.deinit(); + + // Build AST: test :: () -> s64 { if true { return 1; } else { return 2; } } + const cond_node = alloc.create(Node) catch unreachable; + defer alloc.destroy(cond_node); + cond_node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .bool_literal = .{ .value = true } } }; + + const ret1_val = alloc.create(Node) catch unreachable; + defer alloc.destroy(ret1_val); + ret1_val.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .int_literal = .{ .value = 1 } } }; + + const ret2_val = alloc.create(Node) catch unreachable; + defer alloc.destroy(ret2_val); + ret2_val.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .int_literal = .{ .value = 2 } } }; + + const then_ret = alloc.create(Node) catch unreachable; + defer alloc.destroy(then_ret); + then_ret.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .return_stmt = .{ .value = ret1_val } } }; + + const else_ret = alloc.create(Node) catch unreachable; + defer alloc.destroy(else_ret); + else_ret.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .return_stmt = .{ .value = ret2_val } } }; + + const then_body = alloc.create(Node) catch unreachable; + defer alloc.destroy(then_body); + then_body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{then_ret} } } }; + + const else_body = alloc.create(Node) catch unreachable; + defer alloc.destroy(else_body); + else_body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{else_ret} } } }; + + const if_node = alloc.create(Node) catch unreachable; + defer alloc.destroy(if_node); + if_node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .if_expr = .{ + .condition = cond_node, + .then_branch = then_body, + .else_branch = else_body, + .is_inline = false, + } } }; + + const fn_body = alloc.create(Node) catch unreachable; + defer alloc.destroy(fn_body); + fn_body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{if_node} } } }; + + const ret_type = alloc.create(Node) catch unreachable; + defer alloc.destroy(ret_type); + ret_type.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "s64", .is_generic = false } } }; + + const fn_decl = ast.FnDecl{ + .name = "test_if", + .params = &.{}, + .return_type = ret_type, + .body = fn_body, + }; + + var lowering = Lowering.init(&module); + lowering.lowerFunction(&fn_decl, "test_if"); + + // Verify: should have 4 blocks (entry, if.then, if.else, if.merge) + const func = module.getFunction(FuncId.fromIndex(0)); + try std.testing.expectEqual(@as(usize, 4), func.blocks.items.len); + + // Print and verify structure + const print_mod = @import("print.zig"); + var aw = std.Io.Writer.Allocating.init(alloc); + try print_mod.printModule(&module, &aw.writer); + var result = aw.writer.toArrayList(); + defer result.deinit(alloc); + const output = result.items; + + try std.testing.expect(std.mem.indexOf(u8, output, "cond_br") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "if.then") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "if.else") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "if.merge") != null); +} + +test "lower: while loop generates header/body/exit blocks" { + const alloc = std.testing.allocator; + var module = ir_mod.Module.init(alloc); + defer module.deinit(); + + // Build AST: loop :: () { while true { break; } } + const cond_node = alloc.create(Node) catch unreachable; + defer alloc.destroy(cond_node); + cond_node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .bool_literal = .{ .value = true } } }; + + const break_node = alloc.create(Node) catch unreachable; + defer alloc.destroy(break_node); + break_node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .break_expr }; + + const while_body = alloc.create(Node) catch unreachable; + defer alloc.destroy(while_body); + while_body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{break_node} } } }; + + const while_node = alloc.create(Node) catch unreachable; + defer alloc.destroy(while_node); + while_node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .while_expr = .{ + .condition = cond_node, + .body = while_body, + } } }; + + const fn_body = alloc.create(Node) catch unreachable; + defer alloc.destroy(fn_body); + fn_body.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .block = .{ .stmts = &.{while_node} } } }; + + const fn_decl = ast.FnDecl{ + .name = "loop_test", + .params = &.{}, + .return_type = null, + .body = fn_body, + }; + + var lowering = Lowering.init(&module); + lowering.lowerFunction(&fn_decl, "loop_test"); + + // Verify: should have 4 blocks (entry, while.hdr, while.body, while.exit) + const func = module.getFunction(FuncId.fromIndex(0)); + try std.testing.expectEqual(@as(usize, 4), func.blocks.items.len); + + // Print and verify structure + const print_mod = @import("print.zig"); + var aw = std.Io.Writer.Allocating.init(alloc); + try print_mod.printModule(&module, &aw.writer); + var result = aw.writer.toArrayList(); + defer result.deinit(alloc); + const output = result.items; + + try std.testing.expect(std.mem.indexOf(u8, output, "while.hdr") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "while.body") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "while.exit") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "cond_br") != null); +} diff --git a/src/ir/lower.zig b/src/ir/lower.zig new file mode 100644 index 0000000..0026a97 --- /dev/null +++ b/src/ir/lower.zig @@ -0,0 +1,1142 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ast = @import("../ast.zig"); +const Node = ast.Node; +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); +const mod_mod = @import("module.zig"); +const type_bridge = @import("type_bridge.zig"); +const unescape = @import("../unescape.zig"); + +const TypeId = types.TypeId; +const StringId = types.StringId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Function = inst_mod.Function; +const Module = mod_mod.Module; +const Builder = mod_mod.Builder; + +// ── Scope ─────────────────────────────────────────────────────────────── + +const Binding = struct { + ref: Ref, + ty: TypeId, + is_alloca: bool, // true if ref is a pointer that needs load +}; + +const Scope = struct { + map: std.StringHashMap(Binding), + parent: ?*Scope, + + fn init(alloc: Allocator, parent: ?*Scope) Scope { + return .{ + .map = std.StringHashMap(Binding).init(alloc), + .parent = parent, + }; + } + + fn deinit(self: *Scope) void { + self.map.deinit(); + } + + fn put(self: *Scope, name: []const u8, binding: Binding) void { + self.map.put(name, binding) catch unreachable; + } + + fn lookup(self: *const Scope, name: []const u8) ?Binding { + if (self.map.get(name)) |b| return b; + if (self.parent) |p| return p.lookup(name); + return null; + } +}; + +// ── Lowering ──────────────────────────────────────────────────────────── + +pub const Lowering = struct { + module: *Module, + builder: Builder, + alloc: Allocator, + scope: ?*Scope = null, + break_target: ?BlockId = null, + continue_target: ?BlockId = null, + block_counter: u32 = 0, + comptime_counter: u32 = 0, + + pub fn init(module: *Module) Lowering { + return .{ + .module = module, + .builder = Builder.init(module), + .alloc = module.alloc, + }; + } + + // ── Public entry point ────────────────────────────────────────── + + /// Lower all top-level declarations from a root node. + pub fn lowerRoot(self: *Lowering, root: *const Node) void { + const decls = switch (root.data) { + .root => |r| r.decls, + else => return, + }; + self.lowerDecls(decls); + } + + /// Lower a list of top-level declarations (used by both lowerRoot and irComptimeEval). + pub fn lowerDecls(self: *Lowering, decls: []const *const Node) void { + for (decls) |decl| { + switch (decl.data) { + .fn_decl => |fd| self.lowerFunction(&fd, fd.name), + .const_decl => |cd| { + // const_decl where value is fn_decl → named function + if (cd.value.data == .fn_decl) { + self.lowerFunction(&cd.value.data.fn_decl, cd.name); + } else if (cd.value.data == .comptime_expr) { + // NAME :: #run expr; → create comptime function + global constant + self.lowerComptimeGlobal(cd.name, cd.value.data.comptime_expr.expr, cd.type_annotation); + } + // Other const decls (types, values) → skip for now + }, + .comptime_expr => |ct| { + // Standalone #run expr; at top level → comptime side-effect function + self.lowerComptimeSideEffect(ct.expr); + }, + // Skip type declarations, imports, etc for now + else => {}, + } + } + } + + /// Lower a single function declaration. + pub fn lowerFunction(self: *Lowering, fd: *const ast.FnDecl, name: []const u8) void { + const name_id = self.module.types.internString(name); + const ret_ty = self.resolveReturnType(fd); + + // Build param list + var params = std.ArrayList(Function.Param).empty; + for (fd.params) |p| { + const pty = self.resolveParamType(&p); + params.append(self.alloc, .{ + .name = self.module.types.internString(p.name), + .ty = pty, + }) catch unreachable; + } + + const func_id = self.builder.beginFunction( + name_id, + params.items, + ret_ty, + ); + _ = func_id; + + // Create entry block + const entry_name = self.module.types.internString("entry"); + const entry = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry); + + // Create scope and bind params + var scope = Scope.init(self.alloc, self.scope); + defer scope.deinit(); + self.scope = &scope; + defer self.scope = scope.parent; + + for (fd.params, 0..) |p, i| { + const pty = self.resolveParamType(&p); + // Allocate stack slot for param, store initial value + const slot = self.builder.alloca(pty); + const param_ref = Ref.fromIndex(@intCast(i)); + // Params are the first N refs (from beginFunction bindings) + // For now, use const_int placeholders since we don't have real param refs yet + const placeholder = self.builder.constInt(0, pty); + _ = param_ref; + self.builder.store(slot, placeholder); + scope.put(p.name, .{ .ref = slot, .ty = pty, .is_alloca = true }); + } + + // Lower the function body + self.lowerBlock(fd.body); + + // If no terminator in current block, add implicit return + self.ensureTerminator(ret_ty); + + self.builder.finalize(); + } + + // ── Statement lowering ────────────────────────────────────────── + + fn lowerBlock(self: *Lowering, node: *const Node) void { + switch (node.data) { + .block => |blk| { + for (blk.stmts) |stmt| { + self.lowerStmt(stmt); + } + }, + else => { + // Single expression as body (arrow functions) + self.lowerStmt(node); + }, + } + } + + fn lowerStmt(self: *Lowering, node: *const Node) void { + switch (node.data) { + .var_decl => |vd| self.lowerVarDecl(&vd), + .const_decl => |cd| self.lowerConstDecl(&cd), + .return_stmt => |rs| self.lowerReturn(&rs), + .assignment => |asgn| self.lowerAssignment(&asgn), + .defer_stmt => |ds| self.lowerDefer(&ds), + .push_stmt => |ps| self.lowerPush(&ps), + .multi_assign => |ma| self.lowerMultiAssign(&ma), + // Expression statement + else => { + _ = self.lowerExpr(node); + }, + } + } + + fn lowerVarDecl(self: *Lowering, vd: *const ast.VarDecl) void { + const ty = self.resolveType(vd.type_annotation); + const slot = self.builder.alloca(ty); + + if (vd.value) |val| { + const ref = self.lowerExpr(val); + self.builder.store(slot, ref); + } + + if (self.scope) |scope| { + scope.put(vd.name, .{ .ref = slot, .ty = ty, .is_alloca = true }); + } + } + + fn lowerConstDecl(self: *Lowering, cd: *const ast.ConstDecl) void { + const ty = self.resolveType(cd.type_annotation); + const ref = self.lowerExpr(cd.value); + + if (self.scope) |scope| { + scope.put(cd.name, .{ .ref = ref, .ty = ty, .is_alloca = false }); + } + } + + fn lowerReturn(self: *Lowering, rs: *const ast.ReturnStmt) void { + if (rs.value) |val| { + const ref = self.lowerExpr(val); + self.builder.ret(ref, .s64); // stub type + } else { + self.builder.retVoid(); + } + } + + fn lowerAssignment(self: *Lowering, asgn: *const ast.Assignment) void { + const val = self.lowerExpr(asgn.value); + + switch (asgn.target.data) { + .identifier => |id| { + if (self.scope) |scope| { + if (scope.lookup(id.name)) |binding| { + if (binding.is_alloca) { + if (asgn.op == .assign) { + self.builder.store(binding.ref, val); + } else { + // Compound assignment: load, op, store + const loaded = self.builder.load(binding.ref, binding.ty); + const result = self.emitCompoundOp(loaded, val, asgn.op, binding.ty); + self.builder.store(binding.ref, result); + } + } + } + } + }, + .field_access => |fa| { + const obj = self.lowerExpr(fa.object); + if (std.mem.eql(u8, fa.field, "len")) { + // .len is special — struct_gep index 1 for fat pointers + const gep = self.builder.structGep(obj, 1, .s64); + self.builder.store(gep, val); + } else if (std.mem.eql(u8, fa.field, "ptr")) { + const gep = self.builder.structGep(obj, 0, .s64); + self.builder.store(gep, val); + } else { + // Generic field — stub index 0 (real resolution needs type info) + const gep = self.builder.structGep(obj, 0, .s64); + self.builder.store(gep, val); + } + }, + .index_expr => |ie| { + const obj = self.lowerExpr(ie.object); + const idx = self.lowerExpr(ie.index); + const gep = self.builder.emit(.{ .index_gep = .{ .lhs = obj, .rhs = idx } }, .s64); + self.builder.store(gep, val); + }, + .deref_expr => |de| { + const ptr = self.lowerExpr(de.operand); + self.builder.store(ptr, val); + }, + else => { + _ = self.emitPlaceholder("assignment_target"); + }, + } + } + + fn emitCompoundOp(self: *Lowering, lhs: Ref, rhs: Ref, op: ast.Assignment.Op, ty: TypeId) Ref { + return switch (op) { + .add_assign => self.builder.add(lhs, rhs, ty), + .sub_assign => self.builder.sub(lhs, rhs, ty), + .mul_assign => self.builder.mul(lhs, rhs, ty), + .div_assign => self.builder.div(lhs, rhs, ty), + else => self.emitPlaceholder("compound_assign"), + }; + } + + // ── Expression lowering ───────────────────────────────────────── + + fn lowerExpr(self: *Lowering, node: *const Node) Ref { + return switch (node.data) { + .int_literal => |lit| self.builder.constInt(lit.value, .s64), + .float_literal => |lit| self.builder.constFloat(lit.value, .f64), + .bool_literal => |lit| self.builder.constBool(lit.value), + .string_literal => |lit| blk: { + const str = if (lit.is_raw) + lit.raw + else + unescape.unescapeString(self.alloc, lit.raw) catch lit.raw; + const sid = self.module.types.internString(str); + break :blk self.builder.constString(sid); + }, + .null_literal => self.builder.constNull(.void), + .undef_literal => self.builder.constUndef(.void), + + .identifier => |id| blk: { + if (self.scope) |scope| { + if (scope.lookup(id.name)) |binding| { + if (binding.is_alloca) { + break :blk self.builder.load(binding.ref, binding.ty); + } + break :blk binding.ref; + } + } + // Unknown identifier — emit placeholder + break :blk self.emitPlaceholder(id.name); + }, + + .binary_op => |bop| self.lowerBinaryOp(&bop), + + .unary_op => |uop| blk: { + const operand = self.lowerExpr(uop.operand); + break :blk switch (uop.op) { + .negate => self.builder.emit(.{ .neg = .{ .operand = operand } }, .s64), + .not => self.builder.emit(.{ .bool_not = .{ .operand = operand } }, .bool), + else => self.emitPlaceholder("unary_op"), + }; + }, + + .if_expr => |ie| self.lowerIfExpr(&ie), + .match_expr => |me| self.lowerMatch(&me), + .while_expr => |we| self.lowerWhile(&we), + .for_expr => |fe| self.lowerFor(&fe), + .break_expr => self.lowerBreak(), + .continue_expr => self.lowerContinue(), + .call => |c| self.lowerCall(&c), + .field_access => |fa| self.lowerFieldAccess(&fa), + .struct_literal => |sl| self.lowerStructLiteral(&sl), + .array_literal => |al| self.lowerArrayLiteral(&al), + .index_expr => |ie| self.lowerIndexExpr(&ie), + .slice_expr => |se| self.lowerSliceExpr(&se), + .lambda => |lam| self.lowerLambda(&lam), + .force_unwrap => |fu| self.lowerForceUnwrap(&fu), + .null_coalesce => |nc| self.lowerNullCoalesce(&nc), + .deref_expr => |de| self.lowerDerefExpr(&de), + .enum_literal => |el| self.lowerEnumLiteral(&el), + .comptime_expr => |ct| self.lowerInlineComptime(ct.expr), + .tuple_literal => |tl| self.lowerTupleLiteral(&tl), + .spread_expr => self.emitPlaceholder("spread_expr"), + .chained_comparison => |cc| self.lowerChainedComparison(&cc), + + // Statements that can appear in expression position + .block => |blk| blk: { + for (blk.stmts) |stmt| { + self.lowerStmt(stmt); + } + break :blk self.builder.constInt(0, .void); + }, + + else => self.emitPlaceholder("unknown_expr"), + }; + } + + fn lowerBinaryOp(self: *Lowering, bop: *const ast.BinaryOp) Ref { + const lhs = self.lowerExpr(bop.lhs); + const rhs = self.lowerExpr(bop.rhs); + const ty: TypeId = .s64; // stub — real type resolution in Step 1.3 + + return switch (bop.op) { + .add => self.builder.add(lhs, rhs, ty), + .sub => self.builder.sub(lhs, rhs, ty), + .mul => self.builder.mul(lhs, rhs, ty), + .div => self.builder.div(lhs, rhs, ty), + .mod => self.builder.emit(.{ .mod = .{ .lhs = lhs, .rhs = rhs } }, ty), + .eq => self.builder.cmpEq(lhs, rhs), + .neq => self.builder.emit(.{ .cmp_ne = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .lt => self.builder.cmpLt(lhs, rhs), + .lte => self.builder.emit(.{ .cmp_le = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .gt => self.builder.cmpGt(lhs, rhs), + .gte => self.builder.emit(.{ .cmp_ge = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .and_op => self.builder.emit(.{ .bool_and = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .or_op => self.builder.emit(.{ .bool_or = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .bit_and => self.builder.emit(.{ .bit_and = .{ .lhs = lhs, .rhs = rhs } }, ty), + .bit_or => self.builder.emit(.{ .bit_or = .{ .lhs = lhs, .rhs = rhs } }, ty), + .bit_xor => self.builder.emit(.{ .bit_xor = .{ .lhs = lhs, .rhs = rhs } }, ty), + .shl => self.builder.emit(.{ .shl = .{ .lhs = lhs, .rhs = rhs } }, ty), + .shr => self.builder.emit(.{ .shr = .{ .lhs = lhs, .rhs = rhs } }, ty), + .in_op => self.emitPlaceholder("in_op"), + }; + } + + // ── Control flow ──────────────────────────────────────────────── + + fn lowerIfExpr(self: *Lowering, ie: *const ast.IfExpr) Ref { + const cond = self.lowerExpr(ie.condition); + const has_else = ie.else_branch != null; + const is_value = ie.is_inline and has_else; + + const then_bb = self.freshBlock("if.then"); + const else_bb: ?BlockId = if (has_else) self.freshBlock("if.else") else null; + const merge_params: []const TypeId = if (is_value) &.{.s64} else &.{}; + const merge_bb = self.freshBlockWithParams("if.merge", merge_params); + + // Conditional branch + self.builder.condBr( + cond, + then_bb, + &.{}, + if (else_bb) |eb| eb else merge_bb, + &.{}, + ); + + // Then branch + self.builder.switchToBlock(then_bb); + if (is_value) { + const v = self.lowerExpr(ie.then_branch); + if (!self.currentBlockHasTerminator()) { + self.builder.br(merge_bb, &.{v}); + } + } else { + self.lowerBlock(ie.then_branch); + if (!self.currentBlockHasTerminator()) { + self.builder.br(merge_bb, &.{}); + } + } + + // Else branch + if (has_else) { + self.builder.switchToBlock(else_bb.?); + if (is_value) { + const v = self.lowerExpr(ie.else_branch.?); + if (!self.currentBlockHasTerminator()) { + self.builder.br(merge_bb, &.{v}); + } + } else { + self.lowerBlock(ie.else_branch.?); + if (!self.currentBlockHasTerminator()) { + self.builder.br(merge_bb, &.{}); + } + } + } + + // Continue at merge + self.builder.switchToBlock(merge_bb); + if (is_value) { + return self.builder.blockParam(merge_bb, 0, .s64); + } + return self.builder.constInt(0, .void); + } + + fn lowerWhile(self: *Lowering, we: *const ast.WhileExpr) Ref { + const header_bb = self.freshBlock("while.hdr"); + const body_bb = self.freshBlock("while.body"); + const exit_bb = self.freshBlock("while.exit"); + + // Branch to header + self.builder.br(header_bb, &.{}); + + // Header: evaluate condition + self.builder.switchToBlock(header_bb); + const cond = self.lowerExpr(we.condition); + self.builder.condBr(cond, body_bb, &.{}, exit_bb, &.{}); + + // Body + self.builder.switchToBlock(body_bb); + + // Save and set loop targets + const old_break = self.break_target; + const old_continue = self.continue_target; + self.break_target = exit_bb; + self.continue_target = header_bb; + defer { + self.break_target = old_break; + self.continue_target = old_continue; + } + + self.lowerBlock(we.body); + if (!self.currentBlockHasTerminator()) { + self.builder.br(header_bb, &.{}); + } + + // Continue at exit + self.builder.switchToBlock(exit_bb); + return self.builder.constInt(0, .void); + } + + fn lowerFor(self: *Lowering, fe: *const ast.ForExpr) Ref { + // Lower iterable + const iterable = self.lowerExpr(fe.iterable); + + // Get length + const len = self.builder.emit(.{ .length = .{ .operand = iterable } }, .s64); + + // Create index variable + const idx_slot = self.builder.alloca(.s64); + const zero = self.builder.constInt(0, .s64); + self.builder.store(idx_slot, zero); + + const header_bb = self.freshBlock("for.hdr"); + const body_bb = self.freshBlock("for.body"); + const exit_bb = self.freshBlock("for.exit"); + + self.builder.br(header_bb, &.{}); + + // Header: compare index < length + self.builder.switchToBlock(header_bb); + const idx_val = self.builder.load(idx_slot, .s64); + const cmp = self.builder.cmpLt(idx_val, len); + self.builder.condBr(cmp, body_bb, &.{}, exit_bb, &.{}); + + // Body + self.builder.switchToBlock(body_bb); + + // Bind element + const elem = self.builder.emit(.{ .index_get = .{ .lhs = iterable, .rhs = idx_val } }, .s64); + + var body_scope = Scope.init(self.alloc, self.scope); + const old_scope = self.scope; + self.scope = &body_scope; + + body_scope.put(fe.capture_name, .{ .ref = elem, .ty = .s64, .is_alloca = false }); + + // Bind index if requested + if (fe.index_name) |iname| { + body_scope.put(iname, .{ .ref = idx_val, .ty = .s64, .is_alloca = false }); + } + + // Save and set loop targets + const old_break = self.break_target; + const old_continue = self.continue_target; + self.break_target = exit_bb; + self.continue_target = header_bb; + + self.lowerBlock(fe.body); + + self.break_target = old_break; + self.continue_target = old_continue; + self.scope = old_scope; + body_scope.deinit(); + + // Increment index + if (!self.currentBlockHasTerminator()) { + const cur_idx = self.builder.load(idx_slot, .s64); + const one = self.builder.constInt(1, .s64); + const next_idx = self.builder.add(cur_idx, one, .s64); + self.builder.store(idx_slot, next_idx); + self.builder.br(header_bb, &.{}); + } + + // Continue at exit + self.builder.switchToBlock(exit_bb); + return self.builder.constInt(0, .void); + } + + fn lowerMatch(self: *Lowering, me: *const ast.MatchExpr) Ref { + const subject = self.lowerExpr(me.subject); + + const merge_bb = self.freshBlock("match.merge"); + + // Build arm blocks and case list + var default_bb: ?BlockId = null; + var cases = std.ArrayList(inst_mod.SwitchBranch.Case).empty; + defer cases.deinit(self.alloc); + var arm_blocks = std.ArrayList(BlockId).empty; + defer arm_blocks.deinit(self.alloc); + + for (me.arms, 0..) |arm, i| { + const arm_bb = self.freshBlock("match.arm"); + arm_blocks.append(self.alloc, arm_bb) catch unreachable; + + if (arm.pattern == null) { + // Default/else arm + default_bb = arm_bb; + } else { + cases.append(self.alloc, .{ + .value = @intCast(i), + .target = arm_bb, + .args = &.{}, + }) catch unreachable; + } + } + + // If no default arm, create an unreachable default + if (default_bb == null) { + default_bb = self.freshBlock("match.unr"); + } + + // Extract tag and switch + const tag = self.builder.enumTag(subject); + self.builder.switchBr(tag, cases.items, default_bb.?, &.{}); + + // Lower each arm's body + for (me.arms, 0..) |arm, i| { + const arm_bb = arm_blocks.items[i]; + self.builder.switchToBlock(arm_bb); + + var arm_scope = Scope.init(self.alloc, self.scope); + const old_scope = self.scope; + self.scope = &arm_scope; + + if (arm.capture) |capture_name| { + const payload = self.builder.emit(.{ .enum_payload = .{ + .base = subject, + .field_index = @intCast(i), + } }, .s64); + arm_scope.put(capture_name, .{ .ref = payload, .ty = .s64, .is_alloca = false }); + } + + self.lowerBlock(arm.body); + + self.scope = old_scope; + arm_scope.deinit(); + + if (!self.currentBlockHasTerminator()) { + self.builder.br(merge_bb, &.{}); + } + } + + // Emit unreachable in synthetic default block if needed + var found_default = false; + for (me.arms) |arm| { + if (arm.pattern == null) { + found_default = true; + break; + } + } + if (!found_default) { + self.builder.switchToBlock(default_bb.?); + self.builder.emitUnreachable(); + } + + self.builder.switchToBlock(merge_bb); + return self.builder.constInt(0, .void); + } + + fn lowerBreak(self: *Lowering) Ref { + if (self.break_target) |target| { + self.builder.br(target, &.{}); + } + return Ref.none; + } + + fn lowerContinue(self: *Lowering) Ref { + if (self.continue_target) |target| { + self.builder.br(target, &.{}); + } + return Ref.none; + } + + // ── Struct/enum/union ops ─────────────────────────────────────── + + fn lowerStructLiteral(self: *Lowering, sl: *const ast.StructLiteral) Ref { + var fields = std.ArrayList(Ref).empty; + defer fields.deinit(self.alloc); + + for (sl.field_inits) |fi| { + const val = self.lowerExpr(fi.value); + fields.append(self.alloc, val) catch unreachable; + } + + const ty: TypeId = if (sl.struct_name) |name| blk: { + const name_id = self.module.types.internString(name); + break :blk self.module.types.intern(.{ .@"struct" = .{ .name = name_id, .fields = &.{} } }); + } else .s64; + + const result = self.builder.structInit(fields.items, ty); + + // Lower init block if present + if (sl.init_block) |ib| { + self.lowerBlock(ib); + } + + return result; + } + + fn lowerFieldAccess(self: *Lowering, fa: *const ast.FieldAccess) Ref { + const obj = self.lowerExpr(fa.object); + + // Special fields on slices/strings + if (std.mem.eql(u8, fa.field, "len")) { + return self.builder.emit(.{ .length = .{ .operand = obj } }, .s64); + } + if (std.mem.eql(u8, fa.field, "ptr")) { + return self.builder.emit(.{ .data_ptr = .{ .operand = obj } }, .s64); + } + + // Generic struct field — stub index 0 (real resolution needs type info) + return self.builder.structGet(obj, 0, .s64); + } + + fn lowerEnumLiteral(self: *Lowering, el: *const ast.EnumLiteral) Ref { + _ = el; + return self.builder.enumInit(0, Ref.none, .s64); + } + + fn lowerArrayLiteral(self: *Lowering, al: *const ast.ArrayLiteral) Ref { + var elems = std.ArrayList(Ref).empty; + defer elems.deinit(self.alloc); + + for (al.elements) |elem| { + const val = self.lowerExpr(elem); + elems.append(self.alloc, val) catch unreachable; + } + + return self.builder.structInit(elems.items, .s64); + } + + fn lowerIndexExpr(self: *Lowering, ie: *const ast.IndexExpr) Ref { + const obj = self.lowerExpr(ie.object); + const idx = self.lowerExpr(ie.index); + return self.builder.emit(.{ .index_get = .{ .lhs = obj, .rhs = idx } }, .s64); + } + + fn lowerSliceExpr(self: *Lowering, se: *const ast.SliceExpr) Ref { + const obj = self.lowerExpr(se.object); + const lo = if (se.start) |s| self.lowerExpr(s) else self.builder.constInt(0, .s64); + const hi = if (se.end) |e| self.lowerExpr(e) else self.builder.emit(.{ .length = .{ .operand = obj } }, .s64); + return self.builder.emit(.{ .subslice = .{ .base = obj, .lo = lo, .hi = hi } }, .s64); + } + + fn lowerTupleLiteral(self: *Lowering, tl: *const ast.TupleLiteral) Ref { + var elems = std.ArrayList(Ref).empty; + defer elems.deinit(self.alloc); + + for (tl.elements) |elem| { + const val = self.lowerExpr(elem.value); + elems.append(self.alloc, val) catch unreachable; + } + + const owned = self.alloc.dupe(Ref, elems.items) catch unreachable; + return self.builder.emit(.{ .tuple_init = .{ .fields = owned } }, .s64); + } + + fn lowerDerefExpr(self: *Lowering, de: *const ast.DerefExpr) Ref { + const ptr = self.lowerExpr(de.operand); + return self.builder.emit(.{ .deref = .{ .operand = ptr } }, .s64); + } + + fn lowerForceUnwrap(self: *Lowering, fu: *const ast.ForceUnwrap) Ref { + const val = self.lowerExpr(fu.operand); + return self.builder.optionalUnwrap(val, .s64); + } + + fn lowerNullCoalesce(self: *Lowering, nc: *const ast.NullCoalesce) Ref { + const lhs = self.lowerExpr(nc.lhs); + const rhs = self.lowerExpr(nc.rhs); + return self.builder.emit(.{ .optional_coalesce = .{ .lhs = lhs, .rhs = rhs } }, .s64); + } + + // ── Calls ─────────────────────────────────────────────────────── + + fn lowerCall(self: *Lowering, c: *const ast.Call) Ref { + // Lower args + var args = std.ArrayList(Ref).empty; + defer args.deinit(self.alloc); + for (c.args) |arg| { + const val = self.lowerExpr(arg); + args.append(self.alloc, val) catch unreachable; + } + + switch (c.callee.data) { + .identifier => |id| { + // Check builtins first + if (resolveBuiltin(id.name)) |bid| { + return self.builder.callBuiltin(bid, args.items, .s64); + } + // Look up direct function + if (self.resolveFuncByName(id.name)) |fid| { + return self.builder.call(fid, args.items, .s64); + } + // May be a variable holding a function pointer + if (self.scope) |scope| { + if (scope.lookup(id.name)) |binding| { + const callee_ref = if (binding.is_alloca) self.builder.load(binding.ref, binding.ty) else binding.ref; + const owned = self.alloc.dupe(Ref, args.items) catch unreachable; + return self.builder.emit(.{ .call_indirect = .{ .callee = callee_ref, .args = owned } }, .s64); + } + } + // Unresolved — emit placeholder + return self.emitPlaceholder(id.name); + }, + .field_access => |fa| { + // Method call: obj.method(args) → prepend obj + const obj = self.lowerExpr(fa.object); + var method_args = std.ArrayList(Ref).empty; + defer method_args.deinit(self.alloc); + method_args.append(self.alloc, obj) catch unreachable; + for (args.items) |a| { + method_args.append(self.alloc, a) catch unreachable; + } + // Try to resolve as qualified function + if (self.resolveFuncByName(fa.field)) |fid| { + return self.builder.call(fid, method_args.items, .s64); + } + return self.emitPlaceholder(fa.field); + }, + .enum_literal => { + // .Variant(payload) — tagged enum construction + const payload = if (args.items.len > 0) args.items[0] else Ref.none; + return self.builder.enumInit(0, payload, .s64); + }, + else => { + // Indirect call through expression + const callee_ref = self.lowerExpr(c.callee); + const owned = self.alloc.dupe(Ref, args.items) catch unreachable; + return self.builder.emit(.{ .call_indirect = .{ .callee = callee_ref, .args = owned } }, .s64); + }, + } + } + + fn resolveFuncByName(self: *Lowering, name: []const u8) ?FuncId { + const name_id = self.module.types.internString(name); + for (self.module.functions.items, 0..) |func, i| { + if (func.name == name_id) return FuncId.fromIndex(@intCast(i)); + } + return null; + } + + fn resolveBuiltin(name: []const u8) ?inst_mod.BuiltinId { + const builtins = .{ + .{ "print", inst_mod.BuiltinId.print }, + .{ "out", inst_mod.BuiltinId.out }, + .{ "sqrt", inst_mod.BuiltinId.sqrt }, + .{ "size_of", inst_mod.BuiltinId.size_of }, + .{ "cast", inst_mod.BuiltinId.cast }, + .{ "malloc", inst_mod.BuiltinId.malloc }, + .{ "free", inst_mod.BuiltinId.free }, + .{ "memcpy", inst_mod.BuiltinId.memcpy }, + .{ "memset", inst_mod.BuiltinId.memset }, + }; + inline for (builtins) |entry| { + if (std.mem.eql(u8, name, entry[0])) return entry[1]; + } + return null; + } + + // ── Lambda/closure ──────────────────────────────────────────── + + fn lowerLambda(self: *Lowering, lam: *const ast.Lambda) Ref { + // Lower the lambda body as a new anonymous function + var buf: [64]u8 = undefined; + const name = std.fmt.bufPrint(&buf, "__lambda_{d}", .{self.block_counter}) catch "__lambda"; + self.block_counter += 1; + + // Save current builder state + const saved_func = self.builder.func; + const saved_block = self.builder.current_block; + const saved_counter = self.builder.inst_counter; + const saved_scope = self.scope; + + // Build param list (not deinited — function owns the slice) + var params = std.ArrayList(Function.Param).empty; + for (lam.params) |p| { + const pty = self.resolveParamType(&p); + params.append(self.alloc, .{ + .name = self.module.types.internString(p.name), + .ty = pty, + }) catch unreachable; + } + + const ret_ty = self.resolveReturnType2(lam.return_type); + const name_id = self.module.types.internString(name); + const func_id = self.builder.beginFunction(name_id, params.items, ret_ty); + + // Create entry block + const entry_name = self.module.types.internString("entry"); + const entry = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry); + + // Create scope and bind params + var lambda_scope = Scope.init(self.alloc, saved_scope); + self.scope = &lambda_scope; + + for (lam.params, 0..) |p, i| { + const pty = self.resolveParamType(&p); + const slot = self.builder.alloca(pty); + const placeholder = self.builder.constInt(0, pty); + _ = i; + self.builder.store(slot, placeholder); + lambda_scope.put(p.name, .{ .ref = slot, .ty = pty, .is_alloca = true }); + } + + // Lower body + self.lowerBlock(lam.body); + self.ensureTerminator(ret_ty); + self.builder.finalize(); + + // Restore builder state + self.scope = saved_scope; + lambda_scope.deinit(); + self.builder.func = saved_func; + self.builder.current_block = saved_block; + self.builder.inst_counter = saved_counter; + + // Emit closure_create referencing the new function + return self.builder.closureCreate(func_id, Ref.none, .s64); + } + + fn resolveReturnType2(self: *Lowering, rt: ?*const Node) TypeId { + if (rt) |r| return type_bridge.resolveAstType(r, &self.module.types); + return .void; + } + + // ── Chained comparison ────────────────────────────────────────── + + fn lowerChainedComparison(self: *Lowering, cc: *const ast.ChainedComparison) Ref { + // a < b < c → (a < b) and (b < c) + if (cc.operands.len < 2 or cc.ops.len == 0) { + return self.builder.constBool(true); + } + + var result = self.emitCmp( + self.lowerExpr(cc.operands[0]), + self.lowerExpr(cc.operands[1]), + cc.ops[0], + ); + + var i: usize = 1; + while (i < cc.ops.len) : (i += 1) { + const next_cmp = self.emitCmp( + self.lowerExpr(cc.operands[i]), + self.lowerExpr(cc.operands[i + 1]), + cc.ops[i], + ); + result = self.builder.emit(.{ .bool_and = .{ .lhs = result, .rhs = next_cmp } }, .bool); + } + + return result; + } + + fn emitCmp(self: *Lowering, lhs: Ref, rhs: Ref, op: ast.BinaryOp.Op) Ref { + return switch (op) { + .eq => self.builder.cmpEq(lhs, rhs), + .neq => self.builder.emit(.{ .cmp_ne = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .lt => self.builder.cmpLt(lhs, rhs), + .lte => self.builder.emit(.{ .cmp_le = .{ .lhs = lhs, .rhs = rhs } }, .bool), + .gt => self.builder.cmpGt(lhs, rhs), + .gte => self.builder.emit(.{ .cmp_ge = .{ .lhs = lhs, .rhs = rhs } }, .bool), + else => self.builder.constBool(false), + }; + } + + // ── Defer/Push/MultiAssign ────────────────────────────────────── + + fn lowerDefer(self: *Lowering, ds: *const ast.DeferStmt) void { + // For now, lower the deferred expression immediately as a placeholder. + // Real defer needs recording and replay at scope exits. + _ = self.lowerExpr(ds.expr); + } + + fn lowerPush(self: *Lowering, ps: *const ast.PushStmt) void { + // push context_expr { body } + // → context_save, context_store, body, context_restore + const save = self.builder.emit(.context_save, .s64); + const ctx_val = self.lowerExpr(ps.context_expr); + const field_id = self.module.types.internString("allocator"); + self.builder.contextStore(field_id, ctx_val); + self.lowerBlock(ps.body); + _ = self.builder.emit(.{ .context_restore = .{ .operand = save } }, .void); + } + + fn lowerMultiAssign(self: *Lowering, ma: *const ast.MultiAssign) void { + // Evaluate all RHS values first, then assign to LHS targets + var vals = std.ArrayList(Ref).empty; + defer vals.deinit(self.alloc); + for (ma.values) |v| { + vals.append(self.alloc, self.lowerExpr(v)) catch unreachable; + } + + for (ma.targets, 0..) |target, i| { + if (i >= vals.items.len) break; + const val = vals.items[i]; + switch (target.data) { + .identifier => |id| { + if (self.scope) |scope| { + if (scope.lookup(id.name)) |binding| { + if (binding.is_alloca) { + self.builder.store(binding.ref, val); + } + } + } + }, + else => { + _ = self.emitPlaceholder("multi_assign_target"); + }, + } + } + } + + // ── Comptime lowering ──────────────────────────────────────────── + + /// Lower a `#run expr` that appears as a top-level constant binding: + /// NAME :: #run expr; + /// Creates a comptime function wrapping the expression (for later + /// interpretation), plus a global constant to hold the result. + fn lowerComptimeGlobal(self: *Lowering, name: []const u8, expr: *const Node, type_ann: ?*const Node) void { + const ret_ty = self.resolveType(type_ann); + const func_id = self.createComptimeFunction(name, expr, ret_ty); + + // Add a global constant whose initializer will be filled by the interpreter. + const name_id = self.module.types.internString(name); + _ = self.module.addGlobal(.{ + .name = name_id, + .ty = ret_ty, + .init_val = null, // will be filled by interpreter (Phase 2.3) + .is_const = true, + .comptime_func = func_id, + }); + } + + /// Lower a standalone `#run expr;` at the top level (side-effect only). + /// Creates a comptime function that the interpreter should execute. + fn lowerComptimeSideEffect(self: *Lowering, expr: *const Node) void { + _ = self.createComptimeFunction("__run", expr, .void); + } + + /// Lower a `#run expr` that appears inline within an expression. + /// Creates a comptime function and emits a `call` to it, so the + /// interpreter can evaluate it and replace with the constant result. + fn lowerInlineComptime(self: *Lowering, expr: *const Node) Ref { + const ret_ty: TypeId = .s64; // stub — real type inferred later + const func_id = self.createComptimeFunction("__ct", expr, ret_ty); + // Emit a call to the comptime function. At interpretation time, + // this will be evaluated and the result inlined as a constant. + return self.builder.call(func_id, &.{}, ret_ty); + } + + /// Creates a temporary function marked `is_comptime = true` that wraps + /// the given expression as its return value. Returns the FuncId. + pub fn createComptimeFunction(self: *Lowering, prefix: []const u8, expr: *const Node, ret_ty: TypeId) FuncId { + var buf: [64]u8 = undefined; + const name = std.fmt.bufPrint(&buf, "{s}_{d}", .{ prefix, self.comptime_counter }) catch prefix; + self.comptime_counter += 1; + + // Save current builder state + const saved_func = self.builder.func; + const saved_block = self.builder.current_block; + const saved_counter = self.builder.inst_counter; + const saved_scope = self.scope; + + // Create the comptime function (no params, returns ret_ty) + const name_id = self.module.types.internString(name); + const func_id = self.builder.beginFunction(name_id, &.{}, ret_ty); + + // Mark as comptime + self.module.getFunctionMut(func_id).is_comptime = true; + + // Create entry block + const entry_name = self.module.types.internString("entry"); + const entry = self.builder.appendBlock(entry_name, &.{}); + self.builder.switchToBlock(entry); + + // Create a scope that chains to the enclosing scope (so the + // expression can reference names visible at the #run site). + var ct_scope = Scope.init(self.alloc, saved_scope); + self.scope = &ct_scope; + + // Lower the expression and return it + const result = self.lowerExpr(expr); + if (ret_ty == .void) { + self.builder.retVoid(); + } else { + self.builder.ret(result, ret_ty); + } + + self.builder.finalize(); + + // Restore builder state + self.scope = saved_scope; + ct_scope.deinit(); + self.builder.func = saved_func; + self.builder.current_block = saved_block; + self.builder.inst_counter = saved_counter; + + return func_id; + } + + // ── Block helpers ─────────────────────────────────────────────── + + fn freshBlock(self: *Lowering, prefix: []const u8) BlockId { + return self.freshBlockWithParams(prefix, &.{}); + } + + fn freshBlockWithParams(self: *Lowering, prefix: []const u8, params: []const TypeId) BlockId { + var buf: [64]u8 = undefined; + const name = std.fmt.bufPrint(&buf, "{s}.{d}", .{ prefix, self.block_counter }) catch prefix; + self.block_counter += 1; + const name_id = self.module.types.internString(name); + return self.builder.appendBlock(name_id, params); + } + + fn currentBlockHasTerminator(self: *Lowering) bool { + const func = self.builder.module.getFunctionMut(self.builder.func.?); + const block_idx = self.builder.current_block orelse return true; + const block = &func.blocks.items[block_idx.index()]; + if (block.insts.items.len > 0) { + const last_op = block.insts.items[block.insts.items.len - 1].op; + return switch (last_op) { + .ret, .ret_void, .br, .cond_br, .switch_br, .@"unreachable" => true, + else => false, + }; + } + return false; + } + + // ── Type resolution ───────────────────────────────────────────── + // Delegates to type_bridge for full AST type node resolution. + + fn resolveReturnType(self: *Lowering, fd: *const ast.FnDecl) TypeId { + if (fd.return_type) |rt| { + return type_bridge.resolveAstType(rt, &self.module.types); + } + return .void; + } + + fn resolveParamType(self: *Lowering, p: *const ast.Param) TypeId { + return type_bridge.resolveAstType(p.type_expr, &self.module.types); + } + + fn resolveType(self: *Lowering, type_ann: ?*const Node) TypeId { + return type_bridge.resolveAstType(type_ann, &self.module.types); + } + + // ── Helpers ───────────────────────────────────────────────────── + + fn emitPlaceholder(self: *Lowering, name: []const u8) Ref { + const sid = self.module.types.internString(name); + return self.builder.emit(.{ .placeholder = sid }, .s64); + } + + fn ensureTerminator(self: *Lowering, ret_ty: TypeId) void { + if (self.currentBlockHasTerminator()) return; + if (ret_ty == .void) { + self.builder.retVoid(); + } else { + const zero = self.builder.constInt(0, ret_ty); + self.builder.ret(zero, ret_ty); + } + } +}; diff --git a/src/ir/module.test.zig b/src/ir/module.test.zig new file mode 100644 index 0000000..2046815 --- /dev/null +++ b/src/ir/module.test.zig @@ -0,0 +1,116 @@ +// Tests for module.zig +const std = @import("std"); +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); +const mod_mod = @import("module.zig"); + +const TypeId = types.TypeId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Function = inst_mod.Function; +const GlobalId = inst_mod.GlobalId; +const Module = mod_mod.Module; +const Builder = mod_mod.Builder; + +test "Builder: build add(a: s64, b: s64) -> s64" { + const alloc = std.testing.allocator; + var mod = Module.init(alloc); + defer mod.deinit(); + + var b = Builder.init(&mod); + + const name_add = mod.types.internString("add"); + const name_a = mod.types.internString("a"); + const name_b = mod.types.internString("b"); + const name_entry = mod.types.internString("entry"); + + const params = &[_]Function.Param{ + .{ .name = name_a, .ty = .s64 }, + .{ .name = name_b, .ty = .s64 }, + }; + const func_id = b.beginFunction(name_add, params, .s64); + + const entry = b.appendBlock(name_entry, &.{}); + b.switchToBlock(entry); + + // Load params (in real lowering, params are block params of entry) + const a_ref = b.constInt(0, .s64); // placeholder for param a + const b_ref = b.constInt(0, .s64); // placeholder for param b + const sum = b.add(a_ref, b_ref, .s64); + b.ret(sum, .s64); + + b.finalize(); + + // Verify + const func = mod.getFunction(func_id); + try std.testing.expectEqual(@as(usize, 2), func.params.len); + try std.testing.expectEqual(TypeId.s64, func.ret); + try std.testing.expectEqual(@as(usize, 1), func.blocks.items.len); + + const blk = &func.blocks.items[0]; + try std.testing.expectEqual(@as(usize, 4), blk.insts.items.len); // 2 consts + add + ret +} + +test "Builder: conditional branch" { + const alloc = std.testing.allocator; + var mod = Module.init(alloc); + defer mod.deinit(); + + var b = Builder.init(&mod); + + const name_fn = mod.types.internString("test_fn"); + const name_entry = mod.types.internString("entry"); + const name_then = mod.types.internString("then"); + const name_else = mod.types.internString("else"); + const name_merge = mod.types.internString("merge"); + + _ = b.beginFunction(name_fn, &.{}, .s32); + + const entry = b.appendBlock(name_entry, &.{}); + const then_bb = b.appendBlock(name_then, &.{}); + const else_bb = b.appendBlock(name_else, &.{}); + const merge_bb = b.appendBlock(name_merge, &[_]TypeId{.s32}); + + b.switchToBlock(entry); + const cond = b.constBool(true); + b.condBr(cond, then_bb, &.{}, else_bb, &.{}); + + b.switchToBlock(then_bb); + const v1 = b.constInt(42, .s32); + b.br(merge_bb, &.{v1}); + + b.switchToBlock(else_bb); + const v2 = b.constInt(0, .s32); + b.br(merge_bb, &.{v2}); + + b.switchToBlock(merge_bb); + const result = b.emit(.{ .block_param = .{ .block = merge_bb, .param_index = 0 } }, .s32); + b.ret(result, .s32); + + b.finalize(); + + // Verify: 4 blocks, correct instruction counts + const func = mod.getFunction(@enumFromInt(0)); + try std.testing.expectEqual(@as(usize, 4), func.blocks.items.len); + try std.testing.expectEqual(@as(usize, 2), func.blocks.items[0].insts.items.len); // const_bool + cond_br + try std.testing.expectEqual(@as(usize, 2), func.blocks.items[1].insts.items.len); // const_int + br + try std.testing.expectEqual(@as(usize, 2), func.blocks.items[2].insts.items.len); // const_int + br + try std.testing.expectEqual(@as(usize, 2), func.blocks.items[3].insts.items.len); // block_param + ret +} + +test "Module: globals" { + const alloc = std.testing.allocator; + var mod = Module.init(alloc); + defer mod.deinit(); + + const name = mod.types.internString("counter"); + const id = mod.addGlobal(.{ + .name = name, + .ty = .s32, + .init_val = .{ .int = 0 }, + }); + + try std.testing.expectEqual(GlobalId.fromIndex(0), id); + try std.testing.expectEqual(TypeId.s32, mod.globals.items[0].ty); +} diff --git a/src/ir/module.zig b/src/ir/module.zig new file mode 100644 index 0000000..9b86f83 --- /dev/null +++ b/src/ir/module.zig @@ -0,0 +1,416 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const types = @import("types.zig"); +const inst = @import("inst.zig"); + +const TypeId = types.TypeId; +const TypeInfo = types.TypeInfo; +const TypeTable = types.TypeTable; +const StringId = types.StringId; +const Ref = inst.Ref; +const BlockId = inst.BlockId; +const FuncId = inst.FuncId; +const GlobalId = inst.GlobalId; +const Inst = inst.Inst; +const Op = inst.Op; +const Block = inst.Block; +const Function = inst.Function; +const Global = inst.Global; +const Span = inst.Span; + +// ── Module ────────────────────────────────────────────────────────────── + +pub const Module = struct { + types: TypeTable, + functions: std.ArrayList(Function), + globals: std.ArrayList(Global), + /// Maps (protocol_ty, concrete_ty) → list of method FuncIds. + impl_table: ImplTable, + alloc: Allocator, + + pub fn init(alloc: Allocator) Module { + return .{ + .types = TypeTable.init(alloc), + .functions = std.ArrayList(Function).empty, + .globals = std.ArrayList(Global).empty, + .impl_table = ImplTable.init(alloc), + .alloc = alloc, + }; + } + + pub fn deinit(self: *Module) void { + for (self.functions.items) |*func| { + func.deinit(self.alloc); + } + self.functions.deinit(self.alloc); + self.globals.deinit(self.alloc); + self.impl_table.deinit(); + self.types.deinit(); + } + + pub fn addFunction(self: *Module, func: Function) FuncId { + const id = FuncId.fromIndex(@intCast(self.functions.items.len)); + self.functions.append(self.alloc, func) catch unreachable; + return id; + } + + pub fn getFunction(self: *const Module, id: FuncId) *const Function { + return &self.functions.items[id.index()]; + } + + pub fn getFunctionMut(self: *Module, id: FuncId) *Function { + return &self.functions.items[id.index()]; + } + + pub fn addGlobal(self: *Module, global: Global) GlobalId { + const id = GlobalId.fromIndex(@intCast(self.globals.items.len)); + self.globals.append(self.alloc, global) catch unreachable; + return id; + } +}; + +// ── ImplTable ─────────────────────────────────────────────────────────── + +pub const ImplKey = struct { + protocol: TypeId, + concrete: TypeId, +}; + +pub const ImplTable = struct { + map: std.HashMap(ImplKey, []const FuncId, ImplKeyContext, 80), + alloc: Allocator, + + pub fn init(alloc: Allocator) ImplTable { + return .{ + .map = std.HashMap(ImplKey, []const FuncId, ImplKeyContext, 80).init(alloc), + .alloc = alloc, + }; + } + + pub fn deinit(self: *ImplTable) void { + self.map.deinit(); + } + + pub fn put(self: *ImplTable, key: ImplKey, methods: []const FuncId) void { + self.map.put(key, methods) catch unreachable; + } + + pub fn get(self: *const ImplTable, key: ImplKey) ?[]const FuncId { + return self.map.get(key); + } + + const ImplKeyContext = struct { + pub fn hash(_: ImplKeyContext, key: ImplKey) u64 { + var h = std.hash.Wyhash.init(0); + h.update(std.mem.asBytes(&key.protocol)); + h.update(std.mem.asBytes(&key.concrete)); + return h.final(); + } + + pub fn eql(_: ImplKeyContext, a: ImplKey, b: ImplKey) bool { + return a.protocol == b.protocol and a.concrete == b.concrete; + } + }; +}; + +// ── Builder ───────────────────────────────────────────────────────────── +// Fluent API for constructing one function at a time. + +pub const Builder = struct { + module: *Module, + func: ?FuncId = null, + current_block: ?BlockId = null, + /// Running instruction counter within the current function (for Ref assignment). + inst_counter: u32 = 0, + + pub fn init(module: *Module) Builder { + return .{ .module = module }; + } + + // ── Function setup ────────────────────────────────────────────── + + pub fn beginFunction(self: *Builder, name: StringId, params: []const Function.Param, ret_ty: TypeId) FuncId { + const func = Function.init(name, params, ret_ty); + const id = self.module.addFunction(func); + self.func = id; + self.inst_counter = 0; + self.current_block = null; + return id; + } + + pub fn finalize(self: *Builder) void { + self.func = null; + self.current_block = null; + self.inst_counter = 0; + } + + // ── Blocks ────────────────────────────────────────────────────── + + pub fn appendBlock(self: *Builder, name: StringId, params: []const TypeId) BlockId { + const f = self.currentFunc(); + const id = BlockId.fromIndex(@intCast(f.blocks.items.len)); + // Dupe params so the block owns the memory (callers may pass stack slices). + const owned_params = if (params.len > 0) + (self.module.alloc.dupe(TypeId, params) catch unreachable) + else + params; + f.blocks.append(self.module.alloc, Block.init(name, owned_params)) catch unreachable; + return id; + } + + pub fn switchToBlock(self: *Builder, block: BlockId) void { + self.current_block = block; + } + + // ── Emit helpers ──────────────────────────────────────────────── + + pub fn emit(self: *Builder, op: Op, ty: TypeId) Ref { + return self.emitSpan(op, ty, .{}); + } + + fn emitSpan(self: *Builder, op: Op, ty: TypeId, span: Span) Ref { + const block = self.currentBlock(); + const ref = Ref.fromIndex(self.inst_counter); + self.inst_counter += 1; + block.insts.append(self.module.alloc, .{ .op = op, .ty = ty, .span = span }) catch unreachable; + return ref; + } + + /// Emit an instruction with no meaningful result (terminators, stores). + fn emitVoid(self: *Builder, op: Op, ty: TypeId) void { + const block = self.currentBlock(); + self.inst_counter += 1; + block.insts.append(self.module.alloc, .{ .op = op, .ty = ty }) catch unreachable; + } + + // ── Constants ─────────────────────────────────────────────────── + + pub fn constInt(self: *Builder, val: i64, ty: TypeId) Ref { + return self.emit(.{ .const_int = val }, ty); + } + + pub fn constFloat(self: *Builder, val: f64, ty: TypeId) Ref { + return self.emit(.{ .const_float = val }, ty); + } + + pub fn constBool(self: *Builder, val: bool) Ref { + return self.emit(.{ .const_bool = val }, .bool); + } + + pub fn constString(self: *Builder, val: StringId) Ref { + return self.emit(.{ .const_string = val }, .string); + } + + pub fn constNull(self: *Builder, ty: TypeId) Ref { + return self.emit(.const_null, ty); + } + + pub fn constUndef(self: *Builder, ty: TypeId) Ref { + return self.emit(.const_undef, ty); + } + + // ── Arithmetic ────────────────────────────────────────────────── + + pub fn add(self: *Builder, lhs: Ref, rhs: Ref, ty: TypeId) Ref { + return self.emit(.{ .add = .{ .lhs = lhs, .rhs = rhs } }, ty); + } + + pub fn sub(self: *Builder, lhs: Ref, rhs: Ref, ty: TypeId) Ref { + return self.emit(.{ .sub = .{ .lhs = lhs, .rhs = rhs } }, ty); + } + + pub fn mul(self: *Builder, lhs: Ref, rhs: Ref, ty: TypeId) Ref { + return self.emit(.{ .mul = .{ .lhs = lhs, .rhs = rhs } }, ty); + } + + pub fn div(self: *Builder, lhs: Ref, rhs: Ref, ty: TypeId) Ref { + return self.emit(.{ .div = .{ .lhs = lhs, .rhs = rhs } }, ty); + } + + // ── Comparison ────────────────────────────────────────────────── + + pub fn cmpEq(self: *Builder, lhs: Ref, rhs: Ref) Ref { + return self.emit(.{ .cmp_eq = .{ .lhs = lhs, .rhs = rhs } }, .bool); + } + + pub fn cmpLt(self: *Builder, lhs: Ref, rhs: Ref) Ref { + return self.emit(.{ .cmp_lt = .{ .lhs = lhs, .rhs = rhs } }, .bool); + } + + pub fn cmpGt(self: *Builder, lhs: Ref, rhs: Ref) Ref { + return self.emit(.{ .cmp_gt = .{ .lhs = lhs, .rhs = rhs } }, .bool); + } + + // ── Memory ────────────────────────────────────────────────────── + + pub fn alloca(self: *Builder, ty: TypeId) Ref { + const ptr_ty = self.module.types.ptrTo(ty); + return self.emit(.{ .alloca = ty }, ptr_ty); + } + + pub fn load(self: *Builder, ptr: Ref, ty: TypeId) Ref { + return self.emit(.{ .load = .{ .operand = ptr } }, ty); + } + + pub fn store(self: *Builder, ptr: Ref, val: Ref) void { + self.emitVoid(.{ .store = .{ .ptr = ptr, .val = val } }, .void); + } + + // ── Struct ops ────────────────────────────────────────────────── + + pub fn structInit(self: *Builder, fields: []const Ref, ty: TypeId) Ref { + const owned = self.module.alloc.dupe(Ref, fields) catch unreachable; + return self.emit(.{ .struct_init = .{ .fields = owned } }, ty); + } + + pub fn structGet(self: *Builder, base: Ref, field_index: u32, ty: TypeId) Ref { + return self.emit(.{ .struct_get = .{ .base = base, .field_index = field_index } }, ty); + } + + pub fn structGep(self: *Builder, base: Ref, field_index: u32, ty: TypeId) Ref { + return self.emit(.{ .struct_gep = .{ .base = base, .field_index = field_index } }, ty); + } + + // ── Enum ops ──────────────────────────────────────────────────── + + pub fn enumInit(self: *Builder, tag: u32, payload: Ref, ty: TypeId) Ref { + return self.emit(.{ .enum_init = .{ .tag = tag, .payload = payload } }, ty); + } + + pub fn enumTag(self: *Builder, val: Ref) Ref { + return self.emit(.{ .enum_tag = .{ .operand = val } }, .s32); + } + + // ── Optional ops ──────────────────────────────────────────────── + + pub fn optionalWrap(self: *Builder, val: Ref, ty: TypeId) Ref { + return self.emit(.{ .optional_wrap = .{ .operand = val } }, ty); + } + + pub fn optionalUnwrap(self: *Builder, val: Ref, ty: TypeId) Ref { + return self.emit(.{ .optional_unwrap = .{ .operand = val } }, ty); + } + + pub fn optionalHasValue(self: *Builder, val: Ref) Ref { + return self.emit(.{ .optional_has_value = .{ .operand = val } }, .bool); + } + + // ── Calls ─────────────────────────────────────────────────────── + + pub fn call(self: *Builder, callee: FuncId, args: []const Ref, ret_ty: TypeId) Ref { + const owned = self.module.alloc.dupe(Ref, args) catch unreachable; + return self.emit(.{ .call = .{ .callee = callee, .args = owned } }, ret_ty); + } + + pub fn callClosure(self: *Builder, callee: Ref, args: []const Ref, ret_ty: TypeId) Ref { + const owned = self.module.alloc.dupe(Ref, args) catch unreachable; + return self.emit(.{ .call_closure = .{ .callee = callee, .args = owned } }, ret_ty); + } + + pub fn callBuiltin(self: *Builder, builtin: inst.BuiltinId, args: []const Ref, ret_ty: TypeId) Ref { + const owned = self.module.alloc.dupe(Ref, args) catch unreachable; + return self.emit(.{ .call_builtin = .{ .builtin = builtin, .args = owned } }, ret_ty); + } + + // ── Protocol ──────────────────────────────────────────────────── + + pub fn protocolCallDynamic(self: *Builder, receiver: Ref, method_index: u32, args: []const Ref, ret_ty: TypeId) Ref { + const owned = self.module.alloc.dupe(Ref, args) catch unreachable; + return self.emit(.{ .protocol_call_dynamic = .{ .receiver = receiver, .method_index = method_index, .args = owned } }, ret_ty); + } + + pub fn protocolErase(self: *Builder, concrete: Ref, protocol_type: TypeId) Ref { + return self.emit(.{ .protocol_erase = .{ .concrete = concrete, .protocol_type = protocol_type } }, protocol_type); + } + + // ── Closure ───────────────────────────────────────────────────── + + pub fn closureCreate(self: *Builder, func_id: FuncId, env: Ref, ty: TypeId) Ref { + return self.emit(.{ .closure_create = .{ .func = func_id, .env = env } }, ty); + } + + // ── Conversions ───────────────────────────────────────────────── + + pub fn widen(self: *Builder, operand: Ref, from: TypeId, to: TypeId) Ref { + return self.emit(.{ .widen = .{ .operand = operand, .from = from, .to = to } }, to); + } + + pub fn narrow(self: *Builder, operand: Ref, from: TypeId, to: TypeId) Ref { + return self.emit(.{ .narrow = .{ .operand = operand, .from = from, .to = to } }, to); + } + + // ── Any ───────────────────────────────────────────────────────── + + pub fn boxAny(self: *Builder, operand: Ref, source_type: TypeId) Ref { + return self.emit(.{ .box_any = .{ .operand = operand, .source_type = source_type } }, .any); + } + + // ── Context ───────────────────────────────────────────────────── + + pub fn contextLoad(self: *Builder, field: StringId, ty: TypeId) Ref { + return self.emit(.{ .context_load = .{ .field = field, .value = .none } }, ty); + } + + pub fn contextStore(self: *Builder, field: StringId, value: Ref) void { + self.emitVoid(.{ .context_store = .{ .field = field, .value = value } }, .void); + } + + // ── Terminators ───────────────────────────────────────────────── + + pub fn br(self: *Builder, target: BlockId, args: []const Ref) void { + const owned = self.module.alloc.dupe(Ref, args) catch unreachable; + self.emitVoid(.{ .br = .{ .target = target, .args = owned } }, .void); + } + + pub fn condBr(self: *Builder, cond: Ref, then_target: BlockId, then_args: []const Ref, else_target: BlockId, else_args: []const Ref) void { + const t_args = self.module.alloc.dupe(Ref, then_args) catch unreachable; + const e_args = self.module.alloc.dupe(Ref, else_args) catch unreachable; + self.emitVoid(.{ .cond_br = .{ + .cond = cond, + .then_target = then_target, + .then_args = t_args, + .else_target = else_target, + .else_args = e_args, + } }, .void); + } + + pub fn ret(self: *Builder, val: Ref, ty: TypeId) void { + self.emitVoid(.{ .ret = .{ .operand = val } }, ty); + } + + pub fn retVoid(self: *Builder) void { + self.emitVoid(.ret_void, .void); + } + + pub fn switchBr(self: *Builder, operand: Ref, cases: []const inst.SwitchBranch.Case, default: BlockId, default_args: []const Ref) void { + const owned_cases = self.module.alloc.dupe(inst.SwitchBranch.Case, cases) catch unreachable; + const owned_default_args = self.module.alloc.dupe(Ref, default_args) catch unreachable; + self.emitVoid(.{ .switch_br = .{ + .operand = operand, + .cases = owned_cases, + .default = default, + .default_args = owned_default_args, + } }, .void); + } + + pub fn emitUnreachable(self: *Builder) void { + self.emitVoid(.@"unreachable", .void); + } + + // ── Block params ─────────────────────────────────────────────── + + pub fn blockParam(self: *Builder, block: BlockId, param_index: u32, ty: TypeId) Ref { + return self.emit(.{ .block_param = .{ .block = block, .param_index = param_index } }, ty); + } + + // ── Internal helpers ──────────────────────────────────────────── + + fn currentFunc(self: *Builder) *Function { + return self.module.getFunctionMut(self.func.?); + } + + fn currentBlock(self: *Builder) *Block { + const f = self.currentFunc(); + return &f.blocks.items[self.current_block.?.index()]; + } +}; diff --git a/src/ir/print.test.zig b/src/ir/print.test.zig new file mode 100644 index 0000000..3bdbb49 --- /dev/null +++ b/src/ir/print.test.zig @@ -0,0 +1,89 @@ +// Tests for print.zig +const std = @import("std"); +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); +const mod_mod = @import("module.zig"); +const print_mod = @import("print.zig"); + +const TypeId = types.TypeId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const Function = inst_mod.Function; +const Module = mod_mod.Module; +const Builder = mod_mod.Builder; + +test "print simple add function" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + + var b = Builder.init(&module); + + const name_add = module.types.internString("add"); + const name_a = module.types.internString("a"); + const name_b = module.types.internString("b"); + const name_entry = module.types.internString("entry"); + + const params = &[_]Function.Param{ + .{ .name = name_a, .ty = .s64 }, + .{ .name = name_b, .ty = .s64 }, + }; + _ = b.beginFunction(name_add, params, .s64); + const entry = b.appendBlock(name_entry, &.{}); + b.switchToBlock(entry); + + const a_ref = b.constInt(10, .s64); + const b_ref = b.constInt(20, .s64); + const sum = b.add(a_ref, b_ref, .s64); + b.ret(sum, .s64); + b.finalize(); + + var aw = std.Io.Writer.Allocating.init(alloc); + try print_mod.printModule(&module, &aw.writer); + var result = aw.writer.toArrayList(); + defer result.deinit(alloc); + + const output = result.items; + try std.testing.expect(std.mem.indexOf(u8, output, "func @add(a: s64, b: s64) -> s64") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "entry:") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "const 10 : s64") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "add %0, %1 : s64") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "ret %2") != null); +} + +test "print conditional branch" { + const alloc = std.testing.allocator; + var module = Module.init(alloc); + defer module.deinit(); + + var b = Builder.init(&module); + + _ = b.beginFunction(module.types.internString("test"), &.{}, .s32); + const entry = b.appendBlock(module.types.internString("entry"), &.{}); + const then_bb = b.appendBlock(module.types.internString("then"), &.{}); + const else_bb = b.appendBlock(module.types.internString("else"), &.{}); + + b.switchToBlock(entry); + const cond = b.constBool(true); + b.condBr(cond, then_bb, &.{}, else_bb, &.{}); + + b.switchToBlock(then_bb); + const v1 = b.constInt(1, .s32); + b.ret(v1, .s32); + + b.switchToBlock(else_bb); + const v2 = b.constInt(0, .s32); + b.ret(v2, .s32); + b.finalize(); + + var aw = std.Io.Writer.Allocating.init(alloc); + try print_mod.printModule(&module, &aw.writer); + var result = aw.writer.toArrayList(); + defer result.deinit(alloc); + + const output = result.items; + try std.testing.expect(std.mem.indexOf(u8, output, "cond_br %0, bb1, bb2") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "then:") != null); + try std.testing.expect(std.mem.indexOf(u8, output, "else:") != null); +} diff --git a/src/ir/print.zig b/src/ir/print.zig new file mode 100644 index 0000000..4e5356d --- /dev/null +++ b/src/ir/print.zig @@ -0,0 +1,531 @@ +const std = @import("std"); +const types = @import("types.zig"); +const inst_mod = @import("inst.zig"); +const mod_mod = @import("module.zig"); + +const TypeId = types.TypeId; +const TypeTable = types.TypeTable; +const StringId = types.StringId; +const Ref = inst_mod.Ref; +const BlockId = inst_mod.BlockId; +const FuncId = inst_mod.FuncId; +const GlobalId = inst_mod.GlobalId; +const Inst = inst_mod.Inst; +const Op = inst_mod.Op; +const Function = inst_mod.Function; +const Block = inst_mod.Block; +const Global = inst_mod.Global; +const ConstantValue = inst_mod.ConstantValue; +const Module = mod_mod.Module; + +const Writer = *std.Io.Writer; + +// ── Public API ────────────────────────────────────────────────────────── + +pub fn printModule(module: *const Module, writer: Writer) !void { + // Print globals + for (module.globals.items, 0..) |global, i| { + try printGlobal(&global, @intCast(i), module, writer); + } + if (module.globals.items.len > 0 and module.functions.items.len > 0) { + try writer.writeByte('\n'); + } + // Print functions + for (module.functions.items, 0..) |*func, i| { + if (i > 0) try writer.writeByte('\n'); + try printFunction(func, @intCast(i), module, writer); + } +} + +pub fn printFunction(func: *const Function, func_idx: u32, module: *const Module, writer: Writer) !void { + const tt = &module.types; + + // Signature + if (func.is_extern) try writer.writeAll("extern "); + if (func.is_comptime) try writer.writeAll("comptime "); + try writer.writeAll("func @"); + try writer.writeAll(tt.getString(func.name)); + try writer.writeByte('('); + for (func.params, 0..) |param, i| { + if (i > 0) try writer.writeAll(", "); + const pname = tt.getString(param.name); + if (pname.len > 0) { + try writer.writeAll(pname); + try writer.writeAll(": "); + } + try writeType(param.ty, tt, writer); + } + try writer.writeAll(") -> "); + try writeType(func.ret, tt, writer); + + if (func.is_extern) { + try writer.writeAll(";\n"); + return; + } + + try writer.writeAll(" {\n"); + + // Blocks + var ref_counter: u32 = 0; + _ = func_idx; + for (func.blocks.items, 0..) |*block, bi| { + try printBlock(block, @intCast(bi), tt, &ref_counter, writer); + } + + try writer.writeAll("}\n"); +} + +fn printGlobal(global: *const Global, _: u32, module: *const Module, writer: Writer) !void { + const tt = &module.types; + if (global.is_extern) try writer.writeAll("extern "); + if (global.is_const) try writer.writeAll("const ") else try writer.writeAll("global "); + try writer.writeAll("@"); + try writer.writeAll(tt.getString(global.name)); + try writer.writeAll(": "); + try writeType(global.ty, tt, writer); + if (global.init_val) |init| { + try writer.writeAll(" = "); + try writeConstant(init, writer); + } + if (global.comptime_func) |fid| { + try writer.print(" = #run @{d}", .{fid.index()}); + } + try writer.writeAll(";\n"); +} + +fn printBlock(block: *const Block, block_idx: u32, tt: *const TypeTable, ref_counter: *u32, writer: Writer) !void { + // Block header + try writer.writeAll(" "); + const name = tt.getString(block.name); + if (name.len > 0) { + try writer.writeAll(name); + } else { + try writer.print("bb{d}", .{block_idx}); + } + if (block.params.len > 0) { + try writer.writeByte('('); + for (block.params, 0..) |pty, i| { + if (i > 0) try writer.writeAll(", "); + try writeType(pty, tt, writer); + } + try writer.writeByte(')'); + } + try writer.writeAll(":\n"); + + // Instructions + for (block.insts.items) |*instruction| { + try printInst(instruction, ref_counter.*, tt, writer); + ref_counter.* += 1; + } +} + +fn printInst(instruction: *const Inst, ref_idx: u32, tt: *const TypeTable, writer: Writer) !void { + const op = instruction.op; + const ty = instruction.ty; + + // Check if this is a void/terminator instruction (no result) + const has_result = !isVoidOp(op); + + try writer.writeAll(" "); + if (has_result) { + try writer.print("%{d} = ", .{ref_idx}); + } + + switch (op) { + // ── Constants ─────────────────────────────────────────── + .const_int => |v| try writer.print("const {d} : ", .{v}), + .const_float => |v| try writer.print("const {d:.6} : ", .{v}), + .const_bool => |v| try writer.print("const {s} : ", .{if (v) "true" else "false"}), + .const_string => |sid| { + try writer.writeAll("const \""); + try writer.writeAll(tt.getString(sid)); + try writer.writeAll("\" : "); + }, + .const_null => try writer.writeAll("const null : "), + .const_undef => try writer.writeAll("const undef : "), + + // ── Arithmetic ────────────────────────────────────────── + .add => |b| try writer.print("add %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .sub => |b| try writer.print("sub %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .mul => |b| try writer.print("mul %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .div => |b| try writer.print("div %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .mod => |b| try writer.print("mod %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .neg => |u| try writer.print("neg %{d} : ", .{u.operand.index()}), + + // ── Bitwise ───────────────────────────────────────────── + .bit_and => |b| try writer.print("bit_and %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .bit_or => |b| try writer.print("bit_or %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .bit_xor => |b| try writer.print("bit_xor %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .bit_not => |u| try writer.print("bit_not %{d} : ", .{u.operand.index()}), + .shl => |b| try writer.print("shl %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .shr => |b| try writer.print("shr %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + + // ── Comparison ────────────────────────────────────────── + .cmp_eq => |b| try writer.print("cmp_eq %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .cmp_ne => |b| try writer.print("cmp_ne %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .cmp_lt => |b| try writer.print("cmp_lt %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .cmp_le => |b| try writer.print("cmp_le %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .cmp_gt => |b| try writer.print("cmp_gt %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .cmp_ge => |b| try writer.print("cmp_ge %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + + // ── Logical ───────────────────────────────────────────── + .bool_and => |b| try writer.print("bool_and %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .bool_or => |b| try writer.print("bool_or %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + .bool_not => |u| try writer.print("bool_not %{d} : ", .{u.operand.index()}), + + // ── Conversions ───────────────────────────────────────── + .widen => |c| { + try writer.print("widen %{d} : ", .{c.operand.index()}); + try writeType(c.from, tt, writer); + try writer.writeAll(" -> "); + try writeType(c.to, tt, writer); + try writer.writeByte('\n'); + return; + }, + .narrow => |c| { + try writer.print("narrow %{d} : ", .{c.operand.index()}); + try writeType(c.from, tt, writer); + try writer.writeAll(" -> "); + try writeType(c.to, tt, writer); + try writer.writeByte('\n'); + return; + }, + .bitcast => |c| { + try writer.print("bitcast %{d} : ", .{c.operand.index()}); + try writeType(c.from, tt, writer); + try writer.writeAll(" -> "); + try writeType(c.to, tt, writer); + try writer.writeByte('\n'); + return; + }, + .int_to_float => |c| { + try writer.print("int_to_float %{d} : ", .{c.operand.index()}); + try writeType(c.from, tt, writer); + try writer.writeAll(" -> "); + try writeType(c.to, tt, writer); + try writer.writeByte('\n'); + return; + }, + .float_to_int => |c| { + try writer.print("float_to_int %{d} : ", .{c.operand.index()}); + try writeType(c.from, tt, writer); + try writer.writeAll(" -> "); + try writeType(c.to, tt, writer); + try writer.writeByte('\n'); + return; + }, + + // ── Memory ────────────────────────────────────────────── + .alloca => |aty| { + try writer.writeAll("alloca "); + try writeType(aty, tt, writer); + try writer.writeAll(" : "); + }, + .load => |u| try writer.print("load %{d} : ", .{u.operand.index()}), + .store => |s| { + try writer.print("store %{d}, %{d}\n", .{ s.ptr.index(), s.val.index() }); + return; + }, + .heap_alloc => |u| try writer.print("heap_alloc %{d} : ", .{u.operand.index()}), + .heap_free => |u| { + try writer.print("heap_free %{d}\n", .{u.operand.index()}); + return; + }, + + // ── Struct ops ────────────────────────────────────────── + .struct_init => |agg| { + try writer.writeAll("struct_init ["); + for (agg.fields, 0..) |f, i| { + if (i > 0) try writer.writeAll(", "); + try writer.print("%{d}", .{f.index()}); + } + try writer.writeAll("] : "); + }, + .struct_get => |fa| try writer.print("struct_get %{d}, {d} : ", .{ fa.base.index(), fa.field_index }), + .struct_gep => |fa| try writer.print("struct_gep %{d}, {d} : ", .{ fa.base.index(), fa.field_index }), + + // ── Enum ops ──────────────────────────────────────────── + .enum_init => |ei| { + if (ei.payload.isNone()) { + try writer.print("enum_init tag={d} : ", .{ei.tag}); + } else { + try writer.print("enum_init tag={d}, payload=%{d} : ", .{ ei.tag, ei.payload.index() }); + } + }, + .enum_tag => |u| try writer.print("enum_tag %{d} : ", .{u.operand.index()}), + .enum_payload => |fa| try writer.print("enum_payload %{d}, {d} : ", .{ fa.base.index(), fa.field_index }), + + // ── Union ops ─────────────────────────────────────────── + .union_get => |fa| try writer.print("union_get %{d}, {d} : ", .{ fa.base.index(), fa.field_index }), + .union_gep => |fa| try writer.print("union_gep %{d}, {d} : ", .{ fa.base.index(), fa.field_index }), + + // ── Array/Slice ops ───────────────────────────────────── + .index_get => |b| try writer.print("index_get %{d}[%{d}] : ", .{ b.lhs.index(), b.rhs.index() }), + .index_gep => |b| try writer.print("index_gep %{d}[%{d}] : ", .{ b.lhs.index(), b.rhs.index() }), + .length => |u| try writer.print("length %{d} : ", .{u.operand.index()}), + .data_ptr => |u| try writer.print("data_ptr %{d} : ", .{u.operand.index()}), + .subslice => |s| try writer.print("subslice %{d}[%{d}..%{d}] : ", .{ s.base.index(), s.lo.index(), s.hi.index() }), + .array_to_slice => |u| try writer.print("array_to_slice %{d} : ", .{u.operand.index()}), + + // ── Tuple ops ─────────────────────────────────────────── + .tuple_init => |agg| { + try writer.writeAll("tuple_init ["); + for (agg.fields, 0..) |f, i| { + if (i > 0) try writer.writeAll(", "); + try writer.print("%{d}", .{f.index()}); + } + try writer.writeAll("] : "); + }, + .tuple_get => |fa| try writer.print("tuple_get %{d}, {d} : ", .{ fa.base.index(), fa.field_index }), + + // ── Optional ops ──────────────────────────────────────── + .optional_wrap => |u| try writer.print("optional_wrap %{d} : ", .{u.operand.index()}), + .optional_unwrap => |u| try writer.print("optional_unwrap %{d} : ", .{u.operand.index()}), + .optional_has_value => |u| try writer.print("optional_has_value %{d} : ", .{u.operand.index()}), + .optional_coalesce => |b| try writer.print("optional_coalesce %{d}, %{d} : ", .{ b.lhs.index(), b.rhs.index() }), + + // ── Pointer ops ───────────────────────────────────────── + .addr_of => |u| try writer.print("addr_of %{d} : ", .{u.operand.index()}), + .deref => |u| try writer.print("deref %{d} : ", .{u.operand.index()}), + + // ── Vector ops ────────────────────────────────────────── + .vec_splat => |u| try writer.print("vec_splat %{d} : ", .{u.operand.index()}), + .vec_extract => |b| try writer.print("vec_extract %{d}[%{d}] : ", .{ b.lhs.index(), b.rhs.index() }), + .vec_insert => |t| try writer.print("vec_insert %{d}[%{d}] = %{d} : ", .{ t.a.index(), t.b.index(), t.c.index() }), + + // ── Calls ─────────────────────────────────────────────── + .call => |c| { + try writer.print("call @{d}(", .{c.callee.index()}); + try writeArgs(c.args, writer); + try writer.writeAll(") : "); + }, + .call_indirect => |c| { + try writer.print("call_indirect %{d}(", .{c.callee.index()}); + try writeArgs(c.args, writer); + try writer.writeAll(") : "); + }, + .call_closure => |c| { + try writer.print("call_closure %{d}(", .{c.callee.index()}); + try writeArgs(c.args, writer); + try writer.writeAll(") : "); + }, + .call_builtin => |c| { + try writer.print("call_builtin {s}(", .{@tagName(c.builtin)}); + try writeArgs(c.args, writer); + try writer.writeAll(") : "); + }, + + // ── Protocol ──────────────────────────────────────────── + .protocol_call_dynamic => |c| { + try writer.print("protocol_call_dynamic %{d}.{d}(", .{ c.receiver.index(), c.method_index }); + try writeArgs(c.args, writer); + try writer.writeAll(") : "); + }, + .protocol_erase => |pe| { + try writer.print("protocol_erase %{d} -> ", .{pe.concrete.index()}); + try writeType(pe.protocol_type, tt, writer); + try writer.writeByte('\n'); + return; + }, + + // ── Closure ───────────────────────────────────────────── + .closure_create => |cc| { + try writer.print("closure_create @{d}", .{cc.func.index()}); + if (!cc.env.isNone()) { + try writer.print(", env=%{d}", .{cc.env.index()}); + } + try writer.writeAll(" : "); + }, + + // ── Context ───────────────────────────────────────────── + .context_load => |co| { + try writer.writeAll("context_load ."); + try writer.writeAll(tt.getString(co.field)); + try writer.writeAll(" : "); + }, + .context_store => |co| { + try writer.writeAll("context_store ."); + try writer.writeAll(tt.getString(co.field)); + try writer.print(", %{d}\n", .{co.value.index()}); + return; + }, + .context_save => { + try writer.writeAll("context_save : "); + }, + .context_restore => |u| { + try writer.print("context_restore %{d}\n", .{u.operand.index()}); + return; + }, + + // ── Globals ───────────────────────────────────────────── + .global_get => |gid| try writer.print("global_get @{d} : ", .{gid.index()}), + .global_set => |gs| { + try writer.print("global_set @{d}, %{d}\n", .{ gs.global.index(), gs.value.index() }); + return; + }, + + // ── Block params ──────────────────────────────────────── + .block_param => |bp| try writer.print("block_param bb{d}[{d}] : ", .{ bp.block.index(), bp.param_index }), + + // ── Any ───────────────────────────────────────────────── + .box_any => |ba| try writer.print("box_any %{d} : ", .{ba.operand.index()}), + .unbox_any => |u| try writer.print("unbox_any %{d} : ", .{u.operand.index()}), + + // ── Terminators ───────────────────────────────────────── + .br => |b| { + try writer.print("br bb{d}", .{b.target.index()}); + if (b.args.len > 0) { + try writer.writeByte('('); + try writeArgs(b.args, writer); + try writer.writeByte(')'); + } + try writer.writeByte('\n'); + return; + }, + .cond_br => |cb| { + try writer.print("cond_br %{d}, bb{d}", .{ cb.cond.index(), cb.then_target.index() }); + if (cb.then_args.len > 0) { + try writer.writeByte('('); + try writeArgs(cb.then_args, writer); + try writer.writeByte(')'); + } + try writer.print(", bb{d}", .{cb.else_target.index()}); + if (cb.else_args.len > 0) { + try writer.writeByte('('); + try writeArgs(cb.else_args, writer); + try writer.writeByte(')'); + } + try writer.writeByte('\n'); + return; + }, + .switch_br => |sb| { + try writer.print("switch_br %{d} [", .{sb.operand.index()}); + for (sb.cases, 0..) |case, i| { + if (i > 0) try writer.writeAll(", "); + try writer.print("{d} -> bb{d}", .{ case.value, case.target.index() }); + } + try writer.print("] default bb{d}\n", .{sb.default.index()}); + return; + }, + .ret => |u| { + try writer.print("ret %{d}\n", .{u.operand.index()}); + return; + }, + .ret_void => { + try writer.writeAll("ret void\n"); + return; + }, + .@"unreachable" => { + try writer.writeAll("unreachable\n"); + return; + }, + + // ── Misc ──────────────────────────────────────────────── + .placeholder => |sid| { + try writer.writeAll("placeholder \""); + try writer.writeAll(tt.getString(sid)); + try writer.writeAll("\" : "); + }, + } + + // Default: print the result type + try writeType(ty, tt, writer); + try writer.writeByte('\n'); +} + +// ── Helpers ───────────────────────────────────────────────────────────── + +fn writeType(id: TypeId, tt: *const TypeTable, writer: Writer) !void { + // Fast path for builtins + if (id.isBuiltin()) { + try writer.writeAll(tt.typeName(id)); + return; + } + // Composite types — format recursively + const info = tt.get(id); + switch (info) { + .@"struct" => |s| try writer.writeAll(tt.getString(s.name)), + .@"enum" => |e| try writer.writeAll(tt.getString(e.name)), + .@"union" => |u| try writer.writeAll(tt.getString(u.name)), + .protocol => |p| try writer.writeAll(tt.getString(p.name)), + .pointer => |p| { + try writer.writeByte('*'); + try writeType(p.pointee, tt, writer); + }, + .many_pointer => |p| { + try writer.writeAll("[*]"); + try writeType(p.element, tt, writer); + }, + .slice => |s| { + try writer.writeAll("[]"); + try writeType(s.element, tt, writer); + }, + .array => |a| { + try writer.print("[{d}]", .{a.length}); + try writeType(a.element, tt, writer); + }, + .optional => |o| { + try writer.writeByte('?'); + try writeType(o.child, tt, writer); + }, + .vector => |v| { + try writer.print("Vector({d}, ", .{v.length}); + try writeType(v.element, tt, writer); + try writer.writeByte(')'); + }, + .function => |f| { + try writer.writeByte('('); + for (f.params, 0..) |p, i| { + if (i > 0) try writer.writeAll(", "); + try writeType(p, tt, writer); + } + try writer.writeAll(") -> "); + try writeType(f.ret, tt, writer); + }, + .closure => |c| { + try writer.writeAll("closure("); + for (c.params, 0..) |p, i| { + if (i > 0) try writer.writeAll(", "); + try writeType(p, tt, writer); + } + try writer.writeAll(") -> "); + try writeType(c.ret, tt, writer); + }, + .tuple => |t| { + try writer.writeByte('('); + for (t.fields, 0..) |f, i| { + if (i > 0) try writer.writeAll(", "); + try writeType(f, tt, writer); + } + try writer.writeByte(')'); + }, + else => try writer.writeAll(tt.typeName(id)), + } +} + +fn writeArgs(args: []const Ref, writer: Writer) !void { + for (args, 0..) |arg, i| { + if (i > 0) try writer.writeAll(", "); + try writer.print("%{d}", .{arg.index()}); + } +} + +fn writeConstant(val: ConstantValue, writer: Writer) !void { + switch (val) { + .int => |v| try writer.print("{d}", .{v}), + .float => |v| try writer.print("{d:.6}", .{v}), + .boolean => |v| try writer.writeAll(if (v) "true" else "false"), + .string => try writer.writeAll("\"...\""), + .null_val => try writer.writeAll("null"), + .undef => try writer.writeAll("undef"), + .zeroinit => try writer.writeAll("zeroinit"), + .aggregate => try writer.writeAll("{...}"), + } +} + +fn isVoidOp(op: Op) bool { + return switch (op) { + .store, .heap_free, .context_store, .context_restore, .global_set, .br, .cond_br, .switch_br, .ret, .ret_void, .@"unreachable" => true, + else => false, + }; +} diff --git a/src/ir/type_bridge.test.zig b/src/ir/type_bridge.test.zig new file mode 100644 index 0000000..3ceb260 --- /dev/null +++ b/src/ir/type_bridge.test.zig @@ -0,0 +1,111 @@ +// Tests for type_bridge.zig +const std = @import("std"); +const types = @import("types.zig"); +const type_bridge = @import("type_bridge.zig"); +const ast = @import("../ast.zig"); +const Node = ast.Node; + +const TypeId = types.TypeId; +const TypeInfo = types.TypeInfo; +const TypeTable = types.TypeTable; + +test "bridgeType: primitives" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + try std.testing.expectEqual(TypeId.s32, type_bridge.bridgeType(.{ .signed = 32 }, &table)); + try std.testing.expectEqual(TypeId.u8, type_bridge.bridgeType(.{ .unsigned = 8 }, &table)); + try std.testing.expectEqual(TypeId.f64, type_bridge.bridgeType(.f64, &table)); + try std.testing.expectEqual(TypeId.void, type_bridge.bridgeType(.void_type, &table)); + try std.testing.expectEqual(TypeId.bool, type_bridge.bridgeType(.boolean, &table)); + try std.testing.expectEqual(TypeId.string, type_bridge.bridgeType(.string_type, &table)); + try std.testing.expectEqual(TypeId.any, type_bridge.bridgeType(.any_type, &table)); +} + +test "bridgeType: composite types" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + // Pointer + const ptr_id = type_bridge.bridgeType(.{ .pointer_type = .{ .pointee_name = "s32" } }, &table); + try std.testing.expectEqual(TypeInfo{ .pointer = .{ .pointee = .s32 } }, table.get(ptr_id)); + + // Slice + const slice_id = type_bridge.bridgeType(.{ .slice_type = .{ .element_name = "u8" } }, &table); + try std.testing.expectEqual(TypeInfo{ .slice = .{ .element = .u8 } }, table.get(slice_id)); + + // Array + const arr_id = type_bridge.bridgeType(.{ .array_type = .{ .element_name = "f32", .length = 4 } }, &table); + try std.testing.expectEqual(TypeInfo{ .array = .{ .element = .f32, .length = 4 } }, table.get(arr_id)); + + // Optional + const opt_id = type_bridge.bridgeType(.{ .optional_type = .{ .child_name = "s64" } }, &table); + try std.testing.expectEqual(TypeInfo{ .optional = .{ .child = .s64 } }, table.get(opt_id)); +} + +test "resolveAstType: primitive type_expr" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const node = try alloc.create(Node); + defer alloc.destroy(node); + node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "f64" } } }; + + try std.testing.expectEqual(TypeId.f64, type_bridge.resolveAstType(node, &table)); +} + +test "resolveAstType: pointer type" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const inner = try alloc.create(Node); + defer alloc.destroy(inner); + inner.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "s32" } } }; + + const node = try alloc.create(Node); + defer alloc.destroy(node); + node.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .pointer_type_expr = .{ .pointee_type = inner } } }; + + const id = type_bridge.resolveAstType(node, &table); + try std.testing.expectEqual(TypeInfo{ .pointer = .{ .pointee = .s32 } }, table.get(id)); +} + +test "resolveAstType: optional slice" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const elem = try alloc.create(Node); + defer alloc.destroy(elem); + elem.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .type_expr = .{ .name = "u8" } } }; + + const slice = try alloc.create(Node); + defer alloc.destroy(slice); + slice.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .slice_type_expr = .{ .element_type = elem } } }; + + const opt = try alloc.create(Node); + defer alloc.destroy(opt); + opt.* = .{ .span = .{ .start = 0, .end = 0 }, .data = .{ .optional_type_expr = .{ .inner_type = slice } } }; + + const id = type_bridge.resolveAstType(opt, &table); + const info = table.get(id); + switch (info) { + .optional => |o| { + const child_info = table.get(o.child); + try std.testing.expectEqual(TypeInfo{ .slice = .{ .element = .u8 } }, child_info); + }, + else => return error.TestUnexpectedResult, + } +} + +test "resolveAstType: null returns default" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + try std.testing.expectEqual(TypeId.s64, type_bridge.resolveAstType(null, &table)); +} diff --git a/src/ir/type_bridge.zig b/src/ir/type_bridge.zig new file mode 100644 index 0000000..4d8cac1 --- /dev/null +++ b/src/ir/type_bridge.zig @@ -0,0 +1,288 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const ast = @import("../ast.zig"); +const Node = ast.Node; +const sx_types = @import("../types.zig"); +const ir_types = @import("types.zig"); +const TypeId = ir_types.TypeId; +const TypeInfo = ir_types.TypeInfo; +const TypeTable = ir_types.TypeTable; +const StringId = ir_types.StringId; + +// ── AST Node → TypeId ─────────────────────────────────────────────────── +// Resolve an AST type node into an IR TypeId. Used during lowering when +// we only have the parsed AST (no codegen type registry). + +pub fn resolveAstType(node: ?*const Node, table: *TypeTable) TypeId { + const n = node orelse return .s64; // no annotation → default + return switch (n.data) { + .type_expr => |te| resolveTypeName(te.name, table), + .array_type_expr => |at| resolveArrayType(&at, table), + .slice_type_expr => |st| resolveSliceType(&st, table), + .pointer_type_expr => |pt| resolvePointerType(&pt, table), + .many_pointer_type_expr => |mpt| resolveManyPointerType(&mpt, table), + .optional_type_expr => |ot| resolveOptionalType(&ot, table), + .function_type_expr => |ft| resolveFunctionType(&ft, table), + .closure_type_expr => |ct| resolveClosureType(&ct, table), + .tuple_type_expr => |tt| resolveTupleType(&tt, table), + .parameterized_type_expr => |pt| resolveParameterizedType(&pt, table), + .inferred_type => .s64, // inferred — default until we have type inference + else => .s64, // fallback for unknown nodes + }; +} + +// ── types.Type → TypeId ───────────────────────────────────────────────── +// Translate an existing codegen Type value into an IR TypeId. Used when +// we have access to the codegen's resolved type info (Phase 3+). + +pub fn bridgeType(ty: sx_types.Type, table: *TypeTable) TypeId { + return switch (ty) { + .signed => |w| switch (w) { + 8 => .s8, + 16 => .s16, + 32 => .s32, + 64 => .s64, + else => .s64, + }, + .unsigned => |w| switch (w) { + 8 => .u8, + 16 => .u16, + 32 => .u32, + 64 => .u64, + else => .u64, + }, + .f32 => .f32, + .f64 => .f64, + .void_type => .void, + .boolean => .bool, + .string_type => .string, + .any_type => .any, + .enum_type => |name| resolveNamedType(name, .@"enum", table), + .struct_type => |name| resolveNamedType(name, .@"struct", table), + .union_type => |name| resolveNamedType(name, .@"union", table), + .array_type => |info| blk: { + const elem = resolveTypeName(info.element_name, table); + break :blk table.arrayOf(elem, info.length); + }, + .slice_type => |info| blk: { + const elem = resolveTypeName(info.element_name, table); + break :blk table.sliceOf(elem); + }, + .pointer_type => |info| blk: { + const pointee = resolveTypeName(info.pointee_name, table); + break :blk table.ptrTo(pointee); + }, + .many_pointer_type => |info| blk: { + const elem = resolveTypeName(info.element_name, table); + break :blk table.manyPtrTo(elem); + }, + .optional_type => |info| blk: { + const child = resolveTypeName(info.child_name, table); + break :blk table.optionalOf(child); + }, + .vector_type => |info| blk: { + const elem = resolveTypeName(info.element_name, table); + break :blk table.vectorOf(elem, info.length); + }, + .function_type => |info| blk: { + const alloc = table.alloc; + var param_ids = std.ArrayList(TypeId).empty; + for (info.param_types) |pt| { + param_ids.append(alloc, bridgeType(pt, table)) catch unreachable; + } + const ret_id = bridgeType(info.return_type.*, table); + break :blk table.functionType(param_ids.items, ret_id); + }, + .closure_type => |info| blk: { + const alloc = table.alloc; + var param_ids = std.ArrayList(TypeId).empty; + for (info.param_types) |pt| { + param_ids.append(alloc, bridgeType(pt, table)) catch unreachable; + } + const ret_id = bridgeType(info.return_type.*, table); + break :blk table.closureType(param_ids.items, ret_id); + }, + .tuple_type => |info| blk: { + const alloc = table.alloc; + var field_ids = std.ArrayList(TypeId).empty; + for (info.field_types) |ft| { + field_ids.append(alloc, bridgeType(ft, table)) catch unreachable; + } + var name_ids: ?[]const StringId = null; + if (info.field_names) |names| { + var ids = std.ArrayList(StringId).empty; + for (names) |n| { + ids.append(alloc, table.internString(n)) catch unreachable; + } + name_ids = ids.items; + } + break :blk table.intern(.{ .tuple = .{ + .fields = field_ids.items, + .names = name_ids, + } }); + }, + .meta_type => .any, // meta types map to Any for now + }; +} + +// ── Internal helpers ───────────────────────────────────────────────────── + +const NamedKind = enum { @"struct", @"enum", @"union" }; + +fn resolveNamedType(name: []const u8, kind: NamedKind, table: *TypeTable) TypeId { + // Check if primitive first + if (resolveTypePrimitive(name)) |id| return id; + + // Register as a named type + const name_id = table.internString(name); + return switch (kind) { + .@"struct" => table.intern(.{ .@"struct" = .{ .name = name_id, .fields = &.{} } }), + .@"enum" => table.intern(.{ .@"enum" = .{ .name = name_id, .variants = &.{} } }), + .@"union" => table.intern(.{ .@"union" = .{ .name = name_id, .fields = &.{}, .tag_type = null } }), + }; +} + +fn resolveTypeName(name: []const u8, table: *TypeTable) TypeId { + // Try primitive first + if (resolveTypePrimitive(name)) |id| return id; + + // Sentinel-terminated slice: [:0]u8 → string + if (name.len >= 5 and name[0] == '[' and name[1] == ':') { + if (std.mem.indexOfScalar(u8, name, ']')) |close| { + const sentinel = name[2..close]; + const elem = name[close + 1 ..]; + if (std.mem.eql(u8, sentinel, "0") and std.mem.eql(u8, elem, "u8")) { + return .string; + } + } + } + + // Many-pointer: [*]T + if (name.len >= 4 and name[0] == '[' and name[1] == '*' and name[2] == ']') { + const elem = resolveTypeName(name[3..], table); + return table.manyPtrTo(elem); + } + + // Pointer: *T + if (name.len >= 2 and name[0] == '*') { + const pointee = resolveTypeName(name[1..], table); + return table.ptrTo(pointee); + } + + // Optional: ?T + if (name.len >= 2 and name[0] == '?') { + const child = resolveTypeName(name[1..], table); + return table.optionalOf(child); + } + + // Assume it's a named struct/enum/union type + const name_id = table.internString(name); + return table.intern(.{ .@"struct" = .{ .name = name_id, .fields = &.{} } }); +} + +fn resolveTypePrimitive(name: []const u8) ?TypeId { + if (name.len == 0) return null; + // Fast path for common types + if (std.mem.eql(u8, name, "s64")) return .s64; + if (std.mem.eql(u8, name, "s32")) return .s32; + if (std.mem.eql(u8, name, "s16")) return .s16; + if (std.mem.eql(u8, name, "s8")) return .s8; + if (std.mem.eql(u8, name, "u64")) return .u64; + if (std.mem.eql(u8, name, "u32")) return .u32; + if (std.mem.eql(u8, name, "u16")) return .u16; + if (std.mem.eql(u8, name, "u8")) return .u8; + if (std.mem.eql(u8, name, "f32")) return .f32; + if (std.mem.eql(u8, name, "f64")) return .f64; + if (std.mem.eql(u8, name, "bool")) return .bool; + if (std.mem.eql(u8, name, "string")) return .string; + if (std.mem.eql(u8, name, "void")) return .void; + if (std.mem.eql(u8, name, "Any")) return .any; + if (std.mem.eql(u8, name, "noreturn")) return .noreturn; + return null; +} + +fn resolveArrayType(at: *const ast.ArrayTypeExpr, table: *TypeTable) TypeId { + const elem = resolveAstType(at.element_type, table); + const length: u32 = switch (at.length.data) { + .int_literal => |lit| @intCast(@as(u64, @bitCast(lit.value))), + else => 0, + }; + return table.arrayOf(elem, length); +} + +fn resolveSliceType(st: *const ast.SliceTypeExpr, table: *TypeTable) TypeId { + const elem = resolveAstType(st.element_type, table); + return table.sliceOf(elem); +} + +fn resolvePointerType(pt: *const ast.PointerTypeExpr, table: *TypeTable) TypeId { + const pointee = resolveAstType(pt.pointee_type, table); + return table.ptrTo(pointee); +} + +fn resolveManyPointerType(mpt: *const ast.ManyPointerTypeExpr, table: *TypeTable) TypeId { + const elem = resolveAstType(mpt.element_type, table); + return table.manyPtrTo(elem); +} + +fn resolveOptionalType(ot: *const ast.OptionalTypeExpr, table: *TypeTable) TypeId { + const child = resolveAstType(ot.inner_type, table); + return table.optionalOf(child); +} + +fn resolveFunctionType(ft: *const ast.FunctionTypeExpr, table: *TypeTable) TypeId { + const alloc = table.alloc; + var param_ids = std.ArrayList(TypeId).empty; + for (ft.param_types) |pt| { + param_ids.append(alloc, resolveAstType(pt, table)) catch unreachable; + } + const ret_id = if (ft.return_type) |rt| resolveAstType(rt, table) else TypeId.void; + return table.functionType(param_ids.items, ret_id); +} + +fn resolveClosureType(ct: *const ast.ClosureTypeExpr, table: *TypeTable) TypeId { + const alloc = table.alloc; + var param_ids = std.ArrayList(TypeId).empty; + for (ct.param_types) |pt| { + param_ids.append(alloc, resolveAstType(pt, table)) catch unreachable; + } + const ret_id = if (ct.return_type) |rt| resolveAstType(rt, table) else TypeId.void; + return table.closureType(param_ids.items, ret_id); +} + +fn resolveTupleType(tt: *const ast.TupleTypeExpr, table: *TypeTable) TypeId { + const alloc = table.alloc; + var field_ids = std.ArrayList(TypeId).empty; + for (tt.field_types) |ft| { + field_ids.append(alloc, resolveAstType(ft, table)) catch unreachable; + } + var name_ids: ?[]const StringId = null; + if (tt.field_names) |names| { + var ids = std.ArrayList(StringId).empty; + for (names) |n| { + ids.append(alloc, table.internString(n)) catch unreachable; + } + name_ids = ids.items; + } + return table.intern(.{ .tuple = .{ + .fields = field_ids.items, + .names = name_ids, + } }); +} + +fn resolveParameterizedType(pt: *const ast.ParameterizedTypeExpr, table: *TypeTable) TypeId { + // Vector(N, T) is a built-in parameterized type + if (std.mem.eql(u8, pt.name, "Vector")) { + if (pt.args.len == 2) { + const length: u32 = switch (pt.args[0].data) { + .int_literal => |lit| @intCast(@as(u64, @bitCast(lit.value))), + else => 0, + }; + const elem = resolveAstType(pt.args[1], table); + return table.vectorOf(elem, length); + } + } + // Generic struct instantiation — register as named type + const name_id = table.internString(pt.name); + return table.intern(.{ .@"struct" = .{ .name = name_id, .fields = &.{} } }); +} diff --git a/src/ir/types.test.zig b/src/ir/types.test.zig new file mode 100644 index 0000000..91aa065 --- /dev/null +++ b/src/ir/types.test.zig @@ -0,0 +1,121 @@ +// Tests for types.zig +const std = @import("std"); +const types = @import("types.zig"); +const TypeId = types.TypeId; +const TypeTable = types.TypeTable; +const TypeInfo = types.TypeInfo; + +test "builtin types pre-populated" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + // Verify builtin slots + try std.testing.expectEqual(TypeInfo.void, table.get(.void)); + try std.testing.expectEqual(TypeInfo.bool, table.get(.bool)); + try std.testing.expectEqual(TypeInfo{ .signed = 32 }, table.get(.s32)); + try std.testing.expectEqual(TypeInfo{ .unsigned = 8 }, table.get(.u8)); + try std.testing.expectEqual(TypeInfo.f64, table.get(.f64)); + try std.testing.expectEqual(TypeInfo.string, table.get(.string)); + try std.testing.expectEqual(TypeInfo.any, table.get(.any)); +} + +test "intern deduplicates structural types" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const ptr1 = table.ptrTo(.s32); + const ptr2 = table.ptrTo(.s32); + try std.testing.expectEqual(ptr1, ptr2); + + const ptr3 = table.ptrTo(.f64); + try std.testing.expect(ptr1 != ptr3); +} + +test "slice and array interning" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const slice1 = table.sliceOf(.s32); + const slice2 = table.sliceOf(.s32); + try std.testing.expectEqual(slice1, slice2); + + const arr1 = table.arrayOf(.u8, 10); + const arr2 = table.arrayOf(.u8, 10); + const arr3 = table.arrayOf(.u8, 20); + try std.testing.expectEqual(arr1, arr2); + try std.testing.expect(arr1 != arr3); +} + +test "optional interning" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const opt1 = table.optionalOf(.s32); + const opt2 = table.optionalOf(.s32); + try std.testing.expectEqual(opt1, opt2); + + const opt3 = table.optionalOf(.f64); + try std.testing.expect(opt1 != opt3); +} + +test "function type interning" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const params = &[_]TypeId{ .s32, .s32 }; + const fn1 = table.functionType(params, .s64); + const fn2 = table.functionType(params, .s64); + try std.testing.expectEqual(fn1, fn2); + + const fn3 = table.functionType(params, .f64); + try std.testing.expect(fn1 != fn3); +} + +test "string pool interning" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + const id1 = table.internString("Point"); + const id2 = table.internString("Point"); + const id3 = table.internString("Rect"); + + try std.testing.expectEqual(id1, id2); + try std.testing.expect(id1 != id3); + try std.testing.expectEqualStrings("Point", table.getString(id1)); + try std.testing.expectEqualStrings("Rect", table.getString(id3)); +} + +test "sizeOf builtins" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + try std.testing.expectEqual(@as(u32, 0), table.sizeOf(.void)); + try std.testing.expectEqual(@as(u32, 1), table.sizeOf(.bool)); + try std.testing.expectEqual(@as(u32, 4), table.sizeOf(.s32)); + try std.testing.expectEqual(@as(u32, 8), table.sizeOf(.s64)); + try std.testing.expectEqual(@as(u32, 1), table.sizeOf(.u8)); + try std.testing.expectEqual(@as(u32, 4), table.sizeOf(.f32)); + try std.testing.expectEqual(@as(u32, 8), table.sizeOf(.f64)); + try std.testing.expectEqual(@as(u32, 16), table.sizeOf(.string)); + try std.testing.expectEqual(@as(u32, 8), table.sizeOf(table.ptrTo(.s32))); + try std.testing.expectEqual(@as(u32, 16), table.sizeOf(table.sliceOf(.s32))); +} + +test "typeName for builtins" { + const alloc = std.testing.allocator; + var table = TypeTable.init(alloc); + defer table.deinit(); + + try std.testing.expectEqualStrings("s32", table.typeName(.s32)); + try std.testing.expectEqualStrings("bool", table.typeName(.bool)); + try std.testing.expectEqualStrings("string", table.typeName(.string)); + try std.testing.expectEqualStrings("void", table.typeName(.void)); + try std.testing.expectEqualStrings("Any", table.typeName(.any)); +} diff --git a/src/ir/types.zig b/src/ir/types.zig new file mode 100644 index 0000000..54850fe --- /dev/null +++ b/src/ir/types.zig @@ -0,0 +1,484 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; + +// ── TypeId ────────────────────────────────────────────────────────────── +// Opaque handle into the TypeTable. First 16 slots are reserved for builtins. + +pub const TypeId = enum(u32) { + // Builtin slots 0–15 + void = 0, + bool = 1, + s8 = 2, + s16 = 3, + s32 = 4, + s64 = 5, + u8 = 6, + u16 = 7, + u32 = 8, + u64 = 9, + f32 = 10, + f64 = 11, + string = 12, // [:0]u8 + any = 13, + noreturn = 14, + _reserved = 15, + _, // user-defined types start at 16 + + pub const first_user: u32 = 16; + + pub fn index(self: TypeId) u32 { + return @intFromEnum(self); + } + + pub fn fromIndex(i: u32) TypeId { + return @enumFromInt(i); + } + + pub fn isBuiltin(self: TypeId) bool { + return self.index() < first_user; + } +}; + +// ── TypeInfo ──────────────────────────────────────────────────────────── +// Resolved type information stored in the TypeTable. +// Unlike the AST-level `types.Type` which uses string names for references, +// TypeInfo uses TypeId handles, making it fully resolved and internable. + +pub const TypeInfo = union(enum) { + signed: u8, // bit width: 1–64 + unsigned: u8, + f32, + f64, + void, + bool, + string, // [:0]u8 — fat pointer {ptr, len} + + @"struct": StructInfo, + @"enum": EnumInfo, + @"union": UnionInfo, + array: ArrayInfo, + slice: SliceInfo, + pointer: PointerInfo, + many_pointer: ManyPointerInfo, + vector: VectorInfo, + function: FunctionInfo, + closure: ClosureInfo, + optional: OptionalInfo, + tuple: TupleInfo, + any, + protocol: ProtocolInfo, + noreturn, + + pub const StructInfo = struct { + name: StringId, + fields: []const Field, + + pub const Field = struct { + name: StringId, + ty: TypeId, + }; + }; + + pub const EnumInfo = struct { + name: StringId, + variants: []const StringId, + }; + + pub const UnionInfo = struct { + name: StringId, + fields: []const StructInfo.Field, + tag_type: ?TypeId, // tagged union enum type, null if untagged + }; + + pub const ArrayInfo = struct { + element: TypeId, + length: u32, + }; + + pub const SliceInfo = struct { + element: TypeId, + }; + + pub const PointerInfo = struct { + pointee: TypeId, + }; + + pub const ManyPointerInfo = struct { + element: TypeId, + }; + + pub const VectorInfo = struct { + element: TypeId, + length: u32, + }; + + pub const FunctionInfo = struct { + params: []const TypeId, + ret: TypeId, + }; + + pub const ClosureInfo = struct { + params: []const TypeId, + ret: TypeId, + }; + + pub const OptionalInfo = struct { + child: TypeId, + }; + + pub const TupleInfo = struct { + fields: []const TypeId, + names: ?[]const StringId, + }; + + pub const ProtocolInfo = struct { + name: StringId, + methods: []const Method, + + pub const Method = struct { + name: StringId, + sig: TypeId, // function type + }; + }; +}; + +// ── StringId ──────────────────────────────────────────────────────────── + +pub const StringId = enum(u32) { + empty = 0, + _, + + pub fn index(self: StringId) u32 { + return @intFromEnum(self); + } +}; + +// ── StringPool ────────────────────────────────────────────────────────── +// Intern strings for type/field/variant names. Deduplicates by content. + +pub const StringPool = struct { + /// Maps string content → StringId for dedup. Keys point to owned allocations in `strings`. + map: std.StringHashMap(StringId), + /// Owned string data indexed by StringId. Each entry is separately heap-allocated. + strings: std.ArrayList([]const u8), + next_id: u32, + + pub fn init(alloc: Allocator) StringPool { + var pool = StringPool{ + .map = std.StringHashMap(StringId).init(alloc), + .strings = std.ArrayList([]const u8).empty, + .next_id = 1, // 0 is reserved for empty + }; + // Slot 0 = empty string (not heap-allocated) + pool.strings.append(alloc, "") catch unreachable; + return pool; + } + + pub fn deinit(self: *StringPool, alloc: Allocator) void { + // Free heap-allocated strings (skip slot 0 which is a string literal) + for (self.strings.items[1..]) |s| { + alloc.free(@constCast(s)); + } + self.strings.deinit(alloc); + self.map.deinit(); + } + + pub fn intern(self: *StringPool, alloc: Allocator, str: []const u8) StringId { + if (str.len == 0) return .empty; + if (self.map.get(str)) |id| return id; + + const id: StringId = @enumFromInt(self.next_id); + self.next_id += 1; + + // Allocate a stable copy — used as both map key and lookup value + const owned = alloc.dupe(u8, str) catch unreachable; + self.strings.append(alloc, owned) catch unreachable; + self.map.put(owned, id) catch unreachable; + + return id; + } + + pub fn get(self: *const StringPool, id: StringId) []const u8 { + const idx = id.index(); + if (idx >= self.strings.items.len) return ""; + return self.strings.items[idx]; + } +}; + +// ── TypeTable ─────────────────────────────────────────────────────────── +// Holds all resolved types. Builtins in slots 0–15, user types interned from 16+. + +pub const TypeTable = struct { + infos: std.ArrayList(TypeInfo), + strings: StringPool, + /// Maps TypeInfo → TypeId for dedup of structural types + intern_map: std.HashMap(TypeKey, TypeId, TypeKeyContext, 80), + alloc: Allocator, + + pub fn init(alloc: Allocator) TypeTable { + var table = TypeTable{ + .infos = std.ArrayList(TypeInfo).empty, + .strings = StringPool.init(alloc), + .intern_map = std.HashMap(TypeKey, TypeId, TypeKeyContext, 80).init(alloc), + .alloc = alloc, + }; + + // Pre-populate builtin slots 0–15 (must match TypeId enum order) + const builtins = [_]TypeInfo{ + .void, // 0 + .bool, // 1 + .{ .signed = 8 }, // 2: s8 + .{ .signed = 16 }, // 3: s16 + .{ .signed = 32 }, // 4: s32 + .{ .signed = 64 }, // 5: s64 + .{ .unsigned = 8 }, // 6: u8 + .{ .unsigned = 16 }, // 7: u16 + .{ .unsigned = 32 }, // 8: u32 + .{ .unsigned = 64 }, // 9: u64 + .f32, // 10 + .f64, // 11 + .string, // 12 + .any, // 13 + .noreturn, // 14 + .void, // 15: reserved (placeholder) + }; + for (&builtins) |info| { + table.infos.append(alloc, info) catch unreachable; + } + + return table; + } + + pub fn deinit(self: *TypeTable) void { + self.infos.deinit(self.alloc); + self.strings.deinit(self.alloc); + self.intern_map.deinit(); + } + + /// Look up the TypeInfo for a given TypeId. + pub fn get(self: *const TypeTable, id: TypeId) TypeInfo { + return self.infos.items[id.index()]; + } + + /// Intern a TypeInfo, returning the existing TypeId if structurally equal. + pub fn intern(self: *TypeTable, info: TypeInfo) TypeId { + const key = TypeKey{ .info = info }; + if (self.intern_map.get(key)) |existing| { + return existing; + } + const id = TypeId.fromIndex(@intCast(self.infos.items.len)); + self.infos.append(self.alloc, info) catch unreachable; + self.intern_map.putNoClobber(key, id) catch unreachable; + return id; + } + + // ── Convenience constructors ──────────────────────────────────────── + + pub fn ptrTo(self: *TypeTable, pointee: TypeId) TypeId { + return self.intern(.{ .pointer = .{ .pointee = pointee } }); + } + + pub fn manyPtrTo(self: *TypeTable, element: TypeId) TypeId { + return self.intern(.{ .many_pointer = .{ .element = element } }); + } + + pub fn sliceOf(self: *TypeTable, element: TypeId) TypeId { + return self.intern(.{ .slice = .{ .element = element } }); + } + + pub fn arrayOf(self: *TypeTable, element: TypeId, length: u32) TypeId { + return self.intern(.{ .array = .{ .element = element, .length = length } }); + } + + pub fn optionalOf(self: *TypeTable, child: TypeId) TypeId { + return self.intern(.{ .optional = .{ .child = child } }); + } + + pub fn functionType(self: *TypeTable, params: []const TypeId, ret: TypeId) TypeId { + const owned_params = self.alloc.dupe(TypeId, params) catch unreachable; + return self.intern(.{ .function = .{ .params = owned_params, .ret = ret } }); + } + + pub fn closureType(self: *TypeTable, params: []const TypeId, ret: TypeId) TypeId { + const owned_params = self.alloc.dupe(TypeId, params) catch unreachable; + return self.intern(.{ .closure = .{ .params = owned_params, .ret = ret } }); + } + + pub fn vectorOf(self: *TypeTable, element: TypeId, length: u32) TypeId { + return self.intern(.{ .vector = .{ .element = element, .length = length } }); + } + + /// Size in bytes for a type (pointer-sized = 8 on 64-bit). + pub fn sizeOf(self: *const TypeTable, id: TypeId) u32 { + const info = self.get(id); + return switch (info) { + .void, .noreturn => 0, + .bool => 1, + .signed => |w| @max(1, w / 8), + .unsigned => |w| @max(1, w / 8), + .f32 => 4, + .f64 => 8, + .string => 16, // {ptr, len} + .pointer, .many_pointer, .function => 8, + .closure => 16, // {fn_ptr, env} + .optional => |opt| self.sizeOf(opt.child) + 8, // child + has_value flag (aligned) + .slice => 16, // {ptr, len} + .array => |arr| arr.length * self.sizeOf(arr.element), + .vector => |vec| vec.length * self.sizeOf(vec.element), + .any => 16, // {type_tag, data_ptr} + .@"struct", .@"union", .@"enum", .tuple, .protocol => { + // Sizes of composite types depend on layout — return 0 as placeholder. + // Real size computation needs struct layout info from codegen/sema. + return 0; + }, + }; + } + + /// Intern a string into the pool. + pub fn internString(self: *TypeTable, str: []const u8) StringId { + return self.strings.intern(self.alloc, str); + } + + /// Look up a string from its id. + pub fn getString(self: *const TypeTable, id: StringId) []const u8 { + return self.strings.get(id); + } + + /// Format a TypeId for display (e.g., "s32", "*bool", "[]u8"). + pub fn typeName(self: *const TypeTable, id: TypeId) []const u8 { + // Fast path for builtins + return switch (id) { + .void => "void", + .bool => "bool", + .s8 => "s8", + .s16 => "s16", + .s32 => "s32", + .s64 => "s64", + .u8 => "u8", + .u16 => "u16", + .u32 => "u32", + .u64 => "u64", + .f32 => "f32", + .f64 => "f64", + .string => "string", + .any => "Any", + .noreturn => "noreturn", + else => { + // User types — format from TypeInfo + const info = self.get(id); + return switch (info) { + .@"struct" => |s| self.getString(s.name), + .@"enum" => |e| self.getString(e.name), + .@"union" => |u| self.getString(u.name), + .protocol => |p| self.getString(p.name), + else => "?", + }; + }, + }; + } +}; + +// ── Intern map support ────────────────────────────────────────────────── +// We use a custom hash/eql context so structurally identical types dedup. + +const TypeKey = struct { + info: TypeInfo, +}; + +const TypeKeyContext = struct { + pub fn hash(_: TypeKeyContext, key: TypeKey) u64 { + var h = std.hash.Wyhash.init(0); + hashTypeInfo(&h, key.info); + return h.final(); + } + + pub fn eql(_: TypeKeyContext, a: TypeKey, b: TypeKey) bool { + return typeInfoEql(a.info, b.info); + } +}; + +fn hashTypeInfo(h: *std.hash.Wyhash, info: TypeInfo) void { + // Hash the tag + const tag: u8 = @intFromEnum(std.meta.activeTag(info)); + h.update(&.{tag}); + + switch (info) { + .signed => |w| h.update(&.{w}), + .unsigned => |w| h.update(&.{w}), + .f32, .f64, .void, .bool, .string, .any, .noreturn => {}, + .pointer => |p| h.update(std.mem.asBytes(&p.pointee)), + .many_pointer => |p| h.update(std.mem.asBytes(&p.element)), + .slice => |s| h.update(std.mem.asBytes(&s.element)), + .array => |a| { + h.update(std.mem.asBytes(&a.element)); + h.update(std.mem.asBytes(&a.length)); + }, + .vector => |v| { + h.update(std.mem.asBytes(&v.element)); + h.update(std.mem.asBytes(&v.length)); + }, + .optional => |o| h.update(std.mem.asBytes(&o.child)), + .function => |f| { + for (f.params) |p| h.update(std.mem.asBytes(&p)); + h.update(std.mem.asBytes(&f.ret)); + }, + .closure => |c| { + for (c.params) |p| h.update(std.mem.asBytes(&p)); + h.update(std.mem.asBytes(&c.ret)); + }, + .@"struct" => |s| h.update(std.mem.asBytes(&s.name)), + .@"enum" => |e| h.update(std.mem.asBytes(&e.name)), + .@"union" => |u| h.update(std.mem.asBytes(&u.name)), + .protocol => |p| h.update(std.mem.asBytes(&p.name)), + .tuple => |t| { + for (t.fields) |f| h.update(std.mem.asBytes(&f)); + }, + } +} + +fn typeInfoEql(a: TypeInfo, b: TypeInfo) bool { + const Tag = std.meta.Tag(TypeInfo); + const a_tag: Tag = a; + const b_tag: Tag = b; + if (a_tag != b_tag) return false; + + return switch (a) { + .signed => |w| w == b.signed, + .unsigned => |w| w == b.unsigned, + .f32, .f64, .void, .bool, .string, .any, .noreturn => true, + .pointer => |p| p.pointee == b.pointer.pointee, + .many_pointer => |p| p.element == b.many_pointer.element, + .slice => |s| s.element == b.slice.element, + .array => |ar| ar.element == b.array.element and ar.length == b.array.length, + .vector => |v| v.element == b.vector.element and v.length == b.vector.length, + .optional => |o| o.child == b.optional.child, + .function => |f| { + const g = b.function; + if (f.params.len != g.params.len) return false; + for (f.params, g.params) |fp, gp| { + if (fp != gp) return false; + } + return f.ret == g.ret; + }, + .closure => |c| { + const d = b.closure; + if (c.params.len != d.params.len) return false; + for (c.params, d.params) |cp, dp| { + if (cp != dp) return false; + } + return c.ret == d.ret; + }, + .@"struct" => |s| s.name == b.@"struct".name, + .@"enum" => |e| e.name == b.@"enum".name, + .@"union" => |u| u.name == b.@"union".name, + .protocol => |p| p.name == b.protocol.name, + .tuple => |t| { + const u = b.tuple; + if (t.fields.len != u.fields.len) return false; + for (t.fields, u.fields) |tf, uf| { + if (tf != uf) return false; + } + return true; + }, + }; +} diff --git a/src/main.zig b/src/main.zig index d2d9723..0e53ca6 100644 --- a/src/main.zig +++ b/src/main.zig @@ -91,6 +91,8 @@ pub fn main(init: std.process.Init) !void { std.debug.print("compiled: {s}\n", .{output_name}); } else if (std.mem.eql(u8, command, "ir")) { emitIR(allocator, io, path, target_config) catch return; + } else if (std.mem.eql(u8, command, "ir-dump")) { + dumpSxIR(allocator, io, path) catch return; } else if (std.mem.eql(u8, command, "asm")) { emitAsm(allocator, io, path, target_config) catch return; } else if (std.mem.eql(u8, command, "run")) { @@ -304,6 +306,24 @@ fn compilePipeline(allocator: std.mem.Allocator, io: std.Io, input_path: []const return comp; } +fn dumpSxIR(allocator: std.mem.Allocator, io: std.Io, input_path: []const u8) !void { + const source = try readSource(allocator, io, input_path); + var comp = sx.core.Compilation.init(allocator, io, input_path, source, .{}); + defer comp.deinit(); + + comp.parse() catch { comp.renderErrors(); return error.CompileError; }; + comp.resolveImports() catch { comp.renderErrors(); return error.CompileError; }; + + var ir_module = comp.lowerToIR(); + defer ir_module.deinit(); + + var aw = std.Io.Writer.Allocating.init(allocator); + sx.ir.printModule(&ir_module, &aw.writer) catch return; + var result = aw.writer.toArrayList(); + defer result.deinit(allocator); + std.debug.print("{s}", .{result.items}); +} + fn emitIR(allocator: std.mem.Allocator, io: std.Io, input_path: []const u8, target_config: sx.codegen.TargetConfig) !void { var timer = Timing.init(false); var comp = try compilePipeline(allocator, io, input_path, target_config, &timer); diff --git a/src/root.zig b/src/root.zig index 709f4dc..5df44cf 100644 --- a/src/root.zig +++ b/src/root.zig @@ -11,6 +11,7 @@ pub const sema = @import("sema.zig"); pub const imports = @import("imports.zig"); pub const core = @import("core.zig"); pub const c_import = @import("c_import.zig"); +pub const ir = @import("ir/ir.zig"); pub const lsp = struct { pub const server = @import("lsp/server.zig");