diff --git a/openprompt/data_utils/utils.py b/openprompt/data_utils/utils.py index 0751105..bb13430 100644 --- a/openprompt/data_utils/utils.py +++ b/openprompt/data_utils/utils.py @@ -178,11 +178,12 @@ def to_tensor(self, device: str = 'cuda'): def to(self, device: str = "cuda:0"): r"""move the tensor keys to runtime device, such as gpu:0 """ - for key in self.tensorable_keys: - value = getattr(self, key) + target = copy.deepcopy(self) + for key in target.tensorable_keys: + value = getattr(target, key) if value is not None: - setattr(self, key, value.to(device)) - return self + setattr(target, key, value.to(device)) + return target def cuda(self, device: str = "cuda:0"): r"""mimic the tensor behavior