diff --git a/torchstat/compute_memory.py b/torchstat/compute_memory.py index dc9dac6..fc55d9d 100644 --- a/torchstat/compute_memory.py +++ b/torchstat/compute_memory.py @@ -18,7 +18,7 @@ def compute_memory(module, inp, out): return compute_Pool2d_memory(module, inp, out) else: print(f"[Memory]: {type(module).__name__} is not supported!") - return (0, 0) + return 0, 0 pass @@ -28,20 +28,21 @@ def num_params(module): def compute_ReLU_memory(module, inp, out): assert isinstance(module, (nn.ReLU, nn.ReLU6, nn.ELU, nn.LeakyReLU)) - batch_size = inp.size()[0] - mread = batch_size * inp.size()[1:].numel() - mwrite = batch_size * inp.size()[1:].numel() - return (mread, mwrite) + mread = inp.numel() + mwrite = out.numel() + + return mread, mwrite def compute_PReLU_memory(module, inp, out): - assert isinstance(module, (nn.PReLU)) + assert isinstance(module, nn.PReLU) + batch_size = inp.size()[0] - mread = batch_size * (inp.size()[1:].numel() + num_params(module)) - mwrite = batch_size * inp.size()[1:].numel() + mread = batch_size * (inp[0].numel() + num_params(module)) + mwrite = out.numel() - return (mread, mwrite) + return mread, mwrite def compute_Conv2d_memory(module, inp, out): @@ -50,39 +51,42 @@ def compute_Conv2d_memory(module, inp, out): assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) batch_size = inp.size()[0] - in_c = inp.size()[1] - out_c, out_h, out_w = out.size()[1:] - # This includes weighs with bias if the module contains it. - mread = batch_size * (inp.size()[1:].numel() + num_params(module)) - mwrite = batch_size * out_c * out_h * out_w - return (mread, mwrite) + # This includes weights with bias if the module contains it. + mread = batch_size * (inp[0].numel() + num_params(module)) + mwrite = out.numel() + return mread, mwrite def compute_BatchNorm2d_memory(module, inp, out): assert isinstance(module, nn.BatchNorm2d) assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) + batch_size, in_c, in_h, in_w = inp.size() + mread = batch_size * (inp[0].numel() + 2 * in_c) + mwrite = out.numel() - mread = batch_size * (inp.size()[1:].numel() + 2 * in_c) - mwrite = inp.size().numel() - return (mread, mwrite) + return mread, mwrite def compute_Linear_memory(module, inp, out): assert isinstance(module, nn.Linear) assert len(inp.size()) == 2 and len(out.size()) == 2 + batch_size = inp.size()[0] - mread = batch_size * (inp.size()[1:].numel() + num_params(module)) - mwrite = out.size().numel() - return (mread, mwrite) + # This includes weights with bias if the module contains it. + mread = batch_size * (inp[0].numel() + num_params(module)) + mwrite = out.numel() + + return mread, mwrite def compute_Pool2d_memory(module, inp, out): assert isinstance(module, (nn.MaxPool2d, nn.AvgPool2d)) assert len(inp.size()) == 4 and len(inp.size()) == len(out.size()) - batch_size = inp.size()[0] - mread = batch_size * inp.size()[1:].numel() - mwrite = batch_size * out.size()[1:].numel() - return (mread, mwrite) + + mread = inp.numel() + mwrite = out.numel() + + return mread, mwrite