diff --git a/models/whisper/export.py b/models/whisper/export.py index 2ab3153..352d915 100644 --- a/models/whisper/export.py +++ b/models/whisper/export.py @@ -107,8 +107,17 @@ def create_whisper( example_inputs = reference_inputs(model_name, dtype) + example_inputs["decoder_input_ids"] = torch.tensor( + [[50258, 50259, 50360, 50364]], dtype=torch.int32 + ) + dynamic_shapes = { + "input_features": {}, + "decoder_input_ids": {1: torch.export.Dim("dec_seq_len", min=1, max=448)}, + } with torch.autocast(device_type="cpu", dtype=dtype): - exported = torch.export.export(model, args=(), kwargs=example_inputs) + exported = torch.export.export( + model, args=(), kwargs=example_inputs, dynamic_shapes=dynamic_shapes + ) exported = exported.run_decompositions(get_decomp_table()) print("[INFO] Model exported. Converting to Core AI...")