From 0e777e9d2ec92bcea7566f3d2daf7fcb5959234a Mon Sep 17 00:00:00 2001 From: agra Date: Sat, 14 Feb 2026 19:33:33 +0200 Subject: [PATCH] ... --- examples/28-sdl-graphics.sx | 19 +- examples/modules/sdl3.sx | 300 ++++++++++++++++++++++- src/ast.zig | 1 + src/codegen.zig | 463 ++++++++++++++++++++++++++++++++---- src/lsp/server.zig | 90 +++++++ src/parser.zig | 64 ++++- src/sema.zig | 92 ++++++- 7 files changed, 957 insertions(+), 72 deletions(-) diff --git a/examples/28-sdl-graphics.sx b/examples/28-sdl-graphics.sx index ff4fc56..2c44ab2 100644 --- a/examples/28-sdl-graphics.sx +++ b/examples/28-sdl-graphics.sx @@ -267,13 +267,22 @@ GLSL; // Render loop running := true; - event : [128]u8 = ---; + event : SDL_Event = .none; while running { - while SDL_PollEvent(xx @event[0]) { - etype : u32 = xx event[0]; - if etype == SDL_EVENT_QUIT { - running = false; + while SDL_PollEvent(event) { + if event == { + case .quit: running = false; + case .key_up: (e) { + if e.key == { + case .escape: running = false; + } + } + case .key_down: (e) { + k : u32 = xx e.key; + print("ts={} wid={} sc={} key={}\n", + e.timestamp, e.window_id, e.scancode, k); + } } } diff --git a/examples/modules/sdl3.sx b/examples/modules/sdl3.sx index 4b58003..1e261c3 100644 --- a/examples/modules/sdl3.sx +++ b/examples/modules/sdl3.sx @@ -20,9 +20,301 @@ SDL_GL_CONTEXT_PROFILE_CORE :s32: 0x1; // SDL_GLContextFlag SDL_GL_CONTEXT_FORWARD_COMPATIBLE_FLAG :s32: 0x2; -// SDL_EventType -SDL_EVENT_QUIT :u32: 0x100; -SDL_EVENT_KEY_DOWN :u32: 0x300; +// SDL_Keycode — virtual key codes (layout-dependent) +SDL_Keycode :: enum u32 { + // Common + unknown :: 0x00; + return_key :: 0x0D; + escape :: 0x1B; + backspace :: 0x08; + tab :: 0x09; + space :: 0x20; + delete_key :: 0x7F; + + // Punctuation + exclaim :: 0x21; + double_quote :: 0x22; + hash :: 0x23; + dollar :: 0x24; + percent :: 0x25; + ampersand :: 0x26; + apostrophe :: 0x27; + leftparen :: 0x28; + rightparen :: 0x29; + asterisk :: 0x2A; + plus :: 0x2B; + comma :: 0x2C; + minus :: 0x2D; + period :: 0x2E; + slash :: 0x2F; + colon :: 0x3A; + semicolon :: 0x3B; + less :: 0x3C; + equals :: 0x3D; + greater :: 0x3E; + question :: 0x3F; + at :: 0x40; + leftbracket :: 0x5B; + backslash :: 0x5C; + rightbracket :: 0x5D; + caret :: 0x5E; + underscore :: 0x5F; + grave :: 0x60; + leftbrace :: 0x7B; + pipe :: 0x7C; + rightbrace :: 0x7D; + tilde :: 0x7E; + plusminus :: 0xB1; + + // Numbers + key_0 :: 0x30; + key_1 :: 0x31; + key_2 :: 0x32; + key_3 :: 0x33; + key_4 :: 0x34; + key_5 :: 0x35; + key_6 :: 0x36; + key_7 :: 0x37; + key_8 :: 0x38; + key_9 :: 0x39; + + // Letters + a :: 0x61; + b :: 0x62; + c :: 0x63; + d :: 0x64; + e :: 0x65; + f :: 0x66; + g :: 0x67; + h :: 0x68; + i :: 0x69; + j :: 0x6A; + k :: 0x6B; + l :: 0x6C; + m :: 0x6D; + n :: 0x6E; + o :: 0x6F; + p :: 0x70; + q :: 0x71; + r :: 0x72; + s :: 0x73; + t :: 0x74; + u :: 0x75; + v :: 0x76; + w :: 0x77; + x :: 0x78; + y :: 0x79; + z :: 0x7A; + + // Function keys + f1 :: 0x4000003A; + f2 :: 0x4000003B; + f3 :: 0x4000003C; + f4 :: 0x4000003D; + f5 :: 0x4000003E; + f6 :: 0x4000003F; + f7 :: 0x40000040; + f8 :: 0x40000041; + f9 :: 0x40000042; + f10 :: 0x40000043; + f11 :: 0x40000044; + f12 :: 0x40000045; + f13 :: 0x40000068; + f14 :: 0x40000069; + f15 :: 0x4000006A; + f16 :: 0x4000006B; + f17 :: 0x4000006C; + f18 :: 0x4000006D; + f19 :: 0x4000006E; + f20 :: 0x4000006F; + f21 :: 0x40000070; + f22 :: 0x40000071; + f23 :: 0x40000072; + f24 :: 0x40000073; + + // Navigation + capslock :: 0x40000039; + printscreen :: 0x40000046; + scrolllock :: 0x40000047; + pause :: 0x40000048; + insert :: 0x40000049; + home :: 0x4000004A; + pageup :: 0x4000004B; + end :: 0x4000004D; + pagedown :: 0x4000004E; + right :: 0x4000004F; + left :: 0x40000050; + down :: 0x40000051; + up :: 0x40000052; + + // Keypad + numlock :: 0x40000053; + kp_divide :: 0x40000054; + kp_multiply :: 0x40000055; + kp_minus :: 0x40000056; + kp_plus :: 0x40000057; + kp_enter :: 0x40000058; + kp_1 :: 0x40000059; + kp_2 :: 0x4000005A; + kp_3 :: 0x4000005B; + kp_4 :: 0x4000005C; + kp_5 :: 0x4000005D; + kp_6 :: 0x4000005E; + kp_7 :: 0x4000005F; + kp_8 :: 0x40000060; + kp_9 :: 0x40000061; + kp_0 :: 0x40000062; + kp_period :: 0x40000063; + kp_equals :: 0x40000067; + kp_comma :: 0x40000085; + + // Modifiers + lctrl :: 0x400000E0; + lshift :: 0x400000E1; + lalt :: 0x400000E2; + lgui :: 0x400000E3; + rctrl :: 0x400000E4; + rshift :: 0x400000E5; + ralt :: 0x400000E6; + rgui :: 0x400000E7; + mode :: 0x40000101; + + // Editing + undo :: 0x4000007A; + cut :: 0x4000007B; + copy :: 0x4000007C; + paste :: 0x4000007D; + find :: 0x4000007E; + + // Media + mute :: 0x4000007F; + volumeup :: 0x40000080; + volumedown :: 0x40000081; + media_play :: 0x40000106; + media_pause :: 0x40000107; + media_fast_forward :: 0x40000109; + media_rewind :: 0x4000010A; + media_next_track :: 0x4000010B; + media_previous_track :: 0x4000010C; + media_stop :: 0x4000010D; + media_eject :: 0x4000010E; + media_play_pause :: 0x4000010F; + + // System + application :: 0x40000065; + power :: 0x40000066; + execute :: 0x40000074; + help :: 0x40000075; + menu :: 0x40000076; + select :: 0x40000077; + sleep :: 0x40000102; + wake :: 0x40000103; +} + +// Event payload structs — match SDL3 layout from byte 4 onward (after the u32 type tag) +// Common header: reserved (u32), timestamp (u64) + +SDL_WindowData :: struct { + timestamp: u64; // event time in nanoseconds + window_id: u32; + data1: s32; // event-dependent: x position for moved, width for resized + data2: s32; // event-dependent: y position for moved, height for resized +} + +SDL_Keymod :: enum flags u16 { + lshift :: 0b0000_0000_0000_0001; // left Shift + rshift :: 0b0000_0000_0000_0010; // right Shift + level5 :: 0b0000_0000_0000_0100; // Level 5 Shift + lctrl :: 0b0000_0000_0100_0000; // left Ctrl + rctrl :: 0b0000_0000_1000_0000; // right Ctrl + lalt :: 0b0000_0001_0000_0000; // left Alt + ralt :: 0b0000_0010_0000_0000; // right Alt + lgui :: 0b0000_0100_0000_0000; // left GUI (Windows/Cmd key) + rgui :: 0b0000_1000_0000_0000; // right GUI (Windows/Cmd key) + num :: 0b0001_0000_0000_0000; // Num Lock + caps :: 0b0010_0000_0000_0000; // Caps Lock + mode :: 0b0100_0000_0000_0000; // AltGr + scroll :: 0b1000_0000_0000_0000; // Scroll Lock +} + +SDL_KeyData :: struct { + timestamp: u64; // event time in nanoseconds + window_id: u32; // window with keyboard focus + which: u32; // keyboard instance id, 0 if unknown or virtual + scancode: u32; // physical key code (layout-independent) + key: SDL_Keycode; // virtual key code (layout-dependent) + mod: SDL_Keymod; // active modifier keys + raw: u16; // platform-specific scancode + down: u8; // 1 if pressed, 0 if released + repeat: u8; // 1 if this is a key repeat +} + +SDL_MouseMotionData :: struct { + timestamp: u64; // event time in nanoseconds + window_id: u32; // window with mouse focus + which: u32; // mouse instance id, 0 for touch events + state: u32; // button state bitmask (bit 0 = left, 1 = middle, 2 = right) + x: f32; // x position relative to window + y: f32; // y position relative to window + xrel: f32; // relative motion in x + yrel: f32; // relative motion in y +} + +SDL_MouseButtonData :: struct { + timestamp: u64; // event time in nanoseconds + window_id: u32; // window with mouse focus + which: u32; // mouse instance id, 0 for touch events + button: u8; // button index (1 = left, 2 = middle, 3 = right) + down: u8; // 1 if pressed, 0 if released + clicks: u8; // 1 for single-click, 2 for double-click, etc. + _: u8; + x: f32; // x position relative to window + y: f32; // y position relative to window +} + +SDL_MouseWheelData :: struct { + timestamp: u64; // event time in nanoseconds + window_id: u32; // window with mouse focus + which: u32; // mouse instance id + x: f32; // horizontal scroll (positive = right) + y: f32; // vertical scroll (positive = away from user) + direction: u32; // 0 = normal, 1 = flipped (multiply by -1 to normalize) + mouse_x: f32; // mouse x position relative to window + mouse_y: f32; // mouse y position relative to window +} + +SDL_Event :: enum struct { tag: u32; _: u32; payload: [30]u32; } { + none :: 0; + + // Application + quit :: 0x100; + + // Window + window_shown :: 0x202: SDL_WindowData; + window_hidden :: 0x203: SDL_WindowData; + window_exposed :: 0x204: SDL_WindowData; + window_moved :: 0x205: SDL_WindowData; + window_resized :: 0x206: SDL_WindowData; + window_minimized :: 0x209: SDL_WindowData; + window_maximized :: 0x20A: SDL_WindowData; + window_restored :: 0x20B: SDL_WindowData; + window_mouse_enter :: 0x20C: SDL_WindowData; + window_mouse_leave :: 0x20D: SDL_WindowData; + window_focus_gained :: 0x20E: SDL_WindowData; + window_focus_lost :: 0x20F: SDL_WindowData; + window_close_requested :: 0x210: SDL_WindowData; + window_destroyed :: 0x219: SDL_WindowData; + + // Keyboard + key_down :: 0x300: SDL_KeyData; + key_up :: 0x301: SDL_KeyData; + + // Mouse + mouse_motion :: 0x400: SDL_MouseMotionData; + mouse_button_down :: 0x401: SDL_MouseButtonData; + mouse_button_up :: 0x402: SDL_MouseButtonData; + mouse_wheel :: 0x403: SDL_MouseWheelData; +} // Functions SDL_Init :: (flags: u32) -> bool #foreign; @@ -36,6 +328,6 @@ SDL_GL_MakeCurrent :: (window: *void, context: *void) -> bool #foreign; SDL_GL_SwapWindow :: (window: *void) -> bool #foreign; SDL_GL_SetSwapInterval :: (interval: s32) -> bool #foreign; SDL_GL_GetProcAddress :: (proc: [:0]u8) -> *void #foreign; -SDL_PollEvent :: (event: *void) -> bool #foreign; +SDL_PollEvent :: (event: *SDL_Event) -> bool #foreign; SDL_GetTicks :: () -> u64 #foreign; SDL_Delay :: (ms: u32) -> void #foreign; diff --git a/src/ast.zig b/src/ast.zig index a7e130a..661f61e 100644 --- a/src/ast.zig +++ b/src/ast.zig @@ -196,6 +196,7 @@ pub const MatchArm = struct { pattern: ?*Node, // null = else (default) arm body: *Node, is_break: bool, + capture: ?[]const u8 = null, // payload binding name: case .variant: (name) { ... } }; pub const ConstDecl = struct { diff --git a/src/codegen.zig b/src/codegen.zig index be5b17a..c529cdb 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -247,8 +247,9 @@ pub const CodeGen = struct { const TaggedEnumInfo = struct { variant_names: []const []const u8, variant_types: []const Type, // void_type for void variants - llvm_type: c.LLVMTypeRef, // { i32, [max_payload_size x i8] } + llvm_type: c.LLVMTypeRef, // layout struct or { tag, [max_payload_size x i8] } max_payload_size: u64, + payload_field_index: c_uint = 1, // struct field index of the payload array }; const PromotedField = struct { @@ -281,13 +282,33 @@ pub const CodeGen = struct { const module = c.LLVMModuleCreateWithNameInContext(module_name, ctx); const builder = c.LLVMCreateBuilderInContext(ctx); - // Set target triple on module so it appears in IR output - if (target_config.triple) |t| { - c.LLVMSetTarget(module, t); + // Initialize LLVM targets and set data layout early so alignment queries work + llvm.initAllTargets(); + + const triple_owned = target_config.triple == null; + const triple = target_config.triple orelse c.LLVMGetDefaultTargetTriple(); + defer if (triple_owned) c.LLVMDisposeMessage(@constCast(triple)); + + c.LLVMSetTarget(module, triple); + + var target: c.LLVMTargetRef = null; + var err_msg: [*c]u8 = null; + if (c.LLVMGetTargetFromTriple(triple, &target, &err_msg) == 0) { + const tm = c.LLVMCreateTargetMachine( + target, + triple, + target_config.getCpu(), + target_config.getFeatures(), + target_config.opt_level.toLLVM(), + c.LLVMRelocPIC, + c.LLVMCodeModelDefault, + ); + const dl = c.LLVMCreateTargetDataLayout(tm); + c.LLVMSetModuleDataLayout(module, dl); + c.LLVMDisposeTargetData(dl); + c.LLVMDisposeTargetMachine(tm); } else { - const default_triple = c.LLVMGetDefaultTargetTriple(); - c.LLVMSetTarget(module, default_triple); - c.LLVMDisposeMessage(default_triple); + if (err_msg != null) c.LLVMDisposeMessage(err_msg); } return .{ .context = ctx, @@ -1285,6 +1306,7 @@ pub const CodeGen = struct { .variant_types = build.variant_sx_types, .llvm_type = build.llvm_type, .max_payload_size = build.max_payload_size, + .payload_field_index = build.payload_field_index, }); _ = try self.getAnyTypeId(mangled_name, .{ .union_type = mangled_name }); @@ -1928,6 +1950,16 @@ pub const CodeGen = struct { const info = self.struct_types.get(sname) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{sname}); const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, val, "retval"); _ = c.LLVMBuildRet(self.builder, loaded); + } else if (ret_sx_type.isUnion()) { + // Tagged enum implicit return: val may be alloca or loaded value + const uname = ret_sx_type.union_type; + const resolved = self.type_aliases.get(uname) orelse uname; + const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved}); + const ret_val2 = if (c.LLVMGetTypeKind(c.LLVMTypeOf(val)) == c.LLVMPointerTypeKind) + c.LLVMBuildLoad2(self.builder, info.llvm_type, val, "retval") + else + val; + _ = c.LLVMBuildRet(self.builder, ret_val2); } else { const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(val)); const ret_val = self.convertValue(val, src_ty, self.current_return_type); @@ -2090,6 +2122,17 @@ pub const CodeGen = struct { const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval"); try self.emitAllDefers(); _ = c.LLVMBuildRet(self.builder, loaded); + } else if (self.current_return_type.isUnion()) { + // Tagged enum return: raw_val may be alloca (enum literal) or loaded value (identifier/call) + const uname = self.current_return_type.union_type; + const resolved = self.type_aliases.get(uname) orelse uname; + const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved}); + const ret_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(raw_val)) == c.LLVMPointerTypeKind) + c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval") + else + raw_val; + try self.emitAllDefers(); + _ = c.LLVMBuildRet(self.builder, ret_val); } else { const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val)); const val = self.convertValue(raw_val, src_ty, self.current_return_type); @@ -2253,15 +2296,23 @@ pub const CodeGen = struct { try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty }); return null; } else if (vd.value.?.data == .call) { - // Call returning a union (e.g., Shape.circle(3.14)) — genExpr returns alloca - const result_alloca = try self.genExpr(vd.value.?); - const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result_alloca, "union_load"); - _ = c.LLVMBuildStore(self.builder, loaded, alloca); + // Call returning a union — could be enum construction (alloca) or function call (value) + const result = try self.genExpr(vd.value.?); + if (c.LLVMGetTypeKind(c.LLVMTypeOf(result)) == c.LLVMPointerTypeKind) { + const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result, "union_load"); + _ = c.LLVMBuildStore(self.builder, loaded, alloca); + } else { + _ = c.LLVMBuildStore(self.builder, result, alloca); + } } else { // Other expression — try genExprAsType - const result_alloca = try self.genExprAsType(vd.value.?, sx_ty); - const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result_alloca, "union_load"); - _ = c.LLVMBuildStore(self.builder, loaded, alloca); + const result = try self.genExprAsType(vd.value.?, sx_ty); + if (c.LLVMGetTypeKind(c.LLVMTypeOf(result)) == c.LLVMPointerTypeKind) { + const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result, "union_load"); + _ = c.LLVMBuildStore(self.builder, loaded, alloca); + } else { + _ = c.LLVMBuildStore(self.builder, result, alloca); + } } try self.saveShadowed(vd.name); @@ -2564,13 +2615,16 @@ pub const CodeGen = struct { return null; } - // Tagged enum reassignment: s = .circle(3.14) or s = .none + // Tagged enum reassignment: s = .circle(3.14) or s = .none or s = fn_call() if (entry.ty.isUnion() and asgn.op == .assign) { if (self.tagged_enum_types.get(entry.ty.union_type)) |info| { - const new_alloca = try self.genExprAsType(asgn.value, entry.ty); - // Copy from new alloca to existing alloca - const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load"); - _ = c.LLVMBuildStore(self.builder, loaded, entry.ptr); + const new_val = try self.genExprAsType(asgn.value, entry.ty); + // genExprAsType returns alloca for enum literals, loaded value for calls + const store_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(new_val)) == c.LLVMPointerTypeKind) + c.LLVMBuildLoad2(self.builder, info.llvm_type, new_val, "union_load") + else + new_val; + _ = c.LLVMBuildStore(self.builder, store_val, entry.ptr); return null; } // C-style union: full assignment not supported, use field assignment @@ -2860,6 +2914,33 @@ pub const CodeGen = struct { const lhs_ty = self.inferType(binop.lhs); const rhs_ty = self.inferType(binop.rhs); const result_type = Type.widen(lhs_ty, rhs_ty); + + // Tagged enum comparison: compare tags only + if (result_type.isUnion() and (binop.op == .eq or binop.op == .neq)) { + const uname = result_type.union_type; + const resolved = self.type_aliases.get(uname) orelse uname; + const info = self.tagged_enum_types.get(resolved) orelse return self.emitError("unknown tagged enum type"); + const tag_ty = self.getEnumLLVMType(resolved); + + var lhs_val = try self.genExprAsType(binop.lhs, result_type); + var rhs_val = try self.genExprAsType(binop.rhs, result_type); + + // If either side is a pointer (alloca from genTaggedEnumLiteral), load it + if (c.LLVMGetTypeKind(c.LLVMTypeOf(lhs_val)) == c.LLVMPointerTypeKind) { + lhs_val = c.LLVMBuildLoad2(self.builder, info.llvm_type, lhs_val, "union_load_l"); + } + if (c.LLVMGetTypeKind(c.LLVMTypeOf(rhs_val)) == c.LLVMPointerTypeKind) { + rhs_val = c.LLVMBuildLoad2(self.builder, info.llvm_type, rhs_val, "union_load_r"); + } + + // Extract tags (field 0) and compare + const lhs_tag = c.LLVMBuildExtractValue(self.builder, lhs_val, 0, "lhs_tag"); + const rhs_tag = c.LLVMBuildExtractValue(self.builder, rhs_val, 0, "rhs_tag"); + _ = tag_ty; + const pred: c_uint = if (binop.op == .eq) c.LLVMIntEQ else c.LLVMIntNE; + return c.LLVMBuildICmp(self.builder, pred, lhs_tag, rhs_tag, "tag_cmp"); + } + const lhs = try self.genExprAsType(binop.lhs, result_type); const rhs = try self.genExprAsType(binop.rhs, result_type); return self.genBinaryOp(binop.op, lhs, rhs, result_type); @@ -2895,6 +2976,15 @@ pub const CodeGen = struct { .xx, .address_of => unreachable, }; }, + .enum_literal => |el| { + if (self.current_return_type.isUnion()) { + return self.genTaggedEnumLiteral(el, self.current_return_type.union_type); + } + if (self.current_return_type.isEnum()) { + return self.genEnumLiteral(el.name, self.current_return_type.enum_type); + } + return self.emitError("cannot infer enum type for literal"); + }, .struct_literal => |sl| { const ctx_name: ?[]const u8 = if (self.current_return_type.isStruct()) self.current_return_type.struct_type else null; return self.genStructLiteral(sl, ctx_name); @@ -2976,10 +3066,29 @@ pub const CodeGen = struct { .return_stmt => |rs| { if (rs.value) |val_node| { const raw_val = try self.genExpr(val_node); - const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val)); - const val = self.convertValue(raw_val, src_ty, self.current_return_type); - try self.emitAllDefers(); - _ = c.LLVMBuildRet(self.builder, val); + if (self.current_return_type.isStruct()) { + const sname = self.current_return_type.struct_type; + const resolved = self.type_aliases.get(sname) orelse sname; + const sinfo = self.struct_types.get(resolved) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{resolved}); + const loaded = c.LLVMBuildLoad2(self.builder, sinfo.llvm_type, raw_val, "retval"); + try self.emitAllDefers(); + _ = c.LLVMBuildRet(self.builder, loaded); + } else if (self.current_return_type.isUnion()) { + const uname = self.current_return_type.union_type; + const resolved = self.type_aliases.get(uname) orelse uname; + const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved}); + const ret_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(raw_val)) == c.LLVMPointerTypeKind) + c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval") + else + raw_val; + try self.emitAllDefers(); + _ = c.LLVMBuildRet(self.builder, ret_val); + } else { + const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val)); + const val = self.convertValue(raw_val, src_ty, self.current_return_type); + try self.emitAllDefers(); + _ = c.LLVMBuildRet(self.builder, val); + } } else { try self.emitAllDefers(); _ = c.LLVMBuildRetVoid(self.builder); @@ -3111,6 +3220,7 @@ pub const CodeGen = struct { variant_sx_types: []const Type, llvm_type: c.LLVMTypeRef, max_payload_size: u64, + payload_field_index: c_uint, }; fn buildUnionFields(self: *CodeGen, name: []const u8, variant_type_nodes: []const ?*Node) !UnionBuildResult { @@ -3134,6 +3244,7 @@ pub const CodeGen = struct { const union_ty = c.LLVMStructCreateNamed(self.context, name_z.ptr); const tag_ty = self.getEnumLLVMType(name); const i8_ty = c.LLVMInt8TypeInContext(self.context); + const payload_array_ty = c.LLVMArrayType2(i8_ty, max_payload_size); var union_fields = [2]c.LLVMTypeRef{ tag_ty, payload_array_ty }; c.LLVMStructSetBody(union_ty, &union_fields, 2, 0); @@ -3142,6 +3253,7 @@ pub const CodeGen = struct { .variant_sx_types = try variant_sx_types.toOwnedSlice(self.allocator), .llvm_type = union_ty, .max_payload_size = max_payload_size, + .payload_field_index = 1, }; } @@ -3226,21 +3338,176 @@ pub const CodeGen = struct { } } - // Register backing type before buildUnionFields (which uses getEnumLLVMType) - if (ud.backing_type) |bt_node| { - const bt = self.resolveType(bt_node); - try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt)); + // Check if backing type is a struct layout specification + const layout_info = try self.resolveEnumLayout(ud); + + if (layout_info) |layout| { + // Struct-backed layout: use the struct's LLVM type directly + try self.enum_backing_types.put(ud.name, layout.tag_llvm_type); + + // Resolve variant sx types + var variant_sx_types = std.ArrayList(Type).empty; + for (ud.variant_types) |vt| { + if (vt) |type_node| { + try variant_sx_types.append(self.allocator, self.resolveType(type_node)); + } else { + try variant_sx_types.append(self.allocator, .void_type); + } + } + + try self.tagged_enum_types.put(ud.name, .{ + .variant_names = ud.variant_names, + .variant_types = try variant_sx_types.toOwnedSlice(self.allocator), + .llvm_type = layout.llvm_type, + .max_payload_size = layout.payload_size, + .payload_field_index = layout.payload_field_index, + }); + } else { + // Primitive backing type (e.g. enum u32 { ... }) + if (ud.backing_type) |bt_node| { + const bt = self.resolveType(bt_node); + try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt)); + } + + const build = try self.buildUnionFields(ud.name, ud.variant_types); + + try self.tagged_enum_types.put(ud.name, .{ + .variant_names = ud.variant_names, + .variant_types = build.variant_sx_types, + .llvm_type = build.llvm_type, + .max_payload_size = build.max_payload_size, + .payload_field_index = build.payload_field_index, + }); } - const build = try self.buildUnionFields(ud.name, ud.variant_types); - - try self.tagged_enum_types.put(ud.name, .{ - .variant_names = ud.variant_names, - .variant_types = build.variant_sx_types, - .llvm_type = build.llvm_type, - .max_payload_size = build.max_payload_size, - }); _ = try self.getAnyTypeId(ud.name, .{ .union_type = ud.name }); + + // Compute and store variant values (explicit or sequential) + const values = try self.allocator.alloc(i64, ud.variant_names.len); + for (ud.variant_names, 0..) |_, i| { + if (ud.variant_values.len > i and ud.variant_values[i] != null) { + const val_node = ud.variant_values[i].?; + values[i] = switch (val_node.data) { + .int_literal => |il| il.value, + else => @as(i64, @intCast(i)), + }; + } else { + values[i] = @intCast(i); + } + } + try self.enum_variant_values.put(ud.name, values); + } + + const EnumLayoutInfo = struct { + llvm_type: c.LLVMTypeRef, + tag_llvm_type: c.LLVMTypeRef, + payload_field_index: c_uint, + payload_size: u64, + }; + + /// Resolve a struct-backed layout for a tagged enum. + /// Returns null if the backing type is a primitive (e.g. u32), in which case + /// the caller should fall back to buildUnionFields. + /// + /// The layout struct must have: + /// - A field named `tag` (integer type) — the discriminant + /// - A field named `payload` (array type) — the overlay area for variant data + /// - Any other fields are treated as padding/reserved + fn resolveEnumLayout(self: *CodeGen, ud: ast.EnumDecl) !?EnumLayoutInfo { + const bt_node = ud.backing_type orelse return null; + + // Check for inline struct: enum struct { ... } { ... } + if (bt_node.data == .struct_decl) { + const layout_name = try std.fmt.allocPrint(self.allocator, "{s}.__layout", .{ud.name}); + var sd = bt_node.data.struct_decl; + sd.name = layout_name; + try self.registerStructType(sd); + return try self.validateEnumLayout(ud.name, layout_name); + } + + // Check for named struct reference: enum MyLayout { ... } + if (bt_node.data == .type_expr) { + const name = bt_node.data.type_expr.name; + // If it resolves to a primitive type, it's not a layout struct + if (Type.fromName(name) != null) return null; + // Check type aliases + const resolved = self.type_aliases.get(name) orelse name; + if (Type.fromName(resolved) != null) return null; + // Must be a registered struct + if (self.struct_types.contains(resolved)) { + return try self.validateEnumLayout(ud.name, resolved); + } + } + + return null; + } + + fn validateEnumLayout(self: *CodeGen, enum_name: []const u8, layout_name: []const u8) !EnumLayoutInfo { + const layout = self.struct_types.get(layout_name) orelse { + return self.emitErrorFmt("enum '{s}': layout type '{s}' is not a registered struct", .{ enum_name, layout_name }); + }; + + // Find 'tag' field + var tag_index: ?usize = null; + var payload_index: ?usize = null; + for (layout.field_names, 0..) |fname, i| { + if (std.mem.eql(u8, fname, "tag")) { + tag_index = i; + } else if (std.mem.eql(u8, fname, "payload")) { + payload_index = i; + } + } + + if (tag_index == null) { + return self.emitErrorFmt( + "enum '{s}': layout struct '{s}' must have a field named 'tag' (the discriminant). Expected layout: struct {{ tag: ; payload: [N]; }}", + .{ enum_name, layout_name }, + ); + } + if (payload_index == null) { + return self.emitErrorFmt( + "enum '{s}': layout struct '{s}' must have a field named 'payload' (the variant data area). Expected layout: struct {{ tag: ; payload: [N]; }}", + .{ enum_name, layout_name }, + ); + } + + const tag_ty = layout.field_types[tag_index.?]; + const payload_ty = layout.field_types[payload_index.?]; + + // Validate tag is an integer type + switch (tag_ty) { + .signed, .unsigned => {}, + else => return self.emitErrorFmt( + "enum '{s}': layout field 'tag' must be an integer type (e.g. u32), got '{s}'", + .{ enum_name, tag_ty.displayName(self.allocator) catch "?" }, + ), + } + + // Validate payload is an array type + const payload_size = switch (payload_ty) { + .array_type => |info| blk: { + const elem_ty = Type.fromName(info.element_name) orelse { + return self.emitErrorFmt( + "enum '{s}': layout field 'payload' has unresolved element type '{s}'", + .{ enum_name, info.element_name }, + ); + }; + const elem_llvm = self.typeToLLVM(elem_ty); + const data_layout = c.LLVMGetModuleDataLayout(self.module); + break :blk c.LLVMStoreSizeOfType(data_layout, elem_llvm) * info.length; + }, + else => return self.emitErrorFmt( + "enum '{s}': layout field 'payload' must be an array type (e.g. [30]u32), got '{s}'", + .{ enum_name, payload_ty.displayName(self.allocator) catch "?" }, + ), + }; + + return .{ + .llvm_type = layout.llvm_type, + .tag_llvm_type = self.typeToLLVM(tag_ty), + .payload_field_index = @intCast(payload_index.?), + .payload_size = payload_size, + }; } fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void { @@ -3320,16 +3587,17 @@ pub const CodeGen = struct { const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp"); const tag_ty = self.getEnumLLVMType(resolved_name); - // Store tag (field 0) + // Store tag (field 0) — use explicit value if available, otherwise index const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag"); - _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, idx, 0), tag_gep); + const tag_val: u64 = if (self.enum_variant_values.get(resolved_name)) |vals| @bitCast(vals[idx]) else idx; + _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, tag_val, 0), tag_gep); // Store payload (field 1) if not void if (el.payload) |payload_node| { const variant_ty = info.variant_types[idx]; if (variant_ty != .void_type) { const payload_val = try self.genExprAsType(payload_node, variant_ty); - const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload"); + const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, info.payload_field_index, "payload"); // genExprAsType returns a loaded value for all types (including structs) _ = c.LLVMBuildStore(self.builder, payload_val, payload_gep); } @@ -3572,7 +3840,7 @@ pub const CodeGen = struct { .type_expr = null, .field_inits = sl.field_inits, }, payload_struct_name); - const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload"); + const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, info.payload_field_index, "payload"); const payload_llvm_ty = self.typeToLLVM(variant_ty); const struct_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_alloca, "struct_load"); _ = c.LLVMBuildStore(self.builder, struct_val, payload_gep); @@ -3687,7 +3955,12 @@ pub const CodeGen = struct { std.mem.eql(u8, src_ty.struct_type, pointee_name) or (if (self.type_aliases.get(src_ty.struct_type)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or (if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, src_ty.struct_type) else false) - else if (Type.fromName(pointee_name)) |pointee_ty| + else if (src_ty.isUnion()) blk: { + const uname = src_ty.union_type; + break :blk std.mem.eql(u8, uname, pointee_name) or + (if (self.type_aliases.get(uname)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or + (if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, uname) else false); + } else if (Type.fromName(pointee_name)) |pointee_ty| src_ty.eql(pointee_ty) else false; @@ -4178,7 +4451,7 @@ pub const CodeGen = struct { // Read tag (field 0) const tag_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 0, "fv_tag_ptr"); const tag_val = c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(val_ty.union_type), tag_ptr, "fv_tag"); - const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 1, "fv_payload_ptr"); + const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, uinfo.payload_field_index, "fv_payload_ptr"); const n = uinfo.variant_names.len; const function = self.current_function; @@ -4371,6 +4644,72 @@ pub const CodeGen = struct { return phi; } + fn genFieldIndex(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef { + if (call_node.args.len != 2) return self.emitError("field_index expects 2 arguments: field_index(T, value)"); + const ty = self.resolveType(call_node.args[0]); + const i64_type = c.LLVMInt64TypeInContext(self.context); + if (!ty.isEnum()) { + _ = try self.genExpr(call_node.args[1]); + return c.LLVMConstInt(i64_type, 0, 0); + } + const enum_name = ty.enum_type; + // Flags enums don't use sequential indices + if (self.flags_enum_types.contains(enum_name)) { + _ = try self.genExpr(call_node.args[1]); + return c.LLVMConstInt(i64_type, 0, 0); + } + const values = self.enum_variant_values.get(enum_name); + const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]); + const n = variants.len; + + const val = try self.genExpr(call_node.args[1]); + // Ensure the switch value uses the enum's backing type + const enum_llvm_ty = self.getEnumLLVMType(enum_name); + const sw_val = if (c.LLVMTypeOf(val) != enum_llvm_ty) + c.LLVMBuildIntCast2(self.builder, val, enum_llvm_ty, 0, "fi_cast") + else + val; + + const function = self.current_function; + const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_merge"); + const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_default"); + const sw = c.LLVMBuildSwitch(self.builder, sw_val, default_bb, @intCast(n)); + + var phi_vals = std.ArrayList(c.LLVMValueRef).empty; + var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty; + var seen_values = std.ArrayList(u64).empty; + + for (0..n) |i| { + const explicit_val: u64 = if (values) |vals| @bitCast(vals[i]) else i; + // Skip duplicate values (first one wins) + var is_dup = false; + for (seen_values.items) |sv| { + if (sv == explicit_val) { is_dup = true; break; } + } + if (is_dup) continue; + try seen_values.append(self.allocator, explicit_val); + const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_case"); + c.LLVMAddCase(sw, c.LLVMConstInt(enum_llvm_ty, explicit_val, 0), case_bb); + c.LLVMPositionBuilderAtEnd(self.builder, case_bb); + try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, i, 0)); + try phi_bbs.append(self.allocator, case_bb); + _ = c.LLVMBuildBr(self.builder, merge_bb); + } + + c.LLVMPositionBuilderAtEnd(self.builder, default_bb); + const neg_one = c.LLVMConstInt(i64_type, @bitCast(@as(i64, -1)), 0); + try phi_vals.append(self.allocator, neg_one); + try phi_bbs.append(self.allocator, default_bb); + _ = c.LLVMBuildBr(self.builder, merge_bb); + + c.LLVMPositionBuilderAtEnd(self.builder, merge_bb); + const vals_slice = try phi_vals.toOwnedSlice(self.allocator); + const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator); + const phi = c.LLVMBuildPhi(self.builder, i64_type, "fi_result"); + c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len)); + return phi; + } + fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef { if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr"); const target_ty = self.resolveType(call_node.args[0]); @@ -4505,8 +4844,8 @@ pub const CodeGen = struct { const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ fa.field, uname }); const variant_ty = info.variant_types[idx]; if (variant_ty == .void_type) return self.emitErrorFmt("cannot access payload of void variant '{s}'", .{fa.field}); - // GEP to field 1 (payload area), load as variant type - const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 1, "payload"); + // GEP to payload area, load as variant type + const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, info.payload_field_index, "payload"); return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(variant_ty), payload_gep, "union_payload"); } if (entry.ty.isVector()) { @@ -6123,12 +6462,9 @@ pub const CodeGen = struct { // Determine subject type for enum vs union dispatch var enum_name: ?[]const u8 = null; var union_name: ?[]const u8 = null; - if (match.subject.data == .identifier) { - if (self.named_values.get(match.subject.data.identifier.name)) |entry| { - if (entry.ty.isEnum()) enum_name = entry.ty.enum_type; - if (entry.ty.isUnion()) union_name = entry.ty.union_type; - } - } + const subject_ty = self.inferType(match.subject); + if (subject_ty.isEnum()) enum_name = subject_ty.enum_type; + if (subject_ty.isUnion()) union_name = subject_ty.union_type; // Get the switch value: for unions, load the tag from field 0; for enums, use the value directly const subject_val: c.LLVMValueRef = if (union_name != null) blk: { @@ -6224,6 +6560,32 @@ pub const CodeGen = struct { // Category/type arm with no matching types — BB is unreachable, skip body _ = c.LLVMBuildBr(self.builder, merge_bb); } else { + // Payload capture: bind variant payload as a local variable + if (arm.capture) |cap_name| { + if (union_name) |un| { + const uinfo = self.tagged_enum_types.get(un).?; + const pat = arm.pattern.?; + if (pat.data == .enum_literal) { + const vname = pat.data.enum_literal.name; + var vidx: ?usize = null; + for (uinfo.variant_names, 0..) |vn, vi| { + if (std.mem.eql(u8, vn, vname)) { vidx = vi; break; } + } + if (vidx) |vi| { + const variant_ty = uinfo.variant_types[vi]; + if (variant_ty != .void_type) { + const subject_entry = self.named_values.get(match.subject.data.identifier.name).?; + const payload_gep = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, subject_entry.ptr, uinfo.payload_field_index, "cap_payload"); + const payload_llvm_ty = self.typeToLLVM(variant_ty); + const payload_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_gep, "cap_load"); + const cap_alloca = c.LLVMBuildAlloca(self.builder, payload_llvm_ty, @ptrCast(cap_name.ptr)); + _ = c.LLVMBuildStore(self.builder, payload_val, cap_alloca); + try self.named_values.put(cap_name, .{ .ptr = cap_alloca, .ty = variant_ty }); + } + } + } + } + } // Set match arm context for runtime type dispatch const saved_match_tags = self.current_match_tags; self.current_match_tags = arm_tag_values.items[i]; @@ -6428,6 +6790,7 @@ pub const CodeGen = struct { if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node); if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node); if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node); + if (std.mem.eql(u8, base, "field_index")) return self.genFieldIndex(call_node); return self.emitErrorFmt("unknown builtin function '{s}'", .{name}); } @@ -6663,6 +7026,8 @@ pub const CodeGen = struct { if (std.mem.eql(u8, base_name, "is_flags")) return .boolean; // Built-in: field_value_int returns s64 if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64); + // Built-in: field_index returns s64 + if (std.mem.eql(u8, base_name, "field_index")) return Type.s(64); // Built-in: cast returns the target type (first arg) if (std.mem.eql(u8, base_name, "cast")) { if (call_node.args.len > 0) return self.resolveType(call_node.args[0]); diff --git a/src/lsp/server.zig b/src/lsp/server.zig index 1a7caf4..79d37af 100644 --- a/src/lsp/server.zig +++ b/src/lsp/server.zig @@ -493,12 +493,75 @@ pub const Server = struct { } } } + } else { + // Bare dot (no prefix) — check if we're inside a match expression + // and offer the subject's enum variants (e.g. case .quit) + try self.collectMatchEnumCompletions(&items, doc, cursor_offset); } const items_json = try lsp.completionListJson(self.allocator, items.items); try self.sendResponse(id_json, items_json); } + fn collectMatchEnumCompletions(self: *Server, items: *std.ArrayList(lsp.CompletionItem), doc: *const Document, cursor_offset: u32) !void { + const root = doc.root orelse return; + const sema = doc.sema orelse return; + + // Find enclosing match expression's subject + const subject = sx.sema.findEnclosingMatchSubject(root, cursor_offset) orelse return; + + // Resolve the subject to an enum type name + const enum_name: ?[]const u8 = switch (subject.data) { + .identifier => |id| blk: { + // Look up variable type, then check if it's an enum + for (sema.symbols) |sym| { + if (!std.mem.eql(u8, sym.name, id.name)) continue; + if (sym.kind != .variable and sym.kind != .param) continue; + const ty = sym.ty orelse break; + break :blk switch (ty) { + .enum_type => |n| n, + .union_type => |n| n, + else => null, + }; + } + break :blk null; + }, + .field_access => |fa| blk: { + // e.g. e.key — resolve the field's type + if (fa.object.data == .identifier) { + const var_name = fa.object.data.identifier.name; + // Find variable's struct type, then look up the field type + for (sema.symbols) |sym| { + if (!std.mem.eql(u8, sym.name, var_name)) continue; + const ty = sym.ty orelse break; + const struct_name = switch (ty) { + .struct_type => |n| n, + else => break, + }; + // Look up the struct's field type + if (sema.struct_types.get(struct_name)) |info| { + for (info.field_names, 0..) |fname, fi| { + if (std.mem.eql(u8, fname, fa.field) and fi < info.field_types.len) { + break :blk switch (info.field_types[fi]) { + .enum_type => |n| n, + .union_type => |n| n, + else => null, + }; + } + } + } + break; + } + } + break :blk null; + }, + else => null, + }; + + const name = enum_name orelse return; + try self.collectMemberCompletions(items, sema, root, name); + } + fn collectDeclCompletions(allocator: std.mem.Allocator, items: *std.ArrayList(lsp.CompletionItem), decls: []const *sx.ast.Node) !void { for (decls) |decl| { switch (decl.data) { @@ -1643,6 +1706,33 @@ pub const Server = struct { if (bt.data == .type_expr) { try buf.appendSlice(allocator, bt.data.type_expr.name); try buf.appendSlice(allocator, " "); + } else if (bt.data == .struct_decl) { + const sd = bt.data.struct_decl; + try buf.appendSlice(allocator, "struct { "); + for (sd.field_names, 0..) |fn_, fi| { + if (fi > 0) try buf.appendSlice(allocator, "; "); + try buf.appendSlice(allocator, fn_); + try buf.appendSlice(allocator, ": "); + if (fi < sd.field_types.len) { + if (sd.field_types[fi].data == .type_expr) { + try buf.appendSlice(allocator, sd.field_types[fi].data.type_expr.name); + } else if (sd.field_types[fi].data == .array_type_expr) { + const ate = sd.field_types[fi].data.array_type_expr; + try buf.append(allocator, '['); + if (ate.length.data == .int_literal) { + const val = ate.length.data.int_literal.value; + var num_buf: [20]u8 = undefined; + const num_str = std.fmt.bufPrint(&num_buf, "{d}", .{val}) catch "?"; + try buf.appendSlice(allocator, num_str); + } + try buf.append(allocator, ']'); + if (ate.element_type.data == .type_expr) { + try buf.appendSlice(allocator, ate.element_type.data.type_expr.name); + } + } + } + } + try buf.appendSlice(allocator, " } "); } } try buf.appendSlice(allocator, "{ "); diff --git a/src/parser.zig b/src/parser.zig index ec5e156..6074cca 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -408,15 +408,23 @@ pub const Parser = struct { try variant_names.append(self.allocator, self.tokenSlice(self.current)); self.advance(); if (self.current.tag == .colon_colon) { - // Explicit value: name :: expr; - if (!is_flags) { - return self.fail("explicit enum values require 'enum flags'"); - } + // Explicit value: name :: expr; or name :: expr: type; self.advance(); const val_expr = try self.parseExpr(); try variant_values.append(self.allocator, val_expr); - try variant_types.append(self.allocator, null); has_any_value = true; + // Check for payload type after value: name :: 0x300: KeyData + if (self.current.tag == .colon) { + if (is_flags) { + return self.fail("flags enum variants cannot have payloads"); + } + self.advance(); + const vtype = try self.parseTypeExpr(); + try variant_types.append(self.allocator, vtype); + has_any_type = true; + } else { + try variant_types.append(self.allocator, null); + } } else if (self.current.tag == .colon) { // Typed variant: name: type; if (is_flags) { @@ -553,7 +561,12 @@ pub const Parser = struct { // All names in the group share the same type and default for (group_names.items) |fname| { - try field_names.append(self.allocator, fname); + // `_` is an ignore identifier — auto-rename to unique internal name + const actual_name = if (std.mem.eql(u8, fname, "_")) + try std.fmt.allocPrint(self.allocator, "_{d}", .{field_names.items.len}) + else + fname; + try field_names.append(self.allocator, actual_name); try field_types.append(self.allocator, field_type); try field_defaults.append(self.allocator, default_val); } @@ -892,6 +905,24 @@ pub const Parser = struct { return try self.createNode(start, .{ .insert_expr = .{ .expr = inner } }); } + // Block-form if/while/for as statements — parse directly to prevent + // postfix chaining (e.g. `if cond { ... }.field` being misparsed) + if (self.current.tag == .kw_if) { + const expr = try self.parseIfExpr(); + try self.expectSemicolonAfter(expr); + return expr; + } + if (self.current.tag == .kw_while) { + const expr = try self.parsePrimary(); + try self.expectSemicolonAfter(expr); + return expr; + } + if (self.current.tag == .kw_for) { + const expr = try self.parsePrimary(); + try self.expectSemicolonAfter(expr); + return expr; + } + // Expression statement const expr = try self.parseExpr(); @@ -1427,11 +1458,28 @@ pub const Parser = struct { } else try self.parsePrimary(); // .variant try self.expect(.colon); + // Optional payload capture: (ident) + var capture: ?[]const u8 = null; + if (self.current.tag == .l_paren) { + self.advance(); + if (self.current.tag != .identifier) return self.fail("expected capture name"); + capture = self.tokenSlice(self.current); + self.advance(); + try self.expect(.r_paren); + } + if (self.current.tag == .kw_break) { self.advance(); try self.expect(.semicolon); const body = try self.createNode(arm_start, .{ .block = .{ .stmts = &.{} } }); - try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = true }); + try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = true, .capture = capture }); + } else if (self.current.tag == .fat_arrow) { + // Short form: (ident) => expr; + self.advance(); + const expr = try self.parseExpr(); + try self.expect(.semicolon); + const body = try self.createNode(arm_start, .{ .block = .{ .stmts = try self.allocator.dupe(*Node, &.{expr}) } }); + try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = false, .capture = capture }); } else { const stmts_start = self.current.loc.start; var stmts = std.ArrayList(*Node).empty; @@ -1439,7 +1487,7 @@ pub const Parser = struct { try stmts.append(self.allocator, try self.parseStmt()); } const body = try self.createNode(stmts_start, .{ .block = .{ .stmts = try stmts.toOwnedSlice(self.allocator) } }); - try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = false }); + try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = false, .capture = capture }); } } // Optional else arm (default) diff --git a/src/sema.zig b/src/sema.zig index b0c142b..ff0a297 100644 --- a/src/sema.zig +++ b/src/sema.zig @@ -148,7 +148,7 @@ pub const Analyzer = struct { }); }, .const_decl => |cd| { - const ty = resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value); + const ty = self.resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value); const kind = classifyConstDecl(cd); try self.addSymbol(cd.name, kind, ty, node.span); // Populate type_aliases registry @@ -175,7 +175,7 @@ pub const Analyzer = struct { } }, .var_decl => |vd| { - const ty = resolveTypeAnnotation(vd.type_annotation); + const ty = self.resolveTypeAnnotation(vd.type_annotation); try self.addSymbol(vd.name, .variable, ty, node.span); }, .enum_decl => |ed| { @@ -581,7 +581,7 @@ pub const Analyzer = struct { .const_decl => |cd| { // Analyze value first (so it can't reference itself) try self.analyzeNode(cd.value); - const ty = resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value); + const ty = self.resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value); const kind = classifyConstDecl(cd); try self.addSymbol(cd.name, kind, ty, node.span); }, @@ -589,7 +589,7 @@ pub const Analyzer = struct { if (vd.value) |val| { try self.analyzeNode(val); } - const ty = resolveTypeAnnotation(vd.type_annotation) orelse + const ty = self.resolveTypeAnnotation(vd.type_annotation) orelse if (vd.value) |val| self.inferExprType(val) else null; try self.addSymbol(vd.name, .variable, ty, node.span); }, @@ -637,7 +637,12 @@ pub const Analyzer = struct { .match_expr => |me| { try self.analyzeNode(me.subject); for (me.arms) |arm| { + try self.pushScope(); + if (arm.capture) |cap_name| { + try self.addSymbol(cap_name, .variable, null, arm.body.span); + } try self.analyzeNode(arm.body); + self.popScope(); } }, .while_expr => |we| { @@ -773,9 +778,19 @@ pub const Analyzer = struct { return null; } - fn resolveTypeAnnotation(type_node: ?*Node) ?Type { + fn resolveTypeAnnotation(self: *Analyzer, type_node: ?*Node) ?Type { if (type_node) |tn| { - return Type.fromTypeExpr(tn); + if (Type.fromTypeExpr(tn)) |t| return t; + // Check registered types (structs, enums, tagged enums) + if (tn.data == .type_expr) { + const name = tn.data.type_expr.name; + // Check type aliases first + const resolved = self.type_aliases.get(name) orelse name; + for (self.symbols.items) |sym| { + if (!std.mem.eql(u8, sym.name, resolved)) continue; + if (sym.ty) |ty| return ty; + } + } } return null; } @@ -986,6 +1001,71 @@ pub fn findNodeAtOffset(node: *Node, offset: u32) ?*Node { return node; } +/// Find the nearest match_expr ancestor that contains the given offset. +/// Returns the match subject node if found, null otherwise. +pub fn findEnclosingMatchSubject(node: *Node, offset: u32) ?*Node { + if (offset < node.span.start or offset >= node.span.end) return null; + + switch (node.data) { + .match_expr => |me| { + // First recurse into arm bodies — there might be a nested match + for (me.arms) |arm| { + if (findEnclosingMatchSubject(arm.body, offset)) |inner| return inner; + } + // If offset is inside this match_expr (but not in the subject itself), + // it's in an arm pattern, between arms, or in a partially-typed arm + if (me.subject.span.start <= offset and offset < me.subject.span.end) { + // Cursor is on the subject itself, not in an arm + } else { + return me.subject; + } + }, + .root => |r| { + for (r.decls) |decl| { + if (findEnclosingMatchSubject(decl, offset)) |found| return found; + } + }, + .fn_decl => |fd| { + if (findEnclosingMatchSubject(fd.body, offset)) |found| return found; + }, + .block => |blk| { + for (blk.stmts) |stmt| { + if (findEnclosingMatchSubject(stmt, offset)) |found| return found; + } + }, + .if_expr => |ie| { + if (findEnclosingMatchSubject(ie.then_branch, offset)) |found| return found; + if (ie.else_branch) |eb| { + if (findEnclosingMatchSubject(eb, offset)) |found| return found; + } + }, + .while_expr => |we| { + if (findEnclosingMatchSubject(we.body, offset)) |found| return found; + }, + .for_expr => |fe| { + if (findEnclosingMatchSubject(fe.body, offset)) |found| return found; + }, + .const_decl => |cd| { + if (findEnclosingMatchSubject(cd.value, offset)) |found| return found; + }, + .var_decl => |vd| { + if (vd.value) |val| { + if (findEnclosingMatchSubject(val, offset)) |found| return found; + } + }, + .lambda => |lam| { + if (findEnclosingMatchSubject(lam.body, offset)) |found| return found; + }, + .namespace_decl => |ns| { + for (ns.decls) |decl| { + if (findEnclosingMatchSubject(decl, offset)) |found| return found; + } + }, + else => {}, + } + return null; +} + test "sema: collect top-level declarations" { const parser_mod = @import("parser.zig");