From 1cc67f9b5a29f7b59eff1e4542ff16395d6cebf9 Mon Sep 17 00:00:00 2001 From: agra Date: Sun, 22 Feb 2026 22:16:30 +0200 Subject: [PATCH] optionals --- examples/26-pointers.sx | 2 +- examples/32-optionals.sx | 97 ++++ examples/50-smoke.sx | 239 ++++++++++ specs.md | 116 ++++- src/ast.zig | 19 + src/codegen.zig | 789 ++++++++++++++++++++++++++++++- src/comptime.zig | 70 ++- src/lexer.zig | 11 + src/lsp/server.zig | 324 +++++++++++++ src/lsp/types.zig | 29 +- src/parser.zig | 93 +++- src/sema.zig | 79 +++- src/token.zig | 6 + src/types.zig | 49 ++ tests/expected/32-optionals.exit | 1 + tests/expected/32-optionals.txt | 20 + tests/expected/50-smoke.txt | 40 ++ 17 files changed, 1952 insertions(+), 32 deletions(-) create mode 100644 examples/32-optionals.sx create mode 100644 tests/expected/32-optionals.exit create mode 100644 tests/expected/32-optionals.txt diff --git a/examples/26-pointers.sx b/examples/26-pointers.sx index 483bf26..43d463a 100644 --- a/examples/26-pointers.sx +++ b/examples/26-pointers.sx @@ -16,7 +16,7 @@ main :: () { ptr := @v; copy := ptr.*; print("copy: {}\n", copy); - + // null pointer np : *Vec2 = null; diff --git a/examples/32-optionals.sx b/examples/32-optionals.sx new file mode 100644 index 0000000..101fc83 --- /dev/null +++ b/examples/32-optionals.sx @@ -0,0 +1,97 @@ +#import "modules/std.sx"; + +// --- Type declarations --- +OptNode :: struct { value: s32; next: ?s32; } +OptInner :: struct { val: s32; } +OptOuter :: struct { inner: ?OptInner; } + +// --- Comptime optionals --- +ct_sum :: () -> s32 { + x: ?s32 = 42; + y: ?s32 = null; + return (x ?? 0) + (y ?? 99); +} +CT_RESULT :: #run ct_sum(); + +main :: () -> s32 { + // Basic optional creation + x: ?s32 = 42; + y: ?s32 = null; + print("x = {}\n", x); + print("y = {}\n", y); + + // Force unwrap + print("x! = {}\n", x!); + + // Null coalescing + print("x ?? 0 = {}\n", x ?? 0); + print("y ?? 99 = {}\n", y ?? 99); + + // If-binding (safe unwrap) + if val := x { + print("if-bind x: {}\n", val); + } + if val := y { + print("should not print\n"); + } else { + print("if-bind y: none\n"); + } + + // Pattern matching + check :: (v: ?s32) -> s32 { + return if v == { + case .some: (val) { val; } + case .none: { 0; } + }; + } + print("match some: {}\n", check(42)); + print("match none: {}\n", check(null)); + + // Optional chaining + p: ?OptNode = OptNode.{ value = 10, next = 20 }; + q: ?OptNode = null; + print("p?.value = {}\n", p?.value ?? 0); + print("q?.value = {}\n", q?.value ?? 0); + + // Deep chaining + o1 := OptOuter.{ inner = OptInner.{ val = 99 } }; + o2 := OptOuter.{ inner = null }; + print("o1.inner?.val = {}\n", o1.inner?.val ?? 0); + print("o2.inner?.val = {}\n", o2.inner?.val ?? 0); + + // Flow-sensitive narrowing + a: ?s32 = 10; + b: ?s32 = 20; + if a != null { + print("narrowed a: {}\n", a); + } + + // Guard narrowing + guard :: (v: ?s32) -> s32 { + if v == null { return 0; } + return v; + } + print("guard 42: {}\n", guard(42)); + print("guard null: {}\n", guard(null)); + + // Compound narrowing + if a != null and b != null { + print("both: {} {}\n", a, b); + } + + // Compound guard + guard2 :: (a: ?s32, b: ?s32) -> s32 { + if a == null or b == null { return 0; } + return a + b; + } + print("guard2: {}\n", guard2(3, 4)); + + // Struct field defaults + n := OptNode.{ value = 10 }; + print("default next: {}\n", n.next); + + // Comptime result + print("comptime: {}\n", CT_RESULT); + + return 0; +} diff --git a/examples/50-smoke.sx b/examples/50-smoke.sx index bd16fca..50189c3 100644 --- a/examples/50-smoke.sx +++ b/examples/50-smoke.sx @@ -1,4 +1,5 @@ #import "modules/std.sx"; +#import "modules/math"; pkg :: #import "modules/testpkg"; // ============================================================ @@ -33,6 +34,14 @@ Defaults :: struct { c: s32 = ---; } +OptNode :: struct { + value: s32; + next: ?s32; +} + +OptInner :: struct { val: s32; } +OptOuter :: struct { inner: ?OptInner; } + MyFloat :: f64; Perms :: enum flags { read; write; execute; } @@ -86,6 +95,25 @@ CT_VAL :: #run add(10, 15); CT_MUL :: #run mul(6, 7); CT_CHAIN :: #run add(CT_VAL, 5); +// #run compile-time optional tests +ct_opt_coalesce :: () -> s32 { + x: ?s32 = 42; + y: ?s32 = null; + return (x ?? 0) + (y ?? 99); +} +ct_opt_unwrap :: () -> s32 { + x: ?s32 = 77; + return x!; +} +ct_opt_guard :: () -> s32 { + x: ?s32 = 10; + if x == null { return -1; } + return x; +} +CT_OPT_COALESCE :: #run ct_opt_coalesce(); +CT_OPT_UNWRAP :: #run ct_opt_unwrap(); +CT_OPT_GUARD :: #run ct_opt_guard(); + // #insert helpers gen_code :: () -> string { return "print(\"insert-ok\\n\");"; @@ -1019,6 +1047,11 @@ END; // #run chained dependency print("run-chain: {}\n", CT_CHAIN); + // #run comptime optionals + print("ct-opt-coalesce: {}\n", CT_OPT_COALESCE); // ct-opt-coalesce: 141 + print("ct-opt-unwrap: {}\n", CT_OPT_UNWRAP); // ct-opt-unwrap: 77 + print("ct-opt-guard: {}\n", CT_OPT_GUARD); // ct-opt-guard: 10 + // #insert with function #insert gen_code(); @@ -1480,5 +1513,211 @@ END; } } + // ======================================================== + // OPTIONALS + // ======================================================== + print("--- optionals ---\n"); + + // Basic optional creation and null + { + x: ?s32 = 42; + y: ?s32 = null; + print("opt x: {}\n", x); // opt x: 42 + print("opt y: {}\n", y); // opt y: null + } + + // Force unwrap + { + x: ?s32 = 10; + val := x!; + print("unwrap: {}\n", val); // unwrap: 10 + } + + // Null coalescing + { + x: ?s32 = 42; + y: ?s32 = null; + a := x ?? 0; + b := y ?? 99; + print("coalesce a: {}\n", a); // coalesce a: 42 + print("coalesce b: {}\n", b); // coalesce b: 99 + } + + // If-binding (safe unwrap) + { + x: ?s32 = 7; + y: ?s32 = null; + if val := x { + print("if-bind x: {}\n", val); // if-bind x: 7 + } + if val := y { + print("if-bind y: should not print\n"); + } else { + print("if-bind y: none\n"); // if-bind y: none + } + } + + // Pattern matching on optionals + { + check :: (v: ?s32) -> s32 { + return if v == { + case .some: (val) { val; } + case .none: { 0; } + }; + } + a: ?s32 = 55; + b: ?s32 = null; + print("match some: {}\n", check(a)); // match some: 55 + print("match none: {}\n", check(b)); // match none: 0 + } + + // Optional with implicit wrapping + { + opt_wrap :: (n: s32) -> ?s32 { + if n > 0 { + return n; + } + return null; + } + r1 := opt_wrap(5); + r2 := opt_wrap(0); + print("wrap pos: {}\n", r1); // wrap pos: 5 + print("wrap neg: {}\n", r2); // wrap neg: null + } + + // Struct field defaults for ?T + { + n := OptNode.{ value = 10 }; + print("opt field default: {}\n", n.next); // opt field default: null + m := OptNode.{ value = 20, next = 42 }; + print("opt field set: {}\n", m.next); // opt field set: 42 + } + + // ?T as function parameter + { + opt_process :: (val: ?s32) -> s32 { + return val ?? 0; + } + a: ?s32 = 42; + b: ?s32 = null; + print("opt param a: {}\n", opt_process(a)); // opt param a: 42 + print("opt param b: {}\n", opt_process(b)); // opt param b: 0 + print("opt param 7: {}\n", opt_process(7)); // opt param 7: 7 + } + + // Generic function with ?T return + { + first_pos :: ($T: Type, a: T, b: T) -> ?T { + if a > 0 { return a; } + if b > 0 { return b; } + return null; + } + print("generic opt 1: {}\n", first_pos(s32, 5, 10)); // generic opt 1: 5 + print("generic opt 2: {}\n", first_pos(s32, 0, 7)); // generic opt 2: 7 + print("generic opt 3: {}\n", first_pos(s32, 0, 0)); // generic opt 3: null + } + + // Optional chaining (?.) + { + p: ?OptNode = OptNode.{ value = 10, next = 20 }; + q: ?OptNode = null; + print("chain some: {}\n", p?.value ?? 0); // chain some: 10 + print("chain none: {}\n", q?.value ?? 0); // chain none: 0 + print("chain print: {}\n", p?.next); // chain print: 20 + print("chain null: {}\n", q?.next); // chain null: null + + // Chained: obj.field?.field + o1 := OptOuter.{ inner = OptInner.{ val = 99 } }; + o2 := OptOuter.{ inner = null }; + print("deep chain 1: {}\n", o1.inner?.val ?? 0); // deep chain 1: 99 + print("deep chain 2: {}\n", o2.inner?.val ?? 0); // deep chain 2: 0 + } + + // Flow-sensitive narrowing + { + x: ?s32 = 42; + y: ?s32 = null; + + // if x != null → x is narrowed to s32 + if x != null { + print("narrow x: {}\n", x); // narrow x: 42 + } + + // if y != null → not entered + if y != null { + print("should not print\n"); + } else { + print("narrow y else: null\n"); // narrow y else: null + } + + // if x == null ... else → else-branch narrowed + if x == null { + print("should not print\n"); + } else { + print("narrow else x: {}\n", x); // narrow else x: 42 + } + } + + // Guard narrowing + { + guard_fn :: (v: ?s32) -> s32 { + if v == null { return 0; } + return v; + } + print("guard some: {}\n", guard_fn(42)); // guard some: 42 + print("guard none: {}\n", guard_fn(null)); // guard none: 0 + } + + // Compound narrowing: && chains + { + a: ?s32 = 10; + b: ?s32 = 20; + c: ?s32 = null; + if a != null and b != null { + print("and both: {} {}\n", a, b); // and both: 10 20 + } + if a != null and c != null { + print("should not print\n"); + } else { + print("and one null\n"); // and one null + } + } + + // Compound guard narrowing: || chains + { + guard2 :: (a: ?s32, b: ?s32) -> s32 { + if a == null or b == null { return 0; } + return a + b; + } + print("or guard: {}\n", guard2(3, 4)); // or guard: 7 + print("or guard null: {}\n", guard2(3, null)); // or guard null: 0 + } + + // Nested if narrowing + { + a: ?s32 = 10; + b: ?s32 = 20; + if a != null { + if b != null { + print("nested narrow: {} {}\n", a, b); // nested narrow: 10 20 + } + } + } + + // Guard narrowing used in loop + { + guard_loop :: (v: ?s32) -> s32 { + if v == null { return 0; } + sum := 0; + i := 0; + while i < v { + sum = sum + 1; + i = i + 1; + } + return sum; + } + print("guard loop: {}\n", guard_loop(3)); // guard loop: 3 + } + print("=== DONE ===\n"); } diff --git a/specs.md b/specs.md index 4fb5602..746fbbd 100644 --- a/specs.md +++ b/specs.md @@ -429,7 +429,7 @@ set_x :: (p: *Vec2, val: f32) { set_x(@v, 99.0); ``` -**Null**: All pointer types are nullable. `null` is the null pointer literal. +**Null**: Pointer types are currently nullable by default. `null` is the null pointer literal. ```sx np : *Vec2 = null; ``` @@ -451,6 +451,120 @@ val := mp[2]; // 30 **Fat pointer layout**: `[:0]u8`, `string`, and `[]T` are `{ptr, i64}` structs. The raw pointer is always the first field at offset 0. This means `*[:0]u8` works as C's `char**` — a C function dereferences through the outer pointer and reads the raw `char*` from offset 0. +### Optional Types + +Optional types represent values that may or may not be present. + +#### Type Syntax +```sx +x: ?s32 = 42; // optional s32, has value +y: ?s32 = null; // optional s32, no value +``` + +Any type `T` can be made optional: `?s32`, `?string`, `?Point`, `?*T`, `?[]T`. + +#### LLVM Representation +- Non-pointer optionals (`?s32`, `?Point`): `{ T, i1 }` struct — payload + has_value flag +- Pointer optionals (`?*T`): bare pointer — null represents absence + +#### Implicit Wrapping +A value of type `T` implicitly converts to `?T`: +```sx +wrap :: (n: s32) -> ?s32 { + if n > 0 { return n; } // s32 → ?s32 (wraps) + return null; // null → ?s32 +} +``` + +#### Force Unwrap (`!`) +Extracts the payload, traps at runtime if null: +```sx +x: ?s32 = 42; +val := x!; // val : s32 = 42 +``` + +#### Null Coalescing (`??`) +Returns the payload if present, otherwise evaluates the right-hand side: +```sx +x: ?s32 = 42; +y: ?s32 = null; +a := x ?? 0; // 42 +b := y ?? 99; // 99 +``` + +#### Safe Unwrap (`if val := expr`) +Binds the payload to a variable if present: +```sx +x: ?s32 = 42; +if val := x { + print("{}\n", val); // val : s32 = 42 +} else { + print("none\n"); +} +``` + +#### While-Optional Binding +```sx +while val := get_next() { + // val is the unwrapped value +} +``` + +#### Pattern Matching +Optionals support `.some` and `.none` virtual enum variants: +```sx +result := if opt == { + case .some: (val) { val * 2; } + case .none: { 0; } +}; +``` + +#### Optional Chaining (`?.`) +Short-circuits field access on optionals: +```sx +x: ?Point = Point.{ x = 1, y = 2 }; +y: ?Point = null; +a := x?.x ?? 0; // 1 +b := y?.x ?? 0; // 0 +``` + +Result type of `x?.field` is always `?FieldType`. + +#### Flow-Sensitive Narrowing +The compiler narrows `?T` to `T` in control flow branches: +```sx +x: ?s32 = 42; +if x != null { + print("{}\n", x); // x is s32 here (narrowed) +} +if x == null { return; } +print("{}\n", x); // x is s32 here (guard narrowing) +``` + +Compound conditions: +```sx +if a != null and b != null { + // both a and b are narrowed to their inner types +} +if a == null or b == null { return; } +// both a and b are narrowed after the guard +``` + +Reassignment kills narrowing. + +#### Struct Field Defaults +Optional fields in structs default to `null`: +```sx +Node :: struct { value: s32; next: ?s32; } +n := Node.{ value = 10 }; // n.next is null +``` + +#### Printing +`print("{}", opt)` prints the payload value if present, or `"null"`. + +#### Comptime +Optionals work in `#run` blocks — `??`, `!`, `if val :=`, null checks all supported. + ### Foreign Function Interface (C Interop) To call C functions, declare a library constant with `#library` and bind functions with `#foreign`: diff --git a/src/ast.zig b/src/ast.zig index a1d3ca1..c42f126 100644 --- a/src/ast.zig +++ b/src/ast.zig @@ -53,6 +53,9 @@ pub const Node = struct { slice_expr: SliceExpr, pointer_type_expr: PointerTypeExpr, many_pointer_type_expr: ManyPointerTypeExpr, + optional_type_expr: OptionalTypeExpr, + force_unwrap: ForceUnwrap, + null_coalesce: NullCoalesce, deref_expr: DerefExpr, null_literal: void, while_expr: WhileExpr, @@ -192,6 +195,7 @@ pub const Call = struct { pub const FieldAccess = struct { object: *Node, field: []const u8, + is_optional: bool = false, }; pub const IfExpr = struct { @@ -199,6 +203,7 @@ pub const IfExpr = struct { then_branch: *Node, else_branch: ?*Node, is_inline: bool, // true for `if cond then a else b` + binding_name: ?[]const u8 = null, // for `if val := expr { ... }` optional binding }; pub const MatchExpr = struct { @@ -371,6 +376,19 @@ pub const ManyPointerTypeExpr = struct { element_type: *Node, }; +pub const OptionalTypeExpr = struct { + inner_type: *Node, +}; + +pub const ForceUnwrap = struct { + operand: *Node, +}; + +pub const NullCoalesce = struct { + lhs: *Node, + rhs: *Node, +}; + pub const DerefExpr = struct { operand: *Node, }; @@ -378,6 +396,7 @@ pub const DerefExpr = struct { pub const WhileExpr = struct { condition: *Node, body: *Node, + binding_name: ?[]const u8 = null, // for `while val := expr { ... }` optional binding }; pub const ForExpr = struct { diff --git a/src/codegen.zig b/src/codegen.zig index 5a28e48..3d95013 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -122,6 +122,8 @@ pub const CodeGen = struct { // Symbol table: maps variable names to their alloca pointers named_values: std.StringHashMap(NamedValue), + // Flow-sensitive narrowing: tracks variables narrowed from ?T to T + narrowed_types: std.StringHashMap(NarrowedInfo), // Unified type registry: single lookup for all named types (structs, enums, unions, aliases) type_registry: std.StringHashMap(TypeRegistryEntry), // Flags enum registry: tracks which enum names are flags @@ -316,6 +318,12 @@ pub const CodeGen = struct { is_const: bool = false, }; + /// Info for a flow-narrowed optional variable: ?T → T in a checked scope + const NarrowedInfo = struct { + narrowed_ty: Type, // the inner type T (not ?T) + payload_ptr: c.LLVMValueRef, // alloca holding the unwrapped value + }; + /// Unified value lookup result — avoids sequential hash lookups at hot paths. const ValueLookup = union(enum) { local: NamedValue, @@ -391,6 +399,7 @@ pub const CodeGen = struct { .ts_context = ts_ctx, .target_machine = tm, .named_values = std.StringHashMap(NamedValue).init(allocator), + .narrowed_types = std.StringHashMap(NarrowedInfo).init(allocator), .type_registry = std.StringHashMap(TypeRegistryEntry).init(allocator), .flags_enum_types = std.StringHashMap(void).init(allocator), .enum_variant_values = std.StringHashMap([]const i64).init(allocator), @@ -589,6 +598,19 @@ pub const CodeGen = struct { return c.LLVMVectorType(self.typeToLLVM(elem_ty), info.length); }, .pointer_type, .many_pointer_type, .function_type => self.ptrType(), + .optional_type => |info| { + // ?*T, ?[*]T → bare pointer (null = none) + const child_type = self.resolveTypeFromName(info.child_name) orelse unreachable; + if (child_type.isPointer() or child_type.isManyPointer() or child_type.isFunctionType()) { + return self.ptrType(); + } + // ?T → { T, i1 } struct + var field_types: [2]c.LLVMTypeRef = .{ + self.typeToLLVM(child_type), + self.i1Type(), + }; + return c.LLVMStructTypeInContext(self.context, &field_types, 2, 0); + }, .any_type => self.getAnyStructType(), .meta_type => self.ptrType(), .tuple_type => |info| { @@ -704,6 +726,42 @@ pub const CodeGen = struct { const i64_ty = self.i64Type(); const undef = self.getUndef(any_ty); + // Optional: branch on has_value — some prints inner value, none prints "null" + if (in_ty.isOptional()) { + const has_val = self.optionalHasValue(val, in_ty); + const some_bb = self.appendBB("opt_any_some"); + const none_bb = self.appendBB("opt_any_none"); + const merge_bb = self.appendBB("opt_any_merge"); + _ = c.LLVMBuildCondBr(self.builder, has_val, some_bb, none_bb); + + // Some: extract payload and wrap as Any + self.positionAt(some_bb); + const payload = self.optionalPayload(val, in_ty); + const child_name = in_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse .void_type; + const some_any = try self.buildAnyValue(payload, child_ty); + const some_out = self.getCurrentBlock(); + self.br(merge_bb); + + // None: wrap "null" string as Any + self.positionAt(none_bb); + const null_str = c.LLVMBuildGlobalStringPtr(self.builder, "null", "null_str"); + const null_len = self.constInt64(4); + const null_slice = self.buildStringSlice(null_str, null_len); + const none_any = try self.buildAnyValue(null_slice, .string_type); + const none_out = self.getCurrentBlock(); + self.br(merge_bb); + + self.positionAt(merge_bb); + var phi_vals = std.ArrayList(c.LLVMValueRef).empty; + var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty; + try phi_vals.append(self.allocator, some_any); + try phi_bbs.append(self.allocator, some_out); + try phi_vals.append(self.allocator, none_any); + try phi_bbs.append(self.allocator, none_out); + return try self.buildPhiNode(&phi_vals, &phi_bbs, any_ty, "opt_any_phi"); + } + // []u8 boxes as string (same repr, same Any tag) const ty: Type = if (in_ty.isSlice() and std.mem.eql(u8, in_ty.slice_type.element_name, "u8")) .string_type @@ -850,6 +908,50 @@ pub const CodeGen = struct { _ = c.LLVMBuildStore(self.builder, c.LLVMConstNull(ty), ptr); } + /// Wrap a value into an optional { value, i1 1 }. + /// For pointer optionals, the value is already a pointer — return as-is. + fn wrapOptional(self: *CodeGen, value: c.LLVMValueRef, opt_ty: Type) c.LLVMValueRef { + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse unreachable; + if (child_ty.isPointer() or child_ty.isManyPointer() or child_ty.isFunctionType()) { + return value; // pointer is already nullable — just pass through + } + const llvm_opt_ty = self.typeToLLVM(opt_ty); + var result = self.getUndef(llvm_opt_ty); + result = c.LLVMBuildInsertValue(self.builder, result, value, 0, "opt_val"); + result = c.LLVMBuildInsertValue(self.builder, result, c.LLVMConstInt(self.i1Type(), 1, 0), 1, "opt_some"); + return result; + } + + /// Create a null optional value (none). + /// For pointer optionals: null pointer. For value optionals: { undef, i1 0 }. + fn makeNullOptional(self: *CodeGen, opt_ty: Type) c.LLVMValueRef { + const llvm_opt_ty = self.typeToLLVM(opt_ty); + return c.LLVMConstNull(llvm_opt_ty); + } + + /// Check if an optional has a value. Returns an i1. + /// For pointer optionals: ptr != null. For value optionals: extractvalue i1 flag. + fn optionalHasValue(self: *CodeGen, opt_val: c.LLVMValueRef, opt_ty: Type) c.LLVMValueRef { + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse unreachable; + if (child_ty.isPointer() or child_ty.isManyPointer() or child_ty.isFunctionType()) { + return c.LLVMBuildICmp(self.builder, c.LLVMIntNE, opt_val, c.LLVMConstNull(self.ptrType()), "opt_nonnull"); + } + return c.LLVMBuildExtractValue(self.builder, opt_val, 1, "opt_flag"); + } + + /// Extract the payload from an optional (no check — caller must ensure has_value). + /// For pointer optionals: returns the pointer. For value optionals: extractvalue payload. + fn optionalPayload(self: *CodeGen, opt_val: c.LLVMValueRef, opt_ty: Type) c.LLVMValueRef { + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse unreachable; + if (child_ty.isPointer() or child_ty.isManyPointer() or child_ty.isFunctionType()) { + return opt_val; // for pointer optionals, the value IS the pointer + } + return c.LLVMBuildExtractValue(self.builder, opt_val, 0, "opt_payload"); + } + fn loadTyped(self: *CodeGen, ty: Type, ptr: c.LLVMValueRef, name: [*c]const u8) c.LLVMValueRef { return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(ty), ptr, name); } @@ -1475,6 +1577,13 @@ pub const CodeGen = struct { const elem_name = elem_type.displayName(self.allocator) catch unreachable; return .{ .slice_type = .{ .element_name = elem_name } }; } + // Optional type: ?T + if (tn.data == .optional_type_expr) { + const ote = tn.data.optional_type_expr; + const inner_type = self.resolveType(ote.inner_type); + const inner_name = inner_type.displayName(self.allocator) catch unreachable; + return .{ .optional_type = .{ .child_name = inner_name } }; + } // Pointer type: *T if (tn.data == .pointer_type_expr) { const pte = tn.data.pointer_type_expr; @@ -2492,6 +2601,7 @@ pub const CodeGen = struct { fn genFnBody(self: *CodeGen, fd: ast.FnDecl, llvm_name: []const u8) !void { self.named_values.clearRetainingCapacity(); + self.narrowed_types.clearRetainingCapacity(); const ret_sx_type = self.resolveType(fd.return_type); const is_main = std.mem.eql(u8, llvm_name, "main"); @@ -2583,6 +2693,7 @@ pub const CodeGen = struct { fn genLambdaBody(self: *CodeGen, name: []const u8, lambda: ast.Lambda) !void { self.named_values.clearRetainingCapacity(); + self.narrowed_types.clearRetainingCapacity(); const ret_sx_type = self.inferType(lambda.body); self.current_return_type = ret_sx_type; @@ -2628,7 +2739,9 @@ pub const CodeGen = struct { const saved_bb = self.getCurrentBlock(); const saved_ret = self.current_return_type; const saved_named = self.named_values; + const saved_narrowed = self.narrowed_types; self.named_values = std.StringHashMap(NamedValue).init(self.allocator); + self.narrowed_types = std.StringHashMap(NarrowedInfo).init(self.allocator); // Infer return type from body for => lambdas without explicit annotation const ret_sx_type = if (fd.return_type != null) self.resolveType(fd.return_type) else if (fd.is_arrow) self.inferType(fd.body) else Type.void_type; @@ -2688,6 +2801,7 @@ pub const CodeGen = struct { // Restore outer function state self.named_values = saved_named; + self.narrowed_types = saved_narrowed; self.current_return_type = saved_ret; self.current_function = saved_fn; self.positionAt(saved_bb); @@ -2711,8 +2825,12 @@ pub const CodeGen = struct { .return_stmt => |rs| { // Evaluate return value first, then emit all defers, then return if (rs.value) |val_node| { - const raw_val = try self.genExpr(val_node); - const ret_val = try self.prepareReturnValue(raw_val, self.current_return_type); + const ret_val = if (self.current_return_type.isOptional()) blk: { + break :blk try self.genExprAsType(val_node, self.current_return_type); + } else blk: { + const raw_val = try self.genExpr(val_node); + break :blk try self.prepareReturnValue(raw_val, self.current_return_type); + }; try self.emitAllDefers(); self.ret(ret_val); } else { @@ -3003,6 +3121,25 @@ pub const CodeGen = struct { return null; } + // Optional-typed variable + if (sx_ty.isOptional()) { + const llvm_ty = self.typeToLLVM(sx_ty); + const alloca = try self.buildNamedAlloca(llvm_ty, vd.name); + + if (vd.value == null) { + // Default-init: null optional + self.storeNull(llvm_ty, alloca); + } else if (vd.value.?.data == .undef_literal) { + self.storeUndef(llvm_ty, alloca); + } else { + const val = try self.genExprAsType(vd.value.?, sx_ty); + _ = c.LLVMBuildStore(self.builder, val, alloca); + } + + try self.registerVariable(vd.name, alloca, sx_ty); + return null; + } + // Guard: void type cannot be allocated (would crash LLVM) if (sx_ty == .void_type) { return self.emitErrorFmt("cannot declare variable '{s}' with void type", .{vd.name}); @@ -3077,13 +3214,16 @@ pub const CodeGen = struct { const saved_bb = self.getCurrentBlock(); const saved_ret = self.current_return_type; const saved_named = self.named_values; + const saved_narrowed2 = self.narrowed_types; self.named_values = std.StringHashMap(NamedValue).init(self.allocator); + self.narrowed_types = std.StringHashMap(NarrowedInfo).init(self.allocator); try self.registerLambdaAsFunction(cd.name, cd.value.data.lambda); try self.genLambdaBody(cd.name, cd.value.data.lambda); self.named_values.deinit(); self.named_values = saved_named; + self.narrowed_types = saved_narrowed2; self.current_return_type = saved_ret; self.current_function = saved_fn; self.positionAt(saved_bb); @@ -3249,6 +3389,10 @@ pub const CodeGen = struct { if (entry.is_const) return self.emitErrorFmt("cannot assign to '{s}'", .{name}); + // Kill narrowing on reassignment: if x was narrowed from ?T to T, + // reassignment invalidates the narrowed value + _ = self.narrowed_types.remove(name); + // Meta type reassignment: x = Vec4, x = f64, x = test if (entry.ty == .meta_type and asgn.op == .assign) { const raw_name = self.asTypeName(asgn.value) orelse blk: { @@ -3565,6 +3709,10 @@ pub const CodeGen = struct { return self.buildStringSlice(ptr, self.constInt64(@intCast(content.len))); }, .identifier => |ident| { + // Flow-sensitive narrowing: if variable is narrowed, load from payload alloca + if (self.narrowed_types.get(ident.name)) |ni| { + return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(ni.narrowed_ty), ni.payload_ptr, "narrow_val"); + } if (self.lookupValue(ident.name)) |v| { switch (v) { .local => |nv| return c.LLVMBuildLoad2(self.builder, self.typeToLLVM(nv.ty), nv.ptr, "loadtmp"), @@ -3601,6 +3749,29 @@ pub const CodeGen = struct { .binary_op => |binop| { if (binop.op == .and_op) return self.genShortCircuitOp(binop, true); if (binop.op == .or_op) return self.genShortCircuitOp(binop, false); + + // Optional-null comparison: x == null / x != null + if (binop.op == .eq or binop.op == .neq) { + const lhs_ty = self.inferType(binop.lhs); + const rhs_ty = self.inferType(binop.rhs); + const opt_side: ?struct { expr: *Node, ty: Type } = + if (lhs_ty.isOptional() and binop.rhs.data == .null_literal) + .{ .expr = binop.lhs, .ty = lhs_ty } + else if (rhs_ty.isOptional() and binop.lhs.data == .null_literal) + .{ .expr = binop.rhs, .ty = rhs_ty } + else + null; + if (opt_side) |os| { + const opt_val = try self.genExpr(os.expr); + const has_val = self.optionalHasValue(opt_val, os.ty); + // == null → NOT has_value; != null → has_value + if (binop.op == .eq) { + return c.LLVMBuildNot(self.builder, has_val, "opt_is_null"); + } + return has_val; + } + } + const lhs_ty = self.inferType(binop.lhs); const rhs_ty = self.inferType(binop.rhs); const result_type = Type.widen(lhs_ty, rhs_ty); @@ -3763,6 +3934,9 @@ pub const CodeGen = struct { return self.genArrayLiteral(al, null); }, .field_access => |fa| { + if (fa.is_optional) { + return self.genOptionalChain(fa); + } return self.genFieldAccess(fa); }, .index_expr => |ie| { @@ -3829,8 +4003,12 @@ pub const CodeGen = struct { }, .return_stmt => |rs| { if (rs.value) |val_node| { - const raw_val = try self.genExpr(val_node); - const ret_val = try self.prepareReturnValue(raw_val, self.current_return_type); + const ret_val = if (self.current_return_type.isOptional()) blk: { + break :blk try self.genExprAsType(val_node, self.current_return_type); + } else blk: { + const raw_val = try self.genExpr(val_node); + break :blk try self.prepareReturnValue(raw_val, self.current_return_type); + }; try self.emitAllDefers(); self.ret(ret_val); } else { @@ -3840,6 +4018,63 @@ pub const CodeGen = struct { _ = self.appendBlock(self.current_function, "after_ret"); return null; }, + .null_coalesce => |nc| { + const opt_val = try self.genExpr(nc.lhs); + const opt_ty = self.inferType(nc.lhs); + if (!opt_ty.isOptional()) return self.emitError("'??' requires an optional type on the left side"); + + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse + return self.emitError("unknown optional inner type"); + + const has_val = self.optionalHasValue(opt_val, opt_ty); + const then_bb = self.appendBB("coalesce_some"); + const else_bb = self.appendBB("coalesce_none"); + const merge_bb = self.appendBB("coalesce_merge"); + _ = c.LLVMBuildCondBr(self.builder, has_val, then_bb, else_bb); + + // Some path: extract payload + c.LLVMPositionBuilderAtEnd(self.builder, then_bb); + const payload = self.optionalPayload(opt_val, opt_ty); + const then_end_bb = c.LLVMGetInsertBlock(self.builder); + _ = c.LLVMBuildBr(self.builder, merge_bb); + + // None path: evaluate default + c.LLVMPositionBuilderAtEnd(self.builder, else_bb); + const default_val = try self.genExprAsType(nc.rhs, child_ty); + const else_end_bb = c.LLVMGetInsertBlock(self.builder); + _ = c.LLVMBuildBr(self.builder, merge_bb); + + // Merge with PHI + c.LLVMPositionBuilderAtEnd(self.builder, merge_bb); + const llvm_child_ty = self.typeToLLVM(child_ty); + const phi = c.LLVMBuildPhi(self.builder, llvm_child_ty, "coalesce"); + var vals = [2]c.LLVMValueRef{ payload, default_val }; + var bbs = [2]c.LLVMBasicBlockRef{ then_end_bb, else_end_bb }; + c.LLVMAddIncoming(phi, &vals, &bbs, 2); + return phi; + }, + .force_unwrap => |fu| { + const opt_val = try self.genExpr(fu.operand); + const opt_ty = self.inferType(fu.operand); + if (!opt_ty.isOptional()) return self.emitError("force unwrap (!) requires an optional type"); + + // Check has_value — if false, trap + const has_val = self.optionalHasValue(opt_val, opt_ty); + const then_bb = self.appendBB("unwrap_ok"); + const trap_bb = self.appendBB("unwrap_trap"); + _ = c.LLVMBuildCondBr(self.builder, has_val, then_bb, trap_bb); + + // Trap block: call llvm.trap + unreachable + c.LLVMPositionBuilderAtEnd(self.builder, trap_bb); + const trap_fn = c.LLVMGetIntrinsicDeclaration(self.module, c.LLVMLookupIntrinsicID("llvm.trap", 9), null, 0); + _ = c.LLVMBuildCall2(self.builder, c.LLVMFunctionType(self.voidType(), null, 0, 0), trap_fn, null, 0, ""); + _ = c.LLVMBuildUnreachable(self.builder); + + // OK block: extract payload + c.LLVMPositionBuilderAtEnd(self.builder, then_bb); + return self.optionalPayload(opt_val, opt_ty); + }, .deref_expr => |de| { const ptr_val = try self.genExpr(de.operand); const ptr_ty = self.inferType(de.operand); @@ -4654,6 +4889,29 @@ pub const CodeGen = struct { return self.buildGlobalString(str_z.ptr, "str"); } + // Optional target type: wrap value or produce null + if (target_ty.isOptional()) { + // null literal → null optional + if (node.data == .null_literal) { + return self.makeNullOptional(target_ty); + } + // If source expression already produces the same optional type, pass through + const src_ty = self.inferType(node); + if (src_ty.eql(target_ty)) { + return try self.genExpr(node); + } + // If source is a different optional, generate as-is (e.g. ?s32 → ?s64 widening) + if (src_ty.isOptional()) { + return try self.genExpr(node); + } + // Expression producing a value — generate it as the inner type, then wrap + const child_name = target_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse + return self.emitErrorFmt("unknown optional inner type '{s}'", .{child_name}); + const val = try self.genExprAsType(node, child_ty); + return self.wrapOptional(val, target_ty); + } + // Enum literal assigned to enum type: resolve variant value if (node.data == .enum_literal and target_ty.isEnum()) { return self.genEnumLiteral(node.data.enum_literal.name, target_ty.enum_type); @@ -5586,6 +5844,108 @@ pub const CodeGen = struct { return self.emitErrorFmt("unsupported vector swizzle '{s}'", .{field}); } + /// Optional chaining: expr?.field — short-circuit to null if expr is null + fn genOptionalChain(self: *CodeGen, fa: ast.FieldAccess) !c.LLVMValueRef { + const opt_ty = self.inferType(fa.object); + if (!opt_ty.isOptional()) { + return self.emitError("'?.' used on non-optional type"); + } + + const opt_val = try self.genExpr(fa.object); + const has_val = self.optionalHasValue(opt_val, opt_ty); + + const some_bb = self.appendBB("chain_some"); + const none_bb = self.appendBB("chain_none"); + const merge_bb = self.appendBB("chain_merge"); + _ = c.LLVMBuildCondBr(self.builder, has_val, some_bb, none_bb); + + // Some: unwrap, access field, re-wrap as ?FieldType + self.positionAt(some_bb); + const payload = self.optionalPayload(opt_val, opt_ty); + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse + return self.emitErrorFmt("unknown optional inner type '{s}'", .{child_name}); + + // Create a synthetic non-optional field_access to reuse genFieldAccess + const inner_fa = ast.FieldAccess{ .object = fa.object, .field = fa.field, .is_optional = false }; + // We need the field type to construct ?FieldType + const field_ty = self.inferFieldType(child_ty, fa.field) orelse + return self.emitErrorFmt("type '{s}' has no field '{s}'", .{ child_name, fa.field }); + + // Generate the field access on the unwrapped payload + const field_val = try self.genFieldAccessOnValue(payload, child_ty, inner_fa.field); + const result_opt_ty = Type{ .optional_type = .{ .child_name = try field_ty.displayName(self.allocator) } }; + const some_result = self.wrapOptional(field_val, result_opt_ty); + const some_out_bb = self.getCurrentBlock(); + self.br(merge_bb); + + // None: produce null optional + self.positionAt(none_bb); + const none_result = self.makeNullOptional(result_opt_ty); + const none_out_bb = self.getCurrentBlock(); + self.br(merge_bb); + + // Merge + self.positionAt(merge_bb); + var phi_vals = std.ArrayList(c.LLVMValueRef).empty; + var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty; + try phi_vals.append(self.allocator, some_result); + try phi_bbs.append(self.allocator, some_out_bb); + try phi_vals.append(self.allocator, none_result); + try phi_bbs.append(self.allocator, none_out_bb); + const result_llvm_ty = self.typeToLLVM(result_opt_ty); + return try self.buildPhiNode(&phi_vals, &phi_bbs, result_llvm_ty, "chain_result"); + } + + /// Infer the type of a field on a given type + fn inferFieldType(self: *CodeGen, ty: Type, field: []const u8) ?Type { + if (ty.isStruct()) { + const info = self.lookupStructInfo(ty.struct_type) orelse return null; + for (info.field_names, 0..) |fname, i| { + if (std.mem.eql(u8, fname, field)) { + return info.field_types[i]; + } + } + } + // string/slice .ptr/.len + if (ty == .string_type or ty.isSlice()) { + if (std.mem.eql(u8, field, "len")) return .{ .signed = 64 }; + if (std.mem.eql(u8, field, "ptr")) return .{ .pointer_type = .{ .pointee_name = "u8" } }; + } + return null; + } + + /// Generate field access on a raw value (not from the AST) + fn genFieldAccessOnValue(self: *CodeGen, val: c.LLVMValueRef, val_ty: Type, field: []const u8) !c.LLVMValueRef { + if (val_ty.isStruct()) { + const sname = val_ty.struct_type; + const info = self.lookupStructInfo(sname) orelse + return self.emitErrorFmt("unknown struct '{s}'", .{sname}); + for (info.field_names, 0..) |fname, fi| { + if (std.mem.eql(u8, fname, field)) { + // val is a loaded struct value — store to temp alloca, GEP, load field + const alloca = self.buildEntryBlockAlloca(info.llvm_type, "chain_tmp"); + _ = c.LLVMBuildStore(self.builder, val, alloca); + const field_llvm_ty = self.typeToLLVM(info.field_types[fi]); + return self.loadStructField(info.llvm_type, alloca, @intCast(fi), field_llvm_ty); + } + } + return self.emitErrorFmt("struct '{s}' has no field '{s}'", .{ sname, field }); + } + // string/slice .ptr/.len + if (val_ty == .string_type or val_ty.isSlice()) { + const str_ty = self.getStringStructType(); + if (std.mem.eql(u8, field, "ptr")) { + return c.LLVMBuildExtractValue(self.builder, val, 0, "chain_ptr"); + } + if (std.mem.eql(u8, field, "len")) { + return c.LLVMBuildExtractValue(self.builder, val, 1, "chain_len"); + } + _ = str_ty; + } + return self.emitErrorFmt("cannot access field '{s}' via optional chaining", .{field}); + } + fn genFieldAccess(self: *CodeGen, fa: ast.FieldAccess) !c.LLVMValueRef { // Check if the object is a struct or vector variable if (fa.object.data == .identifier) { @@ -7357,7 +7717,171 @@ pub const CodeGen = struct { return function; } + /// Result of detecting a null-check pattern in a condition expression + const NullCheck = struct { + var_name: []const u8, + is_eq: bool, // true for == null, false for != null + }; + + /// Detect a single `x == null` or `x != null` pattern in a condition expression + fn detectNullCheck(self: *CodeGen, cond: *Node) ?NullCheck { + if (cond.data != .binary_op) return null; + const bop = cond.data.binary_op; + if (bop.op != .eq and bop.op != .neq) return null; + + // Check: identifier op null_literal OR null_literal op identifier + const var_name: ?[]const u8 = if (bop.lhs.data == .identifier and bop.rhs.data == .null_literal) + bop.lhs.data.identifier.name + else if (bop.lhs.data == .null_literal and bop.rhs.data == .identifier) + bop.rhs.data.identifier.name + else + null; + + const name = var_name orelse return null; + + // Verify the variable is actually optional + if (self.named_values.get(name)) |entry| { + if (entry.ty.isOptional()) { + return NullCheck{ + .var_name = name, + .is_eq = bop.op == .eq, + }; + } + } + return null; + } + + /// Collect null checks from compound conditions: + /// `x != null && y != null` → [x(!=), y(!=)] — narrow both in then-branch + /// `x == null || y == null` → [x(==), y(==)] — narrow both after guard + /// Only collects when ALL leaves are null checks with the SAME polarity connected + /// by the expected operator (&& for !=null, || for ==null). + fn collectNullChecks(self: *CodeGen, cond: *Node, buf: []NullCheck) usize { + // Try single null check first + if (self.detectNullCheck(cond)) |nc| { + buf[0] = nc; + return 1; + } + // Try compound: binary_op with and_op or or_op + if (cond.data != .binary_op) return 0; + const bop = cond.data.binary_op; + if (bop.op != .and_op and bop.op != .or_op) return 0; + + var left_buf: [8]NullCheck = undefined; + var right_buf: [8]NullCheck = undefined; + const left_n = self.collectNullChecks(bop.lhs, &left_buf); + const right_n = self.collectNullChecks(bop.rhs, &right_buf); + if (left_n == 0 or right_n == 0) return 0; + if (left_n + right_n > buf.len) return 0; + + // All checks must have same polarity: + // && chains: all must be != null (is_eq=false) + // || chains: all must be == null (is_eq=true) + const expected_eq = bop.op == .or_op; // || → ==null, && → !=null + for (left_buf[0..left_n]) |nc| { + if (nc.is_eq != expected_eq) return 0; + } + for (right_buf[0..right_n]) |nc| { + if (nc.is_eq != expected_eq) return 0; + } + + @memcpy(buf[0..left_n], left_buf[0..left_n]); + @memcpy(buf[left_n..][0..right_n], right_buf[0..right_n]); + return left_n + right_n; + } + + /// Push a narrowing: load the optional, extract payload, store in temp alloca + fn pushNarrowing(self: *CodeGen, var_name: []const u8) !void { + const entry = self.named_values.get(var_name) orelse return; + const opt_ty = entry.ty; + if (!opt_ty.isOptional()) return; + + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse return; + const child_llvm_ty = self.typeToLLVM(child_ty); + + // Load the optional value and extract the payload + const opt_llvm_ty = self.typeToLLVM(opt_ty); + const opt_val = c.LLVMBuildLoad2(self.builder, opt_llvm_ty, entry.ptr, "narrow_load"); + const payload = self.optionalPayload(opt_val, opt_ty); + + // Store payload in a temp alloca + const alloca = self.buildEntryBlockAlloca(child_llvm_ty, "narrowed"); + _ = c.LLVMBuildStore(self.builder, payload, alloca); + + try self.narrowed_types.put(var_name, .{ + .narrowed_ty = child_ty, + .payload_ptr = alloca, + }); + } + + /// Pop a narrowing entry + fn popNarrowing(self: *CodeGen, var_name: []const u8) void { + _ = self.narrowed_types.remove(var_name); + } + fn genIfExpr(self: *CodeGen, if_expr: ast.IfExpr) !c.LLVMValueRef { + // Optional binding: if val := expr { ... } + if (if_expr.binding_name) |binding_name| { + const opt_val = try self.genExpr(if_expr.condition); + const opt_ty = self.inferType(if_expr.condition); + if (!opt_ty.isOptional()) return self.emitError("'if val := expr' requires an optional expression"); + + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse + return self.emitError("unknown optional inner type"); + + const has_val = self.optionalHasValue(opt_val, opt_ty); + const has_else = if_expr.else_branch != null; + + var then_bb = self.appendBB("if_some"); + var else_bb: c.LLVMBasicBlockRef = if (has_else) self.appendBB("if_none") else null; + const merge_bb = self.appendBB("if_merge"); + + const false_dest = if (has_else) else_bb else merge_bb; + self.condBr(has_val, then_bb, false_dest); + + // Then branch: bind the unwrapped value + self.positionAt(then_bb); + const payload = self.optionalPayload(opt_val, opt_ty); + const alloca = try self.buildNamedAlloca(self.typeToLLVM(child_ty), binding_name); + _ = c.LLVMBuildStore(self.builder, payload, alloca); + try self.registerVariable(binding_name, alloca, child_ty); + const then_val = try self.genExpr(if_expr.then_branch); + then_bb = self.getCurrentBlock(); + self.br(merge_bb); + + // Else branch + var else_val: c.LLVMValueRef = null; + if (if_expr.else_branch) |else_branch| { + self.positionAt(else_bb); + else_val = try self.genExpr(else_branch); + else_bb = self.getCurrentBlock(); + self.br(merge_bb); + } + + self.positionAt(merge_bb); + + if (then_val != null and else_val != null) { + const ty = c.LLVMTypeOf(then_val); + if (c.LLVMGetTypeKind(ty) != c.LLVMVoidTypeKind) { + const phi = c.LLVMBuildPhi(self.builder, ty, "iftmp"); + var vals = [2]c.LLVMValueRef{ then_val, else_val }; + var blocks = [2]c.LLVMBasicBlockRef{ then_bb, else_bb }; + c.LLVMAddIncoming(phi, &vals, &blocks, 2); + return phi; + } + } + + return null; + } + + // Detect null-check narrowing: if x != null { ... } or if x == null { ... } + // Also handles compound: if x != null && y != null { ... } + var null_checks_buf: [8]NullCheck = undefined; + const null_check_count = self.collectNullChecks(if_expr.condition, &null_checks_buf); + const null_checks = null_checks_buf[0..null_check_count]; + // Generate condition const cond_val = self.valueToBool(try self.genExpr(if_expr.condition)); @@ -7373,17 +7897,33 @@ pub const CodeGen = struct { const false_dest = if (has_else) else_bb else merge_bb; self.condBr(cond_val, then_bb, false_dest); - // Then branch + // Then branch — apply narrowing for != null checks (including && chains) self.positionAt(then_bb); + for (null_checks) |nc| { + if (!nc.is_eq) { // x != null → narrow in then + try self.pushNarrowing(nc.var_name); + } + } const then_val = try self.genExpr(if_expr.then_branch); + for (null_checks) |nc| { + if (!nc.is_eq) self.popNarrowing(nc.var_name); + } then_bb = self.getCurrentBlock(); // may have changed due to nested control flow self.br(merge_bb); - // Else branch + // Else branch — apply narrowing for == null checks (x is non-null in else) var else_val: c.LLVMValueRef = null; if (if_expr.else_branch) |else_branch| { self.positionAt(else_bb); + for (null_checks) |nc| { + if (nc.is_eq) { // x == null → narrow in else + try self.pushNarrowing(nc.var_name); + } + } else_val = try self.genExpr(else_branch); + for (null_checks) |nc| { + if (nc.is_eq) self.popNarrowing(nc.var_name); + } else_bb = self.getCurrentBlock(); self.br(merge_bb); } @@ -7391,6 +7931,17 @@ pub const CodeGen = struct { // Merge block self.positionAt(merge_bb); + // Guard narrowing: if x == null { return; } → x narrowed after + // Also handles: if x == null || y == null { return; } → both narrowed after + if (!has_else and null_check_count > 0 and bodyAlwaysExits(if_expr.then_branch)) { + for (null_checks) |nc| { + if (nc.is_eq) { + try self.pushNarrowing(nc.var_name); + // Persists for rest of enclosing block, cleaned up at function boundary + } + } + } + // PHI node if both branches produced values (skip for void type) if (then_val != null and else_val != null) { const ty = c.LLVMTypeOf(then_val); @@ -7406,6 +7957,20 @@ pub const CodeGen = struct { return null; } + /// Check if a body expression unconditionally exits the current scope + fn bodyAlwaysExits(body: *Node) bool { + if (body.data == .return_stmt) return true; + if (body.data == .break_expr) return true; + if (body.data == .continue_expr) return true; + if (body.data == .block) { + const stmts = body.data.block.stmts; + if (stmts.len > 0) { + return bodyAlwaysExits(stmts[stmts.len - 1]); + } + } + return false; + } + fn genWhileExpr(self: *CodeGen, while_expr: ast.WhileExpr) !c.LLVMValueRef { // Create basic blocks: condition, body, after const cond_bb = self.appendBB("while.cond"); @@ -7417,6 +7982,46 @@ pub const CodeGen = struct { // Condition block self.positionAt(cond_bb); + + // Optional binding: while val := expr { ... } + if (while_expr.binding_name) |binding_name| { + const opt_val = try self.genExpr(while_expr.condition); + const opt_ty = self.inferType(while_expr.condition); + if (!opt_ty.isOptional()) return self.emitError("'while val := expr' requires an optional expression"); + + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse + return self.emitError("unknown optional inner type"); + + const has_val = self.optionalHasValue(opt_val, opt_ty); + self.condBr(has_val, body_bb, after_bb); + + // Body block: bind the unwrapped value + self.positionAt(body_bb); + const saved_break_bb = self.loop_break_bb; + const saved_continue_bb = self.loop_continue_bb; + self.loop_break_bb = after_bb; + self.loop_continue_bb = cond_bb; + + const payload = self.optionalPayload(opt_val, opt_ty); + const alloca = try self.buildNamedAlloca(self.typeToLLVM(child_ty), binding_name); + _ = c.LLVMBuildStore(self.builder, payload, alloca); + try self.registerVariable(binding_name, alloca, child_ty); + + _ = try self.genExpr(while_expr.body); + + self.loop_break_bb = saved_break_bb; + self.loop_continue_bb = saved_continue_bb; + + const current_bb = self.getCurrentBlock(); + if (c.LLVMGetBasicBlockTerminator(current_bb) == null) { + self.br(cond_bb); + } + + self.positionAt(after_bb); + return null; + } + const cond_val = self.valueToBool(try self.genExpr(while_expr.condition)); self.condBr(cond_val, body_bb, after_bb); @@ -7617,6 +8222,145 @@ pub const CodeGen = struct { return 0; } + /// Generate match expression for optional types: case .some: (val) { ... } case .none: { ... } + fn genOptionalMatch(self: *CodeGen, match: ast.MatchExpr, opt_ty: Type) !c.LLVMValueRef { + const opt_val = try self.genExpr(match.subject); + const has_val = self.optionalHasValue(opt_val, opt_ty); + + const merge_bb = self.appendBB("opt_match_end"); + const some_bb = self.appendBB("opt_some"); + const none_bb = self.appendBB("opt_none"); + + // Find .some and .none arms + var some_arm: ?ast.MatchArm = null; + var none_arm: ?ast.MatchArm = null; + var else_arm: ?ast.MatchArm = null; + for (match.arms) |arm| { + if (arm.pattern) |pat| { + if (pat.data == .enum_literal) { + if (std.mem.eql(u8, pat.data.enum_literal.name, "some")) { + some_arm = arm; + } else if (std.mem.eql(u8, pat.data.enum_literal.name, "none")) { + none_arm = arm; + } + } + } else { + else_arm = arm; + } + } + + // Branch on has_value: 1 = some, 0 = none + _ = c.LLVMBuildCondBr(self.builder, has_val, some_bb, none_bb); + + var phi_vals = std.ArrayList(c.LLVMValueRef).empty; + var phi_bbs = std.ArrayList(c.LLVMBasicBlockRef).empty; + var has_result = false; + var result_type: c.LLVMTypeRef = null; + + // Generate .some arm + self.positionAt(some_bb); + const some_val = blk: { + if (some_arm) |arm| { + if (arm.is_break) { + self.br(merge_bb); + break :blk @as(?c.LLVMValueRef, null); + } + // Payload capture for .some + if (arm.capture) |cap_name| { + const payload = self.optionalPayload(opt_val, opt_ty); + const child_name = opt_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse unreachable; + const payload_llvm_ty = self.typeToLLVM(child_ty); + const cap_alloca = c.LLVMBuildAlloca(self.builder, payload_llvm_ty, @ptrCast(cap_name.ptr)); + _ = c.LLVMBuildStore(self.builder, payload, cap_alloca); + try self.named_values.put(cap_name, .{ .ptr = cap_alloca, .ty = child_ty }); + } + const val = try self.genExpr(arm.body); + break :blk val; + } else if (else_arm) |arm| { + const val = try self.genExpr(arm.body); + break :blk val; + } else { + self.br(merge_bb); + break :blk @as(?c.LLVMValueRef, null); + } + }; + // Record .some arm result (before branching, capture current BB) + if (some_val != null and c.LLVMGetTypeKind(c.LLVMTypeOf(some_val.?)) != c.LLVMVoidTypeKind) { + has_result = true; + if (result_type == null) result_type = c.LLVMTypeOf(some_val.?); + } + const some_out_bb = self.getCurrentBlock(); + if (some_val != null or (some_arm != null and !some_arm.?.is_break) or else_arm != null) { + // Only br if we didn't already (break case already branched) + if (some_arm == null or !some_arm.?.is_break) { + self.br(merge_bb); + } + } + + // Generate .none arm + self.positionAt(none_bb); + const none_val = blk: { + if (none_arm) |arm| { + if (arm.is_break) { + self.br(merge_bb); + break :blk @as(?c.LLVMValueRef, null); + } + const val = try self.genExpr(arm.body); + break :blk val; + } else if (else_arm) |arm| { + const val = try self.genExpr(arm.body); + break :blk val; + } else { + self.br(merge_bb); + break :blk @as(?c.LLVMValueRef, null); + } + }; + if (none_val != null and c.LLVMGetTypeKind(c.LLVMTypeOf(none_val.?)) != c.LLVMVoidTypeKind) { + has_result = true; + if (result_type == null) result_type = c.LLVMTypeOf(none_val.?); + } + const none_out_bb = self.getCurrentBlock(); + if (none_val != null or (none_arm != null and !none_arm.?.is_break) or else_arm != null) { + if (none_arm == null or !none_arm.?.is_break) { + self.br(merge_bb); + } + } + + self.positionAt(merge_bb); + + if (has_result and result_type != null) { + // Convert values to match result_type (handle int width mismatches) + const vals_to_add = [_]struct { val: ?c.LLVMValueRef, bb: c.LLVMBasicBlockRef }{ + .{ .val = some_val, .bb = some_out_bb }, + .{ .val = none_val, .bb = none_out_bb }, + }; + for (&vals_to_add) |entry| { + const v = entry.val orelse continue; + const vty = c.LLVMTypeOf(v); + if (c.LLVMGetTypeKind(vty) == c.LLVMVoidTypeKind) continue; + var converted = v; + if (vty != result_type) { + const src_kind = c.LLVMGetTypeKind(vty); + const dst_kind = c.LLVMGetTypeKind(result_type); + if (src_kind == c.LLVMIntegerTypeKind and dst_kind == c.LLVMIntegerTypeKind) { + const src_bits = c.LLVMGetIntTypeWidth(vty); + const dst_bits = c.LLVMGetIntTypeWidth(result_type); + if (src_bits > dst_bits) { + converted = self.trunc(v, result_type, "match_trunc"); + } else { + converted = c.LLVMBuildSExt(self.builder, v, result_type, "match_sext"); + } + } + } + try phi_vals.append(self.allocator, converted); + try phi_bbs.append(self.allocator, entry.bb); + } + return try self.buildPhiNode(&phi_vals, &phi_bbs, result_type, "opt_matchtmp"); + } + return null; + } + fn genMatchExpr(self: *CodeGen, match: ast.MatchExpr) !c.LLVMValueRef { // Determine subject type for enum vs union dispatch var enum_name: ?[]const u8 = null; @@ -7625,6 +8369,11 @@ pub const CodeGen = struct { if (subject_ty.isEnum()) enum_name = subject_ty.enum_type; if (subject_ty.isUnion()) union_name = subject_ty.union_type; + // Special case: optional type matching with .some/.none + if (subject_ty.isOptional()) { + return self.genOptionalMatch(match, subject_ty); + } + // Get the switch value: for unions, load the tag from field 0; for enums, use the value directly // For union subjects, we need a pointer for both tag loading and payload capture. // If the subject is a simple identifier, use its existing alloca; otherwise generate @@ -8051,6 +8800,7 @@ pub const CodeGen = struct { return switch (node.data) { .type_expr => |te| te.name, .identifier => |id| id.name, + .optional_type_expr => |ote| std.fmt.allocPrint(self.allocator, "?{s}", .{self.typeNodeToString(ote.inner_type)}) catch "?", .pointer_type_expr => |pte| std.fmt.allocPrint(self.allocator, "*{s}", .{self.typeNodeToString(pte.pointee_type)}) catch "?", .many_pointer_type_expr => |mpte| std.fmt.allocPrint(self.allocator, "[*]{s}", .{self.typeNodeToString(mpte.element_type)}) catch "?", .slice_type_expr => |ste| std.fmt.allocPrint(self.allocator, "[]{s}", .{self.typeNodeToString(ste.element_type)}) catch "?", @@ -8194,6 +8944,8 @@ pub const CodeGen = struct { }, .chained_comparison => return .boolean, .identifier => |ident| { + // Flow-sensitive narrowing: return narrowed type + if (self.narrowed_types.get(ident.name)) |ni| return ni.narrowed_ty; if (self.lookupValue(ident.name)) |v| return v.ty(); return Type.s(64); }, @@ -8378,6 +9130,20 @@ pub const CodeGen = struct { } return self.inferType(unop.operand); }, + .null_coalesce => |nc| { + const opt_ty = self.inferType(nc.lhs); + if (opt_ty.isOptional()) { + return self.resolveTypeFromName(opt_ty.optional_type.child_name) orelse Type.s(64); + } + return Type.s(64); + }, + .force_unwrap => |fu| { + const opt_ty = self.inferType(fu.operand); + if (opt_ty.isOptional()) { + return self.resolveTypeFromName(opt_ty.optional_type.child_name) orelse Type.s(64); + } + return Type.s(64); + }, .deref_expr => |de| { const ptr_ty = self.inferType(de.operand); if (ptr_ty.isPointer()) return self.resolveTypeFromName(ptr_ty.pointer_type.pointee_name) orelse Type.s(64); @@ -8385,6 +9151,17 @@ pub const CodeGen = struct { }, .null_literal => return .{ .pointer_type = .{ .pointee_name = "void" } }, .field_access => |fa| { + // Optional chaining: x?.field → ?FieldType + if (fa.is_optional) { + const opt_obj_ty = self.inferType(fa.object); + if (opt_obj_ty.isOptional()) { + const child_name = opt_obj_ty.optional_type.child_name; + const child_ty = self.resolveTypeFromName(child_name) orelse return Type.s(64); + const field_ty = self.inferFieldType(child_ty, fa.field) orelse return Type.s(64); + const dn = field_ty.displayName(self.allocator) catch return Type.s(64); + return Type{ .optional_type = .{ .child_name = dn } }; + } + } var obj_ty = self.inferType(fa.object); // Auto-deref: if pointer, unwrap to pointee if (obj_ty.isPointer()) { diff --git a/src/comptime.zig b/src/comptime.zig index 9f39eb6..41482fb 100644 --- a/src/comptime.zig +++ b/src/comptime.zig @@ -268,6 +268,9 @@ pub const Instruction = union(enum) { // Code insertion eval_insert: InsertInfo, // pop string, parse as code, compile + execute inline + // Optionals + opt_unwrap, // pop value, error if null_val, else push back + // Unions make_union: UnionMake, get_union_field: UnionFieldAccess, @@ -826,18 +829,43 @@ pub const Compiler = struct { } }, .if_expr => |ie| { - try self.compileNode(ie.condition); - const jump_false_idx = self.instructions.items.len; - try self.emit(.{ .jump_if_false = 0 }); // placeholder - try self.compileNode(ie.then_branch); - if (ie.else_branch) |eb| { - const jump_end_idx = self.instructions.items.len; - try self.emit(.{ .jump = 0 }); // placeholder - self.patchJumpIfFalse(jump_false_idx); - try self.compileNode(eb); - self.patchJump(jump_end_idx); + if (ie.binding_name) |binding_name| { + // if val := optional_expr { ... } else { ... } + try self.compileNode(ie.condition); + // Dup the optional value, test truthiness + try self.emit(.dup); + const jump_false_idx = self.instructions.items.len; + try self.emit(.{ .jump_if_false = 0 }); // placeholder + // Non-null path: the value is on the stack, bind as local + const slot: u16 = @intCast(self.locals.items.len); + try self.locals.append(self.allocator, .{ .name = binding_name, .depth = self.scope_depth }); + try self.emit(.{ .set_local = slot }); + try self.compileNode(ie.then_branch); + if (ie.else_branch) |eb| { + const jump_end_idx = self.instructions.items.len; + try self.emit(.{ .jump = 0 }); // placeholder + self.patchJumpIfFalse(jump_false_idx); + try self.emit(.pop); // discard the null value + try self.compileNode(eb); + self.patchJump(jump_end_idx); + } else { + self.patchJumpIfFalse(jump_false_idx); + try self.emit(.pop); // discard the null value + } } else { - self.patchJumpIfFalse(jump_false_idx); + try self.compileNode(ie.condition); + const jump_false_idx = self.instructions.items.len; + try self.emit(.{ .jump_if_false = 0 }); // placeholder + try self.compileNode(ie.then_branch); + if (ie.else_branch) |eb| { + const jump_end_idx = self.instructions.items.len; + try self.emit(.{ .jump = 0 }); // placeholder + self.patchJumpIfFalse(jump_false_idx); + try self.compileNode(eb); + self.patchJump(jump_end_idx); + } else { + self.patchJumpIfFalse(jump_false_idx); + } } }, .call => |call_node| { @@ -1066,6 +1094,20 @@ pub const Compiler = struct { .field_names = fnames, } }); }, + .force_unwrap => |fu| { + try self.compileNode(fu.operand); + try self.emit(.opt_unwrap); + }, + .null_coalesce => |nc| { + // x ?? y: evaluate x, if non-null keep it, else evaluate y + try self.compileNode(nc.lhs); + try self.emit(.dup); + const jump_idx = self.instructions.items.len; + try self.emit(.{ .jump_if_true = 0 }); // placeholder + try self.emit(.pop); // discard the null + try self.compileNode(nc.rhs); + self.patchJumpIfTrue(jump_idx); + }, .ufcs_alias => {}, // UFCS aliases are resolved at codegen, no-op in comptime else => { return error.UnsupportedExpression; @@ -1574,6 +1616,12 @@ pub const VM = struct { } }, + .opt_unwrap => { + const val = try self.pop(); + if (val == .null_val) return error.NullDereference; + try self.push(val); + }, + // Code insertion .eval_insert => |info| { // Pop the code string (result of evaluating the inner expression) diff --git a/src/lexer.zig b/src/lexer.zig index 015a75c..e91d7cc 100644 --- a/src/lexer.zig +++ b/src/lexer.zig @@ -204,6 +204,17 @@ pub const Lexer = struct { return self.makeToken(.caret, start, self.index); }, '~' => return self.makeToken(.tilde, start, self.index), + '?' => { + if (self.peek() == '?') { + self.index += 1; + return self.makeToken(.question_question, start, self.index); + } + if (self.peek() == '.') { + self.index += 1; + return self.makeToken(.question_dot, start, self.index); + } + return self.makeToken(.question, start, self.index); + }, '!' => { if (self.peek() == '=') { self.index += 1; diff --git a/src/lsp/server.zig b/src/lsp/server.zig index 12e3943..b4589e2 100644 --- a/src/lsp/server.zig +++ b/src/lsp/server.zig @@ -76,6 +76,8 @@ pub const Server = struct { if (params) |p| self.handleSignatureHelp(id, p) catch |e| self.logError(method, e); } else if (std.mem.eql(u8, method, "textDocument/semanticTokens/full")) { if (params) |p| self.handleSemanticTokens(id, p) catch |e| self.logError(method, e); + } else if (std.mem.eql(u8, method, "textDocument/inlayHint")) { + if (params) |p| self.handleInlayHint(id, p) catch |e| self.logError(method, e); } return true; @@ -1015,6 +1017,325 @@ pub const Server = struct { try self.sendResponse(id_json, result_json); } + // ---- Inlay hints ---- + + fn handleInlayHint(self: *Server, id: ?std.json.Value, params: std.json.Value) !void { + const ctx = try self.extractRequest(id, params) orelse return; + const id_json = ctx.id_json; + const file_path = uriToFilePath(ctx.uri) orelse ""; + + const doc = self.documents.get(file_path) orelse { + return try self.sendResponse(id_json, "[]"); + }; + const sema = doc.sema orelse doc.last_good_sema orelse { + return try self.sendResponse(id_json, "[]"); + }; + const root = doc.root orelse { + return try self.sendResponse(id_json, "[]"); + }; + + var hints = std.ArrayList(lsp.InlayHint).empty; + collectInlayHints(self.allocator, root, sema.symbols, doc.source, &hints); + self.collectCallHints(doc, root, &hints); + const result_json = try lsp.inlayHintsJson(self.allocator, hints.items); + try self.sendResponse(id_json, result_json); + } + + fn collectInlayHints( + allocator: std.mem.Allocator, + node: *const sx.ast.Node, + symbols: []const sx.sema.Symbol, + source: [:0]const u8, + hints: *std.ArrayList(lsp.InlayHint), + ) void { + switch (node.data) { + .root => |r| { + for (r.decls) |decl| collectInlayHints(allocator, decl, symbols, source, hints); + }, + .block => |b| { + for (b.stmts) |stmt| collectInlayHints(allocator, stmt, symbols, source, hints); + }, + .fn_decl => |fd| { + collectInlayHints(allocator, fd.body, symbols, source, hints); + }, + .lambda => |lm| { + collectInlayHints(allocator, lm.body, symbols, source, hints); + }, + .if_expr => |ie| { + if (ie.binding_name) |bname| { + addBindingHint(allocator, bname, node.span, symbols, source, hints); + } + collectInlayHints(allocator, ie.then_branch, symbols, source, hints); + if (ie.else_branch) |eb| collectInlayHints(allocator, eb, symbols, source, hints); + }, + .while_expr => |we| { + if (we.binding_name) |bname| { + addBindingHint(allocator, bname, node.span, symbols, source, hints); + } + collectInlayHints(allocator, we.body, symbols, source, hints); + }, + .for_expr => |fe| { + collectInlayHints(allocator, fe.body, symbols, source, hints); + }, + .var_decl => |vd| { + // Only show hint when type is inferred (:= syntax) + if (vd.type_annotation != null) return; + if (vd.value == null) return; + addHintForDecl(allocator, vd.name, node.span, symbols, source, hints, true); + }, + .const_decl => |cd| { + // Skip if explicit type annotation + if (cd.type_annotation != null) return; + // Skip functions, types, structs, enums, unions, comptime, foreign, library + switch (cd.value.data) { + .lambda, .fn_decl, .type_expr, .struct_decl, .enum_decl, .union_decl, + .comptime_expr, .foreign_expr, .library_decl, + => return, + else => {}, + } + addHintForDecl(allocator, cd.name, node.span, symbols, source, hints, false); + }, + else => {}, + } + } + + fn addHintForDecl( + allocator: std.mem.Allocator, + name: []const u8, + span: sx.ast.Span, + symbols: []const sx.sema.Symbol, + source: [:0]const u8, + hints: *std.ArrayList(lsp.InlayHint), + is_colon_equal: bool, + ) void { + // Find symbol by matching span start + const sym = findSymbolAtSpan(symbols, span.start, name) orelse return; + const ty = sym.ty orelse return; + + // Skip void types — not useful to display + if (ty == .void_type) return; + + const type_name = ty.displayName(allocator) catch return; + + if (is_colon_equal) { + // For `:=` declarations: place hint between `:` and `=` + // Scan from after the name to find `:=` + var pos = span.start + @as(u32, @intCast(name.len)); + while (pos + 1 < source.len) : (pos += 1) { + if (source[pos] == ':' and source[pos + 1] == '=') { + // Place hint at the `=` position (between `:` and `=`) + const eq_offset = pos + 1; + const loc = sx.errors.SourceLoc.compute(source, eq_offset); + if (loc.line == 0 or loc.col == 0) return; + hints.append(allocator, .{ + .line = loc.line - 1, + .character = loc.col - 1, + .label = type_name, + .padding_left = true, + .padding_right = true, + }) catch {}; + return; + } + } + } else { + // For `::` declarations: place hint between first `:` and second `:` + var pos = span.start + @as(u32, @intCast(name.len)); + while (pos + 1 < source.len) : (pos += 1) { + if (source[pos] == ':' and source[pos + 1] == ':') { + const second_colon = pos + 1; + const loc = sx.errors.SourceLoc.compute(source, second_colon); + if (loc.line == 0 or loc.col == 0) return; + hints.append(allocator, .{ + .line = loc.line - 1, + .character = loc.col - 1, + .label = type_name, + .padding_left = true, + .padding_right = true, + }) catch {}; + return; + } + } + } + } + + fn addBindingHint( + allocator: std.mem.Allocator, + name: []const u8, + span: sx.ast.Span, + symbols: []const sx.sema.Symbol, + source: [:0]const u8, + hints: *std.ArrayList(lsp.InlayHint), + ) void { + // Look up symbol by name + span (sema stores binding with if/while node span) + const sym = findSymbolAtSpan(symbols, span.start, name) orelse return; + const ty = sym.ty orelse return; + if (ty == .void_type) return; + + const type_name = ty.displayName(allocator) catch return; + + // Scan from span start to find the `:=` used in the binding + var pos = span.start; + while (pos + 1 < source.len) : (pos += 1) { + if (source[pos] == ':' and source[pos + 1] == '=') { + const eq_offset = pos + 1; + const loc = sx.errors.SourceLoc.compute(source, eq_offset); + if (loc.line == 0 or loc.col == 0) return; + hints.append(allocator, .{ + .line = loc.line - 1, + .character = loc.col - 1, + .label = type_name, + .padding_left = true, + .padding_right = true, + }) catch {}; + return; + } + } + } + + fn findSymbolAtSpan(symbols: []const sx.sema.Symbol, span_start: u32, name: []const u8) ?sx.sema.Symbol { + for (symbols) |sym| { + if (sym.def_span.start == span_start and std.mem.eql(u8, sym.name, name)) { + return sym; + } + } + return null; + } + + // ---- Parameter name hints at call sites ---- + + fn collectCallHints(self: *Server, doc: *const Document, node: *const sx.ast.Node, hints: *std.ArrayList(lsp.InlayHint)) void { + switch (node.data) { + .root => |r| { + for (r.decls) |decl| self.collectCallHints(doc, decl, hints); + }, + .block => |b| { + for (b.stmts) |stmt| self.collectCallHints(doc, stmt, hints); + }, + .fn_decl => |fd| { + self.collectCallHints(doc, fd.body, hints); + }, + .lambda => |lm| { + self.collectCallHints(doc, lm.body, hints); + }, + .if_expr => |ie| { + self.collectCallHints(doc, ie.condition, hints); + self.collectCallHints(doc, ie.then_branch, hints); + if (ie.else_branch) |eb| self.collectCallHints(doc, eb, hints); + }, + .while_expr => |we| { + self.collectCallHints(doc, we.condition, hints); + self.collectCallHints(doc, we.body, hints); + }, + .for_expr => |fe| { + self.collectCallHints(doc, fe.iterable, hints); + self.collectCallHints(doc, fe.body, hints); + }, + .var_decl => |vd| { + if (vd.value) |val| self.collectCallHints(doc, val, hints); + }, + .const_decl => |cd| { + self.collectCallHints(doc, cd.value, hints); + }, + .return_stmt => |rs| { + if (rs.value) |val| self.collectCallHints(doc, val, hints); + }, + .assignment => |a| { + self.collectCallHints(doc, a.value, hints); + }, + .binary_op => |bop| { + self.collectCallHints(doc, bop.lhs, hints); + self.collectCallHints(doc, bop.rhs, hints); + }, + .unary_op => |uop| { + self.collectCallHints(doc, uop.operand, hints); + }, + .call => |c| { + // Recurse into arguments (they may contain nested calls) + for (c.args) |arg| self.collectCallHints(doc, arg, hints); + // Emit parameter name hints for this call + self.emitCallParamHints(doc, c, hints); + }, + .push_stmt => |ps| { + self.collectCallHints(doc, ps.context_expr, hints); + self.collectCallHints(doc, ps.body, hints); + }, + .defer_stmt => |ds| { + self.collectCallHints(doc, ds.expr, hints); + }, + else => {}, + } + } + + fn emitCallParamHints(self: *Server, doc: *const Document, call: sx.ast.Call, hints: *std.ArrayList(lsp.InlayHint)) void { + if (call.args.len == 0) return; + + // Resolve callee name and find function declaration + var param_offset: usize = 0; + const fd = self.resolveCallTarget(doc, call, ¶m_offset) orelse return; + + // Emit hints for each argument + for (call.args, 0..) |arg, i| { + const param_idx = i + param_offset; + if (param_idx >= fd.params.len) break; + + const param = fd.params[param_idx]; + + // Skip variadic params + if (param.is_variadic) break; + + // Skip if arg is an identifier matching the param name + if (arg.data == .identifier) { + if (std.mem.eql(u8, arg.data.identifier.name, param.name)) continue; + } + + // Skip _ params + if (std.mem.eql(u8, param.name, "_")) continue; + + const loc = sx.errors.SourceLoc.compute(doc.source, arg.span.start); + if (loc.line == 0 or loc.col == 0) continue; + + const label = std.fmt.allocPrint(self.allocator, "{s}:", .{param.name}) catch continue; + hints.append(self.allocator, .{ + .line = loc.line - 1, + .character = loc.col - 1, + .label = label, + .padding_left = false, + }) catch {}; + } + } + + fn resolveCallTarget(self: *Server, doc: *const Document, call: sx.ast.Call, param_offset: *usize) ?sx.ast.FnDecl { + param_offset.* = 0; + + if (call.callee.data == .identifier) { + const name = call.callee.data.identifier.name; + return self.findFnDeclByName(doc, name); + } + + if (call.callee.data == .field_access) { + const fa = call.callee.data.field_access; + + // Try namespaced: "ns.func" + if (fa.object.data == .identifier) { + const ns_name = fa.object.data.identifier.name; + const qualified = std.fmt.allocPrint(self.allocator, "{s}.{s}", .{ ns_name, fa.field }) catch return null; + if (self.findFnDeclByName(doc, qualified)) |fd| { + return fd; + } + } + + // Try UFCS: bare function name, skip first param (receiver) + if (self.findFnDeclByName(doc, fa.field)) |fd| { + if (fd.params.len == call.args.len + 1) { + param_offset.* = 1; + } + return fd; + } + } + + return null; + } + fn classifyToken(tok: sx.token.Token, sema: SemaResult, source: [:0]const u8) ?u32 { const ST = lsp.SemanticTokenType; return switch (tok.tag) { @@ -1084,6 +1405,9 @@ pub const Server = struct { .pipe_arrow, .caret, .caret_equal, + .question, + .question_question, + .question_dot, .tilde, .less_less, .less_less_equal, diff --git a/src/lsp/types.zig b/src/lsp/types.zig index 232e395..4418331 100644 --- a/src/lsp/types.zig +++ b/src/lsp/types.zig @@ -108,7 +108,8 @@ pub fn initializeResultJson(allocator: std.mem.Allocator) ![]const u8 { "\"semanticTokensProvider\":{{\"legend\":{{" ++ "\"tokenTypes\":[\"namespace\",\"type\",\"enum\",\"struct\",\"parameter\",\"variable\",\"enumMember\",\"function\",\"keyword\",\"number\",\"string\",\"operator\"]," ++ "\"tokenModifiers\":[\"declaration\",\"readonly\"]" ++ - "}},\"full\":true}}}}}}", + "}},\"full\":true}}," ++ + "\"inlayHintProvider\":true}}}}", .{}, ); } @@ -358,3 +359,29 @@ pub fn publishDiagnosticsJson(allocator: std.mem.Allocator, uri: []const u8, dia try buf.appendSlice(allocator, "]}"); return buf.items; } + +pub const InlayHint = struct { + line: u32, + character: u32, + label: []const u8, + kind: u32 = 1, // 1 = Type + padding_left: bool = true, + padding_right: bool = false, +}; + +/// Build inlay hints JSON array response. +pub fn inlayHintsJson(allocator: std.mem.Allocator, hints: []const InlayHint) ![]const u8 { + var buf = std.ArrayList(u8).empty; + try buf.append(allocator, '['); + for (hints, 0..) |hint, idx| { + if (idx > 0) try buf.append(allocator, ','); + const label_escaped = try jsonString(allocator, hint.label); + const json = try std.fmt.allocPrint(allocator, + "{{\"position\":{{\"line\":{d},\"character\":{d}}},\"label\":{s},\"kind\":{d},\"paddingLeft\":{s},\"paddingRight\":{s}}}", + .{ hint.line, hint.character, label_escaped, hint.kind, if (hint.padding_left) "true" else "false", if (hint.padding_right) "true" else "false" }, + ); + try buf.appendSlice(allocator, json); + } + try buf.append(allocator, ']'); + return buf.items; +} diff --git a/src/parser.zig b/src/parser.zig index e4327ba..8d10900 100644 --- a/src/parser.zig +++ b/src/parser.zig @@ -327,6 +327,13 @@ pub const Parser = struct { fn parseTypeExpr(self: *Parser) anyerror!*Node { const start = self.current.loc.start; + // Optional type: ?T + if (self.current.tag == .question) { + self.advance(); // skip '?' + const inner_type = try self.parseTypeExpr(); + return try self.createNode(start, .{ .optional_type_expr = .{ .inner_type = inner_type } }); + } + // Pointer type: *T if (self.current.tag == .star) { self.advance(); // skip '*' @@ -1128,6 +1135,14 @@ pub const Parser = struct { continue; } + // Null coalescing: expr ?? default + if (self.current.tag == .question_question and Prec.null_coalesce >= min_prec) { + self.advance(); + const rhs = try self.parseBinary(Prec.null_coalesce + 1); + lhs = try self.createNode(lhs.span.start, .{ .null_coalesce = .{ .lhs = lhs, .rhs = rhs } }); + continue; + } + const prec = self.binaryPrec(); if (prec == 0 or prec < min_prec) break; @@ -1291,6 +1306,20 @@ pub const Parser = struct { } else { return self.fail("expected field name or index after '.'"); } + } else if (self.current.tag == .question_dot) { + // Optional chaining: expr?.field + self.advance(); + if (self.current.tag == .identifier) { + const field = self.tokenSlice(self.current); + self.advance(); + expr = try self.createNode(expr.span.start, .{ .field_access = .{ .object = expr, .field = field, .is_optional = true } }); + } else if (self.current.tag == .int_literal) { + const field = self.tokenSlice(self.current); + self.advance(); + expr = try self.createNode(expr.span.start, .{ .field_access = .{ .object = expr, .field = field, .is_optional = true } }); + } else { + return self.fail("expected field name after '?.'"); + } } else if (self.current.tag == .l_bracket) { // Index or slice access: expr[expr] or expr[start..end] self.advance(); @@ -1323,6 +1352,11 @@ pub const Parser = struct { } }); } } + } else if (self.current.tag == .bang) { + // Force unwrap: expr! + // Only if it's not != (bang_equal would have been lexed as a single token) + self.advance(); + expr = try self.createNode(expr.span.start, .{ .force_unwrap = .{ .operand = expr } }); } else { break; } @@ -1532,6 +1566,32 @@ pub const Parser = struct { const start = self.current.loc.start; self.advance(); // skip 'if' + // Optional binding: if val := expr { ... } + // Detect: identifier followed by := + if (self.current.tag == .identifier and self.peekNext() == .colon_equal) { + const binding_name = self.tokenSlice(self.current); + self.advance(); // skip identifier + self.advance(); // skip := + const source_expr = try self.parseExpr(); + const then_branch = try self.parseBlock(); + var else_branch: ?*Node = null; + if (self.current.tag == .kw_else) { + self.advance(); + if (self.current.tag == .kw_if) { + else_branch = try self.parseIfExpr(); + } else { + else_branch = try self.parseBlock(); + } + } + return try self.createNode(start, .{ .if_expr = .{ + .condition = source_expr, + .then_branch = then_branch, + .else_branch = else_branch, + .is_inline = false, + .binding_name = binding_name, + } }); + } + // Parse condition above comparison level, leaving comparisons // unconsumed for manual handling with match disambiguation. var condition = try self.parseBinary(Prec.shift); @@ -1627,6 +1687,20 @@ pub const Parser = struct { const start = self.current.loc.start; self.advance(); // skip 'while' + // Optional binding: while val := expr { ... } + if (self.current.tag == .identifier and self.peekNext() == .colon_equal) { + const binding_name = self.tokenSlice(self.current); + self.advance(); // skip identifier + self.advance(); // skip := + const source_expr = try self.parseExpr(); + const body = try self.parseBlock(); + return try self.createNode(start, .{ .while_expr = .{ + .condition = source_expr, + .body = body, + .binding_name = binding_name, + } }); + } + const condition = try self.parseExpr(); const body = try self.parseBlock(); @@ -1934,15 +2008,16 @@ pub const Parser = struct { const Prec = struct { const none: u8 = 0; const pipe: u8 = 1; // |> - const logical_or: u8 = 2; // or - const logical_and: u8 = 3; // and - const bit_or: u8 = 4; // | - const bit_xor: u8 = 5; // ^ - const bit_and: u8 = 6; // & - const comparison: u8 = 7; // == != < <= > >= in - const shift: u8 = 8; // << >> - const additive: u8 = 9; // + - - const multiplicative: u8 = 10; // * / % + const null_coalesce: u8 = 2; // ?? + const logical_or: u8 = 3; // or + const logical_and: u8 = 4; // and + const bit_or: u8 = 5; // | + const bit_xor: u8 = 6; // ^ + const bit_and: u8 = 7; // & + const comparison: u8 = 8; // == != < <= > >= in + const shift: u8 = 9; // << >> + const additive: u8 = 10; // + - + const multiplicative: u8 = 11; // * / % }; fn binaryPrec(self: *const Parser) u8 { diff --git a/src/sema.zig b/src/sema.zig index 6310b59..57f683f 100644 --- a/src/sema.zig +++ b/src/sema.zig @@ -278,6 +278,13 @@ pub const Analyzer = struct { const elem_name = elem_type.displayName(self.allocator) catch return .void_type; return .{ .slice_type = .{ .element_name = elem_name } }; } + // Optional type: ?T + if (tn.data == .optional_type_expr) { + const ote = tn.data.optional_type_expr; + const inner_type = self.resolveTypeNode(ote.inner_type); + const inner_name = inner_type.displayName(self.allocator) catch return .void_type; + return .{ .optional_type = .{ .child_name = inner_name } }; + } // Pointer type: *T if (tn.data == .pointer_type_expr) { const pte = tn.data.pointer_type_expr; @@ -456,6 +463,16 @@ pub const Analyzer = struct { } return .void_type; }, + .force_unwrap => |fu| { + const opt_ty = self.inferExprType(fu.operand); + if (opt_ty.isOptional()) return Type.fromName(opt_ty.optional_type.child_name) orelse .void_type; + return .void_type; + }, + .null_coalesce => |nc| { + const opt_ty = self.inferExprType(nc.lhs); + if (opt_ty.isOptional()) return Type.fromName(opt_ty.optional_type.child_name) orelse .void_type; + return self.inferExprType(nc.rhs); + }, .deref_expr => |de| { const ptr_ty = self.inferExprType(de.operand); if (ptr_ty.isPointer()) return ptr_ty.pointerPointeeType() orelse .void_type; @@ -711,7 +728,20 @@ pub const Analyzer = struct { }, .if_expr => |ie| { try self.analyzeNode(ie.condition); - try self.analyzeNode(ie.then_branch); + if (ie.binding_name) |bname| { + // `if val := expr { ... }` — val is the unwrapped optional + const cond_ty = self.inferExprType(ie.condition); + const inner_ty: ?Type = if (cond_ty.isOptional()) + Type.fromName(cond_ty.optional_type.child_name) + else + null; + try self.pushScope(); + try self.addSymbol(bname, .variable, inner_ty, node.span); + try self.analyzeNode(ie.then_branch); + self.popScope(); + } else { + try self.analyzeNode(ie.then_branch); + } if (ie.else_branch) |eb| { try self.analyzeNode(eb); } @@ -729,7 +759,19 @@ pub const Analyzer = struct { }, .while_expr => |we| { try self.analyzeNode(we.condition); - try self.analyzeNode(we.body); + if (we.binding_name) |bname| { + const cond_ty = self.inferExprType(we.condition); + const inner_ty: ?Type = if (cond_ty.isOptional()) + Type.fromName(cond_ty.optional_type.child_name) + else + null; + try self.pushScope(); + try self.addSymbol(bname, .variable, inner_ty, node.span); + try self.analyzeNode(we.body); + self.popScope(); + } else { + try self.analyzeNode(we.body); + } }, .for_expr => |fe| { try self.analyzeNode(fe.iterable); @@ -812,6 +854,7 @@ pub const Analyzer = struct { .slice_type_expr, .pointer_type_expr, .many_pointer_type_expr, + .optional_type_expr, .null_literal, .array_literal, .parameterized_type_expr, @@ -829,6 +872,13 @@ pub const Analyzer = struct { try self.analyzeNode(elem.value); } }, + .force_unwrap => |fu| { + try self.analyzeNode(fu.operand); + }, + .null_coalesce => |nc| { + try self.analyzeNode(nc.lhs); + try self.analyzeNode(nc.rhs); + }, .deref_expr => |de| { try self.analyzeNode(de.operand); }, @@ -864,6 +914,8 @@ pub const Analyzer = struct { .index_expr, .slice_expr, .deref_expr, + .force_unwrap, + .null_coalesce, .null_literal, .type_expr, .insert_expr, @@ -905,7 +957,17 @@ pub const Analyzer = struct { } } } - // For compound types (pointers, slices, arrays), resolve inner type refs + // Compound types: ?T, *T, [*]T, []T, [N]T — delegate to resolveTypeNode + switch (tn.data) { + .optional_type_expr, .pointer_type_expr, .many_pointer_type_expr, + .slice_type_expr, .array_type_expr, + => { + const resolved = self.resolveTypeNode(tn); + if (resolved != .void_type) return resolved; + }, + else => {}, + } + // For compound types, resolve inner type refs self.resolveTypeRef(tn); } return null; @@ -950,6 +1012,9 @@ pub const Analyzer = struct { .array_type_expr => |ate| { self.resolveTypeRef(ate.element_type); }, + .optional_type_expr => |ote| { + self.resolveTypeRef(ote.inner_type); + }, else => {}, } } @@ -1152,6 +1217,7 @@ pub fn findNodeAtOffset(node: *Node, offset: u32) ?*Node { .slice_type_expr, .pointer_type_expr, .many_pointer_type_expr, + .optional_type_expr, .null_literal, .array_literal, .parameterized_type_expr, @@ -1165,6 +1231,13 @@ pub fn findNodeAtOffset(node: *Node, offset: u32) ?*Node { if (findNodeAtOffset(elem.value, offset)) |found| return found; } }, + .null_coalesce => |nc| { + if (findNodeAtOffset(nc.lhs, offset)) |found| return found; + if (findNodeAtOffset(nc.rhs, offset)) |found| return found; + }, + .force_unwrap => |fu| { + if (findNodeAtOffset(fu.operand, offset)) |found| return found; + }, .deref_expr => |de| { if (findNodeAtOffset(de.operand, offset)) |found| return found; }, diff --git a/src/token.zig b/src/token.zig index c5e6a5e..e05878a 100644 --- a/src/token.zig +++ b/src/token.zig @@ -70,6 +70,9 @@ pub const Tag = enum { pipe_arrow, // |> caret, // ^ caret_equal, // ^= + question, // ? + question_question, // ?? + question_dot, // ?. tilde, // ~ less_less, // << less_less_equal, // <<= @@ -142,6 +145,9 @@ pub const Tag = enum { .pipe_arrow => "|>", .caret => "^", .caret_equal => "^=", + .question => "?", + .question_question => "??", + .question_dot => "?.", .tilde => "~", .less_less => "<<", .less_less_equal => "<<=", diff --git a/src/types.zig b/src/types.zig index 7fefdbb..5de3f3f 100644 --- a/src/types.zig +++ b/src/types.zig @@ -23,6 +23,7 @@ pub const Type = union(enum) { vector_type: VectorTypeInfo, function_type: FunctionTypeInfo, any_type, + optional_type: OptionalTypeInfo, meta_type: MetaTypeInfo, tuple_type: TupleTypeInfo, @@ -53,6 +54,10 @@ pub const Type = union(enum) { length: u32, }; + pub const OptionalTypeInfo = struct { + child_name: []const u8, + }; + pub const MetaTypeInfo = struct { name: []const u8, }; @@ -90,6 +95,7 @@ pub const Type = union(enum) { } return info.return_type.eql(o.return_type.*); }, + .optional_type => |info| std.mem.eql(u8, info.child_name, other.optional_type.child_name), .meta_type => |info| std.mem.eql(u8, info.name, other.meta_type.name), .tuple_type => |info| { const o = other.tuple_type; @@ -141,6 +147,7 @@ pub const Type = union(enum) { if (std.mem.eql(u8, name, "f64")) return .f64; return null; }, + '?' => if (name.len >= 2) .{ .optional_type = .{ .child_name = name[1..] } } else null, 'A' => if (std.mem.eql(u8, name, "Any")) .any_type else null, 'v' => if (std.mem.eql(u8, name, "void")) .void_type else null, '[' => { @@ -212,6 +219,20 @@ pub const Type = union(enum) { }; } + pub fn isOptional(self: Type) bool { + return switch (self) { + .optional_type => true, + else => false, + }; + } + + pub fn optionalChild(self: Type) ?[]const u8 { + return switch (self) { + .optional_type => |info| info.child_name, + else => null, + }; + } + pub fn isAny(self: Type) bool { return switch (self) { .any_type => true, @@ -382,6 +403,30 @@ pub const Type = union(enum) { return true; } + // T → ?T: any type implicitly wraps into its optional + if (target.isOptional()) { + const child_name = target.optional_type.child_name; + // null → ?T + if (self.isPointer() and std.mem.eql(u8, self.pointer_type.pointee_name, "void")) return true; + // ?T → ?U when T → U + if (self.isOptional()) { + const self_child = fromName(self.optional_type.child_name) orelse return false; + const target_child = fromName(child_name) orelse return false; + return self_child.isImplicitlyConvertibleTo(target_child); + } + // T → ?T: check if self matches the child type + if (fromName(child_name)) |child_type| { + return self.eql(child_type) or self.isImplicitlyConvertibleTo(child_type); + } + // Non-primitive child (struct/enum name): compare by name + return switch (self) { + .struct_type => |n| std.mem.eql(u8, n, child_name), + .enum_type => |n| std.mem.eql(u8, n, child_name), + .union_type => |n| std.mem.eql(u8, n, child_name), + else => false, + }; + } + const src_float = self.isFloat(); const dst_float = target.isFloat(); const src_int = self.isInt(); @@ -461,6 +506,7 @@ pub const Type = union(enum) { } return try buf.toOwnedSlice(allocator); }, + .optional_type => |info| return fmtAlloc(allocator, "?{s}", .{info.child_name}), .meta_type => |info| info.name, .tuple_type => |info| { var buf = std.ArrayList(u8).empty; @@ -531,6 +577,9 @@ pub const Type = union(enum) { return Type.s(capped); } + // Optional types: widen inner types + if (a.isOptional() and b.isOptional()) return a; + // Pointer types: both are pointers → return first (all are opaque ptr at LLVM level) if ((a.isPointer() or a.isManyPointer()) and (b.isPointer() or b.isManyPointer())) return a; diff --git a/tests/expected/32-optionals.exit b/tests/expected/32-optionals.exit new file mode 100644 index 0000000..573541a --- /dev/null +++ b/tests/expected/32-optionals.exit @@ -0,0 +1 @@ +0 diff --git a/tests/expected/32-optionals.txt b/tests/expected/32-optionals.txt new file mode 100644 index 0000000..fe6a4f6 --- /dev/null +++ b/tests/expected/32-optionals.txt @@ -0,0 +1,20 @@ +x = 42 +y = null +x! = 42 +x ?? 0 = 42 +y ?? 99 = 99 +if-bind x: 42 +if-bind y: none +match some: 42 +match none: 0 +p?.value = 10 +q?.value = 0 +o1.inner?.val = 99 +o2.inner?.val = 0 +narrowed a: 10 +guard 42: 42 +guard null: 0 +both: 10 20 +guard2: 7 +default next: null +comptime: 141 diff --git a/tests/expected/50-smoke.txt b/tests/expected/50-smoke.txt index ed35685..3d40bef 100644 --- a/tests/expected/50-smoke.txt +++ b/tests/expected/50-smoke.txt @@ -249,6 +249,9 @@ cast-int-f64: 42.000000 run-const: 25 run-expr: 42 run-chain: 30 +ct-opt-coalesce: 141 +ct-opt-unwrap: 77 +ct-opt-guard: 10 insert-ok insert-gen: 42 === 9. Flags === @@ -386,4 +389,41 @@ buf reset: 0 1 == (1) (1) == 1 1 == 1 +--- optionals --- +opt x: 42 +opt y: null +unwrap: 10 +coalesce a: 42 +coalesce b: 99 +if-bind x: 7 +if-bind y: none +match some: 55 +match none: 0 +wrap pos: 5 +wrap neg: null +opt field default: null +opt field set: 42 +opt param a: 42 +opt param b: 0 +opt param 7: 7 +generic opt 1: 5 +generic opt 2: 7 +generic opt 3: null +chain some: 10 +chain none: 0 +chain print: 20 +chain null: null +deep chain 1: 99 +deep chain 2: 0 +narrow x: 42 +narrow y else: null +narrow else x: 42 +guard some: 42 +guard none: 0 +and both: 10 20 +and one null +or guard: 7 +or guard null: 0 +nested narrow: 10 20 +guard loop: 3 === DONE ===