diff --git a/config/igemm_fwd_gtc_gfx1030_nchwc_fp16x8_fsr.config b/config/igemm_fwd_gtc_gfx1030_nchwc_fp16x8_fsr.config new file mode 100755 index 00000000..a273f0d7 --- /dev/null +++ b/config/igemm_fwd_gtc_gfx1030_nchwc_fp16x8_fsr.config @@ -0,0 +1,83 @@ +[codegen] +arch = 'gfx1030' +code_object = 'cov3' +mode = 'flat' + +######################################################################################### +#--------------------------- 16x512x256 +[igemm_fwd_gtc] +gemm_m_per_block = 16 +gemm_n_per_block = 512 +gemm_k_per_block = 256 +lanegroup_tile_m = 8 +lanegroup_wave_m = 1 +lanegroup_repeat_m = 2 +lanegroup_tile_n = 8 +lanegroup_wave_n = 8 +lanegroup_repeat_n = 1 +tensor_a_thread_lengths = [1, 1, 1, 8] # 1xCEx1xK/Vec-c +tensor_a_cluster_lengths = [1,32, 1, 16] # 1xCEx1xK +tensor_b_thread_lengths = [1, 1, 1, 8] # 1xCExNB0xVec-c +tensor_b_cluster_lengths = [1, 1, 1,512] # 1xCEx1xNB1 +direction = "fwd" +precision = "fp16" +tensor_layout = 'nchwc_cyxkc' +nxb = 0 +nxe = 1 +wavefront_size = 64 +cumode = 0 +vector_c = 8 +mini_weights = 1 +tensor_b_pass_through = 1 + +#--------------------------- 16x256x128 +[igemm_fwd_gtc] +gemm_m_per_block = 16 +gemm_n_per_block = 256 +gemm_k_per_block = 128 +lanegroup_tile_m = 8 +lanegroup_wave_m = 1 +lanegroup_repeat_m = 2 +lanegroup_tile_n = 8 +lanegroup_wave_n = 8 +lanegroup_repeat_n = 1 +tensor_a_thread_lengths = [1, 1, 1, 8] # 1xCEx1xK/Vec-c +tensor_a_cluster_lengths = [1,16, 1, 16] # 1xCEx1xK +tensor_b_thread_lengths = [1, 1, 1, 8] # 1xCExNB0xVec-c +tensor_b_cluster_lengths = [1, 1, 1,256] # 1xCEx1xNB1 +direction = "fwd" +precision = "fp16" +tensor_layout = 'nchwc_cyxkc' +nxb = 0 +nxe = 1 +wavefront_size = 32 +cumode = 0 +vector_c = 8 +mini_weights = 1 +tensor_b_pass_through = 1 + +#--------------------------- 16x64x32 +[igemm_fwd_gtc] +gemm_m_per_block = 16 +gemm_n_per_block = 64 +gemm_k_per_block = 32 +lanegroup_tile_m = 8 +lanegroup_wave_m = 1 +lanegroup_repeat_m = 2 +lanegroup_tile_n = 8 +lanegroup_wave_n = 8 +lanegroup_repeat_n = 1 +tensor_a_thread_lengths = [1, 1, 1, 8] # 1xCEx1xK/Vec-c +tensor_a_cluster_lengths = [1, 4, 1, 16] # 1xCEx1xK +tensor_b_thread_lengths = [1, 1, 1, 8] # 1xCExNB0xVec-c +tensor_b_cluster_lengths = [1, 1, 1, 64] # 1xCEx1xNB1 +direction = "fwd" +precision = "fp16" +tensor_layout = 'nchwc_cyxkc' +nxb = 0 +nxe = 1 +wavefront_size = 64 +cumode = 0 +vector_c = 8 +mini_weights = 1 +tensor_b_pass_through = 1 diff --git a/python/igemm/igemm_base.py b/python/igemm/igemm_base.py index f562598a..94fe3e26 100755 --- a/python/igemm/igemm_base.py +++ b/python/igemm/igemm_base.py @@ -235,6 +235,10 @@ def __init__(self, tunable_dict): self.vector_c = utility_dict_with_default_t(tunable_dict)('vector_c', 1) self.wavefront_size = utility_dict_with_default_t(tunable_dict)('wavefront_size', 64) self.cumode = utility_dict_with_default_t(tunable_dict)('cumode', 0) + + self.mini_weights = utility_dict_with_default_t(tunable_dict)('mini_weights', 0) + if self.mini_weights == 1: + self.tensor_b_pass_through = 1 assert type(self.tensor_a_thread_lengths) is list and type(self.tensor_a_cluster_lengths) is list assert type(self.tensor_b_thread_lengths) is list and type(self.tensor_b_cluster_lengths) is list @@ -383,7 +387,8 @@ def _unmerge_x1_from_e(unroll_k, nxe): gemm_msg = f"gemm_m_per_block:{self.gemm_m_per_block} - {self.wave_tile_m}x{self.wave_step_m}x{self.wave_repeat_m}, gemm_n_per_block:{self.gemm_n_per_block} - {self.wave_tile_n}x{self.wave_step_n}x{self.wave_repeat_n}, gemm_k_per_block:{self.gemm_k_per_block}" assert self.num_global_load_a * self.block_size == self.gemm_m_per_block * self.gemm_k_per_block, gemm_msg - assert self.num_global_load_b * self.block_size == self.gemm_n_per_block * self.gemm_k_per_block, gemm_msg + if self.mini_weights != 1: + assert self.num_global_load_b * self.block_size == self.gemm_n_per_block * self.gemm_k_per_block, gemm_msg # LDS size self.lds_pad_m, self.lds_pad_n = self.get_lds_pad() # LDS pad @@ -409,6 +414,11 @@ def _unmerge_x1_from_e(unroll_k, nxe): self.lds_total = self.lds_buffer_num * self.lds_single # print(f"lds_a:{self.lds_a}, lds_b:{self.lds_b}, lds_a_np2:{self.lds_a_np2}, lds_b_np2:{self.lds_b_np2}, lds_single:{self.lds_single}, lds_total:{self.lds_total}") # TODO: LDS size check + + if self.mini_weights == 1: + self.lds_single = 8 * 1024 + self.lds_total = 8 * 1024 + self.lds_buffer_num = 1 # some parameter not in modular_conv if self.fma_type in (IGEMM_GTC_TUNABLE_FMA_TYPE_MAC, IGEMM_GTC_TUNABLE_FMA_TYPE_DLOPS): diff --git a/python/igemm/igemm_fwd_gtc_nchwc.py b/python/igemm/igemm_fwd_gtc_nchwc.py index b3c0ec18..936b3e20 100644 --- a/python/igemm/igemm_fwd_gtc_nchwc.py +++ b/python/igemm/igemm_fwd_gtc_nchwc.py @@ -128,6 +128,16 @@ def try_shift_stride(self, gpr, shifter): else: self._emit(f"s_lshr_b32 s[{gpr()}], s[{gpr()}], {-shifter}") return self._get_deferred() + + def num_weight_global_load(self): + na_k_vec_c, na_ce, _, _, _ = self.get_dims_lengths() + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + bytes_per_gld_wei = na_k_vec_c * na_ce * data_byte + num_insts = self.tunable.lds_total // bytes_per_gld_wei + return num_insts + + def num_weight_ds_write(self): + return self.num_weight_global_load() class macro_set_flag_nhw(macro_base_t): def __init__(self, mc, inline, **options): @@ -372,6 +382,148 @@ def expr(self): self._emit(f"v_bfe_u32 v[{self.v_tmp6(0)}], v[{self.v_in_flag_n()}], {i+1}, 1 ; extract flag_n ({i+1})") self._emit(m_set_flag_nhw(self.v_in_flag(i), self.v_in_flag(i), self.v_in_ihi_list(i), self.v_in_iwi_list(i), self.s_sps_hi(), self.s_sps_wi())) + class macro_move_slice_window_block_wise_acc_yx_t(macro_base_t): + ''' + can not inline + prefer to put this before global load wait. And for simplicity, no auto schedule. + ''' + def __init__(self, mc, tunable, inline, **options): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + self.declare_arg("s_ix") + self.declare_arg("s_iy") + self.declare_arg("s_x") + self.declare_arg("s_y") + self.declare_arg("v_in_os") + if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING: + self.declare_arg("v_in_i_hw_list") + else: + assert False, "not tested" + self.declare_arg("v_in_ihi_list") + self.declare_arg("v_in_iwi_list") + self.declare_arg("v_in_flag") + self.declare_arg("v_in_flag_n") + self.declare_arg("s_diff_in_iwi_acc_x") + self.declare_arg("s_diff_in_iwi_ovf_x") + self.declare_arg("s_dilation_w") + self.declare_arg("s_diff_in_ihi_acc_y") + self.declare_arg("s_diff_in_ihi_ovf_y") + self.declare_arg("s_dilation_h") + self.declare_arg("s_diff_in_os_acc_c_y_x") + self.declare_arg("s_diff_in_os_ovf_x_acc_y") + self.declare_arg("s_diff_in_os_ovf_y_acc_c") + self.declare_arg("s_sps_hi") + self.declare_arg("s_sps_wi") + + self.declare_arg("v_tmp") # 2 needed + self.declare_arg("s_tmp") + self.options = options + def name(self): + return '.v_fwd_gtc_nhwc_move_slice_window_block_wise_acc_yx' + + def expr(self): + assert "label_acc_yx" in self.options + label_acc_yx_x_end = self.options["label_acc_yx"] + '_x_end' + '_{}'.format(self.expr_cnt) + + assert "nb_per_thread" in self.options + nb_per_thread = self.options["nb_per_thread"] + + assert 'm_set_flag_nhw' in self.options + m_set_flag_nhw = self.options['m_set_flag_nhw'] + + ''' + ix accumulate, will only accumulate in width, and will never carry on to height + iy accumulate, will only accumulate in height, and will never carry on to batch + this makes life easier + ''' + # s_ix = -s_x, s_iy = -s_y, s_ix++ + self._emit(f"s_mov_b32 s[{self.s_diff_in_ihi_acc_y()}], 0") + self._emit(f"s_add_u32 s[{self.s_ix()}], 1, s[{self.s_ix()}]") + + # ihi = iho * s_stride_h + iy * s_dilation_h - s_pad_h + # iwi = iwo * s_stride_w + ix * s_dilation_w - s_pad_w + + # update iwi + self._emit(f"s_cselect_b32 s[{self.s_diff_in_iwi_acc_x()}], s[{self.s_diff_in_iwi_ovf_x()}], s[{self.s_dilation_w()}]") + #for i in range(nb_per_thread): + # self._emit(f"v_add_nc_u32 v[{self.v_in_iwi_list(i)}], s[{self.s_tmp()}], v[{self.v_in_iwi_list(i)}]") + + # update in_os + for i in range(nb_per_thread): + self._emit(f"v_add_nc_u32 v[{self.v_in_os(i)}], s[{self.s_diff_in_os_acc_c_y_x()}], v[{self.v_in_os(i)}]") + + # update ihi, accumulate + self._emit(f"s_cbranch_scc0 {label_acc_yx_x_end}") + self._emit(f"s_mul_i32 s[{self.s_ix()}], -1, s[{self.s_x()}]") + self._emit(f"s_add_u32 s[{self.s_iy()}], 1, s[{self.s_iy()}]") + self._emit(f"s_cselect_b32 s[{self.s_diff_in_ihi_acc_y()}], s[{self.s_diff_in_ihi_ovf_y()}], s[{self.s_dilation_h()}]") + + # update in_os + for i in range(nb_per_thread): + self._emit(f"v_add_nc_u32 v[{self.v_in_os(i)}], s[{self.s_diff_in_os_ovf_x_acc_y()}], v[{self.v_in_os(i)}]") + + self._emit(f"s_cbranch_scc0 {label_acc_yx_x_end}") + self._emit(f"s_mul_i32 s[{self.s_iy()}], -1, s[{self.s_y()}]") + for i in range(nb_per_thread): + self._emit(f"v_add_nc_u32 v[{self.v_in_os(i)}], s[{self.s_diff_in_os_ovf_y_acc_c()}], v[{self.v_in_os(i)}]") + + self._emit_front(f"{label_acc_yx_x_end}:") + + # now set flags + self._emit(f"s_pack_ll_b32_b16 s[{self.s_tmp()}], s[{self.s_diff_in_iwi_acc_x()}], s[{self.s_diff_in_ihi_acc_y()}]") + self._emit(f"v_mov_b32 v[{self.v_tmp()}], s[{self.s_tmp()}]") + + if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING: + self._emit(f"v_bfe_u32 v[{self.v_in_flag(0)}], v[{self.v_in_flag_n()}], {0}, 1 ; extract flag_n (0)") + if nb_per_thread >= 2: + self._emit(f"v_bfe_u32 v[{self.v_in_flag(1)}], v[{self.v_in_flag_n()}], {0}, 1 ; extract flag_n (0)") + for i in range(nb_per_thread): + self._emit(f"v_pk_add_u16 v[{self.v_in_i_hw_list(i)}], v[{self.v_tmp()}], v[{self.v_in_i_hw_list(i)}]") + else: + assert False, "not support" + self._emit_empty_line() + + if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING: + self._emit(m_set_flag_nhw(self.v_in_flag(), self.v_in_flag_n(), self.v_in_i_hw_list(), self.s_sps_hi(), self.s_sps_wi(), self.v_tmp(), self.v_tmp(1))) + else: + assert False, "un debuged code" + + + self._emit_empty_line() + + class macro_move_slice_window_block_wise_a_t(macro_base_t): + def __init__(self, mc, tunable): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + self.declare_arg("v_sld_a_os") + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + self.step = tunable.vector_c * tunable.gemm_m_per_block * data_byte + + def name(self): + return '.macro_move_slice_window_block_wise_a_t{self.tunable.tensor_a_pass_through}_{self.tunable.tensor_b_pass_through}' + + def expr(self): + self._emit(f"v_add_nc_u32 v[{self.v_sld_a_os()}], v[{self.v_sld_a_os()}], {self.step}") + self._emit_empty_line() + + class macro_move_gld_b_to_v_b_t(macro_base_t): + def __init__(self, mc, tunable): + macro_base_t.__init__(self, mc, True) + self.tunable = tunable + self.declare_arg("v_b") + self.declare_arg("v_gld_b") + data_byte = amdgpu_precision_data_byte(self.tunable.precision) + self.length = tunable.tensor_b_thread_lengths[2] * tunable.tensor_b_thread_lengths[3] // (4 // data_byte) + + def name(self): + return '.macro_move_gld_b_to_v_b_t{self.tunable.tensor_a_pass_through}_{self.tunable.tensor_b_pass_through}' + + def expr(self): + self._emit(f"s_waitcnt vmcnt(0)") + for i in range(self.length): + self._emit(f"v_mov_b32 v[{self.v_b(i)}], v[{self.v_gld_b(i)}]") + self._emit_empty_line() + class global_load_in_t(mc_base_t): def __init__(self, mc, outer): mc_base_t.__init__(self, mc) @@ -413,7 +565,10 @@ def __call__(self): s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 = self.outer.get_symbol_global_load_s_stride_d0_d1() with self._deferred_context(): self._emit(f"; load weight") - self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset() if self.outer.tunable.tensor_layout == 'nchwc_kcyxc' else None )) + if self.outer.tunable.mini_weights == 1: + self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset())) + else: + self._emit(m_wei_2d_global_load(v.v_gld_a(), s.s_p_wei(), v.v_wei_os(), s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset() if self.outer.tunable.tensor_layout == 'nchwc_kcyxc' else None )) return self._get_deferred() class shared_store_in_t(mc_base_t): @@ -632,6 +787,11 @@ def __init__(self, mc, outer): m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() wei_npc = m_wei_2d_global_load.get_num_precache_soffset() self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) + + if outer.tunable.mini_weights == 1: + m_wei_2d_global_load, m_in_2d_global_load = outer.get_macro_global_load() + wei_npc = m_wei_2d_global_load.get_num_precache_soffset() + self.s_wei_offset = sym_t("s_wei_offset" ,sseq(wei_npc)) if outer.tunable.gemm_k_global_split: self.s_block_gtc_ic = sym_t("s_block_gtc_ic" ,sseq(1)) # add c split @@ -643,6 +803,10 @@ def __init__(self, mc, outer): self.s_x_dilation_w = sym_t("s_x_dilation_w" ,self.s_tile_os_hi.value) self.s_y_dilation_h = sym_t("s_y_dilation_h" ,self.s_tile_os_wi.value) + + if self.outer.tunable.mini_weights == 1: + self.s_ix = sym_t("s_ix" ,sseq(1)) + self.s_iy = sym_t("s_iy" ,sseq(1)) self.s_end = sym_t("s_end" ,sseq()) @@ -678,7 +842,7 @@ def __init__(self, mc, outer): share_load_packed_vgpr = share_load_packed // int(4 // data_byte) num_vgpr_acc_a = share_load_packed_vgpr * outer.tunable.num_vgpr_accumulate_a if not outer.tunable.tensor_a_pass_through else 0 - num_vgpr_acc_b = share_load_packed_vgpr * outer.tunable.num_vgpr_accumulate_b if not outer.tunable.tensor_b_pass_through else 0 + num_vgpr_acc_b = share_load_packed_vgpr * outer.tunable.num_vgpr_accumulate_b # if not outer.tunable.tensor_b_pass_through else 0 # print(f"share_load_packed_vgpr:{share_load_packed_vgpr}, tunable.num_vgpr_accumulate_b:{outer.tunable.num_vgpr_accumulate_b}, num_vgpr_acc_b:{num_vgpr_acc_b}") if is_vgpr_acc_c: @@ -697,9 +861,13 @@ def __init__(self, mc, outer): if not outer.tunable.tensor_a_pass_through: self.v_a = sym_t("v_a" ,vseq(num_vgpr_acc_a+1)) - if not outer.tunable.tensor_b_pass_through: + #if not outer.tunable.tensor_b_pass_through: + if outer.tunable.mini_weights == 0: self.v_b = sym_t("v_b" ,vseq(num_vgpr_acc_b)) - self.v_gld_a = sym_t("v_gld_a" ,vseq(num_vgpr_global_load_a)) + self.v_gld_a = sym_t("v_gld_a" ,vseq(num_vgpr_global_load_a)) + else: + self.v_b = sym_t("v_b" ,vseq(num_vgpr_acc_b + num_vgpr_global_load_a)) + self.v_gld_a = sym_t("v_gld_a" ,2) if outer.tunable.global_prefetch_a_num == 2: self.v_gld_a_gpf = sym_t("v_gld_a_gpf" ,vseq(num_vgpr_global_load_a)) self.v_gld_b = sym_t("v_gld_b" ,vseq(num_vgpr_global_load_b)) @@ -732,7 +900,7 @@ def __init__(self, mc, outer): if not (outer.tunable.tensor_a_pass_through and outer.tunable.tensor_b_pass_through): self.v_gtc_ic = sym_t("v_gtc_ic" ,vseq(1)) - assert not outer.tunable.tensor_b_pass_through + assert not outer.tunable.tensor_a_pass_through self.v_gtc_iec = sym_t("v_gtc_iec" ,vseq(1)) self.v_gtc_iy = sym_t("v_gtc_iy" ,vseq(1)) self.v_gtc_ix = sym_t("v_gtc_ix" ,vseq(1)) @@ -926,6 +1094,10 @@ def get_macro_global_load(self): ctrl_wei_gld.dim_conti_flag = 0 if self.tunable.tensor_layout == 'nchwc_kcyxc' else 1 ctrl_wei_gld.workgroup_length = ca_k + + if self.tunable.mini_weights == 1: + ctrl_wei_gld.dim_conti_flag = 0 + ctrl_wei_gld.length_d1 *= self.num_weight_global_load() if self.tunable.precache_soffset: return macro_igemm_2d_global_load_precache_soffset_t(self.mc, ctrl_wei_gld, inline), \ @@ -969,8 +1141,8 @@ def get_macro_shared_store(self): in_sst_ctrl.stride_d1 = nb_nb0 * nb_nb1_vec_c // self.tunable.vector_c * k_pack * data_byte inline = True if self.tunable.fma_interleave else False - return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline) if not self.tunable.tensor_a_pass_through else None, \ - macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) if not self.tunable.tensor_b_pass_through else None + return macro_igemm_3d_shared_store_t(self.mc, in_sst_ctrl, inline) if not self.tunable.tensor_b_pass_through else None, \ + macro_igemm_3d_shared_store_t(self.mc, wei_sst_ctrl, inline) if not self.tunable.tensor_a_pass_through else None def get_macro_move_slice_window(self): inline = True if self.tunable.fma_interleave else False @@ -989,6 +1161,35 @@ def get_macro_move_slice_window(self): # return single functor ! return move_slice_window + + def get_macro_move_slice_window_acc(self): + inline = True if self.tunable.fma_interleave else False + ta_k_vec_c, tb_nb0, tb_nb_vec_c = self.get_thread_lengths() + nb_per_thread = tb_nb0 + nk_per_thread = ta_k_vec_c + unroll_k = self.tunable.gemm_k_per_block + m_set_flag_nhw = self.macro_set_flag_nhw_16_sched(self.mc, inline, nb_per_thread=nb_per_thread) if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING else \ + self.macro_set_flag_nhw(self.mc, inline) + if self.tunable.nxe != 0: + move_slice_window = self.macro_move_slice_window_block_wise_acc_yx_t(self.mc, self.tunable, inline, label_acc_yx = self.name() + "_acc_yx", + unroll_k=unroll_k, nb_per_thread=nb_per_thread, nk_per_thread=nk_per_thread, m_set_flag_nhw=m_set_flag_nhw) + else: + assert False, "not implemented ex0" + #move_slice_window = self.macro_move_slice_window_block_wise_1x1_t(self.mc, self.tunable, inline, + # unroll_k=unroll_k, nb_per_thread=nb_per_thread, nk_per_thread=nk_per_thread) + + # return single functor ! + return move_slice_window + + def get_macro_move_slice_window_a(self): + move_slice_window_a = self.macro_move_slice_window_block_wise_a_t(self.mc, self.tunable) + # return single functor ! + return move_slice_window_a + + def get_macro_move_gld_b_to_v_b(self): + move_gld_b_to_v_b = self.macro_move_gld_b_to_v_b_t(self.mc, self.tunable) + # return single functor ! + return move_gld_b_to_v_b def get_macro_set_flag_nhw(self): inline = True if self.tunable.fma_interleave else False @@ -1043,6 +1244,10 @@ def get_symbol_global_load_s_stride_d0_d1(self): else: s_wei_stride_d0 = s_dummy s_wei_stride_d1 = s_dummy + + if self.tunable.mini_weights == 1: + s_wei_stride_d0 = s_dummy + s_wei_stride_d1 = s.s_wei_stride_x return s_in_stride_d0, s_in_stride_d1, s_wei_stride_d0, s_wei_stride_d1 @@ -1214,11 +1419,7 @@ def unpack_hw(s_hw, s_h, s_w, s_0xffff): self._emit(f"; inp(1, ce, nb0, nb1) thread_length: {1}x{1}x{tb_nb0}x{tb_nb_vec_c}, cluster_length: {1}x{cb_ce}x{1}x{cb_nb1}, k_pack:{self.tunable.vector_c}") self._emit(f"v_mov_b32 v[{v.v_tmp()}], v0") - if self.tunable.tensor_b_pass_through: - # - assert False - else: - self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), cb_nb1, 1, True)) + self._emit(tc_index_dispatcher(v.v_in_inb(), v.v_tmp(), cb_nb1, 1, True)) self._emit(f"s_mov_b32 s[{s.s_0xffff()}], {0xffff}") self._emit(f"s_mov_b32 s[{s.s_tmp(1)}], {0xff}") @@ -1396,20 +1597,21 @@ def v_i_wi_psu_0(): self._emit(f"v_sub_nc_u16 v[{v_i_wi_psu_0()}], v[{v_i_wi_psu_0()}], s[{s.s_sps_px()}]") self._emit_empty_line() - # transform ce - if self.tunable.nxe != 0: - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") - self._emit(m_mdiv_u32_rem_vs(v.v_gtc_ix(), v.v_tmp(4), v.v_gtc_iec(), s.s_magic_5(), s.s_tmp(3), s.s_x(), v.v_tmp(3))) - self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") - self._emit(m_mdiv_u32_rem_vs(v.v_gtc_iy(), v.v_gtc_ic(), v.v_tmp(4), s.s_magic_4(), s.s_tmp(3), s.s_y(), v.v_tmp(3))) + if self.tunable.mini_weights == 0: + # transform ce + if self.tunable.nxe != 0: + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080008 ; offset:8, width:8") + self._emit(m_mdiv_u32_rem_vs(v.v_gtc_ix(), v.v_tmp(4), v.v_gtc_iec(), s.s_magic_5(), s.s_tmp(3), s.s_x(), v.v_tmp(3))) + self._emit(f"s_bfe_u32 s[{s.s_tmp(3)}], s[{s.s_shift_pack_1()}], 0x00080000 ; offset:0, width:8") + self._emit(m_mdiv_u32_rem_vs(v.v_gtc_iy(), v.v_gtc_ic(), v.v_tmp(4), s.s_magic_4(), s.s_tmp(3), s.s_y(), v.v_tmp(3))) - self._emit(f"v_mul_u32_u24 v[{v.v_sld_a_os()}], s[{s.s_dilation_w()}], v[{v.v_gtc_ix()}]") - self._emit(f"v_mul_u32_u24 v[{v.v_sst_a_os()}], s[{s.s_dilation_h()}], v[{v.v_gtc_iy()}]") + self._emit(f"v_mul_u32_u24 v[{v.v_sld_a_os()}], s[{s.s_dilation_w()}], v[{v.v_gtc_ix()}]") + self._emit(f"v_mul_u32_u24 v[{v.v_sst_a_os()}], s[{s.s_dilation_h()}], v[{v.v_gtc_iy()}]") - self._emit(f"v_add_nc_u16 v[{v_i_wi_psu_0()}], v[{v_i_wi_psu_0()}], v[{v.v_sld_a_os()}]") - self._emit(f"v_add_nc_u16 v[{v_i_hi_psu_0()}], v[{v_i_hi_psu_0()}], v[{v.v_sst_a_os()}]") - else: - self._emit(f"v_mov_b32 v[{v.v_gtc_ic()}], v[{v.v_gtc_iec()}]") + self._emit(f"v_add_nc_u16 v[{v_i_wi_psu_0()}], v[{v_i_wi_psu_0()}], v[{v.v_sld_a_os()}]") + self._emit(f"v_add_nc_u16 v[{v_i_hi_psu_0()}], v[{v_i_hi_psu_0()}], v[{v.v_sst_a_os()}]") + else: + self._emit(f"v_mov_b32 v[{v.v_gtc_ic()}], v[{v.v_gtc_iec()}]") self._emit(f"v_cmp_gt_u32 s[{s.s_n()}], v[{v.v_in_in()}]") self._emit(f"v_cndmask_b32 v[{v.v_tmp(3)}], 0, 1") @@ -1617,6 +1819,9 @@ def calculate_and_load_weight(): if self.tunable.tensor_layout == 'nchwc_kcyxc': self._emit(m_mul_u32_si(s.s_wei_stride_k(), s.s_wei_stride_k(), ca_k)) + + if self.tunable.mini_weights == 1: + self._emit(m_mul_u32_si(s.s_wei_stride_x(), s.s_wei_stride_x(), na_ce)) self._emit_empty_line() if self.wei_thread_copy_ndim != 1: @@ -1629,10 +1834,13 @@ def calculate_and_load_weight(): # cyxkc layout do not need s_wei_offset if self.tunable.tensor_layout == 'nchwc_kcyxc' and self.tunable.precache_soffset: self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) + + if self.tunable.mini_weights == 1 and self.tunable.precache_soffset: + self._emit(m_wei_2d_global_load.init_precache_soffset(s_wei_stride_d0(), s_wei_stride_d1(), s.s_wei_offset(), s.s_tmp())) - self._emit(f".v_clear_nc {v.v_gld_a()}, {self.get_num_vgpr_global_load_a()}") + # self._emit(f".v_clear_nc {v.v_gld_a()}, {self.get_num_vgpr_global_load_a()}") - if self.tunable.tensor_b_pass_through and self.tunable.tensor_b_pass_through_interleave_gld: + if self.tunable.tensor_a_pass_through and self.tunable.tensor_a_pass_through_interleave_gld: mbb_gld_wei = create_machine_basic_block(self.global_load_wei()) gld_per_k = self.tunable.wave_repeat_n * self.tunable.wave_step_n for i_mbb in mbb_gld_wei[0:(-1 * gld_per_k)]: @@ -1644,6 +1852,28 @@ def calculate_and_load_weight(): # do load calculate_and_load_weight() + + if self.tunable.mini_weights == 1: + v_igemm_k = v.v_gtc_iec + self._emit(f"; LDS store, wei: 1,ce,1,k: {1}x{1}x{1}x{ta_k_vec_c}, {1}x{ca_ce}x{1}x{ca_k}, k_pack:{k_pack}, k_pack_gld_a:{k_pack_gld_a}, {self.tunable.precision}") + + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {utility_log2(k_pack_src_mat)}, v[{v.v_wei_ik()}]") + na_k = na_k_vec_c // self.tunable.vector_c + self._emit(f"v_mad_u32_u24 v[{v.v_tmp()}], v[{v_igemm_k()}], {na_k * self.tunable.vector_c}, v[{v.v_tmp(2)}]") + + self._emit(m_mul_u32_vi(v.v_sst_a_os(), v.v_tmp(), data_byte)) + self._emit_empty_line() + + num_gld_wei = self.global_load_wei.get_issues() + for i in range(self.num_weight_ds_write()): + offset = na_k_vec_c * na_ce * data_byte * i + self._emit(f"s_waitcnt vmcnt({num_gld_wei - 1 - i})") + self._emit(f"ds_write_b128 v[{v.v_sst_a_os()}], v[{v.v_gld_a(i * 4)}:{v.v_gld_a(i * 4 + 3)}] offset:{offset}") + + self._emit_empty_line() + self._emit(f"v_mov_b32 v[{v.v_gtc_ic()}], 0") + + calculate_and_load_input() if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: @@ -1670,14 +1900,15 @@ def calculate_and_load_weight(): ''' v_igemm_k = v.v_gtc_iec if not self.tunable.tensor_a_pass_through: - self._emit(f"; LDS store, wei: 1,ce,1,k: {1}x{1}x{1}x{ta_k_vec_c}, {1}x{ca_ce}x{1}x{ca_k}, k_pack:{k_pack}, k_pack_gld_a:{k_pack_gld_a}, {self.tunable.precision}") - if k_pack_src_mat != 1: - self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {utility_log2(k_pack_src_mat)}, v[{v.v_wei_ik()}]") - na_k = na_k_vec_c // self.tunable.vector_c - self._emit(f"v_mad_u32_u24 v[{v.v_tmp()}], v[{v_igemm_k()}], {na_k * self.tunable.vector_c}, v[{v.v_tmp(2)}]") - else: - assert False, "need k pack larger than 1" - self._emit(m_mul_u32_vi(v.v_sst_a_os(), v.v_tmp(), data_byte)) + if self.tunable.mini_weights == 0: + self._emit(f"; LDS store, wei: 1,ce,1,k: {1}x{1}x{1}x{ta_k_vec_c}, {1}x{ca_ce}x{1}x{ca_k}, k_pack:{k_pack}, k_pack_gld_a:{k_pack_gld_a}, {self.tunable.precision}") + if k_pack_src_mat != 1: + self._emit(f"v_lshlrev_b32 v[{v.v_tmp(2)}], {utility_log2(k_pack_src_mat)}, v[{v.v_wei_ik()}]") + na_k = na_k_vec_c // self.tunable.vector_c + self._emit(f"v_mad_u32_u24 v[{v.v_tmp()}], v[{v_igemm_k()}], {na_k * self.tunable.vector_c}, v[{v.v_tmp(2)}]") + else: + assert False, "need k pack larger than 1" + self._emit(m_mul_u32_vi(v.v_sst_a_os(), v.v_tmp(), data_byte)) self._emit_empty_line() self._emit(f"v_lshlrev_b32 v[{v.v_sld_a_os()}], {utility_log2(data_byte * k_pack_src_mat * self.dotx_mapping.ctrl.thread_m())}, v[{v.v_gemm_im()}] ; LDS load wei") @@ -1767,75 +1998,96 @@ def calculate_and_load_weight(): self._emit(f"v_mov_b32 v[{v.v_coalescing_store_index()}], v[0]") self._emit(f"; move slice stride") + + if self.tunable.mini_weights == 0: + if self.tunable.nxe != 0: + pass + else: + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_in_stride_c()}], {int(data_byte * na_ce)}") + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_acc_c()}], {self.tunable.gemm_k_per_block // self.tunable.vector_c}") - if self.tunable.nxe != 0: - pass - else: - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_stride_c()}], s[{s.s_in_stride_c()}], {int(data_byte * na_ce)}") - self._emit(f"s_mov_b32 s[{s.s_move_slice_k_acc_c()}], {self.tunable.gemm_k_per_block // self.tunable.vector_c}") - - if self.tunable.tensor_layout == 'nchwc_kcyxc': - self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_gemm_k()}], {int(self.tunable.gemm_k_per_block * data_byte)}") - else: - self._emit(f"s_lshl_b32 s[{s.s_move_slice_k_stride_gemm_k()}], s[{s.s_k()}], {utility_log2(self.tunable.gemm_k_per_block * data_byte)}") + if self.tunable.tensor_layout == 'nchwc_kcyxc': + self._emit(f"s_mov_b32 s[{s.s_move_slice_k_stride_gemm_k()}], {int(self.tunable.gemm_k_per_block * data_byte)}") + else: + self._emit(f"s_lshl_b32 s[{s.s_move_slice_k_stride_gemm_k()}], s[{s.s_k()}], {utility_log2(self.tunable.gemm_k_per_block * data_byte)}") - if self.tunable.nxe != 0: - # s_diff_in_os_acc_c_y_x : s_move_slice_k_c * in_stride_c + s_move_slice_k_x * s_dilation_w * in_stride_wi + s_move_slice_k_y * s_dilation_h * in_stride_hi - # s_diff_in_os_ovf_y_acc_c : -s_y * s_dilation_h * in_stride_hi + in_stride_c - # s_diff_in_os_ovf_x_acc_y : -s_x * s_dilation_w * in_stride_wi + s_dilation_h * in_stride_hi - # s_diff_in_iwi_acc_x : s_move_slice_k_x * s_dilation_w - # s_diff_in_iwi_ovf_x : s_diff_in_iwi_acc_x - s_x * s_dilation_w - # s_x_dilation_w : -1 * s_x * s_dilation_w - # s_diff_in_ihi_acc_y : s_move_slice_k_y * s_dilation_h - # s_diff_in_ihi_ovf_y : s_diff_in_ihi_acc_y - s_y * s_dilation_h - # s_y_dilation_h : -1 * s_y * s_dilation_h - self._emit(f"s_mul_i32 s[{s.s_x_dilation_w()}], s[{s.s_x()}], s[{s.s_dilation_w()}]") - self._emit(f"s_mul_i32 s[{s.s_y_dilation_h()}], s[{s.s_y()}], s[{s.s_dilation_h()}]") - self._emit(f"v_mov_b32 v[{v.v_tmp(0)}], s[{s.s_x_dilation_w()}]") - - self._emit(f"s_mul_i32 s[{s.s_x_dilation_w()}], -1, s[{s.s_x_dilation_w()}]") - self._emit(f"v_mov_b32 v[{v.v_tmp(1)}], s[{s.s_y_dilation_h()}]") - - self._emit(f"s_mul_i32 s[{s.s_y_dilation_h()}], -1, s[{s.s_y_dilation_h()}]") - self._emit(f"v_mul_u32_u24 v[{v.v_gtc_ix()}], s[{s.s_dilation_w()}], v[{v.v_gtc_ix()}]") # CAUSION: ix * dilation_w - self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_wi()}], {utility_log2(self.tunable.vector_c * data_byte)}") # in_stride_hi - self._emit(f"v_mul_u32_u24 v[{v.v_gtc_iy()}], s[{s.s_dilation_h()}], v[{v.v_gtc_iy()}]") # CAUSION: iy * dilation_h - #self._emit(f"s_lshl_b32 s[{s.s_tmp(0)}], s[{s.s_in_stride_c()}], {utility_log2(data_byte)}") # in_stride_c - self._emit(m_mul_u32_si(s.s_tmp(0), s.s_in_stride_c(), data_byte)) - self._emit(f"s_mul_i32 s[{s.s_tmp(5)}], s[{s.s_dilation_h()}], s[{s.s_tmp(5)}]") # s_dilation_h * in_stride_hi - self._emit(f"s_mul_i32 s[{s.s_diff_in_iwi_acc_x()}], s[{s.s_move_slice_k_x()}], s[{s.s_dilation_w()}]") - self._emit(f"s_mul_i32 s[{s.s_diff_in_ihi_acc_y()}], s[{s.s_move_slice_k_y()}], s[{s.s_dilation_h()}]") + if self.tunable.nxe != 0: + # s_diff_in_os_acc_c_y_x : s_move_slice_k_c * in_stride_c + s_move_slice_k_x * s_dilation_w * in_stride_wi + s_move_slice_k_y * s_dilation_h * in_stride_hi + # s_diff_in_os_ovf_y_acc_c : -s_y * s_dilation_h * in_stride_hi + in_stride_c + # s_diff_in_os_ovf_x_acc_y : -s_x * s_dilation_w * in_stride_wi + s_dilation_h * in_stride_hi + # s_diff_in_iwi_acc_x : s_move_slice_k_x * s_dilation_w + # s_diff_in_iwi_ovf_x : s_diff_in_iwi_acc_x - s_x * s_dilation_w + # s_x_dilation_w : -1 * s_x * s_dilation_w + # s_diff_in_ihi_acc_y : s_move_slice_k_y * s_dilation_h + # s_diff_in_ihi_ovf_y : s_diff_in_ihi_acc_y - s_y * s_dilation_h + # s_y_dilation_h : -1 * s_y * s_dilation_h + self._emit(f"s_mul_i32 s[{s.s_x_dilation_w()}], s[{s.s_x()}], s[{s.s_dilation_w()}]") + self._emit(f"s_mul_i32 s[{s.s_y_dilation_h()}], s[{s.s_y()}], s[{s.s_dilation_h()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp(0)}], s[{s.s_x_dilation_w()}]") + + self._emit(f"s_mul_i32 s[{s.s_x_dilation_w()}], -1, s[{s.s_x_dilation_w()}]") + self._emit(f"v_mov_b32 v[{v.v_tmp(1)}], s[{s.s_y_dilation_h()}]") + + self._emit(f"s_mul_i32 s[{s.s_y_dilation_h()}], -1, s[{s.s_y_dilation_h()}]") + self._emit(f"v_mul_u32_u24 v[{v.v_gtc_ix()}], s[{s.s_dilation_w()}], v[{v.v_gtc_ix()}]") # CAUSION: ix * dilation_w + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_wi()}], {utility_log2(self.tunable.vector_c * data_byte)}") # in_stride_hi + self._emit(f"v_mul_u32_u24 v[{v.v_gtc_iy()}], s[{s.s_dilation_h()}], v[{v.v_gtc_iy()}]") # CAUSION: iy * dilation_h + #self._emit(f"s_lshl_b32 s[{s.s_tmp(0)}], s[{s.s_in_stride_c()}], {utility_log2(data_byte)}") # in_stride_c + self._emit(m_mul_u32_si(s.s_tmp(0), s.s_in_stride_c(), data_byte)) + self._emit(f"s_mul_i32 s[{s.s_tmp(5)}], s[{s.s_dilation_h()}], s[{s.s_tmp(5)}]") # s_dilation_h * in_stride_hi + self._emit(f"s_mul_i32 s[{s.s_diff_in_iwi_acc_x()}], s[{s.s_move_slice_k_x()}], s[{s.s_dilation_w()}]") + self._emit(f"s_mul_i32 s[{s.s_diff_in_ihi_acc_y()}], s[{s.s_move_slice_k_y()}], s[{s.s_dilation_h()}]") - self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_y()}], s[{s.s_tmp(5)}]") # s_y * s_dilation_h * in_stride_hi - self._emit(f"v_sub_nc_u16 v[{v.v_tmp(0)}], s[{s.s_diff_in_iwi_acc_x()}],v[{v.v_tmp(0)}]") - #self._emit(f"s_add_i32 s[{s.s_diff_in_iwi_ovf_x()}], s[{s.s_diff_in_iwi_acc_x()}], s[{s.s_x_dilation_w()}]") - #self._emit(f"s_add_i32 s[{s.s_diff_in_ihi_ovf_y()}], s[{s.s_diff_in_ihi_acc_y()}], s[{s.s_y_dilation_h()}]") - self._emit(f"s_sub_i32 s[{s.s_diff_in_os_ovf_y_acc_c()}], s[{s.s_tmp(0)}], s[{s.s_tmp(4)}]") - self._emit(f"v_sub_nc_u16 v[{v.v_tmp(1)}], s[{s.s_diff_in_ihi_acc_y()}],v[{v.v_tmp(1)}]") - - self._emit(f"s_lshl_b32 s[{s.s_tmp(2)}], s[{s.s_move_slice_k_x()}], {utility_log2(data_byte * self.tunable.vector_c)}") # s_move_slice_k_x * s_dilation_w * in_stride_wi - self._emit(f"s_mul_i32 s[{s.s_tmp(3)}], s[{s.s_move_slice_k_c()}], s[{s.s_tmp(0)}]") # s_move_slice_k_c * in_stride_c - self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_dilation_w()}], s[{s.s_tmp(2)}]") # s_move_slice_k_x * s_dilation_w * in_stride_wi - self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_move_slice_k_y()}], s[{s.s_tmp(5)}]") # s_move_slice_k_y * s_dilation_h * in_stride_hi - - self._emit(f"v_readfirstlane_b32 s[{s.s_diff_in_iwi_ovf_x()}], v[{v.v_tmp(0)}]") - - #self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_x_dilation_w()}], {self.tunable.vector_c * data_byte}") # s_x * s_dilation_w * in_stride_wi - self._emit(self.mul_si_func(s.s_tmp(0), s.s_x_dilation_w(), self.tunable.vector_c * data_byte)) - self._emit(f"v_readfirstlane_b32 s[{s.s_diff_in_ihi_ovf_y()}], v[{v.v_tmp(1)}]") - - self._emit(f"s_add_u32 s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_tmp(3)}], s[{s.s_tmp(1)}]") - self._emit(f"s_add_u32 s[{s.s_diff_in_os_ovf_x_acc_y()}], s[{s.s_tmp(5)}], s[{s.s_tmp(0)}]") - self._emit(f"s_add_u32 s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_tmp(2)}]") - - self._emit(f"s_mul_i32 s[{s.s_y_x_c()}], s[{s.s_x()}], s[{s.s_y()}]") - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_y_dh()}], s[{s.s_dilation_h()}], s[{s.s_move_slice_k_y()}]") - self._emit(f"s_mul_i32 s[{s.s_y_x_c()}], s[{s.s_y_x_c()}], s[{s.s_c()}]") - self._emit(f"v_add_nc_u32 v[{v.v_gtc_ix()}], s[{s.s_x_dilation_w()}], v[{v.v_gtc_ix()}]") - self._emit(f"s_mul_i32 s[{s.s_move_slice_k_x_dw()}], s[{s.s_dilation_w()}], s[{s.s_move_slice_k_x()}]") - #self._emit(f"s_lshl_b32 s[{s.s_y_x_c()}], s[{s.s_y_x_c()}], {utility_log2(data_byte)}") - self._emit(m_mul_u32_si(s.s_y_x_c(), s.s_y_x_c(), data_byte)) - self._emit(f"v_add_nc_u32 v[{v.v_gtc_iy()}], s[{s.s_y_dilation_h()}], v[{v.v_gtc_iy()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_y()}], s[{s.s_tmp(5)}]") # s_y * s_dilation_h * in_stride_hi + self._emit(f"v_sub_nc_u16 v[{v.v_tmp(0)}], s[{s.s_diff_in_iwi_acc_x()}],v[{v.v_tmp(0)}]") + #self._emit(f"s_add_i32 s[{s.s_diff_in_iwi_ovf_x()}], s[{s.s_diff_in_iwi_acc_x()}], s[{s.s_x_dilation_w()}]") + #self._emit(f"s_add_i32 s[{s.s_diff_in_ihi_ovf_y()}], s[{s.s_diff_in_ihi_acc_y()}], s[{s.s_y_dilation_h()}]") + self._emit(f"s_sub_i32 s[{s.s_diff_in_os_ovf_y_acc_c()}], s[{s.s_tmp(0)}], s[{s.s_tmp(4)}]") + self._emit(f"v_sub_nc_u16 v[{v.v_tmp(1)}], s[{s.s_diff_in_ihi_acc_y()}],v[{v.v_tmp(1)}]") + + self._emit(f"s_lshl_b32 s[{s.s_tmp(2)}], s[{s.s_move_slice_k_x()}], {utility_log2(data_byte * self.tunable.vector_c)}") # s_move_slice_k_x * s_dilation_w * in_stride_wi + self._emit(f"s_mul_i32 s[{s.s_tmp(3)}], s[{s.s_move_slice_k_c()}], s[{s.s_tmp(0)}]") # s_move_slice_k_c * in_stride_c + self._emit(f"s_mul_i32 s[{s.s_tmp(2)}], s[{s.s_dilation_w()}], s[{s.s_tmp(2)}]") # s_move_slice_k_x * s_dilation_w * in_stride_wi + self._emit(f"s_mul_i32 s[{s.s_tmp(1)}], s[{s.s_move_slice_k_y()}], s[{s.s_tmp(5)}]") # s_move_slice_k_y * s_dilation_h * in_stride_hi + + self._emit(f"v_readfirstlane_b32 s[{s.s_diff_in_iwi_ovf_x()}], v[{v.v_tmp(0)}]") + + #self._emit(f"s_mul_i32 s[{s.s_tmp(0)}], s[{s.s_x_dilation_w()}], {self.tunable.vector_c * data_byte}") # s_x * s_dilation_w * in_stride_wi + self._emit(self.mul_si_func(s.s_tmp(0), s.s_x_dilation_w(), self.tunable.vector_c * data_byte)) + self._emit(f"v_readfirstlane_b32 s[{s.s_diff_in_ihi_ovf_y()}], v[{v.v_tmp(1)}]") + + self._emit(f"s_add_u32 s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_tmp(3)}], s[{s.s_tmp(1)}]") + self._emit(f"s_add_u32 s[{s.s_diff_in_os_ovf_x_acc_y()}], s[{s.s_tmp(5)}], s[{s.s_tmp(0)}]") + self._emit(f"s_add_u32 s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_tmp(2)}]") + + self._emit(f"s_mul_i32 s[{s.s_y_x_c()}], s[{s.s_x()}], s[{s.s_y()}]") + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_y_dh()}], s[{s.s_dilation_h()}], s[{s.s_move_slice_k_y()}]") + self._emit(f"s_mul_i32 s[{s.s_y_x_c()}], s[{s.s_y_x_c()}], s[{s.s_c()}]") + self._emit(f"v_add_nc_u32 v[{v.v_gtc_ix()}], s[{s.s_x_dilation_w()}], v[{v.v_gtc_ix()}]") + self._emit(f"s_mul_i32 s[{s.s_move_slice_k_x_dw()}], s[{s.s_dilation_w()}], s[{s.s_move_slice_k_x()}]") + #self._emit(f"s_lshl_b32 s[{s.s_y_x_c()}], s[{s.s_y_x_c()}], {utility_log2(data_byte)}") + self._emit(m_mul_u32_si(s.s_y_x_c(), s.s_y_x_c(), data_byte)) + self._emit(f"v_add_nc_u32 v[{v.v_gtc_iy()}], s[{s.s_y_dilation_h()}], v[{v.v_gtc_iy()}]") + + else: + if self.tunable.nxe != 0: + self._emit(f"s_mul_i32 s[{s.s_ix()}], -1, s[{s.s_x()}]") + self._emit(f"s_mul_i32 s[{s.s_iy()}], -1, s[{s.s_y()}]") + self._emit(f"s_mul_i32 s[{s.s_x_dilation_w()}], s[{s.s_x()}], s[{s.s_dilation_w()}]") + self._emit(f"s_mul_i32 s[{s.s_y_dilation_h()}], s[{s.s_y()}], s[{s.s_dilation_h()}]") + self._emit(f"s_mul_i32 s[{s.s_x_dilation_w()}], -1, s[{s.s_x_dilation_w()}]") + self._emit(f"s_mul_i32 s[{s.s_y_dilation_h()}], -1, s[{s.s_y_dilation_h()}]") + self._emit(f"s_lshl_b32 s[{s.s_tmp(5)}], s[{s.s_wi()}], {igemm_log2(self.tunable.vector_c * data_byte)}") + self._emit(f"s_lshl_b32 s[{s.s_tmp()}], s[{s.s_in_stride_c()}], {igemm_log2(data_byte)}") + self._emit(f"s_mul_i32 s[{s.s_tmp(5)}], s[{s.s_dilation_h()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_diff_in_ihi_ovf_y()}], s[{s.s_dilation_h()}], s[{s.s_y_dilation_h()}]") + self._emit(f"s_mul_i32 s[{s.s_tmp(4)}], s[{s.s_y_dilation_h()}], s[{s.s_tmp(5)}]") + self._emit(f"s_add_u32 s[{s.s_diff_in_iwi_ovf_x()}], s[{s.s_dilation_w()}], s[{s.s_x_dilation_w()}]") + self._emit(f"s_add_u32 s[{s.s_diff_in_os_ovf_y_acc_c()}], s[{s.s_tmp()}], s[{s.s_tmp(4)}]") + self._emit(f"v_sub_nc_u16 v[{v.v_tmp(1)}], s[{s.s_diff_in_ihi_acc_y()}],v[{v.v_tmp(1)}]") + self._emit(f"s_lshl_b32 s[{s.s_diff_in_os_acc_c_y_x()}], s[{s.s_dilation_w()}], {igemm_log2(self.tunable.vector_c * data_byte)}") + self._emit(f"s_mul_i32 s[{s.s_tmp()}], s[{s.s_x_dilation_w()}], {self.tunable.vector_c * data_byte}") + self._emit(f"s_add_u32 s[{s.s_diff_in_os_ovf_x_acc_y()}], s[{s.s_tmp(5)}], s[{s.s_tmp()}]") self._emit_empty_line() @@ -1845,11 +2097,12 @@ def calculate_and_load_weight(): self._emit(f"s_mov_b32 s[{s.s_p_out(2)}], 0xffffffff") self._emit(f"s_mov_b32 s[{s.s_p_out(3)}], 0x31014000") - # pad gemmk - k_acc_per_block = self.tunable.gemm_k_per_block // self.tunable.vector_c # need to divide by vector_c - self._emit(f"s_add_i32 s[{s.s_knum()}], s[{s.s_knum()}], {k_acc_per_block - 1}") - self._emit(f"s_lshr_b32 s[{s.s_knum()}], s[{s.s_knum()}], {utility_log2(k_acc_per_block)}") - self._emit(f"s_lshl_b32 s[{s.s_knum()}], s[{s.s_knum()}], {utility_log2(k_acc_per_block)}") + if self.tunable.mini_weights == 0: + # pad gemmk + k_acc_per_block = self.tunable.gemm_k_per_block // self.tunable.vector_c # need to divide by vector_c + self._emit(f"s_add_i32 s[{s.s_knum()}], s[{s.s_knum()}], {k_acc_per_block - 1}") + self._emit(f"s_lshr_b32 s[{s.s_knum()}], s[{s.s_knum()}], {utility_log2(k_acc_per_block)}") + self._emit(f"s_lshl_b32 s[{s.s_knum()}], s[{s.s_knum()}], {utility_log2(k_acc_per_block)}") self._emit_empty_line() def emit_kernel_fma_main_loop(self): @@ -1863,91 +2116,83 @@ def emit_kernel_fma_main_loop(self): k_pack_src_mat = k_pack #if k_pack != 1 else k_pack_lanegroup m_move_slice_window = self.get_macro_move_slice_window() + m_move_slice_window_acc = self.get_macro_move_slice_window_acc() + m_move_slice_window_a = self.get_macro_move_slice_window_a() + m_move_gld_b_to_v_b = self.get_macro_move_gld_b_to_v_b() def move_slice_window_b(): ''' in nchw we only need call one move slice window ''' - if self.tunable.nxe != 0: + if self.tunable.mini_weights == 0: + if self.tunable.nxe != 0: + with self._deferred_context(): + self._emit(m_move_slice_window( + v.v_gtc_iy(), v.v_gtc_ix(), v.v_gtc_ic(), v.v_in_os(), + *(v.v_in_i_hw_list(),) if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING else (v.v_in_ihi_list(), v.v_in_iwi_list()), + v.v_in_flag(), v.v_in_flag_n(), + v.v_wei_os(), + s.s_diff_in_iwi_acc_x(), s.s_diff_in_iwi_ovf_x(), s.s_x_dilation_w(), + s.s_diff_in_ihi_acc_y(), s.s_diff_in_ihi_ovf_y(), s.s_y_dilation_h(), + s.s_move_slice_k_y_dh(), s.s_move_slice_k_x_dw(), s.s_move_slice_k_c(), + s.s_move_slice_k_stride_gemm_k(), + s.s_diff_in_os_acc_c_y_x(), s.s_diff_in_os_ovf_y_acc_c(), s.s_diff_in_os_ovf_x_acc_y(), + s.s_dilation_h(), + s.s_c(), s.s_sps_hi(), s.s_sps_wi(), + v.v_tmp())) + return self._get_deferred() + else: + with self._deferred_context(): + self._emit(m_move_slice_window( + s.s_p_in() if self.tunable.tensor_a_pass_through else s.s_in_offset(), + v.v_wei_os(), + s.s_move_slice_k_stride_c(), + s.s_move_slice_k_stride_gemm_k(), + v.v_gtc_ic(), + s.s_move_slice_k_acc_c(), + s.s_c(), + v.v_in_flag(), + v.v_tmp() + )) + return self._get_deferred() + else: + if self.tunable.nxe != 0: + with self._deferred_context(): + self._emit(m_move_slice_window_acc( + s.s_ix(), s.s_iy(), s.s_x(), s.s_y(), v.v_in_os(), + *(v.v_in_i_hw_list(),) if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING else (v.v_in_ihi_list(), v.v_in_iwi_list()), + v.v_in_flag(), v.v_in_flag_n(), + s.s_diff_in_iwi_acc_x(), s.s_diff_in_iwi_ovf_x(), s.s_dilation_w(), + s.s_diff_in_ihi_acc_y(), s.s_diff_in_ihi_ovf_y(), s.s_dilation_h(), + s.s_diff_in_os_acc_c_y_x(), s.s_diff_in_os_ovf_x_acc_y(), s.s_diff_in_os_ovf_y_acc_c(), + s.s_sps_hi(), s.s_sps_wi(), + v.v_tmp(), + s.s_tmp())) + return self._get_deferred() + else: + assert False, "not implemented" + + def move_slice_window_a(): + if self.tunable.mini_weights == 0: + return '' + else: with self._deferred_context(): - self._emit(m_move_slice_window( - v.v_gtc_iy(), v.v_gtc_ix(), v.v_gtc_ic(), v.v_in_os(), - *(v.v_in_i_hw_list(),) if IGEMM_FWD_GTC_NCHWC_16BIT_SPATIAL_INDEXING else (v.v_in_ihi_list(), v.v_in_iwi_list()), - v.v_in_flag(), v.v_in_flag_n(), - v.v_wei_os(), - s.s_diff_in_iwi_acc_x(), s.s_diff_in_iwi_ovf_x(), s.s_x_dilation_w(), - s.s_diff_in_ihi_acc_y(), s.s_diff_in_ihi_ovf_y(), s.s_y_dilation_h(), - s.s_move_slice_k_y_dh(), s.s_move_slice_k_x_dw(), s.s_move_slice_k_c(), - s.s_move_slice_k_stride_gemm_k(), - s.s_diff_in_os_acc_c_y_x(), s.s_diff_in_os_ovf_y_acc_c(), s.s_diff_in_os_ovf_x_acc_y(), - s.s_dilation_h(), - s.s_c(), s.s_sps_hi(), s.s_sps_wi(), - v.v_tmp())) + self._emit(m_move_slice_window_a(v.v_sld_a_os())) return self._get_deferred() + + def move_gld_b_to_v_b(): + if self.tunable.mini_weights == 0: + return '' else: with self._deferred_context(): - self._emit(m_move_slice_window( - s.s_p_in() if self.tunable.tensor_a_pass_through else s.s_in_offset(), - v.v_wei_os(), - s.s_move_slice_k_stride_c(), - s.s_move_slice_k_stride_gemm_k(), - v.v_gtc_ic(), - s.s_move_slice_k_acc_c(), - s.s_c(), - v.v_in_flag(), - v.v_tmp() - )) + self._emit(m_move_gld_b_to_v_b(v.v_b(), v.v_gld_b())) return self._get_deferred() - def move_slice_window_a(): - return '' - - def move_slice_window_acc(): - return '' - if self.tunable.fma_type != IGEMM_GTC_TUNABLE_FMA_TYPE_XDLOPS: # TODO: reopen legacy fma instruction if hasattr(self.tunable, 'thread_tile_m'): - fctrl = ctrl_fma_main_loop_t() - fctrl.thread_m = self.tunable.thread_tile_m - fctrl.thread_n = self.tunable.thread_tile_n - fctrl.unroll_k = self.tunable.gemm_k_per_block // self.tunable.vector_c - fctrl.label_prefix = self.name() - fctrl.gemm_m_repeat = self.tunable.gemm_m_repeat - fctrl.gemm_m_level0_cluster = self.tunable.gemm_m_level0_cluster - fctrl.gemm_m_level1_cluster = self.tunable.gemm_m_level1_cluster - fctrl.gemm_n_repeat = self.tunable.gemm_n_repeat - fctrl.gemm_n_level0_cluster = self.tunable.gemm_n_level0_cluster - fctrl.gemm_n_level1_cluster = self.tunable.gemm_n_level1_cluster - fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 - fctrl.lds_buffer_num = self.tunable.lds_buffer_num - fctrl.precision = self.tunable.precision - - # functor - fctrl.global_load_a_functor = self.global_load_wei - fctrl.global_load_b_functor = self.global_load_in - fctrl.shared_store_a_functor = self.shared_store_wei - fctrl.shared_store_b_functor = self.shared_store_in - fctrl.shared_load_a_functor = inst_ds_read_t(self.tunable.thread_sub_tile_m * data_byte) - fctrl.shared_load_b_functor = inst_ds_read_t(self.tunable.thread_sub_tile_n * data_byte) - fctrl.move_slice_window_a_functor = move_slice_window_a - fctrl.move_slice_window_b_functor = move_slice_window_b - - # sympol type - fctrl.v_a = v.v_a - fctrl.v_b = v.v_b - fctrl.v_c = v.v_c - fctrl.v_gld_a = v.v_gld_a - fctrl.v_gld_b = v.v_gld_b - fctrl.v_sld_a_os = v.v_sld_a_os - fctrl.v_sld_b_os = v.v_sld_b_os - fctrl.v_sst_a_os = v.v_sst_a_os - fctrl.v_sst_b_os = v.v_sst_b_os - fctrl.s_kitr = s.s_kitr - fctrl.s_knum = s.s_knum - - fma_main_loop = fma_main_loop_t(self.mc, fctrl) - fma_main_loop.emit() + # obsolete mode + assert False, "do not use this mode any more" else: fctrl = ctrl_dotx_main_loop_t() ctrl_dotx_mapping = get_ctrl_dotx_mapping_from_wave_tile(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, @@ -1957,7 +2202,10 @@ def move_slice_window_acc(): self.tunable.lanegroup_repeat_m, self.tunable.lanegroup_repeat_n, self.tunable.precision) fctrl.dotx_m = ctrl_dotx_mapping - fctrl.unroll_k = self.tunable.gemm_k_per_block // k_pack_src_mat + if self.tunable.mini_weights == 0: + fctrl.unroll_k = self.tunable.gemm_k_per_block // k_pack_src_mat + else: + fctrl.unroll_k = 1 fctrl.label_prefix = self.name() fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 fctrl.lds_buffer_num = self.tunable.lds_buffer_num @@ -1966,18 +2214,22 @@ def move_slice_window_acc(): fctrl.local_prefetch_num_m = self.tunable.local_prefetch_num_m fctrl.lds_k_pack = k_pack_src_mat - fctrl.k_per_step = self.tunable.gemm_k_per_block // self.tunable.vector_c + if self.tunable.mini_weights == 0: + fctrl.k_per_step = self.tunable.gemm_k_per_block // self.tunable.vector_c + else: + fctrl.k_per_step = 1 # functor # compute dpp index - fctrl.global_load_a_functor = self.global_load_wei + fctrl.global_load_a_functor = self.global_load_wei if self.tunable.mini_weights == 0 else "" fctrl.global_load_b_functor = self.global_load_in - fctrl.shared_store_a_functor = self.shared_store_wei + fctrl.shared_store_a_functor = self.shared_store_wei if self.tunable.mini_weights == 0 else "" fctrl.shared_store_b_functor = self.shared_store_in fctrl.shared_load_a_functor = inst_ds_read_mc_t(self.mc, data_byte * k_pack_src_mat * ctrl_dotx_mapping.thread_m()) - fctrl.shared_load_b_functor = inst_ds_read_mc_t(self.mc, data_byte * k_pack_src_mat * ctrl_dotx_mapping.thread_n()) + fctrl.shared_load_b_functor = inst_ds_read_mc_t(self.mc, data_byte * k_pack_src_mat * ctrl_dotx_mapping.thread_n()) if self.tunable.mini_weights == 0 else "" fctrl.move_slice_window_a_functor = move_slice_window_a fctrl.move_slice_window_b_functor = move_slice_window_b + fctrl.move_gld_b_to_v_b_functor = move_gld_b_to_v_b # sympol type fctrl.v_a = v.v_a @@ -1986,90 +2238,22 @@ def move_slice_window_acc(): fctrl.v_gld_a = v.v_gld_a fctrl.v_gld_b = v.v_gld_b fctrl.v_sld_a_os = v.v_sld_a_os - fctrl.v_sld_b_os = v.v_sld_b_os - fctrl.v_sst_a_os = v.v_sst_a_os - fctrl.v_sst_b_os = v.v_sst_b_os + fctrl.v_sld_b_os = v.v_sld_b_os if self.tunable.tensor_b_pass_through == 0 else None + fctrl.v_sst_a_os = v.v_sst_a_os if self.tunable.mini_weights == 0 else None + fctrl.v_sst_b_os = v.v_sst_b_os if self.tunable.tensor_b_pass_through == 0 else None fctrl.s_kitr = s.s_kitr fctrl.s_knum = s.s_knum + + if self.tunable.mini_weights == 1: + fctrl.mini_weights = 1 + fctrl.tensor_b_bypass_lds = 1 + fctrl.tensor_a_bypass_gld = 1 fma_main_loop = dotx_main_loop_t(self.mc, fctrl) fma_main_loop.emit() else: - a = self.agpr - fctrl = ctrl_mfma_main_loop_t() - ctrl_xdlops_mapping = get_ctrl_xdlops_mapping_from_wave_tile(self.tunable.gemm_m_per_block, self.tunable.gemm_n_per_block, - self.tunable.wave_tile_m, self.tunable.wave_tile_n, self.tunable.wave_tile_k, - self.tunable.wave_repeat_m, self.tunable.wave_repeat_n, - self.tunable.wave_step_m, self.tunable.wave_step_n, self.tunable.block_size // AMDGPU_WAVE_SIZE, - self.tunable.precision) - fctrl.cxm = ctrl_xdlops_mapping - fctrl.unroll_k = self.tunable.gemm_k_per_block - fctrl.label_prefix = self.name() - fctrl.lds_single_size = self.tunable.lds_single # in byte, should be power of 2 - fctrl.lds_buffer_num = self.tunable.lds_buffer_num - fctrl.local_prefetch_num = self.tunable.local_prefetch_num - fctrl.interleave = self.tunable.fma_interleave - fctrl.accvgpr_unified = IGEMM_FWD_GTC_NCHWC_ACCVGPR_UNIFIED and self.mc.arch_config.arch == AMDGPU_ARCH_GFX90A - - # functor - # fctrl.global_load_a_functor = self.global_load_wei - # fctrl.global_load_b_functor = self.global_load_in - # fctrl.shared_store_a_functor = self.shared_store_wei - # fctrl.shared_store_b_functor = self.shared_store_in - fctrl.global_load_a_functor = self.global_load_in - fctrl.global_load_b_functor = self.global_load_wei - fctrl.shared_store_a_functor = self.shared_store_in - fctrl.shared_store_b_functor = self.shared_store_wei - - # ta_k0, ta_k1, ta_ce0, ta_ce1, tb_ce0, tb_ce1, tb_nb0, tb_nb1 = self.get_thread_lengths() - fctrl.lds_k_pack = k_pack_src_mat - - share_load_packed = k_pack if self.tunable.tensor_a_pass_through or self.tunable.tensor_b_pass_through else ctrl_xdlops_mapping.lanegroup_k_per_thread() - - if ctrl_xdlops_mapping.wave_step_m == 1: - fctrl.shared_load_a_functor = inst_ds_read_t(data_byte * share_load_packed) # xdlops load from LDS always single load - else: - assert ctrl_xdlops_mapping.wave_step_m == 2, "currently only support wave_step_m is 2" - fctrl.shared_load_a_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte * share_load_packed, k_pack*ctrl_xdlops_mapping.wave_tile_m * data_byte, sym_t(self.vgpr.v_tmp(4))) - - if ctrl_xdlops_mapping.wave_step_n == 1: - fctrl.shared_load_b_functor = inst_ds_read_t(data_byte * share_load_packed) # xdlops load from LDS always single load - else: - assert ctrl_xdlops_mapping.wave_step_n == 2, "currently only support wave_step_n is 2" - fctrl.shared_load_b_functor = inst_ds_read2_likely_accumulate_offset_t(self.mc, 2, data_byte * share_load_packed, k_pack*ctrl_xdlops_mapping.wave_tile_n * data_byte, sym_t(self.vgpr.v_tmp(5))) - fctrl.move_slice_window_a_functor = move_slice_window_a - fctrl.move_slice_window_b_functor = move_slice_window_b - fctrl.move_slice_window_accumule_functor = None - - # sympol type - fctrl.v_a = v.v_a if not self.tunable.tensor_a_pass_through else None - fctrl.v_b = v.v_b if not self.tunable.tensor_b_pass_through else None - fctrl.a_c = a.a_c - fctrl.v_gld_a = v.v_gld_a - fctrl.v_gld_b = v.v_gld_b - fctrl.v_gld_a_gpf = v.v_gld_a_gpf if self.tunable.global_prefetch_a_num == 2 else None - fctrl.v_gld_b_gpf = v.v_gld_b_gpf if self.tunable.global_prefetch_b_num == 2 else None - fctrl.v_gld_a_num = self.get_num_vgpr_global_load_a() - fctrl.v_gld_b_num = self.get_num_vgpr_global_load_b() - fctrl.v_sld_a_os = v.v_sld_a_os if not self.tunable.tensor_a_pass_through else None - fctrl.v_sld_b_os = v.v_sld_b_os if not self.tunable.tensor_b_pass_through else None - fctrl.v_sst_a_os = v.v_sst_a_os if not self.tunable.tensor_a_pass_through else None - fctrl.v_sst_b_os = v.v_sst_b_os if not self.tunable.tensor_b_pass_through else None - fctrl.s_kitr = s.s_kitr - fctrl.s_knum = s.s_knum - fctrl.pass_through_a = self.tunable.tensor_a_pass_through - fctrl.pass_through_b = self.tunable.tensor_b_pass_through - fctrl.pass_through_a_v_pack = self.get_k_pack() - fctrl.pass_through_b_v_pack = self.get_k_pack() - - fctrl.pass_through_a_interleave_gld = 1 if self.tunable.tensor_a_pass_through_interleave_gld else 0 - fctrl.pass_through_b_interleave_gld = 1 if self.tunable.tensor_b_pass_through_interleave_gld else 0 - - fctrl.precision = self.tunable.precision - - mfma_main_loop = mfma_main_loop_t(self.mc, fctrl) - mfma_main_loop.emit() + assert False, "do not support for mfma" def emit_kernel_epilogue(self): diff --git a/python/operations/dotx_mapping.py b/python/operations/dotx_mapping.py index 12657ba2..9033d89a 100644 --- a/python/operations/dotx_mapping.py +++ b/python/operations/dotx_mapping.py @@ -288,7 +288,13 @@ def serialize(self): #ctrl_dotx_mapping_t( 32, 64, 8, 8, 8, 1, 2, 2, 4, 1, 2, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 #ctrl_dotx_mapping_t( 32, 64, 8, 8, 8, 1, 2, 4, 4, 1, 1, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 + ctrl_dotx_mapping_t( 16, 64, 8, 8, 8, 1, 2, 4, 1, 1, 2, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 ctrl_dotx_mapping_t( 32, 32, 8, 8, 8, 1, 2, 2, 1, 2, 2, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 + + # for inference configs + ctrl_dotx_mapping_t( 16, 512, 8, 8, 8, 1, 1, 8, 8, 2, 1, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 + ctrl_dotx_mapping_t( 16, 256, 8, 8, 8, 1, 1, 8, 4, 2, 1, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 + ctrl_dotx_mapping_t( 16, 64, 8, 8, 8, 1, 1, 8, 1, 2, 1, v_dot2c_f32_f16), # extra k pack can be 1, 2, 4 ] # mt_m,mt_n,lt_m,lt_n,ld_m,ld_n,lw_m,lw_n, ws,lr_m,lr_n, inst_mfma diff --git a/python/operations/main_loop_graph.py b/python/operations/main_loop_graph.py index e27c2c10..92f65d8d 100644 --- a/python/operations/main_loop_graph.py +++ b/python/operations/main_loop_graph.py @@ -52,6 +52,7 @@ def __init__(self): self.shared_load_b_functor = None self.move_slice_window_a_functor = None self.move_slice_window_b_functor = None + self.move_gld_b_to_v_b_functor = None # sympol type self.v_a = None @@ -75,6 +76,13 @@ def __init__(self): self.k_per_step = 1 self.lds_pad_m = 0 # pad how many pixels per m row self.lds_pad_n = 0 # pad how many pixels per n row + + # control bit for mini weights kernels + self.mini_weights = 0 + self.tensor_a_bypass_lds = 0 + self.tensor_b_bypass_lds = 0 + self.tensor_a_bypass_gld = 0 + self.tensor_b_bypass_gld = 0 class ds_waitcnt_t(object): ''' @@ -227,6 +235,7 @@ def finish_stack(self, node_stack): def form_loop_body(self, ctrl): assert isinstance(ctrl, ctrl_dotx_main_loop_t), "wrong ctrl type" + expr_empty_line = dotx_core_loop_expr(self.mc, "empty line", "") f_gld_a = ctrl.global_load_a_functor f_gld_b = ctrl.global_load_b_functor f_sst_a = ctrl.shared_store_a_functor @@ -237,6 +246,8 @@ def form_loop_body(self, ctrl): f_move_slice_window_a = ctrl.move_slice_window_a_functor f_move_slice_window_b = ctrl.move_slice_window_b_functor + + f_move_gld_b_to_v_b = ctrl.move_gld_b_to_v_b_functor v_a = ctrl.v_a v_b = ctrl.v_b @@ -287,30 +298,44 @@ def form_loop_body(self, ctrl): gld_a_b = dotx_core_loop_node("global load a/b", gld_a, gld_b) sld_a = dotx_core_loop_expr(self.mc, "sld_a", f_sld_a) sld_a.expr_set_args(v_a(), v_sld_a_os(), lds_base_m) - sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) - sld_b.expr_set_args(v_b(), v_sld_b_os(), lds_base_n) - sst_a = dotx_core_loop_node("sst a node", - dotx_core_loop_expr(self.mc, "wait a global load", f"s_waitcnt vmcnt({f_gld_b.get_issues()})"), - dotx_core_loop_expr(self.mc, "sst_a", f_sst_a)) - sst_b = dotx_core_loop_node("sst b node", - dotx_core_loop_expr(self.mc, "wait b global load", f"s_waitcnt vmcnt(0)"), - dotx_core_loop_expr(self.mc, "sst_b", f_sst_b)) + if ctrl.tensor_b_bypass_lds == 0: + sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) + sld_b.expr_set_args(v_b(), v_sld_b_os(), lds_base_n) + + if ctrl.mini_weights == 0: + sst_a = dotx_core_loop_node("sst a node", + dotx_core_loop_expr(self.mc, "wait a global load", f"s_waitcnt vmcnt({f_gld_b.get_issues()})"), + dotx_core_loop_expr(self.mc, "sst_a", f_sst_a)) + else: + sst_a = expr_empty_line + if ctrl.tensor_b_bypass_lds == 0: + sst_b = dotx_core_loop_node("sst b node", + dotx_core_loop_expr(self.mc, "wait b global load", f"s_waitcnt vmcnt(0)"), + dotx_core_loop_expr(self.mc, "sst_b", f_sst_b)) + else: + sst_b = expr_empty_line sst_a_b = dotx_core_loop_node("sst a/b before core loop", sst_a, sst_b) - msw_a_b = dotx_core_loop_node("msw a/b node", - dotx_core_loop_expr(self.mc, "msw a", f_move_slice_window_a), - dotx_core_loop_expr(self.mc, "msw b", f_move_slice_window_b)) + msw_a = dotx_core_loop_expr(self.mc, "msw a", f_move_slice_window_a) + msw_b = dotx_core_loop_expr(self.mc, "msw b", f_move_slice_window_b) + move_gld_b_to_v_b = dotx_core_loop_expr(self.mc, "move gld to v", f_move_gld_b_to_v_b) wait_all_lgkm = dotx_core_loop_expr(self.mc, "wait all lds", f"s_waitcnt lgkmcnt(0)") barrier = dotx_core_loop_expr(self.mc, "barrier", f"s_barrier") - wait_sst_node = dotx_core_loop_node("wait sst node", wait_all_lgkm, barrier) + if ctrl.mini_weights == 0: + wait_node = dotx_core_loop_node("wait sst node", wait_all_lgkm, barrier) + else: + wait_node = move_gld_b_to_v_b # sst a/b double buffer switch - sst_buffer_switch_b = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") - sst_buffer_switch_a = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") - sst_buffer_switch_node = dotx_core_loop_node("sst buffer switch node", sst_buffer_switch_b, sst_buffer_switch_a) + if ctrl.mini_weights == 0: + sst_buffer_switch_b = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") + sst_buffer_switch_a = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") + sst_buffer_switch_node = dotx_core_loop_node("sst buffer switch node", sst_buffer_switch_b, sst_buffer_switch_a) + else: + sst_buffer_switch_node = expr_empty_line dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) @@ -320,16 +345,21 @@ def form_loop_body(self, ctrl): stack = [loop_body] # form repeat k - fma_main_loop_node = self.form_loop_fma_body(ctrl, dotx_m.lanegroup_repeat_m, dotx_m.lanegroup_repeat_n - 1) - self.append_new_node(fma_main_loop_node, stack, "after fma") - - if ctrl.lds_buffer_num == 2: - sld_buffer_switch_b = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sld_b_os()}], {hex(lds_single_size)}, v[{v_sld_b_os()}] ; switch double buffer b load") - sld_buffer_switch_a = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sld_a_os()}], {hex(lds_single_size)}, v[{v_sld_a_os()}] ; switch double buffer a load") - sld_buffer_switch_node = dotx_core_loop_node("sst buffer switch node", sld_buffer_switch_b, sld_buffer_switch_a) - self.append_new_node(sld_buffer_switch_node, stack, "after sld buffer switch") + if ctrl.unroll_k == 1 and dotx_m.lanegroup_repeat_n == 1: + fma_main_loop_node = self.form_loop_fma_body(ctrl, dotx_m.lanegroup_repeat_m, dotx_m.lanegroup_repeat_n) else: - self.append_new_node(wait_sst_node, stack, "after wait lds op") + fma_main_loop_node = self.form_loop_fma_body(ctrl, dotx_m.lanegroup_repeat_m, dotx_m.lanegroup_repeat_n - 1) + self.append_new_node(fma_main_loop_node, stack, "after fma") + self.append_new_node(msw_a, stack, "msw a") + + if ctrl.mini_weights == 0: + if ctrl.lds_buffer_num == 2: + sld_buffer_switch_b = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sld_b_os()}], {hex(lds_single_size)}, v[{v_sld_b_os()}] ; switch double buffer b load") + sld_buffer_switch_a = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sld_a_os()}], {hex(lds_single_size)}, v[{v_sld_a_os()}] ; switch double buffer a load") + sld_buffer_switch_node = dotx_core_loop_node("sst buffer switch node", sld_buffer_switch_b, sld_buffer_switch_a) + self.append_new_node(sld_buffer_switch_node, stack, "after sld buffer switch") + else: + self.append_new_node(wait_node, stack, "after wait lds op") # sst node self.append_new_node(sst_a_b, stack, "after sst") @@ -338,42 +368,49 @@ def form_loop_body(self, ctrl): self.append_new_node(self.form_loop_jump_finish_check(), stack, "global load and last repeat n") # move slice window part - self.append_new_node(msw_a_b, stack, "after msw") + self.append_new_node(msw_b, stack, "after msw") # last k last n dotx - wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for all sld", f's_waitcnt lgkmcnt({f_sst_a.get_issues() + f_sst_b.get_issues()})') - self.append_new_node(wait_lgkmcnt, stack, "last n dotx") + if ctrl.mini_weights == 0: + wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for all sld", f's_waitcnt lgkmcnt({f_sst_a.get_issues() + f_sst_b.get_issues()})') + self.append_new_node(wait_lgkmcnt, stack, "last n dotx") - for i_rm in range(dotx_m.lanegroup_repeat_m - 1): - # compute index for three matrice - i_rn = dotx_m.lanegroup_repeat_n - 1 - c_index = i_rm * c_thread_n + i_rn * c_per_inst - a_index = (i_rm % local_prefetch_num_m) * local_buffer_m - b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n + if ctrl.unroll_k == 1 and dotx_m.lanegroup_repeat_n == 1: + pass + else: + for i_rm in range(dotx_m.lanegroup_repeat_m - 1): + # compute index for three matrice + i_rn = dotx_m.lanegroup_repeat_n - 1 + c_index = i_rm * c_thread_n + i_rn * c_per_inst + a_index = (i_rm % local_prefetch_num_m) * local_buffer_m + b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n - dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) - dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) - self.append_new_node(dotx, stack, "last dotx") + dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) + dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) + self.append_new_node(dotx, stack, "last dotx") # sst double buffer if ctrl.lds_buffer_num == 2: self.append_new_node(sst_buffer_switch_node, stack, "after sst switch node") # wait for sst done - self.append_new_node(wait_sst_node, stack, "after waiting for sst") + self.append_new_node(wait_node, stack, "after waiting for sst") # global load node self.append_new_node(gld_a_b, stack, "after global load") # last n repeat - i_rn = dotx_m.lanegroup_repeat_n - 1 - i_rm = dotx_m.lanegroup_repeat_m - 1 - c_index = i_rm * c_thread_n + i_rn * c_per_inst - a_index = (i_rm % local_prefetch_num_m) * local_buffer_m - b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n - dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) - dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) - self.append_new_node(dotx, stack, "last dotx") + if ctrl.unroll_k == 1 and dotx_m.lanegroup_repeat_n == 1: + pass + else: + i_rn = dotx_m.lanegroup_repeat_n - 1 + i_rm = dotx_m.lanegroup_repeat_m - 1 + c_index = i_rm * c_thread_n + i_rn * c_per_inst + a_index = (i_rm % local_prefetch_num_m) * local_buffer_m + b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n + dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) + dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) + self.append_new_node(dotx, stack, "last dotx") # jump to begin self.append_new_node(self.form_loop_jump_to_begin(), stack, "finishing branch") @@ -382,28 +419,29 @@ def form_loop_body(self, ctrl): loop_finishing_label = dotx_core_loop_expr(self.mc, "loop finish label", self.loop_label_finish+':') self.append_new_node(loop_finishing_label, stack, "finishing branch") - wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for all sld", f's_waitcnt lgkmcnt({f_sst_a.get_issues() + f_sst_b.get_issues()})') - self.append_new_node(wait_lgkmcnt, stack, "last n dotx in finish branch") - for i_rm in range(dotx_m.lanegroup_repeat_m): - # compute index for three matrice - i_rn = dotx_m.lanegroup_repeat_n - 1 - c_index = i_rm * c_thread_n + i_rn * c_per_inst - a_index = (i_rm % local_prefetch_num_m) * local_buffer_m - b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n + if ctrl.mini_weights == 0: + wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for all sld", f's_waitcnt lgkmcnt({f_sst_a.get_issues() + f_sst_b.get_issues()})') + self.append_new_node(wait_lgkmcnt, stack, "last n dotx in finish branch") + + if ctrl.unroll_k == 1 and dotx_m.lanegroup_repeat_n == 1: + pass + else: + for i_rm in range(dotx_m.lanegroup_repeat_m): + # compute index for three matrice + i_rn = dotx_m.lanegroup_repeat_n - 1 + c_index = i_rm * c_thread_n + i_rn * c_per_inst + a_index = (i_rm % local_prefetch_num_m) * local_buffer_m + b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n - dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) - dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) - self.append_new_node(dotx, stack, "last dotx") + dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) + dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) + self.append_new_node(dotx, stack, "last dotx") # loop end branch loop_end_label = dotx_core_loop_expr(self.mc, "loop end label", self.loop_label_end+':') self.append_new_node(loop_end_label, stack, "node end fma body") - wait_all_lgkm = dotx_core_loop_expr(self.mc, "wait all lds", f"s_waitcnt lgkmcnt(0)") - barrier = dotx_core_loop_expr(self.mc, "barrier", f"s_barrier") - wait_sst_node = dotx_core_loop_node("wait sst node", wait_all_lgkm, barrier) - - self.append_new_node(wait_sst_node, stack, "end loop dotx") + self.append_new_node(wait_node, stack, "end loop dotx") loop_end_node = self.form_loop_fma_body(ctrl, dotx_m.lanegroup_repeat_m, dotx_m.lanegroup_repeat_n) self.append_new_node(loop_end_node, stack, "finish") self.finish_stack(stack) @@ -460,8 +498,9 @@ def form_loop_fma_body(self, ctrl, repeat_m, repeat_n): sld_a = dotx_core_loop_expr(self.mc, "sld_a", f_sld_a) sld_a.expr_set_args(v_a(), v_sld_a_os(), lds_base_m) - sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) - sld_b.expr_set_args(v_b(), v_sld_b_os(), lds_base_n) + if ctrl.tensor_b_bypass_lds == 0: + sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) + sld_b.expr_set_args(v_b(), v_sld_b_os(), lds_base_n) loop_fma_body = dotx_core_loop_node("loop fma body") stack = [loop_fma_body] @@ -480,9 +519,10 @@ def form_loop_fma_body(self, ctrl, repeat_m, repeat_n): prefetch_a.append(sld_a) for i_prefetch in range(local_prefetch_num): - sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) - sld_b.expr_set_args(v_b(i_prefetch * local_buffer_n), v_sld_b_os(), lds_base_n + i_prefetch * lds_width_n_per_read) - prefetch_b.append(sld_b) + if ctrl.tensor_b_bypass_lds == 0: + sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) + sld_b.expr_set_args(v_b(i_prefetch * local_buffer_n), v_sld_b_os(), lds_base_n + i_prefetch * lds_width_n_per_read) + prefetch_b.append(sld_b) local_prefetch[0:0] = prefetch_b local_prefetch[1:1] = prefetch_a @@ -503,16 +543,22 @@ def form_loop_fma_body(self, ctrl, repeat_m, repeat_n): for i_k in range(unroll_k - 1): for i_rn in range(dotx_m.lanegroup_repeat_n): if i_rn > 0: - sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) - sld_b.expr_set_args(v_b(((i_k * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num)* local_buffer_n), v_sld_b_os(), f'{lds_base_n}+{i_k}*{lds_width_n}+{(i_rn+1)*lds_width_n_per_read}') - ds_waitcnt.push_new_vgpr(v_b(((i_k * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num)* local_buffer_n), v_b_wait_index) - self.append_new_node(sld_b, stack, "after prefetch b") + if ctrl.tensor_b_bypass_lds == 0: + sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) + sld_b.expr_set_args(v_b(((i_k * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num)* local_buffer_n), v_sld_b_os(), f'{lds_base_n}+{i_k}*{lds_width_n}+{(i_rn+1)*lds_width_n_per_read}') + ds_waitcnt.push_new_vgpr(v_b(((i_k * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num)* local_buffer_n), v_b_wait_index) + self.append_new_node(sld_b, stack, "after prefetch b") for i_rm in range(dotx_m.lanegroup_repeat_m): # compute index for three matrice c_index = i_rm * c_thread_n + i_rn * c_per_inst a_index = (i_rm % local_prefetch_num_m) * local_buffer_m b_index = ((i_k * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n - lgkmcnt = ds_waitcnt.compute_waitcnt([v_a(a_index), v_b(b_index)]) + + if ctrl.tensor_b_bypass_lds == 0: + lgkmcnt = ds_waitcnt.compute_waitcnt([v_a(a_index), v_b(b_index)]) + else: + lgkmcnt = ds_waitcnt.compute_waitcnt([v_a(a_index)]) + if lgkmcnt != -1: wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for dotx {i_k, i_rm, i_rn}", f's_waitcnt lgkmcnt({lgkmcnt})') self.append_new_node(wait_lgkmcnt, stack, "after wait cnt") @@ -530,33 +576,41 @@ def form_loop_fma_body(self, ctrl, repeat_m, repeat_n): sld_a = dotx_core_loop_expr(self.mc, "sld a", f_sld_a) sld_a.expr_set_args(v_a((local_prefetch_num_m - 1) * local_buffer_m), v_sld_a_os(), f'{lds_base_m}+{i_k + 1}*{lds_width_m}+{(local_prefetch_num_m - 1)*lds_width_m_per_read}') ds_waitcnt.push_new_vgpr(v_a((local_prefetch_num_m - 1) * local_buffer_m),v_a_wait_index) - sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) - sld_b.expr_set_args(v_b((i_k * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num * local_buffer_n), v_sld_b_os(), f'{lds_base_n}+{i_k + 1}*{lds_width_n}+{(local_prefetch_num - 1)*lds_width_n_per_read}') - ds_waitcnt.push_new_vgpr(v_b((i_k * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num * local_buffer_n), v_b_wait_index) + if ctrl.tensor_b_bypass_lds == 0: + sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) + sld_b.expr_set_args(v_b((i_k * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num * local_buffer_n), v_sld_b_os(), f'{lds_base_n}+{i_k + 1}*{lds_width_n}+{(local_prefetch_num - 1)*lds_width_n_per_read}') + ds_waitcnt.push_new_vgpr(v_b((i_k * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num * local_buffer_n), v_b_wait_index) + else: + sld_b = dotx_core_loop_expr(self.mc, "empty line", "") sld_a_b = dotx_core_loop_node("last sld a/b", sld_a, sld_b) self.append_new_node(sld_a_b, stack, f"next dotx node {i_k}") for i_rn in range(repeat_n): if i_rn > 0 and dotx_m.lanegroup_repeat_n > local_prefetch_num and i_rn < dotx_m.lanegroup_repeat_n - 1: - sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) - sld_b.expr_set_args(v_b((((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num * local_buffer_n)), v_sld_b_os(), f'{lds_base_n}+{unroll_k - 1}*{lds_width_n}+{(i_rn+1)*lds_width_n_per_read}') - ds_waitcnt.push_new_vgpr(v_b((((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num * local_buffer_n)), v_b_wait_index) - - self.append_new_node(sld_b, stack, "after prefetch b") + if ctrl.tensor_b_bypass_lds == 0: + sld_b = dotx_core_loop_expr(self.mc, "sld_b", f_sld_b) + sld_b.expr_set_args(v_b((((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num * local_buffer_n)), v_sld_b_os(), f'{lds_base_n}+{unroll_k - 1}*{lds_width_n}+{(i_rn+1)*lds_width_n_per_read}') + ds_waitcnt.push_new_vgpr(v_b((((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn - 1) % local_prefetch_num * local_buffer_n)), v_b_wait_index) + self.append_new_node(sld_b, stack, "after prefetch b") for i_rm in range(repeat_m): # compute index for three matrice c_index = i_rm * c_thread_n + i_rn * c_per_inst a_index = (i_rm % local_prefetch_num_m) * local_buffer_m b_index = (((unroll_k - 1) * dotx_m.lanegroup_repeat_n + i_rn) % local_prefetch_num) * local_buffer_n - lgkmcnt = ds_waitcnt.compute_waitcnt([v_a(a_index), v_b(b_index)]) + + if ctrl.tensor_b_bypass_lds == 0: + lgkmcnt = ds_waitcnt.compute_waitcnt([v_a(a_index), v_b(b_index)]) + else: + lgkmcnt = ds_waitcnt.compute_waitcnt([v_a(a_index)]) + if lgkmcnt != -1: - wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for dotx {i_k, i_rm, i_rn}", f's_waitcnt lgkmcnt({lgkmcnt})') + wait_lgkmcnt = dotx_core_loop_expr(self.mc, f"wait for dotx {unroll_k - 1, i_rm, i_rn}", f's_waitcnt lgkmcnt({lgkmcnt})') self.append_new_node(wait_lgkmcnt, stack, "after wait cnt") dotx = dotx_core_loop_expr(self.mc, "dotx", v_dotx_k) dotx.expr_set_args(v_c(c_index), v_a(a_index), v_b(b_index)) - self.append_new_node(dotx, stack, f"dotx node next {i_k, i_rm, i_rn}") + self.append_new_node(dotx, stack, f"dotx node next {unroll_k - 1, i_rm, i_rn}") self.finish_stack(stack) @@ -576,7 +630,7 @@ def add_node_comment(self, node, str_comment): return new_node def creat_base_graph(self): - + expr_empty_line = dotx_core_loop_expr(self.mc, "empty line", "") label_fma_body = 'L_{}_fma_body'.format(self.ctrl.label_prefix) label_fma_finishing = 'L_{}_fma_finishing'.format(self.ctrl.label_prefix) label_fma_end = 'L_{}_end'.format(self.ctrl.label_prefix) @@ -596,13 +650,14 @@ def creat_base_graph(self): f_gld_a = self.ctrl.global_load_a_functor f_gld_b = self.ctrl.global_load_b_functor - f_gld_b = self.ctrl.global_load_b_functor f_sst_a = self.ctrl.shared_store_a_functor f_sst_b = self.ctrl.shared_store_b_functor f_move_slice_window_a = self.ctrl.move_slice_window_a_functor f_move_slice_window_b = self.ctrl.move_slice_window_b_functor + f_move_gld_b_to_v_b = self.ctrl.move_gld_b_to_v_b_functor + v_sst_a_os = self.ctrl.v_sst_a_os v_sst_b_os = self.ctrl.v_sst_b_os @@ -611,16 +666,20 @@ def creat_base_graph(self): gld_a = dotx_core_loop_expr(self.mc, "gld_a", f_gld_a) gld_b = dotx_core_loop_expr(self.mc, "gld_b", f_gld_b) - sst_a = dotx_core_loop_node("sst a node", - dotx_core_loop_expr(self.mc, "wait a global load", f"s_waitcnt vmcnt({f_gld_b.get_issues()})"), - dotx_core_loop_expr(self.mc, "sst_a", f_sst_a)) - sst_b = dotx_core_loop_node("sst b node", - dotx_core_loop_expr(self.mc, "wait b global load", f"s_waitcnt vmcnt(0)"), - dotx_core_loop_expr(self.mc, "sst_b", f_sst_b)) + if self.ctrl.mini_weights == 0: + sst_a = dotx_core_loop_node("sst a node", + dotx_core_loop_expr(self.mc, "wait a global load", f"s_waitcnt vmcnt({f_gld_b.get_issues()})"), + dotx_core_loop_expr(self.mc, "sst_a", f_sst_a)) + else: + sst_a = expr_empty_line + if self.ctrl.tensor_b_bypass_lds == 0: + sst_b = dotx_core_loop_node("sst b node", + dotx_core_loop_expr(self.mc, "wait b global load", f"s_waitcnt vmcnt(0)"), + dotx_core_loop_expr(self.mc, "sst_b", f_sst_b)) + else: + sst_b = expr_empty_line - msw_a_b = dotx_core_loop_node("msw a/b node", - dotx_core_loop_expr(self.mc, "msw a", f_move_slice_window_a), - dotx_core_loop_expr(self.mc, "msw b", f_move_slice_window_b)) + msw_b = dotx_core_loop_expr(self.mc, "msw b", f_move_slice_window_b) base_node = dotx_core_loop_node("core_loop") node_clear_c = dotx_core_loop_expr(self.mc, ".clear_c", f".v_clear_nc {v_c()}, {dotx_m.total_acc_c()}") @@ -642,23 +701,30 @@ def creat_base_graph(self): node_before_for_loop = dotx_core_loop_node("sst a/b before core loop0", first_sst, node_clear_c) check_loop_end_node = base_for_loop.form_loop_jump_end_check() end_check_before_msw = dotx_core_loop_node("end_check_before_msw", node_before_for_loop, check_loop_end_node) - base_node.first = dotx_core_loop_node("sst a/b before core loop1", end_check_before_msw, msw_a_b) + base_node.first = dotx_core_loop_node("sst a/b before core loop1", end_check_before_msw, msw_b) # sst a/b double buffer switch - sst_buffer_switch_b = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") - sst_buffer_switch_a = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") - sst_buffer_switch_node = dotx_core_loop_node("sst buffer switch node", sst_buffer_switch_b, sst_buffer_switch_a) + if self.ctrl.mini_weights == 0: + sst_buffer_switch_b = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_b_os()}], {hex(lds_single_size)}, v[{v_sst_b_os()}] ; switch double buffer b store") + sst_buffer_switch_a = dotx_core_loop_expr(self.mc, "sst a buffer switch", f"v_xor_b32 v[{v_sst_a_os()}], {hex(lds_single_size)}, v[{v_sst_a_os()}] ; switch double buffer a store") + sst_buffer_switch_node = dotx_core_loop_node("sst buffer switch node", sst_buffer_switch_b, sst_buffer_switch_a) + else: + sst_buffer_switch_node = expr_empty_line # first barrier and waitcnt + move_gld_b_to_v_b = dotx_core_loop_expr(self.mc, "move gld to v", f_move_gld_b_to_v_b) wait_all_lgkm = dotx_core_loop_expr(self.mc, "wait all lds", f"s_waitcnt lgkmcnt(0)") barrier = dotx_core_loop_expr(self.mc, "barrier", f"s_barrier") wait_sst_node = dotx_core_loop_node("wait sst node", wait_all_lgkm, barrier) # - if self.ctrl.lds_buffer_num == 2: - wait_node = dotx_core_loop_node("wait node", sst_buffer_switch_node, wait_sst_node) + if self.ctrl.mini_weights == 0: + if self.ctrl.lds_buffer_num == 2: + wait_node = dotx_core_loop_node("wait node", sst_buffer_switch_node, wait_sst_node) + else: + wait_node = wait_sst_node else: - wait_node = wait_sst_node + wait_node = dotx_core_loop_node("wait node", wait_sst_node, move_gld_b_to_v_b) # global load before loop global_load_a_b = dotx_core_loop_node("global load a/b", gld_a, gld_b) diff --git a/script/gtc_conv_fsr.sh b/script/gtc_conv_fsr.sh new file mode 100755 index 00000000..8a0d223c --- /dev/null +++ b/script/gtc_conv_fsr.sh @@ -0,0 +1,100 @@ + +#!/bin/sh +if [ $# -ge 1 ] ; then + DIR=$1 +else + DIR=bwd +fi + +if [ $# -ge 2 ] ; then + LAYOUT=$2 +else + LAYOUT="nchw" +fi + +if [ $# -ge 3 ] ; then + PREC=$3 +else + PREC="fp32" +fi + +if [ $# -ge 4 ] ; then + ARCH=$4 +else + ARCH="gfx908" +fi + +if [[ "${LAYOUT}" = "nchw" ]] ; then + LAYOUT_HSACO="" + LAYOUT_ARG="" +elif [[ "${LAYOUT}" = "nhwc" ]] ; then + LAYOUT_HSACO="_nhwc" + LAYOUT_ARG="--in_layout NHWC --fil_layout NHWC --out_layout NHWC" +elif [[ "${LAYOUT}" = "nchwc_kcyxc" ]] ; then + LAYOUT_HSACO="_nchwc" + LAYOUT_ARG="--in_layout NCHWC --fil_layout NCHWC --out_layout NCHWC" +elif [[ "${LAYOUT}" = "nchwc_cyxkc" ]] ; then + LAYOUT_HSACO="_nchwc" + LAYOUT_ARG="--in_layout NCHWC --fil_layout CHWNC --out_layout NCHWC" +else + echo "wrong layout: ${LAYOUT}" + exit 1 +fi + +if [[ "${PREC}" = "fp32" ]] ; then + PREC_HSACO="" + CONV="conv" +elif [[ "${PREC}" = "int4"* ]] ; then + PREC_HSACO="_${PREC}" + CONV="conv${PREC}" +elif [[ "${PREC}" = "fp16"* ]] ; then + PREC_HSACO="_${PREC}" + CONV="conv${PREC}" +elif [[ "${PREC}" = "int8"* ]] ; then + PREC_HSACO="_${PREC}" + CONV="conv${PREC}" +elif [[ "${PREC}" = "bf16"* ]] ; then + PREC_HSACO="_${PREC}" + CONV="convbfp16${PREC:4}" +else + echo "wrong precision: ${PREC}" + exit 1 +fi + +if [ "${ARCH}" != "gfx90a" ] && [ "${ARCH}" != "gfx908" ] && [ "${ARCH}" != "gfx1030" ] ; then + echo "wrong arch: ${ARCH}" + exit 1 +fi + +echo IGEMM_HSACO=out/igemm_${DIR}_gtc_${ARCH}${LAYOUT_HSACO}${PREC_HSACO}_fsr.hsaco +export IGEMM_HSACO=out/igemm_${DIR}_gtc_${ARCH}${LAYOUT_HSACO}${PREC_HSACO}_fsr.hsaco +export IGEMM_TENSOR_CAST_HSACO=out/igemm_gtc_tensor_cast.hsaco +export IGEMM_GPU_NAIVE_CONV_HSACO=out/naive_conv.hsaco +export IGEMM_SCLK_MHZ=2450 +export IGEMM_LOG_FASTEST_CONFIG=1 +export IGEMM_SLEEP_MS=117 +export PER_PIXEL_CHECK=1 +export PER_PIXEL_CHECK_PRINT=1 + +export DBG_MODE=0 +export IGEMM_ASSERT_WHEN_INVALID=1 +export IGEMM_WARMUP=1 +export IGEMM_REPEAT=4 +export IGEMM_GKS_ITERATIVE=1 +#export IGEMM_BENCH_CSV=1 +export IGEMM_RAND_INT=1 + +# Flag enables fwd, bwd, wrw convolutions +if [ "${DIR}" = "fwd" ] ; then + FORW=1 +elif [ "${DIR}" = "bwd" ] ; then + FORW=2 +elif [ "${DIR}" = "wrw" ] ; then + FORW=4 +else + echo "wrong direction" + exit 1 +fi + +./out/conv_driver.exe $CONV -n 1 -c 8 -H 1080 -W 1920 -k 16 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -t 1 -F $FORW ${LAYOUT_ARG} +./out/conv_driver.exe $CONV -n 1 -c 16 -H 135 -W 240 -k 16 -y 3 -x 3 -p 1 -q 1 -u 1 -v 1 -l 1 -j 1 -t 1 -F $FORW ${LAYOUT_ARG}