flags
This commit is contained in:
@@ -148,6 +148,8 @@ pub const BinaryOp = struct {
|
||||
gte,
|
||||
and_op,
|
||||
or_op,
|
||||
bit_and,
|
||||
bit_or,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -227,6 +229,8 @@ pub const EnumDecl = struct {
|
||||
name: []const u8,
|
||||
variant_names: []const []const u8,
|
||||
variant_types: []const ?*Node = &.{}, // null entries = no payload; empty = payload-less enum
|
||||
is_flags: bool = false,
|
||||
variant_values: []const ?*Node = &.{}, // explicit value per variant (null = auto), empty = all auto
|
||||
};
|
||||
|
||||
pub const UnionDecl = struct {
|
||||
|
||||
137
src/codegen.zig
137
src/codegen.zig
@@ -112,6 +112,10 @@ pub const CodeGen = struct {
|
||||
tagged_enum_types: std.StringHashMap(TaggedEnumInfo),
|
||||
// Union registry: maps name to field info + LLVM type (untagged, C-style)
|
||||
union_types: std.StringHashMap(UnionInfo),
|
||||
// Flags enum registry: tracks which enum names are flags
|
||||
flags_enum_types: std.StringHashMap(void),
|
||||
// Enum variant values: maps enum name → resolved i64 values per variant
|
||||
enum_variant_values: std.StringHashMap([]const i64),
|
||||
// Built-in functions (printf, etc.)
|
||||
builtins: ?Builtins,
|
||||
// Current function being generated (for alloca insertion)
|
||||
@@ -294,6 +298,8 @@ pub const CodeGen = struct {
|
||||
.struct_types = std.StringHashMap(StructInfo).init(allocator),
|
||||
.tagged_enum_types = std.StringHashMap(TaggedEnumInfo).init(allocator),
|
||||
.union_types = std.StringHashMap(UnionInfo).init(allocator),
|
||||
.flags_enum_types = std.StringHashMap(void).init(allocator),
|
||||
.enum_variant_values = std.StringHashMap([]const i64).init(allocator),
|
||||
.builtins = null,
|
||||
.current_function = null,
|
||||
.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty,
|
||||
@@ -709,6 +715,30 @@ pub const CodeGen = struct {
|
||||
// Payload-less enum
|
||||
try self.enum_types.put(ed.name, ed.variant_names);
|
||||
_ = try self.getAnyTypeId(ed.name, .{ .enum_type = ed.name });
|
||||
|
||||
if (ed.is_flags) {
|
||||
try self.flags_enum_types.put(ed.name, {});
|
||||
}
|
||||
|
||||
// Compute and store variant values
|
||||
const values = try self.allocator.alloc(i64, ed.variant_names.len);
|
||||
for (ed.variant_names, 0..) |_, i| {
|
||||
if (ed.variant_values.len > i and ed.variant_values[i] != null) {
|
||||
// Explicit value: evaluate comptime int literal
|
||||
const val_node = ed.variant_values[i].?;
|
||||
values[i] = switch (val_node.data) {
|
||||
.int_literal => |il| il.value,
|
||||
else => @as(i64, @intCast(i)), // fallback
|
||||
};
|
||||
} else if (ed.is_flags) {
|
||||
// Auto power-of-2: 1, 2, 4, 8, ...
|
||||
values[i] = @as(i64, 1) << @intCast(i);
|
||||
} else {
|
||||
// Regular enum: sequential 0, 1, 2, ...
|
||||
values[i] = @intCast(i);
|
||||
}
|
||||
}
|
||||
try self.enum_variant_values.put(ed.name, values);
|
||||
}
|
||||
},
|
||||
.struct_decl => |sd| try self.registerStructType(sd),
|
||||
@@ -3446,6 +3476,23 @@ pub const CodeGen = struct {
|
||||
return c.LLVMBuildGlobalStringPtr(self.builder, str_z.ptr, "str");
|
||||
}
|
||||
|
||||
// Enum literal assigned to enum type: resolve variant value
|
||||
if (node.data == .enum_literal and target_ty.isEnum()) {
|
||||
return self.genEnumLiteral(node.data.enum_literal.name, target_ty.enum_type);
|
||||
}
|
||||
|
||||
// Bitwise op on enum type: recursively generate both sides with enum context
|
||||
if (node.data == .binary_op and (node.data.binary_op.op == .bit_or or node.data.binary_op.op == .bit_and) and target_ty.isEnum()) {
|
||||
const binop = node.data.binary_op;
|
||||
const lhs = try self.genExprAsType(binop.lhs, target_ty);
|
||||
const rhs = try self.genExprAsType(binop.rhs, target_ty);
|
||||
const b = self.builder;
|
||||
return if (binop.op == .bit_or)
|
||||
c.LLVMBuildOr(b, lhs, rhs, "bortmp")
|
||||
else
|
||||
c.LLVMBuildAnd(b, lhs, rhs, "bandtmp");
|
||||
}
|
||||
|
||||
// Enum/union literal assigned to union type: construct tagged enum
|
||||
if (node.data == .enum_literal and target_ty.isUnion()) {
|
||||
const el = node.data.enum_literal;
|
||||
@@ -4203,6 +4250,62 @@ pub const CodeGen = struct {
|
||||
return phi;
|
||||
}
|
||||
|
||||
fn genIsFlags(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 1) return self.emitError("is_flags expects exactly 1 argument");
|
||||
const ty = self.resolveType(call_node.args[0]);
|
||||
const i1_type = c.LLVMInt1TypeInContext(self.context);
|
||||
if (ty.isEnum()) {
|
||||
const is_flags = self.flags_enum_types.contains(ty.enum_type);
|
||||
return c.LLVMConstInt(i1_type, @intFromBool(is_flags), 0);
|
||||
}
|
||||
return c.LLVMConstInt(i1_type, 0, 0);
|
||||
}
|
||||
|
||||
fn genFieldValueInt(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 2) return self.emitError("field_value_int expects 2 arguments: field_value_int(T, idx)");
|
||||
const ty = self.resolveType(call_node.args[0]);
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
// For non-enum types (e.g. tagged enums compiled via dead code), return the index as value
|
||||
if (!ty.isEnum()) {
|
||||
return try self.genExpr(call_node.args[1]);
|
||||
}
|
||||
const enum_name = ty.enum_type;
|
||||
const values = self.enum_variant_values.get(enum_name);
|
||||
const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]);
|
||||
const n = variants.len;
|
||||
|
||||
const idx = try self.genExpr(call_node.args[1]);
|
||||
const function = self.current_function;
|
||||
const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_merge");
|
||||
const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_default");
|
||||
const sw = c.LLVMBuildSwitch(self.builder, idx, default_bb, @intCast(n));
|
||||
|
||||
var phi_vals = std.ArrayList(c.LLVMValueRef).empty;
|
||||
var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty;
|
||||
|
||||
for (0..n) |i| {
|
||||
const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fvi_case");
|
||||
c.LLVMAddCase(sw, c.LLVMConstInt(i64_type, i, 0), case_bb);
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, case_bb);
|
||||
const val: u64 = if (values) |vals| @bitCast(vals[i]) else i;
|
||||
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, val, 0));
|
||||
try phi_bbs.append(self.allocator, case_bb);
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
}
|
||||
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, default_bb);
|
||||
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, 0, 0));
|
||||
try phi_bbs.append(self.allocator, default_bb);
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, merge_bb);
|
||||
const vals_slice = try phi_vals.toOwnedSlice(self.allocator);
|
||||
const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator);
|
||||
const phi = c.LLVMBuildPhi(self.builder, i64_type, "fvi_result");
|
||||
c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len));
|
||||
return phi;
|
||||
}
|
||||
|
||||
fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr");
|
||||
const target_ty = self.resolveType(call_node.args[0]);
|
||||
@@ -4591,6 +4694,8 @@ pub const CodeGen = struct {
|
||||
.lte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOLE, lhs, rhs, "letmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntULE, lhs, rhs, "letmp") else c.LLVMBuildICmp(b, c.LLVMIntSLE, lhs, rhs, "letmp"),
|
||||
.gt => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGT, lhs, rhs, "gttmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGT, lhs, rhs, "gttmp") else c.LLVMBuildICmp(b, c.LLVMIntSGT, lhs, rhs, "gttmp"),
|
||||
.gte => if (is_float) c.LLVMBuildFCmp(b, c.LLVMRealOGE, lhs, rhs, "getmp") else if (is_unsigned) c.LLVMBuildICmp(b, c.LLVMIntUGE, lhs, rhs, "getmp") else c.LLVMBuildICmp(b, c.LLVMIntSGE, lhs, rhs, "getmp"),
|
||||
.bit_and => c.LLVMBuildAnd(b, lhs, rhs, "bandtmp"),
|
||||
.bit_or => c.LLVMBuildOr(b, lhs, rhs, "bortmp"),
|
||||
.and_op, .or_op => unreachable,
|
||||
};
|
||||
}
|
||||
@@ -5198,6 +5303,7 @@ pub const CodeGen = struct {
|
||||
try self.instantiateGeneric(fd, bindings, mangled);
|
||||
|
||||
// Generate arguments with type conversion to match parameter types
|
||||
const saved_call_bindings = self.type_param_bindings;
|
||||
self.type_param_bindings = bindings;
|
||||
var arg_vals = std.ArrayList(c.LLVMValueRef).empty;
|
||||
for (call_node.args, 0..) |arg, i| {
|
||||
@@ -5208,7 +5314,7 @@ pub const CodeGen = struct {
|
||||
try arg_vals.append(self.allocator, try self.genExpr(arg));
|
||||
}
|
||||
}
|
||||
self.type_param_bindings = null;
|
||||
self.type_param_bindings = saved_call_bindings;
|
||||
const args_slice = try arg_vals.toOwnedSlice(self.allocator);
|
||||
|
||||
const fn_type = c.LLVMGlobalGetValueType(callee_fn);
|
||||
@@ -5584,9 +5690,10 @@ pub const CodeGen = struct {
|
||||
self.scope_saves = std.ArrayList(std.ArrayList(ScopeEntry)).empty;
|
||||
self.defer_stack = std.ArrayList(std.ArrayList(*Node)).empty;
|
||||
|
||||
// Set type param bindings
|
||||
// Set type param bindings (save/restore to support nested generic instantiation)
|
||||
const saved_bindings = self.type_param_bindings;
|
||||
self.type_param_bindings = bindings;
|
||||
defer self.type_param_bindings = null;
|
||||
defer self.type_param_bindings = saved_bindings;
|
||||
|
||||
// Build the specialized function type
|
||||
const fn_type = try self.buildFnType(fd.params, fd.return_type, mangled);
|
||||
@@ -5913,18 +6020,28 @@ pub const CodeGen = struct {
|
||||
fn genEnumLiteral(self: *CodeGen, variant_name: []const u8, enum_type_name: []const u8) c.LLVMValueRef {
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
const variants = self.enum_types.get(enum_type_name) orelse return c.LLVMConstInt(i64_type, 0, 0);
|
||||
const values = self.enum_variant_values.get(enum_type_name);
|
||||
for (variants, 0..) |v, i| {
|
||||
if (std.mem.eql(u8, v, variant_name)) {
|
||||
return c.LLVMConstInt(i64_type, @intCast(i), 0);
|
||||
const val: u64 = if (values) |vals| @bitCast(vals[i]) else @intCast(i);
|
||||
return c.LLVMConstInt(i64_type, val, 0);
|
||||
}
|
||||
}
|
||||
return c.LLVMConstInt(i64_type, 0, 0);
|
||||
}
|
||||
|
||||
fn lookupVariantIndex(variants: ?[]const []const u8, name: []const u8) u64 {
|
||||
fn lookupVariantValue(self: *CodeGen, enum_name: ?[]const u8, variants: ?[]const []const u8, name: []const u8) u64 {
|
||||
if (variants) |vs| {
|
||||
for (vs, 0..) |v, i| {
|
||||
if (std.mem.eql(u8, v, name)) return i;
|
||||
if (std.mem.eql(u8, v, name)) {
|
||||
// Use resolved values if available (flags enums, explicit values)
|
||||
if (enum_name) |en| {
|
||||
if (self.enum_variant_values.get(en)) |vals| {
|
||||
return @bitCast(vals[i]);
|
||||
}
|
||||
}
|
||||
return i;
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
@@ -5985,7 +6102,7 @@ pub const CodeGen = struct {
|
||||
for (match.arms, 0..) |arm, i| {
|
||||
const pat = arm.pattern orelse continue; // skip else arm
|
||||
if (pat.data == .enum_literal) {
|
||||
const idx = lookupVariantIndex(variants, pat.data.enum_literal.name);
|
||||
const idx = self.lookupVariantValue(enum_name orelse union_name, variants, pat.data.enum_literal.name);
|
||||
const case_val = c.LLVMConstInt(i64_type, idx, 0);
|
||||
c.LLVMAddCase(sw, case_val, case_bbs.items[i]);
|
||||
} else if (pat.data == .type_expr) {
|
||||
@@ -6235,6 +6352,8 @@ pub const CodeGen = struct {
|
||||
if (std.mem.eql(u8, base, "field_count")) return self.genFieldCount(call_node);
|
||||
if (std.mem.eql(u8, base, "field_name")) return self.genFieldName(call_node);
|
||||
if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node);
|
||||
if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node);
|
||||
if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node);
|
||||
return self.emitErrorFmt("unknown builtin function '{s}'", .{name});
|
||||
}
|
||||
|
||||
@@ -6466,6 +6585,10 @@ pub const CodeGen = struct {
|
||||
if (std.mem.eql(u8, base_name, "field_name")) return .string_type;
|
||||
// Built-in: field_value returns Any
|
||||
if (std.mem.eql(u8, base_name, "field_value")) return .{ .any_type = {} };
|
||||
// Built-in: is_flags returns bool
|
||||
if (std.mem.eql(u8, base_name, "is_flags")) return .boolean;
|
||||
// Built-in: field_value_int returns s64
|
||||
if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64);
|
||||
// Built-in: cast returns the target type (first arg)
|
||||
if (std.mem.eql(u8, base_name, "cast")) {
|
||||
if (call_node.args.len > 0) return self.resolveType(call_node.args[0]);
|
||||
|
||||
@@ -147,6 +147,10 @@ pub const Instruction = union(enum) {
|
||||
gt,
|
||||
gte,
|
||||
|
||||
// Bitwise
|
||||
bit_and,
|
||||
bit_or,
|
||||
|
||||
// Logic
|
||||
not,
|
||||
|
||||
@@ -451,6 +455,8 @@ pub const Compiler = struct {
|
||||
.lte => .lte,
|
||||
.gt => .gt,
|
||||
.gte => .gte,
|
||||
.bit_and => .bit_and,
|
||||
.bit_or => .bit_or,
|
||||
.and_op, .or_op => unreachable,
|
||||
});
|
||||
}
|
||||
@@ -1026,6 +1032,20 @@ pub const VM = struct {
|
||||
const a = try self.pop();
|
||||
try self.push(try self.arith(a, b, .mod_op));
|
||||
},
|
||||
.bit_and => {
|
||||
const b = try self.pop();
|
||||
const a = try self.pop();
|
||||
if (a == .int_val and b == .int_val) {
|
||||
try self.push(.{ .int_val = a.int_val & b.int_val });
|
||||
} else return error.TypeError;
|
||||
},
|
||||
.bit_or => {
|
||||
const b = try self.pop();
|
||||
const a = try self.pop();
|
||||
if (a == .int_val and b == .int_val) {
|
||||
try self.push(.{ .int_val = a.int_val | b.int_val });
|
||||
} else return error.TypeError;
|
||||
},
|
||||
.negate => {
|
||||
const v = try self.pop();
|
||||
try self.push(switch (v) {
|
||||
|
||||
@@ -177,6 +177,7 @@ pub const Lexer = struct {
|
||||
return self.makeToken(.percent, start, self.index);
|
||||
},
|
||||
'&' => return self.makeToken(.ampersand, start, self.index),
|
||||
'|' => return self.makeToken(.pipe, start, self.index),
|
||||
'!' => {
|
||||
if (self.peek() == '=') {
|
||||
self.index += 1;
|
||||
|
||||
@@ -786,6 +786,7 @@ pub const Server = struct {
|
||||
.percent,
|
||||
.percent_equal,
|
||||
.ampersand,
|
||||
.pipe,
|
||||
.arrow,
|
||||
.fat_arrow,
|
||||
.colon_colon,
|
||||
@@ -1632,7 +1633,11 @@ pub const Server = struct {
|
||||
},
|
||||
.enum_decl => |ed| {
|
||||
try buf.appendSlice(allocator, ed.name);
|
||||
try buf.appendSlice(allocator, " :: enum { ");
|
||||
if (ed.is_flags) {
|
||||
try buf.appendSlice(allocator, " :: enum flags { ");
|
||||
} else {
|
||||
try buf.appendSlice(allocator, " :: enum { ");
|
||||
}
|
||||
for (ed.variant_names, 0..) |v, i| {
|
||||
if (i > 0) try buf.appendSlice(allocator, ", ");
|
||||
try buf.append(allocator, '.');
|
||||
|
||||
@@ -381,25 +381,50 @@ pub const Parser = struct {
|
||||
|
||||
fn parseEnumDecl(self: *Parser, name: []const u8, start_pos: u32) anyerror!*Node {
|
||||
self.advance(); // skip 'enum'
|
||||
|
||||
// Check for 'flags' modifier: enum flags { ... }
|
||||
var is_flags = false;
|
||||
if (self.current.tag == .identifier and std.mem.eql(u8, self.tokenSlice(self.current), "flags")) {
|
||||
is_flags = true;
|
||||
self.advance();
|
||||
}
|
||||
|
||||
try self.expect(.l_brace);
|
||||
var variant_names = std.ArrayList([]const u8).empty;
|
||||
var variant_types = std.ArrayList(?*Node).empty;
|
||||
var variant_values = std.ArrayList(?*Node).empty;
|
||||
var has_any_type = false;
|
||||
var has_any_value = false;
|
||||
while (self.current.tag != .r_brace and self.current.tag != .eof) {
|
||||
if (self.current.tag != .identifier) {
|
||||
return self.fail("expected variant name");
|
||||
}
|
||||
try variant_names.append(self.allocator, self.tokenSlice(self.current));
|
||||
self.advance();
|
||||
if (self.current.tag == .colon) {
|
||||
if (self.current.tag == .colon_colon) {
|
||||
// Explicit value: name :: expr;
|
||||
if (!is_flags) {
|
||||
return self.fail("explicit enum values require 'enum flags'");
|
||||
}
|
||||
self.advance();
|
||||
const val_expr = try self.parseExpr();
|
||||
try variant_values.append(self.allocator, val_expr);
|
||||
try variant_types.append(self.allocator, null);
|
||||
has_any_value = true;
|
||||
} else if (self.current.tag == .colon) {
|
||||
// Typed variant: name: type;
|
||||
if (is_flags) {
|
||||
return self.fail("flags enum variants cannot have payloads");
|
||||
}
|
||||
self.advance();
|
||||
const vtype = try self.parseTypeExpr();
|
||||
try variant_types.append(self.allocator, vtype);
|
||||
try variant_values.append(self.allocator, null);
|
||||
has_any_type = true;
|
||||
} else {
|
||||
// Void variant: name;
|
||||
try variant_types.append(self.allocator, null);
|
||||
try variant_values.append(self.allocator, null);
|
||||
}
|
||||
if (self.current.tag == .semicolon) {
|
||||
self.advance();
|
||||
@@ -411,6 +436,8 @@ pub const Parser = struct {
|
||||
.name = name,
|
||||
.variant_names = try variant_names.toOwnedSlice(self.allocator),
|
||||
.variant_types = if (has_any_type) try variant_types.toOwnedSlice(self.allocator) else &.{},
|
||||
.is_flags = is_flags,
|
||||
.variant_values = if (has_any_value) try variant_values.toOwnedSlice(self.allocator) else &.{},
|
||||
} });
|
||||
}
|
||||
|
||||
@@ -1539,6 +1566,8 @@ pub const Parser = struct {
|
||||
return switch (self.current.tag) {
|
||||
.kw_or => 1,
|
||||
.kw_and => 2,
|
||||
.pipe => 3,
|
||||
.ampersand => 3,
|
||||
.equal_equal, .bang_equal, .less, .less_equal, .greater, .greater_equal => 4,
|
||||
.plus, .minus => 5,
|
||||
.star, .slash, .percent => 6,
|
||||
@@ -1550,6 +1579,8 @@ pub const Parser = struct {
|
||||
return switch (self.current.tag) {
|
||||
.kw_and => .and_op,
|
||||
.kw_or => .or_op,
|
||||
.pipe => .bit_or,
|
||||
.ampersand => .bit_and,
|
||||
.plus => .add,
|
||||
.minus => .sub,
|
||||
.star => .mul,
|
||||
|
||||
@@ -60,6 +60,7 @@ pub const Tag = enum {
|
||||
percent, // %
|
||||
percent_equal, // %=
|
||||
ampersand, // &
|
||||
pipe, // |
|
||||
|
||||
// Delimiters
|
||||
l_paren, // (
|
||||
@@ -115,6 +116,7 @@ pub const Tag = enum {
|
||||
.percent => "%",
|
||||
.percent_equal => "%=",
|
||||
.ampersand => "&",
|
||||
.pipe => "|",
|
||||
.kw_null => "null",
|
||||
.l_paren => "(",
|
||||
.r_paren => ")",
|
||||
|
||||
Reference in New Issue
Block a user