Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
358 changes: 358 additions & 0 deletions restoration_training.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,358 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "c775af9f",
"metadata": {},
"source": [
"## Check GPU Availability"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3da24503",
"metadata": {},
"outputs": [],
"source": [
"!nvidia_smi"
]
},
{
"cell_type": "markdown",
"id": "6f14a538",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c8bc525",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import wandb\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments\n",
"from datasets import load_dataset\n",
"\n",
"from adapter import DistillationTrainer"
]
},
{
"cell_type": "markdown",
"id": "1ca4233c",
"metadata": {},
"source": [
"## Wandb setting"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5f56aa06",
"metadata": {},
"outputs": [],
"source": [
"PROJECT_NAME = \"restoration_training\"\n",
"RUN_NAME = \"Qwen3_8B_adapter_restoration\"\n",
"\n",
"# WandB Initialization\n",
"wandb.init(project=PROJECT_NAME, name=RUN_NAME)"
]
},
{
"cell_type": "markdown",
"id": "d08a6c97",
"metadata": {},
"source": [
"## Define Dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "625d6659",
"metadata": {},
"outputs": [],
"source": [
"dataset_path = \"retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl\""
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1b335067",
"metadata": {},
"outputs": [],
"source": [
"dataset = load_dataset(\"json\", data_file=dataset_path)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "39876999",
"metadata": {},
"outputs": [],
"source": [
"dataset['train'][0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "959bfd2a",
"metadata": {},
"outputs": [],
"source": [
"dataset['train'][0].keys()"
]
},
{
"cell_type": "markdown",
"id": "31ab1e69",
"metadata": {},
"source": [
"## Load Model"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "2a1e4445",
"metadata": {},
"outputs": [],
"source": [
"torch_dtype = torch.bfloat16\n",
"\n",
"base_model_id = \"Qwen/Qwen3-8B\"\n",
"base_model = AutoModelForCausalLM.from_pretrained(\n",
" base_model_id,\n",
" torch_dtype=torch_dtype,\n",
")\n",
"\n",
"adapt_model_id = \"pretrained/Qwen/Qwen3-8B_adapter\" # 예시 경로\n",
"adapt_model = AutoModelForCausalLM.from_pretrained(\n",
" adapt_model_id,\n",
" torch_dtype=torch_dtype,\n",
")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(base_model_id)\n",
"if tokenizer.pad_token is None:\n",
" tokenizer.pad_token = tokenizer.eos_token"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "214f91eb",
"metadata": {},
"outputs": [],
"source": [
"def tokenize_function(examples):\n",
" formatted_text = [\n",
" f\"질문: {inp}\\n\\n답변: {out}{tokenizer.eos_token}\"\n",
" for inp, out in zip(examples[\"input\"], examples[\"target_output\"])\n",
" ]\n",
" tokenized = tokenizer(formatted_text, truncation=True, padding=\"max_length\", max_length=512)\n",
" tokenized[\"labels\"] = tokenized[\"input_ids\"].copy()\n",
" return tokenized\n",
"\n",
"original_columns = dataset['train'].column_names\n",
"tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=original_columns)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "65a921ac",
"metadata": {},
"outputs": [],
"source": [
"split_dataset = tokenized_datasets[\"train\"].train_test_split(test_size=0.1, seed=42)\n",
"\n",
"train_dataset = split_dataset[\"train\"]\n",
"eval_dataset = split_dataset[\"test\"]"
]
},
{
"cell_type": "markdown",
"id": "79c9fd59",
"metadata": {},
"source": [
"## Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "7c7f4cef",
"metadata": {},
"outputs": [],
"source": [
"training_args = TrainingArguments(\n",
" output_dir=\"./qwen3-restored-final\",\n",
" num_train_epochs=10,\n",
" per_device_train_batch_size=8,\n",
" per_device_eval_batch_size=8,\n",
" learning_rate=5e-5,\n",
" bf16=True,\n",
" evaluation_strategy=\"epoch\",\n",
" save_strategy=\"epoch\", # 매 에포크 종료 시 모델 체크포인트 저장\n",
" logging_steps=10, # 10 스텝마다 로그 출력\n",
" load_best_model_at_end=True,\n",
" metric_for_best_model=\"eval_loss\",\n",
" greater_is_better=False,\n",
" report_to=\"wandb\",\n",
")\n",
"\n",
"trainer = DistillationTrainer(\n",
" model=adapt_model,\n",
" teacher_model=base_model,\n",
" args=training_args,\n",
" train_dataset=train_dataset,\n",
" eval_dataset=eval_dataset,\n",
" alpha=1.0,\n",
" temperature=0.3,\n",
")\n",
"\n",
"trainer.train()\n",
"\n",
"wandb.finish()"
]
},
{
"cell_type": "markdown",
"id": "f7f8133d",
"metadata": {},
"source": [
"## Testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "314782cb",
"metadata": {},
"outputs": [],
"source": [
"# 파일명: testing.py\n",
"\n",
"import torch\n",
"from transformers import AutoTokenizer, AutoModelForCausalLM\n",
"from tqdm import tqdm\n",
"import json\n",
"\n",
"TEACHER_MODEL_ID = \"Qwen/Qwen3-8B\"\n",
"STUDENT_MODEL_PATH = \"./qwen3-restored-final\"\n",
"TEST_DATA_PATH = \"retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl\"\n",
"DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
"NUM_TEST_SAMPLES = 50\n",
"MAX_NEW_TOKENS = 128\n",
"\n",
"print(\"INFO: 모델 및 토크나이저 로딩 중...\")\n",
"tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_ID)\n",
"teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)\n",
"student_model = AutoModelForCausalLM.from_pretrained(STUDENT_MODEL_PATH, torch_dtype=torch.bfloat16).to(DEVICE)\n",
"\n",
"teacher_model.eval()\n",
"student_model.eval()\n",
"print(\"INFO: 로딩 완료.\")\n",
"\n",
"test_prompts = []\n",
"try:\n",
" with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:\n",
" for i, line in enumerate(f):\n",
" if i >= NUM_TEST_SAMPLES:\n",
" break\n",
" data = json.loads(line)\n",
" prompt = f\"질문: {data['input']}\\n\\n답변:\"\n",
" test_prompts.append(prompt)\n",
"except FileNotFoundError:\n",
" print(f\"ERROR: 테스트 데이터 파일을 찾을 수 없습니다. 경로를 확인하세요: {TEST_DATA_PATH}\")\n",
" exit()\n",
"\n",
"total_accuracy = 0\n",
"total_mismatched = 0\n",
"total_generated = 0\n",
"\n",
"print(f\"\\nINFO: {len(test_prompts)}개의 샘플에 대한 테스트 시작...\")\n",
"for sample_idx, prompt in enumerate(tqdm(test_prompts, desc=\"Testing Progress\")):\n",
" inputs = tokenizer(prompt, return_tensors=\"pt\").to(DEVICE)\n",
"\n",
" with torch.no_grad():\n",
" teacher_output_ids = teacher_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)\n",
" student_output_ids = student_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)\n",
"\n",
" teacher_tokens = teacher_output_ids[0][inputs.input_ids.shape[1]:]\n",
" student_tokens = student_output_ids[0][inputs.input_ids.shape[1]:]\n",
" \n",
" total_tokens_in_sample = len(teacher_tokens)\n",
" if total_tokens_in_sample == 0:\n",
" continue # Teacher가 아무것도 생성하지 않은 경우 해당 샘플은 건너뜀\n",
"\n",
" mismatched_tokens_in_sample = 0\n",
" mismatch_details = []\n",
"\n",
" len_to_compare = min(len(teacher_tokens), len(student_tokens))\n",
" for i in range(len_to_compare):\n",
" if teacher_tokens[i] != student_tokens[i]:\n",
" mismatched_tokens_in_sample += 1\n",
" mismatch_details.append(\n",
" f\" - Position #{i+1}: \"\n",
" f\"Teacher='{tokenizer.decode(teacher_tokens[i])}' (ID: {teacher_tokens[i]}) vs \"\n",
" f\"Student='{tokenizer.decode(student_tokens[i])}' (ID: {student_tokens[i]})\"\n",
" )\n",
" \n",
" len_diff = abs(len(teacher_tokens) - len(student_tokens))\n",
" if len_diff > 0:\n",
" mismatched_tokens_in_sample += len_diff\n",
" stop_reason = \"Teacher가 먼저 중단\" if len(teacher_tokens) < len(student_tokens) else \"Student가 먼저 중단\"\n",
" mismatch_details.append(\n",
" f\" - 생성 길이 불일치: Teacher={len(teacher_tokens)}개, Student={len(student_tokens)}개 ({stop_reason})\"\n",
" )\n",
"\n",
" if mismatch_details:\n",
" print(f\"\\n\\n--- [sample #{sample_idx+1}] doesn't match---\")\n",
" print(f\"Prompt: \\\"{prompt[:80]}...\\\"\")\n",
" for detail in mismatch_details:\n",
" print(detail)\n",
" print(f\"Teacher output: \\\"...{tokenizer.decode(teacher_tokens, skip_special_tokens=True)}\\\"\")\n",
" print(f\"Student output: \\\"...{tokenizer.decode(student_tokens, skip_special_tokens=True)}\\\"\")\n",
" print(\"-\" * 25)\n",
"\n",
" accuracy = ((total_tokens_in_sample - mismatched_tokens_in_sample) / total_tokens_in_sample) * 100\n",
" total_accuracy += accuracy\n",
" total_mismatched += mismatched_tokens_in_sample\n",
" total_generated += total_tokens_in_sample\n",
"\n",
"if len(test_prompts) > 0:\n",
" avg_accuracy = total_accuracy / len(test_prompts)\n",
" print(\"\\n\\n\" + \"=\"*30)\n",
" print(\"--- 최종 평가 결과 요약 ---\")\n",
" print(f\"평균 토큰 일치율 (정확도): {avg_accuracy:.2f}%\")\n",
" print(f\"총 생성 토큰 (Teacher 기준): {total_generated}\")\n",
" print(f\"총 불일치 토큰: {total_mismatched}\")\n",
" print(\"=\"*30)\n",
"else:\n",
" print(\"테스트할 데이터를 찾지 못해 평가를 진행할 수 없습니다.\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading