From 80b91a14b183550f10d15ad1f9376c42bfb07f26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E9=A3=98=E8=8D=A1=E7=9D=80=E5=91=A2?= Date: Thu, 20 Jul 2023 12:34:27 +0800 Subject: [PATCH] Update llama_model.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 添加LlamaParallelTransformerLayerPipe到checkpointable_layers --- megatron/model/llama_model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/megatron/model/llama_model.py b/megatron/model/llama_model.py index dafc1ce7eb2..1e88e8305ab 100644 --- a/megatron/model/llama_model.py +++ b/megatron/model/llama_model.py @@ -876,7 +876,8 @@ def _to_float16(inputs): loss_fn=CrossEntropy, topology=topo, activation_checkpoint_interval=interval, - partition_method='type:transformer') + partition_method='type:transformer', + checkpointable_layers=['LlamaParallelTransformerLayerPipe']) class LlamaModel(MegatronModule):