Feat multi decoder segmenter#19
Conversation
…rch-utils into feat-multi-decoder-segmenter
carshadi
left a comment
There was a problem hiding this comment.
This looks great @camilolaiton! I just had a few comments.
Also, for the run_proteomics_example.py, there are a couple required dependencies that are not packaged here:
https://github.com/AllenNeuralDynamics/aind-lightsheet-mae
https://github.com/AllenNeuralDynamics/aind-proteomics-image-translator/tree/dev
I was thinking this could be resolved in two ways -
- add documentation to the examples on installing the necessary deps and how to run the example, and/or
- add the required deps as an optional dependency group to the pyproject.toml, e.g.,
pip install .[proteomics]- this would be useful if we plan on using this library for proteomics in general. Though maybe this could be deferred to when the related codebases are released.
Additionally, the aind-proteomics-image-translator pins numpy<2, but this library requires numpy>2 since it uses the copy argument. I had no problems running the example after upgrading numpy to 2.4.6, so maybe that upper bound isn't necessary? https://github.com/AllenNeuralDynamics/aind-proteomics-image-translator/blob/2791cc4c5c81c5ef5a275a49b5bc96f2ebbfa28a/pyproject.toml#L22
| # Ensure the tensor has a channel dimension that matches writers | ||
| if out_np.ndim == 4: | ||
| # legacy single-output (B, pz, py, px) — add channel dim | ||
| out_np = out_np[:, np.newaxis] |
There was a problem hiding this comment.
This will silently drops extra model output channels when fewer output stores are provided. A (B, 2, Z, Y, X) output with one store writes channel 0 and discards channel 1. We should add an explicit out_np.shape[1] == len(self.writers) validation and a test for mismatched counts.
There was a problem hiding this comment.
Good catch!! thank you!!
| ) # floor to avoid near-zero weights | ||
|
|
||
| # Output | ||
| output_denormalize: bool = Field( |
There was a problem hiding this comment.
This value is not exposed in the run.py CLI. CLI users running probability/segmentation outputs with normalization enabled will get values rescaled back into input intensity space.
There was a problem hiding this comment.
You're right. I'll also add a validation step to make sure keep this deactivated it when running segmentation.
| devices = args.devices | ||
| else: | ||
| n = torch.cuda.device_count() | ||
| devices = [f"cuda:{i}" for i in range(n)] if n > 0 else ["cpu"] |
There was a problem hiding this comment.
This falls back to ["cpu"] when there is no cuda device, but GpuWorker is CUDA-only. On a no-GPU machine this path will crash; better to fail early with a clear GPU-required message.
| s3_path = f"{args.out_prefix.rstrip('/')}/{name}.zarr/" | ||
| logger.info(f"Output store: s3://{args.out_bucket}/{s3_path}") | ||
| out_stores.append( | ||
| _open_or_create_s3_zarr(args.out_bucket, s3_path, vol_shape, out_chunks) |
There was a problem hiding this comment.
Output stores are created with the full input shape (T, C, Z, Y, X) but only channel c_idx / timepoint t_idx is ever written, while the OME-NGFF metadata advertises all C channels. A C=3 input yields output zarrs where 2 of 3 channels are all-zero, while write_ome_ngff_metadata describes 3 populated channels. I would limit the output shape to (1, 1, Z, Y, X) just to be safe.
| from aind_torch_utils.model_registry import ModelRegistry | ||
|
|
||
|
|
||
| class SharedEncoderModel(nn.Module): |
There was a problem hiding this comment.
It looks like this class isn't used anywhere outside of the test. The ProteinSharedModel example it was built for duplicates its forward logic verbatim instead of reusing it. Does it make sense to remove the SharedEncoderModel and add the ProteinSharedModel here? Or if you envisioned having this as an extensible class, we could do something like this:
class SharedEncoderModel(nn.Module):
def __init__(self, encoder, decoders, apply_sigmoid=False):
super().__init__()
self.encoder = encoder
self.decoders = nn.ModuleList(decoders)
self.apply_sigmoid = apply_sigmoid
def encode(self, x):
return self.encoder(x)
def decode(self, decoder, encoded, x):
return decoder(encoded)
def forward(self, x):
encoded = self.encode(x)
outputs = [self.decode(dec, encoded, x) for dec in self.decoders]
out = torch.stack(outputs, dim=1).squeeze(2)
if self.apply_sigmoid:
out = torch.clamp(torch.sigmoid(out), min=0.01, max=0.99)
return outand then have the protein-specific model only override the necessary parts, e.g.
class ProteinSharedModel(SharedEncoderModel):
def __init__(self, encoder, decoders, recover_layers=(2, 5), apply_sigmoid=True):
super().__init__(encoder, decoders, apply_sigmoid=apply_sigmoid)
self.recover_layers = recover_layers
def encode(self, x):
latent, _, _, _, feature_maps, _ = self.encoder(
x=x,
mask_ratio=0.0,
recover_layers=self.recover_layers,
)
return latent, feature_maps
def decode(self, decoder, encoded, x):
latent, feature_maps = encoded
return decoder(x=x, latent=latent, hidden_states_out=feature_maps)Or if this model class is already defined in a different library, we could just import it as a dependency and register the model similarly to the denoising Unet. But then we would need to change the model loading logic to support multiple weights files and model-specific kwargs, which would require some refactoring and potentially a more config-based approach instead of just passing --weights. However, that would allow it to be run under the current CLI setup and would likely be more flexible overall. We could worry about that later.
What do you think?
There was a problem hiding this comment.
Yeah, I was actually thinking it as a parent class but after reflecting on it, we should just leave the protein one. Regarding the model loading logic, that's a tricky one because of the refactoring.
But in the protein prediction scenario, it'd be better to refactor it to include a more config-based approach, that way we could satisfy a single model weights or this shared encoder approach. Would that be okay with how you want to set up your package?
I am also okay with handling this directly on a python script specific for proteomics since it's an optimization scenario for large datasets. I do not think it's something super common a lot of people will do.
There was a problem hiding this comment.
Yeah I think it would be useful to move towards the config approach since that is how most ML libraries do it, but that can be a milestone and not necessary for this PR. Maybe we just keep the protein model for now?
| def forward( | ||
| self, | ||
| x: torch.Tensor, | ||
| apply_sigmoid: Optional[bool] = False, |
There was a problem hiding this comment.
apply_sigmoid is unreachable in the pipeline, GpuWorker calls self.model(dev_in) with no kwargs
Thank you for taking a look @carshadi, great comments, I will address! As for the required deps, I haven't released the codebase yet but will do this soon. I can update the numpy dependency directly on aind-proteomics-image-translator to avoid incompatibilities. |
This pull request introduces multi-output model support to the inference pipeline, allowing models to return multiple outputs (e.g., via multiple decoders) and writing each output to a separate store. It also adds a new
SharedEncoderModelclass for shared-encoder/multi-decoder architectures, updates the worker and runner logic to handle multi-output models, and improves normalization and memory handling for more robust and efficient inference.Multi-output model support and pipeline changes:
Updated the pipeline (
run.py,workers.py) to support models that return multiple outputs (shape(B, N, Z, Y, X)), including changes to accept a list of output stores and to spawn a writer worker per output channel. The CLI now requires one--out-specper output channel.WriterWorkernow supports both single- and multi-output models, managing a list of accumulators and writing to multiple output stores as needed.Model architecture and utilities:
SharedEncoderModeltomodels.pyfor architectures with a shared encoder and multiple independent decoders, stacking outputs along a new dimension.Inference configuration and normalization:
Added
output_denormalizeoption toInferenceConfigto control whether outputs are inverse-normalized before writing, supporting models that output values in a different space than the input.Improved normalization logic to clip to percentile bounds before normalizing.
Robustness and documentation:
Improved memory handling for device-to-host transfers, ensuring pinned memory buffers match the actual model output shape and dtype.
Added or clarified module-level docstrings for better documentation and maintainability.
New utility functions:
load_jsonutility inexamples/example_utils/utils.pyfor loading JSON files from local paths or S3, with error handling.