From 5c59bc322cf6dead7a8ebea24bc59faa8b6610ed Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:19:27 +0530 Subject: [PATCH 1/7] Add display_subnet_info util --- compress/graph_extractor.py | 82 +++++++++++++------------------------ 1 file changed, 29 insertions(+), 53 deletions(-) diff --git a/compress/graph_extractor.py b/compress/graph_extractor.py index fec30c9..a8b8392 100644 --- a/compress/graph_extractor.py +++ b/compress/graph_extractor.py @@ -1,6 +1,5 @@ from typing import Optional from itertools import chain -import copy import numpy as np import torch @@ -9,7 +8,7 @@ from compress.heuristics import collect_convolution_layers_to_prune from compress.osscar_utils import get_external_nodes - + from model.resnet import resnet50 @@ -18,6 +17,7 @@ def clone_subnet(gm): env = {} for node in gm.graph.nodes: + def safe_lookup(old_node): if old_node in env: return env[old_node] @@ -67,7 +67,8 @@ def get_initial_prefix_submodule(gm, end_node, external_nodes): assert isinstance(gm, fx.GraphModule) prefix_graph = fx.Graph() env = {} - external_deps = set(chain.from_iterable(external_nodes.values())) + external_deps = set(external_nodes.values()) + external_deps_names = set(node.name for node in external_deps) out_dict = {} last_node = None @@ -80,11 +81,9 @@ def fetch_arg(n): if n in env: return env[n] - # If the original node is listed as an external dependency, create a placeholder - if n.name in external_deps: - ph = prefix_graph.placeholder(f"external__{n.name}") - env[n] = ph - return ph + ph = prefix_graph.placeholder(f"external__{n.name}") + env[n] = ph + return ph for node in gm.graph.nodes: if node.name == end_node: @@ -93,7 +92,7 @@ def fetch_arg(n): new_node = prefix_graph.node_copy(node, fetch_arg) env[node] = new_node last_node = new_node - if node.name in external_deps: + if node.name in external_deps_names: out_dict[node.name] = new_node # Explicitly adds output node @@ -144,7 +143,8 @@ def get_fx_submodule(gm, start_node, end_node, external_nodes): new_graph = fx.Graph() env = {} last_new_node = None - external_deps = set(chain.from_iterable(external_nodes.values())) + external_deps = set(external_nodes.values()) + external_deps_names = set(node.name for node in external_deps) out_dict = {} def fetch_arg(n): @@ -156,13 +156,7 @@ def fetch_arg(n): if n in env: return env[n] - # If the original node is listed as an external dependency, create a placeholder - if n.name in external_deps: - ph = new_graph.placeholder(f"external__{n.name}") - env[n] = ph - return ph - - # Is an fx Node but not added to env yet + # External dependencies ph = new_graph.placeholder(f"external__{n.name}") env[n] = ph return ph @@ -181,8 +175,8 @@ def fetch_arg(n): new_node = new_graph.node_copy(node, fetch_arg) env[node] = new_node last_new_node = new_node - if node.name in external_deps: - out_dict[node.name] = new_node + if node.name in external_deps_names: + out_dict[f"external__{node.name}"] = new_node # Explicitly add output node(s) if last_new_node is None: @@ -201,7 +195,7 @@ def fetch_arg(n): def get_suffix_submodule(gm, start_node, external_nodes): """ Extracts the suffix subgraph of an FX `GraphModule` beginning at - `start_node` and continuing through the model’s final output. + `start_node` and continuing through the model's final output. Parameters ---------- @@ -218,13 +212,13 @@ def get_suffix_submodule(gm, start_node, external_nodes): ------- suffix_gm : torch.fx.GraphModule A new GraphModule containing the suffix portion of the graph. - The output is the model’s final node as replicated in the suffix. + The output is the model's final node as replicated in the suffix. Notes ----- - Any argument not already present in the slice is turned into a placeholder named ``external__``. - - The slice runs from `start_node` to the original graph’s terminal + - The slice runs from `start_node` to the original graph's terminal output node. """ assert isinstance(external_nodes, dict), "External nodes must be a dictionary" @@ -232,7 +226,8 @@ def get_suffix_submodule(gm, start_node, external_nodes): new_graph = fx.Graph() env = {} last_node = None - external_deps = set(chain.from_iterable(external_nodes.values())) + external_deps = set(external_nodes.values()) + external_deps_names = set(node.name for node in external_deps) def fetch_arg(n): # For non-Node literals @@ -244,12 +239,6 @@ def fetch_arg(n): return env[n] # Create a placeholder for an external dependency - if n.name in external_deps: - ph = new_graph.placeholder(f"external__{n.name}") - env[n] = ph - return ph - - # Is an fx Node but not added to env yet ph = new_graph.placeholder(f"external__{n.name}") env[n] = ph return ph @@ -273,7 +262,7 @@ def fetch_arg(n): return suffix_gm -def get_all_subnets(prune_modules_name, graph_module): +def get_all_subnets(prune_modules_name, gm, external_nodes): """ Constructs the prefix, middle, and suffix subgraphs around the modules specified for pruning, producing both pruned and dense (unmodified) @@ -284,8 +273,10 @@ def get_all_subnets(prune_modules_name, graph_module): prune_modules_name : list[str] List of module names in `graph_module` that will be pruned, in the order they appear in the model. - graph_module : fx.GraphModule + gm : fx.GraphModule The full FX-traced model. + external_nodes : dict + Mapping of : node in present subraph -> upstream dependencies Returns ------- @@ -303,14 +294,12 @@ def get_all_subnets(prune_modules_name, graph_module): are independent and do not share graph state. """ assert isinstance( - graph_module, fx.GraphModule + gm, fx.GraphModule ), "graph_module must be an instance of fx.GraphModule" assert len(prune_modules_name) > 0, "Prune list must not be empty" prune_subnets = [] dense_subnets = [] - gm = graph_module - external_nodes = get_external_nodes(gm) first_node_name = "_".join(prune_modules_name[0].split(".")) prefix_subnet_dense = get_initial_prefix_submodule( gm=gm, end_node=first_node_name, external_nodes=external_nodes @@ -352,23 +341,10 @@ def get_all_subnets(prune_modules_name, graph_module): return prune_subnets, dense_subnets -if __name__ == "__main__": - model = resnet50(pretrained=True) - input = torch.randn(1, 3, 224, 224) - weights = model.conv1.weight - prune_conv_modules, prune_modules_name = collect_convolution_layers_to_prune( - model=model - ) - reformatted = [] - for name in prune_modules_name: - reformatted_name = "_".join(name.split(".")) - reformatted.append(reformatted_name) +def display_subnet_info(subnet: fx.GraphModule): + for node in subnet.graph.nodes: + print("Node_name: ", node.name) + print("Node_op:", node.op) + print("Node_args: ", node.args) + print("Node_kwargs: ", node.kwargs) - end = reformatted[0] - gm = fx.symbolic_trace(model) - print(reformatted[3], reformatted[10]) - - prune_subnets, dense_subnets = get_all_subnets( - graph_module=gm, prune_modules_name=prune_modules_name - ) - assert len(prune_subnets) == len(dense_subnets) From 9a1f8f73f40b787aaf669b00d43ba49bfa27f226 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:20:13 +0530 Subject: [PATCH 2/7] Add subnet execution logic --- compress/osscar_utils.py | 192 +++++++++++++++++++++++++++++++++++---- 1 file changed, 174 insertions(+), 18 deletions(-) diff --git a/compress/osscar_utils.py b/compress/osscar_utils.py index 587ba59..7eb0b47 100644 --- a/compress/osscar_utils.py +++ b/compress/osscar_utils.py @@ -1,5 +1,6 @@ import random from typing import cast, Optional +from itertools import chain import numpy as np import torch @@ -760,7 +761,7 @@ def get_parent_module(model, target_module): (nn.Module, str): Parent module and the submodule's attribute name. Raises: - ValueError: If the target module isn’t found. + ValueError: If the target module isn't found. """ for _, module in model.named_modules(): for child_name, child in module.named_children(): @@ -783,6 +784,7 @@ def replace_module(model, target_module, new_module): parent, name = get_parent_module(model, target_module) setattr(parent, name, new_module) + def perform_local_search( dense_weights, layer, @@ -843,6 +845,8 @@ def perform_local_search( total_channels = kept_channels.copy() keep_mask = torch.ones(layer.in_channels, dtype=torch.bool) + print(f"Pruning layer : {layer._get_name}...") + # Iterative greedy pruning, where t=p and hence s1 = 0 for i in range(len(prune_list)): num_prune_iter = prune_list[i] @@ -946,6 +950,8 @@ def prune_one_layer( N = num_batches * batch_size # Get conv module to prune + # Debugging note : Is the first named_module guaranteed to be a nn.Conv2d + # even after the changes involving the placeholders? conv_module = next( m for _, m in pruned_subnet.named_modules() if isinstance(m, nn.Conv2d) ) @@ -961,12 +967,7 @@ def prune_one_layer( total_xtx = get_coeff_h(pruned_X) / N total_xty = get_XtY(pruned_X, dense_X) / N - # # Optimal weights - # w_optimal = get_optimal_W( - # gram_xx=total_xtx, gram_xy=total_xty, dense_weights=reshaped_conv_wt - # ) - - keep_mask, kept_channels, removed_channels = perform_local_search( + keep_mask, _, _ = perform_local_search( dense_weights=reshaped_conv_wt, layer=conv_module, p=layer_prune_channels, @@ -974,16 +975,6 @@ def prune_one_layer( gram_xy=total_xty, ) - # cached_out_pruned = [] - # cached_out_dense = [] - - # for batch_idx in range(num_batches): - # cached_out_pruned.append(pruned_subnet(pruned_input[batch_idx])) - # cached_out_dense.append(dense_subnet(dense_input[batch_idx])) - - # cached_out_pruned = torch.cat(cached_out_pruned, dim=0) - # cached_out_dense = torch.cat(cached_out_dense, dim=0) - new_weight = conv_module.weight[:, keep_mask, :, :] if conv_module.bias is not None: new_bias = conv_module.bias @@ -1018,6 +1009,8 @@ def prune_one_layer( if new_bias is not None: new_conv_module.bias.data = new_bias.clone() + # Debugging note : Does replace_module still work aftet the inclusion of the placeholder + # + residual connection logic? replace_module( model=pruned_subnet, target_module=conv_module, new_module=new_conv_module ) @@ -1025,6 +1018,167 @@ def prune_one_layer( return pruned_subnet, keep_mask +class DenseSubnetRunner(nn.Module): + def __init__(self, subnet_gm: torch.fx.GraphModule, external_nodes: dict): + super().__init__() + self.gm = subnet_gm + self.external_nodes = external_nodes + self.upstream_deps = external_nodes.values() + self.upstream_node_names = [node.name for node in self.upstream_deps] + + def get_prefix_removed_name(self, prefix_name, prefix="external__"): + return prefix_name[len(prefix) :] + + def forward(self, x, ctx: dict, batch_idx: int = 0): + """ + Execute subnet with `x` and `ctx`, returning output + updated `ctx`. + """ + inputs = {} + for node in self.gm.graph.nodes: + if node.op == "placeholder": + node_name = node.name + if self.get_prefix_removed_name(node_name) in self.upstream_node_names: + ctx_tensor = ctx[node_name] + inputs[node_name] = ctx_tensor[batch_idx] + else: + inputs[node_name] = x + + # for k, t in inputs.items(): + # if isinstance(t, torch.Tensor): + # print(f"Debse Runner input {k}: ndim={t.ndim}, shape={t.shape}") + # else: + # print(f"Dense Runner input {k}: type={type(t)} (BAD)") + + # print("Input keys : ", inputs.keys()) + out = self.gm(**inputs) + + if isinstance(out, dict): + for k, v in out.items(): + if k != "output": + if k not in ctx: + ctx[k] = [] + ctx[k].append(v) + return out["output"], ctx + + return out, ctx + + +class SubnetRunner(nn.Module): + def __init__( + self, + subnet_gm: torch.fx.GraphModule, + external_nodes: dict, + pruned_conv_target, + ): + """ + Runs a sliced FX subgraph and automatically supplies required inputs. + + Parameters + ---------- + subnet_gm : torch.fx.GraphModule + The sliced subgraph whose placeholders include: + - One real input placeholder (receives `x`) + - External placeholders (filled from `ctx`) + external_nodes : dict + Mapping: sliced_placeholder_name → original_node_name. + Used to fetch upstream activations from `ctx` during execution. + + Notes + ----- + - Placeholders not in `external_nodes.values()` get `x`. + - External placeholders get `ctx[original_name]`. + - Dict outputs store all keys except `"output"` into `ctx`. + """ + super().__init__() + self.gm = subnet_gm + self.external_nodes = external_nodes + self.upstream_deps = external_nodes.values() + self.downstream_consumers = external_nodes.keys() + print("External nodes: ", self.external_nodes) + self.upstream_node_names = [node.name for node in self.upstream_deps] + self.pruned_conv_target = pruned_conv_target + print(f"Prune_target is {self.pruned_conv_target}") + self.placeholder_mask_map = self.build_placeholder_mask_map() + + def build_placeholder_mask_map(self): + """ + Returns: + dict[str, bool] : placeholder_name -> should_mask + """ + + graph = self.gm.graph + + # Locate which node(conv2d layer) was pruned in this subnet + pruned_node = None + for node in graph.nodes: + if node.op == "call_module" and node.name == self.pruned_conv_target: + pruned_node = node + break + + if pruned_node is None: + raise RuntimeError( + f"Pruned conv {self.pruned_conv_target} not found in subnet graph" + ) + + # Backward BFS - to locate all ancestors of the pruned_conv_target node + visited = set() + stack = [pruned_node] + + while stack: + n = stack.pop() + if n in visited: + continue + visited.add(n) + + for arg in n.all_input_nodes: + stack.append(arg) + + placeholder_mask_map = {} + + # Only masks + for node in graph.nodes: + if node.op == "placeholder": + placeholder_mask_map[node.name] = node in visited + + return placeholder_mask_map + + def get_prefix_removed_name(self, prefix_name, prefix="external__"): + return prefix_name[len(prefix) :] + + def forward(self, x, ctx: dict, batch_idx: int = 0, keep_mask: list = []): + inputs = {} + + for node in self.gm.graph.nodes: + if node.op != "placeholder": + continue + + name = node.name + + # Real inputs + if self.get_prefix_removed_name(name) not in self.upstream_node_names: + inputs[name] = x + continue + + # This placeholder comes from ctx + tensor = ctx[name][batch_idx] + + # Mask if BFS says so + if self.placeholder_mask_map.get(name, False): + tensor = tensor[:, keep_mask, :, :] + + inputs[name] = tensor + + out = self.gm(**inputs) + + if isinstance(out, dict): + for k, v in out.items(): + if k != "output": + ctx.setdefault(k, []).append(v) + return out["output"], ctx + + return out, ctx + + def is_real_consumer(node): # Ignores shape-only ops if node.op == "call_method" and node.target in [ @@ -1043,6 +1197,8 @@ def get_external_nodes(gm: torch.fx.GraphModule): for node in gm.graph.nodes: real_users = [u for u in node.users if is_real_consumer(u)] if len(real_users) > 1: - external_nodes[node.name] = real_users + for user in real_users: + if "downsample" in str(user.target) or "add" in str(user.target): + external_nodes[user.name] = node return external_nodes From e9c1472957da31c516cde6aed8dba875bf1438a3 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:20:52 +0530 Subject: [PATCH 3/7] Update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index a37252d..ace8474 100644 --- a/.gitignore +++ b/.gitignore @@ -210,4 +210,5 @@ __marimo__/ /data/test /profiler/trace.json /logs/profiler -/.vscode \ No newline at end of file +/.vscode +/.VSCodeCounter \ No newline at end of file From a99b040c3809db3d2114f2382ea1e839bd0a0814 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:21:54 +0530 Subject: [PATCH 4/7] Fix node dependency bug --- compress/osscar.py | 203 ++++++++++++++++++++++++++++----------------- 1 file changed, 128 insertions(+), 75 deletions(-) diff --git a/compress/osscar.py b/compress/osscar.py index 5ec5ca6..d5161a4 100644 --- a/compress/osscar.py +++ b/compress/osscar.py @@ -2,46 +2,66 @@ import torch import torch.nn as nn -import numpy as np import torch.fx as fx from compress.graph_extractor import get_all_subnets from compress.heuristics import collect_convolution_layers_to_prune from compress.osscar_utils import ( get_count_prune_channels, + SubnetRunner, + DenseSubnetRunner, is_real_consumer, get_external_nodes, - prune_one_layer + prune_one_layer, ) from data.load_data import build_calibration_dataloader, build_eval_dataset from model.resnet import resnet50 from utils.utils import set_global_seed, safe_free, load_yaml -def run_forward_with_mask(subnet, input, input_mask=None, is_input_loader=False): +def run_forward_with_mask( + subnet, + input, + input_mask=None, + is_input_loader=False, + external_nodes=None, + ctx=None, + prune_channel_name=None, +): """ - Runs forward passes through a subnet, optionally masking input channels. - - Supports either a preloaded tensor of shape (num_batches, batch_size, C, H, W) - or a DataLoader. The outputs from all batches are stacked along a new leading - dimension for forming gram matrices. + Run a subnet forward pass over batched inputs (tensor or DataLoader), optionally + masking input channels. Uses SubnetRunner so external dependencies are pulled + from `ctx`, which is updated with new intermediate outputs. Args: - subnet (nn.Module): The model or subnet to evaluate. - input (torch.Tensor | DataLoader): Batched tensor input or a DataLoader. - input_mask (torch.BoolTensor, optional): Channel-wise boolean mask; only - channels with `True` are retained. - is_input_loader (bool, default=False): If True, treats `input` as a DataLoader. + subnet (fx.GraphModule): Subnet to execute. + input (Tensor | DataLoader): 5D tensor (num_batches, B, C, H, W) or loader. + input_mask (BoolTensor, optional): Channel mask. + is_input_loader (bool): Whether `input` is a DataLoader. + external_nodes (dict): External dependency mapping. + ctx (dict): Context shared across subnets. Returns: - cached_input (torch.Tensor): Stacked outputs for all batches, of shape - (num_batches, batch_size, C_out, H_out, W_out). + cached_input (Tensor): Stacked per-batch outputs. + ctx (dict): Updated context. """ + assert external_nodes is not None, "external_nodes must exist even as an empty dict" + assert ctx is not None, "Context dictionary cannot be None" calibration_batches = [] + if input_mask is None: + subnet_runner = DenseSubnetRunner( + subnet_gm=subnet, external_nodes=external_nodes + ) + else: + subnet_runner = SubnetRunner( + subnet_gm=subnet, + external_nodes=external_nodes, + pruned_conv_target=prune_channel_name, + ) if is_input_loader: assert isinstance(input, torch.utils.data.DataLoader) for images, _ in input: - outs = subnet(images) + outs, ctx = subnet_runner(images, ctx) calibration_batches.append(outs) else: assert isinstance(input, torch.Tensor) @@ -49,60 +69,73 @@ def run_forward_with_mask(subnet, input, input_mask=None, is_input_loader=False) num_batches = input.shape[0] if input_mask is not None: - input_tensor = input[:, :, input_mask, :, :] + masked_input = input[:, :, input_mask, :, :] + keep_mask = input_mask + for i in range(num_batches): + current_batch = masked_input[i, :, :, :, :] + output_tensor, ctx = subnet_runner( + x=current_batch, ctx=ctx, keep_mask=keep_mask, batch_idx=i + ) + calibration_batches.append(output_tensor) else: - input_tensor = input - - # Note to self: could include all elements in 1 batch effectively making the input_tensor 4D - # The for loop can then be avoided but might make the solution too memory-intensive - for i in range(num_batches): - current_batch = input_tensor[i, :, :, :, :] - output_tensor = subnet(current_batch) - calibration_batches.append(output_tensor) + masked_input = input + for i in range(num_batches): + current_batch = masked_input[i, :, :, :, :] + output_tensor, ctx = subnet_runner( + x=current_batch, ctx=ctx, batch_idx=i + ) + calibration_batches.append(output_tensor) + cached_input = torch.stack(calibration_batches, dim=0) + for k, v in ctx.items(): + if isinstance(v, list): + ctx[k] = torch.stack(v, dim=0) - return cached_input + return cached_input, ctx + + +def transform_prune_name(name): + name = name.split(".") + return "_".join(name) def run_osscar(model, calibration_loader, args): """ - Apply OSSCAR-style structured pruning to a model using a calibration dataset. - - The function: - 1. Determines how many channels to prune per layer based on `args.prune_percentage`. - 2. Identifies convolution layers eligible for pruning. - 3. Symbolically traces the model to extract subnets for pruning. - 4. Iteratively prunes layers using local greedy search while caching intermediate activations. - 5. Returns a list of pruned subnets forming the pruned model and the per-layer keep masks. - - Parameters - ---------- - model : nn.Module - The original dense PyTorch model to prune. - calibration_loader : torch.utils.data.DataLoader - Data loader providing batches for calibration / activation caching. - args : Namespace - Arguments containing `prune_percentage` among other potential config values. - - Returns - ------- - pruned_model : list[nn.Module] - List of subnets forming the pruned model. - keep_masks : list[torch.BoolTensor] - Per-layer boolean masks indicating which input channels were kept. - """ + Run OSSCAR-style structured channel pruning on a model using a calibration set. + + The procedure: + 1. Compute how many channels to prune per Conv layer. + 2. Symbolically trace the model and split it into prefix/middle/suffix subnets. + 3. Run each subnet on calibration data to cache activations. + 4. Iteratively prune layers using a local greedy search. + 5. Recompute cached inputs after each pruning step. + Args: + model (nn.Module): Dense model to prune. + calibration_loader (DataLoader): Data for activation calibration. + args (Namespace): Must contain `prune_percentage`. + + Returns: + pruned_model (list[nn.Module]): Ordered list of pruned subnets. + keep_masks (list[BoolTensor]): Per-layer channel masks indicating which + input channels were kept after pruning. + """ prune_percentage = args.prune_percentage - channels_post_prune, prune_channels_by_layer, remaining_params = ( - get_count_prune_channels(model=model, prune_percentage=prune_percentage) + _, prune_channels_by_layer, _ = get_count_prune_channels( + model=model, prune_percentage=prune_percentage ) + # if "conv1" in prune_channels_by_layer: + # prune_channels_by_layer.remove("conv1") # print(prune_channels_by_layer) _, prune_modules_name = collect_convolution_layers_to_prune(model=model) - # print(prune_modules_name) + # if "conv1" in prune_modules_name: + # prune_modules_name.remove("conv1") + # print("Channels to prune : ", prune_modules_name) gm = fx.symbolic_trace(model) + external_nodes = get_external_nodes(gm) prune_subnets, dense_subnets = get_all_subnets( - graph_module=gm, prune_modules_name=prune_modules_name + gm=gm, prune_modules_name=prune_modules_name, external_nodes=external_nodes ) assert len(prune_subnets) == len(dense_subnets) @@ -110,16 +143,19 @@ def run_osscar(model, calibration_loader, args): pruned_model = [] prefix_subnet = prune_subnets[0] pruned_model.append(prefix_subnet) + ctx = {} + dense_ctx = {} - dense_cached_input = run_forward_with_mask( + dense_cached_input, dense_ctx = run_forward_with_mask( subnet=prefix_subnet, input=calibration_loader, input_mask=None, is_input_loader=True, - ).detach() - cached_input = dense_cached_input + external_nodes=external_nodes, + ctx=dense_ctx, + ) + cached_input = dense_cached_input.detach() - # Bug: This code shouldn't execute when layer_prune_channels = 0 for i in range(1, len(dense_subnets)): subnet_post_pruning, keep_mask = prune_one_layer( dense_subnet=dense_subnets[i], @@ -128,18 +164,43 @@ def run_osscar(model, calibration_loader, args): pruned_input=cached_input, layer_prune_channels=int(prune_channels_by_layer[i - 1]), ) - + prune_channel_name = prune_modules_name[i - 1] + prune_channel_name = transform_prune_name(prune_channel_name) + print(f"Prune_channel_name = {prune_channel_name}") + from compress.graph_extractor import display_subnet_info + + display_subnet_info(subnet_post_pruning) + print(f"Keep mask after iter {i} : {keep_mask.sum()}") + print(f"Shape of input to the dense forward : {dense_cached_input.shape}") pruned_model.append(subnet_post_pruning) keep_masks.append(keep_mask) - new_dense = run_forward_with_mask( - subnet=dense_subnets[i], input=dense_cached_input, input_mask=None - ).detach() - new_pruned = run_forward_with_mask( - subnet=subnet_post_pruning, input=cached_input, input_mask=keep_mask - ).detach() + new_dense, dense_ctx = run_forward_with_mask( + subnet=dense_subnets[i], + input=dense_cached_input, + input_mask=None, + external_nodes=external_nodes, + ctx=dense_ctx, + ) + print(f"Shape of dense_cached_input after iter {i} : {new_dense.shape}") + ctx_vals = [(k, v.shape) for k, v in ctx.items()] + print(f"Ctx vals shape after iter {i}: {ctx_vals}") + print(f"Shape of input to prune forward : {cached_input.shape}") + + new_pruned, ctx = run_forward_with_mask( + subnet=subnet_post_pruning, + input=cached_input, + input_mask=keep_mask, + external_nodes=external_nodes, + ctx=ctx, + prune_channel_name=prune_channel_name, + ) + print( + f"Shape of cached_input after prune forward {i} times : {new_pruned.shape}" + ) + print(f"Ctx after running prune forward {i} times : {ctx.keys()}") safe_free(dense_cached_input, cached_input) - dense_cached_input, cached_input = new_dense, new_pruned + dense_cached_input, cached_input = new_dense.detach(), new_pruned.detach() return pruned_model, keep_masks @@ -184,14 +245,6 @@ def forward(self, x): ) model = resnet50(pretrained=True) - # activation_in_channels = 3 - # numel_one_channel = 9 - # slice_indices = np.arange(activation_in_channels) - # slice_indices = [ - # (start * numel_one_channel, start * numel_one_channel + numel_one_channel) - # for start in slice_indices - # ] - # print(slice_indices) parser = argparse.ArgumentParser(description="Arguments for OSSCAR") parser.add_argument("--prune_percentage", default=0.25, type=float) args = parser.parse_args() From 20869e26ca6a115f16400a57872567bc57d4af7f Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:41:11 +0530 Subject: [PATCH 5/7] Add script to run OSSCAR --- scripts/run_osscar.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) create mode 100644 scripts/run_osscar.py diff --git a/scripts/run_osscar.py b/scripts/run_osscar.py new file mode 100644 index 0000000..e24aade --- /dev/null +++ b/scripts/run_osscar.py @@ -0,0 +1,33 @@ +import argparse + +import torch + +from data.load_data import build_calibration_dataloader, build_eval_dataset +from model.resnet import resnet50 +from utils.utils import set_global_seed, load_yaml +from compress.osscar import run_osscar + + +if __name__ == "__main__": + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + set_global_seed(seed=232) + g = torch.Generator() + g.manual_seed(11) + + cfg_path = "config/config.yaml" + cfg = load_yaml(cfg_path) + assert cfg is not None, "Config cannot be empty or None" + calibration_dataset = build_eval_dataset(cfg=cfg["eval"]) + calibration_dataloader = build_calibration_dataloader( + dataset=calibration_dataset, num_samples=500, g=g, batch_size=32 + ) + + model = resnet50(pretrained=True) + parser = argparse.ArgumentParser(description="Arguments for OSSCAR") + parser.add_argument("--prune_percentage", default=0.25, type=float) + # parser.add_argument("--prune_layers", default=None, type=list, help="Provide names of layers to be pruned using OSSCAR. If None, the algorithm makes this decision by itself") + args = parser.parse_args() + pruned_model_list, keep_mask_list = run_osscar( + model=model, calibration_loader=calibration_dataloader, args=args + ) From 1f55d743b3294433e1b81a57d6c7694107f92ec2 Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:41:46 +0530 Subject: [PATCH 6/7] Add initial logic to physically prune modules --- compress/prune.py | 25 +++++++++++++++++++++++++ 1 file changed, 25 insertions(+) create mode 100644 compress/prune.py diff --git a/compress/prune.py b/compress/prune.py new file mode 100644 index 0000000..e04f31d --- /dev/null +++ b/compress/prune.py @@ -0,0 +1,25 @@ +import torch.nn as nn + + +class PrunedResnet50(nn.Module): + def __init__(self, pruned_modules_list, keep_mask_list): + super().__init__() + assert len(pruned_modules_list) - 1 == len(keep_mask_list) + self.module_list = nn.ModuleList(pruned_modules_list) + + # Register masks as buffers so that they can be saved/loaded along with the model + for i, mask in enumerate(keep_mask_list): + self.register_buffer(f"input_mask_{i}", mask) + + self.keep_mask_list = keep_mask_list + + def forward(self, x): + # The first module has no input mask + x = self.module_list[0](x) + + # The mask of the ith module in module_list is the i-1th entry of keep_mask_list + for i in range(1, len(self.module_list)): + x = x[:, self.keep_mask_list[i - 1], :, :] + x = self.module_list[i](x) + + return x From 21c175b6cb5468afd882a0e16e82f41a8f8633ff Mon Sep 17 00:00:00 2001 From: Soulsharp <138900246+soulsharp@users.noreply.github.com> Date: Sat, 6 Dec 2025 23:42:09 +0530 Subject: [PATCH 7/7] Lint related changes --- compress/graph_extractor.py | 1 - compress/osscar.py | 25 ------------------------- compress/osscar_utils.py | 19 +------------------ 3 files changed, 1 insertion(+), 44 deletions(-) diff --git a/compress/graph_extractor.py b/compress/graph_extractor.py index a8b8392..b52d385 100644 --- a/compress/graph_extractor.py +++ b/compress/graph_extractor.py @@ -347,4 +347,3 @@ def display_subnet_info(subnet: fx.GraphModule): print("Node_op:", node.op) print("Node_args: ", node.args) print("Node_kwargs: ", node.kwargs) - diff --git a/compress/osscar.py b/compress/osscar.py index d5161a4..eff4664 100644 --- a/compress/osscar.py +++ b/compress/osscar.py @@ -205,30 +205,6 @@ def run_osscar(model, calibration_loader, args): return pruned_model, keep_masks -class PrunedResnet50(nn.Module): - def __init__(self, pruned_modules_list, keep_mask_list): - super().__init__() - assert len(pruned_modules_list) - 1 == len(keep_mask_list) - self.module_list = nn.ModuleList(pruned_modules_list) - - # Register masks as buffers so that they can be saved/loaded along with the model - for i, mask in enumerate(keep_mask_list): - self.register_buffer(f"input_mask_{i}", mask) - - self.keep_mask_list = keep_mask_list - - def forward(self, x): - # The first module has no input mask - x = self.module_list[0](x) - - # The mask of the ith module in module_list is the i-1th entry of keep_mask_list - for i in range(1, len(self.module_list)): - x = x[:, self.keep_mask_list[i - 1], :, :] - x = self.module_list[i](x) - - return x - - if __name__ == "__main__": torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False @@ -251,4 +227,3 @@ def forward(self, x): pruned_model_list, keep_mask_list = run_osscar( model=model, calibration_loader=calibration_dataloader, args=args ) - pruned_model = PrunedResnet50(pruned_model_list, keep_mask_list) diff --git a/compress/osscar_utils.py b/compress/osscar_utils.py index 7eb0b47..e6c541a 100644 --- a/compress/osscar_utils.py +++ b/compress/osscar_utils.py @@ -879,13 +879,7 @@ def perform_local_search( activation_in_channels=layer.in_channels, is_pure_gram=False, ) - # sub_w_optimal = recompute_W( - # prune_mask=temp_keep_mask, - # W=w_optimal, - # activation_in_channels=layer.in_channels, - # kernel_height=layer.kernel_size[0], - # kernel_width=layer.kernel_size[1], - # ) + submatrix_w_pruned = recompute_W( prune_mask=temp_keep_mask, W=dense_weights, @@ -948,10 +942,6 @@ def prune_one_layer( num_batches, batch_size, C, H, W = dense_input.shape N = num_batches * batch_size - - # Get conv module to prune - # Debugging note : Is the first named_module guaranteed to be a nn.Conv2d - # even after the changes involving the placeholders? conv_module = next( m for _, m in pruned_subnet.named_modules() if isinstance(m, nn.Conv2d) ) @@ -1043,13 +1033,6 @@ def forward(self, x, ctx: dict, batch_idx: int = 0): else: inputs[node_name] = x - # for k, t in inputs.items(): - # if isinstance(t, torch.Tensor): - # print(f"Debse Runner input {k}: ndim={t.ndim}, shape={t.shape}") - # else: - # print(f"Dense Runner input {k}: type={type(t)} (BAD)") - - # print("Input keys : ", inputs.keys()) out = self.gm(**inputs) if isinstance(out, dict):