Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1042,11 +1042,19 @@ def standardize_model_attributes(

elif library_name == "sentence_transformers":
if "Transformer" in model[0].__class__.__name__:
model.config = model[0].auto_model.config
model.config.export_model_type = "transformer"
inner_config = model[0].auto_model.config
try:
model.config = inner_config
except AttributeError:
pass
inner_config.export_model_type = "transformer"
elif "CLIP" in model[0].__class__.__name__:
model.config = model[0].model.config
model.config.export_model_type = "clip"
inner_config = model[0].model.config
try:
model.config = inner_config
except AttributeError:
pass
inner_config.export_model_type = "clip"
else:
raise ValueError(
f"The export of a sentence_transformers model with the first module being {model[0].__class__.__name__} is currently not supported in Optimum. Please open an issue or submit a PR to add the support."
Expand Down
29 changes: 29 additions & 0 deletions tests/exporters/common/test_tasks_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,3 +184,32 @@ def test_library_detection(self):
)
self.assertEqual(TasksManager.infer_library_from_model("gpt2"), "transformers")
self.assertEqual(TasksManager.infer_library_from_model("timm/mobilenetv3_large_100.ra_in1k"), "timm")

def test_standardize_sentence_transformers_readonly_config(self):
# Regression: sentence-transformers >= 5 exposes `config` as a read-only
# property, so a direct assignment used to raise AttributeError before
# `export_model_type` could be set on the underlying transformer config.
class FakeTransformer:
def __init__(self, config):
self.auto_model = type("AutoModel", (), {"config": config})()

class FakeSentenceTransformer:
def __init__(self, inner):
self._modules_list = [inner]

def __getitem__(self, idx):
return self._modules_list[idx]

@property
def config(self):
return self._modules_list[0].auto_model.config

inner_config = BertConfig()
transformer = FakeTransformer(inner_config)
transformer.__class__.__name__ = "Transformer"
st_model = FakeSentenceTransformer(transformer)

TasksManager.standardize_model_attributes(st_model, library_name="sentence_transformers")

self.assertEqual(inner_config.export_model_type, "transformer")
self.assertIs(st_model.config, inner_config)