...
This commit is contained in:
@@ -196,6 +196,7 @@ pub const MatchArm = struct {
|
||||
pattern: ?*Node, // null = else (default) arm
|
||||
body: *Node,
|
||||
is_break: bool,
|
||||
capture: ?[]const u8 = null, // payload binding name: case .variant: (name) { ... }
|
||||
};
|
||||
|
||||
pub const ConstDecl = struct {
|
||||
|
||||
463
src/codegen.zig
463
src/codegen.zig
@@ -247,8 +247,9 @@ pub const CodeGen = struct {
|
||||
const TaggedEnumInfo = struct {
|
||||
variant_names: []const []const u8,
|
||||
variant_types: []const Type, // void_type for void variants
|
||||
llvm_type: c.LLVMTypeRef, // { i32, [max_payload_size x i8] }
|
||||
llvm_type: c.LLVMTypeRef, // layout struct or { tag, [max_payload_size x i8] }
|
||||
max_payload_size: u64,
|
||||
payload_field_index: c_uint = 1, // struct field index of the payload array
|
||||
};
|
||||
|
||||
const PromotedField = struct {
|
||||
@@ -281,13 +282,33 @@ pub const CodeGen = struct {
|
||||
const module = c.LLVMModuleCreateWithNameInContext(module_name, ctx);
|
||||
const builder = c.LLVMCreateBuilderInContext(ctx);
|
||||
|
||||
// Set target triple on module so it appears in IR output
|
||||
if (target_config.triple) |t| {
|
||||
c.LLVMSetTarget(module, t);
|
||||
// Initialize LLVM targets and set data layout early so alignment queries work
|
||||
llvm.initAllTargets();
|
||||
|
||||
const triple_owned = target_config.triple == null;
|
||||
const triple = target_config.triple orelse c.LLVMGetDefaultTargetTriple();
|
||||
defer if (triple_owned) c.LLVMDisposeMessage(@constCast(triple));
|
||||
|
||||
c.LLVMSetTarget(module, triple);
|
||||
|
||||
var target: c.LLVMTargetRef = null;
|
||||
var err_msg: [*c]u8 = null;
|
||||
if (c.LLVMGetTargetFromTriple(triple, &target, &err_msg) == 0) {
|
||||
const tm = c.LLVMCreateTargetMachine(
|
||||
target,
|
||||
triple,
|
||||
target_config.getCpu(),
|
||||
target_config.getFeatures(),
|
||||
target_config.opt_level.toLLVM(),
|
||||
c.LLVMRelocPIC,
|
||||
c.LLVMCodeModelDefault,
|
||||
);
|
||||
const dl = c.LLVMCreateTargetDataLayout(tm);
|
||||
c.LLVMSetModuleDataLayout(module, dl);
|
||||
c.LLVMDisposeTargetData(dl);
|
||||
c.LLVMDisposeTargetMachine(tm);
|
||||
} else {
|
||||
const default_triple = c.LLVMGetDefaultTargetTriple();
|
||||
c.LLVMSetTarget(module, default_triple);
|
||||
c.LLVMDisposeMessage(default_triple);
|
||||
if (err_msg != null) c.LLVMDisposeMessage(err_msg);
|
||||
}
|
||||
return .{
|
||||
.context = ctx,
|
||||
@@ -1285,6 +1306,7 @@ pub const CodeGen = struct {
|
||||
.variant_types = build.variant_sx_types,
|
||||
.llvm_type = build.llvm_type,
|
||||
.max_payload_size = build.max_payload_size,
|
||||
.payload_field_index = build.payload_field_index,
|
||||
});
|
||||
_ = try self.getAnyTypeId(mangled_name, .{ .union_type = mangled_name });
|
||||
|
||||
@@ -1928,6 +1950,16 @@ pub const CodeGen = struct {
|
||||
const info = self.struct_types.get(sname) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{sname});
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, val, "retval");
|
||||
_ = c.LLVMBuildRet(self.builder, loaded);
|
||||
} else if (ret_sx_type.isUnion()) {
|
||||
// Tagged enum implicit return: val may be alloca or loaded value
|
||||
const uname = ret_sx_type.union_type;
|
||||
const resolved = self.type_aliases.get(uname) orelse uname;
|
||||
const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved});
|
||||
const ret_val2 = if (c.LLVMGetTypeKind(c.LLVMTypeOf(val)) == c.LLVMPointerTypeKind)
|
||||
c.LLVMBuildLoad2(self.builder, info.llvm_type, val, "retval")
|
||||
else
|
||||
val;
|
||||
_ = c.LLVMBuildRet(self.builder, ret_val2);
|
||||
} else {
|
||||
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(val));
|
||||
const ret_val = self.convertValue(val, src_ty, self.current_return_type);
|
||||
@@ -2090,6 +2122,17 @@ pub const CodeGen = struct {
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval");
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRet(self.builder, loaded);
|
||||
} else if (self.current_return_type.isUnion()) {
|
||||
// Tagged enum return: raw_val may be alloca (enum literal) or loaded value (identifier/call)
|
||||
const uname = self.current_return_type.union_type;
|
||||
const resolved = self.type_aliases.get(uname) orelse uname;
|
||||
const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved});
|
||||
const ret_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(raw_val)) == c.LLVMPointerTypeKind)
|
||||
c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval")
|
||||
else
|
||||
raw_val;
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRet(self.builder, ret_val);
|
||||
} else {
|
||||
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val));
|
||||
const val = self.convertValue(raw_val, src_ty, self.current_return_type);
|
||||
@@ -2253,15 +2296,23 @@ pub const CodeGen = struct {
|
||||
try self.named_values.put(vd.name, .{ .ptr = lit_alloca, .ty = sx_ty });
|
||||
return null;
|
||||
} else if (vd.value.?.data == .call) {
|
||||
// Call returning a union (e.g., Shape.circle(3.14)) — genExpr returns alloca
|
||||
const result_alloca = try self.genExpr(vd.value.?);
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result_alloca, "union_load");
|
||||
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
|
||||
// Call returning a union — could be enum construction (alloca) or function call (value)
|
||||
const result = try self.genExpr(vd.value.?);
|
||||
if (c.LLVMGetTypeKind(c.LLVMTypeOf(result)) == c.LLVMPointerTypeKind) {
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result, "union_load");
|
||||
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
|
||||
} else {
|
||||
_ = c.LLVMBuildStore(self.builder, result, alloca);
|
||||
}
|
||||
} else {
|
||||
// Other expression — try genExprAsType
|
||||
const result_alloca = try self.genExprAsType(vd.value.?, sx_ty);
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result_alloca, "union_load");
|
||||
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
|
||||
const result = try self.genExprAsType(vd.value.?, sx_ty);
|
||||
if (c.LLVMGetTypeKind(c.LLVMTypeOf(result)) == c.LLVMPointerTypeKind) {
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, result, "union_load");
|
||||
_ = c.LLVMBuildStore(self.builder, loaded, alloca);
|
||||
} else {
|
||||
_ = c.LLVMBuildStore(self.builder, result, alloca);
|
||||
}
|
||||
}
|
||||
|
||||
try self.saveShadowed(vd.name);
|
||||
@@ -2564,13 +2615,16 @@ pub const CodeGen = struct {
|
||||
return null;
|
||||
}
|
||||
|
||||
// Tagged enum reassignment: s = .circle(3.14) or s = .none
|
||||
// Tagged enum reassignment: s = .circle(3.14) or s = .none or s = fn_call()
|
||||
if (entry.ty.isUnion() and asgn.op == .assign) {
|
||||
if (self.tagged_enum_types.get(entry.ty.union_type)) |info| {
|
||||
const new_alloca = try self.genExprAsType(asgn.value, entry.ty);
|
||||
// Copy from new alloca to existing alloca
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, info.llvm_type, new_alloca, "union_load");
|
||||
_ = c.LLVMBuildStore(self.builder, loaded, entry.ptr);
|
||||
const new_val = try self.genExprAsType(asgn.value, entry.ty);
|
||||
// genExprAsType returns alloca for enum literals, loaded value for calls
|
||||
const store_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(new_val)) == c.LLVMPointerTypeKind)
|
||||
c.LLVMBuildLoad2(self.builder, info.llvm_type, new_val, "union_load")
|
||||
else
|
||||
new_val;
|
||||
_ = c.LLVMBuildStore(self.builder, store_val, entry.ptr);
|
||||
return null;
|
||||
}
|
||||
// C-style union: full assignment not supported, use field assignment
|
||||
@@ -2860,6 +2914,33 @@ pub const CodeGen = struct {
|
||||
const lhs_ty = self.inferType(binop.lhs);
|
||||
const rhs_ty = self.inferType(binop.rhs);
|
||||
const result_type = Type.widen(lhs_ty, rhs_ty);
|
||||
|
||||
// Tagged enum comparison: compare tags only
|
||||
if (result_type.isUnion() and (binop.op == .eq or binop.op == .neq)) {
|
||||
const uname = result_type.union_type;
|
||||
const resolved = self.type_aliases.get(uname) orelse uname;
|
||||
const info = self.tagged_enum_types.get(resolved) orelse return self.emitError("unknown tagged enum type");
|
||||
const tag_ty = self.getEnumLLVMType(resolved);
|
||||
|
||||
var lhs_val = try self.genExprAsType(binop.lhs, result_type);
|
||||
var rhs_val = try self.genExprAsType(binop.rhs, result_type);
|
||||
|
||||
// If either side is a pointer (alloca from genTaggedEnumLiteral), load it
|
||||
if (c.LLVMGetTypeKind(c.LLVMTypeOf(lhs_val)) == c.LLVMPointerTypeKind) {
|
||||
lhs_val = c.LLVMBuildLoad2(self.builder, info.llvm_type, lhs_val, "union_load_l");
|
||||
}
|
||||
if (c.LLVMGetTypeKind(c.LLVMTypeOf(rhs_val)) == c.LLVMPointerTypeKind) {
|
||||
rhs_val = c.LLVMBuildLoad2(self.builder, info.llvm_type, rhs_val, "union_load_r");
|
||||
}
|
||||
|
||||
// Extract tags (field 0) and compare
|
||||
const lhs_tag = c.LLVMBuildExtractValue(self.builder, lhs_val, 0, "lhs_tag");
|
||||
const rhs_tag = c.LLVMBuildExtractValue(self.builder, rhs_val, 0, "rhs_tag");
|
||||
_ = tag_ty;
|
||||
const pred: c_uint = if (binop.op == .eq) c.LLVMIntEQ else c.LLVMIntNE;
|
||||
return c.LLVMBuildICmp(self.builder, pred, lhs_tag, rhs_tag, "tag_cmp");
|
||||
}
|
||||
|
||||
const lhs = try self.genExprAsType(binop.lhs, result_type);
|
||||
const rhs = try self.genExprAsType(binop.rhs, result_type);
|
||||
return self.genBinaryOp(binop.op, lhs, rhs, result_type);
|
||||
@@ -2895,6 +2976,15 @@ pub const CodeGen = struct {
|
||||
.xx, .address_of => unreachable,
|
||||
};
|
||||
},
|
||||
.enum_literal => |el| {
|
||||
if (self.current_return_type.isUnion()) {
|
||||
return self.genTaggedEnumLiteral(el, self.current_return_type.union_type);
|
||||
}
|
||||
if (self.current_return_type.isEnum()) {
|
||||
return self.genEnumLiteral(el.name, self.current_return_type.enum_type);
|
||||
}
|
||||
return self.emitError("cannot infer enum type for literal");
|
||||
},
|
||||
.struct_literal => |sl| {
|
||||
const ctx_name: ?[]const u8 = if (self.current_return_type.isStruct()) self.current_return_type.struct_type else null;
|
||||
return self.genStructLiteral(sl, ctx_name);
|
||||
@@ -2976,10 +3066,29 @@ pub const CodeGen = struct {
|
||||
.return_stmt => |rs| {
|
||||
if (rs.value) |val_node| {
|
||||
const raw_val = try self.genExpr(val_node);
|
||||
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val));
|
||||
const val = self.convertValue(raw_val, src_ty, self.current_return_type);
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRet(self.builder, val);
|
||||
if (self.current_return_type.isStruct()) {
|
||||
const sname = self.current_return_type.struct_type;
|
||||
const resolved = self.type_aliases.get(sname) orelse sname;
|
||||
const sinfo = self.struct_types.get(resolved) orelse return self.emitErrorFmt("unknown struct type '{s}'", .{resolved});
|
||||
const loaded = c.LLVMBuildLoad2(self.builder, sinfo.llvm_type, raw_val, "retval");
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRet(self.builder, loaded);
|
||||
} else if (self.current_return_type.isUnion()) {
|
||||
const uname = self.current_return_type.union_type;
|
||||
const resolved = self.type_aliases.get(uname) orelse uname;
|
||||
const info = self.tagged_enum_types.get(resolved) orelse return self.emitErrorFmt("unknown enum type '{s}'", .{resolved});
|
||||
const ret_val = if (c.LLVMGetTypeKind(c.LLVMTypeOf(raw_val)) == c.LLVMPointerTypeKind)
|
||||
c.LLVMBuildLoad2(self.builder, info.llvm_type, raw_val, "retval")
|
||||
else
|
||||
raw_val;
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRet(self.builder, ret_val);
|
||||
} else {
|
||||
const src_ty = self.llvmTypeToSxType(c.LLVMTypeOf(raw_val));
|
||||
const val = self.convertValue(raw_val, src_ty, self.current_return_type);
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRet(self.builder, val);
|
||||
}
|
||||
} else {
|
||||
try self.emitAllDefers();
|
||||
_ = c.LLVMBuildRetVoid(self.builder);
|
||||
@@ -3111,6 +3220,7 @@ pub const CodeGen = struct {
|
||||
variant_sx_types: []const Type,
|
||||
llvm_type: c.LLVMTypeRef,
|
||||
max_payload_size: u64,
|
||||
payload_field_index: c_uint,
|
||||
};
|
||||
|
||||
fn buildUnionFields(self: *CodeGen, name: []const u8, variant_type_nodes: []const ?*Node) !UnionBuildResult {
|
||||
@@ -3134,6 +3244,7 @@ pub const CodeGen = struct {
|
||||
const union_ty = c.LLVMStructCreateNamed(self.context, name_z.ptr);
|
||||
const tag_ty = self.getEnumLLVMType(name);
|
||||
const i8_ty = c.LLVMInt8TypeInContext(self.context);
|
||||
|
||||
const payload_array_ty = c.LLVMArrayType2(i8_ty, max_payload_size);
|
||||
var union_fields = [2]c.LLVMTypeRef{ tag_ty, payload_array_ty };
|
||||
c.LLVMStructSetBody(union_ty, &union_fields, 2, 0);
|
||||
@@ -3142,6 +3253,7 @@ pub const CodeGen = struct {
|
||||
.variant_sx_types = try variant_sx_types.toOwnedSlice(self.allocator),
|
||||
.llvm_type = union_ty,
|
||||
.max_payload_size = max_payload_size,
|
||||
.payload_field_index = 1,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -3226,21 +3338,176 @@ pub const CodeGen = struct {
|
||||
}
|
||||
}
|
||||
|
||||
// Register backing type before buildUnionFields (which uses getEnumLLVMType)
|
||||
if (ud.backing_type) |bt_node| {
|
||||
const bt = self.resolveType(bt_node);
|
||||
try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt));
|
||||
// Check if backing type is a struct layout specification
|
||||
const layout_info = try self.resolveEnumLayout(ud);
|
||||
|
||||
if (layout_info) |layout| {
|
||||
// Struct-backed layout: use the struct's LLVM type directly
|
||||
try self.enum_backing_types.put(ud.name, layout.tag_llvm_type);
|
||||
|
||||
// Resolve variant sx types
|
||||
var variant_sx_types = std.ArrayList(Type).empty;
|
||||
for (ud.variant_types) |vt| {
|
||||
if (vt) |type_node| {
|
||||
try variant_sx_types.append(self.allocator, self.resolveType(type_node));
|
||||
} else {
|
||||
try variant_sx_types.append(self.allocator, .void_type);
|
||||
}
|
||||
}
|
||||
|
||||
try self.tagged_enum_types.put(ud.name, .{
|
||||
.variant_names = ud.variant_names,
|
||||
.variant_types = try variant_sx_types.toOwnedSlice(self.allocator),
|
||||
.llvm_type = layout.llvm_type,
|
||||
.max_payload_size = layout.payload_size,
|
||||
.payload_field_index = layout.payload_field_index,
|
||||
});
|
||||
} else {
|
||||
// Primitive backing type (e.g. enum u32 { ... })
|
||||
if (ud.backing_type) |bt_node| {
|
||||
const bt = self.resolveType(bt_node);
|
||||
try self.enum_backing_types.put(ud.name, self.typeToLLVM(bt));
|
||||
}
|
||||
|
||||
const build = try self.buildUnionFields(ud.name, ud.variant_types);
|
||||
|
||||
try self.tagged_enum_types.put(ud.name, .{
|
||||
.variant_names = ud.variant_names,
|
||||
.variant_types = build.variant_sx_types,
|
||||
.llvm_type = build.llvm_type,
|
||||
.max_payload_size = build.max_payload_size,
|
||||
.payload_field_index = build.payload_field_index,
|
||||
});
|
||||
}
|
||||
|
||||
const build = try self.buildUnionFields(ud.name, ud.variant_types);
|
||||
|
||||
try self.tagged_enum_types.put(ud.name, .{
|
||||
.variant_names = ud.variant_names,
|
||||
.variant_types = build.variant_sx_types,
|
||||
.llvm_type = build.llvm_type,
|
||||
.max_payload_size = build.max_payload_size,
|
||||
});
|
||||
_ = try self.getAnyTypeId(ud.name, .{ .union_type = ud.name });
|
||||
|
||||
// Compute and store variant values (explicit or sequential)
|
||||
const values = try self.allocator.alloc(i64, ud.variant_names.len);
|
||||
for (ud.variant_names, 0..) |_, i| {
|
||||
if (ud.variant_values.len > i and ud.variant_values[i] != null) {
|
||||
const val_node = ud.variant_values[i].?;
|
||||
values[i] = switch (val_node.data) {
|
||||
.int_literal => |il| il.value,
|
||||
else => @as(i64, @intCast(i)),
|
||||
};
|
||||
} else {
|
||||
values[i] = @intCast(i);
|
||||
}
|
||||
}
|
||||
try self.enum_variant_values.put(ud.name, values);
|
||||
}
|
||||
|
||||
const EnumLayoutInfo = struct {
|
||||
llvm_type: c.LLVMTypeRef,
|
||||
tag_llvm_type: c.LLVMTypeRef,
|
||||
payload_field_index: c_uint,
|
||||
payload_size: u64,
|
||||
};
|
||||
|
||||
/// Resolve a struct-backed layout for a tagged enum.
|
||||
/// Returns null if the backing type is a primitive (e.g. u32), in which case
|
||||
/// the caller should fall back to buildUnionFields.
|
||||
///
|
||||
/// The layout struct must have:
|
||||
/// - A field named `tag` (integer type) — the discriminant
|
||||
/// - A field named `payload` (array type) — the overlay area for variant data
|
||||
/// - Any other fields are treated as padding/reserved
|
||||
fn resolveEnumLayout(self: *CodeGen, ud: ast.EnumDecl) !?EnumLayoutInfo {
|
||||
const bt_node = ud.backing_type orelse return null;
|
||||
|
||||
// Check for inline struct: enum struct { ... } { ... }
|
||||
if (bt_node.data == .struct_decl) {
|
||||
const layout_name = try std.fmt.allocPrint(self.allocator, "{s}.__layout", .{ud.name});
|
||||
var sd = bt_node.data.struct_decl;
|
||||
sd.name = layout_name;
|
||||
try self.registerStructType(sd);
|
||||
return try self.validateEnumLayout(ud.name, layout_name);
|
||||
}
|
||||
|
||||
// Check for named struct reference: enum MyLayout { ... }
|
||||
if (bt_node.data == .type_expr) {
|
||||
const name = bt_node.data.type_expr.name;
|
||||
// If it resolves to a primitive type, it's not a layout struct
|
||||
if (Type.fromName(name) != null) return null;
|
||||
// Check type aliases
|
||||
const resolved = self.type_aliases.get(name) orelse name;
|
||||
if (Type.fromName(resolved) != null) return null;
|
||||
// Must be a registered struct
|
||||
if (self.struct_types.contains(resolved)) {
|
||||
return try self.validateEnumLayout(ud.name, resolved);
|
||||
}
|
||||
}
|
||||
|
||||
return null;
|
||||
}
|
||||
|
||||
fn validateEnumLayout(self: *CodeGen, enum_name: []const u8, layout_name: []const u8) !EnumLayoutInfo {
|
||||
const layout = self.struct_types.get(layout_name) orelse {
|
||||
return self.emitErrorFmt("enum '{s}': layout type '{s}' is not a registered struct", .{ enum_name, layout_name });
|
||||
};
|
||||
|
||||
// Find 'tag' field
|
||||
var tag_index: ?usize = null;
|
||||
var payload_index: ?usize = null;
|
||||
for (layout.field_names, 0..) |fname, i| {
|
||||
if (std.mem.eql(u8, fname, "tag")) {
|
||||
tag_index = i;
|
||||
} else if (std.mem.eql(u8, fname, "payload")) {
|
||||
payload_index = i;
|
||||
}
|
||||
}
|
||||
|
||||
if (tag_index == null) {
|
||||
return self.emitErrorFmt(
|
||||
"enum '{s}': layout struct '{s}' must have a field named 'tag' (the discriminant). Expected layout: struct {{ tag: <int_type>; payload: [N]<type>; }}",
|
||||
.{ enum_name, layout_name },
|
||||
);
|
||||
}
|
||||
if (payload_index == null) {
|
||||
return self.emitErrorFmt(
|
||||
"enum '{s}': layout struct '{s}' must have a field named 'payload' (the variant data area). Expected layout: struct {{ tag: <int_type>; payload: [N]<type>; }}",
|
||||
.{ enum_name, layout_name },
|
||||
);
|
||||
}
|
||||
|
||||
const tag_ty = layout.field_types[tag_index.?];
|
||||
const payload_ty = layout.field_types[payload_index.?];
|
||||
|
||||
// Validate tag is an integer type
|
||||
switch (tag_ty) {
|
||||
.signed, .unsigned => {},
|
||||
else => return self.emitErrorFmt(
|
||||
"enum '{s}': layout field 'tag' must be an integer type (e.g. u32), got '{s}'",
|
||||
.{ enum_name, tag_ty.displayName(self.allocator) catch "?" },
|
||||
),
|
||||
}
|
||||
|
||||
// Validate payload is an array type
|
||||
const payload_size = switch (payload_ty) {
|
||||
.array_type => |info| blk: {
|
||||
const elem_ty = Type.fromName(info.element_name) orelse {
|
||||
return self.emitErrorFmt(
|
||||
"enum '{s}': layout field 'payload' has unresolved element type '{s}'",
|
||||
.{ enum_name, info.element_name },
|
||||
);
|
||||
};
|
||||
const elem_llvm = self.typeToLLVM(elem_ty);
|
||||
const data_layout = c.LLVMGetModuleDataLayout(self.module);
|
||||
break :blk c.LLVMStoreSizeOfType(data_layout, elem_llvm) * info.length;
|
||||
},
|
||||
else => return self.emitErrorFmt(
|
||||
"enum '{s}': layout field 'payload' must be an array type (e.g. [30]u32), got '{s}'",
|
||||
.{ enum_name, payload_ty.displayName(self.allocator) catch "?" },
|
||||
),
|
||||
};
|
||||
|
||||
return .{
|
||||
.llvm_type = layout.llvm_type,
|
||||
.tag_llvm_type = self.typeToLLVM(tag_ty),
|
||||
.payload_field_index = @intCast(payload_index.?),
|
||||
.payload_size = payload_size,
|
||||
};
|
||||
}
|
||||
|
||||
fn registerUnionType(self: *CodeGen, ud: ast.UnionDecl) !void {
|
||||
@@ -3320,16 +3587,17 @@ pub const CodeGen = struct {
|
||||
const alloca = self.buildEntryBlockAlloca(info.llvm_type, "union_tmp");
|
||||
const tag_ty = self.getEnumLLVMType(resolved_name);
|
||||
|
||||
// Store tag (field 0)
|
||||
// Store tag (field 0) — use explicit value if available, otherwise index
|
||||
const tag_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 0, "tag");
|
||||
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, idx, 0), tag_gep);
|
||||
const tag_val: u64 = if (self.enum_variant_values.get(resolved_name)) |vals| @bitCast(vals[idx]) else idx;
|
||||
_ = c.LLVMBuildStore(self.builder, c.LLVMConstInt(tag_ty, tag_val, 0), tag_gep);
|
||||
|
||||
// Store payload (field 1) if not void
|
||||
if (el.payload) |payload_node| {
|
||||
const variant_ty = info.variant_types[idx];
|
||||
if (variant_ty != .void_type) {
|
||||
const payload_val = try self.genExprAsType(payload_node, variant_ty);
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload");
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, info.payload_field_index, "payload");
|
||||
// genExprAsType returns a loaded value for all types (including structs)
|
||||
_ = c.LLVMBuildStore(self.builder, payload_val, payload_gep);
|
||||
}
|
||||
@@ -3572,7 +3840,7 @@ pub const CodeGen = struct {
|
||||
.type_expr = null,
|
||||
.field_inits = sl.field_inits,
|
||||
}, payload_struct_name);
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, 1, "payload");
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, alloca, info.payload_field_index, "payload");
|
||||
const payload_llvm_ty = self.typeToLLVM(variant_ty);
|
||||
const struct_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_alloca, "struct_load");
|
||||
_ = c.LLVMBuildStore(self.builder, struct_val, payload_gep);
|
||||
@@ -3687,7 +3955,12 @@ pub const CodeGen = struct {
|
||||
std.mem.eql(u8, src_ty.struct_type, pointee_name) or
|
||||
(if (self.type_aliases.get(src_ty.struct_type)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or
|
||||
(if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, src_ty.struct_type) else false)
|
||||
else if (Type.fromName(pointee_name)) |pointee_ty|
|
||||
else if (src_ty.isUnion()) blk: {
|
||||
const uname = src_ty.union_type;
|
||||
break :blk std.mem.eql(u8, uname, pointee_name) or
|
||||
(if (self.type_aliases.get(uname)) |alias| std.mem.eql(u8, alias, pointee_name) else false) or
|
||||
(if (self.type_aliases.get(pointee_name)) |alias| std.mem.eql(u8, alias, uname) else false);
|
||||
} else if (Type.fromName(pointee_name)) |pointee_ty|
|
||||
src_ty.eql(pointee_ty)
|
||||
else
|
||||
false;
|
||||
@@ -4178,7 +4451,7 @@ pub const CodeGen = struct {
|
||||
// Read tag (field 0)
|
||||
const tag_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 0, "fv_tag_ptr");
|
||||
const tag_val = c.LLVMBuildLoad2(self.builder, self.getEnumLLVMType(val_ty.union_type), tag_ptr, "fv_tag");
|
||||
const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, 1, "fv_payload_ptr");
|
||||
const payload_ptr = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, union_alloca, uinfo.payload_field_index, "fv_payload_ptr");
|
||||
|
||||
const n = uinfo.variant_names.len;
|
||||
const function = self.current_function;
|
||||
@@ -4371,6 +4644,72 @@ pub const CodeGen = struct {
|
||||
return phi;
|
||||
}
|
||||
|
||||
fn genFieldIndex(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 2) return self.emitError("field_index expects 2 arguments: field_index(T, value)");
|
||||
const ty = self.resolveType(call_node.args[0]);
|
||||
const i64_type = c.LLVMInt64TypeInContext(self.context);
|
||||
if (!ty.isEnum()) {
|
||||
_ = try self.genExpr(call_node.args[1]);
|
||||
return c.LLVMConstInt(i64_type, 0, 0);
|
||||
}
|
||||
const enum_name = ty.enum_type;
|
||||
// Flags enums don't use sequential indices
|
||||
if (self.flags_enum_types.contains(enum_name)) {
|
||||
_ = try self.genExpr(call_node.args[1]);
|
||||
return c.LLVMConstInt(i64_type, 0, 0);
|
||||
}
|
||||
const values = self.enum_variant_values.get(enum_name);
|
||||
const variants = self.enum_types.get(enum_name) orelse return try self.genExpr(call_node.args[1]);
|
||||
const n = variants.len;
|
||||
|
||||
const val = try self.genExpr(call_node.args[1]);
|
||||
// Ensure the switch value uses the enum's backing type
|
||||
const enum_llvm_ty = self.getEnumLLVMType(enum_name);
|
||||
const sw_val = if (c.LLVMTypeOf(val) != enum_llvm_ty)
|
||||
c.LLVMBuildIntCast2(self.builder, val, enum_llvm_ty, 0, "fi_cast")
|
||||
else
|
||||
val;
|
||||
|
||||
const function = self.current_function;
|
||||
const merge_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_merge");
|
||||
const default_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_default");
|
||||
const sw = c.LLVMBuildSwitch(self.builder, sw_val, default_bb, @intCast(n));
|
||||
|
||||
var phi_vals = std.ArrayList(c.LLVMValueRef).empty;
|
||||
var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty;
|
||||
var seen_values = std.ArrayList(u64).empty;
|
||||
|
||||
for (0..n) |i| {
|
||||
const explicit_val: u64 = if (values) |vals| @bitCast(vals[i]) else i;
|
||||
// Skip duplicate values (first one wins)
|
||||
var is_dup = false;
|
||||
for (seen_values.items) |sv| {
|
||||
if (sv == explicit_val) { is_dup = true; break; }
|
||||
}
|
||||
if (is_dup) continue;
|
||||
try seen_values.append(self.allocator, explicit_val);
|
||||
const case_bb = c.LLVMAppendBasicBlockInContext(self.context, function, "fi_case");
|
||||
c.LLVMAddCase(sw, c.LLVMConstInt(enum_llvm_ty, explicit_val, 0), case_bb);
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, case_bb);
|
||||
try phi_vals.append(self.allocator, c.LLVMConstInt(i64_type, i, 0));
|
||||
try phi_bbs.append(self.allocator, case_bb);
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
}
|
||||
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, default_bb);
|
||||
const neg_one = c.LLVMConstInt(i64_type, @bitCast(@as(i64, -1)), 0);
|
||||
try phi_vals.append(self.allocator, neg_one);
|
||||
try phi_bbs.append(self.allocator, default_bb);
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
|
||||
c.LLVMPositionBuilderAtEnd(self.builder, merge_bb);
|
||||
const vals_slice = try phi_vals.toOwnedSlice(self.allocator);
|
||||
const bbs_slice = try phi_bbs.toOwnedSlice(self.allocator);
|
||||
const phi = c.LLVMBuildPhi(self.builder, i64_type, "fi_result");
|
||||
c.LLVMAddIncoming(phi, vals_slice.ptr, bbs_slice.ptr, @intCast(vals_slice.len));
|
||||
return phi;
|
||||
}
|
||||
|
||||
fn genCast(self: *CodeGen, call_node: ast.Call) !c.LLVMValueRef {
|
||||
if (call_node.args.len != 2) return self.emitError("cast expects: cast(Type) expr");
|
||||
const target_ty = self.resolveType(call_node.args[0]);
|
||||
@@ -4505,8 +4844,8 @@ pub const CodeGen = struct {
|
||||
const idx = vidx orelse return self.emitErrorFmt("no variant '{s}' in enum '{s}'", .{ fa.field, uname });
|
||||
const variant_ty = info.variant_types[idx];
|
||||
if (variant_ty == .void_type) return self.emitErrorFmt("cannot access payload of void variant '{s}'", .{fa.field});
|
||||
// GEP to field 1 (payload area), load as variant type
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, 1, "payload");
|
||||
// GEP to payload area, load as variant type
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, info.llvm_type, entry.ptr, info.payload_field_index, "payload");
|
||||
return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(variant_ty), payload_gep, "union_payload");
|
||||
}
|
||||
if (entry.ty.isVector()) {
|
||||
@@ -6123,12 +6462,9 @@ pub const CodeGen = struct {
|
||||
// Determine subject type for enum vs union dispatch
|
||||
var enum_name: ?[]const u8 = null;
|
||||
var union_name: ?[]const u8 = null;
|
||||
if (match.subject.data == .identifier) {
|
||||
if (self.named_values.get(match.subject.data.identifier.name)) |entry| {
|
||||
if (entry.ty.isEnum()) enum_name = entry.ty.enum_type;
|
||||
if (entry.ty.isUnion()) union_name = entry.ty.union_type;
|
||||
}
|
||||
}
|
||||
const subject_ty = self.inferType(match.subject);
|
||||
if (subject_ty.isEnum()) enum_name = subject_ty.enum_type;
|
||||
if (subject_ty.isUnion()) union_name = subject_ty.union_type;
|
||||
|
||||
// Get the switch value: for unions, load the tag from field 0; for enums, use the value directly
|
||||
const subject_val: c.LLVMValueRef = if (union_name != null) blk: {
|
||||
@@ -6224,6 +6560,32 @@ pub const CodeGen = struct {
|
||||
// Category/type arm with no matching types — BB is unreachable, skip body
|
||||
_ = c.LLVMBuildBr(self.builder, merge_bb);
|
||||
} else {
|
||||
// Payload capture: bind variant payload as a local variable
|
||||
if (arm.capture) |cap_name| {
|
||||
if (union_name) |un| {
|
||||
const uinfo = self.tagged_enum_types.get(un).?;
|
||||
const pat = arm.pattern.?;
|
||||
if (pat.data == .enum_literal) {
|
||||
const vname = pat.data.enum_literal.name;
|
||||
var vidx: ?usize = null;
|
||||
for (uinfo.variant_names, 0..) |vn, vi| {
|
||||
if (std.mem.eql(u8, vn, vname)) { vidx = vi; break; }
|
||||
}
|
||||
if (vidx) |vi| {
|
||||
const variant_ty = uinfo.variant_types[vi];
|
||||
if (variant_ty != .void_type) {
|
||||
const subject_entry = self.named_values.get(match.subject.data.identifier.name).?;
|
||||
const payload_gep = c.LLVMBuildStructGEP2(self.builder, uinfo.llvm_type, subject_entry.ptr, uinfo.payload_field_index, "cap_payload");
|
||||
const payload_llvm_ty = self.typeToLLVM(variant_ty);
|
||||
const payload_val = c.LLVMBuildLoad2(self.builder, payload_llvm_ty, payload_gep, "cap_load");
|
||||
const cap_alloca = c.LLVMBuildAlloca(self.builder, payload_llvm_ty, @ptrCast(cap_name.ptr));
|
||||
_ = c.LLVMBuildStore(self.builder, payload_val, cap_alloca);
|
||||
try self.named_values.put(cap_name, .{ .ptr = cap_alloca, .ty = variant_ty });
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
// Set match arm context for runtime type dispatch
|
||||
const saved_match_tags = self.current_match_tags;
|
||||
self.current_match_tags = arm_tag_values.items[i];
|
||||
@@ -6428,6 +6790,7 @@ pub const CodeGen = struct {
|
||||
if (std.mem.eql(u8, base, "field_value")) return self.genFieldValue(call_node);
|
||||
if (std.mem.eql(u8, base, "is_flags")) return self.genIsFlags(call_node);
|
||||
if (std.mem.eql(u8, base, "field_value_int")) return self.genFieldValueInt(call_node);
|
||||
if (std.mem.eql(u8, base, "field_index")) return self.genFieldIndex(call_node);
|
||||
return self.emitErrorFmt("unknown builtin function '{s}'", .{name});
|
||||
}
|
||||
|
||||
@@ -6663,6 +7026,8 @@ pub const CodeGen = struct {
|
||||
if (std.mem.eql(u8, base_name, "is_flags")) return .boolean;
|
||||
// Built-in: field_value_int returns s64
|
||||
if (std.mem.eql(u8, base_name, "field_value_int")) return Type.s(64);
|
||||
// Built-in: field_index returns s64
|
||||
if (std.mem.eql(u8, base_name, "field_index")) return Type.s(64);
|
||||
// Built-in: cast returns the target type (first arg)
|
||||
if (std.mem.eql(u8, base_name, "cast")) {
|
||||
if (call_node.args.len > 0) return self.resolveType(call_node.args[0]);
|
||||
|
||||
@@ -493,12 +493,75 @@ pub const Server = struct {
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Bare dot (no prefix) — check if we're inside a match expression
|
||||
// and offer the subject's enum variants (e.g. case .quit)
|
||||
try self.collectMatchEnumCompletions(&items, doc, cursor_offset);
|
||||
}
|
||||
|
||||
const items_json = try lsp.completionListJson(self.allocator, items.items);
|
||||
try self.sendResponse(id_json, items_json);
|
||||
}
|
||||
|
||||
fn collectMatchEnumCompletions(self: *Server, items: *std.ArrayList(lsp.CompletionItem), doc: *const Document, cursor_offset: u32) !void {
|
||||
const root = doc.root orelse return;
|
||||
const sema = doc.sema orelse return;
|
||||
|
||||
// Find enclosing match expression's subject
|
||||
const subject = sx.sema.findEnclosingMatchSubject(root, cursor_offset) orelse return;
|
||||
|
||||
// Resolve the subject to an enum type name
|
||||
const enum_name: ?[]const u8 = switch (subject.data) {
|
||||
.identifier => |id| blk: {
|
||||
// Look up variable type, then check if it's an enum
|
||||
for (sema.symbols) |sym| {
|
||||
if (!std.mem.eql(u8, sym.name, id.name)) continue;
|
||||
if (sym.kind != .variable and sym.kind != .param) continue;
|
||||
const ty = sym.ty orelse break;
|
||||
break :blk switch (ty) {
|
||||
.enum_type => |n| n,
|
||||
.union_type => |n| n,
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
break :blk null;
|
||||
},
|
||||
.field_access => |fa| blk: {
|
||||
// e.g. e.key — resolve the field's type
|
||||
if (fa.object.data == .identifier) {
|
||||
const var_name = fa.object.data.identifier.name;
|
||||
// Find variable's struct type, then look up the field type
|
||||
for (sema.symbols) |sym| {
|
||||
if (!std.mem.eql(u8, sym.name, var_name)) continue;
|
||||
const ty = sym.ty orelse break;
|
||||
const struct_name = switch (ty) {
|
||||
.struct_type => |n| n,
|
||||
else => break,
|
||||
};
|
||||
// Look up the struct's field type
|
||||
if (sema.struct_types.get(struct_name)) |info| {
|
||||
for (info.field_names, 0..) |fname, fi| {
|
||||
if (std.mem.eql(u8, fname, fa.field) and fi < info.field_types.len) {
|
||||
break :blk switch (info.field_types[fi]) {
|
||||
.enum_type => |n| n,
|
||||
.union_type => |n| n,
|
||||
else => null,
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
break :blk null;
|
||||
},
|
||||
else => null,
|
||||
};
|
||||
|
||||
const name = enum_name orelse return;
|
||||
try self.collectMemberCompletions(items, sema, root, name);
|
||||
}
|
||||
|
||||
fn collectDeclCompletions(allocator: std.mem.Allocator, items: *std.ArrayList(lsp.CompletionItem), decls: []const *sx.ast.Node) !void {
|
||||
for (decls) |decl| {
|
||||
switch (decl.data) {
|
||||
@@ -1643,6 +1706,33 @@ pub const Server = struct {
|
||||
if (bt.data == .type_expr) {
|
||||
try buf.appendSlice(allocator, bt.data.type_expr.name);
|
||||
try buf.appendSlice(allocator, " ");
|
||||
} else if (bt.data == .struct_decl) {
|
||||
const sd = bt.data.struct_decl;
|
||||
try buf.appendSlice(allocator, "struct { ");
|
||||
for (sd.field_names, 0..) |fn_, fi| {
|
||||
if (fi > 0) try buf.appendSlice(allocator, "; ");
|
||||
try buf.appendSlice(allocator, fn_);
|
||||
try buf.appendSlice(allocator, ": ");
|
||||
if (fi < sd.field_types.len) {
|
||||
if (sd.field_types[fi].data == .type_expr) {
|
||||
try buf.appendSlice(allocator, sd.field_types[fi].data.type_expr.name);
|
||||
} else if (sd.field_types[fi].data == .array_type_expr) {
|
||||
const ate = sd.field_types[fi].data.array_type_expr;
|
||||
try buf.append(allocator, '[');
|
||||
if (ate.length.data == .int_literal) {
|
||||
const val = ate.length.data.int_literal.value;
|
||||
var num_buf: [20]u8 = undefined;
|
||||
const num_str = std.fmt.bufPrint(&num_buf, "{d}", .{val}) catch "?";
|
||||
try buf.appendSlice(allocator, num_str);
|
||||
}
|
||||
try buf.append(allocator, ']');
|
||||
if (ate.element_type.data == .type_expr) {
|
||||
try buf.appendSlice(allocator, ate.element_type.data.type_expr.name);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
try buf.appendSlice(allocator, " } ");
|
||||
}
|
||||
}
|
||||
try buf.appendSlice(allocator, "{ ");
|
||||
|
||||
@@ -408,15 +408,23 @@ pub const Parser = struct {
|
||||
try variant_names.append(self.allocator, self.tokenSlice(self.current));
|
||||
self.advance();
|
||||
if (self.current.tag == .colon_colon) {
|
||||
// Explicit value: name :: expr;
|
||||
if (!is_flags) {
|
||||
return self.fail("explicit enum values require 'enum flags'");
|
||||
}
|
||||
// Explicit value: name :: expr; or name :: expr: type;
|
||||
self.advance();
|
||||
const val_expr = try self.parseExpr();
|
||||
try variant_values.append(self.allocator, val_expr);
|
||||
try variant_types.append(self.allocator, null);
|
||||
has_any_value = true;
|
||||
// Check for payload type after value: name :: 0x300: KeyData
|
||||
if (self.current.tag == .colon) {
|
||||
if (is_flags) {
|
||||
return self.fail("flags enum variants cannot have payloads");
|
||||
}
|
||||
self.advance();
|
||||
const vtype = try self.parseTypeExpr();
|
||||
try variant_types.append(self.allocator, vtype);
|
||||
has_any_type = true;
|
||||
} else {
|
||||
try variant_types.append(self.allocator, null);
|
||||
}
|
||||
} else if (self.current.tag == .colon) {
|
||||
// Typed variant: name: type;
|
||||
if (is_flags) {
|
||||
@@ -553,7 +561,12 @@ pub const Parser = struct {
|
||||
|
||||
// All names in the group share the same type and default
|
||||
for (group_names.items) |fname| {
|
||||
try field_names.append(self.allocator, fname);
|
||||
// `_` is an ignore identifier — auto-rename to unique internal name
|
||||
const actual_name = if (std.mem.eql(u8, fname, "_"))
|
||||
try std.fmt.allocPrint(self.allocator, "_{d}", .{field_names.items.len})
|
||||
else
|
||||
fname;
|
||||
try field_names.append(self.allocator, actual_name);
|
||||
try field_types.append(self.allocator, field_type);
|
||||
try field_defaults.append(self.allocator, default_val);
|
||||
}
|
||||
@@ -892,6 +905,24 @@ pub const Parser = struct {
|
||||
return try self.createNode(start, .{ .insert_expr = .{ .expr = inner } });
|
||||
}
|
||||
|
||||
// Block-form if/while/for as statements — parse directly to prevent
|
||||
// postfix chaining (e.g. `if cond { ... }.field` being misparsed)
|
||||
if (self.current.tag == .kw_if) {
|
||||
const expr = try self.parseIfExpr();
|
||||
try self.expectSemicolonAfter(expr);
|
||||
return expr;
|
||||
}
|
||||
if (self.current.tag == .kw_while) {
|
||||
const expr = try self.parsePrimary();
|
||||
try self.expectSemicolonAfter(expr);
|
||||
return expr;
|
||||
}
|
||||
if (self.current.tag == .kw_for) {
|
||||
const expr = try self.parsePrimary();
|
||||
try self.expectSemicolonAfter(expr);
|
||||
return expr;
|
||||
}
|
||||
|
||||
// Expression statement
|
||||
const expr = try self.parseExpr();
|
||||
|
||||
@@ -1427,11 +1458,28 @@ pub const Parser = struct {
|
||||
} else try self.parsePrimary(); // .variant
|
||||
try self.expect(.colon);
|
||||
|
||||
// Optional payload capture: (ident)
|
||||
var capture: ?[]const u8 = null;
|
||||
if (self.current.tag == .l_paren) {
|
||||
self.advance();
|
||||
if (self.current.tag != .identifier) return self.fail("expected capture name");
|
||||
capture = self.tokenSlice(self.current);
|
||||
self.advance();
|
||||
try self.expect(.r_paren);
|
||||
}
|
||||
|
||||
if (self.current.tag == .kw_break) {
|
||||
self.advance();
|
||||
try self.expect(.semicolon);
|
||||
const body = try self.createNode(arm_start, .{ .block = .{ .stmts = &.{} } });
|
||||
try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = true });
|
||||
try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = true, .capture = capture });
|
||||
} else if (self.current.tag == .fat_arrow) {
|
||||
// Short form: (ident) => expr;
|
||||
self.advance();
|
||||
const expr = try self.parseExpr();
|
||||
try self.expect(.semicolon);
|
||||
const body = try self.createNode(arm_start, .{ .block = .{ .stmts = try self.allocator.dupe(*Node, &.{expr}) } });
|
||||
try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = false, .capture = capture });
|
||||
} else {
|
||||
const stmts_start = self.current.loc.start;
|
||||
var stmts = std.ArrayList(*Node).empty;
|
||||
@@ -1439,7 +1487,7 @@ pub const Parser = struct {
|
||||
try stmts.append(self.allocator, try self.parseStmt());
|
||||
}
|
||||
const body = try self.createNode(stmts_start, .{ .block = .{ .stmts = try stmts.toOwnedSlice(self.allocator) } });
|
||||
try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = false });
|
||||
try arms.append(self.allocator, .{ .pattern = pattern, .body = body, .is_break = false, .capture = capture });
|
||||
}
|
||||
}
|
||||
// Optional else arm (default)
|
||||
|
||||
92
src/sema.zig
92
src/sema.zig
@@ -148,7 +148,7 @@ pub const Analyzer = struct {
|
||||
});
|
||||
},
|
||||
.const_decl => |cd| {
|
||||
const ty = resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value);
|
||||
const ty = self.resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value);
|
||||
const kind = classifyConstDecl(cd);
|
||||
try self.addSymbol(cd.name, kind, ty, node.span);
|
||||
// Populate type_aliases registry
|
||||
@@ -175,7 +175,7 @@ pub const Analyzer = struct {
|
||||
}
|
||||
},
|
||||
.var_decl => |vd| {
|
||||
const ty = resolveTypeAnnotation(vd.type_annotation);
|
||||
const ty = self.resolveTypeAnnotation(vd.type_annotation);
|
||||
try self.addSymbol(vd.name, .variable, ty, node.span);
|
||||
},
|
||||
.enum_decl => |ed| {
|
||||
@@ -581,7 +581,7 @@ pub const Analyzer = struct {
|
||||
.const_decl => |cd| {
|
||||
// Analyze value first (so it can't reference itself)
|
||||
try self.analyzeNode(cd.value);
|
||||
const ty = resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value);
|
||||
const ty = self.resolveTypeAnnotation(cd.type_annotation) orelse inferValueType(cd.value);
|
||||
const kind = classifyConstDecl(cd);
|
||||
try self.addSymbol(cd.name, kind, ty, node.span);
|
||||
},
|
||||
@@ -589,7 +589,7 @@ pub const Analyzer = struct {
|
||||
if (vd.value) |val| {
|
||||
try self.analyzeNode(val);
|
||||
}
|
||||
const ty = resolveTypeAnnotation(vd.type_annotation) orelse
|
||||
const ty = self.resolveTypeAnnotation(vd.type_annotation) orelse
|
||||
if (vd.value) |val| self.inferExprType(val) else null;
|
||||
try self.addSymbol(vd.name, .variable, ty, node.span);
|
||||
},
|
||||
@@ -637,7 +637,12 @@ pub const Analyzer = struct {
|
||||
.match_expr => |me| {
|
||||
try self.analyzeNode(me.subject);
|
||||
for (me.arms) |arm| {
|
||||
try self.pushScope();
|
||||
if (arm.capture) |cap_name| {
|
||||
try self.addSymbol(cap_name, .variable, null, arm.body.span);
|
||||
}
|
||||
try self.analyzeNode(arm.body);
|
||||
self.popScope();
|
||||
}
|
||||
},
|
||||
.while_expr => |we| {
|
||||
@@ -773,9 +778,19 @@ pub const Analyzer = struct {
|
||||
return null;
|
||||
}
|
||||
|
||||
fn resolveTypeAnnotation(type_node: ?*Node) ?Type {
|
||||
fn resolveTypeAnnotation(self: *Analyzer, type_node: ?*Node) ?Type {
|
||||
if (type_node) |tn| {
|
||||
return Type.fromTypeExpr(tn);
|
||||
if (Type.fromTypeExpr(tn)) |t| return t;
|
||||
// Check registered types (structs, enums, tagged enums)
|
||||
if (tn.data == .type_expr) {
|
||||
const name = tn.data.type_expr.name;
|
||||
// Check type aliases first
|
||||
const resolved = self.type_aliases.get(name) orelse name;
|
||||
for (self.symbols.items) |sym| {
|
||||
if (!std.mem.eql(u8, sym.name, resolved)) continue;
|
||||
if (sym.ty) |ty| return ty;
|
||||
}
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
@@ -986,6 +1001,71 @@ pub fn findNodeAtOffset(node: *Node, offset: u32) ?*Node {
|
||||
return node;
|
||||
}
|
||||
|
||||
/// Find the nearest match_expr ancestor that contains the given offset.
|
||||
/// Returns the match subject node if found, null otherwise.
|
||||
pub fn findEnclosingMatchSubject(node: *Node, offset: u32) ?*Node {
|
||||
if (offset < node.span.start or offset >= node.span.end) return null;
|
||||
|
||||
switch (node.data) {
|
||||
.match_expr => |me| {
|
||||
// First recurse into arm bodies — there might be a nested match
|
||||
for (me.arms) |arm| {
|
||||
if (findEnclosingMatchSubject(arm.body, offset)) |inner| return inner;
|
||||
}
|
||||
// If offset is inside this match_expr (but not in the subject itself),
|
||||
// it's in an arm pattern, between arms, or in a partially-typed arm
|
||||
if (me.subject.span.start <= offset and offset < me.subject.span.end) {
|
||||
// Cursor is on the subject itself, not in an arm
|
||||
} else {
|
||||
return me.subject;
|
||||
}
|
||||
},
|
||||
.root => |r| {
|
||||
for (r.decls) |decl| {
|
||||
if (findEnclosingMatchSubject(decl, offset)) |found| return found;
|
||||
}
|
||||
},
|
||||
.fn_decl => |fd| {
|
||||
if (findEnclosingMatchSubject(fd.body, offset)) |found| return found;
|
||||
},
|
||||
.block => |blk| {
|
||||
for (blk.stmts) |stmt| {
|
||||
if (findEnclosingMatchSubject(stmt, offset)) |found| return found;
|
||||
}
|
||||
},
|
||||
.if_expr => |ie| {
|
||||
if (findEnclosingMatchSubject(ie.then_branch, offset)) |found| return found;
|
||||
if (ie.else_branch) |eb| {
|
||||
if (findEnclosingMatchSubject(eb, offset)) |found| return found;
|
||||
}
|
||||
},
|
||||
.while_expr => |we| {
|
||||
if (findEnclosingMatchSubject(we.body, offset)) |found| return found;
|
||||
},
|
||||
.for_expr => |fe| {
|
||||
if (findEnclosingMatchSubject(fe.body, offset)) |found| return found;
|
||||
},
|
||||
.const_decl => |cd| {
|
||||
if (findEnclosingMatchSubject(cd.value, offset)) |found| return found;
|
||||
},
|
||||
.var_decl => |vd| {
|
||||
if (vd.value) |val| {
|
||||
if (findEnclosingMatchSubject(val, offset)) |found| return found;
|
||||
}
|
||||
},
|
||||
.lambda => |lam| {
|
||||
if (findEnclosingMatchSubject(lam.body, offset)) |found| return found;
|
||||
},
|
||||
.namespace_decl => |ns| {
|
||||
for (ns.decls) |decl| {
|
||||
if (findEnclosingMatchSubject(decl, offset)) |found| return found;
|
||||
}
|
||||
},
|
||||
else => {},
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
test "sema: collect top-level declarations" {
|
||||
const parser_mod = @import("parser.zig");
|
||||
|
||||
|
||||
Reference in New Issue
Block a user