diff --git a/examples/31-flags.sx b/examples/31-flags.sx new file mode 100644 index 0000000..3e74f96 --- /dev/null +++ b/examples/31-flags.sx @@ -0,0 +1,61 @@ +#import "modules/std.sx"; + +// Auto power-of-2 values: read=1, write=2, execute=4 +Perms :: enum flags { + read; + write; + execute; +} + +// Explicit values (e.g. for C interop) +WindowFlags :: enum flags { + vsync :: 64; + resizable :: 4; + hidden :: 128; +} + +check_perms :: (p: Perms) { + print(" checking: {}\n", p); + if p & .read { print(" - can read\n"); } + if p & .write { print(" - can write\n"); } + if p & .execute { print(" - can execute\n"); } +} + +main :: () { + // Combine flags with | + p :Perms = .read | .write; + print("perms: {}\n", p); + + // Test individual flags with & + check_perms(p); + + // All flags + all :Perms = .read | .write | .execute; + print("\nall: {}\n", all); + check_perms(all); + + // Single flag + r :Perms = .read; + print("\nread only: {}\n", r); + check_perms(r); + + // Pass flags to functions, match on them + print("\nmatch on flags:\n"); + f :Perms = .execute; + if f == { + case .read: print(" read\n"); + case .write: print(" write\n"); + case .execute: print(" execute\n"); + } + + // Explicit values + w :WindowFlags = .vsync | .resizable; + print("\nwindow: {}\n", w); + print("raw value: {}\n", cast(s64) w); + + // Bitwise ops work on plain integers too + x := 0xFF & 0x0F; + y := 1 | 2 | 4; + print("\n0xFF & 0x0F = {}\n", x); + print("1 | 2 | 4 = {}\n", y); +} diff --git a/examples/modules/std.sx b/examples/modules/std.sx index ba53b3a..53f7491 100644 --- a/examples/modules/std.sx +++ b/examples/modules/std.sx @@ -13,6 +13,8 @@ type_name :: ($T: Type) -> string #builtin; field_count :: ($T: Type) -> s64 #builtin; field_name :: ($T: Type, idx: s64) -> string #builtin; field_value :: (s: $T, idx: s64) -> Any #builtin; +is_flags :: ($T: Type) -> bool #builtin; +field_value_int :: ($T: Type, idx: s64) -> s64 #builtin; string :: []u8 #builtin; int_to_string :: (n: s64) -> string { @@ -201,7 +203,24 @@ pointer_to_string :: (p: $T) -> string { } } +flags_to_string :: (val: $T) -> string { + v := cast(s64) val; + result := ""; + i := 0; + while i < field_count(T) { + fv := field_value_int(T, i); + if v & fv { + if result.len > 0 { result = concat(result, " | "); } + result = concat(result, concat(".", field_name(T, i))); + } + i += 1; + } + if result.len == 0 { result = "0"; } + result; +} + enum_to_string :: (u: $T) -> string { + if is_flags(T) { return flags_to_string(u); } tag := cast(s64) u; result := concat(".", field_name(T, tag)); payload := field_value(u, tag); diff --git a/specs.md b/specs.md index 2f86b07..53562cd 100644 --- a/specs.md +++ b/specs.md @@ -64,6 +64,8 @@ GLSL; | `>` | greater than | | `<=` | less or equal | | `>=` | greater or equal | +| `&` | bitwise AND | +| `\|` | bitwise OR | | `and` | logical AND (short-circuit) | | `or` | logical OR (short-circuit) | | `+=` | add-assign | @@ -553,6 +555,47 @@ Name :: enum { Defines a new enum type with the given variants. Trailing comma is allowed. +### Enum Flags + +```sx +Perms :: enum flags { + read; // 1 + write; // 2 + execute; // 4 +} +``` + +The `flags` modifier assigns auto power-of-2 values (1, 2, 4, 8, ...) instead of sequential indices (0, 1, 2, ...). Flags can be combined with `|` and tested with `&`: + +```sx +p :Perms = .read | .write; +if p & .execute { ... } +print("{}\n", p); // .read | .write +``` + +Explicit values use `::` syntax (Jai-style): + +```sx +WindowFlags :: enum flags { + vsync :: 64; + resizable :: 4; + hidden :: 128; +} +``` + +Restrictions: +- Flags enum variants cannot have payloads +- `flags` is a contextual identifier, not a keyword + +### Bitwise Operators + +`&` (bitwise AND) and `|` (bitwise OR) work on all integer types, not just flags. They sit at precedence level 3, between comparisons and logical operators. + +```sx +x := 0xFF & 0x0F; // 15 +y := 1 | 2 | 4; // 7 +``` + --- ## 4. Expressions @@ -563,9 +606,10 @@ Everything in `sx` is expression-oriented where possible. | Prec | Operators | Notes | |------|-----------|-------| -| 6 (highest) | `*`, `/` | multiplication, division | +| 6 (highest) | `*`, `/`, `%` | multiplication, division, modulo | | 5 | `+`, `-` | addition, subtraction | | 4 | `<`, `<=`, `>`, `>=`, `==`, `!=` | comparisons (chainable) | +| 3 | `&`, `\|` | bitwise AND, bitwise OR | | 2 | `and` | logical AND (short-circuit) | | 1 (lowest) | `or` | logical OR (short-circuit) | diff --git a/src/ast.zig b/src/ast.zig index e045423..9a27524 100644 --- a/src/ast.zig +++ b/src/ast.zig @@ -148,6 +148,8 @@ pub const BinaryOp = struct { gte, and_op, or_op, + bit_and, + bit_or, }; }; @@ -227,6 +229,8 @@ pub const EnumDecl = struct { name: []const u8, variant_names: []const []const u8, variant_types: []const ?*Node = &.{}, // null entries = no payload; empty = payload-less enum + is_flags: bool = false, + variant_values: []const ?*Node = &.{}, // explicit value per variant (null = auto), empty = all auto }; pub const UnionDecl = struct { diff --git a/src/codegen.zig b/src/codegen.zig index 6742687..698aada 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -112,6 +112,10 @@ pub const CodeGen = struct { tagged_enum_types: std.StringHashMap(TaggedEnumInfo), // Union registry: maps name to field info + LLVM type (untagged, C-style) union_types: std.StringHashMap(UnionInfo), + // Flags enum registry: tracks which enum names are flags + flags_enum_types: std.StringHashMap(void), + // Enum variant values: maps enum name → resolved i64 values per variant + enum_variant_values: std.StringHashMap([]const i64), // Built-in functions (printf, etc.) builtins: ?Builtins, // Current function being generated (for alloca insertion) @@ -294,6 +298,8 @@ pub const CodeGen = struct { .struct_types = std.StringHashMap(StructInfo).init(allocator), .tagged_enum_types = std.StringHashMap(TaggedEnumInfo).init(allocator), .union_types = std.StringHashMap(UnionInfo).init(allocator), + .flags_enum_types = std.StringHashMap(void).init(allocator), + .enum_variant_values = std.StringHashMap([]const i64).init(allocator), .builtins = null, .current_function = null, .scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty, @@ -709,6 +715,30 @@ pub const CodeGen = struct { // Payload-less enum try self.enum_types.put(ed.name, ed.variant_names); _ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name }); + + if (ed.is_flags) { + try self.flags_enum_types.put(ed.name, {}); + } + + // Compute and store variant values + const values = try self.allocator.alloc(i64, ed.variant_names.len); + for (ed.variant_names, 0..) |_, i| { + if (ed.variant_values.len > i and ed.variant_values[i] != null) { + // Explicit value: evaluate comptime int literal + const val_node = ed.variant_values[i].?; + values[i] = switch (val_node.data) { + .int_literal => |il| il.value, + else => @as(i64, @intCast(i)), // fallback + }; + } else if (ed.is_flags) { + // Auto power-of-2: 1, 2, 4, 8, ... + values[i] = @as(i64, 1) << @intCast(i); + } else { + // Regular enum: sequential 0, 1, 2, ... + values[i] = @intCast(i); + } + } + try self.enum_variant_values.put(ed.name, values); } }, .struct_decl => |sd| try self.registerStructType(sd), @@ -3446,6 +3476,23 @@ pub const CodeGen = struct { return c.LLVMBuildGlobalStringPtr(self.builder, str_z.ptr, "str"); } + // Enum literal assigned to enum type: resolve variant value + if (node.data == .enum_literal and target_ty.isEnum()) { + return self.genEnumLiteral(node.data.enum_literal.name, target_ty.enum_type); + } + + // Bitwise op on enum type: recursively generate both sides with enum context + if (node.data == .binary_op and (node.data.binary_op.op == .bit_or or node.data.binary_op.op == .bit_and) and target_ty.isEnum()) { + const binop = node.data.binary_op; + const lhs = try self.genExprAsType(binop.lhs, target_ty); + const rhs = try self.genExprAsType(binop.rhs, target_ty); + const b = self.builder; + return if (binop.op == .bit_or) + c.LLVMBuildOr(b, lhs, rhs, "bortmp") + else + c.LLVMBuildAnd(b, lhs, rhs, "bandtmp"); + } + // Enum/union literal assigned to union type: construct tagged enum if (node.data == .enum_literal and target_ty.isUnion()) { const el = node.data.enum_literal; @@ -4203,6 +4250,62 @@ pub const CodeGen = struct { return phi; } + fn genIsFlags(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef { + if (call_node.args.len != 1) return self.emitError("is_flags expects exactly 1 argument"); + const ty = self.resolveType(call_node.args[0]); + const i1_type = c.LLVMInt1TypeInContext(self.context); + if (ty.isEnum()) { + const is_flags = self.flags_enum_types.contains(ty.enum_type); + return c.LLVMConstInt(i1_type, @intFromBool(is_flags), 0); + } + return c.LLVMConstInt(i1_type, 0, 0); + } + + fn genFieldValueInt(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef { + if (call_node.args.len != 2) return self.emitError("field_value_int expects 2 arguments: field_value_int(T, idx)"); + const ty = self.resolveType(call_node.args[0]); + const i64_type = c.LLVMInt64TypeInContext(self.context); + // For non-enum types (e.g. tagged enums compiled via dead code), return the index as value + if (!ty.isEnum()) { + return try self.genExpr(call_node.args[1]); + } + const enum_name = ty.enum_type; + const values = self.enum_variant_values.get(enum_name); + const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]); + const n = variants.len; + + const idx = try self.genExpr(call_node.args[1]); + const function = self.current_function; + const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_merge"); + const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_default"); + const sw = c.LLVMBuildSwitch(self.builder, idx, default_bb, @intCast(n)); + + var phi_vals = std.ArrayList(c.LLVMValueRef).empty; + var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty; + + for (0..n) |i| { + const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_case"); + c.LLVMAddCase(sw, c.LLVMConstInt(i64_type, i, 0), case_bb); + c.LLVMPositionBuilderAtEnd(self.builder, case_bb); + const val: u64 = if (values) |vals| @bitCast(vals[i]) else i; + try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, val, 0)); + try phi_bbs.append(self.allocator, case_bb); + _ = c.LLVMBuildBr(self.builder, merge_bb); + } + + c.LLVMPositionBuilderAtEnd(self.builder, default_bb); + try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, 0, 0)); + try phi_bbs.append(self.allocator, default_bb); + _ = c.LLVMBuildBr(self.builder, merge_bb); + + c.LLVMPositionBuilderAtEnd(self.builder, merge_bb); + const vals_slice = try phi_vals.toOwnedSlice(self.allocator); + const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator); + const phi = c.LLVMBuildPhi(self.builder, i64_type, "fvi_result"); + c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len)); + return phi; + } + fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef { if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr"); const target_ty = self.resolveType(call_node.args[0]); @@ -4591,6 +4694,8 @@ pub const CodeGen = struct { .lte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOLE, lhs, rhs, "letmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntULE, lhs, rhs, "letmp") else c.LLVMBuildICmp(b, c.LLVMIntSLE, lhs, rhs, "letmp"), .gt => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGT, lhs, rhs, "gttmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGT, lhs, rhs, "gttmp") else c.LLVMBuildICmp(b, c.LLVMIntSGT, lhs, rhs, "gttmp"), .gte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGE, lhs, rhs, "getmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGE, lhs, rhs, "getmp") else c.LLVMBuildICmp(b, c.LLVMIntSGE, lhs, rhs, "getmp"), + .bit_and => c.LLVMBuildAnd(b, lhs, rhs, "bandtmp"), + .bit_or => c.LLVMBuildOr(b, lhs, rhs, "bortmp"), .and_op, .or_op => unreachable, }; } @@ -5198,6 +5303,7 @@ pub const CodeGen = struct { try self.instantiateGeneric(fd, bindings, mangled); // Generate arguments with type conversion to match parameter types + const saved_call_bindings = self.type_param_bindings; self.type_param_bindings = bindings; var arg_vals = std.ArrayList(c.LLVMValueRef).empty; for (call_node.args, 0..) |arg, i| { @@ -5208,7 +5314,7 @@ pub const CodeGen = struct { try arg_vals.append(self.allocator, try self.genExpr(arg)); } } - self.type_param_bindings = null; + self.type_param_bindings = saved_call_bindings; const args_slice = try arg_vals.toOwnedSlice(self.allocator); const fn_type = c.LLVMGlobalGetValueType(callee_fn); @@ -5584,9 +5690,10 @@ pub const CodeGen = struct { self.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty; self.defer_stack = std.ArrayList(std.ArrayList(*Node)).empty; - // Set type param bindings + // Set type param bindings (save/restore to support nested generic instantiation) + const saved_bindings = self.type_param_bindings; self.type_param_bindings = bindings; - defer self.type_param_bindings = null; + defer self.type_param_bindings = saved_bindings; // Build the specialized function type const fn_type = try self.buildFnType(fd.params, fd.return_type, mangled); @@ -5913,18 +6020,28 @@ pub const CodeGen = struct { fn genEnumLiteral(self: *CodeGen, variant_name: []const u8, enum_type_name: []const u8) c.LLVMValueRef { const i64_type = c.LLVMInt64TypeInContext(self.context); const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(i64_type, 0, 0); + const values = self.enum_variant_values.get(enum_type_name); for (variants, 0..) |v, i| { if (std.mem.eql(u8, v, variant_name)) { - return c.LLVMConstInt(i64_type, @intCast(i), 0); + const val: u64 = if (values) |vals| @bitCast(vals[i]) else @intCast(i); + return c.LLVMConstInt(i64_type, val, 0); } } return c.LLVMConstInt(i64_type, 0, 0); } - fn lookupVariantIndex(variants: ?[]const []const u8, name: []const u8) u64 { + fn lookupVariantValue(self: *CodeGen, enum_name: ?[]const u8, variants: ?[]const []const u8, name: []const u8) u64 { if (variants) |vs| { for (vs, 0..) |v, i| { - if (std.mem.eql(u8, v, name)) return i; + if (std.mem.eql(u8, v, name)) { + // Use resolved values if available (flags enums, explicit values) + if (enum_name) |en| { + if (self.enum_variant_values.get(en)) |vals| { + return @bitCast(vals[i]); + } + } + return i; + } } } return 0; @@ -5985,7 +6102,7 @@ pub const CodeGen = struct { for (match.arms, 0..) |arm, i| { const pat = arm.pattern orelse continue; // skip else arm if (pat.data == .enum_literal) { - const idx = lookupVariantIndex(variants, pat.data.enum_literal.name); + const idx = self.lookupVariantValue(enum_name orelse union_name, variants, pat.data.enum_literal.name); const case_val = c.LLVMConstInt(i64_type, idx, 0); c.LLVMAddCase(sw, case_val, case_bbs.items[i]); } else if (pat.data == .type_expr) { @@ -6235,6 +6352,8 @@ pub const CodeGen = struct { if (std.mem.eql(u8, base, "field_count")) return self.genFieldCount(call_node); if (std.mem.eql(u8, base, "field_name")) return self.genFieldName(call_node); if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node); + if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node); + if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node); return self.emitErrorFmt("unknown builtin function '{s}'", .{name}); } @@ -6466,6 +6585,10 @@ pub const CodeGen = struct { if (std.mem.eql(u8, base_name, "field_name")) return .string_type; // Built-in: field_value returns Any if (std.mem.eql(u8, base_name, "field_value")) return .{ .any_type = {} }; + // Built-in: is_flags returns bool + if (std.mem.eql(u8, base_name, "is_flags")) return .boolean; + // Built-in: field_value_int returns s64 + if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64); // Built-in: cast returns the target type (first arg) if (std.mem.eql(u8, base_name, "cast")) { if (call_node.args.len > 0) return self.resolveType(call_node.args[0]); diff --git a/src/comptime.zig b/src/comptime.zig index e41b0fb..55503e2 100644 --- a/src/comptime.zig +++ b/src/comptime.zig @@ -147,6 +147,10 @@ pub const Instruction = union(enum) { gt, gte, + // Bitwise + bit_and, + bit_or, + // Logic not, @@ -451,6 +455,8 @@ pub const Compiler = struct { .lte => .lte, .gt => .gt, .gte => .gte, + .bit_and => .bit_and, + .bit_or => .bit_or, .and_op, .or_op => unreachable, }); } @@ -1026,6 +1032,20 @@ pub const VM = struct { const a = try self.pop(); try self.push(try self.arith(a, b, .mod_op)); }, + .bit_and => { + const b = try self.pop(); + const a = try self.pop(); + if (a == .int_val and b == .int_val) { + try self.push(.{ .int_val = a.int_val & b.int_val }); + } else return error.TypeError; + }, + .bit_or => { + const b = try self.pop(); + const a = try self.pop(); + if (a == .int_val and b == .int_val) { + try self.push(.{ .int_val = a.int_val | b.int_val }); + } else return error.TypeError; + }, .negate => { const v = try self.pop(); try self.push(switch (v) { diff --git a/src/lexer.zig b/src/lexer.zig index 2a6ea07..bc4d478 100644 --- a/src/lexer.zig +++ b/src/lexer.zig @@ -177,6 +177,7 @@ pub const Lexer = struct { return self.makeToken(.percent, start, self.index); }, '&' => return self.makeToken(.ampersand, start, self.index), + '|' => return self.makeToken(.pipe, start, self.index), '!' => { if (self.peek() == '=') { self.index += 1; diff --git a/src/lsp/server.zig b/src/lsp/server.zig index c652618..c9b7fed 100644 --- a/src/lsp/server.zig +++ b/src/lsp/server.zig @@ -786,6 +786,7 @@ pub const Server = struct { .percent, .percent_equal, .ampersand, + .pipe, .arrow, .fat_arrow, .colon_colon, @@ -1632,7 +1633,11 @@ pub const Server = struct { }, .enum_decl => |ed| { try buf.appendSlice(allocator, ed.name); - try buf.appendSlice(allocator, " :: enum { "); + if (ed.is_flags) { + try buf.appendSlice(allocator, " :: enum flags { "); + } else { + try buf.appendSlice(allocator, " :: enum { "); + } for (ed.variant_names, 0..) |v, i| { if (i > 0) try buf.appendSlice(allocator, ", "); try buf.append(allocator, '.'); diff --git a/src/parser.zig b/src/parser.zig index 25c0a0d..0efaa26 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -381,25 +381,50 @@ pub const Parser = struct { fn parseEnumDecl(self: *Parser, name: []const u8, start_pos: u32) anyerror!*Node { self.advance(); // skip 'enum' + + // Check for 'flags' modifier: enum flags { ... } + var is_flags = false; + if (self.current.tag == .identifier and std.mem.eql(u8, self.tokenSlice(self.current), "flags")) { + is_flags = true; + self.advance(); + } + try self.expect(.l_brace); var variant_names = std.ArrayList([]const u8).empty; var variant_types = std.ArrayList(?*Node).empty; + var variant_values = std.ArrayList(?*Node).empty; var has_any_type = false; + var has_any_value = false; while (self.current.tag != .r_brace and self.current.tag != .eof) { if (self.current.tag != .identifier) { return self.fail("expected variant name"); } try variant_names.append(self.allocator, self.tokenSlice(self.current)); self.advance(); - if (self.current.tag == .colon) { + if (self.current.tag == .colon_colon) { + // Explicit value: name :: expr; + if (!is_flags) { + return self.fail("explicit enum values require 'enum flags'"); + } + self.advance(); + const val_expr = try self.parseExpr(); + try variant_values.append(self.allocator, val_expr); + try variant_types.append(self.allocator, null); + has_any_value = true; + } else if (self.current.tag == .colon) { // Typed variant: name: type; + if (is_flags) { + return self.fail("flags enum variants cannot have payloads"); + } self.advance(); const vtype = try self.parseTypeExpr(); try variant_types.append(self.allocator, vtype); + try variant_values.append(self.allocator, null); has_any_type = true; } else { // Void variant: name; try variant_types.append(self.allocator, null); + try variant_values.append(self.allocator, null); } if (self.current.tag == .semicolon) { self.advance(); @@ -411,6 +436,8 @@ pub const Parser = struct { .name = name, .variant_names = try variant_names.toOwnedSlice(self.allocator), .variant_types = if (has_any_type) try variant_types.toOwnedSlice(self.allocator) else &.{}, + .is_flags = is_flags, + .variant_values = if (has_any_value) try variant_values.toOwnedSlice(self.allocator) else &.{}, } }); } @@ -1539,6 +1566,8 @@ pub const Parser = struct { return switch (self.current.tag) { .kw_or => 1, .kw_and => 2, + .pipe => 3, + .ampersand => 3, .equal_equal, .bang_equal, .less, .less_equal, .greater, .greater_equal => 4, .plus, .minus => 5, .star, .slash, .percent => 6, @@ -1550,6 +1579,8 @@ pub const Parser = struct { return switch (self.current.tag) { .kw_and => .and_op, .kw_or => .or_op, + .pipe => .bit_or, + .ampersand => .bit_and, .plus => .add, .minus => .sub, .star => .mul, diff --git a/src/token.zig b/src/token.zig index 97b66be..d622573 100644 --- a/src/token.zig +++ b/src/token.zig @@ -60,6 +60,7 @@ pub const Tag = enum { percent, // % percent_equal, // %= ampersand, // & + pipe, // | // Delimiters l_paren, // ( @@ -115,6 +116,7 @@ pub const Tag = enum { .percent => "%", .percent_equal => "%=", .ampersand => "&", + .pipe => "|", .kw_null => "null", .l_paren => "(", .r_paren => ")",