diff --git a/submissions/self_tuning/lion/__init__.py b/submissions/self_tuning/lion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/submissions/self_tuning/lion/submission.py b/submissions/self_tuning/lion/submission.py new file mode 100644 index 00000000..e86b4f6b --- /dev/null +++ b/submissions/self_tuning/lion/submission.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +import collections +from typing import Any, Dict, Iterator, List, Optional, Tuple + +import torch +import torch.distributed.nn as dist_nn +from absl import logging +from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR +from torch.optim.optimizer import Optimizer + +from algoperf import spec +from algoperf.pytorch_utils import pytorch_setup + +USE_PYTORCH_DDP = pytorch_setup()[0] + +# optimal parameters across all workloads, excluding ogbg +HPARAMS = { + 'dropout_rate': 0.1, + 'learning_rate': 2e-4, + 'one_minus_beta1': 0.05, + 'beta2': 0.98, + 'weight_decay': 0.5, + 'warmup_factor': 0.02, +} +HPARAMS = collections.namedtuple('Hyperparameters', HPARAMS.keys())(**HPARAMS) + + +# Modified from https://github.com/google/automl/blob/master/lion/lion_pytorch.py. +class Lion(Optimizer): + def __init__( + self, + params, + lr: float = 1e-4, + betas: Tuple[float, float] = (0.9, 0.99), + weight_decay: float = 0.0, + ): + if not 0.0 <= lr: + raise ValueError('Invalid learning rate: {}'.format(lr)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError('Invalid beta parameter at index 0: {}'.format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError('Invalid beta parameter at index 1: {}'.format(betas[1])) + defaults = dict(lr=lr, betas=betas, weight_decay=weight_decay) + super().__init__(params, defaults) + + @torch.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + + Args: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + + Returns: + the loss. + """ + loss = None + if closure is not None: + with torch.enable_grad(): + loss = closure() + + for group in self.param_groups: + for p in group['params']: + if p.grad is None: + continue + + # Perform stepweight decay + p.data.mul_(1 - group['lr'] * group['weight_decay']) + + grad = p.grad + state = self.state[p] + # State initialization + if len(state) == 0: + # Exponential moving average of gradient values + state['exp_avg'] = torch.zeros_like(p) + + exp_avg = state['exp_avg'] + beta1, beta2 = group['betas'] + + # Weight update + update = exp_avg * beta1 + grad * (1 - beta1) + + p.add_(update.sign_(), alpha=-group['lr']) + + # Decay the momentum running average coefficient + exp_avg.mul_(beta2).add_(grad, alpha=1 - beta2) + + return loss + + +def init_optimizer_state( + workload: spec.Workload, + model_params: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + rng: spec.RandomState, +) -> spec.OptimizerState: + """Creates a Lion optimizer and a learning rate schedule.""" + del model_state + del rng + del hyperparameters + + hyperparameters = HPARAMS + + optimizer_state = { + 'optimizer': Lion( + model_params.parameters(), + lr=HPARAMS.learning_rate, + betas=(1.0 - HPARAMS.one_minus_beta1, HPARAMS.beta2), + weight_decay=HPARAMS.weight_decay, + ) + } + + def pytorch_cosine_warmup(step_hint: int, hyperparameters, optimizer): + warmup_steps = int(hyperparameters.warmup_factor * step_hint) + warmup = LinearLR( + optimizer, start_factor=1e-10, end_factor=1.0, total_iters=warmup_steps + ) + cosine_steps = max(step_hint - warmup_steps, 1) + cosine_decay = CosineAnnealingLR(optimizer, T_max=cosine_steps) + return SequentialLR( + optimizer, schedulers=[warmup, cosine_decay], milestones=[warmup_steps] + ) + + optimizer_state['scheduler'] = pytorch_cosine_warmup( + workload.step_hint, HPARAMS, optimizer_state['optimizer'] + ) + optimizer_state['hyperparameters'] = hyperparameters + + return optimizer_state + + +def update_params( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + batch: Dict[str, spec.Tensor], + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, + train_state: Optional[Dict[str, Any]] = None, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params, updated_model_state).""" + del current_params_types + del loss_type + del train_state + del eval_results + del hyperparameters + + hyperparameters = HPARAMS + + current_model = current_param_container + current_model.train() + optimizer_state['optimizer'].zero_grad() + + logits_batch, new_model_state = workload.model_fn( + params=current_model, + augmented_and_preprocessed_input_batch=batch, + model_state=model_state, + mode=spec.ForwardPassMode.TRAIN, + rng=rng, + update_batch_norm=True, + ) + + label_smoothing = ( + hyperparameters.label_smoothing + if hasattr(HPARAMS, 'label_smoothing') + else 0.0 + ) + if hasattr(hyperparameters, 'grad_clip'): + grad_clip = hyperparameters.grad_clip + else: + grad_clip = None + + loss_dict = workload.loss_fn( + label_batch=batch['targets'], + logits_batch=logits_batch, + mask_batch=batch.get('weights'), + label_smoothing=label_smoothing, + ) + summed_loss = loss_dict['summed'] + n_valid_examples = loss_dict['n_valid_examples'] + if USE_PYTORCH_DDP: + # Use dist_nn.all_reduce to ensure correct loss and gradient scaling. + summed_loss = dist_nn.all_reduce(summed_loss) + n_valid_examples = dist_nn.all_reduce(n_valid_examples) + loss = summed_loss / n_valid_examples + + loss.backward() + + if grad_clip is not None: + torch.nn.utils.clip_grad_norm_( + current_model.parameters(), max_norm=grad_clip + ) + optimizer_state['optimizer'].step() + optimizer_state['scheduler'].step() + + # Log training metrics - loss, grad_norm, batch_size. + if global_step <= 100 or global_step % 500 == 0: + with torch.no_grad(): + parameters = [p for p in current_model.parameters() if p.grad is not None] + grad_norm = torch.norm( + torch.stack([torch.norm(p.grad.detach(), 2) for p in parameters]), 2 + ) + if workload.metrics_logger is not None: + workload.metrics_logger.append_scalar_metrics( + { + 'loss': loss.item(), + 'grad_norm': grad_norm.item(), + }, + global_step, + ) + logging.info( + '%d) loss = %0.3f, grad_norm = %0.3f', + global_step, + loss.item(), + grad_norm.item(), + ) + + return (optimizer_state, current_param_container, new_model_state) + + +def prepare_for_eval( + workload: spec.Workload, + current_param_container: spec.ParameterContainer, + current_params_types: spec.ParameterTypeTree, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + loss_type: spec.LossType, + optimizer_state: spec.OptimizerState, + eval_results: List[Tuple[int, float]], + global_step: int, + rng: spec.RandomState, +) -> spec.UpdateReturn: + """Return (updated_optimizer_state, updated_params).""" + del workload + del hyperparameters + del current_params_types + del loss_type + del eval_results + del global_step + del rng + return (optimizer_state, current_param_container, model_state) + + +def get_batch_size(workload_name): + # Return the global batch size. + if hasattr(HPARAMS, 'batch_size'): + return HPARAMS.batch_size + if workload_name == 'criteo1tb': + return 262_144 + elif workload_name == 'fastmri': + return 32 + elif workload_name == 'imagenet_resnet': + return 1024 + elif workload_name == 'imagenet_resnet_silu': + return 512 + elif workload_name == 'imagenet_resnet_gelu': + return 512 + elif workload_name == 'imagenet_vit': + return 1024 + elif workload_name == 'librispeech_conformer': + return 256 + elif workload_name == 'librispeech_deepspeech': + return 256 + elif workload_name == 'ogbg': + return 512 + elif workload_name == 'wmt': + return 128 + elif workload_name == 'mnist': + return 16 + else: + raise ValueError(f'Unsupported workload name: {workload_name}.') + + +def data_selection( + workload: spec.Workload, + input_queue: Iterator[Dict[str, spec.Tensor]], + optimizer_state: spec.OptimizerState, + current_param_container: spec.ParameterContainer, + model_state: spec.ModelAuxiliaryState, + hyperparameters: spec.Hyperparameters, + global_step: int, + rng: spec.RandomState, +) -> Dict[str, spec.Tensor]: + """Select data from the infinitely repeating, pre-shuffled input queue. + Each element of the queue is a batch of training examples and labels. + """ + del workload + del optimizer_state + del current_param_container + del model_state + del hyperparameters + del global_step + del rng + batch = next(input_queue) + return batch diff --git a/submissions/self_tuning/lion/submission_info.yml b/submissions/self_tuning/lion/submission_info.yml new file mode 100644 index 00000000..f228ca19 --- /dev/null +++ b/submissions/self_tuning/lion/submission_info.yml @@ -0,0 +1,155 @@ +name: algoclean +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - bzip2=1.0.8=h5eee18b_6 + - ca-certificates=2025.2.25=h06a4308_0 + - expat=2.7.1=h6a678d5_0 + - ld_impl_linux-64=2.40=h12ee557_0 + - libffi=3.4.4=h6a678d5_1 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - libuuid=1.41.5=h5eee18b_0 + - libxcb=1.17.0=h9b100fa_0 + - ncurses=6.5=h7934f7d_0 + - openssl=3.0.17=h5eee18b_0 + - pthread-stubs=0.3=h0ce48e5_1 + - python=3.11.13=h1a3bd86_0 + - readline=8.2=h5eee18b_0 + - setuptools=78.1.1=py311h06a4308_0 + - sqlite=3.50.2=hb25bd0a_1 + - tk=8.6.14=h993c535_1 + - wheel=0.45.1=py311h06a4308_0 + - xorg-libx11=1.8.12=h9b100fa_1 + - xorg-libxau=1.0.12=h9b100fa_0 + - xorg-libxdmcp=1.1.5=h9b100fa_0 + - xorg-xorgproto=2024.1=h5eee18b_1 + - xz=5.6.4=h5eee18b_1 + - zlib=1.2.13=h5eee18b_1 + - pip: + - absl-py==2.1.0 + - algoperf==0.5.1.dev494+g22b07c82f + - array-record==0.7.2 + - astunparse==1.6.3 + - attrs==25.3.0 + - certifi==2025.8.3 + - charset-normalizer==3.4.2 + - chex==0.1.86 + - click==8.2.1 + - cloudpickle==3.1.1 + - clu==0.0.12 + - contourpy==1.3.3 + - cycler==0.12.1 + - decorator==5.2.1 + - dm-tree==0.1.9 + - docker==7.1.0 + - docstring-parser==0.17.0 + - einops==0.8.1 + - etils==1.13.0 + - filelock==3.18.0 + - flatbuffers==25.2.10 + - flax==0.8.4 + - fonttools==4.59.0 + - fsspec==2025.7.0 + - gast==0.6.0 + - google-pasta==0.2.0 + - googleapis-common-protos==1.70.0 + - gputil==1.4.0 + - grpcio==1.74.0 + - h5py==3.12.0 + - humanize==4.12.3 + - idna==3.10 + - imageio==2.37.0 + - immutabledict==4.2.1 + - importlib-resources==6.5.2 + - jax==0.4.28 + - jaxlib==0.4.28 + - jinja2==3.1.6 + - joblib==1.5.1 + - jraph==0.0.6.dev0 + - keras==3.11.1 + - kiwisolver==1.4.8 + - lazy-loader==0.4 + - libclang==18.1.1 + - markdown==3.8.2 + - markdown-it-py==3.0.0 + - markupsafe==3.0.2 + - matplotlib==3.10.5 + - mdurl==0.1.2 + - ml-collections==1.1.0 + - ml-dtypes==0.4.1 + - mpmath==1.3.0 + - msgpack==1.1.1 + - namex==0.1.0 + - nest-asyncio==1.6.0 + - networkx==3.2.1 + - numpy==2.0.2 + - nvidia-cublas-cu12==12.4.5.8 + - nvidia-cuda-cupti-cu12==12.4.127 + - nvidia-cuda-nvrtc-cu12==12.4.127 + - nvidia-cuda-runtime-cu12==12.4.127 + - nvidia-cudnn-cu12==9.1.0.70 + - nvidia-cufft-cu12==11.2.1.3 + - nvidia-curand-cu12==10.3.5.147 + - nvidia-cusolver-cu12==11.6.1.9 + - nvidia-cusparse-cu12==12.3.1.170 + - nvidia-nccl-cu12==2.21.5 + - nvidia-nvjitlink-cu12==12.4.127 + - nvidia-nvtx-cu12==12.4.127 + - opt-einsum==3.4.0 + - optax==0.2.2 + - optree==0.17.0 + - orbax-checkpoint==0.6.4 + - packaging==25.0 + - pandas==2.3.1 + - pillow==11.3.0 + - pip==25.2 + - promise==2.3 + - protobuf==4.25.5 + - psutil==6.1.0 + - pyarrow==21.0.0 + - pydub==0.25.1 + - pygments==2.19.2 + - pyparsing==3.2.3 + - python-dateutil==2.9.0.post0 + - pytz==2025.2 + - pyyaml==6.0.2 + - requests==2.32.4 + - rich==14.1.0 + - ruff==0.12.7 + - scikit-image==0.24.0 + - scikit-learn==1.5.2 + - scipy==1.16.1 + - sentencepiece==0.2.0 + - simple-parsing==0.1.7 + - six==1.17.0 + - sympy==1.13.1 + - tabulate==0.9.0 + - tensorboard==2.18.0 + - tensorboard-data-server==0.7.2 + - tensorflow==2.18.0 + - tensorflow-datasets==4.9.7 + - tensorflow-io-gcs-filesystem==0.37.1 + - tensorflow-metadata==1.17.2 + - tensorflow-probability==0.20.0 + - tensorflow-text==2.18.0 + - tensorstore==0.1.74 + - termcolor==3.1.0 + - threadpoolctl==3.6.0 + - tifffile==2025.6.11 + - toml==0.10.2 + - toolz==1.0.0 + - torch==2.5.1 + - torchvision==0.20.1 + - tqdm==4.67.1 + - triton==3.1.0 + - typing-extensions==4.14.1 + - tzdata==2025.2 + - urllib3==2.5.0 + - werkzeug==3.1.3 + - wrapt==1.17.2 + - zipp==3.23.0 +prefix: /private/home/axyang/miniconda3/envs/algoclean