diff --git a/gradio/inference.py b/gradio/inference.py index 2bef0fa..caf3785 100644 --- a/gradio/inference.py +++ b/gradio/inference.py @@ -221,11 +221,10 @@ def inference_patch(period, composer, instrumentation): with torch.inference_mode(): while True: - with torch.autocast(device_type='cuda', dtype=torch.float16): - predicted_patch = model.generate(input_patches.unsqueeze(0), - top_k=TOP_K, - top_p=TOP_P, - temperature=TEMPERATURE) + predicted_patch = model.generate(input_patches.unsqueeze(0), + top_k=TOP_K, + top_p=TOP_P, + temperature=TEMPERATURE) if not tunebody_flag and patchilizer.decode([predicted_patch]).startswith('[r:'): # 初次进入tunebody,必须以[r:0/开头 tunebody_flag = True r0_patch = torch.tensor([ord(c) for c in '[r:0/']).unsqueeze(0).to(device) diff --git a/gradio/utils.py b/gradio/utils.py index 3f34b08..d113cb4 100644 --- a/gradio/utils.py +++ b/gradio/utils.py @@ -305,6 +305,7 @@ def forward(self, return output + @torch.no_grad() def generate(self, encoded_patch: torch.Tensor, # [hidden_size] tokens: torch.Tensor): # [1] @@ -379,7 +380,8 @@ def forward(self, patches = patches[masks == 1] return self.char_level_decoder(encoded_patches, patches) - + + @torch.no_grad() def generate(self, patches: torch.Tensor, top_k=0, diff --git a/inference/utils.py b/inference/utils.py index d28a180..19a11b6 100644 --- a/inference/utils.py +++ b/inference/utils.py @@ -305,6 +305,7 @@ def forward(self, return output + @torch.no_grad() def generate(self, encoded_patch: torch.Tensor, # [hidden_size] tokens: torch.Tensor): # [1] @@ -379,7 +380,8 @@ def forward(self, patches = patches[masks == 1] return self.char_level_decoder(encoded_patches, patches) - + + @torch.no_grad() def generate(self, patches: torch.Tensor, top_k=0,