From 0aacbdcb9ed334013f710381288cebb5e9d0d983 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Mon, 30 Jan 2023 09:16:55 +0000 Subject: [PATCH 01/28] LUT by CC; tune more iters --- .../specializer/funtional/sparse_ctx_base.py | 2 +- sparta/specializer/kernels/kernel_base.py | 6 +- ...ta_matmul_lut.csv => matmul.sparta.70.csv} | 0 .../look_up_tables/matmul.sparta.default.csv | 325 ++++++++++++++++++ sparta/specializer/kernels/matmul.py | 21 +- sparta/testing/utils.py | 2 +- test/bench/attention/attention.py | 4 +- test/bench/matmul/matmul.py | 4 +- test/unit/test_sparse_attention.py | 3 - test/unit/test_sparse_matmul.py | 3 - 10 files changed, 350 insertions(+), 20 deletions(-) rename sparta/specializer/kernels/look_up_tables/{sparta_matmul_lut.csv => matmul.sparta.70.csv} (100%) create mode 100644 sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv diff --git a/sparta/specializer/funtional/sparse_ctx_base.py b/sparta/specializer/funtional/sparse_ctx_base.py index 7535fc80..f744f111 100644 --- a/sparta/specializer/funtional/sparse_ctx_base.py +++ b/sparta/specializer/funtional/sparse_ctx_base.py @@ -65,7 +65,7 @@ def set_sample_inputs(self, sample_inputs: List[torch.Tensor]): for x in sample_inputs ] - def test(self, num_warmups: int = 10, num_iters: int = 10, cuda: bool = True): + def test(self, num_warmups: int = 20, num_iters: int = 100, cuda: bool = False): return self.active_kernel().test( inputs=self.sample_inputs, num_warmups=num_warmups, diff --git a/sparta/specializer/kernels/kernel_base.py b/sparta/specializer/kernels/kernel_base.py index 9a4ade6c..c73eb675 100644 --- a/sparta/specializer/kernels/kernel_base.py +++ b/sparta/specializer/kernels/kernel_base.py @@ -200,9 +200,9 @@ def _convert_data(self, inputs: List[torch.Tensor], outputs: List[torch.Tensor]) def test( self, inputs: List[torch.Tensor], - num_warmups: int = 10, - num_iters: int = 10, - cuda: bool = True, + num_warmups: int = 20, + num_iters: int = 100, + cuda: bool = False, ): """Note that all inputs and outputs are dense tensors here.""" sparse_inputs = [x for x in inputs] diff --git a/sparta/specializer/kernels/look_up_tables/sparta_matmul_lut.csv b/sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/sparta_matmul_lut.csv rename to sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv b/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv new file mode 100644 index 00000000..7a628265 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv @@ -0,0 +1,325 @@ +mode,trans_A,trans_B,BM,BK,BN,TM,TK,TN,latency +dds,False,False,16,16,16,4,4,4,114.3322 +dds,False,False,16,16,32,4,4,4,82.87 +dds,False,False,16,16,64,4,4,4,74.2951 +dds,False,False,16,32,16,4,4,4,120.1529 +dds,False,False,16,32,32,4,4,4,87.5296 +dds,False,False,16,32,64,4,4,4,79.7643 +dds,False,False,16,64,16,4,4,4,148.3481 +dds,False,False,16,64,32,4,4,4,92.0502 +dds,False,False,16,64,64,4,4,4,84.0627 +dds,False,False,32,16,16,4,4,4,77.7297 +dds,False,False,32,16,32,4,4,4,54.2717 +dds,False,False,32,16,64,4,4,4,52.089 +dds,False,False,32,32,16,4,4,4,83.1408 +dds,False,False,32,32,32,4,4,4,59.8523 +dds,False,False,32,32,64,4,4,4,54.2046 +dds,False,False,32,64,16,4,4,4,98.1477 +dds,False,False,32,64,32,4,4,4,69.5779 +dds,False,False,32,64,64,4,4,4,56.5643 +dds,False,False,64,16,16,4,4,4,64.2577 +dds,False,False,64,16,32,4,4,4,50.0309 +dds,False,False,64,16,64,4,4,4,50.0309 +dds,False,False,64,32,16,4,4,4,75.437 +dds,False,False,64,32,32,4,4,4,54.0036 +dds,False,False,64,32,64,4,4,4,54.0036 +dds,False,False,64,64,16,4,4,4,88.9933 +dds,False,False,64,64,32,4,4,4,56.6697 +dds,False,False,64,64,64,4,4,4,56.6697 +dds,False,True,16,16,16,4,4,4,134.6172 +dds,False,True,16,16,32,4,4,4,106.9163 +dds,False,True,16,16,64,4,4,4,101.0877 +dds,False,True,16,32,16,4,4,4,177.8903 +dds,False,True,16,32,32,4,4,4,96.7729 +dds,False,True,16,32,64,4,4,4,89.2383 +dds,False,True,16,64,16,4,4,4,281.8069 +dds,False,True,16,64,32,4,4,4,137.432 +dds,False,True,16,64,64,4,4,4,132.7488 +dds,False,True,32,16,16,4,4,4,83.1024 +dds,False,True,32,16,32,4,4,4,65.4798 +dds,False,True,32,16,64,4,4,4,59.4675 +dds,False,True,32,32,16,4,4,4,93.876 +dds,False,True,32,32,32,4,4,4,72.5978 +dds,False,True,32,32,64,4,4,4,65.1659 +dds,False,True,32,64,16,4,4,4,120.1478 +dds,False,True,32,64,32,4,4,4,98.9469 +dds,False,True,32,64,64,4,4,4,83.0436 +dds,False,True,64,16,16,4,4,4,65.8737 +dds,False,True,64,16,32,4,4,4,53.26 +dds,False,True,64,16,64,4,4,4,53.26 +dds,False,True,64,32,16,4,4,4,81.8998 +dds,False,True,64,32,32,4,4,4,60.1544 +dds,False,True,64,32,64,4,4,4,60.1544 +dds,False,True,64,64,16,4,4,4,99.884 +dds,False,True,64,64,32,4,4,4,70.5451 +dds,False,True,64,64,64,4,4,4,70.5451 +dds,True,False,16,16,16,4,4,4,135.2541 +dds,True,False,16,16,32,4,4,4,88.3433 +dds,True,False,16,16,64,4,4,4,80.7328 +dds,True,False,16,32,16,4,4,4,147.7196 +dds,True,False,16,32,32,4,4,4,83.4137 +dds,True,False,16,32,64,4,4,4,80.5606 +dds,True,False,16,64,16,4,4,4,178.6693 +dds,True,False,16,64,32,4,4,4,96.307 +dds,True,False,16,64,64,4,4,4,85.6432 +dds,True,False,32,16,16,4,4,4,97.407 +dds,True,False,32,16,32,4,4,4,64.0028 +dds,True,False,32,16,64,4,4,4,56.0378 +dds,True,False,32,32,16,4,4,4,112.9191 +dds,True,False,32,32,32,4,4,4,71.4564 +dds,True,False,32,32,64,4,4,4,59.7966 +dds,True,False,32,64,16,4,4,4,124.0738 +dds,True,False,32,64,32,4,4,4,80.3177 +dds,True,False,32,64,64,4,4,4,61.5147 +dds,True,False,64,16,16,4,4,4,124.6127 +dds,True,False,64,16,32,4,4,4,76.9557 +dds,True,False,64,16,64,4,4,4,76.9557 +dds,True,False,64,32,16,4,4,4,143.3676 +dds,True,False,64,32,32,4,4,4,84.2901 +dds,True,False,64,32,64,4,4,4,84.2901 +dds,True,False,64,64,16,4,4,4,150.3928 +dds,True,False,64,64,32,4,4,4,87.5746 +dds,True,False,64,64,64,4,4,4,87.5746 +dds,True,True,16,16,16,4,4,4,163.2706 +dds,True,True,16,16,32,4,4,4,111.4495 +dds,True,True,16,16,64,4,4,4,105.2302 +dds,True,True,16,32,16,4,4,4,215.4126 +dds,True,True,16,32,32,4,4,4,103.8571 +dds,True,True,16,32,64,4,4,4,92.4045 +dds,True,True,16,64,16,4,4,4,310.5157 +dds,True,True,16,64,32,4,4,4,149.429 +dds,True,True,16,64,64,4,4,4,134.981 +dds,True,True,32,16,16,4,4,4,103.7712 +dds,True,True,32,16,32,4,4,4,70.4374 +dds,True,True,32,16,64,4,4,4,63.2451 +dds,True,True,32,32,16,4,4,4,127.7935 +dds,True,True,32,32,32,4,4,4,85.6814 +dds,True,True,32,32,64,4,4,4,69.7582 +dds,True,True,32,64,16,4,4,4,155.0835 +dds,True,True,32,64,32,4,4,4,112.8103 +dds,True,True,32,64,64,4,4,4,91.3851 +dds,True,True,64,16,16,4,4,4,126.8568 +dds,True,True,64,16,32,4,4,4,81.2913 +dds,True,True,64,16,64,4,4,4,81.2913 +dds,True,True,64,32,16,4,4,4,149.8753 +dds,True,True,64,32,32,4,4,4,91.8136 +dds,True,True,64,32,64,4,4,4,91.8136 +dds,True,True,64,64,16,4,4,4,167.0518 +dds,True,True,64,64,32,4,4,4,105.5301 +dds,True,True,64,64,64,4,4,4,105.5301 +dsd,False,False,16,16,16,4,4,4,108.1017 +dsd,False,False,16,16,32,4,4,4,87.8983 +dsd,False,False,16,16,64,4,4,4,80.2562 +dsd,False,False,16,32,16,4,4,4,111.9973 +dsd,False,False,16,32,32,4,4,4,83.6637 +dsd,False,False,16,32,64,4,4,4,83.4532 +dsd,False,False,16,64,16,4,4,4,135.7782 +dsd,False,False,16,64,32,4,4,4,96.0233 +dsd,False,False,16,64,64,4,4,4,96.0233 +dsd,False,False,32,16,16,4,4,4,72.3478 +dsd,False,False,32,16,32,4,4,4,57.4545 +dsd,False,False,32,16,64,4,4,4,52.2037 +dsd,False,False,32,32,16,4,4,4,78.8152 +dsd,False,False,32,32,32,4,4,4,59.3432 +dsd,False,False,32,32,64,4,4,4,54.5784 +dsd,False,False,32,64,16,4,4,4,91.2725 +dsd,False,False,32,64,32,4,4,4,67.9594 +dsd,False,False,32,64,64,4,4,4,67.9594 +dsd,False,False,64,16,16,4,4,4,60.3805 +dsd,False,False,64,16,32,4,4,4,49.5804 +dsd,False,False,64,16,64,4,4,4,46.6983 +dsd,False,False,64,32,16,4,4,4,73.6614 +dsd,False,False,64,32,32,4,4,4,53.7312 +dsd,False,False,64,32,64,4,4,4,46.5469 +dsd,False,False,64,64,16,4,4,4,85.0814 +dsd,False,False,64,64,32,4,4,4,56.0396 +dsd,False,False,64,64,64,4,4,4,56.0396 +dsd,False,True,16,16,16,4,4,4,130.935 +dsd,False,True,16,16,32,4,4,4,109.5658 +dsd,False,True,16,16,64,4,4,4,92.5529 +dsd,False,True,16,32,16,4,4,4,178.6903 +dsd,False,True,16,32,32,4,4,4,101.6086 +dsd,False,True,16,32,64,4,4,4,96.7423 +dsd,False,True,16,64,16,4,4,4,278.8124 +dsd,False,True,16,64,32,4,4,4,142.8228 +dsd,False,True,16,64,64,4,4,4,142.8228 +dsd,False,True,32,16,16,4,4,4,87.9781 +dsd,False,True,32,16,32,4,4,4,66.9843 +dsd,False,True,32,16,64,4,4,4,59.9348 +dsd,False,True,32,32,16,4,4,4,94.2587 +dsd,False,True,32,32,32,4,4,4,73.2396 +dsd,False,True,32,32,64,4,4,4,65.7842 +dsd,False,True,32,64,16,4,4,4,124.7007 +dsd,False,True,32,64,32,4,4,4,98.0345 +dsd,False,True,32,64,64,4,4,4,98.0345 +dsd,False,True,64,16,16,4,4,4,68.01 +dsd,False,True,64,16,32,4,4,4,54.7027 +dsd,False,True,64,16,64,4,4,4,49.913 +dsd,False,True,64,32,16,4,4,4,82.1206 +dsd,False,True,64,32,32,4,4,4,61.7535 +dsd,False,True,64,32,64,4,4,4,53.0829 +dsd,False,True,64,64,16,4,4,4,101.6558 +dsd,False,True,64,64,32,4,4,4,70.1075 +dsd,False,True,64,64,64,4,4,4,70.1075 +dsd,True,False,16,16,16,4,4,4,128.4223 +dsd,True,False,16,16,32,4,4,4,90.0412 +dsd,True,False,16,16,64,4,4,4,83.3222 +dsd,True,False,16,32,16,4,4,4,142.1878 +dsd,True,False,16,32,32,4,4,4,86.9283 +dsd,True,False,16,32,64,4,4,4,84.5636 +dsd,True,False,16,64,16,4,4,4,163.221 +dsd,True,False,16,64,32,4,4,4,99.8338 +dsd,True,False,16,64,64,4,4,4,99.8338 +dsd,True,False,32,16,16,4,4,4,93.9119 +dsd,True,False,32,16,32,4,4,4,62.1346 +dsd,True,False,32,16,64,4,4,4,56.0104 +dsd,True,False,32,32,16,4,4,4,111.0007 +dsd,True,False,32,32,32,4,4,4,70.8843 +dsd,True,False,32,32,64,4,4,4,59.1251 +dsd,True,False,32,64,16,4,4,4,119.8824 +dsd,True,False,32,64,32,4,4,4,80.5292 +dsd,True,False,32,64,64,4,4,4,80.5292 +dsd,True,False,64,16,16,4,4,4,123.8651 +dsd,True,False,64,16,32,4,4,4,76.5209 +dsd,True,False,64,16,64,4,4,4,60.1391 +dsd,True,False,64,32,16,4,4,4,141.5866 +dsd,True,False,64,32,32,4,4,4,83.946 +dsd,True,False,64,32,64,4,4,4,60.4473 +dsd,True,False,64,64,16,4,4,4,148.974 +dsd,True,False,64,64,32,4,4,4,87.7046 +dsd,True,False,64,64,64,4,4,4,87.7046 +dsd,True,True,16,16,16,4,4,4,159.7035 +dsd,True,True,16,16,32,4,4,4,112.7343 +dsd,True,True,16,16,64,4,4,4,96.109 +dsd,True,True,16,32,16,4,4,4,216.4289 +dsd,True,True,16,32,32,4,4,4,113.0748 +dsd,True,True,16,32,64,4,4,4,102.4129 +dsd,True,True,16,64,16,4,4,4,311.4233 +dsd,True,True,16,64,32,4,4,4,152.9842 +dsd,True,True,16,64,64,4,4,4,152.9842 +dsd,True,True,32,16,16,4,4,4,101.2226 +dsd,True,True,32,16,32,4,4,4,73.1764 +dsd,True,True,32,16,64,4,4,4,62.08 +dsd,True,True,32,32,16,4,4,4,128.2265 +dsd,True,True,32,32,32,4,4,4,85.9566 +dsd,True,True,32,32,64,4,4,4,71.4414 +dsd,True,True,32,64,16,4,4,4,156.3029 +dsd,True,True,32,64,32,4,4,4,112.7689 +dsd,True,True,32,64,64,4,4,4,112.7689 +dsd,True,True,64,16,16,4,4,4,127.2963 +dsd,True,True,64,16,32,4,4,4,79.6703 +dsd,True,True,64,16,64,4,4,4,61.7858 +dsd,True,True,64,32,16,4,4,4,149.5118 +dsd,True,True,64,32,32,4,4,4,91.4158 +dsd,True,True,64,32,64,4,4,4,67.6785 +dsd,True,True,64,64,16,4,4,4,168.1443 +dsd,True,True,64,64,32,4,4,4,105.573 +dsd,True,True,64,64,64,4,4,4,105.573 +sdd,False,False,16,16,16,4,4,4,110.7007 +sdd,False,False,16,16,32,4,4,4,83.8641 +sdd,False,False,16,16,64,4,4,4,79.1289 +sdd,False,False,16,32,16,4,4,4,119.6233 +sdd,False,False,16,32,32,4,4,4,82.5102 +sdd,False,False,16,32,64,4,4,4,78.6386 +sdd,False,False,16,64,16,4,4,4,145.0991 +sdd,False,False,16,64,32,4,4,4,89.7248 +sdd,False,False,16,64,64,4,4,4,83.423 +sdd,False,False,32,16,16,4,4,4,72.6872 +sdd,False,False,32,16,32,4,4,4,54.0949 +sdd,False,False,32,16,64,4,4,4,51.6267 +sdd,False,False,32,32,16,4,4,4,82.8882 +sdd,False,False,32,32,32,4,4,4,59.9318 +sdd,False,False,32,32,64,4,4,4,54.2268 +sdd,False,False,32,64,16,4,4,4,92.1725 +sdd,False,False,32,64,32,4,4,4,67.4228 +sdd,False,False,32,64,64,4,4,4,55.8651 +sdd,False,False,64,16,16,4,4,4,57.2198 +sdd,False,False,64,16,32,4,4,4,49.1785 +sdd,False,False,64,16,64,4,4,4,46.1422 +sdd,False,False,64,32,16,4,4,4,74.8587 +sdd,False,False,64,32,32,4,4,4,53.8419 +sdd,False,False,64,32,64,4,4,4,46.7183 +sdd,False,False,64,64,16,4,4,4,74.8587 +sdd,False,False,64,64,32,4,4,4,53.8419 +sdd,False,False,64,64,64,4,4,4,46.7183 +sdd,False,True,16,16,16,4,4,4,131.2918 +sdd,False,True,16,16,32,4,4,4,104.9111 +sdd,False,True,16,16,64,4,4,4,100.4713 +sdd,False,True,16,32,16,4,4,4,178.6799 +sdd,False,True,16,32,32,4,4,4,96.054 +sdd,False,True,16,32,64,4,4,4,88.2646 +sdd,False,True,16,64,16,4,4,4,271.7691 +sdd,False,True,16,64,32,4,4,4,138.3314 +sdd,False,True,16,64,64,4,4,4,132.6129 +sdd,False,True,32,16,16,4,4,4,77.3388 +sdd,False,True,32,16,32,4,4,4,59.9648 +sdd,False,True,32,16,64,4,4,4,60.089 +sdd,False,True,32,32,16,4,4,4,94.166 +sdd,False,True,32,32,32,4,4,4,72.2477 +sdd,False,True,32,32,64,4,4,4,64.8427 +sdd,False,True,32,64,16,4,4,4,120.019 +sdd,False,True,32,64,32,4,4,4,97.3089 +sdd,False,True,32,64,64,4,4,4,83.1809 +sdd,False,True,64,16,16,4,4,4,59.4227 +sdd,False,True,64,16,32,4,4,4,52.1191 +sdd,False,True,64,16,64,4,4,4,49.1211 +sdd,False,True,64,32,16,4,4,4,81.5417 +sdd,False,True,64,32,32,4,4,4,59.0152 +sdd,False,True,64,32,64,4,4,4,52.0989 +sdd,False,True,64,64,16,4,4,4,81.5417 +sdd,False,True,64,64,32,4,4,4,59.0152 +sdd,False,True,64,64,64,4,4,4,52.0989 +sdd,True,False,16,16,16,4,4,4,130.6109 +sdd,True,False,16,16,32,4,4,4,84.93 +sdd,True,False,16,16,64,4,4,4,79.0723 +sdd,True,False,16,32,16,4,4,4,142.6209 +sdd,True,False,16,32,32,4,4,4,82.0366 +sdd,True,False,16,32,64,4,4,4,81.5747 +sdd,True,False,16,64,16,4,4,4,165.6981 +sdd,True,False,16,64,32,4,4,4,92.0803 +sdd,True,False,16,64,64,4,4,4,84.1962 +sdd,True,False,32,16,16,4,4,4,96.9497 +sdd,True,False,32,16,32,4,4,4,62.008 +sdd,True,False,32,16,64,4,4,4,55.4291 +sdd,True,False,32,32,16,4,4,4,113.3967 +sdd,True,False,32,32,32,4,4,4,70.2921 +sdd,True,False,32,32,64,4,4,4,58.7307 +sdd,True,False,32,64,16,4,4,4,120.7301 +sdd,True,False,32,64,32,4,4,4,78.2809 +sdd,True,False,32,64,64,4,4,4,60.8745 +sdd,True,False,64,16,16,4,4,4,125.0771 +sdd,True,False,64,16,32,4,4,4,76.4491 +sdd,True,False,64,16,64,4,4,4,60.097 +sdd,True,False,64,32,16,4,4,4,141.8331 +sdd,True,False,64,32,32,4,4,4,83.7384 +sdd,True,False,64,32,64,4,4,4,60.9805 +sdd,True,False,64,64,16,4,4,4,141.8331 +sdd,True,False,64,64,32,4,4,4,83.7384 +sdd,True,False,64,64,64,4,4,4,60.9805 +sdd,True,True,16,16,16,4,4,4,159.6109 +sdd,True,True,16,16,32,4,4,4,105.5516 +sdd,True,True,16,16,64,4,4,4,100.7452 +sdd,True,True,16,32,16,4,4,4,211.4903 +sdd,True,True,16,32,32,4,4,4,101.2889 +sdd,True,True,16,32,64,4,4,4,90.9672 +sdd,True,True,16,64,16,4,4,4,303.9266 +sdd,True,True,16,64,32,4,4,4,146.3723 +sdd,True,True,16,64,64,4,4,4,133.093 +sdd,True,True,32,16,16,4,4,4,103.5772 +sdd,True,True,32,16,32,4,4,4,71.01 +sdd,True,True,32,16,64,4,4,4,63.2092 +sdd,True,True,32,32,16,4,4,4,128.0946 +sdd,True,True,32,32,32,4,4,4,86.1027 +sdd,True,True,32,32,64,4,4,4,68.897 +sdd,True,True,32,64,16,4,4,4,153.6706 +sdd,True,True,32,64,32,4,4,4,111.8289 +sdd,True,True,32,64,64,4,4,4,91.2396 +sdd,True,True,64,16,16,4,4,4,127.9097 +sdd,True,True,64,16,32,4,4,4,80.6409 +sdd,True,True,64,16,64,4,4,4,62.273 +sdd,True,True,64,32,16,4,4,4,148.6517 +sdd,True,True,64,32,32,4,4,4,90.9354 +sdd,True,True,64,32,64,4,4,4,67.6934 +sdd,True,True,64,64,16,4,4,4,148.6517 +sdd,True,True,64,64,32,4,4,4,90.9354 +sdd,True,True,64,64,64,4,4,4,67.6934 diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 0550718e..5830df80 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -15,7 +15,18 @@ from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig -TILE_LUT = pd.read_csv(io.StringIO(res.read_text(look_up_tables, 'sparta_matmul_lut.csv'))) +def _get_sparta_matmul_lut(): + major, minor = torch.cuda.get_device_capability() + try: + lut_file = f'matmul.sparta.{major}{minor}.csv' + lut_text = res.read_text(look_up_tables, lut_file) + except FileNotFoundError: + lut_file = f'matmul.sparta.default.csv' + lut_text = res.read_text(look_up_tables, lut_file) + return pd.read_csv(io.StringIO(lut_text)) + + +SPARTA_LUT = _get_sparta_matmul_lut() class SparseMatMulKernel(KernelBase): @@ -44,10 +55,10 @@ def __init__( self._sparse_block_H = '' self._sparse_block_W = '' self._tesa_vars = [] - mode_filter = TILE_LUT['mode'] == self._mode - trans_A_filter = TILE_LUT['trans_A'] == self._transpose_A - trans_B_filter = TILE_LUT['trans_B'] == self._transpose_B - self._lut = TILE_LUT[mode_filter & trans_A_filter & trans_B_filter] + mode_filter = SPARTA_LUT['mode'] == self._mode + trans_A_filter = SPARTA_LUT['trans_A'] == self._transpose_A + trans_B_filter = SPARTA_LUT['trans_B'] == self._transpose_B + self._lut = SPARTA_LUT[mode_filter & trans_A_filter & trans_B_filter] super().__init__() def _set_ports(self): diff --git a/sparta/testing/utils.py b/sparta/testing/utils.py index 4eb4f707..474eae26 100644 --- a/sparta/testing/utils.py +++ b/sparta/testing/utils.py @@ -70,4 +70,4 @@ def check(func: Callable, inputs: List, target_outputs: List): outputs = [outputs] assert len(outputs) == len(target_outputs), f'expected {len(target_outputs)} outputs, got {len(outputs)}' for output, target_output in zip(outputs, target_outputs): - torch.testing.assert_close(output, target_output, atol=1e-4, rtol=1e-8) + torch.testing.assert_close(output, target_output, atol=1e-4, rtol=1e-4) diff --git a/test/bench/attention/attention.py b/test/bench/attention/attention.py index b4874394..936fb0e2 100644 --- a/test/bench/attention/attention.py +++ b/test/bench/attention/attention.py @@ -188,8 +188,8 @@ def dense_attention(query, key, value): def load_sparta_config(device: Any = 'cuda'): - device_name = torch.cuda.get_device_name(device) - device_cfg_path = os.path.join(WORK_DIR, 'params', f'{device_name}.csv') + major, minor = torch.cuda.get_device_capability() + device_cfg_path = os.path.join(WORK_DIR, 'params', f'{major}{minor}.csv') default_cfg_path = os.path.join(WORK_DIR, 'params', 'default.csv') if os.path.exists(device_cfg_path): return pd.read_csv(device_cfg_path) diff --git a/test/bench/matmul/matmul.py b/test/bench/matmul/matmul.py index 078e6cf6..cbe0715f 100644 --- a/test/bench/matmul/matmul.py +++ b/test/bench/matmul/matmul.py @@ -156,8 +156,8 @@ def profile_dense_matmul( def load_sparta_config(device: Any = 'cuda'): - device_name = torch.cuda.get_device_name(device) - device_cfg_path = os.path.join(WORK_DIR, 'params', f'{device_name}.csv') + major, minor = torch.cuda.get_device_capability() + device_cfg_path = os.path.join(WORK_DIR, 'params', f'{major}{minor}.csv') default_cfg_path = os.path.join(WORK_DIR, 'params', 'default.csv') if os.path.exists(device_cfg_path): return pd.read_csv(device_cfg_path) diff --git a/test/unit/test_sparse_attention.py b/test/unit/test_sparse_attention.py index 43420f6d..d8e16f51 100644 --- a/test/unit/test_sparse_attention.py +++ b/test/unit/test_sparse_attention.py @@ -20,9 +20,6 @@ def get_params(): 'BLOCK_SIZE_M_VALUE': 32, 'BLOCK_SIZE_K_VALUE': 32, 'BLOCK_SIZE_N_VALUE': 32, - 'THREAD_SIZE_M_VALUE': 4, - 'THREAD_SIZE_K_VALUE': 4, - 'THREAD_SIZE_N_VALUE': 4, } for kernel_name in matmul_kernel_names } diff --git a/test/unit/test_sparse_matmul.py b/test/unit/test_sparse_matmul.py index 3a984dca..d57f3b9e 100644 --- a/test/unit/test_sparse_matmul.py +++ b/test/unit/test_sparse_matmul.py @@ -102,9 +102,6 @@ def get_params(impl: str): 'BLOCK_SIZE_M_VALUE': 32, 'BLOCK_SIZE_K_VALUE': 32, 'BLOCK_SIZE_N_VALUE': 32, - 'THREAD_SIZE_M_VALUE': 4, - 'THREAD_SIZE_K_VALUE': 4, - 'THREAD_SIZE_N_VALUE': 4, } else: return {'_impl': impl} From deca07de32324e28dd4f39a0c0f35e42f42e2ef7 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 31 Jan 2023 06:13:42 +0000 Subject: [PATCH 02/28] fix matmul param checker; bench unbind device --- sparta/specializer/kernels/matmul.py | 45 ++++++++++++------- test/bench/attention/attention.py | 12 +---- .../{params/default.csv => sparta_params.csv} | 0 test/bench/matmul/matmul.py | 12 +---- .../{params/default.csv => sparta_params.csv} | 0 5 files changed, 30 insertions(+), 39 deletions(-) rename test/bench/attention/{params/default.csv => sparta_params.csv} (100%) rename test/bench/matmul/{params/default.csv => sparta_params.csv} (100%) diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 5830df80..9b096d64 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -286,23 +286,34 @@ def threads_per_block(self): return (BN // TN, BM // TM, 1) def _check_parameters(self, params: Dict[str, Any]): - if 'THREAD_SIZE_M_VALUE' in params: - if 'THREAD_SIZE_K_VALUE' in params: - if 'THREAD_SIZE_N_VALUE' in params: - BM = params['BLOCK_SIZE_M_VALUE'] - BK = params['BLOCK_SIZE_K_VALUE'] - BN = params['BLOCK_SIZE_N_VALUE'] - TM = params['THREAD_SIZE_M_VALUE'] - TK = params['THREAD_SIZE_K_VALUE'] - TN = params['THREAD_SIZE_N_VALUE'] - assert BM > TM - assert BK > TK - assert BN > TN - A_thread_per_rows = (BM if self._transpose_A else BK) // 4 - B_thread_per_rows = (BK if self._transpose_A else BN) // 4 - threads_per_block = (BM // TM) * (BN // TN) - assert threads_per_block >= A_thread_per_rows - assert threads_per_block >= B_thread_per_rows + BM = params['BLOCK_SIZE_M_VALUE'] + BK = params['BLOCK_SIZE_K_VALUE'] + BN = params['BLOCK_SIZE_N_VALUE'] + assert BM >= 4 + assert BN >= 4 + assert BK >= 4 + assert BM & (BM - 1) == 0 + assert BK & (BK - 1) == 0 + assert BN & (BN - 1) == 0 + if all([f'THREAD_SIZE_{dim}_VALUE' in params for dim in ['M', 'K', 'N']]): + TM = params['THREAD_SIZE_M_VALUE'] + TK = params['THREAD_SIZE_K_VALUE'] + TN = params['THREAD_SIZE_N_VALUE'] + assert BM >= TM + assert BK >= TK + assert BN >= TN + assert TM & (TM - 1) == 0 + assert TK & (TK - 1) == 0 + assert TN & (TN - 1) == 0 + A_threads_per_row = (BM if self._transpose_A else BK) // 4 + B_threads_per_row = (BK if self._transpose_A else BN) // 4 + threads_per_block = (BM // TM) * (BN // TN) + assert threads_per_block >= A_threads_per_row + assert threads_per_block >= B_threads_per_row + A_tile_row_stride = threads_per_block // A_threads_per_row + B_tile_row_stride = threads_per_block // B_threads_per_row + assert A_tile_row_stride <= (BK if self._transpose_A else BM) + assert B_tile_row_stride <= (BN if self._transpose_B else BK) class OpenAISparseMatMulKernel(SparseMatMulKernel): diff --git a/test/bench/attention/attention.py b/test/bench/attention/attention.py index 936fb0e2..127967da 100644 --- a/test/bench/attention/attention.py +++ b/test/bench/attention/attention.py @@ -187,16 +187,6 @@ def dense_attention(query, key, value): return profile_attention(dense_attention, data) -def load_sparta_config(device: Any = 'cuda'): - major, minor = torch.cuda.get_device_capability() - device_cfg_path = os.path.join(WORK_DIR, 'params', f'{major}{minor}.csv') - default_cfg_path = os.path.join(WORK_DIR, 'params', 'default.csv') - if os.path.exists(device_cfg_path): - return pd.read_csv(device_cfg_path) - else: - return pd.read_csv(default_cfg_path) - - def get_sparta_config(configs: pd.DataFrame, granularity: int, sparsity: float): condition = (configs['granularity'] == granularity) & (configs['sparsity'] == sparsity) config: Dict[str, Any] = {} @@ -216,7 +206,7 @@ def profile_all(log_path: str, device: Any = 'cuda'): cols = ['method', 'Ns', 'Nt', 'E', 'granularity', 'sparsity', 'forward', 'backward'] with open(log_path, 'w') as f: f.write(','.join(cols) + '\n') - sparta_configs = load_sparta_config(device) + sparta_configs = pd.read_csv(os.path.join(WORK_DIR, 'sparta_params.csv')) for g in GRANULARITY_LIST: for s in SPARSITY_LIST: print(f'========== Granularuty: {g} Sparsity: {s} ==========') diff --git a/test/bench/attention/params/default.csv b/test/bench/attention/sparta_params.csv similarity index 100% rename from test/bench/attention/params/default.csv rename to test/bench/attention/sparta_params.csv diff --git a/test/bench/matmul/matmul.py b/test/bench/matmul/matmul.py index cbe0715f..e8c8243f 100644 --- a/test/bench/matmul/matmul.py +++ b/test/bench/matmul/matmul.py @@ -155,16 +155,6 @@ def profile_dense_matmul( return profile_matmul(dense_matmul, data) -def load_sparta_config(device: Any = 'cuda'): - major, minor = torch.cuda.get_device_capability() - device_cfg_path = os.path.join(WORK_DIR, 'params', f'{major}{minor}.csv') - default_cfg_path = os.path.join(WORK_DIR, 'params', 'default.csv') - if os.path.exists(device_cfg_path): - return pd.read_csv(device_cfg_path) - else: - return pd.read_csv(default_cfg_path) - - def get_sparta_config(configs: pd.DataFrame, granularity: int, sparsity: float): condition = (configs['granularity'] == granularity) & (configs['sparsity'] == sparsity) config: Dict[str, Any] = {} @@ -184,7 +174,7 @@ def profile_all(log_path: str, device: Any = 'cuda'): cols = ['method', 'M', 'K', 'N', 'granularity', 'sparsity', 'forward', 'backward'] with open(log_path, 'w') as f: f.write(','.join(cols) + '\n') - sparta_configs = load_sparta_config(device) + sparta_configs = pd.read_csv(os.path.join(WORK_DIR, 'sparta_params.csv')) for g in GRANULARITY_LIST: for s in SPARSITY_LIST: print(f'========== Granularuty: {g} Sparsity: {s} ==========') diff --git a/test/bench/matmul/params/default.csv b/test/bench/matmul/sparta_params.csv similarity index 100% rename from test/bench/matmul/params/default.csv rename to test/bench/matmul/sparta_params.csv From 8528fa724b30d91fa8e56e091269759124dfc3c1 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 31 Jan 2023 15:13:32 +0800 Subject: [PATCH 03/28] rearrange tunning code --- sparta/common/__init__.py | 2 - sparta/common/utils.py | 99 ------------ sparta/nn/module_tuner.py | 152 +----------------- sparta/specializer/kernels/kernel_base.py | 2 +- sparta/specializer/kernels/matmul.py | 2 +- sparta/specializer/kernels/softmax.py | 2 +- sparta/testing/utils.py | 43 +++-- sparta/tuning/__init__.py | 5 + .../{common/tuning.py => tuning/tunable.py} | 10 +- sparta/tuning/tuners.py | 81 ++++++++++ 10 files changed, 115 insertions(+), 283 deletions(-) delete mode 100644 sparta/common/__init__.py delete mode 100644 sparta/common/utils.py create mode 100644 sparta/tuning/__init__.py rename sparta/{common/tuning.py => tuning/tunable.py} (94%) create mode 100644 sparta/tuning/tuners.py diff --git a/sparta/common/__init__.py b/sparta/common/__init__.py deleted file mode 100644 index 9a045456..00000000 --- a/sparta/common/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. diff --git a/sparta/common/utils.py b/sparta/common/utils.py deleted file mode 100644 index bf4afa1b..00000000 --- a/sparta/common/utils.py +++ /dev/null @@ -1,99 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import ctypes -import logging -import uuid -from typing import List, Dict - -_logger = logging.Logger(__name__) - - -def cuda_detect(): - """ - Detect the cuda environment. - """ - def ConvertSMVer2Cores(major, minor): - # Returns the number of CUDA cores per multiprocessor for a given - # Compute Capability version. There is no way to retrieve that via - # the API, so it needs to be hard-coded. - # See _ConvertSMVer2Cores in helper_cuda.h in NVIDIA's CUDA Samples. - return {(1, 0): 8, # Tesla - (1, 1): 8, - (1, 2): 8, - (1, 3): 8, - (2, 0): 32, # Fermi - (2, 1): 48, - (3, 0): 192, # Kepler - (3, 2): 192, - (3, 5): 192, - (3, 7): 192, - (5, 0): 128, # Maxwell - (5, 2): 128, - (5, 3): 128, - (6, 0): 64, # Pascal - (6, 1): 128, - (6, 2): 128, - (7, 0): 64, # Volta - (7, 2): 64, - (7, 5): 64, # Turing - (8, 0): 64, # Ampere - (8, 6): 64, - }.get((major, minor), 0) - # Some constants taken from cuda.h - CUDA_SUCCESS = 0 - libnames = ('libcuda.so', 'libcuda.dylib', 'cuda.dll') - for libname in libnames: - try: - cuda = ctypes.CDLL(libname) - except OSError: - continue - else: - break - else: - raise OSError("could not load any of: " + ' '.join(libnames)) - - nGpus = ctypes.c_int() - name = b' ' * 100 - cc_major = ctypes.c_int() - cc_minor = ctypes.c_int() - - result = ctypes.c_int() - device = ctypes.c_int() - context = ctypes.c_void_p() - error_str = ctypes.c_char_p() - - result = cuda.cuInit(0) - if result != CUDA_SUCCESS: - cuda.cuGetErrorString(result, ctypes.byref(error_str)) - _logger.warning("cuInit failed with error code %d: %s", result, error_str.value.decode()) - return - result = cuda.cuDeviceGetCount(ctypes.byref(nGpus)) - if result != CUDA_SUCCESS: - cuda.cuGetErrorString(result, ctypes.byref(error_str)) - _logger.warning("cuDeviceGetCount failed with error code %d: %s", result, error_str.value.decode()) - return - devices = [] - for i in range(nGpus.value): - result = cuda.cuDeviceGet(ctypes.byref(device), i) - if result != CUDA_SUCCESS: - cuda.cuGetErrorString(result, ctypes.byref(error_str)) - _logger.warning("cuDeviceGet failed with error code %d: %s", result, error_str.value.decode()) - return - device_name = None - device_code = None - if cuda.cuDeviceGetName(ctypes.c_char_p(name), len(name), device) == CUDA_SUCCESS: - device_name = name.split(b'\0', 1)[0].decode() - if cuda.cuDeviceComputeCapability(ctypes.byref(cc_major), ctypes.byref(cc_minor), device) == CUDA_SUCCESS: - device_code = "%d%d" % (cc_major.value, cc_minor.value) - devices.append((device_name, device_code)) - return devices - - -def get_uname(num: int = 8): - return str(uuid.uuid4())[:8] - - -def check_type(obj, cls): - """assert obj is instance of cls""" - assert isinstance(obj, cls), f'{obj} is not instance of {cls}' diff --git a/sparta/nn/module_tuner.py b/sparta/nn/module_tuner.py index d071dabf..fa851a1e 100644 --- a/sparta/nn/module_tuner.py +++ b/sparta/nn/module_tuner.py @@ -1,18 +1,15 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -import abc import sys -import random import logging import warnings -from typing import Any, List, Dict, Callable, Iterator, Optional -from dataclasses import dataclass, field +from typing import Any, List, Dict, Callable, Optional import torch import numpy as np -from sparta.common.tuning import Tunable, TunableItemCfg +from sparta.tuning import TunableItemCfg, GridSearchTuner, RandomSearchTuner from sparta.specializer import OperatorBase @@ -21,151 +18,6 @@ _logger.addHandler(_handler) -# def tune_combined_module( -# module: torch.nn.Module, sample_inputs: List[torch.Tensor], sample_grads: List[torch.Tensor], -# algo: str = 'grid', max_trials: int = sys.maxsize, backward_weight: float = 0, -# tester_kw: Dict = None, build_kw: Dict = None, tuner_kw: Dict = None, verbose: bool = False -# ): -# """Find, tune and build all sparse operators in the model. - -# Args: -# module (torch.nn.Module): A PyTorch module that contains one or more sparse sub-modules. -# sample_inputs (List[torch.Tensor]): Sample input tensors to determine shape parameters. -# algo: (Optional[str]): The algorithm to search the best parameters. Defaults to 'grid'. -# max_trials: (Optional[int]): The maximum number of trials to run. Defaults to sys.maxsize. -# tester_kw: (Optional[Dict]): The keyword arguments for the tester. Defaults to None. -# build_kw: (Optional[Dict]): The keyword arguments for the builder (after tuning). Defaults to None. -# tuner_kw: (Optional[Dict]): The keyword arguments for the tuner. Defaults to None. -# """ -# from nni import NoMoreTrialError - -# @dataclass -# class _TuningContext: -# """Context for tuning.""" -# module_dict: Dict[str, OperatorBase] = field(default_factory=dict) -# space_dict: Dict[str, TunableItemCfg] = field(default_factory=dict) -# input_dict: Dict[str, list] = field(default_factory=dict) -# best_latency: float = math.inf -# best_params: Dict = None - -# def add(self, name, module, space, inputs): -# """Add a module to the context.""" -# _logger.info(f'tunable operator deduced {type(module)} {name} ') -# self.module_dict[name] = module -# self.space_dict[name] = space -# self.input_dict[name] = inputs - -# ctx = _TuningContext() - -# if isinstance(module, OperatorBase): -# ctx.add('root', module, module.get_search_space(), sample_inputs) -# else: -# sample_inputs_dict = {} -# for child_name, child_module in module.named_children(): -# sample_inputs_dict[child_name] = [] -# child_module.register_forward_hook(get_input_hook(sample_inputs_dict, child_name)) -# with warnings.catch_warnings(): -# warnings.simplefilter('ignore') -# module.forward(*sample_inputs) -# for child_name, child_module in module.named_children(): -# if isinstance(child_module, OperatorBase): -# ctx.add(child_name, child_module, child_module.get_search_space(), sample_inputs_dict[child_name]) - -# tuner = Tunable.create_tuner(algo, ctx.space_dict, tuner_kw) -# tester_kw = tester_kw or {} -# for i in range(max_trials): -# try: -# params = tuner.generate_parameters(i) -# except NoMoreTrialError: -# break -# latency = 0.0 -# try: -# for name, module in ctx.module_dict.items(): -# latency += module.test( -# params[name], -# sample_inputs=ctx.input_dict[name], -# **tester_kw -# ) -# except AssertionError: -# _logger.warn(f'Invalid config') -# continue -# _logger.info(f'params:{params} -> latency: {latency}') -# tuner.receive_trial_result(i, params, latency) # TODO: add status here -# if latency < ctx.best_latency: -# ctx.best_latency = latency -# ctx.best_params = params -# tuner.trial_end(i, True) - -# build_kw = build_kw or {} -# for name, module in ctx.module_dict.items(): -# module.build(ctx.best_params[name], sample_inputs=ctx.input_dict[name], **build_kw) -# return ctx.best_params - - -class Tuner(object): - - def __init__( - self, - search_space: Dict[Any, TunableItemCfg], - eval_func: Callable[[int, Dict[Any, Any]], float], - max_trials: int = sys.maxsize, - ): - self._search_space = search_space - self._eval_func = eval_func - space_shape = [len(param_space._value) for param_space in search_space.values()] - space_size = int(np.prod(space_shape)) - self._max_trials = min(max_trials, space_size) - self.best_result = np.inf - self.best_config = None - - @abc.abstractmethod - def next_config(self) -> Iterator[Dict[str, Any]]: - """Yields the next config.""" - - def tune(self): - for i, config in zip(range(self._max_trials), self.next_config()): - result = self._eval_func(i, config) - if result < self.best_result: - self.best_result = result - self.best_config = config - - -class RandomSearchTuner(Tuner): - - def next_config(self): - while True: - yield { - param_name: random.choice(param_space._value) - for param_name, param_space in self._search_space.items() - } - - -class GridSearchTuner(Tuner): - - def next_config(self): - if len(self._search_space) == 0: - yield {} - else: - param_names = [] - param_idxs = [] - param_space_sizes = [] - for param_name, param_space in self._search_space.items(): - param_names.append(param_name) - param_space_sizes.append(len(param_space._value)) - param_idxs.append(0) - while param_idxs[0] < param_space_sizes[0]: - yield { - param_name: self._search_space[param_name]._value[param_idx] - for param_idx, param_name in zip(param_idxs, param_names) - } - k = len(param_idxs) - 1 - param_idxs[k] += 1 - while param_idxs[k] == param_space_sizes[k] and k > 0: - param_idxs[k - 1] += 1 - param_idxs[k] = 0 - k -= 1 - - def tune_sparse_module( module: OperatorBase, name: str, diff --git a/sparta/specializer/kernels/kernel_base.py b/sparta/specializer/kernels/kernel_base.py index c73eb675..5ee8898b 100644 --- a/sparta/specializer/kernels/kernel_base.py +++ b/sparta/specializer/kernels/kernel_base.py @@ -17,7 +17,7 @@ from pycuda.compiler import SourceModule from sparta.tesa import get_bcs_function, BCSIndexes -from sparta.common.tuning import TunableItemCfg +from sparta.tuning import TunableItemCfg from sparta.testing import profile diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 9b096d64..2f3873e6 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -10,7 +10,7 @@ import numpy as np import pandas as pd -from sparta.common.tuning import TunableItemCfg +from sparta.tuning import TunableItemCfg from sparta.specializer.kernels import templates, look_up_tables from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig diff --git a/sparta/specializer/kernels/softmax.py b/sparta/specializer/kernels/softmax.py index 22fbb031..759f3a4b 100644 --- a/sparta/specializer/kernels/softmax.py +++ b/sparta/specializer/kernels/softmax.py @@ -8,7 +8,7 @@ import jinja2 import numpy as np -from sparta.common.tuning import TunableItemCfg +from sparta.tuning import TunableItemCfg from sparta.specializer.kernels import templates from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference diff --git a/sparta/testing/utils.py b/sparta/testing/utils.py index 474eae26..4e328597 100644 --- a/sparta/testing/utils.py +++ b/sparta/testing/utils.py @@ -29,31 +29,28 @@ def profile( """ if target_outputs is not None: check(func, inputs, target_outputs) - try: - torch.cuda.synchronize() - for _ in range(num_warmups): - func(*inputs) - torch.cuda.synchronize() - if cuda: - with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p: - for _ in range(num_iters): - func(*inputs) - latency = 0 - for event in p.key_averages(): - if event.key != 'cudaDeviceSynchronize': - latency += event.cuda_time * event.count - latency /= num_iters * 1000 - else: - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - start.record() + torch.cuda.synchronize() + for _ in range(num_warmups): + func(*inputs) + torch.cuda.synchronize() + if cuda: + with torch.profiler.profile(activities=[torch.profiler.ProfilerActivity.CUDA]) as p: for _ in range(num_iters): func(*inputs) - end.record() - torch.cuda.synchronize() - latency = start.elapsed_time(end) / num_iters - except: - latency = float('inf') + latency = 0 + for event in p.key_averages(): + if event.key != 'cudaDeviceSynchronize': + latency += event.cuda_time * event.count + latency /= num_iters * 1000 + else: + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + start.record() + for _ in range(num_iters): + func(*inputs) + end.record() + torch.cuda.synchronize() + latency = start.elapsed_time(end) / num_iters return latency diff --git a/sparta/tuning/__init__.py b/sparta/tuning/__init__.py new file mode 100644 index 00000000..cd761772 --- /dev/null +++ b/sparta/tuning/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from sparta.tuning.tunable import Tunable, TunableItemCfg +from sparta.tuning.tuners import Tuner, GridSearchTuner, RandomSearchTuner diff --git a/sparta/common/tuning.py b/sparta/tuning/tunable.py similarity index 94% rename from sparta/common/tuning.py rename to sparta/tuning/tunable.py index 53d4b51a..7dade30b 100644 --- a/sparta/common/tuning.py +++ b/sparta/tuning/tunable.py @@ -4,8 +4,6 @@ from dataclasses import dataclass from typing import Optional, Any, Union, Dict, List -from sparta.common.utils import check_type, get_uname - @dataclass class TunableItemCfg: @@ -47,13 +45,13 @@ class TunableItemCfg: def __post_init__(self): assert self._type in ['choice'] if self._is_nested: - check_type(self._value, dict) + assert isinstance(self._value, dict) for ss_name, ss_params in self._value.items(): - check_type(ss_params, dict) + assert isinstance(ss_params, dict) for p_name, p_cfg in ss_params.items(): - check_type(p_cfg, TunableItemCfg) + assert isinstance(p_cfg, TunableItemCfg) else: - check_type(self._value, list) + assert isinstance(self._value, list) def to_nni_search_space(self): """convert to nni search space""" diff --git a/sparta/tuning/tuners.py b/sparta/tuning/tuners.py new file mode 100644 index 00000000..edabe637 --- /dev/null +++ b/sparta/tuning/tuners.py @@ -0,0 +1,81 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import abc +import sys +import random +import logging +from typing import Any, Dict, Callable, Iterator + +import numpy as np + +from sparta.tuning.tunable import TunableItemCfg + + +_logger = logging.Logger(__name__) +_handler = logging.StreamHandler() +_logger.addHandler(_handler) + + +class Tuner(object): + + def __init__( + self, + search_space: Dict[Any, TunableItemCfg], + eval_func: Callable[[int, Dict[Any, Any]], float], + max_trials: int = sys.maxsize, + ): + self._search_space = search_space + self._eval_func = eval_func + space_shape = [len(param_space._value) for param_space in search_space.values()] + space_size = int(np.prod(space_shape)) + self._max_trials = min(max_trials, space_size) + self.best_result = np.inf + self.best_config = None + + @abc.abstractmethod + def next_config(self) -> Iterator[Dict[str, Any]]: + """Yields the next config.""" + + def tune(self): + for i, config in zip(range(self._max_trials), self.next_config()): + result = self._eval_func(i, config) + if result < self.best_result: + self.best_result = result + self.best_config = config + + +class RandomSearchTuner(Tuner): + + def next_config(self): + while True: + yield { + param_name: random.choice(param_space._value) + for param_name, param_space in self._search_space.items() + } + + +class GridSearchTuner(Tuner): + + def next_config(self): + if len(self._search_space) == 0: + yield {} + else: + param_names = [] + param_idxs = [] + param_space_sizes = [] + for param_name, param_space in self._search_space.items(): + param_names.append(param_name) + param_space_sizes.append(len(param_space._value)) + param_idxs.append(0) + while param_idxs[0] < param_space_sizes[0]: + yield { + param_name: self._search_space[param_name]._value[param_idx] + for param_idx, param_name in zip(param_idxs, param_names) + } + k = len(param_idxs) - 1 + param_idxs[k] += 1 + while param_idxs[k] == param_space_sizes[k] and k > 0: + param_idxs[k - 1] += 1 + param_idxs[k] = 0 + k -= 1 From b718c59587ed24d65e0ccc8c25cde0758978d8d7 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 31 Jan 2023 16:28:50 +0800 Subject: [PATCH 04/28] LUT maker --- .gitignore | 1 + sparta/tuning/tuners.py | 5 - test/lut_maker/sparta_matmul.py | 223 ++++++++++++++++++++++++++++++++ 3 files changed, 224 insertions(+), 5 deletions(-) create mode 100644 test/lut_maker/sparta_matmul.py diff --git a/.gitignore b/.gitignore index c26c720d..c64fdcdd 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,4 @@ _build generated test/bench/*/latency.csv test/bench/*/latency.png +test/lut_maker/*.log.csv diff --git a/sparta/tuning/tuners.py b/sparta/tuning/tuners.py index edabe637..5693c184 100644 --- a/sparta/tuning/tuners.py +++ b/sparta/tuning/tuners.py @@ -12,11 +12,6 @@ from sparta.tuning.tunable import TunableItemCfg -_logger = logging.Logger(__name__) -_handler = logging.StreamHandler() -_logger.addHandler(_handler) - - class Tuner(object): def __init__( diff --git a/test/lut_maker/sparta_matmul.py b/test/lut_maker/sparta_matmul.py new file mode 100644 index 00000000..afee4f4e --- /dev/null +++ b/test/lut_maker/sparta_matmul.py @@ -0,0 +1,223 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import logging +import itertools +from typing import Dict, Tuple + +import torch +import pandas as pd + +from sparta.specializer.kernels import SparTASparseMatMulKernel +from sparta.tesa import BCSIndexes +from sparta.testing import block_mask, profile + + +SEARCH_SPACE = { + 'mode': ['sdd', 'dsd', 'dds'], + 'trans_A': [False, True], + 'trans_B': [False, True], + 'BM': [8, 16, 32, 64, 128], + 'BK': [8, 16, 32, 64, 128], + 'BN': [8, 16, 32, 64, 128], + 'TM': [2, 4, 8, 16], + 'TK': [2, 4, 8, 16], + 'TN': [2, 4, 8, 16], +} + + +_logger = logging.Logger(__name__) +_handler = logging.StreamHandler() +_logger.addHandler(_handler) + + +def prepare_data( + batch: int = 4, + M: int = 128, + K: int = 256, + N: int = 192, + granularity: Tuple[int, int] = (8, 8), + sparsity: float = 0.9, + mode: str = 'dds', + trans_A: bool = False, + trans_B: bool = False, + biased: bool = False, + requires_grad: bool = False, + random_seed: int = 2022, +): + inputs = ['A', 'B'] + outputs = ['C'] + shapes = { + 'A': (K, M) if trans_A else (M, K), + 'B': (N, K) if trans_B else (K, N), + 'C': (M, N), + } + if biased: + inputs.append('bias') + shapes['bias'] = (N, ) + + torch.manual_seed(random_seed) + data: Dict[str, torch.Tensor] = {} + for x in inputs: + data[f'input_{x}'] = torch.rand(size=(batch, *shapes[x]), device='cuda') + if requires_grad: + for y in outputs: + data[f'input_grad_{y}'] = torch.rand(size=(batch, *shapes[y]), device='cuda') + + sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] + mask = block_mask(shapes[sparse_port], block=granularity, sparsity=sparsity, device='cuda') + add_mask(data, {sparse_port: mask}, sparse_port, 'input') + + if requires_grad: + for x in inputs: + data[f'input_{x}'].requires_grad = True + + input_A = data['input_A'].swapaxes(1, 2) if trans_A else data['input_A'] + input_B = data['input_B'].swapaxes(1, 2) if trans_B else data['input_B'] + data['target_C'] = torch.bmm(input_A, input_B) + if biased: + data['target_C'] += data['input_bias'].unsqueeze(1) + + if requires_grad: + data['target_C'].backward(data['input_grad_C']) + data['target_grad_A'] = data['input_A'].grad + data['input_A'].grad = None + data['target_grad_B'] = data['input_B'].grad + data['input_B'].grad = None + if biased: + data['target_grad_bias'] = data['input_bias'].grad + data['input_bias'].grad = None + + add_mask(data, {sparse_port: mask}, sparse_port, 'target') + + return data, {sparse_port: mask} + + +def add_mask( + data: Dict[str, torch.Tensor], + masks: Dict[str, torch.Tensor], + sparse_port: str, + stage: str, +): + for name, val in data.items(): + if name.startswith(stage) and name.endswith(sparse_port): + val *= masks[sparse_port] + + +def compress_data( + indexes: BCSIndexes, + sparse_port: str, + data: Dict[str, torch.Tensor], + masks: Dict[str, torch.Tensor], +): + for name in data: + if name.endswith(sparse_port): + data[name] = indexes.convert(data[name].detach()) + masks[sparse_port] = indexes.convert(masks[sparse_port].to(torch.float32)).to(torch.uint8) + if sparse_port in ['A', 'B']: + data[f'input_{sparse_port}'].requires_grad = True + + +def check_results(data: Dict[str, torch.Tensor]): + for name, val in data.items(): + if name.startswith('target_'): + torch.testing.assert_close(val, data[name.replace('target', 'output')], msg=name) + + +def test_sparse_matmul_kernel( + mode: str, + trans_A: bool, + trans_B: bool, + BM: int, + BK: int, + BN: int, + TM: int, + TK: int, + TN: int, + biased: bool = False, + compressed: bool = True, + batch: int = 1, + M: int = 4096, + K: int = 4096, + N: int = 4096, + granularity: Tuple[int, int] = (1, 1), + sparsity: float = 0, +): + data, masks = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, False) + + try: + kernel = SparTASparseMatMulKernel( + mode=mode, + biased=biased, + transpose_A=trans_A, + transpose_B=trans_B, + compressed=compressed, + ) + + for sparse_port, mask in masks.items(): + kernel.ports[sparse_port].set_mask(mask) + kernel.set_shape(batch, M, K, N) + kernel.compile({ + 'BLOCK_SIZE_M_VALUE': BM, + 'BLOCK_SIZE_K_VALUE': BK, + 'BLOCK_SIZE_N_VALUE': BN, + 'THREAD_SIZE_M_VALUE': TM, + 'THREAD_SIZE_K_VALUE': TK, + 'THREAD_SIZE_N_VALUE': TN, + }) + + sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] + if compressed: + compress_data(kernel.ports[sparse_port].indexes, sparse_port, data, masks) + + inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] + input_data = [data[f'input_{x}'].detach() for x in inputs] + + latency = profile(kernel, input_data, num_warmups=10, num_iters=10, cuda=False) + torch.cuda.synchronize() + except: + latency = float('inf') + + return latency + + +if __name__ == '__main__': + _logger.setLevel(logging.DEBUG) + major, minor = torch.cuda.get_device_capability() + lut_file = os.path.join( + 'sparta', + 'specializer', + 'kernels', + 'look_up_tables', + f'matmul.sparta.{major}{minor}.csv' + ) + log_file = os.path.join( + 'test', + 'lut_maker', + f'matmul.sparta.{major}{minor}.log.csv' + ) + _logger.info(f'========== Making LUT: {lut_file} ==========') + + num = 1 + keys, values = [], [] + for k, v in SEARCH_SPACE.items(): + keys.append(k) + values.append(v) + num *= len(v) + + with open(log_file, 'w') as f: + f.write(','.join(keys) + ',latency\n') + + for i, params in enumerate(itertools.product(*values)): + latency = test_sparse_matmul_kernel(**{k: v for k, v in zip(keys, params)}) + with open(log_file, 'a') as f: + f.write(','.join([str(x) for x in params]) + f',{latency}\n') + _logger.info(f'[{i} / {num}] {params} => {latency} ms') + + df = pd.read_csv(log_file) + df = df.groupby(['mode', 'trans_A', 'trans_B', 'BM', 'BK', 'BN']).min('latency') + with open(lut_file, 'w') as f: + f.write(df.reset_index().to_csv(index=False)) + + _logger.info(f'========== Finished. Output: {lut_file} ==========') From 67a981ea00527fac1987ed48c22fa864e6f399b9 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 2 Feb 2023 15:52:58 +0800 Subject: [PATCH 05/28] fix matmul parameter checker --- sparta/specializer/kernels/matmul.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 2f3873e6..05e0bbdc 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -306,7 +306,7 @@ def _check_parameters(self, params: Dict[str, Any]): assert TK & (TK - 1) == 0 assert TN & (TN - 1) == 0 A_threads_per_row = (BM if self._transpose_A else BK) // 4 - B_threads_per_row = (BK if self._transpose_A else BN) // 4 + B_threads_per_row = (BK if self._transpose_B else BN) // 4 threads_per_block = (BM // TM) * (BN // TN) assert threads_per_block >= A_threads_per_row assert threads_per_block >= B_threads_per_row From bb91310729705a08160eb67a04bf4e005de88e46 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 2 Feb 2023 17:28:01 +0800 Subject: [PATCH 06/28] update LUT maker --- test/lut_maker/sparta_matmul.py | 152 +++++--------------------------- 1 file changed, 24 insertions(+), 128 deletions(-) diff --git a/test/lut_maker/sparta_matmul.py b/test/lut_maker/sparta_matmul.py index afee4f4e..c8796415 100644 --- a/test/lut_maker/sparta_matmul.py +++ b/test/lut_maker/sparta_matmul.py @@ -4,16 +4,16 @@ import os import logging import itertools -from typing import Dict, Tuple import torch import pandas as pd from sparta.specializer.kernels import SparTASparseMatMulKernel -from sparta.tesa import BCSIndexes -from sparta.testing import block_mask, profile +from sparta.testing import block_mask +SIZE = 4096 +RANDOM_SEED = 2022 SEARCH_SPACE = { 'mode': ['sdd', 'dsd', 'dds'], 'trans_A': [False, True], @@ -32,100 +32,10 @@ _logger.addHandler(_handler) -def prepare_data( - batch: int = 4, - M: int = 128, - K: int = 256, - N: int = 192, - granularity: Tuple[int, int] = (8, 8), - sparsity: float = 0.9, - mode: str = 'dds', - trans_A: bool = False, - trans_B: bool = False, - biased: bool = False, - requires_grad: bool = False, - random_seed: int = 2022, -): - inputs = ['A', 'B'] - outputs = ['C'] - shapes = { - 'A': (K, M) if trans_A else (M, K), - 'B': (N, K) if trans_B else (K, N), - 'C': (M, N), - } - if biased: - inputs.append('bias') - shapes['bias'] = (N, ) - - torch.manual_seed(random_seed) - data: Dict[str, torch.Tensor] = {} - for x in inputs: - data[f'input_{x}'] = torch.rand(size=(batch, *shapes[x]), device='cuda') - if requires_grad: - for y in outputs: - data[f'input_grad_{y}'] = torch.rand(size=(batch, *shapes[y]), device='cuda') - - sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - mask = block_mask(shapes[sparse_port], block=granularity, sparsity=sparsity, device='cuda') - add_mask(data, {sparse_port: mask}, sparse_port, 'input') - - if requires_grad: - for x in inputs: - data[f'input_{x}'].requires_grad = True - - input_A = data['input_A'].swapaxes(1, 2) if trans_A else data['input_A'] - input_B = data['input_B'].swapaxes(1, 2) if trans_B else data['input_B'] - data['target_C'] = torch.bmm(input_A, input_B) - if biased: - data['target_C'] += data['input_bias'].unsqueeze(1) - - if requires_grad: - data['target_C'].backward(data['input_grad_C']) - data['target_grad_A'] = data['input_A'].grad - data['input_A'].grad = None - data['target_grad_B'] = data['input_B'].grad - data['input_B'].grad = None - if biased: - data['target_grad_bias'] = data['input_bias'].grad - data['input_bias'].grad = None - - add_mask(data, {sparse_port: mask}, sparse_port, 'target') - - return data, {sparse_port: mask} - - -def add_mask( - data: Dict[str, torch.Tensor], - masks: Dict[str, torch.Tensor], - sparse_port: str, - stage: str, -): - for name, val in data.items(): - if name.startswith(stage) and name.endswith(sparse_port): - val *= masks[sparse_port] - - -def compress_data( - indexes: BCSIndexes, - sparse_port: str, - data: Dict[str, torch.Tensor], - masks: Dict[str, torch.Tensor], -): - for name in data: - if name.endswith(sparse_port): - data[name] = indexes.convert(data[name].detach()) - masks[sparse_port] = indexes.convert(masks[sparse_port].to(torch.float32)).to(torch.uint8) - if sparse_port in ['A', 'B']: - data[f'input_{sparse_port}'].requires_grad = True - - -def check_results(data: Dict[str, torch.Tensor]): - for name, val in data.items(): - if name.startswith('target_'): - torch.testing.assert_close(val, data[name.replace('target', 'output')], msg=name) - - -def test_sparse_matmul_kernel( +def test_sparta_matmul_kernel( + A: torch.Tensor, + B: torch.Tensor, + mask: torch.Tensor, mode: str, trans_A: bool, trans_B: bool, @@ -135,29 +45,19 @@ def test_sparse_matmul_kernel( TM: int, TK: int, TN: int, - biased: bool = False, - compressed: bool = True, - batch: int = 1, - M: int = 4096, - K: int = 4096, - N: int = 4096, - granularity: Tuple[int, int] = (1, 1), - sparsity: float = 0, ): - data, masks = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, False) + sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] + kernel = SparTASparseMatMulKernel( + mode=mode, + biased=False, + transpose_A=trans_A, + transpose_B=trans_B, + compressed=True, + ) try: - kernel = SparTASparseMatMulKernel( - mode=mode, - biased=biased, - transpose_A=trans_A, - transpose_B=trans_B, - compressed=compressed, - ) - - for sparse_port, mask in masks.items(): - kernel.ports[sparse_port].set_mask(mask) - kernel.set_shape(batch, M, K, N) + kernel.ports[sparse_port].set_mask(mask) + kernel.set_shape(1, SIZE, SIZE, SIZE) kernel.compile({ 'BLOCK_SIZE_M_VALUE': BM, 'BLOCK_SIZE_K_VALUE': BK, @@ -166,16 +66,7 @@ def test_sparse_matmul_kernel( 'THREAD_SIZE_K_VALUE': TK, 'THREAD_SIZE_N_VALUE': TN, }) - - sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - if compressed: - compress_data(kernel.ports[sparse_port].indexes, sparse_port, data, masks) - - inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] - input_data = [data[f'input_{x}'].detach() for x in inputs] - - latency = profile(kernel, input_data, num_warmups=10, num_iters=10, cuda=False) - torch.cuda.synchronize() + latency = kernel.test([A, B], num_warmups=10, num_iters=10, cuda=False) except: latency = float('inf') @@ -209,8 +100,13 @@ def test_sparse_matmul_kernel( with open(log_file, 'w') as f: f.write(','.join(keys) + ',latency\n') + torch.manual_seed(2022) + A = torch.rand(size=(1, SIZE, SIZE), device='cuda') + B = torch.rand(size=(1, SIZE, SIZE), device='cuda') + mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') + for i, params in enumerate(itertools.product(*values)): - latency = test_sparse_matmul_kernel(**{k: v for k, v in zip(keys, params)}) + latency = test_sparta_matmul_kernel(A, B, mask, **{k: v for k, v in zip(keys, params)}) with open(log_file, 'a') as f: f.write(','.join([str(x) for x in params]) + f',{latency}\n') _logger.info(f'[{i} / {num}] {params} => {latency} ms') From 624f23460d2221430d3d327b2a44673eed2bf601 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 2 Feb 2023 17:55:28 +0800 Subject: [PATCH 07/28] fix LUT maker: aggregate log by idxmin --- test/lut_maker/sparta_matmul.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/test/lut_maker/sparta_matmul.py b/test/lut_maker/sparta_matmul.py index c8796415..47c35235 100644 --- a/test/lut_maker/sparta_matmul.py +++ b/test/lut_maker/sparta_matmul.py @@ -25,6 +25,7 @@ 'TK': [2, 4, 8, 16], 'TN': [2, 4, 8, 16], } +HYPER_PARAMS = ['mode', 'trans_A', 'trans_B', 'BM', 'BK', 'BN'] _logger = logging.Logger(__name__) @@ -100,7 +101,7 @@ def test_sparta_matmul_kernel( with open(log_file, 'w') as f: f.write(','.join(keys) + ',latency\n') - torch.manual_seed(2022) + torch.manual_seed(RANDOM_SEED) A = torch.rand(size=(1, SIZE, SIZE), device='cuda') B = torch.rand(size=(1, SIZE, SIZE), device='cuda') mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') @@ -112,8 +113,8 @@ def test_sparta_matmul_kernel( _logger.info(f'[{i} / {num}] {params} => {latency} ms') df = pd.read_csv(log_file) - df = df.groupby(['mode', 'trans_A', 'trans_B', 'BM', 'BK', 'BN']).min('latency') + df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})] with open(lut_file, 'w') as f: - f.write(df.reset_index().to_csv(index=False)) + f.write(df.reset_index(drop=True).to_csv(index=False)) _logger.info(f'========== Finished. Output: {lut_file} ==========') From d6e63866f0dcbff6b9446e0ed24925aeea296cce Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Mon, 6 Feb 2023 12:19:02 +0800 Subject: [PATCH 08/28] fix sparse softmax & BCSR kernel; add 61 LUTs --- .../look_up_tables/matmul.sparta.61.csv | 1501 +++++++++++++++++ .../softmax.backward.sparta.61.csv | 26 + .../softmax.forward.sparta.61.csv | 26 + sparta/specializer/kernels/matmul.py | 13 +- sparta/specializer/kernels/softmax.py | 63 +- .../sparta_sparse_softmax_backward.cuh.j2 | 18 +- .../sparta_sparse_softmax_forward.cuh.j2 | 20 +- sparta/tesa/block_compressed.py | 21 +- sparta/tesa/templates/block_compressed.cu.j2 | 12 +- test/lut_maker/sparta_matmul.py | 10 +- test/lut_maker/sparta_softmax.py | 122 ++ test/unit/test_tesa.py | 9 +- test/unit/test_tuners.py | 16 +- 13 files changed, 1796 insertions(+), 61 deletions(-) create mode 100644 sparta/specializer/kernels/look_up_tables/matmul.sparta.61.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.61.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.61.csv create mode 100644 test/lut_maker/sparta_softmax.py diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.61.csv b/sparta/specializer/kernels/look_up_tables/matmul.sparta.61.csv new file mode 100644 index 00000000..5c542710 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/matmul.sparta.61.csv @@ -0,0 +1,1501 @@ +mode,trans_A,trans_B,BM,BK,BN,TM,TK,TN,latency +dds,False,False,8,8,8,2,2,2,346.447265625 +dds,False,False,8,8,16,4,8,2,185.52266845703124 +dds,False,False,8,8,32,2,2,8,136.56832275390624 +dds,False,False,8,8,64,2,2,16,130.3582763671875 +dds,False,False,8,8,128,2,2,2,inf +dds,False,False,8,16,8,2,4,2,325.844775390625 +dds,False,False,8,16,16,2,4,2,176.13494873046875 +dds,False,False,8,16,32,2,8,4,134.83704833984376 +dds,False,False,8,16,64,4,8,4,136.79031982421876 +dds,False,False,8,16,128,4,2,8,135.889501953125 +dds,False,False,8,32,8,2,16,2,317.956201171875 +dds,False,False,8,32,16,4,2,2,174.03084716796874 +dds,False,False,8,32,32,4,8,2,140.1439208984375 +dds,False,False,8,32,64,4,2,2,140.1817138671875 +dds,False,False,8,32,128,4,2,8,132.7256591796875 +dds,False,False,8,64,8,2,16,2,316.356591796875 +dds,False,False,8,64,16,2,16,2,176.83394775390624 +dds,False,False,8,64,32,2,16,4,140.68807373046874 +dds,False,False,8,64,64,2,4,2,140.6648193359375 +dds,False,False,8,64,128,4,2,2,140.1185302734375 +dds,False,False,8,128,8,2,2,2,inf +dds,False,False,8,128,16,2,2,2,185.685791015625 +dds,False,False,8,128,32,2,4,2,141.045654296875 +dds,False,False,8,128,64,2,4,2,142.7243896484375 +dds,False,False,8,128,128,2,2,2,inf +dds,False,False,16,8,8,4,8,2,180.28349609375 +dds,False,False,16,8,16,4,4,2,90.68286743164064 +dds,False,False,16,8,32,2,8,8,63.76509399414063 +dds,False,False,16,8,64,2,2,16,57.508349609375 +dds,False,False,16,8,128,4,4,16,66.71492919921874 +dds,False,False,16,16,8,2,8,2,164.16531982421876 +dds,False,False,16,16,16,2,2,4,90.65840454101564 +dds,False,False,16,16,32,2,4,8,65.7502197265625 +dds,False,False,16,16,64,4,2,8,68.99260864257812 +dds,False,False,16,16,128,16,16,2,68.54666137695312 +dds,False,False,16,32,8,2,8,4,157.861474609375 +dds,False,False,16,32,16,2,8,2,86.76648559570313 +dds,False,False,16,32,32,2,16,4,70.3776611328125 +dds,False,False,16,32,64,4,4,2,70.49021606445312 +dds,False,False,16,32,128,8,4,2,71.19052734375 +dds,False,False,16,64,8,2,16,2,157.70018310546874 +dds,False,False,16,64,16,4,4,2,89.58597412109376 +dds,False,False,16,64,32,2,4,2,71.65633544921874 +dds,False,False,16,64,64,2,16,2,72.54589233398437 +dds,False,False,16,64,128,4,16,4,71.78956909179688 +dds,False,False,16,128,8,2,2,2,130.49764404296874 +dds,False,False,16,128,16,2,4,2,92.35322875976564 +dds,False,False,16,128,32,2,4,2,72.428955078125 +dds,False,False,16,128,64,2,2,2,71.17189331054688 +dds,False,False,16,128,128,2,2,2,inf +dds,False,False,32,8,8,8,2,2,101.254248046875 +dds,False,False,32,8,16,8,4,2,53.191357421875 +dds,False,False,32,8,32,16,4,2,33.65990295410156 +dds,False,False,32,8,64,4,8,16,33.06301574707031 +dds,False,False,32,8,128,2,2,2,inf +dds,False,False,32,16,8,2,16,4,91.6569091796875 +dds,False,False,32,16,16,8,4,2,49.24136047363281 +dds,False,False,32,16,32,4,8,8,34.12129211425781 +dds,False,False,32,16,64,4,8,4,34.86095275878906 +dds,False,False,32,16,128,2,2,2,inf +dds,False,False,32,32,8,4,4,2,92.7152099609375 +dds,False,False,32,32,16,8,16,2,48.76431274414063 +dds,False,False,32,32,32,8,2,2,37.743002319335936 +dds,False,False,32,32,64,8,4,2,35.66120910644531 +dds,False,False,32,32,128,2,2,2,inf +dds,False,False,32,64,8,4,4,2,93.60137939453124 +dds,False,False,32,64,16,4,16,2,52.41231079101563 +dds,False,False,32,64,32,8,4,2,38.35975646972656 +dds,False,False,32,64,64,4,8,2,37.809765625 +dds,False,False,32,64,128,2,2,2,inf +dds,False,False,32,128,8,2,16,2,110.852392578125 +dds,False,False,32,128,16,4,16,2,62.65753784179688 +dds,False,False,32,128,32,4,8,2,41.68714294433594 +dds,False,False,32,128,64,4,4,2,38.31879577636719 +dds,False,False,32,128,128,2,2,2,inf +dds,False,False,64,8,8,16,2,2,66.14743041992188 +dds,False,False,64,8,16,16,8,2,32.39331970214844 +dds,False,False,64,8,32,8,8,8,23.92657928466797 +dds,False,False,64,8,64,2,2,2,inf +dds,False,False,64,8,128,2,2,2,inf +dds,False,False,64,16,8,8,4,2,59.19006958007813 +dds,False,False,64,16,16,16,2,2,30.32555541992188 +dds,False,False,64,16,32,8,4,4,24.202114868164063 +dds,False,False,64,16,64,2,2,2,inf +dds,False,False,64,16,128,2,2,2,inf +dds,False,False,64,32,8,8,16,2,76.00281372070313 +dds,False,False,64,32,16,8,8,2,43.748556518554686 +dds,False,False,64,32,32,16,4,2,28.074496459960937 +dds,False,False,64,32,64,2,2,2,inf +dds,False,False,64,32,128,2,2,2,inf +dds,False,False,64,64,8,4,2,2,84.72975463867188 +dds,False,False,64,64,16,4,16,2,49.43759460449219 +dds,False,False,64,64,32,8,8,2,30.441778564453124 +dds,False,False,64,64,64,2,2,2,inf +dds,False,False,64,64,128,2,2,2,inf +dds,False,False,64,128,8,2,2,2,130.02669677734374 +dds,False,False,64,128,16,4,16,2,60.84157104492188 +dds,False,False,64,128,32,4,16,2,38.609613037109376 +dds,False,False,64,128,64,2,2,2,inf +dds,False,False,64,128,128,2,2,2,inf +dds,False,False,128,8,8,16,4,4,62.37511596679688 +dds,False,False,128,8,16,16,2,4,29.57547607421875 +dds,False,False,128,8,32,2,2,2,inf +dds,False,False,128,8,64,2,2,2,inf +dds,False,False,128,8,128,2,2,2,inf +dds,False,False,128,16,8,16,8,2,57.271295166015626 +dds,False,False,128,16,16,16,4,2,28.65172424316406 +dds,False,False,128,16,32,2,2,2,inf +dds,False,False,128,16,64,2,2,2,inf +dds,False,False,128,16,128,2,2,2,inf +dds,False,False,128,32,8,8,16,2,78.53148193359375 +dds,False,False,128,32,16,8,8,2,42.23107299804688 +dds,False,False,128,32,32,2,2,2,inf +dds,False,False,128,32,64,2,2,2,inf +dds,False,False,128,32,128,2,2,2,inf +dds,False,False,128,64,8,4,2,2,99.802197265625 +dds,False,False,128,64,16,8,8,2,52.353125 +dds,False,False,128,64,32,2,2,2,inf +dds,False,False,128,64,64,2,2,2,inf +dds,False,False,128,64,128,2,2,2,inf +dds,False,False,128,128,8,2,2,2,inf +dds,False,False,128,128,16,2,2,2,inf +dds,False,False,128,128,32,2,2,2,inf +dds,False,False,128,128,64,2,2,2,inf +dds,False,False,128,128,128,2,2,2,inf +dds,False,True,8,8,8,2,2,2,397.6466552734375 +dds,False,True,8,8,16,2,4,8,267.2733154296875 +dds,False,True,8,8,32,2,4,16,297.122412109375 +dds,False,True,8,8,64,8,2,8,310.2191650390625 +dds,False,True,8,8,128,8,4,8,113.92952880859374 +dds,False,True,8,16,8,2,2,2,205.70625 +dds,False,True,8,16,16,2,16,4,155.09473876953126 +dds,False,True,8,16,32,8,4,2,120.915966796875 +dds,False,True,8,16,64,8,8,4,187.51968994140623 +dds,False,True,8,16,128,4,4,16,205.458935546875 +dds,False,True,8,32,8,2,8,2,263.9460205078125 +dds,False,True,8,32,16,2,2,2,139.62332763671876 +dds,False,True,8,32,32,4,16,2,140.8869384765625 +dds,False,True,8,32,64,4,4,2,112.25426025390624 +dds,False,True,8,32,128,8,4,2,140.99168701171874 +dds,False,True,8,64,8,2,4,2,378.4112060546875 +dds,False,True,8,64,16,2,16,2,185.98646240234376 +dds,False,True,8,64,32,4,16,2,157.18707275390625 +dds,False,True,8,64,64,4,16,2,155.3437744140625 +dds,False,True,8,64,128,4,2,2,157.50615234375 +dds,False,True,8,128,8,2,2,2,inf +dds,False,True,8,128,16,2,16,2,308.24130859375 +dds,False,True,8,128,32,2,2,2,287.0211669921875 +dds,False,True,8,128,64,2,2,2,282.979833984375 +dds,False,True,8,128,128,2,2,2,inf +dds,False,True,16,8,8,4,2,2,200.21104736328124 +dds,False,True,16,8,16,2,8,4,196.3701293945313 +dds,False,True,16,8,32,2,8,16,199.3069580078125 +dds,False,True,16,8,64,2,8,16,146.78487548828124 +dds,False,True,16,8,128,16,8,8,124.33868408203124 +dds,False,True,16,16,8,2,8,2,102.47249755859374 +dds,False,True,16,16,16,4,8,4,101.76768188476562 +dds,False,True,16,16,32,16,16,2,95.5430908203125 +dds,False,True,16,16,64,4,2,16,101.10484619140624 +dds,False,True,16,16,128,4,4,8,76.63267822265625 +dds,False,True,16,32,8,2,2,2,122.96375732421876 +dds,False,True,16,32,16,4,2,2,81.21199340820313 +dds,False,True,16,32,32,8,16,2,62.187310791015626 +dds,False,True,16,32,64,4,16,2,69.6859619140625 +dds,False,True,16,32,128,4,4,4,57.88958740234375 +dds,False,True,16,64,8,2,2,2,164.5559814453125 +dds,False,True,16,64,16,4,8,2,115.05264892578126 +dds,False,True,16,64,32,8,16,2,95.08392944335938 +dds,False,True,16,64,64,8,2,2,90.81487426757812 +dds,False,True,16,64,128,4,4,4,89.23197631835937 +dds,False,True,16,128,8,2,8,2,238.734326171875 +dds,False,True,16,128,16,4,2,2,186.3947265625 +dds,False,True,16,128,32,4,2,2,157.79359130859376 +dds,False,True,16,128,64,4,2,2,152.9038818359375 +dds,False,True,16,128,128,2,2,2,inf +dds,False,True,32,8,8,4,2,4,103.8098388671875 +dds,False,True,32,8,16,2,8,16,87.20722045898438 +dds,False,True,32,8,32,2,2,16,46.59742736816406 +dds,False,True,32,8,64,16,8,4,96.6781982421875 +dds,False,True,32,8,128,2,2,2,inf +dds,False,True,32,16,8,4,2,2,63.67701416015625 +dds,False,True,32,16,16,8,8,2,51.837823486328126 +dds,False,True,32,16,32,2,8,16,49.586892700195314 +dds,False,True,32,16,64,2,2,8,47.16410827636719 +dds,False,True,32,16,128,2,2,2,inf +dds,False,True,32,32,8,4,2,2,94.58953247070312 +dds,False,True,32,32,16,8,2,2,59.1730712890625 +dds,False,True,32,32,32,16,8,2,44.52433776855469 +dds,False,True,32,32,64,8,2,4,38.19192199707031 +dds,False,True,32,32,128,2,2,2,inf +dds,False,True,32,64,8,4,4,2,116.1966552734375 +dds,False,True,32,64,16,8,8,2,78.77529296875 +dds,False,True,32,64,32,8,16,2,60.08504028320313 +dds,False,True,32,64,64,4,16,4,56.33953857421875 +dds,False,True,32,64,128,2,2,2,inf +dds,False,True,32,128,8,2,2,2,180.709375 +dds,False,True,32,128,16,4,8,2,118.262158203125 +dds,False,True,32,128,32,4,16,2,97.29157104492188 +dds,False,True,32,128,64,4,8,4,88.31262817382813 +dds,False,True,32,128,128,2,2,2,inf +dds,False,True,64,8,8,8,2,4,65.49246215820312 +dds,False,True,64,8,16,2,4,16,48.14632873535156 +dds,False,True,64,8,32,2,4,16,52.657666015625 +dds,False,True,64,8,64,2,2,2,inf +dds,False,True,64,8,128,2,2,2,inf +dds,False,True,64,16,8,8,8,2,51.87501831054688 +dds,False,True,64,16,16,16,4,2,32.64102478027344 +dds,False,True,64,16,32,16,16,2,29.38050537109375 +dds,False,True,64,16,64,2,2,2,inf +dds,False,True,64,16,128,2,2,2,inf +dds,False,True,64,32,8,8,4,2,84.91397094726562 +dds,False,True,64,32,16,16,16,2,49.67843933105469 +dds,False,True,64,32,32,8,16,4,35.14662475585938 +dds,False,True,64,32,64,2,2,2,inf +dds,False,True,64,32,128,2,2,2,inf +dds,False,True,64,64,8,4,16,2,99.67421264648438 +dds,False,True,64,64,16,8,2,2,65.002392578125 +dds,False,True,64,64,32,16,4,2,45.50236206054687 +dds,False,True,64,64,64,2,2,2,inf +dds,False,True,64,64,128,2,2,2,inf +dds,False,True,64,128,8,4,8,2,164.678125 +dds,False,True,64,128,16,4,8,2,91.44155883789062 +dds,False,True,64,128,32,8,8,2,62.4695068359375 +dds,False,True,64,128,64,2,2,2,inf +dds,False,True,64,128,128,2,2,2,inf +dds,False,True,128,8,8,16,2,4,62.599884033203125 +dds,False,True,128,8,16,16,8,4,32.633242797851565 +dds,False,True,128,8,32,2,2,2,inf +dds,False,True,128,8,64,2,2,2,inf +dds,False,True,128,8,128,2,2,2,inf +dds,False,True,128,16,8,16,2,2,47.714697265625 +dds,False,True,128,16,16,16,8,2,29.39852905273437 +dds,False,True,128,16,32,2,2,2,inf +dds,False,True,128,16,64,2,2,2,inf +dds,False,True,128,16,128,2,2,2,inf +dds,False,True,128,32,8,8,8,2,84.24837036132813 +dds,False,True,128,32,16,16,4,2,45.77205810546875 +dds,False,True,128,32,32,2,2,2,inf +dds,False,True,128,32,64,2,2,2,inf +dds,False,True,128,32,128,2,2,2,inf +dds,False,True,128,64,8,4,16,4,112.75662841796876 +dds,False,True,128,64,16,8,16,2,60.74285888671875 +dds,False,True,128,64,32,2,2,2,inf +dds,False,True,128,64,64,2,2,2,inf +dds,False,True,128,64,128,2,2,2,inf +dds,False,True,128,128,8,2,2,2,inf +dds,False,True,128,128,16,2,2,2,inf +dds,False,True,128,128,32,2,2,2,inf +dds,False,True,128,128,64,2,2,2,inf +dds,False,True,128,128,128,2,2,2,inf +dds,True,False,8,8,8,2,2,2,364.7280029296875 +dds,True,False,8,8,16,4,8,4,196.1740234375 +dds,True,False,8,8,32,2,4,8,137.6901123046875 +dds,True,False,8,8,64,2,2,16,131.92947998046876 +dds,True,False,8,8,128,2,2,2,inf +dds,True,False,8,16,8,2,8,2,351.3238525390625 +dds,True,False,8,16,16,2,4,2,193.1760620117188 +dds,True,False,8,16,32,2,8,8,136.7863037109375 +dds,True,False,8,16,64,8,4,2,139.0878662109375 +dds,True,False,8,16,128,4,16,8,138.1424072265625 +dds,True,False,8,32,8,2,2,2,349.33310546875 +dds,True,False,8,32,16,4,2,2,189.50655517578124 +dds,True,False,8,32,32,2,4,4,140.9553466796875 +dds,True,False,8,32,64,4,2,2,140.61055908203124 +dds,True,False,8,32,128,8,4,4,135.57197265625 +dds,True,False,8,64,8,2,16,2,352.42158203125 +dds,True,False,8,64,16,2,8,2,191.4159912109375 +dds,True,False,8,64,32,2,16,2,141.60804443359376 +dds,True,False,8,64,64,2,8,2,140.1777099609375 +dds,True,False,8,64,128,4,8,2,138.94029541015624 +dds,True,False,8,128,8,2,16,2,345.75576171875 +dds,True,False,8,128,16,2,4,2,195.1240234375 +dds,True,False,8,128,32,2,2,2,141.89752197265625 +dds,True,False,8,128,64,2,16,2,138.83187255859374 +dds,True,False,8,128,128,2,2,2,inf +dds,True,False,16,8,8,4,4,2,184.3178466796875 +dds,True,False,16,8,16,2,2,4,97.78770141601562 +dds,True,False,16,8,32,2,2,8,66.57802124023438 +dds,True,False,16,8,64,2,8,16,59.66036987304688 +dds,True,False,16,8,128,4,4,16,67.42958374023438 +dds,True,False,16,16,8,2,8,2,175.861865234375 +dds,True,False,16,16,16,8,4,2,93.49168701171877 +dds,True,False,16,16,32,2,2,8,63.482470703125 +dds,True,False,16,16,64,16,16,2,70.76022338867188 +dds,True,False,16,16,128,16,4,2,69.24093627929688 +dds,True,False,16,32,8,2,8,2,169.58443603515624 +dds,True,False,16,32,16,2,8,4,90.878564453125 +dds,True,False,16,32,32,2,16,4,69.61937255859375 +dds,True,False,16,32,64,2,8,4,70.58042602539062 +dds,True,False,16,32,128,2,16,8,71.03016967773438 +dds,True,False,16,64,8,2,4,2,170.38345947265626 +dds,True,False,16,64,16,4,4,2,92.02206420898438 +dds,True,False,16,64,32,2,16,2,72.21043090820312 +dds,True,False,16,64,64,2,8,2,72.28897094726562 +dds,True,False,16,64,128,2,8,4,70.91937255859375 +dds,True,False,16,128,8,2,2,2,150.3752197265625 +dds,True,False,16,128,16,2,8,2,93.9304931640625 +dds,True,False,16,128,32,2,8,2,72.27095336914063 +dds,True,False,16,128,64,2,2,2,70.86336059570313 +dds,True,False,16,128,128,2,2,2,inf +dds,True,False,32,8,8,8,2,2,171.532666015625 +dds,True,False,32,8,16,8,8,2,60.76968994140625 +dds,True,False,32,8,32,8,4,4,38.84349365234375 +dds,True,False,32,8,64,4,8,16,34.237945556640625 +dds,True,False,32,8,128,2,2,2,inf +dds,True,False,32,16,8,4,16,2,104.82421875 +dds,True,False,32,16,16,8,16,2,58.68594970703125 +dds,True,False,32,16,32,8,4,4,38.49779052734375 +dds,True,False,32,16,64,4,16,16,36.58433532714844 +dds,True,False,32,16,128,2,2,2,inf +dds,True,False,32,32,8,4,8,2,131.81746826171874 +dds,True,False,32,32,16,8,8,2,67.30023803710938 +dds,True,False,32,32,32,8,4,4,42.1317626953125 +dds,True,False,32,32,64,4,4,4,36.757080078125 +dds,True,False,32,32,128,2,2,2,inf +dds,True,False,32,64,8,4,16,2,131.2527099609375 +dds,True,False,32,64,16,8,2,2,69.99695434570313 +dds,True,False,32,64,32,8,8,2,45.15665893554687 +dds,True,False,32,64,64,4,8,4,37.94001770019531 +dds,True,False,32,64,128,2,2,2,inf +dds,True,False,32,128,8,2,2,2,149.5387939453125 +dds,True,False,32,128,16,4,16,2,77.89833374023438 +dds,True,False,32,128,32,4,8,2,50.59368896484375 +dds,True,False,32,128,64,4,2,2,42.87426452636719 +dds,True,False,32,128,128,2,2,2,inf +dds,True,False,64,8,8,8,8,4,272.1482666015625 +dds,True,False,64,8,16,8,8,4,78.639306640625 +dds,True,False,64,8,32,8,2,8,45.4640625 +dds,True,False,64,8,64,2,2,2,inf +dds,True,False,64,8,128,2,2,2,inf +dds,True,False,64,16,8,8,2,2,152.18533935546876 +dds,True,False,64,16,16,16,4,2,79.01703491210938 +dds,True,False,64,16,32,8,2,8,48.4632568359375 +dds,True,False,64,16,64,2,2,2,inf +dds,True,False,64,16,128,2,2,2,inf +dds,True,False,64,32,8,8,8,2,180.0457275390625 +dds,True,False,64,32,16,16,16,2,92.72125244140624 +dds,True,False,64,32,32,16,4,2,51.453643798828125 +dds,True,False,64,32,64,2,2,2,inf +dds,True,False,64,32,128,2,2,2,inf +dds,True,False,64,64,8,8,8,2,184.9154541015625 +dds,True,False,64,64,16,8,4,2,97.60112915039062 +dds,True,False,64,64,32,8,4,2,54.843597412109375 +dds,True,False,64,64,64,2,2,2,inf +dds,True,False,64,64,128,2,2,2,inf +dds,True,False,64,128,8,4,4,2,198.31817626953125 +dds,True,False,64,128,16,4,2,2,104.98519287109374 +dds,True,False,64,128,32,8,2,2,57.97058715820312 +dds,True,False,64,128,64,2,2,2,inf +dds,True,False,64,128,128,2,2,2,inf +dds,True,False,128,8,8,2,2,2,inf +dds,True,False,128,8,16,8,2,8,132.52567138671876 +dds,True,False,128,8,32,2,2,2,inf +dds,True,False,128,8,64,2,2,2,inf +dds,True,False,128,8,128,2,2,2,inf +dds,True,False,128,16,8,16,8,2,264.026416015625 +dds,True,False,128,16,16,16,4,2,136.8701904296875 +dds,True,False,128,16,32,2,2,2,inf +dds,True,False,128,16,64,2,2,2,inf +dds,True,False,128,16,128,2,2,2,inf +dds,True,False,128,32,8,8,8,2,294.5623779296875 +dds,True,False,128,32,16,16,4,2,149.01636962890626 +dds,True,False,128,32,32,2,2,2,inf +dds,True,False,128,32,64,2,2,2,inf +dds,True,False,128,32,128,2,2,2,inf +dds,True,False,128,64,8,8,8,2,304.08681640625 +dds,True,False,128,64,16,8,16,2,155.3440673828125 +dds,True,False,128,64,32,2,2,2,inf +dds,True,False,128,64,64,2,2,2,inf +dds,True,False,128,64,128,2,2,2,inf +dds,True,False,128,128,8,2,2,2,inf +dds,True,False,128,128,16,2,2,2,inf +dds,True,False,128,128,32,2,2,2,inf +dds,True,False,128,128,64,2,2,2,inf +dds,True,False,128,128,128,2,2,2,inf +dds,True,True,8,8,8,2,2,2,399.963330078125 +dds,True,True,8,8,16,2,4,8,406.4051025390625 +dds,True,True,8,8,32,2,2,16,314.0001708984375 +dds,True,True,8,8,64,8,2,8,292.6605224609375 +dds,True,True,8,8,128,8,4,8,320.0722900390625 +dds,True,True,8,16,8,2,2,2,208.4537109375 +dds,True,True,8,16,16,2,8,4,209.2707763671875 +dds,True,True,8,16,32,8,4,2,193.5134765625 +dds,True,True,8,16,64,2,8,16,213.47880859375 +dds,True,True,8,16,128,4,8,16,224.5900390625 +dds,True,True,8,32,8,2,16,2,273.5004638671875 +dds,True,True,8,32,16,2,4,2,140.41446533203126 +dds,True,True,8,32,32,4,16,2,141.41317138671874 +dds,True,True,8,32,64,2,16,4,105.32105712890623 +dds,True,True,8,32,128,8,4,2,141.51424560546874 +dds,True,True,8,64,8,2,2,2,386.3058349609375 +dds,True,True,8,64,16,2,2,2,192.13475341796877 +dds,True,True,8,64,32,4,2,2,159.93128662109376 +dds,True,True,8,64,64,4,8,2,156.082666015625 +dds,True,True,8,64,128,4,2,2,159.44642333984376 +dds,True,True,8,128,8,2,2,2,inf +dds,True,True,8,128,16,2,8,2,310.764306640625 +dds,True,True,8,128,32,2,2,2,290.8535888671875 +dds,True,True,8,128,64,2,2,2,286.5373291015625 +dds,True,True,8,128,128,2,2,2,inf +dds,True,True,16,8,8,4,8,2,201.5918090820313 +dds,True,True,16,8,16,2,4,8,201.9950439453125 +dds,True,True,16,8,32,2,4,16,104.681982421875 +dds,True,True,16,8,64,2,2,16,184.5109619140625 +dds,True,True,16,8,128,16,2,8,167.90589599609376 +dds,True,True,16,16,8,2,2,2,109.98385009765624 +dds,True,True,16,16,16,4,16,2,105.0525634765625 +dds,True,True,16,16,32,2,2,8,75.39658813476562 +dds,True,True,16,16,64,4,2,16,103.42451171875 +dds,True,True,16,16,128,4,4,8,86.31336669921875 +dds,True,True,16,32,8,2,2,2,144.20714111328124 +dds,True,True,16,32,16,4,16,2,89.20360107421875 +dds,True,True,16,32,32,8,2,2,70.08348388671875 +dds,True,True,16,32,64,4,16,2,70.05478515625 +dds,True,True,16,32,128,4,8,4,57.00474853515625 +dds,True,True,16,64,8,2,4,2,175.2057861328125 +dds,True,True,16,64,16,4,4,2,118.56875 +dds,True,True,16,64,32,8,16,2,97.03975830078124 +dds,True,True,16,64,64,4,2,4,92.81841430664062 +dds,True,True,16,64,128,4,2,4,90.65912475585938 +dds,True,True,16,128,8,2,16,2,243.5973876953125 +dds,True,True,16,128,16,4,16,2,182.8135986328125 +dds,True,True,16,128,32,4,2,2,157.23704833984374 +dds,True,True,16,128,64,4,16,2,154.74503173828126 +dds,True,True,16,128,128,2,2,2,inf +dds,True,True,32,8,8,8,8,2,175.79775390625 +dds,True,True,32,8,16,2,2,8,95.48892211914062 +dds,True,True,32,8,32,2,2,16,79.95289306640625 +dds,True,True,32,8,64,4,8,16,95.31821899414062 +dds,True,True,32,8,128,2,2,2,inf +dds,True,True,32,16,8,4,16,2,109.945654296875 +dds,True,True,32,16,16,8,2,2,63.509710693359374 +dds,True,True,32,16,32,4,4,8,46.339776611328126 +dds,True,True,32,16,64,8,2,2,43.942501831054685 +dds,True,True,32,16,128,2,2,2,inf +dds,True,True,32,32,8,4,2,2,144.12000732421876 +dds,True,True,32,32,16,8,16,2,81.15087280273437 +dds,True,True,32,32,32,16,16,2,52.22266845703125 +dds,True,True,32,32,64,8,4,4,41.66102905273438 +dds,True,True,32,32,128,2,2,2,inf +dds,True,True,32,64,8,4,2,2,158.60653076171874 +dds,True,True,32,64,16,8,8,2,97.88969116210936 +dds,True,True,32,64,32,8,8,2,70.4606201171875 +dds,True,True,32,64,64,4,16,4,60.7115234375 +dds,True,True,32,64,128,2,2,2,inf +dds,True,True,32,128,8,4,4,2,197.12603759765625 +dds,True,True,32,128,16,4,4,2,132.664111328125 +dds,True,True,32,128,32,4,2,2,105.73443603515624 +dds,True,True,32,128,64,4,4,4,92.85127563476564 +dds,True,True,32,128,128,2,2,2,inf +dds,True,True,64,8,8,8,2,4,273.481103515625 +dds,True,True,64,8,16,8,8,4,80.97750854492188 +dds,True,True,64,8,32,8,2,4,51.42496337890625 +dds,True,True,64,8,64,2,2,2,inf +dds,True,True,64,8,128,2,2,2,inf +dds,True,True,64,16,8,8,2,2,155.27698974609376 +dds,True,True,64,16,16,16,16,2,81.091162109375 +dds,True,True,64,16,32,8,8,4,50.919525146484375 +dds,True,True,64,16,64,2,2,2,inf +dds,True,True,64,16,128,2,2,2,inf +dds,True,True,64,32,8,8,16,2,186.0886474609375 +dds,True,True,64,32,16,16,8,2,98.43384399414064 +dds,True,True,64,32,32,16,4,2,58.27532958984375 +dds,True,True,64,32,64,2,2,2,inf +dds,True,True,64,32,128,2,2,2,inf +dds,True,True,64,64,8,4,4,2,200.7953369140625 +dds,True,True,64,64,16,8,16,2,112.19527587890624 +dds,True,True,64,64,32,16,16,2,69.02579345703126 +dds,True,True,64,64,64,2,2,2,inf +dds,True,True,64,64,128,2,2,2,inf +dds,True,True,64,128,8,4,4,2,226.6389404296875 +dds,True,True,64,128,16,4,8,2,133.59820556640625 +dds,True,True,64,128,32,8,8,2,85.59329223632812 +dds,True,True,64,128,64,2,2,2,inf +dds,True,True,64,128,128,2,2,2,inf +dds,True,True,128,8,8,2,2,2,inf +dds,True,True,128,8,16,8,2,8,131.95755615234376 +dds,True,True,128,8,32,2,2,2,inf +dds,True,True,128,8,64,2,2,2,inf +dds,True,True,128,8,128,2,2,2,inf +dds,True,True,128,16,8,16,4,2,264.8448974609375 +dds,True,True,128,16,16,8,16,4,138.47879638671876 +dds,True,True,128,16,32,2,2,2,inf +dds,True,True,128,16,64,2,2,2,inf +dds,True,True,128,16,128,2,2,2,inf +dds,True,True,128,32,8,16,16,2,297.62099609375 +dds,True,True,128,32,16,16,16,2,152.25159912109376 +dds,True,True,128,32,32,2,2,2,inf +dds,True,True,128,32,64,2,2,2,inf +dds,True,True,128,32,128,2,2,2,inf +dds,True,True,128,64,8,8,16,2,310.376025390625 +dds,True,True,128,64,16,8,8,2,161.4609375 +dds,True,True,128,64,32,2,2,2,inf +dds,True,True,128,64,64,2,2,2,inf +dds,True,True,128,64,128,2,2,2,inf +dds,True,True,128,128,8,2,2,2,inf +dds,True,True,128,128,16,2,2,2,inf +dds,True,True,128,128,32,2,2,2,inf +dds,True,True,128,128,64,2,2,2,inf +dds,True,True,128,128,128,2,2,2,inf +dsd,False,False,8,8,8,2,2,2,191.5309936523437 +dsd,False,False,8,8,16,4,4,2,142.9000244140625 +dsd,False,False,8,8,32,2,4,8,139.456103515625 +dsd,False,False,8,8,64,2,2,16,135.39779052734374 +dsd,False,False,8,8,128,2,2,2,inf +dsd,False,False,8,16,8,2,8,2,159.67723388671874 +dsd,False,False,8,16,16,2,2,2,141.12962646484374 +dsd,False,False,8,16,32,2,16,4,137.17186279296874 +dsd,False,False,8,16,64,8,16,4,138.9296630859375 +dsd,False,False,8,16,128,2,2,16,137.25614013671876 +dsd,False,False,8,32,8,2,16,2,165.84366455078126 +dsd,False,False,8,32,16,2,16,2,141.4387451171875 +dsd,False,False,8,32,32,8,8,2,141.15389404296874 +dsd,False,False,8,32,64,2,8,4,140.5824951171875 +dsd,False,False,8,32,128,2,2,2,inf +dsd,False,False,8,64,8,2,2,2,163.56669921875 +dsd,False,False,8,64,16,2,8,2,140.79385986328126 +dsd,False,False,8,64,32,4,4,2,141.07493896484374 +dsd,False,False,8,64,64,2,2,2,inf +dsd,False,False,8,64,128,2,2,2,inf +dsd,False,False,8,128,8,2,2,2,inf +dsd,False,False,8,128,16,2,2,2,140.94683837890625 +dsd,False,False,8,128,32,2,2,2,inf +dsd,False,False,8,128,64,2,2,2,inf +dsd,False,False,8,128,128,2,2,2,inf +dsd,False,False,16,8,8,4,2,2,109.829833984375 +dsd,False,False,16,8,16,2,2,4,67.8739990234375 +dsd,False,False,16,8,32,2,2,8,67.15709228515625 +dsd,False,False,16,8,64,2,2,16,59.45750122070312 +dsd,False,False,16,8,128,16,2,4,68.67128295898438 +dsd,False,False,16,16,8,2,2,2,82.72291870117188 +dsd,False,False,16,16,16,2,8,4,71.8519287109375 +dsd,False,False,16,16,32,2,4,8,66.51893920898438 +dsd,False,False,16,16,64,16,2,2,70.88312377929688 +dsd,False,False,16,16,128,16,4,2,68.99373779296874 +dsd,False,False,16,32,8,2,4,2,99.39988403320312 +dsd,False,False,16,32,16,4,16,2,72.16117553710937 +dsd,False,False,16,32,32,8,4,2,67.4820068359375 +dsd,False,False,16,32,64,2,2,4,71.96723022460938 +dsd,False,False,16,32,128,2,2,2,inf +dsd,False,False,16,64,8,2,2,2,97.66973266601562 +dsd,False,False,16,64,16,2,4,2,72.28272705078125 +dsd,False,False,16,64,32,2,2,2,72.5897216796875 +dsd,False,False,16,64,64,2,2,2,inf +dsd,False,False,16,64,128,2,2,2,inf +dsd,False,False,16,128,8,2,8,2,117.42637939453124 +dsd,False,False,16,128,16,2,2,2,76.32855224609375 +dsd,False,False,16,128,32,2,2,2,inf +dsd,False,False,16,128,64,2,2,2,inf +dsd,False,False,16,128,128,2,2,2,inf +dsd,False,False,32,8,8,8,2,2,77.55427856445313 +dsd,False,False,32,8,16,8,4,2,43.420654296875 +dsd,False,False,32,8,32,4,2,8,36.51061096191406 +dsd,False,False,32,8,64,16,8,2,34.29179382324219 +dsd,False,False,32,8,128,4,8,16,35.03697814941406 +dsd,False,False,32,16,8,4,2,2,58.53501586914062 +dsd,False,False,32,16,16,8,2,2,38.48028259277344 +dsd,False,False,32,16,32,16,16,2,34.754150390625 +dsd,False,False,32,16,64,8,8,2,35.0497802734375 +dsd,False,False,32,16,128,16,8,2,34.821832275390626 +dsd,False,False,32,32,8,4,4,2,81.96536254882812 +dsd,False,False,32,32,16,8,4,2,45.167718505859376 +dsd,False,False,32,32,32,8,8,2,37.82635498046875 +dsd,False,False,32,32,64,4,16,2,37.30174865722656 +dsd,False,False,32,32,128,2,2,2,inf +dsd,False,False,32,64,8,4,16,2,84.30346069335937 +dsd,False,False,32,64,16,4,16,2,51.648583984375 +dsd,False,False,32,64,32,4,2,2,38.18772583007812 +dsd,False,False,32,64,64,2,2,2,inf +dsd,False,False,32,64,128,2,2,2,inf +dsd,False,False,32,128,8,2,2,2,107.4925537109375 +dsd,False,False,32,128,16,4,4,2,63.33204345703125 +dsd,False,False,32,128,32,2,2,2,inf +dsd,False,False,32,128,64,2,2,2,inf +dsd,False,False,32,128,128,2,2,2,inf +dsd,False,False,64,8,8,8,8,4,66.64652709960937 +dsd,False,False,64,8,16,16,8,2,31.68522338867188 +dsd,False,False,64,8,32,8,2,8,26.4263671875 +dsd,False,False,64,8,64,8,2,8,25.222451782226564 +dsd,False,False,64,8,128,16,2,4,23.744717407226563 +dsd,False,False,64,16,8,8,4,2,51.46163330078125 +dsd,False,False,64,16,16,16,2,2,29.58365478515625 +dsd,False,False,64,16,32,8,4,4,25.577880859375 +dsd,False,False,64,16,64,8,2,4,24.477081298828125 +dsd,False,False,64,16,128,8,2,8,23.58855743408203 +dsd,False,False,64,32,8,8,8,2,74.31342163085938 +dsd,False,False,64,32,16,8,8,2,42.91983337402344 +dsd,False,False,64,32,32,8,8,4,28.376065063476563 +dsd,False,False,64,32,64,8,8,4,23.48838348388672 +dsd,False,False,64,32,128,2,2,2,inf +dsd,False,False,64,64,8,4,8,2,83.88802490234374 +dsd,False,False,64,64,16,4,4,2,49.419775390625 +dsd,False,False,64,64,32,8,16,2,30.32053833007813 +dsd,False,False,64,64,64,2,2,2,inf +dsd,False,False,64,64,128,2,2,2,inf +dsd,False,False,64,128,8,2,2,2,122.1655517578125 +dsd,False,False,64,128,16,4,16,2,64.27389526367188 +dsd,False,False,64,128,32,2,2,2,inf +dsd,False,False,64,128,64,2,2,2,inf +dsd,False,False,64,128,128,2,2,2,inf +dsd,False,False,128,8,8,16,4,4,59.268896484375 +dsd,False,False,128,8,16,8,4,8,30.05284118652344 +dsd,False,False,128,8,32,8,8,8,24.729702758789063 +dsd,False,False,128,8,64,16,2,4,22.941696166992188 +dsd,False,False,128,8,128,16,4,4,24.450729370117188 +dsd,False,False,128,16,8,16,8,2,48.74741821289062 +dsd,False,False,128,16,16,16,4,2,28.809829711914062 +dsd,False,False,128,16,32,8,2,4,24.48732147216797 +dsd,False,False,128,16,64,8,8,8,23.141786193847658 +dsd,False,False,128,16,128,8,2,8,22.406935119628905 +dsd,False,False,128,32,8,8,8,2,81.20964965820312 +dsd,False,False,128,32,16,8,8,2,41.73648376464844 +dsd,False,False,128,32,32,8,8,4,27.508633422851563 +dsd,False,False,128,32,64,16,4,4,24.02519073486328 +dsd,False,False,128,32,128,2,2,2,inf +dsd,False,False,128,64,8,4,2,2,101.65493774414062 +dsd,False,False,128,64,16,8,8,2,52.14627685546875 +dsd,False,False,128,64,32,8,4,2,30.089627075195317 +dsd,False,False,128,64,64,2,2,2,inf +dsd,False,False,128,64,128,2,2,2,inf +dsd,False,False,128,128,8,2,2,2,inf +dsd,False,False,128,128,16,2,2,2,inf +dsd,False,False,128,128,32,2,2,2,inf +dsd,False,False,128,128,64,2,2,2,inf +dsd,False,False,128,128,128,2,2,2,inf +dsd,False,True,8,8,8,2,4,2,195.2174072265625 +dsd,False,True,8,8,16,4,4,2,150.169189453125 +dsd,False,True,8,8,32,2,8,8,139.71568603515624 +dsd,False,True,8,8,64,2,8,16,138.71278076171876 +dsd,False,True,8,8,128,4,4,16,137.86224365234375 +dsd,False,True,8,16,8,2,2,2,200.561767578125 +dsd,False,True,8,16,16,2,2,2,144.09287109375 +dsd,False,True,8,16,32,8,4,2,148.59796142578125 +dsd,False,True,8,16,64,8,2,4,143.811376953125 +dsd,False,True,8,16,128,8,2,4,143.46802978515626 +dsd,False,True,8,32,8,2,8,2,276.1943115234375 +dsd,False,True,8,32,16,2,2,2,150.904833984375 +dsd,False,True,8,32,32,2,2,2,152.08131103515626 +dsd,False,True,8,32,64,2,16,4,136.442138671875 +dsd,False,True,8,32,128,2,2,2,inf +dsd,False,True,8,64,8,2,16,2,380.849365234375 +dsd,False,True,8,64,16,2,2,2,190.4585693359375 +dsd,False,True,8,64,32,4,2,2,164.3568115234375 +dsd,False,True,8,64,64,2,2,2,inf +dsd,False,True,8,64,128,2,2,2,inf +dsd,False,True,8,128,8,2,2,2,inf +dsd,False,True,8,128,16,2,2,2,304.826025390625 +dsd,False,True,8,128,32,2,2,2,inf +dsd,False,True,8,128,64,2,2,2,inf +dsd,False,True,8,128,128,2,2,2,inf +dsd,False,True,16,8,8,4,2,2,116.1891845703125 +dsd,False,True,16,8,16,2,8,4,76.41241455078125 +dsd,False,True,16,8,32,2,2,8,74.0337646484375 +dsd,False,True,16,8,64,2,8,16,71.31473999023437 +dsd,False,True,16,8,128,4,8,16,70.32268676757812 +dsd,False,True,16,16,8,2,2,2,92.1997314453125 +dsd,False,True,16,16,16,4,8,2,75.77088012695313 +dsd,False,True,16,16,32,4,2,2,78.14052124023438 +dsd,False,True,16,16,64,16,2,2,78.590771484375 +dsd,False,True,16,16,128,2,2,16,70.82373046875 +dsd,False,True,16,32,8,2,2,2,124.4768310546875 +dsd,False,True,16,32,16,4,2,2,85.75979614257812 +dsd,False,True,16,32,32,8,2,2,81.34788818359375 +dsd,False,True,16,32,64,4,8,2,78.4594970703125 +dsd,False,True,16,32,128,2,2,2,inf +dsd,False,True,16,64,8,2,2,2,166.36641845703124 +dsd,False,True,16,64,16,4,8,2,117.9651123046875 +dsd,False,True,16,64,32,8,8,2,97.21170043945312 +dsd,False,True,16,64,64,2,2,2,inf +dsd,False,True,16,64,128,2,2,2,inf +dsd,False,True,16,128,8,2,2,2,227.211865234375 +dsd,False,True,16,128,16,4,2,2,183.67445068359376 +dsd,False,True,16,128,32,2,2,2,inf +dsd,False,True,16,128,64,2,2,2,inf +dsd,False,True,16,128,128,2,2,2,inf +dsd,False,True,32,8,8,8,2,2,82.42647094726563 +dsd,False,True,32,8,16,8,8,2,45.360128784179686 +dsd,False,True,32,8,32,16,4,2,42.12101135253906 +dsd,False,True,32,8,64,2,4,16,40.78192749023437 +dsd,False,True,32,8,128,4,4,16,37.76192932128906 +dsd,False,True,32,16,8,4,2,2,63.4661865234375 +dsd,False,True,32,16,16,8,2,2,46.11286926269531 +dsd,False,True,32,16,32,16,2,2,44.24960021972656 +dsd,False,True,32,16,64,8,8,2,42.553549194335936 +dsd,False,True,32,16,128,16,2,2,37.30780029296875 +dsd,False,True,32,32,8,4,4,2,94.68699951171877 +dsd,False,True,32,32,16,8,8,2,60.38486938476562 +dsd,False,True,32,32,32,8,8,4,48.58204040527344 +dsd,False,True,32,32,64,4,8,8,45.13853454589844 +dsd,False,True,32,32,128,2,2,2,inf +dsd,False,True,32,64,8,4,2,2,121.8534423828125 +dsd,False,True,32,64,16,8,4,2,80.84459228515625 +dsd,False,True,32,64,32,8,16,2,62.3984619140625 +dsd,False,True,32,64,64,2,2,2,inf +dsd,False,True,32,64,128,2,2,2,inf +dsd,False,True,32,128,8,2,2,2,171.650146484375 +dsd,False,True,32,128,16,4,4,2,116.75975341796877 +dsd,False,True,32,128,32,2,2,2,inf +dsd,False,True,32,128,64,2,2,2,inf +dsd,False,True,32,128,128,2,2,2,inf +dsd,False,True,64,8,8,16,2,2,66.351708984375 +dsd,False,True,64,8,16,8,2,4,34.971533203125 +dsd,False,True,64,8,32,4,2,16,29.758053588867188 +dsd,False,True,64,8,64,4,8,16,26.7695068359375 +dsd,False,True,64,8,128,8,2,8,25.301702880859374 +dsd,False,True,64,16,8,8,16,2,58.3183349609375 +dsd,False,True,64,16,16,16,2,2,33.02317810058594 +dsd,False,True,64,16,32,8,8,4,30.07764587402344 +dsd,False,True,64,16,64,8,16,4,27.921511840820312 +dsd,False,True,64,16,128,8,4,8,25.865625 +dsd,False,True,64,32,8,8,4,2,85.84161376953125 +dsd,False,True,64,32,16,8,8,2,50.681549072265625 +dsd,False,True,64,32,32,8,16,4,35.13456726074219 +dsd,False,True,64,32,64,8,4,4,30.035763549804688 +dsd,False,True,64,32,128,2,2,2,inf +dsd,False,True,64,64,8,4,2,2,105.23985595703124 +dsd,False,True,64,64,16,8,2,2,67.3208251953125 +dsd,False,True,64,64,32,16,4,2,47.07389526367187 +dsd,False,True,64,64,64,2,2,2,inf +dsd,False,True,64,64,128,2,2,2,inf +dsd,False,True,64,128,8,2,2,2,145.19971923828126 +dsd,False,True,64,128,16,4,2,2,89.2669677734375 +dsd,False,True,64,128,32,2,2,2,inf +dsd,False,True,64,128,64,2,2,2,inf +dsd,False,True,64,128,128,2,2,2,inf +dsd,False,True,128,8,8,16,2,4,64.46704711914063 +dsd,False,True,128,8,16,4,4,16,36.77900695800781 +dsd,False,True,128,8,32,8,4,8,30.70411376953125 +dsd,False,True,128,8,64,16,2,4,25.41475830078125 +dsd,False,True,128,8,128,16,4,4,25.26688995361328 +dsd,False,True,128,16,8,16,2,2,55.80257568359375 +dsd,False,True,128,16,16,16,2,2,33.952972412109375 +dsd,False,True,128,16,32,8,2,4,27.87225646972656 +dsd,False,True,128,16,64,8,2,8,24.682701110839844 +dsd,False,True,128,16,128,8,16,8,23.585382080078126 +dsd,False,True,128,32,8,8,4,2,88.70225830078125 +dsd,False,True,128,32,16,8,4,2,46.95654296875 +dsd,False,True,128,32,32,8,8,4,31.39991455078125 +dsd,False,True,128,32,64,8,4,8,26.41059875488281 +dsd,False,True,128,32,128,2,2,2,inf +dsd,False,True,128,64,8,4,2,2,110.9296142578125 +dsd,False,True,128,64,16,8,16,2,60.833282470703125 +dsd,False,True,128,64,32,16,16,2,40.478823852539065 +dsd,False,True,128,64,64,2,2,2,inf +dsd,False,True,128,64,128,2,2,2,inf +dsd,False,True,128,128,8,2,2,2,inf +dsd,False,True,128,128,16,2,2,2,inf +dsd,False,True,128,128,32,2,2,2,inf +dsd,False,True,128,128,64,2,2,2,inf +dsd,False,True,128,128,128,2,2,2,inf +dsd,True,False,8,8,8,2,4,2,196.04478759765624 +dsd,True,False,8,8,16,4,8,2,149.0472900390625 +dsd,True,False,8,8,32,2,8,8,142.06729736328126 +dsd,True,False,8,8,64,2,2,16,138.90089111328126 +dsd,True,False,8,8,128,2,2,2,inf +dsd,True,False,8,16,8,2,4,2,166.29586181640624 +dsd,True,False,8,16,16,2,4,2,145.54091796875 +dsd,True,False,8,16,32,2,16,4,140.79332275390624 +dsd,True,False,8,16,64,8,4,4,141.57967529296874 +dsd,True,False,8,16,128,2,4,16,138.7620361328125 +dsd,True,False,8,32,8,2,4,2,177.61934814453124 +dsd,True,False,8,32,16,2,16,2,144.644287109375 +dsd,True,False,8,32,32,2,8,4,142.95357666015624 +dsd,True,False,8,32,64,2,2,4,141.93233642578124 +dsd,True,False,8,32,128,2,2,2,inf +dsd,True,False,8,64,8,2,2,2,177.02919921875 +dsd,True,False,8,64,16,2,8,2,142.3205322265625 +dsd,True,False,8,64,32,4,4,2,142.78870849609376 +dsd,True,False,8,64,64,2,2,2,inf +dsd,True,False,8,64,128,2,2,2,inf +dsd,True,False,8,128,8,2,4,2,191.33798828125 +dsd,True,False,8,128,16,2,2,2,142.40296630859376 +dsd,True,False,8,128,32,2,2,2,inf +dsd,True,False,8,128,64,2,2,2,inf +dsd,True,False,8,128,128,2,2,2,inf +dsd,True,False,16,8,8,4,4,2,155.755419921875 +dsd,True,False,16,8,16,4,8,2,77.15879516601562 +dsd,True,False,16,8,32,2,2,8,69.79932250976563 +dsd,True,False,16,8,64,2,4,16,62.8180908203125 +dsd,True,False,16,8,128,16,2,4,69.74443359375 +dsd,True,False,16,16,8,2,16,2,104.33758544921876 +dsd,True,False,16,16,16,4,4,2,74.32406616210938 +dsd,True,False,16,16,32,2,16,8,69.3928955078125 +dsd,True,False,16,16,64,16,2,2,71.24234008789062 +dsd,True,False,16,16,128,2,2,16,70.21209716796875 +dsd,True,False,16,32,8,2,8,2,120.39412841796874 +dsd,True,False,16,32,16,4,8,2,74.49405517578126 +dsd,True,False,16,32,32,4,2,4,71.88016967773437 +dsd,True,False,16,32,64,2,2,4,72.29121704101563 +dsd,True,False,16,32,128,2,2,2,inf +dsd,True,False,16,64,8,2,4,2,118.81768798828124 +dsd,True,False,16,64,16,4,16,2,75.29912109375 +dsd,True,False,16,64,32,2,16,2,72.90520629882812 +dsd,True,False,16,64,64,2,2,2,inf +dsd,True,False,16,64,128,2,2,2,inf +dsd,True,False,16,128,8,2,2,2,127.24039306640626 +dsd,True,False,16,128,16,2,16,2,81.93239135742188 +dsd,True,False,16,128,32,2,2,2,inf +dsd,True,False,16,128,64,2,2,2,inf +dsd,True,False,16,128,128,2,2,2,inf +dsd,True,False,32,8,8,4,2,4,183.60247802734372 +dsd,True,False,32,8,16,4,2,4,63.06016235351562 +dsd,True,False,32,8,32,4,4,8,40.452703857421874 +dsd,True,False,32,8,64,4,2,16,34.856756591796874 +dsd,True,False,32,8,128,4,2,16,35.30137634277344 +dsd,True,False,32,16,8,4,2,2,104.81192626953126 +dsd,True,False,32,16,16,8,2,2,58.9707275390625 +dsd,True,False,32,16,32,8,16,4,38.89725341796875 +dsd,True,False,32,16,64,4,4,8,35.332916259765625 +dsd,True,False,32,16,128,16,4,2,32.474420166015626 +dsd,True,False,32,32,8,4,4,2,132.05728759765626 +dsd,True,False,32,32,16,8,8,2,67.71791381835938 +dsd,True,False,32,32,32,16,4,2,42.35079650878906 +dsd,True,False,32,32,64,4,2,4,37.73572998046875 +dsd,True,False,32,32,128,2,2,2,inf +dsd,True,False,32,64,8,4,2,2,133.3865478515625 +dsd,True,False,32,64,16,8,4,2,70.73126220703125 +dsd,True,False,32,64,32,8,8,2,44.9080322265625 +dsd,True,False,32,64,64,2,2,2,inf +dsd,True,False,32,64,128,2,2,2,inf +dsd,True,False,32,128,8,2,2,2,151.80430908203124 +dsd,True,False,32,128,16,4,4,2,79.52486572265624 +dsd,True,False,32,128,32,2,2,2,inf +dsd,True,False,32,128,64,2,2,2,inf +dsd,True,False,32,128,128,2,2,2,inf +dsd,True,False,64,8,8,8,2,4,276.3603759765625 +dsd,True,False,64,8,16,8,2,4,80.73357543945312 +dsd,True,False,64,8,32,8,2,8,47.04132995605469 +dsd,True,False,64,8,64,8,2,8,33.257470703125 +dsd,True,False,64,8,128,16,2,4,27.2468994140625 +dsd,True,False,64,16,8,8,16,2,154.29814453125 +dsd,True,False,64,16,16,16,16,2,80.89815063476563 +dsd,True,False,64,16,32,8,4,8,48.67113037109375 +dsd,True,False,64,16,64,8,16,8,34.17445678710938 +dsd,True,False,64,16,128,8,4,8,27.85955810546875 +dsd,True,False,64,32,8,8,8,2,182.21177978515624 +dsd,True,False,64,32,16,16,4,2,93.49498901367188 +dsd,True,False,64,32,32,16,16,2,52.84812622070312 +dsd,True,False,64,32,64,16,4,4,35.58829956054687 +dsd,True,False,64,32,128,2,2,2,inf +dsd,True,False,64,64,8,8,4,2,186.49927978515623 +dsd,True,False,64,64,16,8,16,2,99.7465087890625 +dsd,True,False,64,64,32,16,4,2,55.89176025390625 +dsd,True,False,64,64,64,2,2,2,inf +dsd,True,False,64,64,128,2,2,2,inf +dsd,True,False,64,128,8,4,16,2,203.6135986328125 +dsd,True,False,64,128,16,4,2,2,106.8759033203125 +dsd,True,False,64,128,32,2,2,2,inf +dsd,True,False,64,128,64,2,2,2,inf +dsd,True,False,64,128,128,2,2,2,inf +dsd,True,False,128,8,8,2,2,2,inf +dsd,True,False,128,8,16,8,2,8,132.375146484375 +dsd,True,False,128,8,32,8,4,8,76.41856079101562 +dsd,True,False,128,8,64,8,2,8,46.568756103515625 +dsd,True,False,128,8,128,16,4,4,34.65922546386719 +dsd,True,False,128,16,8,16,2,2,263.89287109375 +dsd,True,False,128,16,16,16,8,2,138.90672607421874 +dsd,True,False,128,16,32,8,16,8,77.42740478515626 +dsd,True,False,128,16,64,8,8,8,47.06150512695312 +dsd,True,False,128,16,128,8,2,8,34.272256469726564 +dsd,True,False,128,32,8,16,4,2,295.225244140625 +dsd,True,False,128,32,16,16,4,2,149.99757080078126 +dsd,True,False,128,32,32,8,16,4,81.86920776367188 +dsd,True,False,128,32,64,8,2,8,48.23459777832032 +dsd,True,False,128,32,128,2,2,2,inf +dsd,True,False,128,64,8,8,8,2,306.1161865234375 +dsd,True,False,128,64,16,8,8,2,157.15245361328124 +dsd,True,False,128,64,32,16,16,2,83.70462646484376 +dsd,True,False,128,64,64,2,2,2,inf +dsd,True,False,128,64,128,2,2,2,inf +dsd,True,False,128,128,8,2,2,2,inf +dsd,True,False,128,128,16,2,2,2,inf +dsd,True,False,128,128,32,2,2,2,inf +dsd,True,False,128,128,64,2,2,2,inf +dsd,True,False,128,128,128,2,2,2,inf +dsd,True,True,8,8,8,2,2,2,205.176123046875 +dsd,True,True,8,8,16,4,8,2,154.469580078125 +dsd,True,True,8,8,32,2,4,8,143.70413818359376 +dsd,True,True,8,8,64,2,2,16,140.26270751953126 +dsd,True,True,8,8,128,4,2,16,139.13067626953125 +dsd,True,True,8,16,8,2,16,2,212.8994384765625 +dsd,True,True,8,16,16,2,8,2,148.4000244140625 +dsd,True,True,8,16,32,4,16,2,151.075537109375 +dsd,True,True,8,16,64,8,16,4,147.540478515625 +dsd,True,True,8,16,128,4,16,8,144.277099609375 +dsd,True,True,8,32,8,2,2,2,275.79443359375 +dsd,True,True,8,32,16,2,16,2,155.88167724609374 +dsd,True,True,8,32,32,4,2,2,152.702978515625 +dsd,True,True,8,32,64,2,2,4,137.5283203125 +dsd,True,True,8,32,128,2,2,2,inf +dsd,True,True,8,64,8,2,16,2,391.077587890625 +dsd,True,True,8,64,16,2,8,2,198.9496826171875 +dsd,True,True,8,64,32,4,2,2,172.26177978515625 +dsd,True,True,8,64,64,2,2,2,inf +dsd,True,True,8,64,128,2,2,2,inf +dsd,True,True,8,128,8,2,2,2,inf +dsd,True,True,8,128,16,2,2,2,310.083984375 +dsd,True,True,8,128,32,2,2,2,inf +dsd,True,True,8,128,64,2,2,2,inf +dsd,True,True,8,128,128,2,2,2,inf +dsd,True,True,16,8,8,4,2,2,162.6450927734375 +dsd,True,True,16,8,16,4,2,2,78.73065185546875 +dsd,True,True,16,8,32,2,4,8,72.56627197265625 +dsd,True,True,16,8,64,2,8,16,69.21594848632813 +dsd,True,True,16,8,128,4,2,16,68.5728759765625 +dsd,True,True,16,16,8,2,2,2,114.4301513671875 +dsd,True,True,16,16,16,4,4,2,78.94384765625 +dsd,True,True,16,16,32,4,4,2,80.41858520507813 +dsd,True,True,16,16,64,16,2,2,78.91220703125 +dsd,True,True,16,16,128,2,8,16,70.98726196289063 +dsd,True,True,16,32,8,2,8,2,144.57723388671874 +dsd,True,True,16,32,16,4,2,2,94.278759765625 +dsd,True,True,16,32,32,8,2,2,81.99147338867188 +dsd,True,True,16,32,64,4,8,2,78.5966064453125 +dsd,True,True,16,32,128,2,2,2,inf +dsd,True,True,16,64,8,2,4,2,177.346044921875 +dsd,True,True,16,64,16,4,8,2,122.8338134765625 +dsd,True,True,16,64,32,8,4,2,98.97113647460938 +dsd,True,True,16,64,64,2,2,2,inf +dsd,True,True,16,64,128,2,2,2,inf +dsd,True,True,16,128,8,2,8,2,236.5805419921875 +dsd,True,True,16,128,16,4,2,2,187.52010498046877 +dsd,True,True,16,128,32,2,2,2,inf +dsd,True,True,16,128,64,2,2,2,inf +dsd,True,True,16,128,128,2,2,2,inf +dsd,True,True,32,8,8,8,8,2,179.2468994140625 +dsd,True,True,32,8,16,4,2,4,64.5190673828125 +dsd,True,True,32,8,32,4,8,8,42.91881103515625 +dsd,True,True,32,8,64,4,4,16,38.00729675292969 +dsd,True,True,32,8,128,4,4,16,36.25267333984375 +dsd,True,True,32,16,8,4,16,2,113.594140625 +dsd,True,True,32,16,16,8,8,2,63.77820434570312 +dsd,True,True,32,16,32,8,2,4,46.9275634765625 +dsd,True,True,32,16,64,16,16,2,42.92484130859375 +dsd,True,True,32,16,128,8,16,4,36.99824523925781 +dsd,True,True,32,32,8,4,2,2,144.70738525390624 +dsd,True,True,32,32,16,8,8,2,80.85911254882812 +dsd,True,True,32,32,32,16,8,2,54.71138305664063 +dsd,True,True,32,32,64,8,2,4,45.3243896484375 +dsd,True,True,32,32,128,2,2,2,inf +dsd,True,True,32,64,8,4,4,2,160.2193359375 +dsd,True,True,32,64,16,8,4,2,97.89020385742188 +dsd,True,True,32,64,32,8,4,2,71.82909545898437 +dsd,True,True,32,64,64,2,2,2,inf +dsd,True,True,32,64,128,2,2,2,inf +dsd,True,True,32,128,8,4,8,2,201.171240234375 +dsd,True,True,32,128,16,4,2,2,133.8135498046875 +dsd,True,True,32,128,32,2,2,2,inf +dsd,True,True,32,128,64,2,2,2,inf +dsd,True,True,32,128,128,2,2,2,inf +dsd,True,True,64,8,8,8,2,4,273.8006103515625 +dsd,True,True,64,8,16,8,2,4,82.39605712890625 +dsd,True,True,64,8,32,8,2,8,47.792843627929685 +dsd,True,True,64,8,64,8,2,8,34.33840637207031 +dsd,True,True,64,8,128,8,8,8,28.0742919921875 +dsd,True,True,64,16,8,8,2,2,156.6045166015625 +dsd,True,True,64,16,16,16,16,2,84.16409301757812 +dsd,True,True,64,16,32,8,16,8,51.99247436523437 +dsd,True,True,64,16,64,8,16,8,37.035919189453125 +dsd,True,True,64,16,128,8,4,8,30.449459838867188 +dsd,True,True,64,32,8,8,4,2,187.31827392578128 +dsd,True,True,64,32,16,16,16,2,99.79442138671877 +dsd,True,True,64,32,32,8,4,4,59.79473876953125 +dsd,True,True,64,32,64,16,4,4,41.73803405761719 +dsd,True,True,64,32,128,2,2,2,inf +dsd,True,True,64,64,8,8,8,2,199.9794921875 +dsd,True,True,64,64,16,8,16,2,112.7275146484375 +dsd,True,True,64,64,32,16,4,2,69.20181884765626 +dsd,True,True,64,64,64,2,2,2,inf +dsd,True,True,64,64,128,2,2,2,inf +dsd,True,True,64,128,8,4,8,2,228.7771484375 +dsd,True,True,64,128,16,4,8,2,133.517724609375 +dsd,True,True,64,128,32,2,2,2,inf +dsd,True,True,64,128,64,2,2,2,inf +dsd,True,True,64,128,128,2,2,2,inf +dsd,True,True,128,8,8,2,2,2,inf +dsd,True,True,128,8,16,8,2,8,133.31158447265625 +dsd,True,True,128,8,32,8,2,8,76.98534545898437 +dsd,True,True,128,8,64,16,2,4,47.61773986816407 +dsd,True,True,128,8,128,16,2,4,35.36865234375 +dsd,True,True,128,16,8,16,2,2,265.8430908203125 +dsd,True,True,128,16,16,8,2,4,139.7401611328125 +dsd,True,True,128,16,32,8,16,8,78.98224487304688 +dsd,True,True,128,16,64,8,16,8,48.78755798339844 +dsd,True,True,128,16,128,8,2,8,35.972607421875 +dsd,True,True,128,32,8,16,8,2,299.2164794921875 +dsd,True,True,128,32,16,16,8,2,153.24393310546876 +dsd,True,True,128,32,32,8,16,4,85.35081176757812 +dsd,True,True,128,32,64,8,2,8,51.50474243164062 +dsd,True,True,128,32,128,2,2,2,inf +dsd,True,True,128,64,8,8,4,2,312.0203857421875 +dsd,True,True,128,64,16,8,4,2,163.46982421875 +dsd,True,True,128,64,32,16,8,2,90.49733276367188 +dsd,True,True,128,64,64,2,2,2,inf +dsd,True,True,128,64,128,2,2,2,inf +dsd,True,True,128,128,8,2,2,2,inf +dsd,True,True,128,128,16,2,2,2,inf +dsd,True,True,128,128,32,2,2,2,inf +dsd,True,True,128,128,64,2,2,2,inf +dsd,True,True,128,128,128,2,2,2,inf +sdd,False,False,8,8,8,2,2,2,335.815380859375 +sdd,False,False,8,8,16,4,4,2,178.04390869140624 +sdd,False,False,8,8,32,2,8,8,138.64775390625 +sdd,False,False,8,8,64,2,4,16,133.78795166015624 +sdd,False,False,8,8,128,2,2,2,inf +sdd,False,False,8,16,8,4,8,2,301.8798095703125 +sdd,False,False,8,16,16,2,4,2,169.05899658203126 +sdd,False,False,8,16,32,4,4,2,139.20625 +sdd,False,False,8,16,64,2,16,8,137.662255859375 +sdd,False,False,8,16,128,8,4,4,139.66182861328124 +sdd,False,False,8,32,8,2,2,4,294.104052734375 +sdd,False,False,8,32,16,4,8,2,158.1086669921875 +sdd,False,False,8,32,32,4,8,2,138.7292724609375 +sdd,False,False,8,32,64,2,16,4,138.87989501953126 +sdd,False,False,8,32,128,8,8,2,139.49183349609376 +sdd,False,False,8,64,8,2,4,2,298.197900390625 +sdd,False,False,8,64,16,2,4,2,149.730615234375 +sdd,False,False,8,64,32,2,4,4,139.41165771484376 +sdd,False,False,8,64,64,2,2,2,139.8656982421875 +sdd,False,False,8,64,128,4,8,2,139.4408203125 +sdd,False,False,8,128,8,2,2,2,inf +sdd,False,False,8,128,16,2,8,2,163.76964111328124 +sdd,False,False,8,128,32,2,16,2,140.90045166015625 +sdd,False,False,8,128,64,2,16,2,140.58863525390626 +sdd,False,False,8,128,128,2,2,2,inf +sdd,False,False,16,8,8,4,4,2,166.25458984375 +sdd,False,False,16,8,16,4,8,2,90.27747192382812 +sdd,False,False,16,8,32,4,8,8,69.10761108398438 +sdd,False,False,16,8,64,4,2,8,67.66807250976562 +sdd,False,False,16,8,128,8,4,8,66.71800537109375 +sdd,False,False,16,16,8,4,4,2,158.8569091796875 +sdd,False,False,16,16,16,2,2,2,84.75748901367187 +sdd,False,False,16,16,32,4,2,4,70.751025390625 +sdd,False,False,16,16,64,4,4,4,69.515673828125 +sdd,False,False,16,16,128,4,16,8,71.41632080078125 +sdd,False,False,16,32,8,2,2,2,147.57744140625 +sdd,False,False,16,32,16,2,8,2,79.40413208007813 +sdd,False,False,16,32,32,2,2,2,70.26575317382813 +sdd,False,False,16,32,64,2,4,4,70.35594482421875 +sdd,False,False,16,32,128,2,4,8,69.65299072265626 +sdd,False,False,16,64,8,2,16,2,136.4761474609375 +sdd,False,False,16,64,16,2,4,2,75.92714233398438 +sdd,False,False,16,64,32,2,2,2,71.1752685546875 +sdd,False,False,16,64,64,2,8,2,71.73375854492187 +sdd,False,False,16,64,128,2,16,4,70.2403564453125 +sdd,False,False,16,128,8,2,2,2,124.0837158203125 +sdd,False,False,16,128,16,2,4,2,81.1736083984375 +sdd,False,False,16,128,32,2,8,2,71.90896606445312 +sdd,False,False,16,128,64,2,8,2,70.342041015625 +sdd,False,False,16,128,128,2,2,2,inf +sdd,False,False,32,8,8,8,2,2,89.910986328125 +sdd,False,False,32,8,16,4,8,4,47.08485107421875 +sdd,False,False,32,8,32,4,4,8,35.66530151367188 +sdd,False,False,32,8,64,8,8,8,36.20536193847656 +sdd,False,False,32,8,128,8,2,8,37.362667846679685 +sdd,False,False,32,16,8,2,8,4,80.07280883789062 +sdd,False,False,32,16,16,8,2,2,47.84732055664063 +sdd,False,False,32,16,32,4,8,8,38.042727661132815 +sdd,False,False,32,16,64,8,8,2,37.51362609863281 +sdd,False,False,32,16,128,4,4,8,37.37781677246094 +sdd,False,False,32,32,8,4,8,2,85.47921752929688 +sdd,False,False,32,32,16,8,4,2,45.53605041503906 +sdd,False,False,32,32,32,8,4,2,36.92113952636719 +sdd,False,False,32,32,64,4,4,4,36.73414611816406 +sdd,False,False,32,32,128,16,2,2,36.91141052246094 +sdd,False,False,32,64,8,4,4,2,90.2134765625 +sdd,False,False,32,64,16,4,4,2,52.18754272460937 +sdd,False,False,32,64,32,8,8,2,38.48304748535156 +sdd,False,False,32,64,64,4,2,2,37.11744079589844 +sdd,False,False,32,64,128,4,8,4,35.8972412109375 +sdd,False,False,32,128,8,2,2,2,inf +sdd,False,False,32,128,16,2,2,2,inf +sdd,False,False,32,128,32,2,2,2,inf +sdd,False,False,32,128,64,2,2,2,inf +sdd,False,False,32,128,128,2,2,2,inf +sdd,False,False,64,8,8,8,8,4,65.85814819335937 +sdd,False,False,64,8,16,8,8,4,31.638424682617188 +sdd,False,False,64,8,32,8,2,8,24.71656951904297 +sdd,False,False,64,8,64,8,2,8,24.32788543701172 +sdd,False,False,64,8,128,8,2,8,23.616490173339844 +sdd,False,False,64,16,8,8,2,2,54.6862060546875 +sdd,False,False,64,16,16,16,2,2,31.148028564453124 +sdd,False,False,64,16,32,8,8,4,25.38966064453125 +sdd,False,False,64,16,64,8,16,4,24.080995178222658 +sdd,False,False,64,16,128,8,16,8,24.22169647216797 +sdd,False,False,64,32,8,8,2,2,77.34814453125 +sdd,False,False,64,32,16,8,8,2,44.23567504882813 +sdd,False,False,64,32,32,8,8,4,28.473959350585936 +sdd,False,False,64,32,64,8,4,4,23.699337768554688 +sdd,False,False,64,32,128,8,2,8,24.34996490478516 +sdd,False,False,64,64,8,2,2,2,inf +sdd,False,False,64,64,16,2,2,2,inf +sdd,False,False,64,64,32,2,2,2,inf +sdd,False,False,64,64,64,2,2,2,inf +sdd,False,False,64,64,128,2,2,2,inf +sdd,False,False,64,128,8,2,2,2,inf +sdd,False,False,64,128,16,2,2,2,inf +sdd,False,False,64,128,32,2,2,2,inf +sdd,False,False,64,128,64,2,2,2,inf +sdd,False,False,64,128,128,2,2,2,inf +sdd,False,False,128,8,8,16,2,4,58.91676025390625 +sdd,False,False,128,8,16,8,8,8,26.94410095214844 +sdd,False,False,128,8,32,8,4,8,23.53326110839844 +sdd,False,False,128,8,64,16,2,4,21.76747589111328 +sdd,False,False,128,8,128,16,4,4,23.42686767578125 +sdd,False,False,128,16,8,16,2,2,50.80626525878906 +sdd,False,False,128,16,16,16,8,2,28.22850646972656 +sdd,False,False,128,16,32,8,4,4,24.08775634765625 +sdd,False,False,128,16,64,8,4,8,23.01982727050781 +sdd,False,False,128,16,128,8,2,8,23.040103149414065 +sdd,False,False,128,32,8,2,2,2,inf +sdd,False,False,128,32,16,2,2,2,inf +sdd,False,False,128,32,32,2,2,2,inf +sdd,False,False,128,32,64,2,2,2,inf +sdd,False,False,128,32,128,2,2,2,inf +sdd,False,False,128,64,8,2,2,2,inf +sdd,False,False,128,64,16,2,2,2,inf +sdd,False,False,128,64,32,2,2,2,inf +sdd,False,False,128,64,64,2,2,2,inf +sdd,False,False,128,64,128,2,2,2,inf +sdd,False,False,128,128,8,2,2,2,inf +sdd,False,False,128,128,16,2,2,2,inf +sdd,False,False,128,128,32,2,2,2,inf +sdd,False,False,128,128,64,2,2,2,inf +sdd,False,False,128,128,128,2,2,2,inf +sdd,False,True,8,8,8,2,2,2,399.1469970703125 +sdd,False,True,8,8,16,4,4,4,393.4630859375 +sdd,False,True,8,8,32,2,2,8,379.371923828125 +sdd,False,True,8,8,64,8,8,16,406.3237060546875 +sdd,False,True,8,8,128,4,4,16,394.24541015625 +sdd,False,True,8,16,8,2,2,2,206.7267578125 +sdd,False,True,8,16,16,4,8,2,198.42396240234376 +sdd,False,True,8,16,32,8,8,2,187.6971557617188 +sdd,False,True,8,16,64,2,16,16,193.23115234375 +sdd,False,True,8,16,128,4,4,16,197.74112548828128 +sdd,False,True,8,32,8,2,2,2,271.7641845703125 +sdd,False,True,8,32,16,2,16,2,139.34661865234375 +sdd,False,True,8,32,32,4,8,2,138.78026123046874 +sdd,False,True,8,32,64,2,16,4,139.386962890625 +sdd,False,True,8,32,128,8,8,2,140.2882080078125 +sdd,False,True,8,64,8,2,2,2,379.437060546875 +sdd,False,True,8,64,16,2,2,2,185.58453369140625 +sdd,False,True,8,64,32,4,2,2,160.37774658203125 +sdd,False,True,8,64,64,4,2,2,159.0248291015625 +sdd,False,True,8,64,128,4,16,2,162.72178955078124 +sdd,False,True,8,128,8,2,2,2,inf +sdd,False,True,8,128,16,2,16,2,306.841796875 +sdd,False,True,8,128,32,2,8,2,288.7798828125 +sdd,False,True,8,128,64,2,4,2,288.84541015625 +sdd,False,True,8,128,128,2,2,2,inf +sdd,False,True,16,8,8,2,4,4,202.17528076171877 +sdd,False,True,16,8,16,4,2,4,194.959765625 +sdd,False,True,16,8,32,8,2,2,181.95906982421877 +sdd,False,True,16,8,64,4,4,8,177.53671875 +sdd,False,True,16,8,128,4,2,16,169.419775390625 +sdd,False,True,16,16,8,2,2,2,101.73458862304688 +sdd,False,True,16,16,16,4,2,2,100.47365112304688 +sdd,False,True,16,16,32,8,4,2,92.6509033203125 +sdd,False,True,16,16,64,2,16,8,95.26671142578124 +sdd,False,True,16,16,128,8,16,4,100.47825927734377 +sdd,False,True,16,32,8,2,2,2,122.93560791015624 +sdd,False,True,16,32,16,4,8,2,80.80701293945313 +sdd,False,True,16,32,32,4,16,4,71.40126953125 +sdd,False,True,16,32,64,4,16,2,71.59755249023438 +sdd,False,True,16,32,128,4,4,4,71.12864990234375 +sdd,False,True,16,64,8,2,8,2,165.405078125 +sdd,False,True,16,64,16,4,2,2,114.4490966796875 +sdd,False,True,16,64,32,8,4,2,95.41509399414062 +sdd,False,True,16,64,64,4,8,4,92.25265502929688 +sdd,False,True,16,64,128,8,16,2,92.0755126953125 +sdd,False,True,16,128,8,2,4,2,235.228369140625 +sdd,False,True,16,128,16,4,16,2,184.256689453125 +sdd,False,True,16,128,32,4,2,2,156.644970703125 +sdd,False,True,16,128,64,4,2,2,155.04056396484376 +sdd,False,True,16,128,128,2,2,2,inf +sdd,False,True,32,8,8,8,4,2,103.24459228515624 +sdd,False,True,32,8,16,4,8,4,97.24036865234376 +sdd,False,True,32,8,32,2,8,16,89.53897094726562 +sdd,False,True,32,8,64,8,8,8,89.11841430664063 +sdd,False,True,32,8,128,8,4,16,93.94298706054688 +sdd,False,True,32,16,8,4,16,2,68.28972778320312 +sdd,False,True,32,16,16,8,2,2,51.45518188476562 +sdd,False,True,32,16,32,8,16,4,49.334170532226565 +sdd,False,True,32,16,64,2,4,8,49.61821899414063 +sdd,False,True,32,16,128,2,4,16,49.259622192382814 +sdd,False,True,32,32,8,4,4,2,96.26163330078126 +sdd,False,True,32,32,16,8,8,2,60.615576171875 +sdd,False,True,32,32,32,16,4,2,44.796722412109375 +sdd,False,True,32,32,64,8,2,4,38.72890930175781 +sdd,False,True,32,32,128,8,16,4,38.10281982421875 +sdd,False,True,32,64,8,4,2,2,116.53406982421876 +sdd,False,True,32,64,16,8,8,2,80.78519287109376 +sdd,False,True,32,64,32,8,16,2,60.293634033203126 +sdd,False,True,32,64,64,4,16,4,56.82205810546875 +sdd,False,True,32,64,128,8,4,4,56.6828857421875 +sdd,False,True,32,128,8,2,2,2,inf +sdd,False,True,32,128,16,2,2,2,inf +sdd,False,True,32,128,32,2,2,2,inf +sdd,False,True,32,128,64,2,2,2,inf +sdd,False,True,32,128,128,2,2,2,inf +sdd,False,True,64,8,8,8,4,4,66.53447875976562 +sdd,False,True,64,8,16,2,8,16,50.67438049316407 +sdd,False,True,64,8,32,8,2,8,49.14698181152344 +sdd,False,True,64,8,64,2,4,16,49.4206787109375 +sdd,False,True,64,8,128,8,2,8,48.49561462402344 +sdd,False,True,64,16,8,8,2,2,54.80970458984375 +sdd,False,True,64,16,16,16,4,2,33.95225524902344 +sdd,False,True,64,16,32,8,4,8,28.88048706054688 +sdd,False,True,64,16,64,8,16,4,27.716162109375 +sdd,False,True,64,16,128,8,2,8,27.1791015625 +sdd,False,True,64,32,8,8,8,2,84.09588623046875 +sdd,False,True,64,32,16,8,4,2,50.90191345214844 +sdd,False,True,64,32,32,8,8,4,35.54404296875 +sdd,False,True,64,32,64,16,2,4,29.81189880371094 +sdd,False,True,64,32,128,8,4,8,27.36670837402344 +sdd,False,True,64,64,8,2,2,2,inf +sdd,False,True,64,64,16,2,2,2,inf +sdd,False,True,64,64,32,2,2,2,inf +sdd,False,True,64,64,64,2,2,2,inf +sdd,False,True,64,64,128,2,2,2,inf +sdd,False,True,64,128,8,2,2,2,inf +sdd,False,True,64,128,16,2,2,2,inf +sdd,False,True,64,128,32,2,2,2,inf +sdd,False,True,64,128,64,2,2,2,inf +sdd,False,True,64,128,128,2,2,2,inf +sdd,False,True,128,8,8,16,8,4,58.8000244140625 +sdd,False,True,128,8,16,8,8,8,32.438168334960935 +sdd,False,True,128,8,32,4,2,16,30.158950805664062 +sdd,False,True,128,8,64,16,4,4,28.051763916015624 +sdd,False,True,128,8,128,16,4,4,27.196417236328124 +sdd,False,True,128,16,8,16,2,2,49.500979614257815 +sdd,False,True,128,16,16,16,4,2,31.385806274414065 +sdd,False,True,128,16,32,8,2,4,24.700518798828124 +sdd,False,True,128,16,64,8,16,8,23.897190856933594 +sdd,False,True,128,16,128,8,8,8,23.10245056152344 +sdd,False,True,128,32,8,2,2,2,inf +sdd,False,True,128,32,16,2,2,2,inf +sdd,False,True,128,32,32,2,2,2,inf +sdd,False,True,128,32,64,2,2,2,inf +sdd,False,True,128,32,128,2,2,2,inf +sdd,False,True,128,64,8,2,2,2,inf +sdd,False,True,128,64,16,2,2,2,inf +sdd,False,True,128,64,32,2,2,2,inf +sdd,False,True,128,64,64,2,2,2,inf +sdd,False,True,128,64,128,2,2,2,inf +sdd,False,True,128,128,8,2,2,2,inf +sdd,False,True,128,128,16,2,2,2,inf +sdd,False,True,128,128,32,2,2,2,inf +sdd,False,True,128,128,64,2,2,2,inf +sdd,False,True,128,128,128,2,2,2,inf +sdd,True,False,8,8,8,2,2,2,331.9246826171875 +sdd,True,False,8,8,16,4,2,2,174.82608642578126 +sdd,True,False,8,8,32,2,2,8,137.649462890625 +sdd,True,False,8,8,64,2,2,16,135.867578125 +sdd,True,False,8,8,128,2,2,2,inf +sdd,True,False,8,16,8,4,4,2,286.2996337890625 +sdd,True,False,8,16,16,4,16,2,159.4926025390625 +sdd,True,False,8,16,32,4,16,2,137.52545166015625 +sdd,True,False,8,16,64,8,2,4,138.6809326171875 +sdd,True,False,8,16,128,2,4,16,89.31604614257813 +sdd,True,False,8,32,8,2,4,2,272.43427734375 +sdd,True,False,8,32,16,4,4,2,154.5417724609375 +sdd,True,False,8,32,32,4,4,2,139.97801513671874 +sdd,True,False,8,32,64,2,4,4,138.064794921875 +sdd,True,False,8,32,128,4,2,8,114.269384765625 +sdd,True,False,8,64,8,2,4,2,268.17177734375 +sdd,True,False,8,64,16,2,4,2,147.4407470703125 +sdd,True,False,8,64,32,2,4,2,140.5033447265625 +sdd,True,False,8,64,64,2,4,2,139.827099609375 +sdd,True,False,8,64,128,4,8,2,140.272216796875 +sdd,True,False,8,128,8,2,8,2,306.680810546875 +sdd,True,False,8,128,16,2,16,2,159.9556640625 +sdd,True,False,8,128,32,2,2,2,140.867578125 +sdd,True,False,8,128,64,2,8,2,140.7084228515625 +sdd,True,False,8,128,128,2,2,2,inf +sdd,True,False,16,8,8,4,8,2,167.96077880859374 +sdd,True,False,16,8,16,4,4,2,89.021435546875 +sdd,True,False,16,8,32,2,2,8,68.0089599609375 +sdd,True,False,16,8,64,2,4,16,60.29833984375 +sdd,True,False,16,8,128,8,2,8,67.69019165039063 +sdd,True,False,16,16,8,2,8,2,147.019873046875 +sdd,True,False,16,16,16,4,8,2,81.14429931640625 +sdd,True,False,16,16,32,2,16,8,65.33638916015624 +sdd,True,False,16,16,64,2,16,8,67.8930419921875 +sdd,True,False,16,16,128,4,2,8,69.011865234375 +sdd,True,False,16,32,8,2,8,2,143.04644775390625 +sdd,True,False,16,32,16,4,2,2,78.46348876953125 +sdd,True,False,16,32,32,2,16,4,70.1560546875 +sdd,True,False,16,32,64,4,4,2,70.5028076171875 +sdd,True,False,16,32,128,2,8,8,69.9884521484375 +sdd,True,False,16,64,8,2,8,2,128.64287109375 +sdd,True,False,16,64,16,4,2,2,76.81822509765625 +sdd,True,False,16,64,32,2,4,2,71.68389282226562 +sdd,True,False,16,64,64,2,2,4,72.18800048828125 +sdd,True,False,16,64,128,2,16,4,70.818896484375 +sdd,True,False,16,128,8,2,4,2,133.293359375 +sdd,True,False,16,128,16,2,4,2,82.89126586914062 +sdd,True,False,16,128,32,2,16,2,72.04392700195312 +sdd,True,False,16,128,64,2,16,2,70.67965698242188 +sdd,True,False,16,128,128,2,2,2,inf +sdd,True,False,32,8,8,4,4,4,176.229248046875 +sdd,True,False,32,8,16,4,4,4,60.4487548828125 +sdd,True,False,32,8,32,8,4,4,38.89284973144531 +sdd,True,False,32,8,64,4,4,16,35.1984619140625 +sdd,True,False,32,8,128,16,4,4,34.594305419921874 +sdd,True,False,32,16,8,4,8,2,106.14056396484376 +sdd,True,False,32,16,16,8,2,2,58.323150634765625 +sdd,True,False,32,16,32,8,16,4,37.98722534179687 +sdd,True,False,32,16,64,4,4,8,35.280792236328125 +sdd,True,False,32,16,128,4,2,8,35.26717529296875 +sdd,True,False,32,32,8,4,4,2,131.6276123046875 +sdd,True,False,32,32,16,8,16,2,67.39619750976563 +sdd,True,False,32,32,32,16,16,2,42.02823791503906 +sdd,True,False,32,32,64,4,2,4,37.22391357421875 +sdd,True,False,32,32,128,16,2,2,36.79732055664063 +sdd,True,False,32,64,8,4,4,2,132.349853515625 +sdd,True,False,32,64,16,8,4,2,70.10160522460937 +sdd,True,False,32,64,32,8,16,2,44.479888916015625 +sdd,True,False,32,64,64,4,16,4,37.58006286621094 +sdd,True,False,32,64,128,8,8,2,36.60861511230469 +sdd,True,False,32,128,8,2,2,2,inf +sdd,True,False,32,128,16,2,2,2,inf +sdd,True,False,32,128,32,2,2,2,inf +sdd,True,False,32,128,64,2,2,2,inf +sdd,True,False,32,128,128,2,2,2,inf +sdd,True,False,64,8,8,8,4,4,268.561083984375 +sdd,True,False,64,8,16,8,4,4,78.65259399414063 +sdd,True,False,64,8,32,8,2,8,45.51227111816407 +sdd,True,False,64,8,64,8,2,8,32.8384521484375 +sdd,True,False,64,8,128,16,2,4,27.167947387695317 +sdd,True,False,64,16,8,8,8,2,153.704248046875 +sdd,True,False,64,16,16,16,4,2,80.04147338867188 +sdd,True,False,64,16,32,8,4,8,48.03604431152344 +sdd,True,False,64,16,64,8,16,8,33.970159912109374 +sdd,True,False,64,16,128,8,16,8,27.25079040527344 +sdd,True,False,64,32,8,8,4,2,181.22647705078128 +sdd,True,False,64,32,16,16,16,2,93.31394653320312 +sdd,True,False,64,32,32,16,8,2,52.166839599609375 +sdd,True,False,64,32,64,16,4,4,34.93907165527344 +sdd,True,False,64,32,128,8,8,8,27.719146728515625 +sdd,True,False,64,64,8,2,2,2,inf +sdd,True,False,64,64,16,2,2,2,inf +sdd,True,False,64,64,32,2,2,2,inf +sdd,True,False,64,64,64,2,2,2,inf +sdd,True,False,64,64,128,2,2,2,inf +sdd,True,False,64,128,8,2,2,2,inf +sdd,True,False,64,128,16,2,2,2,inf +sdd,True,False,64,128,32,2,2,2,inf +sdd,True,False,64,128,64,2,2,2,inf +sdd,True,False,64,128,128,2,2,2,inf +sdd,True,False,128,8,8,2,2,2,inf +sdd,True,False,128,8,16,8,2,8,131.68056640625 +sdd,True,False,128,8,32,8,8,8,74.3488525390625 +sdd,True,False,128,8,64,16,2,4,46.31574401855469 +sdd,True,False,128,8,128,16,4,4,34.64796142578125 +sdd,True,False,128,16,8,16,2,2,263.7873046875 +sdd,True,False,128,16,16,8,16,4,138.96048583984376 +sdd,True,False,128,16,32,8,4,4,76.269775390625 +sdd,True,False,128,16,64,8,8,8,47.12089538574219 +sdd,True,False,128,16,128,8,2,8,33.60276489257812 +sdd,True,False,128,32,8,2,2,2,inf +sdd,True,False,128,32,16,2,2,2,inf +sdd,True,False,128,32,32,2,2,2,inf +sdd,True,False,128,32,64,2,2,2,inf +sdd,True,False,128,32,128,2,2,2,inf +sdd,True,False,128,64,8,2,2,2,inf +sdd,True,False,128,64,16,2,2,2,inf +sdd,True,False,128,64,32,2,2,2,inf +sdd,True,False,128,64,64,2,2,2,inf +sdd,True,False,128,64,128,2,2,2,inf +sdd,True,False,128,128,8,2,2,2,inf +sdd,True,False,128,128,16,2,2,2,inf +sdd,True,False,128,128,32,2,2,2,inf +sdd,True,False,128,128,64,2,2,2,inf +sdd,True,False,128,128,128,2,2,2,inf +sdd,True,True,8,8,8,2,4,2,394.5120849609375 +sdd,True,True,8,8,16,4,4,4,392.31640625 +sdd,True,True,8,8,32,8,2,8,393.6521240234375 +sdd,True,True,8,8,64,8,8,16,399.4498046875 +sdd,True,True,8,8,128,8,4,16,368.1553466796875 +sdd,True,True,8,16,8,2,2,2,211.0720947265625 +sdd,True,True,8,16,16,4,8,2,198.96678466796877 +sdd,True,True,8,16,32,2,4,8,135.47017822265624 +sdd,True,True,8,16,64,2,16,16,196.57532958984376 +sdd,True,True,8,16,128,4,2,16,212.0193115234375 +sdd,True,True,8,32,8,2,4,2,276.9596435546875 +sdd,True,True,8,32,16,2,8,2,140.61495361328124 +sdd,True,True,8,32,32,4,16,2,140.7639404296875 +sdd,True,True,8,32,64,2,16,4,132.995361328125 +sdd,True,True,8,32,128,2,8,8,139.485595703125 +sdd,True,True,8,64,8,2,8,2,388.26220703125 +sdd,True,True,8,64,16,2,8,2,188.8723876953125 +sdd,True,True,8,64,32,4,2,2,161.6533447265625 +sdd,True,True,8,64,64,4,2,2,158.7583740234375 +sdd,True,True,8,64,128,4,2,2,161.5098876953125 +sdd,True,True,8,128,8,2,2,2,inf +sdd,True,True,8,128,16,2,8,2,303.521484375 +sdd,True,True,8,128,32,2,8,2,286.98173828125 +sdd,True,True,8,128,64,2,16,2,287.9983642578125 +sdd,True,True,8,128,128,2,2,2,inf +sdd,True,True,16,8,8,4,8,2,199.2115234375 +sdd,True,True,16,8,16,2,8,8,163.0445556640625 +sdd,True,True,16,8,32,8,4,2,191.45748291015624 +sdd,True,True,16,8,64,2,8,16,174.1875244140625 +sdd,True,True,16,8,128,16,2,4,181.5934936523437 +sdd,True,True,16,16,8,2,2,2,109.806982421875 +sdd,True,True,16,16,16,2,8,2,100.20831298828124 +sdd,True,True,16,16,32,2,2,8,90.9896728515625 +sdd,True,True,16,16,64,4,4,16,98.95485229492188 +sdd,True,True,16,16,128,16,4,8,101.83095092773438 +sdd,True,True,16,32,8,2,2,2,144.5086181640625 +sdd,True,True,16,32,16,4,2,2,88.58931274414063 +sdd,True,True,16,32,32,8,16,2,71.0134765625 +sdd,True,True,16,32,64,4,2,2,71.24826049804688 +sdd,True,True,16,32,128,4,2,4,58.90867309570312 +sdd,True,True,16,64,8,2,8,2,174.74434814453124 +sdd,True,True,16,64,16,4,2,2,119.31842041015624 +sdd,True,True,16,64,32,8,4,2,97.280615234375 +sdd,True,True,16,64,64,8,8,2,92.27171630859377 +sdd,True,True,16,64,128,8,16,2,90.8737548828125 +sdd,True,True,16,128,8,2,4,2,246.051220703125 +sdd,True,True,16,128,16,4,2,2,185.49586181640623 +sdd,True,True,16,128,32,4,2,2,157.97952880859376 +sdd,True,True,16,128,64,4,2,2,155.3543212890625 +sdd,True,True,16,128,128,2,2,2,inf +sdd,True,True,32,8,8,8,8,2,176.8541015625 +sdd,True,True,32,8,16,8,2,2,91.57816162109376 +sdd,True,True,32,8,32,2,2,16,84.89727783203125 +sdd,True,True,32,8,64,4,8,16,94.82096557617189 +sdd,True,True,32,8,128,16,8,4,46.93627014160156 +sdd,True,True,32,16,8,4,16,2,111.23956298828124 +sdd,True,True,32,16,16,8,4,2,62.79331665039062 +sdd,True,True,32,16,32,4,8,8,47.13816528320312 +sdd,True,True,32,16,64,4,2,8,46.6570068359375 +sdd,True,True,32,16,128,4,2,8,47.15612182617188 +sdd,True,True,32,32,8,4,4,2,144.31630859375 +sdd,True,True,32,32,16,8,4,2,80.0415771484375 +sdd,True,True,32,32,32,16,4,2,54.29685668945312 +sdd,True,True,32,32,64,8,8,4,42.0193115234375 +sdd,True,True,32,32,128,8,4,4,39.480831909179685 +sdd,True,True,32,64,8,4,4,2,158.8496337890625 +sdd,True,True,32,64,16,8,4,2,97.16746215820312 +sdd,True,True,32,64,32,8,4,2,70.76669311523438 +sdd,True,True,32,64,64,4,16,4,60.86021118164062 +sdd,True,True,32,64,128,8,4,4,57.81491088867188 +sdd,True,True,32,128,8,2,2,2,inf +sdd,True,True,32,128,16,2,2,2,inf +sdd,True,True,32,128,32,2,2,2,inf +sdd,True,True,32,128,64,2,2,2,inf +sdd,True,True,32,128,128,2,2,2,inf +sdd,True,True,64,8,8,8,4,4,271.0641357421875 +sdd,True,True,64,8,16,8,8,4,79.34187622070313 +sdd,True,True,64,8,32,8,2,8,52.03404541015625 +sdd,True,True,64,8,64,4,4,8,48.49459228515625 +sdd,True,True,64,8,128,4,8,16,44.99393615722656 +sdd,True,True,64,16,8,8,8,2,156.43052978515624 +sdd,True,True,64,16,16,16,8,2,82.55426635742188 +sdd,True,True,64,16,32,8,4,8,50.67550659179688 +sdd,True,True,64,16,64,8,16,8,36.721435546875 +sdd,True,True,64,16,128,8,16,8,29.921484375 +sdd,True,True,64,32,8,8,8,2,186.93447265625 +sdd,True,True,64,32,16,16,16,2,98.87877197265624 +sdd,True,True,64,32,32,8,2,4,58.901708984375 +sdd,True,True,64,32,64,8,8,4,40.33217468261719 +sdd,True,True,64,32,128,8,2,8,33.15773315429688 +sdd,True,True,64,64,8,2,2,2,inf +sdd,True,True,64,64,16,2,2,2,inf +sdd,True,True,64,64,32,2,2,2,inf +sdd,True,True,64,64,64,2,2,2,inf +sdd,True,True,64,64,128,2,2,2,inf +sdd,True,True,64,128,8,2,2,2,inf +sdd,True,True,64,128,16,2,2,2,inf +sdd,True,True,64,128,32,2,2,2,inf +sdd,True,True,64,128,64,2,2,2,inf +sdd,True,True,64,128,128,2,2,2,inf +sdd,True,True,128,8,8,2,2,2,inf +sdd,True,True,128,8,16,16,2,4,132.443115234375 +sdd,True,True,128,8,32,8,4,8,74.8948486328125 +sdd,True,True,128,8,64,16,2,4,47.04788513183594 +sdd,True,True,128,8,128,16,4,4,34.53890686035156 +sdd,True,True,128,16,8,16,2,2,265.3760498046875 +sdd,True,True,128,16,16,8,16,4,139.68486328125 +sdd,True,True,128,16,32,8,4,4,77.770751953125 +sdd,True,True,128,16,64,8,2,8,48.09429626464844 +sdd,True,True,128,16,128,8,2,8,34.45206909179687 +sdd,True,True,128,32,8,2,2,2,inf +sdd,True,True,128,32,16,2,2,2,inf +sdd,True,True,128,32,32,2,2,2,inf +sdd,True,True,128,32,64,2,2,2,inf +sdd,True,True,128,32,128,2,2,2,inf +sdd,True,True,128,64,8,2,2,2,inf +sdd,True,True,128,64,16,2,2,2,inf +sdd,True,True,128,64,32,2,2,2,inf +sdd,True,True,128,64,64,2,2,2,inf +sdd,True,True,128,64,128,2,2,2,inf +sdd,True,True,128,128,8,2,2,2,inf +sdd,True,True,128,128,16,2,2,2,inf +sdd,True,True,128,128,32,2,2,2,inf +sdd,True,True,128,128,64,2,2,2,inf +sdd,True,True,128,128,128,2,2,2,inf diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.61.csv b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.61.csv new file mode 100644 index 00000000..58dcfcd7 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.61.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,1,8.167164611816407 +8,16,1,4.421769714355468 +8,32,1,2.800918388366699 +8,64,1,2.9330944061279296 +8,128,1,3.078156852722168 +16,8,1,8.232147216796875 +16,16,1,4.562102508544922 +16,32,1,2.7749536514282225 +16,64,1,2.9247968673706053 +16,128,1,2.995363235473633 +32,8,1,8.344572448730469 +32,16,1,4.420640182495117 +32,32,1,2.753011131286621 +32,64,1,2.8409088134765623 +32,128,1,2.924198341369629 +64,8,1,8.27704315185547 +64,16,1,4.348825454711914 +64,32,1,2.731558418273926 +64,64,1,2.8246944427490237 +64,128,1,3.0188703536987305 +128,8,1,8.223318481445313 +128,16,1,4.277807998657226 +128,32,16,2.7176639556884767 +128,64,1,2.8792800903320312 +128,128,1,3.009411239624024 diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.61.csv b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.61.csv new file mode 100644 index 00000000..f143e832 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.61.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,4,2.994144058227539 +8,16,8,2.9954687118530274 +8,32,8,2.824659156799316 +8,64,8,2.9168575286865237 +8,128,8,3.026012802124024 +16,8,4,3.0000864028930665 +16,16,16,2.986384010314941 +16,32,8,2.817740821838379 +16,64,16,2.9176799774169924 +16,128,16,2.981760025024414 +32,8,4,3.01429443359375 +32,16,16,2.968582344055176 +32,32,8,2.814208030700684 +32,64,8,2.889743995666504 +32,128,16,2.94454402923584 +64,8,4,3.002393531799316 +64,16,16,2.9718015670776365 +64,32,8,2.797088050842285 +64,64,16,2.8678495407104494 +64,128,16,2.9687328338623047 +128,8,4,3.0193119049072266 +128,16,16,2.948438453674316 +128,32,8,2.786729621887207 +128,64,16,2.8843936920166016 +128,128,8,2.9729888916015623 diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 05e0bbdc..18da4f15 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -26,7 +26,7 @@ def _get_sparta_matmul_lut(): return pd.read_csv(io.StringIO(lut_text)) -SPARTA_LUT = _get_sparta_matmul_lut() +SPARTA_MATMUL_LUT = _get_sparta_matmul_lut() class SparseMatMulKernel(KernelBase): @@ -55,10 +55,10 @@ def __init__( self._sparse_block_H = '' self._sparse_block_W = '' self._tesa_vars = [] - mode_filter = SPARTA_LUT['mode'] == self._mode - trans_A_filter = SPARTA_LUT['trans_A'] == self._transpose_A - trans_B_filter = SPARTA_LUT['trans_B'] == self._transpose_B - self._lut = SPARTA_LUT[mode_filter & trans_A_filter & trans_B_filter] + mode_filter = SPARTA_MATMUL_LUT['mode'] == self._mode + trans_A_filter = SPARTA_MATMUL_LUT['trans_A'] == self._transpose_A + trans_B_filter = SPARTA_MATMUL_LUT['trans_B'] == self._transpose_B + self._lut = SPARTA_MATMUL_LUT[mode_filter & trans_A_filter & trans_B_filter] super().__init__() def _set_ports(self): @@ -248,7 +248,7 @@ def _add_parameters(self): self._add_parameter( f'BLOCK_SIZE_{dim}_VALUE', is_tunable=True, - search_space=TunableItemCfg('choice', [16, 32, 64]) + search_space=TunableItemCfg('choice', [8, 16, 32, 64]) ) self._add_parameter( f'THREAD_SIZE_{dim}_VALUE', @@ -268,6 +268,7 @@ def set_parameters(self, params: Dict[str, Any]): BN_filter = self._lut['BN'] == BN row = self._lut[BM_filter & BK_filter & BN_filter] assert len(row) > 0, f'block shape ({BM}, {BK}, {BN}) not found in LUT' + assert float(row['latency']) < float('inf'), f'block shape ({BM}, {BK}, {BN}) is invalid' row = row.reset_index(drop=True).iloc[0, :] TM, TK, TN = row['TM'], row['TK'], row['TN'] self.set_parameter('THREAD_SIZE_M_VALUE', int(TM)) diff --git a/sparta/specializer/kernels/softmax.py b/sparta/specializer/kernels/softmax.py index 759f3a4b..1181e930 100644 --- a/sparta/specializer/kernels/softmax.py +++ b/sparta/specializer/kernels/softmax.py @@ -1,19 +1,43 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. +import io from typing import Any, Dict, Tuple import importlib.resources as res import torch import jinja2 import numpy as np +import pandas as pd from sparta.tuning import TunableItemCfg -from sparta.specializer.kernels import templates +from sparta.specializer.kernels import templates, look_up_tables from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference +def _get_sparta_softmax_lut(): + major, minor = torch.cuda.get_device_capability() + try: + forward_lut_file = f'softmax.forward.sparta.{major}{minor}.csv' + forward_lut_text = res.read_text(look_up_tables, forward_lut_file) + except FileNotFoundError: + forward_lut_file = f'softmax.backward.sparta.default.csv' + forward_lut_text = res.read_text(look_up_tables, forward_lut_file) + forward_lut = pd.read_csv(io.StringIO(forward_lut_text)) + try: + backward_lut_file = f'softmax.backward.sparta.{major}{minor}.csv' + backward_lut_text = res.read_text(look_up_tables, backward_lut_file) + except FileNotFoundError: + backward_lut_file = f'softmax.backward.sparta.default.csv' + backward_lut_text = res.read_text(look_up_tables, backward_lut_file) + backward_lut = pd.read_csv(io.StringIO(backward_lut_text)) + return forward_lut, backward_lut + + +SPARTA_SOFTMAX_FORWARD_LUT, SPARTA_SOFTMAX_BACKWARD_LUT = _get_sparta_softmax_lut() + + class SparseSoftmaxKernel(KernelBase): __algo__: str = '' @@ -161,6 +185,10 @@ class SparTASoftmaxKernel(SparseSoftmaxKernel): __algo__ = 'sparta' + def __init__(self, compressed: bool = False, dtype: str = 'float'): + super().__init__(compressed, dtype) + self._lut: pd.DataFrame = None + def _add_parameters(self): super()._add_parameters() self._add_parameter( @@ -175,24 +203,39 @@ def _add_parameters(self): ) self._add_parameter( 'ROW_TILE_VALUE', - is_tunable=True, - search_space=TunableItemCfg('choice', [1, 2, 4, 8, 16]) ) + def set_parameters(self, params: Dict[str, Any]): + super().set_parameters(params) + if 'ROW_TILE_VALUE' in params: + return + BH, BW = self.get_block_shape() + BH_filter = self._lut['BH'] == BH + BW_filter = self._lut['BW'] == BW + row = self._lut[BH_filter & BW_filter] + assert len(row) > 0, f'block shape ({BH}, {BW}) not found in LUT' + assert row['latency'] < float('inf'), f'block shape ({BH}, {BW}) is invalid' + row = row.reset_index(drop=True).iloc[0, :] + self.set_parameter('ROW_TILE_VALUE', int(row['RT'])) + def blocks_per_grid(self): batch_size, H, W = self.get_shape() RT = self.get_parameter('ROW_TILE_VALUE') return (H // RT, batch_size, 1) def threads_per_block(self) -> Tuple[int]: + BW = self.get_parameter('BLOCK_SIZE_W_VALUE') RT = self.get_parameter('ROW_TILE_VALUE') - return (RT * 32, 1, 1) + return (RT * min(BW, 32), 1, 1) def _check_parameters(self, params: Dict[str, Any]): BH = params['BLOCK_SIZE_H_VALUE'] BW = params['BLOCK_SIZE_W_VALUE'] - RT = params['ROW_TILE_VALUE'] - assert BH > RT + assert BH & (BH - 1) == 0 + assert BW & (BW - 1) == 0 + if 'ROW_TILE_VALUE' in params: + RT = params['ROW_TILE_VALUE'] + assert BH >= RT def get_kernel_code(self): template_file = f'{self.__algo__}_sparse_softmax_{self.__direction__}.cuh.j2' @@ -202,9 +245,13 @@ def get_kernel_code(self): class SparTASparseSoftmaxForwardKernel(SparseSoftmaxForwardKernel, SparTASoftmaxKernel): - pass + def __init__(self, compressed: bool = False, dtype: str = 'float'): + super().__init__(compressed, dtype) + self._lut = SPARTA_SOFTMAX_FORWARD_LUT class SparTASparseSoftmaxBackwardKernel(SparseSoftmaxBackwardKernel, SparTASoftmaxKernel): - pass + def __init__(self, compressed: bool = False, dtype: str = 'float'): + super().__init__(compressed, dtype) + self._lut = SPARTA_SOFTMAX_BACKWARD_LUT diff --git a/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 index a51833e1..c4b5b8d8 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 @@ -1,11 +1,9 @@ {# Copyright (c) Microsoft Corporation. #} {# Licensed under the MIT license. #} -{% if BLOCK_SIZE_W_VALUE < 32 %} -#define FULL_MASK 0x{% for _ in range(BLOCK_SIZE_W_VALUE // 4) %}f{% endfor %} -{% else %} -#define FULL_MASK 0xffffffff -{% endif %} +{% set WARP_SIZE = [32, BLOCK_SIZE_W_VALUE]|min %} +{% set INI_OFFSET = WARP_SIZE // 2 %} +#define FULL_MASK 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} const int H = {{ GLOBAL_H_VALUE }}; const int W = {{ GLOBAL_W_VALUE }}; @@ -35,15 +33,15 @@ __global__ void SPARSE_SOFTMAX( uint blk_row_idx = blockIdx.x / (block_h/row_tile) ; int block_inter_row = (blockIdx.x % (block_h/row_tile)) * row_tile; - uint bm = threadIdx.x / 32; - uint bn = threadIdx.x % 32; + uint bm = threadIdx.x / {{ WARP_SIZE }}; + uint bn = threadIdx.x % {{ WARP_SIZE }}; float regSum = 0.0f; int block_seq_start = row_ptr[blk_row_idx]; int block_seq_end = row_ptr[blk_row_idx+1]; - uint index_list[W / 32]; + uint index_list[W / {{ WARP_SIZE }}]; int val_num = 0; - for (int block_inter_col = bn; block_inter_col < block_w; block_inter_col += 32) { + for (int block_inter_col = bn; block_inter_col < block_w; block_inter_col += {{ WARP_SIZE }}) { for (int block_seq = block_seq_start; block_seq < block_seq_end; block_seq++) { {% if COMPRESSED %} uint index = block_h * block_w * block_seq + @@ -63,7 +61,7 @@ __global__ void SPARSE_SOFTMAX( regSum += out_val[index] * out_grad[index]; } - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = {{ INI_OFFSET }}; offset > 0; offset >>= 1) { regSum += __shfl_down_sync(FULL_MASK, regSum, offset); } regSum = __shfl_sync(FULL_MASK, regSum, 0); diff --git a/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 index 975519d4..6e5642f2 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 @@ -1,11 +1,9 @@ {# Copyright (c) Microsoft Corporation. #} {# Licensed under the MIT license. #} -{% if BLOCK_SIZE_W_VALUE < 32 %} -#define FULL_MASK 0x{% for _ in range(BLOCK_SIZE_W_VALUE // 4) %}f{% endfor %} -{% else %} -#define FULL_MASK 0xffffffff -{% endif %} +{% set WARP_SIZE = [32, BLOCK_SIZE_W_VALUE]|min %} +{% set INI_OFFSET = WARP_SIZE // 2 %} +#define FULL_MASK 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} const int H = {{ GLOBAL_H_VALUE }}; const int W = {{ GLOBAL_W_VALUE }}; @@ -32,16 +30,16 @@ __global__ void SPARSE_SOFTMAX( uint blk_row_idx = blockIdx.x / (block_h/row_tile) ; int block_inter_row = (blockIdx.x % (block_h/row_tile)) * row_tile; - uint bm = threadIdx.x / 32; - uint bn = threadIdx.x % 32; + uint bm = threadIdx.x / {{ WARP_SIZE }}; + uint bn = threadIdx.x % {{ WARP_SIZE }}; float regSum = 0.0f; float regMax = -100000.0; int block_seq_start = row_ptr[blk_row_idx]; int block_seq_end = row_ptr[blk_row_idx+1]; - uint index_list[W / 32]; + uint index_list[W / {{ WARP_SIZE }}]; int val_num = 0; - for (int block_inter_col = bn; block_inter_col < block_w; block_inter_col += 32) { + for (int block_inter_col = bn; block_inter_col < block_w; block_inter_col += {{ WARP_SIZE }}) { for (int block_seq = block_seq_start; block_seq < block_seq_end; block_seq++) { {% if COMPRESSED %} uint index = block_h * block_w * block_seq + @@ -65,7 +63,7 @@ __global__ void SPARSE_SOFTMAX( regMax = max(regMax, in_val[index]); } - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = {{ INI_OFFSET }}; offset > 0; offset >>= 1) { regMax = max(regMax, __shfl_down_sync(FULL_MASK, regMax, offset)); } regMax = __shfl_sync(FULL_MASK, regMax, 0); @@ -75,7 +73,7 @@ __global__ void SPARSE_SOFTMAX( regSum += expf((in_val[index] - regMax) * temperature); } - for (int offset = 16; offset > 0; offset >>= 1) { + for (int offset = {{ INI_OFFSET }}; offset > 0; offset >>= 1) { regSum += __shfl_down_sync(FULL_MASK, regSum, offset); } regSum = __shfl_sync(FULL_MASK, regSum, 0); diff --git a/sparta/tesa/block_compressed.py b/sparta/tesa/block_compressed.py index 156166e4..7f370a45 100644 --- a/sparta/tesa/block_compressed.py +++ b/sparta/tesa/block_compressed.py @@ -96,11 +96,15 @@ def get_block_mask(self): block_mask_val = torch.ones( size=(self.block_nnz, ), dtype=torch.uint8, - device=self.row_ptr.device + device=self.row_ptr.device, ) - row_ptr = self.row_ptr.to(torch.int64) - col_idx = self.BCSR_idx.bitwise_and(0xffff).to(torch.int64) - return torch.sparse_csr_tensor(row_ptr, col_idx, block_mask_val).to_dense() + col_idx = self.BCSR_idx.bitwise_and(0xffff) + return torch.sparse_csr_tensor( + crow_indices=self.row_ptr, + col_indices=col_idx, + values=block_mask_val, + size=(self.row_num, self.col_num), + ).to_dense() def convert(self, dense: torch.Tensor): return self.function_context.convert( @@ -144,10 +148,15 @@ def get_block_mask(self): block_mask_val = torch.ones( size=(self.block_nnz, ), dtype=torch.uint8, - device=self.col_ptr.device + device=self.col_ptr.device, ) row_idx = self.BCSC_idx.bitwise_and(0xffff) - return torch.sparse_csr_tensor(self.col_ptr, row_idx, block_mask_val).to_dense().T + return torch.sparse_csr_tensor( + crow_indices=self.col_ptr, + col_indices=row_idx, + values=block_mask_val, + size=(self.col_num, self.row_num), + ).to_dense().T def convert(self, dense: torch.Tensor): return self.function_context.convert( diff --git a/sparta/tesa/templates/block_compressed.cu.j2 b/sparta/tesa/templates/block_compressed.cu.j2 index 1f3ec47f..3f90b968 100644 --- a/sparta/tesa/templates/block_compressed.cu.j2 +++ b/sparta/tesa/templates/block_compressed.cu.j2 @@ -88,27 +88,27 @@ __global__ void bcs_index_1( uint block_offset = _pos / ({{ BW }} / {{ MASK_BATCH }}) * W + _pos % ({{ BW }} / {{ MASK_BATCH }}) * {{ MASK_BATCH }}; {% if BW == 4 %} uint data = __ldg((const uint*)(add_ptr_b(mask, global_offset + block_offset))); - flag += data; + flag = (flag || data); {% elif BW == 8 %} uint2 data = __ldg((const uint2*)(add_ptr_b(mask, global_offset + block_offset))); - flag += data.x + data.y; + flag = (flag || data.x || data.y); {% else %} uint4 data = __ldg((const uint4*)(add_ptr_b(mask, global_offset + block_offset))); - flag += data.x + data.y + data.z + data.w; + flag = (flag || data.x || data.y || data.z || data.w); {% endif %} } - reduce[tid] = flag > 0 ? 1 : 0; + reduce[tid] = flag; // Fast tree reduce accross the block __syncthreads(); for (uint s = blockDim.x >> 1; s > 32; s >>= 1) { - if (tid < s) reduce[tid] += reduce[tid + s]; + if (tid < s) reduce[tid] = (reduce[tid] || reduce[tid + s]); __syncthreads(); } if (tid < {{ [BLOCK_DIM / 2, 32]|min }}) warpReduceMask(reduce, tid); __syncthreads(); - if (tid == 0 & reduce[0] > 0) { + if (tid == 0 && reduce[0] > 0) { {% if BCSR %} // Record BCSR column index, +1 because 0 means empty int col_pos_id = atomicAdd(&extra_buffer[by], 1); diff --git a/test/lut_maker/sparta_matmul.py b/test/lut_maker/sparta_matmul.py index 47c35235..b6a468da 100644 --- a/test/lut_maker/sparta_matmul.py +++ b/test/lut_maker/sparta_matmul.py @@ -74,8 +74,7 @@ def test_sparta_matmul_kernel( return latency -if __name__ == '__main__': - _logger.setLevel(logging.DEBUG) +def make_sparta_matmul_lut(): major, minor = torch.cuda.get_device_capability() lut_file = os.path.join( 'sparta', @@ -113,8 +112,13 @@ def test_sparta_matmul_kernel( _logger.info(f'[{i} / {num}] {params} => {latency} ms') df = pd.read_csv(log_file) - df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})] + df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] with open(lut_file, 'w') as f: f.write(df.reset_index(drop=True).to_csv(index=False)) _logger.info(f'========== Finished. Output: {lut_file} ==========') + + +if __name__ == '__main__': + _logger.setLevel(logging.DEBUG) + make_sparta_matmul_lut() diff --git a/test/lut_maker/sparta_softmax.py b/test/lut_maker/sparta_softmax.py new file mode 100644 index 00000000..fc072954 --- /dev/null +++ b/test/lut_maker/sparta_softmax.py @@ -0,0 +1,122 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import logging +import itertools +from typing import Dict + +import torch +import numpy as np +import pandas as pd + +from sparta.specializer.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel +from sparta.testing import block_mask + + +SIZE = 4096 +RANDOM_SEED = 2022 +SEARCH_SPACE = { + 'BH': [8, 16, 32, 64, 128], + 'BW': [8, 16, 32, 64, 128], + 'RT': [1, 2, 4, 8, 16], +} +HYPER_PARAMS = ['BH', 'BW'] + + +_logger = logging.Logger(__name__) +_handler = logging.StreamHandler() +_logger.addHandler(_handler) + + +def test_sparta_softmax_kernel( + data: Dict, + mask: torch.Tensor, + direction: str, + BH: int, + BW: int, + RT: int, +): + if direction == 'forward': + kernel = SparTASparseSoftmaxForwardKernel(compressed=True) + elif direction == 'backward': + kernel = SparTASparseSoftmaxBackwardKernel(compressed=True) + else: + raise ValueError(f'unrecognized direction: {direction}') + + try: + kernel.ports['y'].set_mask(mask) + kernel.set_shape(1, SIZE, SIZE) + kernel.compile({ + 'BLOCK_SIZE_H_VALUE': BH, + 'BLOCK_SIZE_W_VALUE': BW, + 'ROW_TILE_VALUE': RT, + }) + if direction == 'forward': + inputs = [data['x'], data['T']] + else: + inputs = [data['grad_y'], data['y'], data['T']] + latency = kernel.test(inputs, num_warmups=10, num_iters=10, cuda=False) + except: + latency = float('inf') + + return latency + + +def make_sparta_softmax_lut(direction: str): + major, minor = torch.cuda.get_device_capability() + lut_file = os.path.join( + 'sparta', + 'specializer', + 'kernels', + 'look_up_tables', + f'softmax.{direction}.sparta.{major}{minor}.csv' + ) + log_file = os.path.join( + 'test', + 'lut_maker', + f'softmax.{direction}.sparta.{major}{minor}.log.csv' + ) + _logger.info(f'========== Making LUT: {lut_file} ==========') + + num = 1 + keys, values = [], [] + for k, v in SEARCH_SPACE.items(): + keys.append(k) + values.append(v) + num *= len(v) + + with open(log_file, 'w') as f: + f.write(','.join(keys) + ',latency\n') + + torch.manual_seed(RANDOM_SEED) + data = {} + data['x'] = torch.rand(size=(1, SIZE, SIZE), device='cuda') + data['T'] = np.float32(1 / np.sqrt(SIZE)) + data['y'] = torch.rand(size=(1, SIZE, SIZE), device='cuda') + data['grad_y'] = torch.rand(size=(1, SIZE, SIZE), device='cuda') + mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') + + for i, params in enumerate(itertools.product(*values)): + latency = test_sparta_softmax_kernel( + data, + mask, + direction, + **{k: v for k, v in zip(keys, params)} + ) + with open(log_file, 'a') as f: + f.write(','.join([str(x) for x in params]) + f',{latency}\n') + _logger.info(f'[{i} / {num}] {params} => {latency} ms') + + df = pd.read_csv(log_file) + df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] + with open(lut_file, 'w') as f: + f.write(df.reset_index(drop=True).to_csv(index=False)) + + _logger.info(f'========== Finished. Output: {lut_file} ==========') + + +if __name__ == '__main__': + _logger.setLevel(logging.DEBUG) + make_sparta_softmax_lut('forward') + make_sparta_softmax_lut('backward') diff --git a/test/unit/test_tesa.py b/test/unit/test_tesa.py index f01fb595..ffd4f2b7 100644 --- a/test/unit/test_tesa.py +++ b/test/unit/test_tesa.py @@ -17,15 +17,16 @@ def reduce_mask_ref(mask: torch.Tensor, BH: int, BW: int): return reduced +@pytest.mark.parametrize("sparsity", [0, 0.999]) @pytest.mark.parametrize("BH", [4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("BW", [4, 8, 16, 32, 64, 128]) def test_bcsr( + sparsity: float, BH: int, BW: int, H: int = 1024, W: int = 768, batch_size: int = 2, - sparsity: float = 0.999, random_seed: int = 2023, ): torch.manual_seed(random_seed) @@ -47,15 +48,16 @@ def test_bcsr( torch.testing.assert_close(BCSR_indexes.sum(sparse_val, axis=-2), dense.sum(-2)) +@pytest.mark.parametrize("sparsity", [0, 0.999]) @pytest.mark.parametrize("BH", [4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("BW", [4, 8, 16, 32, 64, 128]) def test_bcsc( + sparsity: float, BH: int, BW: int, H: int = 1024, W: int = 768, batch_size: int = 2, - sparsity: float = 0.999, random_seed: int = 2023, ): torch.manual_seed(random_seed) @@ -77,15 +79,16 @@ def test_bcsc( torch.testing.assert_close(BCSC_indexes.sum(sparse_val, axis=-2), dense.sum(-2)) +@pytest.mark.parametrize("sparsity", [0, 0.999]) @pytest.mark.parametrize("BH", [4, 8, 16, 32, 64, 128]) @pytest.mark.parametrize("BW", [4, 8, 16, 32, 64, 128]) def test_bcsrc( + sparsity: float, BH: int, BW: int, H: int = 1024, W: int = 768, batch_size: int = 2, - sparsity: float = 0.999, random_seed: int = 2023, ): torch.manual_seed(random_seed) diff --git a/test/unit/test_tuners.py b/test/unit/test_tuners.py index f94bbf55..8047acb7 100644 --- a/test/unit/test_tuners.py +++ b/test/unit/test_tuners.py @@ -55,10 +55,10 @@ def test_tune_sparse_linear( sparse_linear = SparseLinear(dense_linear, weight_mask=mask) if backward: count = debug_tune(sparse_linear, [sample_input], [sample_grad]) - assert count == (3 * 3 * 3 + 1) * 3 + assert count == (4 * 4 * 4 + 1) * 3 else: count = debug_tune(sparse_linear, [sample_input], None) - assert count == (3 * 3 * 3 + 1) * 1 + assert count == (4 * 4 * 4 + 1) * 1 @pytest.mark.parametrize("backward", [False, True]) @@ -101,10 +101,10 @@ def test_tune_sparse_matmul( sparse_matmul = SparseBatchMatMul(**matmul_args) if backward: count = debug_tune(sparse_matmul, [A, B], [grad_C]) - assert count == (3 * 3 * 3 + 1) * 3 + assert count == (4 * 4 * 4 + 1) * 3 else: count = debug_tune(sparse_matmul, [A, B], None) - assert count == (3 * 3 * 3 + 1) * 1 + assert count == (4 * 4 * 4 + 1) * 1 @pytest.mark.parametrize("backward", [False, True]) @@ -124,10 +124,10 @@ def test_tune_sparse_softmax( sparse_softmax = SparseSoftmax(mask=mask, temperature=dims, compressed=compressed) if backward: count = debug_tune(sparse_softmax, [sample_input], [sample_grad]) - assert count == (5 * 4 * 5) * 2 + assert count == (5 * 4) * 2 else: count = debug_tune(sparse_softmax, [sample_input], None) - assert count == (5 * 4 * 5) * 1 + assert count == (5 * 4) * 1 @pytest.mark.parametrize("backward", [False, True]) @@ -148,7 +148,7 @@ def test_tune_sparse_attention( sparse_attention = SparseAttention(mask=mask) if backward: count = debug_tune(sparse_attention, [query, key, value], [grad_out]) - assert count == (3 * 3 * 3 + 1) * 6 + (3 * 3 * 5) * 2 + assert count == (3 * 4 * 4 + 1) * 6 + (3 * 4) * 2 else: count = debug_tune(sparse_attention, [query, key, value], None) - assert count == (3 * 3 * 3 + 1) * 2 + (3 * 3 * 5) * 1 + assert count == (3 * 4 * 4 + 1) * 2 + (3 * 4) * 1 From e1cc4cd5148872de916da3a10d5ff6ccc44c35ef Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Mon, 6 Feb 2023 05:03:26 +0000 Subject: [PATCH 09/28] add 70 LUTs (as default) --- .../look_up_tables/matmul.sparta.70.csv | 1824 ++++++++++++++--- .../look_up_tables/matmul.sparta.default.csv | 1824 ++++++++++++++--- .../softmax.backward.sparta.70.csv | 26 + .../softmax.backward.sparta.default.csv | 26 + .../softmax.forward.sparta.70.csv | 26 + .../softmax.forward.sparta.default.csv | 26 + 6 files changed, 3104 insertions(+), 648 deletions(-) create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.70.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.default.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.70.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.default.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv b/sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv index 7a628265..a42e7ef2 100644 --- a/sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv +++ b/sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv @@ -1,325 +1,1501 @@ mode,trans_A,trans_B,BM,BK,BN,TM,TK,TN,latency -dds,False,False,16,16,16,4,4,4,114.3322 -dds,False,False,16,16,32,4,4,4,82.87 -dds,False,False,16,16,64,4,4,4,74.2951 -dds,False,False,16,32,16,4,4,4,120.1529 -dds,False,False,16,32,32,4,4,4,87.5296 -dds,False,False,16,32,64,4,4,4,79.7643 -dds,False,False,16,64,16,4,4,4,148.3481 -dds,False,False,16,64,32,4,4,4,92.0502 -dds,False,False,16,64,64,4,4,4,84.0627 -dds,False,False,32,16,16,4,4,4,77.7297 -dds,False,False,32,16,32,4,4,4,54.2717 -dds,False,False,32,16,64,4,4,4,52.089 -dds,False,False,32,32,16,4,4,4,83.1408 -dds,False,False,32,32,32,4,4,4,59.8523 -dds,False,False,32,32,64,4,4,4,54.2046 -dds,False,False,32,64,16,4,4,4,98.1477 -dds,False,False,32,64,32,4,4,4,69.5779 -dds,False,False,32,64,64,4,4,4,56.5643 -dds,False,False,64,16,16,4,4,4,64.2577 -dds,False,False,64,16,32,4,4,4,50.0309 -dds,False,False,64,16,64,4,4,4,50.0309 -dds,False,False,64,32,16,4,4,4,75.437 -dds,False,False,64,32,32,4,4,4,54.0036 -dds,False,False,64,32,64,4,4,4,54.0036 -dds,False,False,64,64,16,4,4,4,88.9933 -dds,False,False,64,64,32,4,4,4,56.6697 -dds,False,False,64,64,64,4,4,4,56.6697 -dds,False,True,16,16,16,4,4,4,134.6172 -dds,False,True,16,16,32,4,4,4,106.9163 -dds,False,True,16,16,64,4,4,4,101.0877 -dds,False,True,16,32,16,4,4,4,177.8903 -dds,False,True,16,32,32,4,4,4,96.7729 -dds,False,True,16,32,64,4,4,4,89.2383 -dds,False,True,16,64,16,4,4,4,281.8069 -dds,False,True,16,64,32,4,4,4,137.432 -dds,False,True,16,64,64,4,4,4,132.7488 -dds,False,True,32,16,16,4,4,4,83.1024 -dds,False,True,32,16,32,4,4,4,65.4798 -dds,False,True,32,16,64,4,4,4,59.4675 -dds,False,True,32,32,16,4,4,4,93.876 -dds,False,True,32,32,32,4,4,4,72.5978 -dds,False,True,32,32,64,4,4,4,65.1659 -dds,False,True,32,64,16,4,4,4,120.1478 -dds,False,True,32,64,32,4,4,4,98.9469 -dds,False,True,32,64,64,4,4,4,83.0436 -dds,False,True,64,16,16,4,4,4,65.8737 -dds,False,True,64,16,32,4,4,4,53.26 -dds,False,True,64,16,64,4,4,4,53.26 -dds,False,True,64,32,16,4,4,4,81.8998 -dds,False,True,64,32,32,4,4,4,60.1544 -dds,False,True,64,32,64,4,4,4,60.1544 -dds,False,True,64,64,16,4,4,4,99.884 -dds,False,True,64,64,32,4,4,4,70.5451 -dds,False,True,64,64,64,4,4,4,70.5451 -dds,True,False,16,16,16,4,4,4,135.2541 -dds,True,False,16,16,32,4,4,4,88.3433 -dds,True,False,16,16,64,4,4,4,80.7328 -dds,True,False,16,32,16,4,4,4,147.7196 -dds,True,False,16,32,32,4,4,4,83.4137 -dds,True,False,16,32,64,4,4,4,80.5606 -dds,True,False,16,64,16,4,4,4,178.6693 -dds,True,False,16,64,32,4,4,4,96.307 -dds,True,False,16,64,64,4,4,4,85.6432 -dds,True,False,32,16,16,4,4,4,97.407 -dds,True,False,32,16,32,4,4,4,64.0028 -dds,True,False,32,16,64,4,4,4,56.0378 -dds,True,False,32,32,16,4,4,4,112.9191 -dds,True,False,32,32,32,4,4,4,71.4564 -dds,True,False,32,32,64,4,4,4,59.7966 -dds,True,False,32,64,16,4,4,4,124.0738 -dds,True,False,32,64,32,4,4,4,80.3177 -dds,True,False,32,64,64,4,4,4,61.5147 -dds,True,False,64,16,16,4,4,4,124.6127 -dds,True,False,64,16,32,4,4,4,76.9557 -dds,True,False,64,16,64,4,4,4,76.9557 -dds,True,False,64,32,16,4,4,4,143.3676 -dds,True,False,64,32,32,4,4,4,84.2901 -dds,True,False,64,32,64,4,4,4,84.2901 -dds,True,False,64,64,16,4,4,4,150.3928 -dds,True,False,64,64,32,4,4,4,87.5746 -dds,True,False,64,64,64,4,4,4,87.5746 -dds,True,True,16,16,16,4,4,4,163.2706 -dds,True,True,16,16,32,4,4,4,111.4495 -dds,True,True,16,16,64,4,4,4,105.2302 -dds,True,True,16,32,16,4,4,4,215.4126 -dds,True,True,16,32,32,4,4,4,103.8571 -dds,True,True,16,32,64,4,4,4,92.4045 -dds,True,True,16,64,16,4,4,4,310.5157 -dds,True,True,16,64,32,4,4,4,149.429 -dds,True,True,16,64,64,4,4,4,134.981 -dds,True,True,32,16,16,4,4,4,103.7712 -dds,True,True,32,16,32,4,4,4,70.4374 -dds,True,True,32,16,64,4,4,4,63.2451 -dds,True,True,32,32,16,4,4,4,127.7935 -dds,True,True,32,32,32,4,4,4,85.6814 -dds,True,True,32,32,64,4,4,4,69.7582 -dds,True,True,32,64,16,4,4,4,155.0835 -dds,True,True,32,64,32,4,4,4,112.8103 -dds,True,True,32,64,64,4,4,4,91.3851 -dds,True,True,64,16,16,4,4,4,126.8568 -dds,True,True,64,16,32,4,4,4,81.2913 -dds,True,True,64,16,64,4,4,4,81.2913 -dds,True,True,64,32,16,4,4,4,149.8753 -dds,True,True,64,32,32,4,4,4,91.8136 -dds,True,True,64,32,64,4,4,4,91.8136 -dds,True,True,64,64,16,4,4,4,167.0518 -dds,True,True,64,64,32,4,4,4,105.5301 -dds,True,True,64,64,64,4,4,4,105.5301 -dsd,False,False,16,16,16,4,4,4,108.1017 -dsd,False,False,16,16,32,4,4,4,87.8983 -dsd,False,False,16,16,64,4,4,4,80.2562 -dsd,False,False,16,32,16,4,4,4,111.9973 -dsd,False,False,16,32,32,4,4,4,83.6637 -dsd,False,False,16,32,64,4,4,4,83.4532 -dsd,False,False,16,64,16,4,4,4,135.7782 -dsd,False,False,16,64,32,4,4,4,96.0233 -dsd,False,False,16,64,64,4,4,4,96.0233 -dsd,False,False,32,16,16,4,4,4,72.3478 -dsd,False,False,32,16,32,4,4,4,57.4545 -dsd,False,False,32,16,64,4,4,4,52.2037 -dsd,False,False,32,32,16,4,4,4,78.8152 -dsd,False,False,32,32,32,4,4,4,59.3432 -dsd,False,False,32,32,64,4,4,4,54.5784 -dsd,False,False,32,64,16,4,4,4,91.2725 -dsd,False,False,32,64,32,4,4,4,67.9594 -dsd,False,False,32,64,64,4,4,4,67.9594 -dsd,False,False,64,16,16,4,4,4,60.3805 -dsd,False,False,64,16,32,4,4,4,49.5804 -dsd,False,False,64,16,64,4,4,4,46.6983 -dsd,False,False,64,32,16,4,4,4,73.6614 -dsd,False,False,64,32,32,4,4,4,53.7312 -dsd,False,False,64,32,64,4,4,4,46.5469 -dsd,False,False,64,64,16,4,4,4,85.0814 -dsd,False,False,64,64,32,4,4,4,56.0396 -dsd,False,False,64,64,64,4,4,4,56.0396 -dsd,False,True,16,16,16,4,4,4,130.935 -dsd,False,True,16,16,32,4,4,4,109.5658 -dsd,False,True,16,16,64,4,4,4,92.5529 -dsd,False,True,16,32,16,4,4,4,178.6903 -dsd,False,True,16,32,32,4,4,4,101.6086 -dsd,False,True,16,32,64,4,4,4,96.7423 -dsd,False,True,16,64,16,4,4,4,278.8124 -dsd,False,True,16,64,32,4,4,4,142.8228 -dsd,False,True,16,64,64,4,4,4,142.8228 -dsd,False,True,32,16,16,4,4,4,87.9781 -dsd,False,True,32,16,32,4,4,4,66.9843 -dsd,False,True,32,16,64,4,4,4,59.9348 -dsd,False,True,32,32,16,4,4,4,94.2587 -dsd,False,True,32,32,32,4,4,4,73.2396 -dsd,False,True,32,32,64,4,4,4,65.7842 -dsd,False,True,32,64,16,4,4,4,124.7007 -dsd,False,True,32,64,32,4,4,4,98.0345 -dsd,False,True,32,64,64,4,4,4,98.0345 -dsd,False,True,64,16,16,4,4,4,68.01 -dsd,False,True,64,16,32,4,4,4,54.7027 -dsd,False,True,64,16,64,4,4,4,49.913 -dsd,False,True,64,32,16,4,4,4,82.1206 -dsd,False,True,64,32,32,4,4,4,61.7535 -dsd,False,True,64,32,64,4,4,4,53.0829 -dsd,False,True,64,64,16,4,4,4,101.6558 -dsd,False,True,64,64,32,4,4,4,70.1075 -dsd,False,True,64,64,64,4,4,4,70.1075 -dsd,True,False,16,16,16,4,4,4,128.4223 -dsd,True,False,16,16,32,4,4,4,90.0412 -dsd,True,False,16,16,64,4,4,4,83.3222 -dsd,True,False,16,32,16,4,4,4,142.1878 -dsd,True,False,16,32,32,4,4,4,86.9283 -dsd,True,False,16,32,64,4,4,4,84.5636 -dsd,True,False,16,64,16,4,4,4,163.221 -dsd,True,False,16,64,32,4,4,4,99.8338 -dsd,True,False,16,64,64,4,4,4,99.8338 -dsd,True,False,32,16,16,4,4,4,93.9119 -dsd,True,False,32,16,32,4,4,4,62.1346 -dsd,True,False,32,16,64,4,4,4,56.0104 -dsd,True,False,32,32,16,4,4,4,111.0007 -dsd,True,False,32,32,32,4,4,4,70.8843 -dsd,True,False,32,32,64,4,4,4,59.1251 -dsd,True,False,32,64,16,4,4,4,119.8824 -dsd,True,False,32,64,32,4,4,4,80.5292 -dsd,True,False,32,64,64,4,4,4,80.5292 -dsd,True,False,64,16,16,4,4,4,123.8651 -dsd,True,False,64,16,32,4,4,4,76.5209 -dsd,True,False,64,16,64,4,4,4,60.1391 -dsd,True,False,64,32,16,4,4,4,141.5866 -dsd,True,False,64,32,32,4,4,4,83.946 -dsd,True,False,64,32,64,4,4,4,60.4473 -dsd,True,False,64,64,16,4,4,4,148.974 -dsd,True,False,64,64,32,4,4,4,87.7046 -dsd,True,False,64,64,64,4,4,4,87.7046 -dsd,True,True,16,16,16,4,4,4,159.7035 -dsd,True,True,16,16,32,4,4,4,112.7343 -dsd,True,True,16,16,64,4,4,4,96.109 -dsd,True,True,16,32,16,4,4,4,216.4289 -dsd,True,True,16,32,32,4,4,4,113.0748 -dsd,True,True,16,32,64,4,4,4,102.4129 -dsd,True,True,16,64,16,4,4,4,311.4233 -dsd,True,True,16,64,32,4,4,4,152.9842 -dsd,True,True,16,64,64,4,4,4,152.9842 -dsd,True,True,32,16,16,4,4,4,101.2226 -dsd,True,True,32,16,32,4,4,4,73.1764 -dsd,True,True,32,16,64,4,4,4,62.08 -dsd,True,True,32,32,16,4,4,4,128.2265 -dsd,True,True,32,32,32,4,4,4,85.9566 -dsd,True,True,32,32,64,4,4,4,71.4414 -dsd,True,True,32,64,16,4,4,4,156.3029 -dsd,True,True,32,64,32,4,4,4,112.7689 -dsd,True,True,32,64,64,4,4,4,112.7689 -dsd,True,True,64,16,16,4,4,4,127.2963 -dsd,True,True,64,16,32,4,4,4,79.6703 -dsd,True,True,64,16,64,4,4,4,61.7858 -dsd,True,True,64,32,16,4,4,4,149.5118 -dsd,True,True,64,32,32,4,4,4,91.4158 -dsd,True,True,64,32,64,4,4,4,67.6785 -dsd,True,True,64,64,16,4,4,4,168.1443 -dsd,True,True,64,64,32,4,4,4,105.573 -dsd,True,True,64,64,64,4,4,4,105.573 -sdd,False,False,16,16,16,4,4,4,110.7007 -sdd,False,False,16,16,32,4,4,4,83.8641 -sdd,False,False,16,16,64,4,4,4,79.1289 -sdd,False,False,16,32,16,4,4,4,119.6233 -sdd,False,False,16,32,32,4,4,4,82.5102 -sdd,False,False,16,32,64,4,4,4,78.6386 -sdd,False,False,16,64,16,4,4,4,145.0991 -sdd,False,False,16,64,32,4,4,4,89.7248 -sdd,False,False,16,64,64,4,4,4,83.423 -sdd,False,False,32,16,16,4,4,4,72.6872 -sdd,False,False,32,16,32,4,4,4,54.0949 -sdd,False,False,32,16,64,4,4,4,51.6267 -sdd,False,False,32,32,16,4,4,4,82.8882 -sdd,False,False,32,32,32,4,4,4,59.9318 -sdd,False,False,32,32,64,4,4,4,54.2268 -sdd,False,False,32,64,16,4,4,4,92.1725 -sdd,False,False,32,64,32,4,4,4,67.4228 -sdd,False,False,32,64,64,4,4,4,55.8651 -sdd,False,False,64,16,16,4,4,4,57.2198 -sdd,False,False,64,16,32,4,4,4,49.1785 -sdd,False,False,64,16,64,4,4,4,46.1422 -sdd,False,False,64,32,16,4,4,4,74.8587 -sdd,False,False,64,32,32,4,4,4,53.8419 -sdd,False,False,64,32,64,4,4,4,46.7183 -sdd,False,False,64,64,16,4,4,4,74.8587 -sdd,False,False,64,64,32,4,4,4,53.8419 -sdd,False,False,64,64,64,4,4,4,46.7183 -sdd,False,True,16,16,16,4,4,4,131.2918 -sdd,False,True,16,16,32,4,4,4,104.9111 -sdd,False,True,16,16,64,4,4,4,100.4713 -sdd,False,True,16,32,16,4,4,4,178.6799 -sdd,False,True,16,32,32,4,4,4,96.054 -sdd,False,True,16,32,64,4,4,4,88.2646 -sdd,False,True,16,64,16,4,4,4,271.7691 -sdd,False,True,16,64,32,4,4,4,138.3314 -sdd,False,True,16,64,64,4,4,4,132.6129 -sdd,False,True,32,16,16,4,4,4,77.3388 -sdd,False,True,32,16,32,4,4,4,59.9648 -sdd,False,True,32,16,64,4,4,4,60.089 -sdd,False,True,32,32,16,4,4,4,94.166 -sdd,False,True,32,32,32,4,4,4,72.2477 -sdd,False,True,32,32,64,4,4,4,64.8427 -sdd,False,True,32,64,16,4,4,4,120.019 -sdd,False,True,32,64,32,4,4,4,97.3089 -sdd,False,True,32,64,64,4,4,4,83.1809 -sdd,False,True,64,16,16,4,4,4,59.4227 -sdd,False,True,64,16,32,4,4,4,52.1191 -sdd,False,True,64,16,64,4,4,4,49.1211 -sdd,False,True,64,32,16,4,4,4,81.5417 -sdd,False,True,64,32,32,4,4,4,59.0152 -sdd,False,True,64,32,64,4,4,4,52.0989 -sdd,False,True,64,64,16,4,4,4,81.5417 -sdd,False,True,64,64,32,4,4,4,59.0152 -sdd,False,True,64,64,64,4,4,4,52.0989 -sdd,True,False,16,16,16,4,4,4,130.6109 -sdd,True,False,16,16,32,4,4,4,84.93 -sdd,True,False,16,16,64,4,4,4,79.0723 -sdd,True,False,16,32,16,4,4,4,142.6209 -sdd,True,False,16,32,32,4,4,4,82.0366 -sdd,True,False,16,32,64,4,4,4,81.5747 -sdd,True,False,16,64,16,4,4,4,165.6981 -sdd,True,False,16,64,32,4,4,4,92.0803 -sdd,True,False,16,64,64,4,4,4,84.1962 -sdd,True,False,32,16,16,4,4,4,96.9497 -sdd,True,False,32,16,32,4,4,4,62.008 -sdd,True,False,32,16,64,4,4,4,55.4291 -sdd,True,False,32,32,16,4,4,4,113.3967 -sdd,True,False,32,32,32,4,4,4,70.2921 -sdd,True,False,32,32,64,4,4,4,58.7307 -sdd,True,False,32,64,16,4,4,4,120.7301 -sdd,True,False,32,64,32,4,4,4,78.2809 -sdd,True,False,32,64,64,4,4,4,60.8745 -sdd,True,False,64,16,16,4,4,4,125.0771 -sdd,True,False,64,16,32,4,4,4,76.4491 -sdd,True,False,64,16,64,4,4,4,60.097 -sdd,True,False,64,32,16,4,4,4,141.8331 -sdd,True,False,64,32,32,4,4,4,83.7384 -sdd,True,False,64,32,64,4,4,4,60.9805 -sdd,True,False,64,64,16,4,4,4,141.8331 -sdd,True,False,64,64,32,4,4,4,83.7384 -sdd,True,False,64,64,64,4,4,4,60.9805 -sdd,True,True,16,16,16,4,4,4,159.6109 -sdd,True,True,16,16,32,4,4,4,105.5516 -sdd,True,True,16,16,64,4,4,4,100.7452 -sdd,True,True,16,32,16,4,4,4,211.4903 -sdd,True,True,16,32,32,4,4,4,101.2889 -sdd,True,True,16,32,64,4,4,4,90.9672 -sdd,True,True,16,64,16,4,4,4,303.9266 -sdd,True,True,16,64,32,4,4,4,146.3723 -sdd,True,True,16,64,64,4,4,4,133.093 -sdd,True,True,32,16,16,4,4,4,103.5772 -sdd,True,True,32,16,32,4,4,4,71.01 -sdd,True,True,32,16,64,4,4,4,63.2092 -sdd,True,True,32,32,16,4,4,4,128.0946 -sdd,True,True,32,32,32,4,4,4,86.1027 -sdd,True,True,32,32,64,4,4,4,68.897 -sdd,True,True,32,64,16,4,4,4,153.6706 -sdd,True,True,32,64,32,4,4,4,111.8289 -sdd,True,True,32,64,64,4,4,4,91.2396 -sdd,True,True,64,16,16,4,4,4,127.9097 -sdd,True,True,64,16,32,4,4,4,80.6409 -sdd,True,True,64,16,64,4,4,4,62.273 -sdd,True,True,64,32,16,4,4,4,148.6517 -sdd,True,True,64,32,32,4,4,4,90.9354 -sdd,True,True,64,32,64,4,4,4,67.6934 -sdd,True,True,64,64,16,4,4,4,148.6517 -sdd,True,True,64,64,32,4,4,4,90.9354 -sdd,True,True,64,64,64,4,4,4,67.6934 +dds,False,False,8,8,8,2,2,2,83.95069580078125 +dds,False,False,8,8,16,4,8,2,43.84989013671875 +dds,False,False,8,8,32,8,8,2,27.1609619140625 +dds,False,False,8,8,64,8,8,4,28.324365234375 +dds,False,False,8,8,128,2,2,2,inf +dds,False,False,8,16,8,4,16,2,84.48544921875 +dds,False,False,8,16,16,4,4,2,42.477719116210935 +dds,False,False,8,16,32,8,16,2,32.50401000976562 +dds,False,False,8,16,64,8,4,4,30.6483642578125 +dds,False,False,8,16,128,8,8,4,31.26983337402344 +dds,False,False,8,32,8,2,4,4,90.04239501953126 +dds,False,False,8,32,16,4,4,2,45.71172180175781 +dds,False,False,8,32,32,2,16,2,36.58458557128906 +dds,False,False,8,32,64,8,2,2,36.38044128417969 +dds,False,False,8,32,128,8,4,2,36.610171508789065 +dds,False,False,8,64,8,2,16,2,90.79548950195311 +dds,False,False,8,64,16,2,8,2,45.22069091796875 +dds,False,False,8,64,32,2,4,2,39.20011291503906 +dds,False,False,8,64,64,4,2,4,31.470745849609376 +dds,False,False,8,64,128,4,16,2,23.4410400390625 +dds,False,False,8,128,8,2,2,2,inf +dds,False,False,8,128,16,2,4,2,54.0128662109375 +dds,False,False,8,128,32,2,8,2,41.18137817382812 +dds,False,False,8,128,64,2,4,2,40.19525146484375 +dds,False,False,8,128,128,2,2,2,inf +dds,False,False,16,8,8,4,4,2,42.92951354980469 +dds,False,False,16,8,16,4,2,2,25.0050048828125 +dds,False,False,16,8,32,8,8,2,18.46839294433594 +dds,False,False,16,8,64,8,4,4,16.389712524414062 +dds,False,False,16,8,128,16,2,4,15.46875457763672 +dds,False,False,16,16,8,2,16,2,44.96512756347656 +dds,False,False,16,16,16,4,8,2,23.26790008544922 +dds,False,False,16,16,32,8,16,2,18.670246887207032 +dds,False,False,16,16,64,4,2,8,17.152764892578126 +dds,False,False,16,16,128,16,16,2,17.652220153808592 +dds,False,False,16,32,8,2,8,2,49.882073974609376 +dds,False,False,16,32,16,4,4,2,25.979672241210935 +dds,False,False,16,32,32,8,16,2,18.137794494628903 +dds,False,False,16,32,64,4,2,4,19.2738525390625 +dds,False,False,16,32,128,8,4,2,17.795779418945312 +dds,False,False,16,64,8,2,2,2,45.37416687011719 +dds,False,False,16,64,16,4,8,2,25.164306640625 +dds,False,False,16,64,32,4,16,2,20.45664367675781 +dds,False,False,16,64,64,4,16,2,19.650909423828125 +dds,False,False,16,64,128,2,16,4,17.490185546875 +dds,False,False,16,128,8,2,16,2,48.46643981933594 +dds,False,False,16,128,16,2,4,2,31.333306884765623 +dds,False,False,16,128,32,2,8,2,22.09304351806641 +dds,False,False,16,128,64,4,2,2,22.02271728515625 +dds,False,False,16,128,128,2,2,2,inf +dds,False,False,32,8,8,8,2,2,32.411660766601564 +dds,False,False,32,8,16,8,2,2,18.840211486816408 +dds,False,False,32,8,32,8,2,4,13.949766540527344 +dds,False,False,32,8,64,16,8,4,12.54494400024414 +dds,False,False,32,8,128,2,2,2,inf +dds,False,False,32,16,8,4,2,2,31.90923767089844 +dds,False,False,32,16,16,8,8,2,18.17984313964844 +dds,False,False,32,16,32,8,2,4,13.35161895751953 +dds,False,False,32,16,64,8,4,4,12.96336669921875 +dds,False,False,32,16,128,2,2,2,inf +dds,False,False,32,32,8,4,4,2,34.13486022949219 +dds,False,False,32,32,16,8,4,2,18.598626708984376 +dds,False,False,32,32,32,8,16,2,14.275120544433594 +dds,False,False,32,32,64,8,4,2,13.326914978027345 +dds,False,False,32,32,128,2,2,2,inf +dds,False,False,32,64,8,4,2,2,34.046783447265625 +dds,False,False,32,64,16,4,8,2,20.470233154296874 +dds,False,False,32,64,32,4,8,2,15.157046508789062 +dds,False,False,32,64,64,8,8,2,13.67220458984375 +dds,False,False,32,64,128,2,2,2,inf +dds,False,False,32,128,8,2,16,2,39.59329833984375 +dds,False,False,32,128,16,4,4,2,24.433506774902344 +dds,False,False,32,128,32,4,8,2,16.391807556152344 +dds,False,False,32,128,64,4,16,2,14.969456481933594 +dds,False,False,32,128,128,2,2,2,inf +dds,False,False,64,8,8,16,4,2,26.64129638671875 +dds,False,False,64,8,16,8,4,4,15.326361083984375 +dds,False,False,64,8,32,16,8,4,12.275103759765624 +dds,False,False,64,8,64,2,2,2,inf +dds,False,False,64,8,128,2,2,2,inf +dds,False,False,64,16,8,8,2,2,24.83520965576172 +dds,False,False,64,16,16,16,4,2,15.0552734375 +dds,False,False,64,16,32,8,16,4,12.50583038330078 +dds,False,False,64,16,64,2,2,2,inf +dds,False,False,64,16,128,2,2,2,inf +dds,False,False,64,32,8,8,8,2,29.507452392578124 +dds,False,False,64,32,16,8,8,2,16.862696838378906 +dds,False,False,64,32,32,8,4,2,12.744691467285156 +dds,False,False,64,32,64,2,2,2,inf +dds,False,False,64,32,128,2,2,2,inf +dds,False,False,64,64,8,4,2,2,32.07197265625 +dds,False,False,64,64,16,8,16,2,18.548597717285155 +dds,False,False,64,64,32,8,2,2,13.186189270019533 +dds,False,False,64,64,64,2,2,2,inf +dds,False,False,64,64,128,2,2,2,inf +dds,False,False,64,128,8,2,8,2,39.9909912109375 +dds,False,False,64,128,16,4,8,2,20.25826568603516 +dds,False,False,64,128,32,8,2,2,14.141506958007811 +dds,False,False,64,128,64,2,2,2,inf +dds,False,False,64,128,128,2,2,2,inf +dds,False,False,128,8,8,16,2,4,26.484652709960937 +dds,False,False,128,8,16,16,2,4,14.087107849121091 +dds,False,False,128,8,32,2,2,2,inf +dds,False,False,128,8,64,2,2,2,inf +dds,False,False,128,8,128,2,2,2,inf +dds,False,False,128,16,8,16,8,2,21.534544372558592 +dds,False,False,128,16,16,16,16,2,14.640284729003906 +dds,False,False,128,16,32,2,2,2,inf +dds,False,False,128,16,64,2,2,2,inf +dds,False,False,128,16,128,2,2,2,inf +dds,False,False,128,32,8,8,8,2,28.789382934570312 +dds,False,False,128,32,16,8,16,2,15.902000427246094 +dds,False,False,128,32,32,2,2,2,inf +dds,False,False,128,32,64,2,2,2,inf +dds,False,False,128,32,128,2,2,2,inf +dds,False,False,128,64,8,4,2,2,33.127508544921874 +dds,False,False,128,64,16,8,2,2,17.735232543945312 +dds,False,False,128,64,32,2,2,2,inf +dds,False,False,128,64,64,2,2,2,inf +dds,False,False,128,64,128,2,2,2,inf +dds,False,False,128,128,8,2,2,2,inf +dds,False,False,128,128,16,2,2,2,inf +dds,False,False,128,128,32,2,2,2,inf +dds,False,False,128,128,64,2,2,2,inf +dds,False,False,128,128,128,2,2,2,inf +dds,False,True,8,8,8,2,8,2,68.33310546875 +dds,False,True,8,8,16,4,2,2,56.5980712890625 +dds,False,True,8,8,32,8,8,4,68.17498168945312 +dds,False,True,8,8,64,4,8,16,84.39244384765625 +dds,False,True,8,8,128,8,8,16,92.27847290039062 +dds,False,True,8,16,8,2,16,2,67.75989379882813 +dds,False,True,8,16,16,2,4,2,49.00977478027344 +dds,False,True,8,16,32,8,16,2,46.584982299804686 +dds,False,True,8,16,64,8,4,4,42.45918579101563 +dds,False,True,8,16,128,2,8,16,41.13911743164063 +dds,False,True,8,32,8,2,16,2,89.56703491210938 +dds,False,True,8,32,16,2,4,2,46.791586303710936 +dds,False,True,8,32,32,4,8,2,39.323019409179686 +dds,False,True,8,32,64,4,2,2,38.71488342285157 +dds,False,True,8,32,128,8,8,2,39.51044921875 +dds,False,True,8,64,8,2,16,2,130.14737548828126 +dds,False,True,8,64,16,2,8,2,66.99788818359374 +dds,False,True,8,64,32,4,16,2,55.87267456054688 +dds,False,True,8,64,64,4,16,2,54.4582763671875 +dds,False,True,8,64,128,4,4,2,54.76102905273437 +dds,False,True,8,128,8,2,2,2,inf +dds,False,True,8,128,16,2,16,2,109.83070068359376 +dds,False,True,8,128,32,2,8,2,99.8999755859375 +dds,False,True,8,128,64,2,4,2,98.21196899414062 +dds,False,True,8,128,128,2,2,2,inf +dds,False,True,16,8,8,4,4,2,43.08326416015625 +dds,False,True,16,8,16,4,8,2,29.3208740234375 +dds,False,True,16,8,32,8,4,4,43.06368103027344 +dds,False,True,16,8,64,4,4,8,44.58840637207031 +dds,False,True,16,8,128,4,4,16,49.77604370117187 +dds,False,True,16,16,8,2,16,2,37.32265014648438 +dds,False,True,16,16,16,4,8,2,25.79122314453125 +dds,False,True,16,16,32,8,16,2,25.268634033203124 +dds,False,True,16,16,64,16,8,2,21.63783721923828 +dds,False,True,16,16,128,16,16,2,21.9729248046875 +dds,False,True,16,32,8,2,2,2,47.71018371582032 +dds,False,True,16,32,16,4,2,2,29.6340576171875 +dds,False,True,16,32,32,8,16,2,22.302653503417968 +dds,False,True,16,32,64,8,8,2,22.06714172363281 +dds,False,True,16,32,128,8,4,2,21.238131713867187 +dds,False,True,16,64,8,2,8,2,58.50890502929688 +dds,False,True,16,64,16,4,2,2,40.11568603515625 +dds,False,True,16,64,32,8,8,2,33.484283447265625 +dds,False,True,16,64,64,4,2,2,32.49169006347656 +dds,False,True,16,64,128,4,2,4,32.50621032714844 +dds,False,True,16,128,8,2,16,2,86.73631591796875 +dds,False,True,16,128,16,4,8,2,63.19295043945313 +dds,False,True,16,128,32,4,2,2,55.3454345703125 +dds,False,True,16,128,64,4,8,2,53.07149658203125 +dds,False,True,16,128,128,2,2,2,inf +dds,False,True,32,8,8,8,8,2,31.70926513671875 +dds,False,True,32,8,16,8,8,2,20.866188049316406 +dds,False,True,32,8,32,8,2,2,18.719453430175783 +dds,False,True,32,8,64,16,8,2,23.87512969970703 +dds,False,True,32,8,128,2,2,2,inf +dds,False,True,32,16,8,4,2,2,28.93572998046875 +dds,False,True,32,16,16,8,16,2,19.09322509765625 +dds,False,True,32,16,32,16,8,2,15.363580322265625 +dds,False,True,32,16,64,8,2,4,14.772169494628908 +dds,False,True,32,16,128,2,2,2,inf +dds,False,True,32,32,8,4,2,2,36.425103759765626 +dds,False,True,32,32,16,8,4,2,21.15513610839844 +dds,False,True,32,32,32,8,8,2,16.99063720703125 +dds,False,True,32,32,64,8,2,4,16.2457763671875 +dds,False,True,32,32,128,2,2,2,inf +dds,False,True,32,64,8,4,4,2,42.11276245117188 +dds,False,True,32,64,16,8,8,2,26.644598388671877 +dds,False,True,32,64,32,4,16,2,22.55145568847656 +dds,False,True,32,64,64,4,2,4,21.086448669433597 +dds,False,True,32,64,128,2,2,2,inf +dds,False,True,32,128,8,2,8,2,58.4863525390625 +dds,False,True,32,128,16,4,8,2,39.704345703125 +dds,False,True,32,128,32,4,2,2,32.989443969726565 +dds,False,True,32,128,64,4,8,4,31.92890930175781 +dds,False,True,32,128,128,2,2,2,inf +dds,False,True,64,8,8,16,8,2,26.21756591796875 +dds,False,True,64,8,16,8,4,4,16.323394775390625 +dds,False,True,64,8,32,16,2,4,13.234153747558594 +dds,False,True,64,8,64,2,2,2,inf +dds,False,True,64,8,128,2,2,2,inf +dds,False,True,64,16,8,8,4,2,23.53669128417969 +dds,False,True,64,16,16,16,4,2,15.508131408691408 +dds,False,True,64,16,32,8,16,4,13.336930847167968 +dds,False,True,64,16,64,2,2,2,inf +dds,False,True,64,16,128,2,2,2,inf +dds,False,True,64,32,8,8,4,2,30.660494995117187 +dds,False,True,64,32,16,8,16,2,18.16160888671875 +dds,False,True,64,32,32,8,2,2,14.187823486328124 +dds,False,True,64,32,64,2,2,2,inf +dds,False,True,64,32,128,2,2,2,inf +dds,False,True,64,64,8,4,16,2,35.908871459960935 +dds,False,True,64,64,16,8,2,2,22.574208068847657 +dds,False,True,64,64,32,8,8,2,16.651814270019532 +dds,False,True,64,64,64,2,2,2,inf +dds,False,True,64,64,128,2,2,2,inf +dds,False,True,64,128,8,2,2,2,47.98629455566406 +dds,False,True,64,128,16,4,2,2,29.4864501953125 +dds,False,True,64,128,32,4,16,2,22.77893829345703 +dds,False,True,64,128,64,2,2,2,inf +dds,False,True,64,128,128,2,2,2,inf +dds,False,True,128,8,8,16,2,4,26.28378295898437 +dds,False,True,128,8,16,16,2,4,14.428512573242188 +dds,False,True,128,8,32,2,2,2,inf +dds,False,True,128,8,64,2,2,2,inf +dds,False,True,128,8,128,2,2,2,inf +dds,False,True,128,16,8,16,4,2,21.152851867675786 +dds,False,True,128,16,16,16,8,2,14.555728149414062 +dds,False,True,128,16,32,2,2,2,inf +dds,False,True,128,16,64,2,2,2,inf +dds,False,True,128,16,128,2,2,2,inf +dds,False,True,128,32,8,8,4,2,29.437484741210938 +dds,False,True,128,32,16,8,16,2,16.786746215820312 +dds,False,True,128,32,32,2,2,2,inf +dds,False,True,128,32,64,2,2,2,inf +dds,False,True,128,32,128,2,2,2,inf +dds,False,True,128,64,8,4,2,2,35.015570068359374 +dds,False,True,128,64,16,8,2,2,19.62778167724609 +dds,False,True,128,64,32,2,2,2,inf +dds,False,True,128,64,64,2,2,2,inf +dds,False,True,128,64,128,2,2,2,inf +dds,False,True,128,128,8,2,2,2,inf +dds,False,True,128,128,16,2,2,2,inf +dds,False,True,128,128,32,2,2,2,inf +dds,False,True,128,128,64,2,2,2,inf +dds,False,True,128,128,128,2,2,2,inf +dds,True,False,8,8,8,2,2,2,85.9234619140625 +dds,True,False,8,8,16,4,8,2,46.31357421875 +dds,True,False,8,8,32,8,8,2,30.778073120117188 +dds,True,False,8,8,64,4,8,8,29.98555908203125 +dds,True,False,8,8,128,2,2,2,inf +dds,True,False,8,16,8,4,16,2,90.65723876953123 +dds,True,False,8,16,16,2,4,2,45.50893249511719 +dds,True,False,8,16,32,8,2,2,35.563201904296875 +dds,True,False,8,16,64,8,16,4,32.1559814453125 +dds,True,False,8,16,128,8,2,4,32.832516479492185 +dds,True,False,8,32,8,2,16,2,89.41387939453125 +dds,True,False,8,32,16,2,4,2,50.946450805664064 +dds,True,False,8,32,32,2,2,4,38.40348815917969 +dds,True,False,8,32,64,8,8,2,36.46643981933594 +dds,True,False,8,32,128,8,16,2,36.96016235351563 +dds,True,False,8,64,8,2,8,2,96.24725341796876 +dds,True,False,8,64,16,2,2,2,49.09877014160156 +dds,True,False,8,64,32,2,4,2,39.97953491210937 +dds,True,False,8,64,64,2,4,2,38.450704956054686 +dds,True,False,8,64,128,4,2,2,28.423953247070312 +dds,True,False,8,128,8,2,16,2,97.62970581054688 +dds,True,False,8,128,16,2,4,2,50.08351135253906 +dds,True,False,8,128,32,2,16,2,42.19480895996094 +dds,True,False,8,128,64,2,4,2,40.920501708984375 +dds,True,False,8,128,128,2,2,2,inf +dds,True,False,16,8,8,4,8,2,53.02158203125 +dds,True,False,16,8,16,4,8,2,26.3121826171875 +dds,True,False,16,8,32,8,8,2,18.794140625 +dds,True,False,16,8,64,8,2,4,16.877682495117188 +dds,True,False,16,8,128,16,2,4,15.879420471191406 +dds,True,False,16,16,8,2,4,2,47.39478454589844 +dds,True,False,16,16,16,4,8,2,25.626373291015625 +dds,True,False,16,16,32,4,4,2,20.6164794921875 +dds,True,False,16,16,64,16,16,2,18.20005798339844 +dds,True,False,16,16,128,16,4,2,18.098419189453125 +dds,True,False,16,32,8,2,2,2,55.869049072265625 +dds,True,False,16,32,16,4,4,2,29.0021240234375 +dds,True,False,16,32,32,8,16,2,18.64588775634765 +dds,True,False,16,32,64,8,4,2,19.498381042480467 +dds,True,False,16,32,128,8,8,2,18.230441284179687 +dds,True,False,16,64,8,2,2,2,54.580517578125 +dds,True,False,16,64,16,4,8,2,28.303189086914063 +dds,True,False,16,64,32,4,16,2,21.05837097167969 +dds,True,False,16,64,64,4,16,2,19.862054443359376 +dds,True,False,16,64,128,2,8,4,17.901318359375 +dds,True,False,16,128,8,2,4,2,54.739007568359376 +dds,True,False,16,128,16,2,8,2,32.546041870117186 +dds,True,False,16,128,32,2,16,2,23.56947174072265 +dds,True,False,16,128,64,4,2,2,21.99974365234375 +dds,True,False,16,128,128,2,2,2,inf +dds,True,False,32,8,8,8,8,2,59.6756103515625 +dds,True,False,32,8,16,8,8,2,22.40240936279297 +dds,True,False,32,8,32,8,8,4,15.45 +dds,True,False,32,8,64,8,2,8,13.239605712890626 +dds,True,False,32,8,128,2,2,2,inf +dds,True,False,32,16,8,4,16,2,42.68543090820312 +dds,True,False,32,16,16,8,4,2,22.36754608154297 +dds,True,False,32,16,32,8,2,4,15.607037353515626 +dds,True,False,32,16,64,8,2,4,13.975074768066406 +dds,True,False,32,16,128,2,2,2,inf +dds,True,False,32,32,8,4,16,2,51.02578125 +dds,True,False,32,32,16,8,4,2,26.00336608886719 +dds,True,False,32,32,32,8,8,2,16.560348510742188 +dds,True,False,32,32,64,8,4,2,14.603456115722656 +dds,True,False,32,32,128,2,2,2,inf +dds,True,False,32,64,8,4,4,2,50.49428100585938 +dds,True,False,32,64,16,8,2,2,28.138201904296874 +dds,True,False,32,64,32,4,4,2,18.134669494628906 +dds,True,False,32,64,64,8,4,2,14.894908142089845 +dds,True,False,32,64,128,2,2,2,inf +dds,True,False,32,128,8,2,2,2,55.946539306640624 +dds,True,False,32,128,16,4,16,2,30.93107299804688 +dds,True,False,32,128,32,4,4,2,18.512527465820312 +dds,True,False,32,128,64,4,2,2,15.689494323730468 +dds,True,False,32,128,128,2,2,2,inf +dds,True,False,64,8,8,16,4,2,95.04129638671876 +dds,True,False,64,8,16,8,4,4,30.25923156738281 +dds,True,False,64,8,32,16,8,4,19.586994934082032 +dds,True,False,64,8,64,2,2,2,inf +dds,True,False,64,8,128,2,2,2,inf +dds,True,False,64,16,8,8,4,2,58.078216552734375 +dds,True,False,64,16,16,16,4,2,29.82334289550781 +dds,True,False,64,16,32,8,16,4,19.79328308105469 +dds,True,False,64,16,64,2,2,2,inf +dds,True,False,64,16,128,2,2,2,inf +dds,True,False,64,32,8,8,16,2,66.37091674804688 +dds,True,False,64,32,16,8,16,2,34.887594604492186 +dds,True,False,64,32,32,16,16,2,20.612745666503905 +dds,True,False,64,32,64,2,2,2,inf +dds,True,False,64,32,128,2,2,2,inf +dds,True,False,64,64,8,8,2,2,67.04953002929688 +dds,True,False,64,64,16,8,16,2,35.14598083496094 +dds,True,False,64,64,32,8,8,2,21.04156494140625 +dds,True,False,64,64,64,2,2,2,inf +dds,True,False,64,64,128,2,2,2,inf +dds,True,False,64,128,8,4,2,2,72.3403076171875 +dds,True,False,64,128,16,4,4,2,37.190869140625 +dds,True,False,64,128,32,8,2,2,22.02156219482422 +dds,True,False,64,128,64,2,2,2,inf +dds,True,False,64,128,128,2,2,2,inf +dds,True,False,128,8,8,2,2,2,inf +dds,True,False,128,8,16,16,8,4,52.30179443359375 +dds,True,False,128,8,32,2,2,2,inf +dds,True,False,128,8,64,2,2,2,inf +dds,True,False,128,8,128,2,2,2,inf +dds,True,False,128,16,8,16,8,2,95.9490966796875 +dds,True,False,128,16,16,16,8,2,51.84322509765625 +dds,True,False,128,16,32,2,2,2,inf +dds,True,False,128,16,64,2,2,2,inf +dds,True,False,128,16,128,2,2,2,inf +dds,True,False,128,32,8,8,16,2,105.0478271484375 +dds,True,False,128,32,16,16,16,2,53.6555908203125 +dds,True,False,128,32,32,2,2,2,inf +dds,True,False,128,32,64,2,2,2,inf +dds,True,False,128,32,128,2,2,2,inf +dds,True,False,128,64,8,4,4,2,106.87479248046876 +dds,True,False,128,64,16,8,2,2,54.582958984375 +dds,True,False,128,64,32,2,2,2,inf +dds,True,False,128,64,64,2,2,2,inf +dds,True,False,128,64,128,2,2,2,inf +dds,True,False,128,128,8,2,2,2,inf +dds,True,False,128,128,16,2,2,2,inf +dds,True,False,128,128,32,2,2,2,inf +dds,True,False,128,128,64,2,2,2,inf +dds,True,False,128,128,128,2,2,2,inf +dds,True,True,8,8,8,2,4,2,93.86426391601564 +dds,True,True,8,8,16,2,4,4,62.9012451171875 +dds,True,True,8,8,32,4,4,8,75.62587280273438 +dds,True,True,8,8,64,4,4,16,85.3392333984375 +dds,True,True,8,8,128,4,2,16,92.88526000976564 +dds,True,True,8,16,8,2,4,2,79.35208740234376 +dds,True,True,8,16,16,2,4,2,47.40335388183594 +dds,True,True,8,16,32,8,16,2,49.03898620605469 +dds,True,True,8,16,64,8,8,4,43.70534973144531 +dds,True,True,8,16,128,2,16,16,41.79808044433594 +dds,True,True,8,32,8,2,16,2,103.93631591796876 +dds,True,True,8,32,16,2,16,2,52.19888305664063 +dds,True,True,8,32,32,4,8,2,41.50272827148437 +dds,True,True,8,32,64,4,8,2,39.83175354003906 +dds,True,True,8,32,128,8,8,2,39.99073791503906 +dds,True,True,8,64,8,2,4,2,144.64049072265624 +dds,True,True,8,64,16,2,4,2,71.35440673828126 +dds,True,True,8,64,32,4,2,2,58.09429931640625 +dds,True,True,8,64,64,4,2,2,55.59429931640625 +dds,True,True,8,64,128,4,16,2,54.93161010742188 +dds,True,True,8,128,8,2,2,2,inf +dds,True,True,8,128,16,2,8,2,113.36585693359376 +dds,True,True,8,128,32,2,16,2,102.328076171875 +dds,True,True,8,128,64,2,4,2,99.2333984375 +dds,True,True,8,128,128,2,2,2,inf +dds,True,True,16,8,8,4,2,2,55.75620727539062 +dds,True,True,16,8,16,4,4,2,31.414694213867183 +dds,True,True,16,8,32,8,4,4,43.36945495605469 +dds,True,True,16,8,64,2,8,16,41.72640686035156 +dds,True,True,16,8,128,4,4,16,48.80404052734375 +dds,True,True,16,16,8,2,16,2,44.52323303222656 +dds,True,True,16,16,16,4,16,2,27.96387023925781 +dds,True,True,16,16,32,8,16,2,26.15814208984375 +dds,True,True,16,16,64,16,8,2,21.87930908203125 +dds,True,True,16,16,128,16,4,2,21.86894989013672 +dds,True,True,16,32,8,2,2,2,57.255419921875 +dds,True,True,16,32,16,4,4,2,34.42570190429687 +dds,True,True,16,32,32,8,16,2,24.054620361328126 +dds,True,True,16,32,64,8,16,2,22.852418518066408 +dds,True,True,16,32,128,8,16,2,21.65142974853516 +dds,True,True,16,64,8,2,4,2,67.79102783203125 +dds,True,True,16,64,16,4,2,2,44.90336303710937 +dds,True,True,16,64,32,8,2,2,35.3868896484375 +dds,True,True,16,64,64,4,2,2,33.657223510742185 +dds,True,True,16,64,128,4,2,4,32.798883056640626 +dds,True,True,16,128,8,2,16,2,89.61912841796875 +dds,True,True,16,128,16,4,8,2,68.82657470703126 +dds,True,True,16,128,32,4,2,2,56.98488159179688 +dds,True,True,16,128,64,4,2,2,53.906036376953125 +dds,True,True,16,128,128,2,2,2,inf +dds,True,True,32,8,8,8,4,2,61.55089111328125 +dds,True,True,32,8,16,8,2,2,24.22500457763672 +dds,True,True,32,8,32,16,8,2,19.48570861816406 +dds,True,True,32,8,64,16,2,2,19.49196472167969 +dds,True,True,32,8,128,2,2,2,inf +dds,True,True,32,16,8,4,8,2,42.84806823730469 +dds,True,True,32,16,16,8,2,2,23.67737274169922 +dds,True,True,32,16,32,8,16,4,17.302281188964844 +dds,True,True,32,16,64,8,4,4,15.770700073242187 +dds,True,True,32,16,128,2,2,2,inf +dds,True,True,32,32,8,4,16,2,53.422998046875 +dds,True,True,32,32,16,8,16,2,29.681704711914065 +dds,True,True,32,32,32,8,8,2,20.000189208984374 +dds,True,True,32,32,64,8,2,4,17.455282592773436 +dds,True,True,32,32,128,2,2,2,inf +dds,True,True,32,64,8,4,8,2,58.26892700195312 +dds,True,True,32,64,16,8,2,2,34.81661071777344 +dds,True,True,32,64,32,8,8,2,26.6898193359375 +dds,True,True,32,64,64,8,2,2,23.006997680664064 +dds,True,True,32,64,128,2,2,2,inf +dds,True,True,32,128,8,4,2,2,71.09729614257813 +dds,True,True,32,128,16,4,2,2,48.069085693359376 +dds,True,True,32,128,32,4,16,2,36.78126525878906 +dds,True,True,32,128,64,8,2,2,33.828955078125 +dds,True,True,32,128,128,2,2,2,inf +dds,True,True,64,8,8,16,4,2,97.09566650390624 +dds,True,True,64,8,16,8,8,4,31.19393920898437 +dds,True,True,64,8,32,8,8,8,20.78213806152344 +dds,True,True,64,8,64,2,2,2,inf +dds,True,True,64,8,128,2,2,2,inf +dds,True,True,64,16,8,8,8,2,58.19834594726562 +dds,True,True,64,16,16,16,8,2,31.186398315429688 +dds,True,True,64,16,32,8,4,4,20.88908233642578 +dds,True,True,64,16,64,2,2,2,inf +dds,True,True,64,16,128,2,2,2,inf +dds,True,True,64,32,8,8,8,2,67.47835083007813 +dds,True,True,64,32,16,16,8,2,36.777783203125 +dds,True,True,64,32,32,16,8,2,22.203033447265625 +dds,True,True,64,32,64,2,2,2,inf +dds,True,True,64,32,128,2,2,2,inf +dds,True,True,64,64,8,8,16,2,71.129345703125 +dds,True,True,64,64,16,8,4,2,39.23030395507813 +dds,True,True,64,64,32,8,4,2,25.64223022460937 +dds,True,True,64,64,64,2,2,2,inf +dds,True,True,64,64,128,2,2,2,inf +dds,True,True,64,128,8,4,8,2,81.20865478515626 +dds,True,True,64,128,16,4,2,2,46.47157287597656 +dds,True,True,64,128,32,8,2,2,31.310092163085937 +dds,True,True,64,128,64,2,2,2,inf +dds,True,True,64,128,128,2,2,2,inf +dds,True,True,128,8,8,2,2,2,inf +dds,True,True,128,8,16,16,2,4,52.752923583984376 +dds,True,True,128,8,32,2,2,2,inf +dds,True,True,128,8,64,2,2,2,inf +dds,True,True,128,8,128,2,2,2,inf +dds,True,True,128,16,8,16,4,2,96.22736206054688 +dds,True,True,128,16,16,16,4,2,51.95736083984375 +dds,True,True,128,16,32,2,2,2,inf +dds,True,True,128,16,64,2,2,2,inf +dds,True,True,128,16,128,2,2,2,inf +dds,True,True,128,32,8,8,4,2,105.681982421875 +dds,True,True,128,32,16,16,4,2,54.82050170898437 +dds,True,True,128,32,32,2,2,2,inf +dds,True,True,128,32,64,2,2,2,inf +dds,True,True,128,32,128,2,2,2,inf +dds,True,True,128,64,8,4,16,2,108.55108642578124 +dds,True,True,128,64,16,8,4,2,56.63348388671875 +dds,True,True,128,64,32,2,2,2,inf +dds,True,True,128,64,64,2,2,2,inf +dds,True,True,128,64,128,2,2,2,inf +dds,True,True,128,128,8,2,2,2,inf +dds,True,True,128,128,16,2,2,2,inf +dds,True,True,128,128,32,2,2,2,inf +dds,True,True,128,128,64,2,2,2,inf +dds,True,True,128,128,128,2,2,2,inf +dsd,False,False,8,8,8,2,8,2,58.68651123046875 +dsd,False,False,8,8,16,4,4,2,42.6728759765625 +dsd,False,False,8,8,32,4,4,4,38.15791931152344 +dsd,False,False,8,8,64,8,8,4,34.25900573730469 +dsd,False,False,8,8,128,2,2,2,inf +dsd,False,False,8,16,8,2,4,2,50.28748474121094 +dsd,False,False,8,16,16,4,16,2,39.87994384765625 +dsd,False,False,8,16,32,4,8,4,32.09373474121094 +dsd,False,False,8,16,64,4,2,4,33.66380615234375 +dsd,False,False,8,16,128,8,2,4,32.56879577636719 +dsd,False,False,8,32,8,2,8,2,54.3849609375 +dsd,False,False,8,32,16,2,2,2,35.072271728515624 +dsd,False,False,8,32,32,8,8,2,37.27109375 +dsd,False,False,8,32,64,8,16,2,38.08072509765625 +dsd,False,False,8,32,128,2,2,2,inf +dsd,False,False,8,64,8,2,4,2,53.181591796875 +dsd,False,False,8,64,16,2,8,2,41.441680908203125 +dsd,False,False,8,64,32,2,2,2,40.47706909179688 +dsd,False,False,8,64,64,2,2,2,inf +dsd,False,False,8,64,128,2,2,2,inf +dsd,False,False,8,128,8,2,2,2,inf +dsd,False,False,8,128,16,2,2,2,39.33102111816406 +dsd,False,False,8,128,32,2,2,2,inf +dsd,False,False,8,128,64,2,2,2,inf +dsd,False,False,8,128,128,2,2,2,inf +dsd,False,False,16,8,8,4,4,2,38.00410461425781 +dsd,False,False,16,8,16,4,2,2,23.97266540527344 +dsd,False,False,16,8,32,4,4,4,19.30738220214844 +dsd,False,False,16,8,64,4,8,8,17.43036804199219 +dsd,False,False,16,8,128,16,2,4,16.734518432617186 +dsd,False,False,16,16,8,2,2,2,33.422088623046875 +dsd,False,False,16,16,16,4,4,2,21.60905303955078 +dsd,False,False,16,16,32,4,16,2,19.633290100097657 +dsd,False,False,16,16,64,4,16,8,18.44384002685547 +dsd,False,False,16,16,128,2,4,16,19.249530029296874 +dsd,False,False,16,32,8,2,2,2,39.05540161132812 +dsd,False,False,16,32,16,4,8,2,22.90558776855469 +dsd,False,False,16,32,32,8,8,2,18.76983337402344 +dsd,False,False,16,32,64,4,4,2,19.42920684814453 +dsd,False,False,16,32,128,2,2,2,inf +dsd,False,False,16,64,8,2,8,2,38.244808959960935 +dsd,False,False,16,64,16,4,4,2,24.67497863769531 +dsd,False,False,16,64,32,2,2,2,21.384271240234376 +dsd,False,False,16,64,64,2,2,2,inf +dsd,False,False,16,64,128,2,2,2,inf +dsd,False,False,16,128,8,2,16,2,40.819033813476565 +dsd,False,False,16,128,16,2,2,2,27.883056640625 +dsd,False,False,16,128,32,2,2,2,inf +dsd,False,False,16,128,64,2,2,2,inf +dsd,False,False,16,128,128,2,2,2,inf +dsd,False,False,32,8,8,8,4,2,29.94057006835937 +dsd,False,False,32,8,16,8,4,2,18.484355163574214 +dsd,False,False,32,8,32,8,2,4,14.205218505859374 +dsd,False,False,32,8,64,8,2,8,12.878863525390624 +dsd,False,False,32,8,128,8,8,8,12.873526000976565 +dsd,False,False,32,16,8,4,16,2,26.0935546875 +dsd,False,False,32,16,16,8,4,2,17.127200317382812 +dsd,False,False,32,16,32,16,4,2,13.659455871582033 +dsd,False,False,32,16,64,8,16,4,12.99901123046875 +dsd,False,False,32,16,128,8,4,4,12.17227554321289 +dsd,False,False,32,32,8,4,16,2,32.057009887695315 +dsd,False,False,32,32,16,8,8,2,17.632479858398437 +dsd,False,False,32,32,32,8,16,2,13.9248291015625 +dsd,False,False,32,32,64,8,8,2,13.21557159423828 +dsd,False,False,32,32,128,2,2,2,inf +dsd,False,False,32,64,8,4,16,2,32.44314270019531 +dsd,False,False,32,64,16,4,2,2,19.502828979492183 +dsd,False,False,32,64,32,4,2,2,15.068704223632812 +dsd,False,False,32,64,64,2,2,2,inf +dsd,False,False,32,64,128,2,2,2,inf +dsd,False,False,32,128,8,2,16,2,37.05467529296875 +dsd,False,False,32,128,16,4,4,2,23.824124145507813 +dsd,False,False,32,128,32,2,2,2,inf +dsd,False,False,32,128,64,2,2,2,inf +dsd,False,False,32,128,128,2,2,2,inf +dsd,False,False,64,8,8,16,2,2,25.585133361816407 +dsd,False,False,64,8,16,8,2,4,15.365625 +dsd,False,False,64,8,32,16,8,4,12.445209503173828 +dsd,False,False,64,8,64,8,8,8,12.021456146240237 +dsd,False,False,64,8,128,16,2,4,11.294950103759763 +dsd,False,False,64,16,8,8,16,2,22.789225769042968 +dsd,False,False,64,16,16,16,8,2,14.578399658203123 +dsd,False,False,64,16,32,8,4,4,12.400835418701172 +dsd,False,False,64,16,64,8,2,4,11.663340759277345 +dsd,False,False,64,16,128,16,8,4,11.174070739746094 +dsd,False,False,64,32,8,8,4,2,28.659817504882813 +dsd,False,False,64,32,16,8,16,2,16.314057922363283 +dsd,False,False,64,32,32,8,2,2,12.693769836425782 +dsd,False,False,64,32,64,8,8,4,11.483523559570312 +dsd,False,False,64,32,128,2,2,2,inf +dsd,False,False,64,64,8,4,8,2,31.395059204101564 +dsd,False,False,64,64,16,8,16,2,18.30839385986328 +dsd,False,False,64,64,32,8,8,2,12.973321533203125 +dsd,False,False,64,64,64,2,2,2,inf +dsd,False,False,64,64,128,2,2,2,inf +dsd,False,False,64,128,8,2,2,2,43.200830078125 +dsd,False,False,64,128,16,4,2,2,20.80726013183594 +dsd,False,False,64,128,32,2,2,2,inf +dsd,False,False,64,128,64,2,2,2,inf +dsd,False,False,64,128,128,2,2,2,inf +dsd,False,False,128,8,8,16,2,4,25.693502807617183 +dsd,False,False,128,8,16,16,2,4,14.128346252441409 +dsd,False,False,128,8,32,8,2,8,12.39136962890625 +dsd,False,False,128,8,64,8,8,8,11.45447998046875 +dsd,False,False,128,8,128,8,8,8,12.337554931640623 +dsd,False,False,128,16,8,16,2,2,20.80894012451172 +dsd,False,False,128,16,16,16,16,2,14.157225036621094 +dsd,False,False,128,16,32,8,4,4,11.805955505371092 +dsd,False,False,128,16,64,16,8,4,11.266896057128909 +dsd,False,False,128,16,128,4,16,16,12.197372436523438 +dsd,False,False,128,32,8,8,16,2,28.485455322265626 +dsd,False,False,128,32,16,8,8,2,15.9461669921875 +dsd,False,False,128,32,32,16,16,2,11.877593231201171 +dsd,False,False,128,32,64,16,2,2,11.360527801513673 +dsd,False,False,128,32,128,2,2,2,inf +dsd,False,False,128,64,8,4,2,2,33.3581298828125 +dsd,False,False,128,64,16,8,16,2,17.284320068359374 +dsd,False,False,128,64,32,8,8,2,12.954953002929688 +dsd,False,False,128,64,64,2,2,2,inf +dsd,False,False,128,64,128,2,2,2,inf +dsd,False,False,128,128,8,2,2,2,inf +dsd,False,False,128,128,16,2,2,2,inf +dsd,False,False,128,128,32,2,2,2,inf +dsd,False,False,128,128,64,2,2,2,inf +dsd,False,False,128,128,128,2,2,2,inf +dsd,False,True,8,8,8,2,2,2,64.83031005859375 +dsd,False,True,8,8,16,4,8,2,45.480560302734375 +dsd,False,True,8,8,32,8,8,2,43.537939453125 +dsd,False,True,8,8,64,4,4,8,41.256735229492186 +dsd,False,True,8,8,128,8,8,8,34.57571716308594 +dsd,False,True,8,16,8,2,16,2,66.96640625 +dsd,False,True,8,16,16,2,4,2,42.99291076660156 +dsd,False,True,8,16,32,2,4,4,41.06326904296875 +dsd,False,True,8,16,64,4,8,4,37.438262939453125 +dsd,False,True,8,16,128,8,4,4,35.173666381835936 +dsd,False,True,8,32,8,2,4,2,89.96324462890625 +dsd,False,True,8,32,16,2,8,2,47.99249877929688 +dsd,False,True,8,32,32,4,8,2,44.1489990234375 +dsd,False,True,8,32,64,4,2,2,41.00673217773438 +dsd,False,True,8,32,128,2,2,2,inf +dsd,False,True,8,64,8,2,16,2,131.4682373046875 +dsd,False,True,8,64,16,2,16,2,67.10629272460938 +dsd,False,True,8,64,32,4,4,2,58.920703125 +dsd,False,True,8,64,64,2,2,2,inf +dsd,False,True,8,64,128,2,2,2,inf +dsd,False,True,8,128,8,2,2,2,inf +dsd,False,True,8,128,16,2,2,2,108.13199462890626 +dsd,False,True,8,128,32,2,2,2,inf +dsd,False,True,8,128,64,2,2,2,inf +dsd,False,True,8,128,128,2,2,2,inf +dsd,False,True,16,8,8,4,8,2,42.252691650390624 +dsd,False,True,16,8,16,4,8,2,26.201754760742187 +dsd,False,True,16,8,32,8,2,4,26.348110961914063 +dsd,False,True,16,8,64,16,8,4,23.99023284912109 +dsd,False,True,16,8,128,4,4,16,19.612937927246094 +dsd,False,True,16,16,8,2,16,2,36.32303161621094 +dsd,False,True,16,16,16,4,8,2,24.885910034179688 +dsd,False,True,16,16,32,8,2,2,24.24401245117188 +dsd,False,True,16,16,64,8,8,4,21.024485778808597 +dsd,False,True,16,16,128,16,8,2,20.268089294433597 +dsd,False,True,16,32,8,2,16,2,47.83821105957031 +dsd,False,True,16,32,16,4,8,2,29.85913391113281 +dsd,False,True,16,32,32,8,16,2,24.07604217529297 +dsd,False,True,16,32,64,4,16,2,23.95204162597656 +dsd,False,True,16,32,128,2,2,2,inf +dsd,False,True,16,64,8,2,16,2,58.60889282226562 +dsd,False,True,16,64,16,4,4,2,39.94730529785157 +dsd,False,True,16,64,32,4,16,2,33.722048950195315 +dsd,False,True,16,64,64,2,2,2,inf +dsd,False,True,16,64,128,2,2,2,inf +dsd,False,True,16,128,8,2,4,2,80.9656982421875 +dsd,False,True,16,128,16,4,2,2,63.73284912109375 +dsd,False,True,16,128,32,2,2,2,inf +dsd,False,True,16,128,64,2,2,2,inf +dsd,False,True,16,128,128,2,2,2,inf +dsd,False,True,32,8,8,8,2,2,32.44737243652344 +dsd,False,True,32,8,16,8,8,2,20.403779602050783 +dsd,False,True,32,8,32,8,2,2,17.346044921875 +dsd,False,True,32,8,64,8,8,4,15.755056762695313 +dsd,False,True,32,8,128,8,2,8,14.2900634765625 +dsd,False,True,32,16,8,4,16,2,27.93015747070313 +dsd,False,True,32,16,16,4,2,2,19.360806274414063 +dsd,False,True,32,16,32,8,16,4,16.002423095703126 +dsd,False,True,32,16,64,8,16,4,14.859500122070312 +dsd,False,True,32,16,128,8,4,4,13.441619873046877 +dsd,False,True,32,32,8,4,8,2,36.38385009765625 +dsd,False,True,32,32,16,8,16,2,21.3083740234375 +dsd,False,True,32,32,32,8,16,2,17.330905151367187 +dsd,False,True,32,32,64,4,4,4,16.176535034179686 +dsd,False,True,32,32,128,2,2,2,inf +dsd,False,True,32,64,8,4,2,2,42.07832946777344 +dsd,False,True,32,64,16,8,16,2,27.796780395507813 +dsd,False,True,32,64,32,4,4,2,22.622706604003906 +dsd,False,True,32,64,64,2,2,2,inf +dsd,False,True,32,64,128,2,2,2,inf +dsd,False,True,32,128,8,2,16,2,59.004339599609374 +dsd,False,True,32,128,16,4,8,2,40.63676452636719 +dsd,False,True,32,128,32,2,2,2,inf +dsd,False,True,32,128,64,2,2,2,inf +dsd,False,True,32,128,128,2,2,2,inf +dsd,False,True,64,8,8,16,2,2,28.06235656738281 +dsd,False,True,64,8,16,8,4,4,17.711637878417967 +dsd,False,True,64,8,32,8,8,4,14.624826049804687 +dsd,False,True,64,8,64,8,2,4,13.30050811767578 +dsd,False,True,64,8,128,8,2,8,11.8025634765625 +dsd,False,True,64,16,8,8,8,2,23.832765197753908 +dsd,False,True,64,16,16,8,4,2,15.731478881835937 +dsd,False,True,64,16,32,8,8,4,13.64959716796875 +dsd,False,True,64,16,64,8,16,4,12.479481506347655 +dsd,False,True,64,16,128,16,4,4,12.026067352294922 +dsd,False,True,64,32,8,8,8,2,30.883258056640624 +dsd,False,True,64,32,16,8,8,2,18.295631408691406 +dsd,False,True,64,32,32,8,8,2,14.4781982421875 +dsd,False,True,64,32,64,8,16,4,12.99908447265625 +dsd,False,True,64,32,128,2,2,2,inf +dsd,False,True,64,64,8,4,16,2,36.06336059570312 +dsd,False,True,64,64,16,8,8,2,22.488557434082036 +dsd,False,True,64,64,32,8,4,2,16.641293334960938 +dsd,False,True,64,64,64,2,2,2,inf +dsd,False,True,64,64,128,2,2,2,inf +dsd,False,True,64,128,8,2,2,2,49.170901489257815 +dsd,False,True,64,128,16,4,2,2,30.151165771484376 +dsd,False,True,64,128,32,2,2,2,inf +dsd,False,True,64,128,64,2,2,2,inf +dsd,False,True,64,128,128,2,2,2,inf +dsd,False,True,128,8,8,16,2,4,28.43712463378906 +dsd,False,True,128,8,16,16,2,4,18.233078002929688 +dsd,False,True,128,8,32,16,8,4,14.328166198730468 +dsd,False,True,128,8,64,8,8,8,12.280473327636718 +dsd,False,True,128,8,128,8,2,8,12.583350372314452 +dsd,False,True,128,16,8,16,2,2,21.776124572753908 +dsd,False,True,128,16,16,16,2,2,15.49456024169922 +dsd,False,True,128,16,32,8,4,4,12.65200653076172 +dsd,False,True,128,16,64,16,2,4,12.016060638427737 +dsd,False,True,128,16,128,8,8,8,12.376700592041017 +dsd,False,True,128,32,8,8,4,2,29.58363342285156 +dsd,False,True,128,32,16,8,8,2,17.026530456542968 +dsd,False,True,128,32,32,16,16,2,13.03179473876953 +dsd,False,True,128,32,64,8,16,4,12.44648666381836 +dsd,False,True,128,32,128,2,2,2,inf +dsd,False,True,128,64,8,4,16,2,35.76015930175781 +dsd,False,True,128,64,16,8,8,2,19.7291259765625 +dsd,False,True,128,64,32,8,4,2,15.192633056640624 +dsd,False,True,128,64,64,2,2,2,inf +dsd,False,True,128,64,128,2,2,2,inf +dsd,False,True,128,128,8,2,2,2,inf +dsd,False,True,128,128,16,2,2,2,inf +dsd,False,True,128,128,32,2,2,2,inf +dsd,False,True,128,128,64,2,2,2,inf +dsd,False,True,128,128,128,2,2,2,inf +dsd,True,False,8,8,8,2,4,2,64.50003662109376 +dsd,True,False,8,8,16,4,8,2,43.85590515136719 +dsd,True,False,8,8,32,8,4,2,38.58082275390625 +dsd,True,False,8,8,64,4,4,8,35.12033386230469 +dsd,True,False,8,8,128,2,2,2,inf +dsd,True,False,8,16,8,2,4,2,60.79771728515625 +dsd,True,False,8,16,16,2,4,2,42.94162902832032 +dsd,True,False,8,16,32,4,8,4,37.130557250976565 +dsd,True,False,8,16,64,2,2,8,34.81022644042969 +dsd,True,False,8,16,128,8,2,4,33.78464965820312 +dsd,True,False,8,32,8,2,4,2,66.80986328125 +dsd,True,False,8,32,16,2,4,2,44.515805053710935 +dsd,True,False,8,32,32,2,8,4,40.30407104492188 +dsd,True,False,8,32,64,8,2,2,39.0494140625 +dsd,True,False,8,32,128,2,2,2,inf +dsd,True,False,8,64,8,2,4,2,66.54736328125 +dsd,True,False,8,64,16,2,4,2,44.97365112304688 +dsd,True,False,8,64,32,2,2,2,42.03638610839844 +dsd,True,False,8,64,64,2,2,2,inf +dsd,True,False,8,64,128,2,2,2,inf +dsd,True,False,8,128,8,2,8,2,69.602685546875 +dsd,True,False,8,128,16,2,2,2,42.459521484375 +dsd,True,False,8,128,32,2,2,2,inf +dsd,True,False,8,128,64,2,2,2,inf +dsd,True,False,8,128,128,2,2,2,inf +dsd,True,False,16,8,8,4,8,2,49.50525817871094 +dsd,True,False,16,8,16,4,2,2,25.052706909179687 +dsd,True,False,16,8,32,8,2,2,19.538114929199217 +dsd,True,False,16,8,64,4,2,8,17.691792297363282 +dsd,True,False,16,8,128,8,2,8,16.321661376953124 +dsd,True,False,16,16,8,2,2,2,39.98442993164063 +dsd,True,False,16,16,16,4,4,2,23.760643005371094 +dsd,True,False,16,16,32,4,4,2,20.69975433349609 +dsd,True,False,16,16,64,4,2,8,19.054969787597656 +dsd,True,False,16,16,128,4,4,8,19.52235260009765 +dsd,True,False,16,32,8,2,16,2,48.63282470703125 +dsd,True,False,16,32,16,4,8,2,26.63685607910156 +dsd,True,False,16,32,32,8,4,2,19.93187561035156 +dsd,True,False,16,32,64,4,16,4,20.57270965576172 +dsd,True,False,16,32,128,2,2,2,inf +dsd,True,False,16,64,8,2,16,2,47.72169799804688 +dsd,True,False,16,64,16,4,2,2,26.349447631835936 +dsd,True,False,16,64,32,4,4,2,22.27334442138672 +dsd,True,False,16,64,64,2,2,2,inf +dsd,True,False,16,64,128,2,2,2,inf +dsd,True,False,16,128,8,2,4,2,49.356442260742185 +dsd,True,False,16,128,16,2,16,2,30.78685607910156 +dsd,True,False,16,128,32,2,2,2,inf +dsd,True,False,16,128,64,2,2,2,inf +dsd,True,False,16,128,128,2,2,2,inf +dsd,True,False,32,8,8,8,4,2,58.26485595703125 +dsd,True,False,32,8,16,8,4,2,22.083929443359374 +dsd,True,False,32,8,32,8,4,4,15.802262878417968 +dsd,True,False,32,8,64,8,2,8,13.385020446777345 +dsd,True,False,32,8,128,8,2,8,12.973992919921876 +dsd,True,False,32,16,8,4,16,2,40.613363647460936 +dsd,True,False,32,16,16,8,4,2,21.56761932373047 +dsd,True,False,32,16,32,8,16,4,15.306796264648437 +dsd,True,False,32,16,64,8,2,4,13.917820739746094 +dsd,True,False,32,16,128,8,2,4,12.738409423828124 +dsd,True,False,32,32,8,4,2,2,49.412847900390624 +dsd,True,False,32,32,16,8,2,2,25.411875915527343 +dsd,True,False,32,32,32,8,4,2,16.540342712402342 +dsd,True,False,32,32,64,8,16,2,14.542752075195311 +dsd,True,False,32,32,128,2,2,2,inf +dsd,True,False,32,64,8,4,8,2,48.77427368164062 +dsd,True,False,32,64,16,8,8,2,27.46026611328125 +dsd,True,False,32,64,32,4,4,2,17.95720977783203 +dsd,True,False,32,64,64,2,2,2,inf +dsd,True,False,32,64,128,2,2,2,inf +dsd,True,False,32,128,8,2,16,2,53.78228149414063 +dsd,True,False,32,128,16,4,2,2,30.272491455078125 +dsd,True,False,32,128,32,2,2,2,inf +dsd,True,False,32,128,64,2,2,2,inf +dsd,True,False,32,128,128,2,2,2,inf +dsd,True,False,64,8,8,16,4,2,95.88981323242189 +dsd,True,False,64,8,16,8,2,4,30.042803955078124 +dsd,True,False,64,8,32,8,8,8,19.73251800537109 +dsd,True,False,64,8,64,8,2,4,15.112445068359374 +dsd,True,False,64,8,128,16,8,4,13.197616577148438 +dsd,True,False,64,16,8,8,16,2,57.47857055664063 +dsd,True,False,64,16,16,16,4,2,30.21904296875 +dsd,True,False,64,16,32,8,2,4,19.610873413085937 +dsd,True,False,64,16,64,16,2,2,15.586691284179688 +dsd,True,False,64,16,128,16,2,4,13.0103515625 +dsd,True,False,64,32,8,8,16,2,65.5755126953125 +dsd,True,False,64,32,16,8,4,2,34.70885009765625 +dsd,True,False,64,32,32,16,4,2,20.928297424316405 +dsd,True,False,64,32,64,8,16,4,15.283062744140626 +dsd,True,False,64,32,128,2,2,2,inf +dsd,True,False,64,64,8,8,16,2,66.76674194335938 +dsd,True,False,64,64,16,8,8,2,34.96489562988281 +dsd,True,False,64,64,32,8,4,2,21.019471740722658 +dsd,True,False,64,64,64,2,2,2,inf +dsd,True,False,64,64,128,2,2,2,inf +dsd,True,False,64,128,8,2,8,2,72.95369873046874 +dsd,True,False,64,128,16,4,16,2,37.51438598632812 +dsd,True,False,64,128,32,2,2,2,inf +dsd,True,False,64,128,64,2,2,2,inf +dsd,True,False,64,128,128,2,2,2,inf +dsd,True,False,128,8,8,2,2,2,inf +dsd,True,False,128,8,16,16,8,4,52.0931640625 +dsd,True,False,128,8,32,16,4,4,31.23505554199219 +dsd,True,False,128,8,64,16,2,4,20.91832885742188 +dsd,True,False,128,8,128,16,2,8,16.886825561523438 +dsd,True,False,128,16,8,16,4,2,97.2037109375 +dsd,True,False,128,16,16,16,2,2,51.648046875 +dsd,True,False,128,16,32,16,2,2,30.59534606933594 +dsd,True,False,128,16,64,16,2,4,20.84552001953125 +dsd,True,False,128,16,128,8,16,8,16.009730529785156 +dsd,True,False,128,32,8,8,16,2,104.88055419921876 +dsd,True,False,128,32,16,16,8,2,53.86217041015625 +dsd,True,False,128,32,32,16,8,2,30.620196533203124 +dsd,True,False,128,32,64,16,8,4,20.649945068359376 +dsd,True,False,128,32,128,2,2,2,inf +dsd,True,False,128,64,8,4,4,2,106.9627685546875 +dsd,True,False,128,64,16,8,16,2,54.41331176757812 +dsd,True,False,128,64,32,16,8,2,31.31766357421875 +dsd,True,False,128,64,64,2,2,2,inf +dsd,True,False,128,64,128,2,2,2,inf +dsd,True,False,128,128,8,2,2,2,inf +dsd,True,False,128,128,16,2,2,2,inf +dsd,True,False,128,128,32,2,2,2,inf +dsd,True,False,128,128,64,2,2,2,inf +dsd,True,False,128,128,128,2,2,2,inf +dsd,True,True,8,8,8,2,4,2,71.19293212890625 +dsd,True,True,8,8,16,4,8,2,47.526544189453126 +dsd,True,True,8,8,32,8,4,2,45.897119140625 +dsd,True,True,8,8,64,4,4,8,42.27239074707031 +dsd,True,True,8,8,128,8,8,8,34.71629638671875 +dsd,True,True,8,16,8,2,16,2,77.76537475585937 +dsd,True,True,8,16,16,2,8,2,44.7300048828125 +dsd,True,True,8,16,32,8,16,2,47.677008056640624 +dsd,True,True,8,16,64,2,4,8,43.859231567382814 +dsd,True,True,8,16,128,4,2,8,38.1770751953125 +dsd,True,True,8,32,8,2,16,2,104.03848876953126 +dsd,True,True,8,32,16,2,4,2,53.5608642578125 +dsd,True,True,8,32,32,4,8,2,47.158810424804685 +dsd,True,True,8,32,64,4,2,2,44.61595764160156 +dsd,True,True,8,32,128,2,2,2,inf +dsd,True,True,8,64,8,2,8,2,145.02247314453126 +dsd,True,True,8,64,16,2,16,2,71.74268188476563 +dsd,True,True,8,64,32,2,4,2,62.39791259765625 +dsd,True,True,8,64,64,2,2,2,inf +dsd,True,True,8,64,128,2,2,2,inf +dsd,True,True,8,128,8,2,2,2,inf +dsd,True,True,8,128,16,2,8,2,112.65904541015624 +dsd,True,True,8,128,32,2,2,2,inf +dsd,True,True,8,128,64,2,2,2,inf +dsd,True,True,8,128,128,2,2,2,inf +dsd,True,True,16,8,8,4,4,2,53.01531982421875 +dsd,True,True,16,8,16,4,4,2,27.02744445800781 +dsd,True,True,16,8,32,8,4,2,25.09804229736328 +dsd,True,True,16,8,64,16,8,2,20.802879333496094 +dsd,True,True,16,8,128,16,2,4,17.75169219970703 +dsd,True,True,16,16,8,2,16,2,43.57724304199219 +dsd,True,True,16,16,16,4,16,2,27.37100830078125 +dsd,True,True,16,16,32,4,2,2,26.01639404296875 +dsd,True,True,16,16,64,4,8,8,22.524560546875 +dsd,True,True,16,16,128,4,2,8,20.59218292236328 +dsd,True,True,16,32,8,2,8,2,57.39716796875 +dsd,True,True,16,32,16,4,8,2,34.71137390136719 +dsd,True,True,16,32,32,8,2,2,27.15743103027344 +dsd,True,True,16,32,64,4,4,4,25.47241668701172 +dsd,True,True,16,32,128,2,2,2,inf +dsd,True,True,16,64,8,2,4,2,67.52178955078125 +dsd,True,True,16,64,16,4,4,2,44.9527587890625 +dsd,True,True,16,64,32,4,16,2,36.574038696289065 +dsd,True,True,16,64,64,2,2,2,inf +dsd,True,True,16,64,128,2,2,2,inf +dsd,True,True,16,128,8,2,16,2,88.12987060546875 +dsd,True,True,16,128,16,2,4,2,70.0388427734375 +dsd,True,True,16,128,32,2,2,2,inf +dsd,True,True,16,128,64,2,2,2,inf +dsd,True,True,16,128,128,2,2,2,inf +dsd,True,True,32,8,8,8,2,2,60.1640625 +dsd,True,True,32,8,16,8,8,2,22.79217987060547 +dsd,True,True,32,8,32,8,8,4,17.33423309326172 +dsd,True,True,32,8,64,8,2,8,14.491964721679688 +dsd,True,True,32,8,128,8,8,8,13.583482360839843 +dsd,True,True,32,16,8,4,16,2,42.594668579101565 +dsd,True,True,32,16,16,8,4,2,23.31687622070313 +dsd,True,True,32,16,32,8,8,4,17.858428955078125 +dsd,True,True,32,16,64,8,16,4,15.46243133544922 +dsd,True,True,32,16,128,8,16,4,13.765618896484376 +dsd,True,True,32,32,8,4,16,2,53.68150024414062 +dsd,True,True,32,32,16,8,4,2,29.81298828125 +dsd,True,True,32,32,32,8,4,2,20.115510559082036 +dsd,True,True,32,32,64,4,2,4,17.780838012695312 +dsd,True,True,32,32,128,2,2,2,inf +dsd,True,True,32,64,8,4,16,2,58.4751220703125 +dsd,True,True,32,64,16,8,16,2,35.29141845703125 +dsd,True,True,32,64,32,8,4,2,26.837109375 +dsd,True,True,32,64,64,2,2,2,inf +dsd,True,True,32,64,128,2,2,2,inf +dsd,True,True,32,128,8,2,4,2,74.15425415039063 +dsd,True,True,32,128,16,4,2,2,49.33531799316406 +dsd,True,True,32,128,32,2,2,2,inf +dsd,True,True,32,128,64,2,2,2,inf +dsd,True,True,32,128,128,2,2,2,inf +dsd,True,True,64,8,8,16,8,2,96.929638671875 +dsd,True,True,64,8,16,8,4,4,30.505270385742183 +dsd,True,True,64,8,32,8,2,4,20.189126586914064 +dsd,True,True,64,8,64,8,8,4,15.148918151855469 +dsd,True,True,64,8,128,16,8,4,13.314019775390625 +dsd,True,True,64,16,8,8,8,2,58.3099365234375 +dsd,True,True,64,16,16,16,4,2,31.09625244140625 +dsd,True,True,64,16,32,8,2,4,20.450559997558592 +dsd,True,True,64,16,64,8,16,4,16.086697387695313 +dsd,True,True,64,16,128,16,4,4,13.981097412109374 +dsd,True,True,64,32,8,8,4,2,67.69793090820312 +dsd,True,True,64,32,16,16,16,2,36.72890014648438 +dsd,True,True,64,32,32,16,16,2,22.644483947753905 +dsd,True,True,64,32,64,8,16,4,17.29567108154297 +dsd,True,True,64,32,128,2,2,2,inf +dsd,True,True,64,64,8,8,16,2,71.69373168945313 +dsd,True,True,64,64,16,8,16,2,39.72655334472656 +dsd,True,True,64,64,32,8,8,2,25.80042419433594 +dsd,True,True,64,64,64,2,2,2,inf +dsd,True,True,64,64,128,2,2,2,inf +dsd,True,True,64,128,8,2,8,2,83.09647216796876 +dsd,True,True,64,128,16,4,16,2,47.12303771972656 +dsd,True,True,64,128,32,2,2,2,inf +dsd,True,True,64,128,64,2,2,2,inf +dsd,True,True,64,128,128,2,2,2,inf +dsd,True,True,128,8,8,2,2,2,inf +dsd,True,True,128,8,16,16,8,4,52.12281494140625 +dsd,True,True,128,8,32,16,8,4,31.438470458984376 +dsd,True,True,128,8,64,16,8,4,20.78094024658203 +dsd,True,True,128,8,128,16,2,8,17.265670776367188 +dsd,True,True,128,16,8,16,16,2,97.34880981445312 +dsd,True,True,128,16,16,16,8,2,51.683660888671874 +dsd,True,True,128,16,32,16,2,2,30.91318664550781 +dsd,True,True,128,16,64,16,16,4,21.14001007080078 +dsd,True,True,128,16,128,4,2,16,16.960723876953125 +dsd,True,True,128,32,8,8,8,2,105.91396484375 +dsd,True,True,128,32,16,16,4,2,54.87500610351562 +dsd,True,True,128,32,32,16,8,2,31.5608642578125 +dsd,True,True,128,32,64,8,4,4,21.677349853515626 +dsd,True,True,128,32,128,2,2,2,inf +dsd,True,True,128,64,8,4,16,2,109.50223388671876 +dsd,True,True,128,64,16,8,16,2,56.69862670898438 +dsd,True,True,128,64,32,16,8,2,33.59224548339844 +dsd,True,True,128,64,64,2,2,2,inf +dsd,True,True,128,64,128,2,2,2,inf +dsd,True,True,128,128,8,2,2,2,inf +dsd,True,True,128,128,16,2,2,2,inf +dsd,True,True,128,128,32,2,2,2,inf +dsd,True,True,128,128,64,2,2,2,inf +dsd,True,True,128,128,128,2,2,2,inf +sdd,False,False,8,8,8,2,2,2,72.04984130859376 +sdd,False,False,8,8,16,4,8,2,45.3943115234375 +sdd,False,False,8,8,32,8,4,2,38.74576110839844 +sdd,False,False,8,8,64,4,2,8,36.7857421875 +sdd,False,False,8,8,128,2,2,2,inf +sdd,False,False,8,16,8,2,16,2,70.55133056640625 +sdd,False,False,8,16,16,4,16,2,46.160275268554685 +sdd,False,False,8,16,32,4,2,2,36.251708984375 +sdd,False,False,8,16,64,2,2,8,34.42546691894531 +sdd,False,False,8,16,128,4,2,8,35.05966491699219 +sdd,False,False,8,32,8,4,4,2,81.95740966796875 +sdd,False,False,8,32,16,2,4,2,46.26115417480469 +sdd,False,False,8,32,32,4,8,2,38.46751708984375 +sdd,False,False,8,32,64,8,4,2,37.41645812988281 +sdd,False,False,8,32,128,8,16,2,37.6352294921875 +sdd,False,False,8,64,8,2,4,2,74.85977172851562 +sdd,False,False,8,64,16,2,8,2,43.94977111816407 +sdd,False,False,8,64,32,2,16,2,38.66516418457032 +sdd,False,False,8,64,64,2,16,2,37.02501525878906 +sdd,False,False,8,64,128,4,2,2,37.57594604492188 +sdd,False,False,8,128,8,2,2,2,inf +sdd,False,False,8,128,16,2,8,2,50.735546875 +sdd,False,False,8,128,32,2,4,2,44.66907958984375 +sdd,False,False,8,128,64,2,16,2,42.6173095703125 +sdd,False,False,8,128,128,2,2,2,inf +sdd,False,False,16,8,8,4,4,2,40.41903381347656 +sdd,False,False,16,8,16,4,8,2,24.17235565185547 +sdd,False,False,16,8,32,8,8,2,19.93219146728516 +sdd,False,False,16,8,64,2,2,16,19.359555053710935 +sdd,False,False,16,8,128,16,8,4,19.00537872314453 +sdd,False,False,16,16,8,4,4,2,42.32752380371094 +sdd,False,False,16,16,16,2,16,4,25.86504821777344 +sdd,False,False,16,16,32,4,2,2,20.175260925292967 +sdd,False,False,16,16,64,4,16,8,19.01503295898437 +sdd,False,False,16,16,128,4,2,8,19.52960357666016 +sdd,False,False,16,32,8,4,4,2,49.81890258789063 +sdd,False,False,16,32,16,4,2,2,26.464312744140624 +sdd,False,False,16,32,32,8,2,2,20.246003723144533 +sdd,False,False,16,32,64,4,4,4,19.618144226074214 +sdd,False,False,16,32,128,4,4,4,19.372467041015625 +sdd,False,False,16,64,8,4,2,2,43.28736572265625 +sdd,False,False,16,64,16,4,8,2,26.00096740722656 +sdd,False,False,16,64,32,4,2,2,20.53853759765625 +sdd,False,False,16,64,64,4,16,2,19.274000549316405 +sdd,False,False,16,64,128,4,4,2,19.3044677734375 +sdd,False,False,16,128,8,2,16,2,43.4588134765625 +sdd,False,False,16,128,16,2,8,2,29.749468994140624 +sdd,False,False,16,128,32,2,8,2,22.932847595214845 +sdd,False,False,16,128,64,4,2,2,22.40806121826172 +sdd,False,False,16,128,128,2,2,2,inf +sdd,False,False,32,8,8,8,2,2,28.4953125 +sdd,False,False,32,8,16,8,4,2,16.595855712890625 +sdd,False,False,32,8,32,8,2,4,13.706492614746091 +sdd,False,False,32,8,64,8,4,8,13.076988220214844 +sdd,False,False,32,8,128,8,2,8,13.090707397460935 +sdd,False,False,32,16,8,8,8,2,28.655096435546877 +sdd,False,False,32,16,16,8,16,2,17.110902404785158 +sdd,False,False,32,16,32,8,4,4,13.602815246582033 +sdd,False,False,32,16,64,8,8,4,13.031321716308591 +sdd,False,False,32,16,128,8,8,4,12.628358459472656 +sdd,False,False,32,32,8,4,16,2,34.516058349609374 +sdd,False,False,32,32,16,8,16,2,18.954006958007813 +sdd,False,False,32,32,32,8,2,2,14.182521057128906 +sdd,False,False,32,32,64,8,4,2,13.5316162109375 +sdd,False,False,32,32,128,8,4,4,12.78335647583008 +sdd,False,False,32,64,8,4,8,2,33.892111206054686 +sdd,False,False,32,64,16,4,16,2,20.76946868896484 +sdd,False,False,32,64,32,4,8,2,15.313661193847656 +sdd,False,False,32,64,64,8,16,2,13.90250244140625 +sdd,False,False,32,64,128,8,16,2,13.511581420898438 +sdd,False,False,32,128,8,2,2,2,inf +sdd,False,False,32,128,16,2,2,2,inf +sdd,False,False,32,128,32,2,2,2,inf +sdd,False,False,32,128,64,2,2,2,inf +sdd,False,False,32,128,128,2,2,2,inf +sdd,False,False,64,8,8,8,4,4,23.24774017333984 +sdd,False,False,64,8,16,8,8,4,13.581919860839845 +sdd,False,False,64,8,32,8,2,8,12.061090850830078 +sdd,False,False,64,8,64,8,2,8,11.904332733154297 +sdd,False,False,64,8,128,8,8,8,11.4974365234375 +sdd,False,False,64,16,8,8,8,2,22.200015258789065 +sdd,False,False,64,16,16,16,16,2,13.94998779296875 +sdd,False,False,64,16,32,8,2,4,12.431423950195311 +sdd,False,False,64,16,64,8,16,4,11.724240112304688 +sdd,False,False,64,16,128,16,4,4,11.464300537109375 +sdd,False,False,64,32,8,8,16,2,29.659991455078124 +sdd,False,False,64,32,16,8,4,2,16.99700164794922 +sdd,False,False,64,32,32,8,16,2,12.782505798339844 +sdd,False,False,64,32,64,8,16,4,11.693708801269532 +sdd,False,False,64,32,128,16,16,2,11.619532775878906 +sdd,False,False,64,64,8,2,2,2,inf +sdd,False,False,64,64,16,2,2,2,inf +sdd,False,False,64,64,32,2,2,2,inf +sdd,False,False,64,64,64,2,2,2,inf +sdd,False,False,64,64,128,2,2,2,inf +sdd,False,False,64,128,8,2,2,2,inf +sdd,False,False,64,128,16,2,2,2,inf +sdd,False,False,64,128,32,2,2,2,inf +sdd,False,False,64,128,64,2,2,2,inf +sdd,False,False,64,128,128,2,2,2,inf +sdd,False,False,128,8,8,16,4,4,22.090733337402344 +sdd,False,False,128,8,16,16,4,4,12.416374206542969 +sdd,False,False,128,8,32,8,2,8,11.87421112060547 +sdd,False,False,128,8,64,8,8,8,11.37545928955078 +sdd,False,False,128,8,128,8,8,8,12.38086395263672 +sdd,False,False,128,16,8,16,4,2,19.331510925292967 +sdd,False,False,128,16,16,16,16,2,13.283071899414065 +sdd,False,False,128,16,32,8,16,4,11.54400634765625 +sdd,False,False,128,16,64,16,8,4,11.143462371826171 +sdd,False,False,128,16,128,8,4,8,11.590351867675782 +sdd,False,False,128,32,8,2,2,2,inf +sdd,False,False,128,32,16,2,2,2,inf +sdd,False,False,128,32,32,2,2,2,inf +sdd,False,False,128,32,64,2,2,2,inf +sdd,False,False,128,32,128,2,2,2,inf +sdd,False,False,128,64,8,2,2,2,inf +sdd,False,False,128,64,16,2,2,2,inf +sdd,False,False,128,64,32,2,2,2,inf +sdd,False,False,128,64,64,2,2,2,inf +sdd,False,False,128,64,128,2,2,2,inf +sdd,False,False,128,128,8,2,2,2,inf +sdd,False,False,128,128,16,2,2,2,inf +sdd,False,False,128,128,32,2,2,2,inf +sdd,False,False,128,128,64,2,2,2,inf +sdd,False,False,128,128,128,2,2,2,inf +sdd,False,True,8,8,8,2,2,2,91.77452392578124 +sdd,False,True,8,8,16,2,8,8,96.86319580078126 +sdd,False,True,8,8,32,2,4,16,99.28195190429688 +sdd,False,True,8,8,64,4,8,16,129.086669921875 +sdd,False,True,8,8,128,8,4,16,91.06226806640623 +sdd,False,True,8,16,8,2,16,2,66.354736328125 +sdd,False,True,8,16,16,4,16,2,48.97417907714844 +sdd,False,True,8,16,32,4,16,2,46.55622253417969 +sdd,False,True,8,16,64,8,16,4,45.48564758300781 +sdd,False,True,8,16,128,2,4,16,44.08933715820312 +sdd,False,True,8,32,8,2,8,2,89.6635986328125 +sdd,False,True,8,32,16,2,4,2,46.87720947265625 +sdd,False,True,8,32,32,4,16,2,38.79921264648438 +sdd,False,True,8,32,64,4,8,2,38.43687744140625 +sdd,False,True,8,32,128,8,4,2,38.750390625 +sdd,False,True,8,64,8,2,2,2,131.63740234375 +sdd,False,True,8,64,16,2,16,2,66.56493530273437 +sdd,False,True,8,64,32,4,2,2,56.46593017578125 +sdd,False,True,8,64,64,4,16,2,54.95045166015625 +sdd,False,True,8,64,128,4,4,2,55.65963745117188 +sdd,False,True,8,128,8,2,2,2,inf +sdd,False,True,8,128,16,2,8,2,106.23653564453124 +sdd,False,True,8,128,32,2,2,2,101.65936889648438 +sdd,False,True,8,128,64,2,4,2,100.1741455078125 +sdd,False,True,8,128,128,2,2,2,inf +sdd,False,True,16,8,8,4,8,2,47.9957275390625 +sdd,False,True,16,8,16,8,4,4,49.27310485839844 +sdd,False,True,16,8,32,16,8,4,47.62242431640625 +sdd,False,True,16,8,64,2,4,16,54.047509765625 +sdd,False,True,16,8,128,8,2,8,52.71130981445312 +sdd,False,True,16,16,8,2,4,2,35.24564819335937 +sdd,False,True,16,16,16,2,16,4,27.9607421875 +sdd,False,True,16,16,32,8,4,2,24.983120727539063 +sdd,False,True,16,16,64,4,2,8,24.548185729980467 +sdd,False,True,16,16,128,4,2,8,23.86432342529297 +sdd,False,True,16,32,8,2,8,2,47.88221130371094 +sdd,False,True,16,32,16,4,8,2,29.615542602539065 +sdd,False,True,16,32,32,8,16,2,22.84046783447265 +sdd,False,True,16,32,64,4,4,4,21.95774688720703 +sdd,False,True,16,32,128,4,8,4,21.562205505371093 +sdd,False,True,16,64,8,2,16,2,58.48067626953125 +sdd,False,True,16,64,16,4,16,2,39.72404174804687 +sdd,False,True,16,64,32,8,4,2,33.36689147949219 +sdd,False,True,16,64,64,4,4,2,32.759228515625 +sdd,False,True,16,64,128,4,2,4,32.71089172363281 +sdd,False,True,16,128,8,2,16,2,77.55657348632812 +sdd,False,True,16,128,16,4,2,2,62.15914916992188 +sdd,False,True,16,128,32,4,4,2,55.14508056640625 +sdd,False,True,16,128,64,4,16,2,53.7142578125 +sdd,False,True,16,128,128,2,2,2,inf +sdd,False,True,32,8,8,8,4,2,30.496945190429688 +sdd,False,True,32,8,16,16,8,2,25.50102081298828 +sdd,False,True,32,8,32,16,4,4,24.51743621826172 +sdd,False,True,32,8,64,8,4,16,27.03175048828125 +sdd,False,True,32,8,128,8,2,16,34.25230102539062 +sdd,False,True,32,16,8,4,4,2,26.0266845703125 +sdd,False,True,32,16,16,8,2,2,17.680572509765625 +sdd,False,True,32,16,32,8,2,4,14.919331359863282 +sdd,False,True,32,16,64,8,16,4,14.73011474609375 +sdd,False,True,32,16,128,8,8,4,14.204704284667969 +sdd,False,True,32,32,8,4,16,2,36.43758544921875 +sdd,False,True,32,32,16,8,16,2,21.21480255126953 +sdd,False,True,32,32,32,8,16,2,16.7617919921875 +sdd,False,True,32,32,64,8,2,4,16.181004333496094 +sdd,False,True,32,32,128,16,16,2,15.448342895507812 +sdd,False,True,32,64,8,4,8,2,41.87442932128906 +sdd,False,True,32,64,16,8,4,2,26.48095703125 +sdd,False,True,32,64,32,4,16,2,22.530685424804688 +sdd,False,True,32,64,64,8,16,2,21.07642822265625 +sdd,False,True,32,64,128,8,2,4,21.36976013183594 +sdd,False,True,32,128,8,2,2,2,inf +sdd,False,True,32,128,16,2,2,2,inf +sdd,False,True,32,128,32,2,2,2,inf +sdd,False,True,32,128,64,2,2,2,inf +sdd,False,True,32,128,128,2,2,2,inf +sdd,False,True,64,8,8,16,2,2,24.09486389160156 +sdd,False,True,64,8,16,8,2,4,15.721728515625 +sdd,False,True,64,8,32,16,8,4,14.562950134277344 +sdd,False,True,64,8,64,8,2,4,13.80590362548828 +sdd,False,True,64,8,128,16,2,4,13.107994079589844 +sdd,False,True,64,16,8,8,2,2,21.141920471191405 +sdd,False,True,64,16,16,16,2,2,14.441622924804689 +sdd,False,True,64,16,32,8,2,4,13.028182983398438 +sdd,False,True,64,16,64,8,8,4,12.355359649658205 +sdd,False,True,64,16,128,16,4,4,12.280850982666015 +sdd,False,True,64,32,8,8,8,2,30.81255187988281 +sdd,False,True,64,32,16,8,16,2,18.20790405273437 +sdd,False,True,64,32,32,8,4,2,14.05712890625 +sdd,False,True,64,32,64,8,16,4,12.909103393554688 +sdd,False,True,64,32,128,8,2,4,13.388832092285156 +sdd,False,True,64,64,8,2,2,2,inf +sdd,False,True,64,64,16,2,2,2,inf +sdd,False,True,64,64,32,2,2,2,inf +sdd,False,True,64,64,64,2,2,2,inf +sdd,False,True,64,64,128,2,2,2,inf +sdd,False,True,64,128,8,2,2,2,inf +sdd,False,True,64,128,16,2,2,2,inf +sdd,False,True,64,128,32,2,2,2,inf +sdd,False,True,64,128,64,2,2,2,inf +sdd,False,True,64,128,128,2,2,2,inf +sdd,False,True,128,8,8,16,4,4,22.91182098388672 +sdd,False,True,128,8,16,16,2,4,13.245916748046875 +sdd,False,True,128,8,32,8,8,8,12.6768798828125 +sdd,False,True,128,8,64,8,8,8,11.811756896972655 +sdd,False,True,128,8,128,8,2,8,12.738713836669922 +sdd,False,True,128,16,8,16,8,2,18.884083557128907 +sdd,False,True,128,16,16,16,4,2,13.472808837890623 +sdd,False,True,128,16,32,8,16,4,12.021001434326172 +sdd,False,True,128,16,64,16,4,4,11.51459503173828 +sdd,False,True,128,16,128,8,4,8,12.582121276855467 +sdd,False,True,128,32,8,2,2,2,inf +sdd,False,True,128,32,16,2,2,2,inf +sdd,False,True,128,32,32,2,2,2,inf +sdd,False,True,128,32,64,2,2,2,inf +sdd,False,True,128,32,128,2,2,2,inf +sdd,False,True,128,64,8,2,2,2,inf +sdd,False,True,128,64,16,2,2,2,inf +sdd,False,True,128,64,32,2,2,2,inf +sdd,False,True,128,64,64,2,2,2,inf +sdd,False,True,128,64,128,2,2,2,inf +sdd,False,True,128,128,8,2,2,2,inf +sdd,False,True,128,128,16,2,2,2,inf +sdd,False,True,128,128,32,2,2,2,inf +sdd,False,True,128,128,64,2,2,2,inf +sdd,False,True,128,128,128,2,2,2,inf +sdd,True,False,8,8,8,2,2,2,73.9848388671875 +sdd,True,False,8,8,16,4,4,2,46.17842712402344 +sdd,True,False,8,8,32,8,4,2,37.888674926757815 +sdd,True,False,8,8,64,4,8,8,34.64777526855469 +sdd,True,False,8,8,128,2,2,2,inf +sdd,True,False,8,16,8,2,2,2,73.74010620117187 +sdd,True,False,8,16,16,4,4,2,46.40711059570312 +sdd,True,False,8,16,32,2,2,4,34.76233825683594 +sdd,True,False,8,16,64,2,2,8,32.09572143554688 +sdd,True,False,8,16,128,8,8,4,32.73079223632813 +sdd,True,False,8,32,8,4,8,2,86.20098876953125 +sdd,True,False,8,32,16,4,2,2,44.73761596679688 +sdd,True,False,8,32,32,4,2,2,37.705255126953126 +sdd,True,False,8,32,64,8,2,2,37.21822814941406 +sdd,True,False,8,32,128,8,4,2,38.06083374023437 +sdd,True,False,8,64,8,4,2,2,84.308056640625 +sdd,True,False,8,64,16,2,4,2,45.06987915039063 +sdd,True,False,8,64,32,2,8,2,40.30692443847656 +sdd,True,False,8,64,64,2,4,2,38.734326171875 +sdd,True,False,8,64,128,4,2,2,25.005967712402345 +sdd,True,False,8,128,8,2,8,2,70.26617431640625 +sdd,True,False,8,128,16,2,2,2,50.441259765625 +sdd,True,False,8,128,32,2,16,2,35.400308227539064 +sdd,True,False,8,128,64,2,4,2,39.35030517578125 +sdd,True,False,8,128,128,2,2,2,inf +sdd,True,False,16,8,8,4,4,2,51.97205810546875 +sdd,True,False,16,8,16,4,4,2,25.86021728515625 +sdd,True,False,16,8,32,8,2,2,19.348057556152344 +sdd,True,False,16,8,64,16,8,2,17.644924926757813 +sdd,True,False,16,8,128,16,8,4,16.097267150878906 +sdd,True,False,16,16,8,2,4,2,49.89132080078125 +sdd,True,False,16,16,16,4,2,2,26.679196166992188 +sdd,True,False,16,16,32,8,4,2,19.98988494873047 +sdd,True,False,16,16,64,16,4,2,17.62073669433594 +sdd,True,False,16,16,128,4,8,8,18.66569213867188 +sdd,True,False,16,32,8,2,2,2,53.354669189453126 +sdd,True,False,16,32,16,4,16,2,27.72266540527344 +sdd,True,False,16,32,32,8,8,2,18.98997802734375 +sdd,True,False,16,32,64,4,8,2,19.319593811035155 +sdd,True,False,16,32,128,2,2,8,18.159635925292967 +sdd,True,False,16,64,8,2,2,2,49.46534423828125 +sdd,True,False,16,64,16,4,2,2,26.9336669921875 +sdd,True,False,16,64,32,4,8,2,21.0331298828125 +sdd,True,False,16,64,64,4,2,2,19.843218994140624 +sdd,True,False,16,64,128,4,2,2,19.66313934326172 +sdd,True,False,16,128,8,2,16,2,51.30648193359375 +sdd,True,False,16,128,16,2,2,2,30.73262023925781 +sdd,True,False,16,128,32,2,8,2,23.573583984375 +sdd,True,False,16,128,64,2,2,2,22.738954162597658 +sdd,True,False,16,128,128,2,2,2,inf +sdd,True,False,32,8,8,8,2,2,59.884222412109374 +sdd,True,False,32,8,16,8,4,2,22.643101501464844 +sdd,True,False,32,8,32,8,8,4,15.67876739501953 +sdd,True,False,32,8,64,8,8,8,13.152056884765624 +sdd,True,False,32,8,128,8,8,8,12.867955017089844 +sdd,True,False,32,16,8,4,4,2,42.48688049316407 +sdd,True,False,32,16,16,8,16,2,22.31826171875 +sdd,True,False,32,16,32,8,16,4,15.291162109375 +sdd,True,False,32,16,64,8,4,4,13.823011779785157 +sdd,True,False,32,16,128,8,8,4,12.749072265625 +sdd,True,False,32,32,8,4,16,2,51.068521118164064 +sdd,True,False,32,32,16,8,2,2,26.098941040039065 +sdd,True,False,32,32,32,8,4,2,16.640992736816408 +sdd,True,False,32,32,64,8,4,4,14.561097717285156 +sdd,True,False,32,32,128,8,2,4,12.976034545898438 +sdd,True,False,32,64,8,4,16,2,50.79121704101563 +sdd,True,False,32,64,16,8,16,2,27.523394775390624 +sdd,True,False,32,64,32,4,2,2,18.135264587402343 +sdd,True,False,32,64,64,8,8,2,14.8186279296875 +sdd,True,False,32,64,128,8,16,2,13.958329772949218 +sdd,True,False,32,128,8,2,2,2,inf +sdd,True,False,32,128,16,2,2,2,inf +sdd,True,False,32,128,32,2,2,2,inf +sdd,True,False,32,128,64,2,2,2,inf +sdd,True,False,32,128,128,2,2,2,inf +sdd,True,False,64,8,8,16,4,2,96.2021240234375 +sdd,True,False,64,8,16,8,2,4,30.0603515625 +sdd,True,False,64,8,32,8,8,4,19.49802856445313 +sdd,True,False,64,8,64,8,2,4,15.104576110839844 +sdd,True,False,64,8,128,8,8,8,12.94129638671875 +sdd,True,False,64,16,8,8,4,2,58.28064575195312 +sdd,True,False,64,16,16,16,8,2,30.451214599609376 +sdd,True,False,64,16,32,8,2,4,19.571139526367187 +sdd,True,False,64,16,64,16,2,2,15.562969970703126 +sdd,True,False,64,16,128,16,4,4,13.360887145996092 +sdd,True,False,64,32,8,8,16,2,66.38638305664062 +sdd,True,False,64,32,16,8,8,2,34.95736999511719 +sdd,True,False,64,32,32,16,8,2,20.5401123046875 +sdd,True,False,64,32,64,8,16,4,15.419955444335937 +sdd,True,False,64,32,128,8,2,4,13.664102172851562 +sdd,True,False,64,64,8,2,2,2,inf +sdd,True,False,64,64,16,2,2,2,inf +sdd,True,False,64,64,32,2,2,2,inf +sdd,True,False,64,64,64,2,2,2,inf +sdd,True,False,64,64,128,2,2,2,inf +sdd,True,False,64,128,8,2,2,2,inf +sdd,True,False,64,128,16,2,2,2,inf +sdd,True,False,64,128,32,2,2,2,inf +sdd,True,False,64,128,64,2,2,2,inf +sdd,True,False,64,128,128,2,2,2,inf +sdd,True,False,128,8,8,2,2,2,inf +sdd,True,False,128,8,16,16,8,4,52.1150146484375 +sdd,True,False,128,8,32,8,8,8,30.83111572265625 +sdd,True,False,128,8,64,8,8,8,20.565753173828124 +sdd,True,False,128,8,128,16,2,8,16.8664794921875 +sdd,True,False,128,16,8,16,8,2,97.72430419921876 +sdd,True,False,128,16,16,16,2,2,51.40657958984375 +sdd,True,False,128,16,32,16,2,2,30.401171875 +sdd,True,False,128,16,64,16,4,4,20.604071044921877 +sdd,True,False,128,16,128,8,16,8,16.067298889160156 +sdd,True,False,128,32,8,2,2,2,inf +sdd,True,False,128,32,16,2,2,2,inf +sdd,True,False,128,32,32,2,2,2,inf +sdd,True,False,128,32,64,2,2,2,inf +sdd,True,False,128,32,128,2,2,2,inf +sdd,True,False,128,64,8,2,2,2,inf +sdd,True,False,128,64,16,2,2,2,inf +sdd,True,False,128,64,32,2,2,2,inf +sdd,True,False,128,64,64,2,2,2,inf +sdd,True,False,128,64,128,2,2,2,inf +sdd,True,False,128,128,8,2,2,2,inf +sdd,True,False,128,128,16,2,2,2,inf +sdd,True,False,128,128,32,2,2,2,inf +sdd,True,False,128,128,64,2,2,2,inf +sdd,True,False,128,128,128,2,2,2,inf +sdd,True,True,8,8,8,2,2,2,70.79592895507812 +sdd,True,True,8,8,16,2,4,4,65.1414794921875 +sdd,True,True,8,8,32,8,4,4,83.34755859375 +sdd,True,True,8,8,64,2,2,16,119.6942626953125 +sdd,True,True,8,8,128,8,4,8,92.86500854492188 +sdd,True,True,8,16,8,2,8,2,73.23809814453125 +sdd,True,True,8,16,16,4,4,2,50.65773010253906 +sdd,True,True,8,16,32,2,16,4,46.528628540039065 +sdd,True,True,8,16,64,8,8,4,44.46695556640625 +sdd,True,True,8,16,128,2,16,16,42.68679504394531 +sdd,True,True,8,32,8,2,16,2,96.91433715820312 +sdd,True,True,8,32,16,2,16,2,48.09185791015625 +sdd,True,True,8,32,32,4,4,2,39.37124938964844 +sdd,True,True,8,32,64,4,2,2,39.133636474609375 +sdd,True,True,8,32,128,8,8,2,40.11669006347656 +sdd,True,True,8,64,8,2,2,2,138.27012939453124 +sdd,True,True,8,64,16,2,16,2,67.56966552734374 +sdd,True,True,8,64,32,4,8,2,58.15477294921875 +sdd,True,True,8,64,64,4,4,2,56.105419921875 +sdd,True,True,8,64,128,4,8,2,56.26817626953125 +sdd,True,True,8,128,8,2,2,2,inf +sdd,True,True,8,128,16,2,16,2,108.03819580078124 +sdd,True,True,8,128,32,2,4,2,101.84613647460938 +sdd,True,True,8,128,64,2,16,2,99.8545654296875 +sdd,True,True,8,128,128,2,2,2,inf +sdd,True,True,16,8,8,4,2,2,53.96688232421875 +sdd,True,True,16,8,16,8,2,2,36.35191345214844 +sdd,True,True,16,8,32,8,2,8,48.81896362304688 +sdd,True,True,16,8,64,2,8,16,46.240542602539065 +sdd,True,True,16,8,128,8,2,8,42.79500122070313 +sdd,True,True,16,16,8,2,16,2,42.394622802734375 +sdd,True,True,16,16,16,4,4,2,27.01558532714844 +sdd,True,True,16,16,32,8,16,2,24.6191162109375 +sdd,True,True,16,16,64,8,2,4,23.42810821533203 +sdd,True,True,16,16,128,8,4,4,22.794557189941408 +sdd,True,True,16,32,8,2,16,2,55.01532592773437 +sdd,True,True,16,32,16,4,2,2,33.358612060546875 +sdd,True,True,16,32,32,8,2,2,23.708087158203124 +sdd,True,True,16,32,64,4,2,4,22.57019805908203 +sdd,True,True,16,32,128,4,4,4,21.4740966796875 +sdd,True,True,16,64,8,2,16,2,65.340087890625 +sdd,True,True,16,64,16,4,8,2,43.76452026367188 +sdd,True,True,16,64,32,8,16,2,34.87169189453125 +sdd,True,True,16,64,64,4,4,4,33.64231872558594 +sdd,True,True,16,64,128,4,2,4,32.75396728515625 +sdd,True,True,16,128,8,2,16,2,84.63165283203125 +sdd,True,True,16,128,16,4,2,2,65.06669311523437 +sdd,True,True,16,128,32,4,2,2,56.94905395507813 +sdd,True,True,16,128,64,4,4,2,55.36854858398438 +sdd,True,True,16,128,128,2,2,2,inf +sdd,True,True,32,8,8,8,8,2,61.72941284179687 +sdd,True,True,32,8,16,8,2,2,24.83875274658203 +sdd,True,True,32,8,32,8,4,2,20.58190460205078 +sdd,True,True,32,8,64,8,8,4,26.2571044921875 +sdd,True,True,32,8,128,8,4,16,26.21186828613281 +sdd,True,True,32,16,8,4,16,2,43.00454711914063 +sdd,True,True,32,16,16,8,4,2,23.92322540283203 +sdd,True,True,32,16,32,8,8,4,17.551664733886717 +sdd,True,True,32,16,64,8,2,4,15.735151672363282 +sdd,True,True,32,16,128,8,4,4,14.5494873046875 +sdd,True,True,32,32,8,4,8,2,53.545587158203126 +sdd,True,True,32,32,16,8,16,2,29.740457153320317 +sdd,True,True,32,32,32,8,4,2,20.164134216308597 +sdd,True,True,32,32,64,8,4,4,17.43151092529297 +sdd,True,True,32,32,128,16,4,2,16.06116180419922 +sdd,True,True,32,64,8,4,16,2,58.31881713867188 +sdd,True,True,32,64,16,8,16,2,34.770635986328124 +sdd,True,True,32,64,32,8,8,2,26.29350891113281 +sdd,True,True,32,64,64,8,16,2,22.838330078125 +sdd,True,True,32,64,128,8,2,4,22.09839630126953 +sdd,True,True,32,128,8,2,2,2,inf +sdd,True,True,32,128,16,2,2,2,inf +sdd,True,True,32,128,32,2,2,2,inf +sdd,True,True,32,128,64,2,2,2,inf +sdd,True,True,32,128,128,2,2,2,inf +sdd,True,True,64,8,8,16,2,2,97.34681396484376 +sdd,True,True,64,8,16,8,8,4,30.87897644042969 +sdd,True,True,64,8,32,8,8,8,20.65984649658203 +sdd,True,True,64,8,64,8,2,4,16.07484130859375 +sdd,True,True,64,8,128,16,8,4,13.826019287109377 +sdd,True,True,64,16,8,8,2,2,58.582769775390624 +sdd,True,True,64,16,16,16,8,2,31.01230773925781 +sdd,True,True,64,16,32,8,2,4,20.705232238769533 +sdd,True,True,64,16,64,8,16,4,16.316441345214844 +sdd,True,True,64,16,128,16,8,4,14.138400268554689 +sdd,True,True,64,32,8,8,8,2,67.5900146484375 +sdd,True,True,64,32,16,16,4,2,36.62651977539063 +sdd,True,True,64,32,32,16,8,2,22.28140869140625 +sdd,True,True,64,32,64,8,16,4,17.178799438476563 +sdd,True,True,64,32,128,8,4,4,15.548410034179687 +sdd,True,True,64,64,8,2,2,2,inf +sdd,True,True,64,64,16,2,2,2,inf +sdd,True,True,64,64,32,2,2,2,inf +sdd,True,True,64,64,64,2,2,2,inf +sdd,True,True,64,64,128,2,2,2,inf +sdd,True,True,64,128,8,2,2,2,inf +sdd,True,True,64,128,16,2,2,2,inf +sdd,True,True,64,128,32,2,2,2,inf +sdd,True,True,64,128,64,2,2,2,inf +sdd,True,True,64,128,128,2,2,2,inf +sdd,True,True,128,8,8,2,2,2,inf +sdd,True,True,128,8,16,16,8,4,52.57091064453125 +sdd,True,True,128,8,32,8,8,8,31.27769775390625 +sdd,True,True,128,8,64,16,8,4,20.810726928710935 +sdd,True,True,128,8,128,16,8,8,17.33007354736328 +sdd,True,True,128,16,8,16,16,2,97.82230224609376 +sdd,True,True,128,16,16,16,2,2,52.05543212890625 +sdd,True,True,128,16,32,16,2,2,30.84331970214844 +sdd,True,True,128,16,64,16,4,4,21.341754150390624 +sdd,True,True,128,16,128,4,16,16,17.054681396484376 +sdd,True,True,128,32,8,2,2,2,inf +sdd,True,True,128,32,16,2,2,2,inf +sdd,True,True,128,32,32,2,2,2,inf +sdd,True,True,128,32,64,2,2,2,inf +sdd,True,True,128,32,128,2,2,2,inf +sdd,True,True,128,64,8,2,2,2,inf +sdd,True,True,128,64,16,2,2,2,inf +sdd,True,True,128,64,32,2,2,2,inf +sdd,True,True,128,64,64,2,2,2,inf +sdd,True,True,128,64,128,2,2,2,inf +sdd,True,True,128,128,8,2,2,2,inf +sdd,True,True,128,128,16,2,2,2,inf +sdd,True,True,128,128,32,2,2,2,inf +sdd,True,True,128,128,64,2,2,2,inf +sdd,True,True,128,128,128,2,2,2,inf diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv b/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv index 7a628265..a42e7ef2 100644 --- a/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv +++ b/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv @@ -1,325 +1,1501 @@ mode,trans_A,trans_B,BM,BK,BN,TM,TK,TN,latency -dds,False,False,16,16,16,4,4,4,114.3322 -dds,False,False,16,16,32,4,4,4,82.87 -dds,False,False,16,16,64,4,4,4,74.2951 -dds,False,False,16,32,16,4,4,4,120.1529 -dds,False,False,16,32,32,4,4,4,87.5296 -dds,False,False,16,32,64,4,4,4,79.7643 -dds,False,False,16,64,16,4,4,4,148.3481 -dds,False,False,16,64,32,4,4,4,92.0502 -dds,False,False,16,64,64,4,4,4,84.0627 -dds,False,False,32,16,16,4,4,4,77.7297 -dds,False,False,32,16,32,4,4,4,54.2717 -dds,False,False,32,16,64,4,4,4,52.089 -dds,False,False,32,32,16,4,4,4,83.1408 -dds,False,False,32,32,32,4,4,4,59.8523 -dds,False,False,32,32,64,4,4,4,54.2046 -dds,False,False,32,64,16,4,4,4,98.1477 -dds,False,False,32,64,32,4,4,4,69.5779 -dds,False,False,32,64,64,4,4,4,56.5643 -dds,False,False,64,16,16,4,4,4,64.2577 -dds,False,False,64,16,32,4,4,4,50.0309 -dds,False,False,64,16,64,4,4,4,50.0309 -dds,False,False,64,32,16,4,4,4,75.437 -dds,False,False,64,32,32,4,4,4,54.0036 -dds,False,False,64,32,64,4,4,4,54.0036 -dds,False,False,64,64,16,4,4,4,88.9933 -dds,False,False,64,64,32,4,4,4,56.6697 -dds,False,False,64,64,64,4,4,4,56.6697 -dds,False,True,16,16,16,4,4,4,134.6172 -dds,False,True,16,16,32,4,4,4,106.9163 -dds,False,True,16,16,64,4,4,4,101.0877 -dds,False,True,16,32,16,4,4,4,177.8903 -dds,False,True,16,32,32,4,4,4,96.7729 -dds,False,True,16,32,64,4,4,4,89.2383 -dds,False,True,16,64,16,4,4,4,281.8069 -dds,False,True,16,64,32,4,4,4,137.432 -dds,False,True,16,64,64,4,4,4,132.7488 -dds,False,True,32,16,16,4,4,4,83.1024 -dds,False,True,32,16,32,4,4,4,65.4798 -dds,False,True,32,16,64,4,4,4,59.4675 -dds,False,True,32,32,16,4,4,4,93.876 -dds,False,True,32,32,32,4,4,4,72.5978 -dds,False,True,32,32,64,4,4,4,65.1659 -dds,False,True,32,64,16,4,4,4,120.1478 -dds,False,True,32,64,32,4,4,4,98.9469 -dds,False,True,32,64,64,4,4,4,83.0436 -dds,False,True,64,16,16,4,4,4,65.8737 -dds,False,True,64,16,32,4,4,4,53.26 -dds,False,True,64,16,64,4,4,4,53.26 -dds,False,True,64,32,16,4,4,4,81.8998 -dds,False,True,64,32,32,4,4,4,60.1544 -dds,False,True,64,32,64,4,4,4,60.1544 -dds,False,True,64,64,16,4,4,4,99.884 -dds,False,True,64,64,32,4,4,4,70.5451 -dds,False,True,64,64,64,4,4,4,70.5451 -dds,True,False,16,16,16,4,4,4,135.2541 -dds,True,False,16,16,32,4,4,4,88.3433 -dds,True,False,16,16,64,4,4,4,80.7328 -dds,True,False,16,32,16,4,4,4,147.7196 -dds,True,False,16,32,32,4,4,4,83.4137 -dds,True,False,16,32,64,4,4,4,80.5606 -dds,True,False,16,64,16,4,4,4,178.6693 -dds,True,False,16,64,32,4,4,4,96.307 -dds,True,False,16,64,64,4,4,4,85.6432 -dds,True,False,32,16,16,4,4,4,97.407 -dds,True,False,32,16,32,4,4,4,64.0028 -dds,True,False,32,16,64,4,4,4,56.0378 -dds,True,False,32,32,16,4,4,4,112.9191 -dds,True,False,32,32,32,4,4,4,71.4564 -dds,True,False,32,32,64,4,4,4,59.7966 -dds,True,False,32,64,16,4,4,4,124.0738 -dds,True,False,32,64,32,4,4,4,80.3177 -dds,True,False,32,64,64,4,4,4,61.5147 -dds,True,False,64,16,16,4,4,4,124.6127 -dds,True,False,64,16,32,4,4,4,76.9557 -dds,True,False,64,16,64,4,4,4,76.9557 -dds,True,False,64,32,16,4,4,4,143.3676 -dds,True,False,64,32,32,4,4,4,84.2901 -dds,True,False,64,32,64,4,4,4,84.2901 -dds,True,False,64,64,16,4,4,4,150.3928 -dds,True,False,64,64,32,4,4,4,87.5746 -dds,True,False,64,64,64,4,4,4,87.5746 -dds,True,True,16,16,16,4,4,4,163.2706 -dds,True,True,16,16,32,4,4,4,111.4495 -dds,True,True,16,16,64,4,4,4,105.2302 -dds,True,True,16,32,16,4,4,4,215.4126 -dds,True,True,16,32,32,4,4,4,103.8571 -dds,True,True,16,32,64,4,4,4,92.4045 -dds,True,True,16,64,16,4,4,4,310.5157 -dds,True,True,16,64,32,4,4,4,149.429 -dds,True,True,16,64,64,4,4,4,134.981 -dds,True,True,32,16,16,4,4,4,103.7712 -dds,True,True,32,16,32,4,4,4,70.4374 -dds,True,True,32,16,64,4,4,4,63.2451 -dds,True,True,32,32,16,4,4,4,127.7935 -dds,True,True,32,32,32,4,4,4,85.6814 -dds,True,True,32,32,64,4,4,4,69.7582 -dds,True,True,32,64,16,4,4,4,155.0835 -dds,True,True,32,64,32,4,4,4,112.8103 -dds,True,True,32,64,64,4,4,4,91.3851 -dds,True,True,64,16,16,4,4,4,126.8568 -dds,True,True,64,16,32,4,4,4,81.2913 -dds,True,True,64,16,64,4,4,4,81.2913 -dds,True,True,64,32,16,4,4,4,149.8753 -dds,True,True,64,32,32,4,4,4,91.8136 -dds,True,True,64,32,64,4,4,4,91.8136 -dds,True,True,64,64,16,4,4,4,167.0518 -dds,True,True,64,64,32,4,4,4,105.5301 -dds,True,True,64,64,64,4,4,4,105.5301 -dsd,False,False,16,16,16,4,4,4,108.1017 -dsd,False,False,16,16,32,4,4,4,87.8983 -dsd,False,False,16,16,64,4,4,4,80.2562 -dsd,False,False,16,32,16,4,4,4,111.9973 -dsd,False,False,16,32,32,4,4,4,83.6637 -dsd,False,False,16,32,64,4,4,4,83.4532 -dsd,False,False,16,64,16,4,4,4,135.7782 -dsd,False,False,16,64,32,4,4,4,96.0233 -dsd,False,False,16,64,64,4,4,4,96.0233 -dsd,False,False,32,16,16,4,4,4,72.3478 -dsd,False,False,32,16,32,4,4,4,57.4545 -dsd,False,False,32,16,64,4,4,4,52.2037 -dsd,False,False,32,32,16,4,4,4,78.8152 -dsd,False,False,32,32,32,4,4,4,59.3432 -dsd,False,False,32,32,64,4,4,4,54.5784 -dsd,False,False,32,64,16,4,4,4,91.2725 -dsd,False,False,32,64,32,4,4,4,67.9594 -dsd,False,False,32,64,64,4,4,4,67.9594 -dsd,False,False,64,16,16,4,4,4,60.3805 -dsd,False,False,64,16,32,4,4,4,49.5804 -dsd,False,False,64,16,64,4,4,4,46.6983 -dsd,False,False,64,32,16,4,4,4,73.6614 -dsd,False,False,64,32,32,4,4,4,53.7312 -dsd,False,False,64,32,64,4,4,4,46.5469 -dsd,False,False,64,64,16,4,4,4,85.0814 -dsd,False,False,64,64,32,4,4,4,56.0396 -dsd,False,False,64,64,64,4,4,4,56.0396 -dsd,False,True,16,16,16,4,4,4,130.935 -dsd,False,True,16,16,32,4,4,4,109.5658 -dsd,False,True,16,16,64,4,4,4,92.5529 -dsd,False,True,16,32,16,4,4,4,178.6903 -dsd,False,True,16,32,32,4,4,4,101.6086 -dsd,False,True,16,32,64,4,4,4,96.7423 -dsd,False,True,16,64,16,4,4,4,278.8124 -dsd,False,True,16,64,32,4,4,4,142.8228 -dsd,False,True,16,64,64,4,4,4,142.8228 -dsd,False,True,32,16,16,4,4,4,87.9781 -dsd,False,True,32,16,32,4,4,4,66.9843 -dsd,False,True,32,16,64,4,4,4,59.9348 -dsd,False,True,32,32,16,4,4,4,94.2587 -dsd,False,True,32,32,32,4,4,4,73.2396 -dsd,False,True,32,32,64,4,4,4,65.7842 -dsd,False,True,32,64,16,4,4,4,124.7007 -dsd,False,True,32,64,32,4,4,4,98.0345 -dsd,False,True,32,64,64,4,4,4,98.0345 -dsd,False,True,64,16,16,4,4,4,68.01 -dsd,False,True,64,16,32,4,4,4,54.7027 -dsd,False,True,64,16,64,4,4,4,49.913 -dsd,False,True,64,32,16,4,4,4,82.1206 -dsd,False,True,64,32,32,4,4,4,61.7535 -dsd,False,True,64,32,64,4,4,4,53.0829 -dsd,False,True,64,64,16,4,4,4,101.6558 -dsd,False,True,64,64,32,4,4,4,70.1075 -dsd,False,True,64,64,64,4,4,4,70.1075 -dsd,True,False,16,16,16,4,4,4,128.4223 -dsd,True,False,16,16,32,4,4,4,90.0412 -dsd,True,False,16,16,64,4,4,4,83.3222 -dsd,True,False,16,32,16,4,4,4,142.1878 -dsd,True,False,16,32,32,4,4,4,86.9283 -dsd,True,False,16,32,64,4,4,4,84.5636 -dsd,True,False,16,64,16,4,4,4,163.221 -dsd,True,False,16,64,32,4,4,4,99.8338 -dsd,True,False,16,64,64,4,4,4,99.8338 -dsd,True,False,32,16,16,4,4,4,93.9119 -dsd,True,False,32,16,32,4,4,4,62.1346 -dsd,True,False,32,16,64,4,4,4,56.0104 -dsd,True,False,32,32,16,4,4,4,111.0007 -dsd,True,False,32,32,32,4,4,4,70.8843 -dsd,True,False,32,32,64,4,4,4,59.1251 -dsd,True,False,32,64,16,4,4,4,119.8824 -dsd,True,False,32,64,32,4,4,4,80.5292 -dsd,True,False,32,64,64,4,4,4,80.5292 -dsd,True,False,64,16,16,4,4,4,123.8651 -dsd,True,False,64,16,32,4,4,4,76.5209 -dsd,True,False,64,16,64,4,4,4,60.1391 -dsd,True,False,64,32,16,4,4,4,141.5866 -dsd,True,False,64,32,32,4,4,4,83.946 -dsd,True,False,64,32,64,4,4,4,60.4473 -dsd,True,False,64,64,16,4,4,4,148.974 -dsd,True,False,64,64,32,4,4,4,87.7046 -dsd,True,False,64,64,64,4,4,4,87.7046 -dsd,True,True,16,16,16,4,4,4,159.7035 -dsd,True,True,16,16,32,4,4,4,112.7343 -dsd,True,True,16,16,64,4,4,4,96.109 -dsd,True,True,16,32,16,4,4,4,216.4289 -dsd,True,True,16,32,32,4,4,4,113.0748 -dsd,True,True,16,32,64,4,4,4,102.4129 -dsd,True,True,16,64,16,4,4,4,311.4233 -dsd,True,True,16,64,32,4,4,4,152.9842 -dsd,True,True,16,64,64,4,4,4,152.9842 -dsd,True,True,32,16,16,4,4,4,101.2226 -dsd,True,True,32,16,32,4,4,4,73.1764 -dsd,True,True,32,16,64,4,4,4,62.08 -dsd,True,True,32,32,16,4,4,4,128.2265 -dsd,True,True,32,32,32,4,4,4,85.9566 -dsd,True,True,32,32,64,4,4,4,71.4414 -dsd,True,True,32,64,16,4,4,4,156.3029 -dsd,True,True,32,64,32,4,4,4,112.7689 -dsd,True,True,32,64,64,4,4,4,112.7689 -dsd,True,True,64,16,16,4,4,4,127.2963 -dsd,True,True,64,16,32,4,4,4,79.6703 -dsd,True,True,64,16,64,4,4,4,61.7858 -dsd,True,True,64,32,16,4,4,4,149.5118 -dsd,True,True,64,32,32,4,4,4,91.4158 -dsd,True,True,64,32,64,4,4,4,67.6785 -dsd,True,True,64,64,16,4,4,4,168.1443 -dsd,True,True,64,64,32,4,4,4,105.573 -dsd,True,True,64,64,64,4,4,4,105.573 -sdd,False,False,16,16,16,4,4,4,110.7007 -sdd,False,False,16,16,32,4,4,4,83.8641 -sdd,False,False,16,16,64,4,4,4,79.1289 -sdd,False,False,16,32,16,4,4,4,119.6233 -sdd,False,False,16,32,32,4,4,4,82.5102 -sdd,False,False,16,32,64,4,4,4,78.6386 -sdd,False,False,16,64,16,4,4,4,145.0991 -sdd,False,False,16,64,32,4,4,4,89.7248 -sdd,False,False,16,64,64,4,4,4,83.423 -sdd,False,False,32,16,16,4,4,4,72.6872 -sdd,False,False,32,16,32,4,4,4,54.0949 -sdd,False,False,32,16,64,4,4,4,51.6267 -sdd,False,False,32,32,16,4,4,4,82.8882 -sdd,False,False,32,32,32,4,4,4,59.9318 -sdd,False,False,32,32,64,4,4,4,54.2268 -sdd,False,False,32,64,16,4,4,4,92.1725 -sdd,False,False,32,64,32,4,4,4,67.4228 -sdd,False,False,32,64,64,4,4,4,55.8651 -sdd,False,False,64,16,16,4,4,4,57.2198 -sdd,False,False,64,16,32,4,4,4,49.1785 -sdd,False,False,64,16,64,4,4,4,46.1422 -sdd,False,False,64,32,16,4,4,4,74.8587 -sdd,False,False,64,32,32,4,4,4,53.8419 -sdd,False,False,64,32,64,4,4,4,46.7183 -sdd,False,False,64,64,16,4,4,4,74.8587 -sdd,False,False,64,64,32,4,4,4,53.8419 -sdd,False,False,64,64,64,4,4,4,46.7183 -sdd,False,True,16,16,16,4,4,4,131.2918 -sdd,False,True,16,16,32,4,4,4,104.9111 -sdd,False,True,16,16,64,4,4,4,100.4713 -sdd,False,True,16,32,16,4,4,4,178.6799 -sdd,False,True,16,32,32,4,4,4,96.054 -sdd,False,True,16,32,64,4,4,4,88.2646 -sdd,False,True,16,64,16,4,4,4,271.7691 -sdd,False,True,16,64,32,4,4,4,138.3314 -sdd,False,True,16,64,64,4,4,4,132.6129 -sdd,False,True,32,16,16,4,4,4,77.3388 -sdd,False,True,32,16,32,4,4,4,59.9648 -sdd,False,True,32,16,64,4,4,4,60.089 -sdd,False,True,32,32,16,4,4,4,94.166 -sdd,False,True,32,32,32,4,4,4,72.2477 -sdd,False,True,32,32,64,4,4,4,64.8427 -sdd,False,True,32,64,16,4,4,4,120.019 -sdd,False,True,32,64,32,4,4,4,97.3089 -sdd,False,True,32,64,64,4,4,4,83.1809 -sdd,False,True,64,16,16,4,4,4,59.4227 -sdd,False,True,64,16,32,4,4,4,52.1191 -sdd,False,True,64,16,64,4,4,4,49.1211 -sdd,False,True,64,32,16,4,4,4,81.5417 -sdd,False,True,64,32,32,4,4,4,59.0152 -sdd,False,True,64,32,64,4,4,4,52.0989 -sdd,False,True,64,64,16,4,4,4,81.5417 -sdd,False,True,64,64,32,4,4,4,59.0152 -sdd,False,True,64,64,64,4,4,4,52.0989 -sdd,True,False,16,16,16,4,4,4,130.6109 -sdd,True,False,16,16,32,4,4,4,84.93 -sdd,True,False,16,16,64,4,4,4,79.0723 -sdd,True,False,16,32,16,4,4,4,142.6209 -sdd,True,False,16,32,32,4,4,4,82.0366 -sdd,True,False,16,32,64,4,4,4,81.5747 -sdd,True,False,16,64,16,4,4,4,165.6981 -sdd,True,False,16,64,32,4,4,4,92.0803 -sdd,True,False,16,64,64,4,4,4,84.1962 -sdd,True,False,32,16,16,4,4,4,96.9497 -sdd,True,False,32,16,32,4,4,4,62.008 -sdd,True,False,32,16,64,4,4,4,55.4291 -sdd,True,False,32,32,16,4,4,4,113.3967 -sdd,True,False,32,32,32,4,4,4,70.2921 -sdd,True,False,32,32,64,4,4,4,58.7307 -sdd,True,False,32,64,16,4,4,4,120.7301 -sdd,True,False,32,64,32,4,4,4,78.2809 -sdd,True,False,32,64,64,4,4,4,60.8745 -sdd,True,False,64,16,16,4,4,4,125.0771 -sdd,True,False,64,16,32,4,4,4,76.4491 -sdd,True,False,64,16,64,4,4,4,60.097 -sdd,True,False,64,32,16,4,4,4,141.8331 -sdd,True,False,64,32,32,4,4,4,83.7384 -sdd,True,False,64,32,64,4,4,4,60.9805 -sdd,True,False,64,64,16,4,4,4,141.8331 -sdd,True,False,64,64,32,4,4,4,83.7384 -sdd,True,False,64,64,64,4,4,4,60.9805 -sdd,True,True,16,16,16,4,4,4,159.6109 -sdd,True,True,16,16,32,4,4,4,105.5516 -sdd,True,True,16,16,64,4,4,4,100.7452 -sdd,True,True,16,32,16,4,4,4,211.4903 -sdd,True,True,16,32,32,4,4,4,101.2889 -sdd,True,True,16,32,64,4,4,4,90.9672 -sdd,True,True,16,64,16,4,4,4,303.9266 -sdd,True,True,16,64,32,4,4,4,146.3723 -sdd,True,True,16,64,64,4,4,4,133.093 -sdd,True,True,32,16,16,4,4,4,103.5772 -sdd,True,True,32,16,32,4,4,4,71.01 -sdd,True,True,32,16,64,4,4,4,63.2092 -sdd,True,True,32,32,16,4,4,4,128.0946 -sdd,True,True,32,32,32,4,4,4,86.1027 -sdd,True,True,32,32,64,4,4,4,68.897 -sdd,True,True,32,64,16,4,4,4,153.6706 -sdd,True,True,32,64,32,4,4,4,111.8289 -sdd,True,True,32,64,64,4,4,4,91.2396 -sdd,True,True,64,16,16,4,4,4,127.9097 -sdd,True,True,64,16,32,4,4,4,80.6409 -sdd,True,True,64,16,64,4,4,4,62.273 -sdd,True,True,64,32,16,4,4,4,148.6517 -sdd,True,True,64,32,32,4,4,4,90.9354 -sdd,True,True,64,32,64,4,4,4,67.6934 -sdd,True,True,64,64,16,4,4,4,148.6517 -sdd,True,True,64,64,32,4,4,4,90.9354 -sdd,True,True,64,64,64,4,4,4,67.6934 +dds,False,False,8,8,8,2,2,2,83.95069580078125 +dds,False,False,8,8,16,4,8,2,43.84989013671875 +dds,False,False,8,8,32,8,8,2,27.1609619140625 +dds,False,False,8,8,64,8,8,4,28.324365234375 +dds,False,False,8,8,128,2,2,2,inf +dds,False,False,8,16,8,4,16,2,84.48544921875 +dds,False,False,8,16,16,4,4,2,42.477719116210935 +dds,False,False,8,16,32,8,16,2,32.50401000976562 +dds,False,False,8,16,64,8,4,4,30.6483642578125 +dds,False,False,8,16,128,8,8,4,31.26983337402344 +dds,False,False,8,32,8,2,4,4,90.04239501953126 +dds,False,False,8,32,16,4,4,2,45.71172180175781 +dds,False,False,8,32,32,2,16,2,36.58458557128906 +dds,False,False,8,32,64,8,2,2,36.38044128417969 +dds,False,False,8,32,128,8,4,2,36.610171508789065 +dds,False,False,8,64,8,2,16,2,90.79548950195311 +dds,False,False,8,64,16,2,8,2,45.22069091796875 +dds,False,False,8,64,32,2,4,2,39.20011291503906 +dds,False,False,8,64,64,4,2,4,31.470745849609376 +dds,False,False,8,64,128,4,16,2,23.4410400390625 +dds,False,False,8,128,8,2,2,2,inf +dds,False,False,8,128,16,2,4,2,54.0128662109375 +dds,False,False,8,128,32,2,8,2,41.18137817382812 +dds,False,False,8,128,64,2,4,2,40.19525146484375 +dds,False,False,8,128,128,2,2,2,inf +dds,False,False,16,8,8,4,4,2,42.92951354980469 +dds,False,False,16,8,16,4,2,2,25.0050048828125 +dds,False,False,16,8,32,8,8,2,18.46839294433594 +dds,False,False,16,8,64,8,4,4,16.389712524414062 +dds,False,False,16,8,128,16,2,4,15.46875457763672 +dds,False,False,16,16,8,2,16,2,44.96512756347656 +dds,False,False,16,16,16,4,8,2,23.26790008544922 +dds,False,False,16,16,32,8,16,2,18.670246887207032 +dds,False,False,16,16,64,4,2,8,17.152764892578126 +dds,False,False,16,16,128,16,16,2,17.652220153808592 +dds,False,False,16,32,8,2,8,2,49.882073974609376 +dds,False,False,16,32,16,4,4,2,25.979672241210935 +dds,False,False,16,32,32,8,16,2,18.137794494628903 +dds,False,False,16,32,64,4,2,4,19.2738525390625 +dds,False,False,16,32,128,8,4,2,17.795779418945312 +dds,False,False,16,64,8,2,2,2,45.37416687011719 +dds,False,False,16,64,16,4,8,2,25.164306640625 +dds,False,False,16,64,32,4,16,2,20.45664367675781 +dds,False,False,16,64,64,4,16,2,19.650909423828125 +dds,False,False,16,64,128,2,16,4,17.490185546875 +dds,False,False,16,128,8,2,16,2,48.46643981933594 +dds,False,False,16,128,16,2,4,2,31.333306884765623 +dds,False,False,16,128,32,2,8,2,22.09304351806641 +dds,False,False,16,128,64,4,2,2,22.02271728515625 +dds,False,False,16,128,128,2,2,2,inf +dds,False,False,32,8,8,8,2,2,32.411660766601564 +dds,False,False,32,8,16,8,2,2,18.840211486816408 +dds,False,False,32,8,32,8,2,4,13.949766540527344 +dds,False,False,32,8,64,16,8,4,12.54494400024414 +dds,False,False,32,8,128,2,2,2,inf +dds,False,False,32,16,8,4,2,2,31.90923767089844 +dds,False,False,32,16,16,8,8,2,18.17984313964844 +dds,False,False,32,16,32,8,2,4,13.35161895751953 +dds,False,False,32,16,64,8,4,4,12.96336669921875 +dds,False,False,32,16,128,2,2,2,inf +dds,False,False,32,32,8,4,4,2,34.13486022949219 +dds,False,False,32,32,16,8,4,2,18.598626708984376 +dds,False,False,32,32,32,8,16,2,14.275120544433594 +dds,False,False,32,32,64,8,4,2,13.326914978027345 +dds,False,False,32,32,128,2,2,2,inf +dds,False,False,32,64,8,4,2,2,34.046783447265625 +dds,False,False,32,64,16,4,8,2,20.470233154296874 +dds,False,False,32,64,32,4,8,2,15.157046508789062 +dds,False,False,32,64,64,8,8,2,13.67220458984375 +dds,False,False,32,64,128,2,2,2,inf +dds,False,False,32,128,8,2,16,2,39.59329833984375 +dds,False,False,32,128,16,4,4,2,24.433506774902344 +dds,False,False,32,128,32,4,8,2,16.391807556152344 +dds,False,False,32,128,64,4,16,2,14.969456481933594 +dds,False,False,32,128,128,2,2,2,inf +dds,False,False,64,8,8,16,4,2,26.64129638671875 +dds,False,False,64,8,16,8,4,4,15.326361083984375 +dds,False,False,64,8,32,16,8,4,12.275103759765624 +dds,False,False,64,8,64,2,2,2,inf +dds,False,False,64,8,128,2,2,2,inf +dds,False,False,64,16,8,8,2,2,24.83520965576172 +dds,False,False,64,16,16,16,4,2,15.0552734375 +dds,False,False,64,16,32,8,16,4,12.50583038330078 +dds,False,False,64,16,64,2,2,2,inf +dds,False,False,64,16,128,2,2,2,inf +dds,False,False,64,32,8,8,8,2,29.507452392578124 +dds,False,False,64,32,16,8,8,2,16.862696838378906 +dds,False,False,64,32,32,8,4,2,12.744691467285156 +dds,False,False,64,32,64,2,2,2,inf +dds,False,False,64,32,128,2,2,2,inf +dds,False,False,64,64,8,4,2,2,32.07197265625 +dds,False,False,64,64,16,8,16,2,18.548597717285155 +dds,False,False,64,64,32,8,2,2,13.186189270019533 +dds,False,False,64,64,64,2,2,2,inf +dds,False,False,64,64,128,2,2,2,inf +dds,False,False,64,128,8,2,8,2,39.9909912109375 +dds,False,False,64,128,16,4,8,2,20.25826568603516 +dds,False,False,64,128,32,8,2,2,14.141506958007811 +dds,False,False,64,128,64,2,2,2,inf +dds,False,False,64,128,128,2,2,2,inf +dds,False,False,128,8,8,16,2,4,26.484652709960937 +dds,False,False,128,8,16,16,2,4,14.087107849121091 +dds,False,False,128,8,32,2,2,2,inf +dds,False,False,128,8,64,2,2,2,inf +dds,False,False,128,8,128,2,2,2,inf +dds,False,False,128,16,8,16,8,2,21.534544372558592 +dds,False,False,128,16,16,16,16,2,14.640284729003906 +dds,False,False,128,16,32,2,2,2,inf +dds,False,False,128,16,64,2,2,2,inf +dds,False,False,128,16,128,2,2,2,inf +dds,False,False,128,32,8,8,8,2,28.789382934570312 +dds,False,False,128,32,16,8,16,2,15.902000427246094 +dds,False,False,128,32,32,2,2,2,inf +dds,False,False,128,32,64,2,2,2,inf +dds,False,False,128,32,128,2,2,2,inf +dds,False,False,128,64,8,4,2,2,33.127508544921874 +dds,False,False,128,64,16,8,2,2,17.735232543945312 +dds,False,False,128,64,32,2,2,2,inf +dds,False,False,128,64,64,2,2,2,inf +dds,False,False,128,64,128,2,2,2,inf +dds,False,False,128,128,8,2,2,2,inf +dds,False,False,128,128,16,2,2,2,inf +dds,False,False,128,128,32,2,2,2,inf +dds,False,False,128,128,64,2,2,2,inf +dds,False,False,128,128,128,2,2,2,inf +dds,False,True,8,8,8,2,8,2,68.33310546875 +dds,False,True,8,8,16,4,2,2,56.5980712890625 +dds,False,True,8,8,32,8,8,4,68.17498168945312 +dds,False,True,8,8,64,4,8,16,84.39244384765625 +dds,False,True,8,8,128,8,8,16,92.27847290039062 +dds,False,True,8,16,8,2,16,2,67.75989379882813 +dds,False,True,8,16,16,2,4,2,49.00977478027344 +dds,False,True,8,16,32,8,16,2,46.584982299804686 +dds,False,True,8,16,64,8,4,4,42.45918579101563 +dds,False,True,8,16,128,2,8,16,41.13911743164063 +dds,False,True,8,32,8,2,16,2,89.56703491210938 +dds,False,True,8,32,16,2,4,2,46.791586303710936 +dds,False,True,8,32,32,4,8,2,39.323019409179686 +dds,False,True,8,32,64,4,2,2,38.71488342285157 +dds,False,True,8,32,128,8,8,2,39.51044921875 +dds,False,True,8,64,8,2,16,2,130.14737548828126 +dds,False,True,8,64,16,2,8,2,66.99788818359374 +dds,False,True,8,64,32,4,16,2,55.87267456054688 +dds,False,True,8,64,64,4,16,2,54.4582763671875 +dds,False,True,8,64,128,4,4,2,54.76102905273437 +dds,False,True,8,128,8,2,2,2,inf +dds,False,True,8,128,16,2,16,2,109.83070068359376 +dds,False,True,8,128,32,2,8,2,99.8999755859375 +dds,False,True,8,128,64,2,4,2,98.21196899414062 +dds,False,True,8,128,128,2,2,2,inf +dds,False,True,16,8,8,4,4,2,43.08326416015625 +dds,False,True,16,8,16,4,8,2,29.3208740234375 +dds,False,True,16,8,32,8,4,4,43.06368103027344 +dds,False,True,16,8,64,4,4,8,44.58840637207031 +dds,False,True,16,8,128,4,4,16,49.77604370117187 +dds,False,True,16,16,8,2,16,2,37.32265014648438 +dds,False,True,16,16,16,4,8,2,25.79122314453125 +dds,False,True,16,16,32,8,16,2,25.268634033203124 +dds,False,True,16,16,64,16,8,2,21.63783721923828 +dds,False,True,16,16,128,16,16,2,21.9729248046875 +dds,False,True,16,32,8,2,2,2,47.71018371582032 +dds,False,True,16,32,16,4,2,2,29.6340576171875 +dds,False,True,16,32,32,8,16,2,22.302653503417968 +dds,False,True,16,32,64,8,8,2,22.06714172363281 +dds,False,True,16,32,128,8,4,2,21.238131713867187 +dds,False,True,16,64,8,2,8,2,58.50890502929688 +dds,False,True,16,64,16,4,2,2,40.11568603515625 +dds,False,True,16,64,32,8,8,2,33.484283447265625 +dds,False,True,16,64,64,4,2,2,32.49169006347656 +dds,False,True,16,64,128,4,2,4,32.50621032714844 +dds,False,True,16,128,8,2,16,2,86.73631591796875 +dds,False,True,16,128,16,4,8,2,63.19295043945313 +dds,False,True,16,128,32,4,2,2,55.3454345703125 +dds,False,True,16,128,64,4,8,2,53.07149658203125 +dds,False,True,16,128,128,2,2,2,inf +dds,False,True,32,8,8,8,8,2,31.70926513671875 +dds,False,True,32,8,16,8,8,2,20.866188049316406 +dds,False,True,32,8,32,8,2,2,18.719453430175783 +dds,False,True,32,8,64,16,8,2,23.87512969970703 +dds,False,True,32,8,128,2,2,2,inf +dds,False,True,32,16,8,4,2,2,28.93572998046875 +dds,False,True,32,16,16,8,16,2,19.09322509765625 +dds,False,True,32,16,32,16,8,2,15.363580322265625 +dds,False,True,32,16,64,8,2,4,14.772169494628908 +dds,False,True,32,16,128,2,2,2,inf +dds,False,True,32,32,8,4,2,2,36.425103759765626 +dds,False,True,32,32,16,8,4,2,21.15513610839844 +dds,False,True,32,32,32,8,8,2,16.99063720703125 +dds,False,True,32,32,64,8,2,4,16.2457763671875 +dds,False,True,32,32,128,2,2,2,inf +dds,False,True,32,64,8,4,4,2,42.11276245117188 +dds,False,True,32,64,16,8,8,2,26.644598388671877 +dds,False,True,32,64,32,4,16,2,22.55145568847656 +dds,False,True,32,64,64,4,2,4,21.086448669433597 +dds,False,True,32,64,128,2,2,2,inf +dds,False,True,32,128,8,2,8,2,58.4863525390625 +dds,False,True,32,128,16,4,8,2,39.704345703125 +dds,False,True,32,128,32,4,2,2,32.989443969726565 +dds,False,True,32,128,64,4,8,4,31.92890930175781 +dds,False,True,32,128,128,2,2,2,inf +dds,False,True,64,8,8,16,8,2,26.21756591796875 +dds,False,True,64,8,16,8,4,4,16.323394775390625 +dds,False,True,64,8,32,16,2,4,13.234153747558594 +dds,False,True,64,8,64,2,2,2,inf +dds,False,True,64,8,128,2,2,2,inf +dds,False,True,64,16,8,8,4,2,23.53669128417969 +dds,False,True,64,16,16,16,4,2,15.508131408691408 +dds,False,True,64,16,32,8,16,4,13.336930847167968 +dds,False,True,64,16,64,2,2,2,inf +dds,False,True,64,16,128,2,2,2,inf +dds,False,True,64,32,8,8,4,2,30.660494995117187 +dds,False,True,64,32,16,8,16,2,18.16160888671875 +dds,False,True,64,32,32,8,2,2,14.187823486328124 +dds,False,True,64,32,64,2,2,2,inf +dds,False,True,64,32,128,2,2,2,inf +dds,False,True,64,64,8,4,16,2,35.908871459960935 +dds,False,True,64,64,16,8,2,2,22.574208068847657 +dds,False,True,64,64,32,8,8,2,16.651814270019532 +dds,False,True,64,64,64,2,2,2,inf +dds,False,True,64,64,128,2,2,2,inf +dds,False,True,64,128,8,2,2,2,47.98629455566406 +dds,False,True,64,128,16,4,2,2,29.4864501953125 +dds,False,True,64,128,32,4,16,2,22.77893829345703 +dds,False,True,64,128,64,2,2,2,inf +dds,False,True,64,128,128,2,2,2,inf +dds,False,True,128,8,8,16,2,4,26.28378295898437 +dds,False,True,128,8,16,16,2,4,14.428512573242188 +dds,False,True,128,8,32,2,2,2,inf +dds,False,True,128,8,64,2,2,2,inf +dds,False,True,128,8,128,2,2,2,inf +dds,False,True,128,16,8,16,4,2,21.152851867675786 +dds,False,True,128,16,16,16,8,2,14.555728149414062 +dds,False,True,128,16,32,2,2,2,inf +dds,False,True,128,16,64,2,2,2,inf +dds,False,True,128,16,128,2,2,2,inf +dds,False,True,128,32,8,8,4,2,29.437484741210938 +dds,False,True,128,32,16,8,16,2,16.786746215820312 +dds,False,True,128,32,32,2,2,2,inf +dds,False,True,128,32,64,2,2,2,inf +dds,False,True,128,32,128,2,2,2,inf +dds,False,True,128,64,8,4,2,2,35.015570068359374 +dds,False,True,128,64,16,8,2,2,19.62778167724609 +dds,False,True,128,64,32,2,2,2,inf +dds,False,True,128,64,64,2,2,2,inf +dds,False,True,128,64,128,2,2,2,inf +dds,False,True,128,128,8,2,2,2,inf +dds,False,True,128,128,16,2,2,2,inf +dds,False,True,128,128,32,2,2,2,inf +dds,False,True,128,128,64,2,2,2,inf +dds,False,True,128,128,128,2,2,2,inf +dds,True,False,8,8,8,2,2,2,85.9234619140625 +dds,True,False,8,8,16,4,8,2,46.31357421875 +dds,True,False,8,8,32,8,8,2,30.778073120117188 +dds,True,False,8,8,64,4,8,8,29.98555908203125 +dds,True,False,8,8,128,2,2,2,inf +dds,True,False,8,16,8,4,16,2,90.65723876953123 +dds,True,False,8,16,16,2,4,2,45.50893249511719 +dds,True,False,8,16,32,8,2,2,35.563201904296875 +dds,True,False,8,16,64,8,16,4,32.1559814453125 +dds,True,False,8,16,128,8,2,4,32.832516479492185 +dds,True,False,8,32,8,2,16,2,89.41387939453125 +dds,True,False,8,32,16,2,4,2,50.946450805664064 +dds,True,False,8,32,32,2,2,4,38.40348815917969 +dds,True,False,8,32,64,8,8,2,36.46643981933594 +dds,True,False,8,32,128,8,16,2,36.96016235351563 +dds,True,False,8,64,8,2,8,2,96.24725341796876 +dds,True,False,8,64,16,2,2,2,49.09877014160156 +dds,True,False,8,64,32,2,4,2,39.97953491210937 +dds,True,False,8,64,64,2,4,2,38.450704956054686 +dds,True,False,8,64,128,4,2,2,28.423953247070312 +dds,True,False,8,128,8,2,16,2,97.62970581054688 +dds,True,False,8,128,16,2,4,2,50.08351135253906 +dds,True,False,8,128,32,2,16,2,42.19480895996094 +dds,True,False,8,128,64,2,4,2,40.920501708984375 +dds,True,False,8,128,128,2,2,2,inf +dds,True,False,16,8,8,4,8,2,53.02158203125 +dds,True,False,16,8,16,4,8,2,26.3121826171875 +dds,True,False,16,8,32,8,8,2,18.794140625 +dds,True,False,16,8,64,8,2,4,16.877682495117188 +dds,True,False,16,8,128,16,2,4,15.879420471191406 +dds,True,False,16,16,8,2,4,2,47.39478454589844 +dds,True,False,16,16,16,4,8,2,25.626373291015625 +dds,True,False,16,16,32,4,4,2,20.6164794921875 +dds,True,False,16,16,64,16,16,2,18.20005798339844 +dds,True,False,16,16,128,16,4,2,18.098419189453125 +dds,True,False,16,32,8,2,2,2,55.869049072265625 +dds,True,False,16,32,16,4,4,2,29.0021240234375 +dds,True,False,16,32,32,8,16,2,18.64588775634765 +dds,True,False,16,32,64,8,4,2,19.498381042480467 +dds,True,False,16,32,128,8,8,2,18.230441284179687 +dds,True,False,16,64,8,2,2,2,54.580517578125 +dds,True,False,16,64,16,4,8,2,28.303189086914063 +dds,True,False,16,64,32,4,16,2,21.05837097167969 +dds,True,False,16,64,64,4,16,2,19.862054443359376 +dds,True,False,16,64,128,2,8,4,17.901318359375 +dds,True,False,16,128,8,2,4,2,54.739007568359376 +dds,True,False,16,128,16,2,8,2,32.546041870117186 +dds,True,False,16,128,32,2,16,2,23.56947174072265 +dds,True,False,16,128,64,4,2,2,21.99974365234375 +dds,True,False,16,128,128,2,2,2,inf +dds,True,False,32,8,8,8,8,2,59.6756103515625 +dds,True,False,32,8,16,8,8,2,22.40240936279297 +dds,True,False,32,8,32,8,8,4,15.45 +dds,True,False,32,8,64,8,2,8,13.239605712890626 +dds,True,False,32,8,128,2,2,2,inf +dds,True,False,32,16,8,4,16,2,42.68543090820312 +dds,True,False,32,16,16,8,4,2,22.36754608154297 +dds,True,False,32,16,32,8,2,4,15.607037353515626 +dds,True,False,32,16,64,8,2,4,13.975074768066406 +dds,True,False,32,16,128,2,2,2,inf +dds,True,False,32,32,8,4,16,2,51.02578125 +dds,True,False,32,32,16,8,4,2,26.00336608886719 +dds,True,False,32,32,32,8,8,2,16.560348510742188 +dds,True,False,32,32,64,8,4,2,14.603456115722656 +dds,True,False,32,32,128,2,2,2,inf +dds,True,False,32,64,8,4,4,2,50.49428100585938 +dds,True,False,32,64,16,8,2,2,28.138201904296874 +dds,True,False,32,64,32,4,4,2,18.134669494628906 +dds,True,False,32,64,64,8,4,2,14.894908142089845 +dds,True,False,32,64,128,2,2,2,inf +dds,True,False,32,128,8,2,2,2,55.946539306640624 +dds,True,False,32,128,16,4,16,2,30.93107299804688 +dds,True,False,32,128,32,4,4,2,18.512527465820312 +dds,True,False,32,128,64,4,2,2,15.689494323730468 +dds,True,False,32,128,128,2,2,2,inf +dds,True,False,64,8,8,16,4,2,95.04129638671876 +dds,True,False,64,8,16,8,4,4,30.25923156738281 +dds,True,False,64,8,32,16,8,4,19.586994934082032 +dds,True,False,64,8,64,2,2,2,inf +dds,True,False,64,8,128,2,2,2,inf +dds,True,False,64,16,8,8,4,2,58.078216552734375 +dds,True,False,64,16,16,16,4,2,29.82334289550781 +dds,True,False,64,16,32,8,16,4,19.79328308105469 +dds,True,False,64,16,64,2,2,2,inf +dds,True,False,64,16,128,2,2,2,inf +dds,True,False,64,32,8,8,16,2,66.37091674804688 +dds,True,False,64,32,16,8,16,2,34.887594604492186 +dds,True,False,64,32,32,16,16,2,20.612745666503905 +dds,True,False,64,32,64,2,2,2,inf +dds,True,False,64,32,128,2,2,2,inf +dds,True,False,64,64,8,8,2,2,67.04953002929688 +dds,True,False,64,64,16,8,16,2,35.14598083496094 +dds,True,False,64,64,32,8,8,2,21.04156494140625 +dds,True,False,64,64,64,2,2,2,inf +dds,True,False,64,64,128,2,2,2,inf +dds,True,False,64,128,8,4,2,2,72.3403076171875 +dds,True,False,64,128,16,4,4,2,37.190869140625 +dds,True,False,64,128,32,8,2,2,22.02156219482422 +dds,True,False,64,128,64,2,2,2,inf +dds,True,False,64,128,128,2,2,2,inf +dds,True,False,128,8,8,2,2,2,inf +dds,True,False,128,8,16,16,8,4,52.30179443359375 +dds,True,False,128,8,32,2,2,2,inf +dds,True,False,128,8,64,2,2,2,inf +dds,True,False,128,8,128,2,2,2,inf +dds,True,False,128,16,8,16,8,2,95.9490966796875 +dds,True,False,128,16,16,16,8,2,51.84322509765625 +dds,True,False,128,16,32,2,2,2,inf +dds,True,False,128,16,64,2,2,2,inf +dds,True,False,128,16,128,2,2,2,inf +dds,True,False,128,32,8,8,16,2,105.0478271484375 +dds,True,False,128,32,16,16,16,2,53.6555908203125 +dds,True,False,128,32,32,2,2,2,inf +dds,True,False,128,32,64,2,2,2,inf +dds,True,False,128,32,128,2,2,2,inf +dds,True,False,128,64,8,4,4,2,106.87479248046876 +dds,True,False,128,64,16,8,2,2,54.582958984375 +dds,True,False,128,64,32,2,2,2,inf +dds,True,False,128,64,64,2,2,2,inf +dds,True,False,128,64,128,2,2,2,inf +dds,True,False,128,128,8,2,2,2,inf +dds,True,False,128,128,16,2,2,2,inf +dds,True,False,128,128,32,2,2,2,inf +dds,True,False,128,128,64,2,2,2,inf +dds,True,False,128,128,128,2,2,2,inf +dds,True,True,8,8,8,2,4,2,93.86426391601564 +dds,True,True,8,8,16,2,4,4,62.9012451171875 +dds,True,True,8,8,32,4,4,8,75.62587280273438 +dds,True,True,8,8,64,4,4,16,85.3392333984375 +dds,True,True,8,8,128,4,2,16,92.88526000976564 +dds,True,True,8,16,8,2,4,2,79.35208740234376 +dds,True,True,8,16,16,2,4,2,47.40335388183594 +dds,True,True,8,16,32,8,16,2,49.03898620605469 +dds,True,True,8,16,64,8,8,4,43.70534973144531 +dds,True,True,8,16,128,2,16,16,41.79808044433594 +dds,True,True,8,32,8,2,16,2,103.93631591796876 +dds,True,True,8,32,16,2,16,2,52.19888305664063 +dds,True,True,8,32,32,4,8,2,41.50272827148437 +dds,True,True,8,32,64,4,8,2,39.83175354003906 +dds,True,True,8,32,128,8,8,2,39.99073791503906 +dds,True,True,8,64,8,2,4,2,144.64049072265624 +dds,True,True,8,64,16,2,4,2,71.35440673828126 +dds,True,True,8,64,32,4,2,2,58.09429931640625 +dds,True,True,8,64,64,4,2,2,55.59429931640625 +dds,True,True,8,64,128,4,16,2,54.93161010742188 +dds,True,True,8,128,8,2,2,2,inf +dds,True,True,8,128,16,2,8,2,113.36585693359376 +dds,True,True,8,128,32,2,16,2,102.328076171875 +dds,True,True,8,128,64,2,4,2,99.2333984375 +dds,True,True,8,128,128,2,2,2,inf +dds,True,True,16,8,8,4,2,2,55.75620727539062 +dds,True,True,16,8,16,4,4,2,31.414694213867183 +dds,True,True,16,8,32,8,4,4,43.36945495605469 +dds,True,True,16,8,64,2,8,16,41.72640686035156 +dds,True,True,16,8,128,4,4,16,48.80404052734375 +dds,True,True,16,16,8,2,16,2,44.52323303222656 +dds,True,True,16,16,16,4,16,2,27.96387023925781 +dds,True,True,16,16,32,8,16,2,26.15814208984375 +dds,True,True,16,16,64,16,8,2,21.87930908203125 +dds,True,True,16,16,128,16,4,2,21.86894989013672 +dds,True,True,16,32,8,2,2,2,57.255419921875 +dds,True,True,16,32,16,4,4,2,34.42570190429687 +dds,True,True,16,32,32,8,16,2,24.054620361328126 +dds,True,True,16,32,64,8,16,2,22.852418518066408 +dds,True,True,16,32,128,8,16,2,21.65142974853516 +dds,True,True,16,64,8,2,4,2,67.79102783203125 +dds,True,True,16,64,16,4,2,2,44.90336303710937 +dds,True,True,16,64,32,8,2,2,35.3868896484375 +dds,True,True,16,64,64,4,2,2,33.657223510742185 +dds,True,True,16,64,128,4,2,4,32.798883056640626 +dds,True,True,16,128,8,2,16,2,89.61912841796875 +dds,True,True,16,128,16,4,8,2,68.82657470703126 +dds,True,True,16,128,32,4,2,2,56.98488159179688 +dds,True,True,16,128,64,4,2,2,53.906036376953125 +dds,True,True,16,128,128,2,2,2,inf +dds,True,True,32,8,8,8,4,2,61.55089111328125 +dds,True,True,32,8,16,8,2,2,24.22500457763672 +dds,True,True,32,8,32,16,8,2,19.48570861816406 +dds,True,True,32,8,64,16,2,2,19.49196472167969 +dds,True,True,32,8,128,2,2,2,inf +dds,True,True,32,16,8,4,8,2,42.84806823730469 +dds,True,True,32,16,16,8,2,2,23.67737274169922 +dds,True,True,32,16,32,8,16,4,17.302281188964844 +dds,True,True,32,16,64,8,4,4,15.770700073242187 +dds,True,True,32,16,128,2,2,2,inf +dds,True,True,32,32,8,4,16,2,53.422998046875 +dds,True,True,32,32,16,8,16,2,29.681704711914065 +dds,True,True,32,32,32,8,8,2,20.000189208984374 +dds,True,True,32,32,64,8,2,4,17.455282592773436 +dds,True,True,32,32,128,2,2,2,inf +dds,True,True,32,64,8,4,8,2,58.26892700195312 +dds,True,True,32,64,16,8,2,2,34.81661071777344 +dds,True,True,32,64,32,8,8,2,26.6898193359375 +dds,True,True,32,64,64,8,2,2,23.006997680664064 +dds,True,True,32,64,128,2,2,2,inf +dds,True,True,32,128,8,4,2,2,71.09729614257813 +dds,True,True,32,128,16,4,2,2,48.069085693359376 +dds,True,True,32,128,32,4,16,2,36.78126525878906 +dds,True,True,32,128,64,8,2,2,33.828955078125 +dds,True,True,32,128,128,2,2,2,inf +dds,True,True,64,8,8,16,4,2,97.09566650390624 +dds,True,True,64,8,16,8,8,4,31.19393920898437 +dds,True,True,64,8,32,8,8,8,20.78213806152344 +dds,True,True,64,8,64,2,2,2,inf +dds,True,True,64,8,128,2,2,2,inf +dds,True,True,64,16,8,8,8,2,58.19834594726562 +dds,True,True,64,16,16,16,8,2,31.186398315429688 +dds,True,True,64,16,32,8,4,4,20.88908233642578 +dds,True,True,64,16,64,2,2,2,inf +dds,True,True,64,16,128,2,2,2,inf +dds,True,True,64,32,8,8,8,2,67.47835083007813 +dds,True,True,64,32,16,16,8,2,36.777783203125 +dds,True,True,64,32,32,16,8,2,22.203033447265625 +dds,True,True,64,32,64,2,2,2,inf +dds,True,True,64,32,128,2,2,2,inf +dds,True,True,64,64,8,8,16,2,71.129345703125 +dds,True,True,64,64,16,8,4,2,39.23030395507813 +dds,True,True,64,64,32,8,4,2,25.64223022460937 +dds,True,True,64,64,64,2,2,2,inf +dds,True,True,64,64,128,2,2,2,inf +dds,True,True,64,128,8,4,8,2,81.20865478515626 +dds,True,True,64,128,16,4,2,2,46.47157287597656 +dds,True,True,64,128,32,8,2,2,31.310092163085937 +dds,True,True,64,128,64,2,2,2,inf +dds,True,True,64,128,128,2,2,2,inf +dds,True,True,128,8,8,2,2,2,inf +dds,True,True,128,8,16,16,2,4,52.752923583984376 +dds,True,True,128,8,32,2,2,2,inf +dds,True,True,128,8,64,2,2,2,inf +dds,True,True,128,8,128,2,2,2,inf +dds,True,True,128,16,8,16,4,2,96.22736206054688 +dds,True,True,128,16,16,16,4,2,51.95736083984375 +dds,True,True,128,16,32,2,2,2,inf +dds,True,True,128,16,64,2,2,2,inf +dds,True,True,128,16,128,2,2,2,inf +dds,True,True,128,32,8,8,4,2,105.681982421875 +dds,True,True,128,32,16,16,4,2,54.82050170898437 +dds,True,True,128,32,32,2,2,2,inf +dds,True,True,128,32,64,2,2,2,inf +dds,True,True,128,32,128,2,2,2,inf +dds,True,True,128,64,8,4,16,2,108.55108642578124 +dds,True,True,128,64,16,8,4,2,56.63348388671875 +dds,True,True,128,64,32,2,2,2,inf +dds,True,True,128,64,64,2,2,2,inf +dds,True,True,128,64,128,2,2,2,inf +dds,True,True,128,128,8,2,2,2,inf +dds,True,True,128,128,16,2,2,2,inf +dds,True,True,128,128,32,2,2,2,inf +dds,True,True,128,128,64,2,2,2,inf +dds,True,True,128,128,128,2,2,2,inf +dsd,False,False,8,8,8,2,8,2,58.68651123046875 +dsd,False,False,8,8,16,4,4,2,42.6728759765625 +dsd,False,False,8,8,32,4,4,4,38.15791931152344 +dsd,False,False,8,8,64,8,8,4,34.25900573730469 +dsd,False,False,8,8,128,2,2,2,inf +dsd,False,False,8,16,8,2,4,2,50.28748474121094 +dsd,False,False,8,16,16,4,16,2,39.87994384765625 +dsd,False,False,8,16,32,4,8,4,32.09373474121094 +dsd,False,False,8,16,64,4,2,4,33.66380615234375 +dsd,False,False,8,16,128,8,2,4,32.56879577636719 +dsd,False,False,8,32,8,2,8,2,54.3849609375 +dsd,False,False,8,32,16,2,2,2,35.072271728515624 +dsd,False,False,8,32,32,8,8,2,37.27109375 +dsd,False,False,8,32,64,8,16,2,38.08072509765625 +dsd,False,False,8,32,128,2,2,2,inf +dsd,False,False,8,64,8,2,4,2,53.181591796875 +dsd,False,False,8,64,16,2,8,2,41.441680908203125 +dsd,False,False,8,64,32,2,2,2,40.47706909179688 +dsd,False,False,8,64,64,2,2,2,inf +dsd,False,False,8,64,128,2,2,2,inf +dsd,False,False,8,128,8,2,2,2,inf +dsd,False,False,8,128,16,2,2,2,39.33102111816406 +dsd,False,False,8,128,32,2,2,2,inf +dsd,False,False,8,128,64,2,2,2,inf +dsd,False,False,8,128,128,2,2,2,inf +dsd,False,False,16,8,8,4,4,2,38.00410461425781 +dsd,False,False,16,8,16,4,2,2,23.97266540527344 +dsd,False,False,16,8,32,4,4,4,19.30738220214844 +dsd,False,False,16,8,64,4,8,8,17.43036804199219 +dsd,False,False,16,8,128,16,2,4,16.734518432617186 +dsd,False,False,16,16,8,2,2,2,33.422088623046875 +dsd,False,False,16,16,16,4,4,2,21.60905303955078 +dsd,False,False,16,16,32,4,16,2,19.633290100097657 +dsd,False,False,16,16,64,4,16,8,18.44384002685547 +dsd,False,False,16,16,128,2,4,16,19.249530029296874 +dsd,False,False,16,32,8,2,2,2,39.05540161132812 +dsd,False,False,16,32,16,4,8,2,22.90558776855469 +dsd,False,False,16,32,32,8,8,2,18.76983337402344 +dsd,False,False,16,32,64,4,4,2,19.42920684814453 +dsd,False,False,16,32,128,2,2,2,inf +dsd,False,False,16,64,8,2,8,2,38.244808959960935 +dsd,False,False,16,64,16,4,4,2,24.67497863769531 +dsd,False,False,16,64,32,2,2,2,21.384271240234376 +dsd,False,False,16,64,64,2,2,2,inf +dsd,False,False,16,64,128,2,2,2,inf +dsd,False,False,16,128,8,2,16,2,40.819033813476565 +dsd,False,False,16,128,16,2,2,2,27.883056640625 +dsd,False,False,16,128,32,2,2,2,inf +dsd,False,False,16,128,64,2,2,2,inf +dsd,False,False,16,128,128,2,2,2,inf +dsd,False,False,32,8,8,8,4,2,29.94057006835937 +dsd,False,False,32,8,16,8,4,2,18.484355163574214 +dsd,False,False,32,8,32,8,2,4,14.205218505859374 +dsd,False,False,32,8,64,8,2,8,12.878863525390624 +dsd,False,False,32,8,128,8,8,8,12.873526000976565 +dsd,False,False,32,16,8,4,16,2,26.0935546875 +dsd,False,False,32,16,16,8,4,2,17.127200317382812 +dsd,False,False,32,16,32,16,4,2,13.659455871582033 +dsd,False,False,32,16,64,8,16,4,12.99901123046875 +dsd,False,False,32,16,128,8,4,4,12.17227554321289 +dsd,False,False,32,32,8,4,16,2,32.057009887695315 +dsd,False,False,32,32,16,8,8,2,17.632479858398437 +dsd,False,False,32,32,32,8,16,2,13.9248291015625 +dsd,False,False,32,32,64,8,8,2,13.21557159423828 +dsd,False,False,32,32,128,2,2,2,inf +dsd,False,False,32,64,8,4,16,2,32.44314270019531 +dsd,False,False,32,64,16,4,2,2,19.502828979492183 +dsd,False,False,32,64,32,4,2,2,15.068704223632812 +dsd,False,False,32,64,64,2,2,2,inf +dsd,False,False,32,64,128,2,2,2,inf +dsd,False,False,32,128,8,2,16,2,37.05467529296875 +dsd,False,False,32,128,16,4,4,2,23.824124145507813 +dsd,False,False,32,128,32,2,2,2,inf +dsd,False,False,32,128,64,2,2,2,inf +dsd,False,False,32,128,128,2,2,2,inf +dsd,False,False,64,8,8,16,2,2,25.585133361816407 +dsd,False,False,64,8,16,8,2,4,15.365625 +dsd,False,False,64,8,32,16,8,4,12.445209503173828 +dsd,False,False,64,8,64,8,8,8,12.021456146240237 +dsd,False,False,64,8,128,16,2,4,11.294950103759763 +dsd,False,False,64,16,8,8,16,2,22.789225769042968 +dsd,False,False,64,16,16,16,8,2,14.578399658203123 +dsd,False,False,64,16,32,8,4,4,12.400835418701172 +dsd,False,False,64,16,64,8,2,4,11.663340759277345 +dsd,False,False,64,16,128,16,8,4,11.174070739746094 +dsd,False,False,64,32,8,8,4,2,28.659817504882813 +dsd,False,False,64,32,16,8,16,2,16.314057922363283 +dsd,False,False,64,32,32,8,2,2,12.693769836425782 +dsd,False,False,64,32,64,8,8,4,11.483523559570312 +dsd,False,False,64,32,128,2,2,2,inf +dsd,False,False,64,64,8,4,8,2,31.395059204101564 +dsd,False,False,64,64,16,8,16,2,18.30839385986328 +dsd,False,False,64,64,32,8,8,2,12.973321533203125 +dsd,False,False,64,64,64,2,2,2,inf +dsd,False,False,64,64,128,2,2,2,inf +dsd,False,False,64,128,8,2,2,2,43.200830078125 +dsd,False,False,64,128,16,4,2,2,20.80726013183594 +dsd,False,False,64,128,32,2,2,2,inf +dsd,False,False,64,128,64,2,2,2,inf +dsd,False,False,64,128,128,2,2,2,inf +dsd,False,False,128,8,8,16,2,4,25.693502807617183 +dsd,False,False,128,8,16,16,2,4,14.128346252441409 +dsd,False,False,128,8,32,8,2,8,12.39136962890625 +dsd,False,False,128,8,64,8,8,8,11.45447998046875 +dsd,False,False,128,8,128,8,8,8,12.337554931640623 +dsd,False,False,128,16,8,16,2,2,20.80894012451172 +dsd,False,False,128,16,16,16,16,2,14.157225036621094 +dsd,False,False,128,16,32,8,4,4,11.805955505371092 +dsd,False,False,128,16,64,16,8,4,11.266896057128909 +dsd,False,False,128,16,128,4,16,16,12.197372436523438 +dsd,False,False,128,32,8,8,16,2,28.485455322265626 +dsd,False,False,128,32,16,8,8,2,15.9461669921875 +dsd,False,False,128,32,32,16,16,2,11.877593231201171 +dsd,False,False,128,32,64,16,2,2,11.360527801513673 +dsd,False,False,128,32,128,2,2,2,inf +dsd,False,False,128,64,8,4,2,2,33.3581298828125 +dsd,False,False,128,64,16,8,16,2,17.284320068359374 +dsd,False,False,128,64,32,8,8,2,12.954953002929688 +dsd,False,False,128,64,64,2,2,2,inf +dsd,False,False,128,64,128,2,2,2,inf +dsd,False,False,128,128,8,2,2,2,inf +dsd,False,False,128,128,16,2,2,2,inf +dsd,False,False,128,128,32,2,2,2,inf +dsd,False,False,128,128,64,2,2,2,inf +dsd,False,False,128,128,128,2,2,2,inf +dsd,False,True,8,8,8,2,2,2,64.83031005859375 +dsd,False,True,8,8,16,4,8,2,45.480560302734375 +dsd,False,True,8,8,32,8,8,2,43.537939453125 +dsd,False,True,8,8,64,4,4,8,41.256735229492186 +dsd,False,True,8,8,128,8,8,8,34.57571716308594 +dsd,False,True,8,16,8,2,16,2,66.96640625 +dsd,False,True,8,16,16,2,4,2,42.99291076660156 +dsd,False,True,8,16,32,2,4,4,41.06326904296875 +dsd,False,True,8,16,64,4,8,4,37.438262939453125 +dsd,False,True,8,16,128,8,4,4,35.173666381835936 +dsd,False,True,8,32,8,2,4,2,89.96324462890625 +dsd,False,True,8,32,16,2,8,2,47.99249877929688 +dsd,False,True,8,32,32,4,8,2,44.1489990234375 +dsd,False,True,8,32,64,4,2,2,41.00673217773438 +dsd,False,True,8,32,128,2,2,2,inf +dsd,False,True,8,64,8,2,16,2,131.4682373046875 +dsd,False,True,8,64,16,2,16,2,67.10629272460938 +dsd,False,True,8,64,32,4,4,2,58.920703125 +dsd,False,True,8,64,64,2,2,2,inf +dsd,False,True,8,64,128,2,2,2,inf +dsd,False,True,8,128,8,2,2,2,inf +dsd,False,True,8,128,16,2,2,2,108.13199462890626 +dsd,False,True,8,128,32,2,2,2,inf +dsd,False,True,8,128,64,2,2,2,inf +dsd,False,True,8,128,128,2,2,2,inf +dsd,False,True,16,8,8,4,8,2,42.252691650390624 +dsd,False,True,16,8,16,4,8,2,26.201754760742187 +dsd,False,True,16,8,32,8,2,4,26.348110961914063 +dsd,False,True,16,8,64,16,8,4,23.99023284912109 +dsd,False,True,16,8,128,4,4,16,19.612937927246094 +dsd,False,True,16,16,8,2,16,2,36.32303161621094 +dsd,False,True,16,16,16,4,8,2,24.885910034179688 +dsd,False,True,16,16,32,8,2,2,24.24401245117188 +dsd,False,True,16,16,64,8,8,4,21.024485778808597 +dsd,False,True,16,16,128,16,8,2,20.268089294433597 +dsd,False,True,16,32,8,2,16,2,47.83821105957031 +dsd,False,True,16,32,16,4,8,2,29.85913391113281 +dsd,False,True,16,32,32,8,16,2,24.07604217529297 +dsd,False,True,16,32,64,4,16,2,23.95204162597656 +dsd,False,True,16,32,128,2,2,2,inf +dsd,False,True,16,64,8,2,16,2,58.60889282226562 +dsd,False,True,16,64,16,4,4,2,39.94730529785157 +dsd,False,True,16,64,32,4,16,2,33.722048950195315 +dsd,False,True,16,64,64,2,2,2,inf +dsd,False,True,16,64,128,2,2,2,inf +dsd,False,True,16,128,8,2,4,2,80.9656982421875 +dsd,False,True,16,128,16,4,2,2,63.73284912109375 +dsd,False,True,16,128,32,2,2,2,inf +dsd,False,True,16,128,64,2,2,2,inf +dsd,False,True,16,128,128,2,2,2,inf +dsd,False,True,32,8,8,8,2,2,32.44737243652344 +dsd,False,True,32,8,16,8,8,2,20.403779602050783 +dsd,False,True,32,8,32,8,2,2,17.346044921875 +dsd,False,True,32,8,64,8,8,4,15.755056762695313 +dsd,False,True,32,8,128,8,2,8,14.2900634765625 +dsd,False,True,32,16,8,4,16,2,27.93015747070313 +dsd,False,True,32,16,16,4,2,2,19.360806274414063 +dsd,False,True,32,16,32,8,16,4,16.002423095703126 +dsd,False,True,32,16,64,8,16,4,14.859500122070312 +dsd,False,True,32,16,128,8,4,4,13.441619873046877 +dsd,False,True,32,32,8,4,8,2,36.38385009765625 +dsd,False,True,32,32,16,8,16,2,21.3083740234375 +dsd,False,True,32,32,32,8,16,2,17.330905151367187 +dsd,False,True,32,32,64,4,4,4,16.176535034179686 +dsd,False,True,32,32,128,2,2,2,inf +dsd,False,True,32,64,8,4,2,2,42.07832946777344 +dsd,False,True,32,64,16,8,16,2,27.796780395507813 +dsd,False,True,32,64,32,4,4,2,22.622706604003906 +dsd,False,True,32,64,64,2,2,2,inf +dsd,False,True,32,64,128,2,2,2,inf +dsd,False,True,32,128,8,2,16,2,59.004339599609374 +dsd,False,True,32,128,16,4,8,2,40.63676452636719 +dsd,False,True,32,128,32,2,2,2,inf +dsd,False,True,32,128,64,2,2,2,inf +dsd,False,True,32,128,128,2,2,2,inf +dsd,False,True,64,8,8,16,2,2,28.06235656738281 +dsd,False,True,64,8,16,8,4,4,17.711637878417967 +dsd,False,True,64,8,32,8,8,4,14.624826049804687 +dsd,False,True,64,8,64,8,2,4,13.30050811767578 +dsd,False,True,64,8,128,8,2,8,11.8025634765625 +dsd,False,True,64,16,8,8,8,2,23.832765197753908 +dsd,False,True,64,16,16,8,4,2,15.731478881835937 +dsd,False,True,64,16,32,8,8,4,13.64959716796875 +dsd,False,True,64,16,64,8,16,4,12.479481506347655 +dsd,False,True,64,16,128,16,4,4,12.026067352294922 +dsd,False,True,64,32,8,8,8,2,30.883258056640624 +dsd,False,True,64,32,16,8,8,2,18.295631408691406 +dsd,False,True,64,32,32,8,8,2,14.4781982421875 +dsd,False,True,64,32,64,8,16,4,12.99908447265625 +dsd,False,True,64,32,128,2,2,2,inf +dsd,False,True,64,64,8,4,16,2,36.06336059570312 +dsd,False,True,64,64,16,8,8,2,22.488557434082036 +dsd,False,True,64,64,32,8,4,2,16.641293334960938 +dsd,False,True,64,64,64,2,2,2,inf +dsd,False,True,64,64,128,2,2,2,inf +dsd,False,True,64,128,8,2,2,2,49.170901489257815 +dsd,False,True,64,128,16,4,2,2,30.151165771484376 +dsd,False,True,64,128,32,2,2,2,inf +dsd,False,True,64,128,64,2,2,2,inf +dsd,False,True,64,128,128,2,2,2,inf +dsd,False,True,128,8,8,16,2,4,28.43712463378906 +dsd,False,True,128,8,16,16,2,4,18.233078002929688 +dsd,False,True,128,8,32,16,8,4,14.328166198730468 +dsd,False,True,128,8,64,8,8,8,12.280473327636718 +dsd,False,True,128,8,128,8,2,8,12.583350372314452 +dsd,False,True,128,16,8,16,2,2,21.776124572753908 +dsd,False,True,128,16,16,16,2,2,15.49456024169922 +dsd,False,True,128,16,32,8,4,4,12.65200653076172 +dsd,False,True,128,16,64,16,2,4,12.016060638427737 +dsd,False,True,128,16,128,8,8,8,12.376700592041017 +dsd,False,True,128,32,8,8,4,2,29.58363342285156 +dsd,False,True,128,32,16,8,8,2,17.026530456542968 +dsd,False,True,128,32,32,16,16,2,13.03179473876953 +dsd,False,True,128,32,64,8,16,4,12.44648666381836 +dsd,False,True,128,32,128,2,2,2,inf +dsd,False,True,128,64,8,4,16,2,35.76015930175781 +dsd,False,True,128,64,16,8,8,2,19.7291259765625 +dsd,False,True,128,64,32,8,4,2,15.192633056640624 +dsd,False,True,128,64,64,2,2,2,inf +dsd,False,True,128,64,128,2,2,2,inf +dsd,False,True,128,128,8,2,2,2,inf +dsd,False,True,128,128,16,2,2,2,inf +dsd,False,True,128,128,32,2,2,2,inf +dsd,False,True,128,128,64,2,2,2,inf +dsd,False,True,128,128,128,2,2,2,inf +dsd,True,False,8,8,8,2,4,2,64.50003662109376 +dsd,True,False,8,8,16,4,8,2,43.85590515136719 +dsd,True,False,8,8,32,8,4,2,38.58082275390625 +dsd,True,False,8,8,64,4,4,8,35.12033386230469 +dsd,True,False,8,8,128,2,2,2,inf +dsd,True,False,8,16,8,2,4,2,60.79771728515625 +dsd,True,False,8,16,16,2,4,2,42.94162902832032 +dsd,True,False,8,16,32,4,8,4,37.130557250976565 +dsd,True,False,8,16,64,2,2,8,34.81022644042969 +dsd,True,False,8,16,128,8,2,4,33.78464965820312 +dsd,True,False,8,32,8,2,4,2,66.80986328125 +dsd,True,False,8,32,16,2,4,2,44.515805053710935 +dsd,True,False,8,32,32,2,8,4,40.30407104492188 +dsd,True,False,8,32,64,8,2,2,39.0494140625 +dsd,True,False,8,32,128,2,2,2,inf +dsd,True,False,8,64,8,2,4,2,66.54736328125 +dsd,True,False,8,64,16,2,4,2,44.97365112304688 +dsd,True,False,8,64,32,2,2,2,42.03638610839844 +dsd,True,False,8,64,64,2,2,2,inf +dsd,True,False,8,64,128,2,2,2,inf +dsd,True,False,8,128,8,2,8,2,69.602685546875 +dsd,True,False,8,128,16,2,2,2,42.459521484375 +dsd,True,False,8,128,32,2,2,2,inf +dsd,True,False,8,128,64,2,2,2,inf +dsd,True,False,8,128,128,2,2,2,inf +dsd,True,False,16,8,8,4,8,2,49.50525817871094 +dsd,True,False,16,8,16,4,2,2,25.052706909179687 +dsd,True,False,16,8,32,8,2,2,19.538114929199217 +dsd,True,False,16,8,64,4,2,8,17.691792297363282 +dsd,True,False,16,8,128,8,2,8,16.321661376953124 +dsd,True,False,16,16,8,2,2,2,39.98442993164063 +dsd,True,False,16,16,16,4,4,2,23.760643005371094 +dsd,True,False,16,16,32,4,4,2,20.69975433349609 +dsd,True,False,16,16,64,4,2,8,19.054969787597656 +dsd,True,False,16,16,128,4,4,8,19.52235260009765 +dsd,True,False,16,32,8,2,16,2,48.63282470703125 +dsd,True,False,16,32,16,4,8,2,26.63685607910156 +dsd,True,False,16,32,32,8,4,2,19.93187561035156 +dsd,True,False,16,32,64,4,16,4,20.57270965576172 +dsd,True,False,16,32,128,2,2,2,inf +dsd,True,False,16,64,8,2,16,2,47.72169799804688 +dsd,True,False,16,64,16,4,2,2,26.349447631835936 +dsd,True,False,16,64,32,4,4,2,22.27334442138672 +dsd,True,False,16,64,64,2,2,2,inf +dsd,True,False,16,64,128,2,2,2,inf +dsd,True,False,16,128,8,2,4,2,49.356442260742185 +dsd,True,False,16,128,16,2,16,2,30.78685607910156 +dsd,True,False,16,128,32,2,2,2,inf +dsd,True,False,16,128,64,2,2,2,inf +dsd,True,False,16,128,128,2,2,2,inf +dsd,True,False,32,8,8,8,4,2,58.26485595703125 +dsd,True,False,32,8,16,8,4,2,22.083929443359374 +dsd,True,False,32,8,32,8,4,4,15.802262878417968 +dsd,True,False,32,8,64,8,2,8,13.385020446777345 +dsd,True,False,32,8,128,8,2,8,12.973992919921876 +dsd,True,False,32,16,8,4,16,2,40.613363647460936 +dsd,True,False,32,16,16,8,4,2,21.56761932373047 +dsd,True,False,32,16,32,8,16,4,15.306796264648437 +dsd,True,False,32,16,64,8,2,4,13.917820739746094 +dsd,True,False,32,16,128,8,2,4,12.738409423828124 +dsd,True,False,32,32,8,4,2,2,49.412847900390624 +dsd,True,False,32,32,16,8,2,2,25.411875915527343 +dsd,True,False,32,32,32,8,4,2,16.540342712402342 +dsd,True,False,32,32,64,8,16,2,14.542752075195311 +dsd,True,False,32,32,128,2,2,2,inf +dsd,True,False,32,64,8,4,8,2,48.77427368164062 +dsd,True,False,32,64,16,8,8,2,27.46026611328125 +dsd,True,False,32,64,32,4,4,2,17.95720977783203 +dsd,True,False,32,64,64,2,2,2,inf +dsd,True,False,32,64,128,2,2,2,inf +dsd,True,False,32,128,8,2,16,2,53.78228149414063 +dsd,True,False,32,128,16,4,2,2,30.272491455078125 +dsd,True,False,32,128,32,2,2,2,inf +dsd,True,False,32,128,64,2,2,2,inf +dsd,True,False,32,128,128,2,2,2,inf +dsd,True,False,64,8,8,16,4,2,95.88981323242189 +dsd,True,False,64,8,16,8,2,4,30.042803955078124 +dsd,True,False,64,8,32,8,8,8,19.73251800537109 +dsd,True,False,64,8,64,8,2,4,15.112445068359374 +dsd,True,False,64,8,128,16,8,4,13.197616577148438 +dsd,True,False,64,16,8,8,16,2,57.47857055664063 +dsd,True,False,64,16,16,16,4,2,30.21904296875 +dsd,True,False,64,16,32,8,2,4,19.610873413085937 +dsd,True,False,64,16,64,16,2,2,15.586691284179688 +dsd,True,False,64,16,128,16,2,4,13.0103515625 +dsd,True,False,64,32,8,8,16,2,65.5755126953125 +dsd,True,False,64,32,16,8,4,2,34.70885009765625 +dsd,True,False,64,32,32,16,4,2,20.928297424316405 +dsd,True,False,64,32,64,8,16,4,15.283062744140626 +dsd,True,False,64,32,128,2,2,2,inf +dsd,True,False,64,64,8,8,16,2,66.76674194335938 +dsd,True,False,64,64,16,8,8,2,34.96489562988281 +dsd,True,False,64,64,32,8,4,2,21.019471740722658 +dsd,True,False,64,64,64,2,2,2,inf +dsd,True,False,64,64,128,2,2,2,inf +dsd,True,False,64,128,8,2,8,2,72.95369873046874 +dsd,True,False,64,128,16,4,16,2,37.51438598632812 +dsd,True,False,64,128,32,2,2,2,inf +dsd,True,False,64,128,64,2,2,2,inf +dsd,True,False,64,128,128,2,2,2,inf +dsd,True,False,128,8,8,2,2,2,inf +dsd,True,False,128,8,16,16,8,4,52.0931640625 +dsd,True,False,128,8,32,16,4,4,31.23505554199219 +dsd,True,False,128,8,64,16,2,4,20.91832885742188 +dsd,True,False,128,8,128,16,2,8,16.886825561523438 +dsd,True,False,128,16,8,16,4,2,97.2037109375 +dsd,True,False,128,16,16,16,2,2,51.648046875 +dsd,True,False,128,16,32,16,2,2,30.59534606933594 +dsd,True,False,128,16,64,16,2,4,20.84552001953125 +dsd,True,False,128,16,128,8,16,8,16.009730529785156 +dsd,True,False,128,32,8,8,16,2,104.88055419921876 +dsd,True,False,128,32,16,16,8,2,53.86217041015625 +dsd,True,False,128,32,32,16,8,2,30.620196533203124 +dsd,True,False,128,32,64,16,8,4,20.649945068359376 +dsd,True,False,128,32,128,2,2,2,inf +dsd,True,False,128,64,8,4,4,2,106.9627685546875 +dsd,True,False,128,64,16,8,16,2,54.41331176757812 +dsd,True,False,128,64,32,16,8,2,31.31766357421875 +dsd,True,False,128,64,64,2,2,2,inf +dsd,True,False,128,64,128,2,2,2,inf +dsd,True,False,128,128,8,2,2,2,inf +dsd,True,False,128,128,16,2,2,2,inf +dsd,True,False,128,128,32,2,2,2,inf +dsd,True,False,128,128,64,2,2,2,inf +dsd,True,False,128,128,128,2,2,2,inf +dsd,True,True,8,8,8,2,4,2,71.19293212890625 +dsd,True,True,8,8,16,4,8,2,47.526544189453126 +dsd,True,True,8,8,32,8,4,2,45.897119140625 +dsd,True,True,8,8,64,4,4,8,42.27239074707031 +dsd,True,True,8,8,128,8,8,8,34.71629638671875 +dsd,True,True,8,16,8,2,16,2,77.76537475585937 +dsd,True,True,8,16,16,2,8,2,44.7300048828125 +dsd,True,True,8,16,32,8,16,2,47.677008056640624 +dsd,True,True,8,16,64,2,4,8,43.859231567382814 +dsd,True,True,8,16,128,4,2,8,38.1770751953125 +dsd,True,True,8,32,8,2,16,2,104.03848876953126 +dsd,True,True,8,32,16,2,4,2,53.5608642578125 +dsd,True,True,8,32,32,4,8,2,47.158810424804685 +dsd,True,True,8,32,64,4,2,2,44.61595764160156 +dsd,True,True,8,32,128,2,2,2,inf +dsd,True,True,8,64,8,2,8,2,145.02247314453126 +dsd,True,True,8,64,16,2,16,2,71.74268188476563 +dsd,True,True,8,64,32,2,4,2,62.39791259765625 +dsd,True,True,8,64,64,2,2,2,inf +dsd,True,True,8,64,128,2,2,2,inf +dsd,True,True,8,128,8,2,2,2,inf +dsd,True,True,8,128,16,2,8,2,112.65904541015624 +dsd,True,True,8,128,32,2,2,2,inf +dsd,True,True,8,128,64,2,2,2,inf +dsd,True,True,8,128,128,2,2,2,inf +dsd,True,True,16,8,8,4,4,2,53.01531982421875 +dsd,True,True,16,8,16,4,4,2,27.02744445800781 +dsd,True,True,16,8,32,8,4,2,25.09804229736328 +dsd,True,True,16,8,64,16,8,2,20.802879333496094 +dsd,True,True,16,8,128,16,2,4,17.75169219970703 +dsd,True,True,16,16,8,2,16,2,43.57724304199219 +dsd,True,True,16,16,16,4,16,2,27.37100830078125 +dsd,True,True,16,16,32,4,2,2,26.01639404296875 +dsd,True,True,16,16,64,4,8,8,22.524560546875 +dsd,True,True,16,16,128,4,2,8,20.59218292236328 +dsd,True,True,16,32,8,2,8,2,57.39716796875 +dsd,True,True,16,32,16,4,8,2,34.71137390136719 +dsd,True,True,16,32,32,8,2,2,27.15743103027344 +dsd,True,True,16,32,64,4,4,4,25.47241668701172 +dsd,True,True,16,32,128,2,2,2,inf +dsd,True,True,16,64,8,2,4,2,67.52178955078125 +dsd,True,True,16,64,16,4,4,2,44.9527587890625 +dsd,True,True,16,64,32,4,16,2,36.574038696289065 +dsd,True,True,16,64,64,2,2,2,inf +dsd,True,True,16,64,128,2,2,2,inf +dsd,True,True,16,128,8,2,16,2,88.12987060546875 +dsd,True,True,16,128,16,2,4,2,70.0388427734375 +dsd,True,True,16,128,32,2,2,2,inf +dsd,True,True,16,128,64,2,2,2,inf +dsd,True,True,16,128,128,2,2,2,inf +dsd,True,True,32,8,8,8,2,2,60.1640625 +dsd,True,True,32,8,16,8,8,2,22.79217987060547 +dsd,True,True,32,8,32,8,8,4,17.33423309326172 +dsd,True,True,32,8,64,8,2,8,14.491964721679688 +dsd,True,True,32,8,128,8,8,8,13.583482360839843 +dsd,True,True,32,16,8,4,16,2,42.594668579101565 +dsd,True,True,32,16,16,8,4,2,23.31687622070313 +dsd,True,True,32,16,32,8,8,4,17.858428955078125 +dsd,True,True,32,16,64,8,16,4,15.46243133544922 +dsd,True,True,32,16,128,8,16,4,13.765618896484376 +dsd,True,True,32,32,8,4,16,2,53.68150024414062 +dsd,True,True,32,32,16,8,4,2,29.81298828125 +dsd,True,True,32,32,32,8,4,2,20.115510559082036 +dsd,True,True,32,32,64,4,2,4,17.780838012695312 +dsd,True,True,32,32,128,2,2,2,inf +dsd,True,True,32,64,8,4,16,2,58.4751220703125 +dsd,True,True,32,64,16,8,16,2,35.29141845703125 +dsd,True,True,32,64,32,8,4,2,26.837109375 +dsd,True,True,32,64,64,2,2,2,inf +dsd,True,True,32,64,128,2,2,2,inf +dsd,True,True,32,128,8,2,4,2,74.15425415039063 +dsd,True,True,32,128,16,4,2,2,49.33531799316406 +dsd,True,True,32,128,32,2,2,2,inf +dsd,True,True,32,128,64,2,2,2,inf +dsd,True,True,32,128,128,2,2,2,inf +dsd,True,True,64,8,8,16,8,2,96.929638671875 +dsd,True,True,64,8,16,8,4,4,30.505270385742183 +dsd,True,True,64,8,32,8,2,4,20.189126586914064 +dsd,True,True,64,8,64,8,8,4,15.148918151855469 +dsd,True,True,64,8,128,16,8,4,13.314019775390625 +dsd,True,True,64,16,8,8,8,2,58.3099365234375 +dsd,True,True,64,16,16,16,4,2,31.09625244140625 +dsd,True,True,64,16,32,8,2,4,20.450559997558592 +dsd,True,True,64,16,64,8,16,4,16.086697387695313 +dsd,True,True,64,16,128,16,4,4,13.981097412109374 +dsd,True,True,64,32,8,8,4,2,67.69793090820312 +dsd,True,True,64,32,16,16,16,2,36.72890014648438 +dsd,True,True,64,32,32,16,16,2,22.644483947753905 +dsd,True,True,64,32,64,8,16,4,17.29567108154297 +dsd,True,True,64,32,128,2,2,2,inf +dsd,True,True,64,64,8,8,16,2,71.69373168945313 +dsd,True,True,64,64,16,8,16,2,39.72655334472656 +dsd,True,True,64,64,32,8,8,2,25.80042419433594 +dsd,True,True,64,64,64,2,2,2,inf +dsd,True,True,64,64,128,2,2,2,inf +dsd,True,True,64,128,8,2,8,2,83.09647216796876 +dsd,True,True,64,128,16,4,16,2,47.12303771972656 +dsd,True,True,64,128,32,2,2,2,inf +dsd,True,True,64,128,64,2,2,2,inf +dsd,True,True,64,128,128,2,2,2,inf +dsd,True,True,128,8,8,2,2,2,inf +dsd,True,True,128,8,16,16,8,4,52.12281494140625 +dsd,True,True,128,8,32,16,8,4,31.438470458984376 +dsd,True,True,128,8,64,16,8,4,20.78094024658203 +dsd,True,True,128,8,128,16,2,8,17.265670776367188 +dsd,True,True,128,16,8,16,16,2,97.34880981445312 +dsd,True,True,128,16,16,16,8,2,51.683660888671874 +dsd,True,True,128,16,32,16,2,2,30.91318664550781 +dsd,True,True,128,16,64,16,16,4,21.14001007080078 +dsd,True,True,128,16,128,4,2,16,16.960723876953125 +dsd,True,True,128,32,8,8,8,2,105.91396484375 +dsd,True,True,128,32,16,16,4,2,54.87500610351562 +dsd,True,True,128,32,32,16,8,2,31.5608642578125 +dsd,True,True,128,32,64,8,4,4,21.677349853515626 +dsd,True,True,128,32,128,2,2,2,inf +dsd,True,True,128,64,8,4,16,2,109.50223388671876 +dsd,True,True,128,64,16,8,16,2,56.69862670898438 +dsd,True,True,128,64,32,16,8,2,33.59224548339844 +dsd,True,True,128,64,64,2,2,2,inf +dsd,True,True,128,64,128,2,2,2,inf +dsd,True,True,128,128,8,2,2,2,inf +dsd,True,True,128,128,16,2,2,2,inf +dsd,True,True,128,128,32,2,2,2,inf +dsd,True,True,128,128,64,2,2,2,inf +dsd,True,True,128,128,128,2,2,2,inf +sdd,False,False,8,8,8,2,2,2,72.04984130859376 +sdd,False,False,8,8,16,4,8,2,45.3943115234375 +sdd,False,False,8,8,32,8,4,2,38.74576110839844 +sdd,False,False,8,8,64,4,2,8,36.7857421875 +sdd,False,False,8,8,128,2,2,2,inf +sdd,False,False,8,16,8,2,16,2,70.55133056640625 +sdd,False,False,8,16,16,4,16,2,46.160275268554685 +sdd,False,False,8,16,32,4,2,2,36.251708984375 +sdd,False,False,8,16,64,2,2,8,34.42546691894531 +sdd,False,False,8,16,128,4,2,8,35.05966491699219 +sdd,False,False,8,32,8,4,4,2,81.95740966796875 +sdd,False,False,8,32,16,2,4,2,46.26115417480469 +sdd,False,False,8,32,32,4,8,2,38.46751708984375 +sdd,False,False,8,32,64,8,4,2,37.41645812988281 +sdd,False,False,8,32,128,8,16,2,37.6352294921875 +sdd,False,False,8,64,8,2,4,2,74.85977172851562 +sdd,False,False,8,64,16,2,8,2,43.94977111816407 +sdd,False,False,8,64,32,2,16,2,38.66516418457032 +sdd,False,False,8,64,64,2,16,2,37.02501525878906 +sdd,False,False,8,64,128,4,2,2,37.57594604492188 +sdd,False,False,8,128,8,2,2,2,inf +sdd,False,False,8,128,16,2,8,2,50.735546875 +sdd,False,False,8,128,32,2,4,2,44.66907958984375 +sdd,False,False,8,128,64,2,16,2,42.6173095703125 +sdd,False,False,8,128,128,2,2,2,inf +sdd,False,False,16,8,8,4,4,2,40.41903381347656 +sdd,False,False,16,8,16,4,8,2,24.17235565185547 +sdd,False,False,16,8,32,8,8,2,19.93219146728516 +sdd,False,False,16,8,64,2,2,16,19.359555053710935 +sdd,False,False,16,8,128,16,8,4,19.00537872314453 +sdd,False,False,16,16,8,4,4,2,42.32752380371094 +sdd,False,False,16,16,16,2,16,4,25.86504821777344 +sdd,False,False,16,16,32,4,2,2,20.175260925292967 +sdd,False,False,16,16,64,4,16,8,19.01503295898437 +sdd,False,False,16,16,128,4,2,8,19.52960357666016 +sdd,False,False,16,32,8,4,4,2,49.81890258789063 +sdd,False,False,16,32,16,4,2,2,26.464312744140624 +sdd,False,False,16,32,32,8,2,2,20.246003723144533 +sdd,False,False,16,32,64,4,4,4,19.618144226074214 +sdd,False,False,16,32,128,4,4,4,19.372467041015625 +sdd,False,False,16,64,8,4,2,2,43.28736572265625 +sdd,False,False,16,64,16,4,8,2,26.00096740722656 +sdd,False,False,16,64,32,4,2,2,20.53853759765625 +sdd,False,False,16,64,64,4,16,2,19.274000549316405 +sdd,False,False,16,64,128,4,4,2,19.3044677734375 +sdd,False,False,16,128,8,2,16,2,43.4588134765625 +sdd,False,False,16,128,16,2,8,2,29.749468994140624 +sdd,False,False,16,128,32,2,8,2,22.932847595214845 +sdd,False,False,16,128,64,4,2,2,22.40806121826172 +sdd,False,False,16,128,128,2,2,2,inf +sdd,False,False,32,8,8,8,2,2,28.4953125 +sdd,False,False,32,8,16,8,4,2,16.595855712890625 +sdd,False,False,32,8,32,8,2,4,13.706492614746091 +sdd,False,False,32,8,64,8,4,8,13.076988220214844 +sdd,False,False,32,8,128,8,2,8,13.090707397460935 +sdd,False,False,32,16,8,8,8,2,28.655096435546877 +sdd,False,False,32,16,16,8,16,2,17.110902404785158 +sdd,False,False,32,16,32,8,4,4,13.602815246582033 +sdd,False,False,32,16,64,8,8,4,13.031321716308591 +sdd,False,False,32,16,128,8,8,4,12.628358459472656 +sdd,False,False,32,32,8,4,16,2,34.516058349609374 +sdd,False,False,32,32,16,8,16,2,18.954006958007813 +sdd,False,False,32,32,32,8,2,2,14.182521057128906 +sdd,False,False,32,32,64,8,4,2,13.5316162109375 +sdd,False,False,32,32,128,8,4,4,12.78335647583008 +sdd,False,False,32,64,8,4,8,2,33.892111206054686 +sdd,False,False,32,64,16,4,16,2,20.76946868896484 +sdd,False,False,32,64,32,4,8,2,15.313661193847656 +sdd,False,False,32,64,64,8,16,2,13.90250244140625 +sdd,False,False,32,64,128,8,16,2,13.511581420898438 +sdd,False,False,32,128,8,2,2,2,inf +sdd,False,False,32,128,16,2,2,2,inf +sdd,False,False,32,128,32,2,2,2,inf +sdd,False,False,32,128,64,2,2,2,inf +sdd,False,False,32,128,128,2,2,2,inf +sdd,False,False,64,8,8,8,4,4,23.24774017333984 +sdd,False,False,64,8,16,8,8,4,13.581919860839845 +sdd,False,False,64,8,32,8,2,8,12.061090850830078 +sdd,False,False,64,8,64,8,2,8,11.904332733154297 +sdd,False,False,64,8,128,8,8,8,11.4974365234375 +sdd,False,False,64,16,8,8,8,2,22.200015258789065 +sdd,False,False,64,16,16,16,16,2,13.94998779296875 +sdd,False,False,64,16,32,8,2,4,12.431423950195311 +sdd,False,False,64,16,64,8,16,4,11.724240112304688 +sdd,False,False,64,16,128,16,4,4,11.464300537109375 +sdd,False,False,64,32,8,8,16,2,29.659991455078124 +sdd,False,False,64,32,16,8,4,2,16.99700164794922 +sdd,False,False,64,32,32,8,16,2,12.782505798339844 +sdd,False,False,64,32,64,8,16,4,11.693708801269532 +sdd,False,False,64,32,128,16,16,2,11.619532775878906 +sdd,False,False,64,64,8,2,2,2,inf +sdd,False,False,64,64,16,2,2,2,inf +sdd,False,False,64,64,32,2,2,2,inf +sdd,False,False,64,64,64,2,2,2,inf +sdd,False,False,64,64,128,2,2,2,inf +sdd,False,False,64,128,8,2,2,2,inf +sdd,False,False,64,128,16,2,2,2,inf +sdd,False,False,64,128,32,2,2,2,inf +sdd,False,False,64,128,64,2,2,2,inf +sdd,False,False,64,128,128,2,2,2,inf +sdd,False,False,128,8,8,16,4,4,22.090733337402344 +sdd,False,False,128,8,16,16,4,4,12.416374206542969 +sdd,False,False,128,8,32,8,2,8,11.87421112060547 +sdd,False,False,128,8,64,8,8,8,11.37545928955078 +sdd,False,False,128,8,128,8,8,8,12.38086395263672 +sdd,False,False,128,16,8,16,4,2,19.331510925292967 +sdd,False,False,128,16,16,16,16,2,13.283071899414065 +sdd,False,False,128,16,32,8,16,4,11.54400634765625 +sdd,False,False,128,16,64,16,8,4,11.143462371826171 +sdd,False,False,128,16,128,8,4,8,11.590351867675782 +sdd,False,False,128,32,8,2,2,2,inf +sdd,False,False,128,32,16,2,2,2,inf +sdd,False,False,128,32,32,2,2,2,inf +sdd,False,False,128,32,64,2,2,2,inf +sdd,False,False,128,32,128,2,2,2,inf +sdd,False,False,128,64,8,2,2,2,inf +sdd,False,False,128,64,16,2,2,2,inf +sdd,False,False,128,64,32,2,2,2,inf +sdd,False,False,128,64,64,2,2,2,inf +sdd,False,False,128,64,128,2,2,2,inf +sdd,False,False,128,128,8,2,2,2,inf +sdd,False,False,128,128,16,2,2,2,inf +sdd,False,False,128,128,32,2,2,2,inf +sdd,False,False,128,128,64,2,2,2,inf +sdd,False,False,128,128,128,2,2,2,inf +sdd,False,True,8,8,8,2,2,2,91.77452392578124 +sdd,False,True,8,8,16,2,8,8,96.86319580078126 +sdd,False,True,8,8,32,2,4,16,99.28195190429688 +sdd,False,True,8,8,64,4,8,16,129.086669921875 +sdd,False,True,8,8,128,8,4,16,91.06226806640623 +sdd,False,True,8,16,8,2,16,2,66.354736328125 +sdd,False,True,8,16,16,4,16,2,48.97417907714844 +sdd,False,True,8,16,32,4,16,2,46.55622253417969 +sdd,False,True,8,16,64,8,16,4,45.48564758300781 +sdd,False,True,8,16,128,2,4,16,44.08933715820312 +sdd,False,True,8,32,8,2,8,2,89.6635986328125 +sdd,False,True,8,32,16,2,4,2,46.87720947265625 +sdd,False,True,8,32,32,4,16,2,38.79921264648438 +sdd,False,True,8,32,64,4,8,2,38.43687744140625 +sdd,False,True,8,32,128,8,4,2,38.750390625 +sdd,False,True,8,64,8,2,2,2,131.63740234375 +sdd,False,True,8,64,16,2,16,2,66.56493530273437 +sdd,False,True,8,64,32,4,2,2,56.46593017578125 +sdd,False,True,8,64,64,4,16,2,54.95045166015625 +sdd,False,True,8,64,128,4,4,2,55.65963745117188 +sdd,False,True,8,128,8,2,2,2,inf +sdd,False,True,8,128,16,2,8,2,106.23653564453124 +sdd,False,True,8,128,32,2,2,2,101.65936889648438 +sdd,False,True,8,128,64,2,4,2,100.1741455078125 +sdd,False,True,8,128,128,2,2,2,inf +sdd,False,True,16,8,8,4,8,2,47.9957275390625 +sdd,False,True,16,8,16,8,4,4,49.27310485839844 +sdd,False,True,16,8,32,16,8,4,47.62242431640625 +sdd,False,True,16,8,64,2,4,16,54.047509765625 +sdd,False,True,16,8,128,8,2,8,52.71130981445312 +sdd,False,True,16,16,8,2,4,2,35.24564819335937 +sdd,False,True,16,16,16,2,16,4,27.9607421875 +sdd,False,True,16,16,32,8,4,2,24.983120727539063 +sdd,False,True,16,16,64,4,2,8,24.548185729980467 +sdd,False,True,16,16,128,4,2,8,23.86432342529297 +sdd,False,True,16,32,8,2,8,2,47.88221130371094 +sdd,False,True,16,32,16,4,8,2,29.615542602539065 +sdd,False,True,16,32,32,8,16,2,22.84046783447265 +sdd,False,True,16,32,64,4,4,4,21.95774688720703 +sdd,False,True,16,32,128,4,8,4,21.562205505371093 +sdd,False,True,16,64,8,2,16,2,58.48067626953125 +sdd,False,True,16,64,16,4,16,2,39.72404174804687 +sdd,False,True,16,64,32,8,4,2,33.36689147949219 +sdd,False,True,16,64,64,4,4,2,32.759228515625 +sdd,False,True,16,64,128,4,2,4,32.71089172363281 +sdd,False,True,16,128,8,2,16,2,77.55657348632812 +sdd,False,True,16,128,16,4,2,2,62.15914916992188 +sdd,False,True,16,128,32,4,4,2,55.14508056640625 +sdd,False,True,16,128,64,4,16,2,53.7142578125 +sdd,False,True,16,128,128,2,2,2,inf +sdd,False,True,32,8,8,8,4,2,30.496945190429688 +sdd,False,True,32,8,16,16,8,2,25.50102081298828 +sdd,False,True,32,8,32,16,4,4,24.51743621826172 +sdd,False,True,32,8,64,8,4,16,27.03175048828125 +sdd,False,True,32,8,128,8,2,16,34.25230102539062 +sdd,False,True,32,16,8,4,4,2,26.0266845703125 +sdd,False,True,32,16,16,8,2,2,17.680572509765625 +sdd,False,True,32,16,32,8,2,4,14.919331359863282 +sdd,False,True,32,16,64,8,16,4,14.73011474609375 +sdd,False,True,32,16,128,8,8,4,14.204704284667969 +sdd,False,True,32,32,8,4,16,2,36.43758544921875 +sdd,False,True,32,32,16,8,16,2,21.21480255126953 +sdd,False,True,32,32,32,8,16,2,16.7617919921875 +sdd,False,True,32,32,64,8,2,4,16.181004333496094 +sdd,False,True,32,32,128,16,16,2,15.448342895507812 +sdd,False,True,32,64,8,4,8,2,41.87442932128906 +sdd,False,True,32,64,16,8,4,2,26.48095703125 +sdd,False,True,32,64,32,4,16,2,22.530685424804688 +sdd,False,True,32,64,64,8,16,2,21.07642822265625 +sdd,False,True,32,64,128,8,2,4,21.36976013183594 +sdd,False,True,32,128,8,2,2,2,inf +sdd,False,True,32,128,16,2,2,2,inf +sdd,False,True,32,128,32,2,2,2,inf +sdd,False,True,32,128,64,2,2,2,inf +sdd,False,True,32,128,128,2,2,2,inf +sdd,False,True,64,8,8,16,2,2,24.09486389160156 +sdd,False,True,64,8,16,8,2,4,15.721728515625 +sdd,False,True,64,8,32,16,8,4,14.562950134277344 +sdd,False,True,64,8,64,8,2,4,13.80590362548828 +sdd,False,True,64,8,128,16,2,4,13.107994079589844 +sdd,False,True,64,16,8,8,2,2,21.141920471191405 +sdd,False,True,64,16,16,16,2,2,14.441622924804689 +sdd,False,True,64,16,32,8,2,4,13.028182983398438 +sdd,False,True,64,16,64,8,8,4,12.355359649658205 +sdd,False,True,64,16,128,16,4,4,12.280850982666015 +sdd,False,True,64,32,8,8,8,2,30.81255187988281 +sdd,False,True,64,32,16,8,16,2,18.20790405273437 +sdd,False,True,64,32,32,8,4,2,14.05712890625 +sdd,False,True,64,32,64,8,16,4,12.909103393554688 +sdd,False,True,64,32,128,8,2,4,13.388832092285156 +sdd,False,True,64,64,8,2,2,2,inf +sdd,False,True,64,64,16,2,2,2,inf +sdd,False,True,64,64,32,2,2,2,inf +sdd,False,True,64,64,64,2,2,2,inf +sdd,False,True,64,64,128,2,2,2,inf +sdd,False,True,64,128,8,2,2,2,inf +sdd,False,True,64,128,16,2,2,2,inf +sdd,False,True,64,128,32,2,2,2,inf +sdd,False,True,64,128,64,2,2,2,inf +sdd,False,True,64,128,128,2,2,2,inf +sdd,False,True,128,8,8,16,4,4,22.91182098388672 +sdd,False,True,128,8,16,16,2,4,13.245916748046875 +sdd,False,True,128,8,32,8,8,8,12.6768798828125 +sdd,False,True,128,8,64,8,8,8,11.811756896972655 +sdd,False,True,128,8,128,8,2,8,12.738713836669922 +sdd,False,True,128,16,8,16,8,2,18.884083557128907 +sdd,False,True,128,16,16,16,4,2,13.472808837890623 +sdd,False,True,128,16,32,8,16,4,12.021001434326172 +sdd,False,True,128,16,64,16,4,4,11.51459503173828 +sdd,False,True,128,16,128,8,4,8,12.582121276855467 +sdd,False,True,128,32,8,2,2,2,inf +sdd,False,True,128,32,16,2,2,2,inf +sdd,False,True,128,32,32,2,2,2,inf +sdd,False,True,128,32,64,2,2,2,inf +sdd,False,True,128,32,128,2,2,2,inf +sdd,False,True,128,64,8,2,2,2,inf +sdd,False,True,128,64,16,2,2,2,inf +sdd,False,True,128,64,32,2,2,2,inf +sdd,False,True,128,64,64,2,2,2,inf +sdd,False,True,128,64,128,2,2,2,inf +sdd,False,True,128,128,8,2,2,2,inf +sdd,False,True,128,128,16,2,2,2,inf +sdd,False,True,128,128,32,2,2,2,inf +sdd,False,True,128,128,64,2,2,2,inf +sdd,False,True,128,128,128,2,2,2,inf +sdd,True,False,8,8,8,2,2,2,73.9848388671875 +sdd,True,False,8,8,16,4,4,2,46.17842712402344 +sdd,True,False,8,8,32,8,4,2,37.888674926757815 +sdd,True,False,8,8,64,4,8,8,34.64777526855469 +sdd,True,False,8,8,128,2,2,2,inf +sdd,True,False,8,16,8,2,2,2,73.74010620117187 +sdd,True,False,8,16,16,4,4,2,46.40711059570312 +sdd,True,False,8,16,32,2,2,4,34.76233825683594 +sdd,True,False,8,16,64,2,2,8,32.09572143554688 +sdd,True,False,8,16,128,8,8,4,32.73079223632813 +sdd,True,False,8,32,8,4,8,2,86.20098876953125 +sdd,True,False,8,32,16,4,2,2,44.73761596679688 +sdd,True,False,8,32,32,4,2,2,37.705255126953126 +sdd,True,False,8,32,64,8,2,2,37.21822814941406 +sdd,True,False,8,32,128,8,4,2,38.06083374023437 +sdd,True,False,8,64,8,4,2,2,84.308056640625 +sdd,True,False,8,64,16,2,4,2,45.06987915039063 +sdd,True,False,8,64,32,2,8,2,40.30692443847656 +sdd,True,False,8,64,64,2,4,2,38.734326171875 +sdd,True,False,8,64,128,4,2,2,25.005967712402345 +sdd,True,False,8,128,8,2,8,2,70.26617431640625 +sdd,True,False,8,128,16,2,2,2,50.441259765625 +sdd,True,False,8,128,32,2,16,2,35.400308227539064 +sdd,True,False,8,128,64,2,4,2,39.35030517578125 +sdd,True,False,8,128,128,2,2,2,inf +sdd,True,False,16,8,8,4,4,2,51.97205810546875 +sdd,True,False,16,8,16,4,4,2,25.86021728515625 +sdd,True,False,16,8,32,8,2,2,19.348057556152344 +sdd,True,False,16,8,64,16,8,2,17.644924926757813 +sdd,True,False,16,8,128,16,8,4,16.097267150878906 +sdd,True,False,16,16,8,2,4,2,49.89132080078125 +sdd,True,False,16,16,16,4,2,2,26.679196166992188 +sdd,True,False,16,16,32,8,4,2,19.98988494873047 +sdd,True,False,16,16,64,16,4,2,17.62073669433594 +sdd,True,False,16,16,128,4,8,8,18.66569213867188 +sdd,True,False,16,32,8,2,2,2,53.354669189453126 +sdd,True,False,16,32,16,4,16,2,27.72266540527344 +sdd,True,False,16,32,32,8,8,2,18.98997802734375 +sdd,True,False,16,32,64,4,8,2,19.319593811035155 +sdd,True,False,16,32,128,2,2,8,18.159635925292967 +sdd,True,False,16,64,8,2,2,2,49.46534423828125 +sdd,True,False,16,64,16,4,2,2,26.9336669921875 +sdd,True,False,16,64,32,4,8,2,21.0331298828125 +sdd,True,False,16,64,64,4,2,2,19.843218994140624 +sdd,True,False,16,64,128,4,2,2,19.66313934326172 +sdd,True,False,16,128,8,2,16,2,51.30648193359375 +sdd,True,False,16,128,16,2,2,2,30.73262023925781 +sdd,True,False,16,128,32,2,8,2,23.573583984375 +sdd,True,False,16,128,64,2,2,2,22.738954162597658 +sdd,True,False,16,128,128,2,2,2,inf +sdd,True,False,32,8,8,8,2,2,59.884222412109374 +sdd,True,False,32,8,16,8,4,2,22.643101501464844 +sdd,True,False,32,8,32,8,8,4,15.67876739501953 +sdd,True,False,32,8,64,8,8,8,13.152056884765624 +sdd,True,False,32,8,128,8,8,8,12.867955017089844 +sdd,True,False,32,16,8,4,4,2,42.48688049316407 +sdd,True,False,32,16,16,8,16,2,22.31826171875 +sdd,True,False,32,16,32,8,16,4,15.291162109375 +sdd,True,False,32,16,64,8,4,4,13.823011779785157 +sdd,True,False,32,16,128,8,8,4,12.749072265625 +sdd,True,False,32,32,8,4,16,2,51.068521118164064 +sdd,True,False,32,32,16,8,2,2,26.098941040039065 +sdd,True,False,32,32,32,8,4,2,16.640992736816408 +sdd,True,False,32,32,64,8,4,4,14.561097717285156 +sdd,True,False,32,32,128,8,2,4,12.976034545898438 +sdd,True,False,32,64,8,4,16,2,50.79121704101563 +sdd,True,False,32,64,16,8,16,2,27.523394775390624 +sdd,True,False,32,64,32,4,2,2,18.135264587402343 +sdd,True,False,32,64,64,8,8,2,14.8186279296875 +sdd,True,False,32,64,128,8,16,2,13.958329772949218 +sdd,True,False,32,128,8,2,2,2,inf +sdd,True,False,32,128,16,2,2,2,inf +sdd,True,False,32,128,32,2,2,2,inf +sdd,True,False,32,128,64,2,2,2,inf +sdd,True,False,32,128,128,2,2,2,inf +sdd,True,False,64,8,8,16,4,2,96.2021240234375 +sdd,True,False,64,8,16,8,2,4,30.0603515625 +sdd,True,False,64,8,32,8,8,4,19.49802856445313 +sdd,True,False,64,8,64,8,2,4,15.104576110839844 +sdd,True,False,64,8,128,8,8,8,12.94129638671875 +sdd,True,False,64,16,8,8,4,2,58.28064575195312 +sdd,True,False,64,16,16,16,8,2,30.451214599609376 +sdd,True,False,64,16,32,8,2,4,19.571139526367187 +sdd,True,False,64,16,64,16,2,2,15.562969970703126 +sdd,True,False,64,16,128,16,4,4,13.360887145996092 +sdd,True,False,64,32,8,8,16,2,66.38638305664062 +sdd,True,False,64,32,16,8,8,2,34.95736999511719 +sdd,True,False,64,32,32,16,8,2,20.5401123046875 +sdd,True,False,64,32,64,8,16,4,15.419955444335937 +sdd,True,False,64,32,128,8,2,4,13.664102172851562 +sdd,True,False,64,64,8,2,2,2,inf +sdd,True,False,64,64,16,2,2,2,inf +sdd,True,False,64,64,32,2,2,2,inf +sdd,True,False,64,64,64,2,2,2,inf +sdd,True,False,64,64,128,2,2,2,inf +sdd,True,False,64,128,8,2,2,2,inf +sdd,True,False,64,128,16,2,2,2,inf +sdd,True,False,64,128,32,2,2,2,inf +sdd,True,False,64,128,64,2,2,2,inf +sdd,True,False,64,128,128,2,2,2,inf +sdd,True,False,128,8,8,2,2,2,inf +sdd,True,False,128,8,16,16,8,4,52.1150146484375 +sdd,True,False,128,8,32,8,8,8,30.83111572265625 +sdd,True,False,128,8,64,8,8,8,20.565753173828124 +sdd,True,False,128,8,128,16,2,8,16.8664794921875 +sdd,True,False,128,16,8,16,8,2,97.72430419921876 +sdd,True,False,128,16,16,16,2,2,51.40657958984375 +sdd,True,False,128,16,32,16,2,2,30.401171875 +sdd,True,False,128,16,64,16,4,4,20.604071044921877 +sdd,True,False,128,16,128,8,16,8,16.067298889160156 +sdd,True,False,128,32,8,2,2,2,inf +sdd,True,False,128,32,16,2,2,2,inf +sdd,True,False,128,32,32,2,2,2,inf +sdd,True,False,128,32,64,2,2,2,inf +sdd,True,False,128,32,128,2,2,2,inf +sdd,True,False,128,64,8,2,2,2,inf +sdd,True,False,128,64,16,2,2,2,inf +sdd,True,False,128,64,32,2,2,2,inf +sdd,True,False,128,64,64,2,2,2,inf +sdd,True,False,128,64,128,2,2,2,inf +sdd,True,False,128,128,8,2,2,2,inf +sdd,True,False,128,128,16,2,2,2,inf +sdd,True,False,128,128,32,2,2,2,inf +sdd,True,False,128,128,64,2,2,2,inf +sdd,True,False,128,128,128,2,2,2,inf +sdd,True,True,8,8,8,2,2,2,70.79592895507812 +sdd,True,True,8,8,16,2,4,4,65.1414794921875 +sdd,True,True,8,8,32,8,4,4,83.34755859375 +sdd,True,True,8,8,64,2,2,16,119.6942626953125 +sdd,True,True,8,8,128,8,4,8,92.86500854492188 +sdd,True,True,8,16,8,2,8,2,73.23809814453125 +sdd,True,True,8,16,16,4,4,2,50.65773010253906 +sdd,True,True,8,16,32,2,16,4,46.528628540039065 +sdd,True,True,8,16,64,8,8,4,44.46695556640625 +sdd,True,True,8,16,128,2,16,16,42.68679504394531 +sdd,True,True,8,32,8,2,16,2,96.91433715820312 +sdd,True,True,8,32,16,2,16,2,48.09185791015625 +sdd,True,True,8,32,32,4,4,2,39.37124938964844 +sdd,True,True,8,32,64,4,2,2,39.133636474609375 +sdd,True,True,8,32,128,8,8,2,40.11669006347656 +sdd,True,True,8,64,8,2,2,2,138.27012939453124 +sdd,True,True,8,64,16,2,16,2,67.56966552734374 +sdd,True,True,8,64,32,4,8,2,58.15477294921875 +sdd,True,True,8,64,64,4,4,2,56.105419921875 +sdd,True,True,8,64,128,4,8,2,56.26817626953125 +sdd,True,True,8,128,8,2,2,2,inf +sdd,True,True,8,128,16,2,16,2,108.03819580078124 +sdd,True,True,8,128,32,2,4,2,101.84613647460938 +sdd,True,True,8,128,64,2,16,2,99.8545654296875 +sdd,True,True,8,128,128,2,2,2,inf +sdd,True,True,16,8,8,4,2,2,53.96688232421875 +sdd,True,True,16,8,16,8,2,2,36.35191345214844 +sdd,True,True,16,8,32,8,2,8,48.81896362304688 +sdd,True,True,16,8,64,2,8,16,46.240542602539065 +sdd,True,True,16,8,128,8,2,8,42.79500122070313 +sdd,True,True,16,16,8,2,16,2,42.394622802734375 +sdd,True,True,16,16,16,4,4,2,27.01558532714844 +sdd,True,True,16,16,32,8,16,2,24.6191162109375 +sdd,True,True,16,16,64,8,2,4,23.42810821533203 +sdd,True,True,16,16,128,8,4,4,22.794557189941408 +sdd,True,True,16,32,8,2,16,2,55.01532592773437 +sdd,True,True,16,32,16,4,2,2,33.358612060546875 +sdd,True,True,16,32,32,8,2,2,23.708087158203124 +sdd,True,True,16,32,64,4,2,4,22.57019805908203 +sdd,True,True,16,32,128,4,4,4,21.4740966796875 +sdd,True,True,16,64,8,2,16,2,65.340087890625 +sdd,True,True,16,64,16,4,8,2,43.76452026367188 +sdd,True,True,16,64,32,8,16,2,34.87169189453125 +sdd,True,True,16,64,64,4,4,4,33.64231872558594 +sdd,True,True,16,64,128,4,2,4,32.75396728515625 +sdd,True,True,16,128,8,2,16,2,84.63165283203125 +sdd,True,True,16,128,16,4,2,2,65.06669311523437 +sdd,True,True,16,128,32,4,2,2,56.94905395507813 +sdd,True,True,16,128,64,4,4,2,55.36854858398438 +sdd,True,True,16,128,128,2,2,2,inf +sdd,True,True,32,8,8,8,8,2,61.72941284179687 +sdd,True,True,32,8,16,8,2,2,24.83875274658203 +sdd,True,True,32,8,32,8,4,2,20.58190460205078 +sdd,True,True,32,8,64,8,8,4,26.2571044921875 +sdd,True,True,32,8,128,8,4,16,26.21186828613281 +sdd,True,True,32,16,8,4,16,2,43.00454711914063 +sdd,True,True,32,16,16,8,4,2,23.92322540283203 +sdd,True,True,32,16,32,8,8,4,17.551664733886717 +sdd,True,True,32,16,64,8,2,4,15.735151672363282 +sdd,True,True,32,16,128,8,4,4,14.5494873046875 +sdd,True,True,32,32,8,4,8,2,53.545587158203126 +sdd,True,True,32,32,16,8,16,2,29.740457153320317 +sdd,True,True,32,32,32,8,4,2,20.164134216308597 +sdd,True,True,32,32,64,8,4,4,17.43151092529297 +sdd,True,True,32,32,128,16,4,2,16.06116180419922 +sdd,True,True,32,64,8,4,16,2,58.31881713867188 +sdd,True,True,32,64,16,8,16,2,34.770635986328124 +sdd,True,True,32,64,32,8,8,2,26.29350891113281 +sdd,True,True,32,64,64,8,16,2,22.838330078125 +sdd,True,True,32,64,128,8,2,4,22.09839630126953 +sdd,True,True,32,128,8,2,2,2,inf +sdd,True,True,32,128,16,2,2,2,inf +sdd,True,True,32,128,32,2,2,2,inf +sdd,True,True,32,128,64,2,2,2,inf +sdd,True,True,32,128,128,2,2,2,inf +sdd,True,True,64,8,8,16,2,2,97.34681396484376 +sdd,True,True,64,8,16,8,8,4,30.87897644042969 +sdd,True,True,64,8,32,8,8,8,20.65984649658203 +sdd,True,True,64,8,64,8,2,4,16.07484130859375 +sdd,True,True,64,8,128,16,8,4,13.826019287109377 +sdd,True,True,64,16,8,8,2,2,58.582769775390624 +sdd,True,True,64,16,16,16,8,2,31.01230773925781 +sdd,True,True,64,16,32,8,2,4,20.705232238769533 +sdd,True,True,64,16,64,8,16,4,16.316441345214844 +sdd,True,True,64,16,128,16,8,4,14.138400268554689 +sdd,True,True,64,32,8,8,8,2,67.5900146484375 +sdd,True,True,64,32,16,16,4,2,36.62651977539063 +sdd,True,True,64,32,32,16,8,2,22.28140869140625 +sdd,True,True,64,32,64,8,16,4,17.178799438476563 +sdd,True,True,64,32,128,8,4,4,15.548410034179687 +sdd,True,True,64,64,8,2,2,2,inf +sdd,True,True,64,64,16,2,2,2,inf +sdd,True,True,64,64,32,2,2,2,inf +sdd,True,True,64,64,64,2,2,2,inf +sdd,True,True,64,64,128,2,2,2,inf +sdd,True,True,64,128,8,2,2,2,inf +sdd,True,True,64,128,16,2,2,2,inf +sdd,True,True,64,128,32,2,2,2,inf +sdd,True,True,64,128,64,2,2,2,inf +sdd,True,True,64,128,128,2,2,2,inf +sdd,True,True,128,8,8,2,2,2,inf +sdd,True,True,128,8,16,16,8,4,52.57091064453125 +sdd,True,True,128,8,32,8,8,8,31.27769775390625 +sdd,True,True,128,8,64,16,8,4,20.810726928710935 +sdd,True,True,128,8,128,16,8,8,17.33007354736328 +sdd,True,True,128,16,8,16,16,2,97.82230224609376 +sdd,True,True,128,16,16,16,2,2,52.05543212890625 +sdd,True,True,128,16,32,16,2,2,30.84331970214844 +sdd,True,True,128,16,64,16,4,4,21.341754150390624 +sdd,True,True,128,16,128,4,16,16,17.054681396484376 +sdd,True,True,128,32,8,2,2,2,inf +sdd,True,True,128,32,16,2,2,2,inf +sdd,True,True,128,32,32,2,2,2,inf +sdd,True,True,128,32,64,2,2,2,inf +sdd,True,True,128,32,128,2,2,2,inf +sdd,True,True,128,64,8,2,2,2,inf +sdd,True,True,128,64,16,2,2,2,inf +sdd,True,True,128,64,32,2,2,2,inf +sdd,True,True,128,64,64,2,2,2,inf +sdd,True,True,128,64,128,2,2,2,inf +sdd,True,True,128,128,8,2,2,2,inf +sdd,True,True,128,128,16,2,2,2,inf +sdd,True,True,128,128,32,2,2,2,inf +sdd,True,True,128,128,64,2,2,2,inf +sdd,True,True,128,128,128,2,2,2,inf diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.70.csv b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.70.csv new file mode 100644 index 00000000..76905a84 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.70.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,1,2.397910308837891 +8,16,1,1.357868766784668 +8,32,8,0.8915360450744629 +8,64,8,0.9132736206054688 +8,128,8,0.9273440361022948 +16,8,1,2.307228851318359 +16,16,1,1.3961952209472657 +16,32,8,0.8938655853271484 +16,64,8,0.912384033203125 +16,128,8,0.9198207855224608 +32,8,1,2.3392576217651366 +32,16,1,1.3536288261413574 +32,32,8,0.9012639999389648 +32,64,8,0.9076224327087402 +32,128,8,0.9179776191711426 +64,8,1,2.295167922973633 +64,16,1,1.3171648025512694 +64,32,8,0.891209602355957 +64,64,8,0.9077119827270508 +64,128,8,0.927840042114258 +128,8,1,2.5166624069213865 +128,16,1,1.517516803741455 +128,32,8,1.0153599739074708 +128,64,16,1.0029184341430664 +128,128,8,0.9450048446655274 diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.default.csv b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.default.csv new file mode 100644 index 00000000..76905a84 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.default.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,1,2.397910308837891 +8,16,1,1.357868766784668 +8,32,8,0.8915360450744629 +8,64,8,0.9132736206054688 +8,128,8,0.9273440361022948 +16,8,1,2.307228851318359 +16,16,1,1.3961952209472657 +16,32,8,0.8938655853271484 +16,64,8,0.912384033203125 +16,128,8,0.9198207855224608 +32,8,1,2.3392576217651366 +32,16,1,1.3536288261413574 +32,32,8,0.9012639999389648 +32,64,8,0.9076224327087402 +32,128,8,0.9179776191711426 +64,8,1,2.295167922973633 +64,16,1,1.3171648025512694 +64,32,8,0.891209602355957 +64,64,8,0.9077119827270508 +64,128,8,0.927840042114258 +128,8,1,2.5166624069213865 +128,16,1,1.517516803741455 +128,32,8,1.0153599739074708 +128,64,16,1.0029184341430664 +128,128,8,0.9450048446655274 diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.70.csv b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.70.csv new file mode 100644 index 00000000..b0f6fc89 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.70.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,8,1.279923152923584 +8,16,8,0.99749755859375 +8,32,1,1.0297887802124024 +8,64,1,1.0467743873596191 +8,128,1,1.0584704399108886 +16,8,8,1.2794048309326171 +16,16,8,1.0008064270019532 +16,32,1,1.0202848434448242 +16,64,1,1.0429951667785644 +16,128,1,1.0366432189941406 +32,8,8,1.2889599800109863 +32,16,8,1.001814365386963 +32,32,1,1.0289952278137209 +32,64,1,1.0377951622009278 +32,128,1,1.0327327728271485 +64,8,8,1.2794719696044925 +64,16,8,1.0031680107116698 +64,32,1,1.0151167869567872 +64,64,1,1.0283136367797852 +64,128,1,1.0416159629821775 +128,8,16,1.3982175827026366 +128,16,8,1.114236831665039 +128,32,1,1.099283218383789 +128,64,1,1.1017439842224122 +128,128,1,1.077948760986328 diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.default.csv b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.default.csv new file mode 100644 index 00000000..b0f6fc89 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.default.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,8,1.279923152923584 +8,16,8,0.99749755859375 +8,32,1,1.0297887802124024 +8,64,1,1.0467743873596191 +8,128,1,1.0584704399108886 +16,8,8,1.2794048309326171 +16,16,8,1.0008064270019532 +16,32,1,1.0202848434448242 +16,64,1,1.0429951667785644 +16,128,1,1.0366432189941406 +32,8,8,1.2889599800109863 +32,16,8,1.001814365386963 +32,32,1,1.0289952278137209 +32,64,1,1.0377951622009278 +32,128,1,1.0327327728271485 +64,8,8,1.2794719696044925 +64,16,8,1.0031680107116698 +64,32,1,1.0151167869567872 +64,64,1,1.0283136367797852 +64,128,1,1.0416159629821775 +128,8,16,1.3982175827026366 +128,16,8,1.114236831665039 +128,32,1,1.099283218383789 +128,64,1,1.1017439842224122 +128,128,1,1.077948760986328 From bd11b5f75e7b4cb99192935fff5aba9e2e48167d Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Mon, 6 Feb 2023 13:13:47 +0800 Subject: [PATCH 10/28] add 75 LUTs --- .../look_up_tables/matmul.sparta.75.csv | 1501 +++++++++++++++++ .../softmax.backward.sparta.75.csv | 26 + .../softmax.forward.sparta.75.csv | 26 + 3 files changed, 1553 insertions(+) create mode 100644 sparta/specializer/kernels/look_up_tables/matmul.sparta.75.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.75.csv create mode 100644 sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.75.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.75.csv b/sparta/specializer/kernels/look_up_tables/matmul.sparta.75.csv new file mode 100644 index 00000000..8f94c56c --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/matmul.sparta.75.csv @@ -0,0 +1,1501 @@ +mode,trans_A,trans_B,BM,BK,BN,TM,TK,TN,latency +dds,False,False,8,8,8,2,4,2,137.782421875 +dds,False,False,8,8,16,4,4,2,81.21327514648438 +dds,False,False,8,8,32,4,8,4,58.17118530273437 +dds,False,False,8,8,64,4,8,8,53.44810180664062 +dds,False,False,8,8,128,2,2,2,inf +dds,False,False,8,16,8,2,2,2,127.66444091796876 +dds,False,False,8,16,16,4,4,2,75.37125854492187 +dds,False,False,8,16,32,8,2,2,57.64833374023438 +dds,False,False,8,16,64,8,2,2,56.29953002929688 +dds,False,False,8,16,128,4,2,8,52.68043823242188 +dds,False,False,8,32,8,4,2,2,128.90439453125 +dds,False,False,8,32,16,4,16,2,77.77259521484375 +dds,False,False,8,32,32,2,8,4,57.87215576171875 +dds,False,False,8,32,64,8,8,2,56.64256591796875 +dds,False,False,8,32,128,8,4,4,35.3204833984375 +dds,False,False,8,64,8,2,8,2,140.40911865234375 +dds,False,False,8,64,16,2,2,2,78.588525390625 +dds,False,False,8,64,32,4,4,2,59.27306518554688 +dds,False,False,8,64,64,8,2,2,40.53104553222656 +dds,False,False,8,64,128,4,4,2,39.970193481445314 +dds,False,False,8,128,8,2,2,2,inf +dds,False,False,8,128,16,2,16,2,78.0264892578125 +dds,False,False,8,128,32,4,16,2,59.5465576171875 +dds,False,False,8,128,64,2,2,2,57.46602783203125 +dds,False,False,8,128,128,2,2,2,inf +dds,False,False,16,8,8,4,2,2,77.00511474609375 +dds,False,False,16,8,16,4,8,2,43.67920532226562 +dds,False,False,16,8,32,4,8,4,29.630026245117183 +dds,False,False,16,8,64,8,8,4,25.52135620117188 +dds,False,False,16,8,128,16,4,4,22.91919708251953 +dds,False,False,16,16,8,2,2,2,72.24513549804688 +dds,False,False,16,16,16,4,8,2,39.28416442871094 +dds,False,False,16,16,32,8,4,2,29.208779907226564 +dds,False,False,16,16,64,16,2,2,25.984786987304688 +dds,False,False,16,16,128,16,16,4,27.671630859375 +dds,False,False,16,32,8,2,2,2,78.5121337890625 +dds,False,False,16,32,16,4,16,2,43.13347473144531 +dds,False,False,16,32,32,8,4,2,30.027035522460935 +dds,False,False,16,32,64,4,2,2,29.388629150390624 +dds,False,False,16,32,128,8,2,2,27.249371337890626 +dds,False,False,16,64,8,2,2,2,78.403759765625 +dds,False,False,16,64,16,4,16,2,42.969488525390624 +dds,False,False,16,64,32,4,4,2,30.39943542480469 +dds,False,False,16,64,64,4,16,2,29.69788513183594 +dds,False,False,16,64,128,4,16,4,24.464175415039065 +dds,False,False,16,128,8,2,2,2,79.0888916015625 +dds,False,False,16,128,16,4,2,2,42.49003295898437 +dds,False,False,16,128,32,4,8,2,32.81447143554688 +dds,False,False,16,128,64,4,2,2,30.41606750488281 +dds,False,False,16,128,128,2,2,2,inf +dds,False,False,32,8,8,8,2,2,52.80548706054688 +dds,False,False,32,8,16,4,2,4,31.57626953125 +dds,False,False,32,8,32,8,8,4,19.97953948974609 +dds,False,False,32,8,64,16,4,4,15.821974182128908 +dds,False,False,32,8,128,2,2,2,inf +dds,False,False,32,16,8,4,2,2,51.09674377441407 +dds,False,False,32,16,16,8,4,2,27.853668212890625 +dds,False,False,32,16,32,8,16,4,18.36992645263672 +dds,False,False,32,16,64,8,8,8,15.766653442382813 +dds,False,False,32,16,128,2,2,2,inf +dds,False,False,32,32,8,4,2,2,59.74097900390625 +dds,False,False,32,32,16,8,4,2,32.705108642578125 +dds,False,False,32,32,32,16,16,2,19.0359619140625 +dds,False,False,32,32,64,8,2,4,16.518963623046876 +dds,False,False,32,32,128,2,2,2,inf +dds,False,False,32,64,8,4,2,2,60.2196533203125 +dds,False,False,32,64,16,8,2,2,32.974847412109376 +dds,False,False,32,64,32,16,4,2,20.70876159667969 +dds,False,False,32,64,64,4,16,4,19.525535583496094 +dds,False,False,32,64,128,2,2,2,inf +dds,False,False,32,128,8,4,16,2,60.68859252929688 +dds,False,False,32,128,16,4,2,2,37.68844909667969 +dds,False,False,32,128,32,4,16,2,25.16766662597656 +dds,False,False,32,128,64,8,4,2,21.745401000976564 +dds,False,False,32,128,128,2,2,2,inf +dds,False,False,64,8,8,16,2,2,45.05986328125 +dds,False,False,64,8,16,8,2,4,25.698541259765623 +dds,False,False,64,8,32,8,4,8,17.289830017089844 +dds,False,False,64,8,64,2,2,2,inf +dds,False,False,64,8,128,2,2,2,inf +dds,False,False,64,16,8,8,2,2,41.67466735839844 +dds,False,False,64,16,16,16,16,2,22.591337585449217 +dds,False,False,64,16,32,16,8,4,15.5552001953125 +dds,False,False,64,16,64,2,2,2,inf +dds,False,False,64,16,128,2,2,2,inf +dds,False,False,64,32,8,8,2,2,52.57105712890625 +dds,False,False,64,32,16,16,16,2,27.855935668945317 +dds,False,False,64,32,32,16,2,2,17.336764526367187 +dds,False,False,64,32,64,2,2,2,inf +dds,False,False,64,32,128,2,2,2,inf +dds,False,False,64,64,8,8,16,2,51.7627197265625 +dds,False,False,64,64,16,16,8,2,29.431808471679688 +dds,False,False,64,64,32,8,4,2,19.314483642578125 +dds,False,False,64,64,64,2,2,2,inf +dds,False,False,64,64,128,2,2,2,inf +dds,False,False,64,128,8,2,2,2,73.40010375976563 +dds,False,False,64,128,16,4,2,2,38.94587097167969 +dds,False,False,64,128,32,8,4,2,22.064947509765624 +dds,False,False,64,128,64,2,2,2,inf +dds,False,False,64,128,128,2,2,2,inf +dds,False,False,128,8,8,16,2,4,41.38941650390625 +dds,False,False,128,8,16,16,2,4,22.06685485839844 +dds,False,False,128,8,32,2,2,2,inf +dds,False,False,128,8,64,2,2,2,inf +dds,False,False,128,8,128,2,2,2,inf +dds,False,False,128,16,8,16,2,2,37.3514892578125 +dds,False,False,128,16,16,16,8,4,21.61268768310547 +dds,False,False,128,16,32,2,2,2,inf +dds,False,False,128,16,64,2,2,2,inf +dds,False,False,128,16,128,2,2,2,inf +dds,False,False,128,32,8,16,16,2,48.12799987792969 +dds,False,False,128,32,16,16,2,2,27.00267639160156 +dds,False,False,128,32,32,2,2,2,inf +dds,False,False,128,32,64,2,2,2,inf +dds,False,False,128,32,128,2,2,2,inf +dds,False,False,128,64,8,4,16,4,60.2175048828125 +dds,False,False,128,64,16,8,2,2,32.68710021972656 +dds,False,False,128,64,32,2,2,2,inf +dds,False,False,128,64,64,2,2,2,inf +dds,False,False,128,64,128,2,2,2,inf +dds,False,False,128,128,8,2,2,2,inf +dds,False,False,128,128,16,2,2,2,inf +dds,False,False,128,128,32,2,2,2,inf +dds,False,False,128,128,64,2,2,2,inf +dds,False,False,128,128,128,2,2,2,inf +dds,False,True,8,8,8,2,2,2,116.650390625 +dds,False,True,8,8,16,4,8,2,96.10188598632811 +dds,False,True,8,8,32,2,4,8,99.15728759765624 +dds,False,True,8,8,64,2,4,16,141.58489990234375 +dds,False,True,8,8,128,8,2,16,194.570849609375 +dds,False,True,8,16,8,2,2,2,119.985302734375 +dds,False,True,8,16,16,2,2,2,90.74093627929688 +dds,False,True,8,16,32,8,16,2,83.85215454101562 +dds,False,True,8,16,64,8,16,4,80.86709594726562 +dds,False,True,8,16,128,2,4,16,79.41593017578126 +dds,False,True,8,32,8,2,2,2,155.47176513671874 +dds,False,True,8,32,16,2,4,2,83.863671875 +dds,False,True,8,32,32,4,8,2,64.37730102539062 +dds,False,True,8,32,64,4,8,4,59.86529541015625 +dds,False,True,8,32,128,8,4,4,60.0380615234375 +dds,False,True,8,64,8,2,2,2,220.03232421875 +dds,False,True,8,64,16,2,2,2,113.5142578125 +dds,False,True,8,64,32,4,8,2,95.73397827148438 +dds,False,True,8,64,64,8,4,2,92.929345703125 +dds,False,True,8,64,128,4,8,2,95.42840576171876 +dds,False,True,8,128,8,2,2,2,inf +dds,False,True,8,128,16,2,4,2,180.6630615234375 +dds,False,True,8,128,32,4,16,2,165.9755615234375 +dds,False,True,8,128,64,4,8,2,171.0952392578125 +dds,False,True,8,128,128,2,2,2,inf +dds,False,True,16,8,8,4,2,2,74.5037109375 +dds,False,True,16,8,16,4,8,2,50.4748046875 +dds,False,True,16,8,32,2,2,8,50.283041381835936 +dds,False,True,16,8,64,2,4,16,57.63734741210938 +dds,False,True,16,8,128,8,4,16,81.13134155273437 +dds,False,True,16,16,8,2,2,2,66.37405395507812 +dds,False,True,16,16,16,4,16,2,46.44863586425781 +dds,False,True,16,16,32,4,4,4,42.01728515625 +dds,False,True,16,16,64,4,8,8,35.7154541015625 +dds,False,True,16,16,128,4,2,16,40.30254821777344 +dds,False,True,16,32,8,2,2,2,82.6472412109375 +dds,False,True,16,32,16,4,8,2,53.80067138671875 +dds,False,True,16,32,32,8,8,2,38.530355834960936 +dds,False,True,16,32,64,8,2,4,34.18015747070312 +dds,False,True,16,32,128,8,8,4,33.76231384277344 +dds,False,True,16,64,8,2,2,2,96.83884887695312 +dds,False,True,16,64,16,4,2,2,68.49608154296875 +dds,False,True,16,64,32,8,4,2,55.22697143554687 +dds,False,True,16,64,64,8,2,2,53.000177001953126 +dds,False,True,16,64,128,8,16,2,52.92344360351562 +dds,False,True,16,128,8,2,4,2,130.4015380859375 +dds,False,True,16,128,16,4,16,2,105.049951171875 +dds,False,True,16,128,32,4,16,2,92.82927856445312 +dds,False,True,16,128,64,4,2,2,92.16082153320312 +dds,False,True,16,128,128,2,2,2,inf +dds,False,True,32,8,8,8,2,2,53.98472290039062 +dds,False,True,32,8,16,8,2,2,34.580419921875 +dds,False,True,32,8,32,8,8,4,27.266140747070317 +dds,False,True,32,8,64,4,2,16,26.711062622070312 +dds,False,True,32,8,128,2,2,2,inf +dds,False,True,32,16,8,4,2,2,50.495458984375 +dds,False,True,32,16,16,8,2,2,30.542901611328126 +dds,False,True,32,16,32,8,8,4,23.12939453125 +dds,False,True,32,16,64,16,8,4,21.41405792236328 +dds,False,True,32,16,128,2,2,2,inf +dds,False,True,32,32,8,4,2,2,62.9032958984375 +dds,False,True,32,32,16,8,2,2,38.04856262207032 +dds,False,True,32,32,32,16,4,2,25.37069091796875 +dds,False,True,32,32,64,8,2,4,22.618585205078126 +dds,False,True,32,32,128,2,2,2,inf +dds,False,True,32,64,8,4,8,2,70.26052856445312 +dds,False,True,32,64,16,8,8,2,46.60901794433594 +dds,False,True,32,64,32,8,2,2,35.37620239257812 +dds,False,True,32,64,64,8,8,2,33.871795654296875 +dds,False,True,32,64,128,2,2,2,inf +dds,False,True,32,128,8,2,2,2,94.4341796875 +dds,False,True,32,128,16,4,8,2,67.17088623046875 +dds,False,True,32,128,32,8,2,2,52.83756713867187 +dds,False,True,32,128,64,4,16,4,51.19150085449219 +dds,False,True,32,128,128,2,2,2,inf +dds,False,True,64,8,8,16,2,2,45.628598022460935 +dds,False,True,64,8,16,8,8,4,27.13495178222656 +dds,False,True,64,8,32,16,4,4,19.34381103515625 +dds,False,True,64,8,64,2,2,2,inf +dds,False,True,64,8,128,2,2,2,inf +dds,False,True,64,16,8,8,2,2,41.39543762207031 +dds,False,True,64,16,16,16,8,2,24.13658905029297 +dds,False,True,64,16,32,16,16,4,17.45899200439453 +dds,False,True,64,16,64,2,2,2,inf +dds,False,True,64,16,128,2,2,2,inf +dds,False,True,64,32,8,8,2,2,53.82050170898437 +dds,False,True,64,32,16,16,16,2,30.43846740722656 +dds,False,True,64,32,32,16,8,2,20.72993927001953 +dds,False,True,64,32,64,2,2,2,inf +dds,False,True,64,32,128,2,2,2,inf +dds,False,True,64,64,8,4,8,2,61.343194580078126 +dds,False,True,64,64,16,16,4,2,36.582061767578125 +dds,False,True,64,64,32,8,8,2,26.40849914550781 +dds,False,True,64,64,64,2,2,2,inf +dds,False,True,64,64,128,2,2,2,inf +dds,False,True,64,128,8,2,16,2,89.57435302734375 +dds,False,True,64,128,16,4,2,2,55.05433349609375 +dds,False,True,64,128,32,8,16,2,37.877691650390624 +dds,False,True,64,128,64,2,2,2,inf +dds,False,True,64,128,128,2,2,2,inf +dds,False,True,128,8,8,16,2,4,41.845391845703126 +dds,False,True,128,8,16,16,2,4,22.87071990966797 +dds,False,True,128,8,32,2,2,2,inf +dds,False,True,128,8,64,2,2,2,inf +dds,False,True,128,8,128,2,2,2,inf +dds,False,True,128,16,8,16,2,2,37.15332946777344 +dds,False,True,128,16,16,16,4,4,22.21670379638672 +dds,False,True,128,16,32,2,2,2,inf +dds,False,True,128,16,64,2,2,2,inf +dds,False,True,128,16,128,2,2,2,inf +dds,False,True,128,32,8,16,8,2,48.627114868164064 +dds,False,True,128,32,16,16,2,2,28.17583923339844 +dds,False,True,128,32,32,2,2,2,inf +dds,False,True,128,32,64,2,2,2,inf +dds,False,True,128,32,128,2,2,2,inf +dds,False,True,128,64,8,4,2,2,65.926904296875 +dds,False,True,128,64,16,8,2,2,36.267822265625 +dds,False,True,128,64,32,2,2,2,inf +dds,False,True,128,64,64,2,2,2,inf +dds,False,True,128,64,128,2,2,2,inf +dds,False,True,128,128,8,2,2,2,inf +dds,False,True,128,128,16,2,2,2,inf +dds,False,True,128,128,32,2,2,2,inf +dds,False,True,128,128,64,2,2,2,inf +dds,False,True,128,128,128,2,2,2,inf +dds,True,False,8,8,8,2,4,2,141.81129150390626 +dds,True,False,8,8,16,4,2,2,82.14745483398437 +dds,True,False,8,8,32,8,4,2,58.6197021484375 +dds,True,False,8,8,64,4,4,8,53.07578125 +dds,True,False,8,8,128,2,2,2,inf +dds,True,False,8,16,8,2,4,2,136.486279296875 +dds,True,False,8,16,16,2,16,2,77.42894287109375 +dds,True,False,8,16,32,8,2,2,58.02311401367187 +dds,True,False,8,16,64,8,2,2,56.43548583984375 +dds,True,False,8,16,128,4,2,8,53.762457275390624 +dds,True,False,8,32,8,2,2,2,142.93323974609376 +dds,True,False,8,32,16,2,2,2,82.658251953125 +dds,True,False,8,32,32,4,2,2,58.54985961914063 +dds,True,False,8,32,64,4,8,4,57.65079956054687 +dds,True,False,8,32,128,4,4,8,39.97030029296875 +dds,True,False,8,64,8,2,2,2,148.51365966796874 +dds,True,False,8,64,16,2,8,2,83.1766357421875 +dds,True,False,8,64,32,4,4,2,59.449554443359375 +dds,True,False,8,64,64,8,16,2,56.786688232421874 +dds,True,False,8,64,128,4,4,2,40.5866455078125 +dds,True,False,8,128,8,2,2,2,144.783544921875 +dds,True,False,8,128,16,2,16,2,81.474560546875 +dds,True,False,8,128,32,2,2,2,59.823046875 +dds,True,False,8,128,64,2,2,2,55.313385009765625 +dds,True,False,8,128,128,2,2,2,inf +dds,True,False,16,8,8,4,2,2,89.65245361328125 +dds,True,False,16,8,16,4,2,2,44.886016845703125 +dds,True,False,16,8,32,8,2,2,29.844512939453125 +dds,True,False,16,8,64,8,8,4,25.365597534179688 +dds,True,False,16,8,128,16,4,4,21.949267578125 +dds,True,False,16,16,8,2,2,2,78.07701416015625 +dds,True,False,16,16,16,4,8,2,42.8124755859375 +dds,True,False,16,16,32,8,16,2,29.851785278320317 +dds,True,False,16,16,64,8,16,4,25.923660278320312 +dds,True,False,16,16,128,16,8,4,28.110232543945312 +dds,True,False,16,32,8,2,2,2,92.39326171875 +dds,True,False,16,32,16,4,16,2,50.35707702636719 +dds,True,False,16,32,32,8,4,2,30.1670166015625 +dds,True,False,16,32,64,16,2,2,29.626019287109376 +dds,True,False,16,32,128,8,2,2,26.2468994140625 +dds,True,False,16,64,8,2,2,2,93.433447265625 +dds,True,False,16,64,16,4,8,2,49.92449951171875 +dds,True,False,16,64,32,8,4,2,31.033657836914063 +dds,True,False,16,64,64,4,16,2,30.16905822753906 +dds,True,False,16,64,128,4,16,4,25.200038146972656 +dds,True,False,16,128,8,2,2,2,92.0052734375 +dds,True,False,16,128,16,4,8,2,49.56792297363281 +dds,True,False,16,128,32,4,16,2,34.03858337402344 +dds,True,False,16,128,64,4,2,2,31.314492797851564 +dds,True,False,16,128,128,2,2,2,inf +dds,True,False,32,8,8,8,2,2,99.0732177734375 +dds,True,False,32,8,16,8,8,2,39.013021850585936 +dds,True,False,32,8,32,8,2,4,23.90037384033203 +dds,True,False,32,8,64,16,4,4,17.611180114746094 +dds,True,False,32,8,128,2,2,2,inf +dds,True,False,32,16,8,4,2,2,73.20350952148438 +dds,True,False,32,16,16,8,16,2,38.88066711425781 +dds,True,False,32,16,32,8,8,4,23.941529846191408 +dds,True,False,32,16,64,8,4,8,17.943696594238283 +dds,True,False,32,16,128,2,2,2,inf +dds,True,False,32,32,8,4,8,2,88.30719604492188 +dds,True,False,32,32,16,8,2,2,46.75213623046875 +dds,True,False,32,32,32,16,2,2,25.38228454589844 +dds,True,False,32,32,64,8,4,4,19.16673583984375 +dds,True,False,32,32,128,2,2,2,inf +dds,True,False,32,64,8,4,2,2,87.69459838867188 +dds,True,False,32,64,16,8,8,2,46.815863037109374 +dds,True,False,32,64,32,8,8,4,27.17962951660156 +dds,True,False,32,64,64,8,16,2,22.669293212890626 +dds,True,False,32,64,128,2,2,2,inf +dds,True,False,32,128,8,4,2,2,87.91981201171875 +dds,True,False,32,128,16,4,4,2,50.69249267578125 +dds,True,False,32,128,32,8,4,2,30.337432861328125 +dds,True,False,32,128,64,8,2,2,23.933750915527344 +dds,True,False,32,128,128,2,2,2,inf +dds,True,False,64,8,8,16,4,2,157.3812744140625 +dds,True,False,64,8,16,8,2,4,48.66491088867188 +dds,True,False,64,8,32,8,4,8,28.486843872070317 +dds,True,False,64,8,64,2,2,2,inf +dds,True,False,64,8,128,2,2,2,inf +dds,True,False,64,16,8,8,8,2,97.76129760742188 +dds,True,False,64,16,16,16,16,2,50.456939697265625 +dds,True,False,64,16,32,16,8,4,28.66033935546875 +dds,True,False,64,16,64,2,2,2,inf +dds,True,False,64,16,128,2,2,2,inf +dds,True,False,64,32,8,8,4,2,114.5839111328125 +dds,True,False,64,32,16,16,16,2,58.21708984375 +dds,True,False,64,32,32,16,8,2,31.95519104003906 +dds,True,False,64,32,64,2,2,2,inf +dds,True,False,64,32,128,2,2,2,inf +dds,True,False,64,64,8,8,2,2,115.44453125 +dds,True,False,64,64,16,16,8,2,59.109521484375 +dds,True,False,64,64,32,8,8,2,33.661102294921875 +dds,True,False,64,64,64,2,2,2,inf +dds,True,False,64,64,128,2,2,2,inf +dds,True,False,64,128,8,4,16,2,119.41300048828126 +dds,True,False,64,128,16,8,4,2,63.08079833984375 +dds,True,False,64,128,32,8,8,2,33.90361633300781 +dds,True,False,64,128,64,2,2,2,inf +dds,True,False,64,128,128,2,2,2,inf +dds,True,False,128,8,8,2,2,2,inf +dds,True,False,128,8,16,16,2,4,78.7171142578125 +dds,True,False,128,8,32,2,2,2,inf +dds,True,False,128,8,64,2,2,2,inf +dds,True,False,128,8,128,2,2,2,inf +dds,True,False,128,16,8,16,2,2,162.1299072265625 +dds,True,False,128,16,16,16,8,2,82.79563598632812 +dds,True,False,128,16,32,2,2,2,inf +dds,True,False,128,16,64,2,2,2,inf +dds,True,False,128,16,128,2,2,2,inf +dds,True,False,128,32,8,16,8,2,180.652783203125 +dds,True,False,128,32,16,16,16,2,91.19170532226562 +dds,True,False,128,32,32,2,2,2,inf +dds,True,False,128,32,64,2,2,2,inf +dds,True,False,128,32,128,2,2,2,inf +dds,True,False,128,64,8,8,16,2,182.3877075195313 +dds,True,False,128,64,16,8,4,2,95.69812622070312 +dds,True,False,128,64,32,2,2,2,inf +dds,True,False,128,64,64,2,2,2,inf +dds,True,False,128,64,128,2,2,2,inf +dds,True,False,128,128,8,2,2,2,inf +dds,True,False,128,128,16,2,2,2,inf +dds,True,False,128,128,32,2,2,2,inf +dds,True,False,128,128,64,2,2,2,inf +dds,True,False,128,128,128,2,2,2,inf +dds,True,True,8,8,8,2,2,2,127.7572998046875 +dds,True,True,8,8,16,4,2,2,98.1478515625 +dds,True,True,8,8,32,2,2,8,105.85396728515624 +dds,True,True,8,8,64,4,8,16,145.3745849609375 +dds,True,True,8,8,128,8,2,16,185.062158203125 +dds,True,True,8,16,8,2,2,2,139.81124267578124 +dds,True,True,8,16,16,2,16,2,91.10975341796876 +dds,True,True,8,16,32,8,8,2,85.40437622070313 +dds,True,True,8,16,64,4,16,8,83.19768676757812 +dds,True,True,8,16,128,2,16,16,81.06796264648438 +dds,True,True,8,32,8,2,2,2,179.30546875 +dds,True,True,8,32,16,2,4,2,91.29049682617188 +dds,True,True,8,32,32,4,16,2,68.43511352539062 +dds,True,True,8,32,64,4,8,4,60.7952392578125 +dds,True,True,8,32,128,8,4,4,60.4896484375 +dds,True,True,8,64,8,2,16,2,244.8777587890625 +dds,True,True,8,64,16,2,2,2,122.7343505859375 +dds,True,True,8,64,32,4,8,2,99.85744018554688 +dds,True,True,8,64,64,4,4,4,95.29630126953126 +dds,True,True,8,64,128,4,16,2,95.92606811523436 +dds,True,True,8,128,8,2,2,2,inf +dds,True,True,8,128,16,2,2,2,189.3578125 +dds,True,True,8,128,32,4,16,2,169.03778076171875 +dds,True,True,8,128,64,4,16,2,172.27548828125 +dds,True,True,8,128,128,2,2,2,inf +dds,True,True,16,8,8,4,2,2,94.51527099609376 +dds,True,True,16,8,16,4,4,2,51.60560913085938 +dds,True,True,16,8,32,8,4,4,53.02169799804688 +dds,True,True,16,8,64,4,8,16,55.79038696289062 +dds,True,True,16,8,128,8,4,16,86.81287841796875 +dds,True,True,16,16,8,2,2,2,77.85866088867188 +dds,True,True,16,16,16,4,8,2,49.15158996582032 +dds,True,True,16,16,32,8,8,2,42.97212524414063 +dds,True,True,16,16,64,16,16,2,37.35468139648437 +dds,True,True,16,16,128,4,8,16,40.749514770507815 +dds,True,True,16,32,8,2,2,2,98.41561279296874 +dds,True,True,16,32,16,4,2,2,61.32503662109375 +dds,True,True,16,32,32,8,4,2,42.562939453125 +dds,True,True,16,32,64,8,2,4,35.741265869140626 +dds,True,True,16,32,128,8,2,4,34.263336181640625 +dds,True,True,16,64,8,2,4,2,114.30032958984376 +dds,True,True,16,64,16,4,8,2,76.36978759765626 +dds,True,True,16,64,32,8,2,2,58.7079345703125 +dds,True,True,16,64,64,8,2,2,54.8456787109375 +dds,True,True,16,64,128,8,8,2,53.772900390625 +dds,True,True,16,128,8,2,2,2,146.853857421875 +dds,True,True,16,128,16,4,4,2,110.31800537109376 +dds,True,True,16,128,32,4,16,2,96.67322387695312 +dds,True,True,16,128,64,4,2,2,93.23410034179688 +dds,True,True,16,128,128,2,2,2,inf +dds,True,True,32,8,8,8,2,2,101.45195922851562 +dds,True,True,32,8,16,8,4,2,42.46036376953125 +dds,True,True,32,8,32,8,2,4,28.42144470214844 +dds,True,True,32,8,64,4,8,16,25.17054748535156 +dds,True,True,32,8,128,2,2,2,inf +dds,True,True,32,16,8,4,2,2,73.46295776367188 +dds,True,True,32,16,16,8,2,2,42.08187561035156 +dds,True,True,32,16,32,8,16,4,28.084405517578126 +dds,True,True,32,16,64,16,16,4,22.42252807617188 +dds,True,True,32,16,128,2,2,2,inf +dds,True,True,32,32,8,4,2,2,91.77128295898436 +dds,True,True,32,32,16,8,2,2,52.41979370117188 +dds,True,True,32,32,32,16,16,2,32.20746765136719 +dds,True,True,32,32,64,8,4,4,25.96764221191406 +dds,True,True,32,32,128,2,2,2,inf +dds,True,True,32,64,8,4,2,2,99.95119018554688 +dds,True,True,32,64,16,8,4,2,60.39188842773437 +dds,True,True,32,64,32,16,2,2,40.84258728027344 +dds,True,True,32,64,64,8,4,4,36.96003723144531 +dds,True,True,32,64,128,2,2,2,inf +dds,True,True,32,128,8,4,4,2,116.718896484375 +dds,True,True,32,128,16,4,4,2,81.38380737304688 +dds,True,True,32,128,32,8,16,2,60.08902587890625 +dds,True,True,32,128,64,4,16,4,54.5151123046875 +dds,True,True,32,128,128,2,2,2,inf +dds,True,True,64,8,8,16,2,2,159.270458984375 +dds,True,True,64,8,16,8,4,4,50.46022338867188 +dds,True,True,64,8,32,16,4,4,31.01878051757813 +dds,True,True,64,8,64,2,2,2,inf +dds,True,True,64,8,128,2,2,2,inf +dds,True,True,64,16,8,8,2,2,98.39224243164062 +dds,True,True,64,16,16,16,4,2,51.91385498046875 +dds,True,True,64,16,32,16,8,4,31.189453125 +dds,True,True,64,16,64,2,2,2,inf +dds,True,True,64,16,128,2,2,2,inf +dds,True,True,64,32,8,8,4,2,116.67642822265626 +dds,True,True,64,32,16,16,8,2,61.442633056640624 +dds,True,True,64,32,32,16,8,2,35.341748046875 +dds,True,True,64,32,64,2,2,2,inf +dds,True,True,64,32,128,2,2,2,inf +dds,True,True,64,64,8,8,4,2,122.48782958984376 +dds,True,True,64,64,16,8,2,2,67.2903076171875 +dds,True,True,64,64,32,8,16,2,40.63457336425781 +dds,True,True,64,64,64,2,2,2,inf +dds,True,True,64,64,128,2,2,2,inf +dds,True,True,64,128,8,4,8,2,135.30081787109376 +dds,True,True,64,128,16,4,8,2,82.08575439453125 +dds,True,True,64,128,32,8,4,2,53.673974609375 +dds,True,True,64,128,64,2,2,2,inf +dds,True,True,64,128,128,2,2,2,inf +dds,True,True,128,8,8,2,2,2,inf +dds,True,True,128,8,16,16,4,4,80.01533203125 +dds,True,True,128,8,32,2,2,2,inf +dds,True,True,128,8,64,2,2,2,inf +dds,True,True,128,8,128,2,2,2,inf +dds,True,True,128,16,8,16,2,2,162.4451416015625 +dds,True,True,128,16,16,16,4,2,83.57948608398438 +dds,True,True,128,16,32,2,2,2,inf +dds,True,True,128,16,64,2,2,2,inf +dds,True,True,128,16,128,2,2,2,inf +dds,True,True,128,32,8,8,4,2,182.0763916015625 +dds,True,True,128,32,16,16,2,2,92.733056640625 +dds,True,True,128,32,32,2,2,2,inf +dds,True,True,128,32,64,2,2,2,inf +dds,True,True,128,32,128,2,2,2,inf +dds,True,True,128,64,8,8,16,2,186.8861083984375 +dds,True,True,128,64,16,8,2,2,99.483642578125 +dds,True,True,128,64,32,2,2,2,inf +dds,True,True,128,64,64,2,2,2,inf +dds,True,True,128,64,128,2,2,2,inf +dds,True,True,128,128,8,2,2,2,inf +dds,True,True,128,128,16,2,2,2,inf +dds,True,True,128,128,32,2,2,2,inf +dds,True,True,128,128,64,2,2,2,inf +dds,True,True,128,128,128,2,2,2,inf +dsd,False,False,8,8,8,2,2,2,93.71647338867189 +dsd,False,False,8,8,16,4,2,2,64.03095092773438 +dsd,False,False,8,8,32,4,2,4,58.64285888671875 +dsd,False,False,8,8,64,4,4,8,56.20660400390625 +dsd,False,False,8,8,128,2,2,2,inf +dsd,False,False,8,16,8,2,2,2,87.55302124023437 +dsd,False,False,8,16,16,2,8,2,61.10001831054687 +dsd,False,False,8,16,32,4,16,4,58.4094970703125 +dsd,False,False,8,16,64,2,8,8,58.043218994140624 +dsd,False,False,8,16,128,4,4,8,55.48840942382813 +dsd,False,False,8,32,8,2,16,2,94.81031494140623 +dsd,False,False,8,32,16,2,4,2,60.44342041015625 +dsd,False,False,8,32,32,2,4,4,57.83162231445313 +dsd,False,False,8,32,64,4,8,4,57.91334228515625 +dsd,False,False,8,32,128,2,2,2,inf +dsd,False,False,8,64,8,2,16,2,92.61300659179688 +dsd,False,False,8,64,16,2,16,2,60.11107177734375 +dsd,False,False,8,64,32,4,2,2,59.4615478515625 +dsd,False,False,8,64,64,2,2,2,inf +dsd,False,False,8,64,128,2,2,2,inf +dsd,False,False,8,128,8,2,2,2,inf +dsd,False,False,8,128,16,2,2,2,60.06309204101562 +dsd,False,False,8,128,32,2,2,2,inf +dsd,False,False,8,128,64,2,2,2,inf +dsd,False,False,8,128,128,2,2,2,inf +dsd,False,False,16,8,8,4,2,2,63.24615478515625 +dsd,False,False,16,8,16,4,4,2,39.68951416015625 +dsd,False,False,16,8,32,4,8,4,30.354635620117183 +dsd,False,False,16,8,64,4,2,8,27.33725891113281 +dsd,False,False,16,8,128,16,4,4,24.254103088378905 +dsd,False,False,16,16,8,2,8,2,58.42618408203125 +dsd,False,False,16,16,16,4,16,2,34.07795715332031 +dsd,False,False,16,16,32,4,4,4,29.76649169921875 +dsd,False,False,16,16,64,16,16,2,27.015093994140624 +dsd,False,False,16,16,128,16,8,4,27.910430908203125 +dsd,False,False,16,32,8,2,4,2,68.20228881835938 +dsd,False,False,16,32,16,4,2,2,38.896176147460935 +dsd,False,False,16,32,32,8,16,2,30.33519592285156 +dsd,False,False,16,32,64,16,4,2,29.708221435546875 +dsd,False,False,16,32,128,2,2,2,inf +dsd,False,False,16,64,8,2,2,2,66.80291137695312 +dsd,False,False,16,64,16,4,16,2,38.57438659667969 +dsd,False,False,16,64,32,4,2,2,30.53526916503906 +dsd,False,False,16,64,64,2,2,2,inf +dsd,False,False,16,64,128,2,2,2,inf +dsd,False,False,16,128,8,2,2,2,65.34825439453125 +dsd,False,False,16,128,16,4,4,2,37.8095458984375 +dsd,False,False,16,128,32,2,2,2,inf +dsd,False,False,16,128,64,2,2,2,inf +dsd,False,False,16,128,128,2,2,2,inf +dsd,False,False,32,8,8,8,2,2,47.86884155273437 +dsd,False,False,32,8,16,4,2,4,30.295150756835938 +dsd,False,False,32,8,32,8,8,4,19.929270935058597 +dsd,False,False,32,8,64,16,2,4,15.657574462890626 +dsd,False,False,32,8,128,16,4,8,14.73047332763672 +dsd,False,False,32,16,8,4,2,2,46.04456176757813 +dsd,False,False,32,16,16,8,16,2,26.240228271484376 +dsd,False,False,32,16,32,8,4,4,18.08135070800781 +dsd,False,False,32,16,64,16,8,4,15.81854705810547 +dsd,False,False,32,16,128,8,16,8,15.232643127441406 +dsd,False,False,32,32,8,4,4,2,55.965277099609374 +dsd,False,False,32,32,16,8,16,2,31.10966796875 +dsd,False,False,32,32,32,16,8,2,18.982470703125 +dsd,False,False,32,32,64,8,4,4,16.208685302734374 +dsd,False,False,32,32,128,2,2,2,inf +dsd,False,False,32,64,8,4,4,2,55.36128540039063 +dsd,False,False,32,64,16,8,8,2,31.33301086425781 +dsd,False,False,32,64,32,8,8,2,20.929423522949214 +dsd,False,False,32,64,64,2,2,2,inf +dsd,False,False,32,64,128,2,2,2,inf +dsd,False,False,32,128,8,4,4,2,54.879913330078125 +dsd,False,False,32,128,16,4,4,2,36.07020263671875 +dsd,False,False,32,128,32,2,2,2,inf +dsd,False,False,32,128,64,2,2,2,inf +dsd,False,False,32,128,128,2,2,2,inf +dsd,False,False,64,8,8,16,2,2,42.93273315429688 +dsd,False,False,64,8,16,8,8,4,24.822169494628906 +dsd,False,False,64,8,32,16,8,4,16.78294982910156 +dsd,False,False,64,8,64,16,8,8,13.175990295410156 +dsd,False,False,64,8,128,16,2,8,12.944717407226562 +dsd,False,False,64,16,8,8,2,2,39.867437744140624 +dsd,False,False,64,16,16,16,8,2,21.873036193847657 +dsd,False,False,64,16,32,16,16,4,15.200108337402344 +dsd,False,False,64,16,64,8,8,8,13.46641845703125 +dsd,False,False,64,16,128,16,16,4,12.786656188964844 +dsd,False,False,64,32,8,8,8,2,50.62677612304688 +dsd,False,False,64,32,16,16,16,2,27.102239990234374 +dsd,False,False,64,32,32,16,16,2,17.164028930664063 +dsd,False,False,64,32,64,8,16,4,14.032992553710937 +dsd,False,False,64,32,128,2,2,2,inf +dsd,False,False,64,64,8,8,8,2,50.184375 +dsd,False,False,64,64,16,16,16,2,29.096368408203126 +dsd,False,False,64,64,32,8,16,2,19.208355712890626 +dsd,False,False,64,64,64,2,2,2,inf +dsd,False,False,64,64,128,2,2,2,inf +dsd,False,False,64,128,8,4,4,2,67.56678466796875 +dsd,False,False,64,128,16,8,16,2,35.93500671386719 +dsd,False,False,64,128,32,2,2,2,inf +dsd,False,False,64,128,64,2,2,2,inf +dsd,False,False,64,128,128,2,2,2,inf +dsd,False,False,128,8,8,16,2,4,40.03868103027344 +dsd,False,False,128,8,16,16,2,4,21.720985412597656 +dsd,False,False,128,8,32,16,8,8,15.086416625976565 +dsd,False,False,128,8,64,16,8,8,13.1880859375 +dsd,False,False,128,8,128,16,8,8,12.556444549560547 +dsd,False,False,128,16,8,16,2,2,36.49515380859375 +dsd,False,False,128,16,16,16,8,4,20.951942443847656 +dsd,False,False,128,16,32,16,8,4,14.756060791015624 +dsd,False,False,128,16,64,16,8,4,13.128387451171877 +dsd,False,False,128,16,128,16,8,8,12.023948669433594 +dsd,False,False,128,32,8,16,16,2,48.86194763183594 +dsd,False,False,128,32,16,16,16,2,26.5552001953125 +dsd,False,False,128,32,32,16,16,2,16.184320068359376 +dsd,False,False,128,32,64,16,8,4,12.830514526367187 +dsd,False,False,128,32,128,2,2,2,inf +dsd,False,False,128,64,8,4,8,2,58.8990478515625 +dsd,False,False,128,64,16,8,2,2,31.26006774902344 +dsd,False,False,128,64,32,16,2,2,18.415318298339844 +dsd,False,False,128,64,64,2,2,2,inf +dsd,False,False,128,64,128,2,2,2,inf +dsd,False,False,128,128,8,2,2,2,inf +dsd,False,False,128,128,16,2,2,2,inf +dsd,False,False,128,128,32,2,2,2,inf +dsd,False,False,128,128,64,2,2,2,inf +dsd,False,False,128,128,128,2,2,2,inf +dsd,False,True,8,8,8,2,2,2,106.71483154296877 +dsd,False,True,8,8,16,4,2,2,73.85866088867188 +dsd,False,True,8,8,32,8,4,2,61.3180908203125 +dsd,False,True,8,8,64,8,4,4,59.63038940429688 +dsd,False,True,8,8,128,4,8,16,57.55181884765625 +dsd,False,True,8,16,8,2,2,2,116.87203369140624 +dsd,False,True,8,16,16,2,2,2,64.51404418945313 +dsd,False,True,8,16,32,4,8,2,61.75587158203125 +dsd,False,True,8,16,64,2,8,8,62.07415771484375 +dsd,False,True,8,16,128,8,16,4,59.1298583984375 +dsd,False,True,8,32,8,2,2,2,155.31005859375 +dsd,False,True,8,32,16,2,8,2,83.82421875 +dsd,False,True,8,32,32,4,8,2,64.32686767578124 +dsd,False,True,8,32,64,8,2,2,62.68171997070313 +dsd,False,True,8,32,128,2,2,2,inf +dsd,False,True,8,64,8,2,2,2,219.0763671875 +dsd,False,True,8,64,16,2,2,2,113.57874755859376 +dsd,False,True,8,64,32,4,2,2,95.80451049804688 +dsd,False,True,8,64,64,2,2,2,inf +dsd,False,True,8,64,128,2,2,2,inf +dsd,False,True,8,128,8,2,2,2,inf +dsd,False,True,8,128,16,2,2,2,179.626806640625 +dsd,False,True,8,128,32,2,2,2,inf +dsd,False,True,8,128,64,2,2,2,inf +dsd,False,True,8,128,128,2,2,2,inf +dsd,False,True,16,8,8,4,2,2,70.9658447265625 +dsd,False,True,16,8,16,4,8,2,43.00933837890625 +dsd,False,True,16,8,32,4,8,4,32.59908752441406 +dsd,False,True,16,8,64,16,8,4,31.052691650390624 +dsd,False,True,16,8,128,16,8,4,28.495932006835936 +dsd,False,True,16,16,8,2,2,2,65.24340209960937 +dsd,False,True,16,16,16,4,16,2,40.7180908203125 +dsd,False,True,16,16,32,4,16,4,34.13094482421875 +dsd,False,True,16,16,64,8,16,4,31.407711791992188 +dsd,False,True,16,16,128,16,4,4,31.13136901855469 +dsd,False,True,16,32,8,2,2,2,82.91044921875 +dsd,False,True,16,32,16,4,8,2,53.95642700195312 +dsd,False,True,16,32,32,8,16,2,38.63623657226562 +dsd,False,True,16,32,64,8,8,4,33.974850463867185 +dsd,False,True,16,32,128,2,2,2,inf +dsd,False,True,16,64,8,2,2,2,97.25258178710938 +dsd,False,True,16,64,16,4,2,2,69.24348754882813 +dsd,False,True,16,64,32,8,2,2,55.416680908203126 +dsd,False,True,16,64,64,2,2,2,inf +dsd,False,True,16,64,128,2,2,2,inf +dsd,False,True,16,128,8,2,2,2,130.18917236328124 +dsd,False,True,16,128,16,4,2,2,104.25732421875 +dsd,False,True,16,128,32,2,2,2,inf +dsd,False,True,16,128,64,2,2,2,inf +dsd,False,True,16,128,128,2,2,2,inf +dsd,False,True,32,8,8,8,2,2,53.04258422851562 +dsd,False,True,32,8,16,8,2,2,32.39458312988281 +dsd,False,True,32,8,32,8,4,4,21.76031646728516 +dsd,False,True,32,8,64,16,4,4,18.56409606933594 +dsd,False,True,32,8,128,16,2,8,17.860627746582033 +dsd,False,True,32,16,8,4,2,2,49.93617858886719 +dsd,False,True,32,16,16,8,8,2,29.666705322265624 +dsd,False,True,32,16,32,8,16,4,21.81918792724609 +dsd,False,True,32,16,64,16,2,4,19.822023010253908 +dsd,False,True,32,16,128,16,2,4,17.916249084472657 +dsd,False,True,32,32,8,4,2,2,64.38092651367188 +dsd,False,True,32,32,16,8,16,2,38.615664672851565 +dsd,False,True,32,32,32,16,4,2,26.05098571777344 +dsd,False,True,32,32,64,8,16,4,22.82095642089844 +dsd,False,True,32,32,128,2,2,2,inf +dsd,False,True,32,64,8,4,4,2,70.98646240234375 +dsd,False,True,32,64,16,8,16,2,46.50224914550781 +dsd,False,True,32,64,32,16,8,2,35.40425415039063 +dsd,False,True,32,64,64,2,2,2,inf +dsd,False,True,32,64,128,2,2,2,inf +dsd,False,True,32,128,8,4,2,2,92.30762329101564 +dsd,False,True,32,128,16,4,2,2,67.62554321289062 +dsd,False,True,32,128,32,2,2,2,inf +dsd,False,True,32,128,64,2,2,2,inf +dsd,False,True,32,128,128,2,2,2,inf +dsd,False,True,64,8,8,16,2,2,44.95367126464844 +dsd,False,True,64,8,16,8,4,4,25.882449340820312 +dsd,False,True,64,8,32,16,4,4,18.52846069335937 +dsd,False,True,64,8,64,16,4,8,14.562205505371091 +dsd,False,True,64,8,128,16,4,8,13.817481994628906 +dsd,False,True,64,16,8,8,4,2,41.5960693359375 +dsd,False,True,64,16,16,16,8,2,23.975321960449214 +dsd,False,True,64,16,32,16,4,4,17.057513427734374 +dsd,False,True,64,16,64,16,16,4,15.122802734375 +dsd,False,True,64,16,128,16,2,4,14.433549499511718 +dsd,False,True,64,32,8,8,4,2,54.07551879882813 +dsd,False,True,64,32,16,16,16,2,30.9832275390625 +dsd,False,True,64,32,32,16,8,2,20.66288604736328 +dsd,False,True,64,32,64,8,2,4,17.48600616455078 +dsd,False,True,64,32,128,2,2,2,inf +dsd,False,True,64,64,8,8,2,2,60.210791015625 +dsd,False,True,64,64,16,8,4,2,36.817529296875 +dsd,False,True,64,64,32,8,16,2,26.52713623046875 +dsd,False,True,64,64,64,2,2,2,inf +dsd,False,True,64,64,128,2,2,2,inf +dsd,False,True,64,128,8,4,8,2,78.29977416992188 +dsd,False,True,64,128,16,4,16,2,51.589532470703126 +dsd,False,True,64,128,32,2,2,2,inf +dsd,False,True,64,128,64,2,2,2,inf +dsd,False,True,64,128,128,2,2,2,inf +dsd,False,True,128,8,8,16,2,4,41.05773010253906 +dsd,False,True,128,8,16,16,2,4,22.999497985839845 +dsd,False,True,128,8,32,16,8,4,17.009033203125 +dsd,False,True,128,8,64,16,8,8,13.924940490722657 +dsd,False,True,128,8,128,16,8,8,13.011990356445311 +dsd,False,True,128,16,8,16,8,2,37.16488037109375 +dsd,False,True,128,16,16,16,8,4,22.16284484863281 +dsd,False,True,128,16,32,16,4,4,16.032858276367186 +dsd,False,True,128,16,64,8,4,8,14.073855590820312 +dsd,False,True,128,16,128,16,4,8,12.9249755859375 +dsd,False,True,128,32,8,16,8,2,49.282763671875 +dsd,False,True,128,32,16,16,8,2,28.181707763671877 +dsd,False,True,128,32,32,16,4,2,18.208338928222656 +dsd,False,True,128,32,64,16,2,4,14.724563598632812 +dsd,False,True,128,32,128,2,2,2,inf +dsd,False,True,128,64,8,8,8,2,62.97972412109375 +dsd,False,True,128,64,16,8,2,2,34.55672912597656 +dsd,False,True,128,64,32,16,16,2,21.978770446777343 +dsd,False,True,128,64,64,2,2,2,inf +dsd,False,True,128,64,128,2,2,2,inf +dsd,False,True,128,128,8,2,2,2,inf +dsd,False,True,128,128,16,2,2,2,inf +dsd,False,True,128,128,32,2,2,2,inf +dsd,False,True,128,128,64,2,2,2,inf +dsd,False,True,128,128,128,2,2,2,inf +dsd,True,False,8,8,8,2,2,2,105.0726318359375 +dsd,True,False,8,8,16,4,8,2,68.37989501953125 +dsd,True,False,8,8,32,4,2,4,59.30982666015625 +dsd,True,False,8,8,64,8,2,4,56.40568237304687 +dsd,True,False,8,8,128,2,2,2,inf +dsd,True,False,8,16,8,2,2,2,107.74365234375 +dsd,True,False,8,16,16,2,8,2,62.30513916015625 +dsd,True,False,8,16,32,8,4,2,58.94583740234375 +dsd,True,False,8,16,64,4,16,8,58.718438720703126 +dsd,True,False,8,16,128,8,16,4,56.52176513671875 +dsd,True,False,8,32,8,2,2,2,118.640625 +dsd,True,False,8,32,16,2,4,2,65.05188598632813 +dsd,True,False,8,32,32,2,16,4,58.65556640625 +dsd,True,False,8,32,64,8,2,2,58.23341674804688 +dsd,True,False,8,32,128,2,2,2,inf +dsd,True,False,8,64,8,2,2,2,116.23341064453123 +dsd,True,False,8,64,16,2,16,2,63.991015625 +dsd,True,False,8,64,32,4,2,2,59.7391357421875 +dsd,True,False,8,64,64,2,2,2,inf +dsd,True,False,8,64,128,2,2,2,inf +dsd,True,False,8,128,8,2,2,2,115.34276123046877 +dsd,True,False,8,128,16,2,16,2,63.252734375 +dsd,True,False,8,128,32,2,2,2,inf +dsd,True,False,8,128,64,2,2,2,inf +dsd,True,False,8,128,128,2,2,2,inf +dsd,True,False,16,8,8,4,2,2,83.14073486328125 +dsd,True,False,16,8,16,4,8,2,41.93546142578125 +dsd,True,False,16,8,32,4,2,4,30.63380126953125 +dsd,True,False,16,8,64,8,8,4,26.991793823242187 +dsd,True,False,16,8,128,16,8,4,24.62370910644531 +dsd,True,False,16,16,8,2,2,2,70.25438842773437 +dsd,True,False,16,16,16,4,8,2,39.605865478515625 +dsd,True,False,16,16,32,8,8,2,30.216619873046877 +dsd,True,False,16,16,64,16,16,2,26.670489501953124 +dsd,True,False,16,16,128,16,8,4,28.378298950195312 +dsd,True,False,16,32,8,2,2,2,83.79273681640625 +dsd,True,False,16,32,16,4,16,2,47.30290222167969 +dsd,True,False,16,32,32,8,2,2,30.566775512695312 +dsd,True,False,16,32,64,16,4,2,29.870693969726563 +dsd,True,False,16,32,128,2,2,2,inf +dsd,True,False,16,64,8,2,2,2,82.36276245117188 +dsd,True,False,16,64,16,4,16,2,46.767092895507815 +dsd,True,False,16,64,32,8,16,2,30.788980102539064 +dsd,True,False,16,64,64,2,2,2,inf +dsd,True,False,16,64,128,2,2,2,inf +dsd,True,False,16,128,8,2,2,2,81.1625 +dsd,True,False,16,128,16,4,16,2,45.954278564453126 +dsd,True,False,16,128,32,2,2,2,inf +dsd,True,False,16,128,64,2,2,2,inf +dsd,True,False,16,128,128,2,2,2,inf +dsd,True,False,32,8,8,8,2,2,95.72559814453123 +dsd,True,False,32,8,16,8,2,2,38.18620300292969 +dsd,True,False,32,8,32,8,4,4,23.87391662597656 +dsd,True,False,32,8,64,16,8,4,17.351699829101562 +dsd,True,False,32,8,128,16,4,8,15.071437072753906 +dsd,True,False,32,16,8,4,2,2,69.16351318359375 +dsd,True,False,32,16,16,8,2,2,37.622555541992185 +dsd,True,False,32,16,32,8,16,4,23.72455749511719 +dsd,True,False,32,16,64,8,8,8,17.51331481933594 +dsd,True,False,32,16,128,16,4,4,15.982438659667968 +dsd,True,False,32,32,8,4,2,2,84.504345703125 +dsd,True,False,32,32,16,8,2,2,45.35299987792969 +dsd,True,False,32,32,32,16,4,2,25.59196166992188 +dsd,True,False,32,32,64,8,2,4,19.055007934570312 +dsd,True,False,32,32,128,2,2,2,inf +dsd,True,False,32,64,8,4,8,2,84.14051513671875 +dsd,True,False,32,64,16,8,16,2,45.101986694335935 +dsd,True,False,32,64,32,16,2,2,26.621539306640624 +dsd,True,False,32,64,64,2,2,2,inf +dsd,True,False,32,64,128,2,2,2,inf +dsd,True,False,32,128,8,4,16,2,84.91974487304688 +dsd,True,False,32,128,16,8,16,2,49.0318603515625 +dsd,True,False,32,128,32,2,2,2,inf +dsd,True,False,32,128,64,2,2,2,inf +dsd,True,False,32,128,128,2,2,2,inf +dsd,True,False,64,8,8,16,4,2,156.4103271484375 +dsd,True,False,64,8,16,8,2,4,47.864923095703126 +dsd,True,False,64,8,32,8,8,8,28.413397216796877 +dsd,True,False,64,8,64,16,8,8,19.639353942871093 +dsd,True,False,64,8,128,16,8,8,15.72833251953125 +dsd,True,False,64,16,8,8,8,2,97.26983642578124 +dsd,True,False,64,16,16,16,16,2,49.42695007324219 +dsd,True,False,64,16,32,16,16,4,28.5388427734375 +dsd,True,False,64,16,64,16,16,4,19.786595153808594 +dsd,True,False,64,16,128,8,2,8,16.351426696777345 +dsd,True,False,64,32,8,8,2,2,113.108349609375 +dsd,True,False,64,32,16,16,2,2,57.786370849609376 +dsd,True,False,64,32,32,16,8,2,31.747702026367183 +dsd,True,False,64,32,64,16,2,4,20.874855041503903 +dsd,True,False,64,32,128,2,2,2,inf +dsd,True,False,64,64,8,8,16,2,114.476025390625 +dsd,True,False,64,64,16,16,16,2,58.68845825195312 +dsd,True,False,64,64,32,8,2,2,33.57569885253906 +dsd,True,False,64,64,64,2,2,2,inf +dsd,True,False,64,64,128,2,2,2,inf +dsd,True,False,64,128,8,4,4,2,120.13372802734376 +dsd,True,False,64,128,16,4,16,2,65.25214233398438 +dsd,True,False,64,128,32,2,2,2,inf +dsd,True,False,64,128,64,2,2,2,inf +dsd,True,False,64,128,128,2,2,2,inf +dsd,True,False,128,8,8,2,2,2,inf +dsd,True,False,128,8,16,16,2,4,79.05245971679688 +dsd,True,False,128,8,32,8,2,8,43.53925170898437 +dsd,True,False,128,8,64,16,8,8,26.67002868652344 +dsd,True,False,128,8,128,16,8,8,19.352313232421874 +dsd,True,False,128,16,8,16,2,2,161.6327392578125 +dsd,True,False,128,16,16,16,8,2,82.59317626953126 +dsd,True,False,128,16,32,16,8,4,44.50947265625 +dsd,True,False,128,16,64,16,4,8,27.16807861328125 +dsd,True,False,128,16,128,16,8,8,19.533062744140626 +dsd,True,False,128,32,8,8,2,2,180.34525146484376 +dsd,True,False,128,32,16,16,8,2,90.9325927734375 +dsd,True,False,128,32,32,16,4,2,47.33459777832032 +dsd,True,False,128,32,64,16,4,4,28.327731323242187 +dsd,True,False,128,32,128,2,2,2,inf +dsd,True,False,128,64,8,8,16,2,183.2932373046875 +dsd,True,False,128,64,16,8,8,2,94.98787841796874 +dsd,True,False,128,64,32,16,2,2,48.95641479492188 +dsd,True,False,128,64,64,2,2,2,inf +dsd,True,False,128,64,128,2,2,2,inf +dsd,True,False,128,128,8,2,2,2,inf +dsd,True,False,128,128,16,2,2,2,inf +dsd,True,False,128,128,32,2,2,2,inf +dsd,True,False,128,128,64,2,2,2,inf +dsd,True,False,128,128,128,2,2,2,inf +dsd,True,True,8,8,8,2,2,2,118.23125 +dsd,True,True,8,8,16,4,2,2,80.09966430664062 +dsd,True,True,8,8,32,8,4,2,62.34053955078125 +dsd,True,True,8,8,64,8,8,4,59.93763427734375 +dsd,True,True,8,8,128,8,8,8,52.66724853515625 +dsd,True,True,8,16,8,2,2,2,137.11962890625 +dsd,True,True,8,16,16,2,16,2,70.31336669921875 +dsd,True,True,8,16,32,4,2,2,62.28681640625 +dsd,True,True,8,16,64,2,8,8,64.71558227539063 +dsd,True,True,8,16,128,8,2,4,61.026214599609375 +dsd,True,True,8,32,8,2,2,2,179.8203369140625 +dsd,True,True,8,32,16,2,16,2,92.24152221679688 +dsd,True,True,8,32,32,4,2,2,68.66116333007812 +dsd,True,True,8,32,64,8,2,2,63.09805908203125 +dsd,True,True,8,32,128,2,2,2,inf +dsd,True,True,8,64,8,2,4,2,244.2031982421875 +dsd,True,True,8,64,16,2,2,2,122.4283447265625 +dsd,True,True,8,64,32,4,4,2,99.98135375976562 +dsd,True,True,8,64,64,2,2,2,inf +dsd,True,True,8,64,128,2,2,2,inf +dsd,True,True,8,128,8,2,2,2,inf +dsd,True,True,8,128,16,2,2,2,187.6283447265625 +dsd,True,True,8,128,32,2,2,2,inf +dsd,True,True,8,128,64,2,2,2,inf +dsd,True,True,8,128,128,2,2,2,inf +dsd,True,True,16,8,8,4,2,2,90.2330078125 +dsd,True,True,16,8,16,4,8,2,44.93805847167969 +dsd,True,True,16,8,32,8,8,2,33.386087036132814 +dsd,True,True,16,8,64,4,8,8,30.773260498046877 +dsd,True,True,16,8,128,16,4,4,25.82737426757813 +dsd,True,True,16,16,8,2,2,2,76.90401000976563 +dsd,True,True,16,16,16,4,4,2,46.816644287109376 +dsd,True,True,16,16,32,8,2,2,35.595672607421875 +dsd,True,True,16,16,64,8,8,4,31.4635498046875 +dsd,True,True,16,16,128,16,2,4,31.38978271484375 +dsd,True,True,16,32,8,2,2,2,99.06978759765624 +dsd,True,True,16,32,16,4,16,2,62.08654174804688 +dsd,True,True,16,32,32,8,16,2,42.750265502929686 +dsd,True,True,16,32,64,8,2,4,35.54405212402344 +dsd,True,True,16,32,128,2,2,2,inf +dsd,True,True,16,64,8,2,2,2,113.74410400390624 +dsd,True,True,16,64,16,4,8,2,77.19734497070313 +dsd,True,True,16,64,32,8,16,2,59.00011596679688 +dsd,True,True,16,64,64,2,2,2,inf +dsd,True,True,16,64,128,2,2,2,inf +dsd,True,True,16,128,8,2,2,2,146.14710693359376 +dsd,True,True,16,128,16,4,4,2,110.43963623046876 +dsd,True,True,16,128,32,2,2,2,inf +dsd,True,True,16,128,64,2,2,2,inf +dsd,True,True,16,128,128,2,2,2,inf +dsd,True,True,32,8,8,8,2,2,99.25596313476562 +dsd,True,True,32,8,16,4,2,4,39.78704833984375 +dsd,True,True,32,8,32,8,8,4,25.51207427978516 +dsd,True,True,32,8,64,8,8,8,18.553628540039064 +dsd,True,True,32,8,128,16,4,8,16.5560546875 +dsd,True,True,32,16,8,4,2,2,73.2929443359375 +dsd,True,True,32,16,16,8,2,2,41.48325805664062 +dsd,True,True,32,16,32,8,16,4,27.090509033203126 +dsd,True,True,32,16,64,16,8,4,20.703590393066406 +dsd,True,True,32,16,128,16,2,4,18.536476135253903 +dsd,True,True,32,32,8,4,2,2,92.30847778320312 +dsd,True,True,32,32,16,8,16,2,52.80338134765625 +dsd,True,True,32,32,32,16,16,2,32.82637329101563 +dsd,True,True,32,32,64,8,2,4,26.021478271484376 +dsd,True,True,32,32,128,2,2,2,inf +dsd,True,True,32,64,8,4,16,2,100.34822998046874 +dsd,True,True,32,64,16,8,2,2,60.166455078125 +dsd,True,True,32,64,32,8,8,2,42.613311767578125 +dsd,True,True,32,64,64,2,2,2,inf +dsd,True,True,32,64,128,2,2,2,inf +dsd,True,True,32,128,8,4,4,2,117.73109130859376 +dsd,True,True,32,128,16,4,2,2,81.627783203125 +dsd,True,True,32,128,32,2,2,2,inf +dsd,True,True,32,128,64,2,2,2,inf +dsd,True,True,32,128,128,2,2,2,inf +dsd,True,True,64,8,8,16,4,2,158.24068603515624 +dsd,True,True,64,8,16,8,4,4,49.08482666015625 +dsd,True,True,64,8,32,8,8,8,29.31374206542969 +dsd,True,True,64,8,64,8,8,8,20.53645477294922 +dsd,True,True,64,8,128,16,4,8,16.3684326171875 +dsd,True,True,64,16,8,8,2,2,98.58057250976564 +dsd,True,True,64,16,16,16,2,2,51.93522338867187 +dsd,True,True,64,16,32,16,16,4,30.307171630859376 +dsd,True,True,64,16,64,16,4,8,21.627452087402343 +dsd,True,True,64,16,128,16,2,8,17.584332275390626 +dsd,True,True,64,32,8,8,8,2,116.79525146484374 +dsd,True,True,64,32,16,16,2,2,61.6658935546875 +dsd,True,True,64,32,32,16,4,2,35.399261474609375 +dsd,True,True,64,32,64,16,2,4,24.33681030273437 +dsd,True,True,64,32,128,2,2,2,inf +dsd,True,True,64,64,8,8,2,2,122.92733154296874 +dsd,True,True,64,64,16,8,2,2,67.2964599609375 +dsd,True,True,64,64,32,8,16,2,40.79920654296875 +dsd,True,True,64,64,64,2,2,2,inf +dsd,True,True,64,64,128,2,2,2,inf +dsd,True,True,64,128,8,4,4,2,133.99801025390624 +dsd,True,True,64,128,16,4,16,2,81.50208740234375 +dsd,True,True,64,128,32,2,2,2,inf +dsd,True,True,64,128,64,2,2,2,inf +dsd,True,True,64,128,128,2,2,2,inf +dsd,True,True,128,8,8,2,2,2,inf +dsd,True,True,128,8,16,16,4,4,79.65204467773438 +dsd,True,True,128,8,32,8,8,8,44.142999267578126 +dsd,True,True,128,8,64,16,8,8,27.371810913085938 +dsd,True,True,128,8,128,16,2,8,19.97423095703125 +dsd,True,True,128,16,8,16,2,2,162.4993896484375 +dsd,True,True,128,16,16,16,2,2,83.48057861328125 +dsd,True,True,128,16,32,16,8,4,45.415707397460935 +dsd,True,True,128,16,64,16,4,8,28.13584289550781 +dsd,True,True,128,16,128,16,4,8,20.519363403320312 +dsd,True,True,128,32,8,8,2,2,182.2271728515625 +dsd,True,True,128,32,16,16,2,2,92.73519287109374 +dsd,True,True,128,32,32,16,16,2,49.20975341796875 +dsd,True,True,128,32,64,16,8,4,30.175570678710937 +dsd,True,True,128,32,128,2,2,2,inf +dsd,True,True,128,64,8,8,4,2,187.40059814453124 +dsd,True,True,128,64,16,8,2,2,98.85172119140626 +dsd,True,True,128,64,32,16,2,2,53.22855224609375 +dsd,True,True,128,64,64,2,2,2,inf +dsd,True,True,128,64,128,2,2,2,inf +dsd,True,True,128,128,8,2,2,2,inf +dsd,True,True,128,128,16,2,2,2,inf +dsd,True,True,128,128,32,2,2,2,inf +dsd,True,True,128,128,64,2,2,2,inf +dsd,True,True,128,128,128,2,2,2,inf +sdd,False,False,8,8,8,2,2,2,100.51373901367188 +sdd,False,False,8,8,16,4,2,2,64.60966796875 +sdd,False,False,8,8,32,8,8,2,57.00081787109375 +sdd,False,False,8,8,64,4,4,8,56.4310546875 +sdd,False,False,8,8,128,2,2,2,inf +sdd,False,False,8,16,8,2,2,2,98.814892578125 +sdd,False,False,8,16,16,2,4,2,58.948492431640624 +sdd,False,False,8,16,32,4,8,4,56.5875732421875 +sdd,False,False,8,16,64,8,8,4,55.15122680664062 +sdd,False,False,8,16,128,2,16,16,55.09329833984375 +sdd,False,False,8,32,8,2,4,2,110.617626953125 +sdd,False,False,8,32,16,4,2,2,72.39398803710938 +sdd,False,False,8,32,32,8,8,2,55.9151123046875 +sdd,False,False,8,32,64,4,4,4,55.30057373046875 +sdd,False,False,8,32,128,8,16,2,55.566705322265626 +sdd,False,False,8,64,8,2,16,2,114.453466796875 +sdd,False,False,8,64,16,4,8,2,69.06260986328125 +sdd,False,False,8,64,32,2,16,2,56.490509033203125 +sdd,False,False,8,64,64,4,8,2,56.05826416015625 +sdd,False,False,8,64,128,4,8,2,57.33589477539063 +sdd,False,False,8,128,8,2,2,2,inf +sdd,False,False,8,128,16,2,16,2,64.98045654296875 +sdd,False,False,8,128,32,2,16,2,57.79884643554688 +sdd,False,False,8,128,64,2,2,2,60.33175048828125 +sdd,False,False,8,128,128,2,2,2,inf +sdd,False,False,16,8,8,4,2,2,59.2751953125 +sdd,False,False,16,8,16,4,2,2,37.26842346191406 +sdd,False,False,16,8,32,4,8,4,29.124264526367188 +sdd,False,False,16,8,64,16,4,4,28.78825988769531 +sdd,False,False,16,8,128,16,2,4,29.269412231445312 +sdd,False,False,16,16,8,2,16,2,62.56676025390625 +sdd,False,False,16,16,16,4,16,2,35.13594055175781 +sdd,False,False,16,16,32,4,16,4,28.555059814453124 +sdd,False,False,16,16,64,16,16,2,27.95408935546875 +sdd,False,False,16,16,128,4,16,8,28.408853149414064 +sdd,False,False,16,32,8,2,8,2,77.71033325195313 +sdd,False,False,16,32,16,4,4,2,42.81219482421875 +sdd,False,False,16,32,32,8,8,2,28.557516479492183 +sdd,False,False,16,32,64,4,4,4,28.33338928222656 +sdd,False,False,16,32,128,4,4,4,28.3524169921875 +sdd,False,False,16,64,8,2,2,2,77.14835815429687 +sdd,False,False,16,64,16,4,2,2,42.45551452636719 +sdd,False,False,16,64,32,8,8,2,29.25578002929688 +sdd,False,False,16,64,64,8,16,2,28.65099487304688 +sdd,False,False,16,64,128,4,16,4,29.612652587890626 +sdd,False,False,16,128,8,2,2,2,75.71721801757812 +sdd,False,False,16,128,16,4,8,2,41.27252502441407 +sdd,False,False,16,128,32,4,16,2,32.78056945800781 +sdd,False,False,16,128,64,4,8,2,32.97606506347656 +sdd,False,False,16,128,128,2,2,2,inf +sdd,False,False,32,8,8,8,2,2,40.01423034667969 +sdd,False,False,32,8,16,4,8,4,26.011032104492188 +sdd,False,False,32,8,32,8,2,4,17.514921569824217 +sdd,False,False,32,8,64,16,2,4,15.479808044433591 +sdd,False,False,32,8,128,16,2,8,15.296368408203126 +sdd,False,False,32,16,8,4,2,2,46.48823547363281 +sdd,False,False,32,16,16,8,4,2,26.083248901367188 +sdd,False,False,32,16,32,8,16,4,17.562828063964844 +sdd,False,False,32,16,64,8,4,8,15.211427307128906 +sdd,False,False,32,16,128,8,2,8,15.429017639160156 +sdd,False,False,32,32,8,4,4,2,61.226904296875 +sdd,False,False,32,32,16,8,16,2,33.44827575683594 +sdd,False,False,32,32,32,16,16,2,19.189141845703126 +sdd,False,False,32,32,64,8,2,4,16.45886688232422 +sdd,False,False,32,32,128,8,2,4,16.234320068359374 +sdd,False,False,32,64,8,4,4,2,60.69839477539063 +sdd,False,False,32,64,16,8,2,2,33.48793334960938 +sdd,False,False,32,64,32,16,2,2,21.469392395019533 +sdd,False,False,32,64,64,8,8,2,20.12193908691406 +sdd,False,False,32,64,128,8,8,4,19.00091857910156 +sdd,False,False,32,128,8,2,2,2,inf +sdd,False,False,32,128,16,2,2,2,inf +sdd,False,False,32,128,32,2,2,2,inf +sdd,False,False,32,128,64,2,2,2,inf +sdd,False,False,32,128,128,2,2,2,inf +sdd,False,False,64,8,8,16,2,2,32.308685302734375 +sdd,False,False,64,8,16,8,4,4,19.22510986328125 +sdd,False,False,64,8,32,16,2,4,14.082421875 +sdd,False,False,64,8,64,16,8,8,12.987596130371092 +sdd,False,False,64,8,128,16,8,8,12.834736633300782 +sdd,False,False,64,16,8,8,4,2,38.203704833984375 +sdd,False,False,64,16,16,16,16,2,20.37720031738281 +sdd,False,False,64,16,32,16,16,4,14.33055419921875 +sdd,False,False,64,16,64,8,8,8,13.391404724121092 +sdd,False,False,64,16,128,16,16,4,12.859599304199218 +sdd,False,False,64,32,8,8,2,2,53.3166015625 +sdd,False,False,64,32,16,16,2,2,28.50967712402344 +sdd,False,False,64,32,32,16,8,2,17.530674743652344 +sdd,False,False,64,32,64,8,2,4,14.465974426269533 +sdd,False,False,64,32,128,8,4,8,12.986674499511718 +sdd,False,False,64,64,8,2,2,2,inf +sdd,False,False,64,64,16,2,2,2,inf +sdd,False,False,64,64,32,2,2,2,inf +sdd,False,False,64,64,64,2,2,2,inf +sdd,False,False,64,64,128,2,2,2,inf +sdd,False,False,64,128,8,2,2,2,inf +sdd,False,False,64,128,16,2,2,2,inf +sdd,False,False,64,128,32,2,2,2,inf +sdd,False,False,64,128,64,2,2,2,inf +sdd,False,False,64,128,128,2,2,2,inf +sdd,False,False,128,8,8,16,2,4,29.03531494140625 +sdd,False,False,128,8,16,16,2,4,16.07601318359375 +sdd,False,False,128,8,32,16,2,8,13.012652587890624 +sdd,False,False,128,8,64,16,8,8,12.412217712402343 +sdd,False,False,128,8,128,16,2,8,12.105235290527345 +sdd,False,False,128,16,8,16,8,2,33.21717529296875 +sdd,False,False,128,16,16,16,16,4,19.459730529785155 +sdd,False,False,128,16,32,16,4,4,13.754547119140623 +sdd,False,False,128,16,64,8,2,8,12.522496032714844 +sdd,False,False,128,16,128,16,8,8,12.006649780273438 +sdd,False,False,128,32,8,2,2,2,inf +sdd,False,False,128,32,16,2,2,2,inf +sdd,False,False,128,32,32,2,2,2,inf +sdd,False,False,128,32,64,2,2,2,inf +sdd,False,False,128,32,128,2,2,2,inf +sdd,False,False,128,64,8,2,2,2,inf +sdd,False,False,128,64,16,2,2,2,inf +sdd,False,False,128,64,32,2,2,2,inf +sdd,False,False,128,64,64,2,2,2,inf +sdd,False,False,128,64,128,2,2,2,inf +sdd,False,False,128,128,8,2,2,2,inf +sdd,False,False,128,128,16,2,2,2,inf +sdd,False,False,128,128,32,2,2,2,inf +sdd,False,False,128,128,64,2,2,2,inf +sdd,False,False,128,128,128,2,2,2,inf +sdd,False,True,8,8,8,2,2,2,171.00050048828126 +sdd,False,True,8,8,16,2,2,8,163.78897705078126 +sdd,False,True,8,8,32,2,8,16,159.98873291015624 +sdd,False,True,8,8,64,8,2,16,164.0970947265625 +sdd,False,True,8,8,128,8,4,16,183.36787109375 +sdd,False,True,8,16,8,2,2,2,117.20167236328123 +sdd,False,True,8,16,16,2,2,2,81.7120361328125 +sdd,False,True,8,16,32,8,16,2,79.3997314453125 +sdd,False,True,8,16,64,4,16,8,78.4077392578125 +sdd,False,True,8,16,128,4,2,8,77.11439819335938 +sdd,False,True,8,32,8,2,2,2,156.24798583984375 +sdd,False,True,8,32,16,2,8,2,84.08251953125 +sdd,False,True,8,32,32,4,2,2,64.31005249023437 +sdd,False,True,8,32,64,4,8,4,58.52037353515625 +sdd,False,True,8,32,128,8,8,2,58.091217041015625 +sdd,False,True,8,64,8,2,2,2,220.189990234375 +sdd,False,True,8,64,16,2,2,2,114.05440673828124 +sdd,False,True,8,64,32,4,8,2,95.51515502929688 +sdd,False,True,8,64,64,4,4,2,93.97026977539062 +sdd,False,True,8,64,128,4,16,2,97.15015869140623 +sdd,False,True,8,128,8,2,2,2,inf +sdd,False,True,8,128,16,2,16,2,181.31094970703128 +sdd,False,True,8,128,32,4,8,2,164.9397705078125 +sdd,False,True,8,128,64,4,16,2,171.19639892578124 +sdd,False,True,8,128,128,2,2,2,inf +sdd,False,True,16,8,8,4,2,2,86.89203491210938 +sdd,False,True,16,8,16,8,8,4,81.80593872070312 +sdd,False,True,16,8,32,8,4,8,79.93838500976562 +sdd,False,True,16,8,64,8,8,8,101.238623046875 +sdd,False,True,16,8,128,8,4,16,102.6164306640625 +sdd,False,True,16,16,8,2,2,2,62.99890747070312 +sdd,False,True,16,16,16,4,4,2,42.55126953125 +sdd,False,True,16,16,32,4,4,4,40.209408569335935 +sdd,False,True,16,16,64,2,16,16,40.08545837402344 +sdd,False,True,16,16,128,4,16,16,40.05580749511719 +sdd,False,True,16,32,8,2,2,2,83.31412963867187 +sdd,False,True,16,32,16,4,4,2,53.95750732421875 +sdd,False,True,16,32,32,8,2,2,38.29801025390625 +sdd,False,True,16,32,64,8,4,4,33.32770690917969 +sdd,False,True,16,32,128,8,2,4,33.29131164550781 +sdd,False,True,16,64,8,2,8,2,97.72012329101562 +sdd,False,True,16,64,16,4,4,2,68.52507934570312 +sdd,False,True,16,64,32,8,2,2,54.92421264648438 +sdd,False,True,16,64,64,8,8,2,52.87239379882813 +sdd,False,True,16,64,128,8,16,2,53.95496826171875 +sdd,False,True,16,128,8,2,2,2,131.61343994140626 +sdd,False,True,16,128,16,4,8,2,103.31490478515624 +sdd,False,True,16,128,32,4,16,2,92.17283935546877 +sdd,False,True,16,128,64,4,4,2,92.34268188476562 +sdd,False,True,16,128,128,2,2,2,inf +sdd,False,True,32,8,8,8,2,2,46.26974182128906 +sdd,False,True,32,8,16,2,8,8,41.85513305664063 +sdd,False,True,32,8,32,16,4,8,49.90415954589844 +sdd,False,True,32,8,64,2,4,16,49.80457458496094 +sdd,False,True,32,8,128,8,8,16,54.3072265625 +sdd,False,True,32,16,8,4,2,2,46.49102783203125 +sdd,False,True,32,16,16,8,8,2,28.61341552734375 +sdd,False,True,32,16,32,8,8,4,21.76053466796875 +sdd,False,True,32,16,64,16,8,4,21.050335693359376 +sdd,False,True,32,16,128,8,4,4,21.02478790283203 +sdd,False,True,32,32,8,4,2,2,63.95247802734375 +sdd,False,True,32,32,16,8,4,2,38.48013610839844 +sdd,False,True,32,32,32,16,8,2,25.52860107421875 +sdd,False,True,32,32,64,8,2,4,22.682418823242188 +sdd,False,True,32,32,128,8,16,4,22.095289611816405 +sdd,False,True,32,64,8,4,4,2,70.35035400390625 +sdd,False,True,32,64,16,8,2,2,46.534426879882815 +sdd,False,True,32,64,32,16,2,2,34.376165771484374 +sdd,False,True,32,64,64,8,4,4,33.43510131835937 +sdd,False,True,32,64,128,8,4,4,31.86273193359375 +sdd,False,True,32,128,8,2,2,2,inf +sdd,False,True,32,128,16,2,2,2,inf +sdd,False,True,32,128,32,2,2,2,inf +sdd,False,True,32,128,64,2,2,2,inf +sdd,False,True,32,128,128,2,2,2,inf +sdd,False,True,64,8,8,16,2,2,34.15183715820312 +sdd,False,True,64,8,16,8,4,4,23.15445709228516 +sdd,False,True,64,8,32,8,4,4,21.65299224853516 +sdd,False,True,64,8,64,16,4,4,25.66669616699219 +sdd,False,True,64,8,128,8,2,8,29.09327392578125 +sdd,False,True,64,16,8,8,4,2,37.92826538085937 +sdd,False,True,64,16,16,16,16,2,21.889523315429688 +sdd,False,True,64,16,32,16,4,4,16.37224884033203 +sdd,False,True,64,16,64,8,4,8,15.151309204101562 +sdd,False,True,64,16,128,8,4,8,14.999139404296876 +sdd,False,True,64,32,8,8,8,2,54.12269897460938 +sdd,False,True,64,32,16,16,4,2,30.51377868652344 +sdd,False,True,64,32,32,16,8,2,20.60735321044922 +sdd,False,True,64,32,64,16,2,4,17.45869140625 +sdd,False,True,64,32,128,8,2,8,16.512409973144532 +sdd,False,True,64,64,8,2,2,2,inf +sdd,False,True,64,64,16,2,2,2,inf +sdd,False,True,64,64,32,2,2,2,inf +sdd,False,True,64,64,64,2,2,2,inf +sdd,False,True,64,64,128,2,2,2,inf +sdd,False,True,64,128,8,2,2,2,inf +sdd,False,True,64,128,16,2,2,2,inf +sdd,False,True,64,128,32,2,2,2,inf +sdd,False,True,64,128,64,2,2,2,inf +sdd,False,True,64,128,128,2,2,2,inf +sdd,False,True,128,8,8,16,2,4,29.791290283203125 +sdd,False,True,128,8,16,16,2,4,16.895657348632813 +sdd,False,True,128,8,32,16,4,8,14.107527160644532 +sdd,False,True,128,8,64,8,8,8,13.506764221191409 +sdd,False,True,128,8,128,16,2,8,13.45937957763672 +sdd,False,True,128,16,8,16,16,2,33.48744201660156 +sdd,False,True,128,16,16,16,16,4,20.25232391357422 +sdd,False,True,128,16,32,16,2,4,14.85677490234375 +sdd,False,True,128,16,64,8,2,8,13.494082641601562 +sdd,False,True,128,16,128,16,4,8,13.044131469726562 +sdd,False,True,128,32,8,2,2,2,inf +sdd,False,True,128,32,16,2,2,2,inf +sdd,False,True,128,32,32,2,2,2,inf +sdd,False,True,128,32,64,2,2,2,inf +sdd,False,True,128,32,128,2,2,2,inf +sdd,False,True,128,64,8,2,2,2,inf +sdd,False,True,128,64,16,2,2,2,inf +sdd,False,True,128,64,32,2,2,2,inf +sdd,False,True,128,64,64,2,2,2,inf +sdd,False,True,128,64,128,2,2,2,inf +sdd,False,True,128,128,8,2,2,2,inf +sdd,False,True,128,128,16,2,2,2,inf +sdd,False,True,128,128,32,2,2,2,inf +sdd,False,True,128,128,64,2,2,2,inf +sdd,False,True,128,128,128,2,2,2,inf +sdd,True,False,8,8,8,2,2,2,107.84891357421876 +sdd,True,False,8,8,16,4,4,2,67.75611572265625 +sdd,True,False,8,8,32,4,4,4,59.0141845703125 +sdd,True,False,8,8,64,8,2,4,57.180572509765625 +sdd,True,False,8,8,128,2,2,2,inf +sdd,True,False,8,16,8,2,2,2,109.46407470703124 +sdd,True,False,8,16,16,2,16,2,61.050567626953125 +sdd,True,False,8,16,32,4,16,4,58.43782348632813 +sdd,True,False,8,16,64,2,2,8,57.12587890625 +sdd,True,False,8,16,128,8,8,4,55.89215087890625 +sdd,True,False,8,32,8,2,2,2,120.77034912109374 +sdd,True,False,8,32,16,4,16,2,75.55987548828125 +sdd,True,False,8,32,32,2,4,4,57.96864013671875 +sdd,True,False,8,32,64,8,16,2,58.12390747070312 +sdd,True,False,8,32,128,4,2,8,44.92681884765625 +sdd,True,False,8,64,8,2,2,2,122.9762451171875 +sdd,True,False,8,64,16,4,4,2,70.36700439453125 +sdd,True,False,8,64,32,4,8,2,59.4720458984375 +sdd,True,False,8,64,64,8,16,2,36.19658203125 +sdd,True,False,8,64,128,4,4,2,39.398468017578125 +sdd,True,False,8,128,8,2,2,2,119.67017822265623 +sdd,True,False,8,128,16,2,8,2,65.66317138671874 +sdd,True,False,8,128,32,4,2,2,58.19050903320313 +sdd,True,False,8,128,64,2,4,2,60.50167236328125 +sdd,True,False,8,128,128,2,2,2,inf +sdd,True,False,16,8,8,4,2,2,85.56646118164062 +sdd,True,False,16,8,16,4,8,2,42.578939819335936 +sdd,True,False,16,8,32,8,8,2,30.206103515625 +sdd,True,False,16,8,64,4,8,8,27.041256713867188 +sdd,True,False,16,8,128,8,2,8,26.4382080078125 +sdd,True,False,16,16,8,2,2,2,73.14339599609374 +sdd,True,False,16,16,16,4,4,2,40.42448120117187 +sdd,True,False,16,16,32,4,8,4,29.81109619140625 +sdd,True,False,16,16,64,4,8,8,27.577120971679687 +sdd,True,False,16,16,128,16,16,4,28.35008850097656 +sdd,True,False,16,32,8,2,2,2,87.06130981445312 +sdd,True,False,16,32,16,4,4,2,48.084799194335936 +sdd,True,False,16,32,32,8,4,2,30.470199584960938 +sdd,True,False,16,32,64,16,8,2,29.42889404296875 +sdd,True,False,16,32,128,8,2,2,27.97276306152344 +sdd,True,False,16,64,8,2,2,2,87.30538940429688 +sdd,True,False,16,64,16,4,8,2,47.590399169921874 +sdd,True,False,16,64,32,8,16,2,31.03179016113281 +sdd,True,False,16,64,64,16,16,2,30.02081298828125 +sdd,True,False,16,64,128,8,2,2,25.25854034423828 +sdd,True,False,16,128,8,2,2,2,86.09505615234374 +sdd,True,False,16,128,16,4,8,2,46.5567138671875 +sdd,True,False,16,128,32,4,16,2,33.30592651367188 +sdd,True,False,16,128,64,4,8,2,32.04450988769531 +sdd,True,False,16,128,128,2,2,2,inf +sdd,True,False,32,8,8,8,2,2,99.80089721679688 +sdd,True,False,32,8,16,8,4,2,39.57965393066407 +sdd,True,False,32,8,32,8,8,4,24.052531433105468 +sdd,True,False,32,8,64,16,8,4,17.61296691894531 +sdd,True,False,32,8,128,16,8,8,15.401779174804688 +sdd,True,False,32,16,8,4,2,2,73.40809936523438 +sdd,True,False,32,16,16,8,8,2,39.44325256347656 +sdd,True,False,32,16,32,8,8,4,23.993994140625 +sdd,True,False,32,16,64,16,8,4,17.59660186767578 +sdd,True,False,32,16,128,16,8,4,16.005331420898436 +sdd,True,False,32,32,8,4,2,2,88.47379150390626 +sdd,True,False,32,32,16,8,16,2,47.06881713867188 +sdd,True,False,32,32,32,16,2,2,25.58919677734375 +sdd,True,False,32,32,64,8,4,4,19.35011901855469 +sdd,True,False,32,32,128,8,4,4,16.88312683105469 +sdd,True,False,32,64,8,4,8,2,87.8712158203125 +sdd,True,False,32,64,16,8,16,2,46.817807006835935 +sdd,True,False,32,64,32,16,2,2,26.837152099609376 +sdd,True,False,32,64,64,8,8,4,22.69001007080078 +sdd,True,False,32,64,128,8,8,4,19.504771423339843 +sdd,True,False,32,128,8,2,2,2,inf +sdd,True,False,32,128,16,2,2,2,inf +sdd,True,False,32,128,32,2,2,2,inf +sdd,True,False,32,128,64,2,2,2,inf +sdd,True,False,32,128,128,2,2,2,inf +sdd,True,False,64,8,8,16,2,2,157.8266845703125 +sdd,True,False,64,8,16,8,2,4,49.07491149902344 +sdd,True,False,64,8,32,8,8,8,28.78367919921875 +sdd,True,False,64,8,64,8,4,8,20.032687377929687 +sdd,True,False,64,8,128,16,8,8,15.866058349609377 +sdd,True,False,64,16,8,8,8,2,98.84671020507812 +sdd,True,False,64,16,16,16,4,2,50.58737487792969 +sdd,True,False,64,16,32,16,16,4,28.727398681640626 +sdd,True,False,64,16,64,16,8,8,20.045263671875 +sdd,True,False,64,16,128,16,16,4,16.233351135253905 +sdd,True,False,64,32,8,8,8,2,115.4033447265625 +sdd,True,False,64,32,16,16,2,2,58.267724609375 +sdd,True,False,64,32,32,16,8,2,31.884262084960938 +sdd,True,False,64,32,64,16,2,4,21.14292755126953 +sdd,True,False,64,32,128,8,2,8,16.677273559570313 +sdd,True,False,64,64,8,2,2,2,inf +sdd,True,False,64,64,16,2,2,2,inf +sdd,True,False,64,64,32,2,2,2,inf +sdd,True,False,64,64,64,2,2,2,inf +sdd,True,False,64,64,128,2,2,2,inf +sdd,True,False,64,128,8,2,2,2,inf +sdd,True,False,64,128,16,2,2,2,inf +sdd,True,False,64,128,32,2,2,2,inf +sdd,True,False,64,128,64,2,2,2,inf +sdd,True,False,64,128,128,2,2,2,inf +sdd,True,False,128,8,8,2,2,2,inf +sdd,True,False,128,8,16,16,4,4,79.24072265625 +sdd,True,False,128,8,32,8,4,8,43.543756103515626 +sdd,True,False,128,8,64,16,8,8,27.0137451171875 +sdd,True,False,128,8,128,16,8,8,19.31786193847656 +sdd,True,False,128,16,8,16,2,2,162.3300537109375 +sdd,True,False,128,16,16,16,4,2,83.507421875 +sdd,True,False,128,16,32,16,8,4,44.29751586914063 +sdd,True,False,128,16,64,16,4,8,27.23143615722656 +sdd,True,False,128,16,128,16,8,8,19.55742034912109 +sdd,True,False,128,32,8,2,2,2,inf +sdd,True,False,128,32,16,2,2,2,inf +sdd,True,False,128,32,32,2,2,2,inf +sdd,True,False,128,32,64,2,2,2,inf +sdd,True,False,128,32,128,2,2,2,inf +sdd,True,False,128,64,8,2,2,2,inf +sdd,True,False,128,64,16,2,2,2,inf +sdd,True,False,128,64,32,2,2,2,inf +sdd,True,False,128,64,64,2,2,2,inf +sdd,True,False,128,64,128,2,2,2,inf +sdd,True,False,128,128,8,2,2,2,inf +sdd,True,False,128,128,16,2,2,2,inf +sdd,True,False,128,128,32,2,2,2,inf +sdd,True,False,128,128,64,2,2,2,inf +sdd,True,False,128,128,128,2,2,2,inf +sdd,True,True,8,8,8,2,2,2,119.818115234375 +sdd,True,True,8,8,16,4,4,2,95.63402099609377 +sdd,True,True,8,8,32,2,8,8,103.4179443359375 +sdd,True,True,8,8,64,8,8,16,155.1286376953125 +sdd,True,True,8,8,128,8,4,16,168.80965576171874 +sdd,True,True,8,16,8,2,2,2,129.49100341796876 +sdd,True,True,8,16,16,2,8,2,90.29096069335938 +sdd,True,True,8,16,32,8,4,2,85.05003051757812 +sdd,True,True,8,16,64,4,4,8,81.44771728515624 +sdd,True,True,8,16,128,4,16,8,79.3880615234375 +sdd,True,True,8,32,8,2,2,2,167.95340576171876 +sdd,True,True,8,32,16,2,4,2,85.97968139648438 +sdd,True,True,8,32,32,4,4,2,65.43773803710937 +sdd,True,True,8,32,64,8,4,2,60.2838134765625 +sdd,True,True,8,32,128,8,4,4,59.63026733398438 +sdd,True,True,8,64,8,2,2,2,232.9419189453125 +sdd,True,True,8,64,16,2,2,2,116.23697509765626 +sdd,True,True,8,64,32,4,8,2,96.583154296875 +sdd,True,True,8,64,64,8,16,2,93.96688842773438 +sdd,True,True,8,64,128,4,2,2,96.8304443359375 +sdd,True,True,8,128,8,2,2,2,inf +sdd,True,True,8,128,16,2,2,2,182.8800537109375 +sdd,True,True,8,128,32,4,8,2,165.28294677734374 +sdd,True,True,8,128,64,4,4,2,172.30274658203126 +sdd,True,True,8,128,128,2,2,2,inf +sdd,True,True,16,8,8,4,2,2,91.8294677734375 +sdd,True,True,16,8,16,4,2,2,50.0265380859375 +sdd,True,True,16,8,32,8,8,4,52.8593017578125 +sdd,True,True,16,8,64,16,4,8,64.15350341796875 +sdd,True,True,16,8,128,16,8,8,79.35950927734375 +sdd,True,True,16,16,8,2,2,2,74.62520751953124 +sdd,True,True,16,16,16,4,16,2,47.85111083984375 +sdd,True,True,16,16,32,8,16,2,42.73371887207031 +sdd,True,True,16,16,64,8,4,4,38.24332275390625 +sdd,True,True,16,16,128,2,2,16,41.39579162597656 +sdd,True,True,16,32,8,2,2,2,95.1736328125 +sdd,True,True,16,32,16,4,16,2,59.775384521484376 +sdd,True,True,16,32,32,8,2,2,41.61084594726562 +sdd,True,True,16,32,64,8,2,4,34.766482543945315 +sdd,True,True,16,32,128,8,4,4,33.867242431640626 +sdd,True,True,16,64,8,2,2,2,110.17044677734376 +sdd,True,True,16,64,16,4,16,2,74.63895263671876 +sdd,True,True,16,64,32,8,4,2,57.64906616210938 +sdd,True,True,16,64,64,8,8,2,54.14190673828125 +sdd,True,True,16,64,128,8,16,2,53.55643920898437 +sdd,True,True,16,128,8,2,4,2,142.96043701171874 +sdd,True,True,16,128,16,4,16,2,108.21922607421877 +sdd,True,True,16,128,32,4,8,2,95.71756591796876 +sdd,True,True,16,128,64,4,16,2,93.30078735351564 +sdd,True,True,16,128,128,2,2,2,inf +sdd,True,True,32,8,8,8,2,2,102.19821166992188 +sdd,True,True,32,8,16,8,4,2,42.9096923828125 +sdd,True,True,32,8,32,8,8,4,29.99979248046875 +sdd,True,True,32,8,64,4,8,16,31.118014526367187 +sdd,True,True,32,8,128,16,4,8,35.449188232421875 +sdd,True,True,32,16,8,4,16,2,74.87221069335938 +sdd,True,True,32,16,16,8,16,2,42.631622314453125 +sdd,True,True,32,16,32,8,16,4,28.22671813964844 +sdd,True,True,32,16,64,16,8,4,22.189605712890625 +sdd,True,True,32,16,128,8,8,4,21.13179168701172 +sdd,True,True,32,32,8,4,4,2,92.75965576171876 +sdd,True,True,32,32,16,8,8,2,52.53775634765625 +sdd,True,True,32,32,32,16,4,2,32.35062866210937 +sdd,True,True,32,32,64,16,2,4,25.97874145507813 +sdd,True,True,32,32,128,8,16,4,23.706874084472656 +sdd,True,True,32,64,8,4,8,2,100.84080810546877 +sdd,True,True,32,64,16,8,8,2,60.23905029296875 +sdd,True,True,32,64,32,16,8,2,41.790463256835935 +sdd,True,True,32,64,64,8,2,4,36.65054626464844 +sdd,True,True,32,64,128,8,8,4,33.875634765625 +sdd,True,True,32,128,8,2,2,2,inf +sdd,True,True,32,128,16,2,2,2,inf +sdd,True,True,32,128,32,2,2,2,inf +sdd,True,True,32,128,64,2,2,2,inf +sdd,True,True,32,128,128,2,2,2,inf +sdd,True,True,64,8,8,16,2,2,160.8046630859375 +sdd,True,True,64,8,16,8,2,4,50.83729248046875 +sdd,True,True,64,8,32,8,4,8,30.98486633300781 +sdd,True,True,64,8,64,16,2,8,22.053683471679687 +sdd,True,True,64,8,128,16,2,4,18.328111267089845 +sdd,True,True,64,16,8,8,16,2,99.50169677734377 +sdd,True,True,64,16,16,16,8,2,51.96105346679688 +sdd,True,True,64,16,32,16,8,4,31.01470642089844 +sdd,True,True,64,16,64,16,4,8,22.368418884277343 +sdd,True,True,64,16,128,16,2,8,18.470967102050786 +sdd,True,True,64,32,8,8,4,2,117.6858642578125 +sdd,True,True,64,32,16,16,4,2,61.45115966796875 +sdd,True,True,64,32,32,16,2,2,35.50120239257812 +sdd,True,True,64,32,64,16,8,4,24.213912963867188 +sdd,True,True,64,32,128,8,16,8,20.149658203125 +sdd,True,True,64,64,8,2,2,2,inf +sdd,True,True,64,64,16,2,2,2,inf +sdd,True,True,64,64,32,2,2,2,inf +sdd,True,True,64,64,64,2,2,2,inf +sdd,True,True,64,64,128,2,2,2,inf +sdd,True,True,64,128,8,2,2,2,inf +sdd,True,True,64,128,16,2,2,2,inf +sdd,True,True,64,128,32,2,2,2,inf +sdd,True,True,64,128,64,2,2,2,inf +sdd,True,True,64,128,128,2,2,2,inf +sdd,True,True,128,8,8,2,2,2,inf +sdd,True,True,128,8,16,16,4,4,80.0758056640625 +sdd,True,True,128,8,32,8,4,8,44.93201599121094 +sdd,True,True,128,8,64,16,4,8,28.457400512695312 +sdd,True,True,128,8,128,16,4,8,20.684226989746094 +sdd,True,True,128,16,8,16,2,2,163.924169921875 +sdd,True,True,128,16,16,16,4,2,84.27235717773438 +sdd,True,True,128,16,32,16,16,4,45.40900268554688 +sdd,True,True,128,16,64,16,2,8,28.49513854980469 +sdd,True,True,128,16,128,16,4,8,20.79976348876953 +sdd,True,True,128,32,8,2,2,2,inf +sdd,True,True,128,32,16,2,2,2,inf +sdd,True,True,128,32,32,2,2,2,inf +sdd,True,True,128,32,64,2,2,2,inf +sdd,True,True,128,32,128,2,2,2,inf +sdd,True,True,128,64,8,2,2,2,inf +sdd,True,True,128,64,16,2,2,2,inf +sdd,True,True,128,64,32,2,2,2,inf +sdd,True,True,128,64,64,2,2,2,inf +sdd,True,True,128,64,128,2,2,2,inf +sdd,True,True,128,128,8,2,2,2,inf +sdd,True,True,128,128,16,2,2,2,inf +sdd,True,True,128,128,32,2,2,2,inf +sdd,True,True,128,128,64,2,2,2,inf +sdd,True,True,128,128,128,2,2,2,inf diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.75.csv b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.75.csv new file mode 100644 index 00000000..c9c0d8ec --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.75.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,1,2.932566452026367 +8,16,1,1.9577472686767576 +8,32,8,1.262806415557861 +8,64,4,1.323628807067871 +8,128,8,1.3792127609252929 +16,8,1,2.8719263076782227 +16,16,1,2.08015365600586 +16,32,16,1.258345603942871 +16,64,16,1.326899242401123 +16,128,16,1.3656064033508302 +32,8,1,3.0519296646118166 +32,16,1,1.8957759857177732 +32,32,16,1.257376003265381 +32,64,16,1.3052191734313965 +32,128,16,1.350175952911377 +64,8,1,2.818243217468262 +64,16,1,1.8339935302734376 +64,32,16,1.2456512451171875 +64,64,16,1.295359992980957 +64,128,16,1.386297607421875 +128,8,1,2.712630462646485 +128,16,1,1.7739871978759765 +128,32,16,1.2388575553894043 +128,64,16,1.31909122467041 +128,128,8,1.383884811401367 diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.75.csv b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.75.csv new file mode 100644 index 00000000..6c6f258d --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.75.csv @@ -0,0 +1,26 @@ +BH,BW,RT,latency +8,8,8,1.47172155380249 +8,16,8,1.378435230255127 +8,32,8,1.3317407608032226 +8,64,4,1.3999648094177246 +8,128,8,1.4530559539794925 +16,8,16,1.4786432266235352 +16,16,16,1.3883392333984377 +16,32,8,1.337548828125 +16,64,4,1.4118751525878903 +16,128,8,1.435654354095459 +32,8,16,1.4958016395568847 +32,16,16,1.380352020263672 +32,32,8,1.332419204711914 +32,64,4,1.3878047943115237 +32,128,8,1.4331423759460449 +64,8,16,1.4967616081237791 +64,16,16,1.378217601776123 +64,32,8,1.3172800064086914 +64,64,4,1.3741184234619142 +64,128,4,1.4449919700622558 +128,8,16,1.5054783821105957 +128,16,16,1.3612863540649414 +128,32,8,1.3031231880187988 +128,64,4,1.3879072189331054 +128,128,8,1.4422240257263184 From 530e530894fd65c25e873107d22c7cc426ad72f2 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Mon, 6 Feb 2023 05:25:29 +0000 Subject: [PATCH 11/28] fix kernel.set_parameters() --- sparta/specializer/kernels/matmul.py | 2 +- sparta/specializer/kernels/softmax.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 18da4f15..00bf6190 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -268,8 +268,8 @@ def set_parameters(self, params: Dict[str, Any]): BN_filter = self._lut['BN'] == BN row = self._lut[BM_filter & BK_filter & BN_filter] assert len(row) > 0, f'block shape ({BM}, {BK}, {BN}) not found in LUT' - assert float(row['latency']) < float('inf'), f'block shape ({BM}, {BK}, {BN}) is invalid' row = row.reset_index(drop=True).iloc[0, :] + assert float(row['latency']) < float('inf'), f'block shape ({BM}, {BK}, {BN}) is invalid' TM, TK, TN = row['TM'], row['TK'], row['TN'] self.set_parameter('THREAD_SIZE_M_VALUE', int(TM)) self.set_parameter('THREAD_SIZE_K_VALUE', int(TK)) diff --git a/sparta/specializer/kernels/softmax.py b/sparta/specializer/kernels/softmax.py index 1181e930..5f90a481 100644 --- a/sparta/specializer/kernels/softmax.py +++ b/sparta/specializer/kernels/softmax.py @@ -214,8 +214,8 @@ def set_parameters(self, params: Dict[str, Any]): BW_filter = self._lut['BW'] == BW row = self._lut[BH_filter & BW_filter] assert len(row) > 0, f'block shape ({BH}, {BW}) not found in LUT' - assert row['latency'] < float('inf'), f'block shape ({BH}, {BW}) is invalid' row = row.reset_index(drop=True).iloc[0, :] + assert float(row['latency']) < float('inf'), f'block shape ({BH}, {BW}) is invalid' self.set_parameter('ROW_TILE_VALUE', int(row['RT'])) def blocks_per_grid(self): From 66f1748826381c31c5300f36447a3b2cacd495f7 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 21 Feb 2023 17:02:10 +0800 Subject: [PATCH 12/28] refactoring: functional --- docs/1-code-specializer.md | 4 +- sparta/nn/module_tuner.py | 2 +- sparta/specializer/functional/__init__.py | 6 + sparta/specializer/functional/batch_matmul.py | 376 ++++++++++++++++++ .../specializer/functional/batch_softmax.py | 188 +++++++++ .../specializer/functional/function_base.py | 201 ++++++++++ sparta/specializer/funtional/__init__.py | 6 - sparta/specializer/funtional/batch_matmul.py | 179 --------- sparta/specializer/funtional/batch_softmax.py | 103 ----- .../specializer/funtional/sparse_ctx_base.py | 157 -------- sparta/specializer/kernels/__init__.py | 2 +- sparta/specializer/kernels/kernel_base.py | 105 +---- .../look_up_tables/matmul.openai.61.csv | 13 + .../look_up_tables/matmul.openai.default.csv | 13 + sparta/specializer/kernels/matmul.py | 297 ++++---------- sparta/specializer/kernels/softmax.py | 273 +++++-------- .../templates/openai_sparse_matmul_dds.cuh.j2 | 2 + .../templates/openai_sparse_matmul_dsd.cuh.j2 | 2 + .../templates/openai_sparse_matmul_sdd.cuh.j2 | 2 + .../templates/sparta_sparse_matmul_dds.cuh.j2 | 2 + .../templates/sparta_sparse_matmul_dsd.cuh.j2 | 2 + .../templates/sparta_sparse_matmul_sdd.cuh.j2 | 2 + .../sparta_sparse_softmax_backward.cuh.j2 | 25 +- .../sparta_sparse_softmax_forward.cuh.j2 | 29 +- sparta/specializer/operators/operator_base.py | 2 +- .../seqlen_dynamic_sparse_attention.py | 2 +- .../specializer/operators/sparse_attention.py | 2 +- sparta/specializer/operators/sparse_linear.py | 4 +- sparta/specializer/operators/sparse_matmul.py | 4 +- sparta/specializer/operators/sparse_moe.py | 2 +- .../specializer/operators/sparse_softmax.py | 4 +- test/lut_maker/matmul.py | 127 ++++++ test/lut_maker/softmax.py | 106 +++++ test/lut_maker/sparta_matmul.py | 124 ------ test/lut_maker/sparta_softmax.py | 122 ------ test/unit/test_sparse_matmul.py | 356 +++++++++-------- test/unit/test_sparse_softmax.py | 172 ++++---- 37 files changed, 1572 insertions(+), 1446 deletions(-) create mode 100644 sparta/specializer/functional/__init__.py create mode 100644 sparta/specializer/functional/batch_matmul.py create mode 100644 sparta/specializer/functional/batch_softmax.py create mode 100644 sparta/specializer/functional/function_base.py delete mode 100644 sparta/specializer/funtional/__init__.py delete mode 100644 sparta/specializer/funtional/batch_matmul.py delete mode 100644 sparta/specializer/funtional/batch_softmax.py delete mode 100644 sparta/specializer/funtional/sparse_ctx_base.py create mode 100644 sparta/specializer/kernels/look_up_tables/matmul.openai.61.csv create mode 100644 sparta/specializer/kernels/look_up_tables/matmul.openai.default.csv create mode 100644 test/lut_maker/matmul.py create mode 100644 test/lut_maker/softmax.py delete mode 100644 test/lut_maker/sparta_matmul.py delete mode 100644 test/lut_maker/sparta_softmax.py diff --git a/docs/1-code-specializer.md b/docs/1-code-specializer.md index 7cccd8f1..a2ff27c0 100644 --- a/docs/1-code-specializer.md +++ b/docs/1-code-specializer.md @@ -11,8 +11,8 @@ To balance between the flexibility, performance, and developing efficiency, we a | Layer | Base Class | Role | | :- | :- | :- | | Sparse Operator | [`sparta.nn.OperatorBase`](reference/nn.rst) | User interface as `torch.nn.Module` | -| Sparse Context | `sparta.specializer.funtional.SparseCtxBase` | Function context to interact with `torch.autograd.Function` | -| Sparse Kernel Placeholder | `sparta.specializer.funtional.KernelPlaceholder` | Collection of multiple kernel implementations | +| Sparse Context | `sparta.specializer.functional.SparseCtxBase` | Function context to interact with `torch.autograd.Function` | +| Sparse Kernel Placeholder | `sparta.specializer.functional.KernelPlaceholder` | Collection of multiple kernel implementations | | Sparse Kernel | `sparta.specializer.kernels.KernelBase` | Tunable sparse CUDA kernel interface | ## Generating CUDA Codes diff --git a/sparta/nn/module_tuner.py b/sparta/nn/module_tuner.py index fa851a1e..122b36d1 100644 --- a/sparta/nn/module_tuner.py +++ b/sparta/nn/module_tuner.py @@ -76,7 +76,7 @@ def lower_search(upper_idx: int, upper_params: Dict[Any, Any]): def try_params(lower_idx: int, params: Dict[Any, Any]): try: kernel.build(params) - latency = kernel.test() + latency = kernel.profile() except AssertionError: latency = np.inf _logger.info(f'{impl} #{lower_idx}: {list(params.values())} => {latency} ms') diff --git a/sparta/specializer/functional/__init__.py b/sparta/specializer/functional/__init__.py new file mode 100644 index 00000000..f6bc7d35 --- /dev/null +++ b/sparta/specializer/functional/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from sparta.specializer.functional.function_base import Port, SparsityAttr, SparseFunctionBase +from sparta.specializer.functional.batch_matmul import SparseMatMul, SparseBatchMatMul +from sparta.specializer.functional.batch_softmax import SparseSoftmax, SparseBatchSoftmax diff --git a/sparta/specializer/functional/batch_matmul.py b/sparta/specializer/functional/batch_matmul.py new file mode 100644 index 00000000..c017d27b --- /dev/null +++ b/sparta/specializer/functional/batch_matmul.py @@ -0,0 +1,376 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Any, List, Dict, Tuple, Optional + +import torch +import numpy as np + +from sparta.specializer.kernels import KernelBase, SparTASparseMatMulKernel, OpenAISparseMatMulKernel +from sparta.specializer.functional.function_base import Port, SparsityAttr, SparseFunctionBase, SparseAutoGradFunction + + +class SparseBatchMatMulForward(SparseFunctionBase): + + __batched__ = True + + def __init__( + self, + mode: str, + transpose_A: bool, + transpose_B: bool, + biased: bool, + compressed: bool, + ): + if mode not in ['sdd', 'dsd', 'dds']: + raise ValueError(f'invalid sparse matmul mode: {mode}') + + super().__init__() + + self._transpose_A = transpose_A + self._transpose_B = transpose_B + self._biased = biased + self._compressed = compressed + + self._sparse_axis = { + 'sdd': ['K', 'M'] if transpose_A else ['M', 'K'], + 'dsd': ['N', 'K'] if transpose_B else ['K', 'N'], + 'dds': ['M', 'N'], + }[mode] + self._BCSR = { + 'sdd': not transpose_A, + 'dsd': transpose_B, + 'dds': True, + }[mode] + + self._sparse_port = 'ABC'[mode.find('s')] + self.ports['A'] = Port(self, 'A') + self.ports['B'] = Port(self, 'B') + self.ports['C'] = Port(self, 'C', fine_mask=False) # DDS known issue + if biased: + self.ports['bias'] = Port(self, 'bias') + self.ports[self._sparse_port].attr = SparsityAttr(self._BCSR, not self._BCSR) + + specs = { + 'mode': mode, + 'biased': biased, + 'transpose_A': transpose_A, + 'transpose_B': transpose_B, + 'compressed': compressed, + 'batched': self.__batched__, + } + self.kernels['forward'] = { + 'sparta': SparTASparseMatMulKernel(**specs), + 'openai': OpenAISparseMatMulKernel(**specs), + } + + self.shape: Tuple[int, int, int, int] = None + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + # TODO: check shape conflicts + if self.__batched__: + batch_size = sample_inputs[0].shape[0] + else: + batch_size = 1 + if self._transpose_A: + K, M = sample_inputs[0].shape[-2:] + else: + M, K = sample_inputs[0].shape[-2:] + if self._transpose_B: + N, K = sample_inputs[1].shape[-2:] + else: + K, N = sample_inputs[1].shape[-2:] + self.shape = (batch_size, M, K, N) + self.ports['A'].set_data(sample_inputs[0]) + self.ports['B'].set_data(sample_inputs[1]) + if self._biased: + self.ports['bias'].set_data(sample_inputs[2]) + + def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): + sparse_attr = self.get_sparse_attr() + kernel.set_parameter('BCSR', self._BCSR or sparse_attr.BCSR) + kernel.set_parameter('BCSC', not self._BCSR) + block_size = [params[f'BLOCK_SIZE_{axis}_VALUE'] for axis in self._sparse_axis] + sparse_attr.set_block_size(*block_size) + kernel.compile(params, self.shape, sparse_attr) + + def _set_forward(self): + if 'forward' in self._compiled_kernels: + self.forward = self._compiled_kernels['forward'] + + def _kernel_func_call(self, kernel_name: str): + A = self.ports['A'].get_data(compressed=self._compressed) + B = self.ports['B'].get_data(compressed=self._compressed) + kernel = self._compiled_kernels[kernel_name] + if self._biased: + bias = self.ports['bias'].get_data() + return lambda : kernel(A, B, bias) + else: + return lambda : kernel(A, B) + + def _kernel_reference(self, kernel_name: str): + return self.ports['C'].get_data(compressed=self._compressed) + + def _calc_kernel_flops(self, kernel_name: str): + indexes = self.get_sparse_attr().indexes + sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num + return np.prod(self.shape) * sparse_rate + + def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + if sample_inputs is not None: + self._read_sample_inputs(sample_inputs) + A = self.ports['A'].get_data(compressed=False) + B = self.ports['B'].get_data(compressed=False) + if self._transpose_A: + A = A.swapaxes(self.__batched__ + 0, self.__batched__ + 1) + if self._transpose_B: + B = B.swapaxes(self.__batched__ + 0, self.__batched__ + 1) + C = torch.bmm(A, B) if self.__batched__ else torch.mm(A, B) + if self._biased: + bias = self.ports['bias'].get_data() + C += bias.unsqueeze(1) if self.__batched__ else bias + self.ports['C'].set_data(C) + + +class SparseMatMulForward(SparseBatchMatMulForward): + + __batched__ = False + + +class SparseBatchMatMulBackward(SparseFunctionBase): + + __batched__ = True + + def __init__( + self, + mode: str, + transpose_A: bool, + transpose_B: bool, + biased: bool, + compressed: bool, + ports: Dict[str, Port], + ): + if mode not in ['sdd', 'dsd', 'dds']: + raise ValueError(f'invalid sparse matmul mode: {mode}') + + super().__init__() + + self._mode = mode + self._transpose_A = transpose_A + self._transpose_B = transpose_B + self._biased = biased + self._compressed = compressed + + self.ports = ports + self._sparse_port = 'ABC'[mode.find('s')] + + self._BCSR = { + 'backward:A': { + 'sdd': True, + 'dsd': not transpose_B, + 'dds': True, + }[mode], + 'backward:B': { + 'sdd': transpose_A, + 'dsd': True, + 'dds': False, + }[mode], + } + self.get_sparse_attr().update_axis(True, True) + + A_spec = { + 'mode': ''.join(mode[i] for i in ([1, 2, 0] if transpose_A else [2, 1, 0])), + 'biased': False, + 'transpose_A': transpose_A and transpose_B, + 'transpose_B': transpose_A or not transpose_B, + 'compressed': compressed, + 'batched': self.__batched__, + } + B_spec = { + 'mode': ''.join(mode[i] for i in ([2, 0, 1] if transpose_B else [0, 2, 1])), + 'biased': False, + 'transpose_A': not transpose_A or transpose_B, + 'transpose_B': transpose_A and transpose_B, + 'compressed': compressed, + 'batched': self.__batched__, + } + + self.kernels['backward:A'] = { + 'sparta': SparTASparseMatMulKernel(**A_spec), + 'openai': OpenAISparseMatMulKernel(**A_spec), + } + self.kernels['backward:B'] = { + 'sparta': SparTASparseMatMulKernel(**B_spec), + 'openai': OpenAISparseMatMulKernel(**B_spec), + } + + self.shape: Tuple[int, int, int, int] = None + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + pass + + def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): + batch, M, K, N = self.shape + shape = { + 'backward:A': (batch, K, N, M) if self._transpose_A else (batch, M, N, K), + 'backward:B': (batch, N, M, K) if self._transpose_B else (batch, K, M, N), + }[kernel_name] + sparse_attr = self.get_sparse_attr() + kernel.set_parameter('BCSR', self._BCSR[kernel_name] or sparse_attr.BCSR) + kernel.set_parameter('BCSC', not self._BCSR[kernel_name]) + kernel.compile(params, shape, sparse_attr) + + def _set_forward(self): + if 'backward:A' in self._compiled_kernels: + kernel_A = self._compiled_kernels['backward:A'] + else: + kernel_A = lambda *inputs: None + if 'backward:B' in self._compiled_kernels: + kernel_B = self._compiled_kernels['backward:B'] + else: + kernel_B = lambda *inputs: None + if self._transpose_A: + backward_A = lambda grad_C, B: kernel_A(B, grad_C) + else: + backward_A = lambda grad_C, B: kernel_A(grad_C, B) + if self._transpose_B: + backward_B = lambda grad_C, A: kernel_B(grad_C, A) + else: + backward_B = lambda grad_C, A: kernel_B(A, grad_C) + if self._mode == 'dds' and self._compressed: + C_indexes = self.ports['C'].attr.indexes + backward_bias = lambda grad_C: C_indexes.sum(grad_C, axis=-2) + else: + backward_bias = lambda grad_C: grad_C.sum(-2) + + def backward(grad, A, B, needs_grad): + grad_A, grad_B, grad_bias = None, None, None + if needs_grad[1]: + grad_A = backward_A(grad, B) + if needs_grad[2]: + grad_B = backward_B(grad, A) + if self._biased and needs_grad[3]: + grad_bias = backward_bias(grad) + return grad_A, grad_B, grad_bias + + self.forward = backward + + def _kernel_func_call(self, kernel_name: str): + grad_C = self.ports['C'].get_data(grad=True, compressed=self._compressed) + kernel = self._compiled_kernels[kernel_name] + if kernel_name == 'backward:A': + B = self.ports['B'].get_data(compressed=self._compressed) + if self._transpose_A: + return lambda : kernel(B, grad_C) + else: + return lambda : kernel(grad_C, B) + elif kernel_name == 'backward:B': + A = self.ports['A'].get_data(compressed=self._compressed) + if self._transpose_A: + return lambda : kernel(grad_C, A) + else: + return lambda : kernel(A, grad_C) + else: + raise ValueError(f'kernel not found: {kernel_name}') + + def _kernel_reference(self, kernel_name: str): + if kernel_name == 'backward:A': + return self.ports['A'].get_data(grad=True, compressed=False) + elif kernel_name == 'backward:B': + return self.ports['B'].get_data(grad=True, compressed=False) + else: + raise ValueError(f'kernel not found: {kernel_name}') + + def _calc_kernel_flops(self, kernel_name: str): + indexes = self.get_sparse_attr().indexes + sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num + return np.prod(self.shape) * sparse_rate + + def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + pass + + +class SparseMatMulBackward(SparseBatchMatMulBackward): + + __batched__ = False + + +class _SparseMatMul(torch.autograd.Function): + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + func: SparseAutoGradFunction, + *inputs, + ): + ctx.save_for_backward(inputs[0], inputs[1]) + ctx.backward = func.backward + return func.forward(*inputs) + + @staticmethod + def backward(ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor): + A, B = ctx.saved_tensors + return None, *ctx.backward(grad, A, B, ctx.needs_input_grad) + + +class SparseBatchMatMul(SparseAutoGradFunction, SparseBatchMatMulForward): + + __static_func__ = _SparseMatMul + + def __init__( + self, + mode: str, + transpose_A: bool, + transpose_B: bool, + biased: bool, + compressed: bool, + ): + super().__init__(mode, transpose_A, transpose_B, biased, compressed) + self.backward = SparseBatchMatMulBackward( + mode=mode, + transpose_A=transpose_A, + transpose_B=transpose_B, + biased=biased, + compressed=compressed, + ports=self.ports, + ) + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + super()._read_sample_inputs(sample_inputs) + self.backward.shape = self.shape + + def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): + self.ports['C'].set_data(sample_grads[0], grad=True) + self.ports['C'].get_data().backward(sample_grads[0]) + + +class SparseMatMul(SparseAutoGradFunction, SparseMatMulForward): + + __static_func__ = _SparseMatMul + + def __init__( + self, + mode: str, + transpose_A: bool, + transpose_B: bool, + biased: bool, + compressed: bool, + ): + super().__init__(mode, transpose_A, transpose_B, biased, compressed) + self.backward = SparseMatMulBackward( + mode=mode, + transpose_A=transpose_A, + transpose_B=transpose_B, + biased=biased, + compressed=compressed, + ports=self.ports, + ) + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + super()._read_sample_inputs(sample_inputs) + self.backward.shape = self.shape + + def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): + self.ports['C'].set_data(sample_grads[0], grad=True) + self.ports['C'].get_data().backward(sample_grads[0]) + diff --git a/sparta/specializer/functional/batch_softmax.py b/sparta/specializer/functional/batch_softmax.py new file mode 100644 index 00000000..d92457ed --- /dev/null +++ b/sparta/specializer/functional/batch_softmax.py @@ -0,0 +1,188 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Any, List, Dict, Tuple, Optional + +import torch +import numpy as np + +from sparta.specializer.kernels import KernelBase, SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel +from sparta.specializer.functional.function_base import Port, SparsityAttr, SparseFunctionBase, SparseAutoGradFunction +from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference + +class SparseBatchSoftmaxForward(SparseFunctionBase): + + __batched__ = True + __direction__ = 'forward' + + def __init__(self, compressed: bool, temperature: float = 1): + super().__init__() + + self._compressed = compressed + self._T = np.float32(1 / temperature) + + self._sparse_port = 'y' + sparse_attr = SparsityAttr(True, False) + for port_name in ['x', 'y']: + self.ports[port_name] = Port(self, port_name) + self.ports[port_name].attr = sparse_attr + + self.kernels[self.__direction__] = { + 'sparta': { + 'forward': SparTASparseSoftmaxForwardKernel, + 'backward': SparTASparseSoftmaxBackwardKernel, + }[self.__direction__]( + compressed=compressed, + batched=self.__batched__, + ), + } + + self.shape: Tuple[int, int, int] = None + + def set_temperature(self, temperature: float): + self._T = np.float32(1 / temperature) + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + x = sample_inputs[0] + if self.__batched__: + batch_size = x.shape[0] + else: + batch_size = 1 + H, W = x.shape[-2:] + self.shape = (batch_size, H, W) + self.ports['x'].set_data(x) + + def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): + sparse_attr = self.get_sparse_attr() + sparse_attr.set_block_size( + block_H=params['BLOCK_SIZE_H_VALUE'], + block_W=params['BLOCK_SIZE_W_VALUE'], + ) + kernel.set_parameter('MAX_W_VALUE', self.shape[-1]) + kernel.compile(params, self.shape, sparse_attr) + + def _set_forward(self): + if self.__direction__ in self._compiled_kernels: + kernel = self._compiled_kernels[self.__direction__] + sparse_attr = self.get_sparse_attr() + self.forward = lambda *inputs: kernel(*inputs, sparse_attr.mask, self._T) + + def _kernel_func_call(self, kernel_name: str): + x = self.ports['x'].get_data(compressed=self._compressed) + sparse_attr = self.get_sparse_attr() + kernel = self._compiled_kernels[kernel_name] + return lambda : kernel(x, sparse_attr.mask, self._T) + + def _kernel_reference(self, kernel_name: str): + return self.ports['y'].get_data(compressed=self._compressed) + + def _calc_kernel_flops(self, kernel_name: str): + indexes = self.get_sparse_attr().indexes + sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num + return np.prod(self.shape) * sparse_rate + + def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + if sample_inputs is not None: + self._read_sample_inputs(sample_inputs) + x = self.ports['x'].get_data(compressed=False) + mask = self.get_sparse_attr().mask + y = sparse_softmax_forward_reference(x, mask, 1 / self._T) + self.ports['y'].set_data(y) + + +class SparseSoftmaxForward(SparseBatchSoftmaxForward): + + __batched__ = False + + +class SparseBatchSoftmaxBackward(SparseBatchSoftmaxForward): + + __batched__ = True + __direction__ = 'backward' + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + pass + + def _kernel_func_call(self, kernel_name: str): + gy = self.ports['y'].get_data(grad=True, compressed=self._compressed) + y = self.ports['y'].get_data(grad=False, compressed=self._compressed) + sparse_attr = self.get_sparse_attr() + kernel = self._compiled_kernels[kernel_name] + return lambda : kernel(gy, y, sparse_attr.mask, self._T) + + def _kernel_reference(self, kernel_name: str): + return self.ports['x'].get_data(grad=True, compressed=self._compressed) + + def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + pass + + +class SparseSoftmaxBackward(SparseBatchSoftmaxBackward): + + __batched__ = False + + +class _SparseSoftmax(torch.autograd.Function): + + @staticmethod + def forward( + ctx: torch.autograd.function.FunctionCtx, + func: SparseAutoGradFunction, + x: torch.Tensor, + ): + ctx.backward = func.backward + y = func.forward(x) + ctx.save_for_backward(y) + return y + + @staticmethod + def backward(ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor): + y = ctx.saved_tensors[0] + if ctx.needs_input_grad[1]: + return None, ctx.backward(grad, y) + else: + return None, None + + +class SparseBatchSoftmax(SparseAutoGradFunction, SparseBatchSoftmaxForward): + + __static_func__ = _SparseSoftmax + + def __init__(self, compressed: bool, temperature: float = 1): + super().__init__(compressed, temperature) + self.backward = SparseBatchSoftmaxBackward(compressed, temperature) + self.backward.ports = self.ports + + def set_temperature(self, temperature: float): + super().set_temperature(temperature) + self.backward.set_temperature(temperature) + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + super()._read_sample_inputs(sample_inputs) + self.backward.shape = self.shape + + def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): + self.ports['y'].set_data(sample_grads[0], grad=True) + self.ports['y'].get_data().backward(sample_grads[0]) + + +class SparseSoftmax(SparseAutoGradFunction, SparseSoftmaxForward): + + __static_func__ = _SparseSoftmax + + def __init__(self, compressed: bool, temperature: float = 1): + super().__init__(compressed, temperature) + self.backward = SparseSoftmaxBackward(compressed, temperature) + self.backward.ports = self.ports + + def set_temperature(self, temperature: float): + super().set_temperature(temperature) + self.backward.set_temperature(temperature) + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + super()._read_sample_inputs(sample_inputs) + self.backward.shape = self.shape + + def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): + self.ports['y'].set_data(sample_grads[0], grad=True) + self.ports['y'].get_data().backward(sample_grads[0]) diff --git a/sparta/specializer/functional/function_base.py b/sparta/specializer/functional/function_base.py new file mode 100644 index 00000000..77dd9130 --- /dev/null +++ b/sparta/specializer/functional/function_base.py @@ -0,0 +1,201 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from __future__ import annotations + +import abc +from typing import Any, Dict, List, Callable, Optional + +import torch + +from sparta.tesa import get_bcs_function, BCSIndexes +from sparta.specializer.kernels import KernelBase +from sparta.testing import profile + + +class SparsityAttr(object): + + def __init__(self, BCSR: bool, BCSC: bool): + self.BCSR = BCSR + self.BCSC = BCSC + self.mask: torch.Tensor = None + self._block_H: int = 0 + self._block_W: int = 0 + self.indexes: BCSIndexes = None + + def update_axis(self, BCSR: bool, BCSC: bool): + self.BCSR |= BCSR + self.BCSC |= BCSC + + def set_block_size(self, block_H: int, block_W: int): + if block_H != self._block_H or block_W != self._block_W: + self._block_H = block_H + self._block_W = block_W + self._update_indexes() + + def set_mask(self, mask: torch.Tensor): + self.mask = mask + self._update_indexes() + + def _update_indexes(self): + if self._block_H > 0 and self._block_W > 0 and self.mask is not None: + self.indexes = get_bcs_function( + self._block_H, self._block_W, + self.BCSR, self.BCSC, + ).build_indexes(self.mask) + + +class Port(object): + + def __init__(self, func: SparseFunctionBase, name: str, fine_mask: bool = True): + self.name = name + self.funcs: List[SparseFunctionBase] = [func] + self.attr: SparsityAttr = None + self._sample_data: torch.Tensor = None # Always dense + self._fine_mask = fine_mask + + def set_data(self, data: torch.Tensor, grad: bool = False): + if grad: + self._sample_data.grad = data + else: + self._sample_data = data + + def get_data(self, grad: bool = False, compressed: bool = False): + data: torch.Tensor = self._sample_data.grad if grad else self._sample_data + if self.attr is not None and data is not None: + if self._fine_mask: + data = data * self.attr.mask + if compressed: + data = self.attr.indexes.convert(data.detach()) + elif not self._fine_mask: + data = data * self.attr.indexes.get_mask() + return data + + def clear_data(self): + self._sample_data = None + + def connect(self, other: Port): + for func in other.funcs: + func.ports[other.name] = self + self.funcs.append(func) + if self.attr is not None and other.attr is not None: + self.attr.update_axis(other.attr.BCSR, other.attr.BCSC) + + +class SparseFunctionBase(Callable): + + def __init__(self): + self.kernels: Dict[str, Dict[str, KernelBase]] = {} + self._compiled_kernels: Dict[str, KernelBase] = {} + self.ports: Dict[str, Port] = {} + self._sparse_port: str = '' + self.forward: Callable = None + self.backward: SparseFunctionBase = None + + def get_sparse_attr(self): + return self.ports[self._sparse_port].attr + + def __call__(self, *inputs): + return self.forward(*inputs) + + @abc.abstractmethod + def _set_forward(self): + """Build forward function with compiled kernels.""" + + @abc.abstractmethod + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + """Get shape parameters from sample inputs.""" + + @abc.abstractmethod + def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): + """Compile kernel with params, shapes and sparse indexes.""" + + def build( + self, + config: Dict[str, Dict[str, Any]], + sample_inputs: Optional[List[torch.Tensor]] = None, + ): + if sample_inputs is not None: + self._read_sample_inputs(sample_inputs) + self._compiled_kernels: Dict[str, KernelBase] = {} + for kernel_name, params in config.items(): + if kernel_name in self.kernels: + kernel = self.kernels[kernel_name][params['_impl']] + self._compile_kernel(kernel_name, kernel, params) + self._compiled_kernels[kernel_name] = kernel + self._set_forward() + + def clear_sample_data(self): + for port in self.ports.values(): + port.clear_data() + + @abc.abstractmethod + def _kernel_reference(self, kernel_name: str) -> torch.Tensor: + """Get kernel reference output from related port(s).""" + + @abc.abstractmethod + def _kernel_func_call(self, kernel_name: str) -> Callable[[], torch.Tensor]: + """Callable kernel function based on sample data of ports.""" + + def profile_kernel( + self, + kernel_name: str, + num_warmups: int = 20, + num_iters: int = 100, + cuda: bool = False, + ): + """Profile kernel latency. Note that all inputs and outputs are dense tensors here.""" + kernel_func = self._kernel_func_call(kernel_name) + target_output = self._kernel_reference(kernel_name) + return profile(kernel_func, [], [target_output], num_warmups, num_iters, cuda) + + @abc.abstractmethod + def _calc_kernel_flops(self, kernel_name: str): + """Calculate kernel flops using sparse rate and shape.""" + + def estimate_kernel(self, kernel_name: str): + kernel = self._compiled_kernels[kernel_name] + flops = self._calc_kernel_flops(kernel_name) + return kernel.estimated_latency_per_flop * flops + + @abc.abstractmethod + def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + """Read input data from input port(s) and set output data to output port(s).""" + + +class SparseAutoGradFunction(SparseFunctionBase): + + __static_func__: torch.autograd.Function = None + + def __call__(self, *inputs): + return self.__static_func__.apply(self, *inputs) + + def build( + self, + config: Dict[str, Dict[str, Any]], + sample_inputs: Optional[List[torch.Tensor]] = None, + ): + super().build(config, sample_inputs) + self.backward.build(config, sample_inputs) + + def profile_kernel( + self, + kernel_name: str, + num_warmups: int = 20, + num_iters: int = 100, + cuda: bool = False + ): + if kernel_name in self.kernels: + return super().profile_kernel(kernel_name, num_warmups, num_iters, cuda) + elif self.backward is not None: + return self.backward.profile_kernel(kernel_name, num_warmups, num_iters, cuda) + + def estimate_kernel(self, kernel_name: str): + if kernel_name in self.kernels: + return super().estimate_kernel(kernel_name) + elif self.backward is not None: + return self.backward.estimate_kernel(kernel_name) + + @abc.abstractmethod + def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): + """Read grad data from output port(s) and backward by auto-grad.""" diff --git a/sparta/specializer/funtional/__init__.py b/sparta/specializer/funtional/__init__.py deleted file mode 100644 index 186a1919..00000000 --- a/sparta/specializer/funtional/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from sparta.specializer.funtional.sparse_ctx_base import KernelPlaceholder, SparseCtxBase -from sparta.specializer.funtional.batch_matmul import SparseBatchMatMulCtx, SparseBatchMatMulFunc -from sparta.specializer.funtional.batch_softmax import SparseBatchSoftmaxCtx, SparseBatchSoftmaxFunc diff --git a/sparta/specializer/funtional/batch_matmul.py b/sparta/specializer/funtional/batch_matmul.py deleted file mode 100644 index 0756c7a2..00000000 --- a/sparta/specializer/funtional/batch_matmul.py +++ /dev/null @@ -1,179 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Any, List, Dict, Tuple, Optional - -import torch - -from sparta.specializer.kernels import SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.specializer.funtional import SparseCtxBase, KernelPlaceholder - - -class SparseBatchMatMulCtx(SparseCtxBase): - - def __init__( - self, - mode: str, - transpose_A: bool, - transpose_B: bool, - biased: bool, - compressed: bool, - ): - super().__init__() - - self._biased = biased - self._compressed = compressed - self._transpose_A = transpose_A - self._transpose_B = transpose_B - self._mode = mode - - def select(x: str, source: str, target: str): - return target[source.find(x)] - - def rearange(s: str, source_order: str, target_order: str): - return ''.join(select(x, source_order, s) for x in target_order) - - def calc_tesa_shape(mode: str, trans_A: bool, trans_B: bool): - if mode == 'sdd': - return ('K', 'M') if trans_A else ('M', 'K') - elif mode == 'dsd': - return ('N', 'K') if trans_B else ('K', 'N') - else: - return ('M', 'N') - - sparse_tensor = select('s', mode, 'ABC') - - self._tesa_shapes: Dict[str, Tuple[str, str]] = {} - for kernel_name, bias, target_order, trans_A, trans_B in zip( - ['forward:C', 'backward:A', 'backward:B'], - [biased, False, False], - ['ABC', 'BCA' if transpose_A else 'CBA', 'CAB' if transpose_B else 'ACB'], - [transpose_A, transpose_A and transpose_B, not transpose_A or transpose_B], - [transpose_B, transpose_A or not transpose_B, transpose_A and transpose_B], - ): - s_type = rearange(mode, 'ABC', target_order) - self._kernels[kernel_name] = KernelPlaceholder( - name=kernel_name, - impls={ - 'sparta': SparTASparseMatMulKernel, - 'openai': OpenAISparseMatMulKernel, - }, - args={ - 'biased': bias, - 'compressed': compressed, - 'transpose_A': trans_A, - 'transpose_B': trans_B, - 'mode': s_type, - }, - port_map={sparse_tensor: select(sparse_tensor, target_order, 'ABC')}, - connectable=compressed, - ) - self._tesa_shapes[kernel_name] = calc_tesa_shape(s_type, trans_A, trans_B) - - self._init_sparse_ports([sparse_tensor]) - - def set_shape(self, batch_size: int, M: int, K: int, N: int): - self._kernels['forward:C'].set_shape(batch_size, M, K, N) - if self._transpose_A: - self._kernels['backward:A'].set_shape(batch_size, K, N, M) - else: - self._kernels['backward:A'].set_shape(batch_size, M, N, K) - if self._transpose_B: - self._kernels['backward:B'].set_shape(batch_size, N, M, K) - else: - self._kernels['backward:B'].set_shape(batch_size, K, M, N) - - def build(self, config: Dict[str, Dict[str, Any]]): - super().build(config) - forward_kernel = self._kernels['forward:C'].active_kernel() - if forward_kernel is not None: - if self._biased: - self.forward_C = lambda A, B, bias: forward_kernel(A, B, bias) - else: - self.forward_C = lambda A, B: forward_kernel(A, B) - if self._mode == 'dds' and self._compressed: - C_indexes = self._kernels['forward:C'].active_kernel().ports['C'].indexes - self.backward_bias = lambda grad_C: C_indexes.sum(grad_C, axis=-2) - else: - self.backward_bias = lambda grad_C: grad_C.sum(-2) - backward_A_kernel = self._kernels['backward:A'].active_kernel() - if backward_A_kernel is not None: - if self._transpose_A: - self.backward_A = lambda grad_C, B: backward_A_kernel(B, grad_C) - else: - self.backward_A = lambda grad_C, B: backward_A_kernel(grad_C, B) - backward_B_kernel = self._kernels['backward:B'].active_kernel() - if backward_B_kernel is not None: - if self._transpose_B: - self.backward_B = lambda grad_C, A: backward_B_kernel(grad_C, A) - else: - self.backward_B = lambda grad_C, A: backward_B_kernel(A, grad_C) - - def set_sample_inputs( - self, - sample_inputs: List[torch.Tensor], - sample_grads: Optional[List[torch.Tensor]] = None, - ): - A = sample_inputs[0] - B = sample_inputs[1] - if self._biased: - bias = sample_inputs[2].detach() - self._kernels['forward:C'].set_sample_inputs([A, B, bias]) - else: - self._kernels['forward:C'].set_sample_inputs([A, B]) - if sample_grads is not None: - grad_C = sample_grads[0] - if self._transpose_A: - self._kernels['backward:A'].set_sample_inputs([B, grad_C]) - else: - self._kernels['backward:A'].set_sample_inputs([grad_C, B]) - if self._transpose_B: - self._kernels['backward:B'].set_sample_inputs([grad_C, A]) - else: - self._kernels['backward:B'].set_sample_inputs([A, grad_C]) - - def get_connections(self, backward: bool = False): - if self._compressed and backward: - conditions = [{}, {}] - for kernel_name, tesa_shapes in self._tesa_shapes.items(): - for k, dim in enumerate(tesa_shapes): - conditions[k][kernel_name] = f'BLOCK_SIZE_{dim}_VALUE' - return conditions - else: - return [] - - def dense_forward(self, *args): - return self._kernels['forward:C'].dense_func(*args) - - -class SparseBatchMatMulFunc(torch.autograd.Function): - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - sparta_ctx: SparseBatchMatMulCtx, - input: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - ): - input = input.detach() - weight = weight.detach() - ctx.save_for_backward(input, weight, bias) - ctx.sparta_ctx = sparta_ctx - if bias is None: - return sparta_ctx.forward_C(input, weight) - else: - return sparta_ctx.forward_C(input, weight, bias.detach()) - - @staticmethod - def backward(ctx: torch.autograd.function.FunctionCtx, grad_output: torch.Tensor): - input, weight, bias = ctx.saved_tensors - grad_input = grad_weight = grad_bias = None - grad_output = grad_output.detach() - if ctx.needs_input_grad[1]: - grad_input = ctx.sparta_ctx.backward_A(grad_output, weight) - if ctx.needs_input_grad[2]: - grad_weight = ctx.sparta_ctx.backward_B(grad_output, input) - if bias is not None and ctx.needs_input_grad[3]: - grad_bias = ctx.sparta_ctx.backward_bias(grad_output) - return None, grad_input, grad_weight, grad_bias diff --git a/sparta/specializer/funtional/batch_softmax.py b/sparta/specializer/funtional/batch_softmax.py deleted file mode 100644 index a4bed5a4..00000000 --- a/sparta/specializer/funtional/batch_softmax.py +++ /dev/null @@ -1,103 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Any, List, Dict, Optional - -import torch -import numpy as np - -from sparta.specializer.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.specializer.funtional import SparseCtxBase, KernelPlaceholder - - -class SparseBatchSoftmaxCtx(SparseCtxBase): - - def __init__(self, compressed: bool, temperature: float = 1): - super().__init__() - - self._compressed = compressed - self._T = np.float32(1 / temperature) - self._batch_size: int = None - - for kernel_name, kernel_class in zip( - ['forward:y', 'backward:x'], - [SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel], - ): - self._kernels[kernel_name] = KernelPlaceholder( - name=kernel_name, - impls={'sparta': kernel_class}, - args={'compressed': compressed}, - port_map={'y': 'y'}, - connectable=compressed, - ) - - self._init_sparse_ports(['y']) - - def set_temperature(self, temperature: float): - self._T = np.float32(1 / temperature) - - def set_shape(self, batch_size: int, H: int, W: int): - self._kernels['forward:y'].set_shape(batch_size, H, W) - self._kernels['backward:x'].set_shape(batch_size, H, W) - self._batch_size = batch_size - - def get_conditions(self, impls: Dict[str, str]): - if self._compressed and len(impls) > 1: - return [ - ['forward:y;BLOCK_SIZE_H_VALUE', 'backward:x;BLOCK_SIZE_H_VALUE'], - ['forward:y;BLOCK_SIZE_W_VALUE', 'backward:x;BLOCK_SIZE_W_VALUE'], - ] - else: - return [] - - def build(self, config: Dict[str, Dict[str, Any]]): - super().build(config) - forward_kernel = self._kernels['forward:y'].active_kernel() - if forward_kernel is not None: - self.forward = lambda x: forward_kernel(x, self._T) - backward_kernel = self._kernels['backward:x'].active_kernel() - if backward_kernel is not None: - self.backward = lambda grad, output: backward_kernel(grad, output, self._T) - - def set_sample_inputs( - self, - sample_inputs: List[torch.Tensor], - sample_grads: Optional[List[torch.Tensor]] = None, - ): - x = sample_inputs[0] - self._kernels['forward:y'].set_sample_inputs([x, self._T]) - if sample_grads is not None: - grad_y = sample_grads[0] - y = self._kernels['forward:y'].dense_func(x, self._T) - self._kernels['backward:x'].set_sample_inputs([grad_y, y, self._T]) - - def get_connections(self, backward: bool = False): - if self._compressed and backward: - return [ - {'forward:y': 'BLOCK_SIZE_H_VALUE', 'backward:x': 'BLOCK_SIZE_H_VALUE'}, - {'forward:y': 'BLOCK_SIZE_W_VALUE', 'backward:x': 'BLOCK_SIZE_W_VALUE'}, - ] - else: - return [] - - def dense_forward(self, *args): - return self._kernels['forward:y'].dense_func(*args, self._T) - - -class SparseBatchSoftmaxFunc(torch.autograd.Function): - - @staticmethod - def forward( - ctx: torch.autograd.function.FunctionCtx, - sparta_ctx: SparseBatchSoftmaxCtx, - x: torch.Tensor, - ): - ctx.sparta_ctx = sparta_ctx - output = sparta_ctx.forward(x.detach()) - ctx.save_for_backward(output) - return output - - @staticmethod - def backward(ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor): - output, = ctx.saved_tensors - return None, ctx.sparta_ctx.backward(grad.detach(), output.detach()) diff --git a/sparta/specializer/funtional/sparse_ctx_base.py b/sparta/specializer/funtional/sparse_ctx_base.py deleted file mode 100644 index f744f111..00000000 --- a/sparta/specializer/funtional/sparse_ctx_base.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import abc -from typing import Any, Dict, List, Type, Optional - -import torch - -from sparta.specializer.kernels import KernelBase, PortConfig - - -class KernelPlaceholder(object): - - def __init__( - self, - name: str, - impls: Dict[str, Type[KernelBase]], - args: Dict[str, Any], - port_map: Dict[str, str], - connectable: bool, - ): - self.name = name - self.possible_kernels = {key: impl(**args) for key, impl in impls.items()} - self.port_map = port_map - self.connectable = connectable - self._active_impl: str = None - self.sample_inputs: List[torch.Tensor] = [] - self.dense_func = list(self.possible_kernels.values())[0].reference - self.ready = False - - def set_shape(self, *args, **kwargs): - if self._active_impl is None: - for kernel in self.possible_kernels.values(): - kernel.set_shape(*args, **kwargs) - else: - self.possible_kernels[self._active_impl].set_shape(*args, **kwargs) - - def select_impl(self, impl: str): - self._active_impl = impl - - def active_kernel(self): - if self._active_impl is None: - return None - else: - return self.possible_kernels[self._active_impl] - - def update_func(self): - self.active_kernel().update_func() - - def build(self, config: Dict[str, Any]): - self.active_kernel().compile(config) - self.ready = True - - def get_search_space(self, fixed_params: Optional[Dict[str, Any]] = None): - search_space = {} - for impl, kernel in self.possible_kernels.items(): - kernel_search_space = kernel.get_search_space(fixed_params) - if kernel_search_space is not None: - search_space[impl] = kernel_search_space - return search_space - - def set_sample_inputs(self, sample_inputs: List[torch.Tensor]): - self.sample_inputs = [ - x.detach() if type(x) is torch.Tensor else x - for x in sample_inputs - ] - - def test(self, num_warmups: int = 20, num_iters: int = 100, cuda: bool = False): - return self.active_kernel().test( - inputs=self.sample_inputs, - num_warmups=num_warmups, - num_iters=num_iters, - cuda=cuda, - ) - - -class SparseCtxBase(object): - - def __init__(self): - self._kernels: Dict[str, KernelPlaceholder] = {} - self.sparse_ports: Dict[str, List[PortConfig]] = {} - - def _init_sparse_ports(self, port_names: List[str]): - self.sparse_ports = { - port_name: [ - kernel.ports[kernel_placeholder.port_map[port_name]] - for kernel_placeholder in self._kernels.values() - for kernel in kernel_placeholder.possible_kernels.values() - ] - for port_name in port_names - } - - @abc.abstractmethod - def set_shape(self, *args, **kwargs): - """Set shape parameters.""" - - def select_impls(self, impls: Dict[str, str]): - connected_ports: Dict[str, PortConfig] = {} - self.sparse_ports = {port_name: [] for port_name in self.sparse_ports.keys()} - for kernel_name, kernel_impl in impls.items(): - kernel_placeholder = self._kernels[kernel_name] - kernel_placeholder.select_impl(kernel_impl) - for global_port_name, kernel_port_name in kernel_placeholder.port_map.items(): - kernel = kernel_placeholder.active_kernel() - if kernel_placeholder.connectable: - if global_port_name in connected_ports: - connected_ports[global_port_name].connect(kernel, kernel_port_name) - else: - connected_ports[global_port_name] = kernel.ports[kernel_port_name] - else: - self.sparse_ports[global_port_name].append(kernel.ports[kernel_port_name]) - for global_port_name, port in connected_ports.items(): - self.sparse_ports[global_port_name].append(port) - - def update_func(self): - for kernel_placeholder in self._kernels.values(): - kernel_placeholder.update_func() - - def build(self, config: Dict[str, Dict[str, Any]]): - for kernel_name, kernel_config in config.items(): - self._kernels[kernel_name].build(kernel_config) - - def get_sparse_indexes(self, port_name: str): - if port_name in self.sparse_ports: - return self.sparse_ports[port_name][0].indexes - else: - return None - - def get_kernel_placeholders(self, backward: bool = False): - return { - kernel_name: kernel - for kernel_name, kernel in self._kernels.items() - if backward or kernel_name.startswith('forward') - } - - @abc.abstractmethod - def set_sample_inputs( - self, - sample_inputs: List[torch.Tensor], - sample_grads: Optional[List[torch.Tensor]] = None, - ): - """Set sample inputs and gradients for tuning.""" - - @abc.abstractmethod - def get_connections(self, backward: bool = False) -> List[Dict[str, str]]: - """Get connected params among different kernels.""" - - def get_search_space(self, backward: bool = False): - return { - kernel_name: kernel.get_search_space() - for kernel_name, kernel in self._kernels.items() - if backward or kernel_name.startswith('forward') - } - - @abc.abstractmethod - def dense_forward(self, *args) -> Any: - """Use dense method to forward (requires gradient).""" diff --git a/sparta/specializer/kernels/__init__.py b/sparta/specializer/kernels/__init__.py index 233e2cf7..0b934bbc 100644 --- a/sparta/specializer/kernels/__init__.py +++ b/sparta/specializer/kernels/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig +from sparta.specializer.kernels.kernel_base import KernelBase from sparta.specializer.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel from sparta.specializer.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel diff --git a/sparta/specializer/kernels/kernel_base.py b/sparta/specializer/kernels/kernel_base.py index 5ee8898b..fca5b89a 100644 --- a/sparta/specializer/kernels/kernel_base.py +++ b/sparta/specializer/kernels/kernel_base.py @@ -6,7 +6,7 @@ import abc import warnings import dataclasses -from typing import Any, Dict, List, Tuple, Callable, Optional, Type +from typing import Any, Dict, List, Tuple, Callable, Optional import torch @@ -16,9 +16,7 @@ import pycuda.autoprimaryctx from pycuda.compiler import SourceModule -from sparta.tesa import get_bcs_function, BCSIndexes from sparta.tuning import TunableItemCfg -from sparta.testing import profile @dataclasses.dataclass @@ -34,66 +32,15 @@ def __post_init__(self): assert self.is_tunable -@dataclasses.dataclass -class PortConfig(object): - name: str - is_input: bool - is_sparse: bool = False - BCSR: bool = False - BCSC: bool = False - - def __post_init__(self): - self.mask: torch.Tensor = None - self._block_H: int = 0 - self._block_W: int = 0 - self.indexes: BCSIndexes = None - - def set_sparse(self, BCSR: bool, BCSC: bool): - self.is_sparse = True - self.BCSR = BCSR - self.BCSC = BCSC - - def set_block_size(self, block_H: int, block_W: int): - if self.is_sparse: - if block_H != self._block_H or block_W != self._block_W: - self._block_H = block_H - self._block_W = block_W - self._update_indexes() - - def set_mask(self, mask: torch.Tensor): - if self.is_sparse: - self.mask = mask - self._update_indexes() - - def _update_indexes(self): - if self._block_H > 0 and self._block_W > 0 and self.mask is not None: - self.indexes = get_bcs_function( - self._block_H, self._block_W, - self.BCSR, self.BCSC, - ).build_indexes(self.mask) - - def connect(self, kernel: KernelBase, port_name: str): - other_port = kernel.ports[port_name] - assert self.is_sparse == other_port.is_sparse - self.BCSR |= other_port.BCSR - self.BCSC |= other_port.BCSC - kernel.ports[port_name] = self - - class KernelBase(Callable): def __init__(self): self._parameters: Dict[str, _Parameter] = {} self._kernel: Callable = None self._func: Callable = None - self.ports: Dict[str, PortConfig] = {} self.ready = False self._add_parameters() - self._set_ports() - - @abc.abstractmethod - def _set_ports(self): - """Set input and output ports.""" + self.estimated_latency_per_flop = float('inf') @abc.abstractmethod def _add_parameters(self): @@ -135,9 +82,8 @@ def get_search_space(self, fixed_params: Optional[Dict[str, Any]] = None): return search_space def set_parameter(self, name: str, value: Any): - if name not in self._parameters and name in ['_name', '_impl']: - return # ignore some special key words - self._parameters[name].value = value + if name in self._parameters: + self._parameters[name].value = value def set_parameters(self, params: Dict[str, Any]): for name, value in params.items(): @@ -148,22 +94,13 @@ def get_parameter(self, name: str): def get_parameters(self, names: Optional[List[str]] = None): if names is None: - return {k: v.value for k, v in self._parameters.items()} - else: - return {k: self._parameters[k].value for k in names} - - @abc.abstractmethod - def set_shape(self, *args, **kwargs): - """Set shape parameters.""" + names = self._parameters.keys() + return {name: self.get_parameter(name) for name in names} @abc.abstractmethod def get_kernel_code(self) -> str: """Get CUDA code of the kernel.""" - @abc.abstractmethod - def blocks_per_grid(self: int) -> Tuple[int]: - """Get launch config: number of blocks per grid.""" - @abc.abstractmethod def threads_per_block(self) -> Tuple[int]: """Get launch config: number of threads per block.""" @@ -173,10 +110,10 @@ def _check_parameters(self, params: Dict[str, Any]): """Raise an error if the input paramater dict is invalid.""" @abc.abstractmethod - def update_func(self): + def set_kernel_call(self, shape: Tuple, sparse_attr: Any): """Convert pycuda kernel (self._kernel) to python function call (self._func).""" - def compile(self, params: Dict[str, Any]): + def compile(self, params: Dict[str, Any], shape: Any, sparse_attr: Any): self._check_parameters(params) self.set_parameters(params) kernel_code = self.get_kernel_code() @@ -186,33 +123,9 @@ def compile(self, params: Dict[str, Any]): warnings.simplefilter('ignore') source_module = SourceModule(kernel_code, options=['-O3']) self._kernel = source_module.get_function(kernel_name) - self.update_func() + self.set_kernel_call(shape, sparse_attr) self.ready = True - @abc.abstractmethod - def reference(self, *args) -> Any: - """Dense reference. Note that all inputs and outputs are dense tensors here.""" - - @abc.abstractmethod - def _convert_data(self, inputs: List[torch.Tensor], outputs: List[torch.Tensor]): - """Convert sample inputs and target outputs to sparse tenors in place if necessary.""" - - def test( - self, - inputs: List[torch.Tensor], - num_warmups: int = 20, - num_iters: int = 100, - cuda: bool = False, - ): - """Note that all inputs and outputs are dense tensors here.""" - sparse_inputs = [x for x in inputs] - sparse_outputs = self.reference(*sparse_inputs) - if type(sparse_outputs) is not tuple: - sparse_outputs = (sparse_outputs, ) - sparse_outputs = [y for y in sparse_outputs] - self._convert_data(sparse_inputs, sparse_outputs) - return profile(self, sparse_inputs, sparse_outputs, num_warmups, num_iters, cuda) - def __call__(self, *args) -> torch.Tensor: if self.ready: return self._func(*args) diff --git a/sparta/specializer/kernels/look_up_tables/matmul.openai.61.csv b/sparta/specializer/kernels/look_up_tables/matmul.openai.61.csv new file mode 100644 index 00000000..63f04dfb --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/matmul.openai.61.csv @@ -0,0 +1,13 @@ +mode,trans_A,trans_B,BM,BK,BN,latency +dds,False,False,32,64,32,38.08401794433594 +dds,False,True,32,64,32,38.0838134765625 +dds,True,False,32,64,32,37.28413391113281 +dds,True,True,32,64,32,38.15277099609375 +dsd,False,False,32,64,32,38.40625610351562 +dsd,False,True,32,64,32,40.095452880859376 +dsd,True,False,32,64,32,38.40674743652344 +dsd,True,True,32,64,32,40.13760986328125 +sdd,False,False,32,64,32,38.06143798828125 +sdd,False,True,32,64,32,37.84909057617188 +sdd,True,False,32,64,32,38.10197448730469 +sdd,True,True,32,64,32,37.697930908203126 diff --git a/sparta/specializer/kernels/look_up_tables/matmul.openai.default.csv b/sparta/specializer/kernels/look_up_tables/matmul.openai.default.csv new file mode 100644 index 00000000..5f7f1bf9 --- /dev/null +++ b/sparta/specializer/kernels/look_up_tables/matmul.openai.default.csv @@ -0,0 +1,13 @@ +mode,trans_A,trans_B,BM,BK,BN,latency +sdd,False,False,32,64,32,50.0 +sdd,False,True,32,64,32,50.0 +sdd,True,False,32,64,32,50.0 +sdd,True,True,32,64,32,50.0 +dsd,False,False,32,64,32,50.0 +dsd,False,True,32,64,32,50.0 +dsd,True,False,32,64,32,50.0 +dsd,True,True,32,64,32,50.0 +dds,False,False,32,64,32,50.0 +dds,False,True,32,64,32,50.0 +dds,True,False,32,64,32,50.0 +dds,True,True,32,64,32,50.0 diff --git a/sparta/specializer/kernels/matmul.py b/sparta/specializer/kernels/matmul.py index 00bf6190..60d25390 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/specializer/kernels/matmul.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. import io -from typing import Any, Dict +import textwrap import importlib.resources as res +from typing import Any, Dict, Tuple import torch import jinja2 @@ -12,21 +13,24 @@ from sparta.tuning import TunableItemCfg from sparta.specializer.kernels import templates, look_up_tables -from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig +from sparta.specializer.kernels.kernel_base import KernelBase -def _get_sparta_matmul_lut(): +def _get_matmul_lut(impl: str): major, minor = torch.cuda.get_device_capability() try: - lut_file = f'matmul.sparta.{major}{minor}.csv' + lut_file = f'matmul.{impl}.{major}{minor}.csv' lut_text = res.read_text(look_up_tables, lut_file) except FileNotFoundError: - lut_file = f'matmul.sparta.default.csv' + lut_file = f'matmul.{impl}.default.csv' lut_text = res.read_text(look_up_tables, lut_file) return pd.read_csv(io.StringIO(lut_text)) -SPARTA_MATMUL_LUT = _get_sparta_matmul_lut() +_MATMUL_LUTS = { + 'sparta': _get_matmul_lut('sparta'), + 'openai': _get_matmul_lut('openai'), +} class SparseMatMulKernel(KernelBase): @@ -36,207 +40,102 @@ class SparseMatMulKernel(KernelBase): def __init__( self, mode: str, + biased: bool, + transpose_A: bool, + transpose_B: bool, + compressed: bool, + batched: bool, dtype: str = 'float', - biased: bool = True, - transpose_A: bool = False, - transpose_B: bool = True, - compressed: bool = True, ): - if mode not in ['sdd', 'dsd', 'dds']: - raise ValueError(f'invalid sparse type: {mode}') self._biased = biased self._transpose_A = transpose_A self._transpose_B = transpose_B self._compressed = compressed - self._bcs_mode = '' + self._batched = batched self._mode = mode self._dtype = dtype - self._sparse_port = '' - self._sparse_block_H = '' - self._sparse_block_W = '' - self._tesa_vars = [] - mode_filter = SPARTA_MATMUL_LUT['mode'] == self._mode - trans_A_filter = SPARTA_MATMUL_LUT['trans_A'] == self._transpose_A - trans_B_filter = SPARTA_MATMUL_LUT['trans_B'] == self._transpose_B - self._lut = SPARTA_MATMUL_LUT[mode_filter & trans_A_filter & trans_B_filter] - super().__init__() - def _set_ports(self): - self.ports['A'] = PortConfig(name='A', is_input=True) - self.ports['B'] = PortConfig(name='B', is_input=True) - if self._biased: - self.ports['bias'] = PortConfig(name='bias', is_input=True) - self.ports['C'] = PortConfig(name='C', is_input=False) - - if self._mode == 'sdd': - self._sparse_port = 'A' - if self._transpose_A: - self._sparse_block_H = 'BLOCK_SIZE_K_VALUE' - self._sparse_block_W = 'BLOCK_SIZE_M_VALUE' - self._bcs_mode = 'BCSC' - else: - self._sparse_block_H = 'BLOCK_SIZE_M_VALUE' - self._sparse_block_W = 'BLOCK_SIZE_K_VALUE' - self._bcs_mode = 'BCSR' - elif self._mode == 'dsd': - self._sparse_port = 'B' - if self._transpose_B: - self._sparse_block_H = 'BLOCK_SIZE_N_VALUE' - self._sparse_block_W = 'BLOCK_SIZE_K_VALUE' - self._bcs_mode = 'BCSR' - else: - self._sparse_block_H = 'BLOCK_SIZE_K_VALUE' - self._sparse_block_W = 'BLOCK_SIZE_N_VALUE' - self._bcs_mode = 'BCSC' - elif self._mode == 'dds': - self._sparse_port = 'C' - self._sparse_block_H = 'BLOCK_SIZE_M_VALUE' - self._sparse_block_W = 'BLOCK_SIZE_N_VALUE' - self._bcs_mode = 'BCSR' - - BCSR = self._bcs_mode == 'BCSR' - BCSC = self._bcs_mode == 'BCSC' - self.ports[self._sparse_port].set_sparse(BCSR, BCSC) - - if BCSR: - self._tesa_vars = ['row_ptr', 'BCSR_idx', 'nnz'] - elif BCSC: - self._tesa_vars = ['col_ptr', 'BCSC_idx', 'nnz'] - else: - raise ValueError('failed to initialize SparseMatMulKernel') - if self._mode == 'dds': - self._tesa_vars = self._tesa_vars[1:] + lut = _MATMUL_LUTS[self.__algo__] + mode_filter = lut['mode'] == self._mode + trans_A_filter = lut['trans_A'] == self._transpose_A + trans_B_filter = lut['trans_B'] == self._transpose_B + self._lut = lut[mode_filter & trans_A_filter & trans_B_filter] + + super().__init__() def _add_parameters(self): - self._add_parameter('BATCH_SIZE') - self._add_parameter('GLOBAL_M_VALUE') - self._add_parameter('GLOBAL_K_VALUE') - self._add_parameter('GLOBAL_N_VALUE') + self._add_parameter('MODE', value=self._mode) self._add_parameter('BIASED', value=self._biased) + self._add_parameter('BATCHED', value=self._batched) self._add_parameter('TRANSPOSE_A', value=self._transpose_A) self._add_parameter('TRANSPOSE_B', value=self._transpose_B) self._add_parameter('COMPRESSED', value=self._compressed) self._add_parameter('BCSR') self._add_parameter('BCSC') - def set_parameters(self, params: Dict[str, Any]): - super().set_parameters(params) - sparse_port = self.ports[self._sparse_port] - if self._bcs_mode == 'BCSR': - self.set_parameter('BCSR', True) - self.set_parameter('BCSC', False) - elif self._bcs_mode == 'BCSC': - self.set_parameter('BCSR', sparse_port.BCSR) - self.set_parameter('BCSC', True) - BH = self.get_parameter(self._sparse_block_H) - BW = self.get_parameter(self._sparse_block_W) - sparse_port.set_block_size(BH, BW) - - def set_shape(self, batch_size: int, M: int, K: int, N: int): - self.set_parameter('BATCH_SIZE', batch_size) - self.set_parameter('GLOBAL_M_VALUE', M) - self.set_parameter('GLOBAL_K_VALUE', K) - self.set_parameter('GLOBAL_N_VALUE', N) - - def get_shape(self): - batch_size = self.get_parameter('BATCH_SIZE') - M = self.get_parameter('GLOBAL_M_VALUE') - K = self.get_parameter('GLOBAL_K_VALUE') - N = self.get_parameter('GLOBAL_N_VALUE') - return batch_size, M, K, N - def get_block_shape(self): BM = self.get_parameter('BLOCK_SIZE_M_VALUE') BK = self.get_parameter('BLOCK_SIZE_K_VALUE') BN = self.get_parameter('BLOCK_SIZE_N_VALUE') return BM, BK, BN - def blocks_per_grid(self): - batch_size, M, K, N = self.get_shape() - if self._mode == 'dds': - return (self.ports['C'].indexes.block_nnz, batch_size, 1) - else: - BM, BK, BN = self.get_block_shape() - return (N // BN, M // BM, batch_size) - - def update_func(self): - batch_size, M, K, N = self.get_shape() + def set_kernel_call(self, shape: Tuple[int, int, int, int], sparse_attr: Any): + batch, M, K, N = shape + M_32, K_32, N_32 = np.int32(M), np.int32(K), np.int32(N) BM, BK, BN = self.get_block_shape() - - indexes = self.ports[self._sparse_port].indexes - if self._mode == 'dds' and self._compressed: - C_shape = (batch_size, indexes.block_nnz * BM * BN) - else: - C_shape = (batch_size, M, N) - - M_32 = np.int32(M) - K_32 = np.int32(K) - N_32 = np.int32(N) - tesa_vars = [getattr(indexes, x) for x in self._tesa_vars] + row_num, col_num = M // BM, N // BN block = self.threads_per_block() - grid = self.blocks_per_grid() raw_func = self._kernel - - if self._biased: - def matmul_func(A: torch.Tensor, B: torch.Tensor, bias: torch.Tensor): - C = torch.zeros(C_shape, device=A.device) - raw_func( - A, B, bias, C, *tesa_vars, - M_32, K_32, N_32, - block=block, grid=grid - ) - return C - else: - def matmul_func(A: torch.Tensor, B: torch.Tensor): - C = torch.zeros(C_shape, device=A.device) + zeros = torch.zeros + int32 = np.int32 + + func_code = jinja2.Template(textwrap.dedent(''' + def matmul_func(A, B{% if BIASED %}, bias{% endif %}): + {% if MODE == "sdd" %} + {% if BATCHED %}batch, {% endif %}{% if TRANSPOSE_B %}N, _{% else %}_, N{% endif %} = B.shape + N_32 = int32(N) + col_num = N // BN + {% elif MODE == 'dsd' %} + {% if BATCHED %}batch, {% endif %}{% if TRANSPOSE_A %}_, M{% else %}M, _{% endif %} = A.shape + M_32 = int32(M) + row_num = M // BM + {% else %} + {% if BATCHED %}batch, {% endif %}{% if TRANSPOSE_A %}K, _{% else %}_, K{% endif %} = A.shape + K_32 = int32(K) + {% endif %} + {% if MODE == 'dds' and COMPRESSED %} + C = zeros(({% if BATCHED %}batch, {% endif %}sparse_attr.indexes.block_nnz * BM * BN), device=A.device) + {% else %} + C = zeros(({% if BATCHED %}batch, {% endif %}M, N), device=A.device) + {% endif %} raw_func( - A, B, C, *tesa_vars, + A.detach(), B.detach(), {% if BIASED %}bias.detach(), {% endif %}C, + {% if (MODE == "sdd" and TRANSPOSE_A) or (MODE == "dsd" and not TRANSPOSE_B) %} + sparse_attr.indexes.col_ptr, sparse_attr.indexes.BCSC_idx, sparse_attr.indexes.nnz, + {% elif (MODE == "sdd" and not TRANSPOSE_A) or (MODE == "dsd" and TRANSPOSE_B) %} + sparse_attr.indexes.row_ptr, sparse_attr.indexes.BCSR_idx, sparse_attr.indexes.nnz, + {% else %} + sparse_attr.indexes.BCSR_idx, sparse_attr.indexes.nnz, + {% endif %} M_32, K_32, N_32, - block=block, grid=grid + block=block, + {% if MODE == 'dds' %} + grid=(sparse_attr.indexes.block_nnz, {% if BATCHED %}batch{% else %}1{% endif %}, 1), + {% else %} + grid=(col_num, row_num, {% if BATCHED %}batch{% else %}1{% endif %}), + {% endif %} ) return C + ''')).render(self.get_parameters()) - self._func = matmul_func + exec(func_code, locals()) + self._func = locals()['matmul_func'] def get_kernel_code(self): template_file = f'{self.__algo__}_sparse_matmul_{self._mode}.cuh.j2' kernel_template = res.read_text(templates, template_file) return jinja2.Template(kernel_template).render(self.get_parameters()) - def _convert_data(self, inputs, outputs): - if self._mode == 'sdd': - inputs[0] = inputs[0] * self.ports['A'].mask - elif self._mode == 'dsd': - inputs[1] = inputs[1] * self.ports['B'].mask - for i in range(len(inputs)): - inputs[i] = inputs[i].detach() - outputs[0] = outputs[0].detach() - if self._compressed: - if self._sparse_port == 'A': - inputs[0] = self.ports['A'].indexes.convert(inputs[0]) - elif self._sparse_port == 'B': - inputs[1] = self.ports['B'].indexes.convert(inputs[1]) - elif self._sparse_port == 'C': - outputs[0] = self.ports['C'].indexes.convert(outputs[0]) - - def reference(self, *args): - A, B = args[0], args[1] - if self._mode == 'sdd': - A = A * self.ports['A'].mask - elif self._mode == 'dsd': - B = B * self.ports['B'].mask - A_str = 'bkm' if self._transpose_A else 'bmk' - B_str = 'bnk' if self._transpose_B else 'bkn' - C: torch.Tensor = torch.einsum(f'{A_str}, {B_str} -> bmn', A, B) - if self._biased: - C = C + args[2].unsqueeze(1) - if self._mode == 'dds': - if self.ready: - C = C * self.ports['C'].indexes.get_mask() # DDS known issue - else: - C = C * self.ports['C'].mask - return C - class SparTASparseMatMulKernel(SparseMatMulKernel): @@ -248,33 +147,12 @@ def _add_parameters(self): self._add_parameter( f'BLOCK_SIZE_{dim}_VALUE', is_tunable=True, - search_space=TunableItemCfg('choice', [8, 16, 32, 64]) + search_space=TunableItemCfg('choice', [8, 16, 32, 64]), ) self._add_parameter( f'THREAD_SIZE_{dim}_VALUE', ) - def set_parameters(self, params: Dict[str, Any]): - super().set_parameters(params) - if 'THREAD_SIZE_M_VALUE' in params: - if 'THREAD_SIZE_K_VALUE' in params: - if 'THREAD_SIZE_N_VALUE' in params: - return - BM = params['BLOCK_SIZE_M_VALUE'] - BK = params['BLOCK_SIZE_K_VALUE'] - BN = params['BLOCK_SIZE_N_VALUE'] - BM_filter = self._lut['BM'] == BM - BK_filter = self._lut['BK'] == BK - BN_filter = self._lut['BN'] == BN - row = self._lut[BM_filter & BK_filter & BN_filter] - assert len(row) > 0, f'block shape ({BM}, {BK}, {BN}) not found in LUT' - row = row.reset_index(drop=True).iloc[0, :] - assert float(row['latency']) < float('inf'), f'block shape ({BM}, {BK}, {BN}) is invalid' - TM, TK, TN = row['TM'], row['TK'], row['TN'] - self.set_parameter('THREAD_SIZE_M_VALUE', int(TM)) - self.set_parameter('THREAD_SIZE_K_VALUE', int(TK)) - self.set_parameter('THREAD_SIZE_N_VALUE', int(TN)) - def get_thread_shape(self): TM = self.get_parameter('THREAD_SIZE_M_VALUE') TK = self.get_parameter('THREAD_SIZE_K_VALUE') @@ -315,6 +193,16 @@ def _check_parameters(self, params: Dict[str, Any]): B_tile_row_stride = threads_per_block // B_threads_per_row assert A_tile_row_stride <= (BK if self._transpose_A else BM) assert B_tile_row_stride <= (BN if self._transpose_B else BK) + else: + row = self._lut[(self._lut['BM'] == BM) & (self._lut['BK'] == BK) & (self._lut['BN'] == BN)] + assert len(row) > 0, f'block shape ({BM}, {BK}, {BN}) not found in LUT' + row = row.reset_index(drop=True).iloc[0, :] + assert float(row['latency']) < float('inf'), f'block shape ({BM}, {BK}, {BN}) is invalid' + TM, TK, TN = row['TM'], row['TK'], row['TN'] + self.set_parameter('THREAD_SIZE_M_VALUE', int(TM)) + self.set_parameter('THREAD_SIZE_K_VALUE', int(TK)) + self.set_parameter('THREAD_SIZE_N_VALUE', int(TN)) + self.estimated_latency_per_flop = row['latency'] / 4096 / 4096 / 4096 class OpenAISparseMatMulKernel(SparseMatMulKernel): @@ -323,24 +211,13 @@ class OpenAISparseMatMulKernel(SparseMatMulKernel): def _add_parameters(self): super()._add_parameters() - self._add_parameter( - 'BLOCK_SIZE_M_VALUE', - value=32, - is_tunable=True, - search_space=TunableItemCfg('choice', [32]), - ) - self._add_parameter( - 'BLOCK_SIZE_K_VALUE', - value=64, - is_tunable=True, - search_space=TunableItemCfg('choice', [64]), - ) - self._add_parameter( - 'BLOCK_SIZE_N_VALUE', - value=32, - is_tunable=True, - search_space=TunableItemCfg('choice', [32]), - ) + for dim, val in zip(['M', 'K', 'N'], [32, 64, 32]): + self._add_parameter( + f'BLOCK_SIZE_{dim}_VALUE', + value=val, + is_tunable=True, + search_space=TunableItemCfg('choice', [val]), + ) def threads_per_block(self): return (256, 1, 1) @@ -352,3 +229,5 @@ def _check_parameters(self, params: Dict[str, Any]): assert params['BLOCK_SIZE_K_VALUE'] == 64 if 'BLOCK_SIZE_N_VALUE' in params: assert params['BLOCK_SIZE_N_VALUE'] == 32 + row = self._lut.reset_index(drop=True).iloc[0, :] + self.estimated_latency_per_flop = row['latency'] / 32 / 64 / 32 diff --git a/sparta/specializer/kernels/softmax.py b/sparta/specializer/kernels/softmax.py index 5f90a481..ff32e960 100644 --- a/sparta/specializer/kernels/softmax.py +++ b/sparta/specializer/kernels/softmax.py @@ -2,8 +2,9 @@ # Licensed under the MIT license. import io -from typing import Any, Dict, Tuple +import textwrap import importlib.resources as res +from typing import Any, Dict, Tuple import torch import jinja2 @@ -12,30 +13,26 @@ from sparta.tuning import TunableItemCfg from sparta.specializer.kernels import templates, look_up_tables -from sparta.specializer.kernels.kernel_base import KernelBase, PortConfig -from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference +from sparta.specializer.kernels.kernel_base import KernelBase -def _get_sparta_softmax_lut(): +def _get_softmax_lut(impl: str, direction: str): major, minor = torch.cuda.get_device_capability() try: - forward_lut_file = f'softmax.forward.sparta.{major}{minor}.csv' - forward_lut_text = res.read_text(look_up_tables, forward_lut_file) - except FileNotFoundError: - forward_lut_file = f'softmax.backward.sparta.default.csv' - forward_lut_text = res.read_text(look_up_tables, forward_lut_file) - forward_lut = pd.read_csv(io.StringIO(forward_lut_text)) - try: - backward_lut_file = f'softmax.backward.sparta.{major}{minor}.csv' - backward_lut_text = res.read_text(look_up_tables, backward_lut_file) + lut_file = f'softmax.{direction}.{impl}.{major}{minor}.csv' + lut_text = res.read_text(look_up_tables, lut_file) except FileNotFoundError: - backward_lut_file = f'softmax.backward.sparta.default.csv' - backward_lut_text = res.read_text(look_up_tables, backward_lut_file) - backward_lut = pd.read_csv(io.StringIO(backward_lut_text)) - return forward_lut, backward_lut + lut_file = f'softmax.{direction}.{impl}.default.csv' + lut_text = res.read_text(look_up_tables, lut_file) + return pd.read_csv(io.StringIO(lut_text)) -SPARTA_SOFTMAX_FORWARD_LUT, SPARTA_SOFTMAX_BACKWARD_LUT = _get_sparta_softmax_lut() +_SOFTMAX_LUTS = { + 'sparta': { + 'forward': _get_softmax_lut('sparta', 'forward'), + 'backward': _get_softmax_lut('sparta', 'backward'), + }, +} class SparseSoftmaxKernel(KernelBase): @@ -43,34 +40,17 @@ class SparseSoftmaxKernel(KernelBase): __algo__: str = '' __direction__: str = '' - def __init__(self, compressed: bool = False, dtype: str = 'float'): + def __init__(self, compressed: bool, batched: bool, dtype: str = 'float'): self._compressed = compressed + self._batched = batched self._dtype = dtype + self._lut = _SOFTMAX_LUTS[self.__algo__][self.__direction__] super().__init__() def _add_parameters(self): - self._add_parameter('BATCH_SIZE') - self._add_parameter('GLOBAL_H_VALUE') - self._add_parameter('GLOBAL_W_VALUE') + self._add_parameter('BATCHED', value=self._batched) self._add_parameter('COMPRESSED', value=self._compressed) - def set_parameters(self, params: Dict[str, Any]): - super().set_parameters(params) - sparse_port = self.ports['y'] - BH, BW = self.get_block_shape() - sparse_port.set_block_size(BH, BW) - - def set_shape(self, batch_size: int, H: int, W: int): - self.set_parameter('BATCH_SIZE', batch_size) - self.set_parameter('GLOBAL_H_VALUE', H) - self.set_parameter('GLOBAL_W_VALUE', W) - - def get_shape(self): - batch_size = self.get_parameter('BATCH_SIZE') - H = self.get_parameter('GLOBAL_H_VALUE') - W = self.get_parameter('GLOBAL_W_VALUE') - return batch_size, H, W - def get_block_shape(self): BH = self.get_parameter('BLOCK_SIZE_H_VALUE') BW = self.get_parameter('BLOCK_SIZE_W_VALUE') @@ -81,112 +61,18 @@ class SparseSoftmaxForwardKernel(SparseSoftmaxKernel): __direction__ = 'forward' - def _set_ports(self): - self.ports['x'] = PortConfig(name='x', is_input=True, is_sparse=True, BCSR=True) - self.ports['y'] = PortConfig(name='y', is_input=False, is_sparse=True, BCSR=True) - self.ports['x'].connect(self, 'y') - - def update_func(self): - batch_size, H, W = self.get_shape() - BH, BW = self.get_block_shape() - - indexes = self.ports['y'].indexes - row_ptr = indexes.row_ptr - BCSR_idx = indexes.BCSR_idx - if self._compressed: - shape = (batch_size, indexes.block_nnz * BH * BW) - else: - shape = (batch_size, H, W) - mask = indexes.raw_mask - if self._compressed: - mask = indexes.convert(mask.to(torch.float32)).to(torch.uint8) - block = self.threads_per_block() - grid = self.blocks_per_grid() - raw_func = self._kernel - - def softmax_forward_func(x: torch.Tensor, T: np.int32): - y = torch.zeros(shape, device=x.device) - raw_func(x, row_ptr, BCSR_idx, mask, T, y, block=block, grid=grid) - return y - - self._func = softmax_forward_func - - def _convert_data(self, inputs, outputs): - mask = self.ports['y'].mask - inputs[0] = (inputs[0].reshape(self.get_shape()) * mask).detach() - outputs[0] = (outputs[0].reshape(self.get_shape()) * mask).detach() - if self._compressed: - indexes = self.ports['y'].indexes - inputs[0] = indexes.convert(inputs[0]) - outputs[0] = indexes.convert(outputs[0]) - - def reference(self, *args): - x, T = args - mask = self.ports['y'].mask - y = sparse_softmax_forward_reference(x, mask, 1 / T) - return y - class SparseSoftmaxBackwardKernel(SparseSoftmaxKernel): __direction__ = 'backward' - def _set_ports(self): - self.ports['grad_y'] = PortConfig(name='grad_y', is_input=True, is_sparse=True, BCSR=True) - self.ports['y'] = PortConfig(name='y', is_input=True, is_sparse=True, BCSR=True) - self.ports['grad_x'] = PortConfig(name='grad_x', is_input=False, is_sparse=True, BCSR=True) - self.ports['grad_y'].connect(self, 'y') - self.ports['grad_y'].connect(self, 'grad_x') - - def update_func(self): - batch_size, H, W = self.get_shape() - BH, BW = self.get_block_shape() - - indexes = self.ports['y'].indexes - row_ptr = indexes.row_ptr - BCSR_idx = indexes.BCSR_idx - if self._compressed: - shape = (batch_size, indexes.block_nnz * BH * BW) - else: - shape = (batch_size, H, W) - mask = indexes.raw_mask - if self._compressed: - mask = indexes.convert(mask.to(torch.float32)).to(torch.uint8) - block = self.threads_per_block() - grid = self.blocks_per_grid() - raw_func = self._kernel - - def softmax_backward_func(grad_y: torch.Tensor, y: torch.Tensor, T: np.int32): - x = torch.zeros(shape, device=grad_y.device) - raw_func(grad_y, row_ptr, BCSR_idx, y, mask, T, x, block=block, grid=grid) - return x - - self._func = softmax_backward_func - - def _convert_data(self, inputs, outputs): - mask = self.ports['y'].mask - inputs[0] = (inputs[0].reshape(self.get_shape()) * mask).detach() - inputs[1] = (inputs[1].reshape(self.get_shape()) * mask).detach() - outputs[0] = (outputs[0].reshape(self.get_shape()) * mask).detach() - if self._compressed: - indexes = self.ports['y'].indexes - inputs[0] = indexes.convert(inputs[0]) - inputs[1] = indexes.convert(inputs[1]) - outputs[0] = indexes.convert(outputs[0]) - - def reference(self, *args): - grad_y, y, T = args - mask = self.ports['y'].indexes.raw_mask - grad_x = sparse_softmax_backward_reference(grad_y, y, mask, 1 / T) - return grad_x - class SparTASoftmaxKernel(SparseSoftmaxKernel): __algo__ = 'sparta' - def __init__(self, compressed: bool = False, dtype: str = 'float'): - super().__init__(compressed, dtype) + def __init__(self, compressed: bool, batched: bool, dtype: str = 'float'): + super().__init__(compressed, batched, dtype) self._lut: pd.DataFrame = None def _add_parameters(self): @@ -199,29 +85,10 @@ def _add_parameters(self): self._add_parameter( 'BLOCK_SIZE_W_VALUE', is_tunable=True, - search_space=TunableItemCfg('choice', [16, 32, 64, 128]) - ) - self._add_parameter( - 'ROW_TILE_VALUE', + search_space=TunableItemCfg('choice', [8, 16, 32, 64, 128]) ) - - def set_parameters(self, params: Dict[str, Any]): - super().set_parameters(params) - if 'ROW_TILE_VALUE' in params: - return - BH, BW = self.get_block_shape() - BH_filter = self._lut['BH'] == BH - BW_filter = self._lut['BW'] == BW - row = self._lut[BH_filter & BW_filter] - assert len(row) > 0, f'block shape ({BH}, {BW}) not found in LUT' - row = row.reset_index(drop=True).iloc[0, :] - assert float(row['latency']) < float('inf'), f'block shape ({BH}, {BW}) is invalid' - self.set_parameter('ROW_TILE_VALUE', int(row['RT'])) - - def blocks_per_grid(self): - batch_size, H, W = self.get_shape() - RT = self.get_parameter('ROW_TILE_VALUE') - return (H // RT, batch_size, 1) + self._add_parameter('ROW_TILE_VALUE') + self._add_parameter('MAX_W_VALUE', value=1024) def threads_per_block(self) -> Tuple[int]: BW = self.get_parameter('BLOCK_SIZE_W_VALUE') @@ -236,6 +103,16 @@ def _check_parameters(self, params: Dict[str, Any]): if 'ROW_TILE_VALUE' in params: RT = params['ROW_TILE_VALUE'] assert BH >= RT + else: + BH, BW = self.get_block_shape() + BH_filter = self._lut['BH'] == BH + BW_filter = self._lut['BW'] == BW + row = self._lut[BH_filter & BW_filter] + assert len(row) > 0, f'block shape ({BH}, {BW}) not found in LUT' + row = row.reset_index(drop=True).iloc[0, :] + assert float(row['latency']) < float('inf'), f'block shape ({BH}, {BW}) is invalid' + self.set_parameter('ROW_TILE_VALUE', int(row['RT'])) + self.estimated_latency_per_flop = row['latency'] / BH / BW def get_kernel_code(self): template_file = f'{self.__algo__}_sparse_softmax_{self.__direction__}.cuh.j2' @@ -245,13 +122,83 @@ def get_kernel_code(self): class SparTASparseSoftmaxForwardKernel(SparseSoftmaxForwardKernel, SparTASoftmaxKernel): - def __init__(self, compressed: bool = False, dtype: str = 'float'): - super().__init__(compressed, dtype) - self._lut = SPARTA_SOFTMAX_FORWARD_LUT + def set_kernel_call(self, shape: Tuple[int, int, int], sparse_attr: Any): + batch, H, W = shape + H_32, W_32 = np.int32(H), np.int32(W) + BH, BW = self.get_block_shape() + RT = self.get_parameter('ROW_TILE_VALUE') + block = self.threads_per_block() + row_num = H // RT + raw_func = self._kernel + zeros = torch.zeros + + func_code = jinja2.Template(textwrap.dedent(''' + def softmax_forward_func(x, mask, T): + {% if BATCHED %} + batch = x.shape[0] + {% if COMPRESSED %} + y = zeros((batch, sparse_attr.indexes.block_nnz * BH * BW), device=x.device) + {% else %} + y = zeros((batch, H, W), device=x.device) + {% endif %} + {% else %} + {% if COMPRESSED %} + y = zeros((sparse_attr.indexes.block_nnz * BH * BW), device=x.device) + {% else %} + y = zeros((H, W), device=x.device) + {% endif %} + {% endif %} + raw_func( + x.detach(), mask, T, y, + sparse_attr.indexes.row_ptr, sparse_attr.indexes.BCSR_idx, + H_32, W_32, + block=block, + grid=(row_num, {% if BATCHED %}batch{% else %}1{% endif %}, 1), + ) + return y + ''')).render(self.get_parameters()) + + exec(func_code, locals()) + self._func = locals()['softmax_forward_func'] class SparTASparseSoftmaxBackwardKernel(SparseSoftmaxBackwardKernel, SparTASoftmaxKernel): - def __init__(self, compressed: bool = False, dtype: str = 'float'): - super().__init__(compressed, dtype) - self._lut = SPARTA_SOFTMAX_BACKWARD_LUT + def set_kernel_call(self, shape: Tuple[int, int, int], sparse_attr: Any): + batch, H, W = shape + H_32, W_32 = np.int32(H), np.int32(W) + BH, BW = self.get_block_shape() + RT = self.get_parameter('ROW_TILE_VALUE') + block = self.threads_per_block() + row_num = H // RT + raw_func = self._kernel + zeros = torch.zeros + + func_code = jinja2.Template(textwrap.dedent(''' + def softmax_backward_func(gy, y, mask, T): + {% if BATCHED %} + batch = gy.shape[0] + {% if COMPRESSED %} + gx = zeros((batch, sparse_attr.indexes.block_nnz * BH * BW), device=gy.device) + {% else %} + gx = zeros((batch, H, W), device=gy.device) + {% endif %} + {% else %} + {% if COMPRESSED %} + gx = zeros((sparse_attr.indexes.block_nnz * BH * BW), device=gy.device) + {% else %} + gx = zeros((H, W), device=gy.device) + {% endif %} + {% endif %} + raw_func( + gy, y.detach(), mask, T, gx, + sparse_attr.indexes.row_ptr, sparse_attr.indexes.BCSR_idx, + H_32, W_32, + block=block, + grid=(row_num, {% if BATCHED %}batch{% else %}1{% endif %}, 1), + ) + return gx + ''')).render(self.get_parameters()) + + exec(func_code, locals()) + self._func = locals()['softmax_backward_func'] diff --git a/sparta/specializer/kernels/templates/openai_sparse_matmul_dds.cuh.j2 b/sparta/specializer/kernels/templates/openai_sparse_matmul_dds.cuh.j2 index a2835eb2..65be6d92 100644 --- a/sparta/specializer/kernels/templates/openai_sparse_matmul_dds.cuh.j2 +++ b/sparta/specializer/kernels/templates/openai_sparse_matmul_dds.cuh.j2 @@ -51,6 +51,7 @@ __global__ void BLOCK_SPARSE_MATMUL_OUT_32_64_32( const int BLOCK_SIZE_N = 32; //128 const int THREAD_SIZE_K = 64; + {% if BATCHED %} A += M*K*blockIdx.y; B += K*N*blockIdx.y; {% if COMPRESSED %} @@ -61,6 +62,7 @@ __global__ void BLOCK_SPARSE_MATMUL_OUT_32_64_32( {% if BIASED %} bias += N*blockIdx.y; {% endif %} + {% endif %} assert(blockDim.x % 32 == 0); uint n_warp = 8; // blockDim.x / 32 diff --git a/sparta/specializer/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 b/sparta/specializer/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 index 75b48e07..828a4280 100644 --- a/sparta/specializer/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 +++ b/sparta/specializer/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 @@ -52,6 +52,7 @@ __global__ void BLOCK_SPARSE_MATMUL_32_64_32( const int BLOCK_SIZE_N = 32; //128 const int THREAD_SIZE_K = 64; + {% if BATCHED %} A += M*K*blockIdx.z; {% if COMPRESSED %} B_val += B_block_nnz*64*32*blockIdx.z; @@ -62,6 +63,7 @@ __global__ void BLOCK_SPARSE_MATMUL_32_64_32( {% if BIASED %} bias += N*blockIdx.z; {% endif %} + {% endif %} assert(blockDim.x % 32 == 0); uint n_warp = 8; // blockDim.x / 32 diff --git a/sparta/specializer/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 b/sparta/specializer/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 index fb8f904c..d4dd2e15 100644 --- a/sparta/specializer/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 +++ b/sparta/specializer/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 @@ -52,6 +52,7 @@ __global__ void BLOCK_SPARSE_MATMUL_32_64_32( const int BLOCK_SIZE_N = 32; //128 const int THREAD_SIZE_K = 64; + {% if BATCHED %} {% if COMPRESSED %} A_val += A_block_nnz*64*32*blockIdx.z; {% else %} @@ -62,6 +63,7 @@ __global__ void BLOCK_SPARSE_MATMUL_32_64_32( {% if BIASED %} bias += N*blockIdx.z; {% endif %} + {% endif %} assert(blockDim.x % 32 == 0); uint n_warp = 8; // blockDim.x / 32 diff --git a/sparta/specializer/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 index 1f168445..f32a0293 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 @@ -45,6 +45,7 @@ __global__ void BLOCK_SPARSE_MATMUL( int ty = threadIdx.y; int tx = threadIdx.x; + {% if BATCHED %} A += M * K * blockIdx.y; B += K * N * blockIdx.y; {% if COMPRESSED %} @@ -55,6 +56,7 @@ __global__ void BLOCK_SPARSE_MATMUL( {% if BIASED %} bias += N * blockIdx.y; {% endif %} + {% endif %} __shared__ float As[BM * BK]; __shared__ float Bs[BN * BK]; diff --git a/sparta/specializer/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 index 45c1f66c..66096508 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 @@ -34,6 +34,7 @@ __global__ void BLOCK_SPARSE_MATMUL( int ty = threadIdx.y; int tx = threadIdx.x; + {% if BATCHED %} A += M * K * blockIdx.z; {% if COMPRESSED %} B_val += B_block_nnz * BN * BK * blockIdx.z; @@ -44,6 +45,7 @@ __global__ void BLOCK_SPARSE_MATMUL( {% if BIASED %} bias += N * blockIdx.z; {% endif %} + {% endif %} __shared__ float As[BM * BK]; __shared__ float Bs[BN * BK]; diff --git a/sparta/specializer/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 index bf963634..b666a20b 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 @@ -34,6 +34,7 @@ __global__ void BLOCK_SPARSE_MATMUL( int ty = threadIdx.y; int tx = threadIdx.x; + {% if BATCHED %} {% if COMPRESSED %} A_val += A_block_nnz * BM * BK * blockIdx.z; {% else %} @@ -44,6 +45,7 @@ __global__ void BLOCK_SPARSE_MATMUL( {% if BIASED %} bias += N * blockIdx.z; {% endif %} + {% endif %} __shared__ float As[BM * BK]; __shared__ float Bs[BN * BK]; diff --git a/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 index c4b5b8d8..67c6f1be 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 @@ -5,21 +5,22 @@ {% set INI_OFFSET = WARP_SIZE // 2 %} #define FULL_MASK 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} -const int H = {{ GLOBAL_H_VALUE }}; -const int W = {{ GLOBAL_W_VALUE }}; const int block_h = {{ BLOCK_SIZE_H_VALUE }}; const int block_w = {{ BLOCK_SIZE_W_VALUE }}; const int row_tile = {{ ROW_TILE_VALUE }}; __global__ void SPARSE_SOFTMAX( float* out_grad, - int* row_ptr, - int* BCSR_idx, float* out_val, unsigned char* mask, float temperature, - float* in_grad + float* in_grad, + int* row_ptr, + int* BCSR_idx, + int H, + int W ) { + {% if BATCHED %} {% if COMPRESSED %} int num_nnz = row_ptr[H / block_h]; out_grad += blockIdx.y * num_nnz * block_h * block_w; @@ -30,6 +31,7 @@ __global__ void SPARSE_SOFTMAX( out_val += blockIdx.y * H * W; in_grad += blockIdx.y * H * W; {% endif %} + {% endif %} uint blk_row_idx = blockIdx.x / (block_h/row_tile) ; int block_inter_row = (blockIdx.x % (block_h/row_tile)) * row_tile; @@ -39,19 +41,20 @@ __global__ void SPARSE_SOFTMAX( int block_seq_start = row_ptr[blk_row_idx]; int block_seq_end = row_ptr[blk_row_idx+1]; - uint index_list[W / {{ WARP_SIZE }}]; + uint index_list[{{ MAX_W_VALUE // WARP_SIZE }}]; int val_num = 0; for (int block_inter_col = bn; block_inter_col < block_w; block_inter_col += {{ WARP_SIZE }}) { for (int block_seq = block_seq_start; block_seq < block_seq_end; block_seq++) { + uint mask_index = (blk_row_idx * block_h + block_inter_row + bm) * W + + ((BCSR_idx[block_seq] & 0xffff) * block_w + block_inter_col); {% if COMPRESSED %} - uint index = block_h * block_w * block_seq + + uint val_index = block_h * block_w * block_seq + (block_inter_row + bm) * block_w + block_inter_col; {% else %} - uint index = (blk_row_idx * block_h + block_inter_row + bm) * W + - ((BCSR_idx[block_seq] & 0xffff) * block_w + block_inter_col); + uint val_index = mask_index; {% endif %} - if (mask[index]) { - index_list[val_num++] = index; + if (mask[mask_index]) { + index_list[val_num++] = val_index; } } } diff --git a/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 b/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 index 6e5642f2..5e527e26 100644 --- a/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 +++ b/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 @@ -5,20 +5,21 @@ {% set INI_OFFSET = WARP_SIZE // 2 %} #define FULL_MASK 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} -const int H = {{ GLOBAL_H_VALUE }}; -const int W = {{ GLOBAL_W_VALUE }}; const int block_h = {{ BLOCK_SIZE_H_VALUE }}; const int block_w = {{ BLOCK_SIZE_W_VALUE }}; const int row_tile = {{ ROW_TILE_VALUE }}; __global__ void SPARSE_SOFTMAX( float* in_val, - int* row_ptr, - int* BCSR_idx, unsigned char* mask, float temperature, - float* out_val + float* out_val, + int* row_ptr, + int* BCSR_idx, + int H, + int W ) { + {% if BATCHED %} {% if COMPRESSED %} int num_nnz = row_ptr[H / block_h]; in_val += blockIdx.y * num_nnz * block_h * block_w; @@ -27,6 +28,7 @@ __global__ void SPARSE_SOFTMAX( in_val += blockIdx.y * H * W; out_val += blockIdx.y * H * W; {% endif %} + {% endif %} uint blk_row_idx = blockIdx.x / (block_h/row_tile) ; int block_inter_row = (blockIdx.x % (block_h/row_tile)) * row_tile; @@ -37,23 +39,20 @@ __global__ void SPARSE_SOFTMAX( int block_seq_start = row_ptr[blk_row_idx]; int block_seq_end = row_ptr[blk_row_idx+1]; - uint index_list[W / {{ WARP_SIZE }}]; + uint index_list[{{ MAX_W_VALUE // WARP_SIZE }}]; int val_num = 0; for (int block_inter_col = bn; block_inter_col < block_w; block_inter_col += {{ WARP_SIZE }}) { for (int block_seq = block_seq_start; block_seq < block_seq_end; block_seq++) { + uint mask_index = (blk_row_idx * block_h + block_inter_row + bm) * W + + ((BCSR_idx[block_seq] & 0xffff) * block_w + block_inter_col); {% if COMPRESSED %} - uint index = block_h * block_w * block_seq + + uint val_index = block_h * block_w * block_seq + (block_inter_row + bm) * block_w + block_inter_col; {% else %} - uint index = (blk_row_idx * block_h + block_inter_row + bm) * W + - ((BCSR_idx[block_seq] & 0xffff) * block_w + block_inter_col); + uint val_index = mask_index; {% endif %} - /* - index_list += index * mask[index]; - val_num += mask[index]; - */ - if (mask[index]) { - index_list[val_num++] = index; + if (mask[mask_index]) { + index_list[val_num++] = val_index; } } } diff --git a/sparta/specializer/operators/operator_base.py b/sparta/specializer/operators/operator_base.py index 356b6cc1..23e194b8 100644 --- a/sparta/specializer/operators/operator_base.py +++ b/sparta/specializer/operators/operator_base.py @@ -7,7 +7,7 @@ import torch -from sparta.specializer.funtional import SparseCtxBase +from sparta.specializer.functional import SparseFunctionBase class OperatorBase(torch.nn.Module): diff --git a/sparta/specializer/operators/seqlen_dynamic_sparse_attention.py b/sparta/specializer/operators/seqlen_dynamic_sparse_attention.py index 86ba1245..ace18acc 100644 --- a/sparta/specializer/operators/seqlen_dynamic_sparse_attention.py +++ b/sparta/specializer/operators/seqlen_dynamic_sparse_attention.py @@ -5,7 +5,7 @@ import torch -import seqlen_dynamic_sparse_attention_cpp +# import seqlen_dynamic_sparse_attention_cpp class SeqlenDynamicSparseAttentionFunction(torch.autograd.Function): diff --git a/sparta/specializer/operators/sparse_attention.py b/sparta/specializer/operators/sparse_attention.py index 581d4304..2f1927b7 100644 --- a/sparta/specializer/operators/sparse_attention.py +++ b/sparta/specializer/operators/sparse_attention.py @@ -8,7 +8,7 @@ import numpy as np from sparta.specializer.operators import OperatorBase, SparseBatchMatMul, SparseSoftmax -from sparta.specializer.kernels import KernelBase, PortConfig +from sparta.specializer.kernels import KernelBase class SparseAttention(OperatorBase): diff --git a/sparta/specializer/operators/sparse_linear.py b/sparta/specializer/operators/sparse_linear.py index cbca75cb..fc9780da 100644 --- a/sparta/specializer/operators/sparse_linear.py +++ b/sparta/specializer/operators/sparse_linear.py @@ -6,7 +6,7 @@ import torch from sparta.specializer.operators import OperatorBase -from sparta.specializer.funtional import SparseBatchMatMulCtx, SparseBatchMatMulFunc +from sparta.specializer.functional import SparseBatchMatMul class SparseLinear(OperatorBase): @@ -58,7 +58,7 @@ class SparseLinear(OperatorBase): """ __base_class__ = torch.nn.Linear - __sparse_func__ = SparseBatchMatMulFunc + __sparse_func__ = SparseBatchMatMul def __init__( self, diff --git a/sparta/specializer/operators/sparse_matmul.py b/sparta/specializer/operators/sparse_matmul.py index 03922a93..9aba9e49 100644 --- a/sparta/specializer/operators/sparse_matmul.py +++ b/sparta/specializer/operators/sparse_matmul.py @@ -6,7 +6,7 @@ import torch from sparta.specializer.operators import OperatorBase -from sparta.specializer.funtional import SparseBatchMatMulCtx, SparseBatchMatMulFunc +from sparta.specializer.functional import SparseBatchMatMul class SparseBatchMatMul(OperatorBase): @@ -60,7 +60,7 @@ class SparseBatchMatMul(OperatorBase): """ - __sparse_func__ = SparseBatchMatMulFunc + __sparse_func__ = SparseBatchMatMul def __init__( self, diff --git a/sparta/specializer/operators/sparse_moe.py b/sparta/specializer/operators/sparse_moe.py index 06749764..cb954f37 100644 --- a/sparta/specializer/operators/sparse_moe.py +++ b/sparta/specializer/operators/sparse_moe.py @@ -5,7 +5,7 @@ import torch -import sparse_moe_cpp +# import sparse_moe_cpp class DynamicSparseMoE(torch.nn.Module): diff --git a/sparta/specializer/operators/sparse_softmax.py b/sparta/specializer/operators/sparse_softmax.py index 0c8f1795..6b650e6a 100644 --- a/sparta/specializer/operators/sparse_softmax.py +++ b/sparta/specializer/operators/sparse_softmax.py @@ -6,7 +6,7 @@ import numpy as np from sparta.specializer.operators import OperatorBase -from sparta.specializer.funtional import SparseBatchSoftmaxCtx, SparseBatchSoftmaxFunc +from sparta.specializer.functional import SparseBatchSoftmax class SparseSoftmax(OperatorBase): @@ -43,7 +43,7 @@ class SparseSoftmax(OperatorBase): """ - __sparse_func__ = SparseBatchSoftmaxFunc + __sparse_func__ = SparseBatchSoftmax def __init__(self, mask: torch.Tensor, temperature: float = 1, compressed: bool = False): super().__init__() diff --git a/test/lut_maker/matmul.py b/test/lut_maker/matmul.py new file mode 100644 index 00000000..ea3382c8 --- /dev/null +++ b/test/lut_maker/matmul.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import logging +import itertools +from typing import Dict, Any + +import torch +import pandas as pd + +from sparta.specializer.functional.batch_matmul import SparseBatchMatMulForward +from sparta.testing import block_mask + + +SIZE = 4096 +RANDOM_SEED = 2022 +SPEC_SEARCH_SPACE = { + 'mode': ['sdd', 'dsd', 'dds'], + 'trans_A': [False, True], + 'trans_B': [False, True], +} +PARAM_SEARCH_SPACE = { + 'sparta': { + 'BLOCK_SIZE_M_VALUE': [8, 16, 32, 64, 128], + 'BLOCK_SIZE_K_VALUE': [8, 16, 32, 64, 128], + 'BLOCK_SIZE_N_VALUE': [8, 16, 32, 64, 128], + 'THREAD_SIZE_M_VALUE': [2, 4, 8, 16], + 'THREAD_SIZE_K_VALUE': [2, 4, 8, 16], + 'THREAD_SIZE_N_VALUE': [2, 4, 8, 16], + }, + 'openai': { + 'BLOCK_SIZE_M_VALUE': [32], + 'BLOCK_SIZE_K_VALUE': [64], + 'BLOCK_SIZE_N_VALUE': [32], + }, +} +HYPER_PARAMS = ['mode', 'trans_A', 'trans_B', 'BM', 'BK', 'BN'] + + +_logger = logging.Logger(__name__) +_handler = logging.StreamHandler() +_logger.addHandler(_handler) + + +def test_matmul_kernel( + impl: str, + func: SparseBatchMatMulForward, + params: Dict[str, Any], +): + try: + func.build(config={'forward': {'_impl': impl, **params}}) + latency = func.profile_kernel('forward', num_warmups=10, num_iters=10, cuda=False) + except: + latency = float('inf') + + return latency + + +def make_matmul_lut(impl: str): + major, minor = torch.cuda.get_device_capability() + lut_file = os.path.join( + 'sparta', + 'specializer', + 'kernels', + 'look_up_tables', + f'matmul.{impl}.{major}{minor}.csv' + ) + log_file = os.path.join( + 'test', + 'lut_maker', + f'matmul.{impl}.{major}{minor}.log.csv' + ) + _logger.info(f'========== Making LUT: {lut_file} ==========') + + num = 1 + spec_keys, spec_values = [], [] + for k, v in SPEC_SEARCH_SPACE.items(): + spec_keys.append(k) + spec_values.append(v) + num *= len(v) + param_keys, param_alts, param_values = [], [], [] + for k, v in PARAM_SEARCH_SPACE[impl].items(): + param_keys.append(k) + param_alts.append(f'{k[0]}{k[-7]}') + param_values.append(v) + num *= len(v) + bits = len(str(num)) + + with open(log_file, 'w') as f: + header = ','.join(spec_keys) + ',' + ','.join(param_alts) + ',latency\n' + header = header.replace('BLOCK', 'B').replace('THREAD', 'T') + header = header.replace('_SIZE_', '').replace('_VALUE', '') + f.write(header) + + torch.manual_seed(RANDOM_SEED) + A = torch.rand(size=(1, SIZE, SIZE), device='cuda') + B = torch.rand(size=(1, SIZE, SIZE), device='cuda') + mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') + + iters = 0 + for specs in itertools.product(*spec_values): + mode, trans_A, trans_B = specs + func = SparseBatchMatMulForward(mode, trans_A, trans_B, False, True) + func.get_sparse_attr().set_mask(mask) + func.reference_forward([A, B]) + for params in itertools.product(*param_values): + param_dict = {k: v for k, v in zip(param_keys, params)} + latency = test_matmul_kernel(impl, func, param_dict) + with open(log_file, 'a') as f: + items = [mode, trans_A, trans_B, *params, latency] + f.write(','.join([str(x) for x in items]) + '\n') + iters += 1 + _logger.info(f'[{str(iters).zfill(bits)} / {num}] {params} => {latency} ms') + + df = pd.read_csv(log_file) + df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] + with open(lut_file, 'w') as f: + f.write(df.reset_index(drop=True).to_csv(index=False)) + + _logger.info(f'========== Finished. Output: {lut_file} ==========') + + +if __name__ == '__main__': + _logger.setLevel(logging.DEBUG) + make_matmul_lut('sparta') + make_matmul_lut('openai') diff --git a/test/lut_maker/softmax.py b/test/lut_maker/softmax.py new file mode 100644 index 00000000..38360178 --- /dev/null +++ b/test/lut_maker/softmax.py @@ -0,0 +1,106 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import os +import logging +import itertools +from typing import Dict, Any + +import torch +import numpy as np +import pandas as pd + +from sparta.specializer.functional import SparseBatchSoftmax +from sparta.testing import block_mask + + +SIZE = 4096 +RANDOM_SEED = 2022 +SEARCH_SPACE = { + 'BLOCK_SIZE_H_VALUE': [8, 16, 32, 64, 128], + 'BLOCK_SIZE_W_VALUE': [8, 16, 32, 64, 128], + 'ROW_TILE_VALUE': [1, 2, 4, 8, 16], +} +HYPER_PARAMS = ['BH', 'BW'] + + +_logger = logging.Logger(__name__) +_handler = logging.StreamHandler() +_logger.addHandler(_handler) + + +def test_softmax_kernel( + impl: str, + func: SparseBatchSoftmax, + direction: str, + params: Dict[str, Any], +): + try: + func.build(config={direction: {'_impl': impl, **params}}) + latency = func.profile_kernel(direction, num_warmups=10, num_iters=10, cuda=False) + except: + latency = float('inf') + + return latency + + +def make_softmax_lut(impl: str, direction: str): + major, minor = torch.cuda.get_device_capability() + lut_file = os.path.join( + impl, + 'specializer', + 'kernels', + 'look_up_tables', + f'softmax.{direction}.{impl}.{major}{minor}.csv' + ) + log_file = os.path.join( + 'test', + 'lut_maker', + f'softmax.{direction}.{impl}.{major}{minor}.log.csv' + ) + _logger.info(f'========== Making LUT: {lut_file} ==========') + + num = 1 + keys, alts, values = [], [], [] + for k, v in SEARCH_SPACE.items(): + keys.append(k) + alt = [s[0] for s in k.split('_')] + alts.append(f'{alt[0]}{alt[-2]}') + values.append(v) + num *= len(v) + bits = len(str(num)) + + with open(log_file, 'w') as f: + f.write(','.join(alts) + ',latency\n') + + torch.manual_seed(RANDOM_SEED) + x = torch.rand(size=(1, SIZE, SIZE), device='cuda', requires_grad=True) + grad_y = torch.rand(size=(1, SIZE, SIZE), device='cuda') + mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') + + func = SparseBatchSoftmax(compressed=True, temperature=np.float32(1 / np.sqrt(SIZE))) + func.get_sparse_attr().set_mask(mask) + func.reference_forward([x]) + func.reference_backward([grad_y]) + + iters = 0 + for params in itertools.product(*values): + param_dict = {k: v for k, v in zip(keys, params)} + latency = test_softmax_kernel(impl, func, direction, param_dict) + with open(log_file, 'a') as f: + f.write(','.join([str(x) for x in params]) + f',{latency}\n') + iters += 1 + _logger.info(f'[{str(iters).zfill(bits)} / {num}] {params} => {latency} ms') + + df = pd.read_csv(log_file) + df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] + with open(lut_file, 'w') as f: + f.write(df.reset_index(drop=True).to_csv(index=False)) + + _logger.info(f'========== Finished. Output: {lut_file} ==========') + + +if __name__ == '__main__': + _logger.setLevel(logging.DEBUG) + make_softmax_lut('sparta', 'forward') + make_softmax_lut('sparta', 'backward') diff --git a/test/lut_maker/sparta_matmul.py b/test/lut_maker/sparta_matmul.py deleted file mode 100644 index b6a468da..00000000 --- a/test/lut_maker/sparta_matmul.py +++ /dev/null @@ -1,124 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import os -import logging -import itertools - -import torch -import pandas as pd - -from sparta.specializer.kernels import SparTASparseMatMulKernel -from sparta.testing import block_mask - - -SIZE = 4096 -RANDOM_SEED = 2022 -SEARCH_SPACE = { - 'mode': ['sdd', 'dsd', 'dds'], - 'trans_A': [False, True], - 'trans_B': [False, True], - 'BM': [8, 16, 32, 64, 128], - 'BK': [8, 16, 32, 64, 128], - 'BN': [8, 16, 32, 64, 128], - 'TM': [2, 4, 8, 16], - 'TK': [2, 4, 8, 16], - 'TN': [2, 4, 8, 16], -} -HYPER_PARAMS = ['mode', 'trans_A', 'trans_B', 'BM', 'BK', 'BN'] - - -_logger = logging.Logger(__name__) -_handler = logging.StreamHandler() -_logger.addHandler(_handler) - - -def test_sparta_matmul_kernel( - A: torch.Tensor, - B: torch.Tensor, - mask: torch.Tensor, - mode: str, - trans_A: bool, - trans_B: bool, - BM: int, - BK: int, - BN: int, - TM: int, - TK: int, - TN: int, -): - sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - kernel = SparTASparseMatMulKernel( - mode=mode, - biased=False, - transpose_A=trans_A, - transpose_B=trans_B, - compressed=True, - ) - - try: - kernel.ports[sparse_port].set_mask(mask) - kernel.set_shape(1, SIZE, SIZE, SIZE) - kernel.compile({ - 'BLOCK_SIZE_M_VALUE': BM, - 'BLOCK_SIZE_K_VALUE': BK, - 'BLOCK_SIZE_N_VALUE': BN, - 'THREAD_SIZE_M_VALUE': TM, - 'THREAD_SIZE_K_VALUE': TK, - 'THREAD_SIZE_N_VALUE': TN, - }) - latency = kernel.test([A, B], num_warmups=10, num_iters=10, cuda=False) - except: - latency = float('inf') - - return latency - - -def make_sparta_matmul_lut(): - major, minor = torch.cuda.get_device_capability() - lut_file = os.path.join( - 'sparta', - 'specializer', - 'kernels', - 'look_up_tables', - f'matmul.sparta.{major}{minor}.csv' - ) - log_file = os.path.join( - 'test', - 'lut_maker', - f'matmul.sparta.{major}{minor}.log.csv' - ) - _logger.info(f'========== Making LUT: {lut_file} ==========') - - num = 1 - keys, values = [], [] - for k, v in SEARCH_SPACE.items(): - keys.append(k) - values.append(v) - num *= len(v) - - with open(log_file, 'w') as f: - f.write(','.join(keys) + ',latency\n') - - torch.manual_seed(RANDOM_SEED) - A = torch.rand(size=(1, SIZE, SIZE), device='cuda') - B = torch.rand(size=(1, SIZE, SIZE), device='cuda') - mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') - - for i, params in enumerate(itertools.product(*values)): - latency = test_sparta_matmul_kernel(A, B, mask, **{k: v for k, v in zip(keys, params)}) - with open(log_file, 'a') as f: - f.write(','.join([str(x) for x in params]) + f',{latency}\n') - _logger.info(f'[{i} / {num}] {params} => {latency} ms') - - df = pd.read_csv(log_file) - df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] - with open(lut_file, 'w') as f: - f.write(df.reset_index(drop=True).to_csv(index=False)) - - _logger.info(f'========== Finished. Output: {lut_file} ==========') - - -if __name__ == '__main__': - _logger.setLevel(logging.DEBUG) - make_sparta_matmul_lut() diff --git a/test/lut_maker/sparta_softmax.py b/test/lut_maker/sparta_softmax.py deleted file mode 100644 index fc072954..00000000 --- a/test/lut_maker/sparta_softmax.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import os -import logging -import itertools -from typing import Dict - -import torch -import numpy as np -import pandas as pd - -from sparta.specializer.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.testing import block_mask - - -SIZE = 4096 -RANDOM_SEED = 2022 -SEARCH_SPACE = { - 'BH': [8, 16, 32, 64, 128], - 'BW': [8, 16, 32, 64, 128], - 'RT': [1, 2, 4, 8, 16], -} -HYPER_PARAMS = ['BH', 'BW'] - - -_logger = logging.Logger(__name__) -_handler = logging.StreamHandler() -_logger.addHandler(_handler) - - -def test_sparta_softmax_kernel( - data: Dict, - mask: torch.Tensor, - direction: str, - BH: int, - BW: int, - RT: int, -): - if direction == 'forward': - kernel = SparTASparseSoftmaxForwardKernel(compressed=True) - elif direction == 'backward': - kernel = SparTASparseSoftmaxBackwardKernel(compressed=True) - else: - raise ValueError(f'unrecognized direction: {direction}') - - try: - kernel.ports['y'].set_mask(mask) - kernel.set_shape(1, SIZE, SIZE) - kernel.compile({ - 'BLOCK_SIZE_H_VALUE': BH, - 'BLOCK_SIZE_W_VALUE': BW, - 'ROW_TILE_VALUE': RT, - }) - if direction == 'forward': - inputs = [data['x'], data['T']] - else: - inputs = [data['grad_y'], data['y'], data['T']] - latency = kernel.test(inputs, num_warmups=10, num_iters=10, cuda=False) - except: - latency = float('inf') - - return latency - - -def make_sparta_softmax_lut(direction: str): - major, minor = torch.cuda.get_device_capability() - lut_file = os.path.join( - 'sparta', - 'specializer', - 'kernels', - 'look_up_tables', - f'softmax.{direction}.sparta.{major}{minor}.csv' - ) - log_file = os.path.join( - 'test', - 'lut_maker', - f'softmax.{direction}.sparta.{major}{minor}.log.csv' - ) - _logger.info(f'========== Making LUT: {lut_file} ==========') - - num = 1 - keys, values = [], [] - for k, v in SEARCH_SPACE.items(): - keys.append(k) - values.append(v) - num *= len(v) - - with open(log_file, 'w') as f: - f.write(','.join(keys) + ',latency\n') - - torch.manual_seed(RANDOM_SEED) - data = {} - data['x'] = torch.rand(size=(1, SIZE, SIZE), device='cuda') - data['T'] = np.float32(1 / np.sqrt(SIZE)) - data['y'] = torch.rand(size=(1, SIZE, SIZE), device='cuda') - data['grad_y'] = torch.rand(size=(1, SIZE, SIZE), device='cuda') - mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') - - for i, params in enumerate(itertools.product(*values)): - latency = test_sparta_softmax_kernel( - data, - mask, - direction, - **{k: v for k, v in zip(keys, params)} - ) - with open(log_file, 'a') as f: - f.write(','.join([str(x) for x in params]) + f',{latency}\n') - _logger.info(f'[{i} / {num}] {params} => {latency} ms') - - df = pd.read_csv(log_file) - df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] - with open(lut_file, 'w') as f: - f.write(df.reset_index(drop=True).to_csv(index=False)) - - _logger.info(f'========== Finished. Output: {lut_file} ==========') - - -if __name__ == '__main__': - _logger.setLevel(logging.DEBUG) - make_sparta_softmax_lut('forward') - make_sparta_softmax_lut('backward') diff --git a/test/unit/test_sparse_matmul.py b/test/unit/test_sparse_matmul.py index d57f3b9e..99db55c2 100644 --- a/test/unit/test_sparse_matmul.py +++ b/test/unit/test_sparse_matmul.py @@ -1,20 +1,20 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Dict, Tuple, Type +from typing import Dict, Tuple, Type, Optional import torch import pytest from sparta.specializer.kernels import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.specializer.funtional import SparseBatchMatMulCtx, SparseBatchMatMulFunc -from sparta.nn import SparseBatchMatMul, SparseLinear +from sparta.specializer.functional import SparsityAttr, SparseMatMul, SparseBatchMatMul +# from sparta.nn import SparseBatchMatMul, SparseLinear from sparta.tesa import BCSIndexes from sparta.testing import block_mask def prepare_data( - batch: int = 4, + batch: Optional[int] = 4, M: int = 128, K: int = 256, N: int = 192, @@ -41,33 +41,33 @@ def prepare_data( torch.manual_seed(random_seed) data: Dict[str, torch.Tensor] = {} for x in inputs: - shape = (batch, *shapes[x]) if batch > 0 else shapes[x] + shape = shapes[x] if batch is None else (batch, *shapes[x]) data[f'input_{x}'] = torch.rand(size=shape, device='cuda') if requires_grad: for y in outputs: - shape = (batch, *shapes[y]) if batch > 0 else shapes[y] + shape = shapes[y] if batch is None else (batch, *shapes[y]) data[f'input_grad_{y}'] = torch.rand(size=shape, device='cuda') sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] mask = block_mask(shapes[sparse_port], block=granularity, sparsity=sparsity, device='cuda') - add_mask(data, {sparse_port: mask}, sparse_port, 'input') + add_mask(data, mask, sparse_port, 'input') if requires_grad: for x in inputs: data[f'input_{x}'].requires_grad = True - if batch > 0: - input_A = data['input_A'].swapaxes(1, 2) if trans_A else data['input_A'] - input_B = data['input_B'].swapaxes(1, 2) if trans_B else data['input_B'] - data['target_C'] = torch.bmm(input_A, input_B) - if biased: - data['target_C'] += data['input_bias'].unsqueeze(1) - else: + if batch is None: input_A = data['input_A'].T if trans_A else data['input_A'] input_B = data['input_B'].T if trans_B else data['input_B'] data['target_C'] = torch.mm(input_A, input_B) if biased: data['target_C'] += data['input_bias'] + else: + input_A = data['input_A'].swapaxes(1, 2) if trans_A else data['input_A'] + input_B = data['input_B'].swapaxes(1, 2) if trans_B else data['input_B'] + data['target_C'] = torch.bmm(input_A, input_B) + if biased: + data['target_C'] += data['input_bias'].unsqueeze(1) if requires_grad: data['target_C'].backward(data['input_grad_C']) @@ -79,20 +79,20 @@ def prepare_data( data['target_grad_bias'] = data['input_bias'].grad data['input_bias'].grad = None - add_mask(data, {sparse_port: mask}, sparse_port, 'target') + add_mask(data, mask, sparse_port, 'target') - return data, {sparse_port: mask} + return data, mask def add_mask( data: Dict[str, torch.Tensor], - masks: Dict[str, torch.Tensor], + mask: torch.Tensor, sparse_port: str, stage: str, ): for name, val in data.items(): if name.startswith(stage) and name.endswith(sparse_port): - val *= masks[sparse_port] + val *= mask def get_params(impl: str): @@ -111,14 +111,16 @@ def compress_data( indexes: BCSIndexes, sparse_port: str, data: Dict[str, torch.Tensor], - masks: Dict[str, torch.Tensor], + mask: torch.Tensor, + requires_grad: bool, ): for name in data: if name.endswith(sparse_port): data[name] = indexes.convert(data[name].detach()) - masks[sparse_port] = indexes.convert(masks[sparse_port].to(torch.float32)).to(torch.uint8) - if sparse_port in ['A', 'B']: + mask = indexes.convert(mask.to(torch.float32)).to(torch.uint8) + if sparse_port in ['A', 'B'] and requires_grad: data[f'input_{sparse_port}'].requires_grad = True + return data, mask def check_results(data: Dict[str, torch.Tensor]): @@ -133,6 +135,7 @@ def check_results(data: Dict[str, torch.Tensor]): @pytest.mark.parametrize("trans_A", [False, True]) @pytest.mark.parametrize("trans_B", [False, True]) @pytest.mark.parametrize("compressed", [False, True]) +@pytest.mark.parametrize("batch", [None, 4]) def test_sparse_matmul_kernel( impl: str, mode: str, @@ -140,35 +143,63 @@ def test_sparse_matmul_kernel( compressed: bool, trans_A: bool, trans_B: bool, - batch: int = 4, + batch: Optional[int], M: int = 128, K: int = 256, N: int = 192, granularity: Tuple[int, int] = (8, 8), sparsity: float = 0.9, ): - data, masks = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, False) + data, mask = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, False) kernelClass: Type[SparseMatMulKernel] = { 'sparta': SparTASparseMatMulKernel, 'openai': OpenAISparseMatMulKernel, }[impl] + batched = batch is not None kernel = kernelClass( mode=mode, biased=biased, transpose_A=trans_A, transpose_B=trans_B, compressed=compressed, + batched=batched, ) + shape = (batch, M, K, N) + sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] + BCSR = { + 'A': not trans_A, + 'B': trans_B, + 'C': True, + }[sparse_port] + BCSC = not BCSR + attr = SparsityAttr(BCSR, BCSC) + attr.set_mask(mask) + kernel.set_parameter('BCSR', BCSR) + kernel.set_parameter('BCSC', BCSC) + kernel.compile(get_params(impl), shape, attr) + + sparse_axis = { + 'A': ['K', 'M'] if trans_A else ['M', 'K'], + 'B': ['N', 'K'] if trans_B else ['K', 'N'], + 'C': ['M', 'N'], + }[sparse_port] + attr.BCSR = BCSR + attr.BCSC = BCSC + attr.set_block_size(*[ + kernel.get_parameter(f'BLOCK_SIZE_{i}_VALUE') + for i in sparse_axis + ]) - for sparse_port, mask in masks.items(): - kernel.ports[sparse_port].set_mask(mask) - kernel.set_shape(batch, M, K, N) - kernel.compile(get_params(impl)) + if compressed: + data, mask = compress_data(attr.indexes, sparse_port, data, mask, False) inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] input_data = [data[f'input_{x}'] for x in inputs] - kernel.test(input_data, num_warmups=0, num_iters=1, cuda=False) + + data['output_C'] = kernel(*input_data) + add_mask(data, mask, sparse_port, 'output') + check_results(data) @pytest.mark.parametrize("mode", ['sdd', 'dsd', 'dds']) @@ -176,161 +207,158 @@ def test_sparse_matmul_kernel( @pytest.mark.parametrize("trans_A", [False, True]) @pytest.mark.parametrize("trans_B", [False, True]) @pytest.mark.parametrize("compressed", [False, True]) +@pytest.mark.parametrize("batch", [None, 4]) def test_sparse_matmul_function( mode: str, biased: bool, compressed: bool, trans_A: bool, trans_B: bool, - batch: int = 4, + batch: Optional[int], M: int = 128, K: int = 256, N: int = 192, granularity: Tuple[int, int] = (8, 8), sparsity: float = 0.9, ): - data, masks = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, True) - - sparse_ctx = SparseBatchMatMulCtx(mode, trans_A, trans_B, biased, compressed) - kernel_names = sparse_ctx.get_kernel_placeholders(backward=True).keys() - sparse_ctx.select_impls({ - kernel_name: 'sparta' - for kernel_name in kernel_names - }) - sparse_ctx.set_shape(batch, M, K, N) - for port_name, ports in sparse_ctx.sparse_ports.items(): - for port in ports: - port.set_mask(masks[port_name]) - sparse_ctx.build({ - kernel_name: get_params('sparta') - for kernel_name in kernel_names - }) + data, mask = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, True) + + if batch is None: + func = SparseMatMul(mode, trans_A, trans_B, biased, compressed) + else: + func = SparseBatchMatMul(mode, trans_A, trans_B, biased, compressed) + + sparse_attr = func.get_sparse_attr() + sparse_attr.set_mask(mask) + + kernel_names = ['forward', 'backward:A', 'backward:B'] + inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] + func.build( + config={kernel_name: get_params('sparta') for kernel_name in kernel_names}, + sample_inputs=[data[f'input_{x}'] for x in inputs] + ) sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] if compressed: - compress_data(sparse_ctx.sparse_ports[sparse_port][0].indexes, sparse_port, data, masks) + data, mask = compress_data(sparse_attr.indexes, sparse_port, data, mask, True) - inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] - input_data = [data[f'input_{x}'] for x in inputs] - data['output_C'] = SparseBatchMatMulFunc.apply(sparse_ctx, *input_data) + data['output_C'] = func(*[data[f'input_{x}'] for x in inputs]) data['output_C'].backward(data['input_grad_C']) for x in inputs: data[f'output_grad_{x}'] = data[f'input_{x}'].grad - - add_mask(data, masks, sparse_port, 'output') - + add_mask(data, mask, sparse_port, 'output') check_results(data) -@pytest.mark.parametrize("mode", ['sdd', 'dsd', 'dds']) -@pytest.mark.parametrize("trans_A", [False, True]) -@pytest.mark.parametrize("trans_B", [False, True]) -@pytest.mark.parametrize("compressed", [False, True]) -def test_sparse_matmul_operator( - mode: str, - compressed: bool, - trans_A: bool, - trans_B: bool, - batch: int = 4, - M: int = 128, - K: int = 256, - N: int = 192, - granularity: Tuple[int, int] = (8, 8), - sparsity: float = 0.9, -): - data, masks = prepare_data( - batch, M, K, N, granularity, sparsity, - mode, trans_A, trans_B, False, True, - ) - - sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - sparse_matmul = SparseBatchMatMul( - **{f'{name}_mask': val for name, val in masks.items()}, - transpose_A=trans_A, - transpose_B=trans_B, - compressed=compressed, - ) - sparse_matmul.build( - config={ - kernel_name: get_params('sparta') - for kernel_name in sparse_matmul.get_kernel_placeholders(backward=True) - }, - sample_inputs=[data['input_A'], data['input_B']], - ) - - for random_seed in range(3): # Test dynamic sparse - if compressed: - compress_data(sparse_matmul.get_sparse_indexes(sparse_port), sparse_port, data, masks) - - data['output_C'] = sparse_matmul.forward(data['input_A'], data['input_B']) - data['output_C'].backward(data['input_grad_C']) - for x in ['A', 'B']: - data[f'output_grad_{x}'] = data[f'input_{x}'].grad - - add_mask(data, masks, sparse_port, 'output') - - check_results(data) - - data, masks = prepare_data( - batch, M, K, N, granularity, sparsity, - mode, trans_A, trans_B, False, True, random_seed, - ) - sparse_matmul.update_mask(**{f'{name}_mask': val for name, val in masks.items()}) - - -@pytest.mark.parametrize('mode', ['sdd', 'dsd', 'dds']) -@pytest.mark.parametrize('biased', [False, True]) -def test_sparse_linear_operator( - mode: str, - biased: bool, - batch: int = 128, - in_dims: int = 256, - out_dims: int = 192, - granularity: Tuple[int, int] = (8, 8), - sparsity: float = 0.9, -): - data, masks = prepare_data( - -1, batch, in_dims, out_dims, granularity, sparsity, - mode, False, True, biased, True, - ) - - dense_linear = torch.nn.Linear(in_dims, out_dims, bias=biased, device='cuda') - if biased: - dense_linear.load_state_dict({'weight': data['input_B'], 'bias': data['input_bias']}) - else: - dense_linear.load_state_dict({'weight': data['input_B']}) - - sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - mask_name = {'sdd': 'input_mask', 'dsd': 'weight_mask', 'dds': 'output_mask'}[mode] - sparse_linear = SparseLinear(dense_linear, **{mask_name: masks[sparse_port]}) - sparse_linear.build( - config={ - kernel_name: get_params('sparta') - for kernel_name in sparse_linear.get_kernel_placeholders(backward=True) - }, - sample_inputs=[data['input_A']], - ) - - for random_seed in range(3): # Test dynamic sparse - if mode == 'dsd': - compress_data(sparse_linear.get_sparse_indexes('B'), 'B', data, masks) - - data['output_C'] = sparse_linear.forward(data['input_A']) - data['output_C'].backward(data['input_grad_C']) - data[f'output_grad_A'] = data[f'input_A'].grad - data[f'output_grad_B'] = sparse_linear.weight.grad - if biased: - data[f'output_grad_bias'] = sparse_linear.bias.grad - - add_mask(data, masks, sparse_port, 'output') - - check_results(data) - - data, masks = prepare_data( - -1, batch, in_dims, out_dims, granularity, sparsity, - mode, False, True, biased, True, random_seed, - ) - if biased: - sparse_linear.bias = torch.nn.Parameter(data['input_bias']) - sparse_linear._raw_weight = data['input_B'] - sparse_linear.update_mask(**{mask_name: masks[sparse_port]}) +# @pytest.mark.parametrize("mode", ['sdd', 'dsd', 'dds']) +# @pytest.mark.parametrize("trans_A", [False, True]) +# @pytest.mark.parametrize("trans_B", [False, True]) +# @pytest.mark.parametrize("compressed", [False, True]) +# def test_sparse_matmul_operator( +# mode: str, +# compressed: bool, +# trans_A: bool, +# trans_B: bool, +# batch: int = 4, +# M: int = 128, +# K: int = 256, +# N: int = 192, +# granularity: Tuple[int, int] = (8, 8), +# sparsity: float = 0.9, +# ): +# data, masks = prepare_data( +# batch, M, K, N, granularity, sparsity, +# mode, trans_A, trans_B, False, True, +# ) + +# sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] +# sparse_matmul = SparseBatchMatMul( +# **{f'{name}_mask': val for name, val in masks.items()}, +# transpose_A=trans_A, +# transpose_B=trans_B, +# compressed=compressed, +# ) +# sparse_matmul.build( +# config={ +# kernel_name: get_params('sparta') +# for kernel_name in sparse_matmul.get_kernel_placeholders(backward=True) +# }, +# sample_inputs=[data['input_A'], data['input_B']], +# ) + +# for random_seed in range(3): # Test dynamic sparse +# if compressed: +# compress_data(sparse_matmul.get_sparse_indexes(sparse_port), sparse_port, data, masks) + +# data['output_C'] = sparse_matmul.forward(data['input_A'], data['input_B']) +# data['output_C'].backward(data['input_grad_C']) +# for x in ['A', 'B']: +# data[f'output_grad_{x}'] = data[f'input_{x}'].grad + +# add_mask(data, masks, sparse_port, 'output') + +# check_results(data) + +# data, masks = prepare_data( +# batch, M, K, N, granularity, sparsity, +# mode, trans_A, trans_B, False, True, random_seed, +# ) +# sparse_matmul.update_mask(**{f'{name}_mask': val for name, val in masks.items()}) + + +# @pytest.mark.parametrize('mode', ['sdd', 'dsd', 'dds']) +# @pytest.mark.parametrize('biased', [False, True]) +# def test_sparse_linear_operator( +# mode: str, +# biased: bool, +# batch: int = 128, +# in_dims: int = 256, +# out_dims: int = 192, +# granularity: Tuple[int, int] = (8, 8), +# sparsity: float = 0.9, +# ): +# data, masks = prepare_data( +# -1, batch, in_dims, out_dims, granularity, sparsity, +# mode, False, True, biased, True, +# ) + +# dense_linear = torch.nn.Linear(in_dims, out_dims, bias=biased, device='cuda') +# if biased: +# dense_linear.load_state_dict({'weight': data['input_B'], 'bias': data['input_bias']}) +# else: +# dense_linear.load_state_dict({'weight': data['input_B']}) + +# sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] +# mask_name = {'sdd': 'input_mask', 'dsd': 'weight_mask', 'dds': 'output_mask'}[mode] +# sparse_linear = SparseLinear(dense_linear, **{mask_name: masks[sparse_port]}) +# sparse_linear.build( +# config={ +# kernel_name: get_params('sparta') +# for kernel_name in sparse_linear.get_kernel_placeholders(backward=True) +# }, +# sample_inputs=[data['input_A']], +# ) + +# for random_seed in range(3): # Test dynamic sparse +# if mode == 'dsd': +# compress_data(sparse_linear.get_sparse_indexes('B'), 'B', data, masks) + +# data['output_C'] = sparse_linear.forward(data['input_A']) +# data['output_C'].backward(data['input_grad_C']) +# data[f'output_grad_A'] = data[f'input_A'].grad +# data[f'output_grad_B'] = sparse_linear.weight.grad +# if biased: +# data[f'output_grad_bias'] = sparse_linear.bias.grad + +# add_mask(data, masks, sparse_port, 'output') + +# check_results(data) + +# data, masks = prepare_data( +# -1, batch, in_dims, out_dims, granularity, sparsity, +# mode, False, True, biased, True, random_seed, +# ) +# if biased: +# sparse_linear.bias = torch.nn.Parameter(data['input_bias']) +# sparse_linear._raw_weight = data['input_B'] +# sparse_linear.update_mask(**{mask_name: masks[sparse_port]}) diff --git a/test/unit/test_sparse_softmax.py b/test/unit/test_sparse_softmax.py index 623b6943..2886c2df 100644 --- a/test/unit/test_sparse_softmax.py +++ b/test/unit/test_sparse_softmax.py @@ -8,8 +8,8 @@ import numpy as np from sparta.specializer.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.specializer.funtional import SparseBatchSoftmaxCtx, SparseBatchSoftmaxFunc -from sparta.nn import SparseSoftmax +from sparta.specializer.functional import SparsityAttr, SparseSoftmax, SparseBatchSoftmax +# from sparta.nn import SparseSoftmax from sparta.testing import block_mask, sparse_softmax_forward_reference @@ -30,16 +30,19 @@ def prepare_data( mask = block_mask((H, W), block=granularity, sparsity=sparsity, device='cuda') - if requires_grad: - data['grad_y'] = torch.rand(shape, device='cuda') - data['input_x'].requires_grad = True + data['grad_y'] = torch.rand(shape, device='cuda') + data['input_x'].requires_grad = True data['target_y'] = sparse_softmax_forward_reference(data['input_x'], mask, temperature) + data['target_y'].backward(data['grad_y']) + data['target_grad_x'] = data['input_x'].grad + if requires_grad: - data['target_y'].backward(data['grad_y']) - data['target_grad_x'] = data['input_x'].grad data['input_x'].grad = None + else: + data['input_x'] = data['input_x'].detach() + data['target_y'] = data['target_y'].detach() return data, mask @@ -60,44 +63,50 @@ def get_params(): @pytest.mark.parametrize("compressed", [False, True]) +@pytest.mark.parametrize("batch", [None, 4]) def test_sparse_softmax_kernels( compressed: bool, - batch: Optional[int] = 4, + batch: Optional[int], H: int = 128, W: int = 256, granularity: Tuple[int, int] = (8, 8), - sparsity: float = 0.9, + sparsity: float = 0, ): - data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=True) + data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=False) - forward_kernel = SparTASparseSoftmaxForwardKernel(compressed=compressed) - backward_kernel = SparTASparseSoftmaxBackwardKernel(compressed=compressed) + batched = batch is not None + forward_kernel = SparTASparseSoftmaxForwardKernel(compressed, batched) + backward_kernel = SparTASparseSoftmaxBackwardKernel(compressed, batched) - if compressed: - sparse_port = backward_kernel.ports['y'] - sparse_port.connect(forward_kernel, 'x') - sparse_port.connect(forward_kernel, 'y') - sparse_port.set_mask(mask) - else: - forward_kernel.ports['y'].set_mask(mask) - backward_kernel.ports['y'].set_mask(mask) - forward_kernel.set_shape(batch, H, W) - forward_kernel.compile(get_params()) - backward_kernel.set_shape(batch, H, W) - backward_kernel.compile(get_params()) + attr = SparsityAttr(True, False) + attr.set_mask(mask) + shape = (batch, H, W) + forward_kernel.set_parameter('MAX_W_VALUE', W) + backward_kernel.set_parameter('MAX_W_VALUE', W) + forward_kernel.compile(get_params(), shape, attr) + backward_kernel.compile(get_params(), shape, attr) + + attr.set_block_size( + forward_kernel.get_parameter('BLOCK_SIZE_H_VALUE'), + forward_kernel.get_parameter('BLOCK_SIZE_W_VALUE'), + ) temperature = np.float32(1 / np.sqrt(W)) - forward_inputs = [data['input_x'], temperature] - forward_kernel.test(forward_inputs, num_warmups=0, num_iters=1, cuda=False) - backward_inputs = [data['grad_y'], data['target_y'], temperature] - backward_kernel.test(backward_inputs, num_warmups=0, num_iters=1, cuda=False) + if compressed: + for name in data: + data[name] = attr.indexes.convert(data[name].detach()) + + data['output_y'] = forward_kernel(data['input_x'], mask, temperature) + data['output_grad_x'] = backward_kernel(data['grad_y'], data['target_y'], mask, temperature) + check_results(data) @pytest.mark.parametrize("compressed", [False, True]) +@pytest.mark.parametrize("batch", [None, 4]) def test_sparse_softmax_function( compressed: bool, - batch: Optional[int] = 4, + batch: Optional[int], H: int = 128, W: int = 256, granularity: Tuple[int, int] = (8, 8), @@ -105,67 +114,64 @@ def test_sparse_softmax_function( ): data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=True) - sparse_ctx = SparseBatchSoftmaxCtx(compressed, np.sqrt(W)) - kernel_names = sparse_ctx.get_kernel_placeholders(backward=True).keys() - sparse_ctx.select_impls({ - kernel_name: 'sparta' - for kernel_name in kernel_names - }) - sparse_ctx.set_shape(batch, H, W) - for port_name, ports in sparse_ctx.sparse_ports.items(): - for port in ports: - port.set_mask(mask) - sparse_ctx.build({ - kernel_name: get_params() - for kernel_name in kernel_names - }) + if batch is None: + func = SparseSoftmax(compressed, np.sqrt(W)) + else: + func = SparseBatchSoftmax(compressed, np.sqrt(W)) + + sparse_attr = func.get_sparse_attr() + sparse_attr.set_mask(mask) + + kernel_names = ['forward', 'backward'] + func.build( + config={kernel_name: get_params() for kernel_name in kernel_names}, + sample_inputs=[data['input_x']] + ) if compressed: - indexes = sparse_ctx.sparse_ports['y'][0].indexes for name in data: - data[name] = indexes.convert(data[name].detach()) + data[name] = sparse_attr.indexes.convert(data[name].detach()) data['input_x'].requires_grad = True - data['output_y'] = SparseBatchSoftmaxFunc.apply(sparse_ctx, data['input_x']) + data['output_y'] = func(data['input_x']) data['output_y'].backward(data['grad_y']) data['output_grad_x'] = data['input_x'].grad - check_results(data) -@pytest.mark.parametrize("batch", [None, 4]) -@pytest.mark.parametrize("compressed", [False, True]) -def test_sparse_softmax_operator( - compressed: bool, - batch: Optional[int], - H: int = 128, - W: int = 256, - granularity: Tuple[int, int] = (8, 8), - sparsity: float = 0.9, -): - data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=True) - - sparse_softmax = SparseSoftmax(mask, np.sqrt(W), compressed) - sparse_softmax.build( - config={ - kernel_name: get_params() - for kernel_name in sparse_softmax.get_kernel_placeholders(backward=True) - }, - sample_inputs=[data['input_x']], - ) - - for random_seed in range(3): # Test dynamic sparse - if compressed: - indexes = sparse_softmax.get_sparse_indexes('y') - for name in data: - data[name] = indexes.convert(data[name].detach()) - data['input_x'].requires_grad = True - - data['output_y'] = sparse_softmax.forward(data['input_x']) - data['output_y'].backward(data['grad_y']) - data['output_grad_x'] = data['input_x'].grad - - check_results(data) - - data, mask = prepare_data(batch, H, W, granularity, sparsity, True, random_seed) - sparse_softmax.update_mask(mask) +# @pytest.mark.parametrize("batch", [None, 4]) +# @pytest.mark.parametrize("compressed", [False, True]) +# def test_sparse_softmax_operator( +# compressed: bool, +# batch: Optional[int], +# H: int = 128, +# W: int = 256, +# granularity: Tuple[int, int] = (8, 8), +# sparsity: float = 0.9, +# ): +# data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=True) + +# sparse_softmax = SparseSoftmax(mask, np.sqrt(W), compressed) +# sparse_softmax.build( +# config={ +# kernel_name: get_params() +# for kernel_name in sparse_softmax.get_kernel_placeholders(backward=True) +# }, +# sample_inputs=[data['input_x']], +# ) + +# for random_seed in range(3): # Test dynamic sparse +# if compressed: +# indexes = sparse_softmax.get_sparse_indexes('y') +# for name in data: +# data[name] = indexes.convert(data[name].detach()) +# data['input_x'].requires_grad = True + +# data['output_y'] = sparse_softmax.forward(data['input_x']) +# data['output_y'].backward(data['grad_y']) +# data['output_grad_x'] = data['input_x'].grad + +# check_results(data) + +# data, mask = prepare_data(batch, H, W, granularity, sparsity, True, random_seed) +# sparse_softmax.update_mask(mask) From 1cbcbfe48a259fa54f6ee6ef6ce29a83e5a97173 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Fri, 24 Feb 2023 11:06:16 +0800 Subject: [PATCH 13/28] update operators; combine functions and operators --- docs/1-code-specializer.md | 2 +- docs/reference/nn.rst | 2 +- setup.py | 4 +- sparta/kernels/__init__.py | 6 + .../{specializer => }/kernels/kernel_base.py | 0 .../kernels/look_up_tables/__init__.py | 0 .../look_up_tables/matmul.openai.61.csv | 0 .../look_up_tables/matmul.openai.default.csv | 0 .../look_up_tables/matmul.sparta.61.csv | 0 .../look_up_tables/matmul.sparta.70.csv | 0 .../look_up_tables/matmul.sparta.75.csv | 0 .../look_up_tables/matmul.sparta.default.csv | 0 .../softmax.backward.sparta.61.csv | 0 .../softmax.backward.sparta.70.csv | 0 .../softmax.backward.sparta.75.csv | 0 .../softmax.backward.sparta.default.csv | 0 .../softmax.forward.sparta.61.csv | 0 .../softmax.forward.sparta.70.csv | 0 .../softmax.forward.sparta.75.csv | 0 .../softmax.forward.sparta.default.csv | 0 sparta/{specializer => }/kernels/matmul.py | 3 +- sparta/{specializer => }/kernels/softmax.py | 3 +- .../kernels/templates/__init__.py | 0 .../templates/openai_sparse_matmul_dds.cuh.j2 | 0 .../templates/openai_sparse_matmul_dsd.cuh.j2 | 0 .../templates/openai_sparse_matmul_sdd.cuh.j2 | 0 .../templates/sparta_sparse_matmul_dds.cuh.j2 | 0 .../templates/sparta_sparse_matmul_dsd.cuh.j2 | 0 .../templates/sparta_sparse_matmul_sdd.cuh.j2 | 0 .../sparta_sparse_softmax_backward.cuh.j2 | 0 .../sparta_sparse_softmax_forward.cuh.j2 | 0 sparta/nn/__init__.py | 2 +- sparta/nn/module_tuner.py | 16 +- sparta/operators/__init__.py | 9 + .../operator_base.py} | 77 ++-- sparta/operators/sparse_attention.py | 69 ++++ .../sparse_matmul.py} | 107 ++++-- .../{specializer => }/operators/sparse_moe.py | 0 .../sparse_seqlen_attention.py} | 0 .../sparse_softmax.py} | 63 ++-- sparta/specializer/__init__.py | 4 - sparta/specializer/functional/__init__.py | 6 - sparta/specializer/kernels/__init__.py | 6 - sparta/specializer/operators/__init__.py | 10 - sparta/specializer/operators/operator_base.py | 129 ------- .../specializer/operators/sparse_attention.py | 238 ------------- sparta/specializer/operators/sparse_linear.py | 172 --------- sparta/specializer/operators/sparse_matmul.py | 136 ------- .../specializer/operators/sparse_softmax.py | 78 ---- test/lut_maker/matmul.py | 19 +- test/lut_maker/softmax.py | 20 +- test/unit/test_sparse_matmul.py | 332 ++++++++++-------- test/unit/test_sparse_softmax.py | 87 ++--- 53 files changed, 477 insertions(+), 1123 deletions(-) create mode 100644 sparta/kernels/__init__.py rename sparta/{specializer => }/kernels/kernel_base.py (100%) rename sparta/{specializer => }/kernels/look_up_tables/__init__.py (100%) rename sparta/{specializer => }/kernels/look_up_tables/matmul.openai.61.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/matmul.openai.default.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/matmul.sparta.61.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/matmul.sparta.70.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/matmul.sparta.75.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/matmul.sparta.default.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.backward.sparta.61.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.backward.sparta.70.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.backward.sparta.75.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.backward.sparta.default.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.forward.sparta.61.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.forward.sparta.70.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.forward.sparta.75.csv (100%) rename sparta/{specializer => }/kernels/look_up_tables/softmax.forward.sparta.default.csv (100%) rename sparta/{specializer => }/kernels/matmul.py (98%) rename sparta/{specializer => }/kernels/softmax.py (98%) rename sparta/{specializer => }/kernels/templates/__init__.py (100%) rename sparta/{specializer => }/kernels/templates/openai_sparse_matmul_dds.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 (100%) rename sparta/{specializer => }/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 (100%) create mode 100644 sparta/operators/__init__.py rename sparta/{specializer/functional/function_base.py => operators/operator_base.py} (74%) create mode 100644 sparta/operators/sparse_attention.py rename sparta/{specializer/functional/batch_matmul.py => operators/sparse_matmul.py} (79%) rename sparta/{specializer => }/operators/sparse_moe.py (100%) rename sparta/{specializer/operators/seqlen_dynamic_sparse_attention.py => operators/sparse_seqlen_attention.py} (100%) rename sparta/{specializer/functional/batch_softmax.py => operators/sparse_softmax.py} (70%) delete mode 100644 sparta/specializer/__init__.py delete mode 100644 sparta/specializer/functional/__init__.py delete mode 100644 sparta/specializer/kernels/__init__.py delete mode 100644 sparta/specializer/operators/__init__.py delete mode 100644 sparta/specializer/operators/operator_base.py delete mode 100644 sparta/specializer/operators/sparse_attention.py delete mode 100644 sparta/specializer/operators/sparse_linear.py delete mode 100644 sparta/specializer/operators/sparse_matmul.py delete mode 100644 sparta/specializer/operators/sparse_softmax.py diff --git a/docs/1-code-specializer.md b/docs/1-code-specializer.md index a2ff27c0..158aab24 100644 --- a/docs/1-code-specializer.md +++ b/docs/1-code-specializer.md @@ -10,7 +10,7 @@ To balance between the flexibility, performance, and developing efficiency, we a | Layer | Base Class | Role | | :- | :- | :- | -| Sparse Operator | [`sparta.nn.OperatorBase`](reference/nn.rst) | User interface as `torch.nn.Module` | +| Sparse Operator | [`sparta.nn.SparseOperator`](reference/nn.rst) | User interface as `torch.nn.Module` | | Sparse Context | `sparta.specializer.functional.SparseCtxBase` | Function context to interact with `torch.autograd.Function` | | Sparse Kernel Placeholder | `sparta.specializer.functional.KernelPlaceholder` | Collection of multiple kernel implementations | | Sparse Kernel | `sparta.specializer.kernels.KernelBase` | Tunable sparse CUDA kernel interface | diff --git a/docs/reference/nn.rst b/docs/reference/nn.rst index 4ac92164..5a1c7091 100644 --- a/docs/reference/nn.rst +++ b/docs/reference/nn.rst @@ -2,7 +2,7 @@ sparta.nn =================================== -.. autoclass:: sparta.nn.OperatorBase +.. autoclass:: sparta.nn.SparseOperator :members: .. autoclass:: sparta.nn.SparseLinear diff --git a/setup.py b/setup.py index b66d7b9f..c6918043 100644 --- a/setup.py +++ b/setup.py @@ -63,8 +63,8 @@ cmdclass={'build_ext': BuildExtension}, include_package_data=True, package_data={ - 'sparta.specializer.kernels.templates': ['*.j2'], - 'sparta.specializer.kernels.look_up_tables': ['*.csv'], + 'sparta.kernels.templates': ['*.j2'], + 'sparta.kernels.look_up_tables': ['*.csv'], 'sparta.tesa.templates': ['*.j2'], }, ) diff --git a/sparta/kernels/__init__.py b/sparta/kernels/__init__.py new file mode 100644 index 00000000..482785bb --- /dev/null +++ b/sparta/kernels/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from sparta.kernels.kernel_base import KernelBase +from sparta.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel +from sparta.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel diff --git a/sparta/specializer/kernels/kernel_base.py b/sparta/kernels/kernel_base.py similarity index 100% rename from sparta/specializer/kernels/kernel_base.py rename to sparta/kernels/kernel_base.py diff --git a/sparta/specializer/kernels/look_up_tables/__init__.py b/sparta/kernels/look_up_tables/__init__.py similarity index 100% rename from sparta/specializer/kernels/look_up_tables/__init__.py rename to sparta/kernels/look_up_tables/__init__.py diff --git a/sparta/specializer/kernels/look_up_tables/matmul.openai.61.csv b/sparta/kernels/look_up_tables/matmul.openai.61.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/matmul.openai.61.csv rename to sparta/kernels/look_up_tables/matmul.openai.61.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.openai.default.csv b/sparta/kernels/look_up_tables/matmul.openai.default.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/matmul.openai.default.csv rename to sparta/kernels/look_up_tables/matmul.openai.default.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.61.csv b/sparta/kernels/look_up_tables/matmul.sparta.61.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/matmul.sparta.61.csv rename to sparta/kernels/look_up_tables/matmul.sparta.61.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv b/sparta/kernels/look_up_tables/matmul.sparta.70.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/matmul.sparta.70.csv rename to sparta/kernels/look_up_tables/matmul.sparta.70.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.75.csv b/sparta/kernels/look_up_tables/matmul.sparta.75.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/matmul.sparta.75.csv rename to sparta/kernels/look_up_tables/matmul.sparta.75.csv diff --git a/sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv b/sparta/kernels/look_up_tables/matmul.sparta.default.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/matmul.sparta.default.csv rename to sparta/kernels/look_up_tables/matmul.sparta.default.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.61.csv b/sparta/kernels/look_up_tables/softmax.backward.sparta.61.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.61.csv rename to sparta/kernels/look_up_tables/softmax.backward.sparta.61.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.70.csv b/sparta/kernels/look_up_tables/softmax.backward.sparta.70.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.70.csv rename to sparta/kernels/look_up_tables/softmax.backward.sparta.70.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.75.csv b/sparta/kernels/look_up_tables/softmax.backward.sparta.75.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.75.csv rename to sparta/kernels/look_up_tables/softmax.backward.sparta.75.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.default.csv b/sparta/kernels/look_up_tables/softmax.backward.sparta.default.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.backward.sparta.default.csv rename to sparta/kernels/look_up_tables/softmax.backward.sparta.default.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.61.csv b/sparta/kernels/look_up_tables/softmax.forward.sparta.61.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.61.csv rename to sparta/kernels/look_up_tables/softmax.forward.sparta.61.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.70.csv b/sparta/kernels/look_up_tables/softmax.forward.sparta.70.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.70.csv rename to sparta/kernels/look_up_tables/softmax.forward.sparta.70.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.75.csv b/sparta/kernels/look_up_tables/softmax.forward.sparta.75.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.75.csv rename to sparta/kernels/look_up_tables/softmax.forward.sparta.75.csv diff --git a/sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.default.csv b/sparta/kernels/look_up_tables/softmax.forward.sparta.default.csv similarity index 100% rename from sparta/specializer/kernels/look_up_tables/softmax.forward.sparta.default.csv rename to sparta/kernels/look_up_tables/softmax.forward.sparta.default.csv diff --git a/sparta/specializer/kernels/matmul.py b/sparta/kernels/matmul.py similarity index 98% rename from sparta/specializer/kernels/matmul.py rename to sparta/kernels/matmul.py index 60d25390..e3236178 100644 --- a/sparta/specializer/kernels/matmul.py +++ b/sparta/kernels/matmul.py @@ -12,8 +12,7 @@ import pandas as pd from sparta.tuning import TunableItemCfg -from sparta.specializer.kernels import templates, look_up_tables -from sparta.specializer.kernels.kernel_base import KernelBase +from sparta.kernels import KernelBase, templates, look_up_tables def _get_matmul_lut(impl: str): diff --git a/sparta/specializer/kernels/softmax.py b/sparta/kernels/softmax.py similarity index 98% rename from sparta/specializer/kernels/softmax.py rename to sparta/kernels/softmax.py index ff32e960..6f1b469d 100644 --- a/sparta/specializer/kernels/softmax.py +++ b/sparta/kernels/softmax.py @@ -12,8 +12,7 @@ import pandas as pd from sparta.tuning import TunableItemCfg -from sparta.specializer.kernels import templates, look_up_tables -from sparta.specializer.kernels.kernel_base import KernelBase +from sparta.kernels import KernelBase, templates, look_up_tables def _get_softmax_lut(impl: str, direction: str): diff --git a/sparta/specializer/kernels/templates/__init__.py b/sparta/kernels/templates/__init__.py similarity index 100% rename from sparta/specializer/kernels/templates/__init__.py rename to sparta/kernels/templates/__init__.py diff --git a/sparta/specializer/kernels/templates/openai_sparse_matmul_dds.cuh.j2 b/sparta/kernels/templates/openai_sparse_matmul_dds.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/openai_sparse_matmul_dds.cuh.j2 rename to sparta/kernels/templates/openai_sparse_matmul_dds.cuh.j2 diff --git a/sparta/specializer/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 b/sparta/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 rename to sparta/kernels/templates/openai_sparse_matmul_dsd.cuh.j2 diff --git a/sparta/specializer/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 b/sparta/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 rename to sparta/kernels/templates/openai_sparse_matmul_sdd.cuh.j2 diff --git a/sparta/specializer/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 b/sparta/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 rename to sparta/kernels/templates/sparta_sparse_matmul_dds.cuh.j2 diff --git a/sparta/specializer/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 b/sparta/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 rename to sparta/kernels/templates/sparta_sparse_matmul_dsd.cuh.j2 diff --git a/sparta/specializer/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 b/sparta/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 rename to sparta/kernels/templates/sparta_sparse_matmul_sdd.cuh.j2 diff --git a/sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 b/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 rename to sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 diff --git a/sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 b/sparta/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 similarity index 100% rename from sparta/specializer/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 rename to sparta/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 diff --git a/sparta/nn/__init__.py b/sparta/nn/__init__.py index cf1dd627..28427b94 100644 --- a/sparta/nn/__init__.py +++ b/sparta/nn/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from sparta.specializer import OperatorBase, SparseLinear, SparseBatchMatMul, SparseSoftmax, SparseAttention, DynamicSparseMoE, SeqlenDynamicSparseAttention +from sparta.operators import SparseOperator, SparseLinear, SparseMatMul, SparseBatchMatMul, SparseSoftmax, SparseBatchSoftmax, SparseAttention, DynamicSparseMoE, SeqlenDynamicSparseAttention from sparta.nn.module_tuner import tune_combined_module as tune, build_combined_module as build diff --git a/sparta/nn/module_tuner.py b/sparta/nn/module_tuner.py index 122b36d1..d353a1ef 100644 --- a/sparta/nn/module_tuner.py +++ b/sparta/nn/module_tuner.py @@ -10,7 +10,7 @@ import numpy as np from sparta.tuning import TunableItemCfg, GridSearchTuner, RandomSearchTuner -from sparta.specializer import OperatorBase +from sparta.operators import SparseOperator _logger = logging.Logger(__name__) @@ -19,7 +19,7 @@ def tune_sparse_module( - module: OperatorBase, + module: SparseOperator, name: str, sample_inputs: List[torch.Tensor], sample_grads: Optional[List[torch.Tensor]] = None, @@ -138,7 +138,7 @@ def tune_combined_module( sample_grads_dict = {'root': sample_grads} hook_handlers = [] - def register_hooks(op: OperatorBase, name: str): + def register_hooks(op: SparseOperator, name: str): hook_handlers.append(op.register_forward_hook(get_input_hook(sample_inputs_dict, name))) hook_handlers.append(op.register_full_backward_hook(get_grad_hook(sample_grads_dict, name))) @@ -169,7 +169,7 @@ def register_hooks(op: OperatorBase, name: str): else: _logger.setLevel(logging.WARNING) - def tune(op: OperatorBase, name: str): + def tune(op: SparseOperator, name: str): best_configs[name] = tune_sparse_module( module=op, name=name, @@ -203,7 +203,7 @@ def build_combined_module( sample_inputs_dict = {'root': sample_inputs} hook_handlers = [] - def register_hooks(op: OperatorBase, name: str): + def register_hooks(op: SparseOperator, name: str): hook_handlers.append(op.register_forward_hook(get_input_hook(sample_inputs_dict, name))) iter_sparse_modules(module, 'root', register_hooks) @@ -215,7 +215,7 @@ def register_hooks(op: OperatorBase, name: str): for handler in hook_handlers: handler.remove() - def build(op: OperatorBase, name: str): + def build(op: SparseOperator, name: str): op.build(configs[name], sample_inputs=sample_inputs_dict[name]) iter_sparse_modules(module, 'root', build) @@ -224,9 +224,9 @@ def build(op: OperatorBase, name: str): def iter_sparse_modules( module: torch.nn.Module, module_name: str, - func: Callable[[OperatorBase, str], None], + func: Callable[[SparseOperator, str], None], ): - if isinstance(module, OperatorBase): + if isinstance(module, SparseOperator): func(module, module_name) return for child_name, child_module in module.named_children(): diff --git a/sparta/operators/__init__.py b/sparta/operators/__init__.py new file mode 100644 index 00000000..644f1796 --- /dev/null +++ b/sparta/operators/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from sparta.operators.operator_base import SparseOperator, SparseAutoGrad, Port, SparsityAttr +from sparta.operators.sparse_matmul import SparseBatchMatMul, SparseMatMul, SparseLinear +from sparta.operators.sparse_softmax import SparseBatchSoftmax, SparseSoftmax +from sparta.operators.sparse_attention import SparseAttention +from sparta.operators.sparse_moe import DynamicSparseMoE +from sparta.operators.sparse_seqlen_attention import SeqlenDynamicSparseAttention diff --git a/sparta/specializer/functional/function_base.py b/sparta/operators/operator_base.py similarity index 74% rename from sparta/specializer/functional/function_base.py rename to sparta/operators/operator_base.py index 77dd9130..3fc4f9b9 100644 --- a/sparta/specializer/functional/function_base.py +++ b/sparta/operators/operator_base.py @@ -4,18 +4,24 @@ from __future__ import annotations import abc -from typing import Any, Dict, List, Callable, Optional +from typing import Any, Dict, List, Tuple, Callable, Optional import torch from sparta.tesa import get_bcs_function, BCSIndexes -from sparta.specializer.kernels import KernelBase +from sparta.kernels import KernelBase from sparta.testing import profile class SparsityAttr(object): - def __init__(self, BCSR: bool, BCSC: bool): + def __init__( + self, + operator: SparseOperator, + param_map: Dict[str, str], + BCSR: bool, + BCSC: bool, + ): self.BCSR = BCSR self.BCSC = BCSC self.mask: torch.Tensor = None @@ -47,9 +53,9 @@ def _update_indexes(self): class Port(object): - def __init__(self, func: SparseFunctionBase, name: str, fine_mask: bool = True): + def __init__(self, func: SparseOperator, name: str, fine_mask: bool = True): self.name = name - self.funcs: List[SparseFunctionBase] = [func] + self.funcs: List[SparseOperator] = [func] self.attr: SparsityAttr = None self._sample_data: torch.Tensor = None # Always dense self._fine_mask = fine_mask @@ -78,25 +84,31 @@ def connect(self, other: Port): for func in other.funcs: func.ports[other.name] = self self.funcs.append(func) - if self.attr is not None and other.attr is not None: + if self.attr is None: + self.attr = other.attr + elif other.attr is not None: self.attr.update_axis(other.attr.BCSR, other.attr.BCSC) -class SparseFunctionBase(Callable): +class SparseOperator(torch.nn.Module): def __init__(self): - self.kernels: Dict[str, Dict[str, KernelBase]] = {} + super().__init__() + self._kernels: Dict[str, Dict[str, KernelBase]] = {} self._compiled_kernels: Dict[str, KernelBase] = {} self.ports: Dict[str, Port] = {} self._sparse_port: str = '' - self.forward: Callable = None - self.backward: SparseFunctionBase = None + self.forward_func = self.reference + self.shape: Tuple = None def get_sparse_attr(self): return self.ports[self._sparse_port].attr - def __call__(self, *inputs): - return self.forward(*inputs) + def set_mask(self, mask: torch.Tensor): + self.get_sparse_attr().set_mask(mask) + + def forward(self, *inputs): + return self.forward_func(*inputs) @abc.abstractmethod def _set_forward(self): @@ -119,8 +131,8 @@ def build( self._read_sample_inputs(sample_inputs) self._compiled_kernels: Dict[str, KernelBase] = {} for kernel_name, params in config.items(): - if kernel_name in self.kernels: - kernel = self.kernels[kernel_name][params['_impl']] + if kernel_name in self._kernels: + kernel = self._kernels[kernel_name][params['_impl']] self._compile_kernel(kernel_name, kernel, params) self._compiled_kernels[kernel_name] = kernel self._set_forward() @@ -159,15 +171,26 @@ def estimate_kernel(self, kernel_name: str): return kernel.estimated_latency_per_flop * flops @abc.abstractmethod - def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): - """Read input data from input port(s) and set output data to output port(s).""" + def reference(self, *inputs): + """Reference forward function. Sync data with ports at the same time.""" + + def get_search_space(self, backward: bool = False): + """Get search space of the sparse context.""" + pass + + def get_connections(self, backward: bool = False): + """Get cross-kernel connected hyper parameters of the sparse context.""" + pass -class SparseAutoGradFunction(SparseFunctionBase): +class SparseAutoGrad(SparseOperator): __static_func__: torch.autograd.Function = None - def __call__(self, *inputs): + def _set_backward(self, backward_op: SparseOperator): + self.backward_op = backward_op + + def forward(self, *inputs): return self.__static_func__.apply(self, *inputs) def build( @@ -176,7 +199,7 @@ def build( sample_inputs: Optional[List[torch.Tensor]] = None, ): super().build(config, sample_inputs) - self.backward.build(config, sample_inputs) + self.backward_op.build(config, sample_inputs) def profile_kernel( self, @@ -185,17 +208,13 @@ def profile_kernel( num_iters: int = 100, cuda: bool = False ): - if kernel_name in self.kernels: + if kernel_name in self._kernels: return super().profile_kernel(kernel_name, num_warmups, num_iters, cuda) - elif self.backward is not None: - return self.backward.profile_kernel(kernel_name, num_warmups, num_iters, cuda) + else: + return self.backward_op.profile_kernel(kernel_name, num_warmups, num_iters, cuda) def estimate_kernel(self, kernel_name: str): - if kernel_name in self.kernels: + if kernel_name in self._kernels: return super().estimate_kernel(kernel_name) - elif self.backward is not None: - return self.backward.estimate_kernel(kernel_name) - - @abc.abstractmethod - def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): - """Read grad data from output port(s) and backward by auto-grad.""" + else: + return self.backward_op.estimate_kernel(kernel_name) diff --git a/sparta/operators/sparse_attention.py b/sparta/operators/sparse_attention.py new file mode 100644 index 00000000..1de71db4 --- /dev/null +++ b/sparta/operators/sparse_attention.py @@ -0,0 +1,69 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +from typing import Any, Dict, List, Optional +import warnings + +import torch +import numpy as np + +from sparta.operators import SparseBatchMatMul, SparseBatchSoftmax + + +class SparseAttention(torch.nn.Module): + r"""The sparse attention operator. + + .. math:: + \text{Attention}(Q, K, V) = \text{Softmax}(Q K) V + + Args: + mask (torch.Tensor): The mask tensor of shape :math:`(N_{target}, N_{sourse})`, + where :math:`N_{target}` is the target sequence length + and :math:`N_{sourse}` is the sourse sequence length. + + Shape: + - Input1: :math:`(B \times H, N_{target}, E)` where :math:`B` is the batch size, + :math:`H` is the number of heads and :math:`E` is the embed dimension. + - Input2: :math:`(B \times H, N_{sourse}, E)`. + - Input3: :math:`(B \times H, N_{sourse}, E)`, same shape as the second input. + - Output: :math:`(B \times H, N_{target}, E)`, same shape as the first input. + + Examples: + + .. code-block:: python + + B, H, Ns, Nt, E = 4, 4, 1024, 1024, 1024 + + # Create a mask + mask = sparta.testing.block_mask((Nt, Ns), sparsity=0.99) + + # Create a sparse attention operator using the mask + sparse_attention = sparta.nn.SparseAttention(mask=mask) + + # Tune the sparse attention operator + sparta.nn.tune(sparse_attention, sample_inputs=[ + torch.rand((B * H, Nt, E), device='cuda'), + torch.rand((B * H, Ns, E), device='cuda'), + torch.rand((B * H, Ns, E), device='cuda'), + ]) + + """ + + def __init__(self, mask: Optional[torch.Tensor] = None): + super().__init__() + self._matmul_qk = SparseBatchMatMul('dds', False, True, False, True) + self._softmax = SparseBatchSoftmax(True, temperature=None) + self._matmul_out = SparseBatchMatMul('sdd', False, False, False, True) + self._matmul_qk.ports['C'].connect(self._softmax.ports['x']) + self._softmax.ports['y'].connect(self._matmul_out.ports['A']) + self._sparse_attr = self._matmul_qk.get_sparse_attr() + if mask is not None: + self.set_mask(mask) + + def set_mask(self, mask: torch.Tensor): + self._sparse_attr.set_mask(mask) + + def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor): + QK = self._matmul_qk(Q, K) + SM = self._softmax(QK) + return self._matmul_out(SM, V) diff --git a/sparta/specializer/functional/batch_matmul.py b/sparta/operators/sparse_matmul.py similarity index 79% rename from sparta/specializer/functional/batch_matmul.py rename to sparta/operators/sparse_matmul.py index c017d27b..692804b3 100644 --- a/sparta/specializer/functional/batch_matmul.py +++ b/sparta/operators/sparse_matmul.py @@ -6,11 +6,11 @@ import torch import numpy as np -from sparta.specializer.kernels import KernelBase, SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.specializer.functional.function_base import Port, SparsityAttr, SparseFunctionBase, SparseAutoGradFunction +from sparta.kernels import KernelBase, SparTASparseMatMulKernel, OpenAISparseMatMulKernel +from sparta.operators import Port, SparsityAttr, SparseOperator, SparseAutoGrad -class SparseBatchMatMulForward(SparseFunctionBase): +class SparseBatchMatMulForward(SparseOperator): __batched__ = True @@ -59,7 +59,7 @@ def __init__( 'compressed': compressed, 'batched': self.__batched__, } - self.kernels['forward'] = { + self._kernels['forward'] = { 'sparta': SparTASparseMatMulKernel(**specs), 'openai': OpenAISparseMatMulKernel(**specs), } @@ -96,7 +96,7 @@ def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str def _set_forward(self): if 'forward' in self._compiled_kernels: - self.forward = self._compiled_kernels['forward'] + self.forward_func = self._compiled_kernels['forward'] def _kernel_func_call(self, kernel_name: str): A = self.ports['A'].get_data(compressed=self._compressed) @@ -116,9 +116,9 @@ def _calc_kernel_flops(self, kernel_name: str): sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num return np.prod(self.shape) * sparse_rate - def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): - if sample_inputs is not None: - self._read_sample_inputs(sample_inputs) + def reference(self, *inputs): + if len(inputs) > 0: + self._read_sample_inputs(inputs) A = self.ports['A'].get_data(compressed=False) B = self.ports['B'].get_data(compressed=False) if self._transpose_A: @@ -130,6 +130,7 @@ def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): bias = self.ports['bias'].get_data() C += bias.unsqueeze(1) if self.__batched__ else bias self.ports['C'].set_data(C) + return C class SparseMatMulForward(SparseBatchMatMulForward): @@ -137,7 +138,7 @@ class SparseMatMulForward(SparseBatchMatMulForward): __batched__ = False -class SparseBatchMatMulBackward(SparseFunctionBase): +class SparseBatchMatMulBackward(SparseOperator): __batched__ = True @@ -195,11 +196,11 @@ def __init__( 'batched': self.__batched__, } - self.kernels['backward:A'] = { + self._kernels['backward:A'] = { 'sparta': SparTASparseMatMulKernel(**A_spec), 'openai': OpenAISparseMatMulKernel(**A_spec), } - self.kernels['backward:B'] = { + self._kernels['backward:B'] = { 'sparta': SparTASparseMatMulKernel(**B_spec), 'openai': OpenAISparseMatMulKernel(**B_spec), } @@ -238,8 +239,8 @@ def _set_forward(self): else: backward_B = lambda grad_C, A: kernel_B(A, grad_C) if self._mode == 'dds' and self._compressed: - C_indexes = self.ports['C'].attr.indexes - backward_bias = lambda grad_C: C_indexes.sum(grad_C, axis=-2) + C_attr = self.ports['C'].attr + backward_bias = lambda grad_C: C_attr.indexes.sum(grad_C, axis=-2) else: backward_bias = lambda grad_C: grad_C.sum(-2) @@ -253,7 +254,7 @@ def backward(grad, A, B, needs_grad): grad_bias = backward_bias(grad) return grad_A, grad_B, grad_bias - self.forward = backward + self.forward_func = backward def _kernel_func_call(self, kernel_name: str): grad_C = self.ports['C'].get_data(grad=True, compressed=self._compressed) @@ -286,7 +287,7 @@ def _calc_kernel_flops(self, kernel_name: str): sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num return np.prod(self.shape) * sparse_rate - def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + def reference(self, *inputs): pass @@ -300,12 +301,12 @@ class _SparseMatMul(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, - func: SparseAutoGradFunction, + func: SparseAutoGrad, *inputs, ): ctx.save_for_backward(inputs[0], inputs[1]) - ctx.backward = func.backward - return func.forward(*inputs) + ctx.backward = func.backward_op + return func.forward_func(*inputs) @staticmethod def backward(ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor): @@ -313,7 +314,7 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor): return None, *ctx.backward(grad, A, B, ctx.needs_input_grad) -class SparseBatchMatMul(SparseAutoGradFunction, SparseBatchMatMulForward): +class SparseBatchMatMul(SparseAutoGrad, SparseBatchMatMulForward): __static_func__ = _SparseMatMul @@ -323,28 +324,24 @@ def __init__( transpose_A: bool, transpose_B: bool, biased: bool, - compressed: bool, + compressed: bool = True, ): super().__init__(mode, transpose_A, transpose_B, biased, compressed) - self.backward = SparseBatchMatMulBackward( + self._set_backward(SparseBatchMatMulBackward( mode=mode, transpose_A=transpose_A, transpose_B=transpose_B, biased=biased, compressed=compressed, ports=self.ports, - ) + )) def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): super()._read_sample_inputs(sample_inputs) - self.backward.shape = self.shape - - def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): - self.ports['C'].set_data(sample_grads[0], grad=True) - self.ports['C'].get_data().backward(sample_grads[0]) + self.backward_op.shape = self.shape -class SparseMatMul(SparseAutoGradFunction, SparseMatMulForward): +class SparseMatMul(SparseAutoGrad, SparseMatMulForward): __static_func__ = _SparseMatMul @@ -354,23 +351,63 @@ def __init__( transpose_A: bool, transpose_B: bool, biased: bool, - compressed: bool, + compressed: bool = True, ): super().__init__(mode, transpose_A, transpose_B, biased, compressed) - self.backward = SparseMatMulBackward( + self._set_backward(SparseMatMulBackward( mode=mode, transpose_A=transpose_A, transpose_B=transpose_B, biased=biased, compressed=compressed, ports=self.ports, - ) + )) + + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + super()._read_sample_inputs(sample_inputs) + self.backward_op.shape = self.shape + + +class SparseLinear(SparseMatMul): + + def __init__(self, raw_module: torch.nn.Linear, mode: str = 'dsd'): + self._biased = raw_module.bias is not None + self._compressed = mode == 'dsd' + super().__init__(mode, False, True, self._biased, self._compressed) + self.weight: torch.nn.Parameter = None + self.bias = raw_module.bias + self.ports['B'].set_data(raw_module.weight) + if self._biased: + self.ports['bias'].set_data(raw_module.bias) + self.forward = self._forward_with_bias + else: + self.forward = self._forward_without_bias + + def _update_weight(self): + if 'forward' in self._compiled_kernels: + weight = self.ports['B'].get_data(compressed=self._compressed) + self.weight = torch.nn.Parameter(weight, requires_grad=True) + + def set_mask(self, mask: torch.Tensor): + super().set_mask(mask) + self._update_weight() + + def build( + self, + config: Dict[str, Dict[str, Any]], + sample_inputs: Optional[List[torch.Tensor]] = None, + ): + super().build(config, sample_inputs) + self._update_weight() def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + sample_inputs.append(self.ports['B'].get_data()) + if self._biased: + sample_inputs.append(self.ports['bias'].get_data()) super()._read_sample_inputs(sample_inputs) - self.backward.shape = self.shape - def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): - self.ports['C'].set_data(sample_grads[0], grad=True) - self.ports['C'].get_data().backward(sample_grads[0]) + def _forward_with_bias(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x, self.weight, self.bias) + def _forward_without_bias(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x, self.weight) diff --git a/sparta/specializer/operators/sparse_moe.py b/sparta/operators/sparse_moe.py similarity index 100% rename from sparta/specializer/operators/sparse_moe.py rename to sparta/operators/sparse_moe.py diff --git a/sparta/specializer/operators/seqlen_dynamic_sparse_attention.py b/sparta/operators/sparse_seqlen_attention.py similarity index 100% rename from sparta/specializer/operators/seqlen_dynamic_sparse_attention.py rename to sparta/operators/sparse_seqlen_attention.py diff --git a/sparta/specializer/functional/batch_softmax.py b/sparta/operators/sparse_softmax.py similarity index 70% rename from sparta/specializer/functional/batch_softmax.py rename to sparta/operators/sparse_softmax.py index d92457ed..2fe49c93 100644 --- a/sparta/specializer/functional/batch_softmax.py +++ b/sparta/operators/sparse_softmax.py @@ -6,20 +6,20 @@ import torch import numpy as np -from sparta.specializer.kernels import KernelBase, SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.specializer.functional.function_base import Port, SparsityAttr, SparseFunctionBase, SparseAutoGradFunction +from sparta.kernels import KernelBase, SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel +from sparta.operators.operator_base import Port, SparsityAttr, SparseOperator, SparseAutoGrad from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference -class SparseBatchSoftmaxForward(SparseFunctionBase): +class SparseBatchSoftmaxForward(SparseOperator): __batched__ = True __direction__ = 'forward' - def __init__(self, compressed: bool, temperature: float = 1): + def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): super().__init__() self._compressed = compressed - self._T = np.float32(1 / temperature) + self._T = None if temperature is None else np.float32(1 / temperature) self._sparse_port = 'y' sparse_attr = SparsityAttr(True, False) @@ -27,7 +27,7 @@ def __init__(self, compressed: bool, temperature: float = 1): self.ports[port_name] = Port(self, port_name) self.ports[port_name].attr = sparse_attr - self.kernels[self.__direction__] = { + self._kernels[self.__direction__] = { 'sparta': { 'forward': SparTASparseSoftmaxForwardKernel, 'backward': SparTASparseSoftmaxBackwardKernel, @@ -51,6 +51,8 @@ def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): H, W = x.shape[-2:] self.shape = (batch_size, H, W) self.ports['x'].set_data(x) + if self._T is None: + self.set_temperature(np.sqrt(W)) def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): sparse_attr = self.get_sparse_attr() @@ -65,7 +67,7 @@ def _set_forward(self): if self.__direction__ in self._compiled_kernels: kernel = self._compiled_kernels[self.__direction__] sparse_attr = self.get_sparse_attr() - self.forward = lambda *inputs: kernel(*inputs, sparse_attr.mask, self._T) + self.forward_func = lambda *inputs: kernel(*inputs, sparse_attr.mask, self._T) def _kernel_func_call(self, kernel_name: str): x = self.ports['x'].get_data(compressed=self._compressed) @@ -81,13 +83,14 @@ def _calc_kernel_flops(self, kernel_name: str): sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num return np.prod(self.shape) * sparse_rate - def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): - if sample_inputs is not None: - self._read_sample_inputs(sample_inputs) + def reference(self, *inputs): + if len(inputs) > 0: + self._read_sample_inputs(inputs) x = self.ports['x'].get_data(compressed=False) mask = self.get_sparse_attr().mask y = sparse_softmax_forward_reference(x, mask, 1 / self._T) self.ports['y'].set_data(y) + return y class SparseSoftmaxForward(SparseBatchSoftmaxForward): @@ -113,7 +116,7 @@ def _kernel_func_call(self, kernel_name: str): def _kernel_reference(self, kernel_name: str): return self.ports['x'].get_data(grad=True, compressed=self._compressed) - def reference_forward(self, sample_inputs: Optional[List[torch.Tensor]] = None): + def reference(self, *inputs): pass @@ -127,11 +130,11 @@ class _SparseSoftmax(torch.autograd.Function): @staticmethod def forward( ctx: torch.autograd.function.FunctionCtx, - func: SparseAutoGradFunction, + func: SparseAutoGrad, x: torch.Tensor, ): - ctx.backward = func.backward - y = func.forward(x) + ctx.backward = func.backward_op + y = func.forward_func(x) ctx.save_for_backward(y) return y @@ -144,45 +147,37 @@ def backward(ctx: torch.autograd.function.FunctionCtx, grad: torch.Tensor): return None, None -class SparseBatchSoftmax(SparseAutoGradFunction, SparseBatchSoftmaxForward): +class SparseBatchSoftmax(SparseAutoGrad, SparseBatchSoftmaxForward): __static_func__ = _SparseSoftmax - def __init__(self, compressed: bool, temperature: float = 1): + def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): super().__init__(compressed, temperature) - self.backward = SparseBatchSoftmaxBackward(compressed, temperature) - self.backward.ports = self.ports + self._set_backward(SparseBatchSoftmaxBackward(compressed, temperature)) + self.backward_op.ports = self.ports def set_temperature(self, temperature: float): super().set_temperature(temperature) - self.backward.set_temperature(temperature) + self.backward_op.set_temperature(temperature) def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): super()._read_sample_inputs(sample_inputs) - self.backward.shape = self.shape + self.backward_op.shape = self.shape - def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): - self.ports['y'].set_data(sample_grads[0], grad=True) - self.ports['y'].get_data().backward(sample_grads[0]) - -class SparseSoftmax(SparseAutoGradFunction, SparseSoftmaxForward): +class SparseSoftmax(SparseAutoGrad, SparseSoftmaxForward): __static_func__ = _SparseSoftmax - def __init__(self, compressed: bool, temperature: float = 1): + def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): super().__init__(compressed, temperature) - self.backward = SparseSoftmaxBackward(compressed, temperature) - self.backward.ports = self.ports + self._set_backward(SparseSoftmaxBackward(compressed, temperature)) + self.backward_op.ports = self.ports def set_temperature(self, temperature: float): super().set_temperature(temperature) - self.backward.set_temperature(temperature) + self.backward_op.set_temperature(temperature) def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): super()._read_sample_inputs(sample_inputs) - self.backward.shape = self.shape - - def reference_backward(self, sample_grads: Optional[List[torch.Tensor]] = None): - self.ports['y'].set_data(sample_grads[0], grad=True) - self.ports['y'].get_data().backward(sample_grads[0]) + self.backward_op.shape = self.shape diff --git a/sparta/specializer/__init__.py b/sparta/specializer/__init__.py deleted file mode 100644 index f37bb0aa..00000000 --- a/sparta/specializer/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from sparta.specializer.operators import OperatorBase, SparseLinear, SparseBatchMatMul, SparseSoftmax, SparseAttention, DynamicSparseMoE, SeqlenDynamicSparseAttention diff --git a/sparta/specializer/functional/__init__.py b/sparta/specializer/functional/__init__.py deleted file mode 100644 index f6bc7d35..00000000 --- a/sparta/specializer/functional/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from sparta.specializer.functional.function_base import Port, SparsityAttr, SparseFunctionBase -from sparta.specializer.functional.batch_matmul import SparseMatMul, SparseBatchMatMul -from sparta.specializer.functional.batch_softmax import SparseSoftmax, SparseBatchSoftmax diff --git a/sparta/specializer/kernels/__init__.py b/sparta/specializer/kernels/__init__.py deleted file mode 100644 index 0b934bbc..00000000 --- a/sparta/specializer/kernels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from sparta.specializer.kernels.kernel_base import KernelBase -from sparta.specializer.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.specializer.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel diff --git a/sparta/specializer/operators/__init__.py b/sparta/specializer/operators/__init__.py deleted file mode 100644 index 80462a7b..00000000 --- a/sparta/specializer/operators/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from sparta.specializer.operators.operator_base import OperatorBase -from sparta.specializer.operators.sparse_linear import SparseLinear -from sparta.specializer.operators.sparse_matmul import SparseBatchMatMul -from sparta.specializer.operators.sparse_softmax import SparseSoftmax -from sparta.specializer.operators.sparse_attention import SparseAttention -from sparta.specializer.operators.sparse_moe import DynamicSparseMoE -from sparta.specializer.operators.seqlen_dynamic_sparse_attention import SeqlenDynamicSparseAttention diff --git a/sparta/specializer/operators/operator_base.py b/sparta/specializer/operators/operator_base.py deleted file mode 100644 index 23e194b8..00000000 --- a/sparta/specializer/operators/operator_base.py +++ /dev/null @@ -1,129 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import abc -import warnings -from typing import Any, List, Dict, Optional, Type - -import torch - -from sparta.specializer.functional import SparseFunctionBase - - -class OperatorBase(torch.nn.Module): - """Base class of sparse operators. - - Each sparse operator contains a sparse function context and applies a PyTorch - autograd function to do forward and backward calculation. - - SparTA does not handle parameter initialization. Instead, a PyTorch dense operator - is required to provide parameter(s) for parametric operators. - - Args: - raw_module (Optional[torch.nn.Module]): The corresponding dense operator. - - """ - - __base_class__: Type[torch.nn.Module] = None - __sparse_func__: Type[torch.autograd.Function] = None - - def __init__(self, raw_module: Optional[torch.nn.Module] = None): - if self.__base_class__ is not None and type(raw_module) is not self.__base_class__: - raise ValueError(f'expected a {self.__base_class__} module') - super().__init__() - self._raw_module = raw_module - self._sparse_ctx: SparseCtxBase = None - self._shape: Dict[str, int] = None - self.ready: bool = False - - @abc.abstractmethod - def update_mask(self, *args, **kwargs): - """Translate and set input mask(s) to sparse port(s).""" - - def _set_mask(self, masks: Dict[str, torch.Tensor]): - for port_name, ports in self._sparse_ctx.sparse_ports.items(): - for port in ports: - port.set_mask(masks[port_name]) - if self.ready: - self._sparse_ctx.update_func() - - @abc.abstractmethod - def _read_sample_inputs(self, *args): - """Read missing shape value from sample inputs.""" - - def build(self, config: Dict[str, Dict[str, Any]], sample_inputs: List[Any]): - """The build function includes following steps: - 1. Read and confirm the operator shape from sample inputs. - 2. Set implementations for the sparse context. - 3. Set shape for the sparse context. - 4. Compile with input config. - - Args: - config (Dict[str, Any]): A dictionary gives value of each required - hyper parameter of the sparse context. - sample_inputs (List[Any]): List of sample inputs. - """ - self._read_sample_inputs(*sample_inputs) - self._sparse_ctx.select_impls({k: v['_impl'] for k, v in config.items()}) - self._sparse_ctx.set_shape(**self._shape) - self._compile(config=config) - - def _compile(self, config: Optional[Dict[str, Dict[str, Any]]] = None): - """The compile function includes following steps: - 1. Build the sparse context with input config. - 2. Replace the forward function from dense to sparse version. - 3. Disconnect the raw module which may contain dense parameter(s). - 4. Mark the sparse operator as ready. - - Args: - config (Dict[str, Any]): A dictionary gives value of each required - hyper parameter of the sparse context. - """ - self._sparse_ctx.build(config) - self.forward = self._sparse_forward - self._raw_module = None - self.ready = True - - def _sparse_forward(self, *args): - """Apply the sparse autograd function.""" - return self.__sparse_func__.apply(self._sparse_ctx, *args) - - def _dense_forward(self, *args): - """The dense forward function for reference.""" - if self._raw_module is None: - return self._sparse_ctx.dense_forward(*args) - else: - return self._raw_module.forward(*args) - - def forward(self, *args) -> torch.Tensor: - """Forward function. Calls the corresponding dense operator if not built.""" - warnings.warn('the sparse module is not compiled, using the dense module to forward') - return self._dense_forward(*args) - - def set_sample_inputs( - self, - sample_inputs: List[torch.Tensor], - sample_grads: Optional[List[torch.Tensor]] = None, - ): - """Set sample inputs and gradients for tuning.""" - self._read_sample_inputs(*sample_inputs) - self._sparse_ctx.set_shape(**self._shape) - self._sparse_ctx.set_sample_inputs(sample_inputs, sample_grads) - - def get_search_space(self, backward: bool = False): - """Get search space of the sparse context.""" - return self._sparse_ctx.get_search_space(backward) - - def get_connections(self, backward: bool = False): - """Get cross-kernel connected hyper parameters of the sparse context.""" - return self._sparse_ctx.get_connections(backward) - - def get_sparse_indexes(self, port_name: str): - """Get TeSA indexes of specified sparse port.""" - return self._sparse_ctx.get_sparse_indexes(port_name) - - def get_kernel_placeholders(self, backward: bool = False): - """Get kernel placeholders. - Returns only forward kernel placeholders(s) if backward is not required. - """ - return self._sparse_ctx.get_kernel_placeholders(backward) diff --git a/sparta/specializer/operators/sparse_attention.py b/sparta/specializer/operators/sparse_attention.py deleted file mode 100644 index 2f1927b7..00000000 --- a/sparta/specializer/operators/sparse_attention.py +++ /dev/null @@ -1,238 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Any, Dict, List, Optional -import warnings - -import torch -import numpy as np - -from sparta.specializer.operators import OperatorBase, SparseBatchMatMul, SparseSoftmax -from sparta.specializer.kernels import KernelBase - - -class SparseAttention(OperatorBase): - r"""The sparse attention operator. - - .. math:: - \text{Attention}(Q, K, V) = \text{Softmax}(Q K) V - - Args: - mask (torch.Tensor): The mask tensor of shape :math:`(N_{target}, N_{sourse})`, - where :math:`N_{target}` is the target sequence length - and :math:`N_{sourse}` is the sourse sequence length. - - Shape: - - Input1: :math:`(B \times H, N_{target}, E)` where :math:`B` is the batch size, - :math:`H` is the number of heads and :math:`E` is the embed dimension. - - Input2: :math:`(B \times H, N_{sourse}, E)`. - - Input3: :math:`(B \times H, N_{sourse}, E)`, same shape as the second input. - - Output: :math:`(B \times H, N_{target}, E)`, same shape as the first input. - - Examples: - - .. code-block:: python - - B, H, Ns, Nt, E = 4, 4, 1024, 1024, 1024 - - # Create a mask - mask = sparta.testing.block_mask((Nt, Ns), sparsity=0.99) - - # Create a sparse attention operator using the mask - sparse_attention = sparta.nn.SparseAttention(mask=mask) - - # Tune the sparse attention operator - sparta.nn.tune(sparse_attention, sample_inputs=[ - torch.rand((B * H, Nt, E), device='cuda'), - torch.rand((B * H, Ns, E), device='cuda'), - torch.rand((B * H, Ns, E), device='cuda'), - ]) - - """ - - def __init__(self, mask: torch.Tensor): - super().__init__() - self.mask = mask - self._Nt, self._Ns = mask.shape - self._matmul_qk = SparseBatchMatMul( - C_mask=mask, - transpose_A=False, - transpose_B=True, - compressed=True, - ) - self._softmax = SparseSoftmax( - mask=mask, - compressed=True, - ) - self._matmul_out = SparseBatchMatMul( - A_mask=mask, - transpose_A=False, - transpose_B=False, - compressed=True, - ) - self._sparse_port: PortConfig = None - self.ready: bool = False - - def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor): - qk = self._matmul_qk.forward(query, key) - sm = self._softmax.forward(qk) - out = self._matmul_out.forward(sm, value) - return out - - def update_mask(self, mask: torch.Tensor): - if self.ready: - self._sparse_port.set_mask(mask) - self._matmul_qk._sparse_ctx.update_func() - self._softmax._sparse_ctx.update_func() - self._matmul_out._sparse_ctx.update_func() - else: - op_list: List[OperatorBase] = [self._matmul_qk, self._softmax, self._matmul_out] - for op in op_list: - for ports in op._sparse_ctx.sparse_ports.values(): - for port in ports: - port.set_mask(mask) - - def build(self, config: Dict[str, Any], sample_inputs: List[torch.Tensor]): - query, key, value = sample_inputs - - qB = np.prod(query.shape[:-2]) - qN, qE = query.shape[-2:] - kB = np.prod(key.shape[:-2]) - kN, kE = key.shape[-2:] - vB = np.prod(value.shape[:-2]) - vN, vE = value.shape[-2:] - assert qB == kB == vB, f'query, key and value should have the same batch size' - assert self._Nt == qN, f'expect query shape (?, {self._Nt}, ?), got {query.shape}' - assert self._Ns == kN, f'expect key shape (?, {self._Ns}, ?), got {key.shape}' - assert self._Ns == vN, f'expect value shape (?, {self._Ns}, ?), got {value.shape}' - assert qE == kE == vE, f'query, key and value should have the same embed dim' - - self._softmax.set_temperature(np.sqrt(qE)) - - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - qk = self._matmul_qk.forward(query, key) - sm = self._softmax.forward(qk) - - op_dict: Dict[str, OperatorBase] = { - 'qk': self._matmul_qk, - 'sm': self._softmax, - 'out': self._matmul_out, - } - inputs_dict: Dict[str, List[torch.Tensor]] = { - 'qk': [query, key], - 'sm': [qk], - 'out': [sm, value], - } - sparse_port_map: Dict[str, str] = { - 'qk': 'C', - 'sm': 'y', - 'out': 'A', - } - - kernels: List[KernelBase] = [] - port_names: List[str] = [] - for op_name, op in op_dict.items(): - for k, v in op._sparse_ctx.get_kernel_placeholders().items(): - kernel = v.possible_kernels[config[f'{op_name}/{k}']['_impl']] - port_names.append(v.port_map[sparse_port_map[op_name]]) - kernels.append(kernel) - self._sparse_port = kernels[0].ports[port_names[0]] - for kernel, port_name in zip(kernels[1:], port_names[1:]): - self._sparse_port.connect(kernel, port_name) - - for op_name, op in op_dict.items(): - op.build( - config={ - k.split('/')[1]: v - for k, v in config.items() - if k.startswith(op_name) - }, - sample_inputs=inputs_dict[op_name], - ) - - self.ready = True - - def set_sample_inputs( - self, - sample_inputs: List[torch.Tensor], - sample_grads: Optional[List[torch.Tensor]] = None, - ): - query, key, value = sample_inputs - self._softmax.set_temperature(np.sqrt(query.shape[-1])) - - if sample_grads is not None: - query.requires_grad = True - key.requires_grad = True - value.requires_grad = True - - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - qk = self._matmul_qk.forward(query, key) - sample_grad_qk = None - sm = self._softmax.forward(qk) - sample_grad_sm = None - - if sample_grads is not None: - qk.retain_grad() - sm.retain_grad() - grad_out, = sample_grads - - with warnings.catch_warnings(): - warnings.simplefilter('ignore') - out = self._matmul_out.forward(sm, value) - out.backward(grad_out) - - sample_grad_sm = [sm.grad] - sample_grad_qk = [qk.grad] - query = query.detach() - key = key.detach() - value = value.detach() - - self._matmul_qk.set_sample_inputs([query, key], sample_grad_qk) - self._softmax.set_sample_inputs([qk], sample_grad_sm) - self._matmul_out.set_sample_inputs([sm, value], sample_grads) - - def _combine_dict( - self, - qk_dict: Dict[str, Any], - sm_dict: Dict[str, Any], - out_dict: Dict[str, Any], - backward: bool = False, - ): - return dict( - **{f'qk/{k}': v for k, v in qk_dict.items() if k.startswith('forward') or backward}, - **{f'sm/{k}': v for k, v in sm_dict.items() if k.startswith('forward') or backward}, - **{f'out/{k}': v for k, v in out_dict.items() if k.startswith('forward') or backward}, - ) - - def get_search_space(self, backward: bool = False): - return self._combine_dict( - qk_dict=self._matmul_qk.get_search_space(backward=True), - sm_dict=self._softmax.get_search_space(backward=True), - out_dict=self._matmul_out.get_search_space(backward=True), - backward=backward, - ) - - def get_connections(self, backward: bool = False): - return [ - self._combine_dict( - qk_dict=qk_params, - sm_dict=sm_params, - out_dict=out_params, - backward=backward, - ) - for qk_params, sm_params, out_params in zip( - self._matmul_qk.get_connections(backward=True), - self._softmax.get_connections(backward=True), - self._matmul_out.get_connections(backward=True), - ) - ] - - def get_kernel_placeholders(self, backward: bool = False): - return self._combine_dict( - qk_dict=self._matmul_qk.get_kernel_placeholders(backward=True), - sm_dict=self._softmax.get_kernel_placeholders(backward=True), - out_dict=self._matmul_out.get_kernel_placeholders(backward=True), - backward=backward, - ) diff --git a/sparta/specializer/operators/sparse_linear.py b/sparta/specializer/operators/sparse_linear.py deleted file mode 100644 index fc9780da..00000000 --- a/sparta/specializer/operators/sparse_linear.py +++ /dev/null @@ -1,172 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Any, Dict, List, Optional - -import torch - -from sparta.specializer.operators import OperatorBase -from sparta.specializer.functional import SparseBatchMatMul - - -class SparseLinear(OperatorBase): - r"""The sparse linear operator: :math:`y = xA^T + b` - - Args: - raw_module (torch.nn.Linear): The corresponding dense linear operator. - input_mask (Optional[torch.Tensor]): The mask of input tensor. - If `input_mask` is set, the other two masks should be `None` - and the internal MatMul kernel will choose SD=>D mode. - weight_mask (Optional[torch.Tensor]): The mask of weight tensor. - If `weight_mask` is set, the other two masks should be `None` - and the internal MatMul kernel will choose DS=>D mode. - output_mask (Optional[torch.Tensor]): The mask of output tensor. - If `output_mask` is set, the other two masks should be `None` - and the internal MatMul kernel will choose DD=>S mode. - - Shape: - - Input: :math:`(B, H_{in})` where :math:`B = \text{batch_size}` - and :math:`H_{in} = \text{in_features}`. - - Output: :math:`(B, H_{out})` where :math:`H_{out} = \text{out_features}`. - - Attributes: - weight: The learnable weights of the module of shape :math:`(\text{out_features}, \text{in_features})`. - If `weight_mask` is set, the weight will be compressed to BCSR format. - bias: The learnable bias of the module of shape :math:`(\text{out_features})`. - It is a copy of the bias tensor in the raw module. - - Examples: - - .. code-block:: python - - batch_size, in_features, out_features = 1024, 1024, 1024 - - # Create a dense linear operator - dense_linear = torch.nn.Linear(in_features, out_features, device='cuda') - - # Create a weight mask - mask = sparta.testing.block_mask((out_features, in_features), sparsity=0.99) - - # Create a sparse linear operator using the dense operator and the weight mask - sparse_linear = sparta.nn.SparseLinear(dense_linear, weight_mask=mask) - - # Tune the sparse linear operator - sparta.nn.tune(sparse_linear, sample_inputs=[ - torch.rand((batch_size, in_features), device='cuda'), - ]) - - """ - - __base_class__ = torch.nn.Linear - __sparse_func__ = SparseBatchMatMul - - def __init__( - self, - raw_module: torch.nn.Linear, - input_mask: Optional[torch.Tensor] = None, - weight_mask: Optional[torch.Tensor] = None, - output_mask: Optional[torch.Tensor] = None, - ): - super().__init__(raw_module) - - M = None - N, K = raw_module.weight.shape - biased = raw_module.bias is not None - self._raw_weight = torch.clone(raw_module.weight) - self.weight = None - self.bias = raw_module.bias - - if input_mask is not None: - self._sparse_ctx = SparseBatchMatMulCtx('sdd', False, True, biased, False) - M = input_mask.shape[0] - elif weight_mask is not None: - self._sparse_ctx = SparseBatchMatMulCtx('dsd', False, True, biased, True) - elif output_mask is not None: - self._sparse_ctx = SparseBatchMatMulCtx('dds', False, True, biased, False) - M = output_mask.shape[0] - else: - raise ValueError(f'expected a sparse mask on input / weight / output') - - self._shape = {'batch_size': 1, 'M': M, 'K': K, 'N': N} - self.update_mask(input_mask, weight_mask, output_mask) - - def update_mask( - self, - input_mask: Optional[torch.Tensor] = None, - weight_mask: Optional[torch.Tensor] = None, - output_mask: Optional[torch.Tensor] = None, - ): - if sum(map(lambda x: x is not None, [input_mask, weight_mask, output_mask])) > 1: - raise ValueError(f'linear operators with multiple sparse masks are not supported') - - M, K, N = self._shape['M'], self._shape['K'], self._shape['N'] - - def check_mask_shape(mask: torch.Tensor, name: str, H: Optional[int], W: Optional[int]): - if H is None: - check_H = lambda x: True - H = '?' - else: - check_H = lambda x: x == H - if W is None: - check_W = lambda x: True - W = '?' - else: - check_W = lambda x: x == W - err_msg = f'expected {name} mask shape ({H}, {W}), got {mask.shape}' - assert check_H(mask.shape[0]) and check_W(mask.shape[1]), err_msg - - if input_mask is not None: - check_mask_shape(input_mask, 'input', M, K) - self._set_mask({'A': input_mask}) - elif weight_mask is not None: - check_mask_shape(weight_mask, 'weight', N, K) - self._set_mask({'B': weight_mask}) - elif output_mask is not None: - check_mask_shape(output_mask, 'output', M, N) - self._set_mask({'C': output_mask}) - else: - raise ValueError(f'expected a sparse mask on input / weight / output') - - if self.ready: - self._update_parameters() - - def _read_sample_inputs(self, A: torch.Tensor): - M, K = self._shape['M'], self._shape['K'] - if M is None: - assert K == A.shape[1], f'expect input shape (?, {K}), got {A.shape}' - self._shape['M'] = A.shape[0] - else: - assert (M, K) == A.shape, f'expect input shape ({M}, {K}), got {A.shape}' - - def _update_parameters(self): - weight = self._raw_weight - weight_indexes = self.get_sparse_indexes('B') - if weight_indexes is not None: - weight = weight_indexes.convert(weight.detach()) - self.weight = torch.nn.Parameter(weight, requires_grad=True) - - def _compile(self, config: Optional[Dict[str, Dict[str, Any]]] = None): - super()._compile(config) - self._update_parameters() - - def _sparse_forward(self, input_tensor: torch.Tensor): - inputs = [self._sparse_ctx, input_tensor, self.weight] - if self.bias is not None: - inputs.append(self.bias) - return self.__sparse_func__.apply(*inputs).squeeze(0) - - def set_sample_inputs( - self, - sample_inputs: List[torch.Tensor], - sample_grads: Optional[List[torch.Tensor]] = None, - ): - self._read_sample_inputs(*sample_inputs) - self._sparse_ctx.set_shape(**self._shape) - if self.bias is None: - sample_inputs = [sample_inputs[0], self._raw_weight] - else: - sample_inputs = [sample_inputs[0], self._raw_weight, self.bias] - sample_inputs = [x.unsqueeze(0) for x in sample_inputs] - if sample_grads is not None: - sample_grads = [x.unsqueeze(0) for x in sample_grads] - self._sparse_ctx.set_sample_inputs(sample_inputs, sample_grads) diff --git a/sparta/specializer/operators/sparse_matmul.py b/sparta/specializer/operators/sparse_matmul.py deleted file mode 100644 index 9aba9e49..00000000 --- a/sparta/specializer/operators/sparse_matmul.py +++ /dev/null @@ -1,136 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -from typing import Optional - -import torch - -from sparta.specializer.operators import OperatorBase -from sparta.specializer.functional import SparseBatchMatMul - - -class SparseBatchMatMul(OperatorBase): - r"""The sparse batch matrix multiplication operator: :math:`C = AB` - - Args: - A_mask (Optional[torch.Tensor]): The mask of the first input tensor. - If `A_mask` is set, the other two masks should be `None` - and the internal MatMul kernel will choose SD=>D mode. - B_mask (Optional[torch.Tensor]): The mask of the second input tensor. - If `B_mask` is set, the other two masks should be `None` - and the internal MatMul kernel will choose DS=>D mode. - C_mask (Optional[torch.Tensor]): The mask of the output tensor. - If `C_mask` is set, the other two masks should be `None` - and the internal MatMul kernel will choose DD=>S mode. - transpose_A (bool): Determines whether the first input tensor is transposed. - transpose_B (bool): Determines whether the second input tensor is transposed. - compressed (bool): Determines whether the sparse tensor is compressed to - BCSR / BCSC format. - - Shape: - - Input1: :math:`(B, K, M)` (if `transpose_A == True`) - or :math:`(B, M, K)` (if `transpose_A == False`). - If `A_mask` is set and `compressed == True`, the first input will be - compressed to BCSR / BCSC format and the shape will be :math:`(B, *)`. - - Input2: :math:`(B, N, K)` (if `transpose_B == True`) - or :math:`(B, K, N)` (if `transpose_B == False`). - If `B_mask` is set and `compressed == True`, the second input will be - compressed to BCSR / BCSC format and the shape will be :math:`(B, *)`. - - Output: :math:`(B, M, N)`. - If `C_mask` is set and `compressed == True`, the output will be - compressed to BCSR format and the shape will be :math:`(B, *)`. - - Examples: - - .. code-block:: python - - B, M, K, N = 4, 1024, 1024, 1024 - - # Create a output mask - mask = sparta.testing.block_mask((M, N), sparsity=0.99) - - # Create a sparse batch matmul operator using the mask - sparse_matmul = sparta.nn.SparseBatchMatMul(C_mask=mask) - - # Tune the sparse batch matmul operator - sparta.nn.tune(sparse_matmul, sample_inputs=[ - torch.rand((B, M, K), device='cuda'), - torch.rand((B, K, N), device='cuda'), - ]) - - """ - - __sparse_func__ = SparseBatchMatMul - - def __init__( - self, - A_mask: Optional[torch.Tensor] = None, - B_mask: Optional[torch.Tensor] = None, - C_mask: Optional[torch.Tensor] = None, - transpose_A: bool = False, - transpose_B: bool = False, - compressed: bool = True, - ): - super().__init__() - - self._transpose_A = transpose_A - self._transpose_B = transpose_B - ctx_args = { - 'transpose_A': transpose_A, - 'transpose_B': transpose_B, - 'biased': False, - 'compressed': compressed, - } - batch_size, M, K, N = None, None, None, None - - if A_mask is not None: - self._sparse_ctx = SparseBatchMatMulCtx(mode='sdd', **ctx_args) - if transpose_A: - K, M = A_mask.shape - else: - M, K = A_mask.shape - elif B_mask is not None: - self._sparse_ctx = SparseBatchMatMulCtx(mode='dsd', **ctx_args) - if transpose_B: - N, K = B_mask.shape - else: - K, N = B_mask.shape - elif C_mask is not None: - self._sparse_ctx = SparseBatchMatMulCtx(mode='dds', **ctx_args) - M, N = C_mask.shape - else: - raise ValueError(f'expected a sparse mask on A / B / C') - - self._shape = {'batch_size': batch_size, 'M': M, 'K': K, 'N': N} - self.update_mask(A_mask, B_mask, C_mask) - - def update_mask( - self, - A_mask: Optional[torch.Tensor] = None, - B_mask: Optional[torch.Tensor] = None, - C_mask: Optional[torch.Tensor] = None, - ): - # TODO: check shape conflicts - if sum(map(lambda x: x is not None, [A_mask, B_mask, C_mask])) > 1: - raise ValueError(f'linear operators with multiple sparse masks are not supported') - - if A_mask is not None: - self._set_mask({'A': A_mask}) - elif B_mask is not None: - self._set_mask({'B': B_mask}) - elif C_mask is not None: - self._set_mask({'C': C_mask}) - else: - raise ValueError(f'expected a sparse mask on A / B / C') - - def _read_sample_inputs(self, A: torch.Tensor, B: torch.Tensor): - # TODO: check shape conflicts - if self._transpose_A: - batch_size, K, M = A.shape - else: - batch_size, M, K = A.shape - if self._transpose_B: - batch_size, N, K = B.shape - else: - batch_size, K, N = B.shape - self._shape = {'batch_size': batch_size, 'M': M, 'K': K, 'N': N} diff --git a/sparta/specializer/operators/sparse_softmax.py b/sparta/specializer/operators/sparse_softmax.py deleted file mode 100644 index 6b650e6a..00000000 --- a/sparta/specializer/operators/sparse_softmax.py +++ /dev/null @@ -1,78 +0,0 @@ - -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT license. - -import torch -import numpy as np - -from sparta.specializer.operators import OperatorBase -from sparta.specializer.functional import SparseBatchSoftmax - - -class SparseSoftmax(OperatorBase): - r"""The sparse softmax operator. - - .. math:: - \text{Softmax}(x_{i}, T) = \frac{\exp(\frac{x_i}{T})}{\sum_j \exp(\frac{x_j}{T})} - - Args: - mask (torch.Tensor): The mask tensor marking all positions to be calculated. - temperature (float): The hyper parameter :math:`T` to control the smoothness of the results. - compressed (bool): Determines whether input and output tensors are compressed to BCSR format. - - Shape: - - Input: :math:`(*, H, W)` where `*` means any number of additional dimensions. - - Output: :math:`(*, H, W)`, same shape as the input. - - Examples: - - .. code-block:: python - - B, H, W = 4, 1024, 1024 - - # Create a mask - mask = sparta.testing.block_mask((H, W), sparsity=0.99) - - # Create a sparse softmax operator using the mask - sparse_softmax = sparta.nn.SparseSoftmax(mask=mask) - - # Tune the sparse softmax operator - sparta.nn.tune(sparse_softmax, sample_inputs=[ - torch.rand((B, H, W), device='cuda'), - ]) - - """ - - __sparse_func__ = SparseBatchSoftmax - - def __init__(self, mask: torch.Tensor, temperature: float = 1, compressed: bool = False): - super().__init__() - H, W = mask.shape - self._sparse_ctx = SparseBatchSoftmaxCtx(compressed, temperature) - self._shape = {'H': H, 'W': W} - self.update_mask(mask) - - def update_mask(self, mask: torch.Tensor): - H, W = self._shape['H'], self._shape['W'] - assert mask.shape == (H, W), f'expected mask shape ({H}, {W}), got {mask.shape}' - self._set_mask({'y': mask}) - - def set_temperature(self, temperature): - self._sparse_ctx.set_temperature(temperature) - - def _read_sample_inputs(self, x: torch.Tensor): - H, W = self._shape['H'], self._shape['W'] - if len(x.shape) == 2: - assert (H, W) == x.shape, f'expect input shape ({H}, {W}), got {x.shape}' - self._shape['batch_size'] = 1 - self._sparse_forward = self._sparse_forward_squeeze - self._dense_forward = self._dense_forward_squeeze - else: - assert (H, W) == x.shape[-2:], f'expect input shape (?, {H}, {W}), got {x.shape}' - self._shape['batch_size'] = int(np.prod(x.shape[:-2])) - - def _sparse_forward_squeeze(self, x: torch.Tensor): - return self.__sparse_func__.apply(self._sparse_ctx, x).squeeze(0) - - def _dense_forward_squeeze(self, *args): - return self._sparse_ctx.dense_forward(*args).squeeze(0) diff --git a/test/lut_maker/matmul.py b/test/lut_maker/matmul.py index ea3382c8..3b18e478 100644 --- a/test/lut_maker/matmul.py +++ b/test/lut_maker/matmul.py @@ -9,7 +9,7 @@ import torch import pandas as pd -from sparta.specializer.functional.batch_matmul import SparseBatchMatMulForward +from sparta.operators.sparse_matmul import SparseBatchMatMulForward from sparta.testing import block_mask @@ -45,12 +45,12 @@ def test_matmul_kernel( impl: str, - func: SparseBatchMatMulForward, + operator: SparseBatchMatMulForward, params: Dict[str, Any], ): try: - func.build(config={'forward': {'_impl': impl, **params}}) - latency = func.profile_kernel('forward', num_warmups=10, num_iters=10, cuda=False) + operator.build(config={'forward': {'_impl': impl, **params}}) + latency = operator.profile_kernel('forward', num_warmups=10, num_iters=10, cuda=False) except: latency = float('inf') @@ -61,7 +61,6 @@ def make_matmul_lut(impl: str): major, minor = torch.cuda.get_device_capability() lut_file = os.path.join( 'sparta', - 'specializer', 'kernels', 'look_up_tables', f'matmul.{impl}.{major}{minor}.csv' @@ -101,17 +100,17 @@ def make_matmul_lut(impl: str): iters = 0 for specs in itertools.product(*spec_values): mode, trans_A, trans_B = specs - func = SparseBatchMatMulForward(mode, trans_A, trans_B, False, True) - func.get_sparse_attr().set_mask(mask) - func.reference_forward([A, B]) + operator = SparseBatchMatMulForward(mode, trans_A, trans_B, False, True) + operator.set_mask(mask) + operator.reference(A, B) for params in itertools.product(*param_values): param_dict = {k: v for k, v in zip(param_keys, params)} - latency = test_matmul_kernel(impl, func, param_dict) + latency = test_matmul_kernel(impl, operator, param_dict) with open(log_file, 'a') as f: items = [mode, trans_A, trans_B, *params, latency] f.write(','.join([str(x) for x in items]) + '\n') iters += 1 - _logger.info(f'[{str(iters).zfill(bits)} / {num}] {params} => {latency} ms') + _logger.info(f'[{str(iters).rjust(bits)} / {num}] {params} => {latency} ms') df = pd.read_csv(log_file) df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] diff --git a/test/lut_maker/softmax.py b/test/lut_maker/softmax.py index 38360178..26d941fb 100644 --- a/test/lut_maker/softmax.py +++ b/test/lut_maker/softmax.py @@ -10,7 +10,7 @@ import numpy as np import pandas as pd -from sparta.specializer.functional import SparseBatchSoftmax +from sparta.operators.sparse_softmax import SparseBatchSoftmax from sparta.testing import block_mask @@ -31,13 +31,13 @@ def test_softmax_kernel( impl: str, - func: SparseBatchSoftmax, + operator: SparseBatchSoftmax, direction: str, params: Dict[str, Any], ): try: - func.build(config={direction: {'_impl': impl, **params}}) - latency = func.profile_kernel(direction, num_warmups=10, num_iters=10, cuda=False) + operator.build(config={direction: {'_impl': impl, **params}}) + latency = operator.profile_kernel(direction, num_warmups=10, num_iters=10, cuda=False) except: latency = float('inf') @@ -48,7 +48,6 @@ def make_softmax_lut(impl: str, direction: str): major, minor = torch.cuda.get_device_capability() lut_file = os.path.join( impl, - 'specializer', 'kernels', 'look_up_tables', f'softmax.{direction}.{impl}.{major}{minor}.csv' @@ -78,19 +77,18 @@ def make_softmax_lut(impl: str, direction: str): grad_y = torch.rand(size=(1, SIZE, SIZE), device='cuda') mask = block_mask((SIZE, SIZE), sparsity=0, device='cuda') - func = SparseBatchSoftmax(compressed=True, temperature=np.float32(1 / np.sqrt(SIZE))) - func.get_sparse_attr().set_mask(mask) - func.reference_forward([x]) - func.reference_backward([grad_y]) + operator = SparseBatchSoftmax(compressed=True, temperature=np.float32(1 / np.sqrt(SIZE))) + operator.get_sparse_attr().set_mask(mask) + operator.reference(x).backward(grad_y) iters = 0 for params in itertools.product(*values): param_dict = {k: v for k, v in zip(keys, params)} - latency = test_softmax_kernel(impl, func, direction, param_dict) + latency = test_softmax_kernel(impl, operator, direction, param_dict) with open(log_file, 'a') as f: f.write(','.join([str(x) for x in params]) + f',{latency}\n') iters += 1 - _logger.info(f'[{str(iters).zfill(bits)} / {num}] {params} => {latency} ms') + _logger.info(f'[{str(iters).rjust(bits)} / {num}] {params} => {latency} ms') df = pd.read_csv(log_file) df = df.loc[df.groupby(HYPER_PARAMS).aggregate({'latency': 'idxmin'})['latency']] diff --git a/test/unit/test_sparse_matmul.py b/test/unit/test_sparse_matmul.py index 99db55c2..b36b04ab 100644 --- a/test/unit/test_sparse_matmul.py +++ b/test/unit/test_sparse_matmul.py @@ -6,9 +6,8 @@ import torch import pytest -from sparta.specializer.kernels import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.specializer.functional import SparsityAttr, SparseMatMul, SparseBatchMatMul -# from sparta.nn import SparseBatchMatMul, SparseLinear +from sparta.kernels import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel +from sparta.operators import SparsityAttr, SparseLinear, SparseMatMul, SparseBatchMatMul from sparta.tesa import BCSIndexes from sparta.testing import block_mask @@ -25,6 +24,7 @@ def prepare_data( trans_B: bool = False, biased: bool = False, requires_grad: bool = False, + mask: Optional[torch.Tensor] = None, random_seed: int = 2022, ): inputs = ['A', 'B'] @@ -49,25 +49,44 @@ def prepare_data( data[f'input_grad_{y}'] = torch.rand(size=shape, device='cuda') sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - mask = block_mask(shapes[sparse_port], block=granularity, sparsity=sparsity, device='cuda') + if mask is None: + mask = block_mask( + shape=shapes[sparse_port], + block=granularity, + sparsity=sparsity, + device='cuda', + ) add_mask(data, mask, sparse_port, 'input') + calc_target_data(data, requires_grad, trans_A, trans_B) + add_mask(data, mask, sparse_port, 'target') + + return data, mask + + +def calc_target_data( + data: Dict[str, torch.Tensor], + requires_grad: bool, + trans_A: bool, + trans_B: bool, +): if requires_grad: - for x in inputs: - data[f'input_{x}'].requires_grad = True + for k, v in data.items(): + if k.startswith('input'): + v.requires_grad = True - if batch is None: - input_A = data['input_A'].T if trans_A else data['input_A'] - input_B = data['input_B'].T if trans_B else data['input_B'] - data['target_C'] = torch.mm(input_A, input_B) - if biased: - data['target_C'] += data['input_bias'] - else: + if len(data['input_A'].shape) == 3: input_A = data['input_A'].swapaxes(1, 2) if trans_A else data['input_A'] input_B = data['input_B'].swapaxes(1, 2) if trans_B else data['input_B'] data['target_C'] = torch.bmm(input_A, input_B) - if biased: + if 'input_bias' in data: data['target_C'] += data['input_bias'].unsqueeze(1) + else: + input_A = data['input_A'].T if trans_A else data['input_A'] + input_B = data['input_B'].T if trans_B else data['input_B'] + data['target_C'] = torch.mm(input_A, input_B) + if 'input_bias' in data: + data['target_C'] += data['input_bias'] if requires_grad: data['target_C'].backward(data['input_grad_C']) @@ -75,18 +94,14 @@ def prepare_data( data['input_A'].grad = None data['target_grad_B'] = data['input_B'].grad data['input_B'].grad = None - if biased: + if 'input_bias' in data: data['target_grad_bias'] = data['input_bias'].grad data['input_bias'].grad = None - add_mask(data, mask, sparse_port, 'target') - - return data, mask - def add_mask( data: Dict[str, torch.Tensor], - mask: torch.Tensor, + mask: torch.Tensor, sparse_port: str, stage: str, ): @@ -126,7 +141,8 @@ def compress_data( def check_results(data: Dict[str, torch.Tensor]): for name, val in data.items(): if name.startswith('target_'): - torch.testing.assert_close(val, data[name.replace('target', 'output')], msg=name) + out = data[name.replace('target', 'output')] + torch.testing.assert_close(out, val, atol=1e-4, rtol=1e-4, msg=name) @pytest.mark.parametrize("impl", ['sparta', 'openai']) @@ -150,7 +166,12 @@ def test_sparse_matmul_kernel( granularity: Tuple[int, int] = (8, 8), sparsity: float = 0.9, ): - data, mask = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, False) + data, mask = prepare_data( + batch, M, K, N, + granularity, sparsity, + mode, trans_A, trans_B, biased, + False, + ) kernelClass: Type[SparseMatMulKernel] = { 'sparta': SparTASparseMatMulKernel, @@ -208,7 +229,7 @@ def test_sparse_matmul_kernel( @pytest.mark.parametrize("trans_B", [False, True]) @pytest.mark.parametrize("compressed", [False, True]) @pytest.mark.parametrize("batch", [None, 4]) -def test_sparse_matmul_function( +def test_sparse_matmul_operator( mode: str, biased: bool, compressed: bool, @@ -221,144 +242,155 @@ def test_sparse_matmul_function( granularity: Tuple[int, int] = (8, 8), sparsity: float = 0.9, ): - data, mask = prepare_data(batch, M, K, N, granularity, sparsity, mode, trans_A, trans_B, biased, True) + data, mask = prepare_data( + batch, M, K, N, + granularity, sparsity, + mode, trans_A, trans_B, biased, + True, + ) if batch is None: - func = SparseMatMul(mode, trans_A, trans_B, biased, compressed) + sparse_matmul = SparseMatMul(mode, trans_A, trans_B, biased, compressed) else: - func = SparseBatchMatMul(mode, trans_A, trans_B, biased, compressed) - - sparse_attr = func.get_sparse_attr() - sparse_attr.set_mask(mask) + sparse_matmul = SparseBatchMatMul(mode, trans_A, trans_B, biased, compressed) + sparse_matmul.set_mask(mask) kernel_names = ['forward', 'backward:A', 'backward:B'] inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] - func.build( + sparse_matmul.build( config={kernel_name: get_params('sparta') for kernel_name in kernel_names}, sample_inputs=[data[f'input_{x}'] for x in inputs] ) + sparse_matmul.clear_sample_data() sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - if compressed: - data, mask = compress_data(sparse_attr.indexes, sparse_port, data, mask, True) - data['output_C'] = func(*[data[f'input_{x}'] for x in inputs]) - data['output_C'].backward(data['input_grad_C']) - for x in inputs: - data[f'output_grad_{x}'] = data[f'input_{x}'].grad - add_mask(data, mask, sparse_port, 'output') - check_results(data) + def run_test(): + nonlocal sparse_matmul, data, mask + if compressed: + indexes = sparse_matmul.get_sparse_attr().indexes + data, cmask = compress_data(indexes, sparse_port, data, mask, True) + else: + cmask = mask + data['output_C'] = sparse_matmul(*[data[f'input_{x}'] for x in inputs]) + data['output_C'].backward(data['input_grad_C']) + for x in inputs: + data[f'output_grad_{x}'] = data[f'input_{x}'].grad + add_mask(data, cmask, sparse_port, 'output') + check_results(data) + + run_test() + + # Dynamic mask + data, mask = prepare_data( + batch, M, K, N, + granularity, sparsity, + mode, trans_A, trans_B, biased, + True, random_seed=2023, + ) + sparse_matmul.set_mask(mask) + run_test() + + # Dynamic Dim + if sparse_port == 'A': + N = 1024 + elif sparse_port == 'B': + M = 1024 + elif sparse_port == 'C': + K = 1024 + data, mask = prepare_data( + batch, M, K, N, + granularity, sparsity, + mode, trans_A, trans_B, biased, + True, mask, + ) + run_test() + +@pytest.mark.parametrize('mode', ['sdd', 'dsd', 'dds']) +@pytest.mark.parametrize('biased', [False, True]) +def test_sparse_linear_operator( + mode: str, + biased: bool, + batch: int = 128, + in_dims: int = 256, + out_dims: int = 192, + granularity: Tuple[int, int] = (8, 8), + sparsity: float = 0.9, +): + data, mask = prepare_data( + None, batch, in_dims, out_dims, + granularity, sparsity, + mode, False, True, biased, + True, + ) -# @pytest.mark.parametrize("mode", ['sdd', 'dsd', 'dds']) -# @pytest.mark.parametrize("trans_A", [False, True]) -# @pytest.mark.parametrize("trans_B", [False, True]) -# @pytest.mark.parametrize("compressed", [False, True]) -# def test_sparse_matmul_operator( -# mode: str, -# compressed: bool, -# trans_A: bool, -# trans_B: bool, -# batch: int = 4, -# M: int = 128, -# K: int = 256, -# N: int = 192, -# granularity: Tuple[int, int] = (8, 8), -# sparsity: float = 0.9, -# ): -# data, masks = prepare_data( -# batch, M, K, N, granularity, sparsity, -# mode, trans_A, trans_B, False, True, -# ) - -# sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] -# sparse_matmul = SparseBatchMatMul( -# **{f'{name}_mask': val for name, val in masks.items()}, -# transpose_A=trans_A, -# transpose_B=trans_B, -# compressed=compressed, -# ) -# sparse_matmul.build( -# config={ -# kernel_name: get_params('sparta') -# for kernel_name in sparse_matmul.get_kernel_placeholders(backward=True) -# }, -# sample_inputs=[data['input_A'], data['input_B']], -# ) - -# for random_seed in range(3): # Test dynamic sparse -# if compressed: -# compress_data(sparse_matmul.get_sparse_indexes(sparse_port), sparse_port, data, masks) - -# data['output_C'] = sparse_matmul.forward(data['input_A'], data['input_B']) -# data['output_C'].backward(data['input_grad_C']) -# for x in ['A', 'B']: -# data[f'output_grad_{x}'] = data[f'input_{x}'].grad - -# add_mask(data, masks, sparse_port, 'output') - -# check_results(data) - -# data, masks = prepare_data( -# batch, M, K, N, granularity, sparsity, -# mode, trans_A, trans_B, False, True, random_seed, -# ) -# sparse_matmul.update_mask(**{f'{name}_mask': val for name, val in masks.items()}) - - -# @pytest.mark.parametrize('mode', ['sdd', 'dsd', 'dds']) -# @pytest.mark.parametrize('biased', [False, True]) -# def test_sparse_linear_operator( -# mode: str, -# biased: bool, -# batch: int = 128, -# in_dims: int = 256, -# out_dims: int = 192, -# granularity: Tuple[int, int] = (8, 8), -# sparsity: float = 0.9, -# ): -# data, masks = prepare_data( -# -1, batch, in_dims, out_dims, granularity, sparsity, -# mode, False, True, biased, True, -# ) - -# dense_linear = torch.nn.Linear(in_dims, out_dims, bias=biased, device='cuda') -# if biased: -# dense_linear.load_state_dict({'weight': data['input_B'], 'bias': data['input_bias']}) -# else: -# dense_linear.load_state_dict({'weight': data['input_B']}) - -# sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] -# mask_name = {'sdd': 'input_mask', 'dsd': 'weight_mask', 'dds': 'output_mask'}[mode] -# sparse_linear = SparseLinear(dense_linear, **{mask_name: masks[sparse_port]}) -# sparse_linear.build( -# config={ -# kernel_name: get_params('sparta') -# for kernel_name in sparse_linear.get_kernel_placeholders(backward=True) -# }, -# sample_inputs=[data['input_A']], -# ) - -# for random_seed in range(3): # Test dynamic sparse -# if mode == 'dsd': -# compress_data(sparse_linear.get_sparse_indexes('B'), 'B', data, masks) - -# data['output_C'] = sparse_linear.forward(data['input_A']) -# data['output_C'].backward(data['input_grad_C']) -# data[f'output_grad_A'] = data[f'input_A'].grad -# data[f'output_grad_B'] = sparse_linear.weight.grad -# if biased: -# data[f'output_grad_bias'] = sparse_linear.bias.grad - -# add_mask(data, masks, sparse_port, 'output') - -# check_results(data) - -# data, masks = prepare_data( -# -1, batch, in_dims, out_dims, granularity, sparsity, -# mode, False, True, biased, True, random_seed, -# ) -# if biased: -# sparse_linear.bias = torch.nn.Parameter(data['input_bias']) -# sparse_linear._raw_weight = data['input_B'] -# sparse_linear.update_mask(**{mask_name: masks[sparse_port]}) + dense_linear = torch.nn.Linear(in_dims, out_dims, bias=biased, device='cuda') + if biased: + dense_linear.load_state_dict({'weight': data['input_B'], 'bias': data['input_bias']}) + else: + dense_linear.load_state_dict({'weight': data['input_B']}) + + sparse_linear = SparseLinear(dense_linear, mode) + sparse_linear.set_mask(mask) + + kernel_names = ['forward', 'backward:A', 'backward:B'] + sparse_linear.build( + config={kernel_name: get_params('sparta') for kernel_name in kernel_names}, + sample_inputs=[data['input_A']], + ) + sparse_linear.clear_sample_data() + + sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] + + def run_test(): + nonlocal sparse_linear, data, mask + if mode == 'dsd': + indexes = sparse_linear.get_sparse_attr().indexes + data, cmask = compress_data(indexes, sparse_port, data, mask, True) + else: + cmask = mask + data['output_C'] = sparse_linear(data['input_A']) + data['output_C'].backward(data['input_grad_C']) + data[f'output_grad_A'] = data[f'input_A'].grad + data[f'output_grad_B'] = sparse_linear.weight.grad + if biased: + data[f'output_grad_bias'] = sparse_linear.bias.grad + add_mask(data, cmask, sparse_port, 'output') + check_results(data) + + run_test() + + # Dynamic mask + data, mask = prepare_data( + None, batch, in_dims, out_dims, + granularity, sparsity, + mode, False, True, biased, + True, random_seed=2023 + ) + sparse_linear.ports['B'].set_data(data['input_B']) + if biased: + sparse_linear.bias = torch.nn.Parameter(data['input_bias']) + sparse_linear.set_mask(mask) + run_test() + + # Dynamic Dim + if sparse_port == 'A': + N = 1024 + elif sparse_port == 'B': + M = 1024 + elif sparse_port == 'C': + K = 1024 + data, mask = prepare_data( + None, batch, in_dims, out_dims, + granularity, sparsity, + mode, False, True, biased, + True, mask, + ) + weight = data['input_B'] + if mode == 'dsd': + weight = sparse_linear.get_sparse_attr().indexes.convert(weight.detach()) + sparse_linear.weight = torch.nn.Parameter(weight) + if biased: + sparse_linear.bias = torch.nn.Parameter(data['input_bias']) + run_test() diff --git a/test/unit/test_sparse_softmax.py b/test/unit/test_sparse_softmax.py index 2886c2df..e8b253bb 100644 --- a/test/unit/test_sparse_softmax.py +++ b/test/unit/test_sparse_softmax.py @@ -7,9 +7,8 @@ import pytest import numpy as np -from sparta.specializer.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.specializer.functional import SparsityAttr, SparseSoftmax, SparseBatchSoftmax -# from sparta.nn import SparseSoftmax +from sparta.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel +from sparta.operators import SparsityAttr, SparseSoftmax, SparseBatchSoftmax from sparta.testing import block_mask, sparse_softmax_forward_reference @@ -20,6 +19,7 @@ def prepare_data( granularity: Tuple[int, int] = (8, 8), sparsity: float = 0.9, requires_grad: bool = False, + mask: Optional[torch.Tensor] = None, random_seed: int = 2022, ): torch.manual_seed(random_seed) @@ -28,7 +28,8 @@ def prepare_data( data['input_x'] = torch.rand(shape, device='cuda') temperature = np.sqrt(W) - mask = block_mask((H, W), block=granularity, sparsity=sparsity, device='cuda') + if mask is None: + mask = block_mask((H, W), block=granularity, sparsity=sparsity, device='cuda') data['grad_y'] = torch.rand(shape, device='cuda') data['input_x'].requires_grad = True @@ -72,7 +73,7 @@ def test_sparse_softmax_kernels( granularity: Tuple[int, int] = (8, 8), sparsity: float = 0, ): - data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=False) + data, mask = prepare_data(batch, H, W, granularity, sparsity, False) batched = batch is not None forward_kernel = SparTASparseSoftmaxForwardKernel(compressed, batched) @@ -104,7 +105,7 @@ def test_sparse_softmax_kernels( @pytest.mark.parametrize("compressed", [False, True]) @pytest.mark.parametrize("batch", [None, 4]) -def test_sparse_softmax_function( +def test_sparse_softmax_operator( compressed: bool, batch: Optional[int], H: int = 128, @@ -112,66 +113,36 @@ def test_sparse_softmax_function( granularity: Tuple[int, int] = (8, 8), sparsity: float = 0.9, ): - data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=True) + data, mask = prepare_data(batch, H, W, granularity, sparsity, True) if batch is None: - func = SparseSoftmax(compressed, np.sqrt(W)) + sparse_softmax = SparseSoftmax(compressed, np.sqrt(W)) else: - func = SparseBatchSoftmax(compressed, np.sqrt(W)) + sparse_softmax = SparseBatchSoftmax(compressed, np.sqrt(W)) - sparse_attr = func.get_sparse_attr() + sparse_attr = sparse_softmax.get_sparse_attr() sparse_attr.set_mask(mask) kernel_names = ['forward', 'backward'] - func.build( + sparse_softmax.build( config={kernel_name: get_params() for kernel_name in kernel_names}, sample_inputs=[data['input_x']] ) - if compressed: - for name in data: - data[name] = sparse_attr.indexes.convert(data[name].detach()) - data['input_x'].requires_grad = True - - data['output_y'] = func(data['input_x']) - data['output_y'].backward(data['grad_y']) - data['output_grad_x'] = data['input_x'].grad - check_results(data) - - -# @pytest.mark.parametrize("batch", [None, 4]) -# @pytest.mark.parametrize("compressed", [False, True]) -# def test_sparse_softmax_operator( -# compressed: bool, -# batch: Optional[int], -# H: int = 128, -# W: int = 256, -# granularity: Tuple[int, int] = (8, 8), -# sparsity: float = 0.9, -# ): -# data, mask = prepare_data(batch, H, W, granularity, sparsity, requires_grad=True) - -# sparse_softmax = SparseSoftmax(mask, np.sqrt(W), compressed) -# sparse_softmax.build( -# config={ -# kernel_name: get_params() -# for kernel_name in sparse_softmax.get_kernel_placeholders(backward=True) -# }, -# sample_inputs=[data['input_x']], -# ) - -# for random_seed in range(3): # Test dynamic sparse -# if compressed: -# indexes = sparse_softmax.get_sparse_indexes('y') -# for name in data: -# data[name] = indexes.convert(data[name].detach()) -# data['input_x'].requires_grad = True - -# data['output_y'] = sparse_softmax.forward(data['input_x']) -# data['output_y'].backward(data['grad_y']) -# data['output_grad_x'] = data['input_x'].grad - -# check_results(data) - -# data, mask = prepare_data(batch, H, W, granularity, sparsity, True, random_seed) -# sparse_softmax.update_mask(mask) + def run_test(): + nonlocal sparse_softmax, data, mask + if compressed: + for name in data: + data[name] = sparse_attr.indexes.convert(data[name].detach()) + data['input_x'].requires_grad = True + data['output_y'] = sparse_softmax(data['input_x']) + data['output_y'].backward(data['grad_y']) + data['output_grad_x'] = data['input_x'].grad + check_results(data) + + run_test() + + # Dynamic mask + data, mask = prepare_data(batch, H, W, granularity, sparsity, True, random_seed=2023) + sparse_softmax.set_mask(mask) + run_test() From 8de6950cd44ebebc48ed27aab1e1d09cf10b643c Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 7 Mar 2023 10:14:53 +0000 Subject: [PATCH 14/28] move sparse attr to kernel level --- sparta/kernels/__init__.py | 2 +- sparta/kernels/kernel_base.py | 150 +++++++++++++- sparta/kernels/matmul.py | 51 ++++- sparta/kernels/softmax.py | 43 +++- sparta/operators/operator_base.py | 217 ++++++------------- sparta/operators/sparse_attention.py | 3 +- sparta/operators/sparse_matmul.py | 297 ++++++++++----------------- sparta/operators/sparse_softmax.py | 134 +++++------- test/unit/test_sparse_matmul.py | 42 +--- test/unit/test_sparse_softmax.py | 24 +-- 10 files changed, 476 insertions(+), 487 deletions(-) diff --git a/sparta/kernels/__init__.py b/sparta/kernels/__init__.py index 482785bb..983d8cdf 100644 --- a/sparta/kernels/__init__.py +++ b/sparta/kernels/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from sparta.kernels.kernel_base import KernelBase +from sparta.kernels.kernel_base import KernelBase, SparsityAttr, KernelGroup from sparta.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel from sparta.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel diff --git a/sparta/kernels/kernel_base.py b/sparta/kernels/kernel_base.py index fca5b89a..a1b6d40a 100644 --- a/sparta/kernels/kernel_base.py +++ b/sparta/kernels/kernel_base.py @@ -9,6 +9,7 @@ from typing import Any, Dict, List, Tuple, Callable, Optional import torch +import numpy as np from sparta import __env_ready__ if __env_ready__: @@ -16,7 +17,9 @@ import pycuda.autoprimaryctx from pycuda.compiler import SourceModule +from sparta.tesa import get_bcs_function, BCSIndexes from sparta.tuning import TunableItemCfg +from sparta.testing import profile @dataclasses.dataclass @@ -32,19 +35,129 @@ def __post_init__(self): assert self.is_tunable +@dataclasses.dataclass +class _KernelAttrEdge: + attr: SparsityAttr + kernel: KernelBase + block_H: str + block_W: str + + def get_block_size(self): + BH = self.kernel.get_parameter(self.block_H) + BW = self.kernel.get_parameter(self.block_W) + return BH, BW + + +class SparsityAttr(object): + + def __init__( + self, + kernel: KernelBase, + block_H: str, + block_W: str, + BCSR: bool, + BCSC: bool, + ): + self.edges = {kernel.id: _KernelAttrEdge(self, kernel, block_H, block_W)} + self.groups: List[KernelGroup] = [] + self.BCSR = BCSR + self.BCSC = BCSC + self.mask: torch.Tensor = None + self._block_H: int = 0 + self._block_W: int = 0 + self.indexes: BCSIndexes = None + self.ready = False + + def update_block_size(self, kernel_id: int): + block_H, block_W = self.edges[kernel_id].get_block_size() + if block_H != self._block_H or block_W != self._block_W: + self._block_H = block_H + self._block_W = block_W + self._update_indexes() + + def set_mask(self, mask: torch.Tensor): + self.mask = mask + self._update_indexes() + + def _update_indexes(self): + if self._block_H > 0 and self._block_W > 0 and self.mask is not None: + self.indexes = get_bcs_function( + self._block_H, self._block_W, + self.BCSR, self.BCSC, + ).build_indexes(self.mask) + self.ready = True + + def connect(self, other: SparsityAttr): + self.BCSR |= other.BCSR + self.BCSC |= other.BCSC + for kernel_id, kernel_edge in other.edges.items(): + self.edges[kernel_id] = kernel_edge + kernel_edge.kernel.attr = self + for kernel_group in other.groups: + self.groups.append(kernel_group) + kernel_group.attr = self + + def get_search_space(self): + pass + + +class KernelGroup(object): + + def __init__( + self, + kernels: Dict[str, KernelBase], + input_getter: Callable[[], List[torch.Tensor]], + ): + self._kernels = kernels + self._get_inputs = input_getter + kernel_list = list(kernels.values()) + self.attr = kernel_list[0].attr + self.attr.groups.append(self) + for kernel in kernel_list[1:]: + self.attr.connect(kernel.attr) + self.active_kernel: Callable = kernel_list[0].reference + self.ready = False + + def set_sample_shape(self, shape: Tuple): + self._shape = shape + + def build(self, params: Dict[str, Any]): + self.active_kernel = self._kernels[params['_impl']] + self.active_kernel.compile(params, self._shape) + self.ready = True + + def get_search_space(self): + return { + impl: kernel.get_search_space() + for impl, kernel in self._kernels.items() + } + + +_kernel_num = 0 +def _next_kernel_id(): + global _kernel_num + _kernel_num += 1 + return _kernel_num + + class KernelBase(Callable): + __lut_shape__: Tuple = () + def __init__(self): + self.id = _next_kernel_id() + self.attr: SparsityAttr = None self._parameters: Dict[str, _Parameter] = {} self._kernel: Callable = None - self._func: Callable = None + self._func: Callable = self.reference self.ready = False self._add_parameters() - self.estimated_latency_per_flop = float('inf') + self._lut_latency = float('inf') + self.estimate_latency = float('inf') @abc.abstractmethod def _add_parameters(self): - """Add kernel-specialized parameters.""" + """Add kernel-specialized parameters and set sparsity attribute.""" def _add_parameter( self, @@ -110,12 +223,13 @@ def _check_parameters(self, params: Dict[str, Any]): """Raise an error if the input paramater dict is invalid.""" @abc.abstractmethod - def set_kernel_call(self, shape: Tuple, sparse_attr: Any): + def set_kernel_call(self, shape: Tuple): """Convert pycuda kernel (self._kernel) to python function call (self._func).""" - def compile(self, params: Dict[str, Any], shape: Any, sparse_attr: Any): + def compile(self, params: Dict[str, Any], shape: Tuple): self._check_parameters(params) self.set_parameters(params) + self.attr.update_block_size(self.id) kernel_code = self.get_kernel_code() kernel_name = kernel_code[kernel_code.find('__global__ void') + 15:] kernel_name = kernel_name[:kernel_name.find('(')].strip() @@ -123,11 +237,27 @@ def compile(self, params: Dict[str, Any], shape: Any, sparse_attr: Any): warnings.simplefilter('ignore') source_module = SourceModule(kernel_code, options=['-O3']) self._kernel = source_module.get_function(kernel_name) - self.set_kernel_call(shape, sparse_attr) + self.set_kernel_call(shape) self.ready = True + # Calc estimated latency + indexes = self.attr.indexes + sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num + shape_rate = np.prod(shape) / np.prod(self.__lut_shape__) + self.estimate_latency = self._lut_latency * shape_rate * sparse_rate + + @abc.abstractmethod + def reference(self, *inputs, sparse: bool = False): + """Reference forward function.""" + + def profile( + self, + inputs: List[torch.Tensor], + num_warmups: int = 20, + num_iters: int = 100, + cuda: bool = False + ): + target_output = self.reference(*inputs, sparse=True) + return profile(self, inputs, [target_output], num_warmups, num_iters, cuda) def __call__(self, *args) -> torch.Tensor: - if self.ready: - return self._func(*args) - else: - raise ValueError('The kernel is not compiled.') + return self._func(*args) diff --git a/sparta/kernels/matmul.py b/sparta/kernels/matmul.py index e3236178..940664ad 100644 --- a/sparta/kernels/matmul.py +++ b/sparta/kernels/matmul.py @@ -4,7 +4,7 @@ import io import textwrap import importlib.resources as res -from typing import Any, Dict, Tuple +from typing import Any, Dict, Tuple, Optional import torch import jinja2 @@ -12,7 +12,7 @@ import pandas as pd from sparta.tuning import TunableItemCfg -from sparta.kernels import KernelBase, templates, look_up_tables +from sparta.kernels import KernelBase, SparsityAttr, templates, look_up_tables def _get_matmul_lut(impl: str): @@ -34,6 +34,7 @@ def _get_matmul_lut(impl: str): class SparseMatMulKernel(KernelBase): + __lut_shape__ = (4096, 4096, 4096) __algo__: str = '' def __init__( @@ -60,6 +61,17 @@ def __init__( trans_B_filter = lut['trans_B'] == self._transpose_B self._lut = lut[mode_filter & trans_A_filter & trans_B_filter] + self._sparse_axis = { + 'sdd': ['K', 'M'] if transpose_A else ['M', 'K'], + 'dsd': ['N', 'K'] if transpose_B else ['K', 'N'], + 'dds': ['M', 'N'], + }[mode] + self._BCSR = { + 'sdd': not transpose_A, + 'dsd': transpose_B, + 'dds': True, + }[mode] + super().__init__() def _add_parameters(self): @@ -71,6 +83,8 @@ def _add_parameters(self): self._add_parameter('COMPRESSED', value=self._compressed) self._add_parameter('BCSR') self._add_parameter('BCSC') + block_H, block_W = [f'BLOCK_SIZE_{axis}_VALUE' for axis in self._sparse_axis] + self.attr = SparsityAttr(self, block_H, block_W, self._BCSR, not self._BCSR) def get_block_shape(self): BM = self.get_parameter('BLOCK_SIZE_M_VALUE') @@ -78,7 +92,7 @@ def get_block_shape(self): BN = self.get_parameter('BLOCK_SIZE_N_VALUE') return BM, BK, BN - def set_kernel_call(self, shape: Tuple[int, int, int, int], sparse_attr: Any): + def set_kernel_call(self, shape: Tuple[int, int, int, int]): batch, M, K, N = shape M_32, K_32, N_32 = np.int32(M), np.int32(K), np.int32(N) BM, BK, BN = self.get_block_shape() @@ -87,6 +101,7 @@ def set_kernel_call(self, shape: Tuple[int, int, int, int], sparse_attr: Any): raw_func = self._kernel zeros = torch.zeros int32 = np.int32 + sparse_attr = self.attr func_code = jinja2.Template(textwrap.dedent(''' def matmul_func(A, B{% if BIASED %}, bias{% endif %}): @@ -131,10 +146,36 @@ def matmul_func(A, B{% if BIASED %}, bias{% endif %}): self._func = locals()['matmul_func'] def get_kernel_code(self): + self.set_parameter('BCSR', self._BCSR or self.attr.BCSR) + self.set_parameter('BCSC', not self._BCSR) template_file = f'{self.__algo__}_sparse_matmul_{self._mode}.cuh.j2' kernel_template = res.read_text(templates, template_file) return jinja2.Template(kernel_template).render(self.get_parameters()) + def reference( + self, + A: torch.Tensor, + B: torch.Tensor, + bias: Optional[torch.Tensor] = None, + sparse: bool = False, + ): + if sparse and self._compressed: + if self._mode == 'sdd': + A = self.attr.indexes.inverse(A) + elif self._mode == 'dsd': + B = self.attr.indexes.inverse(B) + if self._transpose_A: + A = A.swapaxes(self._batched + 0, self._batched + 1) + if self._transpose_B: + B = B.swapaxes(self._batched + 0, self._batched + 1) + C = torch.bmm(A, B) if self._batched else torch.mm(A, B) + if bias is not None: + C += bias.unsqueeze(1) if self._batched else bias + if sparse and self._compressed: + if self._mode == 'dds': + C = self.attr.indexes.convert(C) + return C + class SparTASparseMatMulKernel(SparseMatMulKernel): @@ -201,7 +242,7 @@ def _check_parameters(self, params: Dict[str, Any]): self.set_parameter('THREAD_SIZE_M_VALUE', int(TM)) self.set_parameter('THREAD_SIZE_K_VALUE', int(TK)) self.set_parameter('THREAD_SIZE_N_VALUE', int(TN)) - self.estimated_latency_per_flop = row['latency'] / 4096 / 4096 / 4096 + self._lut_latency = row['latency'] class OpenAISparseMatMulKernel(SparseMatMulKernel): @@ -229,4 +270,4 @@ def _check_parameters(self, params: Dict[str, Any]): if 'BLOCK_SIZE_N_VALUE' in params: assert params['BLOCK_SIZE_N_VALUE'] == 32 row = self._lut.reset_index(drop=True).iloc[0, :] - self.estimated_latency_per_flop = row['latency'] / 32 / 64 / 32 + self._lut_latency = row['latency'] diff --git a/sparta/kernels/softmax.py b/sparta/kernels/softmax.py index 6f1b469d..7bfc549c 100644 --- a/sparta/kernels/softmax.py +++ b/sparta/kernels/softmax.py @@ -12,7 +12,8 @@ import pandas as pd from sparta.tuning import TunableItemCfg -from sparta.kernels import KernelBase, templates, look_up_tables +from sparta.kernels import KernelBase, SparsityAttr, templates, look_up_tables +from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference def _get_softmax_lut(impl: str, direction: str): @@ -36,6 +37,7 @@ def _get_softmax_lut(impl: str, direction: str): class SparseSoftmaxKernel(KernelBase): + __lut_shape__ = (4096, 4096) __algo__: str = '' __direction__: str = '' @@ -60,11 +62,41 @@ class SparseSoftmaxForwardKernel(SparseSoftmaxKernel): __direction__ = 'forward' + def reference( + self, + x: torch.Tensor, + mask: torch.Tensor, + T: np.float32, + sparse: bool = False, + ): + if sparse and self._compressed: + x = self.attr.indexes.inverse(x) + y = sparse_softmax_forward_reference(x, mask, T) + if sparse and self._compressed: + y = self.attr.indexes.convert(y) + return y + class SparseSoftmaxBackwardKernel(SparseSoftmaxKernel): __direction__ = 'backward' + def reference( + self, + grad_y: torch.Tensor, + y: torch.Tensor, + mask: torch.Tensor, + T: np.float32, + sparse: bool = False, + ): + if sparse and self._compressed: + grad_y = self.attr.indexes.inverse(grad_y) + y = self.attr.indexes.inverse(y) + grad_x = sparse_softmax_backward_reference(grad_y, y, mask, T) + if sparse and self._compressed: + grad_x = self.attr.indexes.convert(grad_x) + return grad_x + class SparTASoftmaxKernel(SparseSoftmaxKernel): @@ -88,6 +120,7 @@ def _add_parameters(self): ) self._add_parameter('ROW_TILE_VALUE') self._add_parameter('MAX_W_VALUE', value=1024) + self.attr = SparsityAttr(self, 'BLOCK_SIZE_H_VALUE', 'BLOCK_SIZE_W_VALUE', True, False) def threads_per_block(self) -> Tuple[int]: BW = self.get_parameter('BLOCK_SIZE_W_VALUE') @@ -111,7 +144,7 @@ def _check_parameters(self, params: Dict[str, Any]): row = row.reset_index(drop=True).iloc[0, :] assert float(row['latency']) < float('inf'), f'block shape ({BH}, {BW}) is invalid' self.set_parameter('ROW_TILE_VALUE', int(row['RT'])) - self.estimated_latency_per_flop = row['latency'] / BH / BW + self._lut_latency = row['latency'] def get_kernel_code(self): template_file = f'{self.__algo__}_sparse_softmax_{self.__direction__}.cuh.j2' @@ -121,7 +154,7 @@ def get_kernel_code(self): class SparTASparseSoftmaxForwardKernel(SparseSoftmaxForwardKernel, SparTASoftmaxKernel): - def set_kernel_call(self, shape: Tuple[int, int, int], sparse_attr: Any): + def set_kernel_call(self, shape: Tuple[int, int, int]): batch, H, W = shape H_32, W_32 = np.int32(H), np.int32(W) BH, BW = self.get_block_shape() @@ -130,6 +163,7 @@ def set_kernel_call(self, shape: Tuple[int, int, int], sparse_attr: Any): row_num = H // RT raw_func = self._kernel zeros = torch.zeros + sparse_attr = self.attr func_code = jinja2.Template(textwrap.dedent(''' def softmax_forward_func(x, mask, T): @@ -163,7 +197,7 @@ def softmax_forward_func(x, mask, T): class SparTASparseSoftmaxBackwardKernel(SparseSoftmaxBackwardKernel, SparTASoftmaxKernel): - def set_kernel_call(self, shape: Tuple[int, int, int], sparse_attr: Any): + def set_kernel_call(self, shape: Tuple[int, int, int]): batch, H, W = shape H_32, W_32 = np.int32(H), np.int32(W) BH, BW = self.get_block_shape() @@ -172,6 +206,7 @@ def set_kernel_call(self, shape: Tuple[int, int, int], sparse_attr: Any): row_num = H // RT raw_func = self._kernel zeros = torch.zeros + sparse_attr = self.attr func_code = jinja2.Template(textwrap.dedent(''' def softmax_backward_func(gy, y, mask, T): diff --git a/sparta/operators/operator_base.py b/sparta/operators/operator_base.py index 3fc4f9b9..5850261e 100644 --- a/sparta/operators/operator_base.py +++ b/sparta/operators/operator_base.py @@ -8,180 +8,103 @@ import torch -from sparta.tesa import get_bcs_function, BCSIndexes -from sparta.kernels import KernelBase -from sparta.testing import profile - - -class SparsityAttr(object): - - def __init__( - self, - operator: SparseOperator, - param_map: Dict[str, str], - BCSR: bool, - BCSC: bool, - ): - self.BCSR = BCSR - self.BCSC = BCSC - self.mask: torch.Tensor = None - self._block_H: int = 0 - self._block_W: int = 0 - self.indexes: BCSIndexes = None - - def update_axis(self, BCSR: bool, BCSC: bool): - self.BCSR |= BCSR - self.BCSC |= BCSC - - def set_block_size(self, block_H: int, block_W: int): - if block_H != self._block_H or block_W != self._block_W: - self._block_H = block_H - self._block_W = block_W - self._update_indexes() - - def set_mask(self, mask: torch.Tensor): - self.mask = mask - self._update_indexes() - - def _update_indexes(self): - if self._block_H > 0 and self._block_W > 0 and self.mask is not None: - self.indexes = get_bcs_function( - self._block_H, self._block_W, - self.BCSR, self.BCSC, - ).build_indexes(self.mask) +from sparta.kernels import KernelBase, SparsityAttr, KernelGroup class Port(object): - def __init__(self, func: SparseOperator, name: str, fine_mask: bool = True): + def __init__(self, operator: SparseOperator, name: str): self.name = name - self.funcs: List[SparseOperator] = [func] - self.attr: SparsityAttr = None - self._sample_data: torch.Tensor = None # Always dense - self._fine_mask = fine_mask - - def set_data(self, data: torch.Tensor, grad: bool = False): - if grad: - self._sample_data.grad = data - else: - self._sample_data = data - - def get_data(self, grad: bool = False, compressed: bool = False): - data: torch.Tensor = self._sample_data.grad if grad else self._sample_data - if self.attr is not None and data is not None: - if self._fine_mask: - data = data * self.attr.mask - if compressed: - data = self.attr.indexes.convert(data.detach()) - elif not self._fine_mask: - data = data * self.attr.indexes.get_mask() + self.ops: List[SparseOperator] = [operator] + self.sample_data: torch.Tensor = None # Dense, not masked + self.attr: Optional[SparsityAttr] = None + self.compressed: bool = False + + def get_sample_data(self, grad: bool = False): + data = self.sample_data.grad if grad else self.sample_data + data = data.detach() + if self.attr is not None: + data = data * self.attr.mask + if self.compressed and self.attr.ready: + data = self.attr.indexes.convert(data) return data def clear_data(self): - self._sample_data = None + self.sample_data = None def connect(self, other: Port): - for func in other.funcs: - func.ports[other.name] = self - self.funcs.append(func) + for operator in other.ops: + operator.ports[other.name] = self + self.ops.append(operator) if self.attr is None: self.attr = other.attr elif other.attr is not None: - self.attr.update_axis(other.attr.BCSR, other.attr.BCSC) + self.attr.connect(other.attr) + # TODO: compressed class SparseOperator(torch.nn.Module): def __init__(self): super().__init__() - self._kernels: Dict[str, Dict[str, KernelBase]] = {} - self._compiled_kernels: Dict[str, KernelBase] = {} + self.kernel_groups: Dict[str, KernelGroup] = {} self.ports: Dict[str, Port] = {} - self._sparse_port: str = '' - self.forward_func = self.reference - self.shape: Tuple = None + self._attr: Optional[SparsityAttr] = None + self.forward_func: Callable = None - def get_sparse_attr(self): - return self.ports[self._sparse_port].attr + def _set_kernel_group(self, kernel_name: str, kernels: Dict[str, KernelBase]): + input_getter = lambda : self._get_sample_inputs(kernel_name) + self.kernel_groups[kernel_name] = KernelGroup(kernels, input_getter) + + @abc.abstractmethod + def _get_sample_inputs(self, kernel_name: str) -> List[torch.Tensor]: + """Get sample inputs from ports for specified kernel.""" + + def get_sparse_indexes(self): + assert self._attr is not None + return self._attr.indexes def set_mask(self, mask: torch.Tensor): - self.get_sparse_attr().set_mask(mask) + if self._attr is None: + for kernel_group in self.kernel_groups.values(): + kernel_group.attr.set_mask(mask) + else: + self._attr.set_mask(mask) def forward(self, *inputs): return self.forward_func(*inputs) @abc.abstractmethod def _set_forward(self): - """Build forward function with compiled kernels.""" + """Build forward function with compiled kernels. Set sample data if reference.""" @abc.abstractmethod - def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]) -> Tuple: """Get shape parameters from sample inputs.""" @abc.abstractmethod - def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): - """Compile kernel with params, shapes and sparse indexes.""" + def _set_sample_shape(self, sample_shape: Tuple): + """Set sample shape to kernels.""" def build( self, - config: Dict[str, Dict[str, Any]], + config: Optional[Dict[str, Dict[str, Any]]] = None, sample_inputs: Optional[List[torch.Tensor]] = None, ): if sample_inputs is not None: - self._read_sample_inputs(sample_inputs) - self._compiled_kernels: Dict[str, KernelBase] = {} - for kernel_name, params in config.items(): - if kernel_name in self._kernels: - kernel = self._kernels[kernel_name][params['_impl']] - self._compile_kernel(kernel_name, kernel, params) - self._compiled_kernels[kernel_name] = kernel + self._set_sample_shape(self._read_sample_inputs(sample_inputs)) + if config is not None: + for kernel_name, params in config.items(): + self.kernel_groups[kernel_name].build(params) + for kernel_name, kernel_group in self.kernel_groups.items(): + assert kernel_group.ready, f'{kernel_name} kernel is not built' + self._post_build() + + def _post_build(self): self._set_forward() - - def clear_sample_data(self): for port in self.ports.values(): port.clear_data() - @abc.abstractmethod - def _kernel_reference(self, kernel_name: str) -> torch.Tensor: - """Get kernel reference output from related port(s).""" - - @abc.abstractmethod - def _kernel_func_call(self, kernel_name: str) -> Callable[[], torch.Tensor]: - """Callable kernel function based on sample data of ports.""" - - def profile_kernel( - self, - kernel_name: str, - num_warmups: int = 20, - num_iters: int = 100, - cuda: bool = False, - ): - """Profile kernel latency. Note that all inputs and outputs are dense tensors here.""" - kernel_func = self._kernel_func_call(kernel_name) - target_output = self._kernel_reference(kernel_name) - return profile(kernel_func, [], [target_output], num_warmups, num_iters, cuda) - - @abc.abstractmethod - def _calc_kernel_flops(self, kernel_name: str): - """Calculate kernel flops using sparse rate and shape.""" - - def estimate_kernel(self, kernel_name: str): - kernel = self._compiled_kernels[kernel_name] - flops = self._calc_kernel_flops(kernel_name) - return kernel.estimated_latency_per_flop * flops - - @abc.abstractmethod - def reference(self, *inputs): - """Reference forward function. Sync data with ports at the same time.""" - - def get_search_space(self, backward: bool = False): - """Get search space of the sparse context.""" - pass - - def get_connections(self, backward: bool = False): - """Get cross-kernel connected hyper parameters of the sparse context.""" - pass - class SparseAutoGrad(SparseOperator): @@ -189,32 +112,16 @@ class SparseAutoGrad(SparseOperator): def _set_backward(self, backward_op: SparseOperator): self.backward_op = backward_op + backward_op.ports = self.ports + for kernel_name, kernel_group in backward_op.kernel_groups.items(): + self.kernel_groups[kernel_name] = kernel_group + if self._attr is not None: + self._attr.connect(kernel_group.attr) + backward_op._set_forward() def forward(self, *inputs): return self.__static_func__.apply(self, *inputs) - def build( - self, - config: Dict[str, Dict[str, Any]], - sample_inputs: Optional[List[torch.Tensor]] = None, - ): - super().build(config, sample_inputs) - self.backward_op.build(config, sample_inputs) - - def profile_kernel( - self, - kernel_name: str, - num_warmups: int = 20, - num_iters: int = 100, - cuda: bool = False - ): - if kernel_name in self._kernels: - return super().profile_kernel(kernel_name, num_warmups, num_iters, cuda) - else: - return self.backward_op.profile_kernel(kernel_name, num_warmups, num_iters, cuda) - - def estimate_kernel(self, kernel_name: str): - if kernel_name in self._kernels: - return super().estimate_kernel(kernel_name) - else: - return self.backward_op.estimate_kernel(kernel_name) + def _post_build(self): + self.backward_op._set_forward() + super()._post_build() diff --git a/sparta/operators/sparse_attention.py b/sparta/operators/sparse_attention.py index 1de71db4..9269a0d6 100644 --- a/sparta/operators/sparse_attention.py +++ b/sparta/operators/sparse_attention.py @@ -56,12 +56,11 @@ def __init__(self, mask: Optional[torch.Tensor] = None): self._matmul_out = SparseBatchMatMul('sdd', False, False, False, True) self._matmul_qk.ports['C'].connect(self._softmax.ports['x']) self._softmax.ports['y'].connect(self._matmul_out.ports['A']) - self._sparse_attr = self._matmul_qk.get_sparse_attr() if mask is not None: self.set_mask(mask) def set_mask(self, mask: torch.Tensor): - self._sparse_attr.set_mask(mask) + self._softmax.set_mask(mask) def forward(self, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor): QK = self._matmul_qk(Q, K) diff --git a/sparta/operators/sparse_matmul.py b/sparta/operators/sparse_matmul.py index 692804b3..585680c3 100644 --- a/sparta/operators/sparse_matmul.py +++ b/sparta/operators/sparse_matmul.py @@ -1,13 +1,12 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from typing import Any, List, Dict, Tuple, Optional +from typing import Any, Dict, List, Tuple, Optional import torch -import numpy as np -from sparta.kernels import KernelBase, SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.operators import Port, SparsityAttr, SparseOperator, SparseAutoGrad +from sparta.kernels import SparTASparseMatMulKernel, OpenAISparseMatMulKernel +from sparta.operators import Port, SparseOperator, SparseAutoGrad class SparseBatchMatMulForward(SparseOperator): @@ -32,25 +31,6 @@ def __init__( self._biased = biased self._compressed = compressed - self._sparse_axis = { - 'sdd': ['K', 'M'] if transpose_A else ['M', 'K'], - 'dsd': ['N', 'K'] if transpose_B else ['K', 'N'], - 'dds': ['M', 'N'], - }[mode] - self._BCSR = { - 'sdd': not transpose_A, - 'dsd': transpose_B, - 'dds': True, - }[mode] - - self._sparse_port = 'ABC'[mode.find('s')] - self.ports['A'] = Port(self, 'A') - self.ports['B'] = Port(self, 'B') - self.ports['C'] = Port(self, 'C', fine_mask=False) # DDS known issue - if biased: - self.ports['bias'] = Port(self, 'bias') - self.ports[self._sparse_port].attr = SparsityAttr(self._BCSR, not self._BCSR) - specs = { 'mode': mode, 'biased': biased, @@ -59,12 +39,26 @@ def __init__( 'compressed': compressed, 'batched': self.__batched__, } - self._kernels['forward'] = { + self._set_kernel_group('forward', { 'sparta': SparTASparseMatMulKernel(**specs), 'openai': OpenAISparseMatMulKernel(**specs), - } + }) + + for p in ['A', 'B', 'C', 'bias'] if biased else ['A', 'B', 'C']: + self.ports[p] = Port(self, p) + sparse_port = self.ports['ABC'[mode.find('s')]] + sparse_port.attr = self.kernel_groups['forward'].attr + sparse_port.compressed = compressed + if compressed: + self._attr = sparse_port.attr - self.shape: Tuple[int, int, int, int] = None + self._set_forward() + + def _get_sample_inputs(self, kernel_name: str): + return [ + self.ports[p].get_sample_data() + for p in (['A', 'B', 'bias'] if self._biased else ['A', 'B']) + ] def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): # TODO: check shape conflicts @@ -80,57 +74,25 @@ def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): N, K = sample_inputs[1].shape[-2:] else: K, N = sample_inputs[1].shape[-2:] - self.shape = (batch_size, M, K, N) - self.ports['A'].set_data(sample_inputs[0]) - self.ports['B'].set_data(sample_inputs[1]) - if self._biased: - self.ports['bias'].set_data(sample_inputs[2]) + return (batch_size, M, K, N) - def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): - sparse_attr = self.get_sparse_attr() - kernel.set_parameter('BCSR', self._BCSR or sparse_attr.BCSR) - kernel.set_parameter('BCSC', not self._BCSR) - block_size = [params[f'BLOCK_SIZE_{axis}_VALUE'] for axis in self._sparse_axis] - sparse_attr.set_block_size(*block_size) - kernel.compile(params, self.shape, sparse_attr) + def _set_sample_shape(self, sample_shape: Tuple): + self.kernel_groups['forward'].set_sample_shape(sample_shape) def _set_forward(self): - if 'forward' in self._compiled_kernels: - self.forward_func = self._compiled_kernels['forward'] - - def _kernel_func_call(self, kernel_name: str): - A = self.ports['A'].get_data(compressed=self._compressed) - B = self.ports['B'].get_data(compressed=self._compressed) - kernel = self._compiled_kernels[kernel_name] - if self._biased: - bias = self.ports['bias'].get_data() - return lambda : kernel(A, B, bias) + kernel_group = self.kernel_groups['forward'] + if kernel_group.ready: + self.forward_func = kernel_group.active_kernel else: - return lambda : kernel(A, B) - - def _kernel_reference(self, kernel_name: str): - return self.ports['C'].get_data(compressed=self._compressed) - - def _calc_kernel_flops(self, kernel_name: str): - indexes = self.get_sparse_attr().indexes - sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num - return np.prod(self.shape) * sparse_rate - - def reference(self, *inputs): - if len(inputs) > 0: - self._read_sample_inputs(inputs) - A = self.ports['A'].get_data(compressed=False) - B = self.ports['B'].get_data(compressed=False) - if self._transpose_A: - A = A.swapaxes(self.__batched__ + 0, self.__batched__ + 1) - if self._transpose_B: - B = B.swapaxes(self.__batched__ + 0, self.__batched__ + 1) - C = torch.bmm(A, B) if self.__batched__ else torch.mm(A, B) - if self._biased: - bias = self.ports['bias'].get_data() - C += bias.unsqueeze(1) if self.__batched__ else bias - self.ports['C'].set_data(C) - return C + def forward_func(*inputs): + self.ports['A'].sample_data = inputs[0] + self.ports['B'].sample_data = inputs[1] + if self._biased: + self.ports['bias'].sample_data = inputs[2] + C = kernel_group.active_kernel(*inputs) + self.ports['C'].sample_data = C + return C + self.forward_func = forward_func class SparseMatMulForward(SparseBatchMatMulForward): @@ -149,36 +111,17 @@ def __init__( transpose_B: bool, biased: bool, compressed: bool, - ports: Dict[str, Port], ): if mode not in ['sdd', 'dsd', 'dds']: raise ValueError(f'invalid sparse matmul mode: {mode}') super().__init__() - self._mode = mode self._transpose_A = transpose_A self._transpose_B = transpose_B self._biased = biased self._compressed = compressed - self.ports = ports - self._sparse_port = 'ABC'[mode.find('s')] - - self._BCSR = { - 'backward:A': { - 'sdd': True, - 'dsd': not transpose_B, - 'dds': True, - }[mode], - 'backward:B': { - 'sdd': transpose_A, - 'dsd': True, - 'dds': False, - }[mode], - } - self.get_sparse_attr().update_axis(True, True) - A_spec = { 'mode': ''.join(mode[i] for i in ([1, 2, 0] if transpose_A else [2, 1, 0])), 'biased': False, @@ -196,50 +139,63 @@ def __init__( 'batched': self.__batched__, } - self._kernels['backward:A'] = { + self._set_kernel_group('backward:A', { 'sparta': SparTASparseMatMulKernel(**A_spec), 'openai': OpenAISparseMatMulKernel(**A_spec), - } - self._kernels['backward:B'] = { + }) + self._set_kernel_group('backward:B', { 'sparta': SparTASparseMatMulKernel(**B_spec), 'openai': OpenAISparseMatMulKernel(**B_spec), - } + }) + + # TODO: connect sparsity attrs for single use + # TODO: set ports for single use + # TODO: set forward function for single use - self.shape: Tuple[int, int, int, int] = None + def _get_sample_inputs(self, kernel_name: str): + grad_C = self.ports['C'].get_sample_data(grad=True) + if kernel_name == 'backward:A': + B = self.ports['B'].get_sample_data() + if self._transpose_A: + return [B, grad_C] + else: + return [grad_C, B] + elif kernel_name == 'backward:B': + A = self.ports['A'].get_sample_data() + if self._transpose_B: + return [grad_C, A] + else: + return [A, grad_C] + else: + raise ValueError(f'unrecognized kernel name: {kernel_name}') def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): pass - def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): - batch, M, K, N = self.shape - shape = { - 'backward:A': (batch, K, N, M) if self._transpose_A else (batch, M, N, K), - 'backward:B': (batch, N, M, K) if self._transpose_B else (batch, K, M, N), - }[kernel_name] - sparse_attr = self.get_sparse_attr() - kernel.set_parameter('BCSR', self._BCSR[kernel_name] or sparse_attr.BCSR) - kernel.set_parameter('BCSC', not self._BCSR[kernel_name]) - kernel.compile(params, shape, sparse_attr) - - def _set_forward(self): - if 'backward:A' in self._compiled_kernels: - kernel_A = self._compiled_kernels['backward:A'] + def _set_sample_shape(self, sample_shape: Tuple): + batch_size, M, K, N = sample_shape + if self._transpose_A: + self.kernel_groups['backward:A'].set_sample_shape((batch_size, K, N, M)) else: - kernel_A = lambda *inputs: None - if 'backward:B' in self._compiled_kernels: - kernel_B = self._compiled_kernels['backward:B'] + self.kernel_groups['backward:A'].set_sample_shape((batch_size, M, N, K)) + if self._transpose_B: + self.kernel_groups['backward:B'].set_sample_shape((batch_size, N, M, K)) else: - kernel_B = lambda *inputs: None + self.kernel_groups['backward:B'].set_sample_shape((batch_size, K, M, N)) + + def _set_forward(self): + kg_A = self.kernel_groups['backward:A'] + kg_B = self.kernel_groups['backward:B'] if self._transpose_A: - backward_A = lambda grad_C, B: kernel_A(B, grad_C) + backward_A = lambda grad_C, B: kg_A.active_kernel(B, grad_C) else: - backward_A = lambda grad_C, B: kernel_A(grad_C, B) + backward_A = lambda grad_C, B: kg_A.active_kernel(grad_C, B) if self._transpose_B: - backward_B = lambda grad_C, A: kernel_B(grad_C, A) + backward_B = lambda grad_C, A: kg_B.active_kernel(grad_C, A) else: - backward_B = lambda grad_C, A: kernel_B(A, grad_C) - if self._mode == 'dds' and self._compressed: - C_attr = self.ports['C'].attr + backward_B = lambda grad_C, A: kg_B.active_kernel(A, grad_C) + C_attr = self.ports['C'].attr + if C_attr is not None and self._compressed: backward_bias = lambda grad_C: C_attr.indexes.sum(grad_C, axis=-2) else: backward_bias = lambda grad_C: grad_C.sum(-2) @@ -256,40 +212,6 @@ def backward(grad, A, B, needs_grad): self.forward_func = backward - def _kernel_func_call(self, kernel_name: str): - grad_C = self.ports['C'].get_data(grad=True, compressed=self._compressed) - kernel = self._compiled_kernels[kernel_name] - if kernel_name == 'backward:A': - B = self.ports['B'].get_data(compressed=self._compressed) - if self._transpose_A: - return lambda : kernel(B, grad_C) - else: - return lambda : kernel(grad_C, B) - elif kernel_name == 'backward:B': - A = self.ports['A'].get_data(compressed=self._compressed) - if self._transpose_A: - return lambda : kernel(grad_C, A) - else: - return lambda : kernel(A, grad_C) - else: - raise ValueError(f'kernel not found: {kernel_name}') - - def _kernel_reference(self, kernel_name: str): - if kernel_name == 'backward:A': - return self.ports['A'].get_data(grad=True, compressed=False) - elif kernel_name == 'backward:B': - return self.ports['B'].get_data(grad=True, compressed=False) - else: - raise ValueError(f'kernel not found: {kernel_name}') - - def _calc_kernel_flops(self, kernel_name: str): - indexes = self.get_sparse_attr().indexes - sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num - return np.prod(self.shape) * sparse_rate - - def reference(self, *inputs): - pass - class SparseMatMulBackward(SparseBatchMatMulBackward): @@ -326,19 +248,24 @@ def __init__( biased: bool, compressed: bool = True, ): - super().__init__(mode, transpose_A, transpose_B, biased, compressed) + super().__init__( + mode=mode, + transpose_A=transpose_A, + transpose_B=transpose_B, + biased=biased, + compressed=compressed, + ) self._set_backward(SparseBatchMatMulBackward( mode=mode, transpose_A=transpose_A, transpose_B=transpose_B, biased=biased, compressed=compressed, - ports=self.ports, )) - def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): - super()._read_sample_inputs(sample_inputs) - self.backward_op.shape = self.shape + def _set_sample_shape(self, sample_shape: Tuple): + super()._set_sample_shape(sample_shape) + self.backward_op._set_sample_shape(sample_shape) class SparseMatMul(SparseAutoGrad, SparseMatMulForward): @@ -353,19 +280,24 @@ def __init__( biased: bool, compressed: bool = True, ): - super().__init__(mode, transpose_A, transpose_B, biased, compressed) + super().__init__( + mode=mode, + transpose_A=transpose_A, + transpose_B=transpose_B, + biased=biased, + compressed=compressed, + ) self._set_backward(SparseMatMulBackward( mode=mode, transpose_A=transpose_A, transpose_B=transpose_B, biased=biased, compressed=compressed, - ports=self.ports, )) - def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): - super()._read_sample_inputs(sample_inputs) - self.backward_op.shape = self.shape + def _set_sample_shape(self, sample_shape: Tuple): + super()._set_sample_shape(sample_shape) + self.backward_op._set_sample_shape(sample_shape) class SparseLinear(SparseMatMul): @@ -376,38 +308,31 @@ def __init__(self, raw_module: torch.nn.Linear, mode: str = 'dsd'): super().__init__(mode, False, True, self._biased, self._compressed) self.weight: torch.nn.Parameter = None self.bias = raw_module.bias - self.ports['B'].set_data(raw_module.weight) + self.ports['B'].sample_data = raw_module.weight if self._biased: - self.ports['bias'].set_data(raw_module.bias) + self.ports['bias'].sample_data = raw_module.bias self.forward = self._forward_with_bias else: self.forward = self._forward_without_bias + def _forward_with_bias(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x, self.weight, self.bias) + + def _forward_without_bias(self, x: torch.Tensor) -> torch.Tensor: + return super().forward(x, self.weight) + def _update_weight(self): - if 'forward' in self._compiled_kernels: - weight = self.ports['B'].get_data(compressed=self._compressed) - self.weight = torch.nn.Parameter(weight, requires_grad=True) + weight = self.ports['B'].get_sample_data() + self.weight = torch.nn.Parameter(weight, requires_grad=True) def set_mask(self, mask: torch.Tensor): super().set_mask(mask) self._update_weight() - def build( - self, - config: Dict[str, Dict[str, Any]], - sample_inputs: Optional[List[torch.Tensor]] = None, - ): - super().build(config, sample_inputs) + def _post_build(self): self._update_weight() + return super()._post_build() def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): - sample_inputs.append(self.ports['B'].get_data()) - if self._biased: - sample_inputs.append(self.ports['bias'].get_data()) - super()._read_sample_inputs(sample_inputs) - - def _forward_with_bias(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x, self.weight, self.bias) - - def _forward_without_bias(self, x: torch.Tensor) -> torch.Tensor: - return super().forward(x, self.weight) + sample_inputs.append(self.ports['B'].sample_data) + return super()._read_sample_inputs(sample_inputs) diff --git a/sparta/operators/sparse_softmax.py b/sparta/operators/sparse_softmax.py index 2fe49c93..de5c74cb 100644 --- a/sparta/operators/sparse_softmax.py +++ b/sparta/operators/sparse_softmax.py @@ -6,9 +6,8 @@ import torch import numpy as np -from sparta.kernels import KernelBase, SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.operators.operator_base import Port, SparsityAttr, SparseOperator, SparseAutoGrad -from sparta.testing import sparse_softmax_forward_reference, sparse_softmax_backward_reference +from sparta.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel +from sparta.operators.operator_base import Port, SparseOperator, SparseAutoGrad class SparseBatchSoftmaxForward(SparseOperator): @@ -21,76 +20,58 @@ def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): self._compressed = compressed self._T = None if temperature is None else np.float32(1 / temperature) - self._sparse_port = 'y' - sparse_attr = SparsityAttr(True, False) - for port_name in ['x', 'y']: - self.ports[port_name] = Port(self, port_name) - self.ports[port_name].attr = sparse_attr - - self._kernels[self.__direction__] = { + self._set_kernel_group(self.__direction__, { 'sparta': { 'forward': SparTASparseSoftmaxForwardKernel, 'backward': SparTASparseSoftmaxBackwardKernel, }[self.__direction__]( compressed=compressed, batched=self.__batched__, - ), - } + ) + }) - self.shape: Tuple[int, int, int] = None + self._sparse_port = 'y' + sparse_attr = self.kernel_groups[self.__direction__].attr + for port_name in ['x', 'y']: + self.ports[port_name] = Port(self, port_name) + self.ports[port_name].attr = sparse_attr + if compressed: + self._attr = sparse_attr def set_temperature(self, temperature: float): self._T = np.float32(1 / temperature) + def _get_sample_inputs(self, kernel_name: str): + return [ + self.ports['x'].get_sample_data(), + self.ports['x'].attr.mask, + self._T, + ] + def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): x = sample_inputs[0] - if self.__batched__: - batch_size = x.shape[0] - else: - batch_size = 1 + batch_size = x.shape[0] if self.__batched__ else 1 H, W = x.shape[-2:] - self.shape = (batch_size, H, W) - self.ports['x'].set_data(x) - if self._T is None: - self.set_temperature(np.sqrt(W)) + return (batch_size, H, W) - def _compile_kernel(self, kernel_name: str, kernel: KernelBase, params: Dict[str, Any]): - sparse_attr = self.get_sparse_attr() - sparse_attr.set_block_size( - block_H=params['BLOCK_SIZE_H_VALUE'], - block_W=params['BLOCK_SIZE_W_VALUE'], - ) - kernel.set_parameter('MAX_W_VALUE', self.shape[-1]) - kernel.compile(params, self.shape, sparse_attr) + def _set_sample_shape(self, sample_shape: Tuple): + self.kernel_groups[self.__direction__].set_sample_shape(sample_shape) + if self._T is None: + self.set_temperature(np.sqrt(sample_shape[-1])) def _set_forward(self): - if self.__direction__ in self._compiled_kernels: - kernel = self._compiled_kernels[self.__direction__] - sparse_attr = self.get_sparse_attr() - self.forward_func = lambda *inputs: kernel(*inputs, sparse_attr.mask, self._T) - - def _kernel_func_call(self, kernel_name: str): - x = self.ports['x'].get_data(compressed=self._compressed) - sparse_attr = self.get_sparse_attr() - kernel = self._compiled_kernels[kernel_name] - return lambda : kernel(x, sparse_attr.mask, self._T) - - def _kernel_reference(self, kernel_name: str): - return self.ports['y'].get_data(compressed=self._compressed) - - def _calc_kernel_flops(self, kernel_name: str): - indexes = self.get_sparse_attr().indexes - sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num - return np.prod(self.shape) * sparse_rate - - def reference(self, *inputs): - if len(inputs) > 0: - self._read_sample_inputs(inputs) - x = self.ports['x'].get_data(compressed=False) - mask = self.get_sparse_attr().mask - y = sparse_softmax_forward_reference(x, mask, 1 / self._T) - self.ports['y'].set_data(y) - return y + kernel_group = self.kernel_groups[self.__direction__] + attr = kernel_group.attr + if kernel_group.ready: + def forward_func(x): + return kernel_group.active_kernel(x, attr.mask, self._T) + else: + def forward_func(x): + self.ports['x'].sample_data = x + y = kernel_group.active_kernel(x, attr.mask, self._T) + self.ports['y'].sample_data = y + return y + self.forward_func = forward_func class SparseSoftmaxForward(SparseBatchSoftmaxForward): @@ -103,21 +84,20 @@ class SparseBatchSoftmaxBackward(SparseBatchSoftmaxForward): __batched__ = True __direction__ = 'backward' - def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): - pass - - def _kernel_func_call(self, kernel_name: str): - gy = self.ports['y'].get_data(grad=True, compressed=self._compressed) - y = self.ports['y'].get_data(grad=False, compressed=self._compressed) - sparse_attr = self.get_sparse_attr() - kernel = self._compiled_kernels[kernel_name] - return lambda : kernel(gy, y, sparse_attr.mask, self._T) - - def _kernel_reference(self, kernel_name: str): - return self.ports['x'].get_data(grad=True, compressed=self._compressed) + def _get_sample_inputs(self, kernel_name: str): + return [ + self.ports['y'].get_sample_data(grad=True), + self.ports['y'].get_sample_data(), + self.ports['y'].attr.mask, + self._T, + ] - def reference(self, *inputs): - pass + def _set_forward(self): + kernel_group = self.kernel_groups[self.__direction__] + attr = kernel_group.attr + def backward_func(gy, y): + return kernel_group.active_kernel(gy, y, attr.mask, self._T) + self.forward_func = backward_func class SparseSoftmaxBackward(SparseBatchSoftmaxBackward): @@ -154,15 +134,14 @@ class SparseBatchSoftmax(SparseAutoGrad, SparseBatchSoftmaxForward): def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): super().__init__(compressed, temperature) self._set_backward(SparseBatchSoftmaxBackward(compressed, temperature)) - self.backward_op.ports = self.ports def set_temperature(self, temperature: float): super().set_temperature(temperature) self.backward_op.set_temperature(temperature) - def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): - super()._read_sample_inputs(sample_inputs) - self.backward_op.shape = self.shape + def _set_sample_shape(self, sample_shape: Tuple): + super()._set_sample_shape(sample_shape) + self.backward_op._set_sample_shape(sample_shape) class SparseSoftmax(SparseAutoGrad, SparseSoftmaxForward): @@ -172,12 +151,11 @@ class SparseSoftmax(SparseAutoGrad, SparseSoftmaxForward): def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): super().__init__(compressed, temperature) self._set_backward(SparseSoftmaxBackward(compressed, temperature)) - self.backward_op.ports = self.ports def set_temperature(self, temperature: float): super().set_temperature(temperature) self.backward_op.set_temperature(temperature) - def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): - super()._read_sample_inputs(sample_inputs) - self.backward_op.shape = self.shape + def _set_sample_shape(self, sample_shape: Tuple): + super()._set_sample_shape(sample_shape) + self.backward_op._set_sample_shape(sample_shape) diff --git a/test/unit/test_sparse_matmul.py b/test/unit/test_sparse_matmul.py index b36b04ab..cbe78e0e 100644 --- a/test/unit/test_sparse_matmul.py +++ b/test/unit/test_sparse_matmul.py @@ -7,7 +7,7 @@ import pytest from sparta.kernels import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel -from sparta.operators import SparsityAttr, SparseLinear, SparseMatMul, SparseBatchMatMul +from sparta.operators import SparseLinear, SparseMatMul, SparseBatchMatMul from sparta.tesa import BCSIndexes from sparta.testing import block_mask @@ -186,34 +186,14 @@ def test_sparse_matmul_kernel( compressed=compressed, batched=batched, ) - shape = (batch, M, K, N) + kernel.attr.set_mask(mask) + batch = 1 if batch is None else batch + kernel.compile(get_params(impl), (batch, M, K, N)) + sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] - BCSR = { - 'A': not trans_A, - 'B': trans_B, - 'C': True, - }[sparse_port] - BCSC = not BCSR - attr = SparsityAttr(BCSR, BCSC) - attr.set_mask(mask) - kernel.set_parameter('BCSR', BCSR) - kernel.set_parameter('BCSC', BCSC) - kernel.compile(get_params(impl), shape, attr) - - sparse_axis = { - 'A': ['K', 'M'] if trans_A else ['M', 'K'], - 'B': ['N', 'K'] if trans_B else ['K', 'N'], - 'C': ['M', 'N'], - }[sparse_port] - attr.BCSR = BCSR - attr.BCSC = BCSC - attr.set_block_size(*[ - kernel.get_parameter(f'BLOCK_SIZE_{i}_VALUE') - for i in sparse_axis - ]) if compressed: - data, mask = compress_data(attr.indexes, sparse_port, data, mask, False) + data, mask = compress_data(kernel.attr.indexes, sparse_port, data, mask, False) inputs = ['A', 'B', 'bias'] if biased else ['A', 'B'] input_data = [data[f'input_{x}'] for x in inputs] @@ -261,14 +241,13 @@ def test_sparse_matmul_operator( config={kernel_name: get_params('sparta') for kernel_name in kernel_names}, sample_inputs=[data[f'input_{x}'] for x in inputs] ) - sparse_matmul.clear_sample_data() sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] def run_test(): nonlocal sparse_matmul, data, mask if compressed: - indexes = sparse_matmul.get_sparse_attr().indexes + indexes = sparse_matmul.get_sparse_indexes() data, cmask = compress_data(indexes, sparse_port, data, mask, True) else: cmask = mask @@ -339,14 +318,13 @@ def test_sparse_linear_operator( config={kernel_name: get_params('sparta') for kernel_name in kernel_names}, sample_inputs=[data['input_A']], ) - sparse_linear.clear_sample_data() sparse_port = {'sdd': 'A', 'dsd': 'B', 'dds': 'C'}[mode] def run_test(): nonlocal sparse_linear, data, mask if mode == 'dsd': - indexes = sparse_linear.get_sparse_attr().indexes + indexes = sparse_linear.get_sparse_indexes() data, cmask = compress_data(indexes, sparse_port, data, mask, True) else: cmask = mask @@ -368,7 +346,7 @@ def run_test(): mode, False, True, biased, True, random_seed=2023 ) - sparse_linear.ports['B'].set_data(data['input_B']) + sparse_linear.ports['B'].sample_data = data['input_B'] if biased: sparse_linear.bias = torch.nn.Parameter(data['input_bias']) sparse_linear.set_mask(mask) @@ -389,7 +367,7 @@ def run_test(): ) weight = data['input_B'] if mode == 'dsd': - weight = sparse_linear.get_sparse_attr().indexes.convert(weight.detach()) + weight = sparse_linear.get_sparse_indexes().convert(weight.detach()) sparse_linear.weight = torch.nn.Parameter(weight) if biased: sparse_linear.bias = torch.nn.Parameter(data['input_bias']) diff --git a/test/unit/test_sparse_softmax.py b/test/unit/test_sparse_softmax.py index e8b253bb..7bf8e1a6 100644 --- a/test/unit/test_sparse_softmax.py +++ b/test/unit/test_sparse_softmax.py @@ -8,7 +8,7 @@ import numpy as np from sparta.kernels import SparTASparseSoftmaxForwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.operators import SparsityAttr, SparseSoftmax, SparseBatchSoftmax +from sparta.operators import SparseSoftmax, SparseBatchSoftmax from sparta.testing import block_mask, sparse_softmax_forward_reference @@ -79,24 +79,20 @@ def test_sparse_softmax_kernels( forward_kernel = SparTASparseSoftmaxForwardKernel(compressed, batched) backward_kernel = SparTASparseSoftmaxBackwardKernel(compressed, batched) - attr = SparsityAttr(True, False) - attr.set_mask(mask) + forward_kernel.attr.connect(backward_kernel.attr) + forward_kernel.attr.set_mask(mask) - shape = (batch, H, W) + batch = 1 if batch is None else batch forward_kernel.set_parameter('MAX_W_VALUE', W) backward_kernel.set_parameter('MAX_W_VALUE', W) - forward_kernel.compile(get_params(), shape, attr) - backward_kernel.compile(get_params(), shape, attr) + forward_kernel.compile(get_params(), (batch, H, W)) + backward_kernel.compile(get_params(), (batch, H, W)) - attr.set_block_size( - forward_kernel.get_parameter('BLOCK_SIZE_H_VALUE'), - forward_kernel.get_parameter('BLOCK_SIZE_W_VALUE'), - ) temperature = np.float32(1 / np.sqrt(W)) if compressed: for name in data: - data[name] = attr.indexes.convert(data[name].detach()) + data[name] = forward_kernel.attr.indexes.convert(data[name].detach()) data['output_y'] = forward_kernel(data['input_x'], mask, temperature) data['output_grad_x'] = backward_kernel(data['grad_y'], data['target_y'], mask, temperature) @@ -120,8 +116,7 @@ def test_sparse_softmax_operator( else: sparse_softmax = SparseBatchSoftmax(compressed, np.sqrt(W)) - sparse_attr = sparse_softmax.get_sparse_attr() - sparse_attr.set_mask(mask) + sparse_softmax.set_mask(mask) kernel_names = ['forward', 'backward'] sparse_softmax.build( @@ -132,8 +127,9 @@ def test_sparse_softmax_operator( def run_test(): nonlocal sparse_softmax, data, mask if compressed: + indexes = sparse_softmax.get_sparse_indexes() for name in data: - data[name] = sparse_attr.indexes.convert(data[name].detach()) + data[name] = indexes.convert(data[name].detach()) data['input_x'].requires_grad = True data['output_y'] = sparse_softmax(data['input_x']) data['output_y'].backward(data['grad_y']) From 604a5b303fc570eb63f2533e2d971d15513a5be9 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 21 Mar 2023 08:28:02 +0000 Subject: [PATCH 15/28] fix port connection --- setup.py | 1 + sparta/kernels/kernel_base.py | 4 ++- sparta/operators/operator_base.py | 46 +++++++++++++++++------------- sparta/operators/sparse_matmul.py | 14 ++++----- sparta/operators/sparse_softmax.py | 13 ++++----- sparta/testing/mask.py | 26 ++++++++--------- test/unit/test_sparse_matmul.py | 2 +- test/unit/test_sparse_softmax.py | 2 +- 8 files changed, 56 insertions(+), 52 deletions(-) diff --git a/setup.py b/setup.py index c6918043..f3849b07 100644 --- a/setup.py +++ b/setup.py @@ -21,6 +21,7 @@ os.makedirs(os.path.join('csrc', 'build'), exist_ok=True) with open(os.path.join('csrc', 'build', 'moe_sparse_forward_kernel.cu'), 'w') as f: f.write(moe_kernel) + moe_ext = CUDAExtension( name='sparse_moe_cpp', sources=[ diff --git a/sparta/kernels/kernel_base.py b/sparta/kernels/kernel_base.py index a1b6d40a..f165bb06 100644 --- a/sparta/kernels/kernel_base.py +++ b/sparta/kernels/kernel_base.py @@ -97,7 +97,7 @@ def connect(self, other: SparsityAttr): self.groups.append(kernel_group) kernel_group.attr = self - def get_search_space(self): + def get_search_space(self, backward: bool = False): pass @@ -105,9 +105,11 @@ class KernelGroup(object): def __init__( self, + kernel_name: str, kernels: Dict[str, KernelBase], input_getter: Callable[[], List[torch.Tensor]], ): + self.for_backward = kernel_name.startswith('backward') self._kernels = kernels self._get_inputs = input_getter kernel_list = list(kernels.values()) diff --git a/sparta/operators/operator_base.py b/sparta/operators/operator_base.py index 5850261e..835338f7 100644 --- a/sparta/operators/operator_base.py +++ b/sparta/operators/operator_base.py @@ -17,16 +17,17 @@ def __init__(self, operator: SparseOperator, name: str): self.name = name self.ops: List[SparseOperator] = [operator] self.sample_data: torch.Tensor = None # Dense, not masked - self.attr: Optional[SparsityAttr] = None + self.get_attr: Callable[[], SparsityAttr] = lambda : None self.compressed: bool = False def get_sample_data(self, grad: bool = False): data = self.sample_data.grad if grad else self.sample_data data = data.detach() - if self.attr is not None: - data = data * self.attr.mask - if self.compressed and self.attr.ready: - data = self.attr.indexes.convert(data) + attr = self.get_attr() + if attr is not None: + data = data * attr.mask + if self.compressed and attr.ready: + data = attr.indexes.convert(data) return data def clear_data(self): @@ -36,40 +37,44 @@ def connect(self, other: Port): for operator in other.ops: operator.ports[other.name] = self self.ops.append(operator) - if self.attr is None: - self.attr = other.attr - elif other.attr is not None: - self.attr.connect(other.attr) - # TODO: compressed + self_attr = self.get_attr() + other_attr = other.get_attr() + if self_attr is None: + self.get_attr = other.get_attr + elif other_attr is None: + other.get_attr = self.get_attr + else: + self_attr.connect(other_attr) class SparseOperator(torch.nn.Module): - def __init__(self): + def __init__(self, compressed: bool): super().__init__() + self._compressed = compressed self.kernel_groups: Dict[str, KernelGroup] = {} self.ports: Dict[str, Port] = {} - self._attr: Optional[SparsityAttr] = None + self._sparse_port: Port = None self.forward_func: Callable = None def _set_kernel_group(self, kernel_name: str, kernels: Dict[str, KernelBase]): input_getter = lambda : self._get_sample_inputs(kernel_name) - self.kernel_groups[kernel_name] = KernelGroup(kernels, input_getter) + self.kernel_groups[kernel_name] = KernelGroup(kernel_name, kernels, input_getter) @abc.abstractmethod def _get_sample_inputs(self, kernel_name: str) -> List[torch.Tensor]: """Get sample inputs from ports for specified kernel.""" def get_sparse_indexes(self): - assert self._attr is not None - return self._attr.indexes + assert self._compressed, 'only compressed sparse operators can export sparse indexes' + return self._sparse_port.get_attr().indexes def set_mask(self, mask: torch.Tensor): - if self._attr is None: + if self._compressed: + self._sparse_port.get_attr().set_mask(mask) + else: for kernel_group in self.kernel_groups.values(): kernel_group.attr.set_mask(mask) - else: - self._attr.set_mask(mask) def forward(self, *inputs): return self.forward_func(*inputs) @@ -115,8 +120,9 @@ def _set_backward(self, backward_op: SparseOperator): backward_op.ports = self.ports for kernel_name, kernel_group in backward_op.kernel_groups.items(): self.kernel_groups[kernel_name] = kernel_group - if self._attr is not None: - self._attr.connect(kernel_group.attr) + if self._compressed: + self_attr = self._sparse_port.get_attr() + self_attr.connect(kernel_group.attr) backward_op._set_forward() def forward(self, *inputs): diff --git a/sparta/operators/sparse_matmul.py b/sparta/operators/sparse_matmul.py index 585680c3..06d7245e 100644 --- a/sparta/operators/sparse_matmul.py +++ b/sparta/operators/sparse_matmul.py @@ -24,7 +24,7 @@ def __init__( if mode not in ['sdd', 'dsd', 'dds']: raise ValueError(f'invalid sparse matmul mode: {mode}') - super().__init__() + super().__init__(compressed) self._transpose_A = transpose_A self._transpose_B = transpose_B @@ -46,11 +46,9 @@ def __init__( for p in ['A', 'B', 'C', 'bias'] if biased else ['A', 'B', 'C']: self.ports[p] = Port(self, p) - sparse_port = self.ports['ABC'[mode.find('s')]] - sparse_port.attr = self.kernel_groups['forward'].attr - sparse_port.compressed = compressed - if compressed: - self._attr = sparse_port.attr + self._sparse_port = self.ports['ABC'[mode.find('s')]] + self._sparse_port.get_attr = lambda : self.kernel_groups['forward'].attr + self._sparse_port.compressed = compressed self._set_forward() @@ -115,7 +113,7 @@ def __init__( if mode not in ['sdd', 'dsd', 'dds']: raise ValueError(f'invalid sparse matmul mode: {mode}') - super().__init__() + super().__init__(compressed) self._transpose_A = transpose_A self._transpose_B = transpose_B @@ -194,7 +192,7 @@ def _set_forward(self): backward_B = lambda grad_C, A: kg_B.active_kernel(grad_C, A) else: backward_B = lambda grad_C, A: kg_B.active_kernel(A, grad_C) - C_attr = self.ports['C'].attr + C_attr = self.ports['C'].get_attr() if C_attr is not None and self._compressed: backward_bias = lambda grad_C: C_attr.indexes.sum(grad_C, axis=-2) else: diff --git a/sparta/operators/sparse_softmax.py b/sparta/operators/sparse_softmax.py index de5c74cb..ffc22939 100644 --- a/sparta/operators/sparse_softmax.py +++ b/sparta/operators/sparse_softmax.py @@ -15,7 +15,7 @@ class SparseBatchSoftmaxForward(SparseOperator): __direction__ = 'forward' def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): - super().__init__() + super().__init__(compressed) self._compressed = compressed self._T = None if temperature is None else np.float32(1 / temperature) @@ -30,13 +30,10 @@ def __init__(self, compressed: bool = False, temperature: Optional[float] = 1): ) }) - self._sparse_port = 'y' - sparse_attr = self.kernel_groups[self.__direction__].attr for port_name in ['x', 'y']: self.ports[port_name] = Port(self, port_name) - self.ports[port_name].attr = sparse_attr - if compressed: - self._attr = sparse_attr + self.ports[port_name].get_attr = lambda : self.kernel_groups[self.__direction__].attr + self._sparse_port = self.ports['y'] def set_temperature(self, temperature: float): self._T = np.float32(1 / temperature) @@ -44,7 +41,7 @@ def set_temperature(self, temperature: float): def _get_sample_inputs(self, kernel_name: str): return [ self.ports['x'].get_sample_data(), - self.ports['x'].attr.mask, + self.ports['x'].get_attr().mask, self._T, ] @@ -88,7 +85,7 @@ def _get_sample_inputs(self, kernel_name: str): return [ self.ports['y'].get_sample_data(grad=True), self.ports['y'].get_sample_data(), - self.ports['y'].attr.mask, + self.ports['y'].get_attr().mask, self._T, ] diff --git a/sparta/testing/mask.py b/sparta/testing/mask.py index a2b213f9..35a06df6 100644 --- a/sparta/testing/mask.py +++ b/sparta/testing/mask.py @@ -8,7 +8,7 @@ def block_mask( shape: Tuple[int], - block: Tuple[int] = (1, 1), + granularity: Tuple[int] = (1, 1), sparsity: float = 0.99, algo: str = 'rand', device: Any = 'cuda', @@ -17,23 +17,23 @@ def block_mask( Args: shape (Tuple[int]): Mask shape. - block (Tuple[int]): Block shape. (1, 1) means finegrained mask. - sparsity (float): The ratio of empty block number to total block number. + granularity (Tuple[int]): block shape. (1, 1) means finegrained mask. + sparsity (float): The ratio of empty blocks. algo (str): Algorithm to generate mask. Only random generator is supported now. """ assert len(shape) == 2, 'only 2D mask is supported' - assert len(block) == 2, 'only 2D mask is supported' - assert shape[0] % block[0] == 0, f'invalid block shape {block}' - assert shape[1] % block[1] == 0, f'invalid block shape {block}' + assert len(granularity) == 2, 'only 2D mask is supported' + assert shape[0] % granularity[0] == 0, f'invalid granularity shape {granularity}' + assert shape[1] % granularity[1] == 0, f'invalid granularity shape {granularity}' if algo == 'rand': - return random_block_mask(shape, block, sparsity, device) + return random_block_mask(shape, granularity, sparsity, device) else: raise ValueError(f'unsupported mask generator: {algo}') def random_block_mask( shape: Tuple[int], - block: Tuple[int], + granularity: Tuple[int], sparsity: float = 0.99, device: Any = 'cuda', ): @@ -41,12 +41,12 @@ def random_block_mask( Args: shape (Tuple[int]): Mask shape. - block (Tuple[int]): Block shape. - sparsity (float): The ratio of empty block number to total block number. + granularity (Tuple[int]): block shape. + sparsity (float): The ratio of empty blocks. """ - compressed_shape = (shape[0] // block[0], shape[1] // block[1]) + compressed_shape = (shape[0] // granularity[0], shape[1] // granularity[1]) mask = random_mask(compressed_shape, sparsity, device) - mask = mask.reshape(compressed_shape + (1, 1)).tile((1, 1) + block) + mask = mask.reshape(compressed_shape + (1, 1)).tile((1, 1) + granularity) return mask.swapaxes(1, 2).reshape(shape).contiguous() @@ -55,6 +55,6 @@ def random_mask(shape: Tuple[int], sparsity: float = 0.99, device: Any = 'cuda') Args: shape (Tuple[int]): Mask shape. - sparsity (float): The ratio of empty block number to total block number. + sparsity (float): The ratio of empty items. """ return (torch.rand(shape, device=device) > sparsity).to(torch.uint8) diff --git a/test/unit/test_sparse_matmul.py b/test/unit/test_sparse_matmul.py index cbe78e0e..ef1bf881 100644 --- a/test/unit/test_sparse_matmul.py +++ b/test/unit/test_sparse_matmul.py @@ -52,7 +52,7 @@ def prepare_data( if mask is None: mask = block_mask( shape=shapes[sparse_port], - block=granularity, + granularity=granularity, sparsity=sparsity, device='cuda', ) diff --git a/test/unit/test_sparse_softmax.py b/test/unit/test_sparse_softmax.py index 7bf8e1a6..f4ed3599 100644 --- a/test/unit/test_sparse_softmax.py +++ b/test/unit/test_sparse_softmax.py @@ -29,7 +29,7 @@ def prepare_data( temperature = np.sqrt(W) if mask is None: - mask = block_mask((H, W), block=granularity, sparsity=sparsity, device='cuda') + mask = block_mask((H, W), granularity=granularity, sparsity=sparsity, device='cuda') data['grad_y'] = torch.rand(shape, device='cuda') data['input_x'].requires_grad = True From ff52951e7a034b87d970094b701c160515200cb5 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 23 Mar 2023 08:51:04 +0000 Subject: [PATCH 16/28] SparseLinear DSD support dynamic input shape --- sparta/operators/sparse_matmul.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/sparta/operators/sparse_matmul.py b/sparta/operators/sparse_matmul.py index 06d7245e..76090d85 100644 --- a/sparta/operators/sparse_matmul.py +++ b/sparta/operators/sparse_matmul.py @@ -307,11 +307,23 @@ def __init__(self, raw_module: torch.nn.Linear, mode: str = 'dsd'): self.weight: torch.nn.Parameter = None self.bias = raw_module.bias self.ports['B'].sample_data = raw_module.weight + self.out_features, self.in_features = raw_module.weight.shape if self._biased: self.ports['bias'].sample_data = raw_module.bias - self.forward = self._forward_with_bias + self._forward_static_shape = self._forward_with_bias else: - self.forward = self._forward_without_bias + self._forward_static_shape = self._forward_without_bias + if mode == 'dsd': + self.forward = self._forward_dynamic_shape + else: + self.forward = self._forward_static_shape + + def _forward_dynamic_shape(self, x: torch.Tensor): + batch_shape = x.shape[:-1] + x = x.reshape(-1, self.in_features) + y = self._forward_static_shape(x) + y = y.reshape(*batch_shape, -1) + return y def _forward_with_bias(self, x: torch.Tensor) -> torch.Tensor: return super().forward(x, self.weight, self.bias) @@ -332,5 +344,6 @@ def _post_build(self): return super()._post_build() def _read_sample_inputs(self, sample_inputs: List[torch.Tensor]): + sample_inputs[0] = sample_inputs[0].reshape((-1, self.in_features)) sample_inputs.append(self.ports['B'].sample_data) return super()._read_sample_inputs(sample_inputs) From 8cad262e404848a361dbba2e7513ce277b36274e Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Fri, 14 Apr 2023 03:12:01 +0900 Subject: [PATCH 17/28] FlashSparseAttentionForwardKernel; fix softmax kernels --- sparta/kernels/__init__.py | 1 + sparta/kernels/attention.py | 127 ++++++++ .../flash_sparse_attention_forward.cuh.j2 | 284 ++++++++++++++++++ .../sparta_sparse_softmax_backward.cuh.j2 | 11 +- .../sparta_sparse_softmax_forward.cuh.j2 | 17 +- sparta/testing/mask.py | 10 +- 6 files changed, 432 insertions(+), 18 deletions(-) create mode 100644 sparta/kernels/attention.py create mode 100644 sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 diff --git a/sparta/kernels/__init__.py b/sparta/kernels/__init__.py index 983d8cdf..f7d9577c 100644 --- a/sparta/kernels/__init__.py +++ b/sparta/kernels/__init__.py @@ -4,3 +4,4 @@ from sparta.kernels.kernel_base import KernelBase, SparsityAttr, KernelGroup from sparta.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel from sparta.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel +from sparta.kernels.attention import FlashSparseAttentionForwardKernel diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py new file mode 100644 index 00000000..12eb358d --- /dev/null +++ b/sparta/kernels/attention.py @@ -0,0 +1,127 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT license. + +import io +import textwrap +import importlib.resources as res +from typing import Any, Dict, Tuple, Optional + +import torch +import jinja2 +import numpy as np +import pandas as pd + +from sparta.tuning import TunableItemCfg +from sparta.kernels import KernelBase, SparsityAttr, templates, look_up_tables +from sparta.testing import sparse_multi_head_attention_reference + + +class FlashSparseAttentionForwardKernel(KernelBase): + + __lut_shape__ = (64 * 12, 1024, 1024, 64) # BxH, Nt, Ns, D + __algo__ = 'flash' + __direction__ = 'forward' + + def __init__(self, buffer: torch.Tensor, dtype: str = 'float'): + self._buffer = buffer + self._dtype = dtype + super().__init__() + + def _add_parameters(self): + self._add_parameter( + f'GLOBAL_SIZE_D_VALUE', + ) + for dim in ['S', 'T']: + self._add_parameter( + f'BLOCK_SIZE_{dim}_VALUE', + is_tunable=True, + search_space=TunableItemCfg('choice', [8, 16, 32, 64, 128, 256]), + ) + self._add_parameter( + f'THREAD_SIZE_{dim}_VALUE', + ) + self.attr = SparsityAttr(self, 'BLOCK_SIZE_T_VALUE', 'BLOCK_SIZE_S_VALUE', BCSR=False, BCSC=True) + + def _check_parameters(self, params: Dict[str, Any]): + Bt = params['BLOCK_SIZE_T_VALUE'] + Bs = params['BLOCK_SIZE_S_VALUE'] + assert Bt >= 4 + assert Bs >= 4 + assert Bt & (Bt - 1) == 0 + assert Bs & (Bs - 1) == 0 + Tt = params['THREAD_SIZE_T_VALUE'] + Ts = params['THREAD_SIZE_S_VALUE'] + assert Bt >= Tt + assert Bs >= Ts + assert Tt & (Tt - 1) == 0 + assert Ts & (Ts - 1) == 0 + + def _check_shape(self, Nt: int, Ns: int, D: int): + Bt, Bs = self.get_block_shape() + Tt, Ts = self.get_thread_shape() + assert D & (D - 1) == 0 # TODO: pad + threads_per_block = Bs // Ts * Bt // Tt + smem_threads_D = D // 4 + assert threads_per_block >= smem_threads_D + smem_threads_N = threads_per_block // smem_threads_D + assert smem_threads_N <= Bt + assert smem_threads_N <= Bs + assert Bs // Ts <= 32 + assert Bs // Ts >= 4 + assert D * Ts >= Bs + + def get_block_shape(self): + Bt = self.get_parameter('BLOCK_SIZE_T_VALUE') + Bs = self.get_parameter('BLOCK_SIZE_S_VALUE') + return Bt, Bs + + def get_thread_shape(self): + Tt = self.get_parameter('THREAD_SIZE_T_VALUE') + Ts = self.get_parameter('THREAD_SIZE_S_VALUE') + return Tt, Ts + + def threads_per_block(self): + Bt, Bs = self.get_block_shape() + Tt, Ts = self.get_thread_shape() + return (Bs // Ts, Bt // Tt, 1) + + def set_kernel_call(self, shape: Tuple[int, int, int, int]): + batch, Nt, Ns, D = shape + self._check_shape(Nt, Ns, D) + Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) + block = self.threads_per_block() + + def attn_func(Q, K, V): + O = torch.zeros_like(Q) + self._buffer.fill_(0) # TODO: try BCSR and delete this + self._kernel( + Q, K, V, O, self._buffer, + self.attr.indexes.BCSC_idx, + Ns_32, Nt_32, # D_32, + self.attr.indexes.nnz, + block=block, + grid=(Q.shape[0], 1, 1), + ) + return O + + self._func = attn_func + + def get_kernel_code(self): + template_file = f'{self.__algo__}_sparse_attention_{self.__direction__}.cuh.j2' + kernel_template = res.read_text(templates, template_file) + with open('tmp.cu', 'w') as f: + f.write(jinja2.Template(kernel_template).render(self.get_parameters())) + return jinja2.Template(kernel_template).render(self.get_parameters()) + + def reference( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + sparse: bool = False, + ): + return sparse_multi_head_attention_reference(Q, K, V, self.attr.mask) + + def compile(self, params: Dict[str, Any], shape: Tuple): + params['GLOBAL_SIZE_D_VALUE'] = shape[-1] + super().compile(params, shape) diff --git a/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 new file mode 100644 index 00000000..4e81b830 --- /dev/null +++ b/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 @@ -0,0 +1,284 @@ +{# Copyright (c) Microsoft Corporation. #} +{# Licensed under the MIT license. #} + +{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE <= 32 #} + +{% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * BLOCK_SIZE_T_VALUE // THREAD_SIZE_T_VALUE %} +{% set THREAD_SIZE_D_VALUE = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} + +const int BS = {{ BLOCK_SIZE_S_VALUE }}; +const int BT = {{ BLOCK_SIZE_T_VALUE }}; +const int TS = {{ THREAD_SIZE_S_VALUE }}; +const int TT = {{ THREAD_SIZE_T_VALUE }}; + +const int D = {{ GLOBAL_SIZE_D_VALUE }}; +const int TD = {{ THREAD_SIZE_D_VALUE }};{# D * TS >= BS #} + +__global__ void BLOCK_SPARSE_FLASH_ATTENTION( + float* Q, + float* K, + float* V, + float* O, + float* ML, + {# unsigned char* mask, #} + uint* block_idx, + uint Ns, + uint Nt, + uint block_nnz +) { + Q += Nt * D * blockIdx.x; + K += Ns * D * blockIdx.x; + V += Ns * D * blockIdx.x; + O += Nt * D * blockIdx.x; + ML += Nt * 2 * blockIdx.x; + + uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; + uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; + + __shared__ float shared_Q[BT * D]; + __shared__ float shared_K[BS * D]; + __shared__ float shared_V[BS * D]; + {# __shared__ float shared_ML[BT * 2]; #} + float* shared_ML = shared_Q; + + int SMEM_THREADS_D = D / 4; + int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; + + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int SMEM_TID_N = tid / SMEM_THREADS_D; + int SMEM_TID_D = tid % SMEM_THREADS_D * 4; + + float4 tmp_float4; + float frag_QO[TT][TD]; + float frag_KV[TS][TD]; + float frag_P[TT][TS]; + float frag_S[TT][TS]; + + float temperature = __frsqrt_rn((float)D); + float row_max; + float row_sum; + float row_sum_new; + float seg_max; + float seg_sum; + float row_coef; + float seg_coef; + int block_row_idx; + + int last_col_idx = -1; + {# BCSC #} + for (int block = 0; block < block_nnz; block++) { + uint idx = block_idx[block]; + int row_idx = idx & 0xffff; + int col_idx = idx >> 16; + // if (blockIdx.x == 0 && threadIdx.x == 0 && threadIdx.y == 0) + // printf("#%d: (%d, %d)\n", block, row_idx, col_idx); + + {# Load Q #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + } + if (col_idx != last_col_idx) { + {# Load K #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + tmp_float4 = (reinterpret_cast(&K[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + shared_K[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; + shared_K[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; + shared_K[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; + shared_K[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; + } + {# Load V #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&shared_V[k * D + SMEM_TID_D])) = *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + } + last_col_idx = col_idx; + } + __syncthreads(); + + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_P[jt][js] = 0; + } + } + + {# Calc P = Q K^T #} + #pragma unroll + for (int k = 0; k < D; k += TD) { + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_QO[jt][i] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i] * temperature; + } + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_KV[js][i] = shared_K[(k + i) * BS + threadIdx.x + blockDim.x * js]; + } + } + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int i = 0; i < TD; i++) { + frag_P[jt][js] += frag_QO[jt][i] * frag_KV[js][i]; + } + } + } + } + __syncthreads(); + + {# Load M, L #} + #pragma unroll + for (int jt = tid * 2; jt < BT; jt += {{ THREADS_PER_BLOCK * 2 }}) { + *((float4*)(&shared_ML[jt * 2])) = *((float4*)(&ML[(row_idx * BT + jt) * 2])); + } + __syncthreads(); + // if (blockIdx.x == 0 && threadIdx.x == 0 && threadIdx.y == 0 && row_idx == 0) { + // printf("P%d[0] = %f\n", col_idx + 1, frag_P[0][0]); + // printf("M = %f, L = %f\n", shared_ML[0], shared_ML[1]); + // } + + {# Load O #} + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + frag_QO[jt] = O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&frag_QO[jt][0])) = + *((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * 2])); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + *((float4*)(&frag_QO[jt][i])) = + *((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i])); + } + {% endif %} + } + + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {# Calc M~ = max_j(P) #} + seg_max = -100000.0; + #pragma unroll + for (int js = 0; js < TS; js++) { + seg_max = max(seg_max, frag_P[jt][js]); + } + #pragma unroll + for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { + seg_max = max(seg_max, __shfl_xor_sync(WARP_MASK, seg_max, offset)); + } + {# Calc S = exp(P - M~) #} + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_P[jt][js] = expf(frag_P[jt][js] - seg_max); + } + {# Calc L~ = sum_j(P) #} + seg_sum = 0.0f; + #pragma unroll + for (int js = 0; js < TS; js++) { + seg_sum += frag_P[jt][js]; + } + #pragma unroll + for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { + seg_sum += __shfl_down_sync(WARP_MASK, seg_sum, offset); + } + {# Calc M' = max(M, M~), L' = exp(M - M') * L + exp(M~ - M') * L~ #} + if (threadIdx.x == 0) { + block_row_idx = (threadIdx.y + blockDim.y * jt) * 2; + row_max = shared_ML[block_row_idx]; + row_sum = shared_ML[block_row_idx + 1]; + if (row_max < seg_max) { + shared_ML[block_row_idx] = seg_max; + row_coef = expf(row_max - seg_max); + row_sum_new = row_coef * row_sum + seg_sum; + row_coef *= row_sum / row_sum_new; + seg_coef = 1.0f / row_sum_new; + } else { + seg_coef = expf(seg_max - row_max); + row_sum_new = row_sum + seg_coef * seg_sum; + row_coef = row_sum / row_sum_new; + seg_coef /= row_sum_new; + } + shared_ML[block_row_idx + 1] = row_sum_new; + } + row_coef = __shfl_sync(WARP_MASK, row_coef, WARP_OFFSET); + seg_coef = __shfl_sync(WARP_MASK, seg_coef, WARP_OFFSET); + // if (blockIdx.x == 0 && threadIdx.x == 0 && threadIdx.y == 0 && row_idx == 0 && jt == 0) { + // printf("M%d = %f, L%d = %f\n", col_idx + 1, seg_max, row_idx + 1, seg_sum); + // printf("S%d[0] = %f\n", col_idx + 1, frag_P[0][0]); + // printf("row_coef = %f, seg_coef = %f\n", row_coef, seg_coef); + // } + {# Calc O' = L / L' * exp(M - M') * O, S' = exp(M~ - M') / L' * S #} + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_P[jt][js] *= seg_coef; + } + #pragma unroll + for (int i = 0; i < TD; i++) { + frag_QO[jt][i] *= row_coef; + } + } + __syncthreads(); + + {# Calc O = O' + S' V #} + #pragma unroll + for (int k = 0; k < BS; k += TS) { + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_S[jt][js] = + __shfl_sync(WARP_MASK, frag_P[jt][(k + js) / blockDim.x], (k + js) % blockDim.x + WARP_OFFSET); + } + } + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int i = 0; i < TD; i++) { + frag_KV[js][i] = shared_V[(k + js) * D + threadIdx.x * TD + i]; + } + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_QO[jt][i] += frag_S[jt][js] * frag_KV[js][i]; + } + } + } + } + + {# Save O #} + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x] = frag_QO[jt]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * 2])) = + *((float2*)(&frag_QO[jt][0])); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + *((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i])) = + *((float4*)(&frag_QO[jt][i])); + } + {% endif %} + } + + {# Save M, L #} + #pragma unroll + for (int jt = tid * 2; jt < BT; jt += {{ THREADS_PER_BLOCK * 2 }}) { + *((float4*)(&ML[(row_idx * BT + jt) * 2])) = *((float4*)(&shared_ML[jt * 2])); + } + } +} diff --git a/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 b/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 index 67c6f1be..876db9da 100644 --- a/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 +++ b/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 @@ -2,8 +2,6 @@ {# Licensed under the MIT license. #} {% set WARP_SIZE = [32, BLOCK_SIZE_W_VALUE]|min %} -{% set INI_OFFSET = WARP_SIZE // 2 %} -#define FULL_MASK 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} const int block_h = {{ BLOCK_SIZE_H_VALUE }}; const int block_w = {{ BLOCK_SIZE_W_VALUE }}; @@ -20,6 +18,9 @@ __global__ void SPARSE_SOFTMAX( int H, int W ) { + uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_SIZE }}) * {{ WARP_SIZE }}; + uint WARP_MASK = 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; + {% if BATCHED %} {% if COMPRESSED %} int num_nnz = row_ptr[H / block_h]; @@ -64,10 +65,10 @@ __global__ void SPARSE_SOFTMAX( regSum += out_val[index] * out_grad[index]; } - for (int offset = {{ INI_OFFSET }}; offset > 0; offset >>= 1) { - regSum += __shfl_down_sync(FULL_MASK, regSum, offset); + #pragma unroll + for (int offset = {{ WARP_SIZE // 2 }}; offset > 0; offset >>= 1) { + regSum += __shfl_xor_sync(WARP_MASK, regSum, offset); } - regSum = __shfl_sync(FULL_MASK, regSum, 0); for (int k = 0; k < val_num; k++) { uint index = index_list[k]; diff --git a/sparta/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 b/sparta/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 index 5e527e26..b7a7c5c7 100644 --- a/sparta/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 +++ b/sparta/kernels/templates/sparta_sparse_softmax_forward.cuh.j2 @@ -2,8 +2,6 @@ {# Licensed under the MIT license. #} {% set WARP_SIZE = [32, BLOCK_SIZE_W_VALUE]|min %} -{% set INI_OFFSET = WARP_SIZE // 2 %} -#define FULL_MASK 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} const int block_h = {{ BLOCK_SIZE_H_VALUE }}; const int block_w = {{ BLOCK_SIZE_W_VALUE }}; @@ -19,6 +17,9 @@ __global__ void SPARSE_SOFTMAX( int H, int W ) { + uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_SIZE }}) * {{ WARP_SIZE }}; + uint WARP_MASK = 0x{% for _ in range(WARP_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; + {% if BATCHED %} {% if COMPRESSED %} int num_nnz = row_ptr[H / block_h]; @@ -62,20 +63,20 @@ __global__ void SPARSE_SOFTMAX( regMax = max(regMax, in_val[index]); } - for (int offset = {{ INI_OFFSET }}; offset > 0; offset >>= 1) { - regMax = max(regMax, __shfl_down_sync(FULL_MASK, regMax, offset)); + #pragma unroll + for (int offset = {{ WARP_SIZE // 2 }}; offset > 0; offset >>= 1) { + regMax = max(regMax, __shfl_xor_sync(WARP_MASK, regMax, offset)); } - regMax = __shfl_sync(FULL_MASK, regMax, 0); for (int k = 0; k < val_num; k++) { uint index = index_list[k]; regSum += expf((in_val[index] - regMax) * temperature); } - for (int offset = {{ INI_OFFSET }}; offset > 0; offset >>= 1) { - regSum += __shfl_down_sync(FULL_MASK, regSum, offset); + #pragma unroll + for (int offset = {{ WARP_SIZE // 2 }}; offset > 0; offset >>= 1) { + regSum += __shfl_xor_sync(WARP_MASK, regSum, offset); } - regSum = __shfl_sync(FULL_MASK, regSum, 0); for (int k = 0; k < val_num; k++) { uint index = index_list[k]; diff --git a/sparta/testing/mask.py b/sparta/testing/mask.py index 35a06df6..52842f55 100644 --- a/sparta/testing/mask.py +++ b/sparta/testing/mask.py @@ -7,8 +7,8 @@ def block_mask( - shape: Tuple[int], - granularity: Tuple[int] = (1, 1), + shape: Tuple[int, int], + granularity: Tuple[int, int] = (1, 1), sparsity: float = 0.99, algo: str = 'rand', device: Any = 'cuda', @@ -32,8 +32,8 @@ def block_mask( def random_block_mask( - shape: Tuple[int], - granularity: Tuple[int], + shape: Tuple[int, int], + granularity: Tuple[int, int], sparsity: float = 0.99, device: Any = 'cuda', ): @@ -50,7 +50,7 @@ def random_block_mask( return mask.swapaxes(1, 2).reshape(shape).contiguous() -def random_mask(shape: Tuple[int], sparsity: float = 0.99, device: Any = 'cuda'): +def random_mask(shape: Tuple[int, int], sparsity: float = 0.99, device: Any = 'cuda'): """Randomly generate a 2D uint8 tensor as finegrained mask. Args: From 370a6db7d92e8b1d22a1d9b4c53be12ceddd7ad3 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 27 Apr 2023 18:21:19 +0900 Subject: [PATCH 18/28] FlashSParseAttentionBackwardKernel with limited performance --- examples/sparse_attention.ipynb | 4 +- sparta/kernels/__init__.py | 2 +- sparta/kernels/attention.py | 93 +++- .../flash_sparse_attention_backward.cuh.j2 | 477 ++++++++++++++++++ .../flash_sparse_attention_forward.cuh.j2 | 162 ++++-- .../sparta_sparse_softmax_backward.cuh.j2 | 2 +- sparta/testing/__init__.py | 2 +- sparta/testing/math.py | 48 +- test/bench/attention/attention.py | 6 +- test/unit/test_seqlen_attention.py | 4 +- test/unit/test_sparse_attention.py | 4 +- 11 files changed, 715 insertions(+), 89 deletions(-) create mode 100644 sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 diff --git a/examples/sparse_attention.ipynb b/examples/sparse_attention.ipynb index 0445cf94..1c0263f5 100644 --- a/examples/sparse_attention.ipynb +++ b/examples/sparse_attention.ipynb @@ -122,7 +122,7 @@ "source": [ "Check whether the sparse operator works correctly.\n", "\n", - "We provide `sparta.testing.sparse_multi_head_attention_reference()` function to calculate masked attention using dense method." + "We provide `sparta.testing.sparse_multi_head_attention_forward_reference()` function to calculate masked attention using dense method." ] }, { @@ -141,7 +141,7 @@ "value.requires_grad = True\n", "\n", "def dense_attention(query, key, value):\n", - " return sparta.testing.sparse_multi_head_attention_reference(query, key, value, mask)\n", + " return sparta.testing.sparse_multi_head_attention_forward_reference(query, key, value, mask)\n", "\n", "for sparse_out, dense_out in zip(forward_backward(dense_attention), forward_backward(sparse_attention)):\n", " torch.testing.assert_close(sparse_out, dense_out)" diff --git a/sparta/kernels/__init__.py b/sparta/kernels/__init__.py index f7d9577c..4b47e5d2 100644 --- a/sparta/kernels/__init__.py +++ b/sparta/kernels/__init__.py @@ -4,4 +4,4 @@ from sparta.kernels.kernel_base import KernelBase, SparsityAttr, KernelGroup from sparta.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel from sparta.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.kernels.attention import FlashSparseAttentionForwardKernel +from sparta.kernels.attention import FlashSparseAttentionForwardKernel, FlashSparseAttentionBackwardKernel diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py index 12eb358d..457776fc 100644 --- a/sparta/kernels/attention.py +++ b/sparta/kernels/attention.py @@ -13,14 +13,14 @@ from sparta.tuning import TunableItemCfg from sparta.kernels import KernelBase, SparsityAttr, templates, look_up_tables -from sparta.testing import sparse_multi_head_attention_reference +from sparta.testing import sparse_multi_head_attention_forward_reference, sparse_multi_head_attention_backward_reference -class FlashSparseAttentionForwardKernel(KernelBase): +class FlashSparseAttentionKernel(KernelBase): __lut_shape__ = (64 * 12, 1024, 1024, 64) # BxH, Nt, Ns, D __algo__ = 'flash' - __direction__ = 'forward' + __direction__ = '' def __init__(self, buffer: torch.Tensor, dtype: str = 'float'): self._buffer = buffer @@ -28,18 +28,19 @@ def __init__(self, buffer: torch.Tensor, dtype: str = 'float'): super().__init__() def _add_parameters(self): + self._add_parameter('GLOBAL_SIZE_D_VALUE') self._add_parameter( - f'GLOBAL_SIZE_D_VALUE', + 'BLOCK_SIZE_S_VALUE', + is_tunable=True, + search_space=TunableItemCfg('choice', [8, 16, 32, 64, 128, 256]), ) - for dim in ['S', 'T']: - self._add_parameter( - f'BLOCK_SIZE_{dim}_VALUE', - is_tunable=True, - search_space=TunableItemCfg('choice', [8, 16, 32, 64, 128, 256]), - ) - self._add_parameter( - f'THREAD_SIZE_{dim}_VALUE', - ) + self._add_parameter( + 'BLOCK_SIZE_T_VALUE', + is_tunable=True, + search_space=TunableItemCfg('choice', [8, 16, 32, 64, 128, 256]), + ) + self._add_parameter('THREAD_SIZE_S_VALUE') + self._add_parameter('THREAD_SIZE_T_VALUE') self.attr = SparsityAttr(self, 'BLOCK_SIZE_T_VALUE', 'BLOCK_SIZE_S_VALUE', BCSR=False, BCSC=True) def _check_parameters(self, params: Dict[str, Any]): @@ -85,6 +86,22 @@ def threads_per_block(self): Tt, Ts = self.get_thread_shape() return (Bs // Ts, Bt // Tt, 1) + def get_kernel_code(self): + template_file = f'{self.__algo__}_sparse_attention_{self.__direction__}.cuh.j2' + kernel_template = res.read_text(templates, template_file) + with open('tmp.cu', 'w') as f: + f.write(jinja2.Template(kernel_template).render(self.get_parameters())) + return jinja2.Template(kernel_template).render(self.get_parameters()) + + def compile(self, params: Dict[str, Any], shape: Tuple): + params['GLOBAL_SIZE_D_VALUE'] = shape[-1] + super().compile(params, shape) + + +class FlashSparseAttentionForwardKernel(FlashSparseAttentionKernel): + + __direction__ = 'forward' + def set_kernel_call(self, shape: Tuple[int, int, int, int]): batch, Nt, Ns, D = shape self._check_shape(Nt, Ns, D) @@ -106,22 +123,56 @@ def attn_func(Q, K, V): self._func = attn_func - def get_kernel_code(self): - template_file = f'{self.__algo__}_sparse_attention_{self.__direction__}.cuh.j2' - kernel_template = res.read_text(templates, template_file) - with open('tmp.cu', 'w') as f: - f.write(jinja2.Template(kernel_template).render(self.get_parameters())) - return jinja2.Template(kernel_template).render(self.get_parameters()) + def reference( + self, + Q: torch.Tensor, + K: torch.Tensor, + V: torch.Tensor, + sparse: bool = False, + ): + return sparse_multi_head_attention_forward_reference(Q, K, V, self.attr.mask) + + +class FlashSparseAttentionBackwardKernel(FlashSparseAttentionKernel): + + __direction__ = 'backward' + + def set_kernel_call(self, shape: Tuple[int, int, int, int]): + batch, Nt, Ns, D = shape + self._check_shape(Nt, Ns, D) + Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) + block = self.threads_per_block() + + def attn_func(grad, O, Q, K, V): + grad_Q = torch.zeros_like(Q) + grad_K = torch.zeros_like(Q) + grad_V = torch.zeros_like(Q) + self._kernel( + Q, K, V, O, grad_Q, grad_K, grad_V, grad, self._buffer, + self.attr.indexes.BCSC_idx, + Ns_32, Nt_32, # D_32, + self.attr.indexes.nnz, + block=block, + grid=(Q.shape[0], 1, 1), + ) + return grad_Q, grad_K, grad_V + + self._func = attn_func def reference( self, + grad: torch.Tensor, + O: torch.Tensor, Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, sparse: bool = False, ): - return sparse_multi_head_attention_reference(Q, K, V, self.attr.mask) + return sparse_multi_head_attention_backward_reference(grad, O, Q, K, V, self.attr.mask) def compile(self, params: Dict[str, Any], shape: Tuple): - params['GLOBAL_SIZE_D_VALUE'] = shape[-1] + Bs = params['BLOCK_SIZE_S_VALUE'] + Bt = params['BLOCK_SIZE_T_VALUE'] + Ts = params['THREAD_SIZE_S_VALUE'] + params['THREAD_SIZE_T_VALUE'] = Ts * Bt // Bs super().compile(params, shape) diff --git a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 new file mode 100644 index 00000000..663ff7ba --- /dev/null +++ b/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 @@ -0,0 +1,477 @@ +{# Copyright (c) Microsoft Corporation. #} +{# Licensed under the MIT license. #} + +{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE <= 32 #} + +{% set THREAD_SIZE_D_VALUE = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} +{% set THREAD_SIZE_T_VALUE = BLOCK_SIZE_T_VALUE // WARP_REDUCE_SIZE %} + +{% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * WARP_REDUCE_SIZE %} + +const int BS = {{ BLOCK_SIZE_S_VALUE }}; +const int BT = {{ BLOCK_SIZE_T_VALUE }}; +const int D = {{ GLOBAL_SIZE_D_VALUE }}; +const int TS = {{ THREAD_SIZE_S_VALUE }}; + +const int TT = {{ THREAD_SIZE_T_VALUE }}; +const int TD = {{ THREAD_SIZE_D_VALUE }};{# D * TS >= BS #} + +__device__ __forceinline__ float2 _add_float2(float2 x, float2 y) \ +{ \ + float2 res; \ + res.x = x.x + y.x; \ + res.y = x.y + y.y; \ + return res; \ +} + +__device__ __forceinline__ float4 _add_float4(float4 x, float4 y) \ +{ \ + float4 res; \ + res.x = x.x + y.x; \ + res.y = x.y + y.y; \ + res.z = x.z + y.z; \ + res.w = x.w + y.w; \ + return res; \ +} + +__device__ __forceinline__ float4 _mul_float4(float4 x, float4 y) \ +{ \ + float4 res; \ + res.x = x.x * y.x; \ + res.y = x.y * y.y; \ + res.z = x.z * y.z; \ + res.w = x.w * y.w; \ + return res; \ +} + +__global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( + float* Q, + float* K, + float* V, + float* O, + float* dQ, + float* dK, + float* dV, + float* dO, + float* ML, + {# unsigned char* mask, #} + uint* block_idx, + uint Ns, + uint Nt, + uint block_nnz +) { + Q += Nt * D * blockIdx.x; + K += Ns * D * blockIdx.x; + V += Ns * D * blockIdx.x; + O += Nt * D * blockIdx.x; + dQ += Nt * D * blockIdx.x; + dK += Ns * D * blockIdx.x; + dV += Ns * D * blockIdx.x; + dO += Nt * D * blockIdx.x; + ML += Nt * 2 * blockIdx.x; + + uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; + uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; + + {# extern __shared__ float buffer[]; + float* shared_Q = &buffer[0]; + float* shared_K = &buffer[BT * D]; + float* shared_V = &buffer[BT * D + BS * D]; + float* shared_O = &buffer[BT * D + 2 * BS * D]; + float* shared_dK = &buffer[2 * BT * D + 2 * BS * D]; + float* shared_dV = &buffer[2 * BT * D + 3 * BS * D]; #} + __shared__ float shared_Q[BT * D]; + __shared__ float shared_K[BS * D]; + __shared__ float shared_V[BS * D]; + __shared__ float shared_O[BT * D]; + __shared__ float shared_dK[BS * D]; + __shared__ float shared_dV[BS * D]; + {# __shared__ float shared_ML[BT * 2]; #} + float* shared_ML = shared_O; + + int SMEM_THREADS_D = D / 4; + int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; + + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int SMEM_TID_N = tid / SMEM_THREADS_D; + int SMEM_TID_D = tid % SMEM_THREADS_D * 4; + + float4 tmp_float4; + float frag_QO[TT][TD]; + float frag_KV[TS][TD]; + float frag_P[TT][TS]; + float frag_S[TT][TS]; + + float temperature = __frsqrt_rn((float)Ns); + float row_max; + float row_sum; + int block_row_idx; + + int last_col_idx = -1; + {# BCSC #} + for (int block = 0; block < block_nnz; block++) { + uint idx = block_idx[block]; + int row_idx = idx & 0xffff; + int col_idx = idx >> 16; + // if (blockIdx.x == 0 && threadIdx.x == 0 && threadIdx.y == 0) + // printf("#%d: (%d, %d)\n", block, row_idx, col_idx); + + {# Load Q #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + } + if (col_idx != last_col_idx) { + if (last_col_idx >= 0) { + {# Save dK, dV #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dK[k * D + SMEM_TID_D])); + *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dV[k * D + SMEM_TID_D])); + } + } + {# Load K, V, dK, dV #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + tmp_float4 = ((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + shared_K[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; + shared_K[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; + shared_K[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; + shared_K[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; + tmp_float4 = ((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + shared_V[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; + shared_V[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; + shared_V[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; + shared_V[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; + *((float4*)(&shared_dK[k * D + SMEM_TID_D])) = *((float4*)(&dK[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_dV[k * D + SMEM_TID_D])) = *((float4*)(&dV[(col_idx * BS + k) * D + SMEM_TID_D])); + } + last_col_idx = col_idx; + } + __syncthreads(); + + {# Initialize P #} + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_P[jt][js] = 0; + } + } + + {# Calc P = Q K^T #} + #pragma unroll + for (int k = 0; k < D; k += TD) { + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_QO[jt][i] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i] * temperature; + } + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_KV[js][i] = shared_K[(k + i) * BS + threadIdx.x + blockDim.x * js]; + } + } + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int i = 0; i < TD; i++) { + frag_P[jt][js] += frag_QO[jt][i] * frag_KV[js][i]; + } + } + } + } + __syncthreads(); + + {# Load M, L #} + #pragma unroll + for (int jt = tid * 2; jt < BT; jt += {{ THREADS_PER_BLOCK * 2 }}) { + *((float4*)(&shared_ML[jt * 2])) = *((float4*)(&ML[(row_idx * BT + jt) * 2])); + } + __syncthreads(); + + {# Calc S = exp(P - M) / L #} + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + block_row_idx = (threadIdx.y + blockDim.y * jt) * 2; + row_max = shared_ML[block_row_idx]; + row_sum = shared_ML[block_row_idx + 1]; + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_S[jt][js] = expf(frag_P[jt][js] - row_max) / row_sum; + } + } + + {# Load dO #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_O[k * D + SMEM_TID_D])) = *((float4*)(&dO[(row_idx * BT + k) * D + SMEM_TID_D])); + } + + {# Calc dV = dV + S^T dO #} + #pragma unroll + for (int kk = 0, k = threadIdx.y * TD; kk < D; k = (k + TD) % D, kk += TD) { + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_KV[js][i] = 0; + } + } + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + frag_QO[jt][0] = shared_O[(threadIdx.y + blockDim.y * jt) * D + k]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&frag_QO[jt][0])) = *((float2*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k])); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + *((float4*)(&frag_QO[jt][i])) = *((float4*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k + i])); + } + {% endif %} + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_KV[js][i] += frag_S[jt][js] * frag_QO[jt][i]; + } + } + } + #pragma unroll + for (int js = 0; js < TS; js++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + shared_dV[(threadIdx.x + blockDim.x * js) * D + k] += frag_KV[js][0]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + ((float2*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k]))[0] = + _add_float2( + ((float2*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k]))[0], + ((float2*)(&frag_KV[js][0]))[0] + ); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + ((float4*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k + i]))[0] = + _add_float4( + ((float4*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k + i]))[0], + ((float4*)(&frag_KV[js][i]))[0] + ); + } + {% endif %} + } + __syncthreads(); + } + + {# Initialize dS #} + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_P[jt][js] = 0; + } + } + + {# Calc dS = dO V^T #} + #pragma unroll + for (int k = 0; k < D; k += TD) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + frag_QO[jt][0] = shared_O[(threadIdx.y + blockDim.y * jt) * D + k]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&frag_QO[jt][0])) = *((float2*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k])); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + *((float4*)(&frag_QO[jt][i])) = *((float4*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k + i])); + } + {% endif %} + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_KV[js][i] = shared_V[(k + i) * BS + threadIdx.x + blockDim.x * js]; + } + } + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int i = 0; i < TD; i++) { + frag_P[jt][js] += frag_QO[jt][i] * frag_KV[js][i]; + } + } + } + } + __syncthreads(); + + {# Calc dO = dO * O #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + ((float4*)(&shared_O[k * D + SMEM_TID_D]))[0] = + _mul_float4( + ((float4*)(&shared_O[k * D + SMEM_TID_D]))[0], + ((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D]))[0] + ); + } + + {# Calc dP = S (dS - sum_j(dO)) #} + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + row_sum = 0.0f; + #pragma unroll + for (int i = 0; i < TD; i++) { + row_sum += shared_O[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i]; + } + #pragma unroll + for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { + row_sum += __shfl_xor_sync(WARP_MASK, row_sum, offset); + } + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_P[jt][js] = frag_S[jt][js] * (frag_P[jt][js] - row_sum) * temperature; + } + } + + {# Calc dK = dK + dP^T Q #} + #pragma unroll + for (int kk = 0, k = threadIdx.y * TD; kk < D; k = (k + TD) % D, kk += TD) { + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_KV[js][i] = 0; + } + } + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + frag_QO[jt][0] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&frag_QO[jt][0])) = + *((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k])); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + *((float4*)(&frag_QO[jt][i])) = + *((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i])); + } + {% endif %} + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_KV[js][i] += frag_P[jt][js] * frag_QO[jt][i]; + } + } + } + #pragma unroll + for (int js = 0; js < TS; js++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + shared_dK[(threadIdx.x + blockDim.x * js) * D + k] += frag_KV[js][0]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + ((float2*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k]))[0] = + _add_float2( + ((float2*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k]))[0], + ((float2*)(&frag_KV[js][0]))[0] + ); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + ((float4*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k + i]))[0] = + _add_float4( + ((float4*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k + i]))[0], + ((float4*)(&frag_KV[js][i]))[0] + ); + } + {% endif %} + } + __syncthreads(); + } + + {# Load dQ #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])); + } + __syncthreads(); + + {# Calc dQ = dQ + dP K #} + #pragma unroll + for (int kk = 0, k = threadIdx.x * TD; kk < D; k = (k + TD) % D, kk += TD) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int i = 0; i < TD; i++) { + frag_QO[jt][i] = 0; + } + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_KV[js][i] = shared_K[(k + i) * BS + threadIdx.x + blockDim.x * js]; + } + } + #pragma unroll + for (int i = 0; i < TD; i++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_QO[jt][i] += frag_P[jt][js] * frag_KV[js][i]; + } + } + } + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + shared_Q[(threadIdx.y + blockDim.y * jt) * D + k] += frag_QO[jt][0]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]))[0] = + _add_float2( + ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]))[0], + ((float2*)(&frag_QO[jt][0]))[0] + ); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]))[0] = + _add_float4( + ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]))[0], + ((float4*)(&frag_QO[jt][i]))[0] + ); + } + {% endif %} + } + __syncthreads(); + } + + {# Save dQ #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k * D + SMEM_TID_D])); + } + } + + {# Save dK, dV for the last column #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dK[k * D + SMEM_TID_D])); + *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dV[k * D + SMEM_TID_D])); + } +} diff --git a/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 index 4e81b830..59a88365 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 @@ -3,17 +3,54 @@ {% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE <= 32 #} -{% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * BLOCK_SIZE_T_VALUE // THREAD_SIZE_T_VALUE %} {% set THREAD_SIZE_D_VALUE = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} +{% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * BLOCK_SIZE_T_VALUE // THREAD_SIZE_T_VALUE %} + const int BS = {{ BLOCK_SIZE_S_VALUE }}; const int BT = {{ BLOCK_SIZE_T_VALUE }}; +const int D = {{ GLOBAL_SIZE_D_VALUE }}; const int TS = {{ THREAD_SIZE_S_VALUE }}; const int TT = {{ THREAD_SIZE_T_VALUE }}; -const int D = {{ GLOBAL_SIZE_D_VALUE }}; const int TD = {{ THREAD_SIZE_D_VALUE }};{# D * TS >= BS #} +__device__ __forceinline__ float2 _add_float2(float2 x, float2 y) \ +{ \ + float2 res; \ + res.x = x.x + y.x; \ + res.y = x.y + y.y; \ + return res; \ +} + +__device__ __forceinline__ float2 _scale_float2(float2 x, float y) \ +{ \ + float2 res; \ + res.x = x.x * y; \ + res.y = x.y * y; \ + return res; \ +} + +__device__ __forceinline__ float4 _add_float4(float4 x, float4 y) \ +{ \ + float4 res; \ + res.x = x.x + y.x; \ + res.y = x.y + y.y; \ + res.z = x.z + y.z; \ + res.w = x.w + y.w; \ + return res; \ +} + +__device__ __forceinline__ float4 _scale_float4(float4 x, float y) \ +{ \ + float4 res; \ + res.x = x.x * y; \ + res.y = x.y * y; \ + res.z = x.z * y; \ + res.w = x.w * y; \ + return res; \ +} + __global__ void BLOCK_SPARSE_FLASH_ATTENTION( float* Q, float* K, @@ -35,6 +72,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; + {# extern __shared__ float buffer[]; + float* shared_Q = &buffer[0]; + float* shared_K = &buffer[BT * D]; + float* shared_V = &buffer[BT * D + BS * D]; #} __shared__ float shared_Q[BT * D]; __shared__ float shared_K[BS * D]; __shared__ float shared_V[BS * D]; @@ -52,9 +93,9 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( float frag_QO[TT][TD]; float frag_KV[TS][TD]; float frag_P[TT][TS]; - float frag_S[TT][TS]; + float frag_ML[TT]; - float temperature = __frsqrt_rn((float)D); + float temperature = __frsqrt_rn((float)Ns); float row_max; float row_sum; float row_sum_new; @@ -82,7 +123,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( {# Load K #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - tmp_float4 = (reinterpret_cast(&K[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + tmp_float4 = ((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; shared_K[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; shared_K[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_K[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; @@ -146,23 +187,6 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( // printf("M = %f, L = %f\n", shared_ML[0], shared_ML[1]); // } - {# Load O #} - #pragma unroll - for (int jt = 0; jt < TT; jt++) { - {% if THREAD_SIZE_D_VALUE == 1 %} - frag_QO[jt] = O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x]; - {% elif THREAD_SIZE_D_VALUE == 2 %} - *((float2*)(&frag_QO[jt][0])) = - *((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * 2])); - {% else %} - #pragma unroll - for (int i = 0; i < TD; i += 4) { - *((float4*)(&frag_QO[jt][i])) = - *((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i])); - } - {% endif %} - } - #pragma unroll for (int jt = 0; jt < TT; jt++) { {# Calc M~ = max_j(P) #} @@ -221,30 +245,63 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( for (int js = 0; js < TS; js++) { frag_P[jt][js] *= seg_coef; } + frag_ML[jt] = row_coef; + } + __syncthreads(); + + {# Save M, L #} + #pragma unroll + for (int jt = tid * 2; jt < BT; jt += {{ THREADS_PER_BLOCK * 2 }}) { + *((float4*)(&ML[(row_idx * BT + jt) * 2])) = *((float4*)(&shared_ML[jt * 2])); + } + + {# Load O #} + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD] = + O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD] * frag_ML[jt]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD]))[0] = + _scale_float2( + ((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD]))[0], + frag_ML[jt] + ); + {% else %} #pragma unroll - for (int i = 0; i < TD; i++) { - frag_QO[jt][i] *= row_coef; + for (int i = 0; i < TD; i += 4) { + ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i]))[0] = + _scale_float4( + ((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i]))[0], + frag_ML[jt] + ); } + {% endif %} } __syncthreads(); {# Calc O = O' + S' V #} #pragma unroll - for (int k = 0; k < BS; k += TS) { + for (int kk = 0, k = threadIdx.x * TD; kk < D; k = (k + TD) % D, kk += TD) { #pragma unroll - for (int js = 0; js < TS; js++) { + for (int jt = 0; jt < TT; jt++) { #pragma unroll - for (int jt = 0; jt < TT; jt++) { - frag_S[jt][js] = - __shfl_sync(WARP_MASK, frag_P[jt][(k + js) / blockDim.x], (k + js) % blockDim.x + WARP_OFFSET); + for (int i = 0; i < TD; i++) { + frag_QO[jt][i] = 0; } } #pragma unroll for (int js = 0; js < TS; js++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + frag_KV[js][0] = shared_V[(threadIdx.x + blockDim.x * js) * D + k]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&frag_KV[js][0])) = *((float2*)(&shared_V[(threadIdx.x + blockDim.x * js) * D + k])); + {% else %} #pragma unroll - for (int i = 0; i < TD; i++) { - frag_KV[js][i] = shared_V[(k + js) * D + threadIdx.x * TD + i]; + for (int i = 0; i < TD; i += 4) { + *((float4*)(&frag_KV[js][i])) = *((float4*)(&shared_V[(threadIdx.x + blockDim.x * js) * D + k + i])); } + {% endif %} } #pragma unroll for (int i = 0; i < TD; i++) { @@ -252,33 +309,38 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( for (int jt = 0; jt < TT; jt++) { #pragma unroll for (int js = 0; js < TS; js++) { - frag_QO[jt][i] += frag_S[jt][js] * frag_KV[js][i]; + frag_QO[jt][i] += frag_P[jt][js] * frag_KV[js][i]; } } } - } - - {# Save O #} - #pragma unroll - for (int jt = 0; jt < TT; jt++) { - {% if THREAD_SIZE_D_VALUE == 1 %} - O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x] = frag_QO[jt]; - {% elif THREAD_SIZE_D_VALUE == 2 %} - *((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * 2])) = - *((float2*)(&frag_QO[jt][0])); - {% else %} #pragma unroll - for (int i = 0; i < TD; i += 4) { - *((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i])) = - *((float4*)(&frag_QO[jt][i])); + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + shared_Q[(threadIdx.y + blockDim.y * jt) * D + k] += frag_QO[jt][0]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]))[0] = + _add_float2( + ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]))[0], + ((float2*)(&frag_QO[jt][0]))[0] + ); + {% else %} + #pragma unroll + for (int i = 0; i < TD; i += 4) { + ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]))[0] = + _add_float4( + ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]))[0], + ((float4*)(&frag_QO[jt][i]))[0] + ); + } + {% endif %} } - {% endif %} + __syncthreads(); } - {# Save M, L #} + {# Save O #} #pragma unroll - for (int jt = tid * 2; jt < BT; jt += {{ THREADS_PER_BLOCK * 2 }}) { - *((float4*)(&ML[(row_idx * BT + jt) * 2])) = *((float4*)(&shared_ML[jt * 2])); + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k * D + SMEM_TID_D])); } } } diff --git a/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 b/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 index 876db9da..7f72ab2a 100644 --- a/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 +++ b/sparta/kernels/templates/sparta_sparse_softmax_backward.cuh.j2 @@ -7,7 +7,7 @@ const int block_h = {{ BLOCK_SIZE_H_VALUE }}; const int block_w = {{ BLOCK_SIZE_W_VALUE }}; const int row_tile = {{ ROW_TILE_VALUE }}; -__global__ void SPARSE_SOFTMAX( +__global__ void SPARSE_SOFTMAX_BACKWARD( float* out_grad, float* out_val, unsigned char* mask, diff --git a/sparta/testing/__init__.py b/sparta/testing/__init__.py index 0244e8ef..7c6416e7 100644 --- a/sparta/testing/__init__.py +++ b/sparta/testing/__init__.py @@ -3,4 +3,4 @@ from sparta.testing.mask import block_mask from sparta.testing.utils import check, profile -from sparta.testing.math import sparse_softmax_forward_reference, sparse_softmax_backward_reference, sparse_multi_head_attention_reference +from sparta.testing.math import sparse_softmax_forward_reference, sparse_softmax_backward_reference, sparse_multi_head_attention_forward_reference, sparse_multi_head_attention_backward_reference diff --git a/sparta/testing/math.py b/sparta/testing/math.py index 5ecd2698..4fa324cc 100644 --- a/sparta/testing/math.py +++ b/sparta/testing/math.py @@ -50,7 +50,7 @@ def sparse_softmax_backward_reference( return (C_prod - masked_output * C_sum) / temperature -def sparse_multi_head_attention_reference( +def sparse_multi_head_attention_forward_reference( query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, @@ -66,14 +66,50 @@ def sparse_multi_head_attention_reference( key (torch.Tensor): The input key tensor of shape :math:`(B, H, N_{source}, E)`. value (torch.Tensor): The input value tensor of shape :math:`(B, H, N_{source}, E)`. mask (torch.Tensor): The mask tensor of shape :math:`(N_{target}, N_{source})`. - temperature (float): The softmax temperature which is set to :math:`\sqrt{E}` by default. + temperature (float): The softmax temperature which is set to :math:`\sqrt{N_{source}}` by default. Returns: torch.Tensor: Sparse multi-head attention output of shape :math:`(B, H, N_{target}, E)`. """ if np.isnan(temperature): - temperature = np.sqrt(query.shape[-1]) + temperature = np.sqrt(mask.shape[-1]) high_dims = ''.join([chr(ord('a') + i) for i in range(len(query.shape) - 2)]) - qk = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) - sm = sparse_softmax_forward_reference(qk, mask, temperature) - return torch.einsum(f'{high_dims}mn, {high_dims}nk -> {high_dims}mk', sm, value) + p = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) + s = sparse_softmax_forward_reference(p, mask, temperature) + return torch.einsum(f'{high_dims}mn, {high_dims}nk -> {high_dims}mk', s, value) + + +def sparse_multi_head_attention_backward_reference( + grad: torch.Tensor, + output: torch.Tensor, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + mask: torch.Tensor, + temperature: float = np.nan, +) -> torch.Tensor: + r"""Sparse multi-head attention backward reference function. + + Args: + grad (torch.Tensor): The gradient of output tensor. Shape: :math:`(B, H, N_{target}, E)`. + output (torch.Tensor): The output tensor of forward function. Shape: :math:`(B, H, N_{target}, E)`. + query (torch.Tensor): The input query tensor of forward function. Shape: :math:`(B, H, N_{target}, E)`. + key (torch.Tensor): The input key tensor of forward function. Shape: :math:`(B, H, N_{source}, E)`. + value (torch.Tensor): The input value tensor of forward function. Shape: :math:`(B, H, N_{source}, E)`. + mask (torch.Tensor): The mask tensor. Shape :math:`(N_{target}, N_{source})`. + temperature (float): The softmax temperature which is set to :math:`\sqrt{N_{source}}` by default. + + Returns: + Tuple: The gradient of query, key and value respectively.. + """ + if np.isnan(temperature): + temperature = np.sqrt(mask.shape[-1]) + high_dims = ''.join([chr(ord('a') + i) for i in range(len(query.shape) - 2)]) + p = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) + s = sparse_softmax_forward_reference(p, mask, temperature) + grad_v = torch.einsum(f'{high_dims}mn, {high_dims}mk -> {high_dims}nk', p, grad) + grad_s = torch.einsum(f'{high_dims}nk, {high_dims}mk -> {high_dims}mn', value, grad) + grad_p = sparse_softmax_backward_reference(grad_s, s, mask, temperature) + grad_q = torch.einsum(f'{high_dims}nk, {high_dims}mn -> {high_dims}mk', key, grad_p) + grad_k = torch.einsum(f'{high_dims}mk, {high_dims}mn -> {high_dims}nk', query, grad_p) + return grad_q, grad_k, grad_v diff --git a/test/bench/attention/attention.py b/test/bench/attention/attention.py index 127967da..5cdcbd0d 100644 --- a/test/bench/attention/attention.py +++ b/test/bench/attention/attention.py @@ -12,7 +12,7 @@ import matplotlib.pyplot as plt from sparta.nn import SparseAttention -from sparta.testing import block_mask, profile, sparse_multi_head_attention_reference +from sparta.testing import block_mask, profile, sparse_multi_head_attention_forward_reference Ns, Nt, E = 4096, 3072, 768 @@ -88,7 +88,7 @@ def prepare_data( data['key'].requires_grad = True data['value'].requires_grad = True inputs = [data['query'], data['key'], data['value']] - data['out'] = sparse_multi_head_attention_reference(*inputs, mask) + data['out'] = sparse_multi_head_attention_forward_reference(*inputs, mask) data['out'].backward(data['grad_out']) data['grad_query'] = data['query'].grad data['grad_key'] = data['key'].grad @@ -182,7 +182,7 @@ def profile_dense_attention( return 0., 0. def dense_attention(query, key, value): - return sparse_multi_head_attention_reference(query, key, value, mask) + return sparse_multi_head_attention_forward_reference(query, key, value, mask) return profile_attention(dense_attention, data) diff --git a/test/unit/test_seqlen_attention.py b/test/unit/test_seqlen_attention.py index c590072e..8c8f6080 100644 --- a/test/unit/test_seqlen_attention.py +++ b/test/unit/test_seqlen_attention.py @@ -5,7 +5,7 @@ import pytest from sparta.nn import SeqlenDynamicSparseAttention -from sparta.testing import sparse_multi_head_attention_reference +from sparta.testing import sparse_multi_head_attention_forward_reference def random_seqlens(B: int, N: int): @@ -38,7 +38,7 @@ def test_seqlen_attention_operator(B: int, H: int, N: int, E: int, global_mode: key = torch.rand(size=(B, H, N, E), dtype=torch.float32, device='cuda') value = torch.rand(size=(B, H, N, E), dtype=torch.float32, device='cuda') - target_out = sparse_multi_head_attention_reference( + target_out = sparse_multi_head_attention_forward_reference( query=query.reshape((-1, N, E)), key=key.reshape((-1, N, E)), value=value.reshape((-1, N, E)), diff --git a/test/unit/test_sparse_attention.py b/test/unit/test_sparse_attention.py index d8e16f51..f7b66265 100644 --- a/test/unit/test_sparse_attention.py +++ b/test/unit/test_sparse_attention.py @@ -7,7 +7,7 @@ import pytest from sparta.nn import SparseAttention -from sparta.testing import block_mask, sparse_multi_head_attention_reference +from sparta.testing import block_mask, sparse_multi_head_attention_forward_reference def get_params(): @@ -72,7 +72,7 @@ def test_sparse_attention_operator( for random_seed in range(3): # Test dynamic sparse query.grad, key.grad, value.grad = None, None, None - target_out = sparse_multi_head_attention_reference(query, key, value, mask) + target_out = sparse_multi_head_attention_forward_reference(query, key, value, mask) target_out.backward(grad_out) target_grad_query = query.grad From cba098b5a62bf8f442aa6f6aebeecc4fcb178dc8 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 27 Apr 2023 19:32:44 +0900 Subject: [PATCH 19/28] update FlashSparseAttentionBackwardKernel --- .../flash_sparse_attention_backward.cuh.j2 | 72 +++++++++---------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 index 663ff7ba..5de9f152 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 @@ -126,8 +126,16 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Save dK, dV #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dK[k * D + SMEM_TID_D])); - *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dV[k * D + SMEM_TID_D])); + tmp_float4.x = shared_dK[(SMEM_TID_D+0) * BS + k]; + tmp_float4.y = shared_dK[(SMEM_TID_D+1) * BS + k]; + tmp_float4.z = shared_dK[(SMEM_TID_D+2) * BS + k]; + tmp_float4.w = shared_dK[(SMEM_TID_D+3) * BS + k]; + ((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; + tmp_float4.x = shared_dV[(SMEM_TID_D+0) * BS + k]; + tmp_float4.y = shared_dV[(SMEM_TID_D+1) * BS + k]; + tmp_float4.z = shared_dV[(SMEM_TID_D+2) * BS + k]; + tmp_float4.w = shared_dV[(SMEM_TID_D+3) * BS + k]; + ((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; } } {# Load K, V, dK, dV #} @@ -143,8 +151,16 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( shared_V[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_V[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; shared_V[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; - *((float4*)(&shared_dK[k * D + SMEM_TID_D])) = *((float4*)(&dK[(col_idx * BS + k) * D + SMEM_TID_D])); - *((float4*)(&shared_dV[k * D + SMEM_TID_D])) = *((float4*)(&dV[(col_idx * BS + k) * D + SMEM_TID_D])); + tmp_float4 = ((float4*)(&dK[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + shared_dK[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; + shared_dK[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; + shared_dK[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; + shared_dK[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; + tmp_float4 = ((float4*)(&dV[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + shared_dV[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; + shared_dV[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; + shared_dV[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; + shared_dV[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; } last_col_idx = col_idx; } @@ -249,24 +265,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( } #pragma unroll for (int js = 0; js < TS; js++) { - {% if THREAD_SIZE_D_VALUE == 1 %} - shared_dV[(threadIdx.x + blockDim.x * js) * D + k] += frag_KV[js][0]; - {% elif THREAD_SIZE_D_VALUE == 2 %} - ((float2*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k]))[0] = - _add_float2( - ((float2*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k]))[0], - ((float2*)(&frag_KV[js][0]))[0] - ); - {% else %} #pragma unroll - for (int i = 0; i < TD; i += 4) { - ((float4*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k + i]))[0] = - _add_float4( - ((float4*)(&shared_dV[(threadIdx.x + blockDim.x * js) * D + k + i]))[0], - ((float4*)(&frag_KV[js][i]))[0] - ); + for (int i = 0; i < TD; i++) { + shared_dV[(k + i) * BS + threadIdx.x + blockDim.x * js] += frag_KV[js][i]; } - {% endif %} } __syncthreads(); } @@ -381,24 +383,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( } #pragma unroll for (int js = 0; js < TS; js++) { - {% if THREAD_SIZE_D_VALUE == 1 %} - shared_dK[(threadIdx.x + blockDim.x * js) * D + k] += frag_KV[js][0]; - {% elif THREAD_SIZE_D_VALUE == 2 %} - ((float2*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k]))[0] = - _add_float2( - ((float2*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k]))[0], - ((float2*)(&frag_KV[js][0]))[0] - ); - {% else %} #pragma unroll - for (int i = 0; i < TD; i += 4) { - ((float4*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k + i]))[0] = - _add_float4( - ((float4*)(&shared_dK[(threadIdx.x + blockDim.x * js) * D + k + i]))[0], - ((float4*)(&frag_KV[js][i]))[0] - ); + for (int i = 0; i < TD; i++) { + shared_dK[(k + i) * BS + threadIdx.x + blockDim.x * js] += frag_KV[js][i]; } - {% endif %} } __syncthreads(); } @@ -471,7 +459,15 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Save dK, dV for the last column #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dK[k * D + SMEM_TID_D])); - *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = *((float4*)(&shared_dV[k * D + SMEM_TID_D])); + tmp_float4.x = shared_dK[(SMEM_TID_D+0) * BS + k]; + tmp_float4.y = shared_dK[(SMEM_TID_D+1) * BS + k]; + tmp_float4.z = shared_dK[(SMEM_TID_D+2) * BS + k]; + tmp_float4.w = shared_dK[(SMEM_TID_D+3) * BS + k]; + ((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; + tmp_float4.x = shared_dV[(SMEM_TID_D+0) * BS + k]; + tmp_float4.y = shared_dV[(SMEM_TID_D+1) * BS + k]; + tmp_float4.z = shared_dV[(SMEM_TID_D+2) * BS + k]; + tmp_float4.w = shared_dV[(SMEM_TID_D+3) * BS + k]; + ((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; } } From 6f8b721f8914ca89e11c17998539417d0a47aa70 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 4 May 2023 13:34:55 +0900 Subject: [PATCH 20/28] update FlashSparseAttentionBackwardKernel --- .../flash_sparse_attention_backward.cuh.j2 | 180 +++++++++++------- 1 file changed, 108 insertions(+), 72 deletions(-) diff --git a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 index 5de9f152..7568cfe2 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 @@ -98,7 +98,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( float4 tmp_float4; float frag_QO[TT][TD]; - float frag_KV[TS][TD]; + float frag_KV[TD][TS]; float frag_P[TT][TS]; float frag_S[TT][TS]; @@ -179,18 +179,30 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( #pragma unroll for (int k = 0; k < D; k += TD) { #pragma unroll - for (int i = 0; i < TD; i++) { + for (int jt = 0; jt < TT; jt++) { + {% if THREAD_SIZE_D_VALUE == 1 %} + frag_QO[jt][0] = shared_Q[(threadIdx.y * TT + jt) * D + k]; + {% elif THREAD_SIZE_D_VALUE == 2 %} + *((float2*)(&frag_QO[jt][0])) = *((float2*)(&shared_Q[(threadIdx.y * TT + jt) * D + k])); + {% else %} #pragma unroll - for (int jt = 0; jt < TT; jt++) { - frag_QO[jt][i] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i] * temperature; + for (int i = 0; i < TD; i += 4) { + *((float4*)(&frag_QO[jt][i])) = *((float4*)(&shared_Q[(threadIdx.y * TT + jt) * D + k + i])); } + {% endif %} } #pragma unroll for (int i = 0; i < TD; i++) { + {% if THREAD_SIZE_S_VALUE == 1 %} + frag_KV[i][0] = shared_K[(k + i) * BS + threadIdx.x * TS]; + {% elif THREAD_SIZE_S_VALUE == 2 %} + *((float2*)(&frag_KV[i][0])) = *((float2*)(&shared_K[(k + i) * BS + threadIdx.x * TS])); + {% else %} #pragma unroll - for (int js = 0; js < TS; js++) { - frag_KV[js][i] = shared_K[(k + i) * BS + threadIdx.x + blockDim.x * js]; + for (int js = 0; js < TS; js += 4) { + *((float4*)(&frag_KV[i][js])) = *((float4*)(&shared_K[(k + i) * BS + threadIdx.x * TS + js])); } + {% endif %} } #pragma unroll for (int js = 0; js < TS; js++) { @@ -198,7 +210,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int jt = 0; jt < TT; jt++) { #pragma unroll for (int i = 0; i < TD; i++) { - frag_P[jt][js] += frag_QO[jt][i] * frag_KV[js][i]; + frag_P[jt][js] += frag_QO[jt][i] * frag_KV[i][js]; } } } @@ -215,12 +227,12 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Calc S = exp(P - M) / L #} #pragma unroll for (int jt = 0; jt < TT; jt++) { - block_row_idx = (threadIdx.y + blockDim.y * jt) * 2; + block_row_idx = (threadIdx.y * TT + jt) * 2; row_max = shared_ML[block_row_idx]; row_sum = shared_ML[block_row_idx + 1]; #pragma unroll for (int js = 0; js < TS; js++) { - frag_S[jt][js] = expf(frag_P[jt][js] - row_max) / row_sum; + frag_S[jt][js] = expf(frag_P[jt][js] * temperature - row_max) / row_sum; } } @@ -229,27 +241,37 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { *((float4*)(&shared_O[k * D + SMEM_TID_D])) = *((float4*)(&dO[(row_idx * BT + k) * D + SMEM_TID_D])); } + __syncthreads(); + + {# Initialize dS #} + #pragma unroll + for (int js = 0; js < TS; js++) { + #pragma unroll + for (int jt = 0; jt < TT; jt++) { + frag_P[jt][js] = 0; + } + } - {# Calc dV = dV + S^T dO #} + {# Calc dV = dV + S^T dO, dS = dO V^T #} #pragma unroll for (int kk = 0, k = threadIdx.y * TD; kk < D; k = (k + TD) % D, kk += TD) { #pragma unroll for (int i = 0; i < TD; i++) { #pragma unroll for (int js = 0; js < TS; js++) { - frag_KV[js][i] = 0; + frag_KV[i][js] = 0; } } #pragma unroll for (int jt = 0; jt < TT; jt++) { {% if THREAD_SIZE_D_VALUE == 1 %} - frag_QO[jt][0] = shared_O[(threadIdx.y + blockDim.y * jt) * D + k]; + frag_QO[jt][0] = shared_O[(threadIdx.y * TT + jt) * D + k]; {% elif THREAD_SIZE_D_VALUE == 2 %} - *((float2*)(&frag_QO[jt][0])) = *((float2*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k])); + *((float2*)(&frag_QO[jt][0])) = *((float2*)(&shared_O[(threadIdx.y * TT + jt) * D + k])); {% else %} #pragma unroll for (int i = 0; i < TD; i += 4) { - *((float4*)(&frag_QO[jt][i])) = *((float4*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k + i])); + *((float4*)(&frag_QO[jt][i])) = *((float4*)(&shared_O[(threadIdx.y * TT + jt) * D + k + i])); } {% endif %} } @@ -259,51 +281,44 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int js = 0; js < TS; js++) { #pragma unroll for (int jt = 0; jt < TT; jt++) { - frag_KV[js][i] += frag_S[jt][js] * frag_QO[jt][i]; + frag_KV[i][js] += frag_S[jt][js] * frag_QO[jt][i]; } } } #pragma unroll - for (int js = 0; js < TS; js++) { - #pragma unroll - for (int i = 0; i < TD; i++) { - shared_dV[(k + i) * BS + threadIdx.x + blockDim.x * js] += frag_KV[js][i]; - } - } - __syncthreads(); - } - - {# Initialize dS #} - #pragma unroll - for (int js = 0; js < TS; js++) { - #pragma unroll - for (int jt = 0; jt < TT; jt++) { - frag_P[jt][js] = 0; - } - } - - {# Calc dS = dO V^T #} - #pragma unroll - for (int k = 0; k < D; k += TD) { - #pragma unroll - for (int jt = 0; jt < TT; jt++) { - {% if THREAD_SIZE_D_VALUE == 1 %} - frag_QO[jt][0] = shared_O[(threadIdx.y + blockDim.y * jt) * D + k]; - {% elif THREAD_SIZE_D_VALUE == 2 %} - *((float2*)(&frag_QO[jt][0])) = *((float2*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k])); + for (int i = 0; i < TD; i++) { + {% if THREAD_SIZE_S_VALUE == 1 %} + shared_dV[(k + i) * BS + threadIdx.x * TS] += frag_KV[i][0]; + {% elif THREAD_SIZE_S_VALUE == 2 %} + ((float2*)(&shared_dV[(k + i) * BS + threadIdx.x * TS]))[0] = + _add_float2( + ((float2*)(&shared_dV[(k + i) * BS + threadIdx.x * TS]))[0], + ((float2*)(&frag_KV[i][0]))[0] + ); {% else %} #pragma unroll - for (int i = 0; i < TD; i += 4) { - *((float4*)(&frag_QO[jt][i])) = *((float4*)(&shared_O[(threadIdx.y + blockDim.y * jt) * D + k + i])); + for (int js = 0; js < TS; js += 4) { + ((float4*)(&shared_dV[(k + i) * BS + threadIdx.x * TS + js]))[0] = + _add_float4( + ((float4*)(&shared_dV[(k + i) * BS + threadIdx.x * TS + js]))[0], + ((float4*)(&frag_KV[i][js]))[0] + ); } {% endif %} } + __syncthreads(); #pragma unroll for (int i = 0; i < TD; i++) { + {% if THREAD_SIZE_S_VALUE == 1 %} + frag_KV[i][0] = shared_V[(k + i) * BS + threadIdx.x * TS]; + {% elif THREAD_SIZE_S_VALUE == 2 %} + *((float2*)(&frag_KV[i][0])) = *((float2*)(&shared_V[(k + i) * BS + threadIdx.x * TS])); + {% else %} #pragma unroll - for (int js = 0; js < TS; js++) { - frag_KV[js][i] = shared_V[(k + i) * BS + threadIdx.x + blockDim.x * js]; + for (int js = 0; js < TS; js += 4) { + *((float4*)(&frag_KV[i][js])) = *((float4*)(&shared_V[(k + i) * BS + threadIdx.x * TS + js])); } + {% endif %} } #pragma unroll for (int js = 0; js < TS; js++) { @@ -311,7 +326,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int jt = 0; jt < TT; jt++) { #pragma unroll for (int i = 0; i < TD; i++) { - frag_P[jt][js] += frag_QO[jt][i] * frag_KV[js][i]; + frag_P[jt][js] += frag_QO[jt][i] * frag_KV[i][js]; } } } @@ -327,6 +342,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( ((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D]))[0] ); } + __syncthreads(); {# Calc dP = S (dS - sum_j(dO)) #} #pragma unroll @@ -334,7 +350,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( row_sum = 0.0f; #pragma unroll for (int i = 0; i < TD; i++) { - row_sum += shared_O[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i]; + row_sum += shared_O[(threadIdx.y * TT + jt) * D + threadIdx.x * TD + i]; } #pragma unroll for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { @@ -346,6 +362,13 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( } } + {# Load dQ #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_O[k * D + SMEM_TID_D])) = *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])); + } + __syncthreads(); + {# Calc dK = dK + dP^T Q #} #pragma unroll for (int kk = 0, k = threadIdx.y * TD; kk < D; k = (k + TD) % D, kk += TD) { @@ -353,21 +376,21 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int i = 0; i < TD; i++) { #pragma unroll for (int js = 0; js < TS; js++) { - frag_KV[js][i] = 0; + frag_KV[i][js] = 0; } } #pragma unroll for (int jt = 0; jt < TT; jt++) { {% if THREAD_SIZE_D_VALUE == 1 %} - frag_QO[jt][0] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]; + frag_QO[jt][0] = shared_Q[(threadIdx.y * TT + jt) * D + k]; {% elif THREAD_SIZE_D_VALUE == 2 %} *((float2*)(&frag_QO[jt][0])) = - *((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k])); + *((float2*)(&shared_Q[(threadIdx.y * TT + jt) * D + k])); {% else %} #pragma unroll for (int i = 0; i < TD; i += 4) { *((float4*)(&frag_QO[jt][i])) = - *((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i])); + *((float4*)(&shared_Q[(threadIdx.y * TT + jt) * D + k + i])); } {% endif %} } @@ -377,27 +400,34 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int js = 0; js < TS; js++) { #pragma unroll for (int jt = 0; jt < TT; jt++) { - frag_KV[js][i] += frag_P[jt][js] * frag_QO[jt][i]; + frag_KV[i][js] += frag_P[jt][js] * frag_QO[jt][i]; } } } #pragma unroll - for (int js = 0; js < TS; js++) { + for (int i = 0; i < TD; i++) { + {% if THREAD_SIZE_S_VALUE == 1 %} + shared_dK[(k + i) * BS + threadIdx.x * TS] += frag_KV[i][0]; + {% elif THREAD_SIZE_S_VALUE == 2 %} + ((float2*)(&shared_dK[(k + i) * BS + threadIdx.x * TS]))[0] = + _add_float2( + ((float2*)(&shared_dK[(k + i) * BS + threadIdx.x * TS]))[0], + ((float2*)(&frag_KV[i][0]))[0] + ); + {% else %} #pragma unroll - for (int i = 0; i < TD; i++) { - shared_dK[(k + i) * BS + threadIdx.x + blockDim.x * js] += frag_KV[js][i]; + for (int js = 0; js < TS; js += 4) { + ((float4*)(&shared_dK[(k + i) * BS + threadIdx.x * TS + js]))[0] = + _add_float4( + ((float4*)(&shared_dK[(k + i) * BS + threadIdx.x * TS + js]))[0], + ((float4*)(&frag_KV[i][js]))[0] + ); } + {% endif %} } __syncthreads(); } - {# Load dQ #} - #pragma unroll - for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])); - } - __syncthreads(); - {# Calc dQ = dQ + dP K #} #pragma unroll for (int kk = 0, k = threadIdx.x * TD; kk < D; k = (k + TD) % D, kk += TD) { @@ -410,10 +440,16 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( } #pragma unroll for (int i = 0; i < TD; i++) { + {% if THREAD_SIZE_S_VALUE == 1 %} + frag_KV[i][0] = shared_K[(k + i) * BS + threadIdx.x * TS]; + {% elif THREAD_SIZE_S_VALUE == 2 %} + *((float2*)(&frag_KV[i][0])) = *((float2*)(&shared_K[(k + i) * BS + threadIdx.x * TS])); + {% else %} #pragma unroll - for (int js = 0; js < TS; js++) { - frag_KV[js][i] = shared_K[(k + i) * BS + threadIdx.x + blockDim.x * js]; + for (int js = 0; js < TS; js += 4) { + *((float4*)(&frag_KV[i][js])) = *((float4*)(&shared_K[(k + i) * BS + threadIdx.x * TS + js])); } + {% endif %} } #pragma unroll for (int i = 0; i < TD; i++) { @@ -421,26 +457,26 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int jt = 0; jt < TT; jt++) { #pragma unroll for (int js = 0; js < TS; js++) { - frag_QO[jt][i] += frag_P[jt][js] * frag_KV[js][i]; + frag_QO[jt][i] += frag_P[jt][js] * frag_KV[i][js]; } } } #pragma unroll for (int jt = 0; jt < TT; jt++) { {% if THREAD_SIZE_D_VALUE == 1 %} - shared_Q[(threadIdx.y + blockDim.y * jt) * D + k] += frag_QO[jt][0]; + shared_O[(threadIdx.y * TT + jt) * D + k] += frag_QO[jt][0]; {% elif THREAD_SIZE_D_VALUE == 2 %} - ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]))[0] = + ((float2*)(&shared_O[(threadIdx.y * TT + jt) * D + k]))[0] = _add_float2( - ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k]))[0], + ((float2*)(&shared_O[(threadIdx.y * TT + jt) * D + k]))[0], ((float2*)(&frag_QO[jt][0]))[0] ); {% else %} #pragma unroll for (int i = 0; i < TD; i += 4) { - ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]))[0] = + ((float4*)(&shared_O[(threadIdx.y * TT + jt) * D + k + i]))[0] = _add_float4( - ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]))[0], + ((float4*)(&shared_O[(threadIdx.y * TT + jt) * D + k + i]))[0], ((float4*)(&frag_QO[jt][i]))[0] ); } @@ -452,7 +488,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Save dQ #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k * D + SMEM_TID_D])); + *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_O[k * D + SMEM_TID_D])); } } From 7aaf4c5bdc37605dd76d1880cf2c65b73f784f6f Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 9 May 2023 18:28:31 +0900 Subject: [PATCH 21/28] Flash Attention fp16 forward version 1 --- ..._sparse_attention_backward_float32.cuh.j2} | 23 +- ...sh_sparse_attention_forward_float16.cuh.j2 | 283 ++++++++++++++++++ ...h_sparse_attention_forward_float32.cuh.j2} | 39 +-- 3 files changed, 314 insertions(+), 31 deletions(-) rename sparta/kernels/templates/{flash_sparse_attention_backward.cuh.j2 => flash_sparse_attention_backward_float32.cuh.j2} (97%) create mode 100644 sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 rename sparta/kernels/templates/{flash_sparse_attention_forward.cuh.j2 => flash_sparse_attention_forward_float32.cuh.j2} (93%) diff --git a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 similarity index 97% rename from sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 rename to sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 index 7568cfe2..e924653a 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 @@ -1,20 +1,20 @@ {# Copyright (c) Microsoft Corporation. #} {# Licensed under the MIT license. #} -{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE <= 32 #} - -{% set THREAD_SIZE_D_VALUE = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} -{% set THREAD_SIZE_T_VALUE = BLOCK_SIZE_T_VALUE // WARP_REDUCE_SIZE %} - -{% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * WARP_REDUCE_SIZE %} +{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE = Bs / Ts <= 32 #} +{% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * BLOCK_SIZE_T_VALUE // THREAD_SIZE_T_VALUE %} +{% set THREAD_SIZE_S_TO_D = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} const int BS = {{ BLOCK_SIZE_S_VALUE }}; const int BT = {{ BLOCK_SIZE_T_VALUE }}; const int D = {{ GLOBAL_SIZE_D_VALUE }}; const int TS = {{ THREAD_SIZE_S_VALUE }}; - const int TT = {{ THREAD_SIZE_T_VALUE }}; -const int TD = {{ THREAD_SIZE_D_VALUE }};{# D * TS >= BS #} +const int TD = {{ THREAD_SIZE_D_VALUE }};{# D / Td >= Bs / Ts, D / Td >= Bt / Tt #} +const int SD = {{ THREAD_SIZE_S_TO_D }}; + +const int SMEM_THREADS_D = D / 4; +const int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; __device__ __forceinline__ float2 _add_float2(float2 x, float2 y) \ { \ @@ -89,9 +89,6 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# __shared__ float shared_ML[BT * 2]; #} float* shared_ML = shared_O; - int SMEM_THREADS_D = D / 4; - int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; - int tid = threadIdx.y * blockDim.x + threadIdx.x; int SMEM_TID_N = tid / SMEM_THREADS_D; int SMEM_TID_D = tid % SMEM_THREADS_D * 4; @@ -349,8 +346,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( for (int jt = 0; jt < TT; jt++) { row_sum = 0.0f; #pragma unroll - for (int i = 0; i < TD; i++) { - row_sum += shared_O[(threadIdx.y * TT + jt) * D + threadIdx.x * TD + i]; + for (int i = 0; i < SD; i++) { + row_sum += shared_O[(threadIdx.y * TT + jt) * D + threadIdx.x * SD + i]; } #pragma unroll for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 new file mode 100644 index 00000000..b91b027d --- /dev/null +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 @@ -0,0 +1,283 @@ +{# Copyright (c) Microsoft Corporation. #} +{# Licensed under the MIT license. #} + +#include +#include + +using namespace nvcuda; + +{% set WARP_SIZE = 32 %} +{% set FRAG_SIZE = 256 %} +{% set QK_WARP_SIZE_N_VALUE = FRAG_SIZE // QK_WARP_SIZE_M_VALUE %} +{% set SV_WARP_SIZE_N_VALUE = FRAG_SIZE // SV_WARP_SIZE_M_VALUE %} +{% set BLOCK_SIZE = BLOCK_SIZE_T_VALUE * BLOCK_SIZE_S_VALUE %} +{% set THREAD_SIZE = BLOCK_SIZE // THREADS_PER_BLOCK %}{# 8 <= THREAD_SIZE <= BLOCK_SIZE_S_VALUE #} +{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE %}{# WARP_REDUCE_SIZE <= WARP_SIZE #} + +const int BS = {{ BLOCK_SIZE_S_VALUE }}; +const int BT = {{ BLOCK_SIZE_T_VALUE }}; +const int D = {{ GLOBAL_SIZE_D_VALUE }}; +const int QK_WARP_M = {{ QK_WARP_SIZE_M_VALUE }}; +const int QK_WARP_N = {{ QK_WARP_SIZE_N_VALUE }}; +const int QK_WARP_K = 16; +const int SV_WARP_M = {{ SV_WARP_SIZE_M_VALUE }}; +const int SV_WARP_N = {{ SV_WARP_SIZE_N_VALUE }}; +const int SV_WARP_K = 16; + +const int B = {{ BLOCK_SIZE }}; +const int T = {{ THREAD_SIZE }}; +const int THREADS = {{ THREADS_PER_BLOCK }};{# THREADS_PER_BLOCK >= WARP_SIZE #} +const int WARPS = THREADS / {{ WARP_SIZE }}; +const int SD = T * D / BS; + +const int SMEM_THREADS_D = D / 8; +const int SMEM_THREADS_N = THREADS / SMEM_THREADS_D; +const int QK_WARPS_N = BS / QK_WARP_N; +const int QK_STRIDE_M = QK_WARP_M * (WARPS / QK_WARPS_N); +const int SV_WARPS_N = D / SV_WARP_N; +const int SV_STRIDE_M = SV_WARP_M * (WARPS / SV_WARPS_N); + +const int D_PAD = 8; +const int S_PAD = 8; + +__device__ __forceinline__ half max(half x, half y) \ +{ \ + return x > y ? x : y; \ +} + +extern "C" { + +__global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( + half* Q, + half* K, + half* V, + half* O, + half* ML, + {# unsigned char* mask, #} + uint* block_idx, + uint Ns, + uint Nt, + uint block_nnz +) { + Q += Nt * D * blockIdx.x; + K += Ns * D * blockIdx.x; + V += Ns * D * blockIdx.x; + O += Nt * D * blockIdx.x; + ML += Nt * 2 * blockIdx.x; + + uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; + uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; + + __shared__ half shared_Q[BT][D + D_PAD]; + __shared__ half shared_P[BT][BS + S_PAD]; + __shared__ half shared_K[BS][D + D_PAD]; + __shared__ half shared_V[BS][D + D_PAD]; + + int SMEM_TID_N = threadIdx.x / SMEM_THREADS_D; + int SMEM_TID_D = threadIdx.x % SMEM_THREADS_D * 8; + + int wid = threadIdx.x / {{ WARP_SIZE }}; + int qk_wx = wid % QK_WARPS_N; + int qk_wy = wid / QK_WARPS_N; + int sv_wx = wid % SV_WARPS_N; + int sv_wy = wid / SV_WARPS_N; + int tx = threadIdx.x % {{ WARP_REDUCE_SIZE }}; + int ty = threadIdx.x / {{ WARP_REDUCE_SIZE }}; + + wmma::fragment frag_Q; + wmma::fragment frag_K; + wmma::fragment frag_P; + wmma::fragment frag_S; + wmma::fragment frag_V; + wmma::fragment frag_O; + float4 tmp_float4; + half frag[T]; + + float temperature = __frsqrt_rn((float)Ns); + half row_max; + half row_sum; + half row_sum_new; + half seg_max; + half seg_sum; + half row_coef; + half seg_coef; + int block_row_idx; + + int last_col_idx = -1; + {# BCSC #} + for (int block = 0; block < block_nnz; block++) { + uint idx = block_idx[block]; + int row_idx = idx & 0xffff; + int col_idx = idx >> 16; + // if (blockIdx.x == 0 && threadIdx.x == 0) + // printf("#%d: (%d, %d)\n", block, row_idx, col_idx); + + {# Load Q #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_Q[k][SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + } + if (col_idx != last_col_idx) { + {# Load K #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&shared_K[k][SMEM_TID_D])) = *((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D])); + } + {# Load V #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&shared_V[k][SMEM_TID_D])) = *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + } + last_col_idx = col_idx; + } + __syncthreads(); + + {# Calc P = Q K^T #} + #pragma unroll + for (int jt = 0; jt < BT; jt += QK_STRIDE_M) { + wmma::fill_fragment(frag_P, 0.0); + #pragma unroll + for (int k = 0; k < D; k += QK_WARP_K) { + wmma::load_matrix_sync(frag_Q, &shared_Q[jt + qk_wy * QK_WARP_M][k], D + D_PAD); + wmma::load_matrix_sync(frag_K, &shared_K[qk_wx * QK_WARP_N][k], D + D_PAD); + wmma::mma_sync(frag_P, frag_Q, frag_K, frag_P); + } + for(int i = 0; i < {{ FRAG_SIZE }}; i++) { + frag_P.x[i] *= temperature; + } + wmma::store_matrix_sync( + &shared_P[jt + qk_wy * QK_WARP_M][qk_wx * QK_WARP_N], frag_P, BS + S_PAD, wmma::mem_row_major); + } + __syncthreads(); + /*if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { + printf("P[0][0] = %f\n", (float)(shared_P[0][0])); + printf("P[0][1] = %f\n", (float)(shared_P[0][1])); + printf("P[1][0] = %f\n", (float)(shared_P[1][0])); + printf("P[1][1] = %f\n", (float)(shared_P[1][1])); + }*/ + + {# Load M, L, P #} + #pragma unroll + for (int jt = threadIdx.x * 4; jt < BT; jt += {{ THREADS_PER_BLOCK * 4 }}) { + tmp_float4 = ((float4*)(&ML[(row_idx * BT + jt) * 2]))[0]; + ((float*)(&shared_Q[jt + 0][0]))[0] = tmp_float4.x; + ((float*)(&shared_Q[jt + 1][0]))[0] = tmp_float4.y; + ((float*)(&shared_Q[jt + 2][0]))[0] = tmp_float4.z; + ((float*)(&shared_Q[jt + 3][0]))[0] = tmp_float4.w; + } + __syncthreads(); + #pragma unroll + for (int i = 0; i < T; i += 8) { + *((float4*)(&frag[i])) = *((float4*)(&shared_P[ty][tx * T + i])); + } + + {# Calc M~ = max_j(P) #} + seg_max = (half)(-1000.0); + #pragma unroll + for (int i = 0; i < T; i++) { + seg_max = max(seg_max, frag[i]); + } + #pragma unroll + for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { + seg_max = max(seg_max, __shfl_xor_sync(WARP_MASK, seg_max, offset)); + } + {# Calc S = exp(P - M~) #} + #pragma unroll + for (int i = 0; i < T; i++) { + frag[i] = hexp(frag[i] - seg_max); + } + {# Calc L~ = sum_j(P) #} + seg_sum = (half)(0.0f); + #pragma unroll + for (int i = 0; i < T; i++) { + seg_sum += frag[i]; + } + #pragma unroll + for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { + seg_sum += __shfl_down_sync(WARP_MASK, seg_sum, offset); + } + {# Calc M' = max(M, M~), L' = exp(M - M') * L + exp(M~ - M') * L~ #} + if (tx == 0) { + row_max = shared_Q[ty][0]; + row_sum = shared_Q[ty][1]; + if (row_max < seg_max) { + shared_Q[ty][0] = seg_max; + row_coef = hexp(row_max - seg_max); + row_sum_new = row_coef * row_sum + seg_sum; + row_coef *= row_sum / row_sum_new; + seg_coef = (half)(1.0f) / row_sum_new; + } else { + seg_coef = hexp(seg_max - row_max); + row_sum_new = row_sum + seg_coef * seg_sum; + row_coef = row_sum / row_sum_new; + seg_coef /= row_sum_new; + } + shared_Q[ty][1] = row_sum_new; + } + row_coef = __shfl_sync(WARP_MASK, row_coef, WARP_OFFSET); + seg_coef = __shfl_sync(WARP_MASK, seg_coef, WARP_OFFSET); + {# Calc O' = L / L' * exp(M - M') * O, S' = exp(M~ - M') / L' * S #} + #pragma unroll + for (int i = 0; i < T; i++) { + frag[i] *= seg_coef; + } + __syncthreads(); + + {# Save M, L, S #} + #pragma unroll + for (int jt = threadIdx.x * 4; jt < BT; jt += {{ THREADS_PER_BLOCK * 4 }}) { + tmp_float4.x = ((float*)(&shared_Q[jt + 0][0]))[0]; + tmp_float4.y = ((float*)(&shared_Q[jt + 1][0]))[0]; + tmp_float4.z = ((float*)(&shared_Q[jt + 2][0]))[0]; + tmp_float4.w = ((float*)(&shared_Q[jt + 3][0]))[0]; + ((float4*)(&ML[(row_idx * BT + jt) * 2]))[0] = tmp_float4; + } + #pragma unroll + for (int i = 0; i < T; i += 8) { + *((float4*)(&shared_P[ty][tx * T + i])) = *((float4*)(&frag[i])); + } + __syncthreads(); + /*if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { + printf("S[0][0] = %f\n", (float)(shared_P[0][0])); + printf("S[0][1] = %f\n", (float)(shared_P[0][1])); + printf("S[1][0] = %f\n", (float)(shared_P[1][0])); + printf("S[1][1] = %f\n", (float)(shared_P[1][1])); + }*/ + + {# Load O #} + #pragma unroll + for (int i = 0; i < SD; i += 8) { + *((float4*)(&frag[0])) = *((float4*)(&O[(row_idx * BT + ty) * D + tx * SD + i])); + #pragma unroll + for (int j = 0; j < 8; j++) { + frag[j] *= row_coef; + } + *((float4*)(&shared_Q[ty][tx * SD + i])) = *((float4*)(&frag[0])); + } + __syncthreads(); + + {# Calc O = O' + S' V #} + #pragma unroll + for (int jt = 0; jt < BT; jt += SV_STRIDE_M) { + wmma::load_matrix_sync( + frag_O, &shared_Q[jt + sv_wy * SV_WARP_M][sv_wx * SV_WARP_N], D + D_PAD, wmma::mem_row_major); + #pragma unroll + for (int k = 0; k < BS; k += SV_WARP_K) { + wmma::load_matrix_sync(frag_S, &shared_P[jt + sv_wy * SV_WARP_M][k], BS + S_PAD); + wmma::load_matrix_sync(frag_V, &shared_V[k][sv_wx * SV_WARP_N], D + D_PAD); + wmma::mma_sync(frag_O, frag_S, frag_V, frag_O); + } + wmma::store_matrix_sync( + &shared_Q[jt + sv_wy * SV_WARP_M][sv_wx * SV_WARP_N], frag_O, D + D_PAD, wmma::mem_row_major); + } + __syncthreads(); + + {# Save O #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k][SMEM_TID_D])); + } + } +} + +} // extern "C" diff --git a/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 similarity index 93% rename from sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 rename to sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 index 59a88365..9314b0fb 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 @@ -1,19 +1,20 @@ {# Copyright (c) Microsoft Corporation. #} {# Licensed under the MIT license. #} -{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE <= 32 #} - -{% set THREAD_SIZE_D_VALUE = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} - +{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE_S_VALUE %}{# WARP_REDUCE_SIZE = Bs / Ts <= 32 #} {% set THREADS_PER_BLOCK = WARP_REDUCE_SIZE * BLOCK_SIZE_T_VALUE // THREAD_SIZE_T_VALUE %} +{% set THREAD_SIZE_S_TO_D = GLOBAL_SIZE_D_VALUE // WARP_REDUCE_SIZE %} const int BS = {{ BLOCK_SIZE_S_VALUE }}; const int BT = {{ BLOCK_SIZE_T_VALUE }}; const int D = {{ GLOBAL_SIZE_D_VALUE }}; const int TS = {{ THREAD_SIZE_S_VALUE }}; const int TT = {{ THREAD_SIZE_T_VALUE }}; +const int TD = {{ THREAD_SIZE_D_VALUE }};{# D / Td >= Bs / Ts #} +const int SD = {{ THREAD_SIZE_S_TO_D }}; -const int TD = {{ THREAD_SIZE_D_VALUE }};{# D * TS >= BS #} +const int SMEM_THREADS_D = D / 4; +const int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; __device__ __forceinline__ float2 _add_float2(float2 x, float2 y) \ { \ @@ -82,9 +83,6 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( {# __shared__ float shared_ML[BT * 2]; #} float* shared_ML = shared_Q; - int SMEM_THREADS_D = D / 4; - int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; - int tid = threadIdx.y * blockDim.x + threadIdx.x; int SMEM_TID_N = tid / SMEM_THREADS_D; int SMEM_TID_D = tid % SMEM_THREADS_D * 4; @@ -153,7 +151,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( for (int i = 0; i < TD; i++) { #pragma unroll for (int jt = 0; jt < TT; jt++) { - frag_QO[jt][i] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i] * temperature; + frag_QO[jt][i] = shared_Q[(threadIdx.y + blockDim.y * jt) * D + k + i]; } } #pragma unroll @@ -189,6 +187,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( #pragma unroll for (int jt = 0; jt < TT; jt++) { + #pragma unroll + for (int js = 0; js < TS; js++) { + frag_P[jt][js] *= temperature; + } {# Calc M~ = max_j(P) #} seg_max = -100000.0; #pragma unroll @@ -254,25 +256,26 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( for (int jt = tid * 2; jt < BT; jt += {{ THREADS_PER_BLOCK * 2 }}) { *((float4*)(&ML[(row_idx * BT + jt) * 2])) = *((float4*)(&shared_ML[jt * 2])); } + __syncthreads(); {# Load O #} #pragma unroll for (int jt = 0; jt < TT; jt++) { - {% if THREAD_SIZE_D_VALUE == 1 %} - shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD] = - O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD] * frag_ML[jt]; - {% elif THREAD_SIZE_D_VALUE == 2 %} - ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD]))[0] = + {% if THREAD_SIZE_S_TO_D == 1 %} + shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD] = + O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD] * frag_ML[jt]; + {% elif THREAD_SIZE_S_TO_D == 2 %} + ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD]))[0] = _scale_float2( - ((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD]))[0], + ((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD]))[0], frag_ML[jt] ); {% else %} #pragma unroll - for (int i = 0; i < TD; i += 4) { - ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i]))[0] = + for (int i = 0; i < SD; i += 4) { + ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD + i]))[0] = _scale_float4( - ((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * TD + i]))[0], + ((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD + i]))[0], frag_ML[jt] ); } From dcc4197be566ceab67d340fb4d81b1a551e43b0b Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Wed, 10 May 2023 01:17:17 +0900 Subject: [PATCH 22/28] Flash Attention fp16 forward version 2: pad (bank conflict) & fp32-softmax --- sparta/kernels/__init__.py | 2 +- sparta/kernels/attention.py | 158 ++++++++++++++---- sparta/kernels/kernel_base.py | 14 +- ...sh_sparse_attention_forward_float16.cuh.j2 | 120 ++++++------- 4 files changed, 191 insertions(+), 103 deletions(-) diff --git a/sparta/kernels/__init__.py b/sparta/kernels/__init__.py index 4b47e5d2..1cd59160 100644 --- a/sparta/kernels/__init__.py +++ b/sparta/kernels/__init__.py @@ -4,4 +4,4 @@ from sparta.kernels.kernel_base import KernelBase, SparsityAttr, KernelGroup from sparta.kernels.matmul import SparseMatMulKernel, SparTASparseMatMulKernel, OpenAISparseMatMulKernel from sparta.kernels.softmax import SparseSoftmaxForwardKernel, SparTASparseSoftmaxForwardKernel, SparseSoftmaxBackwardKernel, SparTASparseSoftmaxBackwardKernel -from sparta.kernels.attention import FlashSparseAttentionForwardKernel, FlashSparseAttentionBackwardKernel +from sparta.kernels.attention import FlashSparseAttentionFP32ForwardKernel, FlashSparseAttentionFP32BackwardKernel, FlashSparseAttentionFP16ForwardKernel, FlashSparseAttentionFP16BackwardKernel diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py index 457776fc..42dfa68b 100644 --- a/sparta/kernels/attention.py +++ b/sparta/kernels/attention.py @@ -2,15 +2,20 @@ # Licensed under the MIT license. import io -import textwrap +import abc +import warnings import importlib.resources as res -from typing import Any, Dict, Tuple, Optional +from typing import Any, Dict, Tuple import torch import jinja2 import numpy as np import pandas as pd +from sparta import __env_ready__ +if __env_ready__: + from pycuda.compiler import SourceModule + from sparta.tuning import TunableItemCfg from sparta.kernels import KernelBase, SparsityAttr, templates, look_up_tables from sparta.testing import sparse_multi_head_attention_forward_reference, sparse_multi_head_attention_backward_reference @@ -20,11 +25,11 @@ class FlashSparseAttentionKernel(KernelBase): __lut_shape__ = (64 * 12, 1024, 1024, 64) # BxH, Nt, Ns, D __algo__ = 'flash' + __dtype__ = '' __direction__ = '' - def __init__(self, buffer: torch.Tensor, dtype: str = 'float'): + def __init__(self, buffer: torch.Tensor): self._buffer = buffer - self._dtype = dtype super().__init__() def _add_parameters(self): @@ -39,9 +44,39 @@ def _add_parameters(self): is_tunable=True, search_space=TunableItemCfg('choice', [8, 16, 32, 64, 128, 256]), ) + self.attr = SparsityAttr(self, 'BLOCK_SIZE_T_VALUE', 'BLOCK_SIZE_S_VALUE', BCSR=False, BCSC=True) + + @abc.abstractmethod + def _check_shape(self, Nt: int, Ns: int, D: int): + '''Check if input shape is valid.''' + + def get_block_shape(self): + Bt = self.get_parameter('BLOCK_SIZE_T_VALUE') + Bs = self.get_parameter('BLOCK_SIZE_S_VALUE') + return Bt, Bs + + def get_kernel_code(self): + self._buffer.to() + template_file = f'{self.__algo__}_sparse_attention_{self.__direction__}_{self.__dtype__}.cuh.j2' + kernel_template = res.read_text(templates, template_file) + with open('tmp.cu', 'w') as f: + f.write(jinja2.Template(kernel_template).render(self.get_parameters())) + return jinja2.Template(kernel_template).render(self.get_parameters()) + + def compile(self, params: Dict[str, Any], shape: Tuple): + params['GLOBAL_SIZE_D_VALUE'] = shape[-1] + super().compile(params, shape) + + +class FlashSparseAttentionFP32Kernel(FlashSparseAttentionKernel): + + __dtype__ = 'float32' + + def _add_parameters(self): + super()._add_parameters() self._add_parameter('THREAD_SIZE_S_VALUE') self._add_parameter('THREAD_SIZE_T_VALUE') - self.attr = SparsityAttr(self, 'BLOCK_SIZE_T_VALUE', 'BLOCK_SIZE_S_VALUE', BCSR=False, BCSC=True) + self._add_parameter('THREAD_SIZE_D_VALUE') def _check_parameters(self, params: Dict[str, Any]): Bt = params['BLOCK_SIZE_T_VALUE'] @@ -59,7 +94,7 @@ def _check_parameters(self, params: Dict[str, Any]): def _check_shape(self, Nt: int, Ns: int, D: int): Bt, Bs = self.get_block_shape() - Tt, Ts = self.get_thread_shape() + Tt, Ts, Td = self.get_thread_shape() assert D & (D - 1) == 0 # TODO: pad threads_per_block = Bs // Ts * Bt // Tt smem_threads_D = D // 4 @@ -69,33 +104,79 @@ def _check_shape(self, Nt: int, Ns: int, D: int): assert smem_threads_N <= Bs assert Bs // Ts <= 32 assert Bs // Ts >= 4 - assert D * Ts >= Bs - - def get_block_shape(self): - Bt = self.get_parameter('BLOCK_SIZE_T_VALUE') - Bs = self.get_parameter('BLOCK_SIZE_S_VALUE') - return Bt, Bs + assert D // Td >= Bs // Ts + if self.__direction__ == 'backward': + assert D // Td >= Bt // Tt def get_thread_shape(self): Tt = self.get_parameter('THREAD_SIZE_T_VALUE') Ts = self.get_parameter('THREAD_SIZE_S_VALUE') - return Tt, Ts + Td = self.get_parameter('THREAD_SIZE_D_VALUE') + return Tt, Ts, Td def threads_per_block(self): Bt, Bs = self.get_block_shape() - Tt, Ts = self.get_thread_shape() + Tt, Ts, _ = self.get_thread_shape() return (Bs // Ts, Bt // Tt, 1) - def get_kernel_code(self): - template_file = f'{self.__algo__}_sparse_attention_{self.__direction__}.cuh.j2' - kernel_template = res.read_text(templates, template_file) - with open('tmp.cu', 'w') as f: - f.write(jinja2.Template(kernel_template).render(self.get_parameters())) - return jinja2.Template(kernel_template).render(self.get_parameters()) - def compile(self, params: Dict[str, Any], shape: Tuple): - params['GLOBAL_SIZE_D_VALUE'] = shape[-1] - super().compile(params, shape) +class FlashSparseAttentionFP16Kernel(FlashSparseAttentionKernel): + + __dtype__ = 'float16' + + def _add_parameters(self): + super()._add_parameters() + self._add_parameter('THREADS_PER_BLOCK') + self._add_parameter('QK_WARP_SIZE_M_VALUE') + self._add_parameter('SV_WARP_SIZE_M_VALUE') + + def _check_parameters(self, params: Dict[str, Any]): + Bs = params['BLOCK_SIZE_S_VALUE'] + Bt = params['BLOCK_SIZE_T_VALUE'] + assert Bs & (Bs - 1) == 0 + assert Bt & (Bt - 1) == 0 + assert Bs >= 16 + assert Bt >= 16 + threads_per_block = params['THREADS_PER_BLOCK'] + assert threads_per_block in [32, 64, 128, 256] + Wn1 = params['QK_WARP_SIZE_M_VALUE'] + assert Wn1 in [8, 16, 32] + Wn2 = params['SV_WARP_SIZE_M_VALUE'] + assert Wn2 in [8, 16, 32] + + def _check_shape(self, Nt: int, Ns: int, D: int): + Bt, Bs = self.get_block_shape() + assert D & (D - 1) == 0 # TODO: pad + assert D >= 16 + threads_per_block = self.get_parameter('THREADS_PER_BLOCK') + smem_threads_D = D // 8 + assert threads_per_block >= smem_threads_D + smem_threads_N = threads_per_block // smem_threads_D + assert smem_threads_N <= Bt + assert smem_threads_N <= Bs + thread_size = Bs * Bt // threads_per_block + assert Bs // thread_size <= 32 + assert 8 <= thread_size <= Bs + + def threads_per_block(self): + return (self.get_parameter('THREADS_PER_BLOCK'), 1, 1) + + def _build_kernel(self, kernel_code: str): + kernel_name = kernel_code[kernel_code.find('__global__ void') + 15:] + kernel_name = kernel_name[:kernel_name.find('(')].strip() + with warnings.catch_warnings(): + warnings.simplefilter('ignore') + source_module = SourceModule( + kernel_code, + options=[ + '-std=c++14', + '-O3', + '-U__CUDA_NO_HALF_OPERATORS__', + '-U__CUDA_NO_HALF_CONVERSIONS__', + ], + no_extern_c=True, + ) + return source_module.get_function(kernel_name) class FlashSparseAttentionForwardKernel(FlashSparseAttentionKernel): @@ -107,6 +188,8 @@ def set_kernel_call(self, shape: Tuple[int, int, int, int]): self._check_shape(Nt, Ns, D) Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) block = self.threads_per_block() + Bt, Bs = self.get_block_shape() + # shared = 4 * (Bt * D + 2 * Bs * D) def attn_func(Q, K, V): O = torch.zeros_like(Q) @@ -118,6 +201,7 @@ def attn_func(Q, K, V): self.attr.indexes.nnz, block=block, grid=(Q.shape[0], 1, 1), + # shared=shared, ) return O @@ -142,6 +226,8 @@ def set_kernel_call(self, shape: Tuple[int, int, int, int]): self._check_shape(Nt, Ns, D) Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) block = self.threads_per_block() + Bt, Bs = self.get_block_shape() + # shared = 4 * (2 * Bt * D + 4 * Bs * D) def attn_func(grad, O, Q, K, V): grad_Q = torch.zeros_like(Q) @@ -154,6 +240,7 @@ def attn_func(grad, O, Q, K, V): self.attr.indexes.nnz, block=block, grid=(Q.shape[0], 1, 1), + # shared=shared, ) return grad_Q, grad_K, grad_V @@ -170,9 +257,22 @@ def reference( ): return sparse_multi_head_attention_backward_reference(grad, O, Q, K, V, self.attr.mask) - def compile(self, params: Dict[str, Any], shape: Tuple): - Bs = params['BLOCK_SIZE_S_VALUE'] - Bt = params['BLOCK_SIZE_T_VALUE'] - Ts = params['THREAD_SIZE_S_VALUE'] - params['THREAD_SIZE_T_VALUE'] = Ts * Bt // Bs - super().compile(params, shape) + +class FlashSparseAttentionFP32ForwardKernel(FlashSparseAttentionFP32Kernel, FlashSparseAttentionForwardKernel): + + pass + + +class FlashSparseAttentionFP32BackwardKernel(FlashSparseAttentionFP32Kernel, FlashSparseAttentionBackwardKernel): + + pass + + +class FlashSparseAttentionFP16ForwardKernel(FlashSparseAttentionFP16Kernel, FlashSparseAttentionForwardKernel): + + pass + + +class FlashSparseAttentionFP16BackwardKernel(FlashSparseAttentionFP16Kernel, FlashSparseAttentionBackwardKernel): + + pass diff --git a/sparta/kernels/kernel_base.py b/sparta/kernels/kernel_base.py index f165bb06..f794f13b 100644 --- a/sparta/kernels/kernel_base.py +++ b/sparta/kernels/kernel_base.py @@ -228,17 +228,19 @@ def _check_parameters(self, params: Dict[str, Any]): def set_kernel_call(self, shape: Tuple): """Convert pycuda kernel (self._kernel) to python function call (self._func).""" - def compile(self, params: Dict[str, Any], shape: Tuple): - self._check_parameters(params) - self.set_parameters(params) - self.attr.update_block_size(self.id) - kernel_code = self.get_kernel_code() + def _build_kernel(self, kernel_code: str): kernel_name = kernel_code[kernel_code.find('__global__ void') + 15:] kernel_name = kernel_name[:kernel_name.find('(')].strip() with warnings.catch_warnings(): warnings.simplefilter('ignore') source_module = SourceModule(kernel_code, options=['-O3']) - self._kernel = source_module.get_function(kernel_name) + return source_module.get_function(kernel_name) + + def compile(self, params: Dict[str, Any], shape: Tuple): + self._check_parameters(params) + self.set_parameters(params) + self.attr.update_block_size(self.id) + self._kernel = self._build_kernel(self.get_kernel_code()) self.set_kernel_call(shape) self.ready = True # Calc estimated latency diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 index b91b027d..78874b99 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 @@ -40,11 +40,6 @@ const int SV_STRIDE_M = SV_WARP_M * (WARPS / SV_WARPS_N); const int D_PAD = 8; const int S_PAD = 8; -__device__ __forceinline__ half max(half x, half y) \ -{ \ - return x > y ? x : y; \ -} - extern "C" { __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( @@ -52,7 +47,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( half* K, half* V, half* O, - half* ML, + float* ML, {# unsigned char* mask, #} uint* block_idx, uint Ns, @@ -90,18 +85,17 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( wmma::fragment frag_S; wmma::fragment frag_V; wmma::fragment frag_O; - float4 tmp_float4; - half frag[T]; + float2 tmp_float2; + half tmp_half8[8]; + float frag[T]; float temperature = __frsqrt_rn((float)Ns); - half row_max; - half row_sum; - half row_sum_new; - half seg_max; - half seg_sum; - half row_coef; - half seg_coef; - int block_row_idx; + float row_max; + float row_sum; + float seg_max; + float seg_sum; + float row_coef; + float seg_coef; int last_col_idx = -1; {# BCSC #} @@ -149,30 +143,25 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( &shared_P[jt + qk_wy * QK_WARP_M][qk_wx * QK_WARP_N], frag_P, BS + S_PAD, wmma::mem_row_major); } __syncthreads(); - /*if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { - printf("P[0][0] = %f\n", (float)(shared_P[0][0])); - printf("P[0][1] = %f\n", (float)(shared_P[0][1])); - printf("P[1][0] = %f\n", (float)(shared_P[1][0])); - printf("P[1][1] = %f\n", (float)(shared_P[1][1])); - }*/ + // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { + // printf("P[0][0] = %f\n", (float)(shared_P[0][0])); + // printf("P[0][1] = %f\n", (float)(shared_P[0][1])); + // printf("P[1][0] = %f\n", (float)(shared_P[1][0])); + // printf("P[1][1] = %f\n", (float)(shared_P[1][1])); + // } - {# Load M, L, P #} - #pragma unroll - for (int jt = threadIdx.x * 4; jt < BT; jt += {{ THREADS_PER_BLOCK * 4 }}) { - tmp_float4 = ((float4*)(&ML[(row_idx * BT + jt) * 2]))[0]; - ((float*)(&shared_Q[jt + 0][0]))[0] = tmp_float4.x; - ((float*)(&shared_Q[jt + 1][0]))[0] = tmp_float4.y; - ((float*)(&shared_Q[jt + 2][0]))[0] = tmp_float4.z; - ((float*)(&shared_Q[jt + 3][0]))[0] = tmp_float4.w; - } - __syncthreads(); + {# Load P #} #pragma unroll for (int i = 0; i < T; i += 8) { - *((float4*)(&frag[i])) = *((float4*)(&shared_P[ty][tx * T + i])); + *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_P[ty][tx * T + i])); + #pragma unroll + for (int j = 0; j < 8; j++) { + frag[i + j] = __half2float(tmp_half8[j]); + } } {# Calc M~ = max_j(P) #} - seg_max = (half)(-1000.0); + seg_max = -100000.0; #pragma unroll for (int i = 0; i < T; i++) { seg_max = max(seg_max, frag[i]); @@ -184,10 +173,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Calc S = exp(P - M~) #} #pragma unroll for (int i = 0; i < T; i++) { - frag[i] = hexp(frag[i] - seg_max); + frag[i] = expf(frag[i] - seg_max); } {# Calc L~ = sum_j(P) #} - seg_sum = (half)(0.0f); + seg_sum = 0.0f; #pragma unroll for (int i = 0; i < T; i++) { seg_sum += frag[i]; @@ -198,21 +187,22 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( } {# Calc M' = max(M, M~), L' = exp(M - M') * L + exp(M~ - M') * L~ #} if (tx == 0) { - row_max = shared_Q[ty][0]; - row_sum = shared_Q[ty][1]; + tmp_float2 = ((float2*)(&ML[(row_idx * BT + ty) * 2]))[0]; + row_max = tmp_float2.x; + row_sum = tmp_float2.y; if (row_max < seg_max) { - shared_Q[ty][0] = seg_max; - row_coef = hexp(row_max - seg_max); - row_sum_new = row_coef * row_sum + seg_sum; - row_coef *= row_sum / row_sum_new; - seg_coef = (half)(1.0f) / row_sum_new; + tmp_float2.x = seg_max; + row_coef = expf(row_max - seg_max); + tmp_float2.y = row_coef * row_sum + seg_sum; + row_coef *= row_sum / tmp_float2.y; + seg_coef = 1.0f / tmp_float2.y; } else { - seg_coef = hexp(seg_max - row_max); - row_sum_new = row_sum + seg_coef * seg_sum; - row_coef = row_sum / row_sum_new; - seg_coef /= row_sum_new; + seg_coef = expf(seg_max - row_max); + tmp_float2.y = row_sum + seg_coef * seg_sum; + row_coef = row_sum / tmp_float2.y; + seg_coef /= tmp_float2.y; } - shared_Q[ty][1] = row_sum_new; + ((float2*)(&ML[(row_idx * BT + ty) * 2]))[0] = tmp_float2; } row_coef = __shfl_sync(WARP_MASK, row_coef, WARP_OFFSET); seg_coef = __shfl_sync(WARP_MASK, seg_coef, WARP_OFFSET); @@ -223,36 +213,32 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( } __syncthreads(); - {# Save M, L, S #} - #pragma unroll - for (int jt = threadIdx.x * 4; jt < BT; jt += {{ THREADS_PER_BLOCK * 4 }}) { - tmp_float4.x = ((float*)(&shared_Q[jt + 0][0]))[0]; - tmp_float4.y = ((float*)(&shared_Q[jt + 1][0]))[0]; - tmp_float4.z = ((float*)(&shared_Q[jt + 2][0]))[0]; - tmp_float4.w = ((float*)(&shared_Q[jt + 3][0]))[0]; - ((float4*)(&ML[(row_idx * BT + jt) * 2]))[0] = tmp_float4; - } + {# Save S #} #pragma unroll for (int i = 0; i < T; i += 8) { - *((float4*)(&shared_P[ty][tx * T + i])) = *((float4*)(&frag[i])); + #pragma unroll + for (int j = 0; j < 8; j++) { + tmp_half8[j] = __float2half(frag[i + j]); + } + *((float4*)(&shared_P[ty][tx * T + i])) = *((float4*)(&tmp_half8[0])); } __syncthreads(); - /*if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { - printf("S[0][0] = %f\n", (float)(shared_P[0][0])); - printf("S[0][1] = %f\n", (float)(shared_P[0][1])); - printf("S[1][0] = %f\n", (float)(shared_P[1][0])); - printf("S[1][1] = %f\n", (float)(shared_P[1][1])); - }*/ + // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { + // printf("S[0][0] = %f\n", (float)(shared_P[0][0])); + // printf("S[0][1] = %f\n", (float)(shared_P[0][1])); + // printf("S[1][0] = %f\n", (float)(shared_P[1][0])); + // printf("S[1][1] = %f\n", (float)(shared_P[1][1])); + // } {# Load O #} #pragma unroll for (int i = 0; i < SD; i += 8) { - *((float4*)(&frag[0])) = *((float4*)(&O[(row_idx * BT + ty) * D + tx * SD + i])); + *((float4*)(&tmp_half8[0])) = *((float4*)(&O[(row_idx * BT + ty) * D + tx * SD + i])); #pragma unroll for (int j = 0; j < 8; j++) { - frag[j] *= row_coef; + tmp_half8[j] = __float2half(__half2float(tmp_half8[j]) * row_coef); } - *((float4*)(&shared_Q[ty][tx * SD + i])) = *((float4*)(&frag[0])); + *((float4*)(&shared_Q[ty][tx * SD + i])) = *((float4*)(&tmp_half8[0])); } __syncthreads(); From a1ff80c4e594ec77e4cf0f2fca838d510c8c52ce Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Thu, 11 May 2023 16:02:35 +0900 Subject: [PATCH 23/28] Flash Attention fp16 backward version 1 --- sparta/kernels/attention.py | 13 +- ...h_sparse_attention_backward_float16.cuh.j2 | 394 ++++++++++++++++++ ...h_sparse_attention_backward_float32.cuh.j2 | 14 +- ...sh_sparse_attention_forward_float16.cuh.j2 | 105 +++-- ...sh_sparse_attention_forward_float32.cuh.j2 | 8 +- 5 files changed, 475 insertions(+), 59 deletions(-) create mode 100644 sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py index 42dfa68b..1aacc776 100644 --- a/sparta/kernels/attention.py +++ b/sparta/kernels/attention.py @@ -127,8 +127,10 @@ class FlashSparseAttentionFP16Kernel(FlashSparseAttentionKernel): def _add_parameters(self): super()._add_parameters() self._add_parameter('THREADS_PER_BLOCK') - self._add_parameter('QK_WARP_SIZE_M_VALUE') - self._add_parameter('SV_WARP_SIZE_M_VALUE') + self._add_parameter('TS_WARP_SIZE_M_VALUE') + self._add_parameter('TD_WARP_SIZE_M_VALUE') + if self.__direction__ == 'backward': + self._add_parameter('SD_WARP_SIZE_M_VALUE') def _check_parameters(self, params: Dict[str, Any]): Bs = params['BLOCK_SIZE_S_VALUE'] @@ -139,10 +141,13 @@ def _check_parameters(self, params: Dict[str, Any]): assert Bt >= 16 threads_per_block = params['THREADS_PER_BLOCK'] assert threads_per_block in [32, 64, 128, 256] - Wn1 = params['QK_WARP_SIZE_M_VALUE'] + Wn1 = params['TS_WARP_SIZE_M_VALUE'] assert Wn1 in [8, 16, 32] - Wn2 = params['SV_WARP_SIZE_M_VALUE'] + Wn2 = params['TD_WARP_SIZE_M_VALUE'] assert Wn2 in [8, 16, 32] + if self.__direction__ == 'backward': + Wn3 = params['SD_WARP_SIZE_M_VALUE'] + assert Wn3 in [8, 16, 32] def _check_shape(self, Nt: int, Ns: int, D: int): Bt, Bs = self.get_block_shape() diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 new file mode 100644 index 00000000..9d133661 --- /dev/null +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 @@ -0,0 +1,394 @@ +{# Copyright (c) Microsoft Corporation. #} +{# Licensed under the MIT license. #} + +#include +#include + +using namespace nvcuda; + +{% set WARP_SIZE = 32 %} +{% set FRAG_SIZE = 256 %} +{% set TS_WARP_SIZE_N_VALUE = FRAG_SIZE // TS_WARP_SIZE_M_VALUE %} +{% set TD_WARP_SIZE_N_VALUE = FRAG_SIZE // TD_WARP_SIZE_M_VALUE %} +{% set SD_WARP_SIZE_N_VALUE = FRAG_SIZE // SD_WARP_SIZE_M_VALUE %} +{% set BLOCK_SIZE = BLOCK_SIZE_T_VALUE * BLOCK_SIZE_S_VALUE %} +{% set THREAD_SIZE = BLOCK_SIZE // THREADS_PER_BLOCK %}{# 8 <= THREAD_SIZE <= BLOCK_SIZE_S_VALUE #} +{% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE %}{# WARP_REDUCE_SIZE <= WARP_SIZE #} + +const int BS = {{ BLOCK_SIZE_S_VALUE }}; +const int BT = {{ BLOCK_SIZE_T_VALUE }}; +const int D = {{ GLOBAL_SIZE_D_VALUE }}; +const int TS_WARP_M = {{ TS_WARP_SIZE_M_VALUE }}; +const int TS_WARP_N = {{ TS_WARP_SIZE_N_VALUE }}; +const int TS_WARP_K = 16; +const int TD_WARP_M = {{ TD_WARP_SIZE_M_VALUE }}; +const int TD_WARP_N = {{ TD_WARP_SIZE_N_VALUE }}; +const int TD_WARP_K = 16; +const int SD_WARP_M = {{ SD_WARP_SIZE_M_VALUE }}; +const int SD_WARP_N = {{ SD_WARP_SIZE_N_VALUE }}; +const int SD_WARP_K = 16; + +const int B = {{ BLOCK_SIZE }}; +const int T = {{ THREAD_SIZE }}; +const int THREADS = {{ THREADS_PER_BLOCK }};{# THREADS_PER_BLOCK >= WARP_SIZE #} +const int WARPS = THREADS / {{ WARP_SIZE }}; +const int SD = T * D / BS; + +const int SMEM_THREADS_D = D / 8; +const int SMEM_THREADS_N = {{ THREADS_PER_BLOCK }} / SMEM_THREADS_D; +const int TS_WARPS_N = BS / TS_WARP_N; +const int TS_STRIDE_M = TS_WARP_M * (WARPS / TS_WARPS_N); +const int TD_WARPS_N = D / TD_WARP_N; +const int TD_STRIDE_M = TD_WARP_M * (WARPS / TD_WARPS_N); +const int SD_WARPS_N = D / SD_WARP_N; +const int SD_STRIDE_M = SD_WARP_M * (WARPS / SD_WARPS_N); + +const int D_PAD = 8; +const int S_PAD = 8; + +extern "C" { + +__global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( + half* Q, + half* K, + half* V, + half* O, + half* dQ, + half* dK, + half* dV, + half* dO, + float* ML, + {# unsigned char* mask, #} + uint* block_idx, + uint Ns, + uint Nt, + uint block_nnz +) { + Q += Nt * D * blockIdx.x; + K += Ns * D * blockIdx.x; + V += Ns * D * blockIdx.x; + O += Nt * D * blockIdx.x; + dQ += Nt * D * blockIdx.x; + dK += Ns * D * blockIdx.x; + dV += Ns * D * blockIdx.x; + dO += Nt * D * blockIdx.x; + ML += Nt * 2 * blockIdx.x; + + uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; + uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; + + __shared__ half shared_Q[BT * (D + D_PAD)]; + __shared__ half shared_P[BT * (BS + S_PAD)]; + __shared__ half shared_S[BT * (BS + S_PAD)]; + __shared__ half shared_K[BS * (D + D_PAD)]; + __shared__ half shared_V[BS * (D + D_PAD)]; + __shared__ half shared_O[BT * (D + D_PAD)]; + __shared__ half shared_dK[BS * (D + D_PAD)]; + __shared__ half shared_dV[BS * (D + D_PAD)]; + + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int SMEM_TID_N = tid / SMEM_THREADS_D; + int SMEM_TID_D = tid % SMEM_THREADS_D * 8; + + int wid = threadIdx.x / {{ WARP_SIZE }}; + int ts_wx = wid % TS_WARPS_N; + int ts_wy = wid / TS_WARPS_N; + int td_wx = wid % TD_WARPS_N; + int td_wy = wid / TD_WARPS_N; + int sd_wx = wid % SD_WARPS_N; + int sd_wy = wid / SD_WARPS_N; + int tx = threadIdx.x % {{ WARP_REDUCE_SIZE }}; + int ty = threadIdx.x / {{ WARP_REDUCE_SIZE }}; + + wmma::fragment frag_ts_a; + wmma::fragment frag_ts_b; + wmma::fragment frag_ts_c; + wmma::fragment frag_td_a; + wmma::fragment frag_td_b; + wmma::fragment frag_td_c; + wmma::fragment frag_sd_a; + wmma::fragment frag_sd_b; + wmma::fragment frag_sd_c; + float2 tmp_float2; + half tmp_half8[8]; + half tmp_half8_2[8]; + float frag_P[T]; + float frag_S[T]; + + float temperature = __frsqrt_rn((float)Ns); + float row_sum; + + int last_col_idx = -1; + {# BCSC #} + for (int block = 0; block < block_nnz; block++) { + uint idx = block_idx[block]; + int row_idx = idx & 0xffff; + int col_idx = idx >> 16; + // if (blockIdx.x == 0 && threadIdx.x == 0 && threadIdx.y == 0) + // printf("#%d: (%d, %d)\n", block, row_idx, col_idx); + + {# Load Q #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + } + if (col_idx != last_col_idx) { + if (last_col_idx >= 0) { + {# Save dK, dV #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&shared_dK[k * (D + D_PAD) + SMEM_TID_D])); + *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&shared_dV[k * (D + D_PAD) + SMEM_TID_D])); + } + } + {# Load K, V, dK, dV #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&shared_dK[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&dK[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_dV[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&dV[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_K[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_V[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + } + last_col_idx = col_idx; + } + __syncthreads(); + + {# Calc P = Q K^T #} + #pragma unroll + for (int j = 0; j < BT; j += TS_STRIDE_M) { + wmma::fill_fragment(frag_ts_c, 0.0); + #pragma unroll + for (int k = 0; k < D; k += TS_WARP_K) { + wmma::load_matrix_sync(frag_ts_a, &shared_Q[(j + ts_wy * TS_WARP_M) * (D + D_PAD) + k], D + D_PAD); + wmma::load_matrix_sync(frag_ts_b, &shared_K[(ts_wx * TS_WARP_N) * (D + D_PAD) + k], D + D_PAD); + wmma::mma_sync(frag_ts_c, frag_ts_a, frag_ts_b, frag_ts_c); + } + wmma::store_matrix_sync( + &shared_P[(j + ts_wy * TS_WARP_M) * (BS + S_PAD) + ts_wx * TS_WARP_N], + frag_ts_c, + BS + S_PAD, + wmma::mem_row_major + ); + } + __syncthreads(); + + {# Load M, L, P #} + #pragma unroll + for (int i = 0; i < T; i += 8) { + *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_P[ty * (BS + S_PAD) + tx * T + i])); + #pragma unroll + for (int j = 0; j < 8; j++) { + frag_P[i + j] = __half2float(tmp_half8[j]); + } + } + if (tx == 0) { + tmp_float2 = ((float2*)(&ML[(row_idx * BT + ty) * 2]))[0]; + } + __syncthreads(); + tmp_float2.x = __shfl_sync(WARP_MASK, tmp_float2.x, WARP_OFFSET); + tmp_float2.y = __shfl_sync(WARP_MASK, tmp_float2.y, WARP_OFFSET); + + {# Calc S = exp(P - M) / L #} + #pragma unroll + for (int i = 0; i < T; i++) { + frag_S[i] = expf(frag_P[i] * temperature - tmp_float2.x) / tmp_float2.y; + } + + {# Save S #} + #pragma unroll + for (int i = 0; i < T; i += 8) { + #pragma unroll + for (int j = 0; j < 8; j++) { + tmp_half8[j] = __float2half(frag_S[i + j]); + } + *((float4*)(&shared_P[ty * (BS + S_PAD) + tx * T + i])) = *((float4*)(&tmp_half8[0])); + } + + {# Load dO #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_O[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&dO[(row_idx * BT + k) * D + SMEM_TID_D])); + } + __syncthreads(); + + {# Calc dV = dV + S^T dO #} + #pragma unroll + for (int j = 0; j < BS; j += SD_STRIDE_M) { + wmma::load_matrix_sync( + frag_sd_c, + &shared_dV[(j + sd_wy * SD_WARP_M) * (D + D_PAD) + sd_wx * SD_WARP_N], + D + D_PAD, + wmma::mem_row_major + ); + #pragma unroll + for (int k = 0; k < BT; k += SD_WARP_K) { + wmma::load_matrix_sync(frag_sd_a, &shared_P[k * (BS + S_PAD) + j + sd_wy * SD_WARP_M], BS + S_PAD); + wmma::load_matrix_sync(frag_sd_b, &shared_O[k * (D + D_PAD) + sd_wx * SD_WARP_N], D + D_PAD); + wmma::mma_sync(frag_sd_c, frag_sd_a, frag_sd_b, frag_sd_c); + } + wmma::store_matrix_sync( + &shared_dV[(j + sd_wy * SD_WARP_M) * (D + D_PAD) + sd_wx * SD_WARP_N], + frag_sd_c, + D + D_PAD, + wmma::mem_row_major + ); + } + __syncthreads(); + + {# Calc dS = dO V^T #} + #pragma unroll + for (int j = 0; j < BT; j += TS_STRIDE_M) { + wmma::fill_fragment(frag_ts_c, 0.0); + #pragma unroll + for (int k = 0; k < D; k += TS_WARP_K) { + wmma::load_matrix_sync(frag_ts_a, &shared_O[(j + ts_wy * TS_WARP_M) * (D + D_PAD) + k], D + D_PAD); + wmma::load_matrix_sync(frag_ts_b, &shared_V[(ts_wx * TS_WARP_N) * (D + D_PAD) + k], D + D_PAD); + wmma::mma_sync(frag_ts_c, frag_ts_a, frag_ts_b, frag_ts_c); + } + wmma::store_matrix_sync( + &shared_S[(j + ts_wy * TS_WARP_M) * (BS + S_PAD) + ts_wx * TS_WARP_N], + frag_ts_c, + BS + S_PAD, + wmma::mem_row_major + ); + } + __syncthreads(); + + // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { + // printf("dS[0][0] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 0])); + // printf("dS[0][1] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 1])); + // printf("dS[1][0] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 0])); + // printf("dS[1][1] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 1])); + // } + + {# Load dS #} + #pragma unroll + for (int i = 0; i < T; i += 8) { + *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_S[ty * (BS + S_PAD) + tx * T + i])); + #pragma unroll + for (int j = 0; j < 8; j++) { + frag_P[i + j] = __half2float(tmp_half8[j]); + } + } + + {# Calc dP = S (dS - sum_j(dO * O)) #} + row_sum = 0.0f; + #pragma unroll + for (int i = 0; i < SD; i += 8) { + *((float4*)(&tmp_half8[0])) = *((float4*)(&O[(row_idx * BT + ty) * D + tx * SD + i])); + *((float4*)(&tmp_half8_2[0])) = *((float4*)(&shared_O[ty * (D + D_PAD) + tx * SD + i])); + #pragma unroll + for (int j = 0; j < 8; j++) { + row_sum += __half2float(tmp_half8[j] * tmp_half8_2[j]); + } + } + #pragma unroll + for (int offset = {{ WARP_REDUCE_SIZE // 2 }}; offset > 0; offset >>= 1) { + row_sum += __shfl_xor_sync(WARP_MASK, row_sum, offset); + } + #pragma unroll + for (int i = 0; i < T; i ++) { + frag_P[i] = frag_S[i] * (frag_P[i] - row_sum) * temperature; + } + __syncthreads(); + + {# Save dP #} + #pragma unroll + for (int i = 0; i < T; i += 8) { + #pragma unroll + for (int j = 0; j < 8; j++) { + tmp_half8[j] = __float2half(frag_P[i + j]); + } + *((float4*)(&shared_P[ty * (BS + S_PAD) + tx * T + i])) = *((float4*)(&tmp_half8[0])); + } + + // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { + // printf("dP[0][0] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 0])); + // printf("dP[0][1] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 1])); + // printf("dP[1][0] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 0])); + // printf("dP[1][1] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 1])); + // } + + {# Load dQ #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&shared_O[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])); + } + __syncthreads(); + + {# Calc dK = dK + dP^T Q #} + #pragma unroll + for (int j = 0; j < BS; j += SD_STRIDE_M) { + wmma::load_matrix_sync( + frag_sd_c, + &shared_dK[(j + sd_wy * SD_WARP_M) * (D + D_PAD) + sd_wx * SD_WARP_N], + D + D_PAD, + wmma::mem_row_major + ); + #pragma unroll + for (int k = 0; k < BT; k += SD_WARP_K) { + wmma::load_matrix_sync(frag_sd_a, &shared_P[k * (BS + S_PAD) + j + sd_wy * SD_WARP_M], BS + S_PAD); + wmma::load_matrix_sync(frag_sd_b, &shared_Q[k * (D + D_PAD) + sd_wx * SD_WARP_N], D + D_PAD); + wmma::mma_sync(frag_sd_c, frag_sd_a, frag_sd_b, frag_sd_c); + } + wmma::store_matrix_sync( + &shared_dK[(j + sd_wy * SD_WARP_M) * (D + D_PAD) + sd_wx * SD_WARP_N], + frag_sd_c, + D + D_PAD, + wmma::mem_row_major + ); + } + __syncthreads(); + + {# Calc dQ = dQ + dP K #} + #pragma unroll + for (int j = 0; j < BT; j += TD_STRIDE_M) { + wmma::load_matrix_sync( + frag_td_c, + &shared_O[(j + td_wy * TD_WARP_M) * (D + D_PAD) + td_wx * TD_WARP_N], + D + D_PAD, + wmma::mem_row_major + ); + #pragma unroll + for (int k = 0; k < BS; k += TD_WARP_K) { + wmma::load_matrix_sync(frag_td_a, &shared_P[(j + td_wy * TD_WARP_M) * (BS + S_PAD) + k], BS + S_PAD); + wmma::load_matrix_sync(frag_td_b, &shared_K[k * (D + D_PAD) + td_wx * TD_WARP_N], D + D_PAD); + wmma::mma_sync(frag_td_c, frag_td_a, frag_td_b, frag_td_c); + } + wmma::store_matrix_sync( + &shared_O[(j + td_wy * TD_WARP_M) * (D + D_PAD) + td_wx * TD_WARP_N], + frag_td_c, + D + D_PAD, + wmma::mem_row_major + ); + } + __syncthreads(); + + {# Save dQ #} + #pragma unroll + for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { + *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])) = + *((float4*)(&shared_O[k * (D + D_PAD) + SMEM_TID_D])); + } + } + + {# Save dK, dV for the last column #} + #pragma unroll + for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { + *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&shared_dK[k * (D + D_PAD) + SMEM_TID_D])); + *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&shared_dV[k * (D + D_PAD) + SMEM_TID_D])); + } +} + +} // extern "C" diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 index e924653a..735c2fe5 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 @@ -73,13 +73,13 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; - {# extern __shared__ float buffer[]; - float* shared_Q = &buffer[0]; - float* shared_K = &buffer[BT * D]; - float* shared_V = &buffer[BT * D + BS * D]; - float* shared_O = &buffer[BT * D + 2 * BS * D]; - float* shared_dK = &buffer[2 * BT * D + 2 * BS * D]; - float* shared_dV = &buffer[2 * BT * D + 3 * BS * D]; #} + {# extern __shared__ float shared[]; + float* shared_Q = &shared[0]; + float* shared_K = &shared[BT * D]; + float* shared_V = &shared[BT * D + BS * D]; + float* shared_O = &shared[BT * D + 2 * BS * D]; + float* shared_dK = &shared[2 * BT * D + 2 * BS * D]; + float* shared_dV = &shared[2 * BT * D + 3 * BS * D]; #} __shared__ float shared_Q[BT * D]; __shared__ float shared_K[BS * D]; __shared__ float shared_V[BS * D]; diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 index 78874b99..cea9c0ba 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 @@ -8,8 +8,8 @@ using namespace nvcuda; {% set WARP_SIZE = 32 %} {% set FRAG_SIZE = 256 %} -{% set QK_WARP_SIZE_N_VALUE = FRAG_SIZE // QK_WARP_SIZE_M_VALUE %} -{% set SV_WARP_SIZE_N_VALUE = FRAG_SIZE // SV_WARP_SIZE_M_VALUE %} +{% set TS_WARP_SIZE_N_VALUE = FRAG_SIZE // TS_WARP_SIZE_M_VALUE %} +{% set TD_WARP_SIZE_N_VALUE = FRAG_SIZE // TD_WARP_SIZE_M_VALUE %} {% set BLOCK_SIZE = BLOCK_SIZE_T_VALUE * BLOCK_SIZE_S_VALUE %} {% set THREAD_SIZE = BLOCK_SIZE // THREADS_PER_BLOCK %}{# 8 <= THREAD_SIZE <= BLOCK_SIZE_S_VALUE #} {% set WARP_REDUCE_SIZE = BLOCK_SIZE_S_VALUE // THREAD_SIZE %}{# WARP_REDUCE_SIZE <= WARP_SIZE #} @@ -17,12 +17,12 @@ using namespace nvcuda; const int BS = {{ BLOCK_SIZE_S_VALUE }}; const int BT = {{ BLOCK_SIZE_T_VALUE }}; const int D = {{ GLOBAL_SIZE_D_VALUE }}; -const int QK_WARP_M = {{ QK_WARP_SIZE_M_VALUE }}; -const int QK_WARP_N = {{ QK_WARP_SIZE_N_VALUE }}; -const int QK_WARP_K = 16; -const int SV_WARP_M = {{ SV_WARP_SIZE_M_VALUE }}; -const int SV_WARP_N = {{ SV_WARP_SIZE_N_VALUE }}; -const int SV_WARP_K = 16; +const int TS_WARP_M = {{ TS_WARP_SIZE_M_VALUE }}; +const int TS_WARP_N = {{ TS_WARP_SIZE_N_VALUE }}; +const int TS_WARP_K = 16; +const int TD_WARP_M = {{ TD_WARP_SIZE_M_VALUE }}; +const int TD_WARP_N = {{ TD_WARP_SIZE_N_VALUE }}; +const int TD_WARP_K = 16; const int B = {{ BLOCK_SIZE }}; const int T = {{ THREAD_SIZE }}; @@ -32,10 +32,10 @@ const int SD = T * D / BS; const int SMEM_THREADS_D = D / 8; const int SMEM_THREADS_N = THREADS / SMEM_THREADS_D; -const int QK_WARPS_N = BS / QK_WARP_N; -const int QK_STRIDE_M = QK_WARP_M * (WARPS / QK_WARPS_N); -const int SV_WARPS_N = D / SV_WARP_N; -const int SV_STRIDE_M = SV_WARP_M * (WARPS / SV_WARPS_N); +const int TS_WARPS_N = BS / TS_WARP_N; +const int TS_STRIDE_M = TS_WARP_M * (WARPS / TS_WARPS_N); +const int TD_WARPS_N = D / TD_WARP_N; +const int TD_STRIDE_M = TD_WARP_M * (WARPS / TD_WARPS_N); const int D_PAD = 8; const int S_PAD = 8; @@ -63,28 +63,33 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; - __shared__ half shared_Q[BT][D + D_PAD]; - __shared__ half shared_P[BT][BS + S_PAD]; - __shared__ half shared_K[BS][D + D_PAD]; - __shared__ half shared_V[BS][D + D_PAD]; + {# extern __shared__ half shared[]; + half* shared_Q = &shared[0]; + half* shared_P = &shared_Q[BT * (D + D_PAD)]; + half* shared_K = &shared_P[BT * (BS + S_PAD)]; + half* shared_V = &shared_K[BS * (D + D_PAD)]; #} + __shared__ half shared_Q[BT * (D + D_PAD)]; + __shared__ half shared_P[BT * (BS + S_PAD)]; + __shared__ half shared_K[BS * (D + D_PAD)]; + __shared__ half shared_V[BS * (D + D_PAD)]; int SMEM_TID_N = threadIdx.x / SMEM_THREADS_D; int SMEM_TID_D = threadIdx.x % SMEM_THREADS_D * 8; int wid = threadIdx.x / {{ WARP_SIZE }}; - int qk_wx = wid % QK_WARPS_N; - int qk_wy = wid / QK_WARPS_N; - int sv_wx = wid % SV_WARPS_N; - int sv_wy = wid / SV_WARPS_N; + int ts_wx = wid % TS_WARPS_N; + int ts_wy = wid / TS_WARPS_N; + int td_wx = wid % TD_WARPS_N; + int td_wy = wid / TD_WARPS_N; int tx = threadIdx.x % {{ WARP_REDUCE_SIZE }}; int ty = threadIdx.x / {{ WARP_REDUCE_SIZE }}; - wmma::fragment frag_Q; - wmma::fragment frag_K; - wmma::fragment frag_P; - wmma::fragment frag_S; - wmma::fragment frag_V; - wmma::fragment frag_O; + wmma::fragment frag_Q; + wmma::fragment frag_K; + wmma::fragment frag_P; + wmma::fragment frag_S; + wmma::fragment frag_V; + wmma::fragment frag_O; float2 tmp_float2; half tmp_half8[8]; float frag[T]; @@ -109,18 +114,18 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Load Q #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_Q[k][SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); } if (col_idx != last_col_idx) { {# Load K #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&shared_K[k][SMEM_TID_D])) = *((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_K[k * (D + D_PAD) + SMEM_TID_D])) = *((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D])); } {# Load V #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&shared_V[k][SMEM_TID_D])) = *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_V[k * (D + D_PAD) + SMEM_TID_D])) = *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); } last_col_idx = col_idx; } @@ -128,19 +133,23 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Calc P = Q K^T #} #pragma unroll - for (int jt = 0; jt < BT; jt += QK_STRIDE_M) { + for (int j = 0; j < BT; j += TS_STRIDE_M) { wmma::fill_fragment(frag_P, 0.0); #pragma unroll - for (int k = 0; k < D; k += QK_WARP_K) { - wmma::load_matrix_sync(frag_Q, &shared_Q[jt + qk_wy * QK_WARP_M][k], D + D_PAD); - wmma::load_matrix_sync(frag_K, &shared_K[qk_wx * QK_WARP_N][k], D + D_PAD); + for (int k = 0; k < D; k += TS_WARP_K) { + wmma::load_matrix_sync(frag_Q, &shared_Q[(j + ts_wy * TS_WARP_M) * (D + D_PAD) + k], D + D_PAD); + wmma::load_matrix_sync(frag_K, &shared_K[(ts_wx * TS_WARP_N) * (D + D_PAD) + k], D + D_PAD); wmma::mma_sync(frag_P, frag_Q, frag_K, frag_P); } for(int i = 0; i < {{ FRAG_SIZE }}; i++) { frag_P.x[i] *= temperature; } wmma::store_matrix_sync( - &shared_P[jt + qk_wy * QK_WARP_M][qk_wx * QK_WARP_N], frag_P, BS + S_PAD, wmma::mem_row_major); + &shared_P[(j + ts_wy * TS_WARP_M) * (BS + S_PAD) + ts_wx * TS_WARP_N], + frag_P, + BS + S_PAD, + wmma::mem_row_major + ); } __syncthreads(); // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { @@ -153,7 +162,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Load P #} #pragma unroll for (int i = 0; i < T; i += 8) { - *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_P[ty][tx * T + i])); + *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_P[ty * (BS + S_PAD) + tx * T + i])); #pragma unroll for (int j = 0; j < 8; j++) { frag[i + j] = __half2float(tmp_half8[j]); @@ -220,7 +229,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( for (int j = 0; j < 8; j++) { tmp_half8[j] = __float2half(frag[i + j]); } - *((float4*)(&shared_P[ty][tx * T + i])) = *((float4*)(&tmp_half8[0])); + *((float4*)(&shared_P[ty * (BS + S_PAD) + tx * T + i])) = *((float4*)(&tmp_half8[0])); } __syncthreads(); // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { @@ -238,30 +247,38 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( for (int j = 0; j < 8; j++) { tmp_half8[j] = __float2half(__half2float(tmp_half8[j]) * row_coef); } - *((float4*)(&shared_Q[ty][tx * SD + i])) = *((float4*)(&tmp_half8[0])); + *((float4*)(&shared_Q[ty * (D + D_PAD) + tx * SD + i])) = *((float4*)(&tmp_half8[0])); } __syncthreads(); {# Calc O = O' + S' V #} #pragma unroll - for (int jt = 0; jt < BT; jt += SV_STRIDE_M) { + for (int j = 0; j < BT; j += TD_STRIDE_M) { wmma::load_matrix_sync( - frag_O, &shared_Q[jt + sv_wy * SV_WARP_M][sv_wx * SV_WARP_N], D + D_PAD, wmma::mem_row_major); + frag_O, + &shared_Q[(j + td_wy * TD_WARP_M) * (D + D_PAD) + td_wx * TD_WARP_N], + D + D_PAD, + wmma::mem_row_major + ); #pragma unroll - for (int k = 0; k < BS; k += SV_WARP_K) { - wmma::load_matrix_sync(frag_S, &shared_P[jt + sv_wy * SV_WARP_M][k], BS + S_PAD); - wmma::load_matrix_sync(frag_V, &shared_V[k][sv_wx * SV_WARP_N], D + D_PAD); + for (int k = 0; k < BS; k += TD_WARP_K) { + wmma::load_matrix_sync(frag_S, &shared_P[(j + td_wy * TD_WARP_M) * (BS + S_PAD) + k], BS + S_PAD); + wmma::load_matrix_sync(frag_V, &shared_V[k * (D + D_PAD) + td_wx * TD_WARP_N], D + D_PAD); wmma::mma_sync(frag_O, frag_S, frag_V, frag_O); } wmma::store_matrix_sync( - &shared_Q[jt + sv_wy * SV_WARP_M][sv_wx * SV_WARP_N], frag_O, D + D_PAD, wmma::mem_row_major); + &shared_Q[(j + td_wy * TD_WARP_M) * (D + D_PAD) + td_wx * TD_WARP_N], + frag_O, + D + D_PAD, + wmma::mem_row_major + ); } __syncthreads(); {# Save O #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k][SMEM_TID_D])); + *((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])); } } } diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 index 9314b0fb..5cf38732 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 @@ -73,10 +73,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; - {# extern __shared__ float buffer[]; - float* shared_Q = &buffer[0]; - float* shared_K = &buffer[BT * D]; - float* shared_V = &buffer[BT * D + BS * D]; #} + {# extern __shared__ float shared[]; + float* shared_Q = &shared[0]; + float* shared_K = &shared[BT * D]; + float* shared_V = &shared[BT * D + BS * D]; #} __shared__ float shared_Q[BT * D]; __shared__ float shared_K[BS * D]; __shared__ float shared_V[BS * D]; From 2a161179c3b44b4a336fc28d568110bfb375245c Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Fri, 12 May 2023 18:21:00 +0900 Subject: [PATCH 24/28] use dynamic shared memory in Flash Attention --- sparta/kernels/attention.py | 33 ++++++++++++++----- ...h_sparse_attention_backward_float16.cuh.j2 | 15 +++++++-- ...h_sparse_attention_backward_float32.cuh.j2 | 18 +++++----- ...sh_sparse_attention_forward_float16.cuh.j2 | 10 +++--- ...sh_sparse_attention_forward_float32.cuh.j2 | 12 +++---- sparta/testing/math.py | 4 +-- 6 files changed, 58 insertions(+), 34 deletions(-) diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py index 1aacc776..857cd532 100644 --- a/sparta/kernels/attention.py +++ b/sparta/kernels/attention.py @@ -14,6 +14,7 @@ from sparta import __env_ready__ if __env_ready__: + from pycuda.driver import function_attribute from pycuda.compiler import SourceModule from sparta.tuning import TunableItemCfg @@ -48,13 +49,17 @@ def _add_parameters(self): @abc.abstractmethod def _check_shape(self, Nt: int, Ns: int, D: int): - '''Check if input shape is valid.''' + """Check if input shape is valid.""" def get_block_shape(self): Bt = self.get_parameter('BLOCK_SIZE_T_VALUE') Bs = self.get_parameter('BLOCK_SIZE_S_VALUE') return Bt, Bs + @abc.abstractmethod + def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): + """Calc shared memory size in bytes.""" + def get_kernel_code(self): self._buffer.to() template_file = f'{self.__algo__}_sparse_attention_{self.__direction__}_{self.__dtype__}.cuh.j2' @@ -194,7 +199,8 @@ def set_kernel_call(self, shape: Tuple[int, int, int, int]): Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) block = self.threads_per_block() Bt, Bs = self.get_block_shape() - # shared = 4 * (Bt * D + 2 * Bs * D) + shared = self._calc_shared_mem_size(Bs, Bt, D) + self._kernel.set_attribute(function_attribute.MAX_DYNAMIC_SHARED_SIZE_BYTES, shared) def attn_func(Q, K, V): O = torch.zeros_like(Q) @@ -206,7 +212,7 @@ def attn_func(Q, K, V): self.attr.indexes.nnz, block=block, grid=(Q.shape[0], 1, 1), - # shared=shared, + shared=shared, ) return O @@ -232,7 +238,8 @@ def set_kernel_call(self, shape: Tuple[int, int, int, int]): Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) block = self.threads_per_block() Bt, Bs = self.get_block_shape() - # shared = 4 * (2 * Bt * D + 4 * Bs * D) + shared = self._calc_shared_mem_size(Bs, Bt, D) + self._kernel.set_attribute(function_attribute.MAX_DYNAMIC_SHARED_SIZE_BYTES, shared) def attn_func(grad, O, Q, K, V): grad_Q = torch.zeros_like(Q) @@ -245,7 +252,7 @@ def attn_func(grad, O, Q, K, V): self.attr.indexes.nnz, block=block, grid=(Q.shape[0], 1, 1), - # shared=shared, + shared=shared, ) return grad_Q, grad_K, grad_V @@ -265,19 +272,27 @@ def reference( class FlashSparseAttentionFP32ForwardKernel(FlashSparseAttentionFP32Kernel, FlashSparseAttentionForwardKernel): - pass + def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): + shared = 4 * (Bt * D + 2 * Bs * D) + return shared class FlashSparseAttentionFP32BackwardKernel(FlashSparseAttentionFP32Kernel, FlashSparseAttentionBackwardKernel): - pass + def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): + shared = 4 * (2 * Bt * D + 4 * Bs * D) + return shared class FlashSparseAttentionFP16ForwardKernel(FlashSparseAttentionFP16Kernel, FlashSparseAttentionForwardKernel): - pass + def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): + shared = 2 * (Bt * (D + 8) + 2 * (Bs * (D + 8)) + Bt * (Bs + 8)) + return shared class FlashSparseAttentionFP16BackwardKernel(FlashSparseAttentionFP16Kernel, FlashSparseAttentionBackwardKernel): - pass + def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): + shared = 2 * (2 * Bt * (D + 8) + 4 * (Bs * (D + 8)) + 2 * Bt * (Bs + 8)) + return shared diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 index 9d133661..d29b0aad 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 @@ -77,14 +77,23 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; - __shared__ half shared_Q[BT * (D + D_PAD)]; + extern __shared__ half shared[]; + half* shared_Q = &shared[0]; + half* shared_P = &shared_Q[BT * (D + D_PAD)]; + half* shared_S = &shared_P[BT * (BS + S_PAD)]; + half* shared_K = &shared_S[BT * (BS + S_PAD)]; + half* shared_V = &shared_K[BS * (D + D_PAD)]; + half* shared_O = &shared_V[BS * (D + D_PAD)]; + half* shared_dK = &shared_O[BT * (D + D_PAD)]; + half* shared_dV = &shared_dK[BS * (D + D_PAD)]; + {# __shared__ half shared_Q[BT * (D + D_PAD)]; __shared__ half shared_P[BT * (BS + S_PAD)]; __shared__ half shared_S[BT * (BS + S_PAD)]; __shared__ half shared_K[BS * (D + D_PAD)]; __shared__ half shared_V[BS * (D + D_PAD)]; __shared__ half shared_O[BT * (D + D_PAD)]; __shared__ half shared_dK[BS * (D + D_PAD)]; - __shared__ half shared_dV[BS * (D + D_PAD)]; + __shared__ half shared_dV[BS * (D + D_PAD)]; #} int tid = threadIdx.y * blockDim.x + threadIdx.x; int SMEM_TID_N = tid / SMEM_THREADS_D; @@ -115,7 +124,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( float frag_P[T]; float frag_S[T]; - float temperature = __frsqrt_rn((float)Ns); + float temperature = __frsqrt_rn((float)D); float row_sum; int last_col_idx = -1; diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 index 735c2fe5..27e7eab2 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 @@ -73,19 +73,19 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; - {# extern __shared__ float shared[]; + extern __shared__ float shared[]; float* shared_Q = &shared[0]; - float* shared_K = &shared[BT * D]; - float* shared_V = &shared[BT * D + BS * D]; - float* shared_O = &shared[BT * D + 2 * BS * D]; - float* shared_dK = &shared[2 * BT * D + 2 * BS * D]; - float* shared_dV = &shared[2 * BT * D + 3 * BS * D]; #} - __shared__ float shared_Q[BT * D]; + float* shared_K = &shared_Q[BT * D]; + float* shared_V = &shared_K[BS * D]; + float* shared_O = &shared_V[BS * D]; + float* shared_dK = &shared_O[BT * D]; + float* shared_dV = &shared_dK[BS * D]; + {# __shared__ float shared_Q[BT * D]; __shared__ float shared_K[BS * D]; __shared__ float shared_V[BS * D]; __shared__ float shared_O[BT * D]; __shared__ float shared_dK[BS * D]; - __shared__ float shared_dV[BS * D]; + __shared__ float shared_dV[BS * D]; #} {# __shared__ float shared_ML[BT * 2]; #} float* shared_ML = shared_O; @@ -99,7 +99,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( float frag_P[TT][TS]; float frag_S[TT][TS]; - float temperature = __frsqrt_rn((float)Ns); + float temperature = __frsqrt_rn((float)D); float row_max; float row_sum; int block_row_idx; diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 index cea9c0ba..56d39ca8 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 @@ -63,15 +63,15 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; - {# extern __shared__ half shared[]; + extern __shared__ half shared[]; half* shared_Q = &shared[0]; half* shared_P = &shared_Q[BT * (D + D_PAD)]; half* shared_K = &shared_P[BT * (BS + S_PAD)]; - half* shared_V = &shared_K[BS * (D + D_PAD)]; #} - __shared__ half shared_Q[BT * (D + D_PAD)]; + half* shared_V = &shared_K[BS * (D + D_PAD)]; + {# __shared__ half shared_Q[BT * (D + D_PAD)]; __shared__ half shared_P[BT * (BS + S_PAD)]; __shared__ half shared_K[BS * (D + D_PAD)]; - __shared__ half shared_V[BS * (D + D_PAD)]; + __shared__ half shared_V[BS * (D + D_PAD)]; #} int SMEM_TID_N = threadIdx.x / SMEM_THREADS_D; int SMEM_TID_D = threadIdx.x % SMEM_THREADS_D * 8; @@ -94,7 +94,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( half tmp_half8[8]; float frag[T]; - float temperature = __frsqrt_rn((float)Ns); + float temperature = __frsqrt_rn((float)D); float row_max; float row_sum; float seg_max; diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 index 5cf38732..0f00769e 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 @@ -73,13 +73,13 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; - {# extern __shared__ float shared[]; + extern __shared__ float shared[]; float* shared_Q = &shared[0]; - float* shared_K = &shared[BT * D]; - float* shared_V = &shared[BT * D + BS * D]; #} - __shared__ float shared_Q[BT * D]; + float* shared_K = &shared_Q[BT * D]; + float* shared_V = &shared_K[BS * D]; + {# __shared__ float shared_Q[BT * D]; __shared__ float shared_K[BS * D]; - __shared__ float shared_V[BS * D]; + __shared__ float shared_V[BS * D]; #} {# __shared__ float shared_ML[BT * 2]; #} float* shared_ML = shared_Q; @@ -93,7 +93,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( float frag_P[TT][TS]; float frag_ML[TT]; - float temperature = __frsqrt_rn((float)Ns); + float temperature = __frsqrt_rn((float)D); float row_max; float row_sum; float row_sum_new; diff --git a/sparta/testing/math.py b/sparta/testing/math.py index 4fa324cc..913d954a 100644 --- a/sparta/testing/math.py +++ b/sparta/testing/math.py @@ -72,7 +72,7 @@ def sparse_multi_head_attention_forward_reference( torch.Tensor: Sparse multi-head attention output of shape :math:`(B, H, N_{target}, E)`. """ if np.isnan(temperature): - temperature = np.sqrt(mask.shape[-1]) + temperature = np.sqrt(key.shape[-1]) high_dims = ''.join([chr(ord('a') + i) for i in range(len(query.shape) - 2)]) p = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) s = sparse_softmax_forward_reference(p, mask, temperature) @@ -103,7 +103,7 @@ def sparse_multi_head_attention_backward_reference( Tuple: The gradient of query, key and value respectively.. """ if np.isnan(temperature): - temperature = np.sqrt(mask.shape[-1]) + temperature = np.sqrt(key.shape[-1]) high_dims = ''.join([chr(ord('a') + i) for i in range(len(query.shape) - 2)]) p = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) s = sparse_softmax_forward_reference(p, mask, temperature) From 92a6781ac8e4f90e8f67e488ccbb9194eac55c43 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Fri, 12 May 2023 19:00:35 +0900 Subject: [PATCH 25/28] Flash Attention fp16 backward version 2 --- sparta/kernels/attention.py | 2 +- ...h_sparse_attention_backward_float16.cuh.j2 | 24 +++++++++---------- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py index 857cd532..94def595 100644 --- a/sparta/kernels/attention.py +++ b/sparta/kernels/attention.py @@ -294,5 +294,5 @@ def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): class FlashSparseAttentionFP16BackwardKernel(FlashSparseAttentionFP16Kernel, FlashSparseAttentionBackwardKernel): def _calc_shared_mem_size(self, Bs: int, Bt: int, D: int): - shared = 2 * (2 * Bt * (D + 8) + 4 * (Bs * (D + 8)) + 2 * Bt * (Bs + 8)) + shared = 2 * (2 * Bt * (D + 8) + 4 * (Bs * (D + 8)) + Bt * (Bs + 8)) return shared diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 index d29b0aad..6f7f90b9 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 @@ -80,15 +80,13 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( extern __shared__ half shared[]; half* shared_Q = &shared[0]; half* shared_P = &shared_Q[BT * (D + D_PAD)]; - half* shared_S = &shared_P[BT * (BS + S_PAD)]; - half* shared_K = &shared_S[BT * (BS + S_PAD)]; + half* shared_K = &shared_P[BT * (BS + S_PAD)]; half* shared_V = &shared_K[BS * (D + D_PAD)]; half* shared_O = &shared_V[BS * (D + D_PAD)]; half* shared_dK = &shared_O[BT * (D + D_PAD)]; half* shared_dV = &shared_dK[BS * (D + D_PAD)]; {# __shared__ half shared_Q[BT * (D + D_PAD)]; __shared__ half shared_P[BT * (BS + S_PAD)]; - __shared__ half shared_S[BT * (BS + S_PAD)]; __shared__ half shared_K[BS * (D + D_PAD)]; __shared__ half shared_V[BS * (D + D_PAD)]; __shared__ half shared_O[BT * (D + D_PAD)]; @@ -263,7 +261,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( wmma::mma_sync(frag_ts_c, frag_ts_a, frag_ts_b, frag_ts_c); } wmma::store_matrix_sync( - &shared_S[(j + ts_wy * TS_WARP_M) * (BS + S_PAD) + ts_wx * TS_WARP_N], + &shared_P[(j + ts_wy * TS_WARP_M) * (BS + S_PAD) + ts_wx * TS_WARP_N], frag_ts_c, BS + S_PAD, wmma::mem_row_major @@ -272,16 +270,16 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( __syncthreads(); // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { - // printf("dS[0][0] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 0])); - // printf("dS[0][1] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 1])); - // printf("dS[1][0] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 0])); - // printf("dS[1][1] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 1])); + // printf("dS[0][0] = %f\n", (float)(shared_P[0 * (BS + S_PAD) + 0])); + // printf("dS[0][1] = %f\n", (float)(shared_P[0 * (BS + S_PAD) + 1])); + // printf("dS[1][0] = %f\n", (float)(shared_P[1 * (BS + S_PAD) + 0])); + // printf("dS[1][1] = %f\n", (float)(shared_P[1 * (BS + S_PAD) + 1])); // } {# Load dS #} #pragma unroll for (int i = 0; i < T; i += 8) { - *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_S[ty * (BS + S_PAD) + tx * T + i])); + *((float4*)(&tmp_half8[0])) = *((float4*)(&shared_P[ty * (BS + S_PAD) + tx * T + i])); #pragma unroll for (int j = 0; j < 8; j++) { frag_P[i + j] = __half2float(tmp_half8[j]); @@ -320,10 +318,10 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( } // if (blockIdx.x == 0 && threadIdx.x == 0 && row_idx == 0 && col_idx == 0) { - // printf("dP[0][0] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 0])); - // printf("dP[0][1] = %f\n", (float)(shared_S[0 * (BS + S_PAD) + 1])); - // printf("dP[1][0] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 0])); - // printf("dP[1][1] = %f\n", (float)(shared_S[1 * (BS + S_PAD) + 1])); + // printf("dP[0][0] = %f\n", (float)(shared_P[0 * (BS + S_PAD) + 0])); + // printf("dP[0][1] = %f\n", (float)(shared_P[0 * (BS + S_PAD) + 1])); + // printf("dP[1][0] = %f\n", (float)(shared_P[1 * (BS + S_PAD) + 0])); + // printf("dP[1][1] = %f\n", (float)(shared_P[1 * (BS + S_PAD) + 1])); // } {# Load dQ #} From 1f636e118b834187ac4aefc6ddc28039973a7a72 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 16 May 2023 15:04:54 +0900 Subject: [PATCH 26/28] update Flash Attention kernel: transpose (N, H) is not required --- sparta/kernels/attention.py | 18 +++--- sparta/kernels/kernel_base.py | 4 +- ...h_sparse_attention_backward_float16.cuh.j2 | 60 ++++++++++++------- ...sh_sparse_attention_forward_float16.cuh.j2 | 42 ++++++++----- sparta/testing/math.py | 40 +++++++++---- 5 files changed, 106 insertions(+), 58 deletions(-) diff --git a/sparta/kernels/attention.py b/sparta/kernels/attention.py index 94def595..65bc555f 100644 --- a/sparta/kernels/attention.py +++ b/sparta/kernels/attention.py @@ -24,17 +24,19 @@ class FlashSparseAttentionKernel(KernelBase): - __lut_shape__ = (64 * 12, 1024, 1024, 64) # BxH, Nt, Ns, D + __lut_shape__ = (64, 1024, 1024, 12, 64) # B, Nt, Ns, H, D __algo__ = 'flash' __dtype__ = '' __direction__ = '' - def __init__(self, buffer: torch.Tensor): + def __init__(self, buffer: torch.Tensor, transposed: bool = False): self._buffer = buffer + self._transposed = transposed super().__init__() def _add_parameters(self): self._add_parameter('GLOBAL_SIZE_D_VALUE') + self._add_parameter('TRANSPOSED', value=self._transposed) self._add_parameter( 'BLOCK_SIZE_S_VALUE', is_tunable=True, @@ -193,8 +195,8 @@ class FlashSparseAttentionForwardKernel(FlashSparseAttentionKernel): __direction__ = 'forward' - def set_kernel_call(self, shape: Tuple[int, int, int, int]): - batch, Nt, Ns, D = shape + def set_kernel_call(self, shape: Tuple[int, int, int, int, int]): + batch, Nt, Ns, heads, D = shape self._check_shape(Nt, Ns, D) Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) block = self.threads_per_block() @@ -211,7 +213,7 @@ def attn_func(Q, K, V): Ns_32, Nt_32, # D_32, self.attr.indexes.nnz, block=block, - grid=(Q.shape[0], 1, 1), + grid=(heads, Q.shape[0], 1), shared=shared, ) return O @@ -232,8 +234,8 @@ class FlashSparseAttentionBackwardKernel(FlashSparseAttentionKernel): __direction__ = 'backward' - def set_kernel_call(self, shape: Tuple[int, int, int, int]): - batch, Nt, Ns, D = shape + def set_kernel_call(self, shape: Tuple[int, int, int, int, int]): + batch, Nt, Ns, heads, D = shape self._check_shape(Nt, Ns, D) Ns_32, Nt_32, D_32 = np.int32(Ns), np.int32(Nt), np.int32(D) block = self.threads_per_block() @@ -251,7 +253,7 @@ def attn_func(grad, O, Q, K, V): Ns_32, Nt_32, # D_32, self.attr.indexes.nnz, block=block, - grid=(Q.shape[0], 1, 1), + grid=(heads, Q.shape[0], 1), shared=shared, ) return grad_Q, grad_K, grad_V diff --git a/sparta/kernels/kernel_base.py b/sparta/kernels/kernel_base.py index f794f13b..1ac879b3 100644 --- a/sparta/kernels/kernel_base.py +++ b/sparta/kernels/kernel_base.py @@ -242,8 +242,10 @@ def compile(self, params: Dict[str, Any], shape: Tuple): self.attr.update_block_size(self.id) self._kernel = self._build_kernel(self.get_kernel_code()) self.set_kernel_call(shape) + self._estimate_latency(shape) self.ready = True - # Calc estimated latency + + def _estimate_latency(self, shape: Tuple): indexes = self.attr.indexes sparse_rate = indexes.block_nnz / indexes.row_num / indexes.col_num shape_rate = np.prod(shape) / np.prod(self.__lut_shape__) diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 index 6f7f90b9..55fdef77 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float16.cuh.j2 @@ -28,7 +28,6 @@ const int SD_WARP_M = {{ SD_WARP_SIZE_M_VALUE }}; const int SD_WARP_N = {{ SD_WARP_SIZE_N_VALUE }}; const int SD_WARP_K = 16; -const int B = {{ BLOCK_SIZE }}; const int T = {{ THREAD_SIZE }}; const int THREADS = {{ THREADS_PER_BLOCK }};{# THREADS_PER_BLOCK >= WARP_SIZE #} const int WARPS = THREADS / {{ WARP_SIZE }}; @@ -64,15 +63,30 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( uint Nt, uint block_nnz ) { - Q += Nt * D * blockIdx.x; - K += Ns * D * blockIdx.x; - V += Ns * D * blockIdx.x; - O += Nt * D * blockIdx.x; - dQ += Nt * D * blockIdx.x; - dK += Ns * D * blockIdx.x; - dV += Ns * D * blockIdx.x; - dO += Nt * D * blockIdx.x; - ML += Nt * 2 * blockIdx.x; + int H = gridDim.x; + int HEAD_IDX = (blockIdx.y * H + blockIdx.x); + {% if TRANSPOSED %} + Q += HEAD_IDX * Nt * D; + K += HEAD_IDX * Ns * D; + V += HEAD_IDX * Ns * D; + O += HEAD_IDX * Nt * D; + dQ += HEAD_IDX * Nt * D; + dK += HEAD_IDX * Ns * D; + dV += HEAD_IDX * Ns * D; + dO += HEAD_IDX * Nt * D; + int stride = D; + {% else %} + Q += blockIdx.y * Nt * H * D + blockIdx.x * D; + K += blockIdx.y * Ns * H * D + blockIdx.x * D; + V += blockIdx.y * Ns * H * D + blockIdx.x * D; + O += blockIdx.y * Nt * H * D + blockIdx.x * D; + dQ += blockIdx.y * Nt * H * D + blockIdx.x * D; + dK += blockIdx.y * Ns * H * D + blockIdx.x * D; + dV += blockIdx.y * Ns * H * D + blockIdx.x * D; + dO += blockIdx.y * Nt * H * D + blockIdx.x * D; + int stride = H * D; + {% endif %} + ML += Nt * 2 * HEAD_IDX; uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; @@ -138,16 +152,16 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&Q[(row_idx * BT + k) * stride + SMEM_TID_D])); } if (col_idx != last_col_idx) { if (last_col_idx >= 0) { {# Save dK, dV #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&dK[(last_col_idx * BS + k) * stride + SMEM_TID_D])) = *((float4*)(&shared_dK[k * (D + D_PAD) + SMEM_TID_D])); - *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&dV[(last_col_idx * BS + k) * stride + SMEM_TID_D])) = *((float4*)(&shared_dV[k * (D + D_PAD) + SMEM_TID_D])); } } @@ -155,13 +169,13 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { *((float4*)(&shared_dK[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&dK[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&dK[(col_idx * BS + k) * stride + SMEM_TID_D])); *((float4*)(&shared_dV[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&dV[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&dV[(col_idx * BS + k) * stride + SMEM_TID_D])); *((float4*)(&shared_K[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&K[(col_idx * BS + k) * stride + SMEM_TID_D])); *((float4*)(&shared_V[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&V[(col_idx * BS + k) * stride + SMEM_TID_D])); } last_col_idx = col_idx; } @@ -222,7 +236,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { *((float4*)(&shared_O[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&dO[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&dO[(row_idx * BT + k) * stride + SMEM_TID_D])); } __syncthreads(); @@ -290,7 +304,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( row_sum = 0.0f; #pragma unroll for (int i = 0; i < SD; i += 8) { - *((float4*)(&tmp_half8[0])) = *((float4*)(&O[(row_idx * BT + ty) * D + tx * SD + i])); + *((float4*)(&tmp_half8[0])) = *((float4*)(&O[(row_idx * BT + ty) * stride + tx * SD + i])); *((float4*)(&tmp_half8_2[0])) = *((float4*)(&shared_O[ty * (D + D_PAD) + tx * SD + i])); #pragma unroll for (int j = 0; j < 8; j++) { @@ -328,7 +342,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { *((float4*)(&shared_O[k * (D + D_PAD) + SMEM_TID_D])) = - *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&dQ[(row_idx * BT + k) * stride + SMEM_TID_D])); } __syncthreads(); @@ -383,7 +397,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Save dQ #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])) = + *((float4*)(&dQ[(row_idx * BT + k) * stride + SMEM_TID_D])) = *((float4*)(&shared_O[k * (D + D_PAD) + SMEM_TID_D])); } } @@ -391,9 +405,9 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Save dK, dV for the last column #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&dK[(last_col_idx * BS + k) * stride + SMEM_TID_D])) = *((float4*)(&shared_dK[k * (D + D_PAD) + SMEM_TID_D])); - *((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D])) = + *((float4*)(&dV[(last_col_idx * BS + k) * stride + SMEM_TID_D])) = *((float4*)(&shared_dV[k * (D + D_PAD) + SMEM_TID_D])); } } diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 index 56d39ca8..3bc6c969 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float16.cuh.j2 @@ -24,7 +24,6 @@ const int TD_WARP_M = {{ TD_WARP_SIZE_M_VALUE }}; const int TD_WARP_N = {{ TD_WARP_SIZE_N_VALUE }}; const int TD_WARP_K = 16; -const int B = {{ BLOCK_SIZE }}; const int T = {{ THREAD_SIZE }}; const int THREADS = {{ THREADS_PER_BLOCK }};{# THREADS_PER_BLOCK >= WARP_SIZE #} const int WARPS = THREADS / {{ WARP_SIZE }}; @@ -54,11 +53,22 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( uint Nt, uint block_nnz ) { - Q += Nt * D * blockIdx.x; - K += Ns * D * blockIdx.x; - V += Ns * D * blockIdx.x; - O += Nt * D * blockIdx.x; - ML += Nt * 2 * blockIdx.x; + int H = gridDim.x; + int HEAD_IDX = (blockIdx.y * H + blockIdx.x); + {% if TRANSPOSED %} + Q += HEAD_IDX * Nt * D; + K += HEAD_IDX * Ns * D; + V += HEAD_IDX * Ns * D; + O += HEAD_IDX * Nt * D; + int stride = D; + {% else %} + Q += blockIdx.y * Nt * H * D + blockIdx.x * D; + K += blockIdx.y * Ns * H * D + blockIdx.x * D; + V += blockIdx.y * Ns * H * D + blockIdx.x * D; + O += blockIdx.y * Nt * H * D + blockIdx.x * D; + int stride = H * D; + {% endif %} + ML += Nt * 2 * HEAD_IDX; uint WARP_OFFSET = ((threadIdx.x / {{ WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}) % {{ WARP_SIZE }}; uint WARP_MASK = 0b{% for _ in range(WARP_REDUCE_SIZE) %}1{% endfor %} << WARP_OFFSET; @@ -114,18 +124,17 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Load Q #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&Q[(row_idx * BT + k) * stride + SMEM_TID_D])); } if (col_idx != last_col_idx) { - {# Load K #} + {# Load K, V #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&shared_K[k * (D + D_PAD) + SMEM_TID_D])) = *((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D])); - } - {# Load V #} - #pragma unroll - for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&shared_V[k * (D + D_PAD) + SMEM_TID_D])) = *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_K[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&K[(col_idx * BS + k) * stride + SMEM_TID_D])); + *((float4*)(&shared_V[k * (D + D_PAD) + SMEM_TID_D])) = + *((float4*)(&V[(col_idx * BS + k) * stride + SMEM_TID_D])); } last_col_idx = col_idx; } @@ -242,7 +251,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Load O #} #pragma unroll for (int i = 0; i < SD; i += 8) { - *((float4*)(&tmp_half8[0])) = *((float4*)(&O[(row_idx * BT + ty) * D + tx * SD + i])); + *((float4*)(&tmp_half8[0])) = *((float4*)(&O[(row_idx * BT + ty) * stride + tx * SD + i])); #pragma unroll for (int j = 0; j < 8; j++) { tmp_half8[j] = __float2half(__half2float(tmp_half8[j]) * row_coef); @@ -278,7 +287,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_FP16( {# Save O #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])); + *((float4*)(&O[(row_idx * BT + k) * stride + SMEM_TID_D])) = + *((float4*)(&shared_Q[k * (D + D_PAD) + SMEM_TID_D])); } } } diff --git a/sparta/testing/math.py b/sparta/testing/math.py index 913d954a..46cb0354 100644 --- a/sparta/testing/math.py +++ b/sparta/testing/math.py @@ -56,27 +56,36 @@ def sparse_multi_head_attention_forward_reference( value: torch.Tensor, mask: torch.Tensor, temperature: float = np.nan, + transposed: bool = False, ) -> torch.Tensor: r"""Sparse multi-head attention reference function with batch size :math:`B`, head number :math:`H`, sourse sequence length :math:`N_{source}`, target sequence length :math:`N_{target}` and embed dimention :math:`E`. Args: - query (torch.Tensor): The input query tensor of shape :math:`(B, H, N_{target}, E)`. - key (torch.Tensor): The input key tensor of shape :math:`(B, H, N_{source}, E)`. - value (torch.Tensor): The input value tensor of shape :math:`(B, H, N_{source}, E)`. + query (torch.Tensor): The input query tensor of shape :math:`(B, N_{target}, H, E)`. + key (torch.Tensor): The input key tensor of shape :math:`(B, N_{source}, H, E)`. + value (torch.Tensor): The input value tensor of shape :math:`(B, N_{source}, H, E)`. mask (torch.Tensor): The mask tensor of shape :math:`(N_{target}, N_{source})`. temperature (float): The softmax temperature which is set to :math:`\sqrt{N_{source}}` by default. + transposed (bool): If true, the head dimension and the sequence length dimension are transposed. Returns: - torch.Tensor: Sparse multi-head attention output of shape :math:`(B, H, N_{target}, E)`. + torch.Tensor: Sparse multi-head attention output of shape :math:`(B, N_{target}, H, E)`. """ if np.isnan(temperature): temperature = np.sqrt(key.shape[-1]) + if not transposed: + query = query.swapaxes(-2, -3) + key = key.swapaxes(-2, -3) + value = value.swapaxes(-2, -3) high_dims = ''.join([chr(ord('a') + i) for i in range(len(query.shape) - 2)]) p = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) s = sparse_softmax_forward_reference(p, mask, temperature) - return torch.einsum(f'{high_dims}mn, {high_dims}nk -> {high_dims}mk', s, value) + out = torch.einsum(f'{high_dims}mn, {high_dims}nk -> {high_dims}mk', s, value) + if not transposed: + out = out.swapaxes(-2, -3) + return out def sparse_multi_head_attention_backward_reference( @@ -87,23 +96,30 @@ def sparse_multi_head_attention_backward_reference( value: torch.Tensor, mask: torch.Tensor, temperature: float = np.nan, + transposed: bool = False, ) -> torch.Tensor: r"""Sparse multi-head attention backward reference function. Args: - grad (torch.Tensor): The gradient of output tensor. Shape: :math:`(B, H, N_{target}, E)`. - output (torch.Tensor): The output tensor of forward function. Shape: :math:`(B, H, N_{target}, E)`. - query (torch.Tensor): The input query tensor of forward function. Shape: :math:`(B, H, N_{target}, E)`. - key (torch.Tensor): The input key tensor of forward function. Shape: :math:`(B, H, N_{source}, E)`. - value (torch.Tensor): The input value tensor of forward function. Shape: :math:`(B, H, N_{source}, E)`. + grad (torch.Tensor): The gradient of output tensor. Shape: :math:`(B, N_{target}, H, E)`. + output (torch.Tensor): The output tensor of forward function. Shape: :math:`(B, N_{target}, H, E)`. + query (torch.Tensor): The input query tensor of forward function. Shape: :math:`(B, N_{target}, H, E)`. + key (torch.Tensor): The input key tensor of forward function. Shape: :math:`(B, N_{source}, H, E)`. + value (torch.Tensor): The input value tensor of forward function. Shape: :math:`(B, N_{source}, H, E)`. mask (torch.Tensor): The mask tensor. Shape :math:`(N_{target}, N_{source})`. temperature (float): The softmax temperature which is set to :math:`\sqrt{N_{source}}` by default. + transposed (bool): If true, the head dimension and the sequence length dimension are transposed. Returns: Tuple: The gradient of query, key and value respectively.. """ if np.isnan(temperature): temperature = np.sqrt(key.shape[-1]) + if not transposed: + grad = grad.swapaxes(-2, -3) + query = query.swapaxes(-2, -3) + key = key.swapaxes(-2, -3) + value = value.swapaxes(-2, -3) high_dims = ''.join([chr(ord('a') + i) for i in range(len(query.shape) - 2)]) p = torch.einsum(f'{high_dims}mk, {high_dims}nk -> {high_dims}mn', query, key) s = sparse_softmax_forward_reference(p, mask, temperature) @@ -112,4 +128,8 @@ def sparse_multi_head_attention_backward_reference( grad_p = sparse_softmax_backward_reference(grad_s, s, mask, temperature) grad_q = torch.einsum(f'{high_dims}nk, {high_dims}mn -> {high_dims}mk', key, grad_p) grad_k = torch.einsum(f'{high_dims}mk, {high_dims}mn -> {high_dims}nk', query, grad_p) + if not transposed: + grad_q = grad_q.swapaxes(-2, -3) + grad_k = grad_k.swapaxes(-2, -3) + grad_v = grad_v.swapaxes(-2, -3) return grad_q, grad_k, grad_v From d4027ba8d333fefa999bce0a71fefe1faf73e99e Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Tue, 16 May 2023 15:37:51 +0900 Subject: [PATCH 27/28] update Flash Attention FP32 kernel: transpose (N, H) is not required --- ...h_sparse_attention_backward_float32.cuh.j2 | 63 ++++++++++++------- ...sh_sparse_attention_forward_float32.cuh.j2 | 38 +++++++---- 2 files changed, 67 insertions(+), 34 deletions(-) diff --git a/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 index 27e7eab2..c29eb3b5 100644 --- a/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_backward_float32.cuh.j2 @@ -60,15 +60,30 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( uint Nt, uint block_nnz ) { - Q += Nt * D * blockIdx.x; - K += Ns * D * blockIdx.x; - V += Ns * D * blockIdx.x; - O += Nt * D * blockIdx.x; - dQ += Nt * D * blockIdx.x; - dK += Ns * D * blockIdx.x; - dV += Ns * D * blockIdx.x; - dO += Nt * D * blockIdx.x; - ML += Nt * 2 * blockIdx.x; + int H = gridDim.x; + int HEAD_IDX = (blockIdx.y * H + blockIdx.x); + {% if TRANSPOSED %} + Q += HEAD_IDX * Nt * D; + K += HEAD_IDX * Ns * D; + V += HEAD_IDX * Ns * D; + O += HEAD_IDX * Nt * D; + dQ += HEAD_IDX * Nt * D; + dK += HEAD_IDX * Ns * D; + dV += HEAD_IDX * Ns * D; + dO += HEAD_IDX * Nt * D; + int stride = D; + {% else %} + Q += blockIdx.y * Nt * H * D + blockIdx.x * D; + K += blockIdx.y * Ns * H * D + blockIdx.x * D; + V += blockIdx.y * Ns * H * D + blockIdx.x * D; + O += blockIdx.y * Nt * H * D + blockIdx.x * D; + dQ += blockIdx.y * Nt * H * D + blockIdx.x * D; + dK += blockIdx.y * Ns * H * D + blockIdx.x * D; + dV += blockIdx.y * Ns * H * D + blockIdx.x * D; + dO += blockIdx.y * Nt * H * D + blockIdx.x * D; + int stride = H * D; + {% endif %} + ML += Nt * 2 * HEAD_IDX; uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; @@ -116,7 +131,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Load Q #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = + *((float4*)(&Q[(row_idx * BT + k) * stride + SMEM_TID_D])); } if (col_idx != last_col_idx) { if (last_col_idx >= 0) { @@ -127,33 +143,33 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( tmp_float4.y = shared_dK[(SMEM_TID_D+1) * BS + k]; tmp_float4.z = shared_dK[(SMEM_TID_D+2) * BS + k]; tmp_float4.w = shared_dK[(SMEM_TID_D+3) * BS + k]; - ((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; + ((float4*)(&dK[(last_col_idx * BS + k) * stride + SMEM_TID_D]))[0] = tmp_float4; tmp_float4.x = shared_dV[(SMEM_TID_D+0) * BS + k]; tmp_float4.y = shared_dV[(SMEM_TID_D+1) * BS + k]; tmp_float4.z = shared_dV[(SMEM_TID_D+2) * BS + k]; tmp_float4.w = shared_dV[(SMEM_TID_D+3) * BS + k]; - ((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; + ((float4*)(&dV[(last_col_idx * BS + k) * stride + SMEM_TID_D]))[0] = tmp_float4; } } {# Load K, V, dK, dV #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - tmp_float4 = ((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + tmp_float4 = ((float4*)(&K[(col_idx * BS + k) * stride + SMEM_TID_D]))[0]; shared_K[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; shared_K[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_K[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; shared_K[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; - tmp_float4 = ((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + tmp_float4 = ((float4*)(&V[(col_idx * BS + k) * stride + SMEM_TID_D]))[0]; shared_V[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; shared_V[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_V[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; shared_V[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; - tmp_float4 = ((float4*)(&dK[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + tmp_float4 = ((float4*)(&dK[(col_idx * BS + k) * stride + SMEM_TID_D]))[0]; shared_dK[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; shared_dK[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_dK[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; shared_dK[(SMEM_TID_D+3) * BS + k] = tmp_float4.w; - tmp_float4 = ((float4*)(&dV[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + tmp_float4 = ((float4*)(&dV[(col_idx * BS + k) * stride + SMEM_TID_D]))[0]; shared_dV[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; shared_dV[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_dV[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; @@ -236,7 +252,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Load dO #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_O[k * D + SMEM_TID_D])) = *((float4*)(&dO[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&shared_O[k * D + SMEM_TID_D])) = + *((float4*)(&dO[(row_idx * BT + k) * stride + SMEM_TID_D])); } __syncthreads(); @@ -336,7 +353,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( ((float4*)(&shared_O[k * D + SMEM_TID_D]))[0] = _mul_float4( ((float4*)(&shared_O[k * D + SMEM_TID_D]))[0], - ((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D]))[0] + ((float4*)(&O[(row_idx * BT + k) * stride + SMEM_TID_D]))[0] ); } __syncthreads(); @@ -362,7 +379,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Load dQ #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_O[k * D + SMEM_TID_D])) = *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&shared_O[k * D + SMEM_TID_D])) = + *((float4*)(&dQ[(row_idx * BT + k) * stride + SMEM_TID_D])); } __syncthreads(); @@ -485,7 +503,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( {# Save dQ #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&dQ[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_O[k * D + SMEM_TID_D])); + *((float4*)(&dQ[(row_idx * BT + k) * stride + SMEM_TID_D])) = + *((float4*)(&shared_O[k * D + SMEM_TID_D])); } } @@ -496,11 +515,11 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION_BACKWARD( tmp_float4.y = shared_dK[(SMEM_TID_D+1) * BS + k]; tmp_float4.z = shared_dK[(SMEM_TID_D+2) * BS + k]; tmp_float4.w = shared_dK[(SMEM_TID_D+3) * BS + k]; - ((float4*)(&dK[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; + ((float4*)(&dK[(last_col_idx * BS + k) * stride + SMEM_TID_D]))[0] = tmp_float4; tmp_float4.x = shared_dV[(SMEM_TID_D+0) * BS + k]; tmp_float4.y = shared_dV[(SMEM_TID_D+1) * BS + k]; tmp_float4.z = shared_dV[(SMEM_TID_D+2) * BS + k]; tmp_float4.w = shared_dV[(SMEM_TID_D+3) * BS + k]; - ((float4*)(&dV[(last_col_idx * BS + k) * D + SMEM_TID_D]))[0] = tmp_float4; + ((float4*)(&dV[(last_col_idx * BS + k) * stride + SMEM_TID_D]))[0] = tmp_float4; } } diff --git a/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 index 0f00769e..e2943aa2 100644 --- a/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 +++ b/sparta/kernels/templates/flash_sparse_attention_forward_float32.cuh.j2 @@ -64,11 +64,22 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( uint Nt, uint block_nnz ) { - Q += Nt * D * blockIdx.x; - K += Ns * D * blockIdx.x; - V += Ns * D * blockIdx.x; - O += Nt * D * blockIdx.x; - ML += Nt * 2 * blockIdx.x; + int H = gridDim.x; + int HEAD_IDX = (blockIdx.y * H + blockIdx.x); + {% if TRANSPOSED %} + Q += HEAD_IDX * Nt * D; + K += HEAD_IDX * Ns * D; + V += HEAD_IDX * Ns * D; + O += HEAD_IDX * Nt * D; + int stride = D; + {% else %} + Q += blockIdx.y * Nt * H * D + blockIdx.x * D; + K += blockIdx.y * Ns * H * D + blockIdx.x * D; + V += blockIdx.y * Ns * H * D + blockIdx.x * D; + O += blockIdx.y * Nt * H * D + blockIdx.x * D; + int stride = H * D; + {% endif %} + ML += Nt * 2 * HEAD_IDX; uint WARP_OFFSET = (threadIdx.y % {{ 32 // WARP_REDUCE_SIZE }}) * {{ WARP_REDUCE_SIZE }}; uint WARP_MASK = 0x{% for _ in range(WARP_REDUCE_SIZE // 4) %}f{% endfor %} << WARP_OFFSET; @@ -115,13 +126,14 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( {# Load Q #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = *((float4*)(&Q[(row_idx * BT + k) * D + SMEM_TID_D])); + *((float4*)(&shared_Q[k * D + SMEM_TID_D])) = + *((float4*)(&Q[(row_idx * BT + k) * stride + SMEM_TID_D])); } if (col_idx != last_col_idx) { {# Load K #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - tmp_float4 = ((float4*)(&K[(col_idx * BS + k) * D + SMEM_TID_D]))[0]; + tmp_float4 = ((float4*)(&K[(col_idx * BS + k) * stride + SMEM_TID_D]))[0]; shared_K[(SMEM_TID_D+0) * BS + k] = tmp_float4.x; shared_K[(SMEM_TID_D+1) * BS + k] = tmp_float4.y; shared_K[(SMEM_TID_D+2) * BS + k] = tmp_float4.z; @@ -130,7 +142,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( {# Load V #} #pragma unroll for (int k = SMEM_TID_N; k < BS; k += SMEM_THREADS_N) { - *((float4*)(&shared_V[k * D + SMEM_TID_D])) = *((float4*)(&V[(col_idx * BS + k) * D + SMEM_TID_D])); + *((float4*)(&shared_V[k * D + SMEM_TID_D])) = + *((float4*)(&V[(col_idx * BS + k) * stride + SMEM_TID_D])); } last_col_idx = col_idx; } @@ -263,11 +276,11 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( for (int jt = 0; jt < TT; jt++) { {% if THREAD_SIZE_S_TO_D == 1 %} shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD] = - O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD] * frag_ML[jt]; + O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * stride + threadIdx.x * SD] * frag_ML[jt]; {% elif THREAD_SIZE_S_TO_D == 2 %} ((float2*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD]))[0] = _scale_float2( - ((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD]))[0], + ((float2*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * stride + threadIdx.x * SD]))[0], frag_ML[jt] ); {% else %} @@ -275,7 +288,7 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( for (int i = 0; i < SD; i += 4) { ((float4*)(&shared_Q[(threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD + i]))[0] = _scale_float4( - ((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * D + threadIdx.x * SD + i]))[0], + ((float4*)(&O[(row_idx * BT + threadIdx.y + blockDim.y * jt) * stride + threadIdx.x * SD + i]))[0], frag_ML[jt] ); } @@ -343,7 +356,8 @@ __global__ void BLOCK_SPARSE_FLASH_ATTENTION( {# Save O #} #pragma unroll for (int k = SMEM_TID_N; k < BT; k += SMEM_THREADS_N) { - *((float4*)(&O[(row_idx * BT + k) * D + SMEM_TID_D])) = *((float4*)(&shared_Q[k * D + SMEM_TID_D])); + *((float4*)(&O[(row_idx * BT + k) * stride + SMEM_TID_D])) = + *((float4*)(&shared_Q[k * D + SMEM_TID_D])); } } } From 5bcae7d7a4f39ea6c10071611593dd71e80d26b5 Mon Sep 17 00:00:00 2001 From: Chengruidong Zhang Date: Sun, 18 Feb 2024 09:08:48 +0000 Subject: [PATCH 28/28] Fix sparse kernel unit tests --- setup.py | 8 +-- sparta/operators/operator_base.py | 3 + sparta/operators/sparse_moe.py | 6 +- sparta/operators/sparse_seqlen_attention.py | 4 +- sparta/testing/math.py | 4 +- sparta/tuning/__init__.py | 2 +- sparta/tuning/tunable.py | 65 ++++----------------- test/bench/matmul/matmul.py | 10 ++-- test/bench/matmul/sparta_params.csv | 2 +- test/tmp.py | 29 +++++++++ test/unit/test_seqlen_attention.py | 9 +-- test/unit/test_sparse_attention.py | 4 +- 12 files changed, 70 insertions(+), 76 deletions(-) create mode 100644 test/tmp.py diff --git a/setup.py b/setup.py index f3849b07..e76d6994 100644 --- a/setup.py +++ b/setup.py @@ -23,13 +23,13 @@ f.write(moe_kernel) moe_ext = CUDAExtension( - name='sparse_moe_cpp', + name='sparta.sp_moe_ops', sources=[ os.path.join('csrc', 'moe_sparse_forward.cpp'), os.path.join('csrc', 'build', 'moe_sparse_forward_kernel.cu'), ], extra_compile_args=[ - '-std=c++14', + '-std=c++17', '-O3', '-U__CUDA_NO_HALF_OPERATORS__', '-U__CUDA_NO_HALF_CONVERSIONS__', @@ -38,12 +38,12 @@ ext_modules.append(moe_ext) seqlen_dynamic_attention_ext = CUDAExtension( - name='seqlen_dynamic_sparse_attention_cpp', + name='sparta.sp_attn_ops', sources=[ os.path.join('csrc', 'seqlen_dynamic_sparse_attention_forward.cpp'), os.path.join('csrc', 'seqlen_dynamic_sparse_attention_forward_kernel.cu'), ], - extra_compile_args=['-std=c++14', '-O3'], + extra_compile_args=['-std=c++17', '-O3'], ) ext_modules.append(seqlen_dynamic_attention_ext) diff --git a/sparta/operators/operator_base.py b/sparta/operators/operator_base.py index 835338f7..5459b049 100644 --- a/sparta/operators/operator_base.py +++ b/sparta/operators/operator_base.py @@ -110,6 +110,9 @@ def _post_build(self): for port in self.ports.values(): port.clear_data() + def get_search_space(self): + pass + class SparseAutoGrad(SparseOperator): diff --git a/sparta/operators/sparse_moe.py b/sparta/operators/sparse_moe.py index cb954f37..b397930e 100644 --- a/sparta/operators/sparse_moe.py +++ b/sparta/operators/sparse_moe.py @@ -5,7 +5,7 @@ import torch -# import sparse_moe_cpp +from sparta import sp_moe_ops class DynamicSparseMoE(torch.nn.Module): @@ -24,8 +24,8 @@ def __init__(self, exp_modules: List[torch.nn.Linear]): self.expert_count = torch.zeros(self.num_exps, dtype=torch.int32, device=self.device) def forward(self, tokens: torch.Tensor, exp_ids: torch.Tensor): - sparse_moe_cpp.convert_index(exp_ids, self.sparse_index, self.expert_count) - return sparse_moe_cpp.forward( + sp_moe_ops.convert_index(exp_ids, self.sparse_index, self.expert_count) + return sp_moe_ops.forward( tokens, self.weight, exp_ids, diff --git a/sparta/operators/sparse_seqlen_attention.py b/sparta/operators/sparse_seqlen_attention.py index ace18acc..950faa8f 100644 --- a/sparta/operators/sparse_seqlen_attention.py +++ b/sparta/operators/sparse_seqlen_attention.py @@ -5,7 +5,7 @@ import torch -# import seqlen_dynamic_sparse_attention_cpp +from sparta import sp_attn_ops class SeqlenDynamicSparseAttentionFunction(torch.autograd.Function): @@ -21,7 +21,7 @@ def forward( H: int ): # ctx.save_for_backward() - return seqlen_dynamic_sparse_attention_cpp.forward(Q, K, V, inter_result, seqlens, H) + return sp_attn_ops.forward(Q, K, V, inter_result, seqlens, H) @staticmethod def backward(ctx, *grad_outputs): diff --git a/sparta/testing/math.py b/sparta/testing/math.py index 46cb0354..62cc4276 100644 --- a/sparta/testing/math.py +++ b/sparta/testing/math.py @@ -21,9 +21,9 @@ def sparse_softmax_forward_reference( torch.Tensor: The output tensor having the same shape with the input tensor. Notice that the return value on completely masked rows will be 0. """ - C_max = x.max(axis=-1).values.unsqueeze(-1) + C_max = x.max(axis=-1, keepdim=True).values C_exp = torch.exp((x - C_max) / temperature) * mask - C_exp_sum = C_exp.sum(axis=-1).unsqueeze(-1) + 1e-10 + C_exp_sum = C_exp.sum(axis=-1, keepdim=True) + 1e-10 return C_exp / C_exp_sum diff --git a/sparta/tuning/__init__.py b/sparta/tuning/__init__.py index cd761772..2ef2de95 100644 --- a/sparta/tuning/__init__.py +++ b/sparta/tuning/__init__.py @@ -1,5 +1,5 @@ # Copyright (c) Microsoft Corporation. # Licensed under the MIT license. -from sparta.tuning.tunable import Tunable, TunableItemCfg +from sparta.tuning.tunable import TunableItemCfg from sparta.tuning.tuners import Tuner, GridSearchTuner, RandomSearchTuner diff --git a/sparta/tuning/tunable.py b/sparta/tuning/tunable.py index 7dade30b..6ef483a7 100644 --- a/sparta/tuning/tunable.py +++ b/sparta/tuning/tunable.py @@ -16,18 +16,21 @@ class TunableItemCfg: cfg = TunableItemCfg('choice', _is_nested=True, _value={ 'openai': {}, 'sparta': { - 'BM': TunableItemCfg('choice', [32,64]), - 'BN': TunableItemCfg('choice', [8,16]), + 'BM': TunableItemCfg('choice', [32, 64]), + 'BN': TunableItemCfg('choice', [8, 16]), } }) nni_space = { - 'test': {'_type':'choice', '_value': [ - {'_name': 'openai'}, - { - '_name': 'sparta', - 'BM': {'_type': 'choice', '_value': [32,64]}, - 'BN': {'_type': 'choice', '_value': [8,16]}, - }]} + 'test': { + '_type': 'choice', '_value': [ + {'_name': 'openai'}, + { + '_name': 'sparta', + 'BM': {'_type': 'choice', '_value': [32, 64]}, + 'BN': {'_type': 'choice', '_value': [8, 16]}, + } + ] + } } # converted to a `NNI` search space (See more in https://nni.readthedocs.io/en/stable/hpo/search_space.html) @@ -71,47 +74,3 @@ def includes(self, value: Any): return False else: return value in self._value - - -class Tunable: - """The wrapper of NNI tuners that supports nested choice search space. - """ - - @staticmethod - def supported_tuners(): - from nni.algorithms.hpo.gridsearch_tuner import GridSearchTuner - from nni.algorithms.hpo.random_tuner import RandomTuner - from nni.algorithms.hpo.tpe_tuner import TpeTuner - from nni.algorithms.hpo.evolution_tuner import EvolutionTuner - return { - 'grid': GridSearchTuner, - 'rand': RandomTuner, - 'tpe': TpeTuner, - 'evolution': EvolutionTuner, - } - - @staticmethod - def create_tuner(algo: str, search_space_cfg: Dict[str, TunableItemCfg], tuner_kw: Dict = None): - """create NNI Tuner - - Args: - algo (str): tuning algorithm, allowed algo values and their corresponding tuners are: - - ========= =================================================== - algo tuner - ========= =================================================== - grid nni.algorithms.hpo.gridsearch_tuner.GridSearchTuner - rand nni.algorithms.hpo.random_tuner.RandomTuner - tpe nni.algorithms.hpo.tpe_tuner.TpeTuner - evolution nni.algorithms.hpo.evolution_tuner.EvolutionTuner - ========= =================================================== - - search_space_cfg (TunableItemCfg): search space config - tuner_kw (Dict): parameters passed to NNI tuner - """ - supported_tuners = Tunable.supported_tuners() - assert algo in supported_tuners, f'{algo} is not supported' - tuner_kw = tuner_kw or {} - tuner = supported_tuners[algo](**tuner_kw) - tuner.update_search_space({k: v.to_nni_search_space() if v else {} for k, v in search_space_cfg.items()}) - return tuner diff --git a/test/bench/matmul/matmul.py b/test/bench/matmul/matmul.py index e8c8243f..c8226a9c 100644 --- a/test/bench/matmul/matmul.py +++ b/test/bench/matmul/matmul.py @@ -124,14 +124,16 @@ def profile_sparta_matmul( return 0., 0. sparta_matmul = SparseBatchMatMul( - B_mask=mask, + mode='dsd', transpose_A=False, transpose_B=True, + biased=False, compressed=True, ) + sparta_matmul.set_mask(mask) sparta_matmul.build(config, sample_inputs=[data['A'], data['B']]) - indexes = sparta_matmul.get_sparse_indexes('B') + indexes = sparta_matmul.get_sparse_indexes() data['B'] = indexes.convert(data['B']) data['grad_B'] = indexes.convert(data['grad_B']) @@ -182,8 +184,8 @@ def profile_all(log_path: str, device: Any = 'cuda'): latency['dense'] = profile_dense_matmul(M, K, N, (g, g), s, device) config = get_sparta_config(sparta_configs, g, s) latency['sparta'] = profile_sparta_matmul(config, M, K, N, (g, g), s, device) - for block in [16, 32, 64]: - latency[f'triton-{block}'] = profile_triton_matmul(block, M, K, N, (g, g), s, device) + # for block in [16, 32, 64]: + # latency[f'triton-{block}'] = profile_triton_matmul(block, M, K, N, (g, g), s, device) with open(log_path, 'a') as f: for method, (lat_f, lat_b) in latency.items(): f.write(f'{method},{M},{K},{N},{g},{s},{lat_f},{lat_b}\n') diff --git a/test/bench/matmul/sparta_params.csv b/test/bench/matmul/sparta_params.csv index 45f9cf46..6c087c58 100644 --- a/test/bench/matmul/sparta_params.csv +++ b/test/bench/matmul/sparta_params.csv @@ -1,4 +1,4 @@ -granularity,sparsity,forward:C;BLOCK_SIZE_M_VALUE,forward:C;BLOCK_SIZE_K_VALUE,forward:C;BLOCK_SIZE_N_VALUE,backward:A;BLOCK_SIZE_M_VALUE,backward:A;BLOCK_SIZE_K_VALUE,backward:A;BLOCK_SIZE_N_VALUE,backward:B;BLOCK_SIZE_M_VALUE,backward:B;BLOCK_SIZE_K_VALUE,backward:B;BLOCK_SIZE_N_VALUE,forward:C;_impl,backward:A;_impl,backward:B;_impl +granularity,sparsity,forward;BLOCK_SIZE_M_VALUE,forward;BLOCK_SIZE_K_VALUE,forward;BLOCK_SIZE_N_VALUE,backward:A;BLOCK_SIZE_M_VALUE,backward:A;BLOCK_SIZE_K_VALUE,backward:A;BLOCK_SIZE_N_VALUE,backward:B;BLOCK_SIZE_M_VALUE,backward:B;BLOCK_SIZE_K_VALUE,backward:B;BLOCK_SIZE_N_VALUE,forward;_impl,backward:A;_impl,backward:B;_impl 1,0.0,0,0,0,64,32,64,32,16,64,openai,sparta,sparta 1,0.2,0,0,0,64,32,64,32,16,64,openai,sparta,sparta 1,0.4,0,0,0,64,32,64,32,16,64,openai,sparta,sparta diff --git a/test/tmp.py b/test/tmp.py new file mode 100644 index 00000000..314d8b13 --- /dev/null +++ b/test/tmp.py @@ -0,0 +1,29 @@ +import torch +import sparta + +batch_size, in_features, out_features = 1024, 1024, 1024 +sparsity = 0.9 +granularity = (8, 8) + +# prepare data +x = torch.rand((batch_size, in_features), device='cuda') +weight = torch.rand((out_features, in_features), device='cuda') +bias = torch.rand((out_features, ), device='cuda') + +# generate and apply weight mask +mask = sparta.testing.block_mask(weight.shape, granularity, sparsity, device='cuda') +weight = torch.mul(weight, mask) + +# create a dense operator +dense_linear = torch.nn.Linear(in_features, out_features, device='cuda') +dense_linear.load_state_dict({'weight': weight, 'bias': bias}) + +# create a sparse operator +sparse_linear = sparta.nn.SparseLinear(dense_linear) +sparse_linear.set_mask(mask) + +# tune the sparse operator +best_config = sparta.nn.tune(sparse_linear, sample_inputs=[x], max_trials=10, algo='rand') + +# check if the sparse operator runs correctly +torch.testing.assert_close(sparse_linear(x), dense_linear(x)) \ No newline at end of file diff --git a/test/unit/test_seqlen_attention.py b/test/unit/test_seqlen_attention.py index 8c8f6080..bacbd50c 100644 --- a/test/unit/test_seqlen_attention.py +++ b/test/unit/test_seqlen_attention.py @@ -39,11 +39,12 @@ def test_seqlen_attention_operator(B: int, H: int, N: int, E: int, global_mode: value = torch.rand(size=(B, H, N, E), dtype=torch.float32, device='cuda') target_out = sparse_multi_head_attention_forward_reference( - query=query.reshape((-1, N, E)), - key=key.reshape((-1, N, E)), - value=value.reshape((-1, N, E)), - mask=mask.tile((1, H, 1, 1)).reshape(-1, N, N), + query=query, + key=key, + value=value, + mask=mask.tile((1, H, 1, 1)), temperature=1.0, + transposed=True, ).reshape((B, H, N, E)) if global_mode: diff --git a/test/unit/test_sparse_attention.py b/test/unit/test_sparse_attention.py index f7b66265..a0c3434f 100644 --- a/test/unit/test_sparse_attention.py +++ b/test/unit/test_sparse_attention.py @@ -53,7 +53,7 @@ def test_sparse_attention_operator( sparsity: float = 0.95, ): torch.manual_seed(2022) - mask = block_mask((Nt, Ns), block=granularity, sparsity=sparsity, device='cuda') + mask = block_mask(shape=(Nt, Ns), granularity=granularity, sparsity=sparsity, device='cuda') query = torch.rand(size=(H, Nt, E), device='cuda') key = torch.rand(size=(H, Ns, E), device='cuda') value = torch.rand(size=(H, Ns, E), device='cuda') @@ -89,5 +89,5 @@ def test_sparse_attention_operator( torch.testing.assert_close(value.grad, target_grad_value, atol=1e-4, rtol=1e-8) torch.manual_seed(random_seed) - mask = block_mask((Nt, Ns), block=granularity, sparsity=sparsity, device='cuda') + mask = block_mask(shape=(Nt, Ns), granularity=granularity, sparsity=sparsity, device='cuda') sparse_attention.update_mask(mask)