Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -210,4 +210,5 @@ __marimo__/
/data/test
/profiler/trace.json
/logs/profiler
/.vscode
/.vscode
/.VSCodeCounter
83 changes: 29 additions & 54 deletions compress/graph_extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Optional
from itertools import chain
import copy

import numpy as np
import torch
Expand All @@ -9,7 +8,7 @@

from compress.heuristics import collect_convolution_layers_to_prune
from compress.osscar_utils import get_external_nodes

from model.resnet import resnet50


Expand All @@ -18,6 +17,7 @@ def clone_subnet(gm):
env = {}

for node in gm.graph.nodes:

def safe_lookup(old_node):
if old_node in env:
return env[old_node]
Expand Down Expand Up @@ -67,7 +67,8 @@ def get_initial_prefix_submodule(gm, end_node, external_nodes):
assert isinstance(gm, fx.GraphModule)
prefix_graph = fx.Graph()
env = {}
external_deps = set(chain.from_iterable(external_nodes.values()))
external_deps = set(external_nodes.values())
external_deps_names = set(node.name for node in external_deps)
out_dict = {}
last_node = None

Expand All @@ -80,11 +81,9 @@ def fetch_arg(n):
if n in env:
return env[n]

# If the original node is listed as an external dependency, create a placeholder
if n.name in external_deps:
ph = prefix_graph.placeholder(f"external__{n.name}")
env[n] = ph
return ph
ph = prefix_graph.placeholder(f"external__{n.name}")
env[n] = ph
return ph

for node in gm.graph.nodes:
if node.name == end_node:
Expand All @@ -93,7 +92,7 @@ def fetch_arg(n):
new_node = prefix_graph.node_copy(node, fetch_arg)
env[node] = new_node
last_node = new_node
if node.name in external_deps:
if node.name in external_deps_names:
out_dict[node.name] = new_node

# Explicitly adds output node
Expand Down Expand Up @@ -144,7 +143,8 @@ def get_fx_submodule(gm, start_node, end_node, external_nodes):
new_graph = fx.Graph()
env = {}
last_new_node = None
external_deps = set(chain.from_iterable(external_nodes.values()))
external_deps = set(external_nodes.values())
external_deps_names = set(node.name for node in external_deps)
out_dict = {}

def fetch_arg(n):
Expand All @@ -156,13 +156,7 @@ def fetch_arg(n):
if n in env:
return env[n]

# If the original node is listed as an external dependency, create a placeholder
if n.name in external_deps:
ph = new_graph.placeholder(f"external__{n.name}")
env[n] = ph
return ph

# Is an fx Node but not added to env yet
# External dependencies
ph = new_graph.placeholder(f"external__{n.name}")
env[n] = ph
return ph
Expand All @@ -181,8 +175,8 @@ def fetch_arg(n):
new_node = new_graph.node_copy(node, fetch_arg)
env[node] = new_node
last_new_node = new_node
if node.name in external_deps:
out_dict[node.name] = new_node
if node.name in external_deps_names:
out_dict[f"external__{node.name}"] = new_node

# Explicitly add output node(s)
if last_new_node is None:
Expand All @@ -201,7 +195,7 @@ def fetch_arg(n):
def get_suffix_submodule(gm, start_node, external_nodes):
"""
Extracts the suffix subgraph of an FX `GraphModule` beginning at
`start_node` and continuing through the models final output.
`start_node` and continuing through the model's final output.

Parameters
----------
Expand All @@ -218,21 +212,22 @@ def get_suffix_submodule(gm, start_node, external_nodes):
-------
suffix_gm : torch.fx.GraphModule
A new GraphModule containing the suffix portion of the graph.
The output is the models final node as replicated in the suffix.
The output is the model's final node as replicated in the suffix.

Notes
-----
- Any argument not already present in the slice is turned into a
placeholder named ``external__<node_name>``.
- The slice runs from `start_node` to the original graphs terminal
- The slice runs from `start_node` to the original graph's terminal
output node.
"""
assert isinstance(external_nodes, dict), "External nodes must be a dictionary"
graph = gm.graph
new_graph = fx.Graph()
env = {}
last_node = None
external_deps = set(chain.from_iterable(external_nodes.values()))
external_deps = set(external_nodes.values())
external_deps_names = set(node.name for node in external_deps)

def fetch_arg(n):
# For non-Node literals
Expand All @@ -244,12 +239,6 @@ def fetch_arg(n):
return env[n]

# Create a placeholder for an external dependency
if n.name in external_deps:
ph = new_graph.placeholder(f"external__{n.name}")
env[n] = ph
return ph

# Is an fx Node but not added to env yet
ph = new_graph.placeholder(f"external__{n.name}")
env[n] = ph
return ph
Expand All @@ -273,7 +262,7 @@ def fetch_arg(n):
return suffix_gm


def get_all_subnets(prune_modules_name, graph_module):
def get_all_subnets(prune_modules_name, gm, external_nodes):
"""
Constructs the prefix, middle, and suffix subgraphs around the modules
specified for pruning, producing both pruned and dense (unmodified)
Expand All @@ -284,8 +273,10 @@ def get_all_subnets(prune_modules_name, graph_module):
prune_modules_name : list[str]
List of module names in `graph_module` that will be pruned, in the
order they appear in the model.
graph_module : fx.GraphModule
gm : fx.GraphModule
The full FX-traced model.
external_nodes : dict
Mapping of : node in present subraph -> upstream dependencies

Returns
-------
Expand All @@ -303,14 +294,12 @@ def get_all_subnets(prune_modules_name, graph_module):
are independent and do not share graph state.
"""
assert isinstance(
graph_module, fx.GraphModule
gm, fx.GraphModule
), "graph_module must be an instance of fx.GraphModule"
assert len(prune_modules_name) > 0, "Prune list must not be empty"
prune_subnets = []
dense_subnets = []

gm = graph_module
external_nodes = get_external_nodes(gm)
first_node_name = "_".join(prune_modules_name[0].split("."))
prefix_subnet_dense = get_initial_prefix_submodule(
gm=gm, end_node=first_node_name, external_nodes=external_nodes
Expand Down Expand Up @@ -352,23 +341,9 @@ def get_all_subnets(prune_modules_name, graph_module):
return prune_subnets, dense_subnets


if __name__ == "__main__":
model = resnet50(pretrained=True)
input = torch.randn(1, 3, 224, 224)
weights = model.conv1.weight
prune_conv_modules, prune_modules_name = collect_convolution_layers_to_prune(
model=model
)
reformatted = []
for name in prune_modules_name:
reformatted_name = "_".join(name.split("."))
reformatted.append(reformatted_name)

end = reformatted[0]
gm = fx.symbolic_trace(model)
print(reformatted[3], reformatted[10])

prune_subnets, dense_subnets = get_all_subnets(
graph_module=gm, prune_modules_name=prune_modules_name
)
assert len(prune_subnets) == len(dense_subnets)
def display_subnet_info(subnet: fx.GraphModule):
for node in subnet.graph.nodes:
print("Node_name: ", node.name)
print("Node_op:", node.op)
print("Node_args: ", node.args)
print("Node_kwargs: ", node.kwargs)
Loading