From e893a805c3c2df3c0927ac51f5c7128e9d68dd59 Mon Sep 17 00:00:00 2001 From: zhangfanTJU Date: Wed, 5 Apr 2023 17:40:33 +0800 Subject: [PATCH] fix-loader-fp16 --- libai/models/utils/model_loader/base_loader.py | 9 ++++++--- projects/GLM/infer_glm.py | 2 +- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/libai/models/utils/model_loader/base_loader.py b/libai/models/utils/model_loader/base_loader.py index 543739298..a39b60b76 100644 --- a/libai/models/utils/model_loader/base_loader.py +++ b/libai/models/utils/model_loader/base_loader.py @@ -23,7 +23,7 @@ from termcolor import colored import libai.utils.distributed as dist -from libai.config import LazyCall +from libai.config import LazyCall, try_get_key from libai.models.build import build_model logger = logging.getLogger(__name__) @@ -389,8 +389,7 @@ def _convert_tensor(self, tensor): Returns: flow.Tensor: The target tensor. """ - tensor = tensor.float() - return flow.Tensor(tensor.detach().cpu().numpy()) + return flow.tensor(tensor.detach().cpu().numpy()) def _convert_tensors(self, torch_state_dict): @@ -578,6 +577,10 @@ def load(self): else: self.model = build_model(LazyCall(self.model)(cfg=self.libai_cfg)) + # Convert to fp16 + if try_get_key(self.libai_cfg, "amp_enabled"): + self.model.half() + # State_dict to global logger.info("transfering state_dict local to global...") flow_state_dict = self._state_dict_to_global(flow_state_dict, mode="pytorch") diff --git a/projects/GLM/infer_glm.py b/projects/GLM/infer_glm.py index 02dc65b0b..a0fb54613 100644 --- a/projects/GLM/infer_glm.py +++ b/projects/GLM/infer_glm.py @@ -43,7 +43,7 @@ output_dropout_prob=0, ) model = loader.load() -model = model.half().cuda() +model = model.cuda() model.eval() dist.set_device_type("cuda")