diff --git a/examples/21-categories.sx b/examples/21-categories.sx index 205f7f5..2c207bb 100644 --- a/examples/21-categories.sx +++ b/examples/21-categories.sx @@ -11,12 +11,12 @@ Color :: struct { main :: () { p := Point.{10, 20}; c := Color.{255, 128, 0}; - pc := &p; + pc := @p; print("p: {}\n", p); print("c: {}\n", c); print("n: {}\n", 42); print("s: {}\n", "hello"); print("b: {}\n", true); print("&p: {}\n", pc); - print("&p: {}\n", &p); + print("&p: {}\n", @p); } diff --git a/examples/26-pointers.sx b/examples/26-pointers.sx index 3a380ff..483bf26 100644 --- a/examples/26-pointers.sx +++ b/examples/26-pointers.sx @@ -10,10 +10,10 @@ main :: () { v := Vec2.{ 1.0, 2.0 }; print("before: {}\n", v); - set_x(&v, 99.0); + set_x(@v, 99.0); print("after: {}\n", v); - ptr := &v; + ptr := @v; copy := ptr.*; print("copy: {}\n", copy); @@ -22,7 +22,7 @@ main :: () { // many-pointer indexing arr : [5]s32 = .[10, 20, 30, 40, 50]; - mp : [*]s32 = &arr[0]; + mp : [*]s32 = @arr[0]; print("mp[0] = {}\n", mp[0]); print("mp[2] = {}\n", mp[2]); } diff --git a/examples/28-sdl-graphics.sx b/examples/28-sdl-graphics.sx index e1cc2e8..ff4fc56 100644 --- a/examples/28-sdl-graphics.sx +++ b/examples/28-sdl-graphics.sx @@ -270,7 +270,7 @@ GLSL; event : [128]u8 = ---; while running { - while SDL_PollEvent(xx &event[0]) { + while SDL_PollEvent(xx @event[0]) { etype : u32 = xx event[0]; if etype == SDL_EVENT_QUIT { running = false; diff --git a/examples/31-flags.sx b/examples/31-flags.sx index 3e74f96..2d5ef9b 100644 --- a/examples/31-flags.sx +++ b/examples/31-flags.sx @@ -7,13 +7,16 @@ Perms :: enum flags { execute; } -// Explicit values (e.g. for C interop) -WindowFlags :: enum flags { +// Explicit values with u32 backing type (e.g. for C interop) +WindowFlags :: enum flags u32 { vsync :: 64; resizable :: 4; hidden :: 128; } +// Backing type on plain enums too +Color :: enum u8 { red; green; blue; } + check_perms :: (p: Perms) { print(" checking: {}\n", p); if p & .read { print(" - can read\n"); } @@ -53,6 +56,11 @@ main :: () { print("\nwindow: {}\n", w); print("raw value: {}\n", cast(s64) w); + // Backing type on plain enums + c :Color = .blue; + print("\ncolor: {}\n", c); + print("raw: {}\n", cast(s64) c); + // Bitwise ops work on plain integers too x := 0xFF & 0x0F; y := 1 | 2 | 4; diff --git a/specs.md b/specs.md index 53562cd..0d0ea98 100644 --- a/specs.md +++ b/specs.md @@ -291,10 +291,10 @@ word := msg[6..11]; // string → "world" | `*[N]T` | pointer to array of N T | yes | yes | | `*[]T` | pointer to slice | yes | yes | -**Address-of**: `&x` returns a pointer to the variable. +**Address-of**: `@x` returns a pointer to the variable. ```sx v := Vec2.{ 1.0, 2.0 }; -ptr := &v; // *Vec2 +ptr := @v; // *Vec2 ``` **Dereference**: `p.*` loads the value through the pointer. @@ -307,7 +307,7 @@ copy := ptr.*; // Vec2 set_x :: (p: *Vec2, val: f32) { p.x = val; // auto-deref: p.*.x = val } -set_x(&v, 99.0); +set_x(@v, 99.0); ``` **Null**: All pointer types are nullable. `null` is the null pointer literal. @@ -318,7 +318,7 @@ np : *Vec2 = null; **Many-pointer**: `[*]T` supports indexing for buffers of unknown size. ```sx arr : [5]s32 = .[10, 20, 30, 40, 50]; -mp : [*]s32 = &arr[0]; // *s32 → [*]s32 implicit +mp : [*]s32 = @arr[0]; // *s32 → [*]s32 implicit val := mp[2]; // 30 ``` @@ -555,6 +555,19 @@ Name :: enum { Defines a new enum type with the given variants. Trailing comma is allowed. +### Enum Backing Type + +An optional backing type can be specified after the `enum` keyword (Jai-style): + +```sx +Color :: enum u8 { red; green; blue; } +Status :: enum s16 { ok; error; timeout; } +``` + +Syntax: `Name :: enum [flags] [type] { ... }` + +The backing type must be an integer type (`u8`, `u16`, `u32`, `s8`, `s16`, `s32`, `s64`, etc.). When omitted, the default is `s64`. This is useful for C interop (matching C enum sizes) and memory efficiency. + ### Enum Flags ```sx @@ -565,6 +578,15 @@ Perms :: enum flags { } ``` +Flags can also specify a backing type: + +```sx +SDL_InitFlags :: enum flags u32 { + video :: 0x20; + audio :: 0x10; +} +``` + The `flags` modifier assigns auto power-of-2 values (1, 2, 4, 8, ...) instead of sequential indices (0, 1, 2, ...). Flags can be combined with `|` and tested with `&`: ```sx diff --git a/src/ast.zig b/src/ast.zig index 9a27524..a7e130a 100644 --- a/src/ast.zig +++ b/src/ast.zig @@ -231,6 +231,7 @@ pub const EnumDecl = struct { variant_types: []const ?*Node = &.{}, // null entries = no payload; empty = payload-less enum is_flags: bool = false, variant_values: []const ?*Node = &.{}, // explicit value per variant (null = auto), empty = all auto + backing_type: ?*Node = null, // optional backing type: enum u8 { ... } }; pub const UnionDecl = struct { diff --git a/src/codegen.zig b/src/codegen.zig index 698aada..be5b17a 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -116,6 +116,8 @@ pub const CodeGen = struct { flags_enum_types: std.StringHashMap(void), // Enum variant values: maps enum name → resolved i64 values per variant enum_variant_values: std.StringHashMap([]const i64), + // Enum backing types: maps enum name → LLVM type for the backing integer (default i64) + enum_backing_types: std.StringHashMap(c.LLVMTypeRef), // Built-in functions (printf, etc.) builtins: ?Builtins, // Current function being generated (for alloca insertion) @@ -300,6 +302,7 @@ pub const CodeGen = struct { .union_types = std.StringHashMap(UnionInfo).init(allocator), .flags_enum_types = std.StringHashMap(void).init(allocator), .enum_variant_values = std.StringHashMap([]const i64).init(allocator), + .enum_backing_types = std.StringHashMap(c.LLVMTypeRef).init(allocator), .builtins = null, .current_function = null, .scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty, @@ -333,6 +336,7 @@ pub const CodeGen = struct { self.tagged_enum_types.deinit(); self.union_types.deinit(); self.comptime_globals.deinit(); + self.enum_backing_types.deinit(); self.generic_templates.deinit(); self.generic_instances.deinit(); self.generic_struct_templates.deinit(); @@ -384,7 +388,7 @@ pub const CodeGen = struct { .void_type => c.LLVMVoidTypeInContext(self.context), .boolean => c.LLVMInt1TypeInContext(self.context), .string_type, .slice_type => self.getStringStructType(), // slices use same {ptr, i32} layout - .enum_type => c.LLVMInt64TypeInContext(self.context), + .enum_type => |name| self.getEnumLLVMType(name), .struct_type => |name| if (self.struct_types.get(name)) |info| info.llvm_type else unreachable, .union_type => |name| if (self.tagged_enum_types.get(name)) |info| info.llvm_type else if (self.union_types.get(name)) |info| info.llvm_type else unreachable, .array_type => |info| { @@ -401,6 +405,11 @@ pub const CodeGen = struct { }; } + fn getEnumLLVMType(self: *CodeGen, enum_name: []const u8) c.LLVMTypeRef { + if (self.enum_backing_types.get(enum_name)) |llvm_ty| return llvm_ty; + return c.LLVMInt64TypeInContext(self.context); + } + fn getAnyStructType(self: *CodeGen) c.LLVMTypeRef { if (self.any_struct_type) |t| return t; var field_types = [_]c.LLVMTypeRef{ @@ -558,9 +567,14 @@ pub const CodeGen = struct { _ = c.LLVMBuildStore(self.builder, val, alloca); break :blk c.LLVMBuildPtrToInt(self.builder, alloca, i64_ty, "any_struct"); }, - .enum_type => blk: { - // Enum is i32 tag — extend to i64 - break :blk c.LLVMBuildZExt(self.builder, val, i64_ty, "any_enum"); + .enum_type => |ename| blk: { + // Enum — extend to i64 for Any storage (no-op if already i64) + const enum_llvm_ty = self.getEnumLLVMType(ename); + const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty); + if (enum_bits < 64) + break :blk c.LLVMBuildZExt(self.builder, val, i64_ty, "any_enum") + else + break :blk val; }, .union_type => |uname| blk: { // Union — store to alloca, pass pointer as i64 @@ -720,6 +734,12 @@ pub const CodeGen = struct { try self.flags_enum_types.put(ed.name, {}); } + // Register backing type if specified + if (ed.backing_type) |bt_node| { + const bt = self.resolveType(bt_node); + try self.enum_backing_types.put(ed.name, self.typeToLLVM(bt)); + } + // Compute and store variant values const values = try self.allocator.alloc(i64, ed.variant_names.len); for (ed.variant_names, 0..) |_, i| { @@ -1634,6 +1654,10 @@ pub const CodeGen = struct { } else { const qualified = try std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns.name, ed.name }); try self.enum_types.put(qualified, ed.variant_names); + if (ed.backing_type) |bt_node| { + const bt = self.resolveType(bt_node); + try self.enum_backing_types.put(qualified, self.typeToLLVM(bt)); + } } }, .struct_decl => |sd| { @@ -3108,10 +3132,10 @@ pub const CodeGen = struct { const name_z = try self.allocator.dupeZ(u8, name); const union_ty = c.LLVMStructCreateNamed(self.context, name_z.ptr); - const i64_ty = c.LLVMInt64TypeInContext(self.context); + const tag_ty = self.getEnumLLVMType(name); const i8_ty = c.LLVMInt8TypeInContext(self.context); const payload_array_ty = c.LLVMArrayType2(i8_ty, max_payload_size); - var union_fields = [2]c.LLVMTypeRef{ i64_ty, payload_array_ty }; + var union_fields = [2]c.LLVMTypeRef{ tag_ty, payload_array_ty }; c.LLVMStructSetBody(union_ty, &union_fields, 2, 0); return .{ @@ -3145,6 +3169,10 @@ pub const CodeGen = struct { } else { try self.enum_types.put(synthetic_name, inline_ed.variant_names); _ = try self.getAnyTypeId(synthetic_name, .{ .enum_type = synthetic_name }); + if (inline_ed.backing_type) |bt_node| { + const bt = self.resolveType(bt_node); + try self.enum_backing_types.put(synthetic_name, self.typeToLLVM(bt)); + } } type_node.data = .{ .type_expr = .{ .name = synthetic_name } }; }, @@ -3198,6 +3226,12 @@ pub const CodeGen = struct { } } + // Register backing type before buildUnionFields (which uses getEnumLLVMType) + if (ud.backing_type) |bt_node| { + const bt = self.resolveType(bt_node); + try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt)); + } + const build = try self.buildUnionFields(ud.name, ud.variant_types); try self.tagged_enum_types.put(ud.name, .{ @@ -3284,11 +3318,11 @@ pub const CodeGen = struct { // Alloca union const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp"); - const i64_ty = c.LLVMInt64TypeInContext(self.context); + const tag_ty = self.getEnumLLVMType(resolved_name); // Store tag (field 0) const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag"); - _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(i64_ty, idx, 0), tag_gep); + _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, idx, 0), tag_gep); // Store payload (field 1) if not void if (el.payload) |payload_node| { @@ -3526,9 +3560,9 @@ pub const CodeGen = struct { // Alloca union, store tag const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_lit"); - const i32_ty = c.LLVMInt32TypeInContext(self.context); + const tag_llvm_ty = self.getEnumLLVMType(uname); const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag"); - _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(i32_ty, idx, 0), tag_gep); + _ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_llvm_ty, idx, 0), tag_gep); // Store struct payload if (variant_ty != .void_type) { @@ -3750,6 +3784,9 @@ pub const CodeGen = struct { } } if (target_ty.isEnum()) { + const enum_llvm_ty = self.getEnumLLVMType(target_ty.enum_type); + const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty); + if (enum_bits < 64) return c.LLVMBuildTrunc(self.builder, i64_val, enum_llvm_ty, "any_to_enum"); return i64_val; } if (target_ty.isUnion()) { @@ -3805,12 +3842,14 @@ pub const CodeGen = struct { if (src_ty.isUnion() and target_ty.isInt()) { const uname = src_ty.union_type; if (self.tagged_enum_types.get(uname)) |info| { + const tag_llvm_ty = self.getEnumLLVMType(uname); + const tag_bits = c.LLVMGetIntTypeWidth(tag_llvm_ty); const tmp = self.buildEntryBlockAlloca(info.llvm_type, "union_cast"); _ = c.LLVMBuildStore(self.builder, val, tmp); const tag_ptr = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, tmp, 0, "tag_ptr"); - const tag_val = c.LLVMBuildLoad2(self.builder, c.LLVMInt32TypeInContext(self.context), tag_ptr, "tag_val"); - if (target_ty.bitWidth() == 32) return tag_val; - if (target_ty.bitWidth() > 32) return c.LLVMBuildSExt(self.builder, tag_val, target_llvm, "tag_ext"); + const tag_val = c.LLVMBuildLoad2(self.builder, tag_llvm_ty, tag_ptr, "tag_val"); + if (target_ty.bitWidth() == tag_bits) return tag_val; + if (target_ty.bitWidth() > tag_bits) return c.LLVMBuildSExt(self.builder, tag_val, target_llvm, "tag_ext"); return c.LLVMBuildTrunc(self.builder, tag_val, target_llvm, "tag_trunc"); } } @@ -3844,6 +3883,31 @@ pub const CodeGen = struct { return c.LLVMBuildExtractValue(self.builder, val, 0, "slice_to_ptr"); } + // Enum → int: extend or truncate from backing type to target int + if (src_ty.isEnum() and target_ty.isInt()) { + const enum_bits = c.LLVMGetIntTypeWidth(self.getEnumLLVMType(src_ty.enum_type)); + const target_bits = target_ty.bitWidth(); + if (target_bits > enum_bits) { + return c.LLVMBuildZExt(self.builder, val, target_llvm, "enum_to_int"); + } else if (target_bits < enum_bits) { + return c.LLVMBuildTrunc(self.builder, val, target_llvm, "enum_to_int"); + } + return val; + } + + // Int → enum: extend or truncate from source int to backing type + if (src_ty.isInt() and target_ty.isEnum()) { + const enum_llvm_ty = self.getEnumLLVMType(target_ty.enum_type); + const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty); + const src_bits = src_ty.bitWidth(); + if (enum_bits > src_bits) { + return c.LLVMBuildZExt(self.builder, val, enum_llvm_ty, "int_to_enum"); + } else if (enum_bits < src_bits) { + return c.LLVMBuildTrunc(self.builder, val, enum_llvm_ty, "int_to_enum"); + } + return val; + } + // *[N]T → [*]T: pointer to array decays to many-pointer (both opaque ptrs, no-op) if (src_ty.isPointer() and target_ty.isManyPointer()) { return val; @@ -4113,7 +4177,7 @@ pub const CodeGen = struct { // Read tag (field 0) const tag_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 0, "fv_tag_ptr"); - const tag_val = c.LLVMBuildLoad2(self.builder, c.LLVMInt64TypeInContext(self.context), tag_ptr, "fv_tag"); + const tag_val = c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(val_ty.union_type), tag_ptr, "fv_tag"); const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 1, "fv_payload_ptr"); const n = uinfo.variant_names.len; @@ -4126,9 +4190,10 @@ pub const CodeGen = struct { var phi_vals = std.ArrayList(c.LLVMValueRef).empty; var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty; + const tag_llvm_ty = self.getEnumLLVMType(val_ty.union_type); for (uinfo.variant_types, 0..) |vty, vi| { const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fv_ucase"); - c.LLVMAddCase(sw, c.LLVMConstInt(c.LLVMInt64TypeInContext(self.context), @intCast(vi), 0), case_bb); + c.LLVMAddCase(sw, c.LLVMConstInt(tag_llvm_ty, @intCast(vi), 0), case_bb); c.LLVMPositionBuilderAtEnd(self.builder, case_bb); const any_val = if (vty == .void_type) blk: { @@ -5608,7 +5673,14 @@ pub const CodeGen = struct { const ptr = c.LLVMBuildIntToPtr(self.builder, any_i64, c.LLVMPointerTypeInContext(self.context, 0), "any_to_struct_ptr"); break :blk c.LLVMBuildLoad2(self.builder, info.llvm_type, ptr, "any_to_struct"); }, - .enum_type => any_i64, + .enum_type => |ename| blk: { + const enum_llvm_ty = self.getEnumLLVMType(ename); + const enum_bits = c.LLVMGetIntTypeWidth(enum_llvm_ty); + if (enum_bits < 64) + break :blk c.LLVMBuildTrunc(self.builder, any_i64, enum_llvm_ty, "any_to_enum") + else + break :blk any_i64; + }, .union_type => |uname| blk: { const info = self.tagged_enum_types.get(uname) orelse return self.emitErrorFmt("unknown enum '{s}'", .{uname}); const ptr = c.LLVMBuildIntToPtr(self.builder, any_i64, c.LLVMPointerTypeInContext(self.context, 0), "any_to_union_ptr"); @@ -6018,16 +6090,16 @@ pub const CodeGen = struct { } fn genEnumLiteral(self: *CodeGen, variant_name: []const u8, enum_type_name: []const u8) c.LLVMValueRef { - const i64_type = c.LLVMInt64TypeInContext(self.context); - const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(i64_type, 0, 0); + const enum_ty = self.getEnumLLVMType(enum_type_name); + const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(enum_ty, 0, 0); const values = self.enum_variant_values.get(enum_type_name); for (variants, 0..) |v, i| { if (std.mem.eql(u8, v, variant_name)) { const val: u64 = if (values) |vals| @bitCast(vals[i]) else @intCast(i); - return c.LLVMConstInt(i64_type, val, 0); + return c.LLVMConstInt(enum_ty, val, 0); } } - return c.LLVMConstInt(i64_type, 0, 0); + return c.LLVMConstInt(enum_ty, 0, 0); } fn lookupVariantValue(self: *CodeGen, enum_name: ?[]const u8, variants: ?[]const []const u8, name: []const u8) u64 { @@ -6064,7 +6136,7 @@ pub const CodeGen = struct { const entry = self.named_values.get(match.subject.data.identifier.name).?; const info = self.tagged_enum_types.get(union_name.?).?; const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 0, "tag"); - break :blk c.LLVMBuildLoad2(self.builder, c.LLVMInt64TypeInContext(self.context), tag_gep, "tag_val"); + break :blk c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(union_name.?), tag_gep, "tag_val"); } else try self.genExpr(match.subject); const variants: ?[]const []const u8 = if (union_name) |un| @@ -6076,6 +6148,8 @@ pub const CodeGen = struct { const function = self.current_function; const i64_type = c.LLVMInt64TypeInContext(self.context); + // Enum/union case constants use the backing type; Any dispatch uses i64 + const case_int_type = if (enum_name) |en| self.getEnumLLVMType(en) else if (union_name) |un| self.getEnumLLVMType(un) else i64_type; const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "match_end"); // Create case basic blocks @@ -6103,7 +6177,7 @@ pub const CodeGen = struct { const pat = arm.pattern orelse continue; // skip else arm if (pat.data == .enum_literal) { const idx = self.lookupVariantValue(enum_name orelse union_name, variants, pat.data.enum_literal.name); - const case_val = c.LLVMConstInt(i64_type, idx, 0); + const case_val = c.LLVMConstInt(case_int_type, idx, 0); c.LLVMAddCase(sw, case_val, case_bbs.items[i]); } else if (pat.data == .type_expr) { // Type-match: resolve type name to Any tag value(s) diff --git a/src/comptime.zig b/src/comptime.zig index 55503e2..ead96be 100644 --- a/src/comptime.zig +++ b/src/comptime.zig @@ -104,7 +104,7 @@ pub const Value = union(enum) { }, .pointer_val => |pv| { const inner = try pv.target[0].format(allocator); - return std.fmt.allocPrint(allocator, "&{s}", .{inner}); + return std.fmt.allocPrint(allocator, "@{s}", .{inner}); }, .null_val => allocator.dupe(u8, "null"), }; @@ -1947,7 +1947,7 @@ test "VM: address-of and deref" { defer arena.deinit(); const alloc = arena.allocator(); - // x := 42; ptr := &x; ptr.* + // x := 42; ptr := @x; ptr.* const code = [_]Instruction{ .{ .push_int = 42 }, .{ .set_local = 0 }, // x = 42 @@ -1973,7 +1973,7 @@ test "VM: deref_set modifies through pointer" { defer arena.deinit(); const alloc = arena.allocator(); - // x := 10; ptr := &x; ptr.* = 99; x + // x := 10; ptr := @x; ptr.* = 99; x const code = [_]Instruction{ .{ .push_int = 10 }, .{ .set_local = 0 }, // x = 10 @@ -2021,7 +2021,7 @@ test "VM: pointer to struct field access" { defer arena.deinit(); const alloc = arena.allocator(); - // Build: struct{x: 10, y: 20}, &struct, get_field(1) — auto-deref + // Build: struct{x: 10, y: 20}, @struct, get_field(1) — auto-deref const code = [_]Instruction{ .{ .push_int = 10 }, .{ .push_int = 20 }, diff --git a/src/lexer.zig b/src/lexer.zig index bc4d478..8e16b5b 100644 --- a/src/lexer.zig +++ b/src/lexer.zig @@ -177,6 +177,7 @@ pub const Lexer = struct { return self.makeToken(.percent, start, self.index); }, '&' => return self.makeToken(.ampersand, start, self.index), + '@' => return self.makeToken(.at, start, self.index), '|' => return self.makeToken(.pipe, start, self.index), '!' => { if (self.peek() == '=') { diff --git a/src/lsp/server.zig b/src/lsp/server.zig index c9b7fed..1a7caf4 100644 --- a/src/lsp/server.zig +++ b/src/lsp/server.zig @@ -786,6 +786,7 @@ pub const Server = struct { .percent, .percent_equal, .ampersand, + .at, .pipe, .arrow, .fat_arrow, @@ -1634,10 +1635,17 @@ pub const Server = struct { .enum_decl => |ed| { try buf.appendSlice(allocator, ed.name); if (ed.is_flags) { - try buf.appendSlice(allocator, " :: enum flags { "); + try buf.appendSlice(allocator, " :: enum flags "); } else { - try buf.appendSlice(allocator, " :: enum { "); + try buf.appendSlice(allocator, " :: enum "); } + if (ed.backing_type) |bt| { + if (bt.data == .type_expr) { + try buf.appendSlice(allocator, bt.data.type_expr.name); + try buf.appendSlice(allocator, " "); + } + } + try buf.appendSlice(allocator, "{ "); for (ed.variant_names, 0..) |v, i| { if (i > 0) try buf.appendSlice(allocator, ", "); try buf.append(allocator, '.'); diff --git a/src/parser.zig b/src/parser.zig index 0efaa26..ec5e156 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -389,6 +389,12 @@ pub const Parser = struct { self.advance(); } + // Check for optional backing type: enum u8 { ... } or enum flags u32 { ... } + var backing_type: ?*Node = null; + if (self.current.tag != .l_brace) { + backing_type = try self.parseTypeExpr(); + } + try self.expect(.l_brace); var variant_names = std.ArrayList([]const u8).empty; var variant_types = std.ArrayList(?*Node).empty; @@ -438,6 +444,7 @@ pub const Parser = struct { .variant_types = if (has_any_type) try variant_types.toOwnedSlice(self.allocator) else &.{}, .is_flags = is_flags, .variant_values = if (has_any_value) try variant_values.toOwnedSlice(self.allocator) else &.{}, + .backing_type = backing_type, } }); } @@ -974,7 +981,7 @@ pub const Parser = struct { const operand = try self.parseUnary(); return try self.createNode(start, .{ .unary_op = .{ .op = .xx, .operand = operand } }); } - if (self.current.tag == .ampersand) { + if (self.current.tag == .at) { const start = self.current.loc.start; self.advance(); const operand = try self.parseUnary(); diff --git a/src/token.zig b/src/token.zig index d622573..bc74814 100644 --- a/src/token.zig +++ b/src/token.zig @@ -60,6 +60,7 @@ pub const Tag = enum { percent, // % percent_equal, // %= ampersand, // & + at, // @ pipe, // | // Delimiters @@ -116,6 +117,7 @@ pub const Tag = enum { .percent => "%", .percent_equal => "%=", .ampersand => "&", + .at => "@", .pipe => "|", .kw_null => "null", .l_paren => "(",