Skip to content

convert_exported_program_to_serialized_trt_engine fails with AssertionError on TorchExportableModuleWithStaticCache exports #4162

@Mgluhovskoi

Description

@Mgluhovskoi

Bug Description

convert_exported_program_to_serialized_trt_engine fails with AssertionError when given an ExportedProgram from TorchExportableModuleWithStaticCache (HuggingFace transformers). The error originates in run_decompositions() which is called internally.

The root cause is in torch._functorch._aot_autograd.graph_compile.aot_stage2_export:

assert isinstance(compiled_fn, torch.fx.GraphModule)
AssertionError

TorchExportableModuleWithStaticCache wraps a causal LM with StaticCache registered as module state. The exported program traces cleanly via torch.export.export(strict=False), but run_decompositions() fails during the AOT re-export step.

This is not the same as #3226 (torchao import order), which was fixed. This bug occurs without torchao.

To Reproduce

import torch
import torch_tensorrt
from transformers import AutoConfig, AutoModelForCausalLM
from transformers.integrations.executorch import TorchExportableModuleWithStaticCache

config = AutoConfig.from_pretrained("gpt2")
config.n_layer = 1
model = AutoModelForCausalLM.from_config(config).eval().half()
model.generation_config.cache_implementation = "static"
model.generation_config.use_cache = True

wrapper = TorchExportableModuleWithStaticCache(model, batch_size=1, max_cache_len=16)

input_ids = torch.tensor([[42]], dtype=torch.long)
cache_position = torch.tensor([0], dtype=torch.long)

exported = torch.export.export(
    wrapper, (), kwargs={"input_ids": input_ids, "cache_position": cache_position}, strict=False
)

# This fails:
engine = torch_tensorrt.dynamo.convert_exported_program_to_serialized_trt_engine(
    exported,
    inputs=[
        torch_tensorrt.Input(shape=input_ids.shape, dtype=input_ids.dtype),
        torch_tensorrt.Input(shape=cache_position.shape, dtype=cache_position.dtype),
    ],
    use_explicit_typing=True,
    min_block_size=1,
)

Error

File ".../torch/export/exported_program.py", line 1484, in run_decompositions
    return _decompose_exported_program(
File ".../torch/export/exported_program.py", line 967, in _decompose_exported_program
    ) = _decompose_and_get_gm_with_new_signature_constants(
File ".../torch/export/exported_program.py", line 476, in _decompose_and_get_gm_with_new_signature_constants
    aten_export_artifact = _export_to_aten_ir(
File ".../torch/export/_trace.py", line 985, in _export_to_aten_ir
    gm, graph_signature = transform(_aot_export_joint_with_descriptors)(
File ".../torch/export/_trace.py", line 924, in _aot_export_joint_with_descriptors
    gm, fw_metadata = aot_stage2_export(
File ".../torch/_functorch/_aot_autograd/graph_compile.py", line 288, in aot_stage2_export
    assert isinstance(compiled_fn, torch.fx.GraphModule)
AssertionError

Workaround

Instead of using TorchExportableModuleWithStaticCache, create a fully stateless wrapper that accepts KV cache tensors as explicit inputs/outputs (no internal buffer mutations). This avoids the run_decompositions code path issue entirely. The stateless wrapper exports and converts to TRT engine successfully.

Environment

  • torch: 2.10.0+cu128
  • torch_tensorrt: 2.10.0+cu130
  • transformers: 5.2.0
  • CUDA: 12.8
  • GPU: NVIDIA GeForce RTX 4090
  • OS: Ubuntu 22.04 (Docker container)

Metadata

Metadata

Labels

No labels
No labels

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions