diff --git a/language/deeplink/__init__.py b/language/deeplink/__init__.py index 044f8557..b6b96b9c 100644 --- a/language/deeplink/__init__.py +++ b/language/deeplink/__init__.py @@ -7,6 +7,7 @@ sync_block_all, set_cross_flag, wait_cross_flag, + barrier_cross_sync, parallel, inline_lambda, alloc, diff --git a/language/deeplink/core.py b/language/deeplink/core.py index 6cb1557a..9c6dec7b 100644 --- a/language/deeplink/core.py +++ b/language/deeplink/core.py @@ -330,6 +330,11 @@ def wait_cross_flag(sync_flag_type: SyncFlagType, event_id: int, _semantic=None) ) +@builtin +def barrier_cross_sync(_semantic=None): + pass + + class parallel(range): """ Iterator that counts upward forever, with parallel execution semantics. diff --git a/patch/triton/python_triton_compiler_code_generator_py.patch b/patch/triton/python_triton_compiler_code_generator_py.patch index c2f30a91..055e9d78 100644 --- a/patch/triton/python_triton_compiler_code_generator_py.patch +++ b/patch/triton/python_triton_compiler_code_generator_py.patch @@ -1,5 +1,5 @@ diff --git a/python/triton/compiler/code_generator.py b/python/triton/compiler/code_generator.py -index df09b3198..3dac7c740 100644 +index df09b3198..9fe71e077 100644 --- a/python/triton/compiler/code_generator.py +++ b/python/triton/compiler/code_generator.py @@ -20,7 +20,6 @@ from .._utils import find_paths_if, get_iterable_path, set_iterable_path @@ -61,7 +61,184 @@ index df09b3198..3dac7c740 100644 def visit_While(self, node): with enter_sub_region(self) as sr: liveins, insert_block = sr -@@ -1137,7 +1173,8 @@ class CodeGenerator(ast.NodeVisitor): +@@ -1117,6 +1153,176 @@ class CodeGenerator(ast.NodeVisitor): + + def visit_ExtSlice(self, node): + return [self.visit(dim) for dim in node.dims] ++ ++ def _update_barrier_cross_sync_in_for(self, node, id_list=[i for i in range(16)]): ++ # event_id: should be 0 ~ 15 ++ str_node = ast.dump(node) ++ num_barrier = len(re.findall(r"barrier_cross_sync", str_node)) ++ if num_barrier < 1: ++ return ++ ++ num_set = len(re.findall(r"set_cross_flag", str_node)) ++ num_wait = len(re.findall(r"wait_cross_flag", str_node)) ++ assert num_set+num_wait == 0, "set_cross_flag and wait_cross_flag are used in for loop with barrier_cross_sync!" ++ ++ def __get_async_task(node): ++ if not isinstance(node, ast.With): ++ return "" ++ for item in node.items: ++ context_expr = item.context_expr ++ if isinstance(context_expr, ast.Call): ++ func = context_expr.func ++ if isinstance(func, ast.Attribute): ++ if func.attr == 'async_task': ++ for keyword in context_expr.keywords: ++ if isinstance(keyword.value, ast.Attribute): ++ return keyword.value.attr ++ return "" ++ ++ def __get_load_and_store(node): ++ load_list = [[]] ++ store_list = [[]] ++ barrier_list = [] ++ for idx,line in enumerate(node.body): ++ if isinstance(line, ast.Expr) or isinstance(line, ast.Assign): ++ if isinstance(line.value, ast.Call): ++ if isinstance(line.value.func, ast.Attribute): ++ if line.value.func.attr == "barrier_cross_sync": ++ barrier_list.append([idx]) ++ load_list.append([]) ++ store_list.append([]) ++ if line.value.func.attr in ("store", "load"): ++ info = line.value.args[0].id ++ loadandstore_list = load_list if line.value.func.attr == "load" else store_list ++ loadandstore_list[-1].append(info) ++ assert len(load_list) == len(store_list), "len(load_list) != len(store_list) in _update_barrier_cross_sync_in_for!" ++ ++ return barrier_list, load_list, store_list ++ ++ def __create_sync(sync_type, flag_type, event_id): ++ new_expr = ast.Expr( ++ value=ast.Call( ++ func=ast.Attribute( ++ value=ast.Name(id='dl', ctx=ast.Load()), ++ attr=sync_type, ++ ctx=ast.Load() ++ ), ++ args=[ ++ ast.Attribute( ++ value=ast.Attribute( ++ value=ast.Name(id='dl', ctx=ast.Load()), ++ attr='SyncFlag', ++ ctx=ast.Load() ++ ), ++ attr=flag_type, ++ ctx=ast.Load() ++ ), ++ ast.Constant(value=event_id) ++ ], ++ keywords=[] ++ ) ++ ) ++ return new_expr ++ ++ # get async_task ++ async_tasks = {} ++ for i in node.body: ++ if __get_async_task(i) == "vector": ++ async_tasks["vector"] = i ++ if __get_async_task(i) == "cube": ++ async_tasks["cube"] = i ++ if len(async_tasks.items()) < 2: ++ return ++ ++ cube_barriers, cube_loads, cube_stores = __get_load_and_store(async_tasks["cube"]) ++ vector_barriers, vector_loads, vector_stores = __get_load_and_store(async_tasks["vector"]) ++ ++ def __get_set(rawlist): ++ return set([item for sublist in rawlist for item in (sublist if isinstance(sublist, list) else [sublist])]) ++ ++ def __clear_not_in_set(rawlist, rawset): ++ for idx, i in enumerate(rawlist): ++ if isinstance(i, list): ++ rawlist[idx] = [ii for ii in i if ii in rawset] ++ ++ cube_loads_set = __get_set(cube_loads) ++ cube_stores_set = __get_set(cube_stores) ++ vector_loads_set = __get_set(vector_loads) ++ vector_stores_set = __get_set(vector_stores) ++ ++ __clear_not_in_set(cube_loads, vector_stores_set) ++ __clear_not_in_set(cube_stores, vector_loads_set) ++ __clear_not_in_set(vector_loads, cube_stores_set) ++ __clear_not_in_set(vector_stores, cube_loads_set) ++ ++ id_cube = 0 ++ id_vector = 0 ++ id_event = 0 ++ while (id_cube < len(cube_loads) or id_vector < len(vector_loads)): ++ if id_cube >= len(cube_loads): ++ id_vector += 1 ++ continue ++ if id_vector >= len(vector_loads): ++ id_cube += 1 ++ continue ++ if len(cube_stores[id_cube]) + len(cube_loads[id_cube]) == 0: ++ id_cube += 1 ++ continue ++ if len(vector_stores[id_vector]) + len(vector_loads[id_vector]) == 0: ++ id_vector += 1 ++ continue ++ ++ cube_to_vector = False ++ vector_to_cube = False ++ event_id = id_list[id_event] ++ ++ cube_store=cube_stores[id_cube] ++ vector_load=vector_loads[id_vector] ++ for ptr in cube_store: ++ if ptr in vector_load: ++ cube_to_vector = True ++ break ++ ++ vector_store=vector_stores[id_vector] ++ cube_load=cube_loads[id_cube] ++ for ptr in vector_store: ++ if ptr in cube_load: ++ vector_to_cube = True ++ break ++ ++ assert cube_to_vector + vector_to_cube == 1, "cube_to_vector + vector_to_cube != 1" ++ ++ if cube_to_vector: ++ flag_type = "C2V" ++ cube_barriers[id_cube].append(("set_cross_flag", flag_type, id_event)) ++ vector_barriers[id_vector-1].append(("wait_cross_flag", flag_type, id_event)) ++ id_cube += 1 ++ if vector_to_cube: ++ flag_type = "V2C" ++ vector_barriers[id_vector].append(("set_cross_flag", flag_type, id_event)) ++ cube_barriers[id_cube-1].append(("wait_cross_flag", flag_type, id_event)) ++ id_vector += 1 ++ id_event += 1 ++ ++ offset = 0 ++ for barrier in cube_barriers: ++ idx = barrier[0] + offset ++ async_tasks["cube"].body.pop(idx) ++ for index, sync in enumerate(barrier[1:]): ++ sync_flag = __create_sync(*sync) ++ async_tasks["cube"].body.insert(idx+index, sync_flag) ++ if index > 0: ++ offset += 1 ++ ++ offset = 0 ++ for barrier in vector_barriers: ++ idx = barrier[0] + offset ++ async_tasks["vector"].body.pop(idx) ++ for index, sync in enumerate(barrier[1:]): ++ sync_flag = __create_sync(*sync) ++ async_tasks["vector"].body.insert(idx+index, sync_flag) ++ if index > 0: ++ offset += 1 + + def visit_For(self, node): + IteratorClass = self.visit(node.iter.func) +@@ -1137,7 +1343,8 @@ class CodeGenerator(ast.NodeVisitor): flatten = False warp_specialize = False disable_licm = False @@ -71,7 +248,7 @@ index df09b3198..3dac7c740 100644 iterator = IteratorClass(*iter_args, **iter_kwargs) # visit iterator arguments # note: only `range` iterator is supported now -@@ -1151,6 +1188,8 @@ class CodeGenerator(ast.NodeVisitor): +@@ -1151,6 +1358,8 @@ class CodeGenerator(ast.NodeVisitor): flatten = iterator.flatten warp_specialize = iterator.warp_specialize disable_licm = iterator.disable_licm @@ -80,7 +257,7 @@ index df09b3198..3dac7c740 100644 elif IteratorClass is range: # visit iterator arguments # note: only `range` iterator is supported now -@@ -1210,6 +1249,9 @@ class CodeGenerator(ast.NodeVisitor): +@@ -1210,6 +1419,9 @@ class CodeGenerator(ast.NodeVisitor): if disable_licm: for_op.set_attr("llvm.loop_annotation", self.builder.get_disable_loop_licm_attr()) @@ -90,7 +267,16 @@ index df09b3198..3dac7c740 100644 self.scf_stack.append(node) for_op_body = for_op.get_body(0) self.builder.set_insertion_point_to_start(for_op_body) -@@ -1363,6 +1405,13 @@ class CodeGenerator(ast.NodeVisitor): +@@ -1218,6 +1430,8 @@ class CodeGenerator(ast.NodeVisitor): + for name, val in zip(names, block_args): + self._maybe_set_loc_to_name(val, name) + self.set_value(name, val) ++ # update barrier cross sync ++ self._update_barrier_cross_sync_in_for(node) + self.visit_compound_statement(node.body) + self.scf_stack.pop() + yield_handles = flatten_values_to_ir(self.lscope[name] for name in names) +@@ -1363,6 +1577,13 @@ class CodeGenerator(ast.NodeVisitor): def visit_Call(self, node): fn = _unwrap_if_constexpr(self.visit(node.func))