Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 42 additions & 116 deletions main_benchmark.py
Original file line number Diff line number Diff line change
@@ -1,125 +1,51 @@
#
# For licensing see accompanying LICENSE file.
# Copyright (C) 2023 Apple Inc. All Rights Reserved.
#

# main_benchmark.py
"""
Benchmark model inference speed (throughput and latency).
Usage:
python main_benchmark.py --ckpt checkpoints/best...pth --img-size 224 --batch-size 64 --iters 200
"""
import time
from typing import Optional

import torch
from torch.cuda.amp import autocast

from cvnets import get_model
from engine.utils import autocast_fn
from options.opts import get_benchmarking_arguments
from utils import logger
from utils.common_utils import device_setup
from utils.pytorch_to_coreml import convert_pytorch_to_coreml
from utils.tensor_utils import create_rand_tensor


def cpu_timestamp(*args, **kwargs):
# perf_counter returns time in seconds
return time.perf_counter()


def cuda_timestamp(cuda_sync=False, device=None, *args, **kwargs):
if cuda_sync:
torch.cuda.synchronize(device=device)
# perf_counter returns time in seconds
return time.perf_counter()


def step(
time_fn,
model,
example_inputs,
autocast_enable: False,
amp_precision: Optional[str] = "float16",
):
start_time = time_fn()
with autocast_fn(enabled=autocast_enable, amp_precision=amp_precision):
model(example_inputs)
end_time = time_fn(cuda_sync=True)
return end_time - start_time


def main_benchmark():
# set-up
opts = get_benchmarking_arguments()
# device set-up
opts = device_setup(opts)

norm_layer = getattr(opts, "model.normalization.name", "batch_norm")
if norm_layer.find("sync") > -1:
norm_layer = norm_layer.replace("sync_", "")
setattr(opts, "model.normalization.name", norm_layer)
device = getattr(opts, "dev.device", torch.device("cpu"))
if torch.cuda.device_count() == 0:
device = torch.device("cpu")
time_fn = cpu_timestamp if device == torch.device("cpu") else cuda_timestamp
warmup_iterations = getattr(opts, "benchmark.warmup_iter", 10)
iterations = getattr(opts, "benchmark.n_iter", 50)
batch_size = getattr(opts, "benchmark.batch_size", 1)
mixed_precision = (
False
if device == torch.device("cpu")
else getattr(opts, "common.mixed_precision", False)
)
mixed_precision_dtype = getattr(opts, "common.mixed_precision_dtype", "float16")

# load the model
model = get_model(opts)
model.eval()
# print model information
model.info()

example_inp = create_rand_tensor(opts=opts, device="cpu", batch_size=batch_size)

# cool down for 5 seconds
time.sleep(5)

if getattr(opts, "benchmark.use_jit_model", False):
converted_models_dict = convert_pytorch_to_coreml(
opts=None,
pytorch_model=model,
input_tensor=example_inp,
jit_model_only=True,
)
model = converted_models_dict["jit"]
model = model.to(device=device)
example_inp = example_inp.to(device=device)
import argparse
from main_train import EnhancedMobileViT

def main(args):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = EnhancedMobileViT(num_classes=2, img_size=args.img_size, pretrained=False, cbam=args.cbam, fusion=args.fusion)
model = model.to(device)
if args.ckpt:
state = torch.load(args.ckpt, map_location=device)
model.load_state_dict(state.get("model_state", state))
model.eval()

# create synthetic input
x = torch.randn(args.batch_size, 3, args.img_size, args.img_size).to(device)
# Warmup
with torch.no_grad():
# warm-up
for i in range(warmup_iterations):
step(
time_fn=time_fn,
model=model,
example_inputs=example_inp,
autocast_enable=mixed_precision,
amp_precision=mixed_precision_dtype,
)

n_steps = n_samples = 0.0
for _ in range(10):
_ = model(x)

# run benchmark
for i in range(iterations):
step_time = step(
time_fn=time_fn,
model=model,
example_inputs=example_inp,
autocast_enable=mixed_precision,
amp_precision=mixed_precision_dtype,
)
n_steps += step_time
n_samples += batch_size

logger.info(
"Number of samples processed per second: {:.3f}".format(n_samples / n_steps)
)
# Timed runs
times = []
with torch.no_grad():
for _ in range(args.iters):
t0 = time.time()
_ = model(x)
t1 = time.time()
times.append(t1 - t0)

avg_time = sum(times) / len(times)
per_image = avg_time / args.batch_size
throughput = args.batch_size / avg_time
print(f"Batch size: {args.batch_size} | Avg batch time: {avg_time:.6f}s | Throughput: {throughput:.2f} img/s | Per-image latency: {per_image*1000:.3f} ms")

if __name__ == "__main__":
main_benchmark()
parser = argparse.ArgumentParser()
parser.add_argument("--ckpt", type=str, default=None)
parser.add_argument("--img-size", type=int, default=224)
parser.add_argument("--batch-size", type=int, default=64)
parser.add_argument("--iters", type=int, default=100)
parser.add_argument("--cbam", action="store_true")
parser.add_argument("--fusion", action="store_true")
args = parser.parse_args()
main(args)