Skip to content

Sophiex/dev ssl/main#2511

Open
sophie-xhonneux wants to merge 11 commits into
develop-sslfrom
sophiex/dev-ssl/main
Open

Sophiex/dev ssl/main#2511
sophie-xhonneux wants to merge 11 commits into
develop-sslfrom
sophiex/dev-ssl/main

Conversation

@sophie-xhonneux

Copy link
Copy Markdown
Contributor

Description

This has several small fixes, including:

  • small feet at the end of training
image - Updating the configs - Multiple validation periods - Predict latent being called during forecasting ## Checklist before asking for review
  • I have performed a self-review of my code
  • My changes comply with basic sanity checks:
    • I have fixed formatting issues with ./scripts/actions.sh lint
    • I have run unit tests with ./scripts/actions.sh unit-test
    • I have documented my code and I have updated the docstrings.
    • I have added unit tests, if relevant
  • I have tried my changes with data and code:
    • I have run the integration tests with ./scripts/actions.sh integration-test
    • (bigger changes) I have run a full training and I have written in the comment the run_id(s): launch-slurm.py --time 60
    • (bigger changes and experiments) I have shared a hegdedoc in the github issue with all the configurations and runs for this experiments
  • I have informed and aligned with people impacted by my change:
    • for config changes: the MatterMost channels and/or a design doc
    • for changes of dependencies: the MatterMost software development channel

}

if channels is not None:
# Respect the order given in the stream config so the channel layout is identical

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we need this? Are we unsure at the moment whether this is necessary?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are unsure if it is necessary

output = self.predict_decoders(model_params, step, tokens, batch, output)
# latent predictions (raw and with SSL heads)
output = self.predict_latent(model_params, step, tokens, batch, output, intermediates)
if "masking" in self.cf.training_config.training_mode:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good


for module in model.encoder.ae_local_engine.ae_local_blocks.modules():
if isinstance(module, modules_to_shard):
if isinstance(module, modules_to_shard) and _has_trainable_params(module):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fsdp fix from this morning?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

continue # set disabled, e.g. by a train_continue override
stage_label = f"val_{name}"
extra_cfg = get_active_stage_config(self.validation_cfg, overrides, cfg_keys_to_filter)
# extra sets never write sample output files (would collide with primary val output)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So we only look at these in plot_train? And then if we run inference (using test that I suppose inherits from validation) we only ever write out one of the validation periods? Not important, just want to check the mechanics here.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure, didn't check to be honest

self.validate(0, self.test_cfg, self.batch_size_test_per_gpu)
logger.info(f"Finished inference run with id: {cf.general.run_id}")

def _check_channel_order_consistency(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep this guard and remove the reordering you did above if we don't think we need it maybe?

self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN)
self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL)

if run_id_contd is not None:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you!!!

@github-actions github-actions Bot added infra Issues related to infrastructure model Related to model training or definition (not generic infra) labels Jun 16, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

infra Issues related to infrastructure model Related to model training or definition (not generic infra)

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

2 participants