Skip to content

Commit 2713100

Browse files
haijieggbonik
authored andcommitted
Simplify undefined variable in IR
Remove the mutable _undefined flag from Var and instead use InvalidType directly at creation sites (scope lookup, store_invalid, aggregate placeholders) to represent variables with partially defined type. DCE no longer mutates Var state — instead of _mark_unused_vars_as_undefined, a new _replace_pruned_with_dummy function inserts MakeDummy() op for carried variables whose defining ops were pruned, using body vars as type fallbacks when the pruned var has InvalidType. This eliminates the special bytecode helpers (get_value_allow_undefined, get_value_or_zero, undefined_value), allowing Loop/Continue/Break to use plain get_value since all vars are now guaranteed to have defining ops. Downstream passes (alias analysis, code motion) are simplified to use structural checks (map membership, dict.get) instead of querying the removed flag. Signed-off-by: Jay Gu <jagu@nvidia.com>
1 parent 1294598 commit 2713100

File tree

8 files changed

+118
-100
lines changed

8 files changed

+118
-100
lines changed

src/cuda/tile/_ir/ir.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -41,18 +41,18 @@ def __init__(self, config: TileContextConfig, tileiras_version: BytecodeVersion)
4141
self.tileiras_version: BytecodeVersion = tileiras_version
4242

4343
# Make a Var with a unique name based on `name`.
44-
def make_var(self, name: str, loc: Loc, undefined: bool = False) -> Var:
44+
def make_var(self, name: str, loc: Loc) -> Var:
4545
var_name = name
4646
while var_name in self._all_vars:
4747
var_name = f"{name}.{next(self._counter_by_name[name])}"
4848
self._all_vars[var_name] = name
49-
return Var(var_name, loc, self, undefined)
49+
return Var(var_name, loc, self)
5050

5151
def make_var_like(self, var: Var) -> Var:
52-
return self.make_var(self.get_original_name(var.name), var.loc, var.is_undefined())
52+
return self.make_var(self.get_original_name(var.name), var.loc)
5353

54-
def make_temp(self, loc: Loc, undefined: bool = False) -> Var:
55-
return self.make_var(f"${next(self._temp_counter)}", loc, undefined=undefined)
54+
def make_temp(self, loc: Loc) -> Var:
55+
return self.make_var(f"${next(self._temp_counter)}", loc)
5656

5757
def get_original_name(self, var_name: str) -> str:
5858
return self._all_vars[var_name]
@@ -154,11 +154,10 @@ def as_tuple(self) -> tuple["Var", ...]:
154154

155155

156156
class Var:
157-
def __init__(self, name: str, loc: Loc, ctx: IRContext, undefined: bool = False):
157+
def __init__(self, name: str, loc: Loc, ctx: IRContext):
158158
self.name = name
159159
self.loc = loc
160160
self.ctx = ctx
161-
self._undefined = undefined
162161

163162
def try_get_type(self) -> Optional[Type]:
164163
return self.ctx.typemap.get(self.name)
@@ -173,9 +172,6 @@ def get_type_allow_invalid(self) -> Type:
173172
try:
174173
return self.ctx.typemap[self.name]
175174
except KeyError:
176-
if self._undefined:
177-
return InvalidType(f"Use of potentially undefined variable"
178-
f" `{self.get_original_name()}`", loc=Loc.unknown())
179175
raise TileInternalError(f"Type of variable {self.name} not found")
180176

181177
def set_type(self, ty: Type, force: bool = False):
@@ -213,12 +209,6 @@ def set_loose_type(self, ty: Type, force: bool = False):
213209
assert self.name not in self.ctx._loose_typemap
214210
self.ctx._loose_typemap[self.name] = ty
215211

216-
def is_undefined(self) -> bool:
217-
return self._undefined
218-
219-
def set_undefined(self):
220-
self._undefined = True
221-
222212
def get_original_name(self) -> str:
223213
return self.ctx.get_original_name(self.name)
224214

src/cuda/tile/_ir/ops.py

Lines changed: 59 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
TupleValue, make_aggregate, RangeValue, BoundMethodValue, ArrayValue, ConstantState,
2626
ListValue, ClosureValue, MemoryEffect, attribute, operand, BlockRestriction
2727
)
28+
from .type import PointerTy
2829
from . import hir
2930
from .hir import ResolvedName
3031
from .op_impl import (
@@ -102,8 +103,8 @@ def body_vars(self) -> tuple[Var, ...]:
102103
@override
103104
def generate_bytecode(self, ctx: BytecodeContext) -> tuple[bc.Value, ...]:
104105
types = tuple(x.get_type() for x in self.body_vars)
105-
initial_values = [ctx.get_value_allow_undefined(input_var, ty)
106-
for input_var, ty in zip(self.initial_values, types, strict=True)]
106+
initial_values = [ctx.get_value(input_var)
107+
for input_var in self.initial_values]
107108
result_type_ids = [typeid(ctx.type_table, ty) for ty in types]
108109

109110
if self.is_for_loop:
@@ -224,9 +225,7 @@ async def loop_impl(body: hir.Block, iterable: Var):
224225
ty = body_var.get_type_allow_invalid()
225226
is_valid = not isinstance(ty, InvalidType)
226227
mask.append(is_valid)
227-
if not is_valid:
228-
body_var.set_undefined()
229-
elif not was_valid and ty.is_aggregate():
228+
if not was_valid and is_valid and ty.is_aggregate():
230229
# The initial variable is invalid but the loop variable is preserved,
231230
# and the loop variable is aggregate. In this case, `flat_body_vars[i]`
232231
# will contain a single variable (previously of InvalidType,
@@ -246,23 +245,28 @@ async def loop_impl(body: hir.Block, iterable: Var):
246245
new_body.params = tuple(body_params)
247246
valid_var_types = tuple(v.get_type() for v, is_valid in zip(body_vars, mask, strict=True)
248247
if is_valid)
248+
flat_var_types = flatten_aggregate_types(valid_var_types)
249249

250250
# Update Continue/Break statements
251251
for jump_info in loop_info.jumps:
252252
values = tuple(out
253253
for out, is_valid in zip(jump_info.outputs, mask, strict=True)
254254
if is_valid)
255255
flat_values = flatten_aggregates(values, valid_var_types)
256+
# For undefined break/continue value, add a MakeDummy op as its producer
257+
flat_values = _add_dummy_op_to_invalid_vars(flat_values, flat_var_types)
256258
assert len(flat_values) == len(all_flattened_body_vars)
257259
jump_info.jump_op.values = flat_values
258260

259261
# Create the loop Operation
260-
valid_initial_values = tuple(v for v, is_valid in zip(initial_values, mask, strict=True)
262+
valid_initial_values = tuple(v for v, is_valid
263+
in zip(initial_values, mask, strict=True)
261264
if is_valid)
262265
flat_initial_values = flatten_aggregates(valid_initial_values, valid_var_types)
266+
# For undefined initial value, add a MakeDummy op as its producer
267+
flat_initial_values = _add_dummy_op_to_invalid_vars(flat_initial_values, flat_var_types)
263268
assert len(flat_initial_values) == len(all_flattened_body_vars)
264269

265-
flat_var_types = flatten_aggregate_types(valid_var_types)
266270
if range_ty is None:
267271
start = stop = step = None
268272
else:
@@ -297,7 +301,8 @@ async def loop_impl(body: hir.Block, iterable: Var):
297301
for body_var, state, local_idx, is_valid in zip(body_vars, var_states, stored_locals, mask,
298302
strict=True):
299303
if not is_valid:
300-
store_undefined(local_idx, body_var.get_type_allow_invalid(), state.result_phi.last_loc)
304+
store_invalid(local_idx, body_var.get_type_allow_invalid(),
305+
state.result_phi.last_loc)
301306

302307
# Do this check at the end because this may be an automatically inserted loop
303308
# around the helper function's body.
@@ -465,14 +470,16 @@ async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block)
465470
assert num_explicit_results == 1
466471
ret = all_results[0]
467472
if ret is None:
468-
ret = builder.ir_ctx.make_temp(builder.loc, undefined=True)
473+
assert isinstance(result_phis[0].ty, InvalidType)
474+
ret = builder.ir_ctx.make_temp(builder.loc)
475+
ret.set_type(result_phis[0].ty)
469476

470477
# Update the scope for stored named
471478
for res_var, phi, local_idx in zip(all_results[num_explicit_results:],
472479
result_phis[num_explicit_results:],
473480
stored_locals, strict=True):
474481
if res_var is None:
475-
store_undefined(local_idx, phi.ty, phi.last_loc)
482+
store_invalid(local_idx, phi.ty, phi.last_loc)
476483
else:
477484
store_var(local_idx, res_var, phi.last_loc)
478485

@@ -486,9 +493,7 @@ class Continue(Operation, opcode="continue", terminator=True):
486493

487494
@override
488495
def generate_bytecode(self, ctx: BytecodeContext) -> tuple[()]:
489-
next_values = [ctx.get_value_allow_undefined(var, ctx.typeof(body_var))
490-
for var, body_var
491-
in zip(self.values, ctx.innermost_loop.body_vars, strict=True)]
496+
next_values = [ctx.get_value(var) for var in self.values]
492497
bc.encode_ContinueOp(ctx.builder, next_values)
493498
return ()
494499

@@ -517,11 +522,7 @@ class Break(Operation, opcode="break", terminator=True):
517522

518523
@override
519524
def generate_bytecode(self, ctx: BytecodeContext) -> tuple[()]:
520-
# body_vars is not a typo. We use body variables because they always contain the actual
521-
# types of the loop variables, whereas result variables may have an InvalidType.
522-
output_values = [ctx.get_value_allow_undefined(var, ctx.typeof(body_var))
523-
for var, body_var
524-
in zip(self.values, ctx.innermost_loop.body_vars, strict=True)]
525+
output_values = [ctx.get_value(var) for var in self.values]
525526
bc.encode_BreakOp(ctx.builder, output_values)
526527
return ()
527528

@@ -616,6 +617,27 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
616617
return ctx.constant(self.value, ctx.typeof(self.result_var))
617618

618619

620+
@dataclass(eq=False)
621+
class MakeDummy(Operation, opcode="make_dummy"):
622+
"""Placeholder value inserted for undefined variable in loop.
623+
624+
The use case for undefined variables is to represent loop's
625+
initial_values or continue/break's next_values during type inference or
626+
post dead code elimination.
627+
"""
628+
629+
@override
630+
def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
631+
ty = ctx.typeof(self.result_var)
632+
if isinstance(ty, TokenTy):
633+
return bc.encode_MakeTokenOp(ctx.builder, ctx.type_table.Token)
634+
if isinstance(ty, TileTy) and isinstance(ty.dtype, PointerTy):
635+
int_ty = TileTy(dtype=datatype.int64, shape=ty.shape)
636+
const = ctx.constant(0, int_ty)
637+
return bc.encode_IntToPtrOp(ctx.builder, typeid(ctx.type_table, ty), const)
638+
return ctx.constant(0, ty)
639+
640+
619641
def loosely_typed_const(value: Any,
620642
ty: Optional[Type] = None,
621643
loose_ty: Optional[Type] = None) -> Var:
@@ -1583,8 +1605,6 @@ def _to_string_rhs(self) -> str:
15831605
def assign(value: Var, res: Var) -> None:
15841606
Builder.get_current().append_verbatim(Assign(value=value, result_vars=(res,), loc=res.loc))
15851607
res.ctx.copy_type_information(value, res)
1586-
if value.is_undefined():
1587-
res.set_undefined()
15881608

15891609

15901610
@impl(hir.identity)
@@ -1776,8 +1796,12 @@ def flatten_aggregates(vars: Sequence[Var], types: Sequence[Type]) -> tuple[Var,
17761796
ret = []
17771797
for x, ty in zip(vars, types, strict=True):
17781798
item_types = tuple(ty.flatten_aggregate())
1779-
if isinstance(x.get_type_allow_invalid(), InvalidType):
1780-
ret.extend(x.ctx.make_temp(x.loc, undefined=True) for _ in item_types)
1799+
x_ty = x.get_type_allow_invalid()
1800+
if isinstance(x_ty, InvalidType):
1801+
for _ in item_types:
1802+
t = x.ctx.make_temp(x.loc)
1803+
t.set_type(x_ty)
1804+
ret.append(t)
17811805
else:
17821806
items = tuple(x.flatten_aggregate())
17831807
assert len(items) == len(item_types)
@@ -1837,10 +1861,10 @@ def _unflatten_proper_aggregate(flattened_iter: Iterator[Var], nominal: Type, ac
18371861
# Pop values from the iterator and throw them out
18381862
for _ in nominal_item_types:
18391863
next(flattened_iter)
1840-
# Return an undefined variable. It is OK that we don't create an aggregate value for it --
1841-
# any use of it should be invalid anyway.
18421864
builder = Builder.get_current()
1843-
return builder.ir_crx.make_temp(builder.loc, undefined=True)
1865+
t = builder.ir_ctx.make_temp(builder.loc)
1866+
t.set_type(actual)
1867+
return t
18441868

18451869
items = tuple(_maybe_unflatten_aggregate(flattened_iter, item_nominal, item_actual)
18461870
for item_nominal, item_actual
@@ -3991,10 +4015,10 @@ def store_var(local_idx: int, value: Var, loc: Loc | None = None):
39914015
assign(value, new_var)
39924016

39934017

3994-
def store_undefined(local_idx: int, ty: Type, loc: Loc | None = None):
4018+
def store_invalid(local_idx: int, ty: Type, loc: Loc | None = None):
4019+
assert isinstance(ty, InvalidType)
39954020
scope = Scope.get_current()
39964021
new_var = scope.local.redefine(local_idx, loc or Builder.get_current().loc)
3997-
new_var.set_undefined()
39984022
new_var.set_type(ty)
39994023

40004024

@@ -4014,8 +4038,6 @@ def load_var_impl(name):
40144038
if rn.depth >= 0:
40154039
ret = scope.local_scopes[rn.depth][rn.index]
40164040
ret.get_type() # Trigger an InvalidType check
4017-
if ret.is_undefined():
4018-
raise TileSyntaxError(f"Undefined variable {name} used")
40194041
return ret
40204042
elif rn.index >= 0:
40214043
val = scope.func_hir.frozen_global_values[rn.index]
@@ -4206,3 +4228,11 @@ def sym2var(x: Any) -> Var:
42064228

42074229
x = get_constant_value(x)
42084230
return loosely_typed_const(x)
4231+
4232+
4233+
def _add_dummy_op_to_invalid_vars(vars: Sequence[Var],
4234+
actual_types: Sequence[Type]) -> tuple[Var, ...]:
4235+
return tuple(add_operation(MakeDummy, actual)
4236+
if isinstance(v.get_type_allow_invalid(), InvalidType)
4237+
else v
4238+
for v, actual in zip(vars, actual_types, strict=True))

src/cuda/tile/_ir/scope.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from cuda.tile._ir import hir
1414
from cuda.tile._ir.hir import ResolvedName
1515
from cuda.tile._ir.ir import Operation, Var, IRContext
16+
from cuda.tile._ir.type import InvalidType
1617

1718

1819
@dataclass
@@ -71,7 +72,9 @@ def get(self, index: int, loc: Loc):
7172
assert index >= 0
7273
var = self._map[index]
7374
if var is None:
74-
return self._ir_ctx.make_var(self._local_names[index], loc, undefined=True)
75+
name = self._local_names[index]
76+
var = self._ir_ctx.make_var(name, loc)
77+
var.set_type(InvalidType(f"Use of potentially undefined variable `{name}`", loc=loc))
7578
return var
7679

7780
@contextmanager

src/cuda/tile/_ir2bytecode.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -313,9 +313,6 @@ def get_constant_or_default(self, var: Var, default=None):
313313
def get_value(self, var: Var) -> bc.Value:
314314
return self._value_map[var.name]
315315

316-
def get_value_allow_undefined(self, var: Var, ty: Type) -> bc.Value:
317-
return self.undefined_value(ty) if var.is_undefined() else self.get_value(var)
318-
319316
def get_optional_value(self, var: Var) -> Optional[bc.Value]:
320317
if var.name in self._constants and self._constants[var.name] is None:
321318
return None
@@ -365,16 +362,6 @@ def constant_tuple(self, value, ty: Type) -> Tuple[bc.Value, ...]:
365362
for item_ty, item_val in zip(ty.value_types, value, strict=True)), ())
366363
return self.constant(value, ty),
367364

368-
def undefined_value(self, ty: Type) -> bc.Value:
369-
if isinstance(ty, TokenTy):
370-
return bc.encode_MakeTokenOp(self.builder, typeid(self.type_table, ty))
371-
372-
if isinstance(ty, TileTy) and isinstance(ty.dtype, PointerTy):
373-
const = self.constant(0, TileTy(dtype=datatype.int64, shape=ty.shape))
374-
return bc.encode_IntToPtrOp(self.builder, typeid(self.type_table, ty), const)
375-
376-
return self.constant(0, ty)
377-
378365
def index_tuple(self, index: tuple[Var, ...]) -> Tuple[bc.Value, ...]:
379366
i32_tile_ty = self.type_table.tile(self.type_table.I32, ())
380367
item_types = tuple(x.get_type() for x in index)

src/cuda/tile/_passes/alias_analysis.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,9 +83,6 @@ def finalize(self):
8383
def _propagate(alias_tracker: _AliasTracker,
8484
src: Var,
8585
dst: Var):
86-
if src.is_undefined():
87-
alias_tracker[src.name] = frozenset()
88-
8986
src_aliases = alias_tracker[src.name]
9087
dst_aliases = alias_tracker.get(dst.name, frozenset())
9188
alias_tracker[dst.name] = dst_aliases | src_aliases

src/cuda/tile/_passes/code_motion.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55

6-
from cuda.tile._ir.ir import Block, Var, MemoryEffect
6+
from cuda.tile._ir.ir import Block, MemoryEffect
77
from cuda.tile._ir.ops import Loop, IfElse, Continue, Break, EndBranch, Return, TileReduce, TileScan
88

99
from dataclasses import dataclass
@@ -93,10 +93,10 @@ def _hoist(block: Block, stack: list[_StackItem], def_depth: dict[str, int], is_
9393

9494
inputs = op.initial_values if isinstance(op, Loop) else op.xs
9595
for v in inputs:
96-
depinfo.update(_get_def_depth(def_depth, v), depth)
96+
depinfo.update(def_depth[v.name], depth)
9797
depinfo.update(body_res.min_depth, depth)
9898
elif isinstance(op, IfElse):
99-
depinfo.update(_get_def_depth(def_depth, op.cond), depth)
99+
depinfo.update(def_depth[op.cond.name], depth)
100100
for branch in (op.then_block, op.else_block):
101101
branch_res = _hoist(branch, stack, def_depth, False)
102102
depinfo.update(branch_res.min_depth, depth)
@@ -115,12 +115,12 @@ def _hoist(block: Block, stack: list[_StackItem], def_depth: dict[str, int], is_
115115
elif isinstance(op, EndBranch):
116116
depinfo.must_stay = True
117117
for v in op.outputs:
118-
depinfo.update(_get_def_depth(def_depth, v), depth)
118+
depinfo.update(def_depth[v.name], depth)
119119
else:
120120
# "Pure" operation without any nested blocks, side effects and jumps.
121121
assert len(op.nested_blocks) == 0
122122
for v in op.all_inputs():
123-
depinfo.update(_get_def_depth(def_depth, v), depth)
123+
depinfo.update(def_depth[v.name], depth)
124124

125125
target_depth = depth
126126
if depinfo.must_stay:
@@ -140,12 +140,3 @@ def _hoist(block: Block, stack: list[_StackItem], def_depth: dict[str, int], is_
140140
stack.pop()
141141
block[:] = new_block.detach_all()
142142
return ret
143-
144-
145-
def _get_def_depth(def_depth: dict[str, int], var: Var) -> int:
146-
try:
147-
return def_depth[var.name]
148-
except KeyError:
149-
pass
150-
assert var.is_undefined(), var.name
151-
return 0

0 commit comments

Comments
 (0)