Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
0aa1c8b
Add CRMOGMWeighting from NeurIPS 2022
KhusPatel4450 May 7, 2026
53f3eb3
fix(aggregation): Fix Sphinx cross-reference warnings in CRMOGMWeight…
KhusPatel4450 May 7, 2026
23c0f62
fix(aggregation): Remove broken NeurIPS URL from CRMOGMWeighting
KhusPatel4450 May 7, 2026
daf59f9
test(aggregation): Cover zero-row branch in CRMOGMWeighting
KhusPatel4450 May 7, 2026
f3c21a0
Merge branch 'main' into feat/cr-mogm-weighting
ValerianRey May 8, 2026
65df561
Merge branch 'main' into feat/cr-mogm-weighting
ValerianRey May 8, 2026
e846a9c
refactor(aggregation): Address review feedback on CRMOGMWeighting
KhusPatel4450 May 8, 2026
e16cf48
refactor(aggregation): Simplify CRMOGMWeighting state logic and impro…
KhusPatel4450 May 8, 2026
1b74974
refactor(aggregation): Raise on shape change in CRMOGMWeighting._ensu…
KhusPatel4450 May 8, 2026
1cb9953
test(aggregation): Fix reset propagation test and cover shape-change …
KhusPatel4450 May 8, 2026
7e02ef8
Merge branch 'main' into feat/cr-mogm-weighting
ValerianRey May 9, 2026
d186b7c
Merge branch 'main' into feat/cr-mogm-weighting
ValerianRey May 10, 2026
fbec036
Improve docstring
ValerianRey May 18, 2026
8047115
Revert changelog change
ValerianRey May 18, 2026
22f9df4
Merge branch 'main' into feat/cr-mogm-weighting
ValerianRey May 18, 2026
129c644
Set alpha default to 0.9
ValerianRey May 18, 2026
f848dd0
feat: Add `initial_weights` param to `CRMOGMWeighting`
ValerianRey May 18, 2026
b2a191d
Remove .detach()
ValerianRey May 18, 2026
4d027c1
Simplify forward
ValerianRey May 18, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ changelog does not include internal changes that do not affect the user.

### Added

- Added `CRMOGMWeighting` from [On the Convergence of Stochastic Multi-Objective Gradient
Manipulation and Beyond](https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf)
(NeurIPS 2022). It wraps an existing `Weighting` and stabilises its weights with an exponential
moving average across calls.
- Added a new abstraction: the `DualConeProjector` abstract base class and its concrete
`QuadprogProjector` implementation, to do the projection of the gradients onto the dual cone, as
required in `UPGrad`, and `DualProj`. These classes can be found in `torchjd.linalg`.
Expand Down
15 changes: 15 additions & 0 deletions docs/source/docs/aggregation/cr_mogm.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
:hide-toc:

CR-MOGM
=======

.. autoclass:: torchjd.aggregation.CRMOGMWeighting
:members: __call__, reset

.. note::
The usage example in the docstring above imports
``WeightedAggregator`` / ``GramianWeightedAggregator`` from
``torchjd.aggregation._aggregator_bases``, which is a private module. These two
aggregator base classes are not currently part of the public ``torchjd.aggregation``
namespace, so this private-module import is the only path that works today. Promoting
them to the public namespace is a separate decision left to the maintainers.
Comment thread
ValerianRey marked this conversation as resolved.
1 change: 1 addition & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ Abstract base classes
cagrad.rst
config.rst
constant.rst
cr_mogm.rst
dualproj.rst
flattening.rst
graddrop.rst
Expand Down
2 changes: 2 additions & 0 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from ._cagrad import CAGrad, CAGradWeighting
from ._config import ConFIG
from ._constant import Constant, ConstantWeighting
from ._cr_mogm import CRMOGMWeighting
from ._dualproj import DualProj, DualProjWeighting
from ._flattening import Flattening
from ._graddrop import GradDrop
Expand All @@ -91,6 +92,7 @@
"ConFIG",
"Constant",
"ConstantWeighting",
"CRMOGMWeighting",
"DualProj",
"DualProjWeighting",
"Flattening",
Expand Down
155 changes: 155 additions & 0 deletions src/torchjd/aggregation/_cr_mogm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from __future__ import annotations

from typing import TypeVar

from torch import Tensor

from torchjd.aggregation._mixins import Stateful

from ._weighting_bases import Weighting

_T = TypeVar("_T", contravariant=True, bound=Tensor)


class CRMOGMWeighting(Weighting[_T], Stateful):
Comment thread
ValerianRey marked this conversation as resolved.
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` that wraps another
:class:`~torchjd.aggregation._weighting_bases.Weighting` and stabilises the weights it
produces with an exponential moving average (EMA) across calls. This is the weight-smoothing
modifier from `On the Convergence of Stochastic Multi-Objective Gradient Manipulation and
Beyond <https://proceedings.neurips.cc/paper_files/paper/2022/file/f91bd64a3620aad8e70a27ad9cb3ca57-Paper-Conference.pdf>`_
(NeurIPS 2022).

Let :math:`\hat{\lambda}_k` be the weights returned by the wrapped weighting at step
:math:`k`. The smoothed weights returned by ``CRMOGMWeighting`` are:

.. math::

\lambda_k = \alpha \, \lambda_{k-1} + (1 - \alpha) \, \hat{\lambda}_k

where :math:`\lambda_0` is ``initial_weights`` if provided, otherwise
:math:`\lambda_0 = \hat{\lambda}_1` (so that the first smoothed output equals
:math:`\hat{\lambda}_1` regardless of :math:`\alpha`).

Creating the corresponding :class:`~torchjd.aggregation.Aggregator` from a wrapped weighting can
be done by composing it with the appropriate aggregator subclass
(:class:`~torchjd.aggregation.WeightedAggregator` or
:class:`~torchjd.aggregation.GramianWeightedAggregator`)

The following example shows how to instantiate a Gramian-based weighted aggregator whose
Gramian weighting is wrapped by CR-MOGM.

.. testcode:: python

from torchjd.aggregation import CRMOGMWeighting, GramianWeightedAggregator, UPGradWeighting

aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

The following example shows how to instantiate a Matrix-based weighted aggregator whose
weighting is wrapped by CR-MOGM.

.. testcode:: python

from torchjd.aggregation import CRMOGMWeighting, MeanWeighting, WeightedAggregator

aggregator = WeightedAggregator(CRMOGMWeighting(MeanWeighting()))

Note that here, :class:`~torchjd.aggregation.MeanWeighting` is used just for the sake of the
example: the exponential moving average of constant weights will always be equal to the weights
themselves, so wrapping by ``CRMOGMWeighting`` will have no effect.

Comment thread
ValerianRey marked this conversation as resolved.
This weighting is stateful: it keeps :math:`\lambda_{k-1}` across calls. Use :meth:`reset`
to restart the smoothing from the initial state. Note that calling :meth:`reset` will also
reset the wrapped weighting if it is :class:`~torchjd.aggregation.Stateful`.

:param weighting: The wrapped weighting whose output is smoothed.
:param alpha: EMA coefficient on the previous weights. ``alpha=0`` disables smoothing
(``CRMOGMWeighting`` returns ``weighting``'s output verbatim) and ``alpha=1`` freezes
the weights at their initial value. The default of ``0.9`` follows the usual EMA
convention (analogous to Adam's :math:`\beta_1`).
:param initial_weights: Optional tensor to use as :math:`\lambda_0`. If ``None`` (default),
:math:`\lambda_0` is set to :math:`\hat{\lambda}_1` on the first forward call, making
the first smoothed output equal to :math:`\hat{\lambda}_1`.

.. note::
``alpha`` is a fixed ``float`` for simplicity. Corollary 1 of the paper recommends a
schedule where :math:`\alpha_k` starts near 0 and increases toward 1 as the learning
rate decays. Update ``alpha`` between forward calls via the setter.

The following example shows how to update alpha with the suggested scheme from the paper,
when the aggregator is a Gramian-based weighted aggregator whose Gramian weighting is
wrapped by CR-MOGM:

.. testcode:: python

from torchjd.aggregation import (
CRMOGMWeighting,
GramianWeightedAggregator,
UPGradWeighting,
)

aggregator = GramianWeightedAggregator(CRMOGMWeighting(UPGradWeighting()))

initial_lr = 0.1
current_lr = 0.05 # e.g. obtained from lr_scheduler.get_lr()[0]

cr_mogm = aggregator.gramian_weighting
cr_mogm.alpha = 1 - current_lr / initial_lr
"""

def __init__(
self, weighting: Weighting[_T], alpha: float = 0.9, initial_weights: Tensor | None = None
) -> None:
super().__init__()
self.weighting = weighting
self.alpha = alpha
self._initial_weights = initial_weights
self._lambda: Tensor | None = None

@property
def alpha(self) -> float:
return self._alpha

@alpha.setter
def alpha(self, value: float) -> None:
if not (0.0 <= value <= 1.0):
raise ValueError(f"Attribute `alpha` must be in [0, 1]. Found alpha={value!r}.")
self._alpha = value

def reset(self) -> None:
r"""
Clears the EMA state so the next forward restarts from the initial state. Also resets the
wrapped weighting if it is :class:`~torchjd.aggregation._mixins.Stateful`.
"""

if isinstance(self.weighting, Stateful):
self.weighting.reset()
self._lambda = None

def forward(self, stat: _T, /) -> Tensor:
lambda_hat = self.weighting(stat)
lambda_prev = self._ensure_state(lambda_hat)
self._lambda = self._alpha * lambda_prev + (1.0 - self._alpha) * lambda_hat
return self._lambda

def _ensure_state(self, lambda_hat: Tensor) -> Tensor:
m = lambda_hat.shape[0]
if self._lambda is None:
if self._initial_weights is not None:
if self._initial_weights.shape != (m,):
raise ValueError(
f"`initial_weights` has shape {tuple(self._initial_weights.shape)}, "
f"expected ({m},)."
)
self._lambda = self._initial_weights.to(
dtype=lambda_hat.dtype, device=lambda_hat.device
)
else:
self._lambda = lambda_hat
elif self._lambda.shape[0] != m:
raise ValueError(
f"The number of objectives changed from {self._lambda.shape[0]} to {m}. Call "
f"`reset()` before changing the number of objectives."
)
return self._lambda
Loading
Loading