From 4bb425ab5fd09ec5b6b712b69cbafaea600c2139 Mon Sep 17 00:00:00 2001 From: Matthieu Buot de l'Epine Date: Sat, 20 Dec 2025 12:39:16 +0100 Subject: [PATCH] doc: Fix weighting hook registration --- docs/source/examples/monitoring.rst | 2 +- tests/doc/test_rst.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/source/examples/monitoring.rst b/docs/source/examples/monitoring.rst index 44426667..8ec675aa 100644 --- a/docs/source/examples/monitoring.rst +++ b/docs/source/examples/monitoring.rst @@ -49,7 +49,7 @@ they have a negative inner product). optimizer = SGD(params, lr=0.1) aggregator = UPGrad() - aggregator.weighting.register_forward_hook(print_weights) + aggregator.weighting.weighting.register_forward_hook(print_weights) aggregator.register_forward_hook(print_gd_similarity) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 diff --git a/tests/doc/test_rst.py b/tests/doc/test_rst.py index 53d92ed2..b64b504c 100644 --- a/tests/doc/test_rst.py +++ b/tests/doc/test_rst.py @@ -300,7 +300,7 @@ def print_gd_similarity(_, inputs: tuple[torch.Tensor, ...], aggregation: torch. optimizer = SGD(params, lr=0.1) aggregator = UPGrad() - aggregator.weighting.register_forward_hook(print_weights) + aggregator.weighting.weighting.register_forward_hook(print_weights) aggregator.register_forward_hook(print_gd_similarity) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10