diff --git a/.gitignore b/.gitignore index 6b46b4b..a37252d 100644 --- a/.gitignore +++ b/.gitignore @@ -209,4 +209,5 @@ __marimo__/ /data/train /data/test /profiler/trace.json -/logs/profiler \ No newline at end of file +/logs/profiler +/.vscode \ No newline at end of file diff --git a/compress/graph_extractor.py b/compress/graph_extractor.py index 1b684ed..6e324c7 100644 --- a/compress/graph_extractor.py +++ b/compress/graph_extractor.py @@ -1,4 +1,5 @@ from typing import Optional + import numpy as np import torch import torch.fx as fx @@ -6,11 +7,13 @@ from compress.heuristics import collect_convolution_layers_to_prune from compress.osscar_utils import ( - get_XtY, get_coeff_h, + get_optimal_W, + get_XtY, + recompute_H, + recompute_W, reshape_conv_layer_input, reshape_filter, - get_optimal_W, ) from model.resnet import resnet50 @@ -251,121 +254,318 @@ def get_all_subnets(prune_modules_name, graph_module): return prune_subnets, dense_subnets +def evaluate_loss(submatrix_w_pruned, dense_weights, subgram_xx, subgram_xy): + """ + Compute the OSSCAR-style loss for a candidate pruned set of channels. + + Parameters + ---------- + submatrix_w_pruned : torch.Tensor + Weight matrix corresponding to the pruned channels. + dense_weights : torch.Tensor + Original dense weight matrix of the layer. + subgram_xx : torch.Tensor + Gram matrix of pruned activations (X^T X for pruned channels). + subgram_xy : torch.Tensor + Cross Gram matrix between pruned activations and dense outputs (X^T Y). + + Returns + ------- + torch.Tensor + Scalar loss measuring reconstruction error for the pruned channels. + """ + A = (submatrix_w_pruned.T @ subgram_xx) @ submatrix_w_pruned + B = (dense_weights.T @ subgram_xy) @ dense_weights + return torch.trace(A) - 2 * torch.trace(B) + + def perform_local_search( - w_optimal, + dense_weights, layer, p, - prune_by_iter: Optional[list], - sym_diff_per_iter: Optional[list], + gram_xx, + gram_xy, + prune_by_iter: Optional[list] = None, + sym_diff_per_iter: Optional[list] = None, prune_per_iter=2, ): + """ + Greedy local search to select which input channels to prune in a Conv2d layer. + + Iteratively evaluates the importance of each input channel using Gram matrices + (X^T X and X^T Y) and prunes the least important channels according to the + specified schedule. + + Args: + dense_weights (torch.Tensor): Original dense weight matrix of the layer. + layer (nn.Conv2d): Convolutional layer to prune. + p (int): Total number of channels to prune. + gram_xx (torch.Tensor): Full Gram matrix of layer inputs (X^T X). + gram_xy (torch.Tensor): Cross Gram matrix between layer inputs and outputs (X^T Y). + prune_by_iter (list, optional): Custom pruning schedule per iteration. + sym_diff_per_iter (list, optional): Symmetric difference allowed per iteration. + prune_per_iter (int, default=2): Number of channels to prune per iteration + (ignored if `prune_by_iter` is provided). + + Returns: + keep_mask (torch.BoolTensor): Boolean mask of channels to keep (`True`) or prune (`False`). + kept_channels (set): Indices of channels retained after pruning. + removed_channels (set): Indices of channels removed during pruning. + """ assert isinstance(layer, nn.Conv2d) assert isinstance(p, int) and p > 0 + + # Determine pruning schedule if prune_by_iter is None: prune_list = [] num_iterations = p // prune_per_iter rem = p % prune_per_iter - zero_set = set() - if p >= prune_per_iter: - prune_list.extend([prune_per_iter for i in range(num_iterations)]) + prune_list.extend([prune_per_iter for _ in range(num_iterations)]) if rem > 0: - prune_list.extend([rem]) + prune_list.append(rem) else: - assert np.sum(np.array(prune_by_iter)) == p prune_list = prune_by_iter if sym_diff_per_iter is not None: assert len(prune_list) == len(sym_diff_per_iter) assert all(p <= t for p, t in zip(prune_list, sym_diff_per_iter)) - else: sym_diff_per_iter = prune_list - # for i in range(len(sym_diff_per_iter)): - # num_prune_iter = sym_diff_per_iter[] + kept_channels = set(range(layer.in_channels)) + removed_channels = set() + total_channels = kept_channels.copy() + keep_mask = torch.ones(layer.in_channels, dtype=torch.bool) + + # Iterative greedy pruning, where t=p and hence s1 = 0 + for i in range(len(prune_list)): + num_prune_iter = prune_list[i] + sym_diff_iter = sym_diff_per_iter[i] + s1 = (sym_diff_iter - num_prune_iter) // 2 + s2 = (sym_diff_iter + num_prune_iter) // 2 + + assert kept_channels.union(removed_channels) == total_channels + + channel_importance_dict = {} + + # Evaluate loss increase if each kept channel were pruned + for channel in kept_channels: + temp_keep_mask = keep_mask.clone() + temp_keep_mask[channel] = False + + subgram_xx = recompute_H( + prune_mask=temp_keep_mask, + H=gram_xx, + kernel_height=layer.kernel_size[0], + kernel_width=layer.kernel_size[1], + activation_in_channels=layer.in_channels, + is_pure_gram=True, + ) + subgram_xy = recompute_H( + prune_mask=temp_keep_mask, + H=gram_xy, + kernel_height=layer.kernel_size[0], + kernel_width=layer.kernel_size[1], + activation_in_channels=layer.in_channels, + is_pure_gram=False, + ) + # sub_w_optimal = recompute_W( + # prune_mask=temp_keep_mask, + # W=w_optimal, + # activation_in_channels=layer.in_channels, + # kernel_height=layer.kernel_size[0], + # kernel_width=layer.kernel_size[1], + # ) + submatrix_w_pruned = recompute_W( + prune_mask=temp_keep_mask, + W=dense_weights, + activation_in_channels=layer.in_channels, + kernel_height=layer.kernel_size[0], + kernel_width=layer.kernel_size[1], + ) + + loss = evaluate_loss( + submatrix_w_pruned, dense_weights, subgram_xx, subgram_xy + ) + channel_importance_dict[channel] = loss.item() + # Sort channels by importance(ascending = least important first) + sorted_channels = [ + k + for k, _ in sorted( + channel_importance_dict.items(), key=lambda item: item[1] + ) + ] -def osscar_prune( - dense_subnet, - prune_subnet, - dense_input, - pruned_input, - cached_dense_out, - cached_prune_out, - layer_name, - prune_layer_counts, -): - assert isinstance(dense_subnet, nn.Module), "Needs a valid subnet(nn.Module)" - assert isinstance(prune_subnet, nn.Module), "Needs a valid subnet(nn.Module)" - assert isinstance(dense_input, torch.Tensor) - assert isinstance(pruned_input, torch.Tensor) - assert pruned_input.ndim == 4, "Pruned input tensor must be of the shape B, C, H, W" - assert dense_input.ndim == 4, "Dense input tensor must be of the shape B, C, H, W" - - num_batches = dense_input.shape[0] - conv_module = None - reshaped_conv_wt = None - - # The way subnets are designed nmakes the first module of a subnet the conv2d thats needed to be pruned - for _, module in dense_subnet.named_modules(): - conv_module = module - assert isinstance(conv_module, nn.Conv2d) - reshaped_conv_wt = reshape_filter(conv_module.weight) - break - - print(conv_module) + # Prune s2 least important channels + to_prune = sorted_channels[:s2] + for channel in to_prune: + keep_mask[channel] = False + kept_channels.remove(channel) + removed_channels.add(channel) + + print( + f"Iteration {i+1}: pruned {len(to_prune)} channels, {len(kept_channels)} remaining." + ) + + return keep_mask, kept_channels, removed_channels + + +def get_parent_module(model, target_module): + """Find the direct parent and name of a given submodule. + + Iterates through all modules in `model` to locate the one that directly contains + `target_module`. + + Args: + model (nn.Module): Root model to search. + target_module (nn.Module): Submodule to locate. + + Returns: + (nn.Module, str): Parent module and the submodule's attribute name. + + Raises: + ValueError: If the target module isn’t found. + """ + for _, module in model.named_modules(): + for child_name, child in module.named_children(): + if child is target_module: + return module, child_name + raise ValueError("Target module not found") + + +def replace_module(model, target_module, new_module): + """Replace a submodule in-place within a model. + Finds the parent of `target_module` using `get_parent_module` and swaps it with + `new_module` via `setattr`. + + Args: + model (nn.Module): Root model. + target_module (nn.Module): Module to replace. + new_module (nn.Module): Replacement module. + """ + parent, name = get_parent_module(model, target_module) + setattr(parent, name, new_module) + + +def prune_one_layer( + dense_subnet, pruned_subnet, dense_input, pruned_input, layer_prune_channels +): + """ + Prune a single Conv2d layer from a subnet and replace it with a smaller version. + + Uses Gram matrices of the layer inputs and outputs to evaluate channel importance + and greedily prune the least important input channels. Replaces the original + layer in `pruned_subnet` with a new Conv2d containing only the kept channels. + + Args: + dense_subnet (nn.Module): Full, unpruned reference subnet. + pruned_subnet (nn.Module): Subnet to be pruned and modified. + dense_input (Tensor): Dense model inputs of shape (num_batches, batch_size, C, H, W). + pruned_input (Tensor): Pruned model inputs of the same shape. + layer_prune_channels (int): Number of channels to prune in this layer. + + Returns: + pruned_subnet (nn.Module): Updated subnet with the pruned layer replaced. + keep_mask (torch.BoolTensor): Boolean mask indicating which channels were kept. + """ assert ( - num_batches == pruned_input.shape[0] - ), "Dense and prune inputs must have the same number of images" + dense_input.ndim == 5 and pruned_input.ndim == 5 + ), "Inputs must be (num_batches, batch_size, C, H, W)" - total_xtx = 0 - total_xty = 0 + num_batches, batch_size, C, H, W = dense_input.shape + N = num_batches * batch_size - for batch_idx in range(num_batches): - dense_batch = dense_input[batch_idx, :, :, :] - pruned_batch = pruned_input[batch_idx, :, :, :] - dense_batch = reshape_conv_layer_input(input=dense_batch, layer=conv_module) - pruned_batch = reshape_conv_layer_input(input=pruned_batch, layer=conv_module) - total_xtx += get_coeff_h(pruned_batch) - total_xty += get_XtY(pruned_batch, dense_batch) - - # Note to self: Remember to double check - total_xtx /= num_batches - total_xty /= num_batches - - w_optimal = get_optimal_W( - gram_xx=total_xtx, gram_xy=total_xty, dense_weights=reshaped_conv_wt + # Get conv module to prune + conv_module = next( + m for _, m in pruned_subnet.named_modules() if isinstance(m, nn.Conv2d) ) + reshaped_conv_wt = reshape_filter(conv_module.weight) + + # Flatten all images into one dimension and reshape for gram_matrices calculation + dense_input_flat = dense_input.reshape(N, C, H, W) + pruned_input_flat = pruned_input.reshape(N, C, H, W) + dense_X = reshape_conv_layer_input(dense_input_flat, conv_module) + pruned_X = reshape_conv_layer_input(pruned_input_flat, conv_module) + + # Compute gram matrices over all N images + total_xtx = get_coeff_h(pruned_X) / N + total_xty = get_XtY(pruned_X, dense_X) / N - # Note to self: prune_layer_counts is supposed to be a dicttionary now, beware, dont forget to change - p = prune_layer_counts["layer_name"] - # pruned_layer = perform_local_search( - # w_optimal, layer_name, prune_layer_counts + # # Optimal weights + # w_optimal = get_optimal_W( + # gram_xx=total_xtx, gram_xy=total_xty, dense_weights=reshaped_conv_wt # ) - # Perform forward till the next layer to be pruned and cache results - # dense_out = dense_subnet(dense_input[batch_idx, :, :, :]) - # prune_out = prune_subnet(pruned_input[batch_idx, :, :, :]) + keep_mask, kept_channels, removed_channels = perform_local_search( + dense_weights=reshaped_conv_wt, + layer=conv_module, + p=layer_prune_channels, + gram_xx=total_xtx, + gram_xy=total_xty, + ) + + cached_out_pruned = [] + cached_out_dense = [] + + for batch_idx in range(num_batches): + cached_out_pruned.append(pruned_subnet(pruned_input[batch_idx])) + cached_out_dense.append(dense_subnet(dense_input[batch_idx])) + + cached_out_pruned = torch.cat(cached_out_pruned, dim=0) + cached_out_dense = torch.cat(cached_out_dense, dim=0) - # cached_out = + new_weight = conv_module.weight[:, keep_mask, :, :] + if conv_module.bias is not None: + new_bias = conv_module.bias + else: + new_bias = None + + kernel_size = conv_module.kernel_size + stride = conv_module.stride + padding = conv_module.padding + dilation = conv_module.dilation + + # Pylance fix + kernel_size_2 = (kernel_size[0], kernel_size[1]) + stride_2 = (stride[0], stride[1]) + padding_2 = (padding[0], padding[1]) + dilation_2 = (dilation[0], dilation[1]) + + # Replacement module + new_conv_module = nn.Conv2d( + in_channels=conv_module.in_channels, + out_channels=conv_module.out_channels - 1, + kernel_size=kernel_size_2, + stride=stride_2, + padding=padding_2, + dilation=dilation_2, + groups=conv_module.groups, + bias=conv_module.bias is not None, + ) + + # Replace weights + new_conv_module.weight.data = new_weight.clone() + if new_bias is not None: + new_conv_module.bias.data = new_bias.clone() + + replace_module( + model=pruned_subnet, target_module=conv_module, new_module=new_conv_module + ) + + return pruned_subnet, keep_mask if __name__ == "__main__": model = resnet50(pretrained=True) input = torch.randn(1, 3, 224, 224) - weights = model.conv1.weight - # print(weights.shape) - - # model.conv1.weight = weights[:, :2, :, :] - prune_conv_modules, prune_modules_name = collect_convolution_layers_to_prune( model=model ) - reformatted = [] for name in prune_modules_name: reformatted_name = "_".join(name.split(".")) @@ -378,64 +578,3 @@ def osscar_prune( graph_module=gm, prune_modules_name=prune_modules_name ) assert len(prune_subnets) == len(dense_subnets) - - print(prune_subnets[0]) - print(dense_subnets[1]) - # prefix_gm, remap = get_initial_prefix_submodule(graph_module=gm, end_node=end) - - # start_node = reformatted[0] - # end_node = reformatted[1] - - # subnet, remap_dict = get_fx_submodule( - # graph_module=gm, value_remap=remap, start_node=start_node, end_node=end_node - # ) - - # out = prefix_gm(input) - # # print(out.shape) - - # # out = out[:, :2, :, :] - # # model.conv1.weight = torch.nn.Parameter(weights[:, :2, :, :], requires_grad=True) - - # out = subnet(out) - - # print(subnet.conv1) - - # print(out.shape) - - # outputs = {} - - # def hook_fn(model, input, output): - # outputs["maxpool"] = output - - # model.maxpool.register_forward_hook(hook_fn) - # model(input) - # print(outputs["maxpool"].shape) - - # suffix_module, remap_final = get_suffix_submodule( - # graph_module=gm, value_remap=remap, start_node=end_node - # ) - # out = suffix_module(out) - # # print(suffix_module) - - # direct_out = model(input) - - # assert torch.allclose(out, direct_out) - # print(out.shape) - # print(direct_out.shape) - - # perform_local_search() - p = 1 - prune_per_iter = 3 - num_iterations = p // prune_per_iter - rem = p % prune_per_iter - zero_set = set() - sym_diff_per_iter = [] - # while p > 1: - # p = p - prune_per_iter - # sym_diff_per_iter.extend([prune_per_iter]) - if p >= prune_per_iter: - sym_diff_per_iter.extend([prune_per_iter for i in range(num_iterations)]) - if rem > 0: - sym_diff_per_iter.extend([rem]) - print(sym_diff_per_iter) - # for diff --git a/compress/osscar.py b/compress/osscar.py index 90ca8e8..c823d63 100644 --- a/compress/osscar.py +++ b/compress/osscar.py @@ -1,662 +1,183 @@ -import random -from typing import cast +import argparse -import numpy as np import torch -from torch import nn +import torch.nn as nn +import numpy as np +import torch.fx as fx -from compress.heuristics import * +from compress.graph_extractor import get_all_subnets, prune_one_layer +from compress.heuristics import collect_convolution_layers_to_prune +from compress.osscar_utils import get_count_prune_channels +from data.load_data import build_calibration_dataloader, build_eval_dataset from model.resnet import resnet50 -from data.load_data import build_eval_dataset, build_eval_dataloader +from utils.utils import set_global_seed, safe_free -def reshape_filter(filter): +def run_forward_with_mask(subnet, input, input_mask=None, is_input_loader=False): """ - Rearrange a Conv2d weight tensor according to OSSCAR paper. + Runs forward passes through a subnet, optionally masking input channels. - Parameters - ---------- - filter : torch.Tensor - Weight tensor of shape (C_out, C_in, K_h, K_w). + Supports either a preloaded tensor of shape (num_batches, batch_size, C, H, W) + or a DataLoader. The outputs from all batches are stacked along a new leading + dimension for later calibration or analysis. - Returns - ------- - torch.Tensor - Reshaped tensor of shape (C_in*K_h*K_w, C_out). - """ - assert isinstance(filter, torch.Tensor) - assert filter.ndim == 4, "Filter shape must be (Cout, Cin, Kh, Kw)" - cout, _, _, _ = filter.size() - reshaped_filter = filter.permute(1, 2, 3, 0) - reshaped_filter = reshaped_filter.reshape(-1, cout) - - return reshaped_filter + Args: + subnet (nn.Module): The model or subnet to evaluate. + input (torch.Tensor | DataLoader): Batched tensor input or a DataLoader. + input_mask (torch.BoolTensor, optional): Channel-wise boolean mask; only + channels with `True` are retained. + is_input_loader (bool, default=False): If True, treats `input` as a DataLoader. -def reshape_conv_layer_input(input, layer): + Returns: + cached_input (torch.Tensor): Stacked outputs for all batches, of shape + (num_batches, batch_size, C_out, H_out, W_out). """ - Unfold an input tensor using a Conv2d layer's settings. - - Parameters - ---------- - input : torch.Tensor - Input tensor of shape (C, H, W) or (B, C, H, W). - layer : nn.Conv2d - Conv2d layer whose kernel/stride/dilation/padding define the unfolding. - - Returns - ------- - torch.Tensor - Unfolded tensor of shape (B*L, C_in*K_h*K_w), - where L is the number of sliding locations. - """ - assert isinstance(input, torch.Tensor), "Input must be a tensor" - assert isinstance(layer, nn.Conv2d), "Layer must be a nn.Conv2d layer" - assert ( - input.ndim == 3 or input.ndim == 4 - ), "Input tensors must be either (C, H, W) or (B, C, H, W)" - - if input.ndim == 3: - input = input.unsqueeze(dim=0) - - _, _, h, w = input.shape - - # Effective size of a kernel changes in a dilated conv op - k_eff_y = (layer.kernel_size[0] - 1) * layer.dilation[0] + 1 - k_eff_x = (layer.kernel_size[1] - 1) * layer.dilation[1] + 1 - - if isinstance(layer.padding, str) and layer.padding == "same": - y_padding = ((layer.stride[0] * h - h) + k_eff_y - layer.stride[0]) // 2 - x_padding = ((layer.stride[1] * w - w) + k_eff_x - layer.stride[1]) // 2 - elif isinstance(layer.padding, tuple): - y_padding, x_padding = layer.padding + calibration_batches = [] + if is_input_loader: + assert isinstance(input, torch.utils.data.DataLoader) + for images, _ in input: + outs = subnet(images) + calibration_batches.append(outs) else: - y_padding = x_padding = layer.padding - - # Silence pylance's static check - y_padding = cast(int, y_padding) - x_padding = cast(int, x_padding) - - unfold = nn.Unfold( - kernel_size=layer.kernel_size, - dilation=layer.dilation, - padding=(y_padding, x_padding), - stride=layer.stride, - ) - - input = unfold(input) - input = input.permute(1, 0, 2) - input = input.flatten(1) - input = input.T - - return input - - -def get_coeff_h(design_matrix): - """ - Compute the input autocorrelation (H) matrix from a 2D design matrix. - - Parameters - ---------- - design_matrix : torch.Tensor - 2D tensor of shape (N, D), where N is the number of samples - (e.g., unfolded spatial positions across the batch) and - D is the feature dimension (e.g., C_in * K_h * K_w). - - Returns - ------- - torch.Tensor - Square tensor of shape (D, D) representing the autocorrelation - matrix H = XᵀX of the input features. - """ - assert isinstance(design_matrix, torch.Tensor) - assert design_matrix.ndim == 2, "Requires the reshaped design matrix" - - return design_matrix.T @ design_matrix - - -def get_XtY(X, Y): - """ - Compute the cross-correlation matrix XᵀY between two unfolded tensors. - - Parameters - ---------- - X : torch.Tensor - 2D tensor of shape (N, D₁), typically an unfolded input/design matrix - from `reshape_conv_layer_input`, where N is the number of samples - (e.g., batch × sliding locations) and D₁ is the feature dimension. - Y : torch.Tensor - 2D tensor of shape (N, D₂), typically another unfolded tensor of - the same number of rows as `X`, but possibly with a different - feature dimension D₂. - - Returns - ------- - torch.Tensor - Matrix of shape (D₁, D₂) representing the cross-correlation - XᵀY between the two inputs. - """ - assert isinstance(X, torch.Tensor) - assert isinstance(Y, torch.Tensor) - - A = X.T + assert isinstance(input, torch.Tensor) + assert input.ndim == 5, "Input shape must be (Num_batches, Batch_size, C, H, W)" + num_batches = input.shape[0] - assert A.shape[1] == Y.shape[0] - return A @ Y + if input_mask is not None: + input_tensor = input[:, :, input_mask, :, :] + else: + input_tensor = input + # Note to self: could include all elements in 1 batch effectively making the input_tensor 4D + # The for loop can then be avoided but might make the solution too memory-intensive + for i in range(num_batches): + input_tensor = input_tensor[i, :, :, :, :] + output_tensor = subnet(input_tensor) + calibration_batches.append(output_tensor) + cached_input = torch.stack(calibration_batches, dim=0) -def get_coeff_g(dense_layer_weights, layer_input): - """ - Compute the G coefficient for a dense (reshaped) layer weight matrix. - - Parameters - ---------- - dense_layer_weights : torch.Tensor - Layer weights reshaped to 2D, shape (out_features, in_features). - layer_input : torch.Tensor - Layer input activations before non-linearity. - Shape (B, C, H, W) or (C, H, W). - - Returns - ------- - torch.Tensor - G matrix capturing the projection of inputs onto the weight space. - """ - assert isinstance(layer_input, torch.Tensor) - assert isinstance(dense_layer_weights, torch.Tensor) - - layer_input_dim = layer_input.ndim - - assert dense_layer_weights.ndim == 2, "get_coeff_g takes in the reshaped weights" - assert ( - layer_input_dim == 3 or layer_input_dim == 4 - ), "Layer input must be of shape (B, C, H, W) or (C, H, W)" - - G = torch.transpose(dense_layer_weights, dim0=0, dim1=1) @ layer_input - - return G - -def compute_layer_loss(dense_weights, pruned_weights, input): - """ - Compute the layer reconstruction loss using H and G coefficients. - - Parameters - ---------- - dense_weights : torch.Tensor - Original (dense) layer weights reshaped to 2D. - pruned_weights : torch.Tensor - Pruned layer weights reshaped to 2D. - input : torch.Tensor - Layer input activations before non-linearity. - Shape (B, C, H, W) or (C, H, W). - - Returns - ------- - torch.Tensor - Scalar loss measuring how well the pruned layer reconstructs the original. - """ - G = get_coeff_g(dense_layer_weights=dense_weights, layer_input=input) - H = get_coeff_h(design_matrix=input) - A = (pruned_weights.T @ H) @ pruned_weights - B = G.T @ pruned_weights - - assert A.ndim == B.ndim == 2, "Trace can be computed only for 2D matrices" - loss = 0.5 * torch.trace(A) + torch.trace(B) + return cached_input - return loss - -def num_params_in_prune_channels(layers): +def run_osscar(model, calibration_loader, args): """ - Compute the total number of parameters across a list of convolutional layers. + Apply OSSCAR-style structured pruning to a model using a calibration dataset. - Parameters - ---------- - layers : list of nn.Conv2d - The convolutional layers whose parameters you want to count. - - Returns - ------- - int - Total number of parameters (weights + biases if present) in the given layers. - """ - params = 0 - - for layer in layers: - assert isinstance(layer, nn.Conv2d) - params += count_parameters(layer, in_millions=False) - - return params - - -def recalculate_importance(rem_channels_layer_wise): - """ - Recalculate normalized importance weights for each layer - based on the remaining number of channels per layer. - - Parameters - ---------- - rem_channels_layer_wise : array-like of int - Number of channels remaining in each layer. - - Returns - ------- - numpy.ndarray - Normalized importance values for each layer (sum to 1). - """ - total = np.sum(rem_channels_layer_wise) - rem_imp = np.divide(rem_channels_layer_wise, total) - - return rem_imp - - -def distribute_remaining_parameters( - rem_params, rem_channels_per_layer, layers, num_iters=20, allowable_tol=250 -): - """ - Stochastically allocate leftover parameter removals across layers. - - At each iteration, a layer is sampled according to a probability - distribution proportional to its remaining channels. One input channel - is removed from the chosen layer if possible, and the remaining parameter - budget is updated. The loop stops when the budget is within the allowable - tolerance or after `num_iters` iterations. - - Parameters - ---------- - rem_params : int - Remaining number of parameters to remove. - rem_channels_per_layer : list of int - Remaining number of channels per layer (mutable, will be updated). - layers : list of nn.Conv2d - Convolutional layers eligible for pruning. - num_iters : int, optional - Maximum number of allocation iterations. Default is 20. - allowable_tol : int, optional - Stop when the remaining parameter budget is within this tolerance. Default is 250. - - Returns - ------- - tuple - p : numpy.ndarray - Array of additional channels removed per layer. - rem_params : int - Remaining number of parameters still to remove after allocation. - """ - num_layers = len(layers) - layer_choices = np.arange(num_layers) - p = np.zeros(num_layers, dtype=np.int32) - rng = random.Random(3) - - rem_imp = recalculate_importance(rem_channels_layer_wise=rem_channels_per_layer) - - for i in range(num_iters): - random_layer_idx = rng.choices(layer_choices, weights=rem_imp)[0] - assert isinstance(random_layer_idx, (int, np.integer)) - - layer = layers[random_layer_idx] - - assert isinstance(layer, nn.Conv2d) - params_removed = ( - layer.kernel_size[0] * layer.kernel_size[1] * layer.out_channels - ) - count_remove_params = rem_params - params_removed - if rem_channels_per_layer[random_layer_idx] > 1 and count_remove_params >= 0: - p[random_layer_idx] += 1 - rem_channels_per_layer[random_layer_idx] -= 1 - - if rem_params - count_remove_params < 0: - continue - - rem_params = count_remove_params - rem_imp = recalculate_importance(rem_channels_per_layer) - - if rem_params <= allowable_tol: - break - - return p, rem_params - - -def get_count_prune_channels(model, prune_percentage, allowable_tol=250): - """ - Compute per-layer channel pruning counts to reach a target global prune percentage. - - This function: - 1. Collects the convolutional layers eligible for pruning. - 2. Computes how many parameters to remove from each layer based on an - importance heuristic. - 3. Converts parameter removals into channel counts per layer. - 4. Distributes any remaining parameter removals stochastically to meet - the target budget within the allowable tolerance. + The function: + 1. Determines how many channels to prune per layer based on `args.prune_percentage`. + 2. Identifies convolution layers eligible for pruning. + 3. Symbolically traces the model to extract subnets for pruning. + 4. Iteratively prunes layers using local greedy search while caching intermediate activations. + 5. Returns a list of pruned subnets forming the pruned model and the per-layer keep masks. Parameters ---------- model : nn.Module - The model to prune. - prune_percentage : float - Target fraction of total model parameters to remove (0–1). - allowable_tol : int, optional - Tolerance for how far from the target parameter count to allow. Default is 250. + The original dense PyTorch model to prune. + calibration_loader : torch.utils.data.DataLoader + Data loader providing batches for calibration / activation caching. + args : Namespace + Arguments containing `prune_percentage` among other potential config values. Returns ------- - tuple - num_channels_left_per_layer : list of int - Number of input channels to keep per layer after pruning. - p : list or numpy.ndarray - Number of input channels to prune per layer (sum of deterministic and random). - remaining_params_to_prune : int - Number of parameters still to prune after allocation (ideally <= allowable_tol). + pruned_model : list[nn.Module] + List of subnets forming the pruned model. + keep_masks : list[torch.BoolTensor] + Per-layer boolean masks indicating which input channels were kept. """ - model_total_params = count_parameters(model, in_millions=False) - - # Collect layers eligible to be pruned - layers, _ = collect_convolution_layers_to_prune(model) - - # Quick sanity check : ensures that total params to be removed doesnt exceed those that can be provided by eligible layers - eligible_prune_params_count = num_params_in_prune_channels(layers=layers) - total_params_to_prune = int(model_total_params * prune_percentage) - - assert eligible_prune_params_count > total_params_to_prune - - # Computes relative importance of each layer, higher importance => More channels pruned from this layer - importance_list = compute_layer_importance_heuristic(layers) - assert math.isclose( - np.sum(importance_list), 1.0, abs_tol=0.00001 - ), "importance scores must sum to 1" - - # Number of params to remove from every eligible layer - num_prune_params_by_layer = total_params_to_prune * importance_list - num_prune_params_by_layer = np.floor(total_params_to_prune * importance_list) - - revised_prune_params_count = 0 - p = [] - num_channels_left_per_layer = [] - - for idx, layer in enumerate(layers): - assert isinstance(layer, nn.Conv2d) + prune_percentage = args.prune_percentage + channels_post_prune, prune_channels_by_layer, remaining_params = ( + get_count_prune_channels(model=model, prune_percentage=prune_percentage) + ) + _, prune_modules_name = collect_convolution_layers_to_prune(model=model) - num_spatial_params = layer.kernel_size[0] * layer.kernel_size[1] - num_channels_per_filter = layer.in_channels - num_filters = layer.out_channels + gm = fx.symbolic_trace(model) - # Num channels to remove from every player(defined as per osscar) - num_channels_to_remove = num_prune_params_by_layer[idx] // ( - num_spatial_params * num_filters + prune_subnets, dense_subnets = get_all_subnets( + graph_module=gm, prune_modules_name=prune_modules_name + ) + assert len(prune_subnets) == len(dense_subnets) + + keep_masks = [] + pruned_model = [] + prefix_subnet = prune_subnets[0] + pruned_model.append(prefix_subnet) + + dense_cached_input = run_forward_with_mask( + subnet=prefix_subnet, + input=calibration_loader, + input_mask=None, + is_input_loader=True, + ).detach() + cached_input = dense_cached_input + + for i in range(1, len(dense_subnets)): + subnet_post_pruning, keep_mask = prune_one_layer( + dense_subnet=dense_subnets[i], + pruned_subnet=prune_subnets[i], + dense_input=dense_cached_input, + pruned_input=cached_input, + layer_prune_channels=prune_channels_by_layer[i], ) - p.append(num_channels_to_remove) - assert ( - num_channels_to_remove < num_channels_per_filter - ), "Cant remove all channels in a filter" + pruned_model.append(subnet_post_pruning) + keep_masks.append(keep_mask) + new_dense = run_forward_with_mask( + subnet=dense_subnets[i], input=dense_cached_input, input_mask=None + ).detach() + new_pruned = run_forward_with_mask( + subnet=prune_subnets[i], input=cached_input, input_mask=keep_mask + ).detach() - num_params_removed = num_spatial_params * num_channels_to_remove * num_filters - revised_prune_params_count += num_params_removed - num_channels_left = num_channels_per_filter - num_channels_to_remove - num_channels_left_per_layer.append(num_channels_left) + safe_free(dense_cached_input, cached_input) + dense_cached_input, cached_input = new_dense, new_pruned - remaining_params_to_prune = total_params_to_prune - revised_prune_params_count + return pruned_model, keep_masks - if remaining_params_to_prune > allowable_tol: - p_rem, remaining_params_to_prune = distribute_remaining_parameters( - rem_params=remaining_params_to_prune, - rem_channels_per_layer=num_channels_left_per_layer, - layers=layers, - allowable_tol=allowable_tol, - ) - p = np.array(p) + np.array(p_rem) +class PrunedResnet50(nn.Module): + def __init__(self, pruned_modules_list, keep_mask_list): + super().__init__() + assert len(pruned_modules_list) - 1 == len(keep_mask_list) + self.module_list = nn.ModuleList(pruned_modules_list) - return num_channels_left_per_layer, p, remaining_params_to_prune + # Register masks as buffers so that they can be saved/loaded along with the model + for i, mask in enumerate(keep_mask_list): + self.register_buffer(f"input_mask_{i}", mask) -def save_layer_input(activations, layer_name): - """ - Factory function to create a forward hook that saves a layer’s output. + self.keep_mask_list = keep_mask_list - Parameters - ---------- - activations : dict - A dictionary (mutable) where the captured outputs will be stored. - Keys are layer names, values are the output tensors. - layer_name : str - Name under which to store this layer’s output in the `activations` dict. + def forward(self, x): + # The first module has no input mask + x = self.module_list[0](x) - Returns - ------- - hook : callable - A forward hook function with signature (module, input, output) - that can be passed to `register_forward_hook`. - """ - def hook(model, input, output): - activations[layer_name] = output.detach() - return hook + # The mask of the ith module in module_list is the i-1th entry of keep_mask_list + for i in range(1, len(self.module_list)): + x = x[:, self.keep_mask_list[i - 1], :, :] + x = self.module_list[i](x) - -def register_hooks_to_collect_outs(prune_modules, prune_module_names, hook_fn): - """ - Register a forward hook on each module in `prune_modules` to collect outputs. - - Parameters - ---------- - prune_modules : list[nn.Module] - List of modules to attach hooks to (e.g. layers to prune). - prune_module_names : list[str] - Names corresponding to each module in `prune_modules`. - Must be the same length as `prune_modules`. - hook_fn : callable - A factory function that accepts `activations` (dict) and `layer_name` (str) - and returns a forward hook with signature (module, input, output). - - Returns - ------- - activations : dict - A dictionary that will be populated with {layer_name: output_tensor} - during a forward pass. - """ - activations = {} - for idx, module in enumerate(prune_modules): - assert isinstance(module, nn.Conv2d) - module_name = prune_module_names[idx] - module.register_forward_hook(hook=hook_fn(activations=activations, layer_name=module_name)) - - return activations - -def recompute_X(prune_mask, X, layer_in_channels, kernel_height, kernel_width): - """ - Recompute the unfolded input matrix X after pruning input channels. - - Parameters - ---------- - prune_mask : array-like of bool - Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). - X : torch.Tensor or np.ndarray - 2D unfolded input matrix of shape (N, M) where N corresponds to flattened channel–kernel elements and - M to the number of sliding positions or samples. - layer_in_channels : int - Number of input channels in the convolution layer prior to pruning. - kernel_height : int - Height of the convolution kernel. - kernel_width : int - Width of the convolution kernel. - - Returns - ------- - torch.Tensor or np.ndarray - Pruned unfolded input matrix with rows corresponding only to kept input channels. - """ - assert X.ndim == 2, "Weight matrix must have already been reshaped" - assert len(prune_mask) == layer_in_channels, "The length of the indicator vector and the number of in_channels must be the same" - - # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix - numel_one_channel = kernel_height * kernel_width - _, X_width = X.shape - - slice_indices = np.arange(layer_in_channels) - mask = np.ones(X_width, dtype=bool) - slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices] - - for idx, indicator in enumerate(prune_mask): - if not indicator: - start, stop = slice_indices[idx] - mask[start:stop] = False - - return X[:, mask] - -def recompute_H(prune_mask, H, kernel_height, kernel_width, activation_out_channels): - """ - Recompute the coefficient matrix H after pruning output channels. - - Parameters - ---------- - prune_mask : array-like of bool - Boolean mask of length `activation_out_channels` indicating which output channels to keep (`True`) or drop (`False`). - H : torch.Tensor or np.ndarray - 2D square matrix of shape (N, N) typically representing a Hessian or covariance term over activations. - kernel_height : int - Height of the convolution kernel. - kernel_width : int - Width of the convolution kernel. - activation_out_channels : int - Number of output channels (activations) prior to pruning. - - Returns - ------- - torch.Tensor or np.ndarray - Pruned square matrix containing only rows and columns corresponding to kept output channels. - """ - assert len(prune_mask) == activation_out_channels - assert H.ndim == 2 - - kept_indices = [] - numel_one_channel = kernel_height * kernel_width - slice_indices = np.arange(activation_out_channels) - slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices] - - for idx, indicator in enumerate(prune_mask): - if not indicator: - start, stop = slice_indices[idx] - kept_indices.extend(np.arange(start=start, stop=stop)) - - H_updated = H[np.ix_(kept_indices, kept_indices)] - - return H_updated - -def recompute_W(prune_mask, W, layer_in_channels, kernel_height, kernel_width): - """ - Recompute the weight matrix W after pruning input channels. - - Parameters - ---------- - prune_mask : array-like of bool - Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). - W : torch.Tensor or np.ndarray - 2D weight matrix of shape (N, M) where N corresponds to flattened channel–kernel elements. - layer_in_channels : int - Number of input channels in the convolution layer prior to pruning. - kernel_height : int - Height of the convolution kernel. - kernel_width : int - Width of the convolution kernel. - - Returns - ------- - torch.Tensor or np.ndarray - Pruned weight matrix with rows corresponding only to kept input channels. - """ - assert W.ndim == 2, "Weight matrix must have already been reshaped" - assert len(prune_mask) == layer_in_channels, "The length of the indicator vector and the number of in_channels must be the same" - - # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix - numel_one_channel = kernel_height * kernel_width - W_height, _ = W.shape - - slice_indices = np.arange(layer_in_channels) - mask = np.ones(W_height, dtype=bool) - - slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices] - print(slice_indices) - - for idx, indicator in enumerate(prune_mask): - if not indicator: - start, stop = slice_indices[idx] - mask[start:stop] = False - - return W[mask, :] - - -def get_optimal_W(pruned_activation, dense_activation, dense_weights): - assert isinstance(pruned_activation, torch.Tensor) - assert isinstance(dense_activation, torch.Tensor) - assert isinstance(dense_weights, torch.Tensor) - - assert pruned_activation.ndim == 2, "Pruned activation should be 2D" - assert dense_activation.ndim == 2, "Dense activation should have been reshaped to 2D" - assert dense_weights.ndim == 2, "Weights should have been reshaped to 2D" - -def get_calibration_dataset(): + return x if __name__ == "__main__": - # rand_filter = torch.randn((8, 3, 3, 3)) - # reshaped = reshape_filter(rand_filter) - # print(reshaped.size()) - # unfold = nn.Unfold( - # kernel_size=3, - # dilation=1, - # padding=1, - # stride=1 - # ) - # inp = unfold(rand_filter) - # inp = inp.permute([1, 0, 2]) - # inp = inp.flatten(1) - - # # print(inp.shape) - # model = resnet50(pretrained=True) - # overall_prune_percentage = 0.3 - # # _, p, rem_params = get_count_prune_channels( - # # model=model, prune_percentage=overall_prune_percentage - # # ) - - # # print(f"P: {p}") - # # print(f"Rem parameters to prune: {rem_params}") - # layer_in_channels = 3 - # numel_one_channel = 9 - # slice_indices = np.arange(layer_in_channels, dtype=np.int32) - # slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel - 1) for start in slice_indices] - # print(slice_indices) - - - # z = np.random.randint(low=0, high=2, size=layer_in_channels) - # print("Z:", z) - # W = np.random.randn(27, 1) - # print("W before: ", W) - # print("W after: ", get_matrix_I(z, W, layer_in_channels, 3, 3)) - - # layer_input = torch.randn((3, 3, 4)) - # output = get_coeff_h(layer_input) - # print(output.shape) - - input_activation = torch.randn(2, 3, 32, 32) - pruned_activation = torch.randn(2, 2, 32, 32) - mini_conv = nn.Conv2d(padding=1, kernel_size=3, in_channels=3, out_channels=8) - - out = reshape_conv_layer_input(input_activation, mini_conv) - - W = reshape_filter(mini_conv.weight) - print(f"Reshaped W before pruning: {W.shape}") - - prune_mask = [1, 1, 0] - W_pruned = recompute_W(prune_mask, W, layer_in_channels=3, kernel_height=3, kernel_width=3) - - print(f"Reshaped W after pruning: {W_pruned.shape}") - - print("Reshaped activation shape before pruning:", out.shape) - - X = recompute_X(prune_mask=prune_mask, X=out, layer_in_channels=3, kernel_height=3, kernel_width=3) - - print("Reshaped activation shape after pruning:", X.shape) - - corr = get_XtY(out, X) - print("XtY shape", corr.shape) - - print(get_coeff_h(out).shape) \ No newline at end of file + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + set_global_seed(seed=232) + g = torch.Generator() + g.manual_seed(11) + + calibration_dataset = build_eval_dataset(cfg="config/config.yaml") + calibration_dataloader = build_calibration_dataloader( + dataset=calibration_dataset, num_samples=500, g=g, batch_size=32 + ) + + model = resnet50(pretrained=True) + parser = argparse.ArgumentParser(description="Arguments for OSSCAR") + parser.add_argument("--prune_percentage", default=0.25, type=float) + args = parser.parse_args() + pruned_model_list, keep_mask_list = run_osscar( + model=model, calibration_loader=calibration_dataloader, args=args + ) + pruned_model = PrunedResnet50(pruned_model_list, keep_mask_list) diff --git a/compress/osscar_utils.py b/compress/osscar_utils.py index 6370eed..f04c27f 100644 --- a/compress/osscar_utils.py +++ b/compress/osscar_utils.py @@ -145,66 +145,34 @@ def get_XtY(X, Y): return A @ Y -def get_coeff_g(dense_layer_weights, layer_input): - """ - Compute the G coefficient for a dense (reshaped) layer weight matrix. - - Parameters - ---------- - dense_layer_weights : torch.Tensor - Layer weights reshaped to 2D, shape (out_features, in_features). - layer_input : torch.Tensor - Layer input activations before non-linearity. - Shape (B, C, H, W) or (C, H, W). - - Returns - ------- - torch.Tensor - G matrix capturing the projection of inputs onto the weight space. - """ - assert isinstance(layer_input, torch.Tensor) - assert isinstance(dense_layer_weights, torch.Tensor) - - layer_input_dim = layer_input.ndim - - assert dense_layer_weights.ndim == 2, "get_coeff_g takes in the reshaped weights" - assert ( - layer_input_dim == 3 or layer_input_dim == 4 - ), "Layer input must be of shape (B, C, H, W) or (C, H, W)" - - G = torch.transpose(dense_layer_weights, dim0=0, dim1=1) @ layer_input - - return G - - -def compute_layer_loss(dense_weights, pruned_weights, input): - """ - Compute the layer reconstruction loss using H and G coefficients. - - Parameters - ---------- - dense_weights : torch.Tensor - Original (dense) layer weights reshaped to 2D. - pruned_weights : torch.Tensor - Pruned layer weights reshaped to 2D. - input : torch.Tensor - Layer input activations before non-linearity. - Shape (B, C, H, W) or (C, H, W). - - Returns - ------- - torch.Tensor - Scalar loss measuring how well the pruned layer reconstructs the original. - """ - G = get_coeff_g(dense_layer_weights=dense_weights, layer_input=input) - H = get_coeff_h(design_matrix=input) - A = (pruned_weights.T @ H) @ pruned_weights - B = G.T @ pruned_weights - - assert A.ndim == B.ndim == 2, "Trace can be computed only for 2D matrices" - loss = 0.5 * torch.trace(A) + torch.trace(B) - - return loss +# def compute_layer_loss(dense_weights, pruned_weights, input): +# """ +# Compute the layer reconstruction loss using H and G coefficients. + +# Parameters +# ---------- +# dense_weights : torch.Tensor +# Original (dense) layer weights reshaped to 2D. +# pruned_weights : torch.Tensor +# Pruned layer weights reshaped to 2D. +# input : torch.Tensor +# Layer input activations before non-linearity. +# Shape (B, C, H, W) or (C, H, W). + +# Returns +# ------- +# torch.Tensor +# Scalar loss measuring how well the pruned layer reconstructs the original. +# """ +# G = get_coeff_g(dense_layer_weights=dense_weights, layer_input=input) +# H = get_coeff_h(design_matrix=input) +# A = (pruned_weights.T @ H) @ pruned_weights +# B = G.T @ pruned_weights + +# assert A.ndim == B.ndim == 2, "Trace can be computed only for 2D matrices" +# loss = 0.5 * torch.trace(A) + torch.trace(B) + +# return loss def num_params_in_prune_channels(layers): @@ -335,7 +303,7 @@ def get_count_prune_channels(model, prune_percentage, allowable_tol=250): model : nn.Module The model to prune. prune_percentage : float - Target fraction of total model parameters to remove (0–1). + Target fraction of total model parameters to remove (0-1). allowable_tol : int, optional Tolerance for how far from the target parameter count to allow. Default is 250. @@ -482,27 +450,59 @@ def register_hooks_to_collect_outs(prune_modules, prune_module_names, hook_fn): return gram_activations -def accumulate_xtx_statistics(model, calibration_dataloader): - """ - Run the model over the calibration data to accumulate per-layer - XᵀX (Gram) matrices using registered forward hooks. - """ - for images, _ in calibration_dataloader: - model(images) - - -def recompute_X(prune_mask, X, layer_in_channels, kernel_height, kernel_width): +# def accumulate_xtx_statistics(model, calibration_dataloader): +# """ +# Run the model over the calibration data to accumulate per-layer +# XᵀX (Gram) matrices using registered forward hooks. +# """ +# for images, _ in calibration_dataloader: +# model(images) + +# def get_kept_indices(prune_mask, kernel_height, kernel_width): +# """ +# Compute row indices corresponding to channels to keep. + +# Parameters +# ---------- +# prune_mask : array-like of bool +# Boolean mask of length `Cin`. True = keep, False = prune. +# kernel_height : int +# Conv kernel height. +# kernel_width : int +# Conv kernel width. + +# Returns +# ------- +# np.ndarray +# 1D array of row indices corresponding to the kept channels. +# """ +# prune_mask = np.asarray(prune_mask) +# assert prune_mask.ndim == 1, "prune_mask must be 1D" +# Cin = len(prune_mask) +# numel_one_channel = kernel_height * kernel_width + +# kept_indices = [] +# for idx, keep in enumerate(prune_mask): +# if keep: +# start = idx * numel_one_channel +# stop = start + numel_one_channel +# kept_indices.extend(range(start, stop)) + +# return np.array(kept_indices) + + +def recompute_X(prune_mask, X, activation_in_channels, kernel_height, kernel_width): """ Recompute the unfolded input matrix X after pruning input channels. Parameters ---------- prune_mask : array-like of bool - Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). + Boolean mask of length `activation_in_channels` indicating which input channels to keep (`True`) or drop (`False`). X : torch.Tensor or np.ndarray - 2D unfolded input matrix of shape (N, M) where N corresponds to flattened channel–kernel elements and + 2D unfolded input matrix of shape (N, M) where N corresponds to flattened channel-kernel elements and M to the number of sliding positions or samples. - layer_in_channels : int + activation_in_channels : int Number of input channels in the convolution layer prior to pruning. kernel_height : int Height of the convolution kernel. @@ -512,18 +512,18 @@ def recompute_X(prune_mask, X, layer_in_channels, kernel_height, kernel_width): Returns ------- torch.Tensor or np.ndarray - Pruned unfolded input matrix with rows corresponding only to kept input channels. + Pruned unfolded input matrix with columns corresponding only to kept input channels. """ - assert X.ndim == 2, "Weight matrix must have already been reshaped" + assert X.ndim == 2, "Input matrix must have already been reshaped" assert ( - len(prune_mask) == layer_in_channels + len(prune_mask) == activation_in_channels ), "The length of the indicator vector and the number of in_channels must be the same" # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix numel_one_channel = kernel_height * kernel_width _, X_width = X.shape - slice_indices = np.arange(layer_in_channels) + slice_indices = np.arange(activation_in_channels) mask = np.ones(X_width, dtype=bool) slice_indices = [ (start * numel_one_channel, start * numel_one_channel + numel_one_channel) @@ -538,60 +538,96 @@ def recompute_X(prune_mask, X, layer_in_channels, kernel_height, kernel_width): return X[:, mask] -def recompute_H(prune_mask, H, kernel_height, kernel_width, activation_out_channels): +# def recompute_X(prune_mask, X, kernel_height, kernel_width): +# kept_idx = get_kept_indices(prune_mask, kernel_height, kernel_width) +# return X[:, kept_idx] + +# def recompute_W(prune_mask, W, kernel_height, kernel_width): +# kept_idx = get_kept_indices(prune_mask, kernel_height, kernel_width) +# return W[kept_idx, :] + +# def recompute_H(prune_mask, H, kernel_height, kernel_width): +# kept_idx = get_kept_indices(prune_mask, kernel_height, kernel_width) +# return H[np.ix_(kept_idx, kept_idx)] + + +def recompute_H( + prune_mask, + H, + activation_in_channels, + kernel_height, + kernel_width, + is_pure_gram=True, +): """ - Recompute the coefficient matrix H after pruning output channels. + Recompute the coefficient matrix H after pruning input channels. Parameters ---------- prune_mask : array-like of bool - Boolean mask of length `activation_out_channels` indicating which output channels to keep (`True`) or drop (`False`). + Boolean mask of length `activation_in_channels` indicating which channels to keep (`True`) or drop (`False`). H : torch.Tensor or np.ndarray - 2D square matrix of shape (N, N) typically representing a Hessian or covariance term over activations. + 2D square matrix of shape (N, N) representing gram matrices over activations. kernel_height : int Height of the convolution kernel. kernel_width : int Width of the convolution kernel. - activation_out_channels : int - Number of output channels (activations) prior to pruning. + activation_in_channels : int + Number of input channels (activations) prior to pruning. Returns ------- torch.Tensor or np.ndarray - Pruned square matrix containing only rows and columns corresponding to kept output channels. + Pruned square matrix containing only rows and columns corresponding to kept input channels. """ - assert len(prune_mask) == activation_out_channels + assert len(prune_mask) == activation_in_channels assert H.ndim == 2 kept_indices = [] numel_one_channel = kernel_height * kernel_width - slice_indices = np.arange(activation_out_channels) + slice_indices = np.arange(activation_in_channels) slice_indices = [ (start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices ] - for idx, indicator in enumerate(prune_mask): - if not indicator: - start, stop = slice_indices[idx] - kept_indices.extend(np.arange(start=start, stop=stop)) + if is_pure_gram: + for idx, indicator in enumerate(prune_mask): + if indicator: + start, stop = slice_indices[idx] + kept_indices.extend(np.arange(start=start, stop=stop)) + + H_updated = H[np.ix_(kept_indices, kept_indices)] + + return H_updated - H_updated = H[np.ix_(kept_indices, kept_indices)] + else: + H_width = H.shape[1] + mask = np.ones(H_width, dtype=bool) + slice_indices = [ + (start * numel_one_channel, start * numel_one_channel + numel_one_channel) + for start in slice_indices + ] + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + mask[start:stop] = False - return H_updated + return H[:, mask] -def recompute_W(prune_mask, W, layer_in_channels, kernel_height, kernel_width): +def recompute_W(prune_mask, W, activation_in_channels, kernel_height, kernel_width): """ Recompute the weight matrix W after pruning input channels. Parameters ---------- prune_mask : array-like of bool - Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). + Boolean mask of length `activation_in_channels` indicating which input channels to keep (`True`) or drop (`False`). W : torch.Tensor or np.ndarray - 2D weight matrix of shape (N, M) where N corresponds to flattened channel–kernel elements. - layer_in_channels : int + 2D weight matrix of shape (N, M) where N corresponds to flattened channel-kernel elements. + activation_in_channels : int Number of input channels in the convolution layer prior to pruning. kernel_height : int Height of the convolution kernel. @@ -605,21 +641,20 @@ def recompute_W(prune_mask, W, layer_in_channels, kernel_height, kernel_width): """ assert W.ndim == 2, "Weight matrix must have already been reshaped" assert ( - len(prune_mask) == layer_in_channels + len(prune_mask) == activation_in_channels ), "The length of the indicator vector and the number of in_channels must be the same" # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix numel_one_channel = kernel_height * kernel_width W_height, _ = W.shape - slice_indices = np.arange(layer_in_channels) + slice_indices = np.arange(activation_in_channels) mask = np.ones(W_height, dtype=bool) slice_indices = [ (start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices ] - print(slice_indices) for idx, indicator in enumerate(prune_mask): if not indicator: @@ -637,14 +672,18 @@ def compute_X_via_cholesky(A, B, C, eig_fallback_tol=1e-12): Returns X with same dtype/device as inputs. """ assert A.ndim == 2 and A.shape[0] == A.shape[1], "A must be square" - Y = B @ C + Y = B @ C n = A.shape[0] device = A.device dtype = A.dtype jitter = 1e-6 # Adpative scaling for proper regularization to treat ill-conditioned matrices - scale = max(torch.trace(A).abs().item() / n, torch.linalg.matrix_norm(A, ord='fro').item() / (n**0.5), 1.0) + scale = max( + torch.trace(A).abs().item() / n, + torch.linalg.matrix_norm(A, ord="fro").item() / (n**0.5), + 1.0, + ) lambda_j = jitter * scale A_reg = A + lambda_j * torch.eye(n, device=device, dtype=dtype) @@ -652,18 +691,22 @@ def compute_X_via_cholesky(A, B, C, eig_fallback_tol=1e-12): L, info = torch.linalg.cholesky_ex(A_reg) if info == 0: # LL^T X = Y => L Z = Y, then L^T X = Z - Z = torch.linalg.solve_triangular(L, Y, upper=False, left=True) + Z = torch.linalg.solve_triangular(L, Y, upper=False, left=True) X = torch.linalg.solve_triangular(L.transpose(-2, -1), Z, upper=True, left=True) return X - # Fallback: eigen-decomposition for PSD - eigvals, eigvecs = torch.linalg.eigh(A) + # Fallback: eigen-decomposition for PSD + eigvals, eigvecs = torch.linalg.eigh(A) reg_eig = eigvals + lambda_j # If reg_eig very small, clamp with tol - reg_eig_clamped = torch.where(reg_eig.abs() < eig_fallback_tol, torch.full_like(reg_eig, eig_fallback_tol), reg_eig) + reg_eig_clamped = torch.where( + reg_eig.abs() < eig_fallback_tol, + torch.full_like(reg_eig, eig_fallback_tol), + reg_eig, + ) - # X = V diag(1/reg_eig_clamped) V^T Y - VtY = eigvecs.transpose(-2, -1) @ Y + # X = V diag(1/reg_eig_clamped) V^T Y + VtY = eigvecs.transpose(-2, -1) @ Y scaled = VtY / reg_eig_clamped.unsqueeze(-1) X = eigvecs @ scaled return X @@ -717,14 +760,12 @@ def get_optimal_W(gram_xx, gram_xy, dense_weights): assert isinstance(dense_weights, torch.Tensor) assert gram_xx.ndim == 2, "Gram matrix XtX should be 2D" - assert ( - gram_xy.ndim == 2 - ), "Gram matrix XtY should be 2D" + assert gram_xy.ndim == 2, "Gram matrix XtY should be 2D" assert dense_weights.ndim == 2, "Weights should have been reshaped to 2D" assert gram_xx.shape[0] == gram_xx.shape[1], "Gram_xx must be square" assert gram_xy.shape[0] == gram_xy.shape[1], "Gram_xy must be square" - Y = gram_xy @ dense_weights + Y = gram_xy @ dense_weights n = gram_xx.shape[0] device = gram_xx.device dtype = gram_xx.dtype @@ -732,7 +773,11 @@ def get_optimal_W(gram_xx, gram_xy, dense_weights): eig_fallback_tol = 1e-12 # Adpative scaling for proper regularization to treat ill-conditioned matrices - scale = max(torch.trace(gram_xx).abs().item() / n, torch.linalg.matrix_norm(gram_xx, ord='fro').item() / (n**0.5), 1.0) + scale = max( + torch.trace(gram_xx).abs().item() / n, + torch.linalg.matrix_norm(gram_xx, ord="fro").item() / (n**0.5), + 1.0, + ) lambda_j = jitter * scale gram_xx_reg = gram_xx + lambda_j * torch.eye(n, device=device, dtype=dtype) @@ -740,22 +785,27 @@ def get_optimal_W(gram_xx, gram_xy, dense_weights): L, info = torch.linalg.cholesky_ex(gram_xx_reg) if info == 0: # LL^T X = Y => L Z = Y, then L^T X = Z - Z = torch.linalg.solve_triangular(L, Y, upper=False, left=True) + Z = torch.linalg.solve_triangular(L, Y, upper=False, left=True) X = torch.linalg.solve_triangular(L.transpose(-2, -1), Z, upper=True, left=True) return X - # Fallback: eigen-decomposition for PSD - eigvals, eigvecs = torch.linalg.eigh(gram_xx) + # Fallback: eigen-decomposition for PSD + eigvals, eigvecs = torch.linalg.eigh(gram_xx) reg_eig = eigvals + lambda_j # If reg_eig very small, clamp with tol - reg_eig_clamped = torch.where(reg_eig.abs() < eig_fallback_tol, torch.full_like(reg_eig, eig_fallback_tol), reg_eig) + reg_eig_clamped = torch.where( + reg_eig.abs() < eig_fallback_tol, + torch.full_like(reg_eig, eig_fallback_tol), + reg_eig, + ) - # X = V diag(1/reg_eig_clamped) V^T Y - VtY = eigvecs.transpose(-2, -1) @ Y + # X = V diag(1/reg_eig_clamped) V^T Y + VtY = eigvecs.transpose(-2, -1) @ Y scaled = VtY / reg_eig_clamped.unsqueeze(-1) X = eigvecs @ scaled return X + # def run_submodules_from_start_to_end_layer( # start_layer_name, end_layer_name, model, input # ): @@ -795,17 +845,17 @@ def get_optimal_W(gram_xx, gram_xy, dense_weights): # # print(f"P: {p}") # # print(f"Rem parameters to prune: {rem_params}") - # layer_in_channels = 3 + # activation_in_channels = 3 # numel_one_channel = 9 - # slice_indices = np.arange(layer_in_channels, dtype=np.int32) + # slice_indices = np.arange(activation_in_channels, dtype=np.int32) # slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel - 1) for start in slice_indices] # print(slice_indices) - # z = np.random.randint(low=0, high=2, size=layer_in_channels) + # z = np.random.randint(low=0, high=2, size=activation_in_channels) # print("Z:", z) # W = np.random.randn(27, 1) # print("W before: ", W) - # print("W after: ", get_matrix_I(z, W, layer_in_channels, 3, 3)) + # print("W after: ", get_matrix_I(z, W, activation_in_channels, 3, 3)) # layer_input = torch.randn((3, 3, 4)) # output = get_coeff_h(layer_input) @@ -821,13 +871,13 @@ def get_optimal_W(gram_xx, gram_xy, dense_weights): # print(f"Reshaped W before pruning: {W.shape}") # prune_mask = [1, 1, 0] - # W_pruned = recompute_W(prune_mask, W, layer_in_channels=3, kernel_height=3, kernel_width=3) + # W_pruned = recompute_W(prune_mask, W, activation_in_channels=3, kernel_height=3, kernel_width=3) # print(f"Reshaped W after pruning: {W_pruned.shape}") # print("Reshaped activation shape before pruning:", out.shape) - # X = recompute_X(prune_mask=prune_mask, X=out, layer_in_channels=3, kernel_height=3, kernel_width=3) + # X = recompute_X(prune_mask=prune_mask, X=out, activation_in_channels=3, kernel_height=3, kernel_width=3) # print("Reshaped activation shape after pruning:", X.shape) @@ -837,4 +887,3 @@ def get_optimal_W(gram_xx, gram_xy, dense_weights): # print(get_coeff_h(out).shape) # print(model.named_children) - diff --git a/data/load_data.py b/data/load_data.py index f1bba43..77bef4a 100644 --- a/data/load_data.py +++ b/data/load_data.py @@ -61,9 +61,9 @@ def build_eval_dataset(cfg): return dataset -def build_calibration_dataloader(dataset, num_samples, batch_size=32): +def build_calibration_dataloader(dataset, num_samples, g, batch_size=32): """ - Create a PyTorch DataLoader for a subset of a dataset to be used for calibration. + Create a reproducible PyTorch DataLoader for a subset of a dataset to be used for calibration. Parameters ---------- @@ -71,20 +71,31 @@ def build_calibration_dataloader(dataset, num_samples, batch_size=32): The full dataset to sample from. num_samples : int Number of samples from the start of the dataset to use for calibration. + g : torch.Generator + A PyTorch random number generator with a fixed manual seed. Ensures + deterministic shuffling of samples within the DataLoader. batch_size : int, optional (default=32) Batch size for the DataLoader. Returns ------- calibration_dataloader : torch.utils.data.DataLoader + A DataLoader whose batch contents and order are fixed across runs + when combined with deterministic seeding. """ + from utils.utils import seed_worker + + subset_indices = list(range(num_samples)) + subset = torch.utils.data.Subset(dataset, subset_indices) calibration_dataloader = torch.utils.data.DataLoader( - dataset[:num_samples], + subset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True, drop_last=True, + worker_init_fn=seed_worker, + generator=g, ) return calibration_dataloader diff --git a/utils/utils.py b/utils/utils.py index ebcff1e..120783f 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -1,3 +1,6 @@ +import random + +import numpy as np import torch import yaml @@ -95,3 +98,24 @@ def return_train_val_cfg(path): val_cfg = cfg["eval"] return train_cfg, val_cfg + + +def set_global_seed(seed): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def seed_worker(worker_id): + worker_seed = torch.initial_seed() % 2**32 + np.random.seed(worker_seed) + random.seed(worker_seed) + + +def safe_free(*tensors): + """Explicitly free tensors and clear cache if on GPU.""" + for t in tensors: + del t + if torch.cuda.is_available(): + torch.cuda.empty_cache()