From 9560b7000ee73afce7156fa43700c15e06663799 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 13 Sep 2025 15:45:36 +0530 Subject: [PATCH 1/7] refactor count_parameters --- utils/utils.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/utils/utils.py b/utils/utils.py index 920a4a6..ebcff1e 100644 --- a/utils/utils.py +++ b/utils/utils.py @@ -41,7 +41,7 @@ def load_yaml(path): return None -def count_parameters(model): +def count_parameters(model, in_millions=True): """ Counts the number of trainable parameters in a model. @@ -52,7 +52,10 @@ def count_parameters(model): float: Number of parameters (in millions). """ params = sum(p.numel() for p in model.parameters() if p.requires_grad) - return params / 1000000 + if in_millions: + return params / 1000000 + else: + return params def get_topk_accuracy(logits, labels, k): From d77aa16d83e11aa738d5e3f2a97a3c25625e2ca4 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Fri, 19 Sep 2025 22:58:36 +0530 Subject: [PATCH 2/7] Add calibration dataloader wrapper --- data/load_data.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/data/load_data.py b/data/load_data.py index b1a7e9b..f1bba43 100644 --- a/data/load_data.py +++ b/data/load_data.py @@ -61,6 +61,35 @@ def build_eval_dataset(cfg): return dataset +def build_calibration_dataloader(dataset, num_samples, batch_size=32): + """ + Create a PyTorch DataLoader for a subset of a dataset to be used for calibration. + + Parameters + ---------- + dataset : torch.utils.data.Dataset + The full dataset to sample from. + num_samples : int + Number of samples from the start of the dataset to use for calibration. + batch_size : int, optional (default=32) + Batch size for the DataLoader. + + Returns + ------- + calibration_dataloader : torch.utils.data.DataLoader + """ + calibration_dataloader = torch.utils.data.DataLoader( + dataset[:num_samples], + batch_size=batch_size, + shuffle=True, + num_workers=2, + pin_memory=True, + drop_last=True, + ) + + return calibration_dataloader + + def build_train_dataloader(dataset, config): """ Wraps a dataset in a DataLoader for training. From 5414e73fe75e514696d5e3540d47a199393f470b Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Fri, 19 Sep 2025 23:53:07 +0530 Subject: [PATCH 3/7] Add hooks to capture pruned module outputs --- compress/osscar.py | 324 ++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 291 insertions(+), 33 deletions(-) diff --git a/compress/osscar.py b/compress/osscar.py index 11e9f64..90ca8e8 100644 --- a/compress/osscar.py +++ b/compress/osscar.py @@ -1,4 +1,5 @@ import random +from typing import cast import numpy as np import torch @@ -6,6 +7,7 @@ from compress.heuristics import * from model.resnet import resnet50 +from data.load_data import build_eval_dataset, build_eval_dataloader def reshape_filter(filter): @@ -30,7 +32,6 @@ def reshape_filter(filter): return reshaped_filter - def reshape_conv_layer_input(input, layer): """ Unfold an input tensor using a Conv2d layer's settings. @@ -45,13 +46,13 @@ def reshape_conv_layer_input(input, layer): Returns ------- torch.Tensor - Unfolded tensor of shape (C_in*K_h*K_w, B*L), + Unfolded tensor of shape (B*L, C_in*K_h*K_w), where L is the number of sliding locations. """ assert isinstance(input, torch.Tensor), "Input must be a tensor" assert isinstance(layer, nn.Conv2d), "Layer must be a nn.Conv2d layer" assert ( - input.ndim == 3 or input.ndim == 3 + input.ndim == 3 or input.ndim == 4 ), "Input tensors must be either (C, H, W) or (B, C, H, W)" if input.ndim == 3: @@ -63,11 +64,17 @@ def reshape_conv_layer_input(input, layer): k_eff_y = (layer.kernel_size[0] - 1) * layer.dilation[0] + 1 k_eff_x = (layer.kernel_size[1] - 1) * layer.dilation[1] + 1 - if layer.padding == "same": + if isinstance(layer.padding, str) and layer.padding == "same": y_padding = ((layer.stride[0] * h - h) + k_eff_y - layer.stride[0]) // 2 x_padding = ((layer.stride[1] * w - w) + k_eff_x - layer.stride[1]) // 2 + elif isinstance(layer.padding, tuple): + y_padding, x_padding = layer.padding else: - y_padding = x_padding = 0 + y_padding = x_padding = layer.padding + + # Silence pylance's static check + y_padding = cast(int, y_padding) + x_padding = cast(int, x_padding) unfold = nn.Unfold( kernel_size=layer.kernel_size, @@ -79,37 +86,62 @@ def reshape_conv_layer_input(input, layer): input = unfold(input) input = input.permute(1, 0, 2) input = input.flatten(1) + input = input.T return input -def get_coeff_h(layer_input): +def get_coeff_h(design_matrix): """ - Compute the H coefficient (input autocorrelation matrix) from activations. + Compute the input autocorrelation (H) matrix from a 2D design matrix. Parameters ---------- - layer_input : torch.Tensor - Layer input activations before non-linearity. - Shape (B, C, H, W) or (C, H, W). + design_matrix : torch.Tensor + 2D tensor of shape (N, D), where N is the number of samples + (e.g., unfolded spatial positions across the batch) and + D is the feature dimension (e.g., C_in * K_h * K_w). Returns ------- torch.Tensor - H matrix capturing correlations of the input activations. + Square tensor of shape (D, D) representing the autocorrelation + matrix H = XᵀX of the input features. """ - assert isinstance(layer_input, torch.Tensor) - layer_input_dim = layer_input.ndim - assert ( - layer_input_dim == 3 or layer_input_dim == 4 - ), "Layer input must be of shape (B, C, H, W) or (C, H, W)" + assert isinstance(design_matrix, torch.Tensor) + assert design_matrix.ndim == 2, "Requires the reshaped design matrix" + + return design_matrix.T @ design_matrix - if layer_input_dim == 3: - H = torch.matmul(torch.transpose(layer_input, dim0=1, dim1=2), layer_input) - else: - H = torch.transpose(layer_input, dim0=2, dim1=3) @ layer_input - return H +def get_XtY(X, Y): + """ + Compute the cross-correlation matrix XᵀY between two unfolded tensors. + + Parameters + ---------- + X : torch.Tensor + 2D tensor of shape (N, D₁), typically an unfolded input/design matrix + from `reshape_conv_layer_input`, where N is the number of samples + (e.g., batch × sliding locations) and D₁ is the feature dimension. + Y : torch.Tensor + 2D tensor of shape (N, D₂), typically another unfolded tensor of + the same number of rows as `X`, but possibly with a different + feature dimension D₂. + + Returns + ------- + torch.Tensor + Matrix of shape (D₁, D₂) representing the cross-correlation + XᵀY between the two inputs. + """ + assert isinstance(X, torch.Tensor) + assert isinstance(Y, torch.Tensor) + + A = X.T + + assert A.shape[1] == Y.shape[0] + return A @ Y def get_coeff_g(dense_layer_weights, layer_input): @@ -142,8 +174,7 @@ def get_coeff_g(dense_layer_weights, layer_input): G = torch.transpose(dense_layer_weights, dim0=0, dim1=1) @ layer_input return G - - + def compute_layer_loss(dense_weights, pruned_weights, input): """ Compute the layer reconstruction loss using H and G coefficients. @@ -164,7 +195,7 @@ def compute_layer_loss(dense_weights, pruned_weights, input): Scalar loss measuring how well the pruned layer reconstructs the original. """ G = get_coeff_g(dense_layer_weights=dense_weights, layer_input=input) - H = get_coeff_h(layer_input=input) + H = get_coeff_h(design_matrix=input) A = (pruned_weights.T @ H) @ pruned_weights B = G.T @ pruned_weights @@ -378,6 +409,192 @@ def get_count_prune_channels(model, prune_percentage, allowable_tol=250): return num_channels_left_per_layer, p, remaining_params_to_prune +def save_layer_input(activations, layer_name): + """ + Factory function to create a forward hook that saves a layer’s output. + + Parameters + ---------- + activations : dict + A dictionary (mutable) where the captured outputs will be stored. + Keys are layer names, values are the output tensors. + layer_name : str + Name under which to store this layer’s output in the `activations` dict. + + Returns + ------- + hook : callable + A forward hook function with signature (module, input, output) + that can be passed to `register_forward_hook`. + """ + def hook(model, input, output): + activations[layer_name] = output.detach() + return hook + + +def register_hooks_to_collect_outs(prune_modules, prune_module_names, hook_fn): + """ + Register a forward hook on each module in `prune_modules` to collect outputs. + + Parameters + ---------- + prune_modules : list[nn.Module] + List of modules to attach hooks to (e.g. layers to prune). + prune_module_names : list[str] + Names corresponding to each module in `prune_modules`. + Must be the same length as `prune_modules`. + hook_fn : callable + A factory function that accepts `activations` (dict) and `layer_name` (str) + and returns a forward hook with signature (module, input, output). + + Returns + ------- + activations : dict + A dictionary that will be populated with {layer_name: output_tensor} + during a forward pass. + """ + activations = {} + for idx, module in enumerate(prune_modules): + assert isinstance(module, nn.Conv2d) + module_name = prune_module_names[idx] + module.register_forward_hook(hook=hook_fn(activations=activations, layer_name=module_name)) + + return activations + +def recompute_X(prune_mask, X, layer_in_channels, kernel_height, kernel_width): + """ + Recompute the unfolded input matrix X after pruning input channels. + + Parameters + ---------- + prune_mask : array-like of bool + Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). + X : torch.Tensor or np.ndarray + 2D unfolded input matrix of shape (N, M) where N corresponds to flattened channel–kernel elements and + M to the number of sliding positions or samples. + layer_in_channels : int + Number of input channels in the convolution layer prior to pruning. + kernel_height : int + Height of the convolution kernel. + kernel_width : int + Width of the convolution kernel. + + Returns + ------- + torch.Tensor or np.ndarray + Pruned unfolded input matrix with rows corresponding only to kept input channels. + """ + assert X.ndim == 2, "Weight matrix must have already been reshaped" + assert len(prune_mask) == layer_in_channels, "The length of the indicator vector and the number of in_channels must be the same" + + # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix + numel_one_channel = kernel_height * kernel_width + _, X_width = X.shape + + slice_indices = np.arange(layer_in_channels) + mask = np.ones(X_width, dtype=bool) + slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices] + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + mask[start:stop] = False + + return X[:, mask] + +def recompute_H(prune_mask, H, kernel_height, kernel_width, activation_out_channels): + """ + Recompute the coefficient matrix H after pruning output channels. + + Parameters + ---------- + prune_mask : array-like of bool + Boolean mask of length `activation_out_channels` indicating which output channels to keep (`True`) or drop (`False`). + H : torch.Tensor or np.ndarray + 2D square matrix of shape (N, N) typically representing a Hessian or covariance term over activations. + kernel_height : int + Height of the convolution kernel. + kernel_width : int + Width of the convolution kernel. + activation_out_channels : int + Number of output channels (activations) prior to pruning. + + Returns + ------- + torch.Tensor or np.ndarray + Pruned square matrix containing only rows and columns corresponding to kept output channels. + """ + assert len(prune_mask) == activation_out_channels + assert H.ndim == 2 + + kept_indices = [] + numel_one_channel = kernel_height * kernel_width + slice_indices = np.arange(activation_out_channels) + slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices] + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + kept_indices.extend(np.arange(start=start, stop=stop)) + + H_updated = H[np.ix_(kept_indices, kept_indices)] + + return H_updated + +def recompute_W(prune_mask, W, layer_in_channels, kernel_height, kernel_width): + """ + Recompute the weight matrix W after pruning input channels. + + Parameters + ---------- + prune_mask : array-like of bool + Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). + W : torch.Tensor or np.ndarray + 2D weight matrix of shape (N, M) where N corresponds to flattened channel–kernel elements. + layer_in_channels : int + Number of input channels in the convolution layer prior to pruning. + kernel_height : int + Height of the convolution kernel. + kernel_width : int + Width of the convolution kernel. + + Returns + ------- + torch.Tensor or np.ndarray + Pruned weight matrix with rows corresponding only to kept input channels. + """ + assert W.ndim == 2, "Weight matrix must have already been reshaped" + assert len(prune_mask) == layer_in_channels, "The length of the indicator vector and the number of in_channels must be the same" + + # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix + numel_one_channel = kernel_height * kernel_width + W_height, _ = W.shape + + slice_indices = np.arange(layer_in_channels) + mask = np.ones(W_height, dtype=bool) + + slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel) for start in slice_indices] + print(slice_indices) + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + mask[start:stop] = False + + return W[mask, :] + + +def get_optimal_W(pruned_activation, dense_activation, dense_weights): + assert isinstance(pruned_activation, torch.Tensor) + assert isinstance(dense_activation, torch.Tensor) + assert isinstance(dense_weights, torch.Tensor) + + assert pruned_activation.ndim == 2, "Pruned activation should be 2D" + assert dense_activation.ndim == 2, "Dense activation should have been reshaped to 2D" + assert dense_weights.ndim == 2, "Weights should have been reshaped to 2D" + +def get_calibration_dataset(): + if __name__ == "__main__": # rand_filter = torch.randn((8, 3, 3, 3)) @@ -393,12 +610,53 @@ def get_count_prune_channels(model, prune_percentage, allowable_tol=250): # inp = inp.permute([1, 0, 2]) # inp = inp.flatten(1) - # print(inp.shape) - model = resnet50(pretrained=True) - overall_prune_percentage = 0.3 - _, p, rem_params = get_count_prune_channels( - model=model, prune_percentage=overall_prune_percentage - ) - - print(f"P: {p}") - print(f"Rem parameters to prune: {rem_params}") + # # print(inp.shape) + # model = resnet50(pretrained=True) + # overall_prune_percentage = 0.3 + # # _, p, rem_params = get_count_prune_channels( + # # model=model, prune_percentage=overall_prune_percentage + # # ) + + # # print(f"P: {p}") + # # print(f"Rem parameters to prune: {rem_params}") + # layer_in_channels = 3 + # numel_one_channel = 9 + # slice_indices = np.arange(layer_in_channels, dtype=np.int32) + # slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel - 1) for start in slice_indices] + # print(slice_indices) + + + # z = np.random.randint(low=0, high=2, size=layer_in_channels) + # print("Z:", z) + # W = np.random.randn(27, 1) + # print("W before: ", W) + # print("W after: ", get_matrix_I(z, W, layer_in_channels, 3, 3)) + + # layer_input = torch.randn((3, 3, 4)) + # output = get_coeff_h(layer_input) + # print(output.shape) + + input_activation = torch.randn(2, 3, 32, 32) + pruned_activation = torch.randn(2, 2, 32, 32) + mini_conv = nn.Conv2d(padding=1, kernel_size=3, in_channels=3, out_channels=8) + + out = reshape_conv_layer_input(input_activation, mini_conv) + + W = reshape_filter(mini_conv.weight) + print(f"Reshaped W before pruning: {W.shape}") + + prune_mask = [1, 1, 0] + W_pruned = recompute_W(prune_mask, W, layer_in_channels=3, kernel_height=3, kernel_width=3) + + print(f"Reshaped W after pruning: {W_pruned.shape}") + + print("Reshaped activation shape before pruning:", out.shape) + + X = recompute_X(prune_mask=prune_mask, X=out, layer_in_channels=3, kernel_height=3, kernel_width=3) + + print("Reshaped activation shape after pruning:", X.shape) + + corr = get_XtY(out, X) + print("XtY shape", corr.shape) + + print(get_coeff_h(out).shape) \ No newline at end of file From bc9066614ff5f59968e04ae24f98122d91ec7d2b Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sun, 21 Sep 2025 17:13:12 +0530 Subject: [PATCH 4/7] Add logic to derive subgraphs from the full forward --- compress/graph_extractor.py | 230 ++++++++++++++++++++++++++++++++++++ 1 file changed, 230 insertions(+) create mode 100644 compress/graph_extractor.py diff --git a/compress/graph_extractor.py b/compress/graph_extractor.py new file mode 100644 index 0000000..16a7a2d --- /dev/null +++ b/compress/graph_extractor.py @@ -0,0 +1,230 @@ +import torch +import torch.fx as fx + +from compress.heuristics import collect_convolution_layers_to_prune +from model.resnet import resnet50 + + +def get_initial_prefix_submodule(graph_module, end_node): + """ + Extracts a prefix subgraph from the start of the FX GraphModule up to (but not including) `end_node`. + + Parameters + ---------- + graph_module : torch.fx.GraphModule + The traced FX GraphModule from which to extract the prefix subgraph. + end_node : str + The name of the node where the prefix subgraph should stop (exclusive). + + Returns + ------- + prefix_gm : torch.fx.GraphModule + A new GraphModule containing only the nodes from the start up to `end_node`. + value_remap : dict + A mapping from original nodes in `graph_module` to the corresponding nodes + in the prefix subgraph. Useful for connecting this prefix to subsequent subgraphs. + + Notes + ----- + - The resulting GraphModule can be called like a normal module, taking the input tensor + that corresponds to the placeholder node. + - The output of the prefix subgraph is the last node before `end_node`. + """ + assert isinstance(graph_module, fx.GraphModule) + graph = graph_module.graph + prefix_nodes = [] + prefix_graph = fx.Graph() + value_remap = {} + + for node in graph.nodes: + if node.name == end_node: + break + else: + prefix_nodes.append(node) + + assert len(prefix_nodes) > 0, "Prefix nodes must not be empty" + + for node in prefix_nodes: + new_node = prefix_graph.node_copy(node, lambda n: value_remap[n]) + value_remap[node] = new_node + + last_node = prefix_nodes[-1] + prefix_graph.output(value_remap[last_node]) + + prefix_gm = fx.GraphModule(root=graph_module, graph=prefix_graph) + return prefix_gm, value_remap + + +def get_fx_submodule(graph_module, value_remap, start_node, end_node): + """ + Extracts a middle subgraph from an FX GraphModule between `start_node` and `end_node`. + + Parameters + ---------- + graph_module : torch.fx.GraphModule + The traced FX GraphModule containing the nodes. + value_remap : dict + A mapping from previously copied nodes (e.g., from a prefix subgraph) to the + corresponding new nodes. Must include any nodes that are inputs to this subgraph. + start_node : str + The name of the node where the subgraph should start (inclusive). + end_node : str + The name of the node where the subgraph should end (exclusive). + + Returns + ------- + new_gm : torch.fx.GraphModule + A new GraphModule containing the nodes between `start_node` and `end_node`. + Automatically adds placeholder nodes for inputs if needed. + value_remap : dict + Updated mapping of original nodes to the corresponding nodes in the new subgraph. + Can be used to chain multiple subgraph extractions together. + + Notes + ----- + - Placeholder nodes are automatically created for any input nodes that are not in `value_remap`. + - The output of the subgraph is set to the last node before `end_node`. + - The resulting GraphModule can be called with the input tensors corresponding to the placeholders. + """ + assert isinstance(graph_module, fx.GraphModule) + assert isinstance(value_remap, dict) + assert ( + len(value_remap) > 0 + ), "Remap dict cant be empty for slices in the middle of the model" + graph = graph_module.graph + new_nodes = [] + new_graph = fx.Graph() + keep = False + + for node in graph.nodes: + if node.name == start_node: + keep = True + if node.name == end_node: + break + if keep: + new_nodes.append(node) + + assert len(new_nodes) > 0, "Node list must not be empty" + + # Adds placeholder to the beginning of subgraph so that its forward can take an input + first_node = new_nodes[0] + for arg in first_node.args: + if isinstance(arg, fx.Node): + ph = new_graph.placeholder(f"input_{arg.name}") + value_remap[arg] = ph + + for node in new_nodes: + new_node = new_graph.node_copy(node, lambda n: value_remap[n]) + value_remap[node] = new_node + + last_node = new_nodes[-1] + new_graph.output(value_remap[last_node]) + + new_gm = fx.GraphModule(root=graph_module, graph=new_graph) + + return new_gm, value_remap + + +def get_suffix_submodule( + graph_module: fx.GraphModule, value_remap: dict, start_node: str +): + """ + Extracts the subgraph from `start_node` (inclusive) to the final output of the model. + + Parameters + ---------- + graph_module : fx.GraphModule + The FX-traced full model. + value_remap : dict + Mapping from previous nodes to their placeholders/substitutes (for start_node input). + start_node : str + Name of the node where the suffix begins. + + Returns + ------- + suffix_gm : fx.GraphModule + FX GraphModule for the suffix. + value_remap : dict + Updated mapping including suffix nodes. + """ + graph = graph_module.graph + new_graph = fx.Graph() + new_nodes = [] + keep = False + + for node in graph.nodes: + if node.name == start_node: + keep = True + if keep: + new_nodes.append(node) + + assert len(new_nodes) > 0, "Suffix nodes cannot be empty" + + # Add a placeholder for the start node input if it hasn't been mapped yet + first_node = new_nodes[0] + for arg in first_node.args: + if isinstance(arg, fx.Node): + ph = new_graph.placeholder(f"input_{arg.name}") + value_remap[arg] = ph + + for node in new_nodes: + new_node = new_graph.node_copy(node, lambda n: value_remap[n]) + value_remap[node] = new_node + + # Explicitly define output of the subgraph + last_node = new_nodes[-1] + new_graph.output(value_remap[last_node]) + + suffix_gm = fx.GraphModule(graph_module, new_graph) + return suffix_gm, value_remap + + +if __name__ == "__main__": + model = resnet50(pretrained=True) + input = torch.randn(1, 3, 224, 224) + + prune_conv_modules, prune_modules_name = collect_convolution_layers_to_prune( + model=model + ) + + reformatted = [] + for name in prune_modules_name: + reformatted_name = "_".join(name.split(".")) + reformatted.append(reformatted_name) + + end = reformatted[0] + gm = fx.symbolic_trace(model) + prefix_gm, remap = get_initial_prefix_submodule(graph_module=gm, end_node=end) + + start_node = reformatted[0] + end_node = reformatted[1] + + subgraph, remap_dict = get_fx_submodule( + graph_module=gm, value_remap=remap, start_node=start_node, end_node=end_node + ) + + out = prefix_gm(input) + out = subgraph(out) + + print(out.shape) + + outputs = {} + + def hook_fn(model, input, output): + outputs["maxpool"] = output + + model.maxpool.register_forward_hook(hook_fn) + model(input) + print(outputs["maxpool"].shape) + + suffix_module, remap_final = get_suffix_submodule( + graph_module=gm, value_remap=remap, start_node=end_node + ) + out = suffix_module(out) + # print(suffix_module) + + direct_out = model(input) + + assert torch.allclose(out, direct_out) + print(out.shape) + print(direct_out.shape) From 7157e73dbe6bc120948995ae78f1853df42f3526 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:09:29 +0530 Subject: [PATCH 5/7] Minor docs fix --- compress/heuristics.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/compress/heuristics.py b/compress/heuristics.py index 60f9202..7d8e358 100644 --- a/compress/heuristics.py +++ b/compress/heuristics.py @@ -42,7 +42,7 @@ def collect_convolution_layers_to_prune(model: nn.Module): Collect convolutional layers from a model that are eligible for pruning. By default this picks all modules whose name contains 'conv' but does not - end with 'conv3' (e.g. skips the final 1×1 conv in ResNet bottlenecks). + end with 'conv3' (e.g. skips the final 1x1 conv in ResNet bottlenecks). Parameters ---------- From a3637e9ad0d771946ebccc39dd5cdf0f0ae501cf Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:10:04 +0530 Subject: [PATCH 6/7] Add submodule extraction logic --- compress/graph_extractor.py | 261 ++++++++++++++++++++++++++++++++---- 1 file changed, 236 insertions(+), 25 deletions(-) diff --git a/compress/graph_extractor.py b/compress/graph_extractor.py index 16a7a2d..1b684ed 100644 --- a/compress/graph_extractor.py +++ b/compress/graph_extractor.py @@ -1,7 +1,17 @@ +from typing import Optional +import numpy as np import torch import torch.fx as fx +from torch import nn from compress.heuristics import collect_convolution_layers_to_prune +from compress.osscar_utils import ( + get_XtY, + get_coeff_h, + reshape_conv_layer_input, + reshape_filter, + get_optimal_W, +) from model.resnet import resnet50 @@ -160,7 +170,7 @@ def get_suffix_submodule( assert len(new_nodes) > 0, "Suffix nodes cannot be empty" - # Add a placeholder for the start node input if it hasn't been mapped yet + # Adds placeholder to the beginning of subgraph so that its forward can take an input first_node = new_nodes[0] for arg in first_node.args: if isinstance(arg, fx.Node): @@ -179,10 +189,179 @@ def get_suffix_submodule( return suffix_gm, value_remap +def get_all_subnets(prune_modules_name, graph_module): + """ + Build prefix / middle / suffix sub-networks around layers to be pruned. + + Parameters + ---------- + prune_modules_name : list[str] + Names of modules in `graph_module` that are candidates for pruning. + + Returns + ------- + prune_subnets : list[torch.fx.GraphModule] + A list of GraphModules corresponding to each subnetwork *to be pruned*. + These subnets may change (weights or architecture) during pruning passes. + + dense_subnets : list[torch.fx.GraphModule] + A list of GraphModules with exactly the same slices as `prune_subnets`, + but using the original (dense) architecture and weights. This list remains + fixed across pruning passes. + + Notes + ----- + - The sequence of subnets covers: the prefix up to the first prune module, + the individual prune-module slices, and the suffix after the last prune module. + - `value_remap` is used internally to maintain graph node identity across slices. + """ + assert isinstance( + graph_module, fx.GraphModule + ), "Graph module must be an instance of fx.GraphModule" + assert len(prune_modules_name) > 0, "Prune list must not be empty" + prune_subnets = [] + dense_subnets = [] + remap = {} + + gm = graph_module + + for idx, name in enumerate(prune_modules_name): + fx_name = "_".join(name.split(".")) + + if idx == 0: + subnet, remap = get_initial_prefix_submodule( + graph_module=gm, end_node=fx_name + ) + elif idx == len(prune_modules_name) - 1: + subnet, remap = get_suffix_submodule( + graph_module=gm, value_remap=remap, start_node=fx_name + ) + else: + end_node = "_".join(prune_modules_name[idx + 1].split(".")) + subnet, remap = get_fx_submodule( + graph_module=gm, + value_remap=remap, + start_node=fx_name, + end_node=end_node, + ) + + prune_subnets.append(subnet) + dense_subnets.append(subnet) + + return prune_subnets, dense_subnets + + +def perform_local_search( + w_optimal, + layer, + p, + prune_by_iter: Optional[list], + sym_diff_per_iter: Optional[list], + prune_per_iter=2, +): + assert isinstance(layer, nn.Conv2d) + assert isinstance(p, int) and p > 0 + if prune_by_iter is None: + prune_list = [] + num_iterations = p // prune_per_iter + rem = p % prune_per_iter + zero_set = set() + + if p >= prune_per_iter: + prune_list.extend([prune_per_iter for i in range(num_iterations)]) + if rem > 0: + prune_list.extend([rem]) + else: + + assert np.sum(np.array(prune_by_iter)) == p + prune_list = prune_by_iter + + if sym_diff_per_iter is not None: + assert len(prune_list) == len(sym_diff_per_iter) + assert all(p <= t for p, t in zip(prune_list, sym_diff_per_iter)) + + else: + sym_diff_per_iter = prune_list + + # for i in range(len(sym_diff_per_iter)): + # num_prune_iter = sym_diff_per_iter[] + + +def osscar_prune( + dense_subnet, + prune_subnet, + dense_input, + pruned_input, + cached_dense_out, + cached_prune_out, + layer_name, + prune_layer_counts, +): + assert isinstance(dense_subnet, nn.Module), "Needs a valid subnet(nn.Module)" + assert isinstance(prune_subnet, nn.Module), "Needs a valid subnet(nn.Module)" + assert isinstance(dense_input, torch.Tensor) + assert isinstance(pruned_input, torch.Tensor) + assert pruned_input.ndim == 4, "Pruned input tensor must be of the shape B, C, H, W" + assert dense_input.ndim == 4, "Dense input tensor must be of the shape B, C, H, W" + + num_batches = dense_input.shape[0] + conv_module = None + reshaped_conv_wt = None + + # The way subnets are designed nmakes the first module of a subnet the conv2d thats needed to be pruned + for _, module in dense_subnet.named_modules(): + conv_module = module + assert isinstance(conv_module, nn.Conv2d) + reshaped_conv_wt = reshape_filter(conv_module.weight) + break + + print(conv_module) + + assert ( + num_batches == pruned_input.shape[0] + ), "Dense and prune inputs must have the same number of images" + + total_xtx = 0 + total_xty = 0 + + for batch_idx in range(num_batches): + dense_batch = dense_input[batch_idx, :, :, :] + pruned_batch = pruned_input[batch_idx, :, :, :] + dense_batch = reshape_conv_layer_input(input=dense_batch, layer=conv_module) + pruned_batch = reshape_conv_layer_input(input=pruned_batch, layer=conv_module) + total_xtx += get_coeff_h(pruned_batch) + total_xty += get_XtY(pruned_batch, dense_batch) + + # Note to self: Remember to double check + total_xtx /= num_batches + total_xty /= num_batches + + w_optimal = get_optimal_W( + gram_xx=total_xtx, gram_xy=total_xty, dense_weights=reshaped_conv_wt + ) + + # Note to self: prune_layer_counts is supposed to be a dicttionary now, beware, dont forget to change + p = prune_layer_counts["layer_name"] + # pruned_layer = perform_local_search( + # w_optimal, layer_name, prune_layer_counts + # ) + + # Perform forward till the next layer to be pruned and cache results + # dense_out = dense_subnet(dense_input[batch_idx, :, :, :]) + # prune_out = prune_subnet(pruned_input[batch_idx, :, :, :]) + + # cached_out = + + if __name__ == "__main__": model = resnet50(pretrained=True) input = torch.randn(1, 3, 224, 224) + weights = model.conv1.weight + # print(weights.shape) + + # model.conv1.weight = weights[:, :2, :, :] + prune_conv_modules, prune_modules_name = collect_convolution_layers_to_prune( model=model ) @@ -194,37 +373,69 @@ def get_suffix_submodule( end = reformatted[0] gm = fx.symbolic_trace(model) - prefix_gm, remap = get_initial_prefix_submodule(graph_module=gm, end_node=end) - - start_node = reformatted[0] - end_node = reformatted[1] - subgraph, remap_dict = get_fx_submodule( - graph_module=gm, value_remap=remap, start_node=start_node, end_node=end_node + prune_subnets, dense_subnets = get_all_subnets( + graph_module=gm, prune_modules_name=prune_modules_name ) + assert len(prune_subnets) == len(dense_subnets) - out = prefix_gm(input) - out = subgraph(out) + print(prune_subnets[0]) + print(dense_subnets[1]) + # prefix_gm, remap = get_initial_prefix_submodule(graph_module=gm, end_node=end) - print(out.shape) + # start_node = reformatted[0] + # end_node = reformatted[1] - outputs = {} + # subnet, remap_dict = get_fx_submodule( + # graph_module=gm, value_remap=remap, start_node=start_node, end_node=end_node + # ) - def hook_fn(model, input, output): - outputs["maxpool"] = output + # out = prefix_gm(input) + # # print(out.shape) - model.maxpool.register_forward_hook(hook_fn) - model(input) - print(outputs["maxpool"].shape) + # # out = out[:, :2, :, :] + # # model.conv1.weight = torch.nn.Parameter(weights[:, :2, :, :], requires_grad=True) - suffix_module, remap_final = get_suffix_submodule( - graph_module=gm, value_remap=remap, start_node=end_node - ) - out = suffix_module(out) - # print(suffix_module) + # out = subnet(out) + + # print(subnet.conv1) + + # print(out.shape) + + # outputs = {} + + # def hook_fn(model, input, output): + # outputs["maxpool"] = output + + # model.maxpool.register_forward_hook(hook_fn) + # model(input) + # print(outputs["maxpool"].shape) + + # suffix_module, remap_final = get_suffix_submodule( + # graph_module=gm, value_remap=remap, start_node=end_node + # ) + # out = suffix_module(out) + # # print(suffix_module) + + # direct_out = model(input) - direct_out = model(input) + # assert torch.allclose(out, direct_out) + # print(out.shape) + # print(direct_out.shape) - assert torch.allclose(out, direct_out) - print(out.shape) - print(direct_out.shape) + # perform_local_search() + p = 1 + prune_per_iter = 3 + num_iterations = p // prune_per_iter + rem = p % prune_per_iter + zero_set = set() + sym_diff_per_iter = [] + # while p > 1: + # p = p - prune_per_iter + # sym_diff_per_iter.extend([prune_per_iter]) + if p >= prune_per_iter: + sym_diff_per_iter.extend([prune_per_iter for i in range(num_iterations)]) + if rem > 0: + sym_diff_per_iter.extend([rem]) + print(sym_diff_per_iter) + # for From 8ed5cefe7c3aef916c322712fe9134773e1fd549 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Fri, 3 Oct 2025 18:11:19 +0530 Subject: [PATCH 7/7] Add Osscar specific utilities --- compress/osscar_utils.py | 840 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 840 insertions(+) create mode 100644 compress/osscar_utils.py diff --git a/compress/osscar_utils.py b/compress/osscar_utils.py new file mode 100644 index 0000000..6370eed --- /dev/null +++ b/compress/osscar_utils.py @@ -0,0 +1,840 @@ +import random +from typing import cast + +import numpy as np +import torch +from torch import nn + +from compress.heuristics import * +from data.load_data import build_eval_dataloader, build_eval_dataset +from model.resnet import resnet50 + + +def reshape_filter(filter): + """ + Rearrange a Conv2d weight tensor according to OSSCAR paper. + + Parameters + ---------- + filter : torch.Tensor + Weight tensor of shape (C_out, C_in, K_h, K_w). + + Returns + ------- + torch.Tensor + Reshaped tensor of shape (C_in*K_h*K_w, C_out). + """ + assert isinstance(filter, torch.Tensor) + assert filter.ndim == 4, "Filter shape must be (Cout, Cin, Kh, Kw)" + cout, _, _, _ = filter.size() + reshaped_filter = filter.permute(1, 2, 3, 0) + reshaped_filter = reshaped_filter.reshape(-1, cout) + + return reshaped_filter + + +def reshape_conv_layer_input(input, layer): + """ + Unfold an input tensor using a Conv2d layer's settings. + + Parameters + ---------- + input : torch.Tensor + Input tensor of shape (C, H, W) or (B, C, H, W). + layer : nn.Conv2d + Conv2d layer whose kernel/stride/dilation/padding define the unfolding. + + Returns + ------- + torch.Tensor + Unfolded tensor of shape (B*L, C_in*K_h*K_w), + where L is the number of sliding locations. + """ + assert isinstance(input, torch.Tensor), "Input must be a tensor" + assert isinstance(layer, nn.Conv2d), "Layer must be a nn.Conv2d layer" + assert ( + input.ndim == 3 or input.ndim == 4 + ), "Input tensors must be either (C, H, W) or (B, C, H, W)" + + if input.ndim == 3: + input = input.unsqueeze(dim=0) + + _, _, h, w = input.shape + + # Effective size of a kernel changes in a dilated conv op + k_eff_y = (layer.kernel_size[0] - 1) * layer.dilation[0] + 1 + k_eff_x = (layer.kernel_size[1] - 1) * layer.dilation[1] + 1 + + if isinstance(layer.padding, str) and layer.padding == "same": + y_padding = ((layer.stride[0] * h - h) + k_eff_y - layer.stride[0]) // 2 + x_padding = ((layer.stride[1] * w - w) + k_eff_x - layer.stride[1]) // 2 + elif isinstance(layer.padding, tuple): + y_padding, x_padding = layer.padding + else: + y_padding = x_padding = layer.padding + + # Silence pylance's static check + y_padding = cast(int, y_padding) + x_padding = cast(int, x_padding) + + unfold = nn.Unfold( + kernel_size=layer.kernel_size, + dilation=layer.dilation, + padding=(y_padding, x_padding), + stride=layer.stride, + ) + + input = unfold(input) + input = input.permute(1, 0, 2) + input = input.flatten(1) + input = input.T + + return input + + +def get_coeff_h(design_matrix): + """ + Compute the input autocorrelation (H) matrix from a 2D design matrix. + + Parameters + ---------- + design_matrix : torch.Tensor + 2D tensor of shape (N, D), where N is the number of samples + (e.g., unfolded spatial positions across the batch) and + D is the feature dimension (e.g., C_in * K_h * K_w). + + Returns + ------- + torch.Tensor + Square tensor of shape (D, D) representing the autocorrelation + matrix H = XᵀX of the input features. + """ + assert isinstance(design_matrix, torch.Tensor) + assert design_matrix.ndim == 2, "Requires the reshaped design matrix" + + return design_matrix.T @ design_matrix + + +def get_XtY(X, Y): + """ + Compute the cross-correlation matrix XᵀY between two unfolded tensors. + + Parameters + ---------- + X : torch.Tensor + 2D tensor of shape (N, D₁), an unfolded input/design matrix + from `reshape_conv_layer_input`, where N is the number of samples + (e.g., batch x sliding locations) and D₁ is the feature dimension. + Y : torch.Tensor + 2D tensor of shape (N, D₂), another unfolded tensor of + the same number of rows as `X`, but possibly with a different + feature dimension D₂. + + Returns + ------- + torch.Tensor + Matrix of shape (D₁, D₂) representing the cross-correlation + XᵀY between the two inputs. + """ + assert isinstance(X, torch.Tensor) + assert isinstance(Y, torch.Tensor) + + A = X.T + + assert A.shape[1] == Y.shape[0] + return A @ Y + + +def get_coeff_g(dense_layer_weights, layer_input): + """ + Compute the G coefficient for a dense (reshaped) layer weight matrix. + + Parameters + ---------- + dense_layer_weights : torch.Tensor + Layer weights reshaped to 2D, shape (out_features, in_features). + layer_input : torch.Tensor + Layer input activations before non-linearity. + Shape (B, C, H, W) or (C, H, W). + + Returns + ------- + torch.Tensor + G matrix capturing the projection of inputs onto the weight space. + """ + assert isinstance(layer_input, torch.Tensor) + assert isinstance(dense_layer_weights, torch.Tensor) + + layer_input_dim = layer_input.ndim + + assert dense_layer_weights.ndim == 2, "get_coeff_g takes in the reshaped weights" + assert ( + layer_input_dim == 3 or layer_input_dim == 4 + ), "Layer input must be of shape (B, C, H, W) or (C, H, W)" + + G = torch.transpose(dense_layer_weights, dim0=0, dim1=1) @ layer_input + + return G + + +def compute_layer_loss(dense_weights, pruned_weights, input): + """ + Compute the layer reconstruction loss using H and G coefficients. + + Parameters + ---------- + dense_weights : torch.Tensor + Original (dense) layer weights reshaped to 2D. + pruned_weights : torch.Tensor + Pruned layer weights reshaped to 2D. + input : torch.Tensor + Layer input activations before non-linearity. + Shape (B, C, H, W) or (C, H, W). + + Returns + ------- + torch.Tensor + Scalar loss measuring how well the pruned layer reconstructs the original. + """ + G = get_coeff_g(dense_layer_weights=dense_weights, layer_input=input) + H = get_coeff_h(design_matrix=input) + A = (pruned_weights.T @ H) @ pruned_weights + B = G.T @ pruned_weights + + assert A.ndim == B.ndim == 2, "Trace can be computed only for 2D matrices" + loss = 0.5 * torch.trace(A) + torch.trace(B) + + return loss + + +def num_params_in_prune_channels(layers): + """ + Compute the total number of parameters across a list of convolutional layers. + + Parameters + ---------- + layers : list of nn.Conv2d + The convolutional layers whose parameters you want to count. + + Returns + ------- + int + Total number of parameters (weights + biases if present) in the given layers. + """ + params = 0 + + for layer in layers: + assert isinstance(layer, nn.Conv2d) + params += count_parameters(layer, in_millions=False) + + return params + + +def recalculate_importance(rem_channels_layer_wise): + """ + Recalculate normalized importance weights for each layer + based on the remaining number of channels per layer. + + Parameters + ---------- + rem_channels_layer_wise : array-like of int + Number of channels remaining in each layer. + + Returns + ------- + numpy.ndarray + Normalized importance values for each layer (sum to 1). + """ + total = np.sum(rem_channels_layer_wise) + rem_imp = np.divide(rem_channels_layer_wise, total) + + return rem_imp + + +def distribute_remaining_parameters( + rem_params, rem_channels_per_layer, layers, num_iters=20, allowable_tol=250 +): + """ + Stochastically allocate leftover parameter removals across layers. + + At each iteration, a layer is sampled according to a probability + distribution proportional to its remaining channels. One input channel + is removed from the chosen layer if possible, and the remaining parameter + budget is updated. The loop stops when the budget is within the allowable + tolerance or after `num_iters` iterations. + + Parameters + ---------- + rem_params : int + Remaining number of parameters to remove. + rem_channels_per_layer : list of int + Remaining number of channels per layer (mutable, will be updated). + layers : list of nn.Conv2d + Convolutional layers eligible for pruning. + num_iters : int, optional + Maximum number of allocation iterations. Default is 20. + allowable_tol : int, optional + Stop when the remaining parameter budget is within this tolerance. Default is 250. + + Returns + ------- + tuple + p : numpy.ndarray + Array of additional channels removed per layer. + rem_params : int + Remaining number of parameters still to remove after allocation. + """ + num_layers = len(layers) + layer_choices = np.arange(num_layers) + p = np.zeros(num_layers, dtype=np.int32) + rng = random.Random(3) + + rem_imp = recalculate_importance(rem_channels_layer_wise=rem_channels_per_layer) + + for i in range(num_iters): + random_layer_idx = rng.choices(layer_choices, weights=rem_imp)[0] + assert isinstance(random_layer_idx, (int, np.integer)) + + layer = layers[random_layer_idx] + + assert isinstance(layer, nn.Conv2d) + params_removed = ( + layer.kernel_size[0] * layer.kernel_size[1] * layer.out_channels + ) + count_remove_params = rem_params - params_removed + if rem_channels_per_layer[random_layer_idx] > 1 and count_remove_params >= 0: + p[random_layer_idx] += 1 + rem_channels_per_layer[random_layer_idx] -= 1 + + if rem_params - count_remove_params < 0: + continue + + rem_params = count_remove_params + rem_imp = recalculate_importance(rem_channels_per_layer) + + if rem_params <= allowable_tol: + break + + return p, rem_params + + +def get_count_prune_channels(model, prune_percentage, allowable_tol=250): + """ + Compute per-layer channel pruning counts to reach a target global prune percentage. + + This function: + 1. Collects the convolutional layers eligible for pruning. + 2. Computes how many parameters to remove from each layer based on an + importance heuristic. + 3. Converts parameter removals into channel counts per layer. + 4. Distributes any remaining parameter removals stochastically to meet + the target budget within the allowable tolerance. + + Parameters + ---------- + model : nn.Module + The model to prune. + prune_percentage : float + Target fraction of total model parameters to remove (0–1). + allowable_tol : int, optional + Tolerance for how far from the target parameter count to allow. Default is 250. + + Returns + ------- + tuple + num_channels_left_per_layer : list of int + Number of input channels to keep per layer after pruning. + p : list or numpy.ndarray + Number of input channels to prune per layer (sum of deterministic and random). + remaining_params_to_prune : int + Number of parameters still to prune after allocation (ideally <= allowable_tol). + """ + model_total_params = count_parameters(model, in_millions=False) + + # Collect layers eligible to be pruned + layers, _ = collect_convolution_layers_to_prune(model) + + # Quick sanity check : ensures that total params to be removed doesnt exceed those that can be provided by eligible layers + eligible_prune_params_count = num_params_in_prune_channels(layers=layers) + total_params_to_prune = int(model_total_params * prune_percentage) + + assert eligible_prune_params_count > total_params_to_prune + + # Computes relative importance of each layer, higher importance => More channels pruned from this layer + importance_list = compute_layer_importance_heuristic(layers) + + assert math.isclose( + np.sum(importance_list), 1.0, abs_tol=0.00001 + ), "importance scores must sum to 1" + + # Number of params to remove from every eligible layer + num_prune_params_by_layer = total_params_to_prune * importance_list + num_prune_params_by_layer = np.floor(total_params_to_prune * importance_list) + + revised_prune_params_count = 0 + p = [] + num_channels_left_per_layer = [] + + for idx, layer in enumerate(layers): + assert isinstance(layer, nn.Conv2d) + + num_spatial_params = layer.kernel_size[0] * layer.kernel_size[1] + num_channels_per_filter = layer.in_channels + num_filters = layer.out_channels + + # Num channels to remove from every player(defined as per osscar) + num_channels_to_remove = num_prune_params_by_layer[idx] // ( + num_spatial_params * num_filters + ) + p.append(num_channels_to_remove) + + assert ( + num_channels_to_remove < num_channels_per_filter + ), "Cant remove all channels in a filter" + + num_params_removed = num_spatial_params * num_channels_to_remove * num_filters + revised_prune_params_count += num_params_removed + num_channels_left = num_channels_per_filter - num_channels_to_remove + num_channels_left_per_layer.append(num_channels_left) + + remaining_params_to_prune = total_params_to_prune - revised_prune_params_count + + if remaining_params_to_prune > allowable_tol: + p_rem, remaining_params_to_prune = distribute_remaining_parameters( + rem_params=remaining_params_to_prune, + rem_channels_per_layer=num_channels_left_per_layer, + layers=layers, + allowable_tol=allowable_tol, + ) + + p = np.array(p) + np.array(p_rem) + + return num_channels_left_per_layer, p, remaining_params_to_prune + + +def save_and_accumulate_layer_input(activations, layer_name): + """ + Create a forward hook that computes and accumulates a per-layer statistic + (e.g. Gram matrix) from the layer's input. + + Parameters + ---------- + activations : dict + A mutable dictionary where the accumulated statistic will be stored. + Keys are layer names; values are initialized to 0 or an existing tensor. + layer_name : str + Key under which to accumulate this layer's statistic in `activations`. + + Returns + ------- + hook : callable + A forward hook with signature (module, input, output) suitable for + passing to `module.register_forward_hook`. On each forward pass, + it calls `get_coeff_h(design_matrix=input)` and adds the result to + `activations[layer_name]`. + + Notes + ----- + The hook expects the incoming `input` to be a 2-D tensor. + """ + + def hook(model, input, output): + assert input.ndim == 2 + coeff_h = get_coeff_h(design_matrix=input) + activations[layer_name] += coeff_h + + return hook + + +def register_hooks_to_collect_outs(prune_modules, prune_module_names, hook_fn): + """ + Attach forward hooks to a list of Conv2d modules to accumulate per-layer statistics. + + Parameters + ---------- + prune_modules : list[nn.Module] + The modules (layers) to attach hooks to (e.g. Conv2d layers to be pruned). + prune_module_names : list[str] + Names corresponding to each module in `prune_modules`. Must match length. + hook_fn : callable + A factory that accepts (activations: dict, layer_name: str) and returns + a forward hook function. For example, `save_and_accumulate_layer_input`. + + Returns + ------- + activations : dict + A dictionary keyed by `prune_module_names` whose values are initialized + to 0. On each forward pass, the registered hooks will add to these values. + + Notes + ----- + This utility is useful for computing and caching Gram matrices or other + layer input statistics over a calibration dataset before pruning. + """ + gram_activations = {name: 0.0 for name in prune_module_names} + for idx, module in enumerate(prune_modules): + assert isinstance(module, nn.Conv2d) + module_name = prune_module_names[idx] + module.register_forward_hook( + hook=hook_fn(activations=gram_activations, layer_name=module_name) + ) + + return gram_activations + + +def accumulate_xtx_statistics(model, calibration_dataloader): + """ + Run the model over the calibration data to accumulate per-layer + XᵀX (Gram) matrices using registered forward hooks. + """ + for images, _ in calibration_dataloader: + model(images) + + +def recompute_X(prune_mask, X, layer_in_channels, kernel_height, kernel_width): + """ + Recompute the unfolded input matrix X after pruning input channels. + + Parameters + ---------- + prune_mask : array-like of bool + Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). + X : torch.Tensor or np.ndarray + 2D unfolded input matrix of shape (N, M) where N corresponds to flattened channel–kernel elements and + M to the number of sliding positions or samples. + layer_in_channels : int + Number of input channels in the convolution layer prior to pruning. + kernel_height : int + Height of the convolution kernel. + kernel_width : int + Width of the convolution kernel. + + Returns + ------- + torch.Tensor or np.ndarray + Pruned unfolded input matrix with rows corresponding only to kept input channels. + """ + assert X.ndim == 2, "Weight matrix must have already been reshaped" + assert ( + len(prune_mask) == layer_in_channels + ), "The length of the indicator vector and the number of in_channels must be the same" + + # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix + numel_one_channel = kernel_height * kernel_width + _, X_width = X.shape + + slice_indices = np.arange(layer_in_channels) + mask = np.ones(X_width, dtype=bool) + slice_indices = [ + (start * numel_one_channel, start * numel_one_channel + numel_one_channel) + for start in slice_indices + ] + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + mask[start:stop] = False + + return X[:, mask] + + +def recompute_H(prune_mask, H, kernel_height, kernel_width, activation_out_channels): + """ + Recompute the coefficient matrix H after pruning output channels. + + Parameters + ---------- + prune_mask : array-like of bool + Boolean mask of length `activation_out_channels` indicating which output channels to keep (`True`) or drop (`False`). + H : torch.Tensor or np.ndarray + 2D square matrix of shape (N, N) typically representing a Hessian or covariance term over activations. + kernel_height : int + Height of the convolution kernel. + kernel_width : int + Width of the convolution kernel. + activation_out_channels : int + Number of output channels (activations) prior to pruning. + + Returns + ------- + torch.Tensor or np.ndarray + Pruned square matrix containing only rows and columns corresponding to kept output channels. + """ + assert len(prune_mask) == activation_out_channels + assert H.ndim == 2 + + kept_indices = [] + numel_one_channel = kernel_height * kernel_width + slice_indices = np.arange(activation_out_channels) + slice_indices = [ + (start * numel_one_channel, start * numel_one_channel + numel_one_channel) + for start in slice_indices + ] + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + kept_indices.extend(np.arange(start=start, stop=stop)) + + H_updated = H[np.ix_(kept_indices, kept_indices)] + + return H_updated + + +def recompute_W(prune_mask, W, layer_in_channels, kernel_height, kernel_width): + """ + Recompute the weight matrix W after pruning input channels. + + Parameters + ---------- + prune_mask : array-like of bool + Boolean mask of length `layer_in_channels` indicating which input channels to keep (`True`) or drop (`False`). + W : torch.Tensor or np.ndarray + 2D weight matrix of shape (N, M) where N corresponds to flattened channel–kernel elements. + layer_in_channels : int + Number of input channels in the convolution layer prior to pruning. + kernel_height : int + Height of the convolution kernel. + kernel_width : int + Width of the convolution kernel. + + Returns + ------- + torch.Tensor or np.ndarray + Pruned weight matrix with rows corresponding only to kept input channels. + """ + assert W.ndim == 2, "Weight matrix must have already been reshaped" + assert ( + len(prune_mask) == layer_in_channels + ), "The length of the indicator vector and the number of in_channels must be the same" + + # Number of elements needed to represent one filter in_channel in the reshaped 2d weight matrix + numel_one_channel = kernel_height * kernel_width + W_height, _ = W.shape + + slice_indices = np.arange(layer_in_channels) + mask = np.ones(W_height, dtype=bool) + + slice_indices = [ + (start * numel_one_channel, start * numel_one_channel + numel_one_channel) + for start in slice_indices + ] + print(slice_indices) + + for idx, indicator in enumerate(prune_mask): + if not indicator: + start, stop = slice_indices[idx] + mask[start:stop] = False + + return W[mask, :] + + +def compute_X_via_cholesky(A, B, C, eig_fallback_tol=1e-12): + """ + Compute X = A^{-1} @ (B @ C) assuming A is PSD (possibly singular). + - Adds jitter to A (A_reg = A + jitter*I). + - Attempts Cholesky. If it fails, falls back to eigendecomposition-based solve. + Returns X with same dtype/device as inputs. + """ + assert A.ndim == 2 and A.shape[0] == A.shape[1], "A must be square" + Y = B @ C + n = A.shape[0] + device = A.device + dtype = A.dtype + jitter = 1e-6 + + # Adpative scaling for proper regularization to treat ill-conditioned matrices + scale = max(torch.trace(A).abs().item() / n, torch.linalg.matrix_norm(A, ord='fro').item() / (n**0.5), 1.0) + lambda_j = jitter * scale + A_reg = A + lambda_j * torch.eye(n, device=device, dtype=dtype) + + # Cholesky decomposition only if A_ref is positive definite(SPD) + L, info = torch.linalg.cholesky_ex(A_reg) + if info == 0: + # LL^T X = Y => L Z = Y, then L^T X = Z + Z = torch.linalg.solve_triangular(L, Y, upper=False, left=True) + X = torch.linalg.solve_triangular(L.transpose(-2, -1), Z, upper=True, left=True) + return X + + # Fallback: eigen-decomposition for PSD + eigvals, eigvecs = torch.linalg.eigh(A) + reg_eig = eigvals + lambda_j + # If reg_eig very small, clamp with tol + reg_eig_clamped = torch.where(reg_eig.abs() < eig_fallback_tol, torch.full_like(reg_eig, eig_fallback_tol), reg_eig) + + # X = V diag(1/reg_eig_clamped) V^T Y + VtY = eigvecs.transpose(-2, -1) @ Y + scaled = VtY / reg_eig_clamped.unsqueeze(-1) + X = eigvecs @ scaled + return X + + +def get_optimal_W(gram_xx, gram_xy, dense_weights): + """ + Compute the optimal weight matrix for a layer in OSSCAR using precomputed Gram matrices. + + This function solves for W* in the least-squares sense: + min_W ||X W - Y||_F^2 + where: + - gram_xx = X^T X + - gram_xy = X^T Y_dense + - dense_weights = Y_dense reshaped as 2D matrix + + It handles potentially ill-conditioned or nearly singular Gram matrices using + adaptive Tikhonov regularization. Depending on the conditioning, it either: + 1. Uses Cholesky decomposition if the matrix is SPD. + 2. Falls back to eigen-decomposition with clamping for positive semi-definite cases. + + Parameters + ---------- + gram_xx : torch.Tensor, shape (N, N) + The input Gram matrix X^T X. Must be square. + gram_xy : torch.Tensor, shape (N, N) + The cross Gram matrix X^T Y_dense. Must be square. + dense_weights : torch.Tensor, shape (N, N) + The target dense weight matrix reshaped to 2D. + + Returns + ------- + torch.Tensor, shape (N, N) + The optimal weight matrix W* for the current layer. + + Notes + ----- + - Adds adaptive regularization proportional to the trace and Frobenius norm of `gram_xx`. + - Uses `torch.linalg.cholesky_ex` to safely attempt Cholesky decomposition. + - If Cholesky fails (matrix not SPD), falls back to eigen-decomposition and clamps small eigenvalues. + - Intended to be called once per layer during the first OSSCAR iteration. + + Raises + ------ + AssertionError + If input tensors are not 2D, or if shapes do not match as expected. + """ + + assert isinstance(gram_xx, torch.Tensor) + assert isinstance(gram_xy, torch.Tensor) + assert isinstance(dense_weights, torch.Tensor) + + assert gram_xx.ndim == 2, "Gram matrix XtX should be 2D" + assert ( + gram_xy.ndim == 2 + ), "Gram matrix XtY should be 2D" + assert dense_weights.ndim == 2, "Weights should have been reshaped to 2D" + assert gram_xx.shape[0] == gram_xx.shape[1], "Gram_xx must be square" + assert gram_xy.shape[0] == gram_xy.shape[1], "Gram_xy must be square" + + Y = gram_xy @ dense_weights + n = gram_xx.shape[0] + device = gram_xx.device + dtype = gram_xx.dtype + jitter = 1e-6 + eig_fallback_tol = 1e-12 + + # Adpative scaling for proper regularization to treat ill-conditioned matrices + scale = max(torch.trace(gram_xx).abs().item() / n, torch.linalg.matrix_norm(gram_xx, ord='fro').item() / (n**0.5), 1.0) + lambda_j = jitter * scale + gram_xx_reg = gram_xx + lambda_j * torch.eye(n, device=device, dtype=dtype) + + # Cholesky decomposition only if gram_xx_reg is positive definite(SPD) + L, info = torch.linalg.cholesky_ex(gram_xx_reg) + if info == 0: + # LL^T X = Y => L Z = Y, then L^T X = Z + Z = torch.linalg.solve_triangular(L, Y, upper=False, left=True) + X = torch.linalg.solve_triangular(L.transpose(-2, -1), Z, upper=True, left=True) + return X + + # Fallback: eigen-decomposition for PSD + eigvals, eigvecs = torch.linalg.eigh(gram_xx) + reg_eig = eigvals + lambda_j + # If reg_eig very small, clamp with tol + reg_eig_clamped = torch.where(reg_eig.abs() < eig_fallback_tol, torch.full_like(reg_eig, eig_fallback_tol), reg_eig) + + # X = V diag(1/reg_eig_clamped) V^T Y + VtY = eigvecs.transpose(-2, -1) @ Y + scaled = VtY / reg_eig_clamped.unsqueeze(-1) + X = eigvecs @ scaled + return X + +# def run_submodules_from_start_to_end_layer( +# start_layer_name, end_layer_name, model, input +# ): +# assert isinstance(model, nn.Module) +# x = input +# fwd_flag = False +# for name, module in model.named_modules(): +# if name == start_layer_name: +# fwd_flag = True + +# if fwd_flag and name == end_layer_name: +# return x +# elif fwd_flag: +# x = module(x) + + +if __name__ == "__main__": + # rand_filter = torch.randn((8, 3, 3, 3)) + # reshaped = reshape_filter(rand_filter) + # print(reshaped.size()) + # unfold = nn.Unfold( + # kernel_size=3, + # dilation=1, + # padding=1, + # stride=1 + # ) + # inp = unfold(rand_filter) + # inp = inp.permute([1, 0, 2]) + # inp = inp.flatten(1) + + # # print(inp.shape) + model = resnet50(pretrained=True) + # overall_prune_percentage = 0.3 + # # _, p, rem_params = get_count_prune_channels( + # # model=model, prune_percentage=overall_prune_percentage + # # ) + + # # print(f"P: {p}") + # # print(f"Rem parameters to prune: {rem_params}") + # layer_in_channels = 3 + # numel_one_channel = 9 + # slice_indices = np.arange(layer_in_channels, dtype=np.int32) + # slice_indices = [(start * numel_one_channel, start * numel_one_channel + numel_one_channel - 1) for start in slice_indices] + # print(slice_indices) + + # z = np.random.randint(low=0, high=2, size=layer_in_channels) + # print("Z:", z) + # W = np.random.randn(27, 1) + # print("W before: ", W) + # print("W after: ", get_matrix_I(z, W, layer_in_channels, 3, 3)) + + # layer_input = torch.randn((3, 3, 4)) + # output = get_coeff_h(layer_input) + # print(output.shape) + + # input_activation = torch.randn(2, 3, 32, 32) + # pruned_activation = torch.randn(2, 2, 32, 32) + # mini_conv = nn.Conv2d(padding=1, kernel_size=3, in_channels=3, out_channels=8) + + # out = reshape_conv_layer_input(input_activation, mini_conv) + + # W = reshape_filter(mini_conv.weight) + # print(f"Reshaped W before pruning: {W.shape}") + + # prune_mask = [1, 1, 0] + # W_pruned = recompute_W(prune_mask, W, layer_in_channels=3, kernel_height=3, kernel_width=3) + + # print(f"Reshaped W after pruning: {W_pruned.shape}") + + # print("Reshaped activation shape before pruning:", out.shape) + + # X = recompute_X(prune_mask=prune_mask, X=out, layer_in_channels=3, kernel_height=3, kernel_width=3) + + # print("Reshaped activation shape after pruning:", X.shape) + + # corr = get_XtY(out, X) + # print("XtY shape", corr.shape) + + # print(get_coeff_h(out).shape) + + # print(model.named_children) +