From 4f0ed9f982cd9ca7f99012815b69b57be6f1249e Mon Sep 17 00:00:00 2001 From: lyakaap Date: Tue, 20 Nov 2018 20:00:58 +0900 Subject: [PATCH 1/2] Fix a issue (#6) --- torchstat/compute_memory.py | 58 ++++++++++++++++++++----------------- 1 file changed, 31 insertions(+), 27 deletions(-) diff --git a/torchstat/compute_memory.py b/torchstat/compute_memory.py index dc9dac6..53382bc 100644 --- a/torchstat/compute_memory.py +++ b/torchstat/compute_memory.py @@ -18,7 +18,7 @@ def compute_memory(module, inp, out): return compute_Pool2d_memory(module, inp, out) else: print(f"[Memory]: {type(module).__name__} is not supported!") - return (0, 0) + return 0, 0 pass @@ -28,20 +28,21 @@ def num_params(module): def compute_ReLU_memory(module, inp, out): assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)) - batch_size = inp.size()[0] - mread = batch_size * inp.size()[1:].numel() - mwrite = batch_size * inp.size()[1:].numel() - - return (mread, mwrite) + + mread = inp.numel() + mwrite = out.numel() + + return mread, mwrite def compute_PReLU_memory(module, inp, out): - assert isinstance(module, (nn.PReLU)) + assert isinstance(module, nn.PReLU) + batch_size = inp.size()[0] - mread = batch_size * (inp.size()[1:].numel() + num_params(module)) - mwrite = batch_size * inp.size()[1:].numel() + mread = batch_size * (inp[0].numel() + num_params(module)) + mwrite = out.numel() - return (mread, mwrite) + return mread, mwrite def compute_Conv2d_memory(module, inp, out): @@ -50,39 +51,42 @@ def compute_Conv2d_memory(module, inp, out): assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) batch_size = inp.size()[0] - in_c = inp.size()[1] - out_c, out_h, out_w = out.size()[1:] - # This includes weighs with bias if the module contains it. - mread = batch_size * (inp.size()[1:].numel() + num_params(module)) - mwrite = batch_size * out_c * out_h * out_w - return (mread, mwrite) + # This includes weights with bias if the module contains it. + mread = batch_size * (inp[0].numel() + num_params(module)) + mwrite = out.numel() + return mread, mwrite def compute_BatchNorm2d_memory(module, inp, out): assert isinstance(module, nn.BatchNorm2d) assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) + batch_size, in_c, in_h, in_w = inp.size() - - mread = batch_size * (inp.size()[1:].numel() + 2 * in_c) - mwrite = inp.size().numel() - return (mread, mwrite) + mread = batch_size * (inp[0].numel() + 2 * in_c) + mwrite = out.numel() + + return mread, mwrite def compute_Linear_memory(module, inp, out): assert isinstance(module, nn.Linear) assert len(inp.size()) == 2 and len(out.size()) == 2 + batch_size = inp.size()[0] - mread = batch_size * (inp.size()[1:].numel() + num_params(module)) - mwrite = out.size().numel() - return (mread, mwrite) + # This includes weights with bias if the module contains it. + mread = batch_size * (inp[0].numel() + num_params(module)) + mwrite = out.numel() + + return mread, mwrite def compute_Pool2d_memory(module, inp, out): assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)) assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) - batch_size = inp.size()[0] - mread = batch_size * inp.size()[1:].numel() - mwrite = batch_size * out.size()[1:].numel() - return (mread, mwrite) + + mread = inp.numel() + mwrite = out.numel() + + return mread, mwrite From 37895205d6c574986e857fe31729693433bf8fd7 Mon Sep 17 00:00:00 2001 From: lyakaap Date: Tue, 20 Nov 2018 20:08:24 +0900 Subject: [PATCH 2/2] Follow pycodestyle --- torchstat/compute_memory.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/torchstat/compute_memory.py b/torchstat/compute_memory.py index 53382bc..fc55d9d 100644 --- a/torchstat/compute_memory.py +++ b/torchstat/compute_memory.py @@ -28,16 +28,16 @@ def num_params(module): def compute_ReLU_memory(module, inp, out): assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)) - + mread = inp.numel() mwrite = out.numel() - + return mread, mwrite def compute_PReLU_memory(module, inp, out): assert isinstance(module, nn.PReLU) - + batch_size = inp.size()[0] mread = batch_size * (inp[0].numel() + num_params(module)) mwrite = out.numel() @@ -61,18 +61,18 @@ def compute_Conv2d_memory(module, inp, out): def compute_BatchNorm2d_memory(module, inp, out): assert isinstance(module, nn.BatchNorm2d) assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) - + batch_size, in_c, in_h, in_w = inp.size() mread = batch_size * (inp[0].numel() + 2 * in_c) mwrite = out.numel() - + return mread, mwrite def compute_Linear_memory(module, inp, out): assert isinstance(module, nn.Linear) assert len(inp.size()) == 2 and len(out.size()) == 2 - + batch_size = inp.size()[0] # This includes weights with bias if the module contains it. @@ -85,8 +85,8 @@ def compute_Linear_memory(module, inp, out): def compute_Pool2d_memory(module, inp, out): assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)) assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) - + mread = inp.numel() mwrite = out.numel() - + return mread, mwrite