diff --git a/CHANGELOG.md b/CHANGELOG.md index aa14a0a7..20ada67c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,9 @@ changelog does not include internal changes that do not affect the user. - Added a `scale_mode` parameter to `AlignedMTL` and `AlignedMTLWeighting`, allowing to choose between `"min"`, `"median"`, and `"rmse"` scaling. +- Added an attribute `gramian_weighting` to all aggregators that use a gramian-based `Weighting`. + Usage is still the same, `aggregator.gramian_weighting` is just an alias for the (quite confusing) + `aggregator.weighting.weighting` field. ### Changed diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 69cc0e1b..c17f96f8 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -49,7 +49,7 @@ they have a negative inner product). optimizer = SGD(params, lr=0.1) aggregator = UPGrad() - aggregator.weighting.weighting.register_forward_hook(print_weights) + aggregator.gramian_weighting.register_forward_hook(print_weights) aggregator.register_forward_hook(print_gd_similarity) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 959b1485..6935199b 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -73,8 +73,10 @@ class GramianWeightedAggregator(WeightedAggregator): WeightedAggregator that computes the gramian of the input jacobian matrix before applying a Weighting to it. - :param weighting: The object responsible for extracting the vector of weights from the gramian. + :param gramian_weighting: The object responsible for extracting the vector of weights from the + gramian. """ - def __init__(self, weighting: Weighting[PSDMatrix]): - super().__init__(weighting << compute_gramian) + def __init__(self, gramian_weighting: Weighting[PSDMatrix]): + super().__init__(gramian_weighting << compute_gramian) + self.gramian_weighting = gramian_weighting diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index cf1846d7..6513d41c 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -308,7 +308,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. optimizer = SGD(params, lr=0.1) aggregator = UPGrad() - aggregator.weighting.weighting.register_forward_hook(print_weights) + aggregator.gramian_weighting.register_forward_hook(print_weights) aggregator.register_forward_hook(print_gd_similarity) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10