From 0aa1c8bfadcb4ad03ada491a942862c68a9d7af1 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 12:20:15 -0400 Subject: [PATCH 01/14] Add CRMOGMWeighting from NeurIPS 2022 --- CHANGELOG.md | 6 +- docs/source/docs/aggregation/cr_mogm.rst | 15 +++ docs/source/docs/aggregation/index.rst | 1 + src/torchjd/aggregation/__init__.py | 2 + src/torchjd/aggregation/_cr_mogm.py | 123 +++++++++++++++++++ tests/unit/aggregation/test_cr_mogm.py | 148 +++++++++++++++++++++++ 6 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 docs/source/docs/aggregation/cr_mogm.rst create mode 100644 src/torchjd/aggregation/_cr_mogm.py create mode 100644 tests/unit/aggregation/test_cr_mogm.py diff --git a/CHANGELOG.md b/CHANGELOG.md index fe52e1fe..3b20f2db 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `CRMOGMWeighting` from + [Conflict-Reduction Multi-Objective Gradient Methods](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4e91f0648fb6e09f0156a7eaf6c4dfdb-Abstract-Conference.html). + It wraps an existing `Weighting` and stabilises its weights with an exponential moving average + across calls. - Added getters and setters for the constructor parameters of all aggregators and weightings, so that they can be changed after initialization. This includes: `pref_vector`, `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; @@ -18,7 +22,7 @@ changelog does not include internal changes that do not affect the user. `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor - checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. + checks. ## [0.10.0] - 2026-04-16 diff --git a/docs/source/docs/aggregation/cr_mogm.rst b/docs/source/docs/aggregation/cr_mogm.rst new file mode 100644 index 00000000..47e70f49 --- /dev/null +++ b/docs/source/docs/aggregation/cr_mogm.rst @@ -0,0 +1,15 @@ +:hide-toc: + +CR-MOGM +======= + +.. autoclass:: torchjd.aggregation.CRMOGMWeighting + :members: __call__, reset + +.. note:: + The usage example in the docstring above imports + ``WeightedAggregator`` / ``GramianWeightedAggregator`` from + ``torchjd.aggregation._aggregator_bases``, which is a private module. These two + aggregator base classes are not currently part of the public ``torchjd.aggregation`` + namespace, so this private-module import is the only path that works today. Promoting + them to the public namespace is a separate decision left to the maintainers. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 4d62f820..98725ef3 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -29,6 +29,7 @@ Abstract base classes cagrad.rst config.rst constant.rst + cr_mogm.rst dualproj.rst flattening.rst graddrop.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 400cfe27..d36a6d82 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -63,6 +63,7 @@ from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting from ._config import ConFIG from ._constant import Constant, ConstantWeighting +from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop @@ -89,6 +90,7 @@ "ConFIG", "Constant", "ConstantWeighting", + "CRMOGMWeighting", "DualProj", "DualProjWeighting", "Flattening", diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py new file mode 100644 index 00000000..0cb7e4d8 --- /dev/null +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -0,0 +1,123 @@ +from __future__ import annotations + +from typing import TypeVar, cast + +import torch +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful + +from ._weighting_bases import Weighting + +_T = TypeVar("_T", contravariant=True, bound=Tensor) + + +class CRMOGMWeighting(Weighting[_T], Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another + :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it + produces with an exponential moving average (EMA) across calls. This is the weight-smoothing + modifier from `Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022) + `_. + + Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step + :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: + + .. math:: + + \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k + + with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top + \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first + forward call once :math:`m` is known and is reset automatically when ``m``, ``dtype`` or + ``device`` of the input changes. + + Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a + :class:`~torchjd.aggregation._weighting_bases.MatrixWeighting` or a + :class:`~torchjd.aggregation._weighting_bases.GramianWeighting`. The user composes it with + the appropriate aggregator base: + + .. code-block:: python + + from torchjd.aggregation import MeanWeighting, UPGradWeighting + from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, WeightedAggregator, + ) + from torchjd.aggregation._cr_mogm import CRMOGMWeighting + + matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + + This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` + when restarting the smoothing from uniform weights. + + :param weighting: The wrapped weighting whose output is smoothed. + :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing + (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes + the weights at their initial uniform value. The default of ``0.9`` follows the usual + EMA convention (analogous to Adam's :math:`\beta_1`). + + .. note:: + ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a + schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning + rate decays. Update ``alpha`` between forward calls via the public attribute on the + wrapping aggregator: + + .. code-block:: python + + # With WeightedAggregator + aggregator.weighting.alpha = 1 - current_lr / initial_lr + + # With GramianWeightedAggregator + aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr + """ + + def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: + super().__init__() + self.weighting = weighting + self.alpha = alpha + self._lambda: Tensor | None = None + self._state_key: tuple[int, torch.dtype, torch.device] | None = None + + @property + def alpha(self) -> float: + return self._alpha + + @alpha.setter + def alpha(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.") + self._alpha = value + + def reset(self) -> None: + """Clears the EMA state so the next forward starts from uniform weights.""" + + self._lambda = None + self._state_key = None + + def forward(self, stat: _T, /) -> Tensor: + device = stat.device + dtype = stat.dtype + m = stat.shape[0] + + self._ensure_state(m, dtype, device) + lambda_prev = cast(Tensor, self._lambda) + + lambda_hat = self.weighting(stat) + lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat + + self._lambda = lambda_k.detach() + return lambda_k + + def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> None: + key = (m, dtype, device) + if self._state_key != key or self._lambda is None: + if m > 0: + self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) + else: + self._lambda = torch.zeros(0, dtype=dtype, device=device) + self._state_key = key + + def __repr__(self) -> str: + return f"CRMOGMWeighting(weighting={self.weighting!r}, alpha={self.alpha!r})" diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py new file mode 100644 index 00000000..a7f520a3 --- /dev/null +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -0,0 +1,148 @@ +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation import MeanWeighting, UPGradWeighting +from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, + WeightedAggregator, +) +from torchjd.aggregation._cr_mogm import CRMOGMWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +# UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in +# scaled_matrices, so the gramian-path structural test only uses typical_matrices. +matrix_pairs = [ + (WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) + for m in typical_matrices + scaled_matrices +] +gramian_pairs = [ + (GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices +] + + +def test_representations() -> None: + W = CRMOGMWeighting(MeanWeighting(), alpha=0.9) + expected = "CRMOGMWeighting(weighting=MeanWeighting(), alpha=0.9)" + # Weighting does not define __str__, so it falls back to __repr__. + assert repr(W) == expected + assert str(W) == expected + + +@mark.parametrize(["aggregator", "matrix"], matrix_pairs) +def test_expected_structure_matrix_weighting( + aggregator: WeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + """ + Use ``UPGradWeighting`` so the weights actually depend on the input — with + ``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would + be trivial. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_alpha_setter_accepts_valid() -> None: + W = CRMOGMWeighting(MeanWeighting()) + W.alpha = 0.0 + assert W.alpha == 0.0 + W.alpha = 0.5 + assert W.alpha == 0.5 + W.alpha = 1.0 + assert W.alpha == 1.0 + + +def test_alpha_setter_rejects_out_of_range() -> None: + W = CRMOGMWeighting(MeanWeighting()) + with raises(ValueError, match="alpha"): + W.alpha = -0.1 + with raises(ValueError, match="alpha"): + W.alpha = 1.1 + + +def test_alpha_zero_reduces_to_bare_weighting() -> None: + """ + With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights + equal the bare weighting's output on every call — not just the first. + """ + + J = randn_((3, 8)) + G = J @ J.T + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0) + + expected = bare(G) + assert_close(smoothed(G), expected) + assert_close(smoothed(G), expected) + + +def test_alpha_one_freezes_weights() -> None: + """ + With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at + their initial uniform value forever. Note: the equality with uniform weights is a + consequence of the uniform initialisation, not a general property of CR-MOGM. + """ + + J = randn_((3, 8)) + m = J.shape[0] + W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) + uniform = tensor_([1.0 / m] * m) + + assert_close(W(J @ J.T), uniform) + assert_close(W(J @ J.T), uniform) + + +def test_ema_is_applied() -> None: + """Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand.""" + + alpha = 0.9 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + m = J1.shape[0] + + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) + + lambda_hat_1 = bare(G1) + lambda_hat_2 = bare(G2) + uniform = tensor_([1.0 / m] * m) + + expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1 + expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2 + + assert_close(smoothed(G1), expected_1) + assert_close(smoothed(G2), expected_2) + + +def test_zero_columns() -> None: + """ + A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row + inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would + raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility. + """ + + aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + out = aggregator(tensor_([]).reshape(2, 0)) + assert out.shape == (0,) From 53f3eb30ee2ab427816b3af8381e101a3fa02c36 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 12:54:21 -0400 Subject: [PATCH 02/14] fix(aggregation): Fix Sphinx cross-reference warnings in CRMOGMWeighting docstring --- src/torchjd/aggregation/_cr_mogm.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 0cb7e4d8..0e34faa2 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -34,9 +34,8 @@ class CRMOGMWeighting(Weighting[_T], Stateful): ``device`` of the input changes. Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a - :class:`~torchjd.aggregation._weighting_bases.MatrixWeighting` or a - :class:`~torchjd.aggregation._weighting_bases.GramianWeighting`. The user composes it with - the appropriate aggregator base: + ``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate + aggregator base: .. code-block:: python From 23c0f62fc79b531e37540bc77d9d41813d64f4ba Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 13:38:15 -0400 Subject: [PATCH 03/14] fix(aggregation): Remove broken NeurIPS URL from CRMOGMWeighting --- CHANGELOG.md | 3 +-- src/torchjd/aggregation/_cr_mogm.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 3b20f2db..05596289 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,8 +10,7 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `CRMOGMWeighting` from - [Conflict-Reduction Multi-Objective Gradient Methods](https://proceedings.neurips.cc/paper_files/paper/2022/hash/4e91f0648fb6e09f0156a7eaf6c4dfdb-Abstract-Conference.html). +- Added `CRMOGMWeighting` from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential moving average across calls. - Added getters and setters for the constructor parameters of all aggregators and weightings, so diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 0e34faa2..7c3512a8 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -18,8 +18,7 @@ class CRMOGMWeighting(Weighting[_T], Stateful): :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing - modifier from `Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022) - `_. + modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: From daf59f9aa484e33bf1e12b76e5fdab150f982577 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 13:49:20 -0400 Subject: [PATCH 04/14] test(aggregation): Cover zero-row branch in CRMOGMWeighting --- tests/unit/aggregation/test_cr_mogm.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index a7f520a3..2fcc1cea 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import MeanWeighting, UPGradWeighting +from torchjd.aggregation import MeanWeighting, SumWeighting, UPGradWeighting from torchjd.aggregation._aggregator_bases import ( GramianWeightedAggregator, WeightedAggregator, @@ -146,3 +146,15 @@ def test_zero_columns() -> None: aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) out = aggregator(tensor_([]).reshape(2, 0)) assert out.shape == (0,) + + +def test_zero_rows() -> None: + """ + Exercises the ``m=0`` branch of ``_ensure_state``. ``SumWeighting`` is used because it + handles zero-row matrices cleanly (``torch.ones(0)``), unlike ``MeanWeighting`` which + would raise ``ZeroDivisionError``. + """ + + W = CRMOGMWeighting(SumWeighting()) + weights = W(tensor_([]).reshape(0, 8)) + assert weights.shape == (0,) From e846a9c2b7c5d23ea4289439fcc010ca13df4ae8 Mon Sep 17 00:00:00 2001 From: Khush Date: Thu, 7 May 2026 22:12:13 -0400 Subject: [PATCH 05/14] refactor(aggregation): Address review feedback on CRMOGMWeighting --- CHANGELOG.md | 10 +++++----- src/torchjd/aggregation/_cr_mogm.py | 27 +++++++++++++------------- tests/unit/aggregation/test_cr_mogm.py | 26 ++++++++++++++++--------- 3 files changed, 35 insertions(+), 28 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 24aff524..5c62e64f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,9 +10,10 @@ changelog does not include internal changes that do not affect the user. ### Added -- Added `CRMOGMWeighting` from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). - It wraps an existing `Weighting` and stabilises its weights with an exponential moving average - across calls. +- Added `CRMOGMWeighting` from [On the Convergence of Stochastic Multi-Objective Gradient + Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf) + (NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential + moving average across calls. - Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting` public. These abstract base classes are now importable from `torchjd.aggregation` and documented. They can be extended to easily implement custom `Weighting`s and `Aggregator`s. @@ -23,8 +24,7 @@ changelog does not include internal changes that do not affect the user. `CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; - `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor - checks. + `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. ## [0.10.0] - 2026-04-16 diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 7c3512a8..6a0e7bc4 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -18,7 +18,9 @@ class CRMOGMWeighting(Weighting[_T], Stateful): :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it produces with an exponential moving average (EMA) across calls. This is the weight-smoothing - modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022). + modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and + Beyond `_ + (NeurIPS 2022). Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: @@ -76,7 +78,6 @@ def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: self.weighting = weighting self.alpha = alpha self._lambda: Tensor | None = None - self._state_key: tuple[int, torch.dtype, torch.device] | None = None @property def alpha(self) -> float: @@ -91,31 +92,29 @@ def alpha(self, value: float) -> None: def reset(self) -> None: """Clears the EMA state so the next forward starts from uniform weights.""" + if isinstance(self.weighting, Stateful): + self.weighting.reset() self._lambda = None - self._state_key = None def forward(self, stat: _T, /) -> Tensor: - device = stat.device - dtype = stat.dtype - m = stat.shape[0] + lambda_hat = self.weighting(stat) - self._ensure_state(m, dtype, device) + self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) lambda_prev = cast(Tensor, self._lambda) - lambda_hat = self.weighting(stat) lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat self._lambda = lambda_k.detach() return lambda_k def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> None: - key = (m, dtype, device) - if self._state_key != key or self._lambda is None: + if ( + self._lambda is None + or self._lambda.shape[0] != m + or self._lambda.dtype != dtype + or self._lambda.device != device + ): if m > 0: self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) else: self._lambda = torch.zeros(0, dtype=dtype, device=device) - self._state_key = key - - def __repr__(self) -> str: - return f"CRMOGMWeighting(weighting={self.weighting!r}, alpha={self.alpha!r})" diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 2fcc1cea..45a6fb82 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import MeanWeighting, SumWeighting, UPGradWeighting +from torchjd.aggregation import GradVacWeighting, MeanWeighting, SumWeighting, UPGradWeighting from torchjd.aggregation._aggregator_bases import ( GramianWeightedAggregator, WeightedAggregator, @@ -24,14 +24,6 @@ ] -def test_representations() -> None: - W = CRMOGMWeighting(MeanWeighting(), alpha=0.9) - expected = "CRMOGMWeighting(weighting=MeanWeighting(), alpha=0.9)" - # Weighting does not define __str__, so it falls back to __repr__. - assert repr(W) == expected - assert str(W) == expected - - @mark.parametrize(["aggregator", "matrix"], matrix_pairs) def test_expected_structure_matrix_weighting( aggregator: WeightedAggregator, matrix: Tensor @@ -62,6 +54,22 @@ def test_reset_restores_first_step_behavior() -> None: assert_close(first, W(G)) +def test_reset_propagates_to_stateful_weighting() -> None: + """ + Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is + :class:`~torchjd.aggregation.Stateful`. Uses ``GradVacWeighting`` as the inner weighting + because it is both stateful and produces weights that depend on its internal state. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(GradVacWeighting(), alpha=0.5) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + def test_alpha_setter_accepts_valid() -> None: W = CRMOGMWeighting(MeanWeighting()) W.alpha = 0.0 From e16cf489456341e9b1b4effbe48976603102031d Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 8 May 2026 09:56:16 -0400 Subject: [PATCH 06/14] refactor(aggregation): Simplify CRMOGMWeighting state logic and improve docstring --- src/torchjd/aggregation/_cr_mogm.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 6a0e7bc4..7ee086e5 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -1,6 +1,6 @@ from __future__ import annotations -from typing import TypeVar, cast +from typing import TypeVar import torch from torch import Tensor @@ -31,11 +31,10 @@ class CRMOGMWeighting(Weighting[_T], Stateful): with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first - forward call once :math:`m` is known and is reset automatically when ``m``, ``dtype`` or - ``device`` of the input changes. + forward call once :math:`m` is known and is reset automatically when ``m`` changes. Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a - ``MatrixWeighting`` or a ``GramianWeighting``. The user composes it with the appropriate + ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate aggregator base: .. code-block:: python @@ -50,7 +49,8 @@ class CRMOGMWeighting(Weighting[_T], Stateful): gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` - when restarting the smoothing from uniform weights. + when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also + reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. :param weighting: The wrapped weighting whose output is smoothed. :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing @@ -99,22 +99,17 @@ def reset(self) -> None: def forward(self, stat: _T, /) -> Tensor: lambda_hat = self.weighting(stat) - self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) - lambda_prev = cast(Tensor, self._lambda) + lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat self._lambda = lambda_k.detach() return lambda_k - def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> None: - if ( - self._lambda is None - or self._lambda.shape[0] != m - or self._lambda.dtype != dtype - or self._lambda.device != device - ): + def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: + if self._lambda is None or self._lambda.shape[0] != m: if m > 0: self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) else: self._lambda = torch.zeros(0, dtype=dtype, device=device) + return self._lambda From 1b74974331733792928c84dab0b81c628f5724ce Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 8 May 2026 10:00:03 -0400 Subject: [PATCH 07/14] refactor(aggregation): Raise on shape change in CRMOGMWeighting._ensure_state --- src/torchjd/aggregation/_cr_mogm.py | 12 +++++++----- tests/unit/aggregation/test_cr_mogm.py | 14 +------------- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 7ee086e5..a1d8bd8b 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -107,9 +107,11 @@ def forward(self, stat: _T, /) -> Tensor: return lambda_k def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: - if self._lambda is None or self._lambda.shape[0] != m: - if m > 0: - self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) - else: - self._lambda = torch.zeros(0, dtype=dtype, device=device) + if self._lambda is None: + self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) + elif self._lambda.shape[0] != m: + raise ValueError( + f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " + f"`reset()` before changing the number of objectives." + ) return self._lambda diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 45a6fb82..9c402d1b 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -3,7 +3,7 @@ from torch.testing import assert_close from utils.tensors import randn_, tensor_ -from torchjd.aggregation import GradVacWeighting, MeanWeighting, SumWeighting, UPGradWeighting +from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting from torchjd.aggregation._aggregator_bases import ( GramianWeightedAggregator, WeightedAggregator, @@ -154,15 +154,3 @@ def test_zero_columns() -> None: aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) out = aggregator(tensor_([]).reshape(2, 0)) assert out.shape == (0,) - - -def test_zero_rows() -> None: - """ - Exercises the ``m=0`` branch of ``_ensure_state``. ``SumWeighting`` is used because it - handles zero-row matrices cleanly (``torch.ones(0)``), unlike ``MeanWeighting`` which - would raise ``ZeroDivisionError``. - """ - - W = CRMOGMWeighting(SumWeighting()) - weights = W(tensor_([]).reshape(0, 8)) - assert weights.shape == (0,) From 1cb9953dd4d6fd345ac02f6799d43a1917add11d Mon Sep 17 00:00:00 2001 From: Khush Date: Fri, 8 May 2026 11:23:36 -0400 Subject: [PATCH 08/14] test(aggregation): Fix reset propagation test and cover shape-change error --- tests/unit/aggregation/test_cr_mogm.py | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 9c402d1b..80fadcc0 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -57,17 +57,26 @@ def test_reset_restores_first_step_behavior() -> None: def test_reset_propagates_to_stateful_weighting() -> None: """ Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is - :class:`~torchjd.aggregation.Stateful`. Uses ``GradVacWeighting`` as the inner weighting - because it is both stateful and produces weights that depend on its internal state. + :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal + state is cleared after ``reset()``. """ + inner = GradVacWeighting() + W = CRMOGMWeighting(inner, alpha=0.5) J = randn_((3, 8)) - G = J @ J.T - W = CRMOGMWeighting(GradVacWeighting(), alpha=0.5) - first = W(G) - W(G) + W(J @ J.T) + assert inner._phi_t is not None W.reset() - assert_close(first, W(G)) + assert inner._phi_t is None + + +def test_changing_m_raises() -> None: + """Verify that changing the number of objectives after the first call raises a ValueError.""" + + W = CRMOGMWeighting(MeanWeighting()) + W(randn_((3, 8)) @ randn_((3, 8)).T) + with raises(ValueError, match="number of objectives"): + W(randn_((2, 8)) @ randn_((2, 8)).T) def test_alpha_setter_accepts_valid() -> None: From fbec0369b7e3ad6e4db66ab6cd4158f33cca9caa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 18 May 2026 14:12:24 +0200 Subject: [PATCH 09/14] Improve docstring --- src/torchjd/aggregation/_cr_mogm.py | 69 ++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 22 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index a1d8bd8b..67723e95 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -30,26 +30,37 @@ class CRMOGMWeighting(Weighting[_T], Stateful): \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top - \in \mathbb{R}^m`. The state :math:`\lambda_{k-1}` is initialised lazily on the first - forward call once :math:`m` is known and is reset automatically when ``m`` changes. + \in \mathbb{R}^m`. - Because ``CRMOGMWeighting`` is generic in the input type ``_T``, it can wrap either a - ``MatrixWeighting`` or a ``GramianWeighting``. Creating a corresponding :class:`~torchjd.aggregation.Aggregator` can be done by composing it with the appropriate - aggregator base: + Creating the corresponding :class:`~torchjd.aggregation.Aggregator` from a wrapped weighting can + be done by composing it with the appropriate aggregator subclass + (:class:`~torchjd.aggregation.WeightedAggregator` or + :class:`~torchjd.aggregation.GramianWeightedAggregator`) - .. code-block:: python + The following example shows how to instantiate a Gramian-based weighted aggregator whose + Gramian weighting is wrapped by CR-MOGM. - from torchjd.aggregation import MeanWeighting, UPGradWeighting - from torchjd.aggregation._aggregator_bases import ( - GramianWeightedAggregator, WeightedAggregator, - ) - from torchjd.aggregation._cr_mogm import CRMOGMWeighting + .. testcode:: python - matrix_aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) - gramian_aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + from torchjd.aggregation import CRMOGMWeighting, GramianWeightedAggregator, UPGradWeighting + + aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + + The following example shows how to instantiate a Matrix-based weighted aggregator whose + weighting is wrapped by CR-MOGM. + + .. testcode:: python + + from torchjd.aggregation import CRMOGMWeighting, MeanWeighting, WeightedAggregator + + aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + + Note that here, :class:`~torchjd.aggregation.MeanWeighting` is used just for the sake of the + example: the exponential moving average of constant weights will always be equal to the weights + themselves, so wrapping by ``CRMOGMWeighting`` will have no effect. This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` - when restarting the smoothing from uniform weights. Note that calling :meth:`reset` will also + to restart the smoothing from uniform weights. Note that calling :meth:`reset` will also reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. :param weighting: The wrapped weighting whose output is smoothed. @@ -61,16 +72,27 @@ class CRMOGMWeighting(Weighting[_T], Stateful): .. note:: ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning - rate decays. Update ``alpha`` between forward calls via the public attribute on the - wrapping aggregator: + rate decays. Update ``alpha`` between forward calls via the setter. + + The following example shows how to update alpha with the suggested scheme from the paper, + when the aggregator is a Gramian-based weighted aggregator whose Gramian weighting is + wrapped by CR-MOGM: + + .. testcode:: python + + from torchjd.aggregation import ( + CRMOGMWeighting, + GramianWeightedAggregator, + UPGradWeighting, + ) - .. code-block:: python + aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) - # With WeightedAggregator - aggregator.weighting.alpha = 1 - current_lr / initial_lr + initial_lr = 0.1 + current_lr = 0.05 # e.g. obtained from lr_scheduler.get_lr()[0] - # With GramianWeightedAggregator - aggregator.gramian_weighting.alpha = 1 - current_lr / initial_lr + cr_mogm = aggregator.gramian_weighting + cr_mogm.alpha = 1 - current_lr / initial_lr """ def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: @@ -90,7 +112,10 @@ def alpha(self, value: float) -> None: self._alpha = value def reset(self) -> None: - """Clears the EMA state so the next forward starts from uniform weights.""" + """ + Clears the EMA state so the next forward starts from uniform weights. Also resets the + wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`. + """ if isinstance(self.weighting, Stateful): self.weighting.reset() From 8047115199918be74bae2492c7797234c653f2a8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 18 May 2026 14:14:44 +0200 Subject: [PATCH 10/14] Revert changelog change --- CHANGELOG.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f74d8aaa..54df7116 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -34,8 +34,8 @@ changelog does not include internal changes that do not affect the user. `CAGrad` and `CAGradWeighting`; `pref_vector` in `ConFIG`; `leak` in `GradDrop`, `n_byzantine` and `n_selected` in `Krum` and `KrumWeighting`; `epsilon` and `max_iters` in `MGDA` and `MGDAWeighting`; `n_tasks`, `max_norm`, `update_weights_every` and `optim_niter` in `NashMTL`; - `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor checks. - Note that setters for `GradVac` and `GradVacWeighting` already existed. + `trim_number` in `TrimmedMean`. Setters validate their inputs matching the existing constructor + checks. Note that setters for `GradVac` and `GradVacWeighting` already existed. ## [0.10.0] - 2026-04-16 From 129c6446b54842231ada6508fddb84ab50d52319 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 18 May 2026 14:17:07 +0200 Subject: [PATCH 11/14] Set alpha default to 0.9 --- src/torchjd/aggregation/_cr_mogm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 67723e95..4cbcfa2c 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -95,7 +95,7 @@ class CRMOGMWeighting(Weighting[_T], Stateful): cr_mogm.alpha = 1 - current_lr / initial_lr """ - def __init__(self, weighting: Weighting[_T], alpha: float = 0.1) -> None: + def __init__(self, weighting: Weighting[_T], alpha: float = 0.9) -> None: super().__init__() self.weighting = weighting self.alpha = alpha From f848dd0af1ae13d2ee110364e871edd89aded9e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 18 May 2026 14:35:14 +0200 Subject: [PATCH 12/14] feat: Add `initial_weights` param to `CRMOGMWeighting` MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Replace the hardcoded uniform λ₀ = 1/m with an optional `initial_weights` parameter. When `None` (default), λ₀ is set to λ̂₁ on the first forward call so the first smoothed output always equals the wrapped weighting's output regardless of α. Users who want uniform initialisation can still pass the tensor explicitly. Co-Authored-By: Claude Sonnet 4.6 --- src/torchjd/aggregation/_cr_mogm.py | 41 ++++++++++++------ tests/unit/aggregation/test_cr_mogm.py | 59 +++++++++++++++++++++----- 2 files changed, 78 insertions(+), 22 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 4cbcfa2c..0eafe8f8 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -2,7 +2,6 @@ from typing import TypeVar -import torch from torch import Tensor from torchjd.aggregation._mixins import Stateful @@ -29,8 +28,9 @@ class CRMOGMWeighting(Weighting[_T], Stateful): \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k - with :math:`\lambda_0 = \begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^\top - \in \mathbb{R}^m`. + where :math:`\lambda_0` is ``initial_weights`` if provided, otherwise + :math:`\lambda_0 = \hat{\lambda}_1` (so that the first smoothed output equals + :math:`\hat{\lambda}_1` regardless of :math:`\alpha`). Creating the corresponding :class:`~torchjd.aggregation.Aggregator` from a wrapped weighting can be done by composing it with the appropriate aggregator subclass @@ -60,14 +60,17 @@ class CRMOGMWeighting(Weighting[_T], Stateful): themselves, so wrapping by ``CRMOGMWeighting`` will have no effect. This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` - to restart the smoothing from uniform weights. Note that calling :meth:`reset` will also + to restart the smoothing from the initial state. Note that calling :meth:`reset` will also reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. :param weighting: The wrapped weighting whose output is smoothed. :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes - the weights at their initial uniform value. The default of ``0.9`` follows the usual - EMA convention (analogous to Adam's :math:`\beta_1`). + the weights at their initial value. The default of ``0.9`` follows the usual EMA + convention (analogous to Adam's :math:`\beta_1`). + :param initial_weights: Optional tensor to use as :math:`\lambda_0`. If ``None`` (default), + :math:`\lambda_0` is set to :math:`\hat{\lambda}_1` on the first forward call, making + the first smoothed output equal to :math:`\hat{\lambda}_1`. .. note:: ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a @@ -95,10 +98,13 @@ class CRMOGMWeighting(Weighting[_T], Stateful): cr_mogm.alpha = 1 - current_lr / initial_lr """ - def __init__(self, weighting: Weighting[_T], alpha: float = 0.9) -> None: + def __init__( + self, weighting: Weighting[_T], alpha: float = 0.9, initial_weights: Tensor | None = None + ) -> None: super().__init__() self.weighting = weighting self.alpha = alpha + self._initial_weights = initial_weights self._lambda: Tensor | None = None @property @@ -112,8 +118,8 @@ def alpha(self, value: float) -> None: self._alpha = value def reset(self) -> None: - """ - Clears the EMA state so the next forward starts from uniform weights. Also resets the + r""" + Clears the EMA state so the next forward restarts from the initial state. Also resets the wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`. """ @@ -124,16 +130,27 @@ def reset(self) -> None: def forward(self, stat: _T, /) -> Tensor: lambda_hat = self.weighting(stat) - lambda_prev = self._ensure_state(lambda_hat.shape[0], lambda_hat.dtype, lambda_hat.device) + lambda_prev = self._ensure_state(lambda_hat) lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat self._lambda = lambda_k.detach() return lambda_k - def _ensure_state(self, m: int, dtype: torch.dtype, device: torch.device) -> Tensor: + def _ensure_state(self, lambda_hat: Tensor) -> Tensor: + m = lambda_hat.shape[0] if self._lambda is None: - self._lambda = torch.full((m,), 1.0 / m, dtype=dtype, device=device) + if self._initial_weights is not None: + if self._initial_weights.shape != (m,): + raise ValueError( + f"`initial_weights` has shape {tuple(self._initial_weights.shape)}, " + f"expected ({m},)." + ) + self._lambda = self._initial_weights.to( + dtype=lambda_hat.dtype, device=lambda_hat.device + ) + else: + self._lambda = lambda_hat elif self._lambda.shape[0] != m: raise ValueError( f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py index 80fadcc0..3f00145a 100644 --- a/tests/unit/aggregation/test_cr_mogm.py +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -116,17 +116,17 @@ def test_alpha_zero_reduces_to_bare_weighting() -> None: def test_alpha_one_freezes_weights() -> None: """ With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at - their initial uniform value forever. Note: the equality with uniform weights is a - consequence of the uniform initialisation, not a general property of CR-MOGM. + their initial value forever. When ``initial_weights`` is ``None``, the initial value is + :math:`\\hat{\\lambda}_1`, so the output is frozen at the first step's bare weights. """ J = randn_((3, 8)) - m = J.shape[0] + G = J @ J.T W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) - uniform = tensor_([1.0 / m] * m) + first = W(G) - assert_close(W(J @ J.T), uniform) - assert_close(W(J @ J.T), uniform) + assert_close(W(G), first) + assert_close(W(G), first) def test_ema_is_applied() -> None: @@ -137,22 +137,61 @@ def test_ema_is_applied() -> None: J2 = randn_((3, 8)) G1 = J1 @ J1.T G2 = J2 @ J2.T - m = J1.shape[0] bare = UPGradWeighting() smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) lambda_hat_1 = bare(G1) lambda_hat_2 = bare(G2) - uniform = tensor_([1.0 / m] * m) - expected_1 = alpha * uniform + (1.0 - alpha) * lambda_hat_1 - expected_2 = alpha * expected_1 + (1.0 - alpha) * lambda_hat_2 + # lambda_0 = lambda_hat_1, so lambda_1 = lambda_hat_1 regardless of alpha + expected_1 = lambda_hat_1 + expected_2 = alpha * lambda_hat_1 + (1.0 - alpha) * lambda_hat_2 assert_close(smoothed(G1), expected_1) assert_close(smoothed(G2), expected_2) +def test_initial_weights_used_as_lambda_0() -> None: + """Verify that when ``initial_weights`` is provided it acts as :math:`\\lambda_0`.""" + + alpha = 0.5 + J = randn_((3, 8)) + G = J @ J.T + initial = tensor_([0.5, 0.3, 0.2]) + + bare = UPGradWeighting() + W = CRMOGMWeighting(UPGradWeighting(), alpha=alpha, initial_weights=initial) + + lambda_hat_1 = bare(G) + expected_1 = alpha * initial + (1.0 - alpha) * lambda_hat_1 + + assert_close(W(G), expected_1) + + +def test_reset_restores_initial_weights() -> None: + """Verify that ``reset()`` restores the user-provided ``initial_weights`` as :math:`\\lambda_0`.""" + + alpha = 0.5 + J = randn_((3, 8)) + G = J @ J.T + initial = tensor_([0.5, 0.3, 0.2]) + + W = CRMOGMWeighting(UPGradWeighting(), alpha=alpha, initial_weights=initial) + first = W(G) + W(G) + W.reset() + assert_close(W(G), first) + + +def test_initial_weights_shape_mismatch_raises() -> None: + """Verify that mismatched ``initial_weights`` shape raises a ``ValueError``.""" + + W = CRMOGMWeighting(MeanWeighting(), initial_weights=tensor_([0.5, 0.5])) + with raises(ValueError, match="initial_weights"): + W(randn_((3, 8)) @ randn_((3, 8)).T) + + def test_zero_columns() -> None: """ A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row From b2a191dee9bc12c8474657b6ac6dacc06cd6c43a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 18 May 2026 14:36:52 +0200 Subject: [PATCH 13/14] Remove .detach() --- src/torchjd/aggregation/_cr_mogm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 0eafe8f8..021cc579 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -134,7 +134,7 @@ def forward(self, stat: _T, /) -> Tensor: lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat - self._lambda = lambda_k.detach() + self._lambda = lambda_k return lambda_k def _ensure_state(self, lambda_hat: Tensor) -> Tensor: From 4d027c14ac130605e33370f6b26d845bccf7174b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Val=C3=A9rian=20Rey?= Date: Mon, 18 May 2026 14:37:15 +0200 Subject: [PATCH 14/14] Simplify forward --- src/torchjd/aggregation/_cr_mogm.py | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py index 021cc579..d056c5ec 100644 --- a/src/torchjd/aggregation/_cr_mogm.py +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -129,13 +129,9 @@ def reset(self) -> None: def forward(self, stat: _T, /) -> Tensor: lambda_hat = self.weighting(stat) - lambda_prev = self._ensure_state(lambda_hat) - - lambda_k = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat - - self._lambda = lambda_k - return lambda_k + self._lambda = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat + return self._lambda def _ensure_state(self, lambda_hat: Tensor) -> Tensor: m = lambda_hat.shape[0]