diff --git a/restoration_training.ipynb b/restoration_training.ipynb new file mode 100644 index 0000000..41e8978 --- /dev/null +++ b/restoration_training.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c775af9f", + "metadata": {}, + "source": [ + "## Check GPU Availability" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3da24503", + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia_smi" + ] + }, + { + "cell_type": "markdown", + "id": "6f14a538", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c8bc525", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import wandb\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments\n", + "from datasets import load_dataset\n", + "\n", + "from adapter import DistillationTrainer" + ] + }, + { + "cell_type": "markdown", + "id": "1ca4233c", + "metadata": {}, + "source": [ + "## Wandb setting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f56aa06", + "metadata": {}, + "outputs": [], + "source": [ + "PROJECT_NAME = \"restoration_training\"\n", + "RUN_NAME = \"Qwen3_8B_adapter_restoration\"\n", + "\n", + "# WandB Initialization\n", + "wandb.init(project=PROJECT_NAME, name=RUN_NAME)" + ] + }, + { + "cell_type": "markdown", + "id": "d08a6c97", + "metadata": {}, + "source": [ + "## Define Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "625d6659", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_path = \"retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b335067", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(\"json\", data_file=dataset_path)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39876999", + "metadata": {}, + "outputs": [], + "source": [ + "dataset['train'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "959bfd2a", + "metadata": {}, + "outputs": [], + "source": [ + "dataset['train'][0].keys()" + ] + }, + { + "cell_type": "markdown", + "id": "31ab1e69", + "metadata": {}, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a1e4445", + "metadata": {}, + "outputs": [], + "source": [ + "torch_dtype = torch.bfloat16\n", + "\n", + "base_model_id = \"Qwen/Qwen3-8B\"\n", + "base_model = AutoModelForCausalLM.from_pretrained(\n", + " base_model_id,\n", + " torch_dtype=torch_dtype,\n", + ")\n", + "\n", + "adapt_model_id = \"pretrained/Qwen/Qwen3-8B_adapter\" # 예시 경로\n", + "adapt_model = AutoModelForCausalLM.from_pretrained(\n", + " adapt_model_id,\n", + " torch_dtype=torch_dtype,\n", + ")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(base_model_id)\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "214f91eb", + "metadata": {}, + "outputs": [], + "source": [ + "def tokenize_function(examples):\n", + " formatted_text = [\n", + " f\"질문: {inp}\\n\\n답변: {out}{tokenizer.eos_token}\"\n", + " for inp, out in zip(examples[\"input\"], examples[\"target_output\"])\n", + " ]\n", + " tokenized = tokenizer(formatted_text, truncation=True, padding=\"max_length\", max_length=512)\n", + " tokenized[\"labels\"] = tokenized[\"input_ids\"].copy()\n", + " return tokenized\n", + "\n", + "original_columns = dataset['train'].column_names\n", + "tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=original_columns)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65a921ac", + "metadata": {}, + "outputs": [], + "source": [ + "split_dataset = tokenized_datasets[\"train\"].train_test_split(test_size=0.1, seed=42)\n", + "\n", + "train_dataset = split_dataset[\"train\"]\n", + "eval_dataset = split_dataset[\"test\"]" + ] + }, + { + "cell_type": "markdown", + "id": "79c9fd59", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7f4cef", + "metadata": {}, + "outputs": [], + "source": [ + "training_args = TrainingArguments(\n", + " output_dir=\"./qwen3-restored-final\",\n", + " num_train_epochs=10,\n", + " per_device_train_batch_size=8,\n", + " per_device_eval_batch_size=8,\n", + " learning_rate=5e-5,\n", + " bf16=True,\n", + " evaluation_strategy=\"epoch\",\n", + " save_strategy=\"epoch\", # 매 에포크 종료 시 모델 체크포인트 저장\n", + " logging_steps=10, # 10 스텝마다 로그 출력\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"eval_loss\",\n", + " greater_is_better=False,\n", + " report_to=\"wandb\",\n", + ")\n", + "\n", + "trainer = DistillationTrainer(\n", + " model=adapt_model,\n", + " teacher_model=base_model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " alpha=1.0,\n", + " temperature=0.3,\n", + ")\n", + "\n", + "trainer.train()\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "id": "f7f8133d", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "314782cb", + "metadata": {}, + "outputs": [], + "source": [ + "# 파일명: testing.py\n", + "\n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from tqdm import tqdm\n", + "import json\n", + "\n", + "TEACHER_MODEL_ID = \"Qwen/Qwen3-8B\"\n", + "STUDENT_MODEL_PATH = \"./qwen3-restored-final\"\n", + "TEST_DATA_PATH = \"retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl\"\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "NUM_TEST_SAMPLES = 50\n", + "MAX_NEW_TOKENS = 128\n", + "\n", + "print(\"INFO: 모델 및 토크나이저 로딩 중...\")\n", + "tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_ID)\n", + "teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)\n", + "student_model = AutoModelForCausalLM.from_pretrained(STUDENT_MODEL_PATH, torch_dtype=torch.bfloat16).to(DEVICE)\n", + "\n", + "teacher_model.eval()\n", + "student_model.eval()\n", + "print(\"INFO: 로딩 완료.\")\n", + "\n", + "test_prompts = []\n", + "try:\n", + " with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:\n", + " for i, line in enumerate(f):\n", + " if i >= NUM_TEST_SAMPLES:\n", + " break\n", + " data = json.loads(line)\n", + " prompt = f\"질문: {data['input']}\\n\\n답변:\"\n", + " test_prompts.append(prompt)\n", + "except FileNotFoundError:\n", + " print(f\"ERROR: 테스트 데이터 파일을 찾을 수 없습니다. 경로를 확인하세요: {TEST_DATA_PATH}\")\n", + " exit()\n", + "\n", + "total_accuracy = 0\n", + "total_mismatched = 0\n", + "total_generated = 0\n", + "\n", + "print(f\"\\nINFO: {len(test_prompts)}개의 샘플에 대한 테스트 시작...\")\n", + "for sample_idx, prompt in enumerate(tqdm(test_prompts, desc=\"Testing Progress\")):\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\").to(DEVICE)\n", + "\n", + " with torch.no_grad():\n", + " teacher_output_ids = teacher_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)\n", + " student_output_ids = student_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)\n", + "\n", + " teacher_tokens = teacher_output_ids[0][inputs.input_ids.shape[1]:]\n", + " student_tokens = student_output_ids[0][inputs.input_ids.shape[1]:]\n", + " \n", + " total_tokens_in_sample = len(teacher_tokens)\n", + " if total_tokens_in_sample == 0:\n", + " continue # Teacher가 아무것도 생성하지 않은 경우 해당 샘플은 건너뜀\n", + "\n", + " mismatched_tokens_in_sample = 0\n", + " mismatch_details = []\n", + "\n", + " len_to_compare = min(len(teacher_tokens), len(student_tokens))\n", + " for i in range(len_to_compare):\n", + " if teacher_tokens[i] != student_tokens[i]:\n", + " mismatched_tokens_in_sample += 1\n", + " mismatch_details.append(\n", + " f\" - Position #{i+1}: \"\n", + " f\"Teacher='{tokenizer.decode(teacher_tokens[i])}' (ID: {teacher_tokens[i]}) vs \"\n", + " f\"Student='{tokenizer.decode(student_tokens[i])}' (ID: {student_tokens[i]})\"\n", + " )\n", + " \n", + " len_diff = abs(len(teacher_tokens) - len(student_tokens))\n", + " if len_diff > 0:\n", + " mismatched_tokens_in_sample += len_diff\n", + " stop_reason = \"Teacher가 먼저 중단\" if len(teacher_tokens) < len(student_tokens) else \"Student가 먼저 중단\"\n", + " mismatch_details.append(\n", + " f\" - 생성 길이 불일치: Teacher={len(teacher_tokens)}개, Student={len(student_tokens)}개 ({stop_reason})\"\n", + " )\n", + "\n", + " if mismatch_details:\n", + " print(f\"\\n\\n--- [sample #{sample_idx+1}] doesn't match---\")\n", + " print(f\"Prompt: \\\"{prompt[:80]}...\\\"\")\n", + " for detail in mismatch_details:\n", + " print(detail)\n", + " print(f\"Teacher output: \\\"...{tokenizer.decode(teacher_tokens, skip_special_tokens=True)}\\\"\")\n", + " print(f\"Student output: \\\"...{tokenizer.decode(student_tokens, skip_special_tokens=True)}\\\"\")\n", + " print(\"-\" * 25)\n", + "\n", + " accuracy = ((total_tokens_in_sample - mismatched_tokens_in_sample) / total_tokens_in_sample) * 100\n", + " total_accuracy += accuracy\n", + " total_mismatched += mismatched_tokens_in_sample\n", + " total_generated += total_tokens_in_sample\n", + "\n", + "if len(test_prompts) > 0:\n", + " avg_accuracy = total_accuracy / len(test_prompts)\n", + " print(\"\\n\\n\" + \"=\"*30)\n", + " print(\"--- 최종 평가 결과 요약 ---\")\n", + " print(f\"평균 토큰 일치율 (정확도): {avg_accuracy:.2f}%\")\n", + " print(f\"총 생성 토큰 (Teacher 기준): {total_generated}\")\n", + " print(f\"총 불일치 토큰: {total_mismatched}\")\n", + " print(\"=\"*30)\n", + "else:\n", + " print(\"테스트할 데이터를 찾지 못해 평가를 진행할 수 없습니다.\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/retentionengine/adapters/engine.py b/retentionengine/adapters/engine.py index 32f5ccb..d1e7291 100644 --- a/retentionengine/adapters/engine.py +++ b/retentionengine/adapters/engine.py @@ -1,14 +1,126 @@ from thelethe.titans import PretrainedTitansConfig, PreTrainedTitansModel -from transformers import PreTrainedModel +from transformers import PreTrainedModel, AutoTokenizer, AutoConfig from torch import nn +from torch.optim import AdamW +import torch +from typing import Optional, Dict, Any +import math + + +class RetentionEngineConfig(PretrainedTitansConfig): + """Extended configuration with size management""" + + def __init__( + self, + max_total_params: float = 8e9, # 8B + base_model_name: str = None, + base_model_params: float = None, + layers_to_remove: int = 0, + original_num_layers: int = None, + module_params: float = None, + **kwargs + ): + super().__init__(**kwargs) + self.max_total_params = max_total_params + self.base_model_name = base_model_name + self.base_model_params = base_model_params + self.layers_to_remove = layers_to_remove + self.original_num_layers = original_num_layers + self.module_params = module_params + + +class ModelSizeManager: + """Utility class for model size calculation and management""" + + @staticmethod + def count_parameters(model: nn.Module, exclude_embeddings: bool = False) -> int: + """Count total parameters in a model""" + total = 0 + for name, param in model.named_parameters(): + if exclude_embeddings and ('embed' in name.lower() or 'lm_head' in name.lower()): + continue + total += param.numel() + return total + + @staticmethod + def estimate_layer_params(model: PreTrainedModel) -> int: + """Estimate parameters per transformer layer""" + config = model.config + + # Get model architecture details + hidden_size = getattr(config, 'hidden_size', 4096) + intermediate_size = getattr(config, 'intermediate_size', hidden_size * 4) + num_attention_heads = getattr(config, 'num_attention_heads', 32) + num_key_value_heads = getattr(config, 'num_key_value_heads', num_attention_heads) + + # Estimate parameters per layer + # Self-attention: Q, K, V, O projections + attn_params = ( + hidden_size * hidden_size + # Q + hidden_size * (hidden_size * num_key_value_heads // num_attention_heads) * 2 + # K, V + hidden_size * hidden_size # O + ) + + # MLP: gate, up, down projections + mlp_params = ( + hidden_size * intermediate_size * 2 + # gate, up + intermediate_size * hidden_size # down + ) + + # Layer norms + norm_params = hidden_size * 2 * 2 # 2 layer norms, weight + bias + + return attn_params + mlp_params + norm_params + + @staticmethod + def calculate_layers_to_remove( + base_model: PreTrainedModel, + module_params: int, + max_total_params: float = 8e9 + ) -> int: + """Calculate how many layers to remove to stay under size limit""" + + total_base_params = ModelSizeManager.count_parameters(base_model) + + if total_base_params + module_params <= max_total_params: + return 0 + + # Need to remove layers + excess_params = (total_base_params + module_params) - max_total_params + params_per_layer = ModelSizeManager.estimate_layer_params(base_model) + + layers_to_remove = math.ceil(excess_params / params_per_layer) + + # Ensure we don't remove too many layers + num_layers = getattr(base_model.config, 'num_hidden_layers', 32) + max_removable = max(0, num_layers - 4) # Keep at least 4 layers + + return min(layers_to_remove, max_removable) class RetentionEngine(nn.Module): - def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): + def __init__( + self, + basemodel: PreTrainedModel, + config: RetentionEngineConfig, + auto_resize: bool = True + ): super().__init__() + + self.original_basemodel = basemodel + self.auto_resize = auto_resize + self.module = PreTrainedTitansModel(config) + + module_params = ModelSizeManager.count_parameters(self.module) + config.module_params = module_params + + if auto_resize: + basemodel = self._resize_basemodel(basemodel, config, module_params) + self.module.model = basemodel - + self._attach_embeddings() + # Combine the configurations self.config = config if hasattr(basemodel.config, 'to_dict'): @@ -19,32 +131,200 @@ def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): for key, value in basemodel.config.__dict__.items(): if not key.startswith('_'): # exclude private attributes self.config.__dict__[key] = value - + + def _resize_basemodel( + self, + basemodel: PreTrainedModel, + config: RetentionEngineConfig, + module_params: int + ) -> PreTrainedModel: + """Resize base model by removing layers if necessary""" + + # Calculate base model size + base_params = ModelSizeManager.count_parameters(basemodel) + config.base_model_params = base_params + config.original_num_layers = getattr(basemodel.config, 'num_hidden_layers', None) + + # Check if resizing is needed + total_params = base_params + module_params + + print(f"Base model params: {base_params/1e9:.2f}B") + print(f"Module params: {module_params/1e9:.2f}B") + print(f"Total params: {total_params/1e9:.2f}B") + + if total_params <= config.max_total_params: + print(f"Model size is within limit ({config.max_total_params/1e9:.1f}B). No resizing needed.") + config.layers_to_remove = 0 + return basemodel + + # Calculate layers to remove + layers_to_remove = ModelSizeManager.calculate_layers_to_remove( + basemodel, + module_params, + config.max_total_params + ) + + if layers_to_remove == 0: + print("Cannot reduce model size further while maintaining minimum layers.") + return basemodel + + print(f"Removing {layers_to_remove} layers to meet size constraint...") + config.layers_to_remove = layers_to_remove + + # Remove layers from model + resized_model = self._remove_layers(basemodel, layers_to_remove) + + # Recalculate size + new_base_params = ModelSizeManager.count_parameters(resized_model) + new_total = new_base_params + module_params + print(f"Resized base model: {new_base_params/1e9:.2f}B") + print(f"New total size: {new_total/1e9:.2f}B") + + return resized_model + + def _remove_layers(self, model: PreTrainedModel, num_layers: int) -> PreTrainedModel: + """Remove specified number of layers from the front of the model""" + + # Access the transformer layers + if hasattr(model, 'model') and hasattr(model.model, 'layers'): + layers = model.model.layers + else: + print("Warning: Could not identify layer structure. Returning original model.") + return model + + total_layers = len(layers) + if num_layers >= total_layers - 4: + print(f"Warning: Cannot remove {num_layers} layers. Keeping minimum 4 layers.") + num_layers = max(0, total_layers - 4) + if num_layers == 0: + return model + + new_layers = nn.ModuleList(layers[num_layers:]) + + # Replace layers in model + if hasattr(model, 'model') and hasattr(model.model, 'layers'): + model.model.layers = new_layers + + # Update config + if hasattr(model.config, 'num_hidden_layers'): + model.config.num_hidden_layers = len(new_layers) + elif hasattr(model.config, 'n_layer'): + model.config.n_layer = len(new_layers) + + print(f"Layers reduced from {total_layers} to {len(new_layers)}") + + return model + + def _merge_configs(self, base_config): + """Merge base model configuration with module configuration""" + if hasattr(base_config, 'to_dict'): + config_dict = base_config.to_dict() + for key, value in config_dict.items(): + if key not in ['num_hidden_layers', 'n_layer']: # Don't override layer count + self.config.__dict__[key] = value + elif hasattr(base_config, '__dict__'): + for key, value in base_config.__dict__.items(): + if not key.startswith('_') and key not in ['num_hidden_layers', 'n_layer']: + self.config.__dict__[key] = value + + def _attach_embeddings(self): + """Attach the input embeddings from the base model to the adapter model""" + if self.module.model is not None and hasattr(self.module.model, 'get_input_embeddings'): + input_embeddings = self.module.model.get_input_embeddings() + self.module.set_input_embeddings(input_embeddings) + def forward(self, *args, **kwargs): return self.module(*args, **kwargs) - - def adapt(self, *args, **kwargs): - pass - + + def get_model_size_info(self) -> Dict[str, Any]: + """Get detailed size information about the model""" + base_params = ModelSizeManager.count_parameters(self.module.model) if self.module.model else 0 + module_params = ModelSizeManager.count_parameters(self.module, exclude_embeddings=True) + total_params = base_params + module_params + + return { + 'base_model_params': base_params, + 'base_model_size_gb': base_params / 1e9, + 'module_params': module_params, + 'module_size_gb': module_params / 1e9, + 'total_params': total_params, + 'total_size_gb': total_params / 1e9, + 'layers_removed': self.config.layers_to_remove, + 'original_layers': self.config.original_num_layers, + 'current_layers': getattr(self.module.model.config, 'num_hidden_layers', None) if self.module.model else None + } + @classmethod def from_pretrained( - cls, - model_id: str, - basemodel: PreTrainedModel + cls, + model_id: str, + basemodel: PreTrainedModel, + auto_resize: bool = True, + max_total_params: float = 8e9 ) -> "RetentionEngine": - """ - Load a RetentionEngine from a pretrained model and configuration. - """ - config = PretrainedTitansConfig.from_pretrained(model_id) + """Load a RetentionEngine from a pretrained model and configuration""" + + # Load config with size constraints + config = RetentionEngineConfig.from_pretrained(model_id) + config.max_total_params = max_total_params + config.base_model_name = getattr(basemodel.config, '_name_or_path', 'unknown') + + # Create engine with auto-resizing + engine = cls(basemodel, config, auto_resize=auto_resize) + + # Load pretrained module weights module = PreTrainedTitansModel.from_pretrained(model_id) - module.model = basemodel - engine = cls(basemodel, config) + engine.module.model = engine.module.model # Use the resized model engine.module = module + return engine - + def save_pretrained(self, save_directory: str): - """ - Save the RetentionEngine to a directory. - """ + """Save the RetentionEngine to a directory""" + # Save size information in config + size_info = self.get_model_size_info() + for key, value in size_info.items(): + setattr(self.config, f'size_{key}', value) + self.module.save_pretrained(save_directory) self.config.save_pretrained(save_directory) + + # Save size report + import json + with open(f"{save_directory}/size_report.json", 'w') as f: + json.dump(size_info, f, indent=2) + + +# Example usage +if __name__ == "__main__": + from transformers import AutoModelForCausalLM + + # Load a base model (e.g., 13B model) + base_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-13b-hf", + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + # Create config with 8B size limit + config = RetentionEngineConfig( + max_total_params=8e9, # 8B max + base_model_name="meta-llama/Llama-2-13b-hf", + # Add other Titans config parameters here + ) + + # Create engine with automatic resizing + engine = RetentionEngine( + basemodel=base_model, + config=config, + auto_resize=True # Automatically remove layers if needed + ) + + # Check final size + size_info = engine.get_model_size_info() + print("\nFinal Model Size Information:") + for key, value in size_info.items(): + print(f" {key}: {value}") + + # Save the resized model + engine.save_pretrained("./retention_engine_8b") \ No newline at end of file diff --git a/retentionengine/datasets/__init__.py b/retentionengine/datasets/__init__.py new file mode 100644 index 0000000..15a5eb2 --- /dev/null +++ b/retentionengine/datasets/__init__.py @@ -0,0 +1 @@ +from .dataloader import get_pg19_dataloader \ No newline at end of file diff --git a/retentionengine/datasets/dataloader.py b/retentionengine/datasets/dataloader.py new file mode 100644 index 0000000..9f18f4d --- /dev/null +++ b/retentionengine/datasets/dataloader.py @@ -0,0 +1,65 @@ +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset + + +class LongTextDataset(Dataset): + """Custom Dataset for formatting and tokenizing long text documents.""" + + def __init__(self, documents, tokenizer, max_length=8192): # Gemma 모델에 맞춰 max_length 조절 + self.tokenizer = tokenizer + self.documents = documents + self.max_length = max_length + + def __len__(self): + return len(self.documents) + + def __getitem__(self, idx): + # PG19 데이터셋에서 텍스트 문서 하나를 가져옴 + text = self.documents[idx]['text'] + + # 텍스트를 토크나이징 + tokenized_output = self.tokenizer( + text, + add_special_tokens=True, # 문장의 시작(BOS)과 끝(EOS) 토큰 추가 + truncation=True, # max_length에 맞춰 자르기 + max_length=self.max_length, + padding='max_length' # max_length에 맞춰 패딩 + ) + + input_ids = tokenized_output['input_ids'] + + # 언어 모델 학습을 위해 labels를 input_ids를 복사하여 생성 + # 일반적으로 패딩 토큰은 손실 계산에서 제외 + labels = torch.tensor(input_ids) + labels[labels == self.tokenizer.pad_token_id] = -100 # 손실 계산에서 패딩 부분은 무시 + + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "labels": labels.to(torch.long) + } + + +def get_pg19_dataloader(tokenizer, batch_size, split='train'): + """ + Creates a DataLoader for the PG19 dataset. + + Args: + tokenizer: The tokenizer to use. + batch_size: The batch size for the DataLoader. + split: The dataset split to use ('train', 'validation', or 'test'). + """ + # PG19 데이터셋 로드 + dataset = load_dataset("pg19", split=split) + + # Dataset 객체 생성 + longtext_dataset = LongTextDataset(dataset, tokenizer) + + # DataLoader 생성 + dataloader = DataLoader( + longtext_dataset, + batch_size=batch_size, + shuffle=(split == 'train') + ) + + return dataloader \ No newline at end of file diff --git a/retentionengine/datasets/dataset_generator.py b/retentionengine/datasets/dataset_generator.py new file mode 100644 index 0000000..85c30f1 --- /dev/null +++ b/retentionengine/datasets/dataset_generator.py @@ -0,0 +1,271 @@ +# -*- coding: utf-8 -*- +""" +Qwen2-based Memory Module Training Dataset Generator (Stabilized Version) +---------------------------------------------------------------------- +Generated data: RETENTIONENGINE/retentionengine/datasets/qwen2_memory_aligned_dataset.jsonl +→ Includes 2/3 Korean and 1/3 English samples. +""" + +import json +import re +import time +import random +import uuid +from typing import List, Dict +import os # Added for directory creation +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import json_repair # Library for repairing malformed JSON; install via: pip install json-repair + +def make_seed_prompt(batch_size: int = 25) -> str: # Reduced default for faster testing + """ + Return a Korean prompt with 2/3 probability, and an English prompt with 1/3 probability. + """ + if random.random() < 0.33: # Approx. 1/3 probability for English + return make_seed_prompt_english(batch_size) + else: + return make_seed_prompt_korean(batch_size) + +def make_seed_prompt_korean(batch_size: int = 25) -> str: + """ + Korean prompt + """ + domains = [ + ("math", 20), ("cs", 20), ("physics", 15), ("literature", 15), + ("philosophy", 10), ("history", 10), ("biology", 5), ("economics", 5) + ] + domain_str = ", ".join([f"{d}({p}%)" for d, p in domains]) + + prompt = f""" +다음 JSON 스키마를 따르는 데이터를 정확히 {batch_size}개 포함하는 단일 JSON 배열(array)을 생성해줘. + +{{ + "id": "sample_001", + "domain": "예: math, physics, literature, cs, philosophy 등", + "input": "장문 또는 단문 입력. 다양한 도메인 혼합 가능", + "target_output": "Qwen2이 생성할 것으로 예상되는 출력", # Updated to Qwen2 + "reasoning_trace": [ + {{ + "step": 1, + "thought": "내부 사고 과정", + "memory_read": ["참조한 메모리 블록 이름"], + "inference": "추론 내용" + }} + ], + "memory_access_pattern": {{ + "reads": ["메모리_블록_A", "메모리_블록_B"], + "writes": ["결과_요약_01"] + }}, + "metadata": {{ + "context_length": 입력 토큰 수, + "reasoning_depth": "trivial / intermediate / advanced" + }} +}} + +### 생성 규칙: +- 30%: 짧은 질문 (100~500 토큰), 출력 분포 정렬용 +- 50%: 중간 길이 CoT (1K~5K 토큰), 메모리 트레이스 포함 +- 20%: 장문 크로스 도메인 (8K~30K 토큰), multi-hop 추론 +- 도메인 분포: {domain_str} +- reasoning_trace는 2~5단계 포함 +- memory_read/write는 실제 메모리 블록 이름 사용 (예: physics_laws_v1, literature_themes_02) +- 문자열 내부에 따옴표(\")가 있다면 반드시 백슬래시(\\)로 이스케이프 처리해줘 (예: `\"`). +- 오직 하나의 완성된 JSON 배열만 출력해줘. 설명, 마크다운, ```json``` 태그는 절대 포함하지 마. + +### 완벽한 출력 예시: +[ + {{ + "id": "temp_001", + "domain": "literature", + "input": "소설 '데미안'의 주제는 무엇인가?", + "target_output": "헤르만 헤세의 소설 '데미안'의 핵심 주제는 개인의 자아 발견과 성장 과정입니다. 주인공 싱클레어가 알을 깨고 나오듯, 기존의 세계를 벗어나 자기 자신만의 길을 찾아가는 과정을 그립니다.", + "reasoning_trace": [], + "memory_access_pattern": {{"reads": [], "writes": []}}, + "metadata": {{"context_length": 50, "reasoning_depth": "trivial"}} + }} +] +""" + return prompt.strip() + +def make_seed_prompt_english(batch_size: int = 25) -> str: + """ + English prompt (called with 1/3 probability) + """ + domains = [ + ("math", 20), ("cs", 20), ("physics", 15), ("literature", 15), + ("philosophy", 10), ("history", 10), ("biology", 5), ("economics", 5) + ] + domain_str = ", ".join([f"{d}({p}%)" for d, p in domains]) + + prompt = f""" +Generate a single JSON array containing exactly {batch_size} samples that follow the JSON schema below. + +{{ + "id": "sample_001", + "domain": "e.g., math, physics, literature, cs, philosophy", + "input": "Long or short input. Mix of domains allowed", + "target_output": "Output expected from Qwen2", # Updated to Qwen2 + "reasoning_trace": [ + {{ + "step": 1, + "thought": "Internal reasoning process", + "memory_read": ["memory_block_name"], + "inference": "Inference content" + }} + ], + "memory_access_pattern": {{ + "reads": ["memory_block_A", "memory_block_B"], + "writes": ["summary_01"] + }}, + "metadata": {{ + "context_length": "number of input tokens", + "reasoning_depth": "trivial / intermediate / advanced" + }} +}} + +### Rules: +- 30%: Short questions (100-500 tokens), for output distribution alignment +- 50%: Medium-length CoT (1K-5K tokens), with memory trace +- 20%: Long cross-domain (8K-30K tokens), multi-hop reasoning +- Domain distribution: {domain_str} +- reasoning_trace must have 2-5 steps +- Use realistic memory block names (e.g., physics_laws_v1, literature_themes_02) +- Output only a single, complete JSON array. Do not include any explanation, markdown, or ```json``` tags. +- Example: `[ {{"id": "...", ...}} ]` +""" + return prompt.strip() + +def parse_json_array(text: str) -> List[Dict]: + """ + Extract a valid JSON array from the generated text. + """ + # Find the first '[' and the last ']' to capture the array + match = re.search(r'\[.*\]', text, re.DOTALL) + if not match: + print("⚠️ WARNING: Could not find a JSON array in the response.") + return [] + + json_array_str = match.group(0) + try: + data = json.loads(json_array_str) + except json.JSONDecodeError: + # If standard JSON decoding fails, attempt to repair with json_repair + print("⚠️ WARNING: JSON decoding failed. Attempting to repair with json_repair.") + try: + data = json_repair.loads(json_array_str) + except Exception as e: + print(f"🆘 CRITICAL: Repair with json_repair also failed: {e}") + return [] + + if isinstance(data, list): + # Ensure all items in the list are dicts with the required keys + valid_items = [ + item for item in data + if isinstance(item, dict) and "input" in item and "target_output" in item + ] + return valid_items + else: + print("⚠️ WARNING: Parsed data is not a list.") + return [] + +def generate_dataset( + model_name: str, + target_count: int = 10000, + batch_size: int = 25 +): + """ + Generate a dataset using the specified model and save it to a JSONL file. + """ + output_dir = "RETENTIONENGINE/retentionengine/datasets" + output_file = os.path.join(output_dir, "qwen2_memory_aligned_dataset.jsonl") # Updated name and path + os.makedirs(output_dir, exist_ok=True) # Create dirs if needed + + print(f"🚀 Loading model: {model_name}...") + try: + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) + print(f"✅ Model loaded successfully. Output file: {output_file}") + except Exception as e: + print(f"🆘 ERROR: Failed to load model: {e}") + return [] # Early exit on load failure + + generated_samples = [] + + with open(output_file, "w", encoding="utf-8") as f_out: + while len(generated_samples) < target_count: + remaining = target_count - len(generated_samples) + current_batch_size = min(batch_size, remaining) + print(f"🔄 Generating: {len(generated_samples)} / {target_count} | Current batch size: {current_batch_size}") + + prompt = make_seed_prompt(current_batch_size) + + # Apply the Qwen2 chat template + system_prompt = "<|im_start|>system\nYou are a helpful assistant that generates high-quality synthetic data.<|im_end|>" + user_prompt = f"<|im_start|>user\n{prompt}<|im_end|>" + assistant_prompt = "<|im_start|>assistant" + + final_text = f"{system_prompt}\n{user_prompt}\n{assistant_prompt}" + + inputs = tokenizer(final_text, return_tensors="pt").to(model.device) + + # Dynamically calculate max_new_tokens to avoid exceeding context window + input_len = inputs["input_ids"].shape[1] + # Set a reasonable generation cap and respect the model's context length (32768 for Qwen2) + max_new = min(16000, 32768 - input_len - 10) + + try: + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new, + temperature=0.5, + do_sample=True, + top_p=0.9, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.eos_token_id + ) + + # Decode only the newly generated tokens + input_token_length = inputs["input_ids"].shape[1] + generated_tokens = outputs[0][input_token_length:] + content = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() + + new_samples = parse_json_array(content) + print(f"✅ Extracted {len(new_samples)} valid samples from this batch.") + + if not new_samples: + print("⚠️ Failed to generate valid samples in this batch. Retrying after a delay.") + time.sleep(5) + continue + + for sample in new_samples: + if len(generated_samples) >= target_count: + break + sample["id"] = str(uuid.uuid4()) # Assign a unique ID + f_out.write(json.dumps(sample, ensure_ascii=False) + "\n") + generated_samples.append(sample) + + f_out.flush() # Write data to disk periodically + print(f"📊 Total generated: {len(generated_samples)} / {target_count}") + + time.sleep(2) # Small delay between API calls + except Exception as e: + print(f"🆘 ERROR during generation: {e}. Retrying after delay.") + time.sleep(5) + continue + + print(f"\n🎉 Generation successfully completed!") + print(f"📁 Data saved to: {output_file}") + print(f"📊 Total samples generated: {len(generated_samples)}") + return generated_samples + +if __name__ == "__main__": + generate_dataset( + model_name="Qwen/Qwen2-7B-Instruct", # Fixed to a real model; adjust size if needed (e.g., -1.5B for smaller) + target_count=5000, + batch_size=25 + ) \ No newline at end of file diff --git a/retentionengine/utils/__init__.py b/retentionengine/utils/__init__.py index e69de29..a7dde9b 100644 --- a/retentionengine/utils/__init__.py +++ b/retentionengine/utils/__init__.py @@ -0,0 +1 @@ +from .adapter import Adapter diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py new file mode 100644 index 0000000..b918f68 --- /dev/null +++ b/retentionengine/utils/adapter.py @@ -0,0 +1,72 @@ +# 파일명: huggingface_adapter.py + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import Trainer +import transformer_engine.pytorch as te + +def convert_to_fp8_layers(module: nn.Module) -> nn.Module: + """nn.Linear와 nn.LayerNorm을 Transformer Engine의 FP8 지원 레이어로 변환합니다.""" + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, te.Linear( + child.in_features, + child.out_features, + bias=(child.bias is not None), + )) + elif isinstance(child, nn.LayerNorm): + setattr(module, name, te.LayerNorm( + child.normalized_shape, + eps=child.eps, + )) + else: + convert_to_fp8_layers(child) + return module + +class DistillationTrainer(Trainer): + """ + KL Divergence Loss를 사용한 지식 증류(복원 학습)를 위한 커스텀 Trainer. + """ + def __init__(self, *args, teacher_model, alpha=1.0, temperature=2.0, **kwargs): + super().__init__(*args, **kwargs) + + print("INFO: Teacher 모델을 FP8 지원 레이어로 변환합니다...") + self.teacher_model = convert_to_fp8_layers(teacher_model) + print("INFO: 변환 완료.") + + self.alpha = alpha + self.temperature = temperature + + self.teacher_model.to(self.model.device) + self.teacher_model.eval() + for param in self.teacher_model.parameters(): + param.requires_grad = False + + def compute_loss(self, model, inputs, return_outputs=False): + """ + 손실 계산 로직: 학생 모델의 LM Loss와 선생님-학생 간 Distillation Loss를 조합. + """ + # 학생(Student) 모델의 출력 및 기본 손실(LM Loss) 계산 + student_outputs = model(**inputs) + loss_lm = student_outputs.loss + logits_student = student_outputs.logits + + # 선생님(Teacher) 모델의 출력을 FP8로 빠르게 계산 + with torch.no_grad(): + with te.fp8_autocast(enabled=True): + teacher_outputs = self.teacher_model(**inputs) + logits_teacher = teacher_outputs.logits + + # KL Divergence Loss를 사용한 Distillation Loss 계산 + distillation_loss_fn = nn.KLDivLoss(reduction="batchmean") + + soft_targets = F.softmax(logits_teacher / self.temperature, dim=-1) + soft_prob = F.log_softmax(logits_student / self.temperature, dim=-1) + + loss_distill = distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) + + # 최종 손실: alpha=1.0으로 설정 시, 오직 두 모델을 똑같이 만드는 데 집중 + loss = (1. - self.alpha) * loss_lm + self.alpha * loss_distill + + return (loss, student_outputs) if return_outputs else loss \ No newline at end of file diff --git a/usage.ipynb b/usage.ipynb index 8ae046a..811af45 100644 --- a/usage.ipynb +++ b/usage.ipynb @@ -2,86 +2,82 @@ "cells": [ { "cell_type": "code", + "execution_count": 1, "id": "initial_id", "metadata": { - "collapsed": true, "ExecuteTime": { "end_time": "2025-05-05T23:45:48.721890Z", "start_time": "2025-05-05T23:45:34.721988Z" - } + }, + "collapsed": true }, + "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from retentionengine import RetentionEngine\n", "from thelethe.titans import CronosConfig" - ], - "outputs": [], - "execution_count": 1 + ] }, { + "cell_type": "code", + "execution_count": null, + "id": "f20a5a772cb31cca", "metadata": { "ExecuteTime": { "end_time": "2025-05-05T23:49:38.482121Z", "start_time": "2025-05-05T23:45:48.934067Z" } }, - "cell_type": "code", - "source": [ - "# Load Gemma3 4b model and tokenizer\n", - "basemodel_name = \"google/gemma-3-4b-it\"\n", - "tokenizer = AutoTokenizer.from_pretrained(basemodel_name)\n", - "basemodel = AutoModelForCausalLM.from_pretrained(basemodel_name)" - ], - "id": "f20a5a772cb31cca", "outputs": [ { "data": { - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00