2525 TupleValue , make_aggregate , RangeValue , BoundMethodValue , ArrayValue , ConstantState ,
2626 ListValue , ClosureValue , MemoryEffect , attribute , operand , BlockRestriction
2727)
28+ from .type import PointerTy
2829from . import hir
2930from .hir import ResolvedName
3031from .op_impl import (
@@ -102,8 +103,8 @@ def body_vars(self) -> tuple[Var, ...]:
102103 @override
103104 def generate_bytecode (self , ctx : BytecodeContext ) -> tuple [bc .Value , ...]:
104105 types = tuple (x .get_type () for x in self .body_vars )
105- initial_values = [ctx .get_value_allow_undefined (input_var , ty )
106- for input_var , ty in zip ( self .initial_values , types , strict = True ) ]
106+ initial_values = [ctx .get_value (input_var )
107+ for input_var in self .initial_values ]
107108 result_type_ids = [typeid (ctx .type_table , ty ) for ty in types ]
108109
109110 if self .is_for_loop :
@@ -224,9 +225,7 @@ async def loop_impl(body: hir.Block, iterable: Var):
224225 ty = body_var .get_type_allow_invalid ()
225226 is_valid = not isinstance (ty , InvalidType )
226227 mask .append (is_valid )
227- if not is_valid :
228- body_var .set_undefined ()
229- elif not was_valid and ty .is_aggregate ():
228+ if not was_valid and is_valid and ty .is_aggregate ():
230229 # The initial variable is invalid but the loop variable is preserved,
231230 # and the loop variable is aggregate. In this case, `flat_body_vars[i]`
232231 # will contain a single variable (previously of InvalidType,
@@ -246,23 +245,28 @@ async def loop_impl(body: hir.Block, iterable: Var):
246245 new_body .params = tuple (body_params )
247246 valid_var_types = tuple (v .get_type () for v , is_valid in zip (body_vars , mask , strict = True )
248247 if is_valid )
248+ flat_var_types = flatten_aggregate_types (valid_var_types )
249249
250250 # Update Continue/Break statements
251251 for jump_info in loop_info .jumps :
252252 values = tuple (out
253253 for out , is_valid in zip (jump_info .outputs , mask , strict = True )
254254 if is_valid )
255255 flat_values = flatten_aggregates (values , valid_var_types )
256+ # For undefined break/continue value, add a MakeDummy op as its producer
257+ flat_values = _add_dummy_op_to_invalid_vars (flat_values , flat_var_types )
256258 assert len (flat_values ) == len (all_flattened_body_vars )
257259 jump_info .jump_op .values = flat_values
258260
259261 # Create the loop Operation
260- valid_initial_values = tuple (v for v , is_valid in zip (initial_values , mask , strict = True )
262+ valid_initial_values = tuple (v for v , is_valid
263+ in zip (initial_values , mask , strict = True )
261264 if is_valid )
262265 flat_initial_values = flatten_aggregates (valid_initial_values , valid_var_types )
266+ # For undefined initial value, add a MakeDummy op as its producer
267+ flat_initial_values = _add_dummy_op_to_invalid_vars (flat_initial_values , flat_var_types )
263268 assert len (flat_initial_values ) == len (all_flattened_body_vars )
264269
265- flat_var_types = flatten_aggregate_types (valid_var_types )
266270 if range_ty is None :
267271 start = stop = step = None
268272 else :
@@ -297,7 +301,8 @@ async def loop_impl(body: hir.Block, iterable: Var):
297301 for body_var , state , local_idx , is_valid in zip (body_vars , var_states , stored_locals , mask ,
298302 strict = True ):
299303 if not is_valid :
300- store_undefined (local_idx , body_var .get_type_allow_invalid (), state .result_phi .last_loc )
304+ store_invalid (local_idx , body_var .get_type_allow_invalid (),
305+ state .result_phi .last_loc )
301306
302307 # Do this check at the end because this may be an automatically inserted loop
303308 # around the helper function's body.
@@ -465,14 +470,16 @@ async def if_else_impl(cond: Var, then_block: hir.Block, else_block: hir.Block)
465470 assert num_explicit_results == 1
466471 ret = all_results [0 ]
467472 if ret is None :
468- ret = builder .ir_ctx .make_temp (builder .loc , undefined = True )
473+ assert isinstance (result_phis [0 ].ty , InvalidType )
474+ ret = builder .ir_ctx .make_temp (builder .loc )
475+ ret .set_type (result_phis [0 ].ty )
469476
470477 # Update the scope for stored named
471478 for res_var , phi , local_idx in zip (all_results [num_explicit_results :],
472479 result_phis [num_explicit_results :],
473480 stored_locals , strict = True ):
474481 if res_var is None :
475- store_undefined (local_idx , phi .ty , phi .last_loc )
482+ store_invalid (local_idx , phi .ty , phi .last_loc )
476483 else :
477484 store_var (local_idx , res_var , phi .last_loc )
478485
@@ -486,9 +493,7 @@ class Continue(Operation, opcode="continue", terminator=True):
486493
487494 @override
488495 def generate_bytecode (self , ctx : BytecodeContext ) -> tuple [()]:
489- next_values = [ctx .get_value_allow_undefined (var , ctx .typeof (body_var ))
490- for var , body_var
491- in zip (self .values , ctx .innermost_loop .body_vars , strict = True )]
496+ next_values = [ctx .get_value (var ) for var in self .values ]
492497 bc .encode_ContinueOp (ctx .builder , next_values )
493498 return ()
494499
@@ -517,11 +522,7 @@ class Break(Operation, opcode="break", terminator=True):
517522
518523 @override
519524 def generate_bytecode (self , ctx : BytecodeContext ) -> tuple [()]:
520- # body_vars is not a typo. We use body variables because they always contain the actual
521- # types of the loop variables, whereas result variables may have an InvalidType.
522- output_values = [ctx .get_value_allow_undefined (var , ctx .typeof (body_var ))
523- for var , body_var
524- in zip (self .values , ctx .innermost_loop .body_vars , strict = True )]
525+ output_values = [ctx .get_value (var ) for var in self .values ]
525526 bc .encode_BreakOp (ctx .builder , output_values )
526527 return ()
527528
@@ -616,6 +617,27 @@ def generate_bytecode(self, ctx: BytecodeContext) -> bc.Value:
616617 return ctx .constant (self .value , ctx .typeof (self .result_var ))
617618
618619
620+ @dataclass (eq = False )
621+ class MakeDummy (Operation , opcode = "make_dummy" ):
622+ """Placeholder value inserted for undefined variable in loop.
623+
624+ The use case for undefined variables is to represent loop's
625+ initial_values or continue/break's next_values during type inference or
626+ post dead code elimination.
627+ """
628+
629+ @override
630+ def generate_bytecode (self , ctx : BytecodeContext ) -> bc .Value :
631+ ty = ctx .typeof (self .result_var )
632+ if isinstance (ty , TokenTy ):
633+ return bc .encode_MakeTokenOp (ctx .builder , ctx .type_table .Token )
634+ if isinstance (ty , TileTy ) and isinstance (ty .dtype , PointerTy ):
635+ int_ty = TileTy (dtype = datatype .int64 , shape = ty .shape )
636+ const = ctx .constant (0 , int_ty )
637+ return bc .encode_IntToPtrOp (ctx .builder , typeid (ctx .type_table , ty ), const )
638+ return ctx .constant (0 , ty )
639+
640+
619641def loosely_typed_const (value : Any ,
620642 ty : Optional [Type ] = None ,
621643 loose_ty : Optional [Type ] = None ) -> Var :
@@ -1583,8 +1605,6 @@ def _to_string_rhs(self) -> str:
15831605def assign (value : Var , res : Var ) -> None :
15841606 Builder .get_current ().append_verbatim (Assign (value = value , result_vars = (res ,), loc = res .loc ))
15851607 res .ctx .copy_type_information (value , res )
1586- if value .is_undefined ():
1587- res .set_undefined ()
15881608
15891609
15901610@impl (hir .identity )
@@ -1776,8 +1796,12 @@ def flatten_aggregates(vars: Sequence[Var], types: Sequence[Type]) -> tuple[Var,
17761796 ret = []
17771797 for x , ty in zip (vars , types , strict = True ):
17781798 item_types = tuple (ty .flatten_aggregate ())
1779- if isinstance (x .get_type_allow_invalid (), InvalidType ):
1780- ret .extend (x .ctx .make_temp (x .loc , undefined = True ) for _ in item_types )
1799+ x_ty = x .get_type_allow_invalid ()
1800+ if isinstance (x_ty , InvalidType ):
1801+ for _ in item_types :
1802+ t = x .ctx .make_temp (x .loc )
1803+ t .set_type (x_ty )
1804+ ret .append (t )
17811805 else :
17821806 items = tuple (x .flatten_aggregate ())
17831807 assert len (items ) == len (item_types )
@@ -1837,10 +1861,10 @@ def _unflatten_proper_aggregate(flattened_iter: Iterator[Var], nominal: Type, ac
18371861 # Pop values from the iterator and throw them out
18381862 for _ in nominal_item_types :
18391863 next (flattened_iter )
1840- # Return an undefined variable. It is OK that we don't create an aggregate value for it --
1841- # any use of it should be invalid anyway.
18421864 builder = Builder .get_current ()
1843- return builder .ir_crx .make_temp (builder .loc , undefined = True )
1865+ t = builder .ir_ctx .make_temp (builder .loc )
1866+ t .set_type (actual )
1867+ return t
18441868
18451869 items = tuple (_maybe_unflatten_aggregate (flattened_iter , item_nominal , item_actual )
18461870 for item_nominal , item_actual
@@ -3991,10 +4015,10 @@ def store_var(local_idx: int, value: Var, loc: Loc | None = None):
39914015 assign (value , new_var )
39924016
39934017
3994- def store_undefined (local_idx : int , ty : Type , loc : Loc | None = None ):
4018+ def store_invalid (local_idx : int , ty : Type , loc : Loc | None = None ):
4019+ assert isinstance (ty , InvalidType )
39954020 scope = Scope .get_current ()
39964021 new_var = scope .local .redefine (local_idx , loc or Builder .get_current ().loc )
3997- new_var .set_undefined ()
39984022 new_var .set_type (ty )
39994023
40004024
@@ -4014,8 +4038,6 @@ def load_var_impl(name):
40144038 if rn .depth >= 0 :
40154039 ret = scope .local_scopes [rn .depth ][rn .index ]
40164040 ret .get_type () # Trigger an InvalidType check
4017- if ret .is_undefined ():
4018- raise TileSyntaxError (f"Undefined variable { name } used" )
40194041 return ret
40204042 elif rn .index >= 0 :
40214043 val = scope .func_hir .frozen_global_values [rn .index ]
@@ -4206,3 +4228,11 @@ def sym2var(x: Any) -> Var:
42064228
42074229 x = get_constant_value (x )
42084230 return loosely_typed_const (x )
4231+
4232+
4233+ def _add_dummy_op_to_invalid_vars (vars : Sequence [Var ],
4234+ actual_types : Sequence [Type ]) -> tuple [Var , ...]:
4235+ return tuple (add_operation (MakeDummy , actual )
4236+ if isinstance (v .get_type_allow_invalid (), InvalidType )
4237+ else v
4238+ for v , actual in zip (vars , actual_types , strict = True ))
0 commit comments