From 9d5a2bb0e883ff754d5c17d52eda58372eaa8b87 Mon Sep 17 00:00:00 2001 From: 3GID Date: Tue, 1 Jul 2025 14:31:29 +0900 Subject: [PATCH 01/13] Feat: Add training adapter and dataloader for OASST dataset Introduces a Trainer utility class for model training and evaluation, and a dataloader for the OpenAssistant/oasst1 dataset with conversation extraction and tokenization. Updates RetentionEngine to support adaptation via training on OASST data using the new utilities. --- retentionengine/adapters/engine.py | 44 +++++++++- retentionengine/data/dataloader.py | 132 +++++++++++++++++++++++++++++ retentionengine/utils/adapter.py | 57 +++++++++++++ 3 files changed, 230 insertions(+), 3 deletions(-) create mode 100644 retentionengine/data/dataloader.py create mode 100644 retentionengine/utils/adapter.py diff --git a/retentionengine/adapters/engine.py b/retentionengine/adapters/engine.py index 32f5ccb..7a18ba9 100644 --- a/retentionengine/adapters/engine.py +++ b/retentionengine/adapters/engine.py @@ -1,7 +1,11 @@ from thelethe.titans import PretrainedTitansConfig, PreTrainedTitansModel -from transformers import PreTrainedModel +from transformers import PreTrainedModel, AutoTokenizer +from torch.optim import AdamW from torch import nn +from ..utils import Trainer +from ..data import get_oasst_dataloader + class RetentionEngine(nn.Module): def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): @@ -23,8 +27,42 @@ def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): def forward(self, *args, **kwargs): return self.module(*args, **kwargs) - def adapt(self, *args, **kwargs): - pass + def adapt( + self, + tokenizer: AutoTokenizer, + epochs: int = 1, + batch_size: int = 4, + learning_rate: float = 1e-5, + train_lang: str = None, + eval_lang: str = None, + device: str = "cuda" + ): + + train_dataloader = get_oasst_dataloader(tokenizer, batch_size, split='train', lang=train_lang) + eval_dataloader = None + if eval_lang: + eval_dataloader = get_oasst_dataloader(tokenizer, batch_size, split='validation', lang=eval_lang) + + optimizer = AdamW(self.parameters(), lr=learning_rate) + loss_fn = nn.CrossEntropyLoss() + + trainer = Trainer( + model=self, + optimizer=optimizer, + loss_fn=loss_fn, + device=device + ) + + for epoch in range(epochs): + print(f"\n--- Epoch {epoch + 1}/{epochs} ---") + + avg_train_loss = trainer.train_epoch(train_dataloader) + print(f"Average Training Loss: {avg_train_loss:.4f}") + + if eval_dataloader: + avg_eval_loss = trainer.eval_epoch(eval_dataloader) + print(f"Average Validation Loss: {avg_eval_loss:.4f}") + @classmethod def from_pretrained( diff --git a/retentionengine/data/dataloader.py b/retentionengine/data/dataloader.py new file mode 100644 index 0000000..d5a8e87 --- /dev/null +++ b/retentionengine/data/dataloader.py @@ -0,0 +1,132 @@ +# data/dataloader.py + +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset +from collections import defaultdict + + +def _extract_conversations(dataset): + """Extracts conversation threads from the oasst1 dataset.""" + message_map = {msg['message_id']: msg for msg in dataset} + + threads = defaultdict(list) + for msg in dataset: + threads[msg['parent_id']].append(msg) + + conversations = [] + for root in threads.get(None, []): + # Build a thread from the root message + thread = [] + queue = [root] + visited = set() + + while queue: + message = queue.pop(0) + if message['message_id'] in visited: + continue + visited.add(message['message_id']) + + thread.append({ + "role": message['role'], + "text": message['text'], + "rank": message['rank'] + }) + + children = sorted(threads.get(message['message_id'], []), key=lambda x: x['rank']) + queue.extend(children) + + # Only include threads with at least one assistant response + if any(msg['role'] == 'assistant' for msg in thread): + conversations.append(thread) + + return conversations + + +class DialogueDataset(Dataset): + """Custom Dataset for formatting and tokenizing dialogue conversations.""" + + def __init__(self, conversations, tokenizer, max_length=1024): + self.tokenizer = tokenizer + self.conversations = conversations + self.max_length = max_length + self.role_tokens = {"prompter": "<|prompter|>", "assistant": "<|assistant|>"} + + # Add special role tokens to the tokenizer if they don't exist + self.tokenizer.add_special_tokens( + {"additional_special_tokens": list(self.role_tokens.values())} + ) + + def __len__(self): + return len(self.conversations) + + def __getitem__(self, idx): + conversation = self.conversations[idx] + + input_ids = [] + labels = [] + + # Start with the beginning-of-sequence token + input_ids.append(self.tokenizer.bos_token_id) + labels.append(-100) # Do not compute loss on the BOS token + + for utterance in conversation: + role_token = self.role_tokens[utterance['role']] + message_tokens = self.tokenizer( + f"{role_token}{utterance['text']}{self.tokenizer.eos_token}", + add_special_tokens=False + ).input_ids + + input_ids.extend(message_tokens) + + if utterance['role'] == 'prompter': + # Mask out prompter's messages and role token in labels + labels.extend([-100] * len(message_tokens)) + else: # assistant + # Only compute loss on the assistant's messages + labels.extend(message_tokens) + + # Truncate to max_length + input_ids = input_ids[:self.max_length] + labels = labels[:self.max_length] + + # Pad to max_length + padding_length = self.max_length - len(input_ids) + input_ids.extend([self.tokenizer.pad_token_id] * padding_length) + labels.extend([-100] * padding_length) + + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "labels": torch.tensor(labels, dtype=torch.long) + } + + +def get_oasst_dataloader(tokenizer, batch_size, split='train', lang=None): + """ + Creates a DataLoader for the OpenAssistant/oasst1 dataset. + + Args: + tokenizer: The tokenizer to use. + batch_size: The batch size for the DataLoader. + split: The dataset split to use ('train' or 'validation'). + lang: The language code to filter by (e.g., 'en', 'ko'). + If None, all languages are used. + """ + dataset = load_dataset("OpenAssistant/oasst1", split=split) + + if lang: + dataset = dataset.filter(lambda example: example['lang'] == lang) + + # Filter for high-quality, ready-to-use data + dataset = dataset.filter(lambda x: x["rank"] is not None and x["rank"] == 0) + + conversations = _extract_conversations(dataset) + dialogue_dataset = DialogueDataset(conversations, tokenizer) + + dataloader = DataLoader( + dialogue_dataset, + batch_size=batch_size, + shuffle=(split == 'train') + ) + + return dataloader \ No newline at end of file diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py new file mode 100644 index 0000000..820c3b4 --- /dev/null +++ b/retentionengine/utils/adapter.py @@ -0,0 +1,57 @@ +import torch +from torch import nn +from torch.utils.data import DataLoader + + +class Trainer: + def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, + loss_fn: nn.Module = nn.CrossEntropyLoss(), device: str = "cuda"): + self.model = model + self.optimizer = optimizer + self.loss_fn = loss_fn + self.device = device + self.model.to(self.device) + + def train_step(self, batch: tuple) -> float: + self.model.train() + inputs, targets = batch + + inputs = inputs[:, :-1].to(self.device) + targets = targets[:, 1:].to(self.device) + + self.optimizer.zero_grad() + outputs = self.model(inputs) + + loss = self.loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)) + + loss.backward() + self.optimizer.step() + + return loss.item() + + def eval_step(self, batch: tuple) -> float: + self.model.eval() + with torch.no_grad(): + inputs, targets = batch + + inputs = inputs[:, :-1].to(self.device) + targets = targets[:, 1:].to(self.device) + + outputs = self.model(inputs) + loss = self.loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)) + + return loss.item() + + def train_epoch(self, dataloader: DataLoader) -> float: + total_loss = 0 + for batch in dataloader: + loss = self.train_step(batch) + total_loss += loss + return total_loss / len(dataloader) + + def eval_epoch(self, dataloader: DataLoader) -> float: + total_loss = 0 + for batch in dataloader: + loss = self.eval_step(batch) + total_loss += loss + return total_loss / len(dataloader) \ No newline at end of file From 7e280191a76674176ca3ba789490f9b2eef0f0c8 Mon Sep 17 00:00:00 2001 From: 3GID Date: Tue, 1 Jul 2025 14:32:04 +0900 Subject: [PATCH 02/13] Init: Init methods --- retentionengine/data/__init__.py | 1 + retentionengine/utils/__init__.py | 1 + 2 files changed, 2 insertions(+) create mode 100644 retentionengine/data/__init__.py diff --git a/retentionengine/data/__init__.py b/retentionengine/data/__init__.py new file mode 100644 index 0000000..c599c46 --- /dev/null +++ b/retentionengine/data/__init__.py @@ -0,0 +1 @@ +from .dataloader import get_oasst_dataloader \ No newline at end of file diff --git a/retentionengine/utils/__init__.py b/retentionengine/utils/__init__.py index e69de29..10a4fbb 100644 --- a/retentionengine/utils/__init__.py +++ b/retentionengine/utils/__init__.py @@ -0,0 +1 @@ +from .adapter import Trainer From 91b9bd63c94167eaa3968c476007913198a4c4ac Mon Sep 17 00:00:00 2001 From: 3GID Date: Tue, 8 Jul 2025 16:39:32 +0900 Subject: [PATCH 03/13] Feat: Fix training adpater and change datasets Replaces the OASST dataloader with a new PG19 dataloader under retentionengine/datasets, removing the old dataloader. Updates RetentionEngine and Adapter to use the new dataloader and support distillation training with PG19. Refactors training logic to use the new Adapter class and updates training parameters for long document handling. --- retentionengine/adapters/engine.py | 64 +++++++----- retentionengine/data/dataloader.py | 132 ------------------------- retentionengine/datasets/dataloader.py | 65 ++++++++++++ retentionengine/utils/adapter.py | 68 +++++++++---- 4 files changed, 152 insertions(+), 177 deletions(-) delete mode 100644 retentionengine/data/dataloader.py create mode 100644 retentionengine/datasets/dataloader.py diff --git a/retentionengine/adapters/engine.py b/retentionengine/adapters/engine.py index 7a18ba9..3baa41c 100644 --- a/retentionengine/adapters/engine.py +++ b/retentionengine/adapters/engine.py @@ -1,10 +1,10 @@ from thelethe.titans import PretrainedTitansConfig, PreTrainedTitansModel from transformers import PreTrainedModel, AutoTokenizer -from torch.optim import AdamW from torch import nn +from torch.optim import AdamW -from ..utils import Trainer -from ..data import get_oasst_dataloader +from ..utils import Adapter +from ..datasets import get_pg19_dataloader class RetentionEngine(nn.Module): @@ -27,41 +27,51 @@ def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): def forward(self, *args, **kwargs): return self.module(*args, **kwargs) - def adapt( - self, + def adapt(self, tokenizer: AutoTokenizer, - epochs: int = 1, - batch_size: int = 4, - learning_rate: float = 1e-5, - train_lang: str = None, - eval_lang: str = None, - device: str = "cuda" + epochs: int = 3, + batch_size: int = 1, + max_length: int = 8192, + learning_rate: float = 2e-5, + + alpha: float = 0.3, + temperature: float = 2.0 ): - train_dataloader = get_oasst_dataloader(tokenizer, batch_size, split='train', lang=train_lang) - eval_dataloader = None - if eval_lang: - eval_dataloader = get_oasst_dataloader(tokenizer, batch_size, split='validation', lang=eval_lang) + device = self.module.device + # 2. 모델 역할 정의 및 설정 🧑‍🏫 + adapt_model = self.module + base_model = self.module.model + base_model.eval() + for param in base_model.parameters(): + param.requires_grad = False + for param in adapt_model.attention_module.parameters(): + param.requires_grad = False # Attention 모듈은 학습하지 않음 + adapt_model.train() - optimizer = AdamW(self.parameters(), lr=learning_rate) - loss_fn = nn.CrossEntropyLoss() + train_dataloader = get_pg19_dataloader(tokenizer, batch_size, 'train', max_length) + eval_dataloader = get_pg19_dataloader(tokenizer, batch_size, 'validation', max_length) - trainer = Trainer( - model=self, + # 4. 옵티마이저 설정 ✍️ + optimizer = AdamW(adapt_model.parameters(), lr=learning_rate) + + # 5. 증류 학습 트레이너 설정 + trainer = Adapter( + adapt_model=adapt_model, + base_model=base_model, optimizer=optimizer, - loss_fn=loss_fn, - device=device + device=device, + alpha=alpha, + temperature=temperature ) for epoch in range(epochs): print(f"\n--- Epoch {epoch + 1}/{epochs} ---") + train_loss = trainer.train_epoch(train_dataloader) + eval_loss = trainer.eval_epoch(eval_dataloader) + print(f"Epoch {epoch + 1} | Train Loss: {train_loss:.4f} | Eval Loss: {eval_loss:.4f}") + - avg_train_loss = trainer.train_epoch(train_dataloader) - print(f"Average Training Loss: {avg_train_loss:.4f}") - - if eval_dataloader: - avg_eval_loss = trainer.eval_epoch(eval_dataloader) - print(f"Average Validation Loss: {avg_eval_loss:.4f}") @classmethod diff --git a/retentionengine/data/dataloader.py b/retentionengine/data/dataloader.py deleted file mode 100644 index d5a8e87..0000000 --- a/retentionengine/data/dataloader.py +++ /dev/null @@ -1,132 +0,0 @@ -# data/dataloader.py - -import torch -from torch.utils.data import Dataset, DataLoader -from datasets import load_dataset -from collections import defaultdict - - -def _extract_conversations(dataset): - """Extracts conversation threads from the oasst1 dataset.""" - message_map = {msg['message_id']: msg for msg in dataset} - - threads = defaultdict(list) - for msg in dataset: - threads[msg['parent_id']].append(msg) - - conversations = [] - for root in threads.get(None, []): - # Build a thread from the root message - thread = [] - queue = [root] - visited = set() - - while queue: - message = queue.pop(0) - if message['message_id'] in visited: - continue - visited.add(message['message_id']) - - thread.append({ - "role": message['role'], - "text": message['text'], - "rank": message['rank'] - }) - - children = sorted(threads.get(message['message_id'], []), key=lambda x: x['rank']) - queue.extend(children) - - # Only include threads with at least one assistant response - if any(msg['role'] == 'assistant' for msg in thread): - conversations.append(thread) - - return conversations - - -class DialogueDataset(Dataset): - """Custom Dataset for formatting and tokenizing dialogue conversations.""" - - def __init__(self, conversations, tokenizer, max_length=1024): - self.tokenizer = tokenizer - self.conversations = conversations - self.max_length = max_length - self.role_tokens = {"prompter": "<|prompter|>", "assistant": "<|assistant|>"} - - # Add special role tokens to the tokenizer if they don't exist - self.tokenizer.add_special_tokens( - {"additional_special_tokens": list(self.role_tokens.values())} - ) - - def __len__(self): - return len(self.conversations) - - def __getitem__(self, idx): - conversation = self.conversations[idx] - - input_ids = [] - labels = [] - - # Start with the beginning-of-sequence token - input_ids.append(self.tokenizer.bos_token_id) - labels.append(-100) # Do not compute loss on the BOS token - - for utterance in conversation: - role_token = self.role_tokens[utterance['role']] - message_tokens = self.tokenizer( - f"{role_token}{utterance['text']}{self.tokenizer.eos_token}", - add_special_tokens=False - ).input_ids - - input_ids.extend(message_tokens) - - if utterance['role'] == 'prompter': - # Mask out prompter's messages and role token in labels - labels.extend([-100] * len(message_tokens)) - else: # assistant - # Only compute loss on the assistant's messages - labels.extend(message_tokens) - - # Truncate to max_length - input_ids = input_ids[:self.max_length] - labels = labels[:self.max_length] - - # Pad to max_length - padding_length = self.max_length - len(input_ids) - input_ids.extend([self.tokenizer.pad_token_id] * padding_length) - labels.extend([-100] * padding_length) - - return { - "input_ids": torch.tensor(input_ids, dtype=torch.long), - "labels": torch.tensor(labels, dtype=torch.long) - } - - -def get_oasst_dataloader(tokenizer, batch_size, split='train', lang=None): - """ - Creates a DataLoader for the OpenAssistant/oasst1 dataset. - - Args: - tokenizer: The tokenizer to use. - batch_size: The batch size for the DataLoader. - split: The dataset split to use ('train' or 'validation'). - lang: The language code to filter by (e.g., 'en', 'ko'). - If None, all languages are used. - """ - dataset = load_dataset("OpenAssistant/oasst1", split=split) - - if lang: - dataset = dataset.filter(lambda example: example['lang'] == lang) - - # Filter for high-quality, ready-to-use data - dataset = dataset.filter(lambda x: x["rank"] is not None and x["rank"] == 0) - - conversations = _extract_conversations(dataset) - dialogue_dataset = DialogueDataset(conversations, tokenizer) - - dataloader = DataLoader( - dialogue_dataset, - batch_size=batch_size, - shuffle=(split == 'train') - ) - - return dataloader \ No newline at end of file diff --git a/retentionengine/datasets/dataloader.py b/retentionengine/datasets/dataloader.py new file mode 100644 index 0000000..9f18f4d --- /dev/null +++ b/retentionengine/datasets/dataloader.py @@ -0,0 +1,65 @@ +import torch +from torch.utils.data import Dataset, DataLoader +from datasets import load_dataset + + +class LongTextDataset(Dataset): + """Custom Dataset for formatting and tokenizing long text documents.""" + + def __init__(self, documents, tokenizer, max_length=8192): # Gemma 모델에 맞춰 max_length 조절 + self.tokenizer = tokenizer + self.documents = documents + self.max_length = max_length + + def __len__(self): + return len(self.documents) + + def __getitem__(self, idx): + # PG19 데이터셋에서 텍스트 문서 하나를 가져옴 + text = self.documents[idx]['text'] + + # 텍스트를 토크나이징 + tokenized_output = self.tokenizer( + text, + add_special_tokens=True, # 문장의 시작(BOS)과 끝(EOS) 토큰 추가 + truncation=True, # max_length에 맞춰 자르기 + max_length=self.max_length, + padding='max_length' # max_length에 맞춰 패딩 + ) + + input_ids = tokenized_output['input_ids'] + + # 언어 모델 학습을 위해 labels를 input_ids를 복사하여 생성 + # 일반적으로 패딩 토큰은 손실 계산에서 제외 + labels = torch.tensor(input_ids) + labels[labels == self.tokenizer.pad_token_id] = -100 # 손실 계산에서 패딩 부분은 무시 + + return { + "input_ids": torch.tensor(input_ids, dtype=torch.long), + "labels": labels.to(torch.long) + } + + +def get_pg19_dataloader(tokenizer, batch_size, split='train'): + """ + Creates a DataLoader for the PG19 dataset. + + Args: + tokenizer: The tokenizer to use. + batch_size: The batch size for the DataLoader. + split: The dataset split to use ('train', 'validation', or 'test'). + """ + # PG19 데이터셋 로드 + dataset = load_dataset("pg19", split=split) + + # Dataset 객체 생성 + longtext_dataset = LongTextDataset(dataset, tokenizer) + + # DataLoader 생성 + dataloader = DataLoader( + longtext_dataset, + batch_size=batch_size, + shuffle=(split == 'train') + ) + + return dataloader \ No newline at end of file diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py index 820c3b4..625a736 100644 --- a/retentionengine/utils/adapter.py +++ b/retentionengine/utils/adapter.py @@ -1,57 +1,89 @@ import torch from torch import nn from torch.utils.data import DataLoader +import torch.nn.functional as F +from tqdm.auto import tqdm -class Trainer: - def __init__(self, model: nn.Module, optimizer: torch.optim.Optimizer, - loss_fn: nn.Module = nn.CrossEntropyLoss(), device: str = "cuda"): - self.model = model +class Adapter: + def __init__(self, + adapt_model: nn.Module, + base_model: nn.Module, + optimizer: torch.optim.Optimizer, + hard_loss_fn: nn.Module = nn.CrossEntropyLoss(),#실제답과의 loss + distillation_loss_fn: nn.Module = nn.KLDivLoss(reduction='batchmean'), #base model과의 distillation loss + temperature: float = 2.0, + alpha: float = 0.5, #hard loss와 soft loss의 가중치 + device: str = "cuda"): + + self.adapt_model = adapt_model.to(device) + self.base_model = base_model.to(device) self.optimizer = optimizer - self.loss_fn = loss_fn + self.hard_loss_fn = hard_loss_fn + self.distillation_loss_fn = distillation_loss_fn + self.temperature = temperature + self.alpha = alpha self.device = device - self.model.to(self.device) + + self.base_model.eval() + for param in self.base_model.parameters(): + param.requires_grad = False # base model의 파라미터는 학습하지 않음 def train_step(self, batch: tuple) -> float: - self.model.train() + self.adapt_model.train() inputs, targets = batch inputs = inputs[:, :-1].to(self.device) targets = targets[:, 1:].to(self.device) - self.optimizer.zero_grad() - outputs = self.model(inputs) + with torch.no_grad(): + base_outputs = self.base_model(inputs) + + adapt_outputs = self.adapt_model(inputs) + + loss_hard = self.hard_loss_fn( + adapt_outputs.reshape(-1, adapt_outputs.size(-1)), + targets.reshape(-1) + ) - loss = self.loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)) + loss_soft = self.distillation_loss_fn( + F.log_softmax(adapt_outputs / self.temperature, dim=-1).reshape(-1, adapt_outputs.size(-1)), + F.softmax(base_outputs / self.temperature, dim=-1).reshape(-1, base_outputs.size(-1)) + ) + loss_soft = loss_soft * (self.temperature ** 2) - loss.backward() + total_loss = self.alpha * loss_hard + (1 - self.alpha) * loss_soft + + self.optimizer.zero_grad() + total_loss.backward() self.optimizer.step() - return loss.item() + return total_loss.item() def eval_step(self, batch: tuple) -> float: - self.model.eval() + self.adapt_model.eval() with torch.no_grad(): inputs, targets = batch - inputs = inputs[:, :-1].to(self.device) targets = targets[:, 1:].to(self.device) - outputs = self.model(inputs) - loss = self.loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)) + outputs = self.adapt_model(inputs) + loss = self.hard_loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)) return loss.item() def train_epoch(self, dataloader: DataLoader) -> float: total_loss = 0 - for batch in dataloader: + + for batch in tqdm(dataloader, desc="Training"): loss = self.train_step(batch) total_loss += loss return total_loss / len(dataloader) def eval_epoch(self, dataloader: DataLoader) -> float: total_loss = 0 - for batch in dataloader: + + for batch in tqdm(dataloader, desc="Evaluating"): loss = self.eval_step(batch) total_loss += loss return total_loss / len(dataloader) \ No newline at end of file From b9812de38bcfb39d88745de59afc16652e9cef33 Mon Sep 17 00:00:00 2001 From: 3GID Date: Tue, 8 Jul 2025 16:40:22 +0900 Subject: [PATCH 04/13] Init: Init Change --- retentionengine/data/__init__.py | 1 - retentionengine/datasets/__init__.py | 1 + retentionengine/utils/__init__.py | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) delete mode 100644 retentionengine/data/__init__.py create mode 100644 retentionengine/datasets/__init__.py diff --git a/retentionengine/data/__init__.py b/retentionengine/data/__init__.py deleted file mode 100644 index c599c46..0000000 --- a/retentionengine/data/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .dataloader import get_oasst_dataloader \ No newline at end of file diff --git a/retentionengine/datasets/__init__.py b/retentionengine/datasets/__init__.py new file mode 100644 index 0000000..15a5eb2 --- /dev/null +++ b/retentionengine/datasets/__init__.py @@ -0,0 +1 @@ +from .dataloader import get_pg19_dataloader \ No newline at end of file diff --git a/retentionengine/utils/__init__.py b/retentionengine/utils/__init__.py index 10a4fbb..a7dde9b 100644 --- a/retentionengine/utils/__init__.py +++ b/retentionengine/utils/__init__.py @@ -1 +1 @@ -from .adapter import Trainer +from .adapter import Adapter From dcd8582b9bbc51e1dd83c12029317ee8b8409d3f Mon Sep 17 00:00:00 2001 From: njhvrta Date: Tue, 22 Jul 2025 19:29:11 +0900 Subject: [PATCH 05/13] Feat : Optimize knowledge distillation training pipeline - Fix redundant forward passes in epoch methods - Unify data processing to eliminate code duplication Breaking change: Step methods now return (loss, predictions, targets)" --- retentionengine/utils/adapter.py | 170 +++++++++++++++++++++++-------- 1 file changed, 129 insertions(+), 41 deletions(-) diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py index 625a736..7da0a1c 100644 --- a/retentionengine/utils/adapter.py +++ b/retentionengine/utils/adapter.py @@ -1,89 +1,177 @@ import torch -from torch import nn +from torch import nn, optim from torch.utils.data import DataLoader import torch.nn.functional as F -from tqdm.auto import tqdm +from tqdm.auto import tqdm class Adapter: - def __init__(self, - adapt_model: nn.Module, - base_model: nn.Module, - optimizer: torch.optim.Optimizer, - hard_loss_fn: nn.Module = nn.CrossEntropyLoss(),#실제답과의 loss - distillation_loss_fn: nn.Module = nn.KLDivLoss(reduction='batchmean'), #base model과의 distillation loss - temperature: float = 2.0, - alpha: float = 0.5, #hard loss와 soft loss의 가중치 - device: str = "cuda"): - + def __init__( + self, + adapt_model: nn.Module, + base_model: nn.Module, + optimizer: optim.Optimizer, + distillation_loss_fn: nn.Module = nn.KLDivLoss(reduction='batchmean'), + # distillation loss with base model + temperature: float = 1.0, + device: str = "cuda" + ): self.adapt_model = adapt_model.to(device) self.base_model = base_model.to(device) self.optimizer = optimizer - self.hard_loss_fn = hard_loss_fn self.distillation_loss_fn = distillation_loss_fn self.temperature = temperature - self.alpha = alpha self.device = device + # Freeze base model parameters self.base_model.eval() for param in self.base_model.parameters(): - param.requires_grad = False # base model의 파라미터는 학습하지 않음 + param.requires_grad = False + + def train_step(self, batch: tuple) -> tuple[float, torch.Tensor, torch.Tensor]: + """ + Perform a single training step on a batch. - def train_step(self, batch: tuple) -> float: + Returns: + loss: batch loss value + predictions: model predictions + targets: actual targets (for accuracy calculation) + """ self.adapt_model.train() inputs, targets = batch + # Process inputs and targets once inputs = inputs[:, :-1].to(self.device) targets = targets[:, 1:].to(self.device) + # Get base model outputs (frozen) with torch.no_grad(): base_outputs = self.base_model(inputs) + # Get adapt model outputs adapt_outputs = self.adapt_model(inputs) - loss_hard = self.hard_loss_fn( - adapt_outputs.reshape(-1, adapt_outputs.size(-1)), - targets.reshape(-1) - ) - + # Calculate distillation loss loss_soft = self.distillation_loss_fn( F.log_softmax(adapt_outputs / self.temperature, dim=-1).reshape(-1, adapt_outputs.size(-1)), F.softmax(base_outputs / self.temperature, dim=-1).reshape(-1, base_outputs.size(-1)) ) - loss_soft = loss_soft * (self.temperature ** 2) - - total_loss = self.alpha * loss_hard + (1 - self.alpha) * loss_soft + total_loss = loss_soft * (self.temperature ** 2) + # Backward pass and optimization self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() - return total_loss.item() + # Calculate predictions (detach to save memory) + with torch.no_grad(): + predictions = adapt_outputs.argmax(dim=-1).detach() + + return total_loss.item(), predictions, targets + + def eval_step(self, batch: tuple) -> tuple[float, torch.Tensor, torch.Tensor]: + """ + Perform a single evaluation step on a batch. - def eval_step(self, batch: tuple) -> float: + Returns: + loss: batch loss value + predictions: model predictions + targets: actual targets + """ self.adapt_model.eval() + with torch.no_grad(): inputs, targets = batch inputs = inputs[:, :-1].to(self.device) targets = targets[:, 1:].to(self.device) - outputs = self.adapt_model(inputs) - loss = self.hard_loss_fn(outputs.reshape(-1, outputs.size(-1)), targets.reshape(-1)) + # Forward pass for both models + adapt_outputs = self.adapt_model(inputs) + base_outputs = self.base_model(inputs) - return loss.item() + # Calculate loss + loss = self.distillation_loss_fn( + F.log_softmax(adapt_outputs / self.temperature, dim=-1).reshape(-1, adapt_outputs.size(-1)), + F.softmax(base_outputs / self.temperature, dim=-1).reshape(-1, base_outputs.size(-1)) + ) + loss = loss * (self.temperature ** 2) - def train_epoch(self, dataloader: DataLoader) -> float: - total_loss = 0 + predictions = adapt_outputs.argmax(dim=-1) - for batch in tqdm(dataloader, desc="Training"): - loss = self.train_step(batch) - total_loss += loss - return total_loss / len(dataloader) + return loss.item(), predictions, targets - def eval_epoch(self, dataloader: DataLoader) -> float: - total_loss = 0 + def train_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int) -> float: + """ + Train for one complete epoch. - for batch in tqdm(dataloader, desc="Evaluating"): - loss = self.eval_step(batch) - total_loss += loss - return total_loss / len(dataloader) \ No newline at end of file + Args: + dataloader: training data loader + epoch: current epoch number + num_epochs: total number of epochs + + Returns: + average loss for the epoch + """ + total_loss = 0 + total_steps = 0 + correct = 0 + total = 0 + + with tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", unit="batch") as pbar: + for batch in pbar: + # Single step call to get all information + loss, predictions, targets = self.train_step(batch) + + total_loss += loss + total_steps += 1 + + # Calculate accuracy without redundant computation + correct += predictions.eq(targets).sum().item() + total += targets.numel() + + # Update progress bar with real-time metrics + pbar.set_postfix({ + 'epoch': epoch, + 'loss': f'{loss:.4f}', + 'acc': f'{(correct / total) * 100:.2f}%' + }) + + return total_loss / total_steps + + def eval_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int) -> float: + """ + Evaluate for one complete epoch. + + Args: + dataloader: evaluation data loader + epoch: current epoch number + num_epochs: total number of epochs + + Returns: + average loss for the epoch + """ + total_loss = 0 + total_steps = 0 + correct = 0 + total = 0 + + with tqdm(dataloader, desc=f"Eval {epoch}/{num_epochs}", unit="batch") as pbar: + for batch in pbar: + # Single step call to get all information + loss, predictions, targets = self.eval_step(batch) + + total_loss += loss + total_steps += 1 + + # Calculate accuracy without redundant computation + correct += predictions.eq(targets).sum().item() + total += targets.numel() + + # Update progress bar with real-time metrics + pbar.set_postfix({ + 'epoch': epoch, + 'loss': f'{loss:.4f}', + 'acc': f'{(correct / total) * 100:.2f}%' + }) + + return total_loss / total_steps \ No newline at end of file From f583a98e8cc0c16888f438cb62cb0adc4945b4bf Mon Sep 17 00:00:00 2001 From: njhvrta Date: Sun, 31 Aug 2025 23:25:10 +0900 Subject: [PATCH 06/13] feat: Add restoration training and update engine restoration_training.ipynb: Implement restoration training notebook for adapter model. usage.ipynb: Review and verify adapter usage logic. adapter.py: Refactor and enhance Adapter class for stability. --- restoration_training.ipynb | 417 +++++++++++++++++++++++++++++ retentionengine/adapters/engine.py | 6 +- retentionengine/utils/adapter.py | 257 ++++++++---------- usage.ipynb | 88 +++--- 4 files changed, 573 insertions(+), 195 deletions(-) create mode 100644 restoration_training.ipynb diff --git a/restoration_training.ipynb b/restoration_training.ipynb new file mode 100644 index 0000000..d6879cb --- /dev/null +++ b/restoration_training.ipynb @@ -0,0 +1,417 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c775af9f", + "metadata": {}, + "source": [ + "## Check GPU Availability" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3da24503", + "metadata": {}, + "outputs": [], + "source": [ + "!nvidia_smi" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c36b79fd", + "metadata": {}, + "outputs": [], + "source": [ + "# Set CUDA Device Number\n", + "DEVICE_NUM = 0\n", + "ADDITIONAL_GPU = 3\n", + "\n", + "from os import environ\n", + "environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([f\"{i+DEVICE_NUM}\" for i in range(ADDITIONAL_GPU + 1)])\n", + "environ[\"CUDA_VISIBLE_DEVICES\"]" + ] + }, + { + "cell_type": "markdown", + "id": "6f14a538", + "metadata": {}, + "source": [ + "## Imports" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5c8bc525", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import random\n", + "\n", + "from transformers import AdamW, get_linear_schedule_with_warmup, Qwen3ForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM\n", + "\n", + "from torch.utils.data import DataLoader\n", + "from datasets import load_dataset\n", + "\n", + "from tqdm.auto import tqdm\n", + "import wandb\n", + "\n", + "from utils.adapter import Adapter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "197bbe62", + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.is_available():\n", + " if ADDITIONAL_GPU:\n", + " device = torch.device(\"cuda\")\n", + " else:\n", + " device = torch.device(f\"cuda\")\n", + "else:\n", + " device = torch.device(\"cpu\")\n", + " DEVICE_NUM = -1\n", + " \n", + "print(f\"INFO: Using device - {device}\" + (f\":{DEVICE_NUM}\" if ADDITIONAL_GPU else \"\"))" + ] + }, + { + "cell_type": "markdown", + "id": "1ca4233c", + "metadata": {}, + "source": [ + "## Wandb setting" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5f56aa06", + "metadata": {}, + "outputs": [], + "source": [ + "PROJECT_NAME = \"restoration_training\"\n", + "RUN_NAME = \"Qwen3_8B_adapter_restoration\"\n", + "\n", + "# WandB Initialization\n", + "wandb.init(project=PROJECT_NAME, name=RUN_NAME)" + ] + }, + { + "cell_type": "markdown", + "id": "d08a6c97", + "metadata": {}, + "source": [ + "## Define Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "625d6659", + "metadata": {}, + "outputs": [], + "source": [ + "dataset_id = \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1b335067", + "metadata": {}, + "outputs": [], + "source": [ + "dataset = load_dataset(dataset_id)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "39876999", + "metadata": {}, + "outputs": [], + "source": [ + "dataset['train'][0]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "959bfd2a", + "metadata": {}, + "outputs": [], + "source": [ + "dataset['train'][0].keys()" + ] + }, + { + "cell_type": "markdown", + "id": "31ab1e69", + "metadata": {}, + "source": [ + "## Load Model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "130802a8", + "metadata": {}, + "outputs": [], + "source": [ + "torch_dtype = torch.bfloat16" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2387f7e6", + "metadata": {}, + "outputs": [], + "source": [ + "base_model_id = \"Qwen/Qwen3-8B\"\n", + "base_model = Qwen3ForCausalLM.from_pretrained(\n", + " base_model_id,\n", + " torch_dtype=torch_dtype,\n", + " devicep_map=\"auto\"\n", + " )\n", + "tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "20171342", + "metadata": {}, + "outputs": [], + "source": [ + "adapt_model_id = \"pretrained/Qwen/Qwen3-8B_adapter\"\n", + "adapt_model = Qwen3ForCausalLM.from_pretrained(\n", + " adapt_model_id,\n", + " torch_dtype=torch_dtype,\n", + " device_map=\"auto\"\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "88a1a9e3", + "metadata": {}, + "outputs": [], + "source": [ + "if torch.cuda.device_count() > 1:\n", + " print(f\"Using {torch.cuda.device_count()} GPUs for training.\")\n", + " base_model = nn.DataParallel(base_model)\n", + " adapt_model = nn.DataParallel(adapt_model)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "40dcb3c8", + "metadata": {}, + "outputs": [], + "source": [ + "base_model.eval()\n", + "for param in base_model.parameters():\n", + " param.requires_grad = False\n", + " \n", + "adapt_model.train()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "214f91eb", + "metadata": {}, + "outputs": [], + "source": [ + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token\n", + "\n", + "def tokenize_function(examples):\n", + " return tokenizer(examples[\"Text\"], truncation=True, padding=\"max_length\", max_length=512)\n", + "\n", + "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", + "tokenized_datasets = tokenized_datasets.map(lambda x: {**x, 'labels': x['input_ids']})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "65a921ac", + "metadata": {}, + "outputs": [], + "source": [ + "split_dataset = tokenized_datasets[\"train\"].train_test_split(test_size=0.1, seed=42)\n", + "\n", + "train_dataset = split_dataset[\"train\"]\n", + "eval_dataset = split_dataset[\"test\"]\n", + "\n", + "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)\n", + "eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=8)" + ] + }, + { + "cell_type": "markdown", + "id": "79c9fd59", + "metadata": {}, + "source": [ + "## Training" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7c7f4cef", + "metadata": {}, + "outputs": [], + "source": [ + "# Hyperparameters\n", + "EPOCHS = 10\n", + "LEARNING_RATE = 5e-5\n", + "ALPHA = 0.5\n", + "T = 0.3\n", + "NUM_TRAINING_STEPS = EPOCHS * len(train_dataloader)\n", + "\n", + "# optimizer and Scheduler\n", + "optimizer = AdamW(adapt_model.parameters(), lr=LEARNING_RATE)\n", + "lr_scheduler = get_linear_schedule_with_warmup(\n", + " optimizer,\n", + " num_warmup_steps=0,\n", + " num_training_steps=NUM_TRAINING_STEPS\n", + ")\n", + "\n", + "# Loss Functions\n", + "distillation_loss_fn = nn.KLDivLoss(reduction=\"batchmean\")\n", + "laungage_modeling_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", + "\n", + "# Initialize Adapter\n", + "adapter_trainer = Adapter(\n", + " adapt_model=adapt_model,\n", + " base_model=base_model,\n", + " optimizer=optimizer,\n", + " lr_scheduler=lr_scheduler,\n", + " distillation_loss_fn=distillation_loss_fn,\n", + " lm_loss_fn=laungage_modeling_loss_fn,\n", + " temperature=T,\n", + " alpha=ALPHA,\n", + " device=device,\n", + " use_wandb=True\n", + ")\n", + "\n", + "# Run\n", + "adapter_trainer.run_training(\n", + " train_dataloader,\n", + " num_epochs=EPOCHS,\n", + " eval_dataloader=eval_dataloader\n", + ")\n", + "\n", + "wandb.finish()" + ] + }, + { + "cell_type": "markdown", + "id": "f7f8133d", + "metadata": {}, + "source": [ + "## Testing" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "314782cb", + "metadata": {}, + "outputs": [], + "source": [ + "output_dir = \"./adapt_model_distilled\"\n", + "model_to_save = adapt_model.module if isinstance(adapt_model, nn.DataParallel) else adapt_model\n", + "model_to_save.save_pretrained(output_dir)\n", + "tokenizer.save_pretrained(output_dir)\n", + "print(f\"Adapted model saved successfully to {output_dir}\")\n", + "\n", + "base_model.eval()\n", + "adapt_model.eval()\n", + "\n", + "user_question = \"Explain the concept of knowledge distillation.\"\n", + "input_text = f\"Question: {user_question} Answer:\"\n", + "encoded_input = tokenizer(input_text, return_tensors=\"pt\").to(device)\n", + "input_ids = encoded_input.input_ids\n", + "\n", + "print(f\"Original Text: {input_text}\\n\")\n", + "print(\"=== Base Model vs Adapted Model Independent Generation Comparison ===\\n\")\n", + "\n", + "base_output_ids = input_ids.clone()\n", + "adapt_output_ids = input_ids.clone()\n", + "\n", + "max_new_tokens = 100\n", + "mismatched_tokens = 0\n", + "total_tokens = 0\n", + "\n", + "base_gen_tokens = []\n", + "adapt_gen_tokens = []\n", + "\n", + "for _ in range(max_new_tokens):\n", + " total_tokens += 1\n", + " \n", + " with torch.no_grad():\n", + " base_outputs = base_model(input_ids=base_output_ids)\n", + " base_logits = base_outputs.logits[:, -1, :]\n", + " base_predicted_id = torch.argmax(base_logits, dim=-1).unsqueeze(-1)\n", + "\n", + " adapt_outputs = adapt_model(input_ids=adapt_output_ids)\n", + " adapt_logits = adapt_outputs.logits[:, -1, :]\n", + " adapt_predicted_id = torch.argmax(adapt_logits, dim=-1).unsqueeze(-1)\n", + "\n", + " base_output_ids = torch.cat([base_output_ids, base_predicted_id], dim=-1)\n", + " adapt_output_ids = torch.cat([adapt_output_ids, adapt_predicted_id], dim=-1)\n", + "\n", + " # comparing and recording generated tokens (for metrics calculation)\n", + " if adapt_predicted_id.item() != base_predicted_id.item():\n", + " mismatched_tokens += 1\n", + " \n", + " base_token_str = tokenizer.decode(base_predicted_id.squeeze())\n", + " adapt_token_str = tokenizer.decode(adapt_predicted_id.squeeze())\n", + " base_gen_tokens.append(base_token_str)\n", + " adapt_gen_tokens.append(adapt_token_str)\n", + "\n", + " # 5. if either model generates EOS token, stop generation\n", + " if base_predicted_id.item() == tokenizer.eos_token_id or adapt_predicted_id.item() == tokenizer.eos_token_id:\n", + " print(\"INFO: EOS token generated. Halting generation.\")\n", + " break\n", + "\n", + "print(\"\\n--- Generation Results ---\")\n", + "print(\"Base Model Output:\", \"\".join(base_gen_tokens))\n", + "print(\"Adapted Model Output:\", \"\".join(adapt_gen_tokens))\n", + "print(\"\\n--- Evaluation Metrics ---\")\n", + "token_accuracy = ((total_tokens - mismatched_tokens) / total_tokens) * 100 if total_tokens > 0 else 0\n", + "print(f\"Token Match Rate (Accuracy): {token_accuracy:.2f}%\")\n", + "print(f\"Total Generated Tokens: {total_tokens}, Mismatched Tokens: {mismatched_tokens}\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/retentionengine/adapters/engine.py b/retentionengine/adapters/engine.py index 3baa41c..9bdf824 100644 --- a/retentionengine/adapters/engine.py +++ b/retentionengine/adapters/engine.py @@ -39,23 +39,20 @@ def adapt(self, ): device = self.module.device - # 2. 모델 역할 정의 및 설정 🧑‍🏫 adapt_model = self.module base_model = self.module.model base_model.eval() for param in base_model.parameters(): param.requires_grad = False for param in adapt_model.attention_module.parameters(): - param.requires_grad = False # Attention 모듈은 학습하지 않음 + param.requires_grad = False # dont trian attention module adapt_model.train() train_dataloader = get_pg19_dataloader(tokenizer, batch_size, 'train', max_length) eval_dataloader = get_pg19_dataloader(tokenizer, batch_size, 'validation', max_length) - # 4. 옵티마이저 설정 ✍️ optimizer = AdamW(adapt_model.parameters(), lr=learning_rate) - # 5. 증류 학습 트레이너 설정 trainer = Adapter( adapt_model=adapt_model, base_model=base_model, @@ -96,3 +93,4 @@ def save_pretrained(self, save_directory: str): """ self.module.save_pretrained(save_directory) self.config.save_pretrained(save_directory) + \ No newline at end of file diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py index 7da0a1c..482277b 100644 --- a/retentionengine/utils/adapter.py +++ b/retentionengine/utils/adapter.py @@ -1,9 +1,14 @@ +# adapter.py + import torch from torch import nn, optim from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm.auto import tqdm +import wandb +# FP8 학습에 권장되는 데이터 타입 (E4M3: 정밀도가 더 높아 학습 안정성에 유리) +FP8_DTYPE = torch.float8_e4m3fn class Adapter: def __init__( @@ -11,167 +16,125 @@ def __init__( adapt_model: nn.Module, base_model: nn.Module, optimizer: optim.Optimizer, - distillation_loss_fn: nn.Module = nn.KLDivLoss(reduction='batchmean'), - # distillation loss with base model - temperature: float = 1.0, - device: str = "cuda" + lr_scheduler, + distillation_loss_fn: nn.Module, + lm_loss_fn: nn.Module, + temperature: float = 2.0, + alpha: float = 0.5, + use_wandb: bool = True ): - self.adapt_model = adapt_model.to(device) - self.base_model = base_model.to(device) + self.adapt_model = adapt_model + self.base_model = base_model self.optimizer = optimizer + self.lr_scheduler = lr_scheduler self.distillation_loss_fn = distillation_loss_fn + self.lm_loss_fn = lm_loss_fn self.temperature = temperature - self.device = device + self.alpha = alpha + self.use_wandb = use_wandb + self.train_dataloader_len = 0 - # Freeze base model parameters self.base_model.eval() for param in self.base_model.parameters(): param.requires_grad = False - def train_step(self, batch: tuple) -> tuple[float, torch.Tensor, torch.Tensor]: - """ - Perform a single training step on a batch. - - Returns: - loss: batch loss value - predictions: model predictions - targets: actual targets (for accuracy calculation) - """ + def train_step(self, batch: dict) -> tuple: + primary_device = next(self.adapt_model.parameters()).device + batch = {k: v.to(primary_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} self.adapt_model.train() - inputs, targets = batch - - # Process inputs and targets once - inputs = inputs[:, :-1].to(self.device) - targets = targets[:, 1:].to(self.device) - - # Get base model outputs (frozen) - with torch.no_grad(): - base_outputs = self.base_model(inputs) - - # Get adapt model outputs - adapt_outputs = self.adapt_model(inputs) - # Calculate distillation loss - loss_soft = self.distillation_loss_fn( - F.log_softmax(adapt_outputs / self.temperature, dim=-1).reshape(-1, adapt_outputs.size(-1)), - F.softmax(base_outputs / self.temperature, dim=-1).reshape(-1, base_outputs.size(-1)) - ) - total_loss = loss_soft * (self.temperature ** 2) + with torch.autocast(device_type='cuda', dtype=FP8_DTYPE): + with torch.no_grad(): + base_outputs = self.base_model(**batch) + base_logits = base_outputs.logits + + adapt_outputs = self.adapt_model(**batch) + adapt_logits = adapt_outputs.logits + + soft_targets = F.softmax(base_logits / self.temperature, dim=-1) + soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) + distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) + + lm_loss = self.lm_loss_fn(adapt_logits.view(-1, adapt_logits.size(-1)), batch['labels'].view(-1)) + + total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss - # Backward pass and optimization self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() - - # Calculate predictions (detach to save memory) - with torch.no_grad(): - predictions = adapt_outputs.argmax(dim=-1).detach() - - return total_loss.item(), predictions, targets - - def eval_step(self, batch: tuple) -> tuple[float, torch.Tensor, torch.Tensor]: - """ - Perform a single evaluation step on a batch. - - Returns: - loss: batch loss value - predictions: model predictions - targets: actual targets - """ + if self.lr_scheduler: + self.lr_scheduler.step() + + return total_loss.item(), distillation_loss.item(), lm_loss.item() + + def train_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int): + total_epoch_loss = 0 + progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs} [T]", unit="batch") + + for step, batch in enumerate(progress_bar): + loss, dist_loss, lm_loss = self.train_step(batch) + total_epoch_loss += loss + + progress_bar.set_postfix({ + 'loss': f'{loss:.4f}', + 'dist_loss': f'{dist_loss:.4f}', + 'lm_loss': f'{lm_loss:.4f}' + }) + + if self.use_wandb: + global_step = (epoch - 1) * len(dataloader) + step + wandb.log({ + "train/step_loss": loss, + "train/distillation_loss": dist_loss, + "train/lm_loss": lm_loss, + "train/learning_rate": self.lr_scheduler.get_last_lr()[0] + }, step=global_step) + + avg_epoch_loss = total_epoch_loss / len(dataloader) + print(f"Epoch [{epoch}/{num_epochs}] Train Avg Loss: {avg_epoch_loss:.4f}") + + if self.use_wandb: + global_step = epoch * len(dataloader) + wandb.log({"train/epoch_loss": avg_epoch_loss}, step=global_step) + + def evaluate_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int): + print("INFO: Starting evaluation...") self.adapt_model.eval() - + total_eval_loss = 0 + primary_device = next(self.adapt_model.parameters()).device + + progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs} [E]", unit="batch") + with torch.no_grad(): - inputs, targets = batch - inputs = inputs[:, :-1].to(self.device) - targets = targets[:, 1:].to(self.device) - - # Forward pass for both models - adapt_outputs = self.adapt_model(inputs) - base_outputs = self.base_model(inputs) - - # Calculate loss - loss = self.distillation_loss_fn( - F.log_softmax(adapt_outputs / self.temperature, dim=-1).reshape(-1, adapt_outputs.size(-1)), - F.softmax(base_outputs / self.temperature, dim=-1).reshape(-1, base_outputs.size(-1)) - ) - loss = loss * (self.temperature ** 2) - - predictions = adapt_outputs.argmax(dim=-1) - - return loss.item(), predictions, targets - - def train_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int) -> float: - """ - Train for one complete epoch. - - Args: - dataloader: training data loader - epoch: current epoch number - num_epochs: total number of epochs - - Returns: - average loss for the epoch - """ - total_loss = 0 - total_steps = 0 - correct = 0 - total = 0 - - with tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs}", unit="batch") as pbar: - for batch in pbar: - # Single step call to get all information - loss, predictions, targets = self.train_step(batch) - - total_loss += loss - total_steps += 1 - - # Calculate accuracy without redundant computation - correct += predictions.eq(targets).sum().item() - total += targets.numel() - - # Update progress bar with real-time metrics - pbar.set_postfix({ - 'epoch': epoch, - 'loss': f'{loss:.4f}', - 'acc': f'{(correct / total) * 100:.2f}%' - }) - - return total_loss / total_steps - - def eval_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int) -> float: - """ - Evaluate for one complete epoch. - - Args: - dataloader: evaluation data loader - epoch: current epoch number - num_epochs: total number of epochs - - Returns: - average loss for the epoch - """ - total_loss = 0 - total_steps = 0 - correct = 0 - total = 0 - - with tqdm(dataloader, desc=f"Eval {epoch}/{num_epochs}", unit="batch") as pbar: - for batch in pbar: - # Single step call to get all information - loss, predictions, targets = self.eval_step(batch) - - total_loss += loss - total_steps += 1 - - # Calculate accuracy without redundant computation - correct += predictions.eq(targets).sum().item() - total += targets.numel() - - # Update progress bar with real-time metrics - pbar.set_postfix({ - 'epoch': epoch, - 'loss': f'{loss:.4f}', - 'acc': f'{(correct / total) * 100:.2f}%' - }) - - return total_loss / total_steps \ No newline at end of file + for batch in progress_bar: + batch = {k: v.to(primary_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} + + with torch.autocast(device_type='cuda', dtype=FP8_DTYPE): + base_outputs = self.base_model(**batch) + base_logits = base_outputs.logits + + adapt_outputs = self.adapt_model(**batch) + adapt_logits = adapt_outputs.logits + + soft_targets = F.softmax(base_logits / self.temperature, dim=-1) + soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) + distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) + + lm_loss = self.lm_loss_fn(adapt_logits.view(-1, adapt_logits.size(-1)), batch['labels'].view(-1)) + total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss + + total_eval_loss += total_loss.item() + + avg_eval_loss = total_eval_loss / len(dataloader) + print(f"Epoch [{epoch}/{num_epochs}] Eval Avg Loss: {avg_eval_loss:.4f}") + + if self.use_wandb: + global_step = epoch * self.train_dataloader_len + wandb.log({"eval/epoch_loss": avg_eval_loss}, step=global_step) + + def run_training(self, train_dataloader: DataLoader, num_epochs: int, eval_dataloader: DataLoader = None): + self.train_dataloader_len = len(train_dataloader) + for epoch in range(1, num_epochs + 1): + self.train_epoch(train_dataloader, epoch, num_epochs) + if eval_dataloader: + self.evaluate_epoch(eval_dataloader, epoch, num_epochs) \ No newline at end of file diff --git a/usage.ipynb b/usage.ipynb index 8ae046a..811af45 100644 --- a/usage.ipynb +++ b/usage.ipynb @@ -2,86 +2,82 @@ "cells": [ { "cell_type": "code", + "execution_count": 1, "id": "initial_id", "metadata": { - "collapsed": true, "ExecuteTime": { "end_time": "2025-05-05T23:45:48.721890Z", "start_time": "2025-05-05T23:45:34.721988Z" - } + }, + "collapsed": true }, + "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "from retentionengine import RetentionEngine\n", "from thelethe.titans import CronosConfig" - ], - "outputs": [], - "execution_count": 1 + ] }, { + "cell_type": "code", + "execution_count": null, + "id": "f20a5a772cb31cca", "metadata": { "ExecuteTime": { "end_time": "2025-05-05T23:49:38.482121Z", "start_time": "2025-05-05T23:45:48.934067Z" } }, - "cell_type": "code", - "source": [ - "# Load Gemma3 4b model and tokenizer\n", - "basemodel_name = \"google/gemma-3-4b-it\"\n", - "tokenizer = AutoTokenizer.from_pretrained(basemodel_name)\n", - "basemodel = AutoModelForCausalLM.from_pretrained(basemodel_name)" - ], - "id": "f20a5a772cb31cca", "outputs": [ { "data": { - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00 Date: Mon, 1 Sep 2025 02:05:46 +0900 Subject: [PATCH 07/13] feat(data): add dataset generation script and update training logic - Add `generate_dataset.py` to create a `.jsonl` dataset for memory training. - Update `restoration_training.ipynb` to load and use the generated dataset, modifying the data loader and tokenizer. --- restoration_training.ipynb | 80 +++--- retentionengine/datasets/dataset_generator.py | 245 ++++++++++++++++++ 2 files changed, 280 insertions(+), 45 deletions(-) create mode 100644 retentionengine/datasets/dataset_generator.py diff --git a/restoration_training.ipynb b/restoration_training.ipynb index d6879cb..e0777c5 100644 --- a/restoration_training.ipynb +++ b/restoration_training.ipynb @@ -120,7 +120,7 @@ "metadata": {}, "outputs": [], "source": [ - "dataset_id = \"\"" + "dataset_path = \"retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl\"" ] }, { @@ -130,7 +130,7 @@ "metadata": {}, "outputs": [], "source": [ - "dataset = load_dataset(dataset_id)" + "dataset = load_dataset(\"json\", data_file=dataset_path)" ] }, { @@ -182,7 +182,7 @@ "base_model = Qwen3ForCausalLM.from_pretrained(\n", " base_model_id,\n", " torch_dtype=torch_dtype,\n", - " devicep_map=\"auto\"\n", + " device_map=\"auto\"\n", " )\n", "tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)" ] @@ -202,19 +202,6 @@ " )" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "88a1a9e3", - "metadata": {}, - "outputs": [], - "source": [ - "if torch.cuda.device_count() > 1:\n", - " print(f\"Using {torch.cuda.device_count()} GPUs for training.\")\n", - " base_model = nn.DataParallel(base_model)\n", - " adapt_model = nn.DataParallel(adapt_model)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -240,10 +227,17 @@ " tokenizer.pad_token = tokenizer.eos_token\n", "\n", "def tokenize_function(examples):\n", - " return tokenizer(examples[\"Text\"], truncation=True, padding=\"max_length\", max_length=512)\n", + " formatted_text = [\n", + " f\"질문: {inp}\\\\n\\\\n답변: {out}{tokenizer.eos_token}\" \n", + " for inp, out in zip(examples[\"input\"], examples[\"target_output\"])\n", + " ]\n", "\n", - "tokenized_datasets = dataset.map(tokenize_function, batched=True)\n", - "tokenized_datasets = tokenized_datasets.map(lambda x: {**x, 'labels': x['input_ids']})" + " tokenized = tokenizer(formatted_text, truncation=True, padding=\"max_length\", max_length=512)\n", + " tokenized[\"labels\"] = tokenized[\"input_ids\"].copy()\n", + " return tokenized\n", + "\n", + "original_columns = dataset['train'].column_names\n", + "tokenized_datasets = dataset.map(tokenize_function, batched=True, remove_columns=original_columns)" ] }, { @@ -294,7 +288,7 @@ "\n", "# Loss Functions\n", "distillation_loss_fn = nn.KLDivLoss(reduction=\"batchmean\")\n", - "laungage_modeling_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", + "language_modeling_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", "\n", "# Initialize Adapter\n", "adapter_trainer = Adapter(\n", @@ -336,21 +330,20 @@ "outputs": [], "source": [ "output_dir = \"./adapt_model_distilled\"\n", - "model_to_save = adapt_model.module if isinstance(adapt_model, nn.DataParallel) else adapt_model\n", - "model_to_save.save_pretrained(output_dir)\n", + "adapt_model.save_pretrained(output_dir)\n", "tokenizer.save_pretrained(output_dir)\n", "print(f\"Adapted model saved successfully to {output_dir}\")\n", "\n", "base_model.eval()\n", "adapt_model.eval()\n", "\n", - "user_question = \"Explain the concept of knowledge distillation.\"\n", - "input_text = f\"Question: {user_question} Answer:\"\n", - "encoded_input = tokenizer(input_text, return_tensors=\"pt\").to(device)\n", - "input_ids = encoded_input.input_ids\n", + "user_question = \"Explain the concept of knowledge distillation in machine learning.\"\n", + "prompt = f\"question: {user_question}\\n\\nanswer:\"\n", + "\n", + "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(next(adapt_model.parameters()).device)\n", "\n", - "print(f\"Original Text: {input_text}\\n\")\n", - "print(\"=== Base Model vs Adapted Model Independent Generation Comparison ===\\n\")\n", + "print(f\"Prompt: {prompt}\\n\")\n", + "print(\"=== Token-by-Token Generation Comparison ===\\n\")\n", "\n", "base_output_ids = input_ids.clone()\n", "adapt_output_ids = input_ids.clone()\n", @@ -359,10 +352,7 @@ "mismatched_tokens = 0\n", "total_tokens = 0\n", "\n", - "base_gen_tokens = []\n", - "adapt_gen_tokens = []\n", - "\n", - "for _ in range(max_new_tokens):\n", + "for step in range(max_new_tokens):\n", " total_tokens += 1\n", " \n", " with torch.no_grad():\n", @@ -377,27 +367,27 @@ " base_output_ids = torch.cat([base_output_ids, base_predicted_id], dim=-1)\n", " adapt_output_ids = torch.cat([adapt_output_ids, adapt_predicted_id], dim=-1)\n", "\n", - " # comparing and recording generated tokens (for metrics calculation)\n", " if adapt_predicted_id.item() != base_predicted_id.item():\n", " mismatched_tokens += 1\n", - " \n", - " base_token_str = tokenizer.decode(base_predicted_id.squeeze())\n", - " adapt_token_str = tokenizer.decode(adapt_predicted_id.squeeze())\n", - " base_gen_tokens.append(base_token_str)\n", - " adapt_gen_tokens.append(adapt_token_str)\n", + " base_token_str = tokenizer.decode(base_predicted_id.squeeze())\n", + " adapt_token_str = tokenizer.decode(adapt_predicted_id.squeeze())\n", + " print(f\"Step {step+1}: Mismatch! Base='{base_token_str}' vs Adapt='{adapt_token_str}'\")\n", "\n", - " # 5. if either model generates EOS token, stop generation\n", " if base_predicted_id.item() == tokenizer.eos_token_id or adapt_predicted_id.item() == tokenizer.eos_token_id:\n", - " print(\"INFO: EOS token generated. Halting generation.\")\n", + " print(\"\\nINFO: EOS token generated. Halting generation.\")\n", " break\n", "\n", - "print(\"\\n--- Generation Results ---\")\n", - "print(\"Base Model Output:\", \"\".join(base_gen_tokens))\n", - "print(\"Adapted Model Output:\", \"\".join(adapt_gen_tokens))\n", - "print(\"\\n--- Evaluation Metrics ---\")\n", + "base_text = tokenizer.decode(base_output_ids[0], skip_special_tokens=True)\n", + "adapt_text = tokenizer.decode(adapt_output_ids[0], skip_special_tokens=True)\n", + "\n", + "print(\"\\n--- Full Generation Results ---\")\n", + "print(\"Base Model Output:\\n\", base_text)\n", + "print(\"\\nAdapted Model Output:\\n\", adapt_text)\n", + "\n", "token_accuracy = ((total_tokens - mismatched_tokens) / total_tokens) * 100 if total_tokens > 0 else 0\n", + "print(\"\\n--- Evaluation Metrics ---\")\n", "print(f\"Token Match Rate (Accuracy): {token_accuracy:.2f}%\")\n", - "print(f\"Total Generated Tokens: {total_tokens}, Mismatched Tokens: {mismatched_tokens}\")\n" + "print(f\"Total Generated Tokens: {total_tokens}, Mismatched Tokens: {mismatched_tokens}\")" ] } ], diff --git a/retentionengine/datasets/dataset_generator.py b/retentionengine/datasets/dataset_generator.py new file mode 100644 index 0000000..dddc181 --- /dev/null +++ b/retentionengine/datasets/dataset_generator.py @@ -0,0 +1,245 @@ +# -*- coding: utf-8 -*- +""" +Qwen3 기반 메모리 모듈 학습용 데이터셋 생성기 (안정화 버전) +------------------------------------------------ +생성된 데이터: RETENTIONENGINE/retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl +→ 2/3 한국어, 1/3 영어 샘플 포함 +""" + +import os +import json +import re +import time +import random +import uuid +from typing import List, Dict +from transformers import AutoTokenizer, AutoModelForCausalLM +import torch +import json_repair + + +def make_seed_prompt(batch_size: int = 100) -> str: + """ + 2/3 확률로 한국어, 1/3 확률로 영어 프롬프트 반환 + """ + if random.random() < 0.33: # 약 1/3 확률로 영어 + return make_seed_prompt_english(batch_size) + else: + return make_seed_prompt_korean(batch_size) + + +def make_seed_prompt_korean(batch_size: int = 5) -> str: + """ + 한국어 프롬프트 + """ + domains = [ + ("math", 20), ("cs", 20), ("physics", 15), ("literature", 15), + ("philosophy", 10), ("history", 10), ("biology", 5), ("economics", 5) + ] + domain_str = ", ".join([f"{d}({p}%)" for d, p in domains]) + + prompt = f""" +다음 JSON 스키마를 따르는 데이터를 정확히 {batch_size}개 포함하는 단일 JSON 배열(array)을 생성해줘. + +{{ + "id": "sample_001", + "domain": "예: math, physics, literature, cs, philosophy 등", + "input": "장문 또는 단문 입력. 다양한 도메인 혼합 가능", + "target_output": "Qwen3이 생성할 것으로 예상되는 출력", + "reasoning_trace": [ + {{ + "step": 1, + "thought": "내부 사고 과정", + "memory_read": ["참조한 메모리 블록 이름"], + "inference": "추론 내용" + }} + ], + "memory_access_pattern": {{ + "reads": ["메모리_블록_A", "메모리_블록_B"], + "writes": ["결과_요약_01"] + }}, + "metadata": {{ + "context_length": 입력 토큰 수, + "reasoning_depth": "trivial / intermediate / advanced" + }} +}} + +### 생성 규칙: +- 30%: 짧은 질문 (100~500 토큰), 출력 분포 정렬용 +- 50%: 중간 길이 CoT (1K~5K 토큰), 메모리 트레이스 포함 +- 20%: 장문 크로스 도메인 (8K~30K 토큰), multi-hop 추론 +- 도메인 분포: {domain_str} +- reasoning_trace는 2~5단계 포함 +- memory_read/write는 실제 메모리 블록 이름 사용 (예: physics_laws_v1, literature_themes_02) +- 오직 하나의 완성된 JSON 배열만 출력해줘. 설명, 마크다운, ```json``` 태그는 절대 포함하지 마. +- 예시: `[ {{...}}, {{...}}, ... ]` +""" + return prompt.strip() + + +def make_seed_prompt_english(batch_size: int = 100) -> str: + """ + 영어 프롬프트 (1/3 확률로 호출) + """ + domains = [ + ("math", 20), ("cs", 20), ("physics", 15), ("literature", 15), + ("philosophy", 10), ("history", 10), ("biology", 5), ("economics", 5) + ] + domain_str = ", ".join([f"{d}({p}%)" for d, p in domains]) + + prompt = f""" +Generate a single JSON array containing exactly {batch_size} samples that follow the JSON schema below. + +{{ + "id": "sample_001", + "domain": "e.g., math, physics, literature, cs, philosophy", + "input": "Long or short input. Mix of domains allowed", + "target_output": "Output expected from Qwen3", + "reasoning_trace": [ + {{ + "step": 1, + "thought": "Internal reasoning process", + "memory_read": ["memory_block_name"], + "inference": "Inference content" + }} + ], + "memory_access_pattern": {{ + "reads": ["memory_block_A", "memory_block_B"], + "writes": ["summary_01"] + }}, + "metadata": {{ + "context_length": "number of input tokens", + "reasoning_depth": "trivial / intermediate / advanced" + }} +}} + +### Rules: +- 30%: Short questions (100-500 tokens), for output distribution alignment +- 50%: Medium-length CoT (1K-5K tokens), with memory trace +- 20%: Long cross-domain (8K-30K tokens), multi-hop reasoning +- Domain distribution: {domain_str} +- reasoning_trace must have 2-5 steps +- Use realistic memory block names (e.g., physics_laws_v1, literature_themes_02) +- Output only a single, complete JSON array. Do not include any explanation, markdown, or ```json``` tags. +- Example: `[ {{...}}, {{...}}, ... ]` +""" + return prompt.strip() + + +# [수정] JSON 배열을 안정적으로 파싱하는 함수 +def parse_json_array(text: str) -> List[Dict]: + """ + 생성된 텍스트에서 유효한 JSON 배열을 추출 + """ + match = re.search(r'\[.*\]', text, re.DOTALL) + if not match: + print("⚠️ 경고: 응답에서 JSON 배열을 찾지 못했습니다.") + return [] + + json_array_str = match.group(0) + try: + data = json.loads(json_array_str) + except json.JSONDecodeError: + data = json_repair.loads(json_array_str) + + if isinstance(data, list): + valid_items = [ + item for item in data + if isinstance(item, dict) and "input" in item and "target_output" in item + ] + return valid_items + else: + print("⚠️ 경고: 파싱된 데이터가 리스트가 아닙니다.") + return [] + + +def generate_dataset( + target_count: int = 10000, + batch_size: int = 100 +): + """ + Qwen3을 사용해 데이터셋을 생성하고, 지정된 경로에 저장 + """ + output_dir = os.path.join("retentionengine", "datasets") + os.makedirs(output_dir, exist_ok=True) + output_file = os.path.join(output_dir, "qwen3_memory_aligned_dataset.jsonl") + + print("🚀 모델을 로딩 중입니다... (Qwen/Qwen2-7B-Instruct)") + model_name = "Qwen/Qwen2-7B-Instruct" + + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) + + print(f"✅ 모델 로딩 완료. 저장 위치: {os.path.abspath(output_file)}") + + generated_samples = [] + + with open(output_file, "w", encoding="utf-8") as f_out: + while len(generated_samples) < target_count: + remaining = target_count - len(generated_samples) + current_batch_size = min(batch_size, remaining) + print(f"🔄 생성 중: {len(generated_samples)} / {target_count} | 이번 배치: {current_batch_size}") + + prompt = make_seed_prompt(current_batch_size) + + messages = [ + {"role": "system", "content": "You are a helpful assistant that generates high-quality synthetic data."}, + {"role": "user", "content": prompt} + ] + text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + inputs = tokenizer(text, return_tensors="pt").to(model.device) + + # 안전한 max_new_tokens 계산 + input_len = inputs["input_ids"].shape[1] + max_new = min(16000, model.config.max_position_embeddings - input_len) + + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new, + temperature=0.8, + do_sample=True, + top_p=0.9, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.eos_token_id + ) + + input_token_length = inputs["input_ids"].shape[1] + generated_tokens = outputs[0][input_token_length:] + content = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() + + new_samples = parse_json_array(content) + print(f"✅ 이번 배치에서 {len(new_samples)}개의 유효한 샘플 추출") + + if not new_samples: + print("⚠️ 이번 배치에서 유효한 샘플을 생성하지 못했습니다. 재시도합니다.") + time.sleep(5) + continue + + for sample in new_samples: + if len(generated_samples) >= target_count: + break + sample["id"] = str(uuid.uuid4()) # uuid 기반 ID + f_out.write(json.dumps(sample, ensure_ascii=False) + "\n") + generated_samples.append(sample) + + f_out.flush() + print(f"📊 누적 생성: {len(generated_samples)} / {target_count}") + + time.sleep(2) + + print(f"\n🎉 성공적으로 완료!") + print(f"📁 저장 위치: {os.path.abspath(output_file)}") + print(f"📊 총 생성된 샘플 수: {len(generated_samples)}") + return generated_samples + + +if __name__ == "__main__": + samples = generate_dataset( + target_count=5000, + batch_size=25 + ) From a8966c6bb4a86a4463e1e55ee4bba7a1d7e17d45 Mon Sep 17 00:00:00 2001 From: njhvrta Date: Tue, 2 Sep 2025 01:46:26 +0900 Subject: [PATCH 08/13] feat: Implement FP8 distillation using Transformer Engine adapter.py: - Replace `torch.autocast` with `transformer_engine` to correctly apply FP8 to the teacher model. - Add `convert_to_fp8_layers` helper and convert teacher model layers in `__init__`. restoration_training.ipynb: - Import the `transformer_engine` library. - Update the testing loop to use `te.fp8_autocast` for the teacher model's inference to ensure consistency with the training environment --- restoration_training.ipynb | 10 +- retentionengine/datasets/dataset_generator.py | 63 +++++--- retentionengine/utils/adapter.py | 140 +++++++++++++----- 3 files changed, 152 insertions(+), 61 deletions(-) diff --git a/restoration_training.ipynb b/restoration_training.ipynb index e0777c5..43c60e7 100644 --- a/restoration_training.ipynb +++ b/restoration_training.ipynb @@ -60,6 +60,7 @@ "\n", "from tqdm.auto import tqdm\n", "import wandb\n", + "import transformer_engine.pytorch as te\n", "\n", "from utils.adapter import Adapter" ] @@ -297,10 +298,9 @@ " optimizer=optimizer,\n", " lr_scheduler=lr_scheduler,\n", " distillation_loss_fn=distillation_loss_fn,\n", - " lm_loss_fn=laungage_modeling_loss_fn,\n", + " lm_loss_fn=language_modeling_loss_fn,\n", " temperature=T,\n", " alpha=ALPHA,\n", - " device=device,\n", " use_wandb=True\n", ")\n", "\n", @@ -334,7 +334,8 @@ "tokenizer.save_pretrained(output_dir)\n", "print(f\"Adapted model saved successfully to {output_dir}\")\n", "\n", - "base_model.eval()\n", + "test_base_model = adapter_trainer.base_model \n", + "test_base_model.eval()\n", "adapt_model.eval()\n", "\n", "user_question = \"Explain the concept of knowledge distillation in machine learning.\"\n", @@ -356,7 +357,8 @@ " total_tokens += 1\n", " \n", " with torch.no_grad():\n", - " base_outputs = base_model(input_ids=base_output_ids)\n", + " with te.fp8_autocast(enabled=True):\n", + " base_outputs = test_base_model(input_ids=base_output_ids)\n", " base_logits = base_outputs.logits[:, -1, :]\n", " base_predicted_id = torch.argmax(base_logits, dim=-1).unsqueeze(-1)\n", "\n", diff --git a/retentionengine/datasets/dataset_generator.py b/retentionengine/datasets/dataset_generator.py index dddc181..2af3b8f 100644 --- a/retentionengine/datasets/dataset_generator.py +++ b/retentionengine/datasets/dataset_generator.py @@ -15,7 +15,7 @@ from typing import List, Dict from transformers import AutoTokenizer, AutoModelForCausalLM import torch -import json_repair +import json_repair # 🟢 라이브러리 다시 임포트 def make_seed_prompt(batch_size: int = 100) -> str: @@ -71,8 +71,21 @@ def make_seed_prompt_korean(batch_size: int = 5) -> str: - 도메인 분포: {domain_str} - reasoning_trace는 2~5단계 포함 - memory_read/write는 실제 메모리 블록 이름 사용 (예: physics_laws_v1, literature_themes_02) +- 문자열 내부에 따옴표(\")가 있다면 반드시 백슬래시(\\)로 이스케이프 처리해줘 (예: `\"`). - 오직 하나의 완성된 JSON 배열만 출력해줘. 설명, 마크다운, ```json``` 태그는 절대 포함하지 마. -- 예시: `[ {{...}}, {{...}}, ... ]` + +### 완벽한 출력 예시: +[ + {{ + "id": "temp_001", + "domain": "literature", + "input": "소설 '데미안'의 주제는 무엇인가?", + "target_output": "헤르만 헤세의 소설 '데미안'의 핵심 주제는 개인의 자아 발견과 성장 과정입니다. 주인공 싱클레어가 알을 깨고 나오듯, 기존의 세계를 벗어나 자기 자신만의 길을 찾아가는 과정을 그립니다.", + "reasoning_trace": [], + "memory_access_pattern": {{"reads": [], "writes": []}}, + "metadata": {{"context_length": 50, "reasoning_depth": "trivial"}} + }} +] """ return prompt.strip() @@ -121,12 +134,11 @@ def make_seed_prompt_english(batch_size: int = 100) -> str: - reasoning_trace must have 2-5 steps - Use realistic memory block names (e.g., physics_laws_v1, literature_themes_02) - Output only a single, complete JSON array. Do not include any explanation, markdown, or ```json``` tags. -- Example: `[ {{...}}, {{...}}, ... ]` +- Example: `[ {{"id": "...", ...}} ]` """ return prompt.strip() -# [수정] JSON 배열을 안정적으로 파싱하는 함수 def parse_json_array(text: str) -> List[Dict]: """ 생성된 텍스트에서 유효한 JSON 배열을 추출 @@ -140,7 +152,14 @@ def parse_json_array(text: str) -> List[Dict]: try: data = json.loads(json_array_str) except json.JSONDecodeError: - data = json_repair.loads(json_array_str) + # 🟢 JSON 디코딩 실패 시 json_repair로 복구 시도 + print("⚠️ 경고: JSON 디코딩 실패. json_repair로 복구를 시도합니다.") + try: + data = json_repair.loads(json_array_str) + except Exception as e: + print(f"🆘 json_repair로도 복구 실패: {e}") + return [] + if isinstance(data, list): valid_items = [ @@ -160,12 +179,10 @@ def generate_dataset( """ Qwen3을 사용해 데이터셋을 생성하고, 지정된 경로에 저장 """ - output_dir = os.path.join("retentionengine", "datasets") - os.makedirs(output_dir, exist_ok=True) - output_file = os.path.join(output_dir, "qwen3_memory_aligned_dataset.jsonl") + output_file = "qwen3_memory_aligned_dataset.jsonl" print("🚀 모델을 로딩 중입니다... (Qwen/Qwen2-7B-Instruct)") - model_name = "Qwen/Qwen2-7B-Instruct" + model_name = "Qwen/Qwen2-7B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( @@ -174,7 +191,7 @@ def generate_dataset( torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 ) - print(f"✅ 모델 로딩 완료. 저장 위치: {os.path.abspath(output_file)}") + print(f"✅ 모델 로딩 완료. 저장 위치: {output_file}") generated_samples = [] @@ -186,22 +203,22 @@ def generate_dataset( prompt = make_seed_prompt(current_batch_size) - messages = [ - {"role": "system", "content": "You are a helpful assistant that generates high-quality synthetic data."}, - {"role": "user", "content": prompt} - ] - text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - inputs = tokenizer(text, return_tensors="pt").to(model.device) - - # 안전한 max_new_tokens 계산 + system_prompt = "<|im_start|>system\nYou are a helpful assistant that generates high-quality synthetic data.<|im_end|>" + user_prompt = f"<|im_start|>user\n{prompt}<|im_end|>" + assistant_prompt = "<|im_start|>assistant" + + final_text = f"{system_prompt}\n{user_prompt}\n{assistant_prompt}" + + inputs = tokenizer(final_text, return_tensors="pt").to(model.device) + input_len = inputs["input_ids"].shape[1] - max_new = min(16000, model.config.max_position_embeddings - input_len) + max_new = min(16000, 32768 - input_len - 10) with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=max_new, - temperature=0.8, + temperature=0.5, do_sample=True, top_p=0.9, eos_token_id=tokenizer.eos_token_id, @@ -223,7 +240,7 @@ def generate_dataset( for sample in new_samples: if len(generated_samples) >= target_count: break - sample["id"] = str(uuid.uuid4()) # uuid 기반 ID + sample["id"] = str(uuid.uuid4()) f_out.write(json.dumps(sample, ensure_ascii=False) + "\n") generated_samples.append(sample) @@ -233,7 +250,7 @@ def generate_dataset( time.sleep(2) print(f"\n🎉 성공적으로 완료!") - print(f"📁 저장 위치: {os.path.abspath(output_file)}") + print(f"📁 저장 위치: {output_file}") print(f"📊 총 생성된 샘플 수: {len(generated_samples)}") return generated_samples @@ -242,4 +259,4 @@ def generate_dataset( samples = generate_dataset( target_count=5000, batch_size=25 - ) + ) \ No newline at end of file diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py index 482277b..6d17c83 100644 --- a/retentionengine/utils/adapter.py +++ b/retentionengine/utils/adapter.py @@ -1,14 +1,32 @@ -# adapter.py - import torch from torch import nn, optim from torch.utils.data import DataLoader import torch.nn.functional as F from tqdm.auto import tqdm import wandb +import os +import transformer_engine.pytorch as te + -# FP8 학습에 권장되는 데이터 타입 (E4M3: 정밀도가 더 높아 학습 안정성에 유리) -FP8_DTYPE = torch.float8_e4m3fn +def convert_to_fp8_layers(module: nn.Module) -> nn.Module: + """ + Replaces nn.Linear and nn.LayerNorm with Transformer Engine's FP8-supported layers. + """ + for name, child in module.named_children(): + if isinstance(child, nn.Linear): + setattr(module, name, te.Linear( + child.in_features, + child.out_features, + bias=(child.bias is not None), + )) + elif isinstance(child, nn.LayerNorm): + setattr(module, name, te.LayerNorm( + child.normalized_shape, + eps=child.eps, + )) + else: + convert_to_fp8_layers(child) + return module class Adapter: def __init__( @@ -21,10 +39,16 @@ def __init__( lm_loss_fn: nn.Module, temperature: float = 2.0, alpha: float = 0.5, - use_wandb: bool = True + use_wandb: bool = True, + checkpoint_dir: str = "./checkpoints", ): self.adapt_model = adapt_model - self.base_model = base_model + + # Convert the teacher model's layers to FP8-supported layers. + print("INFO: Converting base_model to support FP8 with Transformer Engine...") + self.base_model = convert_to_fp8_layers(base_model) + print("INFO: Conversion complete.") + self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.distillation_loss_fn = distillation_loss_fn @@ -32,34 +56,71 @@ def __init__( self.temperature = temperature self.alpha = alpha self.use_wandb = use_wandb - self.train_dataloader_len = 0 - + self.base_model.eval() for param in self.base_model.parameters(): param.requires_grad = False + self.checkpoint_dir = checkpoint_dir + if not os.path.exists(self.checkpoint_dir): + os.makedirs(self.checkpoint_dir, exist_ok=True) + self.best_eval_loss = float('inf') + self.train_dataloader_len = 0 + + def save_checkpoint(self, epoch: int, eval_loss: float = None): + checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}_checkpoint.pth") + + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.adapt_model.state_dict(), + 'optimizer_state_dict': self.optimizer.state_dict(), + 'lr_scheduler_state_dict': self.lr_scheduler.state_dict() if self.lr_scheduler else None, + 'eval_loss': eval_loss, + 'best_eval_loss': self.best_eval_loss + }, checkpoint_path) + print(f"INFO: Checkpoint saved to {checkpoint_path}") + + def save_best_model(self, eval_loss: float, epoch: int): + if eval_loss < self.best_eval_loss: + print(f"INFO: New best model found! Eval Loss: {eval_loss:.4f} (Previous: {self.best_eval_loss:.4f})") + self.best_eval_loss = eval_loss + best_model_path = os.path.join(self.checkpoint_dir, "best_model.pth") + torch.save({ + 'epoch': epoch, + 'model_state_dict': self.adapt_model.state_dict(), + 'eval_loss': eval_loss + }, best_model_path) + print(f"INFO: Best model saved to {best_model_path}") + def train_step(self, batch: dict) -> tuple: primary_device = next(self.adapt_model.parameters()).device batch = {k: v.to(primary_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} + self.adapt_model.train() - - with torch.autocast(device_type='cuda', dtype=FP8_DTYPE): - with torch.no_grad(): + self.optimizer.zero_grad() + + # Apply te.fp8_autocast only to the teacher model. + with torch.no_grad(): + with te.fp8_autocast(enabled=True): base_outputs = self.base_model(**batch) base_logits = base_outputs.logits - - adapt_outputs = self.adapt_model(**batch) - adapt_logits = adapt_outputs.logits - - soft_targets = F.softmax(base_logits / self.temperature, dim=-1) - soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) - distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) - - lm_loss = self.lm_loss_fn(adapt_logits.view(-1, adapt_logits.size(-1)), batch['labels'].view(-1)) - - total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss + + # The student model operates in its original precision (e.g., bfloat16). + adapt_outputs = self.adapt_model(**batch) + adapt_logits = adapt_outputs.logits + + # Loss calculation + soft_targets = F.softmax(base_logits / self.temperature, dim=-1) + soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) + distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) + + lm_loss = self.lm_loss_fn( + adapt_logits.view(-1, adapt_logits.size(-1)), + batch['labels'].view(-1) + ) + + total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss - self.optimizer.zero_grad() total_loss.backward() self.optimizer.step() if self.lr_scheduler: @@ -108,20 +169,25 @@ def evaluate_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int): with torch.no_grad(): for batch in progress_bar: batch = {k: v.to(primary_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - with torch.autocast(device_type='cuda', dtype=FP8_DTYPE): + + # Also apply te.fp8_autocast only to the teacher model during evaluation. + with te.fp8_autocast(enabled=True): base_outputs = self.base_model(**batch) base_logits = base_outputs.logits - adapt_outputs = self.adapt_model(**batch) - adapt_logits = adapt_outputs.logits + adapt_outputs = self.adapt_model(**batch) + adapt_logits = adapt_outputs.logits - soft_targets = F.softmax(base_logits / self.temperature, dim=-1) - soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) - distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) - - lm_loss = self.lm_loss_fn(adapt_logits.view(-1, adapt_logits.size(-1)), batch['labels'].view(-1)) - total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss + soft_targets = F.softmax(base_logits / self.temperature, dim=-1) + soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) + distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) + + lm_loss = self.lm_loss_fn( + adapt_logits.view(-1, adapt_logits.size(-1)), + batch['labels'].view(-1) + ) + + total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss total_eval_loss += total_loss.item() @@ -131,10 +197,16 @@ def evaluate_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int): if self.use_wandb: global_step = epoch * self.train_dataloader_len wandb.log({"eval/epoch_loss": avg_eval_loss}, step=global_step) + + return avg_eval_loss def run_training(self, train_dataloader: DataLoader, num_epochs: int, eval_dataloader: DataLoader = None): self.train_dataloader_len = len(train_dataloader) for epoch in range(1, num_epochs + 1): self.train_epoch(train_dataloader, epoch, num_epochs) if eval_dataloader: - self.evaluate_epoch(eval_dataloader, epoch, num_epochs) \ No newline at end of file + avg_eval_loss = self.evaluate_epoch(eval_dataloader, epoch, num_epochs) + self.save_best_model(avg_eval_loss, epoch) + self.save_checkpoint(epoch, avg_eval_loss) + else: + self.save_checkpoint(epoch) \ No newline at end of file From 763c95f42f95660909829c17e68e60fb5d7611a7 Mon Sep 17 00:00:00 2001 From: njhvrta Date: Tue, 2 Sep 2025 01:55:29 +0900 Subject: [PATCH 09/13] feat: dataset_generator editted --- retentionengine/datasets/dataset_generator.py | 69 ++++++++++--------- 1 file changed, 36 insertions(+), 33 deletions(-) diff --git a/retentionengine/datasets/dataset_generator.py b/retentionengine/datasets/dataset_generator.py index 2af3b8f..ebf7074 100644 --- a/retentionengine/datasets/dataset_generator.py +++ b/retentionengine/datasets/dataset_generator.py @@ -1,12 +1,11 @@ # -*- coding: utf-8 -*- """ -Qwen3 기반 메모리 모듈 학습용 데이터셋 생성기 (안정화 버전) ------------------------------------------------- -생성된 데이터: RETENTIONENGINE/retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl -→ 2/3 한국어, 1/3 영어 샘플 포함 +Qwen3-based Memory Module Training Dataset Generator (Stabilized Version) +---------------------------------------------------------------------- +Generated data: RETENTIONENGINE/retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl +→ Includes 2/3 Korean and 1/3 English samples. """ -import os import json import re import time @@ -15,14 +14,13 @@ from typing import List, Dict from transformers import AutoTokenizer, AutoModelForCausalLM import torch -import json_repair # 🟢 라이브러리 다시 임포트 - +import json_repair # Library for repairing malformed JSON def make_seed_prompt(batch_size: int = 100) -> str: """ - 2/3 확률로 한국어, 1/3 확률로 영어 프롬프트 반환 + Return a Korean prompt with 2/3 probability, and an English prompt with 1/3 probability. """ - if random.random() < 0.33: # 약 1/3 확률로 영어 + if random.random() < 0.33: # Approx. 1/3 probability for English return make_seed_prompt_english(batch_size) else: return make_seed_prompt_korean(batch_size) @@ -30,7 +28,7 @@ def make_seed_prompt(batch_size: int = 100) -> str: def make_seed_prompt_korean(batch_size: int = 5) -> str: """ - 한국어 프롬프트 + Korean prompt """ domains = [ ("math", 20), ("cs", 20), ("physics", 15), ("literature", 15), @@ -92,7 +90,7 @@ def make_seed_prompt_korean(batch_size: int = 5) -> str: def make_seed_prompt_english(batch_size: int = 100) -> str: """ - 영어 프롬프트 (1/3 확률로 호출) + English prompt (called with 1/3 probability) """ domains = [ ("math", 20), ("cs", 20), ("physics", 15), ("literature", 15), @@ -141,49 +139,49 @@ def make_seed_prompt_english(batch_size: int = 100) -> str: def parse_json_array(text: str) -> List[Dict]: """ - 생성된 텍스트에서 유효한 JSON 배열을 추출 + Extract a valid JSON array from the generated text. """ + # Find the first '[' and the last ']' to capture the array match = re.search(r'\[.*\]', text, re.DOTALL) if not match: - print("⚠️ 경고: 응답에서 JSON 배열을 찾지 못했습니다.") + print("⚠️ WARNING: Could not find a JSON array in the response.") return [] json_array_str = match.group(0) try: data = json.loads(json_array_str) except json.JSONDecodeError: - # 🟢 JSON 디코딩 실패 시 json_repair로 복구 시도 - print("⚠️ 경고: JSON 디코딩 실패. json_repair로 복구를 시도합니다.") + # If standard JSON decoding fails, attempt to repair with json_repair + print("⚠️ WARNING: JSON decoding failed. Attempting to repair with json_repair.") try: data = json_repair.loads(json_array_str) except Exception as e: - print(f"🆘 json_repair로도 복구 실패: {e}") + print(f"🆘 CRITICAL: Repair with json_repair also failed: {e}") return [] - if isinstance(data, list): + # Ensure all items in the list are dicts with the required keys valid_items = [ item for item in data if isinstance(item, dict) and "input" in item and "target_output" in item ] return valid_items else: - print("⚠️ 경고: 파싱된 데이터가 리스트가 아닙니다.") + print("⚠️ WARNING: Parsed data is not a list.") return [] def generate_dataset( + model_name: str, target_count: int = 10000, batch_size: int = 100 ): """ - Qwen3을 사용해 데이터셋을 생성하고, 지정된 경로에 저장 + Generate a dataset using the specified model and save it to a JSONL file. """ output_file = "qwen3_memory_aligned_dataset.jsonl" - print("🚀 모델을 로딩 중입니다... (Qwen/Qwen2-7B-Instruct)") - model_name = "Qwen/Qwen2-7B-Instruct" - + print(f"🚀 Loading model: {model_name}...") tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained( model_name, @@ -191,7 +189,7 @@ def generate_dataset( torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 ) - print(f"✅ 모델 로딩 완료. 저장 위치: {output_file}") + print(f"✅ Model loaded successfully. Output file: {output_file}") generated_samples = [] @@ -199,10 +197,11 @@ def generate_dataset( while len(generated_samples) < target_count: remaining = target_count - len(generated_samples) current_batch_size = min(batch_size, remaining) - print(f"🔄 생성 중: {len(generated_samples)} / {target_count} | 이번 배치: {current_batch_size}") + print(f"🔄 Generating: {len(generated_samples)} / {target_count} | Current batch size: {current_batch_size}") prompt = make_seed_prompt(current_batch_size) + # Apply the Qwen2 chat template system_prompt = "<|im_start|>system\nYou are a helpful assistant that generates high-quality synthetic data.<|im_end|>" user_prompt = f"<|im_start|>user\n{prompt}<|im_end|>" assistant_prompt = "<|im_start|>assistant" @@ -211,7 +210,9 @@ def generate_dataset( inputs = tokenizer(final_text, return_tensors="pt").to(model.device) + # Dynamically calculate max_new_tokens to avoid exceeding context window input_len = inputs["input_ids"].shape[1] + # Set a reasonable generation cap and respect the model's context length (32768 for Qwen2) max_new = min(16000, 32768 - input_len - 10) with torch.no_grad(): @@ -225,38 +226,40 @@ def generate_dataset( pad_token_id=tokenizer.eos_token_id ) + # Decode only the newly generated tokens input_token_length = inputs["input_ids"].shape[1] generated_tokens = outputs[0][input_token_length:] content = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() new_samples = parse_json_array(content) - print(f"✅ 이번 배치에서 {len(new_samples)}개의 유효한 샘플 추출") + print(f"✅ Extracted {len(new_samples)} valid samples from this batch.") if not new_samples: - print("⚠️ 이번 배치에서 유효한 샘플을 생성하지 못했습니다. 재시도합니다.") + print("⚠️ Failed to generate valid samples in this batch. Retrying after a delay.") time.sleep(5) continue for sample in new_samples: if len(generated_samples) >= target_count: break - sample["id"] = str(uuid.uuid4()) + sample["id"] = str(uuid.uuid4()) # Assign a unique ID f_out.write(json.dumps(sample, ensure_ascii=False) + "\n") generated_samples.append(sample) - f_out.flush() - print(f"📊 누적 생성: {len(generated_samples)} / {target_count}") + f_out.flush() # Write data to disk periodically + print(f"📊 Total generated: {len(generated_samples)} / {target_count}") - time.sleep(2) + time.sleep(2) # Small delay between API calls - print(f"\n🎉 성공적으로 완료!") - print(f"📁 저장 위치: {output_file}") - print(f"📊 총 생성된 샘플 수: {len(generated_samples)}") + print(f"\n🎉 Generation successfully completed!") + print(f"📁 Data saved to: {output_file}") + print(f"📊 Total samples generated: {len(generated_samples)}") return generated_samples if __name__ == "__main__": samples = generate_dataset( + model_name="Qwen/Qwen3-8B-Instruct", target_count=5000, batch_size=25 ) \ No newline at end of file From dc291a046fee65e3c075ce3c03bdc8e0ea7e69c2 Mon Sep 17 00:00:00 2001 From: njhvrta Date: Tue, 2 Sep 2025 01:56:32 +0900 Subject: [PATCH 10/13] feat: dataset_generator editted --- retentionengine/datasets/dataset_generator.py | 122 +++++++++--------- 1 file changed, 64 insertions(+), 58 deletions(-) diff --git a/retentionengine/datasets/dataset_generator.py b/retentionengine/datasets/dataset_generator.py index ebf7074..85c30f1 100644 --- a/retentionengine/datasets/dataset_generator.py +++ b/retentionengine/datasets/dataset_generator.py @@ -1,8 +1,8 @@ # -*- coding: utf-8 -*- """ -Qwen3-based Memory Module Training Dataset Generator (Stabilized Version) +Qwen2-based Memory Module Training Dataset Generator (Stabilized Version) ---------------------------------------------------------------------- -Generated data: RETENTIONENGINE/retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl +Generated data: RETENTIONENGINE/retentionengine/datasets/qwen2_memory_aligned_dataset.jsonl → Includes 2/3 Korean and 1/3 English samples. """ @@ -12,11 +12,12 @@ import random import uuid from typing import List, Dict +import os # Added for directory creation from transformers import AutoTokenizer, AutoModelForCausalLM import torch -import json_repair # Library for repairing malformed JSON +import json_repair # Library for repairing malformed JSON; install via: pip install json-repair -def make_seed_prompt(batch_size: int = 100) -> str: +def make_seed_prompt(batch_size: int = 25) -> str: # Reduced default for faster testing """ Return a Korean prompt with 2/3 probability, and an English prompt with 1/3 probability. """ @@ -25,8 +26,7 @@ def make_seed_prompt(batch_size: int = 100) -> str: else: return make_seed_prompt_korean(batch_size) - -def make_seed_prompt_korean(batch_size: int = 5) -> str: +def make_seed_prompt_korean(batch_size: int = 25) -> str: """ Korean prompt """ @@ -43,7 +43,7 @@ def make_seed_prompt_korean(batch_size: int = 5) -> str: "id": "sample_001", "domain": "예: math, physics, literature, cs, philosophy 등", "input": "장문 또는 단문 입력. 다양한 도메인 혼합 가능", - "target_output": "Qwen3이 생성할 것으로 예상되는 출력", + "target_output": "Qwen2이 생성할 것으로 예상되는 출력", # Updated to Qwen2 "reasoning_trace": [ {{ "step": 1, @@ -87,8 +87,7 @@ def make_seed_prompt_korean(batch_size: int = 5) -> str: """ return prompt.strip() - -def make_seed_prompt_english(batch_size: int = 100) -> str: +def make_seed_prompt_english(batch_size: int = 25) -> str: """ English prompt (called with 1/3 probability) """ @@ -105,7 +104,7 @@ def make_seed_prompt_english(batch_size: int = 100) -> str: "id": "sample_001", "domain": "e.g., math, physics, literature, cs, philosophy", "input": "Long or short input. Mix of domains allowed", - "target_output": "Output expected from Qwen3", + "target_output": "Output expected from Qwen2", # Updated to Qwen2 "reasoning_trace": [ {{ "step": 1, @@ -136,7 +135,6 @@ def make_seed_prompt_english(batch_size: int = 100) -> str: """ return prompt.strip() - def parse_json_array(text: str) -> List[Dict]: """ Extract a valid JSON array from the generated text. @@ -170,26 +168,30 @@ def parse_json_array(text: str) -> List[Dict]: print("⚠️ WARNING: Parsed data is not a list.") return [] - def generate_dataset( model_name: str, target_count: int = 10000, - batch_size: int = 100 + batch_size: int = 25 ): """ Generate a dataset using the specified model and save it to a JSONL file. """ - output_file = "qwen3_memory_aligned_dataset.jsonl" + output_dir = "RETENTIONENGINE/retentionengine/datasets" + output_file = os.path.join(output_dir, "qwen2_memory_aligned_dataset.jsonl") # Updated name and path + os.makedirs(output_dir, exist_ok=True) # Create dirs if needed print(f"🚀 Loading model: {model_name}...") - tokenizer = AutoTokenizer.from_pretrained(model_name) - model = AutoModelForCausalLM.from_pretrained( - model_name, - device_map="auto", - torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 - ) - - print(f"✅ Model loaded successfully. Output file: {output_file}") + try: + tokenizer = AutoTokenizer.from_pretrained(model_name) + model = AutoModelForCausalLM.from_pretrained( + model_name, + device_map="auto", + torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ) + print(f"✅ Model loaded successfully. Output file: {output_file}") + except Exception as e: + print(f"🆘 ERROR: Failed to load model: {e}") + return [] # Early exit on load failure generated_samples = [] @@ -215,51 +217,55 @@ def generate_dataset( # Set a reasonable generation cap and respect the model's context length (32768 for Qwen2) max_new = min(16000, 32768 - input_len - 10) - with torch.no_grad(): - outputs = model.generate( - **inputs, - max_new_tokens=max_new, - temperature=0.5, - do_sample=True, - top_p=0.9, - eos_token_id=tokenizer.eos_token_id, - pad_token_id=tokenizer.eos_token_id - ) - - # Decode only the newly generated tokens - input_token_length = inputs["input_ids"].shape[1] - generated_tokens = outputs[0][input_token_length:] - content = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() - - new_samples = parse_json_array(content) - print(f"✅ Extracted {len(new_samples)} valid samples from this batch.") - - if not new_samples: - print("⚠️ Failed to generate valid samples in this batch. Retrying after a delay.") + try: + with torch.no_grad(): + outputs = model.generate( + **inputs, + max_new_tokens=max_new, + temperature=0.5, + do_sample=True, + top_p=0.9, + eos_token_id=tokenizer.eos_token_id, + pad_token_id=tokenizer.eos_token_id + ) + + # Decode only the newly generated tokens + input_token_length = inputs["input_ids"].shape[1] + generated_tokens = outputs[0][input_token_length:] + content = tokenizer.decode(generated_tokens, skip_special_tokens=True).strip() + + new_samples = parse_json_array(content) + print(f"✅ Extracted {len(new_samples)} valid samples from this batch.") + + if not new_samples: + print("⚠️ Failed to generate valid samples in this batch. Retrying after a delay.") + time.sleep(5) + continue + + for sample in new_samples: + if len(generated_samples) >= target_count: + break + sample["id"] = str(uuid.uuid4()) # Assign a unique ID + f_out.write(json.dumps(sample, ensure_ascii=False) + "\n") + generated_samples.append(sample) + + f_out.flush() # Write data to disk periodically + print(f"📊 Total generated: {len(generated_samples)} / {target_count}") + + time.sleep(2) # Small delay between API calls + except Exception as e: + print(f"🆘 ERROR during generation: {e}. Retrying after delay.") time.sleep(5) continue - for sample in new_samples: - if len(generated_samples) >= target_count: - break - sample["id"] = str(uuid.uuid4()) # Assign a unique ID - f_out.write(json.dumps(sample, ensure_ascii=False) + "\n") - generated_samples.append(sample) - - f_out.flush() # Write data to disk periodically - print(f"📊 Total generated: {len(generated_samples)} / {target_count}") - - time.sleep(2) # Small delay between API calls - print(f"\n🎉 Generation successfully completed!") print(f"📁 Data saved to: {output_file}") print(f"📊 Total samples generated: {len(generated_samples)}") return generated_samples - if __name__ == "__main__": - samples = generate_dataset( - model_name="Qwen/Qwen3-8B-Instruct", + generate_dataset( + model_name="Qwen/Qwen2-7B-Instruct", # Fixed to a real model; adjust size if needed (e.g., -1.5B for smaller) target_count=5000, batch_size=25 ) \ No newline at end of file From 99962bbb1e372ff2a1bbfe234a13ec93a33f3304 Mon Sep 17 00:00:00 2001 From: njhvrta Date: Tue, 2 Sep 2025 02:35:27 +0900 Subject: [PATCH 11/13] feat: editting restoration_Training --- restoration_training.ipynb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/restoration_training.ipynb b/restoration_training.ipynb index 43c60e7..4191fe3 100644 --- a/restoration_training.ipynb +++ b/restoration_training.ipynb @@ -53,10 +53,11 @@ "import torch.nn as nn\n", "import random\n", "\n", - "from transformers import AdamW, get_linear_schedule_with_warmup, Qwen3ForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM\n", + "from transformers import get_linear_schedule_with_warmup, Qwen3ForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM\n", "\n", "from torch.utils.data import DataLoader\n", "from datasets import load_dataset\n", + "from torch.optim import AdamW\n", "\n", "from tqdm.auto import tqdm\n", "import wandb\n", From f4b45c11d4b7d6bcf8e4df82dd2d6ba4c0a8aedb Mon Sep 17 00:00:00 2001 From: njhvrta Date: Fri, 5 Sep 2025 01:48:14 +0900 Subject: [PATCH 12/13] feat: editting adapter restoration_training --- restoration_training.ipynb | 306 +++++++++++++------------------ retentionengine/utils/adapter.py | 216 ++++------------------ 2 files changed, 165 insertions(+), 357 deletions(-) diff --git a/restoration_training.ipynb b/restoration_training.ipynb index 4191fe3..41e8978 100644 --- a/restoration_training.ipynb +++ b/restoration_training.ipynb @@ -18,22 +18,6 @@ "!nvidia_smi" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "c36b79fd", - "metadata": {}, - "outputs": [], - "source": [ - "# Set CUDA Device Number\n", - "DEVICE_NUM = 0\n", - "ADDITIONAL_GPU = 3\n", - "\n", - "from os import environ\n", - "environ[\"CUDA_VISIBLE_DEVICES\"] = \",\".join([f\"{i+DEVICE_NUM}\" for i in range(ADDITIONAL_GPU + 1)])\n", - "environ[\"CUDA_VISIBLE_DEVICES\"]" - ] - }, { "cell_type": "markdown", "id": "6f14a538", @@ -50,39 +34,11 @@ "outputs": [], "source": [ "import torch\n", - "import torch.nn as nn\n", - "import random\n", - "\n", - "from transformers import get_linear_schedule_with_warmup, Qwen3ForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM\n", - "\n", - "from torch.utils.data import DataLoader\n", - "from datasets import load_dataset\n", - "from torch.optim import AdamW\n", - "\n", - "from tqdm.auto import tqdm\n", "import wandb\n", - "import transformer_engine.pytorch as te\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments\n", + "from datasets import load_dataset\n", "\n", - "from utils.adapter import Adapter" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "197bbe62", - "metadata": {}, - "outputs": [], - "source": [ - "if torch.cuda.is_available():\n", - " if ADDITIONAL_GPU:\n", - " device = torch.device(\"cuda\")\n", - " else:\n", - " device = torch.device(f\"cuda\")\n", - "else:\n", - " device = torch.device(\"cpu\")\n", - " DEVICE_NUM = -1\n", - " \n", - "print(f\"INFO: Using device - {device}\" + (f\":{DEVICE_NUM}\" if ADDITIONAL_GPU else \"\"))" + "from adapter import DistillationTrainer" ] }, { @@ -166,56 +122,27 @@ { "cell_type": "code", "execution_count": null, - "id": "130802a8", - "metadata": {}, - "outputs": [], - "source": [ - "torch_dtype = torch.bfloat16" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "2387f7e6", + "id": "2a1e4445", "metadata": {}, "outputs": [], "source": [ + "torch_dtype = torch.bfloat16\n", + "\n", "base_model_id = \"Qwen/Qwen3-8B\"\n", - "base_model = Qwen3ForCausalLM.from_pretrained(\n", + "base_model = AutoModelForCausalLM.from_pretrained(\n", " base_model_id,\n", " torch_dtype=torch_dtype,\n", - " device_map=\"auto\"\n", - " )\n", - "tokenizer = AutoTokenizer.from_pretrained(base_model_id, use_fast=False)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "20171342", - "metadata": {}, - "outputs": [], - "source": [ - "adapt_model_id = \"pretrained/Qwen/Qwen3-8B_adapter\"\n", - "adapt_model = Qwen3ForCausalLM.from_pretrained(\n", + ")\n", + "\n", + "adapt_model_id = \"pretrained/Qwen/Qwen3-8B_adapter\" # 예시 경로\n", + "adapt_model = AutoModelForCausalLM.from_pretrained(\n", " adapt_model_id,\n", " torch_dtype=torch_dtype,\n", - " device_map=\"auto\"\n", - " )" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "40dcb3c8", - "metadata": {}, - "outputs": [], - "source": [ - "base_model.eval()\n", - "for param in base_model.parameters():\n", - " param.requires_grad = False\n", - " \n", - "adapt_model.train()" + ")\n", + "\n", + "tokenizer = AutoTokenizer.from_pretrained(base_model_id)\n", + "if tokenizer.pad_token is None:\n", + " tokenizer.pad_token = tokenizer.eos_token" ] }, { @@ -225,15 +152,11 @@ "metadata": {}, "outputs": [], "source": [ - "if tokenizer.pad_token is None:\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - "\n", "def tokenize_function(examples):\n", " formatted_text = [\n", - " f\"질문: {inp}\\\\n\\\\n답변: {out}{tokenizer.eos_token}\" \n", + " f\"질문: {inp}\\n\\n답변: {out}{tokenizer.eos_token}\"\n", " for inp, out in zip(examples[\"input\"], examples[\"target_output\"])\n", " ]\n", - "\n", " tokenized = tokenizer(formatted_text, truncation=True, padding=\"max_length\", max_length=512)\n", " tokenized[\"labels\"] = tokenized[\"input_ids\"].copy()\n", " return tokenized\n", @@ -252,10 +175,7 @@ "split_dataset = tokenized_datasets[\"train\"].train_test_split(test_size=0.1, seed=42)\n", "\n", "train_dataset = split_dataset[\"train\"]\n", - "eval_dataset = split_dataset[\"test\"]\n", - "\n", - "train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=8)\n", - "eval_dataloader = DataLoader(eval_dataset, shuffle=False, batch_size=8)" + "eval_dataset = split_dataset[\"test\"]" ] }, { @@ -273,44 +193,33 @@ "metadata": {}, "outputs": [], "source": [ - "# Hyperparameters\n", - "EPOCHS = 10\n", - "LEARNING_RATE = 5e-5\n", - "ALPHA = 0.5\n", - "T = 0.3\n", - "NUM_TRAINING_STEPS = EPOCHS * len(train_dataloader)\n", - "\n", - "# optimizer and Scheduler\n", - "optimizer = AdamW(adapt_model.parameters(), lr=LEARNING_RATE)\n", - "lr_scheduler = get_linear_schedule_with_warmup(\n", - " optimizer,\n", - " num_warmup_steps=0,\n", - " num_training_steps=NUM_TRAINING_STEPS\n", + "training_args = TrainingArguments(\n", + " output_dir=\"./qwen3-restored-final\",\n", + " num_train_epochs=10,\n", + " per_device_train_batch_size=8,\n", + " per_device_eval_batch_size=8,\n", + " learning_rate=5e-5,\n", + " bf16=True,\n", + " evaluation_strategy=\"epoch\",\n", + " save_strategy=\"epoch\", # 매 에포크 종료 시 모델 체크포인트 저장\n", + " logging_steps=10, # 10 스텝마다 로그 출력\n", + " load_best_model_at_end=True,\n", + " metric_for_best_model=\"eval_loss\",\n", + " greater_is_better=False,\n", + " report_to=\"wandb\",\n", ")\n", "\n", - "# Loss Functions\n", - "distillation_loss_fn = nn.KLDivLoss(reduction=\"batchmean\")\n", - "language_modeling_loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)\n", - "\n", - "# Initialize Adapter\n", - "adapter_trainer = Adapter(\n", - " adapt_model=adapt_model,\n", - " base_model=base_model,\n", - " optimizer=optimizer,\n", - " lr_scheduler=lr_scheduler,\n", - " distillation_loss_fn=distillation_loss_fn,\n", - " lm_loss_fn=language_modeling_loss_fn,\n", - " temperature=T,\n", - " alpha=ALPHA,\n", - " use_wandb=True\n", + "trainer = DistillationTrainer(\n", + " model=adapt_model,\n", + " teacher_model=base_model,\n", + " args=training_args,\n", + " train_dataset=train_dataset,\n", + " eval_dataset=eval_dataset,\n", + " alpha=1.0,\n", + " temperature=0.3,\n", ")\n", "\n", - "# Run\n", - "adapter_trainer.run_training(\n", - " train_dataloader,\n", - " num_epochs=EPOCHS,\n", - " eval_dataloader=eval_dataloader\n", - ")\n", + "trainer.train()\n", "\n", "wandb.finish()" ] @@ -330,67 +239,106 @@ "metadata": {}, "outputs": [], "source": [ - "output_dir = \"./adapt_model_distilled\"\n", - "adapt_model.save_pretrained(output_dir)\n", - "tokenizer.save_pretrained(output_dir)\n", - "print(f\"Adapted model saved successfully to {output_dir}\")\n", + "# 파일명: testing.py\n", "\n", - "test_base_model = adapter_trainer.base_model \n", - "test_base_model.eval()\n", - "adapt_model.eval()\n", + "import torch\n", + "from transformers import AutoTokenizer, AutoModelForCausalLM\n", + "from tqdm import tqdm\n", + "import json\n", "\n", - "user_question = \"Explain the concept of knowledge distillation in machine learning.\"\n", - "prompt = f\"question: {user_question}\\n\\nanswer:\"\n", + "TEACHER_MODEL_ID = \"Qwen/Qwen3-8B\"\n", + "STUDENT_MODEL_PATH = \"./qwen3-restored-final\"\n", + "TEST_DATA_PATH = \"retentionengine/datasets/qwen3_memory_aligned_dataset.jsonl\"\n", + "DEVICE = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n", + "NUM_TEST_SAMPLES = 50\n", + "MAX_NEW_TOKENS = 128\n", "\n", - "input_ids = tokenizer(prompt, return_tensors=\"pt\").input_ids.to(next(adapt_model.parameters()).device)\n", + "print(\"INFO: 모델 및 토크나이저 로딩 중...\")\n", + "tokenizer = AutoTokenizer.from_pretrained(TEACHER_MODEL_ID)\n", + "teacher_model = AutoModelForCausalLM.from_pretrained(TEACHER_MODEL_ID, torch_dtype=torch.bfloat16).to(DEVICE)\n", + "student_model = AutoModelForCausalLM.from_pretrained(STUDENT_MODEL_PATH, torch_dtype=torch.bfloat16).to(DEVICE)\n", "\n", - "print(f\"Prompt: {prompt}\\n\")\n", - "print(\"=== Token-by-Token Generation Comparison ===\\n\")\n", + "teacher_model.eval()\n", + "student_model.eval()\n", + "print(\"INFO: 로딩 완료.\")\n", "\n", - "base_output_ids = input_ids.clone()\n", - "adapt_output_ids = input_ids.clone()\n", + "test_prompts = []\n", + "try:\n", + " with open(TEST_DATA_PATH, 'r', encoding='utf-8') as f:\n", + " for i, line in enumerate(f):\n", + " if i >= NUM_TEST_SAMPLES:\n", + " break\n", + " data = json.loads(line)\n", + " prompt = f\"질문: {data['input']}\\n\\n답변:\"\n", + " test_prompts.append(prompt)\n", + "except FileNotFoundError:\n", + " print(f\"ERROR: 테스트 데이터 파일을 찾을 수 없습니다. 경로를 확인하세요: {TEST_DATA_PATH}\")\n", + " exit()\n", "\n", - "max_new_tokens = 100\n", - "mismatched_tokens = 0\n", - "total_tokens = 0\n", + "total_accuracy = 0\n", + "total_mismatched = 0\n", + "total_generated = 0\n", "\n", - "for step in range(max_new_tokens):\n", - " total_tokens += 1\n", - " \n", - " with torch.no_grad():\n", - " with te.fp8_autocast(enabled=True):\n", - " base_outputs = test_base_model(input_ids=base_output_ids)\n", - " base_logits = base_outputs.logits[:, -1, :]\n", - " base_predicted_id = torch.argmax(base_logits, dim=-1).unsqueeze(-1)\n", + "print(f\"\\nINFO: {len(test_prompts)}개의 샘플에 대한 테스트 시작...\")\n", + "for sample_idx, prompt in enumerate(tqdm(test_prompts, desc=\"Testing Progress\")):\n", + " inputs = tokenizer(prompt, return_tensors=\"pt\").to(DEVICE)\n", "\n", - " adapt_outputs = adapt_model(input_ids=adapt_output_ids)\n", - " adapt_logits = adapt_outputs.logits[:, -1, :]\n", - " adapt_predicted_id = torch.argmax(adapt_logits, dim=-1).unsqueeze(-1)\n", + " with torch.no_grad():\n", + " teacher_output_ids = teacher_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)\n", + " student_output_ids = student_model.generate(**inputs, max_new_tokens=MAX_NEW_TOKENS, do_sample=False)\n", "\n", - " base_output_ids = torch.cat([base_output_ids, base_predicted_id], dim=-1)\n", - " adapt_output_ids = torch.cat([adapt_output_ids, adapt_predicted_id], dim=-1)\n", + " teacher_tokens = teacher_output_ids[0][inputs.input_ids.shape[1]:]\n", + " student_tokens = student_output_ids[0][inputs.input_ids.shape[1]:]\n", + " \n", + " total_tokens_in_sample = len(teacher_tokens)\n", + " if total_tokens_in_sample == 0:\n", + " continue # Teacher가 아무것도 생성하지 않은 경우 해당 샘플은 건너뜀\n", "\n", - " if adapt_predicted_id.item() != base_predicted_id.item():\n", - " mismatched_tokens += 1\n", - " base_token_str = tokenizer.decode(base_predicted_id.squeeze())\n", - " adapt_token_str = tokenizer.decode(adapt_predicted_id.squeeze())\n", - " print(f\"Step {step+1}: Mismatch! Base='{base_token_str}' vs Adapt='{adapt_token_str}'\")\n", + " mismatched_tokens_in_sample = 0\n", + " mismatch_details = []\n", "\n", - " if base_predicted_id.item() == tokenizer.eos_token_id or adapt_predicted_id.item() == tokenizer.eos_token_id:\n", - " print(\"\\nINFO: EOS token generated. Halting generation.\")\n", - " break\n", + " len_to_compare = min(len(teacher_tokens), len(student_tokens))\n", + " for i in range(len_to_compare):\n", + " if teacher_tokens[i] != student_tokens[i]:\n", + " mismatched_tokens_in_sample += 1\n", + " mismatch_details.append(\n", + " f\" - Position #{i+1}: \"\n", + " f\"Teacher='{tokenizer.decode(teacher_tokens[i])}' (ID: {teacher_tokens[i]}) vs \"\n", + " f\"Student='{tokenizer.decode(student_tokens[i])}' (ID: {student_tokens[i]})\"\n", + " )\n", + " \n", + " len_diff = abs(len(teacher_tokens) - len(student_tokens))\n", + " if len_diff > 0:\n", + " mismatched_tokens_in_sample += len_diff\n", + " stop_reason = \"Teacher가 먼저 중단\" if len(teacher_tokens) < len(student_tokens) else \"Student가 먼저 중단\"\n", + " mismatch_details.append(\n", + " f\" - 생성 길이 불일치: Teacher={len(teacher_tokens)}개, Student={len(student_tokens)}개 ({stop_reason})\"\n", + " )\n", "\n", - "base_text = tokenizer.decode(base_output_ids[0], skip_special_tokens=True)\n", - "adapt_text = tokenizer.decode(adapt_output_ids[0], skip_special_tokens=True)\n", + " if mismatch_details:\n", + " print(f\"\\n\\n--- [sample #{sample_idx+1}] doesn't match---\")\n", + " print(f\"Prompt: \\\"{prompt[:80]}...\\\"\")\n", + " for detail in mismatch_details:\n", + " print(detail)\n", + " print(f\"Teacher output: \\\"...{tokenizer.decode(teacher_tokens, skip_special_tokens=True)}\\\"\")\n", + " print(f\"Student output: \\\"...{tokenizer.decode(student_tokens, skip_special_tokens=True)}\\\"\")\n", + " print(\"-\" * 25)\n", "\n", - "print(\"\\n--- Full Generation Results ---\")\n", - "print(\"Base Model Output:\\n\", base_text)\n", - "print(\"\\nAdapted Model Output:\\n\", adapt_text)\n", + " accuracy = ((total_tokens_in_sample - mismatched_tokens_in_sample) / total_tokens_in_sample) * 100\n", + " total_accuracy += accuracy\n", + " total_mismatched += mismatched_tokens_in_sample\n", + " total_generated += total_tokens_in_sample\n", "\n", - "token_accuracy = ((total_tokens - mismatched_tokens) / total_tokens) * 100 if total_tokens > 0 else 0\n", - "print(\"\\n--- Evaluation Metrics ---\")\n", - "print(f\"Token Match Rate (Accuracy): {token_accuracy:.2f}%\")\n", - "print(f\"Total Generated Tokens: {total_tokens}, Mismatched Tokens: {mismatched_tokens}\")" + "if len(test_prompts) > 0:\n", + " avg_accuracy = total_accuracy / len(test_prompts)\n", + " print(\"\\n\\n\" + \"=\"*30)\n", + " print(\"--- 최종 평가 결과 요약 ---\")\n", + " print(f\"평균 토큰 일치율 (정확도): {avg_accuracy:.2f}%\")\n", + " print(f\"총 생성 토큰 (Teacher 기준): {total_generated}\")\n", + " print(f\"총 불일치 토큰: {total_mismatched}\")\n", + " print(\"=\"*30)\n", + "else:\n", + " print(\"테스트할 데이터를 찾지 못해 평가를 진행할 수 없습니다.\")" ] } ], diff --git a/retentionengine/utils/adapter.py b/retentionengine/utils/adapter.py index 6d17c83..b918f68 100644 --- a/retentionengine/utils/adapter.py +++ b/retentionengine/utils/adapter.py @@ -1,17 +1,13 @@ +# 파일명: huggingface_adapter.py + import torch -from torch import nn, optim -from torch.utils.data import DataLoader +import torch.nn as nn import torch.nn.functional as F -from tqdm.auto import tqdm -import wandb -import os +from transformers import Trainer import transformer_engine.pytorch as te - def convert_to_fp8_layers(module: nn.Module) -> nn.Module: - """ - Replaces nn.Linear and nn.LayerNorm with Transformer Engine's FP8-supported layers. - """ + """nn.Linear와 nn.LayerNorm을 Transformer Engine의 FP8 지원 레이어로 변환합니다.""" for name, child in module.named_children(): if isinstance(child, nn.Linear): setattr(module, name, te.Linear( @@ -28,185 +24,49 @@ def convert_to_fp8_layers(module: nn.Module) -> nn.Module: convert_to_fp8_layers(child) return module -class Adapter: - def __init__( - self, - adapt_model: nn.Module, - base_model: nn.Module, - optimizer: optim.Optimizer, - lr_scheduler, - distillation_loss_fn: nn.Module, - lm_loss_fn: nn.Module, - temperature: float = 2.0, - alpha: float = 0.5, - use_wandb: bool = True, - checkpoint_dir: str = "./checkpoints", - ): - self.adapt_model = adapt_model - - # Convert the teacher model's layers to FP8-supported layers. - print("INFO: Converting base_model to support FP8 with Transformer Engine...") - self.base_model = convert_to_fp8_layers(base_model) - print("INFO: Conversion complete.") +class DistillationTrainer(Trainer): + """ + KL Divergence Loss를 사용한 지식 증류(복원 학습)를 위한 커스텀 Trainer. + """ + def __init__(self, *args, teacher_model, alpha=1.0, temperature=2.0, **kwargs): + super().__init__(*args, **kwargs) - self.optimizer = optimizer - self.lr_scheduler = lr_scheduler - self.distillation_loss_fn = distillation_loss_fn - self.lm_loss_fn = lm_loss_fn - self.temperature = temperature + print("INFO: Teacher 모델을 FP8 지원 레이어로 변환합니다...") + self.teacher_model = convert_to_fp8_layers(teacher_model) + print("INFO: 변환 완료.") + self.alpha = alpha - self.use_wandb = use_wandb + self.temperature = temperature - self.base_model.eval() - for param in self.base_model.parameters(): + self.teacher_model.to(self.model.device) + self.teacher_model.eval() + for param in self.teacher_model.parameters(): param.requires_grad = False - self.checkpoint_dir = checkpoint_dir - if not os.path.exists(self.checkpoint_dir): - os.makedirs(self.checkpoint_dir, exist_ok=True) - self.best_eval_loss = float('inf') - self.train_dataloader_len = 0 - - def save_checkpoint(self, epoch: int, eval_loss: float = None): - checkpoint_path = os.path.join(self.checkpoint_dir, f"epoch_{epoch}_checkpoint.pth") - - torch.save({ - 'epoch': epoch, - 'model_state_dict': self.adapt_model.state_dict(), - 'optimizer_state_dict': self.optimizer.state_dict(), - 'lr_scheduler_state_dict': self.lr_scheduler.state_dict() if self.lr_scheduler else None, - 'eval_loss': eval_loss, - 'best_eval_loss': self.best_eval_loss - }, checkpoint_path) - print(f"INFO: Checkpoint saved to {checkpoint_path}") - - def save_best_model(self, eval_loss: float, epoch: int): - if eval_loss < self.best_eval_loss: - print(f"INFO: New best model found! Eval Loss: {eval_loss:.4f} (Previous: {self.best_eval_loss:.4f})") - self.best_eval_loss = eval_loss - best_model_path = os.path.join(self.checkpoint_dir, "best_model.pth") - torch.save({ - 'epoch': epoch, - 'model_state_dict': self.adapt_model.state_dict(), - 'eval_loss': eval_loss - }, best_model_path) - print(f"INFO: Best model saved to {best_model_path}") + def compute_loss(self, model, inputs, return_outputs=False): + """ + 손실 계산 로직: 학생 모델의 LM Loss와 선생님-학생 간 Distillation Loss를 조합. + """ + # 학생(Student) 모델의 출력 및 기본 손실(LM Loss) 계산 + student_outputs = model(**inputs) + loss_lm = student_outputs.loss + logits_student = student_outputs.logits - def train_step(self, batch: dict) -> tuple: - primary_device = next(self.adapt_model.parameters()).device - batch = {k: v.to(primary_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - self.adapt_model.train() - self.optimizer.zero_grad() - - # Apply te.fp8_autocast only to the teacher model. + # 선생님(Teacher) 모델의 출력을 FP8로 빠르게 계산 with torch.no_grad(): with te.fp8_autocast(enabled=True): - base_outputs = self.base_model(**batch) - base_logits = base_outputs.logits - - # The student model operates in its original precision (e.g., bfloat16). - adapt_outputs = self.adapt_model(**batch) - adapt_logits = adapt_outputs.logits - - # Loss calculation - soft_targets = F.softmax(base_logits / self.temperature, dim=-1) - soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) - distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) - - lm_loss = self.lm_loss_fn( - adapt_logits.view(-1, adapt_logits.size(-1)), - batch['labels'].view(-1) - ) - - total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss - - total_loss.backward() - self.optimizer.step() - if self.lr_scheduler: - self.lr_scheduler.step() - - return total_loss.item(), distillation_loss.item(), lm_loss.item() - - def train_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int): - total_epoch_loss = 0 - progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs} [T]", unit="batch") - - for step, batch in enumerate(progress_bar): - loss, dist_loss, lm_loss = self.train_step(batch) - total_epoch_loss += loss - - progress_bar.set_postfix({ - 'loss': f'{loss:.4f}', - 'dist_loss': f'{dist_loss:.4f}', - 'lm_loss': f'{lm_loss:.4f}' - }) - - if self.use_wandb: - global_step = (epoch - 1) * len(dataloader) + step - wandb.log({ - "train/step_loss": loss, - "train/distillation_loss": dist_loss, - "train/lm_loss": lm_loss, - "train/learning_rate": self.lr_scheduler.get_last_lr()[0] - }, step=global_step) - - avg_epoch_loss = total_epoch_loss / len(dataloader) - print(f"Epoch [{epoch}/{num_epochs}] Train Avg Loss: {avg_epoch_loss:.4f}") - - if self.use_wandb: - global_step = epoch * len(dataloader) - wandb.log({"train/epoch_loss": avg_epoch_loss}, step=global_step) + teacher_outputs = self.teacher_model(**inputs) + logits_teacher = teacher_outputs.logits - def evaluate_epoch(self, dataloader: DataLoader, epoch: int, num_epochs: int): - print("INFO: Starting evaluation...") - self.adapt_model.eval() - total_eval_loss = 0 - primary_device = next(self.adapt_model.parameters()).device - - progress_bar = tqdm(dataloader, desc=f"Epoch {epoch}/{num_epochs} [E]", unit="batch") + # KL Divergence Loss를 사용한 Distillation Loss 계산 + distillation_loss_fn = nn.KLDivLoss(reduction="batchmean") - with torch.no_grad(): - for batch in progress_bar: - batch = {k: v.to(primary_device) for k, v in batch.items() if isinstance(v, torch.Tensor)} - - # Also apply te.fp8_autocast only to the teacher model during evaluation. - with te.fp8_autocast(enabled=True): - base_outputs = self.base_model(**batch) - base_logits = base_outputs.logits - - adapt_outputs = self.adapt_model(**batch) - adapt_logits = adapt_outputs.logits - - soft_targets = F.softmax(base_logits / self.temperature, dim=-1) - soft_prob = F.log_softmax(adapt_logits / self.temperature, dim=-1) - distillation_loss = self.distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) - - lm_loss = self.lm_loss_fn( - adapt_logits.view(-1, adapt_logits.size(-1)), - batch['labels'].view(-1) - ) - - total_loss = self.alpha * distillation_loss + (1. - self.alpha) * lm_loss - - total_eval_loss += total_loss.item() + soft_targets = F.softmax(logits_teacher / self.temperature, dim=-1) + soft_prob = F.log_softmax(logits_student / self.temperature, dim=-1) - avg_eval_loss = total_eval_loss / len(dataloader) - print(f"Epoch [{epoch}/{num_epochs}] Eval Avg Loss: {avg_eval_loss:.4f}") + loss_distill = distillation_loss_fn(soft_prob, soft_targets) * (self.temperature ** 2) - if self.use_wandb: - global_step = epoch * self.train_dataloader_len - wandb.log({"eval/epoch_loss": avg_eval_loss}, step=global_step) + # 최종 손실: alpha=1.0으로 설정 시, 오직 두 모델을 똑같이 만드는 데 집중 + loss = (1. - self.alpha) * loss_lm + self.alpha * loss_distill - return avg_eval_loss - - def run_training(self, train_dataloader: DataLoader, num_epochs: int, eval_dataloader: DataLoader = None): - self.train_dataloader_len = len(train_dataloader) - for epoch in range(1, num_epochs + 1): - self.train_epoch(train_dataloader, epoch, num_epochs) - if eval_dataloader: - avg_eval_loss = self.evaluate_epoch(eval_dataloader, epoch, num_epochs) - self.save_best_model(avg_eval_loss, epoch) - self.save_checkpoint(epoch, avg_eval_loss) - else: - self.save_checkpoint(epoch) \ No newline at end of file + return (loss, student_outputs) if return_outputs else loss \ No newline at end of file From 204078b0575e2e56fc551dd8158bd174d65e2a16 Mon Sep 17 00:00:00 2001 From: njhvrta Date: Sat, 6 Sep 2025 03:05:15 +0900 Subject: [PATCH 13/13] feat(engine): support embedding reuse, layer trimming, and config merge - attach base model input embeddings to Titans module for consistency - add support for trimming leading transformer layers based on size limit - refine config merge logic to replace with resized model config --- retentionengine/adapters/engine.py | 362 ++++++++++++++++++++++++----- 1 file changed, 298 insertions(+), 64 deletions(-) diff --git a/retentionengine/adapters/engine.py b/retentionengine/adapters/engine.py index 9bdf824..d1e7291 100644 --- a/retentionengine/adapters/engine.py +++ b/retentionengine/adapters/engine.py @@ -1,18 +1,126 @@ from thelethe.titans import PretrainedTitansConfig, PreTrainedTitansModel -from transformers import PreTrainedModel, AutoTokenizer +from transformers import PreTrainedModel, AutoTokenizer, AutoConfig from torch import nn from torch.optim import AdamW +import torch +from typing import Optional, Dict, Any +import math -from ..utils import Adapter -from ..datasets import get_pg19_dataloader + +class RetentionEngineConfig(PretrainedTitansConfig): + """Extended configuration with size management""" + + def __init__( + self, + max_total_params: float = 8e9, # 8B + base_model_name: str = None, + base_model_params: float = None, + layers_to_remove: int = 0, + original_num_layers: int = None, + module_params: float = None, + **kwargs + ): + super().__init__(**kwargs) + self.max_total_params = max_total_params + self.base_model_name = base_model_name + self.base_model_params = base_model_params + self.layers_to_remove = layers_to_remove + self.original_num_layers = original_num_layers + self.module_params = module_params + + +class ModelSizeManager: + """Utility class for model size calculation and management""" + + @staticmethod + def count_parameters(model: nn.Module, exclude_embeddings: bool = False) -> int: + """Count total parameters in a model""" + total = 0 + for name, param in model.named_parameters(): + if exclude_embeddings and ('embed' in name.lower() or 'lm_head' in name.lower()): + continue + total += param.numel() + return total + + @staticmethod + def estimate_layer_params(model: PreTrainedModel) -> int: + """Estimate parameters per transformer layer""" + config = model.config + + # Get model architecture details + hidden_size = getattr(config, 'hidden_size', 4096) + intermediate_size = getattr(config, 'intermediate_size', hidden_size * 4) + num_attention_heads = getattr(config, 'num_attention_heads', 32) + num_key_value_heads = getattr(config, 'num_key_value_heads', num_attention_heads) + + # Estimate parameters per layer + # Self-attention: Q, K, V, O projections + attn_params = ( + hidden_size * hidden_size + # Q + hidden_size * (hidden_size * num_key_value_heads // num_attention_heads) * 2 + # K, V + hidden_size * hidden_size # O + ) + + # MLP: gate, up, down projections + mlp_params = ( + hidden_size * intermediate_size * 2 + # gate, up + intermediate_size * hidden_size # down + ) + + # Layer norms + norm_params = hidden_size * 2 * 2 # 2 layer norms, weight + bias + + return attn_params + mlp_params + norm_params + + @staticmethod + def calculate_layers_to_remove( + base_model: PreTrainedModel, + module_params: int, + max_total_params: float = 8e9 + ) -> int: + """Calculate how many layers to remove to stay under size limit""" + + total_base_params = ModelSizeManager.count_parameters(base_model) + + if total_base_params + module_params <= max_total_params: + return 0 + + # Need to remove layers + excess_params = (total_base_params + module_params) - max_total_params + params_per_layer = ModelSizeManager.estimate_layer_params(base_model) + + layers_to_remove = math.ceil(excess_params / params_per_layer) + + # Ensure we don't remove too many layers + num_layers = getattr(base_model.config, 'num_hidden_layers', 32) + max_removable = max(0, num_layers - 4) # Keep at least 4 layers + + return min(layers_to_remove, max_removable) class RetentionEngine(nn.Module): - def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): + def __init__( + self, + basemodel: PreTrainedModel, + config: RetentionEngineConfig, + auto_resize: bool = True + ): super().__init__() + + self.original_basemodel = basemodel + self.auto_resize = auto_resize + self.module = PreTrainedTitansModel(config) + + module_params = ModelSizeManager.count_parameters(self.module) + config.module_params = module_params + + if auto_resize: + basemodel = self._resize_basemodel(basemodel, config, module_params) + self.module.model = basemodel - + self._attach_embeddings() + # Combine the configurations self.config = config if hasattr(basemodel.config, 'to_dict'): @@ -23,74 +131,200 @@ def __init__(self, basemodel: PreTrainedModel, config: PretrainedTitansConfig): for key, value in basemodel.config.__dict__.items(): if not key.startswith('_'): # exclude private attributes self.config.__dict__[key] = value - + + def _resize_basemodel( + self, + basemodel: PreTrainedModel, + config: RetentionEngineConfig, + module_params: int + ) -> PreTrainedModel: + """Resize base model by removing layers if necessary""" + + # Calculate base model size + base_params = ModelSizeManager.count_parameters(basemodel) + config.base_model_params = base_params + config.original_num_layers = getattr(basemodel.config, 'num_hidden_layers', None) + + # Check if resizing is needed + total_params = base_params + module_params + + print(f"Base model params: {base_params/1e9:.2f}B") + print(f"Module params: {module_params/1e9:.2f}B") + print(f"Total params: {total_params/1e9:.2f}B") + + if total_params <= config.max_total_params: + print(f"Model size is within limit ({config.max_total_params/1e9:.1f}B). No resizing needed.") + config.layers_to_remove = 0 + return basemodel + + # Calculate layers to remove + layers_to_remove = ModelSizeManager.calculate_layers_to_remove( + basemodel, + module_params, + config.max_total_params + ) + + if layers_to_remove == 0: + print("Cannot reduce model size further while maintaining minimum layers.") + return basemodel + + print(f"Removing {layers_to_remove} layers to meet size constraint...") + config.layers_to_remove = layers_to_remove + + # Remove layers from model + resized_model = self._remove_layers(basemodel, layers_to_remove) + + # Recalculate size + new_base_params = ModelSizeManager.count_parameters(resized_model) + new_total = new_base_params + module_params + print(f"Resized base model: {new_base_params/1e9:.2f}B") + print(f"New total size: {new_total/1e9:.2f}B") + + return resized_model + + def _remove_layers(self, model: PreTrainedModel, num_layers: int) -> PreTrainedModel: + """Remove specified number of layers from the front of the model""" + + # Access the transformer layers + if hasattr(model, 'model') and hasattr(model.model, 'layers'): + layers = model.model.layers + else: + print("Warning: Could not identify layer structure. Returning original model.") + return model + + total_layers = len(layers) + if num_layers >= total_layers - 4: + print(f"Warning: Cannot remove {num_layers} layers. Keeping minimum 4 layers.") + num_layers = max(0, total_layers - 4) + if num_layers == 0: + return model + + new_layers = nn.ModuleList(layers[num_layers:]) + + # Replace layers in model + if hasattr(model, 'model') and hasattr(model.model, 'layers'): + model.model.layers = new_layers + + # Update config + if hasattr(model.config, 'num_hidden_layers'): + model.config.num_hidden_layers = len(new_layers) + elif hasattr(model.config, 'n_layer'): + model.config.n_layer = len(new_layers) + + print(f"Layers reduced from {total_layers} to {len(new_layers)}") + + return model + + def _merge_configs(self, base_config): + """Merge base model configuration with module configuration""" + if hasattr(base_config, 'to_dict'): + config_dict = base_config.to_dict() + for key, value in config_dict.items(): + if key not in ['num_hidden_layers', 'n_layer']: # Don't override layer count + self.config.__dict__[key] = value + elif hasattr(base_config, '__dict__'): + for key, value in base_config.__dict__.items(): + if not key.startswith('_') and key not in ['num_hidden_layers', 'n_layer']: + self.config.__dict__[key] = value + + def _attach_embeddings(self): + """Attach the input embeddings from the base model to the adapter model""" + if self.module.model is not None and hasattr(self.module.model, 'get_input_embeddings'): + input_embeddings = self.module.model.get_input_embeddings() + self.module.set_input_embeddings(input_embeddings) + def forward(self, *args, **kwargs): return self.module(*args, **kwargs) - - def adapt(self, - tokenizer: AutoTokenizer, - epochs: int = 3, - batch_size: int = 1, - max_length: int = 8192, - learning_rate: float = 2e-5, - - alpha: float = 0.3, - temperature: float = 2.0 - ): - - device = self.module.device - adapt_model = self.module - base_model = self.module.model - base_model.eval() - for param in base_model.parameters(): - param.requires_grad = False - for param in adapt_model.attention_module.parameters(): - param.requires_grad = False # dont trian attention module - adapt_model.train() - - train_dataloader = get_pg19_dataloader(tokenizer, batch_size, 'train', max_length) - eval_dataloader = get_pg19_dataloader(tokenizer, batch_size, 'validation', max_length) - - optimizer = AdamW(adapt_model.parameters(), lr=learning_rate) - - trainer = Adapter( - adapt_model=adapt_model, - base_model=base_model, - optimizer=optimizer, - device=device, - alpha=alpha, - temperature=temperature - ) - - for epoch in range(epochs): - print(f"\n--- Epoch {epoch + 1}/{epochs} ---") - train_loss = trainer.train_epoch(train_dataloader) - eval_loss = trainer.eval_epoch(eval_dataloader) - print(f"Epoch {epoch + 1} | Train Loss: {train_loss:.4f} | Eval Loss: {eval_loss:.4f}") + + def get_model_size_info(self) -> Dict[str, Any]: + """Get detailed size information about the model""" + base_params = ModelSizeManager.count_parameters(self.module.model) if self.module.model else 0 + module_params = ModelSizeManager.count_parameters(self.module, exclude_embeddings=True) + total_params = base_params + module_params - - - + return { + 'base_model_params': base_params, + 'base_model_size_gb': base_params / 1e9, + 'module_params': module_params, + 'module_size_gb': module_params / 1e9, + 'total_params': total_params, + 'total_size_gb': total_params / 1e9, + 'layers_removed': self.config.layers_to_remove, + 'original_layers': self.config.original_num_layers, + 'current_layers': getattr(self.module.model.config, 'num_hidden_layers', None) if self.module.model else None + } + @classmethod def from_pretrained( - cls, - model_id: str, - basemodel: PreTrainedModel + cls, + model_id: str, + basemodel: PreTrainedModel, + auto_resize: bool = True, + max_total_params: float = 8e9 ) -> "RetentionEngine": - """ - Load a RetentionEngine from a pretrained model and configuration. - """ - config = PretrainedTitansConfig.from_pretrained(model_id) + """Load a RetentionEngine from a pretrained model and configuration""" + + # Load config with size constraints + config = RetentionEngineConfig.from_pretrained(model_id) + config.max_total_params = max_total_params + config.base_model_name = getattr(basemodel.config, '_name_or_path', 'unknown') + + # Create engine with auto-resizing + engine = cls(basemodel, config, auto_resize=auto_resize) + + # Load pretrained module weights module = PreTrainedTitansModel.from_pretrained(model_id) - module.model = basemodel - engine = cls(basemodel, config) + engine.module.model = engine.module.model # Use the resized model engine.module = module + return engine - + def save_pretrained(self, save_directory: str): - """ - Save the RetentionEngine to a directory. - """ + """Save the RetentionEngine to a directory""" + # Save size information in config + size_info = self.get_model_size_info() + for key, value in size_info.items(): + setattr(self.config, f'size_{key}', value) + self.module.save_pretrained(save_directory) self.config.save_pretrained(save_directory) - \ No newline at end of file + + # Save size report + import json + with open(f"{save_directory}/size_report.json", 'w') as f: + json.dump(size_info, f, indent=2) + + +# Example usage +if __name__ == "__main__": + from transformers import AutoModelForCausalLM + + # Load a base model (e.g., 13B model) + base_model = AutoModelForCausalLM.from_pretrained( + "meta-llama/Llama-2-13b-hf", + torch_dtype=torch.bfloat16, + device_map="auto" + ) + + # Create config with 8B size limit + config = RetentionEngineConfig( + max_total_params=8e9, # 8B max + base_model_name="meta-llama/Llama-2-13b-hf", + # Add other Titans config parameters here + ) + + # Create engine with automatic resizing + engine = RetentionEngine( + basemodel=base_model, + config=config, + auto_resize=True # Automatically remove layers if needed + ) + + # Check final size + size_info = engine.get_model_size_info() + print("\nFinal Model Size Information:") + for key, value in size_info.items(): + print(f" {key}: {value}") + + # Save the resized model + engine.save_pretrained("./retention_engine_8b") \ No newline at end of file