Skip to content

attention block for obs assimilation#70

Open
bnb32 wants to merge 29 commits intomainfrom
bnb/attn_block
Open

attention block for obs assimilation#70
bnb32 wants to merge 29 commits intomainfrom
bnb/attn_block

Conversation

@bnb32
Copy link
Copy Markdown
Collaborator

@bnb32 bnb32 commented Mar 5, 2026

No description provided.

@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Mar 5, 2026

Codecov Report

❌ Patch coverage is 77.11111% with 103 lines in your changes missing coverage. Please review.
✅ Project coverage is 87.08%. Comparing base (87be189) to head (8aa248b).

Files with missing lines Patch % Lines
phygnn/layers/custom_layers.py 64.10% 98 Missing ⚠️
tests/test_layers.py 98.12% 3 Missing ⚠️
phygnn/base.py 87.50% 2 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main      #70      +/-   ##
==========================================
- Coverage   88.36%   87.08%   -1.28%     
==========================================
  Files          24       24              
  Lines        3583     3980     +397     
==========================================
+ Hits         3166     3466     +300     
- Misses        417      514      +97     
Flag Coverage Δ
unittests 87.08% <77.11%> (-1.28%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@bnb32 bnb32 requested a review from grantbuster March 5, 2026 20:54
Copy link
Copy Markdown
Member

@grantbuster grantbuster left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generally i think it looks good but i kind of want to hear your plan for training positions vs. inference global/chunk positions. This is also pretty experimental and i just want to see how it goes. I've been having a tough time getting this to work haha.

Comment thread phygnn/layers/custom_layers.py Outdated
else tf.keras.layers.AveragePooling2D(**kwargs)
)

pos = tf.range(x.shape[dim_index]) / x.shape[dim_index]
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd prefer linspace for closed boundaries but this is fine too.

so we're going with relative positional encodings right? Have you thought about position 0-1 in each small training example vs. 0-1 in the inference space vs. 0-1 in each fwp chunk? Even if you have a plan we'll have to verify this works. I think there could be some issues around your encoding frequencies working well with the training sample size in 0-1 but poorly when you have a much larger inference size also in 0-1 because the spatial resolution that the transformer is seeing effectively changes. Not confident about this but i think it's something we might have to experiment with.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, linspace is better.

Good points here. I hadn't thought about this potential issue on effective resolution. I'd expected that as long as fwp chunk shapes are similar size to batch shapes this will be fine, but maybe lat / lon / time features should actually be required inputs for this.

Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
Dropout rate for attention weights. Default is 0 (no dropout).
patch_size : int
Height, width, and depth of patches. Default is 1 for pixel-wise
tokenization.
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Patch size of Sup3r hypercube right? not of key/value obs?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the key data wasn't sparse the patch size would be used for key data also. When it is sparse data the patch size is only used for query data and patch_size = 1 for key data.

Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
)
q_enc += self._pos_encoding(
x, patch_size=self.patch_size, dim_index=2, embed_dim=embed_dim
)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think traditionally that the row/column positional encodings are done each for embed_dim//2 and then concatenated then added to the tensor x. Yes just confirmed this is at least how SatMAE does it.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then how would that change with additional depth/time encoding? Seems simplest to use embed_dim channels for each encoding.

Comment thread phygnn/layers/custom_layers.py Outdated
self.key_dim = key_dim
self.patch_size = patch_size
self.embed_dim = embed_dim
self.dropout = tf.keras.layers.Dropout(dropout)
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that standard attention mechanisms including tf.keras.layers.MultiHeadAttention might include the "output" projections and dropout as part of the MHA class. You can confirm this by checking the MHA attribute "output_dense"

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call on dropout. I included the extra projection since you can change the attention shape with the output_shape arg but maybe that flexibility isn't worth it.

Comment thread phygnn/layers/custom_layers.py Outdated
@bnb32
Copy link
Copy Markdown
Collaborator Author

bnb32 commented Mar 11, 2026

Generally i think it looks good but i kind of want to hear your plan for training positions vs. inference global/chunk positions. This is also pretty experimental and i just want to see how it goes. I've been having a tough time getting this to work haha.

Yeah I thought about the inference issue a bit before opting for the relative encoding. I decided adding obs doesn't change the currently operative assumption that a fwp chunk includes everything that significantly influences the output (if not then we should increase padding).

…size parameters and refactor encoding methods
@bnb32 bnb32 requested review from Copilot and grantbuster March 16, 2026 20:39
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a new cross-attention layer intended for observation assimilation, adds unit tests around tokenization/positional encoding behavior, and modernizes linting by switching from Flake8/super-linter to Ruff.

Changes:

  • Added TokenizeEncodeBase + Sup3rCrossAttention layers for tokenization + positional encoding + cross-attention.
  • Expanded tests to cover cross-attention (2D/3D), patch sizes > 1, and positional encoding pooling correctness.
  • Updated dev tooling/CI linting to Ruff and adjusted Pixi workspace config.

Reviewed changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
phygnn/layers/custom_layers.py Adds tokenization/positional encoding base + new cross-attention layer; removes older attention layers.
tests/test_layers.py Adds tests for cross-attention behavior and positional encoding pooling with patch sizes > 1.
.github/workflows/linter.yml Switches linting to Ruff action and runs on PRs.
pyproject.toml Replaces Flake8 with Ruff in dev deps; updates Pixi config section name.
pixi.lock Updates locked deps to reflect tooling changes.
phygnn/model_interfaces/base_model.py Minor docstring capitalization + small style tweak.
phygnn/base.py Formatting-only changes.
.github/linters/.python-lint / .github/linters/.flake8 Removes legacy linter config files.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

You can also share your feedback on Copilot code review. Take the survey.

Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
Comment on lines +136 to +139
i = tf.range(d // 2, dtype=tf.float32)
theta = k / tf.pow(omega, (2 * i / d))
enc[..., ::2] = tf.math.sin(theta)
enc[..., 1::2] = tf.math.cos(theta)
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
Comment on lines +318 to +321
slices = []
for pad in pads:
slices.append(slice(pad[0], -pad[1] if pad[1] > 0 else None))
return out[tuple(slices)]
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
x_enc += self._pos_encode(x_pad, dim_index=2)
if self.rank == 5:
x_enc += self._pos_encode(x_pad, dim_index=3)
return x_tok, x_enc, x_pad.shape
Comment thread phygnn/layers/custom_layers.py Outdated
…ic padding; enhance positional encoding methods
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a cross-attention block intended for observation (sparse NaN) assimilation by adding tokenization + positional encoding utilities, and updates the repo’s linting/tooling to use Ruff.

Changes:

  • Added TokenizeEncodeBase and Sup3rCrossAttention layers (including patch-based tokenization and positional encoding) and removed older attention layer implementations.
  • Expanded layer tests to cover cross-attention behavior (2D/3D, patch sizes, and positional encoding pooling).
  • Migrated dev lint tooling from Flake8/super-linter to Ruff (pyproject + GitHub Actions), and adjusted Pixi config section naming.

Reviewed changes

Copilot reviewed 8 out of 9 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
phygnn/layers/custom_layers.py Adds tokenization/positional-encoding base and new cross-attention layer; removes prior attention classes.
tests/test_layers.py Adds cross-attention and positional-encoding tests.
.github/workflows/linter.yml Switches CI linting to Ruff GitHub Action and adds PR trigger.
pyproject.toml Replaces Flake8 with Ruff in dev deps; updates Pixi section header.
pixi.lock Lockfile updates reflecting dependency/tooling changes.
phygnn/model_interfaces/base_model.py Minor docstring and small style change (in {1, 2}).
phygnn/base.py Formatting-only changes.
.github/linters/.python-lint / .github/linters/.flake8 Removes legacy linter configuration files.

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread .github/workflows/linter.yml
Comment thread tests/test_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py
Comment thread phygnn/layers/custom_layers.py Outdated
Comment thread phygnn/layers/custom_layers.py
Comment thread phygnn/layers/custom_layers.py Outdated
bnb32 added 10 commits March 21, 2026 10:58
- Introduced MultiHeadAttention class for custom multi-head attention with bias input.
- Added Sup3rCrossAlibi for cross attention with distance bias instead of absolute position encoding..
… parameter; update tests for PositionEncoder and Tokenizer integration
…per position of attention inputs and pos encoding. Shaw and Alibi layers use alternative approaches.
…nd add utility to compute day of year and second of year from timestamps
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants