Skip to content

Commit e0ff06b

Browse files
authored
Merge branch 'main' into add-autojac-jac
2 parents 25342e2 + 81648e5 commit e0ff06b

File tree

10 files changed

+26
-17
lines changed

10 files changed

+26
-17
lines changed

.github/actions/install-deps/action.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ inputs:
1616
runs:
1717
using: composite
1818
steps:
19+
- name: Create virtual environment
20+
shell: bash
21+
run: uv venv
22+
1923
- name: Install dependencies (options=[${{ inputs.options }}], groups=[${{ inputs.groups }}])
2024
shell: bash
2125
run: |

.github/workflows/build-deploy-docs.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,10 @@ jobs:
1818
contents: write
1919
steps:
2020
- name: Checkout repository
21-
uses: actions/checkout@v4
21+
uses: actions/checkout@v6
2222

2323
- name: Set up uv
24-
uses: astral-sh/setup-uv@v5
24+
uses: astral-sh/setup-uv@v7
2525
with:
2626
python-version: '3.14'
2727

.github/workflows/check-todos.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ jobs:
88
runs-on: ubuntu-latest
99
steps:
1010
- name: Checkout code
11-
uses: actions/checkout@v4
11+
uses: actions/checkout@v6
1212

1313
- name: Scan for TODO strings
1414
run: |

.github/workflows/claude.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ jobs:
2626
actions: read # Required for Claude to read CI results on PRs
2727
steps:
2828
- name: Checkout repository
29-
uses: actions/checkout@v4
29+
uses: actions/checkout@v6
3030
with:
3131
fetch-depth: 1
3232

.github/workflows/release.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414
id-token: write
1515
steps:
1616
- name: Checkout repository
17-
uses: actions/checkout@v4
17+
uses: actions/checkout@v6
1818

1919
- name: Set up uv
20-
uses: astral-sh/setup-uv@v5
20+
uses: astral-sh/setup-uv@v7
2121
with:
2222
python-version: '3.14'
2323

.github/workflows/tests.yml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,10 @@ jobs:
3535

3636
steps:
3737
- name: Checkout repository
38-
uses: actions/checkout@v4
38+
uses: actions/checkout@v6
3939

4040
- name: Set up uv
41-
uses: astral-sh/setup-uv@v5
41+
uses: astral-sh/setup-uv@v7
4242
with:
4343
python-version: ${{ matrix.python-version || '3.14' }}
4444

@@ -62,10 +62,10 @@ jobs:
6262
runs-on: ubuntu-latest
6363
steps:
6464
- name: Checkout repository
65-
uses: actions/checkout@v4
65+
uses: actions/checkout@v6
6666

6767
- name: Set up uv
68-
uses: astral-sh/setup-uv@v5
68+
uses: astral-sh/setup-uv@v7
6969
with:
7070
python-version: '3.14'
7171

@@ -82,10 +82,10 @@ jobs:
8282
runs-on: ubuntu-latest
8383
steps:
8484
- name: Checkout repository
85-
uses: actions/checkout@v4
85+
uses: actions/checkout@v6
8686

8787
- name: Set up uv
88-
uses: astral-sh/setup-uv@v5
88+
uses: astral-sh/setup-uv@v7
8989
with:
9090
python-version: '3.14'
9191

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ changelog does not include internal changes that do not affect the user.
1515
Its interface is analog to that of `torch.autograd.grad`.
1616
- Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose
1717
between `"min"`, `"median"`, and `"rmse"` scaling.
18+
- Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`.
19+
Usage is still the same, `aggregator.gramian_weighting` is just an alias for the (quite confusing)
20+
`aggregator.weighting.weighting` field.
1821

1922
### Changed
2023

docs/source/examples/monitoring.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ they have a negative inner product).
4949
optimizer = SGD(params, lr=0.1)
5050
aggregator = UPGrad()
5151
52-
aggregator.weighting.weighting.register_forward_hook(print_weights)
52+
aggregator.gramian_weighting.register_forward_hook(print_weights)
5353
aggregator.register_forward_hook(print_gd_similarity)
5454
5555
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10

src/torchjd/aggregation/_aggregator_bases.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,10 @@ class GramianWeightedAggregator(WeightedAggregator):
7373
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
7474
Weighting to it.
7575
76-
:param weighting: The object responsible for extracting the vector of weights from the gramian.
76+
:param gramian_weighting: The object responsible for extracting the vector of weights from the
77+
gramian.
7778
"""
7879

79-
def __init__(self, weighting: Weighting[PSDMatrix]):
80-
super().__init__(weighting << compute_gramian)
80+
def __init__(self, gramian_weighting: Weighting[PSDMatrix]):
81+
super().__init__(gramian_weighting << compute_gramian)
82+
self.gramian_weighting = gramian_weighting

tests/doc/test_rst.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch.
308308
optimizer = SGD(params, lr=0.1)
309309
aggregator = UPGrad()
310310

311-
aggregator.weighting.weighting.register_forward_hook(print_weights)
311+
aggregator.gramian_weighting.register_forward_hook(print_weights)
312312
aggregator.register_forward_hook(print_gd_similarity)
313313

314314
inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10

0 commit comments

Comments
 (0)