diff --git a/examples/multiple_physics_pretraining/LICENSE b/examples/multiple_physics_pretraining/LICENSE new file mode 100644 index 0000000..706fe7b --- /dev/null +++ b/examples/multiple_physics_pretraining/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/examples/multiple_physics_pretraining/README.md b/examples/multiple_physics_pretraining/README.md new file mode 100644 index 0000000..3fbc301 --- /dev/null +++ b/examples/multiple_physics_pretraining/README.md @@ -0,0 +1,136 @@ +# Multiple Physics Pretraining (MPP) + +This example integrates the [MPP](https://openreview.net/forum?id=DKSI3bULiZ) (Multiple Physics Pretraining) model into PaddleCFD. + +Multiple Physics Pretraining is a pretraining strategy in which multiple sets of dynamics are jointly normalized and embedded into a single space for prediction. It uses an **AViT** (Axial Vision Transformer) architecture that learns multiple physics simultaneously through pretraining, enabling strong finetuning performance even across different physics domains. + +Paper: "Multiple Physics Pretraining for Spatiotemporal Surrogate Models" (NeurIPS 2024) + +Below are quick instructions on paddle, full readme please visit https://github.com/PolymathicAI/multiple_physics_pretraining + +## Installation + +```bash +pip install ppcfd +pip install wandb # optional +``` + +## Quick Start + +### Import the model + +```python +from ppcfd.models.multiple_physics_pretraining import AViT, build_avit +``` + +### Train (single device) + +```bash +python train_basic.py --run_name my_experiment --config basic_config --yaml_config config/mpp_avit_ti_config.yaml +``` + +### Finetune from pretrained weights + +Original PyTorch pretrained weights are available at: +https://drive.google.com/drive/folders/1Qaqa-RnzUDOO8-Gi4zlf4BE53SfWqDwx + +To use them in PaddleCFD, first convert to PaddlePaddle format: + +```bash +python convert_torch_weights.py \ + --yaml_config config/mpp_avit_s_config.yaml \ + --config basic_config \ + --weights path/to/MPP_AViT_S.tar \ + --output models_paddle/MPP_AViT_S.pdparams +``` + +Then finetune: + +```bash +python train_basic.py --run_name my_finetune --config finetune --yaml_config config/mpp_avit_s_config.yaml +``` + +### Inference + +If needed, use follow code to generate test input. + +```bash +python multiple_physics_pretraining/generate_forward_case.py --output /tmp/case.npz --labels 0,1,2 --bcs 0,0 --output ./forward_case.npz +``` + +Then run a test forward case: + +```bash +python forward_pretrained.py \ + --yaml_config config/mpp_avit_s_config.yaml \ + --config basic_config \ + --weights path/to/checkpoint.pdparams \ + --case_npz path/to/input.npz \ + --output path/to/output.npz +``` + +## Model Variants + +| Variant | embed_dim | num_heads | processor_blocks | +| --------- | --------- | --------- | ---------------- | +| Ti (Tiny) | 192 | 3 | 12 | +| S (Small) | 384 | 6 | 12 | +| B (Base) | 768 | 12 | 12 | +| L (Large) | 1024 | 16 | 24 | + +Config files are provided in `config/` for each variant. Use the `basic_config` namespace for pretraining and `finetune` for finetuning. + +## Directory Structure + +``` +examples/multiple_physics_pretraining/ +├── config/ # YAML configuration files (Ti/S/B/L) +├── train_basic.py # Training script +├── forward_pretrained.py # Inference script +├── convert_torch_weights.py # PyTorch -> PaddlePaddle weight conversion +├── requirements.txt # Additional dependencies +├── LICENSE # MIT License +└── README.md # This file + +ppcfd/models/multiple_physics_pretraining/ +├── avit.py # AViT model definition +├── shared_modules.py # MLP, Attention, PositionBias +├── spatial_modules.py # AxialAttention, hMLP stem/output +├── time_modules.py # Temporal attention block +├── mixed_modules.py # SpaceTimeBlock combiner +├── DropPath_util.py # Stochastic depth +├── paddle_utils.py # PaddlePaddle utilities +├── utils/ # Training utilities +│ ├── YParams.py # YAML config parser +│ ├── logging_utils.py # Logging +│ ├── schedulers.py # LR scheduler +│ ├── adan_paddle.py # Adan optimizer +│ ├── dadapt_adam_paddle.py # DAdaptAdam optimizer +│ ├── dadapt_adan_paddle.py # DAdaptAdan optimizer +│ └── custom_optimizer_base.py # Optimizer base class +└── data_utils/ # Data loading + ├── datasets.py # MixedDataset, dataset registry + ├── hdf5_datasets.py # HDF5 dataset classes (SWE, NS, etc.) + └── mixed_dset_sampler.py # Multi-dataset sampler +``` + +## Adding Datasets + +Datasets must return data in `(Batch, Time, Channel, H, W)` format and extend `BaseHDF5DirectoryDataset`. See `data_utils/hdf5_datasets.py` for examples. + +1. Define your dataset class in `ppcfd/models/multiple_physics_pretraining/data_utils/hdf5_datasets.py` +2. Register it in `DSET_NAME_TO_OBJECT` in `datasets.py` +3. Add data paths to the config YAML file + +## Citing + +```bibtex +@inproceedings{ + mccabe2024multiple, + title={Multiple Physics Pretraining for Spatiotemporal Surrogate Models}, + author={Michael McCabe and Bruno R{\'e}galdo-Saint Blancard and others}, + booktitle={NeurIPS}, + year={2024}, + url={https://openreview.net/forum?id=DKSI3bULiZ} +} +``` diff --git a/examples/multiple_physics_pretraining/config/mpp_avit_L_config.yaml b/examples/multiple_physics_pretraining/config/mpp_avit_L_config.yaml new file mode 100644 index 0000000..d163e4c --- /dev/null +++ b/examples/multiple_physics_pretraining/config/mpp_avit_L_config.yaml @@ -0,0 +1,99 @@ +basic_config: &basic_config # Run settings + log_to_wandb: !!bool True # Use wandb integration + log_to_screen: !!bool True # Log progress to screen. + save_checkpoint: !!bool True # Save checkpoints + checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss + debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging + true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs + num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio + enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now + compile: !!bool False # Compile model - Does not currently work + gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory + exp_dir: "./" # Output path, modify as needed + log_interval: 1 # How often to log - Don't think this is actually implemented + pretrained: !!bool False # Whether to load a pretrained model + # wandb settings + project: "project" + group: "debugging" + entity: "entity" + # Training settings + drop_path: 0.1 + batch_size: 1 + max_epochs: 500 + scheduler_epochs: -1 + epoch_size: 2000 # Artificial epoch size + rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1 + optimizer: "adan" # adam, adan, whatever else i end up adding - adan did better on HP sweep + scheduler: "cosine" # Only cosine implemented + warmup_steps: 1000 # Warmup when not using DAdapt + learning_rate: -1 # -1 means use DAdapt + weight_decay: 1e-3 + n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused + state_names: ["Pressure", "Vx", "Vy", "Density", "Vx", "Vy", "Density", "Pressure"] # Should be sorted + dt: 1 # Striding of data - Not currently implemented > 1 + n_steps: 16 # Length of history to include in input + enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception. + accum_grad: 5 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum + # Model settings + model_type: "avit" # Only option so far + block_type: "axial" # Which type of block to use - if axial, next two fields must be set to define axial ops + time_type: "attention" # Conditional on block type + space_type: "axial_attention" # Conditional on block type + tie_fields: !!bool False # Whether to use 1 embedding per field per data + embed_dim: 1024 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L + num_heads: 16 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L + processor_blocks: 24 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L + patch_size: [16, 16] # Actually currently hardcoded at 16 + bias_type: "rel" # Options rel, continuous, none + # Data settings + train_val_test: [.8, .1, .1] + augmentation: !!bool False # Augmentation not implemented + use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets + tie_batches: !!bool False # Force everything in batch to come from one dset + extended_names: !!bool False # Whether to use extended names - not currently implemented + embedding_offset: 0 # Use when adding extra finetuning fields + train_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + valid_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +finetune: &finetune + <<: *basic_config + max_epochs: 500 + train_val_test: [.8, .1, .1] + accum_grad: 1 + pretrained: !!bool True + group: "debugging" + pretrained_ckpt_path: "/B16-noNS/training_checkpoints/ckpt.tar" + train_data_paths: [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + valid_data_paths: # These are the same for all configs - uses split according to train_val_test + [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + embedding_offset: 0 # Number of fields in original model - FT fields start after this + freeze_middle: !!bool False # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +frozen: &frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + +less_frozen: &less_frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool True diff --git a/examples/multiple_physics_pretraining/config/mpp_avit_b_config.yaml b/examples/multiple_physics_pretraining/config/mpp_avit_b_config.yaml new file mode 100644 index 0000000..c093e8d --- /dev/null +++ b/examples/multiple_physics_pretraining/config/mpp_avit_b_config.yaml @@ -0,0 +1,99 @@ +basic_config: &basic_config # Run settings + log_to_wandb: !!bool True # Use wandb integration + log_to_screen: !!bool True # Log progress to screen. + save_checkpoint: !!bool True # Save checkpoints + checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss + debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging + true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs + num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio + enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now + compile: !!bool False # Compile model - Does not currently work + gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory + exp_dir: "./" # Output path, modify as needed + log_interval: 1 # How often to log - Don't think this is actually implemented + pretrained: !!bool False # Whether to load a pretrained model + # wandb settings + project: "project" + group: "debugging" + entity: "entity" + # Training settings + drop_path: 0.1 + batch_size: 1 + max_epochs: 500 + scheduler_epochs: -1 + epoch_size: 2000 # Artificial epoch size + rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1 + optimizer: "adan" # adam, adan, whatever else i end up adding - adan did better on HP sweep + scheduler: "cosine" # Only cosine implemented + warmup_steps: 1000 # Warmup when not using DAdapt + learning_rate: -1 # -1 means use DAdapt + weight_decay: 1e-3 + n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused + state_names: ["Pressure", "Vx", "Vy", "Density", "Vx", "Vy", "Density", "Pressure"] # Should be sorted + dt: 1 # Striding of data - Not currently implemented > 1 + n_steps: 16 # Length of history to include in input + enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception. + accum_grad: 5 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum + # Model settings + model_type: "avit" # Only option so far + block_type: "axial" # Which type of block to use - if axial, next two fields must be set to define axial ops + time_type: "attention" # Conditional on block type + space_type: "axial_attention" # Conditional on block type + tie_fields: !!bool False # Whether to use 1 embedding per field per data + embed_dim: 768 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L + num_heads: 12 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L + processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L + patch_size: [16, 16] # Actually currently hardcoded at 16 + bias_type: "rel" # Options rel, continuous, none + # Data settings + train_val_test: [.8, .1, .1] + augmentation: !!bool False # Augmentation not implemented + use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets + tie_batches: !!bool False # Force everything in batch to come from one dset + extended_names: !!bool False # Whether to use extended names - not currently implemented + embedding_offset: 0 # Use when adding extra finetuning fields + train_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + valid_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +finetune: &finetune + <<: *basic_config + max_epochs: 500 + train_val_test: [.8, .1, .1] + accum_grad: 1 + pretrained: !!bool True + group: "debugging" + pretrained_ckpt_path: "/B16-noNS/training_checkpoints/ckpt.tar" + train_data_paths: [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + valid_data_paths: # These are the same for all configs - uses split according to train_val_test + [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + embedding_offset: 0 # Number of fields in original model - FT fields start after this + freeze_middle: !!bool False # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +frozen: &frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + +less_frozen: &less_frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool True diff --git a/examples/multiple_physics_pretraining/config/mpp_avit_s_config.yaml b/examples/multiple_physics_pretraining/config/mpp_avit_s_config.yaml new file mode 100644 index 0000000..806bf7a --- /dev/null +++ b/examples/multiple_physics_pretraining/config/mpp_avit_s_config.yaml @@ -0,0 +1,99 @@ +basic_config: &basic_config # Run settings + log_to_wandb: !!bool True # Use wandb integration + log_to_screen: !!bool True # Log progress to screen. + save_checkpoint: !!bool True # Save checkpoints + checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss + debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging + true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs + num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio + enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now + compile: !!bool False # Compile model - Does not currently work + gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory + exp_dir: "./" # Output path, modify as needed + log_interval: 1 # How often to log - Don't think this is actually implemented + pretrained: !!bool False # Whether to load a pretrained model + # wandb settings + project: "project" + group: "debugging" + entity: "entity" + # Training settings + drop_path: 0.1 + batch_size: 1 + max_epochs: 500 + scheduler_epochs: -1 + epoch_size: 2000 # Artificial epoch size + rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1 + optimizer: "adan" # adam, adan, whatever else i end up adding - adan did better on HP sweep + scheduler: "cosine" # Only cosine implemented + warmup_steps: 1000 # Warmup when not using DAdapt + learning_rate: -1 # -1 means use DAdapt + weight_decay: 1e-3 + n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused + state_names: ["Pressure", "Vx", "Vy", "Density", "Vx", "Vy", "Density", "Pressure"] # Should be sorted + dt: 1 # Striding of data - Not currently implemented > 1 + n_steps: 16 # Length of history to include in input + enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception. + accum_grad: 5 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum + # Model settings + model_type: "avit" # Only option so far + block_type: "axial" # Which type of block to use - if axial, next two fields must be set to define axial ops + time_type: "attention" # Conditional on block type + space_type: "axial_attention" # Conditional on block type + tie_fields: !!bool False # Whether to use 1 embedding per field per data + embed_dim: 384 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L + num_heads: 6 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L + processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L + patch_size: [16, 16] # Actually currently hardcoded at 16 + bias_type: "rel" # Options rel, continuous, none + # Data settings + train_val_test: [.8, .1, .1] + augmentation: !!bool False # Augmentation not implemented + use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets + tie_batches: !!bool False # Force everything in batch to come from one dset + extended_names: !!bool False # Whether to use extended names - not currently implemented + embedding_offset: 0 # Use when adding extra finetuning fields + train_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + valid_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +finetune: &finetune + <<: *basic_config + max_epochs: 500 + train_val_test: [.8, .1, .1] + accum_grad: 1 + pretrained: !!bool True + group: "debugging" + pretrained_ckpt_path: "/B16-noNS/training_checkpoints/ckpt.tar" + train_data_paths: [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + valid_data_paths: # These are the same for all configs - uses split according to train_val_test + [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + embedding_offset: 0 # Number of fields in original model - FT fields start after this + freeze_middle: !!bool False # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +frozen: &frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + +less_frozen: &less_frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool True diff --git a/examples/multiple_physics_pretraining/config/mpp_avit_ti_config.yaml b/examples/multiple_physics_pretraining/config/mpp_avit_ti_config.yaml new file mode 100644 index 0000000..8641cbb --- /dev/null +++ b/examples/multiple_physics_pretraining/config/mpp_avit_ti_config.yaml @@ -0,0 +1,99 @@ +basic_config: &basic_config # Run settings + log_to_wandb: !!bool False # Use wandb integration + log_to_screen: !!bool True # Log progress to screen. + save_checkpoint: !!bool True # Save checkpoints + checkpoint_save_interval: 10 # Save every # epochs - also saves "best" according to val loss + debug_grad: !!bool True # Compute gradient/step_sizes/ect for debugging + true_time: !!bool False # Debugging setting - sets num workers to zero and activates syncs + num_data_workers: 6 # Generally pulling 8 cpu per process, so using 6 for DL - not sure if best ratio + enable_amp: !!bool False # Use automatic mixed precision - blows up with low variance fields right now + compile: !!bool False # Compile model - Does not currently work + gradient_checkpointing: !!bool False # Whether to use gradient checkpointing - Slow, but lower memory + exp_dir: "~/MPP" # Output path + log_interval: 1 # How often to log - Don't think this is actually implemented + pretrained: !!bool False # Whether to load a pretrained model + # wandb settings + project: "project" + group: "debugging" + entity: "entity" + # Training settings + drop_path: 0.1 + batch_size: 1 + max_epochs: 500 + scheduler_epochs: -1 + epoch_size: 2000 # Artificial epoch size + rescale_gradients: !!bool False # Activate hook that scales block gradients to norm 1 + optimizer: "adan" # adam, adan, whatever else i end up adding - adan did better on HP sweep + scheduler: "cosine" # Only cosine implemented + warmup_steps: 1000 # Warmup when not using DAdapt + learning_rate: -1 # -1 means use DAdapt + weight_decay: 1e-3 + n_states: 12 # Number of state variables across the datasets - Can be larger than real number and things will just go unused + state_names: ["Pressure", "Vx", "Vy", "Density", "Vx", "Vy", "Density", "Pressure"] # Should be sorted + dt: 1 # Striding of data - Not currently implemented > 1 + n_steps: 16 # Length of history to include in input + enforce_max_steps: !!bool False # If false and n_steps > dataset steps, use dataset steps. Otherwise, raise Exception. + accum_grad: 5 # Real batch size is accum * batch_size, real steps/"epoch" is epoch_size / accum + # Model settings + model_type: "avit" # Only option so far + block_type: "axial" # Which type of block to use - if axial, next two fields must be set to define axial ops + time_type: "attention" # Conditional on block type + space_type: "axial_attention" # Conditional on block type + tie_fields: !!bool False # Whether to use 1 embedding per field per data + embed_dim: 192 # Dimension of internal representation - 192/384/768/1024 for Ti/S/B/L + num_heads: 3 # Number of heads for attention - 3/6/12/16 for Ti/S/B/L + processor_blocks: 12 # Number of transformer blocks in the backbone - 12/12/12/24 for Ti/S/B/L + patch_size: [16, 16] # Actually currently hardcoded at 16 + bias_type: "rel" # Options rel, continuous, none + # Data settings + train_val_test: [.8, .1, .1] + augmentation: !!bool False # Augmentation not implemented + use_all_fields: !!bool True # Prepopulate the field metadata dictionary from dictionary in datasets + tie_batches: !!bool False # Force everything in batch to come from one dset + extended_names: !!bool False # Whether to use extended names - not currently implemented + embedding_offset: 0 # Use when adding extra finetuning fields + train_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + valid_data_paths: + [ + ["~/PDEBench/2D/shallow-water", "swe", ""], + ["~/PDEBench/2D/NS_incom", "incompNS", ""], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "128"], + ["~/PDEBench/2D/CFD/2D_Train_Rand", compNS, "512"], + ["~/PDEBench/2D/CFD/2D_Train_Turb", compNS, ""], + ["~/PDEBench/2D/diffusion-reaction", "diffre2d", ""], + ] + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +finetune: &finetune + <<: *basic_config + max_epochs: 500 + train_val_test: [.8, .1, .1] + accum_grad: 1 + pretrained: !!bool True + group: "debugging" + pretrained_ckpt_path: "/B16-noNS/training_checkpoints/ckpt.tar" + train_data_paths: [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + valid_data_paths: # These are the same for all configs - uses split according to train_val_test + [["/PDEBench/2D/CFD/2D_Train_Turb", "compNS", "M1.0"]] + embedding_offset: 0 # Number of fields in original model - FT fields start after this + freeze_middle: !!bool False # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + append_datasets: [] # List of datasets to append to the input/output projections for finetuning + +frozen: &frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool False + +less_frozen: &less_frozen + <<: *finetune + freeze_middle: !!bool True # Whether to freeze the middle layers of the model + freeze_processor: !!bool True diff --git a/examples/multiple_physics_pretraining/convert_torch_weights.py b/examples/multiple_physics_pretraining/convert_torch_weights.py new file mode 100644 index 0000000..8b1b381 --- /dev/null +++ b/examples/multiple_physics_pretraining/convert_torch_weights.py @@ -0,0 +1,180 @@ +import argparse +from pathlib import Path + +import numpy as np +import yaml + + +REPO_ROOT = Path(__file__).resolve().parents[1] + + +def load_yaml_config(yaml_path, config_name): + with open(yaml_path, "r", encoding="utf-8") as handle: + payload = yaml.safe_load(handle) + if config_name not in payload: + raise KeyError(f"Config '{config_name}' not found in {yaml_path}") + return argparse.Namespace(**payload[config_name]) + + +def extract_model_state_dict(payload): + if not isinstance(payload, dict): + raise TypeError("Checkpoint payload must be a mapping") + if "model_state" in payload: + return payload["model_state"] + return payload + + +def strip_module_prefix(state_dict): + keys = list(state_dict.keys()) + if keys and all(key.startswith("module.") for key in keys): + return {key[7:]: value for key, value in state_dict.items()} + return state_dict + + +def build_default_output_path(torch_weights_path): + source_path = Path(torch_weights_path) + return REPO_ROOT / "models_paddle" / f"{source_path.stem}.pdparams" + + +def probe_paddle_load(weight_path): + import paddle + + try: + paddle.load(str(weight_path)) + return True, "" + except Exception as exc: + return False, str(exc) + + +def to_numpy_array(value): + if isinstance(value, np.ndarray): + return value + if hasattr(value, "detach"): + value = value.detach() + if hasattr(value, "cpu"): + value = value.cpu() + if hasattr(value, "numpy"): + return value.numpy() + return np.asarray(value) + + +def convert_array_for_target(source_array, target_shape, target_dtype, key): + source_array = np.asarray(source_array) + if tuple(source_array.shape) == tuple(target_shape): + return source_array.astype(target_dtype, copy=False) + if source_array.ndim == 2 and tuple(source_array.T.shape) == tuple(target_shape): + return source_array.T.astype(target_dtype, copy=False) + raise ValueError( + f"shape mismatch for {key}: source {tuple(source_array.shape)} vs target {tuple(target_shape)}" + ) + + +def get_source_value_for_target_key(source_state_dict, target_key): + if target_key in source_state_dict: + return source_state_dict[target_key] + if target_key.endswith(".scale"): + fallback_key = f"{target_key[:-6]}.weight" + if fallback_key in source_state_dict: + return source_state_dict[fallback_key] + raise KeyError(target_key) + + +def convert_state_dict(source_state_dict, target_metadata): + source_keys = set(source_state_dict.keys()) + missing_keys = [] + used_source_keys = set() + + for target_key in target_metadata: + try: + source_value = get_source_value_for_target_key(source_state_dict, target_key) + for candidate_key, candidate_value in source_state_dict.items(): + if candidate_value is source_value: + used_source_keys.add(candidate_key) + break + except KeyError: + missing_keys.append(target_key) + + extra_keys = sorted(source_keys - used_source_keys) + + if missing_keys: + raise ValueError(f"Missing keys: {missing_keys}") + if extra_keys: + raise ValueError(f"Unexpected keys: {extra_keys}") + + converted = {} + for key, meta in target_metadata.items(): + converted[key] = convert_array_for_target( + source_array=to_numpy_array( + get_source_value_for_target_key(source_state_dict, key) + ), + target_shape=meta["shape"], + target_dtype=meta["dtype"], + key=key, + ) + return converted + + +def collect_target_metadata(state_dict): + metadata = {} + for key, value in state_dict.items(): + array = to_numpy_array(value) + metadata[key] = {"shape": tuple(array.shape), "dtype": str(array.dtype)} + return metadata + + +def build_paddle_model(params): + from ppcfd.models.multiple_physics_pretraining.avit import build_avit + + return build_avit(params) + + +def convert_weights(yaml_config, config_name, torch_weights, output_path): + import paddle + import torch + + is_paddle_file, _ = probe_paddle_load(torch_weights) + if is_paddle_file: + raise ValueError(f"{torch_weights} is already a Paddle-serializable file") + + params = load_yaml_config(yaml_config, config_name) + paddle_model = build_paddle_model(params) + target_metadata = collect_target_metadata(paddle_model.state_dict()) + + checkpoint = torch.load(torch_weights, map_location="cpu") + source_state = strip_module_prefix(extract_model_state_dict(checkpoint)) + converted_state = convert_state_dict(source_state, target_metadata) + paddle_state = { + key: paddle.to_tensor(value, dtype=target_metadata[key]["dtype"]) + for key, value in converted_state.items() + } + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + paddle.save(paddle_state, str(output_path)) + return output_path + + +def build_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--yaml_config", required=True, type=Path) + parser.add_argument("--config", default="basic_config", type=str) + parser.add_argument("--torch_weights", required=True, type=Path) + parser.add_argument("--output", default=None, type=Path) + return parser + + +def main(): + parser = build_arg_parser() + args = parser.parse_args() + output_path = args.output or build_default_output_path(args.torch_weights) + converted_path = convert_weights( + yaml_config=args.yaml_config, + config_name=args.config, + torch_weights=args.torch_weights, + output_path=output_path, + ) + print(f"Saved Paddle weights to {converted_path}") + + +if __name__ == "__main__": + main() diff --git a/examples/multiple_physics_pretraining/forward_pretrained.py b/examples/multiple_physics_pretraining/forward_pretrained.py new file mode 100644 index 0000000..c605af9 --- /dev/null +++ b/examples/multiple_physics_pretraining/forward_pretrained.py @@ -0,0 +1,113 @@ +import argparse +from pathlib import Path + +import numpy as np +import yaml + + +def load_yaml_config(yaml_path, config_name): + with open(yaml_path, "r", encoding="utf-8") as handle: + payload = yaml.safe_load(handle) + if config_name not in payload: + raise KeyError(f"Config '{config_name}' not found in {yaml_path}") + return argparse.Namespace(**payload[config_name]) + + +def extract_model_state_dict(payload): + if not isinstance(payload, dict): + raise TypeError("Checkpoint payload must be a mapping") + if "model_state" in payload: + return payload["model_state"] + return payload + + +def strip_module_prefix(state_dict): + keys = list(state_dict.keys()) + if keys and all(key.startswith("module.") for key in keys): + return {key[7:]: value for key, value in state_dict.items()} + return state_dict + + +def normalize_case_arrays(case_payload): + x = np.asarray(case_payload["x"], dtype=np.float32) + state_labels = np.asarray(case_payload["state_labels"], dtype=np.int64) + bcs = np.asarray(case_payload["bcs"], dtype=np.int64) + batch_size = x.shape[1] + if state_labels.ndim == 1: + state_labels = np.tile(state_labels, (batch_size, 1)) + if bcs.ndim == 1: + bcs = np.tile(bcs, (batch_size, 1)) + return x, state_labels, bcs + + +def summarize_array(array): + return { + "shape": tuple(array.shape), + "dtype": str(array.dtype), + "min": float(np.min(array)), + "max": float(np.max(array)), + "mean": float(np.mean(array)), + "finite": bool(np.isfinite(array).all()), + } + + +def apply_model_state(model, state_dict): + if hasattr(model, "set_state_dict"): + model.set_state_dict(state_dict) + else: + model.load_dict(state_dict) + + +def run_forward(yaml_config, config_name, weights_path, case_npz, output_path): + import paddle + + from ppcfd.models.multiple_physics_pretraining.avit import build_avit + + params = load_yaml_config(yaml_config, config_name) + model = build_avit(params) + checkpoint = paddle.load(str(weights_path)) + model_state = strip_module_prefix(extract_model_state_dict(checkpoint)) + apply_model_state(model, model_state) + model.eval() + + with np.load(case_npz) as case_payload: + x, state_labels, bcs = normalize_case_arrays(case_payload) + + output = model( + paddle.to_tensor(x), + paddle.to_tensor(state_labels), + paddle.to_tensor(bcs), + ) + output_array = output.numpy() + + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + np.savez(output_path, output=output_array) + return summarize_array(output_array) + + +def build_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--yaml_config", required=True, type=Path) + parser.add_argument("--config", default="basic_config", type=str) + parser.add_argument("--weights", required=True, type=Path) + parser.add_argument("--case_npz", required=True, type=Path) + parser.add_argument("--output", required=True, type=Path) + return parser + + +def main(): + parser = build_arg_parser() + args = parser.parse_args() + summary = run_forward( + yaml_config=args.yaml_config, + config_name=args.config, + weights_path=args.weights, + case_npz=args.case_npz, + output_path=args.output, + ) + print(summary) + + +if __name__ == "__main__": + main() diff --git a/examples/multiple_physics_pretraining/generate_forward_case.py b/examples/multiple_physics_pretraining/generate_forward_case.py new file mode 100644 index 0000000..3644446 --- /dev/null +++ b/examples/multiple_physics_pretraining/generate_forward_case.py @@ -0,0 +1,67 @@ +import argparse +from pathlib import Path + +import numpy as np + + +def parse_csv_ints(raw_value): + if raw_value is None or raw_value == "": + return [] + return [int(part.strip()) for part in raw_value.split(",") if part.strip()] + + +def normalize_shape(shape): + if len(shape) != 5: + raise ValueError(f"Expected shape (T,B,C,H,W), got {shape}") + return tuple(int(part) for part in shape) + + +def build_case_payload(shape, labels, bcs, seed): + t, batch_size, channels, height, width = normalize_shape(shape) + if not labels: + raise ValueError("labels must not be empty") + if len(bcs) != 2: + raise ValueError("bcs must contain exactly 2 integers") + + rng = np.random.default_rng(seed) + x = rng.standard_normal((t, batch_size, channels, height, width), dtype=np.float32) + state_labels = np.tile(np.asarray(labels, dtype=np.int64), (batch_size, 1)) + bcs_array = np.tile(np.asarray(bcs, dtype=np.int64), (batch_size, 1)) + return { + "x": x.astype(np.float32, copy=False), + "state_labels": state_labels, + "bcs": bcs_array, + } + + +def save_case_payload(output_path, payload): + output_path = Path(output_path) + output_path.parent.mkdir(parents=True, exist_ok=True) + np.savez(output_path, **payload) + + +def build_arg_parser(): + parser = argparse.ArgumentParser() + parser.add_argument("--output", required=True, type=Path) + parser.add_argument("--seed", default=0, type=int) + parser.add_argument("--shape", default="4,1,3,64,64", type=str) + parser.add_argument("--labels", required=True, type=str) + parser.add_argument("--bcs", default="0,0", type=str) + return parser + + +def main(): + parser = build_arg_parser() + args = parser.parse_args() + payload = build_case_payload( + shape=parse_csv_ints(args.shape), + labels=parse_csv_ints(args.labels), + bcs=parse_csv_ints(args.bcs), + seed=args.seed, + ) + save_case_payload(args.output, payload) + print(f"Saved forward case to {args.output}") + + +if __name__ == "__main__": + main() diff --git a/examples/multiple_physics_pretraining/train_basic.py b/examples/multiple_physics_pretraining/train_basic.py new file mode 100644 index 0000000..7ccd280 --- /dev/null +++ b/examples/multiple_physics_pretraining/train_basic.py @@ -0,0 +1,733 @@ +import argparse +import gc +import os +import pickle as pkl +import time +from collections import OrderedDict, defaultdict +from contextlib import nullcontext + +# # 打开组合算子 +# export FLAGS_prim_enable_dynamic=true && export FLAGS_prim_all=true + +# # 打开 CINN 编译器 +# export FLAGS_use_cinn=true + +# # 是否打印 Program IR 信息 (用于调试) +# export FLAGS_print_ir=false + +# OPEN CINN +# os.environ["FLAGS_prim_enable_dynamic"] = "true" +# os.environ["FLAGS_prim_all"] = "true" +# os.environ["FLAGS_use_cinn"] = "true" +# os.environ["FLAGS_print_ir"] = "false" + +import einops +import numpy as np +import paddle +import wandb + + +# from adan_pytorch import Adan +# from dadaptation import DAdaptAdam, DAdaptAdan +from ppcfd.models.multiple_physics_pretraining.paddle_utils import * +from ruamel.yaml import YAML +from ruamel.yaml.comments import CommentedMap as ruamelDict + +from ppcfd.models.multiple_physics_pretraining.utils.adan_paddle import Adan +from ppcfd.models.multiple_physics_pretraining.utils.dadapt_adam_paddle import ( + DAdaptAdam, +) +from ppcfd.models.multiple_physics_pretraining.utils.dadapt_adan_paddle import ( + DAdaptAdan, +) + + +# from torchinfo import summary + +from ppcfd.models.multiple_physics_pretraining.data_utils.datasets import ( + DSET_NAME_TO_OBJECT, + get_data_loader, +) +from ppcfd.models.multiple_physics_pretraining.avit import build_avit +from ppcfd.models.multiple_physics_pretraining.utils import logging_utils +from ppcfd.models.multiple_physics_pretraining.utils.schedulers import ( + SimpleSequentialScheduler, +) +from ppcfd.models.multiple_physics_pretraining.utils.YParams import YParams + + +def add_weight_decay(model, weight_decay=1e-05, inner_lr=0.001, skip_list=()): + """From Ross Wightman at: + https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/3 + + Goes through the parameter list and if the squeeze dim is 1 or 0 (usually means bias or scale) + then don't apply weight decay. + """ + decay = [] + no_decay = [] + for name, param in model.named_parameters(): + if param.stop_gradient: + continue + if len(param.squeeze().shape) <= 1 or name in skip_list: + no_decay.append(param) + else: + decay.append(param) + return [ + {"params": no_decay, "weight_decay": 0.0}, + {"params": decay, "weight_decay": weight_decay}, + ] + + +class Trainer: + def __init__(self, params, global_rank, local_rank, device, sweep_id=None): + self.device = device + self.params = params + self.global_rank = global_rank + self.local_rank = local_rank + self.world_size = int(paddle.distributed.get_world_size()) + self.sweep_id = sweep_id + self.log_to_screen = params.log_to_screen + self.train_loss = paddle.nn.MSELoss() + self.startEpoch = 0 + self.epoch = 0 + has_cuda_device = ( + paddle.device.is_compiled_with_cuda() + and paddle.device.cuda.device_count() > 0 + ) + self.amp_enabled = has_cuda_device and params.enable_amp + self.mp_type = ( + "bfloat16" + if has_cuda_device and paddle.amp.is_bfloat16_supported() + else "float16" + ) + self.iters = 0 + self.initialize_data(self.params) + print(f"Initializing model on rank {self.global_rank}") + self.initialize_model(self.params) + self.initialize_optimizer(self.params) + if params.resuming: + print("Loading checkpoint %s" % params.checkpoint_path) + self.restore_checkpoint(params.checkpoint_path) + if params.resuming == False and params.pretrained: + print("Starting from pretrained model at %s" % params.pretrained_ckpt_path) + self.restore_checkpoint(params.pretrained_ckpt_path) + self.iters = 0 + self.startEpoch = 0 + self.initialize_scheduler(self.params) + + def single_print(self, *text): + if self.global_rank == 0 and self.log_to_screen: + print(" ".join([str(t) for t in text])) + + def initialize_data(self, params): + if params.tie_batches: + in_rank = 0 + else: + in_rank = self.global_rank + if self.log_to_screen: + print(f"Initializing data on rank {self.global_rank}") + ( + self.train_data_loader, + self.train_dataset, + self.train_sampler, + ) = get_data_loader( + params, + params.train_data_paths, + paddle.distributed.is_initialized(), + split="train", + rank=in_rank, + train_offset=self.params.embedding_offset, + ) + self.valid_data_loader, self.valid_dataset, _ = get_data_loader( + params, + params.valid_data_paths, + paddle.distributed.is_initialized(), + split="val", + rank=in_rank, + ) + if paddle.distributed.is_initialized(): + self.train_sampler.set_epoch(0) + + def initialize_model(self, params): + if self.params.model_type == "avit": + self.model = build_avit(params).to(self.device) + """ + # there's no match api for torch.compile in paddle. + if self.params.compile: + print( + "WARNING: BFLOAT NOT SUPPORTED IN SOME COMPILE OPS SO SWITCHING TO FLOAT16" + ) + self.mp_type = torch.half + self.model = torch.compile(self.model) + """ + if paddle.distributed.is_initialized(): + """paddle.DataParallel无device_ids和output_device""" + self.model = paddle.DataParallel( + layers=self.model, find_unused_parameters=True + ) + self.single_print( + f"Model parameter count: {sum([p.size for p in self.model.parameters()])}" + ) + + def initialize_optimizer(self, params): + parameters = add_weight_decay(self.model, self.params.weight_decay) + if params.optimizer == "adam": + if self.params.learning_rate < 0: + self.optimizer = DAdaptAdam( + parameters, lr=1.0, growth_rate=1.05, log_every=100, decouple=True + ) + else: + self.optimizer = paddle.optimizer.AdamW( + parameters=parameters, + learning_rate=params.learning_rate, + weight_decay=0.0, + ) + elif params.optimizer == "adan": + if self.params.learning_rate < 0: + self.optimizer = DAdaptAdan( + parameters, lr=1.0, growth_rate=1.05, log_every=100 + ) + else: + self.optimizer = Adan(parameters, lr=params.learning_rate) + elif params.optimizer == "sgd": + """there's no param "momentum" in paddle.optimizer.SGD + and the param "lr" is named as "learning_rate", "params" is "parameters" in paddle. + """ + self.optimizer = paddle.optimizer.SGD( + parameters=self.model.parameters(), + learning_rate=params.learning_rate, + weight_decay=0.0, + ) + else: + raise ValueError(f"Optimizer {params.optimizer} not supported") + self.gscaler = paddle.amp.GradScaler( + enable=self.amp_enabled and self.mp_type == "float16", + incr_every_n_steps=2000, + init_loss_scaling=65536.0, + ) + + def initialize_scheduler(self, params): + if params.scheduler_epochs > 0: + sched_epochs = params.scheduler_epochs + else: + sched_epochs = params.max_epochs + if params.scheduler == "cosine": + if self.params.learning_rate < 0: + tmp_lr = paddle.optimizer.lr.CosineAnnealingDecay( + last_epoch=self.startEpoch * params.epoch_size - 1, + T_max=sched_epochs * params.epoch_size, + eta_min=params.learning_rate / 100, + learning_rate=self.optimizer.get_lr(), + ) + self.optimizer.set_lr_scheduler(tmp_lr) + self.scheduler = tmp_lr + else: + k = params.warmup_steps + last_step = self.startEpoch * params.epoch_size - 1 + tmp_lr = paddle.optimizer.lr.LinearLR( + start_factor=0.01, + end_factor=1.0, + total_steps=k, + last_epoch=last_step, + learning_rate=self.optimizer.get_lr(), + ) + warmup = tmp_lr + tmp_lr = paddle.optimizer.lr.CosineAnnealingDecay( + eta_min=params.learning_rate / 100, + T_max=sched_epochs * params.epoch_size - k, + last_epoch=last_step - k if last_step >= k else -1, + learning_rate=self.optimizer.get_lr(), + ) + decay = tmp_lr + self.scheduler = SimpleSequentialScheduler( + self.optimizer, + [warmup, decay], + [k], + last_epoch=params.epoch_size * self.startEpoch - 1, + ) + else: + self.scheduler = None + + def save_checkpoint(self, checkpoint_path, model=None): + """Save model and optimizer to checkpoint""" + if not model: + model = self.model + paddle.save( + obj={ + "iters": self.epoch * self.params.epoch_size, + "epoch": self.epoch, + "model_state": model.state_dict(), + "optimizer_state_dict": self.optimizer.state_dict(), + }, + path=checkpoint_path, + ) + + def restore_checkpoint(self, checkpoint_path): + """Load model/opt from path""" + checkpoint = paddle.load(path=str(checkpoint_path)) + if "model_state" in checkpoint: + model_state = checkpoint["model_state"] + else: + model_state = checkpoint + try: + self.model.load_state_dict(model_state) + except: + if hasattr(self.model, "module"): + self.model.module.load_state_dict(model_state) + else: + new_state_dict = OrderedDict() + for key, val in model_state.items(): + name = key[7:] + new_state_dict[name] = val + self.model.load_state_dict(new_state_dict) + if self.params.resuming: + self.iters = checkpoint["iters"] + self.optimizer.load_state_dict(checkpoint["optimizer_state_dict"]) + self.startEpoch = checkpoint["epoch"] + self.epoch = self.startEpoch + else: + self.iters = 0 + if self.params.pretrained: + if self.params.freeze_middle: + self.model.module.freeze_middle() + elif self.params.freeze_processor: + self.model.module.freeze_processor() + else: + self.model.module.unfreeze() + exp_proj = 0 + for add_on in self.params.append_datasets: + exp_proj += len(DSET_NAME_TO_OBJECT[add_on]._specifics()[2]) + self.model.module.expand_projections(exp_proj) + checkpoint = None + self.model = self.model.to(self.device) + + def train_one_epoch(self): + self.model.train() + self.epoch += 1 + tr_time = 0 + data_time = 0 + data_start = time.time() + self.model.train() + logs = { + "train_rmse": paddle.zeros(1).to(self.device), + "train_nrmse": paddle.zeros(1).to(self.device), + "train_l1": paddle.zeros(1).to(self.device), + } + steps = 0 + last_grads = [paddle.zeros_like(p) for p in self.model.parameters()] + grad_logs = defaultdict(lambda: paddle.zeros(1).to(self.device)) + grad_counts = defaultdict(lambda: paddle.zeros(1).to(self.device)) + loss_logs = defaultdict(lambda: paddle.zeros(1).to(self.device)) + loss_counts = defaultdict(lambda: paddle.zeros(1).to(self.device)) + self.single_print( + "train_loader_size", len(self.train_data_loader), len(self.train_dataset) + ) + for batch_idx, data in enumerate(self.train_data_loader): + steps += 1 + inp, file_index, field_labels, bcs, tar = map( + lambda x: x.to(self.device), data + ) + dset_type = self.train_dataset.sub_dsets[file_index[0]].type + loss_counts[dset_type] += 1 + inp = einops.rearrange(inp, "b t c h w -> t b c h w") + data_time += time.time() - data_start + dtime = time.time() - data_start + should_sync = (1 + batch_idx) % self.params.accum_grad == 0 + sync_context = nullcontext() + if paddle.distributed.is_initialized() and not should_sync: + sync_context = self.model.no_sync() + with sync_context: + with paddle.amp.auto_cast(enable=self.amp_enabled, dtype=self.mp_type): + model_start = time.time() + output = self.model(inp, field_labels, bcs) + spatial_dims = tuple(range(output.ndim))[2:] + residuals = output - tar + tar_norm = 1e-07 + tar.pow(2).mean(spatial_dims, keepdim=True) + raw_loss = residuals.pow(2).mean( + spatial_dims, keepdim=True + ) / tar_norm + loss = raw_loss.mean() / self.params.accum_grad + forward_end = time.time() + forward_time = forward_end - model_start + with paddle.no_grad(): + logs["train_l1"] += paddle.nn.functional.l1_loss( + input=output, label=tar + ) + log_nrmse = raw_loss.sqrt().mean() + logs["train_nrmse"] += log_nrmse + loss_logs[dset_type] += loss.item() + logs["train_rmse"] += ( + residuals.pow(2).mean(spatial_dims).sqrt().mean() + ) + self.gscaler.scale(loss).backward() + backward_end = time.time() + backward_time = backward_end - forward_end + optimizer_step = 0 + if should_sync: + self.gscaler.unscale_(self.optimizer) + paddle.nn.utils.clip_grad_norm_( + parameters=self.model.parameters(), max_norm=1 + ) + self.gscaler.step(self.optimizer) + self.gscaler.update() + self.optimizer.clear_grad() + if self.scheduler is not None: + self.scheduler.step() + optimizer_step = time.time() - backward_end + tr_time += time.time() - model_start + if ( + self.log_to_screen + and batch_idx % self.params.log_interval == 0 + and self.global_rank == 0 + ): + print( + f"Epoch {self.epoch} Batch {batch_idx} Train Loss {log_nrmse.item()}" + ) + if self.log_to_screen: + print( + "Total Times. Batch: {}, Rank: {}, Data Shape: {}, Data time: {}, Forward: {}, Backward: {}, Optimizer: {}".format( + batch_idx, + self.global_rank, + inp.shape, + dtime, + forward_time, + backward_time, + optimizer_step, + ) + ) + data_start = time.time() + logs = {k: (v / steps) for k, v in logs.items()} + if paddle.distributed.is_initialized(): + for key in sorted(logs.keys()): + paddle.distributed.all_reduce(tensor=logs[key].detach()) + logs[key] = float(logs[key] / paddle.distributed.get_world_size()) + for key in sorted(loss_logs.keys()): + paddle.distributed.all_reduce(tensor=loss_logs[key].detach()) + for key in sorted(grad_logs.keys()): + paddle.distributed.all_reduce(tensor=grad_logs[key].detach()) + for key in sorted(loss_counts.keys()): + paddle.distributed.all_reduce(tensor=loss_counts[key].detach()) + for key in sorted(grad_counts.keys()): + paddle.distributed.all_reduce(tensor=grad_counts[key].detach()) + for key in loss_logs.keys(): + logs[f"{key}/train_nrmse"] = loss_logs[key] / loss_counts[key] + self.iters += steps + if self.global_rank == 0: + logs["iters"] = self.iters + self.single_print("all reduces executed!") + return tr_time, data_time, logs + + def validate_one_epoch(self, full=False): + """ + Validates - for each batch just use a small subset to make it easier. + + Note: need to split datasets for meaningful metrics, but TBD. + """ + if self.params.use_ddp and paddle.distributed.is_initialized(): + paddle.distributed.barrier() + if self.global_rank != 0: + paddle.distributed.barrier() + return {} + self.model.eval() + if full: + cutoff = 999999999999 + else: + cutoff = 40 + self.single_print("STARTING VALIDATION!!!") + with paddle.no_grad(): + with paddle.amp.auto_cast(enable=False, dtype=self.mp_type): + field_labels = self.valid_dataset.get_state_names() + distinct_dsets = list( + set( + [ + dset.title + for dset_group in self.valid_dataset.sub_dsets + for dset in dset_group.get_per_file_dsets() + ] + ) + ) + counts = {dset: (0) for dset in distinct_dsets} + logs = {} + for subset_group in self.valid_dataset.sub_dsets: + for subset in subset_group.get_per_file_dsets(): + dset_type = subset.title + self.single_print("VALIDATING ON", dset_type) + temp_loader = paddle.io.DataLoader( + dataset=subset, + batch_size=self.params.batch_size, + num_workers=self.params.num_data_workers, + shuffle=not self.params.use_ddp, + drop_last=True, + ) + count = 0 + for batch_idx, data in enumerate(temp_loader): + if count > cutoff: + del temp_loader + break + count += 1 + counts[dset_type] += 1 + inp, bcs, tar = map(lambda x: x.to(self.device), data) + labels = ( + paddle.to_tensor( + self.train_dataset.subset_dict.get( + subset.get_name(), + [-1] + * len( + self.valid_dataset.subset_dict[ + subset.get_name() + ] + ), + ), + ) + .to(self.device) + .unsqueeze(0) + .expand(tar.shape[0], -1) + ) + inp = einops.rearrange(inp, "b t c h w -> t b c h w") + output = self.model(inp, labels, bcs) + spatial_dims = tuple(range(output.ndim))[2:] + residuals = output - tar + nmse = ( + residuals.pow(2).mean(spatial_dims, keepdim=True) + / (1e-07 + tar.pow(2).mean(spatial_dims, keepdim=True)) + ).sqrt() + logs[f"{dset_type}/valid_nrmse"] = ( + logs.get(f"{dset_type}/valid_nrmse", 0) + nmse.mean() + ) + logs[f"{dset_type}/valid_rmse"] = ( + logs.get(f"{dset_type}/valid_mse", 0) + + residuals.pow(2).mean(spatial_dims).sqrt().mean() + ) + logs[f"{dset_type}/valid_l1"] = ( + logs.get(f"{dset_type}/valid_l1", 0) + + residuals.abs().mean() + ) + for i, field in enumerate( + self.valid_dataset.subset_dict[subset.type] + ): + field_name = field_labels[field] + logs[f"{dset_type}/{field_name}_valid_nrmse"] = ( + logs.get(f"{dset_type}/{field_name}_valid_nrmse", 0) + + nmse[:, i].mean() + ) + logs[f"{dset_type}/{field_name}_valid_rmse"] = ( + logs.get(f"{dset_type}/{field_name}_valid_rmse", 0) + + residuals[:, i : i + 1] + .pow(2) + .mean(spatial_dims) + .sqrt() + .mean() + ) + logs[f"{dset_type}/{field_name}_valid_l1"] = ( + logs.get(f"{dset_type}/{field_name}_valid_l1", 0) + + residuals[:, i].abs().mean() + ) + else: + del temp_loader + self.single_print("DONE VALIDATING - NOW SYNCING") + for k, v in logs.items(): + dset_type = k.split("/")[0] + logs[k] = v / counts[dset_type] + logs["valid_nrmse"] = 0 + for dset_type in distinct_dsets: + logs["valid_nrmse"] += logs[f"{dset_type}/valid_nrmse"] / len( + distinct_dsets + ) + if paddle.distributed.is_initialized(): + for key in sorted(logs.keys()): + paddle.distributed.all_reduce(tensor=logs[key].detach()) + logs[key] = float( + logs[key].item() / paddle.distributed.get_world_size() + ) + if "rmse" in key: + logs[key] = logs[key] + self.single_print("DONE SYNCING - NOW LOGGING") + if self.params.use_ddp and paddle.distributed.is_initialized(): + paddle.distributed.barrier() + return logs + + def train(self): + if self.params.log_to_wandb: + if self.sweep_id: + wandb.init(dir=self.params.experiment_dir) + hpo_config = wandb.config.as_dict() + self.params.update_params(hpo_config) + params = self.params + else: + wandb.init( + dir=self.params.experiment_dir, + config=self.params, + name=self.params.name, + group=self.params.group, + project=self.params.project, + entity=self.params.entity, + resume=True, + ) + if self.sweep_id and paddle.distributed.is_initialized(): + param_file = f"temp_hpo_config_{os.environ['SLURM_JOBID']}.pkl" + if self.global_rank == 0: + with open(param_file, "wb") as f: + pkl.dump(hpo_config, f) + paddle.distributed.barrier() + if self.global_rank != 0: + with open(param_file, "rb") as f: + hpo_config = pkl.load(f) + paddle.distributed.barrier() + if self.global_rank == 0: + os.remove(param_file) + if "batch_size" in hpo_config: + hpo_config["batch_size"] = int( + hpo_config["batch_size"] // self.world_size + ) + self.params.update_params(hpo_config) + params = self.params + self.initialize_data(self.params) + self.initialize_model(self.params) + self.initialize_optimizer(self.params) + self.initialize_scheduler(self.params) + # if self.global_rank == 0: + # summary(self.model) + if self.params.log_to_wandb: + wandb.watch(self.model) + self.single_print("Starting Training Loop...") + best_valid_loss = 1000000.0 + for epoch in range(self.startEpoch, self.params.max_epochs): + if paddle.distributed.is_initialized(): + self.train_sampler.set_epoch(epoch) + start = time.time() + tr_time, data_time, train_logs = self.train_one_epoch() + valid_start = time.time() + if epoch == self.params.max_epochs - 1: + valid_logs = self.validate_one_epoch(True) + else: + valid_logs = self.validate_one_epoch() + post_start = time.time() + train_logs.update(valid_logs) + train_logs["time/train_time"] = valid_start - start + train_logs["time/train_data_time"] = data_time + train_logs["time/train_compute_time"] = tr_time + train_logs["time/valid_time"] = post_start - valid_start + if self.params.log_to_wandb: + wandb.log(train_logs) + gc.collect() + paddle.device.cuda.empty_cache() + if self.global_rank == 0: + if self.params.save_checkpoint: + self.save_checkpoint(self.params.checkpoint_path) + if epoch % self.params.checkpoint_save_interval == 0: + self.save_checkpoint(self.params.checkpoint_path + f"_epoch{epoch}") + if valid_logs["valid_nrmse"] <= best_valid_loss: + self.save_checkpoint(self.params.best_checkpoint_path) + best_valid_loss = valid_logs["valid_nrmse"] + cur_time = time.time() + self.single_print( + f"Time for train {valid_start - start}. For valid: {post_start - valid_start}. For postprocessing:{cur_time - post_start}" + ) + self.single_print( + "Time taken for epoch {} is {} sec".format( + epoch + 1, time.time() - start + ) + ) + self.single_print( + "Train loss: {}. Valid loss: {}".format( + train_logs["train_nrmse"], valid_logs["valid_nrmse"] + ) + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--run_name", default="00", type=str) + parser.add_argument( + "--use_ddp", action="store_true", help="Use distributed data parallel" + ) + parser.add_argument("--yaml_config", default="./config/multi_ds.yaml", type=str) + parser.add_argument("--config", default="basic_config", type=str) + parser.add_argument( + "--sweep_id", + default=None, + type=str, + help="sweep config from ./configs/sweeps.yaml", + ) + args = parser.parse_args() + params = YParams(os.path.abspath(args.yaml_config), args.config) + params.use_ddp = args.use_ddp + has_cuda_device = ( + paddle.device.is_compiled_with_cuda() and paddle.device.cuda.device_count() > 0 + ) + global_rank = 0 + local_rank = 0 + world_size = 1 + if args.use_ddp: + paddle.distributed.init_parallel_env() + parallel_env = paddle.distributed.ParallelEnv() + global_rank = int(paddle.distributed.get_rank()) + local_rank = int(parallel_env.local_rank) + world_size = int(paddle.distributed.get_world_size()) + if has_cuda_device: + paddle.device.set_device(f"gpu:{local_rank}") + elif has_cuda_device: + paddle.device.set_device("gpu:0") + device = f"gpu:{local_rank}" if has_cuda_device else "cpu" + if params.batch_size % world_size != 0: + raise ValueError( + f"Global batch_size ({params.batch_size}) must be divisible by world_size ({world_size})." + ) + params["batch_size"] = int(params.batch_size // world_size) + if params.batch_size < 1: + raise ValueError( + "Per-rank batch_size became 0 after sharding. Increase the configured global batch_size." + ) + params["startEpoch"] = 0 + if args.sweep_id: + jid = os.environ["SLURM_JOBID"] + expDir = os.path.join( + params.exp_dir, args.sweep_id, args.config, str(args.run_name), jid + ) + else: + expDir = os.path.join(params.exp_dir, args.config, str(args.run_name)) + params["old_exp_dir"] = expDir + params["experiment_dir"] = os.path.abspath(expDir) + params["checkpoint_path"] = os.path.join(expDir, "training_checkpoints/ckpt.tar") + params["best_checkpoint_path"] = os.path.join( + expDir, "training_checkpoints/best_ckpt.tar" + ) + params["old_checkpoint_path"] = os.path.join( + params.old_exp_dir, "training_checkpoints/best_ckpt.tar" + ) + if global_rank == 0: + if not os.path.isdir(expDir): + os.makedirs(expDir) + os.makedirs(os.path.join(expDir, "training_checkpoints/")) + params["resuming"] = True if os.path.isfile(params.checkpoint_path) else False + params["name"] = str(args.run_name) + if global_rank == 0: + logging_utils.log_to_file( + logger_name=None, log_filename=os.path.join(expDir, "out.log") + ) + logging_utils.log_versions() + params.log() + params["log_to_wandb"] = global_rank == 0 and params["log_to_wandb"] + params["log_to_screen"] = global_rank == 0 and params["log_to_screen"] + PaddleFlag.cudnn_benchmark = False + if global_rank == 0: + hparams = ruamelDict() + yaml = YAML() + for key, value in params.params.items(): + hparams[str(key)] = str(value) + with open(os.path.join(expDir, "hyperparams.yaml"), "w") as hpfile: + yaml.dump(hparams, hpfile) + trainer = Trainer(params, global_rank, local_rank, device, sweep_id=args.sweep_id) + if args.sweep_id and trainer.global_rank == 0: + print(args.sweep_id, trainer.params.entity, trainer.params.project) + wandb.agent( + args.sweep_id, + function=trainer.train, + count=1, + entity=trainer.params.entity, + project=trainer.params.project, + ) + else: + trainer.train() + if params.log_to_screen: + print("DONE ---- rank %d" % global_rank) diff --git a/ppcfd/models/__init__.py b/ppcfd/models/__init__.py index 13b523a..24c3f35 100755 --- a/ppcfd/models/__init__.py +++ b/ppcfd/models/__init__.py @@ -72,3 +72,11 @@ __all__.append("symbolic_gn") except ImportError: pass # Optional dependency + +# Multiple Physics Pretraining (MPP) - AViT for spatiotemporal surrogate modeling +try: + from ppcfd.models import multiple_physics_pretraining + + __all__.append("multiple_physics_pretraining") +except ImportError: + pass # Optional dependency diff --git a/ppcfd/models/multiple_physics_pretraining/DropPath_util.py b/ppcfd/models/multiple_physics_pretraining/DropPath_util.py new file mode 100644 index 0000000..16a467d --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/DropPath_util.py @@ -0,0 +1,38 @@ +import paddle + + +def drop_path( + x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True +): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) + random_tensor = paddle.bernoulli(paddle.full(shape, keep_prob, dtype=x.dtype)) + if keep_prob > 0.0 and scale_by_keep: + random_tensor = random_tensor / keep_prob + return x * random_tensor + + +class DropPath(paddle.nn.Layer): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True): + super().__init__() + self.drop_prob = drop_prob + self.scale_by_keep = scale_by_keep + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training, self.scale_by_keep) + + def extra_repr(self): + return f"drop_prob={round(self.drop_prob, 3):0.3f}" diff --git a/ppcfd/models/multiple_physics_pretraining/__init__.py b/ppcfd/models/multiple_physics_pretraining/__init__.py new file mode 100644 index 0000000..3a7e6fe --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/__init__.py @@ -0,0 +1,30 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Multiple Physics Pretraining (MPP) models for PaddleCFD. + +This module provides the AViT (Axial Vision Transformer) architecture for +spatiotemporal surrogate modeling with multiple physics pretraining. + +Reference: + "Multiple Physics Pretraining for Spatiotemporal Surrogate Models" + Michael McCabe et al., NeurIPS 2024. +""" + +from .avit import AViT, build_avit + +__all__ = [ + "AViT", + "build_avit", +] diff --git a/ppcfd/models/multiple_physics_pretraining/avit.py b/ppcfd/models/multiple_physics_pretraining/avit.py new file mode 100644 index 0000000..7eb94c4 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/avit.py @@ -0,0 +1,153 @@ +from functools import partial + +import einops +import numpy as np +import paddle +from .paddle_utils import * + +from .mixed_modules import SpaceTimeBlock, build_spacetime_block +from .spatial_modules import SubsampledLinear, hMLP_output, hMLP_stem + + +def build_avit(params): + """Builds model from parameter file. + + General recipe is to build the spatial and temporal modules separately and then + combine them in a model. Eventually the "stem" and "destem" should + also be parameterized. + """ + space_time_block = build_spacetime_block(params) + model = AViT( + patch_size=params.patch_size, + embed_dim=params.embed_dim, + processor_blocks=params.processor_blocks, + n_states=params.n_states, + override_block=space_time_block, + ) + return model + + +class AViT(paddle.nn.Layer): + """ + Naive model that interweaves spatial and temporal attention blocks. Temporal attention + acts only on the time dimension. + + Args: + patch_size (tuple): Size of the input patch + embed_dim (int): Dimension of the embedding + processor_blocks (int): Number of blocks (consisting of spatial mixing - temporal attention) + n_states (int): Number of input state variables. + """ + + def __init__( + self, + patch_size=(16, 16), + embed_dim=768, + processor_blocks=8, + n_states=6, + override_block=None, + drop_path=0.2, + ): + super().__init__() + self.drop_path = drop_path + self.dp = np.linspace(0, drop_path, processor_blocks) + self.space_bag = SubsampledLinear(n_states, embed_dim // 4) + self.embed = hMLP_stem( + patch_size=patch_size, in_chans=embed_dim // 4, embed_dim=embed_dim + ) + if override_block is not None: + inner_block = override_block + else: + inner_block = partial(SpaceTimeBlock, hidden_dim=embed_dim) + self.blocks = paddle.nn.LayerList( + [inner_block(drop_path=self.dp[i]) for i in range(processor_blocks)] + ) + self.debed = hMLP_output( + patch_size=patch_size, embed_dim=embed_dim, out_chans=n_states + ) + + def expand_projections(self, expansion_amount): + """Appends addition embeddings for finetuning on new data""" + with paddle.no_grad(): + temp_space_bag = SubsampledLinear( + dim_in=self.space_bag.dim_in + expansion_amount, + dim_out=self.space_bag.dim_out, + ) + temp_space_bag.weight[:, : self.space_bag.dim_in] = self.space_bag.weight + temp_space_bag.bias[:] = self.space_bag.bias[:] + self.space_bag = temp_space_bag + out_head = paddle.nn.Conv2DTranspose( + in_channels=self.debed.embed_dim // 4, + out_channels=self.debed.out_chans + expansion_amount, + kernel_size=4, + stride=4, + ) + temp_out_kernel = out_head.weight + temp_out_bias = out_head.bias + temp_out_kernel[:, : self.debed.out_chans, :, :] = self.debed.out_kernel + temp_out_bias[: self.debed.out_chans] = self.debed.out_bias + self.debed.out_kernel = temp_out_kernel + self.debed.out_bias = temp_out_bias + + def freeze_middle(self): + for param in self.parameters(): + param.stop_gradient = not False + for param in self.space_bag.parameters(): + param.stop_gradient = not True + self.debed.out_kernel.stop_gradient = not True + self.debed.out_bias.stop_gradient = not True + + def freeze_processor(self): + for param in self.parameters(): + param.stop_gradient = not False + for param in self.space_bag.parameters(): + param.stop_gradient = not True + for param in self.debed.parameters(): + param.stop_gradient = not True + for param in self.embed.parameters(): + param.stop_gradient = not True + + def unfreeze(self): + for param in self.parameters(): + param.stop_gradient = not True + + def forward(self, x, state_labels, bcs): + T, B, C = x.shape[:3] + with paddle.no_grad(): + """data_std, data_mean = torch.std_mean(x, dim=(0, -2, -1), keepdims=True) + paddle has no std_mean api""" + data_std = paddle.std(x=x, axis=(0, -2, -1), keepdim=True) + data_mean = paddle.mean(x=x, axis=(0, -2, -1), keepdim=True) + data_std = data_std + 1e-07 + x = (x - data_mean) / data_std + x = einops.rearrange(x, "t b c h w -> t b h w c") + x = self.space_bag(x, state_labels) + x = einops.rearrange(x, "t b h w c -> (t b) c h w") + x = self.embed(x) + x = einops.rearrange(x, "(t b) c h w -> t b c h w", t=T) + for blk in self.blocks: + x = blk(x, bcs) + x = einops.rearrange(x, "t b c h w -> (t b) c h w") + x = self.debed(x, state_labels[0]) + x = einops.rearrange(x, "(t b) c h w -> t b c h w", t=T) + x = x * data_std + data_mean + return x[-1] + + +if __name__ == "__main__": + print(paddle.device.is_compiled_with_cuda()) + model = AViT().cuda() + for n, p in model.debed.named_parameters(): + print(n, p.shape) + model.expand_projections(2) + for n, p in model.debed.named_parameters(): + print(n, p.shape) + T = 10 + bs = 4 + nx = 128 + ny = 128 + x = paddle.randn(T, bs, 2, nx, ny).cuda() + print("xshape", x.shape) + labels = [0, 1] + y = model(x, labels) + print("yshape", y.shape) diff --git a/ppcfd/models/multiple_physics_pretraining/data_utils/__init__.py b/ppcfd/models/multiple_physics_pretraining/data_utils/__init__.py new file mode 100644 index 0000000..4d7a209 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/data_utils/__init__.py @@ -0,0 +1 @@ +"""Dataset utilities for the Paddle implementation.""" diff --git a/ppcfd/models/multiple_physics_pretraining/data_utils/datasets.py b/ppcfd/models/multiple_physics_pretraining/data_utils/datasets.py new file mode 100644 index 0000000..df05d6d --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/data_utils/datasets.py @@ -0,0 +1,168 @@ +import paddle + +""" +Remember to parameterize the file paths eventually +""" +import numpy as np + +from .hdf5_datasets import * +from .mixed_dset_sampler import MultisetSampler + +import glob + +broken_paths = [] +DSET_NAME_TO_OBJECT = { + "swe": SWEDataset, + "incompNS": IncompNSDataset, + "diffre2d": DiffRe2DDataset, + "compNS": CompNSDataset, +} + + +def get_data_loader(params, paths, distributed, split="train", rank=0, train_offset=0): + dataset = MixedDataset( + paths, + n_steps=params.n_steps, + train_val_test=params.train_val_test, + split=split, + tie_fields=params.tie_fields, + use_all_fields=params.use_all_fields, + enforce_max_steps=params.enforce_max_steps, + train_offset=train_offset, + ) + sampler = MultisetSampler( + dataset, + params.batch_size, + shuffle=(split == "train"), + distributed=distributed, + max_samples=params.epoch_size, + rank=rank, + world_size=paddle.distributed.get_world_size() if distributed else 1, + ) + batch_sampler = paddle.io.BatchSampler( + sampler=sampler, + batch_size=int(params.batch_size), + drop_last=True, + ) + dataloader = paddle.io.DataLoader( + dataset=dataset, + batch_sampler=batch_sampler, + num_workers=params.num_data_workers, + return_list=True, + ) + return dataloader, dataset, sampler + + +class MixedDataset(paddle.io.Dataset): + def __init__( + self, + path_list=[], + n_steps=1, + dt=1, + train_val_test=(0.8, 0.1, 0.1), + split="train", + tie_fields=True, + use_all_fields=True, + extended_names=False, + enforce_max_steps=False, + train_offset=0, + ): + super().__init__() + self.train_offset = train_offset + self.path_list, self.type_list, self.include_string = zip(*path_list) + self.tie_fields = tie_fields + self.extended_names = extended_names + self.split = split + self.sub_dsets = [] + self.offsets = [0] + self.train_val_test = train_val_test + self.use_all_fields = use_all_fields + for dset, path, include_string in zip( + self.type_list, self.path_list, self.include_string + ): + subdset = DSET_NAME_TO_OBJECT[dset]( + path, + include_string, + n_steps=n_steps, + dt=dt, + train_val_test=train_val_test, + split=split, + ) + try: + len(subdset) + except ValueError: + raise ValueError( + f"Dataset {path} is empty. Check that n_steps < trajectory_length in file." + ) + self.sub_dsets.append(subdset) + self.offsets.append(self.offsets[-1] + len(self.sub_dsets[-1])) + self.offsets[0] = -1 + self.subset_dict = self._build_subset_dict() + + def get_state_names(self): + name_list = [] + if self.use_all_fields: + for name, dset in DSET_NAME_TO_OBJECT.items(): + field_names = dset._specifics()[2] + name_list += field_names + return name_list + else: + visited = set() + for dset in self.sub_dsets: + name = dset.get_name() + if not name in visited: + visited.add(name) + name_list.append(dset.field_names) + return [f for fl in name_list for f in fl] + + def _build_subset_dict(self): + if self.tie_fields: + subset_dict = { + "swe": [3], + "incompNS": [0, 1, 2], + "compNS": [0, 1, 2, 3], + "diffre2d": [4, 5], + } + elif self.use_all_fields: + cur_max = 0 + subset_dict = {} + for name, dset in DSET_NAME_TO_OBJECT.items(): + field_names = dset._specifics()[2] + subset_dict[name] = list(range(cur_max, cur_max + len(field_names))) + cur_max += len(field_names) + else: + subset_dict = {} + cur_max = self.train_offset + for dset in self.sub_dsets: + name = dset.get_name(self.extended_names) + if not name in subset_dict: + subset_dict[name] = list( + range(cur_max, cur_max + len(dset.field_names)) + ) + cur_max += len(dset.field_names) + return subset_dict + + def __getitem__(self, index): + file_idx = np.searchsorted(self.offsets, index, side="right") - 1 + local_idx = index - max(self.offsets[file_idx], 0) + try: + x, bcs, y = self.sub_dsets[file_idx][local_idx] + except Exception as err: + current_rank = ( + int(paddle.distributed.get_rank()) + if paddle.distributed.is_initialized() + else 0 + ) + raise RuntimeError( + f"FAILED AT file_idx={file_idx} local_idx={local_idx} index={index} rank={current_rank}" + ) from err + return ( + x, + file_idx, + paddle.to_tensor(self.subset_dict[self.sub_dsets[file_idx].get_name()]), + bcs, + y, + ) + + def __len__(self): + return sum([len(dset) for dset in self.sub_dsets]) diff --git a/ppcfd/models/multiple_physics_pretraining/data_utils/hdf5_datasets.py b/ppcfd/models/multiple_physics_pretraining/data_utils/hdf5_datasets.py new file mode 100644 index 0000000..a393cf2 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/data_utils/hdf5_datasets.py @@ -0,0 +1,472 @@ +import os + +import paddle + +""" +Remember to parameterize the file paths eventually +""" +import glob + +import h5py +import numpy as np + +broken_paths = [""] + + +class BaseHDF5DirectoryDataset(paddle.io.Dataset): + """ + Base class for data loaders. Returns data in T x B x C x H x W format. + + Note - doesn't currently normalize because the data is on wildly different + scales but probably should. + + Split is provided so I can be lazy and not separate out HDF5 files. + + Takes in path to directory of HDF5 files to construct dset. + + Args: + path (str): Path to directory of HDF5 files + include_string (str): Only include files with this string in name + n_steps (int): Number of steps to include in each sample + dt (int): Time step between samples + split (str): train/val/test split + train_val_test (tuple): Percent of data to use for train/val/test + subname (str): Name to use for dataset + split_level (str): 'sample' or 'file' - whether to split by samples within a file + (useful for data segmented by parameters) or file (mostly INS right now) + """ + + def __init__( + self, + path, + include_string="", + n_steps=1, + dt=1, + split="train", + train_val_test=None, + subname=None, + extra_specific=False, + ): + super().__init__() + self.path = path + self.split = split + self.extra_specific = extra_specific + if subname is None: + self.subname = path.split("/")[-1] + else: + self.subname = subname + self.dt = 1 + self.n_steps = n_steps + self.include_string = include_string + self.train_val_test = train_val_test + self.partition = {"train": 0, "val": 1, "test": 2}[split] + ( + self.time_index, + self.sample_index, + self.field_names, + self.type, + self.split_level, + ) = self._specifics() + self._get_directory_stats(path) + if self.extra_specific: + self.title = self.more_specific_title(self.type, path, include_string) + else: + self.title = self.type + + def get_name(self, full_name=False): + if full_name: + return self.subname + "_" + self.type + else: + return self.type + + def more_specific_title(self, type, path, include_string): + """ + Override this to add more info to the dataset name + """ + return type + + @staticmethod + def _specifics(): + raise NotImplementedError + + def get_per_file_dsets(self): + if self.split_level == "file" or len(self.files_paths) == 1: + return [self] + else: + sub_dsets = [] + for file in self.files_paths: + subd = self.__class__( + self.path, + file, + n_steps=self.n_steps, + dt=self.dt, + split=self.split, + train_val_test=self.train_val_test, + subname=self.subname, + extra_specific=True, + ) + sub_dsets.append(subd) + return sub_dsets + + def _get_specific_stats(self, f): + raise NotImplementedError + + def _get_specific_bcs(self, f): + raise NotImplementedError + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + raise NotImplementedError + + def _get_directory_stats(self, path): + self.files_paths = glob.glob(path + "/*.h5") + glob.glob(path + "/*.hdf5") + self.files_paths.sort() + self.n_files = len(self.files_paths) + self.file_steps = [] + self.file_nsteps = [] + self.file_samples = [] + self.split_offsets = [] + self.offsets = [0] + file_paths = [] + for file in self.files_paths: + if len(self.include_string) > 0 and self.include_string not in file: + continue + elif file in broken_paths: + continue + else: + file_paths.append(file) + try: + with h5py.File(file, "r") as _f: + samples, steps = self._get_specific_stats(_f) + if steps - self.n_steps - (self.dt - 1) < 1: + print( + "WARNING: File {} has {} steps, but n_steps is {}. Setting file steps = max allowable.".format( + file, steps, self.n_steps + ) + ) + file_nsteps = steps - self.dt + else: + file_nsteps = self.n_steps + self.file_nsteps.append(file_nsteps) + self.file_steps.append(steps - file_nsteps - (self.dt - 1)) + if self.split_level == "sample": + partition = self.partition + sample_per_part = np.ceil( + np.array(self.train_val_test) * samples + ).astype(int) + sample_per_part[2] = max( + samples - sample_per_part[0] - sample_per_part[1], 0 + ) + self.split_offsets.append( + self.file_steps[-1] * sum(sample_per_part[:partition]) + ) + split_samples = sample_per_part[partition] + else: + split_samples = samples + self.file_samples.append(split_samples) + self.offsets.append( + self.offsets[-1] + + (steps - file_nsteps - (self.dt - 1)) * split_samples + ) + except: + print( + "WARNING: Failed to open file {}. Continuing without it.".format( + file + ) + ) + raise RuntimeError("Failed to open file {}".format(file)) + self.files_paths = file_paths + self.offsets[0] = -1 + self.files = [None for _ in self.files_paths] + self.len = self.offsets[-1] + if self.split_level == "file": + if self.train_val_test is None: + print( + "WARNING: No train/val/test split specified. Using all data for training." + ) + self.split_offset = 0 + self.len = self.offsets[-1] + else: + print("Using train/val/test split: {}".format(self.train_val_test)) + total_samples = sum(self.file_samples) + ideal_split_offsets = [ + int(self.train_val_test[i] * total_samples) for i in range(3) + ] + end_ind = 0 + for i in range(self.partition + 1): + run_sum = 0 + start_ind = end_ind + for samples, steps in zip(self.file_samples, self.file_steps): + run_sum += samples + if run_sum <= ideal_split_offsets[i]: + end_ind += samples * steps + if run_sum == ideal_split_offsets[i]: + break + else: + end_ind += ( + np.abs(run_sum - samples - ideal_split_offsets[i]) + * steps + ) + break + self.split_offset = start_ind + self.len = end_ind - start_ind + + def _open_file(self, file_ind): + _file = h5py.File(self.files_paths[file_ind], "r") + self.files[file_ind] = _file + + def __getitem__(self, index): + if self.split_level == "file": + index = index + self.split_offset + file_idx = int(np.searchsorted(self.offsets, index, side="right") - 1) + nsteps = self.file_nsteps[file_idx] + local_idx = index - max(self.offsets[file_idx], 0) + if self.split_level == "sample": + sample_idx = (local_idx + self.split_offsets[file_idx]) // self.file_steps[ + file_idx + ] + else: + sample_idx = local_idx // self.file_steps[file_idx] + time_idx = local_idx % self.file_steps[file_idx] + if self.files[file_idx] is None: + self._open_file(file_idx) + time_idx = ( + time_idx - self.dt if time_idx >= self.file_steps[file_idx] else time_idx + ) + time_idx += nsteps + try: + trajectory = self._reconstruct_sample( + self.files[file_idx], sample_idx, time_idx, nsteps + ) + bcs = self._get_specific_bcs(self.files[file_idx]) + except: + raise RuntimeError( + f"Failed to reconstruct sample for file {self.files_paths[file_idx]} sample {sample_idx} time {time_idx}" + ) + return trajectory[:-1], paddle.to_tensor(bcs), trajectory[-1] + + def __len__(self): + return self.len + + +class SWEDataset(BaseHDF5DirectoryDataset): + @staticmethod + def _specifics(): + time_index = 0 + sample_index = None + field_names = ["h"] + type = "swe" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = list(f.keys()) + steps = f[samples[0]]["data"].shape[0] + return len(samples), steps + + def _get_specific_bcs(self, f): + return [0, 0] + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + samples = list(file.keys()) + return file[samples[sample_idx]]["data"][ + time_idx - n_steps * self.dt : time_idx + self.dt + ].transpose(0, 3, 1, 2) + + +class DiffRe2DDataset(BaseHDF5DirectoryDataset): + @staticmethod + def _specifics(): + time_index = 0 + sample_index = None + field_names = ["activator", "inhibitor"] + type = "diffre2d" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = list(f.keys()) + steps = f[samples[0]]["data"].shape[0] + return len(samples), steps + + def _get_specific_bcs(self, f): + return [0, 0] + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + samples = list(file.keys()) + return file[samples[sample_idx]]["data"][ + time_idx - n_steps * self.dt : time_idx + self.dt + ].transpose(0, 3, 1, 2) + + +class IncompNSDataset(BaseHDF5DirectoryDataset): + """ + Order Vx, Vy, "particles" + """ + + @staticmethod + def _specifics(): + time_index = 1 + sample_index = 0 + field_names = ["Vx", "Vy", "particles"] + type = "incompNS" + split_level = "file" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = f["velocity"].shape[0] + steps = f["velocity"].shape[1] + return samples, steps + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + velocity = file["velocity"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + particles = file["particles"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + comb = np.concatenate([velocity, particles], -1) + return comb.transpose((0, 3, 1, 2)) + + def _get_specific_bcs(self, f): + return [0, 0] + + +class PDEArenaINS(BaseHDF5DirectoryDataset): + """ + Order Vx, Vy, density, pressure + """ + + @staticmethod + def _specifics(): + time_index = 1 + sample_index = 0 + field_names = ["Vx", "Vy", "u"] + type = "pa_ins" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = f["Vx"].shape[0] + steps = f["Vx"].shape[1] + return samples, steps + + def more_specific_title(self, type, path, include_string): + """ + Override this to add more info to the dataset name + """ + split_path = self.include_string.split("/")[-1].split("_") + buoy = split_path[-3] + nu = split_path[-2] + return f"{type}_buoy{buoy}_nu{nu}" + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + vx = file["Vx"][sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt] + vy = file["Vy"][sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt] + density = file["u"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + comb = np.stack([vx, vy, density], 1) + return comb + + def _get_specific_bcs(self, f): + return [0, 0] + + +class CompNSDataset(BaseHDF5DirectoryDataset): + """ + Order Vx, Vy, density, pressure + """ + + @staticmethod + def _specifics(): + time_index = 1 + sample_index = 0 + field_names = ["Vx", "Vy", "density", "pressure"] + type = "compNS" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = f["Vx"].shape[0] + steps = f["Vx"].shape[1] + return samples, steps + + def more_specific_title(self, type, path, include_string): + """ + Override this to add more info to the dataset name + """ + cns_path = self.include_string.split("/")[-1].split("_") + ic = cns_path[2] + m = cns_path[3] + res = cns_path[-2] + return f"{type}_{ic}_{m}_res{res}" + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + vx = file["Vx"][sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt] + vy = file["Vy"][sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt] + density = file["density"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + p = file["pressure"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + comb = np.stack([vx, vy, density, p], 1) + return comb + + def _get_specific_bcs(self, f): + return [1, 1] + + +class BurgersDataset(BaseHDF5DirectoryDataset): + """ + Order Vx, Vy, density, pressure + """ + + @staticmethod + def _specifics(): + time_index = 1 + sample_index = 0 + field_names = ["Vx"] + type = "burgers" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = f["tensor"].shape[0] + steps = f["tensor"].shape[1] + return samples, steps + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + vx = file["tensor"][ + sample_idx, time_idx - n_steps * self.dt : time_idx + self.dt + ] + vx = vx[:, None, :, None] + return vx + + def _get_specific_bcs(self, f): + return [1, 1] + + +class DiffSorb1DDataset(BaseHDF5DirectoryDataset): + @staticmethod + def _specifics(): + time_index = 0 + sample_index = None + field_names = ["u"] + type = "diffsorb" + split_level = "sample" + return time_index, sample_index, field_names, type, split_level + + def _get_specific_stats(self, f): + samples = list(f.keys()) + steps = f[samples[0]]["data"].shape[0] + return len(samples), steps + + def _get_specific_bcs(self, f): + return [0, 0] + + def _reconstruct_sample(self, file, sample_idx, time_idx, n_steps): + samples = list(file.keys()) + return file[samples[sample_idx]]["data"][ + time_idx - n_steps * self.dt : time_idx + self.dt + ].transpose(0, 2, 1)[:, :, :, None] diff --git a/ppcfd/models/multiple_physics_pretraining/data_utils/mixed_dset_sampler.py b/ppcfd/models/multiple_physics_pretraining/data_utils/mixed_dset_sampler.py new file mode 100644 index 0000000..41f13a7 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/data_utils/mixed_dset_sampler.py @@ -0,0 +1,97 @@ +from typing import Iterator + +import numpy as np +import paddle + +__all__ = ["MultisetSampler"] + + +class MultisetSampler(paddle.io.Sampler): + """Sampler that restricts data loading to a subset of the dataset.""" + + def __init__( + self, + dataset: paddle.io.Dataset, + batch_size: int, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = True, + max_samples=10, + rank=0, + world_size=1, + distributed=True, + ) -> None: + self.batch_size = batch_size + self.sub_dsets = dataset.sub_dsets + self.dataset = dataset + self.epoch = 0 + self.drop_last = drop_last + self.shuffle = shuffle + self.seed = seed + self.max_samples = max_samples + self.rank = rank + self.world_size = world_size + self.distributed = distributed + + def _build_subdataset_indices(self): + """Build per-rank indices for each sub-dataset.""" + subdataset_indices = [] + for subdataset_idx, subdataset in enumerate(self.sub_dsets): + indices = np.arange(len(subdataset), dtype=np.int64) + if self.shuffle: + rng = np.random.default_rng( + 1000 * self.epoch + 100 * self.seed + 10 * self.rank + subdataset_idx + ) + rng.shuffle(indices) + if self.distributed and self.world_size > 1: + indices = indices[self.rank :: self.world_size] + usable_size = (len(indices) // self.batch_size) * self.batch_size + subdataset_indices.append(indices[:usable_size].tolist()) + return subdataset_indices + + def __iter__(self) -> Iterator[int]: + subdataset_indices = self._build_subdataset_indices() + samplers = [iter(indices) for indices in subdataset_indices] + sampler_choices = [ + idx for idx, indices in enumerate(subdataset_indices) if len(indices) > 0 + ] + rng = np.random.default_rng(100 * self.epoch + 10 * self.seed + self.rank) + count = 0 + while len(sampler_choices) > 0: + count += 1 + index_sampled = int(rng.integers(low=0, high=len(sampler_choices))) + dset_sampled = sampler_choices[index_sampled] + offset = max(0, self.dataset.offsets[dset_sampled]) + try: + queue = [] + for i in range(self.batch_size): + queue.append(next(samplers[dset_sampled]) + offset) + if len(queue) == self.batch_size: + for d in queue: + yield d + except Exception as err: + print("ERRRR", err) + sampler_choices.pop(index_sampled) + print( + f"Note: dset {dset_sampled} fully used. Dsets remaining: {len(sampler_choices)}" + ) + continue + if count >= self.max_samples: + break + + def __len__(self) -> int: + available_batches = sum( + len(indices) // self.batch_size for indices in self._build_subdataset_indices() + ) + return min(self.max_samples, available_batches) * self.batch_size + + def set_epoch(self, epoch: int) -> None: + """ + Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch diff --git a/ppcfd/models/multiple_physics_pretraining/mixed_modules.py b/ppcfd/models/multiple_physics_pretraining/mixed_modules.py new file mode 100644 index 0000000..e137406 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/mixed_modules.py @@ -0,0 +1,75 @@ +import math +from functools import partial + +import einops +import numpy as np +import paddle + +from .spatial_modules import build_space_block +from .time_modules import AttentionBlock, build_time_block + + +def build_spacetime_block(params): + """ + Builds a spacetime block from the parameter file. + """ + if params.block_type == "axial": + space_block = build_space_block(params) + time_block = build_time_block(params) + return partial( + SpaceTimeBlock, + params.embed_dim, + params.num_heads, + space_override=space_block, + time_override=time_block, + gradient_checkpointing=params.gradient_checkpointing, + ) + else: + raise NotImplementedError + + +class SpaceTimeBlock(paddle.nn.Layer): + """ + Alternates spatial and temporal processing. Current code base uses + 1D attention over each axis. Spatial axes share weights. + + Note: MLP is in spatial block. + """ + + def __init__( + self, + hidden_dim=768, + num_heads=12, + drop_path=0.0, + space_override=None, + time_override=None, + gradient_checkpointing=False, + ): + super().__init__() + self.gradient_checkpointing = gradient_checkpointing + if space_override is not None: + self.spatial = space_override(drop_path=drop_path) + if time_override is not None: + self.temporal = time_override(drop_path=drop_path) + else: + self.temporal = AttentionBlock(hidden_dim, num_heads, drop_path=drop_path) + + def forward(self, x, bcs): + T, B, C, H, W = x.shape + if self.gradient_checkpointing: + wrapped_temporal = partial(self.temporal) + x = paddle.distributed.fleet.utils.recompute( + wrapped_temporal, x, use_reentrant=False + ) + else: + x = self.temporal(x) + x = einops.rearrange(x, "t b c h w -> (t b) c h w") + if self.gradient_checkpointing: + wrapped_spatial = partial(self.spatial) + x = paddle.distributed.fleet.utils.recompute( + wrapped_spatial, x, bcs, use_reentrant=False + ) + else: + x = self.spatial(x, bcs) + x = einops.rearrange(x, "(t b) c h w -> t b c h w", t=T) + return x diff --git a/ppcfd/models/multiple_physics_pretraining/paddle_utils.py b/ppcfd/models/multiple_physics_pretraining/paddle_utils.py new file mode 100644 index 0000000..d97534e --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/paddle_utils.py @@ -0,0 +1,20 @@ + +import paddle + +############################## 相关utils函数,如下 ############################## +############################ PaConvert 自动生成的代码 ########################### + +def device2int(device): + if isinstance(device, str): + device = device.replace('cuda', 'gpu') + device = device.replace('gpu:', '') + return int(device) + +class PaddleFlag: + cudnn_enabled = True + cudnn_benchmark = False + matmul_allow_tf32 = False + cudnn_allow_tf32 = True + cudnn_deterministic = False +############################## 相关utils函数,如上 ############################## + diff --git a/ppcfd/models/multiple_physics_pretraining/shared_modules.py b/ppcfd/models/multiple_physics_pretraining/shared_modules.py new file mode 100644 index 0000000..45ccd98 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/shared_modules.py @@ -0,0 +1,175 @@ +import math +from functools import partial + +import einops +import numpy as np +import paddle + + +class ContinuousPositionBias1D(paddle.nn.Layer): + def __init__(self, n_heads): + super().__init__() + self.num_heads = n_heads + self.cpb_mlp = paddle.nn.Sequential( + paddle.nn.Linear(1, 512, bias_attr=True), + paddle.nn.ReLU(), + paddle.nn.Linear(512, n_heads, bias_attr=False), + ) + + def forward(self, h, h2, bc=0): + dtype = self.cpb_mlp[0].weight.dtype + if bc == 0: + relative_coords = paddle.arange(-(h - 1), h, dtype=dtype) / (h - 1) + elif bc == 1: + relative_coords = paddle.cat( + [ + paddle.arange(1, h // 2 + 1, dtype=dtype), + paddle.arange(-(h // 2 - 1), h // 2 + 1, dtype=dtype), + paddle.arange(-(h // 2 - 1), 0, dtype=dtype), + ] + ) / (h - 1) + coords = paddle.arange(h, dtype=paddle.float32) + coords = coords[None, :] - coords[:, None] + coords = coords + (h - 1) + rel_pos_model = 16 * paddle.sigmoid( + self.cpb_mlp(relative_coords[:, None]).squeeze() + ) + biases = rel_pos_model[coords.astype("int64")] + return biases.transpose([2, 0, 1]).unsqueeze(0) + + +class RelativePositionBias(paddle.nn.Layer): + """ + From https://gist.github.com/huchenxucs/c65524185e8e35c4bcfae4059f896c16 + + Implementation of T5 relative position bias - can probably do better, but starting with something known. + """ + + def __init__(self, bidirectional=True, num_buckets=32, max_distance=128, n_heads=2): + super(RelativePositionBias, self).__init__() + self.bidirectional = bidirectional + self.num_buckets = num_buckets + self.max_distance = max_distance + self.n_heads = n_heads + self.relative_attention_bias = paddle.nn.Embedding( + self.num_buckets, self.n_heads + ) + + @staticmethod + def _relative_position_bucket( + relative_position, bidirectional=True, num_buckets=32, max_distance=32 + ): + """ + Adapted from Mesh Tensorflow: + https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593 + Translate relative position to a bucket number for relative attention. + The relative position is defined as memory_position - query_position, i.e. + the distance in tokens from the attending position to the attended-to + position. If bidirectional=False, then positive relative positions are + invalid. + We use smaller buckets for small absolute relative_position and larger buckets + for larger absolute relative_positions. All relative positions >=max_distance + map to the same bucket. All relative positions <=-max_distance map to the + same bucket. This should allow for more graceful generalization to longer + sequences than the model has been trained on. + Args: + relative_position: an int32 Tensor + bidirectional: a boolean - whether the attention is bidirectional + num_buckets: an integer + max_distance: an integer + Returns: + a Tensor with the same shape as relative_position, containing int32 + values in the range [0, num_buckets) + """ + ret = 0 + n = -relative_position + if bidirectional: + num_buckets //= 2 + ret += (n < 0).astype("int64") * num_buckets + n = paddle.abs(n) + else: + n = paddle.maximum(n, paddle.zeros_like(n)) + max_exact = num_buckets // 2 + is_small = n < max_exact + val_if_large = max_exact + ( + paddle.log(n.astype("float32") / max_exact) + / math.log(max_distance / max_exact) + * (num_buckets - max_exact) + ).astype("int64") + val_if_large = paddle.minimum( + val_if_large, paddle.full_like(val_if_large, num_buckets - 1) + ) + ret += paddle.where(is_small, n, val_if_large) + return ret + + def compute_bias(self, qlen, klen, bc=0): + """Compute binned relative position bias""" + context_position = paddle.arange(qlen, dtype="int64")[:, None] + memory_position = paddle.arange(klen, dtype="int64")[None, :] + relative_position = memory_position - context_position + """ + k + 0 1 2 3 + q -1 0 1 2 + -2 -1 0 1 + -3 -2 -1 0 + """ + if bc == 1: + thresh = klen // 2 + relative_position[relative_position < -thresh] = ( + relative_position[relative_position < -thresh] % thresh + ) + relative_position[relative_position > thresh] = ( + relative_position[relative_position > thresh] % -thresh + ) + rp_bucket = self._relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + ) + values = self.relative_attention_bias(rp_bucket) + values = values.transpose([2, 0, 1]).unsqueeze(0) + return values + + def forward(self, qlen, klen, bc=0): + return self.compute_bias(qlen, klen, bc) + + +class MLP(paddle.nn.Layer): + def __init__(self, hidden_dim, exp_factor=4.0): + super().__init__() + self.fc1 = paddle.nn.Linear(hidden_dim, int(hidden_dim * exp_factor)) + self.fc2 = paddle.nn.Linear(int(hidden_dim * exp_factor), hidden_dim) + self.act = paddle.nn.GELU() + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class AbsolutePositionBias(paddle.nn.Layer): + """ + From https://gist.github.com/huchenxucs/c65524185e8e35c4bcfae4059f896c16 + + Implementation of T5 relative position bias - can probably do better, but starting with something known. + """ + + def __init__(self, hidden_dim, n_tokens): + super(AbsolutePositionBias, self).__init__() + self.bias = self.create_parameter( + shape=[1, n_tokens, hidden_dim], + default_initializer=paddle.nn.initializer.Normal(std=0.02), + ) + + def forward(self): + return self.bias + + +class MLP(paddle.nn.Layer): + def __init__(self, hidden_dim, exp_factor=4.0): + super().__init__() + self.fc1 = paddle.nn.Linear(hidden_dim, int(hidden_dim * exp_factor)) + self.fc2 = paddle.nn.Linear(int(hidden_dim * exp_factor), hidden_dim) + self.act = paddle.nn.GELU() + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) diff --git a/ppcfd/models/multiple_physics_pretraining/spatial_modules.py b/ppcfd/models/multiple_physics_pretraining/spatial_modules.py new file mode 100644 index 0000000..43c3fb9 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/spatial_modules.py @@ -0,0 +1,270 @@ +import math +from functools import partial + +import einops +import numpy as np +import paddle + +from .DropPath_util import DropPath +from .shared_modules import (MLP, ContinuousPositionBias1D, + RelativePositionBias) + + +def scaled_dot_product_attention(q, k, v, attn_mask=None): + scale = q.shape[-1] ** -0.5 + scores = paddle.matmul(q * scale, k, transpose_y=True) + if attn_mask is not None: + scores = scores + attn_mask + attn = paddle.nn.functional.softmax(scores, axis=-1) + return paddle.matmul(attn, v) + + +def build_space_block(params): + if params.space_type == "axial_attention": + return partial( + AxialAttentionBlock, + params.embed_dim, + params.num_heads, + bias_type=params.bias_type, + ) + else: + raise NotImplementedError + + +class RMSInstanceNorm2d(paddle.nn.Layer): + def __init__(self, dim, affine=True, eps=1e-08): + super().__init__() + self.eps = eps + self.affine = affine + if affine: + self.weight = self.create_parameter( + shape=[dim], + default_initializer=paddle.nn.initializer.Constant(value=1.0), + ) + self.bias = self.create_parameter( + shape=[dim], + default_initializer=paddle.nn.initializer.Constant(value=0.0), + ) + + def forward(self, x): + """std, mean = torch.std_mean(x, dim=(-2, -1), keepdims=True) + paddle has no std_mean""" + std = paddle.std(x=x, axis=(-2, -1), keepdim=True) + x = x / (std + self.eps) + if self.affine: + x = x * self.weight[None, :, None, None] + return x + + +class SubsampledLinear(paddle.nn.Layer): + """ + Cross between a linear layer and EmbeddingBag - takes in input + and list of indices denoting which state variables from the state + vocab are present and only performs the linear layer on rows/cols relevant + to those state variables + + Assumes (... C) input + """ + + def __init__(self, dim_in, dim_out, subsample_in=True): + super().__init__() + self.subsample_in = subsample_in + self.dim_in = dim_in + self.dim_out = dim_out + temp_linear = paddle.nn.Linear(dim_in, dim_out) + self.weight = temp_linear.weight + self.bias = temp_linear.bias + + def forward(self, x, labels): + labels = labels[0] + label_size = len(labels) + if self.subsample_in: + scale = (self.dim_in / label_size) ** 0.5 + x = scale * paddle.nn.functional.linear( + x, self.weight[labels, :], self.bias + ) + else: + x = paddle.nn.functional.linear( + x, self.weight[:, labels], self.bias[labels] + ) + return x + + +class hMLP_stem(paddle.nn.Layer): + """Image to Patch Embedding""" + + def __init__(self, patch_size=(16, 16), in_chans=3, embed_dim=768): + super().__init__() + self.patch_size = patch_size + self.in_chans = in_chans + self.embed_dim = embed_dim + self.in_proj = paddle.nn.Sequential( + *[ + paddle.nn.Conv2D( + in_chans, + embed_dim // 4, + kernel_size=4, + stride=4, + bias_attr=False, + ), + RMSInstanceNorm2d(embed_dim // 4, affine=True), + paddle.nn.GELU(), + paddle.nn.Conv2D( + embed_dim // 4, + embed_dim // 4, + kernel_size=2, + stride=2, + bias_attr=False, + ), + RMSInstanceNorm2d(embed_dim // 4, affine=True), + paddle.nn.GELU(), + paddle.nn.Conv2D( + embed_dim // 4, + embed_dim, + kernel_size=2, + stride=2, + bias_attr=False, + ), + RMSInstanceNorm2d(embed_dim, affine=True), + ] + ) + + def forward(self, x): + x = self.in_proj(x) + return x + + +class hMLP_output(paddle.nn.Layer): + """Patch to Image De-bedding""" + + def __init__(self, patch_size=(16, 16), out_chans=3, embed_dim=768): + super().__init__() + self.patch_size = patch_size + self.out_chans = out_chans + self.embed_dim = embed_dim + self.out_proj = paddle.nn.Sequential( + *[ + paddle.nn.Conv2DTranspose( + in_channels=embed_dim, + out_channels=embed_dim // 4, + kernel_size=2, + stride=2, + bias_attr=False, + ), + RMSInstanceNorm2d(embed_dim // 4, affine=True), + paddle.nn.GELU(), + paddle.nn.Conv2DTranspose( + in_channels=embed_dim // 4, + out_channels=embed_dim // 4, + kernel_size=2, + stride=2, + bias_attr=False, + ), + RMSInstanceNorm2d(embed_dim // 4, affine=True), + paddle.nn.GELU(), + ] + ) + out_head = paddle.nn.Conv2DTranspose( + in_channels=embed_dim // 4, out_channels=out_chans, kernel_size=4, stride=4 + ) + self.out_kernel = out_head.weight + self.out_bias = out_head.bias + + def forward(self, x, state_labels): + x = self.out_proj(x) + x = paddle.nn.functional.conv2d_transpose( + x=x, + weight=self.out_kernel[:, state_labels], + bias=self.out_bias[state_labels], + stride=4, + ) + return x + + +class AxialAttentionBlock(paddle.nn.Layer): + def __init__( + self, + hidden_dim=768, + num_heads=12, + drop_path=0, + layer_scale_init_value=1e-06, + bias_type="rel", + ): + super().__init__() + self.num_heads = num_heads + self.norm1 = RMSInstanceNorm2d(hidden_dim, affine=True) + self.norm2 = RMSInstanceNorm2d(hidden_dim, affine=True) + self.gamma_att = ( + self.create_parameter( + shape=[hidden_dim], + default_initializer=paddle.nn.initializer.Constant( + value=layer_scale_init_value + ), + ) + if layer_scale_init_value > 0 + else None + ) + self.gamma_mlp = ( + self.create_parameter( + shape=[hidden_dim], + default_initializer=paddle.nn.initializer.Constant( + value=layer_scale_init_value + ), + ) + if layer_scale_init_value > 0 + else None + ) + self.input_head = paddle.nn.Conv2D(hidden_dim, 3 * hidden_dim, 1) + self.output_head = paddle.nn.Conv2D(hidden_dim, hidden_dim, 1) + self.qnorm = paddle.nn.LayerNorm(hidden_dim // num_heads) + self.knorm = paddle.nn.LayerNorm(hidden_dim // num_heads) + if bias_type == "none": + self.rel_pos_bias = lambda x, y: None + elif bias_type == "continuous": + self.rel_pos_bias = ContinuousPositionBias1D(n_heads=num_heads) + else: + self.rel_pos_bias = RelativePositionBias(n_heads=num_heads) + """""" + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else paddle.nn.Identity() + ) + self.mlp = MLP(hidden_dim) + self.mlp_norm = RMSInstanceNorm2d(hidden_dim, affine=True) + + def forward(self, x, bcs): + B, C, H, W = x.shape + input = x.clone() + x = self.norm1(x) + x = self.input_head(x) + x = einops.rearrange(x, "b (he c) h w -> b he h w c", he=self.num_heads) + q, k, v = paddle.split(x, num_or_sections=3, axis=-1) + q, k = self.qnorm(q), self.knorm(k) + qx, kx, vx = map( + lambda x: einops.rearrange(x, "b he h w c -> (b h) he w c"), [q, k, v] + ) + rel_pos_bias_x = self.rel_pos_bias(W, W, bcs[0, 0]) + if rel_pos_bias_x is not None: + xx = scaled_dot_product_attention(qx, kx, vx, attn_mask=rel_pos_bias_x) + else: + xx = scaled_dot_product_attention(qx, kx, vx) + xx = einops.rearrange(xx, "(b h) he w c -> b (he c) h w", h=H) + qy, ky, vy = map( + lambda x: einops.rearrange(x, "b he h w c -> (b w) he h c"), [q, k, v] + ) + rel_pos_bias_y = self.rel_pos_bias(H, H, bcs[0, 1]) + if rel_pos_bias_y is not None: + xy = scaled_dot_product_attention(qy, ky, vy, attn_mask=rel_pos_bias_y) + else: + xy = scaled_dot_product_attention(qy, ky, vy) + xy = einops.rearrange(xy, "(b w) he h c -> b (he c) h w", w=W) + x = (xx + xy) / 2 + x = self.norm2(x) + x = self.output_head(x) + x = self.drop_path(x * self.gamma_att[None, :, None, None]) + input + input = x.clone() + x = einops.rearrange(x, "b c h w -> b h w c") + x = self.mlp(x) + x = einops.rearrange(x, "b h w c -> b c h w") + x = self.mlp_norm(x) + output = input + self.drop_path(self.gamma_mlp[None, :, None, None] * x) + return output diff --git a/ppcfd/models/multiple_physics_pretraining/time_modules.py b/ppcfd/models/multiple_physics_pretraining/time_modules.py new file mode 100644 index 0000000..23effc5 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/time_modules.py @@ -0,0 +1,99 @@ +import math +from functools import partial + +import einops +import numpy as np +import paddle + +from .DropPath_util import DropPath +from .shared_modules import (MLP, ContinuousPositionBias1D, + RelativePositionBias) + + +def scaled_dot_product_attention(q, k, v, attn_mask=None): + scale = q.shape[-1] ** -0.5 + scores = paddle.matmul(q * scale, k, transpose_y=True) + if attn_mask is not None: + scores = scores + attn_mask + attn = paddle.nn.functional.softmax(scores, axis=-1) + return paddle.matmul(attn, v) + + +def build_time_block(params): + """ + Builds a time block from the parameter file. + """ + if params.time_type == "attention": + return partial( + AttentionBlock, + params.embed_dim, + params.num_heads, + bias_type=params.bias_type, + ) + else: + raise NotImplementedError + + +class AttentionBlock(paddle.nn.Layer): + def __init__( + self, + hidden_dim=768, + num_heads=12, + drop_path=0, + layer_scale_init_value=1e-06, + bias_type="rel", + ): + super().__init__() + self.num_heads = num_heads + self.norm1 = paddle.nn.InstanceNorm2D( + num_features=hidden_dim, weight_attr=True, bias_attr=True + ) + self.norm2 = paddle.nn.InstanceNorm2D( + num_features=hidden_dim, weight_attr=True, bias_attr=True + ) + self.gamma = ( + self.create_parameter( + shape=[hidden_dim], + default_initializer=paddle.nn.initializer.Constant( + value=layer_scale_init_value + ), + ) + if layer_scale_init_value > 0 + else None + ) + self.input_head = paddle.nn.Conv2D(hidden_dim, 3 * hidden_dim, 1) + self.output_head = paddle.nn.Conv2D(hidden_dim, hidden_dim, 1) + self.qnorm = paddle.nn.LayerNorm(hidden_dim // num_heads) + self.knorm = paddle.nn.LayerNorm(hidden_dim // num_heads) + if bias_type == "none": + self.rel_pos_bias = lambda x, y: None + elif bias_type == "continuous": + self.rel_pos_bias = ContinuousPositionBias1D(n_heads=num_heads) + else: + self.rel_pos_bias = RelativePositionBias(n_heads=num_heads) + self.drop_path = ( + DropPath(drop_path) if drop_path > 0.0 else paddle.nn.Identity() + ) + + def forward(self, x): + T, B, C, H, W = x.shape + input = x.clone() + x = einops.rearrange(x, "t b c h w -> (t b) c h w") + x = self.norm1(x) + x = self.input_head(x) + x = einops.rearrange( + x, "(t b) (he c) h w -> (b h w) he t c", t=T, he=self.num_heads + ) + q, k, v = paddle.split(x, num_or_sections=3, axis=-1) + q, k = self.qnorm(q), self.knorm(k) + rel_pos_bias = self.rel_pos_bias(T, T) + if rel_pos_bias is not None: + x = scaled_dot_product_attention(q, k, v, attn_mask=rel_pos_bias) + else: + x = scaled_dot_product_attention(q, k, v) + x = einops.rearrange(x, "(b h w) he t c -> (t b) (he c) h w", h=H, w=W) + x = self.norm2(x) + x = self.output_head(x) + x = einops.rearrange(x, "(t b) c h w -> t b c h w", t=T) + output = self.drop_path(x * self.gamma[None, None, :, None, None]) + input + return output diff --git a/ppcfd/models/multiple_physics_pretraining/utils/YParams.py b/ppcfd/models/multiple_physics_pretraining/utils/YParams.py new file mode 100644 index 0000000..752a299 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/YParams.py @@ -0,0 +1,47 @@ +import logging + +from ruamel.yaml import YAML + + +class YParams: + """Yaml file parser""" + + def __init__(self, yaml_filename, config_name, print_params=False): + self._yaml_filename = yaml_filename + self._config_name = config_name + self.params = {} + if print_params: + print("------------------ Configuration ------------------") + with open(yaml_filename) as _file: + for key, val in YAML().load(_file)[config_name].items(): + if print_params: + print(key, val) + if val == "None": + val = None + self.params[key] = val + self.__setattr__(key, val) + if print_params: + print("---------------------------------------------------") + + def __getitem__(self, key): + return self.params[key] + + def __setitem__(self, key, val): + self.params[key] = val + self.__setattr__(key, val) + + def __contains__(self, key): + return key in self.params + + def update_params(self, config): + for key, val in config.items(): + self.params[key] = val + self.__setattr__(key, val) + + def log(self): + logging.info("------------------ Configuration ------------------") + logging.info("Configuration file: " + str(self._yaml_filename)) + logging.info("Configuration name: " + str(self._config_name)) + for key, val in self.params.items(): + logging.info(str(key) + " " + str(val)) + logging.info("---------------------------------------------------") diff --git a/ppcfd/models/multiple_physics_pretraining/utils/__init__.py b/ppcfd/models/multiple_physics_pretraining/utils/__init__.py new file mode 100644 index 0000000..db4ed6f --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/__init__.py @@ -0,0 +1 @@ +"""Shared utilities for the Paddle implementation.""" diff --git a/ppcfd/models/multiple_physics_pretraining/utils/adan_paddle.py b/ppcfd/models/multiple_physics_pretraining/utils/adan_paddle.py new file mode 100644 index 0000000..f2e6c72 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/adan_paddle.py @@ -0,0 +1,95 @@ +import math + +import paddle + +from .custom_optimizer_base import ( + TorchStylePaddleOptimizer, + divide_inplace, + scale_inplace, +) + + +def exists(val): + return val is not None + + +class Adan(TorchStylePaddleOptimizer): + def __init__( + self, + params, + lr=0.001, + betas=(0.02, 0.08, 0.01), + eps=1e-08, + weight_decay=0, + restart_cond: callable = None, + ): + assert len(betas) == 3 + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + restart_cond=restart_cond, + ) + super().__init__(params, lr, defaults) + + def step(self, closure=None): + self._sync_group_lr() + loss = None + if exists(closure): + loss = closure() + for group in self.param_groups: + lr = group["lr"] + beta1, beta2, beta3 = group["betas"] + weight_decay = group["weight_decay"] + eps = group["eps"] + restart_cond = group["restart_cond"] + for p in group["params"]: + if not exists(p.grad): + continue + data, grad = p.data, p.grad.data + assert not grad.is_sparse() + state = self.state[p] + if len(state) == 0: + state["step"] = 0 + state["prev_grad"] = paddle.zeros_like(grad) + state["m"] = paddle.zeros_like(grad) + state["v"] = paddle.zeros_like(grad) + state["n"] = paddle.zeros_like(grad) + step, m, v, n, prev_grad = ( + state["step"], + state["m"], + state["v"], + state["n"], + state["prev_grad"], + ) + if step > 0: + prev_grad = state["prev_grad"] + scale_inplace(m, 1 - beta1).add_(grad, alpha=beta1) + grad_diff = grad - prev_grad + scale_inplace(v, 1 - beta2).add_(grad_diff, alpha=beta2) + next_n = (grad + (1 - beta2) * grad_diff) ** 2 + scale_inplace(n, 1 - beta3).add_(next_n, alpha=beta3) + step += 1 + correct_m, correct_v, correct_n = map( + lambda n: 1 / (1 - (1 - n) ** step), (beta1, beta2, beta3) + ) + + def grad_step_(data, m, v, n): + weighted_step_size = lr / ((n * correct_n).sqrt() + eps) + denom = 1 + weight_decay * lr + update = weighted_step_size * ( + m * correct_m + (1 - beta2) * v * correct_v + ) + data.add_(update, alpha=-1.0) + divide_inplace(data, denom) + + grad_step_(data, m, v, n) + if exists(restart_cond) and restart_cond(state): + m.data.copy_(grad) + v.zero_() + n.data.copy_(grad**2) + grad_step_(data, m, v, n) + prev_grad.copy_(grad) + state["step"] = step + return loss diff --git a/ppcfd/models/multiple_physics_pretraining/utils/custom_optimizer_base.py b/ppcfd/models/multiple_physics_pretraining/utils/custom_optimizer_base.py new file mode 100644 index 0000000..ca0d690 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/custom_optimizer_base.py @@ -0,0 +1,131 @@ +from collections import defaultdict + +import paddle + + +def assign_inplace(tensor, value): + tensor.copy_(value) + return tensor + + +def scale_inplace(tensor, scalar): + return assign_inplace(tensor, tensor * scalar) + + +def divide_inplace(tensor, value): + return assign_inplace(tensor, tensor / value) + + +class TorchStylePaddleOptimizer(paddle.optimizer.Optimizer): + """Minimal compatibility layer for Torch-style custom optimizers.""" + + def __init__(self, params, lr, defaults): + super().__init__( + learning_rate=float(lr), + parameters=params, + weight_decay=defaults.get("weight_decay", 0.0), + ) + self.defaults = defaults.copy() + if self._param_groups and isinstance(self._param_groups[0], dict): + self.param_groups = self._param_groups + else: + self.param_groups = [{"params": list(self._param_groups or [])}] + self.state = defaultdict(dict) + self._init_param_groups() + + def _init_param_groups(self): + self._sync_group_lr() + for group in self.param_groups: + for key, value in self.defaults.items(): + if key == "lr": + continue + group.setdefault(key, value) + + def _sync_group_lr(self): + base_lr = float(self.get_lr()) + for group in self.param_groups: + group["lr"] = base_lr * float(group.get("learning_rate", 1.0)) + + @staticmethod + def _clone_value(value): + if isinstance(value, paddle.Tensor): + return value.clone() + return value + + @paddle.base.framework.dygraph_only + def state_dict(self): + state = {} + for group in self.param_groups: + for param in group["params"]: + param_state = self.state.get(param, {}) + if not param_state: + continue + state[param.name] = { + key: self._clone_value(value) + for key, value in param_state.items() + } + + group_state = [] + for group in self.param_groups: + group_state.append( + { + key: ( + [param.name for param in value] + if key == "params" + else self._clone_value(value) + ) + for key, value in group.items() + } + ) + + payload = {"state": state, "param_groups": group_state} + if isinstance(self._learning_rate, paddle.optimizer.lr.LRScheduler): + payload["LR_Scheduler"] = self._learning_rate.state_dict() + return payload + + @paddle.base.framework.dygraph_only + def set_state_dict(self, state_dict): + if "state" not in state_dict or "param_groups" not in state_dict: + raise ValueError( + "Unsupported optimizer state format for custom Paddle optimizer." + ) + + if ( + isinstance(self._learning_rate, paddle.optimizer.lr.LRScheduler) + and "LR_Scheduler" in state_dict + ): + self._learning_rate.set_state_dict(state_dict["LR_Scheduler"]) + + name_to_param = { + param.name: param + for group in self.param_groups + for param in group["params"] + } + self.state = defaultdict(dict) + for param_name, saved_state in state_dict["state"].items(): + if param_name not in name_to_param: + raise ValueError( + f"Optimizer state contains unknown parameter: {param_name}" + ) + param = name_to_param[param_name] + self.state[param] = { + key: self._clone_value(value) + for key, value in saved_state.items() + } + + saved_groups = state_dict["param_groups"] + if len(saved_groups) != len(self.param_groups): + raise ValueError("Optimizer param group count mismatch during restore.") + + for group, saved_group in zip(self.param_groups, saved_groups): + current_names = [param.name for param in group["params"]] + if current_names != saved_group["params"]: + raise ValueError("Optimizer param group layout mismatch during restore.") + for key, value in saved_group.items(): + if key == "params": + continue + group[key] = self._clone_value(value) + + self._sync_group_lr() + + load_state_dict = set_state_dict diff --git a/ppcfd/models/multiple_physics_pretraining/utils/dadapt_adam_paddle.py b/ppcfd/models/multiple_physics_pretraining/utils/dadapt_adam_paddle.py new file mode 100644 index 0000000..536cfcb --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/dadapt_adam_paddle.py @@ -0,0 +1,236 @@ +import logging +import math +import os +import pdb +from typing import TYPE_CHECKING, Any, Callable, Optional + +import paddle + +from .custom_optimizer_base import ( + TorchStylePaddleOptimizer, + scale_inplace, +) + +############################## 相关utils函数,如下 ############################## + +def device2int(device): + if isinstance(device, str): + device = device.replace('cuda', 'gpu') + device = device.replace('gpu:', '') + return int(device) +############################## 相关utils函数,如上 ############################## + + +if TYPE_CHECKING: + pass +else: + _params_t = Any + + +class DAdaptAdam(TorchStylePaddleOptimizer): + """ + Implements Adam with D-Adaptation automatic step-sizes. + Leave LR set to 1 unless you encounter instability. + + To scale the learning rate differently for each layer, set the 'layer_scale' + for each parameter group. Increase (or decrease) from its default value of 1.0 + to increase (or decrease) the learning rate for that layer relative to the + other layers. + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. + betas (Tuple[float, float], optional): coefficients used for computing + running averages of gradient and its square (default: (0.9, 0.999)) + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0). + log_every (int): + Log using print every k steps, default 0 (no logging). + decouple (boolean): + Use AdamW style decoupled weight decay + use_bias_correction (boolean): + Turn on Adam's bias correction. Off by default. + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + fsdp_in_use (bool): + If you're using sharded parameters, this should be set to True. The optimizer + will attempt to auto-detect this, but if you're using an implementation other + than PyTorch's builtin version, the auto-detection won't work. + """ + + def __init__( + self, + params, + lr=1.0, + betas=(0.9, 0.999), + eps=1e-08, + weight_decay=0, + log_every=0, + decouple=False, + use_bias_correction=False, + d0=1e-06, + growth_rate=float("inf"), + fsdp_in_use=False, + ): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if decouple: + print(f"Using decoupled weight decay") + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + d=d0, + k=0, + layer_scale=1.0, + numerator_weighted=0.0, + log_every=log_every, + growth_rate=growth_rate, + use_bias_correction=use_bias_correction, + decouple=decouple, + fsdp_in_use=fsdp_in_use, + ) + self.d0 = d0 + super().__init__(params, lr, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + def step(self, closure=None): + """Performs a single optimization step. + + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._sync_group_lr() + loss = None + if closure is not None: + loss = closure() + sk_l1 = 0.0 + group = self.param_groups[0] + use_bias_correction = group["use_bias_correction"] + numerator_weighted = group["numerator_weighted"] + beta1, beta2 = group["betas"] + k = group["k"] + d = group["d"] + lr = max(group["lr"] for group in self.param_groups) + if use_bias_correction: + bias_correction = (1 - beta2 ** (k + 1)) ** 0.5 / (1 - beta1 ** (k + 1)) + else: + bias_correction = 1 + dlr = d * lr * bias_correction + growth_rate = group["growth_rate"] + decouple = group["decouple"] + log_every = group["log_every"] + fsdp_in_use = group["fsdp_in_use"] + sqrt_beta2 = beta2**0.5 + numerator_acum = 0.0 + for group in self.param_groups: + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + group_lr = group["lr"] + r = group["layer_scale"] + if group_lr not in [lr, 0.0]: + raise RuntimeError( + f"Setting different lr values in different parameter groups is only supported for values of 0. To scale the learning rate differently for each layer, set the 'layer_scale' value instead." + ) + for p in group["params"]: + if p.grad is None: + continue + if hasattr(p, "_fsdp_flattened"): + fsdp_in_use = True + grad = p.grad.data + if decay != 0 and not decouple: + grad.add_(p.data, alpha=decay) + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["s"] = paddle.zeros_like(p.data).detach() + state["exp_avg"] = paddle.zeros_like(p.data).detach() + state["exp_avg_sq"] = paddle.zeros_like(p.data).detach() + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + s = state["s"] + if group_lr > 0.0: + denom = exp_avg_sq.sqrt() + eps + numerator_acum += ( + r + * dlr + * paddle.dot(grad.flatten(), s.div(denom).flatten()).item() + ) + scale_inplace(exp_avg, beta1).add_( + grad, alpha=r * dlr * (1 - beta1) + ) + scale_inplace(exp_avg_sq, beta2).add_((1 - beta2) * grad * grad) + scale_inplace(s, sqrt_beta2).add_( + grad, alpha=dlr * (1 - sqrt_beta2) + ) + sk_l1 += r * s.abs().sum().item() + d_hat = d + if sk_l1 == 0: + return loss + if fsdp_in_use: + dist_tensor = paddle.zeros(2).cuda() + dist_tensor[0] = numerator_acum + dist_tensor[1] = sk_l1 + paddle.distributed.all_reduce( + tensor=dist_tensor, op=paddle.distributed.ReduceOp.SUM + ) + global_numerator_weighted = ( + sqrt_beta2 * numerator_weighted + (1 - sqrt_beta2) * dist_tensor[0] + ) + global_sk_l1 = dist_tensor[1] + else: + global_numerator_weighted = ( + sqrt_beta2 * numerator_weighted + (1 - sqrt_beta2) * numerator_acum + ) + global_sk_l1 = sk_l1 + if lr > 0.0: + d_hat = global_numerator_weighted / ((1 - sqrt_beta2) * global_sk_l1) + d = max(d, min(d_hat, d * growth_rate)) + if log_every > 0 and k % log_every == 0: + logging.info( + f"lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sk_l1={global_sk_l1:1.1e} numerator_weighted={global_numerator_weighted:1.1e}" + ) + for group in self.param_groups: + group["numerator_weighted"] = global_numerator_weighted + group["d"] = d + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + exp_avg, exp_avg_sq = state["exp_avg"], state["exp_avg_sq"] + state["step"] += 1 + denom = exp_avg_sq.sqrt() + eps + if decay != 0 and decouple: + p.data.add_(p.data, alpha=-decay * dlr) + p.data.add_(exp_avg / denom, alpha=-1.0) + group["k"] = k + 1 + return loss diff --git a/ppcfd/models/multiple_physics_pretraining/utils/dadapt_adan_paddle.py b/ppcfd/models/multiple_physics_pretraining/utils/dadapt_adan_paddle.py new file mode 100644 index 0000000..6a2ffd2 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/dadapt_adan_paddle.py @@ -0,0 +1,216 @@ +from typing import TYPE_CHECKING, Any + +import paddle + +from .custom_optimizer_base import ( + TorchStylePaddleOptimizer, + divide_inplace, + scale_inplace, +) + +if TYPE_CHECKING: + pass +else: + _params_t = Any + + +def to_real(x): + if paddle.is_complex(x): + return x.real() + else: + return x + + +class DAdaptAdan(TorchStylePaddleOptimizer): + """ + Implements Adan with D-Adaptation automatic step-sizes. + Has not been as heavily tested as DAdaptAdam and should be considered experimental. + + Leave LR set to 1 unless you encounter instability. + Adan was proposed in + Adan: Adaptive Nesterov Momentum Algorithm for Faster Optimizing Deep Models[J]. arXiv preprint arXiv:2208.06677, 2022. + https://arxiv.org/abs/2208.06677 + + Arguments: + params (iterable): + Iterable of parameters to optimize or dicts defining parameter groups. + lr (float): + Learning rate adjustment parameter. Increases or decreases the D-adapted learning rate. + betas (Tuple[float, float, flot], optional): coefficients used for computing + running averages of gradient and its norm. (default: (0.98, 0.92, 0.99)) + eps (float): + Term added to the denominator outside of the root operation to improve numerical stability. (default: 1e-8). + weight_decay (float): + Weight decay, i.e. a L2 penalty (default: 0.02). + no_prox (boolean): + how to perform the decoupled weight decay (default: False) + log_every (int): + Log using print every k steps, default 0 (no logging). + d0 (float): + Initial D estimate for D-adaptation (default 1e-6). Rarely needs changing. + growth_rate (float): + prevent the D estimate from growing faster than this multiplicative rate. + Default is inf, for unrestricted. Values like 1.02 give a kind of learning + rate warmup effect. + """ + + def __init__( + self, + params, + lr=1.0, + betas=(0.98, 0.92, 0.99), + eps=1e-08, + weight_decay=0.02, + no_prox=False, + log_every=0, + d0=1e-06, + growth_rate=float("inf"), + ): + if not 0.0 < d0: + raise ValueError("Invalid d0 value: {}".format(d0)) + if not 0.0 < lr: + raise ValueError("Invalid learning rate: {}".format(lr)) + if not 0.0 < eps: + raise ValueError("Invalid epsilon value: {}".format(eps)) + if not 0.0 <= betas[0] < 1.0: + raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0])) + if not 0.0 <= betas[1] < 1.0: + raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1])) + if not 0.0 <= betas[2] < 1.0: + raise ValueError("Invalid beta parameter at index 2: {}".format(betas[2])) + defaults = dict( + lr=lr, + betas=betas, + eps=eps, + weight_decay=weight_decay, + no_prox=no_prox, + d=d0, + k=0, + gsq_weighted=0.0, + log_every=log_every, + growth_rate=growth_rate, + ) + super().__init__(params, lr, defaults) + + @property + def supports_memory_efficient_fp16(self): + return False + + @property + def supports_flat_params(self): + return True + + @paddle.no_grad() + def restart_opt(self): + for group in self.param_groups: + group["gsq_weighted"] = 0.0 + for p in group["params"]: + if not p.stop_gradient: + state = self.state[p] + state["step"] = 0 + state["s"] = paddle.zeros_like(p.data).detach() + state["exp_avg"] = paddle.zeros_like(p.data).detach() + state["exp_avg_diff"] = paddle.zeros_like(to_real(p.data)).detach() + state["exp_avg_sq"] = paddle.zeros_like(p.data).detach() + + @paddle.no_grad() + def step(self, closure=None): + """Performs a single optimization step. + Arguments: + closure (callable, optional): A closure that reevaluates the model + and returns the loss. + """ + self._sync_group_lr() + loss = None + if closure is not None: + loss = closure() + g_sq = 0.0 + sksq_weighted = 0.0 + sk_l1 = 0.0 + ngroups = len(self.param_groups) + group = self.param_groups[0] + gsq_weighted = group["gsq_weighted"] + d = group["d"] + lr = group["lr"] + dlr = d * lr + no_prox = group["no_prox"] + growth_rate = group["growth_rate"] + log_every = group["log_every"] + beta1, beta2, beta3 = group["betas"] + for group in self.param_groups: + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + if "step" not in state: + state["step"] = 0 + state["s"] = paddle.zeros_like(p.data).detach() + state["exp_avg"] = paddle.zeros_like(p.data).detach() + state["exp_avg_diff"] = paddle.zeros_like(p.data).detach() + state["exp_avg_sq"] = paddle.zeros_like(to_real(p.data)).detach() + if state["step"] == 0: + state["pre_grad"] = grad.clone() + exp_avg, exp_avg_diff, exp_avg_sq = ( + state["exp_avg"], + state["exp_avg_diff"], + state["exp_avg_sq"], + ) + grad_diff = grad - state["pre_grad"] + grad_grad = to_real(grad * grad.conj()) + update = grad + beta2 * grad_diff + update_update = to_real(update * update.conj()) + scale_inplace(exp_avg, beta1).add_(grad, alpha=dlr * (1.0 - beta1)) + scale_inplace(exp_avg_diff, beta2).add_( + grad_diff, alpha=dlr * (1.0 - beta2) + ) + scale_inplace(exp_avg_sq, beta3).add_(update_update, alpha=1.0 - beta3) + denom = exp_avg_sq.sqrt() + eps + g_sq += grad_grad.div_(denom).sum().item() + s = state["s"] + scale_inplace(s, beta3).add_(grad, alpha=dlr * (1.0 - beta3)) + sksq_weighted += to_real(s * s.conj()).div_(denom).sum().item() + sk_l1 += s.abs().sum().item() + gsq_weighted = beta3 * gsq_weighted + g_sq * dlr**2 * (1 - beta3) + d_hat = d + if sk_l1 == 0: + return loss + if lr > 0.0: + d_hat = (sksq_weighted / (1 - beta3) - gsq_weighted) / sk_l1 + d = max(d, min(d_hat, d * growth_rate)) + if log_every > 0 and k % log_every == 0: + print( + f"ng: {ngroups} lr: {lr} dlr: {dlr} d_hat: {d_hat}, d: {d}. sksq_weighted={sksq_weighted:1.1e} sk_l1={sk_l1:1.1e} gsq_weighted={gsq_weighted:1.1e}" + ) + for group in self.param_groups: + group["gsq_weighted"] = gsq_weighted + group["d"] = d + decay = group["weight_decay"] + k = group["k"] + eps = group["eps"] + for p in group["params"]: + if p.grad is None: + continue + grad = p.grad.data + state = self.state[p] + exp_avg, exp_avg_diff, exp_avg_sq = ( + state["exp_avg"], + state["exp_avg_diff"], + state["exp_avg_sq"], + ) + state["step"] += 1 + denom = exp_avg_sq.sqrt() + eps + denom = denom.astype(p.dtype) + update = (exp_avg + beta2 * exp_avg_diff).div_(denom) + if no_prox: + scale_inplace(p.data, 1 - dlr * decay) + p.add_(update, alpha=-1) + else: + p.add_(update, alpha=-1) + divide_inplace(p.data, 1 + dlr * decay) + state["pre_grad"].copy_(grad) + group["k"] = k + 1 + return loss diff --git a/ppcfd/models/multiple_physics_pretraining/utils/logging_utils.py b/ppcfd/models/multiple_physics_pretraining/utils/logging_utils.py new file mode 100644 index 0000000..810b7b1 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/logging_utils.py @@ -0,0 +1,40 @@ +import logging +import os + +import paddle + +_format = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + + +def config_logger(log_level=logging.INFO): + logging.basicConfig(format=_format, level=log_level) + + +def log_to_file( + logger_name=None, log_level=logging.INFO, log_filename="tensorflow.log" +): + if not os.path.exists(os.path.dirname(log_filename)): + os.makedirs(os.path.dirname(log_filename)) + if logger_name is not None: + log = logging.getLogger(logger_name) + else: + log = logging.getLogger() + fh = logging.FileHandler(log_filename) + fh.setLevel(log_level) + fh.setFormatter(logging.Formatter(_format)) + log.addHandler(fh) + + +def log_versions(): + import subprocess + + logging.info("--------------- Versions ---------------") + logging.info( + "git branch: " + str(subprocess.check_output(["git", "branch"]).strip()) + ) + logging.info( + "git hash: " + + str(subprocess.check_output(["git", "rev-parse", "HEAD"]).strip()) + ) + logging.info("Paddle: " + str(paddle.__version__)) + logging.info("----------------------------------------") diff --git a/ppcfd/models/multiple_physics_pretraining/utils/schedulers.py b/ppcfd/models/multiple_physics_pretraining/utils/schedulers.py new file mode 100644 index 0000000..6dbea32 --- /dev/null +++ b/ppcfd/models/multiple_physics_pretraining/utils/schedulers.py @@ -0,0 +1,103 @@ +""" +自定义学习率调度器,用于支持 PaddlePaddle 迁移 + +本模块提供了与 PyTorch SequentialLR 兼容的最小化实现。 +""" + + +class SimpleSequentialScheduler: + """ + 最小化的 SequentialLR 替代品,完全模拟原始行为 + + 该调度器在指定的里程碑步数切换内部调度器,完全依赖内部调度器的实现。 + + 参数: + optimizer: 优化器实例 + schedulers: 调度器列表,例如 [warmup_scheduler, decay_scheduler] + milestones: 切换点列表,例如 [warmup_steps] + 当 _step_count >= milestones[i] 时,切换到 schedulers[i+1] + last_epoch: 用于恢复训练的步数,默认为 -1 (从头开始) + + 示例: + >>> import torch + >>> optimizer = torch.optim.Adam(model.parameters(), lr=0.1) + >>> + >>> # 创建 warmup 调度器 + >>> warmup = torch.optim.lr_scheduler.LinearLR( + ... optimizer, start_factor=0.01, end_factor=1.0, total_iters=1000 + ... ) + >>> + >>> # 创建 decay 调度器 + >>> decay = torch.optim.lr_scheduler.CosineAnnealingLR( + ... optimizer, eta_min=0.001, T_max=9000 + ... ) + >>> + >>> # 组合调度器 + >>> scheduler = SimpleSequentialScheduler( + ... optimizer, [warmup, decay], [1000], last_epoch=-1 + ... ) + >>> + >>> # 训练循环 + >>> for epoch in range(epochs): + ... for batch in dataloader: + ... optimizer.zero_grad() + ... loss.backward() + ... optimizer.step() + ... scheduler.step() + + 注意: + - 该调度器完全依赖内部调度器的实现,不重新实现数学公式 + - 不实现 state_dict()/load_state_dict(),完全依赖 last_epoch 机制 + - 可以轻松适配到 PaddlePaddle(只需修改内部调度器的创建方式) + """ + + def __init__(self, optimizer, schedulers, milestones, last_epoch=-1): + """ + 初始化调度器切换器 + + 参数: + optimizer: 优化器实例 + schedulers: 调度器列表 + milestones: 切换点列表(步数) + last_epoch: 用于恢复训练的步数,默认为 -1 + """ + self.optimizer = optimizer + self.schedulers = schedulers + self.milestones = milestones + self._step_count = 0 + if last_epoch >= 0: + self._step_count = last_epoch + 1 + scheduler_idx = self._get_scheduler_index() + current_lr = self.schedulers[scheduler_idx].get_lr() + self.optimizer.set_lr(current_lr) + self._last_lr = [current_lr] + + def _get_scheduler_index(self): + """确定当前应该使用哪个调度器""" + scheduler_idx = 0 + for i, milestone in enumerate(self.milestones): + if self._step_count >= milestone: + scheduler_idx = i + 1 + return scheduler_idx + + def step(self): + """ + 执行一步学习率更新 + + 根据当前步数确定应该使用哪个调度器,然后调用该调度器的 step() 方法。 + """ + scheduler_idx = self._get_scheduler_index() + self.schedulers[scheduler_idx].step() + self._step_count += 1 + current_lr = self.schedulers[scheduler_idx].get_lr() + self.optimizer.set_lr(current_lr) + self._last_lr = [current_lr] + + def get_last_lr(self): + """ + 返回当前学习率(用于日志记录) + + 返回: + 包含当前学习率的列表 + """ + return self._last_lr