From 65a09aad6930fc8302a6133b9fc52830861e2818 Mon Sep 17 00:00:00 2001 From: Dylan Hillier Date: Mon, 19 Aug 2024 17:13:25 +0800 Subject: [PATCH 1/3] fixes the initialize --- models/build_models.py | 191 ++++++++++++------ models/cast_configs.py | 30 +++ models/components/layers/activations.py | 104 +++++++--- models/components/layers/attention.py | 82 +++++--- models/components/layers/feedforward.py | 100 +++++++-- models/components/layers/normalization.py | 80 ++++++-- .../components/layers/transformer_blocks.py | 29 +-- models/components/positional_encoding.py | 62 ++++-- models/components/tokenizers/base_class.py | 10 +- models/components/tokenizers/bpe.py | 3 +- models/core_models.py | 51 ++++- models/embedding_models.py | 52 ++++- .../byte_level/byte_model_shell.py | 34 +++- .../byte_level/embedding_model.py | 120 ++++++----- models/experimental/byte_level/layers.py | 69 ++++--- models/experimental/byte_level/model_heads.py | 65 +++--- models/experimental/hugging_face.py | 85 +++++--- .../experimental/next_thought/core_models.py | 38 ++-- .../next_thought/embedding_models.py | 85 +++++--- models/experimental/next_thought/layers.py | 72 ++++--- .../experimental/next_thought/model_heads.py | 35 ++-- models/generator.py | 12 +- models/model_heads.py | 77 +++++-- models/model_shell.py | 35 +++- models/utils.py | 18 +- 25 files changed, 1042 insertions(+), 497 deletions(-) create mode 100644 models/cast_configs.py diff --git a/models/build_models.py b/models/build_models.py index 36236091..8d39f772 100644 --- a/models/build_models.py +++ b/models/build_models.py @@ -3,17 +3,24 @@ core model, lm head and the model shell. """ +from models import model_shell +from models.cast_configs import ModelShellConfigMap from models.core_models import GenericFFNSharedTransfomer, GenericTransformer from models.embedding_models import GenericEmbedder +from models.experimental.byte_level.byte_model_shell import ( + ByteModelShell, + ByteShellConfig, +) from models.experimental.byte_level.embedding_model import ByteLevelEmbedder from models.experimental.byte_level.model_heads import ByteLevelDecoder -from models.experimental.byte_level.byte_model_shell import ByteModelShell from models.experimental.hugging_face import HFEmbedder, HFLMHead, HFTransformerCore +from models.experimental.next_thought.core_models import ( + BaselineCoreModel, + Conv1dCoreModel, +) from models.experimental.next_thought.embedding_models import HierarchicalEncoder from models.experimental.next_thought.model_heads import VariableLengthLatentDecoder -from models.experimental.next_thought.core_models import BaselineCoreModel, Conv1dCoreModel from models.model_heads import AutoregressiveLMHead -from models.model_shell import ModelShell def build_model(model_cfg=None, checkpoint=None): @@ -44,15 +51,17 @@ def build_model(model_cfg=None, checkpoint=None): return model -EMBEDDING_MODEL_DICT = { - "generic": GenericEmbedder, - "byte_level": ByteLevelEmbedder, +EMBEDDER_DICT = { + "generic": GenericEmbedder, "hf_embedder": HFEmbedder, - "hierarchical": HierarchicalEncoder, - } + "nt_embedder": HierarchicalEncoder, + "byte_embedder": ByteLevelEmbedder, +} -def build_embedding_model(model_cfg): +def build_embedding_model( + model_cfg: ModelShellConfigMap | ByteModelShell, +) -> GenericEmbedder: """ Given the embedding model config, build it. Args: @@ -60,9 +69,34 @@ def build_embedding_model(model_cfg): Returns: embedding_model: embedding_model_instance """ - return EMBEDDING_MODEL_DICT[model_cfg["embedder"]["embedding_model_type"]]( - model_cfg=model_cfg - ) + embedder_cfg = model_cfg.embedding_model + embedder_type = model_cfg.embedding_model.embedding_model_type + match embedder_type: + case "byte_embedder": + return ByteLevelEmbedder( + embedder_cfg=embedder_cfg, + byte_cfg=model_cfg, + hidden_dim=model_cfg.hidden_dim, + vocab_size=model_cfg.vocab_size, + ) + case "hf_embedder": + return HFEmbedder(model_cfg=embedder_cfg) + case "nt_embedder": + return HierarchicalEncoder( + embedder_cfg=embedder_cfg, + vocab_size=model_cfg.vocab_size, + hidden_dim=model_cfg.hidden_dim, + context_window=model_cfg.context_window, + positional_encoding_type=model_cfg.positional_encoding_type, + ) + case "generic": + return GenericEmbedder( + embedder_cfg=embedder_cfg, + vocab_size=model_cfg.vocab_size, + hidden_dim=model_cfg.hidden_dim, + context_window=model_cfg.context_window, + positional_encoding_type=model_cfg.positional_encoding_type, + ) CORE_MODEL_DICT = { @@ -70,11 +104,13 @@ def build_embedding_model(model_cfg): "generic_ffn_sharing": GenericFFNSharedTransfomer, "hf_core": HFTransformerCore, "next_thought_baseline": BaselineCoreModel, - "conv": Conv1dCoreModel + "conv": Conv1dCoreModel, } -def build_core_model(model_cfg): +def build_core_model( + model_cfg: model_shell.ModelShellConfig | ByteModelShell, +) -> GenericTransformer: """ Given the core model config, build it. Args: @@ -82,43 +118,69 @@ def build_core_model(model_cfg): Returns: core_model: core_model_instance """ - return CORE_MODEL_DICT[model_cfg["core_model"]["core_model_type"]]( - model_cfg=model_cfg - ) + core_model_cfg = model_cfg.core_model + core_model_type = core_model_cfg.core_model_type + match core_model_type: + case "generic": + return GenericTransformer( + hidden_dim=model_cfg.hidden_dim, + context_window=model_cfg.context_window, + core_model_cfg=core_model_cfg, + ) + case "generic_ffn_sharing": + return GenericFFNSharedTransfomer( + hidden_dim=model_cfg.hidden_dim, + context_window=model_cfg.context_window, + core_model_cfg=core_model_cfg, + ) + case "hf_core": + return HFTransformerCore(model_cfg=core_model_cfg) + case "next_thought_baseline": + return BaselineCoreModel(model_cfg=core_model_cfg) + case "conv": + return Conv1dCoreModel() MODEL_HEAD_DICT = { - "generic": lambda model_cfg, embedding_model: AutoregressiveLMHead(model_cfg=model_cfg), - "byte_level": lambda model_cfg, embedding_model: ByteLevelDecoder(model_cfg=model_cfg), - "hf_head": lambda model_cfg, embedding_model: HFLMHead(model_cfg=model_cfg), - "latent_2_seq": lambda model_cfg, embedding_model: VariableLengthLatentDecoder( - model_cfg=model_cfg, - embedding_model=embedding_model - ), - } + "hf_lm_head": HFLMHead, + "nt_lm_head": VariableLengthLatentDecoder, + "byte_lm_head": ByteLevelDecoder, + "generic": AutoregressiveLMHead, +} -def build_model_head(model_cfg, embedding_model=None): +def build_model_head(model_cfg: ModelShellConfigMap, embedding_model: GenericEmbedder): """ - Given the lm head config, build it. + Given the model head config, build it. Args: model_cfg: model_cfg + embedding_model: embedding_model_instance Returns: model_head: model_head_instance """ - return MODEL_HEAD_DICT[model_cfg["lm_head"]["lm_head_type"]]( - model_cfg=model_cfg, - embedding_model=embedding_model - ) - - -MODEL_SHELL_DICT = { - "standard": ModelShell, - "byte_shell": ByteModelShell -} - - -def build_model_shell(model_cfg, embedding_model, core_model, model_head): + model_head_cfg = model_cfg.model_head + model_head_type = model_head_cfg.model_head_type + match model_head_type: + case "hf_lm_head": + return HFLMHead(model_cfg=model_head_cfg) + case "nt_lm_head": + return VariableLengthLatentDecoder( + model_cfg=model_head_cfg, embedding_model=embedding_model + ) + case "generic": + return AutoregressiveLMHead( + hidden_dim=model_cfg.hidden_dim, + vocab_size=model_cfg.vocab_size, + lm_head_cfg=model_head_cfg, + ) + + +MODEL_SHELL_DICT = {"standard": model_shell.ModelShell, "byte_shell": ByteModelShell} + + +def build_model_shell( + model_cfg: model_shell.ModelShellConfig | ByteShellConfig, +): """ Given the model shell config, build it. Args: @@ -126,12 +188,31 @@ def build_model_shell(model_cfg, embedding_model, core_model, model_head): Returns: model_shell: model_shell_instance """ - return MODEL_SHELL_DICT[model_cfg["model_shell_type"]]( - embedding_model=embedding_model, core_model=core_model, model_head=model_head - ) + model_shell_type = model_cfg.model_shell_type + # build the embedding model + embedding_model = build_embedding_model(model_cfg=model_cfg) + # build the core model + core_model = build_core_model(model_cfg=model_cfg) -def initialize_model(model_cfg): + # build the model head + model_head = build_model_head(model_cfg=model_cfg, embedding_model=embedding_model) + match model_shell_type: + case "standard": + return model_shell.ModelShell( + embedding_model=embedding_model, + core_model=core_model, + model_head=model_head, + ) + case "byte_shell": + return ByteModelShell( + embedding_model=embedding_model, + core_model=core_model, + model_head=model_head, + ) + + +def initialize_model(model_dict: dict): """ Initialize the model given the configuration. Args: @@ -139,30 +220,16 @@ def initialize_model(model_cfg): Returns: model: model_instance """ - # build the embedding model - embedding_model = build_embedding_model(model_cfg=model_cfg) - - # build the core model - core_model = build_core_model(model_cfg=model_cfg) - - # build the model head - model_head = build_model_head( + model_cfg = ModelShellConfigMap(**model_dict) + model = build_model_shell( model_cfg=model_cfg, - embedding_model=embedding_model ) - # check if embedding model weights are to be shared with the model head - if model_cfg["embedding_weight_tying"]: + if model_cfg.embedding_weight_tying: # share the weights between the token embeddings and the final # logit layer, following: https://paperswithcode.com/method/weight-tying - embedding_model.token_embedder.weight = model_head.linear.weight + model.embedding_model.token_embedder.weight = model.model_head.linear.weight # build the model shell - model = build_model_shell( - model_cfg=model_cfg, - embedding_model=embedding_model, - core_model=core_model, - model_head=model_head, - ) return model diff --git a/models/cast_configs.py b/models/cast_configs.py new file mode 100644 index 00000000..54ee3e3e --- /dev/null +++ b/models/cast_configs.py @@ -0,0 +1,30 @@ +"""Pseudo Configs for Model Building/Casting""" + +from typing import Literal + +from models import core_models, embedding_models, model_heads, model_shell +from models.experimental import hugging_face +from models.experimental.next_thought import core_models as nt_core_models +from models.experimental.next_thought import embedding_models as nt_embedding_models + + +class ModelShellConfigMap(model_shell.ModelShellConfig): + """Config for the standard model shell""" + + model_shell_type: Literal["standard"] + core_model: ( + hugging_face.CoreModelConfig + | nt_core_models.CoreModelConfig + | core_models.CoreModelConfig + ) + embedding_model: ( + nt_embedding_models.HierarchicalEncoderConfig + | hugging_face.HFEmbedderConfig + | embedding_models.GenericEmbedderConfig + ) + model_head: hugging_face.HFLMHeadConfig | model_heads.LMHeadConfig + hidden_dim: int + context_window: int + vocab_size: int + embedding_weight_tying: bool + positional_encoding_type: str diff --git a/models/components/layers/activations.py b/models/components/layers/activations.py index e9b89868..2fb719ee 100644 --- a/models/components/layers/activations.py +++ b/models/components/layers/activations.py @@ -2,8 +2,11 @@ A collection of common activation functions. """ +import enum + import torch + class LearnedActivation(torch.nn.Module): def __init__(self, hidden_size=10): super(LearnedActivation, self).__init__() @@ -13,20 +16,52 @@ def __init__(self, hidden_size=10): def initialize_weights(self): # Initialize weights to the learned parameters - self.fc1.weight.data = torch.tensor([[-0.3478], - [-0.3444], - [-0.9863], - [-0.8657], - [-0.0148], - [ 0.1085], - [-0.5282], - [-0.1138], - [-1.1070], - [-0.1035]]) - self.fc1.bias.data = torch.tensor([ 1.4480, 1.4610, -0.8526, 0.0151, -0.1249, -0.7658, 2.2386, -0.8884, 1.0032, -0.6235]) - self.fc2.weight.data = torch.tensor([[-0.4762, -1.2194, 0.4155, 0.3927, -0.2778, 0.0986, -0.9284, 0.2070, 0.3586, -0.2143]]) + self.fc1.weight.data = torch.tensor( + [ + [-0.3478], + [-0.3444], + [-0.9863], + [-0.8657], + [-0.0148], + [0.1085], + [-0.5282], + [-0.1138], + [-1.1070], + [-0.1035], + ] + ) + self.fc1.bias.data = torch.tensor( + [ + 1.4480, + 1.4610, + -0.8526, + 0.0151, + -0.1249, + -0.7658, + 2.2386, + -0.8884, + 1.0032, + -0.6235, + ] + ) + self.fc2.weight.data = torch.tensor( + [ + [ + -0.4762, + -1.2194, + 0.4155, + 0.3927, + -0.2778, + 0.0986, + -0.9284, + 0.2070, + 0.3586, + -0.2143, + ] + ] + ) self.fc2.bias.data = torch.tensor([4.1740]) - + def forward(self, x): # Flatten the input to apply the learned activation element-wise orig_shape = x.shape @@ -36,16 +71,19 @@ def forward(self, x): return x.view(orig_shape) # Reshape back to original shape -ACTIVATIONS_DICT = { - "gelu": torch.nn.GELU(), - "relu": torch.nn.ReLU(), - "leakyrelu": torch.nn.LeakyReLU(), - "tanh": torch.nn.Tanh(), - "sigmoid": torch.nn.Sigmoid(), - "silu": torch.nn.SiLU(), - "learned": LearnedActivation(hidden_size=10), - "none": torch.nn.Identity(), -} +class ActivationType(enum.Enum): + """ + Enum for the different types of activations + """ + + GELU = "gelu" + RELU = "relu" + LEAKYRELU = "leakyrelu" + TANH = "tanh" + SIGMOID = "sigmoid" + SILU = "silu" + LEARNED = "learned" + NONE = "none" def build_activation(activation_name: str): @@ -57,4 +95,22 @@ def build_activation(activation_name: str): Returns: activation: torch.nn.Module """ - return ACTIVATIONS_DICT[activation_name.lower()] + match activation_name: + case "gelu": + return torch.nn.GELU() + case "relu": + return torch.nn.ReLU() + case "leakyrelu": + return torch.nn.LeakyReLU() + case "tanh": + return torch.nn.Tanh() + case "sigmoid": + return torch.nn.Sigmoid() + case "silu": + return torch.nn.SiLU() + case "learned": + return LearnedActivation(hidden_size=10) + case "none": + return torch.nn.Identity() + case _: + raise ValueError("Invalid activation function") diff --git a/models/components/layers/attention.py b/models/components/layers/attention.py index 41b06b37..bc48a16e 100644 --- a/models/components/layers/attention.py +++ b/models/components/layers/attention.py @@ -2,8 +2,27 @@ A collection of attention layers. """ +import enum + +import pydantic import torch +from models.components.layers import normalization + + +class AttentionConfig(pydantic.BaseModel): + """ + Attention configuration + """ + + attn_type = "generic" + num_heads: int + bias: bool + use_rope: bool + is_causal: bool + group_size: int + normalization: str + class Attention(torch.nn.Module): """ @@ -12,15 +31,14 @@ class Attention(torch.nn.Module): def __init__( self, - hidden_dim, - num_heads, - bias, - use_rope, - context_window, - is_causal, - group_size, + attn_config: AttentionConfig, + hidden_dim: int, + context_window: int, ): super().__init__() + bias = attn_config.bias + num_heads = attn_config.num_heads + group_size = attn_config.group_size assert hidden_dim % num_heads == 0, "Hidden dim must be divisible by num heads" # key, query, value projections for all heads @@ -34,17 +52,23 @@ def __init__( # attention dropout self.attn_dropout = torch.nn.Dropout() - self.num_heads = num_heads - self.group_size = group_size - self.is_causal = is_causal + self.num_heads = attn_config.num_heads + self.group_size = attn_config.group_size + self.is_causal = attn_config.is_causal # rope - self.use_rope = use_rope + self.use_rope = attn_config.use_rope + if self.use_rope: assert context_window % 2 == 0 self.freqs_cis = compute_freqs_cis( seq_len=context_window, head_dim=hidden_dim // num_heads ) + self.normalization = normalization.build_normalization( + normalization_name=attn_config.normalization, + dim=hidden_dim, + bias=attn_config.bias, + ) def forward(self, x, attention_mask=None): """ @@ -91,7 +115,7 @@ def forward(self, x, attention_mask=None): # output projection y = self.attn_dropout(self.c_proj(y)) # is this really necessary? - + y = self.normalization(y) return y @@ -126,20 +150,15 @@ def compute_freqs_cis(seq_len, head_dim): return freqs_cis -ATTENTION_DICT = { - "generic": lambda hidden_dim, context_window, use_rope, attn_cfg: Attention( - hidden_dim=hidden_dim, - num_heads=attn_cfg["num_heads"], - bias=attn_cfg["bias"], - use_rope=use_rope, - context_window=context_window, - is_causal=attn_cfg["is_causal"], - group_size=attn_cfg["group_size"], - ) -} +class AttentionMechanisms(enum.Enum): + """ + Enum for the different attention mechanisms + """ + + GENERIC = "generic" -def build_attention(hidden_dim, context_window, use_rope, attn_cfg): +def build_attention(hidden_dim: int, context_window: int, attn_cfg: AttentionConfig): """ Build an attention layer @@ -149,9 +168,12 @@ def build_attention(hidden_dim, context_window, use_rope, attn_cfg): use_rope: whether to use rope attn_cfg: attention config """ - return ATTENTION_DICT[attn_cfg["attn_type"]]( - hidden_dim=hidden_dim, - context_window=context_window, - use_rope=use_rope, - attn_cfg=attn_cfg, - ) + match attn_cfg.attn_type: + case AttentionMechanisms.GENERIC: + return Attention( + attn_config=attn_cfg, + hidden_dim=hidden_dim, + context_window=context_window, + ) + case _: + raise ValueError("Invalid attention mechanism") diff --git a/models/components/layers/feedforward.py b/models/components/layers/feedforward.py index 580f5deb..c430c117 100644 --- a/models/components/layers/feedforward.py +++ b/models/components/layers/feedforward.py @@ -2,10 +2,51 @@ A collection of FFN blocks """ +import enum + +import pydantic import torch import torch.nn.functional as F from models.components.layers.activations import build_activation +from models.components.layers.normalization import build_normalization + + +class FFNTypes(str, enum.Enum): + """ + Types of FFNs + """ + + GENERIC = "generic" + SWIGLU = "swiglu" + + +class FFNConfig(pydantic.BaseModel): + """ + Feedforward network configuration + """ + + ffn_dim: int + bias: bool + ffn_type: FFNTypes + normalization: str + + +class GenericFFNConfig(FFNConfig): + """ + Feedforward network configuration + """ + + ffn_type: FFNTypes.GENERIC + ffn_activation: str + + +class SwiGLUFFNConfig(FFNConfig): + """ + Feedforward network configuration + """ + + ffn_type: FFNTypes.SWIGLU class GenericFFN(torch.nn.Module): @@ -16,22 +57,30 @@ class GenericFFN(torch.nn.Module): def __init__( self, hidden_dim, - ffn_dim, - bias, - ffn_activation, + ffn_config: GenericFFNConfig, ): super().__init__() # build the ffn block - self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) + self.linear_1 = torch.nn.Linear( + hidden_dim, ffn_config.ffn_dim, bias=ffn_config.bias + ) - self.activation = build_activation(activation_name=ffn_activation) + self.activation = build_activation(activation_name=ffn_config.ffn_activation) - self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias) + self.linear_2 = torch.nn.Linear( + ffn_config.ffn_dim, hidden_dim, bias=ffn_config.bias + ) + self.normalization = build_normalization( + normalization_name=ffn_config.normalization, + dim=hidden_dim, + bias=ffn_config.bias, + ) def forward(self, x): """ A simple forward pass through the FFN """ + x = self.normalization(x) x = self.linear_1(x) x = self.activation(x) x = self.linear_2(x) @@ -50,41 +99,48 @@ class SwiGLUFFN(torch.nn.Module): def __init__( self, hidden_dim, - ffn_dim, - bias, + ffn_config: FFNConfig, ): super().__init__() # build the linear functions + ffn_dim, bias = ffn_config.ffn_dim, ffn_config.bias self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) self.linear_2 = torch.nn.Linear(ffn_dim, hidden_dim, bias=bias) self.linear_3 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) + self.normalization = build_normalization( + normalization_name=ffn_config.normalization, + dim=hidden_dim, + bias=ffn_config.bias, + ) def forward(self, x): """ A simple forward pass through the FFN """ + x = self.normalization(x) return self.linear_2(F.silu(self.linear_1(x)) * self.linear_3(x)) -FFN_DICT = { - "generic": lambda hidden_dim, ffn_cfg: GenericFFN( - hidden_dim=hidden_dim, - ffn_dim=ffn_cfg["ffn_dim"], - bias=ffn_cfg["bias"], - ffn_activation=ffn_cfg["activation"], - ), - "swiglu": lambda hidden_dim, ffn_cfg: SwiGLUFFN( - hidden_dim=hidden_dim, - ffn_dim=ffn_cfg["ffn_dim"], - bias=ffn_cfg["bias"], - ), -} +def build_ffn_config(ffn_cfg) -> FFNConfig: + """ + Build the FFN config + """ + match ffn_cfg["ffn_type"]: + case FFNTypes.GENERIC: + return GenericFFNConfig(**ffn_cfg) + case FFNTypes.SWIGLU: + return SwiGLUFFNConfig(**ffn_cfg) def build_ffn(hidden_dim, ffn_cfg): """ Build a feedforward network """ - return FFN_DICT[ffn_cfg["ffn_type"]](hidden_dim=hidden_dim, ffn_cfg=ffn_cfg) + ffn_config = build_ffn_config(ffn_cfg) + match ffn_config.ffn_type: + case FFNTypes.GENERIC: + return GenericFFN(hidden_dim=hidden_dim, ffn_config=ffn_config) + case FFNTypes.SWIGLU: + return SwiGLUFFN(hidden_dim=hidden_dim, ffn_config=ffn_config) diff --git a/models/components/layers/normalization.py b/models/components/layers/normalization.py index 69b87605..06223538 100644 --- a/models/components/layers/normalization.py +++ b/models/components/layers/normalization.py @@ -2,34 +2,81 @@ A collection of normalization layers. """ +import enum + +import pydantic import torch from torch.nn import functional as F +EPSILON = 1e-6 + + +class NormalizationTypes(str, enum.Enum): + """ + Types of normalization + """ + + LAYERNORM = "layernorm" + RMSNORM = "rmsnorm" + + +class NormConfig(pydantic.BaseModel): + """ + Normalization configuration + """ + + normalization: NormalizationTypes + dim: pydantic.PositiveInt + + +class LayerNormConfig(NormConfig): + """ + Layer normalization configuration + """ + + normalization: NormalizationTypes.LAYERNORM + bias: bool = True + class LayerNorm(torch.nn.Module): """LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False""" # taken from nanoGPT - def __init__(self, dim, bias): + def __init__(self, norm_config: LayerNormConfig): super().__init__() - self.weight = torch.nn.Parameter(torch.ones(dim)) - self.bias = torch.nn.Parameter(torch.zeros(dim)) if bias else None + self.weight = torch.nn.Parameter(torch.ones(norm_config.dim)) + self.bias = ( + torch.nn.Parameter(torch.zeros(norm_config.dim)) + if norm_config.bias + else None + ) def forward(self, x): """Apply Layer Norm""" return F.layer_norm(x, self.weight.shape, self.weight, self.bias, 1e-5) +class RMSNormConfig(NormConfig): + """ + RMSNorm configuration + + eps is the epsilon value to prevent division by zero + """ + + normalization: NormalizationTypes.RMSNORM + eps: pydantic.PositiveFloat = EPSILON + + class RMSNorm(torch.nn.Module): """ RMSNorm (https://arxiv.org/abs/1910.07467), implementation from https://github.com/meta-llama/llama3/blob/main/llama/model.py """ - def __init__(self, dim: int, eps: float = 1e-6): + def __init__(self, norm_config: RMSNormConfig): super().__init__() - self.eps = eps - self.weight = torch.nn.Parameter(torch.ones(dim)) + self.eps = norm_config.eps + self.weight = torch.nn.Parameter(torch.ones(norm_config.dim)) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) @@ -40,17 +87,20 @@ def forward(self, x): return output * self.weight -NORMALIZATION_DICT = { - "rms_norm": lambda dim, bias: RMSNorm(dim=dim), - "layer_norm": lambda dim, bias: LayerNorm(dim=dim, bias=bias), - "none": lambda dim, bias: torch.nn.Identity(), -} - - -def build_normalization(normalization_name, dim, bias=None): +def build_normalization( + normalization_name: NormalizationTypes, dim: int, bias: bool = True +): """ Build the normalization layer Available options: rmsnorm, layernorm - Bias is ignored for RMSNorm """ - return NORMALIZATION_DICT[normalization_name](dim=dim, bias=bias) + match normalization_name: + case NormalizationTypes.LAYERNORM: + return LayerNorm( + LayerNormConfig(normalization=normalization_name, dim=dim, bias=bias) + ) + case NormalizationTypes.RMSNORM: + return RMSNorm(RMSNormConfig(normalization=normalization_name, dim=dim)) + case _: + raise ValueError(f"Unknown normalization type: {normalization_name}") diff --git a/models/components/layers/transformer_blocks.py b/models/components/layers/transformer_blocks.py index 879ff744..0aa8bfaa 100644 --- a/models/components/layers/transformer_blocks.py +++ b/models/components/layers/transformer_blocks.py @@ -5,9 +5,8 @@ import torch -from models.components.layers.attention import build_attention -from models.components.layers.feedforward import build_ffn -from models.components.layers.normalization import build_normalization +from models.components.layers.attention import AttentionConfig, build_attention +from models.components.layers.feedforward import FFNConfig, build_ffn class GenericTransformerBlock(torch.nn.Module): @@ -16,31 +15,19 @@ class GenericTransformerBlock(torch.nn.Module): FFN, Attn and normalization. """ - def __init__(self, hidden_dim, context_window, use_rope, ffn_cfg, attn_cfg): + def __init__( + self, hidden_dim, context_window, ffn_cfg: FFNConfig, attn_cfg: AttentionConfig + ): super().__init__() - - # build the attn norm - self.attn_norm = build_normalization( - normalization_name=attn_cfg["normalization"], - dim=hidden_dim, - bias=attn_cfg["bias"], - ) + attn_cfg = AttentionConfig(**attn_cfg) # build the attention self.attn = build_attention( hidden_dim=hidden_dim, context_window=context_window, - use_rope=use_rope, attn_cfg=attn_cfg, ) - # build the ffn norm - self.ffn_norm = build_normalization( - normalization_name=ffn_cfg["normalization"], - dim=hidden_dim, - bias=ffn_cfg["bias"], - ) - # build the ffn block self.ffn = build_ffn( hidden_dim=hidden_dim, @@ -57,6 +44,6 @@ def forward(self, x, attention_mask=None): Returns: x: the output tensor (b, s, h) """ - x = x + self.attn(self.attn_norm(x), attention_mask) - x = x + self.ffn(self.ffn_norm(x)) + x = x + self.attn(x, attention_mask) + x = x + self.ffn(x) return x diff --git a/models/components/positional_encoding.py b/models/components/positional_encoding.py index 603bc099..c8bb8249 100644 --- a/models/components/positional_encoding.py +++ b/models/components/positional_encoding.py @@ -2,9 +2,11 @@ A collection of positional encoding modules. """ -import torch +import enum import math +import torch + class LearnedPosEncoding(torch.nn.Module): """ @@ -45,44 +47,48 @@ def forward(self, x): """ return x -class SinCosPosEncoding( - torch.nn.Module -): + +class SinCosPosEncoding(torch.nn.Module): """SinCos encoding taken from: \\url{https://github.com/pytorch/examples/blob/main/word_language_model/model.py#L65} As used in the Vaiswani et al. paper...""" + def __init__(self, hidden_dim, context_window): """Set up the pe buffer etc.""" super().__init__() pe = torch.zeros(context_window, hidden_dim) position = torch.arange(0, context_window, dtype=torch.float).unsqueeze(1) - div_term = torch.exp(torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim)) + div_term = torch.exp( + torch.arange(0, hidden_dim, 2).float() * (-math.log(10000.0) / hidden_dim) + ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) - pe = pe.unsqueeze(0) # pe has shape (1, S, H) + pe = pe.unsqueeze(0) # pe has shape (1, S, H) - self.pe = torch.nn.Parameter(pe) # hack for distributed data parallel + self.pe = torch.nn.Parameter(pe) # hack for distributed data parallel self.pe.requires_grad = False def forward(self, x): """Add the pe to the input tensor.""" # x of shape (B, S, H) - return x + self.pe[:, :x.size(1)] + return x + self.pe[:, : x.size(1)] + + +class PosEncodingType(enum.Enum): + """ + Enum for the different types of positional encodings + """ + LEARNED = "learned" + ROPE = "rope" + NONE = "none" + SINCOS = "sincos" -POS_ENCODING_DICT = { - "learned": lambda dim, size, **_: LearnedPosEncoding( - hidden_dim=dim, context_window=size - ), - "rope": lambda **_: IdentityEncoding(), - "none": lambda **_: IdentityEncoding(), - "sincos": lambda dim, size, **_: SinCosPosEncoding( - hidden_dim=dim, context_window=size - ), -} -def build_positional_encodings(model_cfg): +def build_positional_encodings( + positional_encoding_type: PosEncodingType, hidden_dim, context_window +): """ Given the positional encoding config, build it. Args: @@ -90,6 +96,18 @@ def build_positional_encodings(model_cfg): Returns: positional_encodings: positional_encodings_instance """ - return POS_ENCODING_DICT[model_cfg["positional_encoding_type"]]( - dim=model_cfg["hidden_dim"], size=model_cfg["context_window"] - ) + match positional_encoding_type: + case PosEncodingType.LEARNED: + return LearnedPosEncoding( + hidden_dim=hidden_dim, context_window=context_window + ) + case PosEncodingType.ROPE: + return IdentityEncoding() + case PosEncodingType.NONE: + return IdentityEncoding() + case PosEncodingType.SINCOS: + return SinCosPosEncoding( + hidden_dim=hidden_dim, context_window=context_window + ) + case _: + raise ValueError("Invalid positional encoding type") diff --git a/models/components/tokenizers/base_class.py b/models/components/tokenizers/base_class.py index 673f7280..bfe644c4 100644 --- a/models/components/tokenizers/base_class.py +++ b/models/components/tokenizers/base_class.py @@ -25,7 +25,7 @@ def encode_batch(self, texts): def pad_batch(self, token_lists, direction="right"): """Pad a list of token lists to the same length, and return the padded tensor, and mask tensor. - + Direction can be 'right' or 'left' to specify the padding direction. """ max_len = max(len(tokens) for tokens in token_lists) @@ -33,10 +33,14 @@ def pad_batch(self, token_lists, direction="right"): mask = [] for tokens in token_lists: if direction == "right": - padded_tokens.append(tokens + [self.pad_token] * (max_len - len(tokens))) + padded_tokens.append( + tokens + [self.pad_token] * (max_len - len(tokens)) + ) mask.append([1] * len(tokens) + [0] * (max_len - len(tokens))) elif direction == "left": - padded_tokens.append([self.pad_token] * (max_len - len(tokens)) + tokens) + padded_tokens.append( + [self.pad_token] * (max_len - len(tokens)) + tokens + ) mask.append([0] * (max_len - len(tokens)) + [1] * len(tokens)) return torch.tensor(padded_tokens), torch.tensor(mask) diff --git a/models/components/tokenizers/bpe.py b/models/components/tokenizers/bpe.py index c39cea20..ad65bed0 100644 --- a/models/components/tokenizers/bpe.py +++ b/models/components/tokenizers/bpe.py @@ -5,10 +5,11 @@ Original Paper: https://arxiv.org/abs/1508.07909v5 """ -import torch + import os from heapq import nlargest +import torch from tqdm import tqdm from models.components.tokenizers import utils diff --git a/models/core_models.py b/models/core_models.py index 96ec2b34..f5d65b54 100644 --- a/models/core_models.py +++ b/models/core_models.py @@ -2,18 +2,43 @@ Simple, flexible core models. """ +import pydantic import torch from models.components.layers.transformer_blocks import GenericTransformerBlock +class CoreModelConfig(pydantic.BaseModel): + """ + Core Model configuration + """ + + core_model_type: str + + +class GenericCoreModelConfig(CoreModelConfig): + """ + Generic Core Model configuration + """ + + positional_encoding_type: str + ffn: dict + attn: dict + num_layers: int + + class GenericTransformer(torch.nn.Module): """ Generic Transformer Class intended to be used for as broad a range of transformer models as possible. """ - def __init__(self, model_cfg): + def __init__( + self, + hidden_dim, + context_window, + core_model_cfg: GenericCoreModelConfig, + ): super().__init__() # build the transformer @@ -23,13 +48,12 @@ def __init__(self, model_cfg): "h": torch.nn.ModuleList( [ GenericTransformerBlock( - hidden_dim=model_cfg["hidden_dim"], - context_window=model_cfg["context_window"], - use_rope=model_cfg["positional_encoding_type"] == "rope", - ffn_cfg=model_cfg["core_model"]["ffn"], - attn_cfg=model_cfg["core_model"]["attn"], + hidden_dim=hidden_dim, + context_window=context_window, + ffn_cfg=core_model_cfg.ffn, + attn_cfg=core_model_cfg.attn, ) - for _ in range(model_cfg["core_model"]["num_layers"]) + for _ in range(core_model_cfg.num_layers) ] ), } @@ -61,8 +85,17 @@ class GenericFFNSharedTransfomer(GenericTransformer): https://arxiv.org/abs/2402.16840). """ - def __init__(self, model_cfg): - super().__init__(model_cfg=model_cfg) + def __init__( + self, + hidden_dim, + context_window, + core_model_cfg: CoreModelConfig, + ): + super().__init__( + hidden_dim=hidden_dim, + context_window=context_window, + core_model_cfg=core_model_cfg, + ) # share the weights between transformer blocks ffn_0 = self.transformer.h[0].ffn diff --git a/models/embedding_models.py b/models/embedding_models.py index c7892203..df3f3d12 100644 --- a/models/embedding_models.py +++ b/models/embedding_models.py @@ -4,12 +4,33 @@ (if necessary). """ +from typing import Literal + +import pydantic import torch from models.components.positional_encoding import build_positional_encodings from models.components.tokenizers import build_tokenizer +class EmbedderConfig(pydantic.BaseModel): + """ + Embedder configuration + """ + + embedding_model_type: str + + +class GenericEmbedderConfig(EmbedderConfig): + """ + Embedder configuration + """ + + embedding_model_type: Literal["generic"] + tokenizer_type: str + dataset_name: str + + class EmbedderInterface(torch.nn.Module): """Interface for the embedder component of the model.""" @@ -47,8 +68,8 @@ def decode(self, tokens: torch.LongTensor): def inference(self, input_string: str, add_eot=False): """This function should map string to embeddings.""" token_ids = self.tokenize_input(input_string, truncate=True, add_eot=add_eot) - token_ids = torch.tensor(token_ids).unsqueeze(0).to( - next(self.parameters()).device + token_ids = ( + torch.tensor(token_ids).unsqueeze(0).to(next(self.parameters()).device) ) return self.forward(token_ids) @@ -95,25 +116,34 @@ class GenericEmbedder(EmbedderInterface): All embedders should inherit from this class. """ - def __init__(self, model_cfg): + def __init__( + self, + embedder_cfg: EmbedderConfig, + vocab_size: int, + hidden_dim: int, + context_window: int, + positional_encoding_type: str, + ): super().__init__() # build the tokenizer self.tokenizer = build_tokenizer( - tokenizer_type=model_cfg["embedder"]["tokenizer_type"], - vocab_size=model_cfg["vocab_size"], - dataset_name=model_cfg["embedder"]["dataset_name"], + tokenizer_type=embedder_cfg.tokenizer_type, + vocab_size=vocab_size, + dataset_name=embedder_cfg.dataset_name, ) # build the token embeddings self.token_embedder = torch.nn.Embedding( - num_embeddings=model_cfg["vocab_size"], - embedding_dim=model_cfg["hidden_dim"], + num_embeddings=vocab_size, + embedding_dim=hidden_dim, ) # build the positional encodings - self.positional_encodings = build_positional_encodings(model_cfg=model_cfg) + self.positional_encodings = build_positional_encodings( + positional_encoding_type, hidden_dim, context_window + ) self.eot_token = self.tokenizer.eot_token - self.model_cfg = model_cfg + self.context_window = context_window def forward(self, token_ids): """ @@ -157,7 +187,7 @@ def pad_batch(self, token_lists, direction="right"): def truncate(self, token_lists): # get model max length - max_length = self.model_cfg["context_window"] + max_length = self.context_window return [token_seq[-max_length:] for token_seq in token_lists] def decode(self, tokens): diff --git a/models/experimental/byte_level/byte_model_shell.py b/models/experimental/byte_level/byte_model_shell.py index b392eb39..12f1513e 100644 --- a/models/experimental/byte_level/byte_model_shell.py +++ b/models/experimental/byte_level/byte_model_shell.py @@ -3,11 +3,32 @@ core model and LM head. """ +from typing import Literal + import torch from models import core_models, embedding_models, model_heads -from models.model_shell import ModelShell +from models.components.layers import attention +from models.experimental.byte_level.embedding_model import ByteLevelEmbedderConfig +from models.experimental.byte_level.layers import ByteTransformerBlockConfig +from models.model_shell import ModelShell, ModelShellConfig + +class ByteShellConfig(ModelShellConfig): + """ + Byte Model Shell configuration + """ + + model_shell_type: Literal["byte_shell"] + byte_vocab_size: int = 256 + byte_context_window: int = 8 + byte_embedding_dim: int = 64 + embedding_model: ( + ByteLevelEmbedderConfig # Not needed since we are using ByteLevelEmbedder + ) + model_head: model_heads.LMHeadConfig + attn_cfg: attention.AttentionConfig + ffn_cfg: ByteTransformerBlockConfig class ByteModelShell(ModelShell): @@ -15,6 +36,7 @@ class ByteModelShell(ModelShell): Slight deviation from the standard Model Shell to allow for a re-constructive auxiliary loss to the input. """ + def __init__( self, embedding_model: embedding_models.EmbedderInterface, @@ -37,12 +59,10 @@ def forward(self, token_ids): # to get B, S, H (with pos encoding if necessary) x = self.embedding_model(token_ids) - # calculate the reconstruction loss + # calculate the reconstruction loss logits = self.model_head(x)[0] loss = torch.nn.functional.cross_entropy( - logits.view(-1, logits.size(-1)), - token_ids.view(-1), - ignore_index=257 + logits.view(-1, logits.size(-1)), token_ids.view(-1), ignore_index=257 ) # pass the embeddings through the core model @@ -51,6 +71,4 @@ def forward(self, token_ids): # pass the core model output through the model head x = self.model_head(x)[0] - return x, loss - - + return x, loss diff --git a/models/experimental/byte_level/embedding_model.py b/models/experimental/byte_level/embedding_model.py index f0c12cc2..347158bf 100644 --- a/models/experimental/byte_level/embedding_model.py +++ b/models/experimental/byte_level/embedding_model.py @@ -4,14 +4,26 @@ (if necessary). """ +from typing import Literal + import torch from models.components.positional_encoding import LearnedPosEncoding from models.components.tokenizers import build_tokenizer -from models.embedding_models import EmbedderInterface +from models.embedding_models import EmbedderInterface, GenericEmbedderConfig +from models.experimental.byte_level.byte_model_shell import ByteShellConfig from models.experimental.byte_level.layers import ByteLevelTransformerBlock +class ByteLevelEmbedderConfig(GenericEmbedderConfig): + """ + Byte Level configuration + """ + + embedder_type: Literal["byte_level"] + byte_tokenizer_type: str + + class ByteLevelEmbedder(EmbedderInterface): """ Takes byte level encodings, processes them via @@ -23,67 +35,78 @@ class ByteLevelEmbedder(EmbedderInterface): """ # pylint: disable=super-init-not-called - def __init__(self, model_cfg): + def __init__( + self, + vocab_size, + hidden_dim, + byte_cfg: ByteShellConfig, + embedder_cfg: ByteLevelEmbedderConfig, + ): super().__init__() - self.model_cfg = model_cfg # build the tokenizers self.byte_tokenizer = build_tokenizer( - tokenizer_type=model_cfg["embedder"]["byte_tokenizer_type"], - vocab_size=model_cfg["byte_vocab_size"], - dataset_name=model_cfg["embedder"]["dataset_name"], + tokenizer_type=embedder_cfg.byte_tokenizer_type, + vocab_size=embedder_cfg, + dataset_name=embedder_cfg.dataset_name, ) self.pooling_tokenizer = build_tokenizer( - tokenizer_type=model_cfg["embedder"]["tokenizer_type"], - vocab_size=model_cfg["vocab_size"], - dataset_name=model_cfg["embedder"]["dataset_name"], + tokenizer_type=embedder_cfg.tokenizer_type, + vocab_size=vocab_size, + dataset_name=embedder_cfg.dataset_name, ) # positional encodings self.pos_encoder = LearnedPosEncoding( - hidden_dim=model_cfg["byte_embedding_dim"], - context_window=model_cfg["byte_context_window"], + hidden_dim=hidden_dim, + context_window=byte_cfg.byte_context_window, ) # build the token embeddings self.byte_token_embedder = torch.nn.Embedding( - num_embeddings=model_cfg["byte_vocab_size"], - embedding_dim=model_cfg["byte_embedding_dim"], + num_embeddings=byte_cfg.byte_vocab_size, + embedding_dim=byte_cfg.byte_embedding_dim, ) # build the transformer blocks self.transformer = torch.nn.ModuleList( [ ByteLevelTransformerBlock( - input_dim=model_cfg["byte_embedding_dim"], - output_dim=model_cfg["byte_embedding_dim"] * 2, - ffn_dim=model_cfg["byte_embedding_dim"] * 4, - context_window=model_cfg["byte_context_window"], - use_rope=False, + input_dim=byte_cfg.byte_embedding_dim, + output_dim=byte_cfg.byte_embedding_dim * 2, + ffn_dim=byte_cfg.byte_embedding_dim * 4, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=embedder_cfg.ffn_cfg, + attn_config=embedder_cfg.attn_cfg, ), ByteLevelTransformerBlock( - input_dim=model_cfg["byte_embedding_dim"]*2, - output_dim=model_cfg["byte_embedding_dim"] * 2, - ffn_dim=model_cfg["byte_embedding_dim"] * 8, - context_window=model_cfg["byte_context_window"], - use_rope=False, + input_dim=byte_cfg.byte_embedding_dim * 2, + output_dim=byte_cfg.byte_embedding_dim * 2, + ffn_dim=byte_cfg.byte_embedding_dim * 8, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=embedder_cfg.ffn_cfg, + attn_config=embedder_cfg.attn_cfg, ), ByteLevelTransformerBlock( - input_dim=model_cfg["byte_embedding_dim"]*2, - output_dim=model_cfg["byte_embedding_dim"] * 2, - ffn_dim=model_cfg["byte_embedding_dim"] * 8, - context_window=model_cfg["byte_context_window"], - use_rope=False, + input_dim=byte_cfg.byte_embedding_dim * 2, + output_dim=byte_cfg.byte_embedding_dim * 2, + ffn_dim=byte_cfg.byte_embedding_dim * 8, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=embedder_cfg.ffn_cfg, + attn_config=embedder_cfg.attn_cfg, ), ByteLevelTransformerBlock( - input_dim=model_cfg["byte_embedding_dim"] * 2, - output_dim=model_cfg["hidden_dim"], - ffn_dim=model_cfg["byte_embedding_dim"] * 8, - context_window=model_cfg["byte_context_window"], - use_rope=False, + input_dim=byte_cfg.byte_embedding_dim * 2, + output_dim=hidden_dim, + ffn_dim=byte_cfg.byte_embedding_dim * 8, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=embedder_cfg.ffn_cfg, + attn_config=embedder_cfg.attn_cfg, ), ] ) + self.byte_cfg = byte_cfg + self.embedder_cfg = embedder_cfg def tokenize_input(self, input_string: str, truncate=False, add_eot=True): """Tokenize an input string. @@ -102,13 +125,13 @@ def tokenize_input(self, input_string: str, truncate=False, add_eot=True): ] # truncate bytes tokens = [ - token_seq[: self.model_cfg["byte_context_window"]] for token_seq in tokens + token_seq[: self.byte_cfg.byte_context_window] for token_seq in tokens ] # pad bytes tokens = [ token_seq + [self.byte_tokenizer.pad_token] - * (self.model_cfg["byte_context_window"] - len(token_seq)) + * (self.byte_cfg.byte_context_window - len(token_seq)) for token_seq in tokens ] return tokens @@ -127,43 +150,38 @@ def pad_batch(self, token_lists, direction="right"): max_len = max([len(token_list) for token_list in token_lists]) padded_token_lists = [] mask = [] - byte_context_window = self.model_cfg["byte_context_window"] + byte_context_window = self.byte_cfg.byte_context_window for token_list in token_lists: if direction == "right": padded_token_list = token_list + [ - [self.byte_tokenizer.pad_token] - * byte_context_window + [self.byte_tokenizer.pad_token] * byte_context_window ] * (max_len - len(token_list)) padded_token_lists.append(padded_token_list) - mask.append( - [1] * len(token_list) - + [0] * (max_len - len(token_list)) - ) + mask.append([1] * len(token_list) + [0] * (max_len - len(token_list))) else: padded_token_list = token_list + [ - [self.byte_tokenizer.pad_token] - * byte_context_window + [self.byte_tokenizer.pad_token] * byte_context_window ] * (max_len - len(token_list)) padded_token_lists.append(padded_token_list) - mask.append( - [0] * (max_len - len(token_list)) - + [1] * len(token_list) - ) + mask.append([0] * (max_len - len(token_list)) + [1] * len(token_list)) # expand the mask to include the byte context window mask[-1] = [[it] * byte_context_window for it in mask[-1]] return torch.tensor(padded_token_lists), torch.tensor(mask) def truncate(self, token_lists): # get model max length - max_length = self.model_cfg["context_window"] + max_length = self.byte_cfg.context_window return [token_seq[-max_length:] for token_seq in token_lists] - def decode(self, list_of_token_idss): + def decode(self, tokens): """ Decode the token ids. + Tokens is a second level list, where the first level + is the batch dimension, and the second level is the + list of token ids for a given sequence. """ return_strings = [] - for list_of_token_ids in list_of_token_idss: + for list_of_token_ids in tokens: return_string = "" for token_ids in list_of_token_ids: token_ids = [ diff --git a/models/experimental/byte_level/layers.py b/models/experimental/byte_level/layers.py index bc2e700b..c3483173 100644 --- a/models/experimental/byte_level/layers.py +++ b/models/experimental/byte_level/layers.py @@ -2,13 +2,24 @@ Shared components of the byte level models. """ +import pydantic import torch from models.components.layers.activations import build_activation -from models.components.layers.attention import Attention +from models.components.layers.attention import AttentionConfig, build_attention from models.components.layers.normalization import build_normalization +class ByteTransformerBlockConfig(pydantic.BaseModel): + """ + Feedforward network configuration + """ + + bias: bool + ffn_activation: str + ffn_normalization: str = "rmsprop" + + class ProjectingFFN(torch.nn.Module): """ A simple feedforward network @@ -16,24 +27,27 @@ class ProjectingFFN(torch.nn.Module): def __init__( self, - hidden_dim, - output_dim, + input_dim, ffn_dim, - bias, - ffn_activation, + output_dim, + block_cfg: ByteTransformerBlockConfig, ): super().__init__() # build the ffn block - self.linear_1 = torch.nn.Linear(hidden_dim, ffn_dim, bias=bias) + self.linear_1 = torch.nn.Linear(input_dim, ffn_dim, bias=block_cfg.bias) - self.activation = build_activation(activation_name=ffn_activation) + self.activation = build_activation(activation_name=block_cfg.ffn_activation) - self.linear_2 = torch.nn.Linear(ffn_dim, output_dim, bias=bias) + self.linear_2 = torch.nn.Linear(ffn_dim, output_dim, bias=block_cfg.bias) + self.normalization = build_normalization( + normalization_name="rmsprop", dim=input_dim, bias=block_cfg.bias + ) def forward(self, x): """ A simple forward pass through the FFN """ + x = self.normalization(x) x = self.linear_1(x) x = self.activation(x) x = self.linear_2(x) @@ -46,37 +60,28 @@ class ByteLevelTransformerBlock(torch.nn.Module): FFN, Attn and normalization. """ - def __init__(self, input_dim, output_dim, ffn_dim, context_window, use_rope=False): + def __init__( + self, + input_dim, + output_dim, + ffn_dim, + context_window, + byte_transformer_block_cfg: ByteTransformerBlockConfig, + attn_config: AttentionConfig, + ): super().__init__() - # build the attn norm - self.attn_norm = build_normalization( - normalization_name="rms_norm", dim=input_dim, bias=False - ) - # build the attention - self.attn = Attention( - hidden_dim=input_dim, - num_heads=8, - bias=False, - use_rope=use_rope, - context_window=context_window, - is_causal=False, - group_size=1, - ) - - # build the ffn norm - self.ffn_norm = build_normalization( - normalization_name="rms_norm", dim=input_dim, bias=False + self.attn = build_attention( + hidden_dim=input_dim, context_window=context_window, attn_cfg=attn_config ) # build the ffn block self.ffn = ProjectingFFN( - hidden_dim=input_dim, + input_dim=input_dim, ffn_dim=ffn_dim, output_dim=output_dim, - bias=False, - ffn_activation="gelu", + block_cfg=byte_transformer_block_cfg, ) def forward(self, x, attention_mask=None): @@ -89,6 +94,6 @@ def forward(self, x, attention_mask=None): Returns: x: the output tensor (b, s, h) """ - x = x + self.attn(self.attn_norm(x), attention_mask) - x = self.ffn(self.ffn_norm(x)) + x = x + self.attn(x, attention_mask) + x = self.ffn(x) return x diff --git a/models/experimental/byte_level/model_heads.py b/models/experimental/byte_level/model_heads.py index ed079aae..8359ed0d 100644 --- a/models/experimental/byte_level/model_heads.py +++ b/models/experimental/byte_level/model_heads.py @@ -5,10 +5,12 @@ import torch from models.components.positional_encoding import LearnedPosEncoding +from models.experimental.byte_level.byte_model_shell import ByteShellConfig from models.experimental.byte_level.layers import ByteLevelTransformerBlock +from models.model_heads import HeadInterface -class ByteLevelDecoder(torch.nn.Module): +class ByteLevelDecoder(HeadInterface): """ Use multiple learned heads to decode into by hidden size, pre-append to the byte embeddings of the answers and @@ -17,21 +19,18 @@ class ByteLevelDecoder(torch.nn.Module): the latent ecoded ones. """ - def __init__(self, model_cfg): + def __init__(self, byte_cfg: ByteShellConfig): super().__init__() - self.hidden_dim = model_cfg["hidden_dim"] - self.embedding_dim = model_cfg["byte_embedding_dim"] - self.byte_vocab_size = model_cfg["byte_vocab_size"] - self.byte_context_window = model_cfg["byte_context_window"] - + self.hidden_dim = byte_cfg.hidden_dim + self.embedding_dim = byte_cfg.byte_embedding_dim self.projection = torch.nn.Linear( in_features=self.hidden_dim, - out_features=self.byte_context_window * self.embedding_dim, + out_features=byte_cfg.byte_context_window * self.embedding_dim, bias=False, ) self.pos_encoder = LearnedPosEncoding( - hidden_dim=self.embedding_dim, context_window=self.byte_context_window + hidden_dim=self.embedding_dim, context_window=byte_cfg.byte_context_window ) # build transformer block @@ -41,59 +40,67 @@ def __init__(self, model_cfg): input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ByteLevelTransformerBlock( input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ByteLevelTransformerBlock( input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ByteLevelTransformerBlock( input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ByteLevelTransformerBlock( input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ByteLevelTransformerBlock( input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ByteLevelTransformerBlock( input_dim=self.embedding_dim, output_dim=self.embedding_dim, ffn_dim=self.embedding_dim * 4, - context_window=self.byte_context_window, - use_rope=False, + context_window=byte_cfg.byte_context_window, + byte_transformer_block_cfg=byte_cfg.ffn_cfg, + attn_config=byte_cfg.attn_cfg, ), ] ) self.lm_head = torch.nn.Linear( in_features=self.embedding_dim, - out_features=self.byte_vocab_size, + out_features=byte_cfg.byte_vocab_size, bias=False, ) + self.byte_cfg = byte_cfg def forward(self, x): """ @@ -102,11 +109,13 @@ def forward(self, x): # project the latent embeddings x = self.projection(x) - x = x.view(x.size(0), x.size(1), self.byte_context_window, self.embedding_dim) + x = x.view( + x.size(0), x.size(1), self.byte_cfg.byte_context_window, self.embedding_dim + ) # pass through model and deocde B, S, _, _ = x.size() - x = x.view(B * S, self.byte_context_window, self.embedding_dim) + x = x.view(B * S, self.byte_cfg.byte_context_window, self.embedding_dim) # positional encoding x = x + self.pos_encoder(x) @@ -119,7 +128,9 @@ def forward(self, x): x = self.lm_head(x) # reshape and return - x = x.view(B, S, self.byte_context_window, self.byte_vocab_size) + x = x.view( + B, S, self.byte_cfg.byte_context_window, self.byte_cfg.byte_vocab_size + ) return x, None diff --git a/models/experimental/hugging_face.py b/models/experimental/hugging_face.py index b6b4b70a..ed42cb53 100644 --- a/models/experimental/hugging_face.py +++ b/models/experimental/hugging_face.py @@ -1,18 +1,54 @@ """An interface for loading in models from the Hugging Face model hub This can be used for finetuning or training from scratch.""" +from typing import Literal + import torch from transformers import AutoModelForCausalLM, AutoTokenizer from models.components.tokenizers.base_class import Tokenizer -from models.embedding_models import EmbedderInterface -from models.model_shell import ModelShell +from models.core_models import CoreModelConfig +from models.embedding_models import EmbedderConfig, EmbedderInterface +from models.model_heads import HeadInterface, LMHeadConfig from trainers.base_trainer import BaseTrainer + +class HFEmbedderConfig(EmbedderConfig): + """ + Configuration for the Hugging Face model. + """ + + model_string: str + flash_attention: bool = False + embedding_model_type: Literal["hf_embedder"] + tokenizer_type: str = "dummy" + dataset_name: str = "dummy" + + +class HFCoreModelConfig(CoreModelConfig): + """ + Configuration for the Hugging Face transformer core. + """ + + core_model_type: Literal["hf_core"] + model_string: str + flash_attention: bool = False + + +class HFLMHeadConfig(LMHeadConfig): + """ + Configuration for the Hugging Face language model head. + """ + + lm_head_type: Literal["hf_lm_head"] + model_string: str + flash_attention: bool = False + + def build_model(model_cfg): - ''' + """ Helper function to build a model from the huggingface model hub. - ''' + """ ## get the model string model_str = model_cfg["model_string"] @@ -35,7 +71,10 @@ def build_model(model_cfg): class HFTokenizerWrapper(Tokenizer): + """Wrapper for the Hugging Face tokenizer""" + def __init__(self, hf_tokenizer_name): + super().__init__() self.hf_tokenizer = AutoTokenizer.from_pretrained(hf_tokenizer_name) self.eot_token = self.hf_tokenizer.eos_token_id self.pad_token = self.hf_tokenizer.pad_token_id @@ -73,18 +112,18 @@ class HFEmbedder(EmbedderInterface): A class for loading in models from the Hugging Face model hub """ - def __init__(self, model_cfg): + def __init__(self, model_cfg: HFEmbedderConfig): super().__init__() self.model_cfg = model_cfg model_string = model_cfg["model_string"] self.tokenizer = HFTokenizerWrapper(model_string) self.embeddings = build_model(model_cfg).get_input_embeddings() - def decode(self, token_ids): + def decode(self, tokens): """ Decode the token ids """ - return self.tokenizer.decode_batch(token_ids) + return self.tokenizer.decode_batch(tokens) def forward(self, token_ids): """ @@ -127,7 +166,7 @@ class HFTransformerCore(torch.nn.Module): def __init__(self, model_cfg): super().__init__() - self.model = build_model(model_cfg = model_cfg) + self.model = build_model(model_cfg=model_cfg) ## freeze the parameters print("Note: Freezing the parameters of the hf_core model.") @@ -139,23 +178,24 @@ def forward(self, x): Calls the huggingface model in question, and returns the last hidden state. """ ## get the hidden states - hidden_states = self.model(inputs_embeds = x, output_hidden_states = True).hidden_states + hidden_states = self.model( + inputs_embeds=x, output_hidden_states=True + ).hidden_states ## return the last hidden state if isinstance(hidden_states, tuple): return hidden_states[-1] - -class HFLMHead(torch.nn.Module): +class HFLMHead(HeadInterface): """ Takes the language model head of a Hugging Face transformer class. """ def __init__(self, model_cfg): super().__init__() - self.lm_head = build_model(model_cfg = model_cfg).get_output_embeddings() - + self.lm_head = build_model(model_cfg=model_cfg).get_output_embeddings() + def forward(self, x): """ Passes the input through the language model head to get logits. @@ -171,24 +211,9 @@ def forward(self, x): class MockTrainer(BaseTrainer): """A trainer that skips the training step, but runs e.g. logging""" - def __init__( - self, - cfg, - model: ModelShell, - optimizer, - dataloader, - loss_fn, - lr_scheduler=None, - dropout_scheduler=None, - ) -> None: - """Just forward the arguments to the parent class""" - super().__init__( - cfg, model, optimizer, dataloader, loss_fn, lr_scheduler, dropout_scheduler - ) - - def _run_step(self, *args, **kwargs): + def _run_step(self, *_, **__): + """We don't want to run the training step in this case...""" return torch.tensor(0.0) def _save_model(self, iter_num=0): """We don't want to save the model in this case...""" - pass diff --git a/models/experimental/next_thought/core_models.py b/models/experimental/next_thought/core_models.py index de760aa6..077dfb99 100644 --- a/models/experimental/next_thought/core_models.py +++ b/models/experimental/next_thought/core_models.py @@ -1,27 +1,41 @@ """ The core next-thought model. """ -import torch +from typing import Literal + +import torch + +from models.core_models import CoreModelConfig + + +class NextThoughtConfig(CoreModelConfig): + """ + Next Thought configuration + """ + + core_model_type: Literal["next_thought"] + latent_dim: int class BaselineCoreModel(torch.nn.Module): """ - An extremely simplistic core model for + An extremely simplistic core model for next thought prediction. """ - def __init__(self, model_cfg): + + def __init__(self, model_cfg: NextThoughtConfig): super().__init__() self.model = torch.nn.Sequential( torch.nn.Linear( - in_features=model_cfg["latent_dim"], - out_features=model_cfg["latent_dim"], + in_features=model_cfg.latent_dim, + out_features=model_cfg.latent_dim, ), torch.nn.ReLU(), torch.nn.Linear( - in_features=model_cfg["latent_dim"], - out_features=model_cfg["latent_dim"], + in_features=model_cfg.latent_dim, + out_features=model_cfg.latent_dim, ), ) @@ -34,13 +48,14 @@ def forward(self, x): x: torch.tensor(B, S, H) """ return self.model(x) - + class Conv1dCoreModel(torch.nn.Module): """ A core model for next thought prediction using Conv1d layers. """ - def __init__(self, model_cfg): + + def __init__(self): super().__init__() # 4800 @@ -48,9 +63,6 @@ def __init__(self, model_cfg): self.conv2 = torch.nn.Linear(300, 300) self.conv3 = torch.nn.Linear(3, 3) - - - def forward(self, x): """ Pass an input through the model @@ -73,4 +85,4 @@ def forward(self, x): x = x.view(x.size(0), 1600, 3) x = self.conv3(x) x = x.view(x.size(0), 4800) - return x \ No newline at end of file + return x diff --git a/models/experimental/next_thought/embedding_models.py b/models/experimental/next_thought/embedding_models.py index c958fd54..de5a2f3a 100644 --- a/models/experimental/next_thought/embedding_models.py +++ b/models/experimental/next_thought/embedding_models.py @@ -1,80 +1,105 @@ """ The Embedding model for a VAE style sequence to sequence model. """ -import torch -from models.embedding_models import GenericEmbedder -from models.components.layers.transformer_blocks import GenericTransformerBlock +from typing import Literal + +import torch +from models.components.layers import attention, feedforward +from models.components.layers.transformer_blocks import GenericTransformerBlock from models.components.positional_encoding import build_positional_encodings from models.components.tokenizers import build_tokenizer - +from models.embedding_models import GenericEmbedder, GenericEmbedderConfig # import local components from models.experimental.next_thought.layers import AttentionPoolingRemoval +class HierarchicalEncoderConfig(GenericEmbedderConfig): + """ + Hierarchical Encoder Configuration + """ + + embedder_type: Literal["hierarchical"] + pooling_dims: list = [768, 256, 128] + pct_pool_per_layer: list = [0.2, 0.2] + ffn: feedforward.FFNConfig + attn: attention.AttentionConfig + class HierarchicalEncoder(GenericEmbedder): """ Accepts an arbitrary length sequence as input, uses the QK^T matrix to, at every layer, - pick the top n-percent of nodes to pool into - a single token (the one paying most attention + pick the top n-percent of nodes to pool into + a single token (the one paying most attention to the other should be pooled into the other token). """ - def __init__(self, model_cfg): - super().__init__(model_cfg=model_cfg) + + def __init__( + self, + embedder_cfg: HierarchicalEncoderConfig, + vocab_size: int, + hidden_dim: int, + context_window: int, + positional_encoding_type: str, + ): + super().__init__( + embedder_cfg=embedder_cfg, + vocab_size=vocab_size, + hidden_dim=hidden_dim, + context_window=context_window, + positional_encoding_type=positional_encoding_type, + ) # build the tokenizer self.tokenizer = build_tokenizer( - tokenizer_type=model_cfg["embedder"]["tokenizer_type"], - vocab_size=model_cfg["vocab_size"], - dataset_name=model_cfg["embedder"]["dataset_name"], + tokenizer_type=embedder_cfg.tokenizer_type, + vocab_size=vocab_size, + dataset_name=embedder_cfg.dataset_name, ) # build the token embeddings self.token_embedder = torch.nn.Embedding( - num_embeddings=model_cfg["vocab_size"], - embedding_dim=model_cfg["embedder"]["pooling_dims"][0], + num_embeddings=vocab_size, + embedding_dim=embedder_cfg.pooling_dims[0], ) # build the positional encodings - self.positional_encodings = build_positional_encodings(model_cfg=model_cfg) - + self.positional_encodings = build_positional_encodings( + positional_encoding_type, hidden_dim, context_window + ) self.standard_transformer = torch.nn.ModuleList( [ GenericTransformerBlock( - hidden_dim=model_cfg["embedder"]["pooling_dims"][0], - context_window=model_cfg["embedder"]["context_window"], - use_rope=False, - ffn_cfg=model_cfg["embedder"]["standard_ffn_block"], - attn_cfg=model_cfg["embedder"]["standard_attn_block"], + hidden_dim=embedder_cfg.pooling_dims[0], + context_window=context_window, + ffn_cfg=embedder_cfg.ffn, + attn_cfg=embedder_cfg.attn, ) ] ) self.pooling_transformer = torch.nn.ModuleList( [ - AttentionPoolingRemoval( - hidden_size_in=model_cfg["embedder"]["pooling_dims"][i], - hidden_size_out=model_cfg["embedder"]["pooling_dims"][i+1], + hidden_size_in=embedder_cfg.pooling_dims[i], + hidden_size_out=embedder_cfg.pooling_dims[i + 1], num_attention_heads=12, - pct_pool_per_layer=model_cfg["embedder"]["pct_pool_per_layer"][i], - ) for i in range(len(model_cfg["embedder"]["pooling_dims"]) - 1) + pct_pool_per_layer=embedder_cfg.pct_pool_per_layer[i], + ) + for i in range(len(embedder_cfg.pooling_dims) - 1) ] ) - def forward(self, token_ids): - # embed the input + # embed the input x = self.embedding(token_ids) - # apply positional encoding + # apply positional encoding x = x + self.positional_encoding(x) - # first pass through normal attention blocks for layer in self.standard: x = layer(x) @@ -84,4 +109,4 @@ def forward(self, token_ids): x = layer(x) # mean pool final representation x = x.mean(dim=-2) - return x \ No newline at end of file + return x diff --git a/models/experimental/next_thought/layers.py b/models/experimental/next_thought/layers.py index 9dd84fa2..65405c04 100644 --- a/models/experimental/next_thought/layers.py +++ b/models/experimental/next_thought/layers.py @@ -1,8 +1,10 @@ """ Layers that are specific to the next thought models """ -import torch -import math + +import math + +import torch class AttentionPoolingRemoval(torch.nn.Module): @@ -10,7 +12,10 @@ class AttentionPoolingRemoval(torch.nn.Module): Transformer block that removes the top-k least paid-attention to tokens. """ - def __init__(self, hidden_size_in, hidden_size_out, num_attention_heads, pct_pool_per_layer): + + def __init__( + self, hidden_size_in, hidden_size_out, num_attention_heads, pct_pool_per_layer + ): super().__init__() self.pct_pool = pct_pool_per_layer self.hidden_size_in = hidden_size_in @@ -26,7 +31,7 @@ def __init__(self, hidden_size_in, hidden_size_out, num_attention_heads, pct_poo self.norm1 = torch.nn.LayerNorm(hidden_size_in) self.norm2 = torch.nn.LayerNorm(hidden_size_out) - + def forward(self, x): # Apply multi-head attention attn_output, attn_output_weights = self.attention(x, x, x) @@ -37,7 +42,6 @@ def forward(self, x): # find how much each token was attended to on average attn_output_weights = attn_output_weights.mean(dim=-2) - # Normalize and add residual connection x = self.norm1(x + attn_output) @@ -59,28 +63,33 @@ def forward(self, x): return reduced_x + # Scaled Dot-Product Attention -def scaled_dot_product_attention(query, key, value, mask=None): +def scaled_dot_product_attention(query, key, value, _=None): """ Compute scaled dot-product attention. """ # Q * K^T - scores = torch.matmul(query, key.transpose(-2, -1)) # (batch_size, num_heads, seq_len, seq_len) - + scores = torch.matmul( + query, key.transpose(-2, -1) + ) # (batch_size, num_heads, seq_len, seq_len) + # Scale by the square root of the key dimension d_k = query.size(-1) scores = scores / math.sqrt(d_k) # Apply mask if provided (optional, for example, in Transformer Decoders) - #if mask is not None: + # if mask is not None: # scores = scores.masked_fill(mask == 0, -1e9) # Softmax to get attention weights attention_weights = torch.nn.functional.softmax(scores, dim=-1) # Multiply by the value to get the final attention output - output = torch.matmul(attention_weights, value) # (batch_size, num_heads, seq_len, depth_per_head) - #input(attention_weights.size()) + output = torch.matmul( + attention_weights, value + ) # (batch_size, num_heads, seq_len, depth_per_head) + # input(attention_weights.size()) return output, attention_weights @@ -89,10 +98,13 @@ class CustomMultiHeadAttention(torch.nn.Module): """ Custom implementation of multi-head attention from scratch. """ + def __init__(self, hidden_size, num_heads): super().__init__() - assert hidden_size % num_heads == 0, "Hidden size must be evenly divisible by number of heads." - + assert ( + hidden_size % num_heads == 0 + ), "Hidden size must be evenly divisible by number of heads." + self.hidden_size = hidden_size self.num_heads = num_heads self.depth_per_head = hidden_size // num_heads @@ -128,10 +140,14 @@ def forward(self, q, k, v): value = self.split_into_heads(self.value_proj(v)) # Apply scaled dot-product attention - attention_output, attention_weights = scaled_dot_product_attention(query, key, value) + attention_output, attention_weights = scaled_dot_product_attention( + query, key, value + ) # Concatenate the heads - attention_output = attention_output.permute(0, 2, 1, 3).reshape(q.size(0), q.size(1), self.hidden_size) + attention_output = attention_output.permute(0, 2, 1, 3).reshape( + q.size(0), q.size(1), self.hidden_size + ) # Final projection to maintain consistent output output = self.out_proj(attention_output) @@ -139,13 +155,13 @@ def forward(self, q, k, v): return output, attention_weights - class LatentSpaceDecoder(torch.nn.Module): """ - Uses a fixed number of heads to decode - the latent space into the same hidden dim + Uses a fixed number of heads to decode + the latent space into the same hidden dim as the sequence """ + def __init__(self, hidden_dim, decoding_length, latent_dim): super().__init__() self.hidden_dim = hidden_dim @@ -153,8 +169,7 @@ def __init__(self, hidden_dim, decoding_length, latent_dim): self.latent_dim = latent_dim self.decoding_layer = torch.nn.Linear( - in_features=latent_dim, - out_features=hidden_dim*decoding_length + in_features=latent_dim, out_features=hidden_dim * decoding_length ) def forward(self, x): @@ -169,11 +184,13 @@ def forward(self, x): x = x.view(batch_size, self.decoding_length, self.hidden_dim) return x - + + class LatentSpaceQuery(torch.nn.Module): """ Lets the decoder query the latent space """ + def __init__(self, hidden_dim, latent_decoded_length, latent_dim): super().__init__() self.hidden_dim = hidden_dim @@ -182,10 +199,7 @@ def __init__(self, hidden_dim, latent_decoded_length, latent_dim): # k,v come from latent space # q comes from the sequence - self.attention = CustomMultiHeadAttention( - hidden_size=hidden_dim, - num_heads=12 - ) + self.attention = CustomMultiHeadAttention(hidden_size=hidden_dim, num_heads=12) def forward(self, x, latent_space): """ @@ -194,10 +208,6 @@ def forward(self, x, latent_space): """ # Query the latent space - x, _ = self.attention( - q=x, - k=latent_space, - v=latent_space - ) + x, _ = self.attention(q=x, k=latent_space, v=latent_space) - return x \ No newline at end of file + return x diff --git a/models/experimental/next_thought/model_heads.py b/models/experimental/next_thought/model_heads.py index 00238115..00552f94 100644 --- a/models/experimental/next_thought/model_heads.py +++ b/models/experimental/next_thought/model_heads.py @@ -1,33 +1,29 @@ """ The latent to variable length sequence decoder. """ -import torch -from models.experimental.next_thought.layers import ( - LatentSpaceDecoder, - LatentSpaceQuery -) +import torch -from models.embedding_models import GenericEmbedder from models.components.layers.transformer_blocks import GenericTransformerBlock -from models.components.positional_encoding import build_positional_encodings +from models.model_heads import HeadInterface - -class VariableLengthLatentDecoder(torch.nn.Module): +class VariableLengthLatentDecoder(HeadInterface): """ Given a latent space representation, decode it into a sequence. This should be similar to how VLMs work (i.e. have an encoder for the latent space and query it at each step to generate the next token). """ + def __init__(self, model_cfg, embedding_model): super().__init__() self.model_cfg = model_cfg self.latent_decoder = torch.nn.Linear( in_features=model_cfg["latent_dim"], - out_features=model_cfg["embedding_dim"] * model_cfg["lm_head"]["latent_decoded_into"], - bias=False + out_features=model_cfg["embedding_dim"] + * model_cfg["lm_head"]["latent_decoded_into"], + bias=False, ) self.token_embedder = embedding_model.token_embedder @@ -38,20 +34,19 @@ def __init__(self, model_cfg, embedding_model): GenericTransformerBlock( hidden_dim=model_cfg["embedding_dim"], context_window=model_cfg["context_window"], - use_rope=False, ffn_cfg=model_cfg["lm_head"]["standard_ffn_block"], attn_cfg=model_cfg["lm_head"]["standard_attn_block"], - ) for _ in range(model_cfg["lm_head"]["num_layers"]) + ) + for _ in range(model_cfg["lm_head"]["num_layers"]) ] ) self.lm_head = torch.nn.Linear( in_features=model_cfg["embedding_dim"], out_features=model_cfg["vocab_size"], - bias=False + bias=False, ) - def forward(self, x, x_raw=None): """ forward @@ -59,7 +54,11 @@ def forward(self, x, x_raw=None): # decode latent into tokens x = self.latent_decoder(x) # reshape - x = x.view(x.size(0), self.model_cfg["lm_head"]["latent_decoded_into"], self.model_cfg["embedding_dim"]) + x = x.view( + x.size(0), + self.model_cfg["lm_head"]["latent_decoded_into"], + self.model_cfg["embedding_dim"], + ) # encode the target tokens with the embedder (w/o gradient) y = self.token_embedder(x_raw) @@ -75,6 +74,6 @@ def forward(self, x, x_raw=None): x = layer(x) # pass through lm head - x = self.lm_head(x[:, self.model_cfg["lm_head"]["latent_decoded_into"]:]) + x = self.lm_head(x[:, self.model_cfg["lm_head"]["latent_decoded_into"] :]) - return x, None \ No newline at end of file + return x, None diff --git a/models/generator.py b/models/generator.py index aed6c5bf..eafbc108 100644 --- a/models/generator.py +++ b/models/generator.py @@ -33,9 +33,9 @@ def generate(self, input_text, max_new_tokens, temperature=1.0, top_k=None): the sequence max_new_tokens times, feeding the predictions back into the model each time. Most likely you'll want to make sure to be in model.eval() mode of operation for this. """ - idx = self.model.embedding_model.tokenize_input(input_string=input_text, - add_eot=False, - truncate=True) + idx = self.model.embedding_model.tokenize_input( + input_string=input_text, add_eot=False, truncate=True + ) # push to device idx = torch.tensor(idx).unsqueeze(0).to(torch.device("cuda")) for _ in range(max_new_tokens): @@ -55,16 +55,14 @@ def generate(self, input_text, max_new_tokens, temperature=1.0, top_k=None): # apply softmax to convert logits to (normalized) probabilities probs = torch.nn.functional.softmax(logits, dim=-1) # sample from the distribution - # check if byte-level and if so, flatten + # check if byte-level and if so, flatten if len(probs.size()) == 4: B, S, S_c, H = probs.size() - probs = probs.view(B* S * S_c, H) + probs = probs.view(B * S * S_c, H) flattened = True else: flattened = False - - idx_next = torch.multinomial(probs, num_samples=1) # check if byte-level and if so, unflatten diff --git a/models/model_heads.py b/models/model_heads.py index ab1eb107..67e2f905 100644 --- a/models/model_heads.py +++ b/models/model_heads.py @@ -2,27 +2,79 @@ A collection of different model heads. """ +from typing import Literal + +import pydantic import torch from models.components.layers.normalization import build_normalization +class LMHeadConfig(pydantic.BaseModel): + """ + Head configuration + """ + + lm_head_type: str + + +class GenericLMHeadConfig(LMHeadConfig): + """ + Language Model Head configuration + """ + + lm_head_type: Literal["generic"] + normalization: str + bias: bool + + +class HeadInterface(torch.nn.Module): + """ + Interface for the head component of the model. + """ + + def forward(self, x) -> tuple[torch.Tensor, torch.Tensor | None]: + """ + This function should take the input tensor x as input, + and return the output tensor. + """ + raise NotImplementedError + + def inference(self, x): + """ + Pass the input through the model, then + Return the final token logits + Args: + x: torch.tensor(B, S, H) + Returns: + x: torch.tensor(B, V) + """ + return self.forward(x)[0][:, -1, :] + + class AutoregressiveLMHead(torch.nn.Module): """ Generic autoregressive language model head. """ - def __init__(self, model_cfg): + def __init__(self, hidden_dim, vocab_size, lm_head_cfg: GenericLMHeadConfig): + """ + Initialize the model. + Args: + hidden_dim: int + vocab_size: int + lm_head_cfg: LMHeadConfig + """ super().__init__() self.layer_norm = build_normalization( - normalization_name=model_cfg["lm_head"]["normalization"], - dim=model_cfg["hidden_dim"], - bias=model_cfg["lm_head"]["bias"], + normalization_name=lm_head_cfg.normalization, + dim=hidden_dim, + bias=lm_head_cfg.bias, ) self.linear = torch.nn.Linear( - in_features=model_cfg["hidden_dim"], - out_features=model_cfg["vocab_size"], - bias=model_cfg["lm_head"]["bias"], + in_features=hidden_dim, + out_features=vocab_size, + bias=lm_head_cfg.bias, ) def forward(self, x): @@ -41,14 +93,3 @@ def forward(self, x): x = self.linear(x) return x, None - - def inference(self, x): - """ - Pass the input through the model, then - Return the final token logits - Args: - x: torch.tensor(B, S, H) - Returns: - x: torch.tensor(B, V) - """ - return self.forward(x)[0][:, -1, :] diff --git a/models/model_shell.py b/models/model_shell.py index 330f6de3..4c6cc781 100644 --- a/models/model_shell.py +++ b/models/model_shell.py @@ -3,11 +3,27 @@ core model and LM head. """ +from typing import Literal + +import pydantic import torch from models import core_models, embedding_models, model_heads +class ModelShellConfig(pydantic.BaseModel): + """Config for the standard model shell""" + + model_shell_type: Literal["standard"] + core_model: core_models.CoreModelConfig + embedding_model: embedding_models.EmbedderConfig + model_head: model_heads.LMHeadConfig + hidden_dim: int + context_window: int + vocab_size: int + embedding_weight_tying: bool + positional_encoding_type: str + class ModelShell(torch.nn.Module): """ @@ -71,7 +87,9 @@ def inference(self, model_input): # check if input is string if isinstance(model_input, str): # use inference function of the embedding model - model_input = self.embedding_model.tokenize_input(model_input, truncate=True, add_eot=False) + model_input = self.embedding_model.tokenize_input( + model_input, truncate=True, add_eot=False + ) x = torch.tensor(model_input, device=self.device, dtype=torch.long).unsqueeze(0) x = self.embedding_model(model_input) @@ -94,9 +112,16 @@ def loglikelihood(self, prefixes, continuations): Returns: ll: torch.tensor(B) """ - total_strings = [f"{prefix} {cont}" for prefix, cont in zip(prefixes, continuations)] - input_tokens = [self.embedding_model.tokenize_input(string, truncate=True) for string in total_strings] - padded_batch, mask = self.embedding_model.pad_batch(input_tokens, direction="right") + total_strings = [ + f"{prefix} {cont}" for prefix, cont in zip(prefixes, continuations) + ] + input_tokens = [ + self.embedding_model.tokenize_input(string, truncate=True) + for string in total_strings + ] + padded_batch, mask = self.embedding_model.pad_batch( + input_tokens, direction="right" + ) input_tensor = torch.tensor(padded_batch, device=self.device, dtype=torch.long) logits, _ = self.forward(input_tensor) logits = logits[:, :-1].reshape(-1, logits.size(-1)) @@ -105,4 +130,4 @@ def loglikelihood(self, prefixes, continuations): mask = mask[:, 1:].reshape(-1).to(ll.device) ll = ll * mask ll = ll.view(input_tensor.size(0), -1).sum(dim=1) - return -ll \ No newline at end of file + return -ll diff --git a/models/utils.py b/models/utils.py index 57cbf3dd..c458102a 100644 --- a/models/utils.py +++ b/models/utils.py @@ -3,29 +3,31 @@ """ import pandas as pd + from models.model_shell import ModelShell + def analyze_shared_parameters(model1, model2): shared_params = 0 total_params1 = 0 total_params2 = 0 - + # Create dictionaries of all parameters for each model params1 = {id(p): p for p in model1.parameters()} params2 = {id(p): p for p in model2.parameters()} - + # Find shared parameters shared_ids = set(params1.keys()) & set(params2.keys()) - + # Count parameters for pid in params1: total_params1 += params1[pid].numel() if pid in shared_ids: shared_params += params1[pid].numel() - + for pid in params2: total_params2 += params2[pid].numel() - + return shared_params, (total_params1 + total_params2 - shared_params) @@ -37,8 +39,10 @@ def print_model_stats(model: ModelShell): total_params = sum(p.numel() for p in model.parameters()) # Check if the parameters are shared - - _, lm_head_and_embeddings_params = analyze_shared_parameters(model.embedding_model, model.model_head) + + _, lm_head_and_embeddings_params = analyze_shared_parameters( + model.embedding_model, model.model_head + ) core_model_params = total_params - lm_head_and_embeddings_params # Format the numbers for better readability From f49d43a9cdb083ac86173f977a7179627ba9a3e9 Mon Sep 17 00:00:00 2001 From: Dylan Hillier Date: Mon, 19 Aug 2024 18:43:09 +0800 Subject: [PATCH 2/3] resolving pre-commit changes --- models/embedding_models.py | 2 +- models/experimental/hugging_face.py | 4 +- models/model_shell.py | 2 +- train.py | 32 ++-- trainers/base_trainer.py | 205 +++++++++++++---------- trainers/build_trainers.py | 127 ++++++-------- trainers/config.py | 35 ++++ trainers/datasets.py | 172 ++++++++++++++----- trainers/evaluation.py | 67 ++++++++ trainers/evaluator.py | 13 -- trainers/loss_fn.py | 36 ++-- trainers/{optimizer.py => optimizers.py} | 49 +++++- trainers/prepare.py | 117 +++++++------ trainers/samplers.py | 30 +++- trainers/scheduler.py | 131 --------------- trainers/schedulers.py | 196 ++++++++++++++++++++++ trainers/utils.py | 165 +++++++++++------- 17 files changed, 873 insertions(+), 510 deletions(-) create mode 100644 trainers/config.py create mode 100644 trainers/evaluation.py delete mode 100644 trainers/evaluator.py rename trainers/{optimizer.py => optimizers.py} (57%) delete mode 100644 trainers/scheduler.py create mode 100644 trainers/schedulers.py diff --git a/models/embedding_models.py b/models/embedding_models.py index df3f3d12..8118ec09 100644 --- a/models/embedding_models.py +++ b/models/embedding_models.py @@ -118,7 +118,7 @@ class GenericEmbedder(EmbedderInterface): def __init__( self, - embedder_cfg: EmbedderConfig, + embedder_cfg: GenericEmbedderConfig, vocab_size: int, hidden_dim: int, context_window: int, diff --git a/models/experimental/hugging_face.py b/models/experimental/hugging_face.py index ed42cb53..ca6de3fa 100644 --- a/models/experimental/hugging_face.py +++ b/models/experimental/hugging_face.py @@ -8,12 +8,12 @@ from models.components.tokenizers.base_class import Tokenizer from models.core_models import CoreModelConfig -from models.embedding_models import EmbedderConfig, EmbedderInterface +from models.embedding_models import EmbedderInterface, GenericEmbedderConfig from models.model_heads import HeadInterface, LMHeadConfig from trainers.base_trainer import BaseTrainer -class HFEmbedderConfig(EmbedderConfig): +class HFEmbedderConfig(GenericEmbedderConfig): """ Configuration for the Hugging Face model. """ diff --git a/models/model_shell.py b/models/model_shell.py index 4c6cc781..e94f7603 100644 --- a/models/model_shell.py +++ b/models/model_shell.py @@ -16,7 +16,7 @@ class ModelShellConfig(pydantic.BaseModel): model_shell_type: Literal["standard"] core_model: core_models.CoreModelConfig - embedding_model: embedding_models.EmbedderConfig + embedding_model: embedding_models.GenericEmbedderConfig model_head: model_heads.LMHeadConfig hidden_dim: int context_window: int diff --git a/train.py b/train.py index f41567d9..09494e7f 100644 --- a/train.py +++ b/train.py @@ -1,20 +1,25 @@ """ The main training code """ + import os import hydra +import torch +import torch.multiprocessing as mp +from torch.distributed import destroy_process_group from models.build_models import build_model -from trainers.build_trainers import build_trainer, ddp_setup -from trainers import base_trainer -from trainers.utils import create_folder_structure, init_print_override, restore_print_override from models.utils import print_model_stats - -import torch -from torch.distributed import destroy_process_group -import torch.multiprocessing as mp +from trainers import base_trainer +from trainers.build_trainers import build_trainer, ddp_setup from trainers.prepare import prepare_data +from trainers.utils import ( + create_folder_structure, + init_print_override, + restore_print_override, +) + def ddp_main(rank, world_size, cfg): """ @@ -35,14 +40,12 @@ def ddp_main(rank, world_size, cfg): print_model_stats(model) # load the relevant trainer trainer: base_trainer.BaseTrainer = build_trainer( - cfg=cfg, - model=model, - gpu_id=rank + cfg=cfg, model=model, gpu_id=rank ) print(f"Rank{rank} Trainer built") # train the model trainer.train() - + finally: # clean up destroy_process_group() @@ -54,19 +57,18 @@ def ddp_main(rank, world_size, cfg): @hydra.main(config_path="configs", config_name="train") def main(cfg): world_size = torch.cuda.device_count() - + if "full_configs" in cfg: cfg = cfg["full_configs"] cfg["general"]["paths"]["data_dir"] = hydra.utils.to_absolute_path( cfg["general"]["paths"]["data_dir"] - ) # must be done before multiprocessing or else the path is wrong? + ) # must be done before multiprocessing or else the path is wrong? create_folder_structure(path_config=cfg["general"]["paths"]) - # process data + # process data prepare_data(cfg) - mp.spawn( ddp_main, args=(world_size, cfg), diff --git a/trainers/base_trainer.py b/trainers/base_trainer.py index b32019dd..e5e72364 100644 --- a/trainers/base_trainer.py +++ b/trainers/base_trainer.py @@ -1,25 +1,20 @@ """Trainer class for training models with Next Token Prediction""" import time +from contextlib import nullcontext +import numpy as np import torch import wandb from omegaconf import OmegaConf +from torch.nn.parallel import DistributedDataParallel as DDP from torch.profiler import ProfilerActivity, profile, record_function -from copy import deepcopy -from contextlib import nullcontext +from high_level_configs import GeneralConfig from models import model_shell -from trainers import datasets as train_dataloader -from trainers import utils - -from trainers.evaluator import train_eval - -import numpy as np -from itertools import islice -from torch.nn.parallel import DistributedDataParallel as DDP -from torch.utils.data.distributed import DistributedSampler -from torch.utils.data import SequentialSampler +from trainers import evaluation, utils +from trainers.config import TrainerConfig +from trainers.optimizers import OptimizerConfig, configure_nanoGPT_optimizer from trainers.utils import aggregate_value, print_evaluation_results @@ -32,64 +27,78 @@ class BaseTrainer: def __init__( self, - cfg, + original_cfg: OmegaConf, + general_cfg: GeneralConfig, + training_cfg: TrainerConfig, + model_cfg: model_shell.ModelShellConfig, model: model_shell.ModelShell, - optimizer, + optimizer_cfg: OptimizerConfig, + evaluation_cfg: evaluation.EvaluationConfig, train_dataloader, val_dataloader, loss_fn, - gpu_id=None, + gpu_id=None, lr_scheduler=None, dropout_scheduler=None, ) -> None: + self.original_cfg = original_cfg + self.general_cfg = general_cfg + self.training_cfg = training_cfg + self.model_cfg = model_cfg + self.optimizer_cfg = optimizer_cfg + self.evaluation_cfg = evaluation_cfg self.model = model - if gpu_id is not None: # using ddp + if gpu_id is not None: # using ddp self.dist = True - self.DDP_model = DDP(self.model, device_ids=[gpu_id]) + self.ddp_model = DDP(self.model, device_ids=[gpu_id]) else: self.dist = False - self.DDP_model = model - self.gpu_id = gpu_id - self.optimizer = optimizer + self.ddp_model = model + self.gpu_id = gpu_id + self.optimizer = configure_nanoGPT_optimizer(self.model, optimizer_cfg) self.lr_scheduler = lr_scheduler self.dropout_scheduler = dropout_scheduler self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn - self.cfg = cfg - #assert self.cfg["trainer"]["training"]["gradient_accumulation_steps"] % torch.cuda.device_count() == 0, "Gradient Accumulation Steps must be divisible by the number of GPUs" - self.gradient_accumulation_steps = cfg["trainer"]["training"][ - "gradient_accumulation_steps" - ] #// torch.cuda.device_count() ## divide by number of GPUs to maximise throughput + self.gradient_accumulation_steps = ( + training_cfg.gradient_accumulation_steps + ) # // torch.cuda.device_count() ## divide by number of GPUs to maximise throughput self.scaler = None - self.use_wandb = cfg["general"]["logging"]["wandb_log"] - self.checkpoint_dir = cfg["general"]["paths"]["checkpoint_dir"] + self.use_wandb = general_cfg.logging.wandb_log + self.checkpoint_dir = general_cfg.paths.checkpoint_dir self.cached_sets = {"train": {}, "val": {}} - self.batch_size = cfg["trainer"]["training"]["batch_size"] ## new + self.batch_size = training_cfg.batch_size ## new # For training, always force the device to be cuda - #assert torch.cuda.is_available(), "CUDA must be available for training" + # assert torch.cuda.is_available(), "CUDA must be available for training" self.ctx = self._setup_ctx() - if self.use_wandb and (self.gpu_id == 0 or not self.dist): ## ensures that only the first GPU logs to wandb + if self.use_wandb and ( + self.gpu_id == 0 or not self.dist + ): ## ensures that only the first GPU logs to wandb self._setup_logging() - if cfg.trainer.training.run_profiler and (self.gpu_id == 0 or not self.dist): ## ensures that only the first GPU runs the profiler + if training_cfg.run_profiler and ( + self.gpu_id == 0 or not self.dist + ): ## ensures that only the first GPU runs the profiler self.run_profile() raise SystemExit def _setup_logging(self): # set run name run_name = ( - f"{self.cfg.model['model_shell_type']}" - f"_{self.cfg.model['core_model']['core_model_type']}" - f"_{self.cfg.trainer['dataset']}_{self.cfg.model['embedder']['embedding_model_type']}" - f"_{self.cfg.model['vocab_size']}" + f"{self.model_cfg.model_shell_type}" + f"_{self.model_cfg.core_model.core_model_type}" + f"_{self.training_cfg.dataset}_{self.model_cfg.embedder.embedding_model_type}" + f"_{self.model_cfg.vocab_size}" ) wandb.init( - project=self.cfg.general.logging.wandb_project, - config=OmegaConf.to_container(self.cfg), + project=self.general_cfg.logging.wandb_project, + config=OmegaConf.to_container(self.original_cfg), name=run_name, ) - wandb.init(project=self.cfg.general.logging.wandb_project) + wandb.init( + project=self.general_cfg.logging.wandb_project + ) ## why does this happen twice??? print("wand_b_initted") def _setup_ctx(self): @@ -109,16 +118,15 @@ def _setup_scaler(self, dtype=torch.float16): """Setup the scaler""" self.scaler = torch.cuda.amp.GradScaler(enabled=dtype == torch.float16) - @torch.no_grad() def estimate_performance(self, eval_iters=None): """Estimate the loss""" if eval_iters is None: - eval_iters = self.cfg.trainer.training.eval_iters + eval_iters = self.evaluation_cfg.eval_iters eval_results = {} self.model.eval() - # eval on val set + # eval on val set losses = [] perplexities = [] for i, (x, y) in enumerate(self.val_dataloader): @@ -132,36 +140,37 @@ def estimate_performance(self, eval_iters=None): losses.append(loss.item()) # compute perplexity - perplexity = torch.exp(loss) # since seq len is always the same during training anyway + perplexity = torch.exp( + loss + ) # since seq len is always the same during training anyway perplexities.append(perplexity.item()) - - if i >= eval_iters: break - - avg_loss = aggregate_value(np.mean(losses), self.cfg.general.device) + + avg_loss = aggregate_value(np.mean(losses), self.general_cfg.device) eval_results["Loss"] = avg_loss - avg_perplexity = aggregate_value(np.mean(perplexities), self.cfg.general.device) + avg_perplexity = aggregate_value(np.mean(perplexities), self.general_cfg.device) eval_results["Perplexity"] = avg_perplexity - evaluator_results = {} - for evaluator in self.cfg.trainer["eval"]: - evaluator_results[evaluator["evaluator"]] = train_eval(evaluator, self.model) + for evaluator_dict in self.evaluation_cfg.evaluators: + evaluator = evaluation.get_evaluator_config(evaluator_dict) + evaluator_results[evaluator.evaluator] = evaluation.train_eval( + evaluator, self.model + ) # recurse over metrics to prepend the evaluator name as a prefix relabeled_results = {} - for metric in evaluator_results[evaluator["evaluator"]]: - relabeled_results[f"{evaluator['evaluator']}/{metric}"] = evaluator_results[evaluator["evaluator"]][metric] - evaluator_results[evaluator["evaluator"]] = relabeled_results + for metric in evaluator_results[evaluator.evaluator]: + relabeled_results[f"{evaluator.evaluator}/{metric}"] = ( + evaluator_results[evaluator.evaluator][metric] + ) + evaluator_results[evaluator.evaluator] = relabeled_results self.model.train() return eval_results, evaluator_results - - - - def _run_step(self, epoch=0): + def _run_step(self): """Run a single step of training with gradient accumulation.""" self.optimizer.zero_grad() # Clear gradients at the start of accumulation @@ -170,15 +179,21 @@ def _run_step(self, epoch=0): y = y.to(self.gpu_id if self.gpu_id is not None else self.model.device) # Enable or disable gradient synchronization based on the need for accumulation - if self.dist and hasattr(self.DDP_model, 'no_sync'): - context_manager = self.DDP_model.no_sync() if i != self.gradient_accumulation_steps - 1 else nullcontext() + if self.dist and hasattr(self.ddp_model, "no_sync"): + context_manager = ( + self.ddp_model.no_sync() + if i != self.gradient_accumulation_steps - 1 + else nullcontext() + ) else: context_manager = nullcontext() with context_manager: with self.ctx: # Assuming self.ctx is something like torch.cuda.amp.autocast - output, aux_loss = self.DDP_model(x) - loss = self.loss_fn(output, y) #+ (aux_loss if aux_loss is not None else 0) + output, aux_loss = self.ddp_model(x) + loss = self.loss_fn( + output, y + ) # + (aux_loss if aux_loss is not None else 0) if aux_loss is not None: loss += aux_loss # Scale loss to simulate larger effective batch size @@ -186,13 +201,17 @@ def _run_step(self, epoch=0): self.scaler.scale(loss).backward() # Step and update only after accumulating enough gradients - if (i + 1) % self.gradient_accumulation_steps == 0 or (i + 1) == len(self.train_dataloader): - if self.cfg.trainer.optimizer.grad_clip > 0: + if (i + 1) % self.gradient_accumulation_steps == 0 or (i + 1) == len( + self.train_dataloader + ): + if self.optimizer_cfg.grad_clip > 0: # Unscale the gradients of the optimizer's assigned params in-place self.scaler.unscale_(self.optimizer) # Clip the gradients with normalization - torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.trainer.optimizer.grad_clip) - + torch.nn.utils.clip_grad_norm_( + self.model.parameters(), self.optimizer_cfg.grad_clip + ) + # Perform a single optimization step self.scaler.step(self.optimizer) self.scaler.update() @@ -214,10 +233,10 @@ def run_profile(self): ) as prof: for i in range(10): if i <= 3: - self._run_step() ## set the 'epoch' to ensure shuffle + self._run_step() ## set the 'epoch' to ensure shuffle else: with record_function("_run_step"): - self._run_step() ## set the 'epoch' to ensure shuffle + self._run_step() ## set the 'epoch' to ensure shuffle # place profile in dictionary backwards_prof = prof.key_averages().table(sort_by="self_cpu_time_total") print(backwards_prof) @@ -245,7 +264,7 @@ def _save_model(self, iter_num=0): "model": self.model.state_dict(), "optimizer": self.optimizer.state_dict(), "iter_num": iter_num, - "config": self.cfg, + "config": self.original_cfg, } checkpoint_path = f"{self.checkpoint_dir}/ckpt_{iter_num}.pt" print(f"saving checkpoint to {checkpoint_path}") @@ -253,7 +272,7 @@ def _save_model(self, iter_num=0): def run_training_loop(self): """Run the training loop""" - for iter_num in range(self.cfg.trainer.training.max_iters): + for iter_num in range(self.training_cfg.max_iters): start_time = time.time() if self.lr_scheduler is not None: lr = self.lr_scheduler.step(self.optimizer, iter_num) @@ -262,51 +281,53 @@ def run_training_loop(self): dropout = self.dropout_scheduler.step(self.model, iter_num) # estimate the loss on the train/val sets if ( - not iter_num % self.cfg.trainer.training.eval_interval - ): # run on first iter to prevent bugs causing it to crash + not iter_num % self.evaluation_cfg.eval_interval + ): # run on first iter to prevent bugs causing it to crash eval_results, benchmark_results = self.estimate_performance() # print the evals as table # evals format is d1: type d2: train/val print_evaluation_results( - iter_num=iter_num, - eval_results=eval_results, - benchmark_results=benchmark_results + iter_num=iter_num, + eval_results=eval_results, + benchmark_results=benchmark_results, ) # Log to wandb - if (self.gpu_id == 0 or self.gpu_id is None) and self.use_wandb: # ensure only the first GPU logs + if ( + self.gpu_id == 0 or self.gpu_id is None + ) and self.use_wandb: # ensure only the first GPU logs log_dict = {"iter": iter_num, "lr": lr, "dropout": dropout} - log_dict.update(eval_results) # Directly add evals to the log dictionary - log_dict.update({k:v for k,v in benchmark_results.items()}) # Add benchmark results to the log dictionary + log_dict.update( + eval_results + ) # Directly add evals to the log dictionary + log_dict.update( + {k: v for k, v in benchmark_results.items()} + ) # Add benchmark results to the log dictionary wandb.log(log_dict) # save checkpoints if ( - not iter_num % self.cfg.trainer.training.checkpoint_interval + not iter_num % self.training_cfg.checkpoint_interval and iter_num > 0 - and ( - self.gpu_id == 0 - or self.gpu_id == None - ) ## ensure only the first GPU prints + and (self.gpu_id in (0, None)) ## ensure only the first GPU prints ): self._save_model(iter_num) - - loss = self._run_step() ## set the 'epoch' to ensure shuffle + loss = self._run_step() ## set the 'epoch' to ensure shuffle end_time = time.time() - if not iter_num % self.cfg.trainer.training.log_interval and iter_num > 0: + if not iter_num % self.training_cfg.log_interval and iter_num > 0: lossf = loss * self.gradient_accumulation_steps - ## uncomment the following line to print the loss on all GPUs - # print(f"GPU {self.gpu_id}: step {iter_num}: loss {lossf:.4f}, lr {lr:.1e}, dt {end_time-start_time:.1f}s") - ## aggregate the loss across all GPUs - lossf = aggregate_value(lossf, self.cfg.general.device) + lossf = aggregate_value(lossf, self.general_cfg.device) ## print and log the result only on the first GPU after aggregation - print(f"All GPU(s): step {iter_num}: loss {lossf:.4f}, lr {lr:.1e}, dt {end_time-start_time:.1f}s") + print( + f"All GPU(s): step {iter_num}: loss {lossf:.4f}," + f" lr {lr:.1e}, dt {end_time-start_time:.1f}s" + ) if (self.gpu_id == 0 or self.gpu_id is None) and self.use_wandb: wandb.log( { @@ -317,7 +338,9 @@ def run_training_loop(self): } ) # save the final model - if self.gpu_id == 0 or self.gpu_id is None: ## ensure only the first GPU saves the model + if ( + self.gpu_id == 0 or self.gpu_id is None + ): ## ensure only the first GPU saves the model self._save_model(iter_num) def train(self, seed=42): diff --git a/trainers/build_trainers.py b/trainers/build_trainers.py index d5c45db4..445bdbfc 100644 --- a/trainers/build_trainers.py +++ b/trainers/build_trainers.py @@ -9,6 +9,7 @@ from torch.distributed import init_process_group from models.experimental.hugging_face import MockTrainer +from trainers import config, optimizers, schedulers from trainers.base_trainer import BaseTrainer from trainers.datasets import ( BaseDataset, @@ -21,15 +22,7 @@ masked_cross_entropy_loss_fn, next_token_mlm_loss_fn, ) -from trainers.optimizer import configure_nanoGPT_optimizer from trainers.samplers import BaseSampler -from trainers.scheduler import ( - CosineLRScheduler, - DropoutScheduler, - LinearDropoutScheduler, - LRScheduler, - TriangleDropoutScheduler, -) def ddp_setup(rank, world_size): @@ -49,74 +42,56 @@ def ddp_setup(rank, world_size): torch.cuda.set_device(rank) -OPTIMIZER_DICT = { - "nanoGPTadamW": lambda model, trainer_cfg: configure_nanoGPT_optimizer( - model=model, - weight_decay=trainer_cfg["weight_decay"], - learning_rate=trainer_cfg["lr"], - betas=(trainer_cfg["beta1"], trainer_cfg["beta2"]), - ), - "adamW": lambda model, trainer_cfg: torch.optim.AdamW( - model.parameters(), - lr=trainer_cfg["lr"], - betas=(trainer_cfg["beta1"], trainer_cfg["beta2"]), - weight_decay=trainer_cfg["weight_decay"], - ), -} - - -def build_optimizer(model, optimizer_config): +def build_optimizer(model, optimizer_config: optimizers.OptimizerConfig): """ Given the optimizer config, build the optimizer """ - return OPTIMIZER_DICT[optimizer_config["name"]]( - model=model, trainer_cfg=optimizer_config - ) - - -SCHEDULER_DICT = { - "cosine": lambda trainer_cfg: CosineLRScheduler( - warmup_iters=trainer_cfg["training"]["warmup_iters"], - decay_iters=trainer_cfg["training"]["lr_decay_iters"], - lr=trainer_cfg["optimizer"]["lr"], - min_lr=trainer_cfg["optimizer"]["min_lr"], - ), - "constant": lambda trainer_cfg: LRScheduler( - lr=trainer_cfg["optimizer"]["lr"], - ), -} - - -def build_lr_scheduler(trainer_cfg): + match optimizer_config.name: + case optimizers.OptimizerTypeNames.NANOGPT_ADAMW: + optimizer_config: optimizers.NanoGPTAdamWConfig = optimizer_config + return optimizers.configure_nanoGPT_optimizer( + model=model, + optimizer_cfg=optimizer_config, + ) + case optimizers.OptimizerTypeNames.ADAMW: + optimizer_config: optimizers.AdamWConfig = optimizer_config + return torch.optim.AdamW( + model.parameters(), + lr=optimizer_config.lr, + betas=(optimizer_config.beta1, optimizer_config.beta2), + weight_decay=optimizer_config.weight_decay, + ) + + +def build_lr_scheduler(scheduler_cfg: schedulers.LRSchedulerConfig): """ Given the trainer config, build the LR scheduler.build_model """ - return SCHEDULER_DICT[trainer_cfg["lr_scheduler"]["name"]](trainer_cfg=trainer_cfg) + match scheduler_cfg.lr_scheduler_type: + case schedulers.LRSchedulerNames.CONSTANT: + return schedulers.LRScheduler(lr_scheduler_cfg=scheduler_cfg) + case schedulers.LRSchedulerNames.COSINE: + return schedulers.CosineLRScheduler(lr_scheduler_cfg=scheduler_cfg) -def build_dropout_scheduler(trainer_cfg): +def build_dropout_scheduler(scheduler_cfg: schedulers.DropoutSchedulerConfig): """ Given the trainer config, build the dropout scheduler. """ - if trainer_cfg["dropout_scheduler"]["dropout_type"] == "constant": - return DropoutScheduler(trainer_cfg["dropout_scheduler"]["dropout"]) - if trainer_cfg["dropout_scheduler"]["dropout_type"] == "linear": - return LinearDropoutScheduler( - start_dropout_p=trainer_cfg["dropout_scheduler"]["start_dropout_p"], - end_dropout_p=trainer_cfg["dropout_scheduler"]["end_dropout_p"], - start_iter=trainer_cfg["dropout_scheduler"]["start_iter"], - end_iter=trainer_cfg["dropout_scheduler"]["end_iter"], - ) - if trainer_cfg["dropout_scheduler"]["dropout_type"] == "triangle": - return TriangleDropoutScheduler( - dropout_trough=trainer_cfg["dropout_scheduler"]["dropout_trough"], - dropout_peak=trainer_cfg["dropout_scheduler"]["dropout_peak"], - num_iterations=trainer_cfg["dropout_scheduler"]["num_iterations"], - num_cycles=trainer_cfg["dropout_scheduler"]["num_cycles"], - ) - raise NotImplementedError( - f"dropout scheduler {trainer_cfg['dropout_scheduler']['dropout_type']} not implemented." - ) + match scheduler_cfg.dropout_type: + case "constant": + scheduler_cfg: schedulers.DropoutSchedulerConfig = scheduler_cfg + return schedulers.DropoutScheduler(dropout_cfg=scheduler_cfg) + case "linear": + scheduler_cfg: schedulers.LinearDropoutSchedulerConfig = scheduler_cfg + return schedulers.LinearDropoutScheduler( + dropout_cfg=scheduler_cfg, + ) + case "triangle": + scheduler_cfg: schedulers.TriangleDropoutSchedulerConfig = scheduler_cfg + return schedulers.TriangleDropoutScheduler( + dropout_cfg=scheduler_cfg, + ) DATASET_DICT: dict[str, DatasetInterface] = { @@ -126,7 +101,7 @@ def build_dropout_scheduler(trainer_cfg): } -def build_dataset(cfg, split): +def build_dataset(cfg, split) -> BaseDataset: """ Given the config, build the dataloader """ @@ -136,7 +111,7 @@ def build_dataset(cfg, split): DATASAMPLER_DICT = {"standard": BaseSampler} -def build_datasampler(dataset, sampling, batch_size): +def build_datasampler(dataset, sampling, batch_size) -> BaseSampler: """ Given the dataset and the sampling method, build the dataloader """ @@ -166,20 +141,20 @@ def build_loss_fn(loss_fn_name): } -def build_trainer(cfg, model, gpu_id): +def build_trainer(cfg: config.TrainConfig, model, gpu_id): """ Given a config, this function builds a trainer and all relevant components of it. """ # build optimizer - optimizer = build_optimizer(model=model, optimizer_config=cfg.trainer["optimizer"]) + optimizer = build_optimizer(model=model, optimizer_config=cfg.optimizer) # build LR scheduler - lr_scheduler = build_lr_scheduler(trainer_cfg=cfg.trainer) + lr_scheduler = build_lr_scheduler(scheduler_cfg=cfg.lr_scheduler) # build dropout scheduler - dropout_scheduler = build_dropout_scheduler(trainer_cfg=cfg.trainer) + dropout_scheduler = build_dropout_scheduler(scheduler_cfg=cfg.dropout_scheduler) # build dataloder train_dataset = build_dataset(cfg=cfg, split="train") @@ -202,23 +177,23 @@ def build_trainer(cfg, model, gpu_id): # wrap in dataloaders train_dataloader = torch.utils.data.DataLoader( dataset=train_dataset, - batch_size=cfg["trainer"]["training"]["batch_size"], + batch_size=cfg.training.batch_size, sampler=train_data_sampler, num_workers=1, ) val_dataloader = torch.utils.data.DataLoader( dataset=val_dataset, - batch_size=cfg["trainer"]["training"]["batch_size"], + batch_size=cfg.training.batch_size, sampler=val_data_sampler, num_workers=1, ) # build loss function - loss_fn = build_loss_fn(loss_fn_name=cfg.trainer["loss_fn"]["name"]) + loss_fn = build_loss_fn(loss_fn_name=cfg.loss_fn.loss_fn_type) # build the trainer - print(cfg.trainer["training"]["trainer_type"]) - trainer = TRAINER_DICT[cfg.trainer["training"]["trainer_type"]]( + print(cfg.training.trainer_type) + trainer = TRAINER_DICT[cfg.training.trainer_type]( cfg=cfg, model=model, optimizer=optimizer, diff --git a/trainers/config.py b/trainers/config.py new file mode 100644 index 00000000..3b32102f --- /dev/null +++ b/trainers/config.py @@ -0,0 +1,35 @@ +import pydantic + +from models import model_shell +from trainers.datasets import DatasetConfig +from trainers.evaluation import EvaluationConfig +from trainers.loss_fn import LossConfig +from trainers.optimizers import OptimizerConfig +from trainers.samplers import SamplerConfig +from trainers.schedulers import DropoutSchedulerConfig, LRSchedulerConfig + + +class TrainerConfig(pydantic.BaseModel): + """Base Trainer Configuration""" + + trainer_type: str + dataset: str = "openwebtext" + batch_size: int = 24 + gradient_accumulation_steps: int = 20 + max_iters: int = 30000 + log_interval: int = 10 + lr_decay_iters: int = 30000 + checkpoint_interval: int = 5000 + run_profiler: bool = False + + +class TrainConfig(pydantic.BaseModel): + model: model_shell.ModelShellConfig + training: TrainerConfig + eval: EvaluationConfig + optimizer: OptimizerConfig + lr_scheduler: LRSchedulerConfig + dropout_scheduler: DropoutSchedulerConfig + dataset: DatasetConfig + sampler: SamplerConfig + loss_fn: LossConfig diff --git a/trainers/datasets.py b/trainers/datasets.py index 875942ce..1fe244c4 100644 --- a/trainers/datasets.py +++ b/trainers/datasets.py @@ -2,48 +2,97 @@ A collection of dataloaders """ +import enum import os import numpy as np +import pydantic import torch -from tqdm import tqdm -import random -from models.embedding_models import GenericEmbedder -from trainers.utils import load_data +from high_level_configs import GeneralConfig +from models.experimental.byte_level import byte_model_shell +from models.model_shell import ModelShellConfig +from trainers.config import TrainerConfig +from trainers.utils import DatasetEnum +class DatasetTypeNames(str, enum.Enum): + """Possible dataset types""" + + STANDARD = "standard" + BYTE_POOLED = "byte_pooling" + DUAL_BYTE_POOLED = "dual_byte_pooling" + + +class DatasetConfig(pydantic.BaseModel): + """Configuration for the dataset""" + + dataset_type: DatasetTypeNames + dataset: DatasetEnum + + +class BaseDatasetConfig(pydantic.BaseModel): + """Configuration for dataset specifying the type""" + + dataset: DatasetEnum + dataset_type: DatasetTypeNames.STANDARD + + +class BytePoolingDatasetConfig(pydantic.BaseModel): + """Configuration for BytePooling dataset""" + + dataset: DatasetEnum + dataset_type: DatasetTypeNames.BYTE_POOLED + + +class DualBytePoolingDatasetConfig(pydantic.BaseModel): + """Configuration for DualBytePooling dataset""" + + dataset: DatasetEnum + dataset_type: DatasetTypeNames.DUAL_BYTE_POOLED + class DatasetInterface(torch.utils.data.Dataset): """ A basic interface to be used by the remaining datasets """ - def __init__(self, split, cfg): + + def __init__( + self, + split, + dataset_cfg: DatasetConfig, + model_cfg: ModelShellConfig, + trainer_cfg: TrainerConfig, + general_cfg: GeneralConfig, + ): """ Arguments: cfg: the train script cfg """ super().__init__() - self.cfg = cfg - self.dataset_name = self.cfg["trainer"]["dataset"] - self.context_window = self.cfg["model"]["context_window"] + self.dataset_name = dataset_cfg.dataset + self.context_window = model_cfg.context_window self.data_path = os.path.join( - self.cfg["general"]["paths"]["data_dir"], + general_cfg.paths.data_dir, self.dataset_name, - f'{self.cfg["model"]["embedder"]["tokenizer_type"]}-{self.cfg["model"]["vocab_size"]}-{self.cfg["trainer"]["dataloader"]["name"]}', - f"{split}.bin" + ( + f"{model_cfg.embedding_model.tokenizer_type}" + f"-{model_cfg.vocab_size}-{trainer_cfg.dataset}" + ), + f"{split}.bin", ) self._load_data() self.dataset_len = len(self.data) - self.context_window - def _load_data(self): """ Get data """ if not os.path.exists(self.data_path): - raise FileNotFoundError(f"{self.data_path} does not exist, preprocess the data first") + raise FileNotFoundError( + f"{self.data_path} does not exist, preprocess the data first" + ) self.data = np.memmap( self.data_path, dtype=np.uint16, @@ -55,24 +104,26 @@ def __len__(self): Return dataset length """ return self.dataset_len - + def __getitem__(self, idx): raise NotImplementedError - + + class BaseDataset(DatasetInterface): """ Simple base dataloader for standard gpt-2'esk architectures and training. """ - def __init__(self, split, cfg): - super().__init__(split, cfg) - def __getitem__(self, idx): """ Get a batch of data """ - x = torch.from_numpy((self.data[idx: idx + self.context_window]).astype(np.int64)) - y = torch.from_numpy((self.data[idx + 1: idx + 1 + self.context_window]).astype(np.int64)) + x = torch.from_numpy( + (self.data[idx : idx + self.context_window]).astype(np.int64) + ) + y = torch.from_numpy( + (self.data[idx + 1 : idx + 1 + self.context_window]).astype(np.int64) + ) return x, y @@ -80,11 +131,19 @@ class BytePoolingDataset(DatasetInterface): """ Simple byte-level dataset """ - def __init__(self, split, cfg): + + def __init__( + self, + split, + dataset_cfg: DatasetConfig, + model_cfg: byte_model_shell.ByteShellConfig, + training_cfg: TrainerConfig, + general_cfg: GeneralConfig, + ): + super().__init__(split, dataset_cfg, model_cfg, training_cfg, general_cfg) self.loading_shape = None - super().__init__(split, cfg) - # force parent init self._load_data() + self.model_cfg = model_cfg def _load_data(self): """ @@ -96,7 +155,10 @@ def _load_data(self): dtype=np.uint16, mode="r", ) - self.loading_shape = (len(data)// self.cfg["model"]["embedder"]["byte_context_window"], self.cfg["model"]["embedder"]["byte_context_window"]) + self.loading_shape = ( + len(data) // self.model_cfg.byte_context_window, + self.model_cfg.byte_context_window, + ) data = None self.data = np.memmap( self.data_path, @@ -104,34 +166,49 @@ def _load_data(self): mode="r", shape=self.loading_shape, ) - + def __getitem__(self, idx): """ Get a batch of data """ - x = torch.from_numpy((self.data[idx: idx + self.context_window]).astype(np.int64)) - y = torch.from_numpy((self.data[idx + 1: idx + 1 + self.context_window]).astype(np.int64)) + x = torch.from_numpy( + (self.data[idx : idx + self.context_window]).astype(np.int64) + ) + y = torch.from_numpy( + (self.data[idx + 1 : idx + 1 + self.context_window]).astype(np.int64) + ) return x, y - + class DualBytePooling(DatasetInterface): """ Dataset for both byte-level and higher token level tokens simultaneously """ - def __init__(self, split, cfg): + + def __init__( + self, + split, + dataset_cfg: DatasetConfig, + model_cfg: byte_model_shell.ByteShellConfig, + training_cfg: TrainerConfig, + general_cfg: GeneralConfig, + ): self.loading_shape = None # overwrite datapath data_folder = os.path.join( - cfg["general"]["paths"]["data_dir"], - cfg["trainer"]["dataset"], - f'{cfg["model"]["embedder"]["tokenizer_type"]}-{cfg["model"]["vocab_size"]}-{cfg["trainer"]["dataloader"]["name"]}', + general_cfg.paths.data_dir, + dataset_cfg.dataset, + ( + f"{model_cfg.embedding_model.tokenizer_type}" + f"-{model_cfg.vocab_size}-{training_cfg.dataset}" + ), ) self.data_path_byte = os.path.join(data_folder, f"{split}_byte.bin") self.data_path_token = os.path.join(data_folder, f"{split}_token.bin") - super().__init__(split, cfg) - + super().__init__(split, dataset_cfg, model_cfg, training_cfg, general_cfg) # force parent init self._load_data() + self.model_cfg = model_cfg def _load_data(self): """ @@ -143,7 +220,10 @@ def _load_data(self): dtype=np.uint16, mode="r", ) - self.loading_shape = (len(data)// self.cfg["model"]["embedder"]["byte_context_window"], self.cfg["model"]["embedder"]["byte_context_window"]) + self.loading_shape = ( + len(data) // self.model_cfg.byte_context_window, + self.model_cfg.byte_context_window, + ) data = None self.data_byte = np.memmap( self.data_path_byte, @@ -156,17 +236,25 @@ def _load_data(self): dtype=np.uint16, mode="r", ) - + def __getitem__(self, idx): """ Get a batch of data from both the byte and higher token level """ # get byte level batch - x_byte = torch.from_numpy((self.data_byte[idx: idx + self.context_window]).astype(np.int64)) - #y_byte = torch.from_numpy((self.data_byte[idx + 1: idx + 1 + self.context_window]).astype(np.int64)) + x_byte = torch.from_numpy( + (self.data_byte[idx : idx + self.context_window]).astype(np.int64) + ) + # y_byte = torch.from_numpy( + # (self.data_byte[idx + 1: idx + 1 + self.context_window]).astype(np.int64) + # ) # get token level batch - #x_token = torch.from_numpy((self.data_token[idx: idx + self.context_window]).astype(np.int64)) - y_token = torch.from_numpy((self.data[idx + 1: idx + 1 + self.context_window]).astype(np.int64)) - return x_byte, y_token - + # x_token = torch.from_numpy( + # ( + # self.data_token[idx: idx + self.context_window] + # ).astype(np.int64)) + y_token = torch.from_numpy( + (self.data[idx + 1 : idx + 1 + self.context_window]).astype(np.int64) + ) + return x_byte, y_token diff --git a/trainers/evaluation.py b/trainers/evaluation.py new file mode 100644 index 00000000..ef9ac6ff --- /dev/null +++ b/trainers/evaluation.py @@ -0,0 +1,67 @@ +"""Code for running samples from the evaluation benchmarks""" + +import pydantic + +from evals.load_evaluators import load_evaluator + + +class EvaluationConfig(pydantic.BaseModel): + """Configuration for Evaluation during training""" + + eval_interval: int = 2000 + eval_iters: int = 500 + evaluators: list[dict] + + +class EvaluatorConfig(pydantic.BaseModel): + """Configuration for Evaluation during training""" + + evaluator: str + + +class MCQEvaluatorConfig(EvaluatorConfig): + """Configuration for Multiple Choice Question Evaluation""" + + evaluator: str = "mcq" + benchmarks: list[str] = ["winograd", "hellaswag", "arc", "mmlu", "blimp"] + num_samples: int = 1000 + + +class PROGEvaluatorConfig(EvaluatorConfig): + """Configuration for PROG Evaluation""" + + evaluator: str = "prog" + + +def get_evaluator_config(evaluator_dict): + """Get the evaluator config""" + evaluator_name = evaluator_dict["evaluator"] + if evaluator_name == "mcq": + return MCQEvaluatorConfig(**evaluator_dict) + elif evaluator_name == "prog": + return PROGEvaluatorConfig(**evaluator_dict) + else: + raise ValueError(f"Unknown evaluator: {evaluator_name}") + + +def train_eval(eval_cfg: EvaluatorConfig, model): + """Train the model""" + evaluator_name = eval_cfg.evaluator + + if evaluator_name == "mcq": + eval_cfg: MCQEvaluatorConfig = eval_cfg + mcq_evaluator = load_evaluator( + evaluator_name, + model, + benchmarks=eval_cfg.benchmarks, + num_samples=eval_cfg.num_samples, + ) + results = mcq_evaluator.evaluate() + return results + elif evaluator_name == "prog": + eval_cfg: PROGEvaluatorConfig = eval_cfg + prog_evaluator = load_evaluator(evaluator_name, model) + results = prog_evaluator.evaluate() + return results + else: + raise ValueError(f"Unknown evaluator: {evaluator_name}") diff --git a/trainers/evaluator.py b/trainers/evaluator.py deleted file mode 100644 index 510a3671..00000000 --- a/trainers/evaluator.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Code for running samples from the evaluation benchmarks""" - -from evals.load_evaluators import load_evaluator - -def train_eval(eval_cfg, model): - """Train the model""" - evaluator_name = eval_cfg["evaluator"] - kwargs = { - key: value for key, value in eval_cfg.items() if key != "evaluator" - } - evaluator = load_evaluator(evaluator_name, model, **kwargs) - results = evaluator.evaluate() - return results diff --git a/trainers/loss_fn.py b/trainers/loss_fn.py index 694b61de..8f281a49 100644 --- a/trainers/loss_fn.py +++ b/trainers/loss_fn.py @@ -3,21 +3,37 @@ Each loss function should take in output of a model and the target labels and return the loss value. This need not be the logits.""" -import time +from enum import Enum +import pydantic import torch -def masked_cross_entropy_loss_fn(logits, y, mask=None): +class LossFNName(str, Enum): + """Enum over possible loss functions""" + + CROSS_ENTROPY = "cross_entropy" + + +class LossConfig(pydantic.BaseModel): + """The type of loss function""" + + loss_fn_type: LossFNName + + +def masked_cross_entropy_loss_fn(logits, y, _=None): """Cross Entropy Loss Function""" - # mask the pad token from y - pad_token_id = 257 + # mask the pad token from y + pad_token_id = ( + 257 # TODO: Make this not dumb... no guarantee this will be the mask token??? + ) logits = logits.view(-1, logits.size(-1)) y = y.view(-1) - #return torch.nn.functional.cross_entropy(logits, y, weight=mask, ignore_index=-1) + # return torch.nn.functional.cross_entropy(logits, y, weight=mask, ignore_index=-1) return torch.nn.functional.cross_entropy(logits, y, ignore_index=pad_token_id) -def cross_entropy_loss_fn(logits, y, mask=None): + +def cross_entropy_loss_fn(logits, y, _=None): """Cross Entropy Loss Function""" logits = logits.view(-1, logits.size(-1)) y = y.view(-1) @@ -79,8 +95,8 @@ def compute_perplexity(logits, y, char_lengths, mask=None): return (torch.exp(loss)).mean().item() -def build_loss_fn(loss_fn_type: str): +def build_loss_fn(loss_config: LossConfig): """Build the loss function""" - if loss_fn_type == "cross_entropy": - return cross_entropy_loss_fn - raise ValueError(f"Loss function {loss_fn_type} not supported.") + match loss_config.loss_fn_type: + case LossFNName.CROSS_ENTROPY: + return cross_entropy_loss_fn diff --git a/trainers/optimizer.py b/trainers/optimizers.py similarity index 57% rename from trainers/optimizer.py rename to trainers/optimizers.py index b89d74a9..e698f8c8 100644 --- a/trainers/optimizer.py +++ b/trainers/optimizers.py @@ -2,13 +2,53 @@ A collection of optimizers used for training. """ +import enum import inspect +import pydantic import torch +class OptimizerTypeNames(str, enum.Enum): + """Possible types of Optimizers""" + + ADAMW = "AdamW" + NANOGPT_ADAMW = "nanoGPTadamW" + + +class OptimizerConfig(pydantic.BaseModel): + """ + Optimizer configuration + """ + + name: OptimizerTypeNames + lr: float = 0.0006 + min_lr: float = 6.0e-05 + decay_lr: bool = True + weight_decay: float | None = 0.1 + warmup_iters: int = 5000 + optimizer_type: str + grad_clip: float = 1.0 + + +class AdamWConfig(OptimizerConfig): + """The nano gpt optimizer configuration""" + + name: OptimizerTypeNames = OptimizerTypeNames.ADAMW + beta1: float = 0.9 + beta2: float = 0.95 + + +class NanoGPTAdamWConfig(OptimizerConfig): + """The nano gpt optimizer configuration""" + + name: OptimizerTypeNames = OptimizerTypeNames.NANOGPT_ADAMW + beta1: float = 0.9 + beta2: float = 0.95 + + # pylint: disable=invalid-name -def configure_nanoGPT_optimizer(model, weight_decay, learning_rate, betas): +def configure_nanoGPT_optimizer(model, optimizer_cfg: AdamWConfig): """Configure the optimizer for NanoGPT""" # start with all of the candidate parameters param_dict = {pn: p for pn, p in model.named_parameters()} @@ -19,7 +59,7 @@ def configure_nanoGPT_optimizer(model, weight_decay, learning_rate, betas): decay_params = [p for _, p in param_dict.items() if p.dim() >= 2] nodecay_params = [p for _, p in param_dict.items() if p.dim() < 2] optim_groups = [ - {"params": decay_params, "weight_decay": weight_decay}, + {"params": decay_params, "weight_decay": optimizer_cfg.weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ] num_decay_params = sum(p.numel() for p in decay_params) @@ -37,7 +77,10 @@ def configure_nanoGPT_optimizer(model, weight_decay, learning_rate, betas): use_fused = fused_available extra_args = {"fused": True} if use_fused else {} optimizer = torch.optim.AdamW( - optim_groups, lr=learning_rate, betas=betas, **extra_args + optim_groups, + lr=optimizer_cfg.lr, + betas=(optimizer_cfg.beta1, optimizer_cfg.beta2), + **extra_args, ) print(f"using fused AdamW: {use_fused}") diff --git a/trainers/prepare.py b/trainers/prepare.py index 9f1df57a..fe33c602 100644 --- a/trainers/prepare.py +++ b/trainers/prepare.py @@ -1,26 +1,33 @@ """ Necessary to be run before training to make sure all of the data is preprcessed etc. """ -import os -import torch -import numpy as np -from tqdm import tqdm + +import os + +import numpy as np +from tqdm import tqdm + +from models.build_models import build_embedding_model +from models.embedding_models import EmbedderInterface from trainers.utils import load_data -from models.build_models import build_embedding_model +NUM_PROCESSING_BATCHES = 1024 +TokenizedDataType = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) class StandardProcessor: """ A standard processor that tokenizes the text """ - def __init__(self, embedder): + + def __init__(self, embedder: EmbedderInterface): self.embedder = embedder def process(self, example): + """Given a sample, apply the tokenization""" ids = self.embedder.tokenize_input(example["text"]) return {"ids": ids, "len": len(ids)} - + def write_tokenized_data(self, tokenized, tokenized_data_folder): """ Write the tokenized data to a file @@ -28,52 +35,46 @@ def write_tokenized_data(self, tokenized, tokenized_data_folder): for split, dset in tokenized.items(): arr_len = np.sum(dset["len"], dtype=np.uint64) filename = os.path.join(tokenized_data_folder, f"{split}.bin") - dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) arr = np.memmap( - filename, - dtype=dtype, - mode="w+", - shape=(arr_len,) + filename, dtype=TokenizedDataType, mode="w+", shape=(arr_len,) ) - total_batches = 1024 - idx = 0 - for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + for batch_idx in tqdm( + range(NUM_PROCESSING_BATCHES), desc=f"writing {filename}" + ): # Batch together samples for faster write batch = dset.shard( - num_shards=total_batches, index=batch_idx, contiguous=True + num_shards=NUM_PROCESSING_BATCHES, index=batch_idx, contiguous=True ).with_format("numpy") arr_batch = np.concatenate(batch["ids"]) # Write into mmap arr[idx : idx + len(arr_batch)] = arr_batch idx += len(arr_batch) arr.flush() - + + class ByteLevelProcessor(StandardProcessor): """ A byte-level processor that tokenizes the text """ - def __init__(self, embedder): - super().__init__(embedder) def write_tokenized_data(self, tokenized, tokenized_data_folder): for split, dset in tokenized.items(): arr_len = np.sum(dset["len"], dtype=np.uint64) filename = os.path.join(tokenized_data_folder, f"{split}.bin") - dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) arr = np.memmap( filename, - dtype=dtype, + dtype=TokenizedDataType, mode="w+", - shape=(arr_len, 12), #TODO remove hardcoding + shape=(arr_len, 12), # TODO remove hardcoding ) - total_batches = 1024 - idx = 0 - for batch_idx in tqdm(range(total_batches), desc=f"writing {filename}"): + for batch_idx in tqdm( + range(NUM_PROCESSING_BATCHES), desc=f"writing {filename}" + ): # Batch together samples for faster write batch = dset.shard( - num_shards=total_batches, index=batch_idx, contiguous=True + num_shards=NUM_PROCESSING_BATCHES, index=batch_idx, contiguous=True ).with_format("numpy") arr_batch = np.concatenate(batch["ids"]) # Write into mmap @@ -81,49 +82,47 @@ def write_tokenized_data(self, tokenized, tokenized_data_folder): idx += len(arr_batch) arr.flush() + class DualByteLevelProcessor(StandardProcessor): """ - This preprocessor stores both the byte level structure and + This preprocessor stores both the byte level structure and the standard structure to enable the training of architectures with byte-level input, but standard token output. """ - def __init__(self, embedder): - super().__init__(embedder) def process(self, example): - byte_ids, token_ids = self.embedder.tokenize_input(example["text"], return_high_level=True) + byte_ids, token_ids = self.embedder.tokenize_input( + example["text"], return_high_level=True + ) return {"byte_ids": byte_ids, "token_ids": token_ids, "len": len(token_ids)} - + def write_tokenized_data(self, tokenized, tokenized_data_folder): for split, dset in tokenized.items(): arr_len = np.sum(dset["len"], dtype=np.uint64) filename_byte = os.path.join(tokenized_data_folder, f"{split}_byte.bin") filename_token = os.path.join(tokenized_data_folder, f"{split}_token.bin") - - dtype = np.uint16 # (can do since enc.max_token_value == 50256 is < 2**16) - arr_byte = np.memmap( filename_byte, - dtype=dtype, + dtype=TokenizedDataType, mode="w+", - shape=(arr_len, 12), #TODO remove hardcoding + shape=(arr_len, 12), # TODO remove hardcoding ) arr_token = np.memmap( filename_token, - dtype=dtype, + dtype=TokenizedDataType, mode="w+", shape=(arr_len,), ) - - total_batches = 1024 - idx = 0 - for batch_idx in tqdm(range(total_batches), desc=f"writing {filename_byte} and {filename_token}"): + for batch_idx in tqdm( + range(NUM_PROCESSING_BATCHES), + desc=f"writing {filename_byte} and {filename_token}", + ): # Batch together samples for faster write batch = dset.shard( - num_shards=total_batches, index=batch_idx, contiguous=True + num_shards=NUM_PROCESSING_BATCHES, index=batch_idx, contiguous=True ).with_format("numpy") arr_batch_byte = np.concatenate(batch["byte_ids"]) arr_batch_token = np.concatenate(batch["token_ids"]) @@ -137,20 +136,16 @@ def write_tokenized_data(self, tokenized, tokenized_data_folder): arr_token.flush() - DATALOADER_PROCESSORS = { "standard": StandardProcessor, "byte_pooling": ByteLevelProcessor, - "dual_byte_pooling": DualByteLevelProcessor + "dual_byte_pooling": DualByteLevelProcessor, } - - - def prepare_data(cfg): """ - Split the data, process & tokenize it, and store + Split the data, process & tokenize it, and store it as memmap bin files """ # check if the data is already preprocessed @@ -159,20 +154,25 @@ def prepare_data(cfg): tokenized_data_folder = os.path.join( cfg["general"]["paths"]["data_dir"], dataset_name, - f'{cfg["model"]["embedder"]["tokenizer_type"]}-{cfg["model"]["vocab_size"]}-{cfg["trainer"]["dataloader"]["name"]}', + ( + f'{cfg["model"]["embedder"]["tokenizer_type"]}-{cfg["model"]["vocab_size"]}' + f'-{cfg["trainer"]["dataloader"]["name"]}' + ), ) # check if already exists (check len because some datasets use differen filenames # (i.e. dual byte level) - if os.path.exists(tokenized_data_folder) and len(os.listdir(tokenized_data_folder))!=0: + if ( + os.path.exists(tokenized_data_folder) + and len(os.listdir(tokenized_data_folder)) != 0 + ): print("Tokenized data already exists") return else: - # create the folder if it doesn't exist + # create the folder if it doesn't exist if not os.path.exists(tokenized_data_folder): os.makedirs(tokenized_data_folder) - # load embedder embedder = build_embedding_model(cfg["model"]) @@ -181,16 +181,14 @@ def prepare_data(cfg): dataset_name=dataset_name, ) - processor_object = DATALOADER_PROCESSORS[dataloader_name]( - embedder=embedder - ) + processor_object = DATALOADER_PROCESSORS[dataloader_name](embedder=embedder) # wrap in try such that half-complete files can be deleted on error try: # Get the maximum number of processors max_procs = os.cpu_count() # cap at 12 to reduce memory usage - max_procs = 1 #min(max_procs, 12) # TODO properly fix this + max_procs = 1 # min(max_procs, 12) # TODO properly fix this print(f"Using {max_procs} processors") # tokenize the dataset @@ -198,13 +196,12 @@ def prepare_data(cfg): processor_object.process, remove_columns=["text"], desc="Tokenizing dataset", - num_proc=max_procs + num_proc=max_procs, ) # concatenate all the ids in each dataset processor_object.write_tokenized_data( - tokenized=tokenized, - tokenized_data_folder=tokenized_data_folder + tokenized=tokenized, tokenized_data_folder=tokenized_data_folder ) except Exception as exc: @@ -212,5 +209,3 @@ def prepare_data(cfg): for file in os.listdir(tokenized_data_folder): os.remove(os.path.join(tokenized_data_folder, file)) raise RuntimeError("Failed to process and write data") from exc - - diff --git a/trainers/samplers.py b/trainers/samplers.py index a25c39e1..c094025c 100644 --- a/trainers/samplers.py +++ b/trainers/samplers.py @@ -1,8 +1,21 @@ """ A collection of different datasamplers. """ -import torch -from typing import Iterator, Optional, Sized + +from typing import Iterator + +import pydantic +import torch +import torch.utils.data +from pydantic import PositiveInt + + +class SamplerConfig(pydantic.BaseModel): + """Config for building the sampler + The data source should be""" + + data_source: torch.utils.data.Dataset + batch_size: PositiveInt class BaseSampler(torch.utils.data.Sampler[int]): @@ -27,13 +40,14 @@ def __iter__(self) -> Iterator[int]: Get a batch worth of random indicies """ # Generate random indices each time __iter__ is called - return iter(torch.randint( - high=self.num_samples, - size=(self.batch_size,), - dtype=torch.int64, - generator=self.generator).tolist() + return iter( + torch.randint( + high=self.num_samples, + size=(self.batch_size,), + dtype=torch.int64, + generator=self.generator, + ).tolist() ) def __len__(self) -> int: return self.num_samples - diff --git a/trainers/scheduler.py b/trainers/scheduler.py deleted file mode 100644 index 431cfca2..00000000 --- a/trainers/scheduler.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Various Scheduler""" - -import math - -import torch.nn as nn - - -class LRScheduler: - """Constant LR scheduler""" - - def __init__(self, lr): - self.lr = lr - - def get_lr(self, _): - """Return Constant LR""" - return self.lr - - def step(self, optimizer, iter_num): - """Step the scheduler""" - lr = self.get_lr(iter_num) - self.apply_lr(optimizer, lr) - return lr - - def apply_lr(self, optimizer, lr): - """Apply the learning rate to the optimizer""" - for param_group in optimizer.param_groups: - param_group["lr"] = lr - - -class CosineLRScheduler(LRScheduler): - """Basic Cosine LR scheduler with warmup and decay.""" - - def __init__(self, warmup_iters, decay_iters, lr, min_lr): - """Initialize the scheduler""" - super().__init__(lr) - self.warmup_iters = warmup_iters - self.decay_iters = decay_iters - self.lr = lr - self.min_lr = min_lr - - def get_lr(self, iter_num): - """Get the learning rate for the iteration number""" - if iter_num < self.warmup_iters: - return self.lr * iter_num / self.warmup_iters - return self.min_lr + 0.5 * (self.lr - self.min_lr) * ( - 1 + math.cos((iter_num - self.warmup_iters) / self.decay_iters * math.pi) - ) - - -class DropoutScheduler: - """Constant Dropout Scheduler""" - - def __init__(self, dropout_p=0.1): - self.dropout_p = dropout_p - - def get_dropout(self, _): - """Return Constant Dropout""" - return self.dropout_p - - def set_dropout(self, model, dropout_p): - """Set the dropout probability for the model""" - for module in model.modules(): - if isinstance(module, nn.Dropout): - module.p = dropout_p - - def step(self, model, iter_num): - """Step the scheduler""" - dropout_p = self.get_dropout(iter_num) - self.set_dropout(model, dropout_p) - return dropout_p - - -class LinearDropoutScheduler(DropoutScheduler): - """Dropout Scheduler""" - - def __init__(self, start_iter, end_iter, start_dropout_p, end_dropout_p): - """Initialize the dropout schedule""" - super().__init__(start_dropout_p) - self.start_iter = start_iter - self.end_iter = end_iter - self.start_dropout_p = start_dropout_p - self.end_dropout_p = end_dropout_p - - def get_dropout(self, iter_num): - """Return Constant Dropout""" - if iter_num < self.start_iter: - return self.start_dropout_p - if iter_num >= self.end_iter: - return self.end_dropout_p - return self.start_dropout_p + (iter_num - self.start_iter) * ( - self.end_dropout_p - self.start_dropout_p - ) / (self.end_iter - self.start_iter) - - -class TriangleDropoutScheduler(DropoutScheduler): - """Triangle Dropout Scheduler. Ref: https://arxiv.org/pdf/1506.01186""" - - def __init__( - self, - dropout_trough, - dropout_peak, - num_iterations, - num_cycles=4, - ): - """Initialize the dropout schedule - Args: - dropout_trough: The minimum dropout probability - dropout_peak: The maximum dropout probability - num_iterations: The total number of iterations - num_cycles: The number of cycles""" - super().__init__(dropout_trough) - self.dropout_trough = dropout_trough - self.dropout_peak = dropout_peak - self.total_iterations = num_iterations - self.cycle_length = self.total_iterations // num_cycles - - def get_dropout(self, iter_num): - cycle_position = iter_num % self.cycle_length - half_cycle = self.cycle_length / 2 - if cycle_position < half_cycle: - return self.dropout_trough + (self.dropout_peak - self.dropout_trough) * ( - cycle_position / half_cycle - ) - return self.dropout_peak - (self.dropout_peak - self.dropout_trough) * ( - (cycle_position - half_cycle) / half_cycle - ) - - def step(self, model, iter_num): - dropout_p = self.get_dropout(iter_num) - self.set_dropout(model, dropout_p) - return dropout_p diff --git a/trainers/schedulers.py b/trainers/schedulers.py new file mode 100644 index 00000000..7552432a --- /dev/null +++ b/trainers/schedulers.py @@ -0,0 +1,196 @@ +"""Various Schedulers for Learning Rate and Dropout""" + +import enum +import math + +import pydantic +import torch.nn as nn +from pydantic import NonNegativeInt, PositiveFloat +from typing_extensions import Annotated + + +class LRSchedulerNames(str, enum.Enum): + """Possible types of LR Scheduler""" + + CONSTANT = "constant" + COSINE = "cosine" + + +class LRSchedulerConfig(pydantic.BaseModel): + """Learning Rate Scheduler Configuration""" + + lr_scheduler_type: LRSchedulerNames.CONSTANT + lr: PositiveFloat + + +class LRScheduler: + """Constant LR scheduler""" + + def __init__(self, lr_scheduler_cfg: LRSchedulerConfig): + self.lr = lr_scheduler_cfg.lr + + def get_lr(self, _): + """Return Constant LR""" + return self.lr + + def step(self, optimizer, iter_num): + """Step the scheduler""" + lr = self.get_lr(iter_num) + self.apply_lr(optimizer, lr) + return lr + + def apply_lr(self, optimizer, lr): + """Apply the learning rate to the optimizer""" + for param_group in optimizer.param_groups: + param_group["lr"] = lr + + +class CosineLRSchedulerConfig(LRSchedulerConfig): + """Cosine LR Scheduler Configuration""" + + lr_scheduler_type: LRSchedulerNames.COSINE + warmup_iters: NonNegativeInt + decay_iters: NonNegativeInt + lr: PositiveFloat + min_lr: PositiveFloat + + +class CosineLRScheduler(LRScheduler): + """Basic Cosine LR scheduler with warmup and decay.""" + + def __init__(self, lr_scheduler_cfg: CosineLRSchedulerConfig): + """Initialize the scheduler""" + super().__init__(lr_scheduler_cfg) + self.warmup_iters = lr_scheduler_cfg.warmup_iters + self.decay_iters = lr_scheduler_cfg.decay_iters + self.lr = lr_scheduler_cfg.lr + self.min_lr = lr_scheduler_cfg.min_lr + + def get_lr(self, iter_num): + """Get the learning rate for the iteration number""" + if iter_num < self.warmup_iters: + return self.lr * iter_num / self.warmup_iters + return self.min_lr + 0.5 * (self.lr - self.min_lr) * ( + 1 + math.cos((iter_num - self.warmup_iters) / self.decay_iters * math.pi) + ) + + +class DropoutSchedulerNames(str, enum.Enum): + """Possible types of dropout scheduler. See indidivual + Types for documentation""" + + CONSTANT = "constant" + LINEAR = "linear" + TRIANGLE = "triangle" + + +ProbabilityFloat = Annotated[float, pydantic.Field(ge=0, lt=1)] + + +class DropoutSchedulerConfig(pydantic.BaseModel): + """Dropout Scheduler Configuration""" + + dropout_type: DropoutSchedulerNames + dropout: ProbabilityFloat + + +class DropoutScheduler: + """Constant Dropout Scheduler""" + + def __init__(self, dropout_cfg: DropoutSchedulerConfig): + self.dropout_p = dropout_cfg.dropout + + def get_dropout(self, _): + """Return Constant Dropout""" + return self.dropout_p + + def set_dropout(self, model, dropout_p): + """Set the dropout probability for the model""" + for module in model.modules(): + if isinstance(module, nn.Dropout): + module.p = dropout_p + + def step(self, model, iter_num): + """Step the scheduler""" + dropout_p = self.get_dropout(iter_num) + self.set_dropout(model, dropout_p) + return dropout_p + + +class LinearDropoutSchedulerConfig(DropoutSchedulerConfig): + """Linear Dropout Scheduler Configuration + Linearly moves between start_dropout_p and end_dropout_p between + start_iter and end_iter. Has the value of start_dropout_p before + and end_dropout_p after""" + + dropout_type: DropoutSchedulerNames.LINEAR + start_iter: NonNegativeInt + end_iter: NonNegativeInt + start_dropout_p: ProbabilityFloat + end_dropout_p: ProbabilityFloat + + +class LinearDropoutScheduler(DropoutScheduler): + """Dropout Scheduler""" + + def __init__(self, dropout_cfg: LinearDropoutSchedulerConfig): + """Initialize the dropout schedule""" + super().__init__(dropout_cfg) + self.start_iter = dropout_cfg.start_iter + self.end_iter = dropout_cfg.end_iter + self.start_dropout_p = dropout_cfg.start_dropout_p + self.end_dropout_p = dropout_cfg.end_dropout_p + + def get_dropout(self, iter_num): + """Return Constant Dropout""" + if iter_num < self.start_iter: + return self.start_dropout_p + if iter_num >= self.end_iter: + return self.end_dropout_p + return self.start_dropout_p + (iter_num - self.start_iter) * ( + self.end_dropout_p - self.start_dropout_p + ) / (self.end_iter - self.start_iter) + + +class TriangleDropoutSchedulerConfig(DropoutSchedulerConfig): + """ + Args: + dropout_trough: The minimum dropout probability + dropout_peak: The maximum dropout probability + num_iterations: The total number of iterations + num_cycles: The number of cycles + """ + + dropout_type: DropoutSchedulerNames.TRIANGLE + dropout_trough: ProbabilityFloat + dropout_peak: ProbabilityFloat + num_iterations: NonNegativeInt = 30000 + num_cycles: NonNegativeInt = 3 + + +class TriangleDropoutScheduler(DropoutScheduler): + """Triangle Dropout Scheduler. Ref: https://arxiv.org/pdf/1506.01186""" + + def __init__(self, dropout_cfg: TriangleDropoutSchedulerConfig): + """Initialize the dropout schedule""" + super().__init__(dropout_cfg) + self.dropout_trough = dropout_cfg.dropout_trough + self.dropout_peak = dropout_cfg.dropout_peak + self.total_iterations = dropout_cfg.num_iterations + self.cycle_length = self.total_iterations // dropout_cfg.num_cycles + + def get_dropout(self, iter_num): + cycle_position = iter_num % self.cycle_length + half_cycle = self.cycle_length / 2 + if cycle_position < half_cycle: + return self.dropout_trough + (self.dropout_peak - self.dropout_trough) * ( + cycle_position / half_cycle + ) + return self.dropout_peak - (self.dropout_peak - self.dropout_trough) * ( + (cycle_position - half_cycle) / half_cycle + ) + + def step(self, model, iter_num): + dropout_p = self.get_dropout(iter_num) + self.set_dropout(model, dropout_p) + return dropout_p diff --git a/trainers/utils.py b/trainers/utils.py index e9983802..3e435c5b 100644 --- a/trainers/utils.py +++ b/trainers/utils.py @@ -1,16 +1,17 @@ """Utilities for the trainer""" +import enum import importlib -from prettytable import PrettyTable import inspect import os import pkgutil import numpy as np import torch -from datasets import load_dataset, DatasetDict, concatenate_datasets - import torch.distributed as dist +from datasets import DatasetDict, concatenate_datasets, load_dataset +from prettytable import PrettyTable + def set_seed(seed): """Setup the trainer""" @@ -29,6 +30,7 @@ def create_folder_structure(path_config): if not os.path.exists(path_config["checkpoint_dir"]): os.makedirs(path_config["checkpoint_dir"]) + def create_stlm_data_mix(): """ A small custom datamix for STLM models containing: @@ -44,41 +46,59 @@ def create_stlm_data_mix(): # Load Python code from DeepMind Code Contests code_dataset = load_dataset("jtatman/python-code-dataset-500k")["train"] - code_dataset = code_dataset.map(lambda x: {"text": f"Instruction: {x['instruction']}\nOutput: {x['output']}"}) - + code_dataset = code_dataset.map( + lambda x: {"text": f"Instruction: {x['instruction']}\nOutput: {x['output']}"} + ) # Load technical QA style data from StackExchange openhermes = load_dataset("teknium/OpenHermes-2.5")["train"] # Transform to have a "text" column with both question and answers - openhermes = openhermes.map(lambda x: {"text": f"Question: {x['conversations'][0]['value']}\nAnswers: {x['conversations'][1]['value']}"}) + openhermes = openhermes.map( + lambda x: { + "text": ( + f"Question: {x['conversations'][0]['value']}" + f"\nAnswers: {x['conversations'][1]['value']}" + ) + } + ) # Add tiny stories tiny_stories = load_dataset("roneneldan/TinyStories")["train"] - # Calculate and print the distribution of string lengths def calculate_length_distribution(dataset): lengths = [len(item["text"]) for item in dataset] return sum(lengths), lengths - wiki_length, wiki_lengths = calculate_length_distribution(wiki) - python3_code_length, python3_code_lengths = calculate_length_distribution(code_dataset) - openhermes_length, openhermes_lengths = calculate_length_distribution(openhermes) - tiny_stories_length, tiny_stories_lengths = calculate_length_distribution(tiny_stories) + wiki_length, _ = calculate_length_distribution(wiki) + python3_code_length, _ = calculate_length_distribution(code_dataset) + openhermes_length, _ = calculate_length_distribution(openhermes) + tiny_stories_length, _ = calculate_length_distribution(tiny_stories) - total_length = wiki_length + python3_code_length + openhermes_length + tiny_stories_length + total_length = ( + wiki_length + python3_code_length + openhermes_length + tiny_stories_length + ) print(f"Wiki Text Length: {wiki_length} ({wiki_length/total_length*100:.2f}%)") - print(f"Python Code Text Length: {python3_code_length} ({python3_code_length/total_length*100:.2f}%)") - print(f"openhermes Text Length: {openhermes_length} ({openhermes_length/total_length*100:.2f}%)") + print( + f"Python Code Text Length: {python3_code_length}" + f" ({python3_code_length/total_length*100:.2f}%)" + ) + print( + f"openhermes Text Length: {openhermes_length} ({openhermes_length/total_length*100:.2f}%)" + ) # Concatenate datasets - combined_dataset = concatenate_datasets([wiki, code_dataset, openhermes, tiny_stories]) + combined_dataset = concatenate_datasets( + [wiki, code_dataset, openhermes, tiny_stories] + ) - combined_dataset = DatasetDict({ - "train": combined_dataset, - }) + combined_dataset = DatasetDict( + { + "train": combined_dataset, + } + ) return combined_dataset @@ -88,50 +108,72 @@ def load_github_code_dataset(): load and re-format the github code dataset https://huggingface.co/datasets/codeparrot/github-code """ - dataset = load_dataset("codeparrot/github-code") + dataset = load_dataset("codeparrot/github-code") # rename "code" column to "text" column dataset = dataset.map(lambda x: {"text": x["code"]})["train"] - #dataset = DatasetDict({ + # dataset = DatasetDict({ # "train": dataset, - #}) - + # }) return dataset + def load_competition_math_dataset(): """ load and re-format the competition math dataset https://huggingface.co/datasets/hendrycks/competition_math """ - dataset = load_dataset("hendrycks/competition_math") + dataset = load_dataset("hendrycks/competition_math") # format the problem and solution into a single "text" column - dataset = dataset.map(lambda x: {"text": f"Problem: {x['problem']}\nSolution: {x['solution']}"}) + dataset = dataset.map( + lambda x: {"text": f"Problem: {x['problem']}\nSolution: {x['solution']}"} + ) - dataset = DatasetDict({ - "train": dataset, - }) + dataset = DatasetDict( + { + "train": dataset, + } + ) return dataset - DATASET_DICT = { "debug": lambda: load_dataset("wikimedia/wikipedia", "20231101.simple"), "en_wiki": lambda: load_dataset("wikimedia/wikipedia", "20231101.en"), "simple_en_wiki": lambda: load_dataset("wikimedia/wikipedia", "20231101.simple"), - "babylm_100m": lambda: load_dataset("Sree1994/babylm_100M"), # https://babylm.github.io/ - "tinystories": lambda: load_dataset("roneneldan/TinyStories"), # https://huggingface.co/datasets/roneneldan/TinyStories + "babylm_100m": lambda: load_dataset( + "Sree1994/babylm_100M" + ), # https://babylm.github.io/ + "tinystories": lambda: load_dataset( + "roneneldan/TinyStories" + ), # https://huggingface.co/datasets/roneneldan/TinyStories "stlm": create_stlm_data_mix, "openhermes-2.5": lambda: load_dataset("teknium/OpenHermes-2.5"), "openwebtext": lambda: load_dataset("Skylion007/openwebtext"), - "github-code": lambda: load_github_code_dataset(), - "competition_math": lambda: load_competition_math_dataset(), + "github-code": load_github_code_dataset, + "competition_math": load_competition_math_dataset, } +class DatasetEnum(str, enum.Enum): + """All the possible dataset mixes we support""" + + DEBUG = "debug" + EN_WIKI = "en_wiki" + SIMPLE_EN_WIKI = "simple_en_wiki" + BABYLM_100M = "babylm_100m" + TINYSTORIES = "tinystories" + STLM = "stlm" + OPENHERMES_2_5 = "openhermes-2.5" + OPENWEBTEXT = "openwebtext" + GITHUB_CODE = "github-code" + COMPETITION_MATH = "competition_math" + + def load_data(dataset_name, shuffle=True): """Load the data""" assert dataset_name in DATASET_DICT, f"Dataset {dataset_name} not found!" @@ -244,16 +286,18 @@ def forward_wrapper(*args, **kwargs): model.forward = forward_wrapper + def is_dist(): """ Check if the current process is distributed. """ return dist.is_initialized() -def aggregate_value(value, device = torch.device("cuda")): + +def aggregate_value(value, device=torch.device("cuda")): """ - Since using DDP, calculation of metrics happen across all GPUs. - This function aggregate the loss across all GPUs. + Since using DDP, calculation of metrics happen across all GPUs. + This function aggregate the loss across all GPUs. """ if not is_dist(): return value @@ -262,36 +306,44 @@ def aggregate_value(value, device = torch.device("cuda")): return all_loss.item() / dist.get_world_size() # return value + def init_print_override(): - ''' - Overriding the print function is useful when running DDP. + """ + Overriding the print function is useful when running DDP. This way, only rank 0 prints to the console. - ''' + """ + # pylint: disable=redefined-builtin + # this is literally the point of this function lol + # pylint: disable=import-outside-toplevel import builtins as __builtin__ - + + # pylint: disable=import-outside-toplevel original_print = __builtin__.print def print(*args, **kwargs): - if os.getenv('GLOBAL_RANK') == '0': + if os.getenv("GLOBAL_RANK") == "0": original_print(*args, **kwargs) __builtin__.print = print - + # pylint: enable=redefined-builtin return original_print + def restore_print_override(original_print): - ''' + """ Restore the original print function. - ''' + """ + # pylint: disable=import-outside-toplevel import builtins as __builtin__ - __builtin__.print = original_print + # pylint: enable=import-outside-toplevel + __builtin__.print = original_print -# Function to print evaluation results and benchmark results def print_evaluation_results(iter_num, eval_results, benchmark_results): - headers = ['Metric', 'Value'] + """Function to print evaluation results and benchmark results""" + headers = ["Metric", "Value"] table = PrettyTable(headers) # Adding eval_results rows @@ -302,20 +354,21 @@ def print_evaluation_results(iter_num, eval_results, benchmark_results): print(f"Iteration {iter_num}") print(table) - - benchmark_table = PrettyTable(['Benchmark', 'Accuracy', "Path Conf.", "Ground Conf."]) + benchmark_table = PrettyTable( + ["Benchmark", "Accuracy", "Path Conf.", "Ground Conf."] + ) for eval_method in benchmark_results.keys(): if eval_method == "ft_qa": continue for benchmark, value in benchmark_results[eval_method].items(): - benchmark_table.add_row([ - f"{benchmark}", - value['accuracy'], - value['path_confidence'], - value['ground_confidence'] - ]) + benchmark_table.add_row( + [ + f"{benchmark}", + value["accuracy"], + value["path_confidence"], + value["ground_confidence"], + ] + ) print("Benchmark Results") print(benchmark_table) - - From 48d84759e8e348901e7c9b01ce1618486817f4a2 Mon Sep 17 00:00:00 2001 From: Dylan Hillier Date: Mon, 19 Aug 2024 18:49:17 +0800 Subject: [PATCH 3/3] updates sth --- configs/full_configs/baseline.yaml | 25 +++++++++++++------------ models/components/layers/attention.py | 2 +- 2 files changed, 14 insertions(+), 13 deletions(-) diff --git a/configs/full_configs/baseline.yaml b/configs/full_configs/baseline.yaml index 85b5e908..e7ba43e6 100644 --- a/configs/full_configs/baseline.yaml +++ b/configs/full_configs/baseline.yaml @@ -32,29 +32,30 @@ trainer: dropout_scheduler: dropout_type: constant dropout: 0.1 - dataset: openwebtext training: + dataset: openwebtext trainer_type: base_trainer batch_size: 24 gradient_accumulation_steps: 20 max_iters: 30000 lr_decay_iters: 30000 warmup_iters: 5000 - eval_interval: 2000 log_interval: 10 - eval_iters: 500 checkpoint_interval: 1000000000.0 run_profiler: false eval: - - benchmarks: - - "winograd" - - "hellaswag" - - "arc" - - "mmlu" - - "blimp" - num_samples: 1000 - evaluator: "mcq" - - evaluator: "prog" + eval_iters: 500 + eval_interval: 2000 + evaluators: + - benchmarks: + - "winograd" + - "hellaswag" + - "arc" + - "mmlu" + - "blimp" + num_samples: 1000 + evaluator: "mcq" + - evaluator: "prog" optimizer: name: nanoGPTadamW lr: 0.0006 diff --git a/models/components/layers/attention.py b/models/components/layers/attention.py index bc48a16e..0a616a89 100644 --- a/models/components/layers/attention.py +++ b/models/components/layers/attention.py @@ -15,7 +15,7 @@ class AttentionConfig(pydantic.BaseModel): Attention configuration """ - attn_type = "generic" + attn_type: str = "generic" num_heads: int bias: bool use_rope: bool