From 9626d36761fcad687b0220f99dc11c0f8ae35d6c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=8C=B6=E5=8F=B6?= Date: Sat, 31 Jan 2026 18:25:01 +0800 Subject: [PATCH] fix: add @torch.no_grad() to prevent memory leak on MPS/CPU - Add @torch.no_grad() decorator to generate() methods in utils.py - Remove torch.autocast(device_type='cuda') wrapper that causes warnings on non-CUDA devices - Fixes memory leak issue on Apple Silicon (MPS) where memory grows to 90GB+ Tested on M4 Max with 64GB RAM, memory usage now stable at 2-3GB. Closes #46 --- gradio/inference.py | 9 ++++----- gradio/utils.py | 4 +++- inference/utils.py | 4 +++- 3 files changed, 10 insertions(+), 7 deletions(-) diff --git a/gradio/inference.py b/gradio/inference.py index 2bef0fa..caf3785 100644 --- a/gradio/inference.py +++ b/gradio/inference.py @@ -221,11 +221,10 @@ def inference_patch(period, composer, instrumentation): with torch.inference_mode(): while True: - with torch.autocast(device_type='cuda', dtype=torch.float16): - predicted_patch = model.generate(input_patches.unsqueeze(0), - top_k=TOP_K, - top_p=TOP_P, - temperature=TEMPERATURE) + predicted_patch = model.generate(input_patches.unsqueeze(0), + top_k=TOP_K, + top_p=TOP_P, + temperature=TEMPERATURE) if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # 初次进入tunebody,必须以[r:0/开头 tunebody_flag = True r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device) diff --git a/gradio/utils.py b/gradio/utils.py index 3f34b08..d113cb4 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -305,6 +305,7 @@ def forward(self, return output + @torch.no_grad() def generate(self, encoded_patch: torch.Tensor, # [hidden_size] tokens: torch.Tensor): # [1] @@ -379,7 +380,8 @@ def forward(self, patches = patches[masks == 1] return self.char_level_decoder(encoded_patches, patches) - + + @torch.no_grad() def generate(self, patches: torch.Tensor, top_k=0, diff --git a/inference/utils.py b/inference/utils.py index d28a180..19a11b6 100644 --- a/inference/utils.py +++ b/inference/utils.py @@ -305,6 +305,7 @@ def forward(self, return output + @torch.no_grad() def generate(self, encoded_patch: torch.Tensor, # [hidden_size] tokens: torch.Tensor): # [1] @@ -379,7 +380,8 @@ def forward(self, patches = patches[masks == 1] return self.char_level_decoder(encoded_patches, patches) - + + @torch.no_grad() def generate(self, patches: torch.Tensor, top_k=0,