diff --git a/megatron/model/llama_model.py b/megatron/model/llama_model.py index dafc1ce7eb2..1e88e8305ab 100644 --- a/megatron/model/llama_model.py +++ b/megatron/model/llama_model.py @@ -876,7 +876,8 @@ def _to_float16(inputs): loss_fn=CrossEntropy, topology=topo, activation_checkpoint_interval=interval, - partition_method='type:transformer') + partition_method='type:transformer', + checkpointable_layers=['LlamaParallelTransformerLayerPipe']) class LlamaModel(MegatronModule):