From 60775756eafb48a93d7b19171418ea4f31087345 Mon Sep 17 00:00:00 2001 From: doyoung kim Date: Fri, 26 Jun 2026 12:28:02 +0900 Subject: [PATCH] Update model instantiation --- submissions/client_encode_encrypt_input.py | 7 +++---- submissions/client_preprocess_input.py | 3 +-- submissions/server_preprocess_model.py | 4 ++-- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/submissions/client_encode_encrypt_input.py b/submissions/client_encode_encrypt_input.py index 4898b05..ddd7580 100644 --- a/submissions/client_encode_encrypt_input.py +++ b/submissions/client_encode_encrypt_input.py @@ -4,12 +4,11 @@ import numpy as np import torch from desilofhe import Engine -from transformers import BertForNextSentencePrediction +from transformers import BertForSequenceClassification from params import InstanceParams -# For encoding, base model is used. -EMBED_MODEL_ID = "google-bert/bert-base-cased" +EMBED_MODEL_ID = "google-bert/bert-base-cased-finetuned-mrpc" EMBED_LEVEL = 9 @@ -82,7 +81,7 @@ def main(): ) secret_key = engine.read_secret_key(io_dir / "secret_key") - embedding_model = BertForNextSentencePrediction.from_pretrained(EMBED_MODEL_ID).bert.embeddings + embedding_model = BertForSequenceClassification.from_pretrained(EMBED_MODEL_ID).bert.embeddings embedding_model.eval() records = [] diff --git a/submissions/client_preprocess_input.py b/submissions/client_preprocess_input.py index b80d5bc..b802182 100644 --- a/submissions/client_preprocess_input.py +++ b/submissions/client_preprocess_input.py @@ -4,8 +4,7 @@ from params import InstanceParams from transformers import AutoTokenizer -# For encoding, base model is used. -MODEL_ID = "google-bert/bert-base-cased" +MODEL_ID = "google-bert/bert-base-cased-finetuned-mrpc" MAX_LENGTH = 128 diff --git a/submissions/server_preprocess_model.py b/submissions/server_preprocess_model.py index 15d90c7..5e3ff50 100644 --- a/submissions/server_preprocess_model.py +++ b/submissions/server_preprocess_model.py @@ -4,7 +4,7 @@ from pathlib import Path from desilofhe import Engine -from transformers import BertForNextSentencePrediction +from transformers import BertForSequenceClassification from params import InstanceParams from encode_weights import ( @@ -77,7 +77,7 @@ def main(): warm_cache(lp_path) return - model = BertForNextSentencePrediction.from_pretrained(MODEL_ID) + model = BertForSequenceClassification.from_pretrained(MODEL_ID, output_hidden_states=True) model.eval() weights = {k: v.detach().cpu().numpy() for k, v in model.state_dict().items()}