Skip to content

Feat multi decoder segmenter#19

Open
camilolaiton wants to merge 11 commits into
mainfrom
feat-multi-decoder-segmenter
Open

Feat multi decoder segmenter#19
camilolaiton wants to merge 11 commits into
mainfrom
feat-multi-decoder-segmenter

Conversation

@camilolaiton

Copy link
Copy Markdown
Collaborator

This pull request introduces multi-output model support to the inference pipeline, allowing models to return multiple outputs (e.g., via multiple decoders) and writing each output to a separate store. It also adds a new SharedEncoderModel class for shared-encoder/multi-decoder architectures, updates the worker and runner logic to handle multi-output models, and improves normalization and memory handling for more robust and efficient inference.

Multi-output model support and pipeline changes:

  • Updated the pipeline (run.py, workers.py) to support models that return multiple outputs (shape (B, N, Z, Y, X)), including changes to accept a list of output stores and to spawn a writer worker per output channel. The CLI now requires one --out-spec per output channel.

  • WriterWorker now supports both single- and multi-output models, managing a list of accumulators and writing to multiple output stores as needed.

Model architecture and utilities:

  • Added SharedEncoderModel to models.py for architectures with a shared encoder and multiple independent decoders, stacking outputs along a new dimension.

Inference configuration and normalization:

  • Added output_denormalize option to InferenceConfig to control whether outputs are inverse-normalized before writing, supporting models that output values in a different space than the input.

  • Improved normalization logic to clip to percentile bounds before normalizing.

Robustness and documentation:

  • Improved memory handling for device-to-host transfers, ensuring pinned memory buffers match the actual model output shape and dtype.

  • Added or clarified module-level docstrings for better documentation and maintainability.

New utility functions:

  • Added load_json utility in examples/example_utils/utils.py for loading JSON files from local paths or S3, with error handling.

@camilolaiton camilolaiton requested a review from carshadi June 10, 2026 21:35

@carshadi carshadi left a comment

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great @camilolaiton! I just had a few comments.

Also, for the run_proteomics_example.py, there are a couple required dependencies that are not packaged here:
https://github.com/AllenNeuralDynamics/aind-lightsheet-mae
https://github.com/AllenNeuralDynamics/aind-proteomics-image-translator/tree/dev
I was thinking this could be resolved in two ways -

  1. add documentation to the examples on installing the necessary deps and how to run the example, and/or
  2. add the required deps as an optional dependency group to the pyproject.toml, e.g., pip install .[proteomics] - this would be useful if we plan on using this library for proteomics in general. Though maybe this could be deferred to when the related codebases are released.

Additionally, the aind-proteomics-image-translator pins numpy<2, but this library requires numpy>2 since it uses the copy argument. I had no problems running the example after upgrading numpy to 2.4.6, so maybe that upper bound isn't necessary? https://github.com/AllenNeuralDynamics/aind-proteomics-image-translator/blob/2791cc4c5c81c5ef5a275a49b5bc96f2ebbfa28a/pyproject.toml#L22

Comment thread src/aind_torch_utils/workers.py
# Ensure the tensor has a channel dimension that matches writers
if out_np.ndim == 4:
# legacy single-output (B, pz, py, px) — add channel dim
out_np = out_np[:, np.newaxis]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will silently drops extra model output channels when fewer output stores are provided. A (B, 2, Z, Y, X) output with one store writes channel 0 and discards channel 1. We should add an explicit out_np.shape[1] == len(self.writers) validation and a test for mismatched counts.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!! thank you!!

) # floor to avoid near-zero weights

# Output
output_denormalize: bool = Field(

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This value is not exposed in the run.py CLI. CLI users running probability/segmentation outputs with normalization enabled will get values rescaled back into input intensity space.

@camilolaiton camilolaiton Jun 11, 2026

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right. I'll also add a validation step to make sure keep this deactivated it when running segmentation.

Comment thread examples/run_proteomics_example.py
devices = args.devices
else:
n = torch.cuda.device_count()
devices = [f"cuda:{i}" for i in range(n)] if n > 0 else ["cpu"]

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This falls back to ["cpu"] when there is no cuda device, but GpuWorker is CUDA-only. On a no-GPU machine this path will crash; better to fail early with a clear GPU-required message.

s3_path = f"{args.out_prefix.rstrip('/')}/{name}.zarr/"
logger.info(f"Output store: s3://{args.out_bucket}/{s3_path}")
out_stores.append(
_open_or_create_s3_zarr(args.out_bucket, s3_path, vol_shape, out_chunks)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Output stores are created with the full input shape (T, C, Z, Y, X) but only channel c_idx / timepoint t_idx is ever written, while the OME-NGFF metadata advertises all C channels. A C=3 input yields output zarrs where 2 of 3 channels are all-zero, while write_ome_ngff_metadata describes 3 populated channels. I would limit the output shape to (1, 1, Z, Y, X) just to be safe.

Comment thread examples/run_proteomics_example.py
from aind_torch_utils.model_registry import ModelRegistry


class SharedEncoderModel(nn.Module):

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like this class isn't used anywhere outside of the test. The ProteinSharedModel example it was built for duplicates its forward logic verbatim instead of reusing it. Does it make sense to remove the SharedEncoderModel and add the ProteinSharedModel here? Or if you envisioned having this as an extensible class, we could do something like this:

class SharedEncoderModel(nn.Module):
    def __init__(self, encoder, decoders, apply_sigmoid=False):
        super().__init__()
        self.encoder = encoder
        self.decoders = nn.ModuleList(decoders)
        self.apply_sigmoid = apply_sigmoid

    def encode(self, x):
        return self.encoder(x)

    def decode(self, decoder, encoded, x):
        return decoder(encoded)

    def forward(self, x):
        encoded = self.encode(x)
        outputs = [self.decode(dec, encoded, x) for dec in self.decoders]
        out = torch.stack(outputs, dim=1).squeeze(2)
        if self.apply_sigmoid:
            out = torch.clamp(torch.sigmoid(out), min=0.01, max=0.99)
        return out

and then have the protein-specific model only override the necessary parts, e.g.

class ProteinSharedModel(SharedEncoderModel):
    def __init__(self, encoder, decoders, recover_layers=(2, 5), apply_sigmoid=True):
        super().__init__(encoder, decoders, apply_sigmoid=apply_sigmoid)
        self.recover_layers = recover_layers

    def encode(self, x):
        latent, _, _, _, feature_maps, _ = self.encoder(
            x=x,
            mask_ratio=0.0,
            recover_layers=self.recover_layers,
        )
        return latent, feature_maps

    def decode(self, decoder, encoded, x):
        latent, feature_maps = encoded
        return decoder(x=x, latent=latent, hidden_states_out=feature_maps)

Or if this model class is already defined in a different library, we could just import it as a dependency and register the model similarly to the denoising Unet. But then we would need to change the model loading logic to support multiple weights files and model-specific kwargs, which would require some refactoring and potentially a more config-based approach instead of just passing --weights. However, that would allow it to be run under the current CLI setup and would likely be more flexible overall. We could worry about that later.

What do you think?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I was actually thinking it as a parent class but after reflecting on it, we should just leave the protein one. Regarding the model loading logic, that's a tricky one because of the refactoring.

But in the protein prediction scenario, it'd be better to refactor it to include a more config-based approach, that way we could satisfy a single model weights or this shared encoder approach. Would that be okay with how you want to set up your package?

I am also okay with handling this directly on a python script specific for proteomics since it's an optimization scenario for large datasets. I do not think it's something super common a lot of people will do.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I think it would be useful to move towards the config approach since that is how most ML libraries do it, but that can be a milestone and not necessary for this PR. Maybe we just keep the protein model for now?

def forward(
self,
x: torch.Tensor,
apply_sigmoid: Optional[bool] = False,

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

apply_sigmoid is unreachable in the pipeline, GpuWorker calls self.model(dev_in) with no kwargs

@camilolaiton

Copy link
Copy Markdown
Collaborator Author

This looks great @camilolaiton! I just had a few comments.

Also, for the run_proteomics_example.py, there are a couple required dependencies that are not packaged here: https://github.com/AllenNeuralDynamics/aind-lightsheet-mae https://github.com/AllenNeuralDynamics/aind-proteomics-image-translator/tree/dev I was thinking this could be resolved in two ways -

  1. add documentation to the examples on installing the necessary deps and how to run the example, and/or
  2. add the required deps as an optional dependency group to the pyproject.toml, e.g., pip install .[proteomics] - this would be useful if we plan on using this library for proteomics in general. Though maybe this could be deferred to when the related codebases are released.

Additionally, the aind-proteomics-image-translator pins numpy<2, but this library requires numpy>2 since it uses the copy argument. I had no problems running the example after upgrading numpy to 2.4.6, so maybe that upper bound isn't necessary? https://github.com/AllenNeuralDynamics/aind-proteomics-image-translator/blob/2791cc4c5c81c5ef5a275a49b5bc96f2ebbfa28a/pyproject.toml#L22

Thank you for taking a look @carshadi, great comments, I will address! As for the required deps, I haven't released the codebase yet but will do this soon. I can update the numpy dependency directly on aind-proteomics-image-translator to avoid incompatibilities.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants