Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions config/config_jepa_multi_data_all_years.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ num_class_tokens: 0
num_register_tokens: 64

# noise before predictor in JEPA
noise_pre_predictor_std: 0
noise_pre_predictor_std: 1e-3

# number of steps offset applied to first target window; if set to zero and forecast_steps=0 then
# one is training an auto-encoder
Expand Down Expand Up @@ -171,7 +171,7 @@ data_loading :
training_config:

# training_mode: "masking", "student_teacher", "latent_loss"
training_mode: ["student_teacher"]
training_mode: ["masking","student_teacher"]

# Deep self-supervision (V-JEPA 2.1 style): compute SSL loss at multiple encoder depths.
# When enabled, intermediate encoder representations are tapped and used as additional
Expand Down Expand Up @@ -209,32 +209,38 @@ training_config:
lr_max: 1e-4
lr_final_decay: 1e-6
lr_final: 0.0
num_steps_warmup: 512
num_steps_cooldown: 1024
num_steps_warmup: 1024
num_steps_cooldown: 16384
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
parallel_scaling_policy: "sqrt"

optimizer:
grad_clip: 0.25
weight_decay: 0.05
weight_decay: 0.1
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.9875 # at B=8 beta1 =0.9
beta2 : 0.994 # at B=8 beta1 approx 0.95
beta2 : 0.99875 # at B=8 beta1 approx 0.99
eps : 2e-08

losses : {
"physical": {
enabled: True,
type: LossPhysical,
weight: 1.0,
loss_fcts: { "mse": { target_source_correspondence: {0 : {0 : "subset"} },}, },
},
"student-teacher": {
enabled: True,
type: LossLatentSSLStudentTeacher,
weight: 1.0,
loss_fcts : {
"JEPA": {
'weight': 4, "loss_extra_args": {}, "out_dim": 2048, "head": transformer,
"num_blocks": 6, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 1024,
"dropout_rate": 0.1,
"num_blocks": 6, "num_heads": 16, "with_qk_lnorm": True, "intermediate_dim": 2048,
"dropout_rate": 0.2,
target_source_correspondence: {0 : {0 : "subset"} },
},
},
Expand Down Expand Up @@ -280,6 +286,8 @@ training_config:
# validation config; full validation config is merge of training and validation config
validation_config:

time_window_step: 06:00:00

samples_per_mini_epoch: 256
shuffle: False

Expand Down
2 changes: 1 addition & 1 deletion config/config_jepa_multi_data_all_years_ft.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

streams_directory: "./config/streams/pretrain_multi_data_od/"
streams_directory: "./config/streams/pretrain_multi_data_all_years/"

general:

Expand Down
33 changes: 31 additions & 2 deletions config/config_jepa_multi_data_ft_forecast.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,40 @@
streams_directory: "./config/streams/jepa_forecast_multi_data_od/"
streams_directory: "./config/streams/jepa_forecast_multi_data_od_ckpt_order/"

freeze_modules: "^(?!.*ERA5)(?=.*(?:encoder|latent_pre_norm|latent_heads)).*$"

general:

# mutable parameters
istep: 0
rank: ???
world_size: ???

training_config:
num_mini_epochs: 32
start_date: 2016-01-01T00:00
# OND-2022 carved out of training so the heldout-train-years extra validation
# set below is genuinely held out (was 2022-12-31)
end_date: 2022-09-30T00:00

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 1e-5
lr_final_decay: 2e-6
lr_final: 0.0
num_steps_warmup: 256
num_steps_cooldown: 32768

num_mini_epochs: 6
samples_per_mini_epoch: 8192

# # extra validation sets, evaluated each mini-epoch and logged as stage "val_<name>";
# # each entry overrides the primary validation_config
# extra_validation_configs:
# # held-out slice inside the training years (excluded from training via the
# # end_date above), season-matched to the OND-2023 primary val window;
# # memorization probe: if its loss tracks val, the train/val gap is memorization,
# # if it stays well below val, the gap is distribution shift
# heldout-train-years:
# start_date: 2022-10-01T00:00
# end_date: 2022-12-31T00:00
# shuffle: True
# samples_per_mini_epoch: 256
34 changes: 24 additions & 10 deletions config/config_jepa_multi_data_ft_forecast_all_years.yml
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ training_config:
enabled: false


num_mini_epochs: 24
num_mini_epochs: 32
samples_per_mini_epoch: 8192
shuffle: True

Expand All @@ -119,11 +119,11 @@ training_config:

learning_rate_scheduling :
lr_start: 1e-6
lr_max: 8e-5
lr_max: 4e-5
lr_final_decay: 2e-6
lr_final: 0.0
lr_final: 2e-6
num_steps_warmup: 256
num_steps_cooldown: 512
num_steps_cooldown: 16384
policy_warmup: "cosine"
policy_decay: "constant"
policy_cooldown: "linear"
Expand All @@ -136,18 +136,18 @@ training_config:
adamw :
# parameters are scaled by number of DDP workers
beta1 : 0.9875 # at B=8 beta1 =0.9
beta2 : 0.994 # at B=8 beta1 approx 0.95
beta2 : 0.99875 # at B=8 beta1 approx 0.95
eps : 1e-08

losses : {
"student-teacher": {
enabled: False,
type: Disabled,
},
# "physical": {
# enabled: False,
# type: Disabled,
# },
"physical": {
enabled: False,
type: Disabled,
},
"forecast": {
type: LossPhysical,
loss_fcts: { "mse": { }, },
Expand All @@ -173,7 +173,7 @@ training_config:

forecast :
time_step: 06:00:00
num_steps: 3
num_steps: 2
offset: 1
policy: "fixed"

Expand All @@ -183,6 +183,7 @@ validation_config:
samples_per_mini_epoch: 256
shuffle: False

time_window_step: 6:00:00
start_date: 2023-10-01T00:00
end_date: 2023-12-31T00:00

Expand All @@ -205,6 +206,19 @@ validation_config:
# run validation before training starts (mainly for model development)
validate_before_training: False

# # extra validation sets, evaluated each mini-epoch and logged as stage "val_<name>";
# # each entry overrides the primary validation_config
# extra_validation_configs:
# # held-out slice inside the training years (excluded from training via the
# # end_date above), season-matched to the OND-2023 primary val window;
# # memorization probe: if its loss tracks val, the train/val gap is memorization,
# # if it stays well below val, the gap is distribution shift
# heldout-train-years:
# start_date: 2022-10-01T00:00
# end_date: 2022-12-31T00:00
# shuffle: True
# samples_per_mini_epoch: 256

# test config; full test config is merge of validation and test config
test_config:

Expand Down
Loading
Loading