Sophiex/dev ssl/main#2511
Conversation
Fix was the dataset length computation in the workers
| } | ||
|
|
||
| if channels is not None: | ||
| # Respect the order given in the stream config so the channel layout is identical |
There was a problem hiding this comment.
Do you think we need this? Are we unsure at the moment whether this is necessary?
There was a problem hiding this comment.
We are unsure if it is necessary
| output = self.predict_decoders(model_params, step, tokens, batch, output) | ||
| # latent predictions (raw and with SSL heads) | ||
| output = self.predict_latent(model_params, step, tokens, batch, output, intermediates) | ||
| if "masking" in self.cf.training_config.training_mode: |
|
|
||
| for module in model.encoder.ae_local_engine.ae_local_blocks.modules(): | ||
| if isinstance(module, modules_to_shard): | ||
| if isinstance(module, modules_to_shard) and _has_trainable_params(module): |
There was a problem hiding this comment.
This is the fsdp fix from this morning?
| continue # set disabled, e.g. by a train_continue override | ||
| stage_label = f"val_{name}" | ||
| extra_cfg = get_active_stage_config(self.validation_cfg, overrides, cfg_keys_to_filter) | ||
| # extra sets never write sample output files (would collide with primary val output) |
There was a problem hiding this comment.
So we only look at these in plot_train? And then if we run inference (using test that I suppose inherits from validation) we only ever write out one of the validation periods? Not important, just want to check the mechanics here.
There was a problem hiding this comment.
not sure, didn't check to be honest
| self.validate(0, self.test_cfg, self.batch_size_test_per_gpu) | ||
| logger.info(f"Finished inference run with id: {cf.general.run_id}") | ||
|
|
||
| def _check_channel_order_consistency( |
There was a problem hiding this comment.
Can we keep this guard and remove the reordering you did above if we don't think we need it maybe?
| self.dataset = MultiStreamDataSampler(cf, self.training_cfg, stage=TRAIN) | ||
| self.dataset_val = MultiStreamDataSampler(cf, self.validation_cfg, stage=VAL) | ||
|
|
||
| if run_id_contd is not None: |
Description
This has several small fixes, including:
./scripts/actions.sh lint./scripts/actions.sh unit-test./scripts/actions.sh integration-testlaunch-slurm.py --time 60