diff --git a/ddw/utils/subtomos.py b/ddw/utils/subtomos.py index 0fdcc9a..569b887 100644 --- a/ddw/utils/subtomos.py +++ b/ddw/utils/subtomos.py @@ -111,13 +111,11 @@ def get_linear_ramp_weights(subtomo_size, subtomo_overlap): weight_map_1d[-subtomo_overlap:] = ramp[::-1] # and at the end, inverted # Create a 3D weight map by extending the 1D weight map to 3 dimensions - weight_map_3d = np.ones((subtomo_size, subtomo_size, subtomo_size)) - for i in range(subtomo_size): - for j in range(subtomo_size): - for k in range(subtomo_size): - weight_map_3d[i, j, k] = ( - weight_map_1d[i] * weight_map_1d[j] * weight_map_1d[k] - ) + weight_map_3d = ( + weight_map_1d[:, None, None] + * weight_map_1d[None, :, None] + * weight_map_1d[None, None, :] + ) return torch.from_numpy(weight_map_3d)