diff --git a/CHANGELOG.md b/CHANGELOG.md index e2bd218f..f092ca02 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -36,6 +36,10 @@ changelog does not include internal changes that do not affect the user. ### Added +- Added `CRMOGMWeighting` from [On the Convergence of Stochastic Multi-Objective Gradient + Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf) + (NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential + moving average across calls. - Added a new abstraction: the `DualConeProjector` abstract base class and its concrete `QuadprogProjector` implementation, to do the projection of the gradients onto the dual cone, as required in `UPGrad`, and `DualProj`. These classes can be found in `torchjd.linalg`. diff --git a/docs/source/docs/aggregation/cr_mogm.rst b/docs/source/docs/aggregation/cr_mogm.rst new file mode 100644 index 00000000..47e70f49 --- /dev/null +++ b/docs/source/docs/aggregation/cr_mogm.rst @@ -0,0 +1,15 @@ +:hide-toc: + +CR-MOGM +======= + +.. autoclass:: torchjd.aggregation.CRMOGMWeighting + :members: __call__, reset + +.. note:: + The usage example in the docstring above imports + ``WeightedAggregator`` / ``GramianWeightedAggregator`` from + ``torchjd.aggregation._aggregator_bases``, which is a private module. These two + aggregator base classes are not currently part of the public ``torchjd.aggregation`` + namespace, so this private-module import is the only path that works today. Promoting + them to the public namespace is a separate decision left to the maintainers. diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index ff6e1811..3c0516b4 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -35,6 +35,7 @@ Abstract base classes cagrad.rst config.rst constant.rst + cr_mogm.rst dualproj.rst flattening.rst graddrop.rst diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 0299bfc3..6358ec0e 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -65,6 +65,7 @@ from ._cagrad import CAGrad, CAGradWeighting from ._config import ConFIG from ._constant import Constant, ConstantWeighting +from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting from ._flattening import Flattening from ._graddrop import GradDrop @@ -91,6 +92,7 @@ "ConFIG", "Constant", "ConstantWeighting", + "CRMOGMWeighting", "DualProj", "DualProjWeighting", "Flattening", diff --git a/src/torchjd/aggregation/_cr_mogm.py b/src/torchjd/aggregation/_cr_mogm.py new file mode 100644 index 00000000..d056c5ec --- /dev/null +++ b/src/torchjd/aggregation/_cr_mogm.py @@ -0,0 +1,155 @@ +from __future__ import annotations + +from typing import TypeVar + +from torch import Tensor + +from torchjd.aggregation._mixins import Stateful + +from ._weighting_bases import Weighting + +_T = TypeVar("_T", contravariant=True, bound=Tensor) + + +class CRMOGMWeighting(Weighting[_T], Stateful): + r""" + :class:`~torchjd.aggregation._mixins.Stateful` + :class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another + :class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it + produces with an exponential moving average (EMA) across calls. This is the weight-smoothing + modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and + Beyond `_ + (NeurIPS 2022). + + Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step + :math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are: + + .. math:: + + \lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k + + where :math:`\lambda_0` is ``initial_weights`` if provided, otherwise + :math:`\lambda_0 = \hat{\lambda}_1` (so that the first smoothed output equals + :math:`\hat{\lambda}_1` regardless of :math:`\alpha`). + + Creating the corresponding :class:`~torchjd.aggregation.Aggregator` from a wrapped weighting can + be done by composing it with the appropriate aggregator subclass + (:class:`~torchjd.aggregation.WeightedAggregator` or + :class:`~torchjd.aggregation.GramianWeightedAggregator`) + + The following example shows how to instantiate a Gramian-based weighted aggregator whose + Gramian weighting is wrapped by CR-MOGM. + + .. testcode:: python + + from torchjd.aggregation import CRMOGMWeighting, GramianWeightedAggregator, UPGradWeighting + + aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + + The following example shows how to instantiate a Matrix-based weighted aggregator whose + weighting is wrapped by CR-MOGM. + + .. testcode:: python + + from torchjd.aggregation import CRMOGMWeighting, MeanWeighting, WeightedAggregator + + aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + + Note that here, :class:`~torchjd.aggregation.MeanWeighting` is used just for the sake of the + example: the exponential moving average of constant weights will always be equal to the weights + themselves, so wrapping by ``CRMOGMWeighting`` will have no effect. + + This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset` + to restart the smoothing from the initial state. Note that calling :meth:`reset` will also + reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`. + + :param weighting: The wrapped weighting whose output is smoothed. + :param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing + (``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes + the weights at their initial value. The default of ``0.9`` follows the usual EMA + convention (analogous to Adam's :math:`\beta_1`). + :param initial_weights: Optional tensor to use as :math:`\lambda_0`. If ``None`` (default), + :math:`\lambda_0` is set to :math:`\hat{\lambda}_1` on the first forward call, making + the first smoothed output equal to :math:`\hat{\lambda}_1`. + + .. note:: + ``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a + schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning + rate decays. Update ``alpha`` between forward calls via the setter. + + The following example shows how to update alpha with the suggested scheme from the paper, + when the aggregator is a Gramian-based weighted aggregator whose Gramian weighting is + wrapped by CR-MOGM: + + .. testcode:: python + + from torchjd.aggregation import ( + CRMOGMWeighting, + GramianWeightedAggregator, + UPGradWeighting, + ) + + aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())) + + initial_lr = 0.1 + current_lr = 0.05 # e.g. obtained from lr_scheduler.get_lr()[0] + + cr_mogm = aggregator.gramian_weighting + cr_mogm.alpha = 1 - current_lr / initial_lr + """ + + def __init__( + self, weighting: Weighting[_T], alpha: float = 0.9, initial_weights: Tensor | None = None + ) -> None: + super().__init__() + self.weighting = weighting + self.alpha = alpha + self._initial_weights = initial_weights + self._lambda: Tensor | None = None + + @property + def alpha(self) -> float: + return self._alpha + + @alpha.setter + def alpha(self, value: float) -> None: + if not (0.0 <= value <= 1.0): + raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.") + self._alpha = value + + def reset(self) -> None: + r""" + Clears the EMA state so the next forward restarts from the initial state. Also resets the + wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`. + """ + + if isinstance(self.weighting, Stateful): + self.weighting.reset() + self._lambda = None + + def forward(self, stat: _T, /) -> Tensor: + lambda_hat = self.weighting(stat) + lambda_prev = self._ensure_state(lambda_hat) + self._lambda = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat + return self._lambda + + def _ensure_state(self, lambda_hat: Tensor) -> Tensor: + m = lambda_hat.shape[0] + if self._lambda is None: + if self._initial_weights is not None: + if self._initial_weights.shape != (m,): + raise ValueError( + f"`initial_weights` has shape {tuple(self._initial_weights.shape)}, " + f"expected ({m},)." + ) + self._lambda = self._initial_weights.to( + dtype=lambda_hat.dtype, device=lambda_hat.device + ) + else: + self._lambda = lambda_hat + elif self._lambda.shape[0] != m: + raise ValueError( + f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call " + f"`reset()` before changing the number of objectives." + ) + return self._lambda diff --git a/tests/unit/aggregation/test_cr_mogm.py b/tests/unit/aggregation/test_cr_mogm.py new file mode 100644 index 00000000..3f00145a --- /dev/null +++ b/tests/unit/aggregation/test_cr_mogm.py @@ -0,0 +1,204 @@ +from pytest import mark, raises +from torch import Tensor +from torch.testing import assert_close +from utils.tensors import randn_, tensor_ + +from torchjd.aggregation import GradVacWeighting, MeanWeighting, UPGradWeighting +from torchjd.aggregation._aggregator_bases import ( + GramianWeightedAggregator, + WeightedAggregator, +) +from torchjd.aggregation._cr_mogm import CRMOGMWeighting + +from ._asserts import assert_expected_structure +from ._inputs import scaled_matrices, typical_matrices + +# UPGradWeighting uses a QP solver that can fail on the extreme scales (0.0, 1e15) found in +# scaled_matrices, so the gramian-path structural test only uses typical_matrices. +matrix_pairs = [ + (WeightedAggregator(CRMOGMWeighting(MeanWeighting())), m) + for m in typical_matrices + scaled_matrices +] +gramian_pairs = [ + (GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting())), m) for m in typical_matrices +] + + +@mark.parametrize(["aggregator", "matrix"], matrix_pairs) +def test_expected_structure_matrix_weighting( + aggregator: WeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +@mark.parametrize(["aggregator", "matrix"], gramian_pairs) +def test_expected_structure_gramian_weighting( + aggregator: GramianWeightedAggregator, matrix: Tensor +) -> None: + assert_expected_structure(aggregator, matrix) + + +def test_reset_restores_first_step_behavior() -> None: + """ + Use ``UPGradWeighting`` so the weights actually depend on the input — with + ``MeanWeighting`` the EMA would be a fixed point at the uniform weights and the test would + be trivial. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(UPGradWeighting(), alpha=0.5) + first = W(G) + W(G) + W.reset() + assert_close(first, W(G)) + + +def test_reset_propagates_to_stateful_weighting() -> None: + """ + Verify that ``reset()`` calls the wrapped weighting's ``reset()`` when it is + :class:`~torchjd.aggregation.Stateful`. Checks that ``GradVacWeighting``'s internal + state is cleared after ``reset()``. + """ + + inner = GradVacWeighting() + W = CRMOGMWeighting(inner, alpha=0.5) + J = randn_((3, 8)) + W(J @ J.T) + assert inner._phi_t is not None + W.reset() + assert inner._phi_t is None + + +def test_changing_m_raises() -> None: + """Verify that changing the number of objectives after the first call raises a ValueError.""" + + W = CRMOGMWeighting(MeanWeighting()) + W(randn_((3, 8)) @ randn_((3, 8)).T) + with raises(ValueError, match="number of objectives"): + W(randn_((2, 8)) @ randn_((2, 8)).T) + + +def test_alpha_setter_accepts_valid() -> None: + W = CRMOGMWeighting(MeanWeighting()) + W.alpha = 0.0 + assert W.alpha == 0.0 + W.alpha = 0.5 + assert W.alpha == 0.5 + W.alpha = 1.0 + assert W.alpha == 1.0 + + +def test_alpha_setter_rejects_out_of_range() -> None: + W = CRMOGMWeighting(MeanWeighting()) + with raises(ValueError, match="alpha"): + W.alpha = -0.1 + with raises(ValueError, match="alpha"): + W.alpha = 1.1 + + +def test_alpha_zero_reduces_to_bare_weighting() -> None: + """ + With ``alpha=0`` the previous state is always multiplied by zero, so the smoothed weights + equal the bare weighting's output on every call — not just the first. + """ + + J = randn_((3, 8)) + G = J @ J.T + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=0.0) + + expected = bare(G) + assert_close(smoothed(G), expected) + assert_close(smoothed(G), expected) + + +def test_alpha_one_freezes_weights() -> None: + """ + With ``alpha=1`` the fresh weights are multiplied by zero, so the smoothed weights stay at + their initial value forever. When ``initial_weights`` is ``None``, the initial value is + :math:`\\hat{\\lambda}_1`, so the output is frozen at the first step's bare weights. + """ + + J = randn_((3, 8)) + G = J @ J.T + W = CRMOGMWeighting(UPGradWeighting(), alpha=1.0) + first = W(G) + + assert_close(W(G), first) + assert_close(W(G), first) + + +def test_ema_is_applied() -> None: + """Run two steps with ``alpha=0.9`` and check the EMA recurrence by hand.""" + + alpha = 0.9 + J1 = randn_((3, 8)) + J2 = randn_((3, 8)) + G1 = J1 @ J1.T + G2 = J2 @ J2.T + + bare = UPGradWeighting() + smoothed = CRMOGMWeighting(UPGradWeighting(), alpha=alpha) + + lambda_hat_1 = bare(G1) + lambda_hat_2 = bare(G2) + + # lambda_0 = lambda_hat_1, so lambda_1 = lambda_hat_1 regardless of alpha + expected_1 = lambda_hat_1 + expected_2 = alpha * lambda_hat_1 + (1.0 - alpha) * lambda_hat_2 + + assert_close(smoothed(G1), expected_1) + assert_close(smoothed(G2), expected_2) + + +def test_initial_weights_used_as_lambda_0() -> None: + """Verify that when ``initial_weights`` is provided it acts as :math:`\\lambda_0`.""" + + alpha = 0.5 + J = randn_((3, 8)) + G = J @ J.T + initial = tensor_([0.5, 0.3, 0.2]) + + bare = UPGradWeighting() + W = CRMOGMWeighting(UPGradWeighting(), alpha=alpha, initial_weights=initial) + + lambda_hat_1 = bare(G) + expected_1 = alpha * initial + (1.0 - alpha) * lambda_hat_1 + + assert_close(W(G), expected_1) + + +def test_reset_restores_initial_weights() -> None: + """Verify that ``reset()`` restores the user-provided ``initial_weights`` as :math:`\\lambda_0`.""" + + alpha = 0.5 + J = randn_((3, 8)) + G = J @ J.T + initial = tensor_([0.5, 0.3, 0.2]) + + W = CRMOGMWeighting(UPGradWeighting(), alpha=alpha, initial_weights=initial) + first = W(G) + W(G) + W.reset() + assert_close(W(G), first) + + +def test_initial_weights_shape_mismatch_raises() -> None: + """Verify that mismatched ``initial_weights`` shape raises a ``ValueError``.""" + + W = CRMOGMWeighting(MeanWeighting(), initial_weights=tensor_([0.5, 0.5])) + with raises(ValueError, match="initial_weights"): + W(randn_((3, 8)) @ randn_((3, 8)).T) + + +def test_zero_columns() -> None: + """ + A ``(2, 0)`` matrix has no columns to combine, so the aggregation must be empty. Zero-row + inputs are intentionally not tested: ``MeanWeighting`` does ``1/m`` in Python and would + raise ``ZeroDivisionError`` at ``m=0``, which is the wrapped weighting's responsibility. + """ + + aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting())) + out = aggregator(tensor_([]).reshape(2, 0)) + assert out.shape == (0,)