Files
sx/src/types.zig
2026-02-10 22:47:43 +02:00

399 lines
13 KiB
Zig
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
const std = @import("std");
const ast = @import("ast.zig");
const Node = ast.Node;
pub const Type = union(enum) {
// Variable-width integers (164 bits)
signed: u8,
unsigned: u8,
// Fixed-width floats
f32,
f64,
// Other
void_type,
boolean,
string_type,
enum_type: []const u8,
struct_type: []const u8,
union_type: []const u8,
array_type: ArrayTypeInfo,
slice_type: SliceTypeInfo,
pointer_type: PointerTypeInfo,
many_pointer_type: ManyPointerTypeInfo,
vector_type: VectorTypeInfo,
any_type,
meta_type: MetaTypeInfo,
pub const SliceTypeInfo = struct {
element_name: []const u8,
};
pub const PointerTypeInfo = struct {
pointee_name: []const u8,
};
pub const ManyPointerTypeInfo = struct {
element_name: []const u8,
};
pub const ArrayTypeInfo = struct {
element_name: []const u8,
length: u32,
};
pub const VectorTypeInfo = struct {
element_name: []const u8,
length: u32,
};
pub const MetaTypeInfo = struct {
name: []const u8,
};
// Convenience constructors
pub fn s(width: u8) Type {
return .{ .signed = width };
}
pub fn u(width: u8) Type {
return .{ .unsigned = width };
}
pub fn fromName(name: []const u8) ?Type {
// Named types (check before variable-width integers since "string" starts with 's')
if (std.mem.eql(u8, name, "string")) return .string_type;
if (std.mem.eql(u8, name, "bool")) return .boolean;
if (std.mem.eql(u8, name, "f32")) return .f32;
if (std.mem.eql(u8, name, "f64")) return .f64;
if (std.mem.eql(u8, name, "Any")) return .any_type;
// Many-pointer: [*]T
if (name.len >= 4 and name[0] == '[' and name[1] == '*' and name[2] == ']') {
return .{ .many_pointer_type = .{ .element_name = name[3..] } };
}
// Pointer: *T
if (name.len >= 2 and name[0] == '*') {
return .{ .pointer_type = .{ .pointee_name = name[1..] } };
}
// Variable-width integers: s1..s64, u1..u64
if (name.len >= 2 and (name[0] == 's' or name[0] == 'u')) {
const width = std.fmt.parseInt(u8, name[1..], 10) catch return null;
if (width < 1 or width > 64) return null;
return if (name[0] == 's') Type.s(width) else Type.u(width);
}
return null;
}
pub fn fromTypeExpr(node: *Node) ?Type {
if (node.data != .type_expr) return null;
return fromName(node.data.type_expr.name);
}
pub fn isEnum(self: Type) bool {
return switch (self) {
.enum_type => true,
else => false,
};
}
pub fn isStruct(self: Type) bool {
return switch (self) {
.struct_type => true,
else => false,
};
}
pub fn isUnion(self: Type) bool {
return switch (self) {
.union_type => true,
else => false,
};
}
pub fn isAny(self: Type) bool {
return switch (self) {
.any_type => true,
else => false,
};
}
pub fn isSlice(self: Type) bool {
return switch (self) {
.slice_type => true,
else => false,
};
}
pub fn sliceElementType(self: Type) ?Type {
return switch (self) {
.slice_type => |info| fromName(info.element_name),
else => null,
};
}
pub fn isPointer(self: Type) bool {
return switch (self) {
.pointer_type => true,
else => false,
};
}
pub fn pointerPointeeType(self: Type) ?Type {
return switch (self) {
.pointer_type => |info| fromName(info.pointee_name),
else => null,
};
}
pub fn isManyPointer(self: Type) bool {
return switch (self) {
.many_pointer_type => true,
else => false,
};
}
pub fn manyPointerElementType(self: Type) ?Type {
return switch (self) {
.many_pointer_type => |info| fromName(info.element_name),
else => null,
};
}
pub fn isArray(self: Type) bool {
return switch (self) {
.array_type => true,
else => false,
};
}
pub fn isVector(self: Type) bool {
return switch (self) {
.vector_type => true,
else => false,
};
}
pub fn vectorElementType(self: Type) ?Type {
return switch (self) {
.vector_type => |info| fromName(info.element_name),
else => null,
};
}
pub fn isFloat(self: Type) bool {
return switch (self) {
.f32, .f64 => true,
else => false,
};
}
pub fn isInt(self: Type) bool {
return self.isSigned() or self.isUnsigned();
}
pub fn isSigned(self: Type) bool {
return switch (self) {
.signed => true,
else => false,
};
}
pub fn isUnsigned(self: Type) bool {
return switch (self) {
.unsigned => true,
else => false,
};
}
pub fn bitWidth(self: Type) u32 {
return switch (self) {
.signed => |w| w,
.unsigned => |w| w,
.f32 => 32,
.f64 => 64,
.boolean => 1,
else => 0,
};
}
/// Check if this type can be implicitly converted to `target` without `xx`.
/// Safe (implicit) conversions:
/// - Same type
/// - Both unsigned int, target width >= source width
/// - Both signed int, target width >= source width
/// - Unsigned to signed, target width strictly > source width
/// - Any int to any float
/// - Float to wider float (f32 → f64)
/// Everything else requires `xx`.
pub fn isImplicitlyConvertibleTo(self: Type, target: Type) bool {
if (std.meta.eql(self, target)) return true;
// Slice types: compare element names by content (not pointer)
if (self.isSlice() and target.isSlice()) {
return std.mem.eql(u8, self.slice_type.element_name, target.slice_type.element_name);
}
// Pointer types: compare pointee names by content, null (*void) → any pointer
if (self.isPointer() and target.isPointer()) {
if (std.mem.eql(u8, self.pointer_type.pointee_name, "void")) return true;
return std.mem.eql(u8, self.pointer_type.pointee_name, target.pointer_type.pointee_name);
}
// Many-pointer types: compare element names by content
if (self.isManyPointer() and target.isManyPointer()) {
return std.mem.eql(u8, self.many_pointer_type.element_name, target.many_pointer_type.element_name);
}
// *T → [*]T: pointer to element is implicitly convertible to many-pointer
if (self.isPointer() and target.isManyPointer()) {
return std.mem.eql(u8, self.pointer_type.pointee_name, target.many_pointer_type.element_name);
}
const src_float = self.isFloat();
const dst_float = target.isFloat();
const src_int = self.isInt();
// Float → wider float
if (src_float and dst_float) {
return target.bitWidth() >= self.bitWidth();
}
// Int → float (always safe)
if (src_int and dst_float) return true;
// Both unsigned → target width >= source width
if (self.isUnsigned() and target.isUnsigned()) {
return target.bitWidth() >= self.bitWidth();
}
// Both signed → target width >= source width
if (self.isSigned() and target.isSigned()) {
return target.bitWidth() >= self.bitWidth();
}
// Unsigned → signed: target must be strictly wider
if (self.isUnsigned() and target.isSigned()) {
return target.bitWidth() > self.bitWidth();
}
// Everything else requires xx
return false;
}
/// Format type name for mangling and display (e.g. "s32", "u8", "f64")
pub fn displayName(self: Type, allocator: std.mem.Allocator) ![]const u8 {
return switch (self) {
.signed => |w| {
var buf = std.ArrayList(u8).empty;
try buf.append(allocator, 's');
var tmp: [4]u8 = undefined;
const width_str = std.fmt.bufPrint(&tmp, "{d}", .{w}) catch unreachable;
try buf.appendSlice(allocator, width_str);
return try buf.toOwnedSlice(allocator);
},
.unsigned => |w| {
var buf = std.ArrayList(u8).empty;
try buf.append(allocator, 'u');
var tmp: [4]u8 = undefined;
const width_str = std.fmt.bufPrint(&tmp, "{d}", .{w}) catch unreachable;
try buf.appendSlice(allocator, width_str);
return try buf.toOwnedSlice(allocator);
},
.f32 => "f32",
.f64 => "f64",
.boolean => "bool",
.string_type => "string",
.void_type => "void",
.any_type => "Any",
.enum_type => |name| name,
.struct_type => |name| name,
.union_type => |name| name,
.slice_type => |info| {
var buf = std.ArrayList(u8).empty;
try buf.appendSlice(allocator, "[]");
try buf.appendSlice(allocator, info.element_name);
return try buf.toOwnedSlice(allocator);
},
.pointer_type => |info| {
var buf = std.ArrayList(u8).empty;
try buf.append(allocator, '*');
try buf.appendSlice(allocator, info.pointee_name);
return try buf.toOwnedSlice(allocator);
},
.many_pointer_type => |info| {
var buf = std.ArrayList(u8).empty;
try buf.appendSlice(allocator, "[*]");
try buf.appendSlice(allocator, info.element_name);
return try buf.toOwnedSlice(allocator);
},
.array_type => |info| {
var buf = std.ArrayList(u8).empty;
try buf.append(allocator, '[');
var tmp: [10]u8 = undefined;
const len_str = std.fmt.bufPrint(&tmp, "{d}", .{info.length}) catch unreachable;
try buf.appendSlice(allocator, len_str);
try buf.append(allocator, ']');
try buf.appendSlice(allocator, info.element_name);
return try buf.toOwnedSlice(allocator);
},
.vector_type => |info| {
var buf = std.ArrayList(u8).empty;
try buf.appendSlice(allocator, "Vector(");
var tmp: [10]u8 = undefined;
const len_str = std.fmt.bufPrint(&tmp, "{d}", .{info.length}) catch unreachable;
try buf.appendSlice(allocator, len_str);
try buf.appendSlice(allocator, ",");
try buf.appendSlice(allocator, info.element_name);
try buf.append(allocator, ')');
return try buf.toOwnedSlice(allocator);
},
.meta_type => |info| info.name,
};
}
/// Widen two types to a common type for binary operations.
/// Used for arithmetic type promotion (e.g., s16 + s32 → s32, int + float → float).
pub fn widen(a: Type, b: Type) Type {
// Same type → return it
if (std.meta.eql(a, b)) return a;
// Vector + vector of same dimensions → return a
if (a.isVector() and b.isVector()) return a;
// Vector + scalar → return vector (scalar will be broadcast)
if (a.isVector() and !b.isVector()) return a;
if (b.isVector() and !a.isVector()) return b;
const a_float = a.isFloat();
const b_float = b.isFloat();
const a_int = a.isInt();
const b_int = b.isInt();
// Both float → wider float
if (a_float and b_float) {
return if (a.bitWidth() >= b.bitWidth()) a else b;
}
// int + float → float
if (a_int and b_float) return b;
if (b_int and a_float) return a;
// Both signed → wider signed
if (a.isSigned() and b.isSigned()) {
return Type.s(@intCast(@max(a.bitWidth(), b.bitWidth())));
}
// Both unsigned → wider unsigned
if (a.isUnsigned() and b.isUnsigned()) {
return Type.u(@intCast(@max(a.bitWidth(), b.bitWidth())));
}
// signed + unsigned (mixed)
if (a_int and b_int) {
const aw = a.bitWidth();
const bw = b.bitWidth();
const max_w = @max(aw, bw);
// If same width, need one extra bit for sign; otherwise max is enough
const need: u32 = if (aw == bw) max_w + 1 else max_w;
const capped: u8 = @intCast(@min(need, 128));
return Type.s(capped);
}
return a;
}
};