Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 0 additions & 90 deletions tests/torch/fx/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pathlib import Path

import torch.fx
import torch.nn.parallel
import torch.utils.data
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from fastdownload import FastDownload
from torch.fx.passes.graph_drawer import FxGraphDrawer

from nncf.experimental.torch.fx.transformations import apply_quantization_transformations
Comment on lines 13 to 16
Expand All @@ -30,90 +24,6 @@
from tests.cross_fw.test_templates.models import NNCFGraphToTestSumAggregation


class TinyImagenetDatasetManager:
DATASET_URL = "http://cs231n.stanford.edu/tiny-imagenet-200.zip"
DATASET_PATH = "~/.cache/nncf/tests/datasets"

def __init__(self, image_size: int, batch_size: int) -> None:
self.image_size = image_size
self.batch_size = batch_size

@staticmethod
def download_dataset() -> Path:
downloader = FastDownload(base=TinyImagenetDatasetManager.DATASET_PATH, archive="downloaded", data="extracted")
return downloader.get(TinyImagenetDatasetManager.DATASET_URL)

@staticmethod
def prepare_tiny_imagenet_200(dataset_dir: Path):
# Format validation set the same way as train set is formatted.
val_data_dir = dataset_dir / "val"
val_images_dir = val_data_dir / "images"
if not val_images_dir.exists():
return

val_annotations_file = val_data_dir / "val_annotations.txt"
with open(val_annotations_file) as f:
val_annotation_data = map(lambda line: line.split("\t")[:2], f.readlines())
for image_filename, image_label in val_annotation_data:
from_image_filepath = val_images_dir / image_filename
to_image_dir = val_data_dir / image_label
if not to_image_dir.exists():
to_image_dir.mkdir()
to_image_filepath = to_image_dir / image_filename
from_image_filepath.rename(to_image_filepath)
val_annotations_file.unlink()
val_images_dir.rmdir()

def create_data_loaders(self):
dataset_path = TinyImagenetDatasetManager.download_dataset()

TinyImagenetDatasetManager.prepare_tiny_imagenet_200(dataset_path)
print(f"Successfully downloaded and prepared dataset at: {dataset_path}")

train_dir = dataset_path / "train"
val_dir = dataset_path / "val"

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

train_dataset = datasets.ImageFolder(
train_dir,
transforms.Compose(
[
transforms.Resize(self.image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
]
),
)
val_dataset = datasets.ImageFolder(
val_dir,
transforms.Compose(
[
transforms.Resize(self.image_size),
transforms.ToTensor(),
normalize,
]
),
)

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=0, pin_memory=True, sampler=None
)

val_loader = torch.utils.data.DataLoader(
val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=0, pin_memory=True
)

# Creating separate dataloader with batch size = 1
# as dataloaders with batches > 1 are not supported yet.
calibration_dataset = torch.utils.data.DataLoader(
val_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=True
)

return train_loader, val_loader, calibration_dataset


def visualize_fx_model(model: torch.fx.GraphModule, output_svg_path: str):
g = FxGraphDrawer(model, output_svg_path)
g.get_dot_graph().write_svg(output_svg_path)
Comment on lines 27 to 29
Expand Down
13 changes: 12 additions & 1 deletion tests/torch/fx/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from tests.cross_fw.test_templates.helpers import YOLO26AttentionBlock
from tests.torch import test_models
from tests.torch.fx.helpers import get_torch_fx_model
from tests.torch.fx.test_sanity import count_q_dq
from tests.torch.test_models.synthetic import ConcatSameTensorModel
from tests.torch.test_models.synthetic import ConvReluBranchModel
from tests.torch.test_models.synthetic import EmbeddingSumModel
Expand Down Expand Up @@ -198,6 +197,18 @@ def test_model(test_case: ModelCase, regen_ref_data: bool):
)


def count_q_dq(model: torch.fx.GraphModule):
q, dq = 0, 0
for node in model.graph.nodes:
if node.op == "call_function" and hasattr(node.target, "overloadpacket"):
node_type = str(node.target.overloadpacket).split(".")[1]
if node_type in ["quantize_per_tensor", "quantize_per_channel"]:
q += 1
elif node_type in ["dequantize_per_tensor", "dequantize_per_channel"]:
dq += 1
return q, dq


@pytest.mark.parametrize("enable_dynamic_shapes", [True, False])
@pytest.mark.parametrize("compress_weights", [True, False])
@pytest.mark.parametrize(
Expand Down
139 changes: 0 additions & 139 deletions tests/torch/fx/test_sanity.py

This file was deleted.