Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ changelog does not include internal changes that do not affect the user.

### Added

- Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting`
public. These abstract base classes are now importable from `torchjd.aggregation` and documented.
- Added getters and setters for the constructor parameters of all aggregators and weightings, so
that they can be changed after initialization. This includes: `pref_vector`,
`norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`;
Expand Down
12 changes: 12 additions & 0 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,21 @@ Abstract base classes
.. autoclass:: torchjd.aggregation.Aggregator
:members: __call__

.. autoclass:: torchjd.aggregation.WeightedAggregator
:members: __call__

.. autoclass:: torchjd.aggregation.GramianWeightedAggregator
:members: __call__

.. autoclass:: torchjd.aggregation.Weighting
:members: __call__

.. autoclass:: torchjd.aggregation.MatrixWeighting
:members: __call__

.. autoclass:: torchjd.aggregation.GramianWeighting
:members: __call__

.. autoclass:: torchjd.aggregation.GeneralizedWeighting
:members: __call__

Expand Down
32 changes: 18 additions & 14 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,23 @@
r"""
When doing Jacobian descent, the Jacobian matrix has to be aggregated into a vector to store in the
``.grad`` fields of the model parameters. The
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` is responsible for these aggregations.
:class:`~torchjd.aggregation.Aggregator` is responsible for these aggregations.

When using the :doc:`autogram <../autogram/index>` engine, we rather need to extract a vector
of weights from the Gramian of the Jacobian. The
:class:`~torchjd.aggregation._weighting_bases.Weighting` is responsible for this.
:class:`~torchjd.aggregation.Weighting` is responsible for this.

.. note::
Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights
from this Gramian using a :class:`~torchjd.aggregation._weighting_bases.Weighting`, and then
combining the rows of the Jacobian using these weights. For all of them, we provide both the
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` interface (to be used in autojac) and
the :class:`~torchjd.aggregation._weighting_bases.Weighting` interface (to be used in autogram).
For the rest, we only provide the :class:`~torchjd.aggregation._aggregator_bases.Aggregator`
from this Gramian using a :class:`~torchjd.aggregation.GramianWeighting`, and then combining the
rows of the Jacobian using these weights. For all of them, we provide both the
:class:`~torchjd.aggregation.Aggregator` interface (to be used in autojac) and the
:class:`~torchjd.aggregation.Weighting` interface (to be used in autogram).
For the rest, we only provide the :class:`~torchjd.aggregation.Aggregator`
interface -- they are not compatible with autogram.

:class:`Aggregators <torchjd.aggregation._aggregator_bases.Aggregator>` and :class:`Weightings
<torchjd.aggregation._weighting_bases.Weighting>` are callables that take a Jacobian matrix or a
:class:`Aggregators <torchjd.aggregation.Aggregator>` and
:class:`Weightings <torchjd.aggregation.Weighting>` are callables that take a Jacobian matrix or a
Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either
aggregate a Jacobian (of shape ``[m, n]``, where ``m`` is the number of objectives and ``n`` is the
number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape ``[m, m]``).
Expand All @@ -39,10 +39,10 @@
When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not
necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian
will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a
:class:`GeneralizedWeighting<torchjd.aggregation._weighting_bases.GeneralizedWeighting>` to extract
:class:`GeneralizedWeighting<torchjd.aggregation.GeneralizedWeighting>` to extract
a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest
:class:`GeneralizedWeighting<torchjd.aggregation._weighting_bases.GeneralizedWeighting>` is
:class:`Flattening<torchjd.aggregation._flattening.Flattening>`: it simply "flattens" the
:class:`GeneralizedWeighting<torchjd.aggregation.GeneralizedWeighting>` is
:class:`Flattening<torchjd.aggregation.Flattening>`: it simply "flattens" the
generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``),
applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of
weights.
Expand All @@ -59,7 +59,7 @@
[0.1667, 0.1667, 0.1667]])
"""

from ._aggregator_bases import Aggregator
from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator
from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting
from ._config import ConFIG
from ._constant import Constant, ConstantWeighting
Expand All @@ -80,7 +80,7 @@
from ._utils.check_dependencies import (
OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError,
)
from ._weighting_bases import GeneralizedWeighting, Weighting
from ._weighting_bases import GeneralizedWeighting, GramianWeighting, MatrixWeighting, Weighting

__all__ = [
"Aggregator",
Expand All @@ -96,10 +96,13 @@
"GradDrop",
"GradVac",
"GradVacWeighting",
"GramianWeightedAggregator",
"GramianWeighting",
"IMTLG",
"IMTLGWeighting",
"Krum",
"KrumWeighting",
"MatrixWeighting",
"Mean",
"MeanWeighting",
"MGDA",
Expand All @@ -114,6 +117,7 @@
"TrimmedMean",
"UPGrad",
"UPGradWeighting",
"WeightedAggregator",
"Weighting",
]

Expand Down
23 changes: 12 additions & 11 deletions src/torchjd/aggregation/_aggregator_bases.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from abc import ABC, abstractmethod
from typing import cast

from torch import Tensor, nn

from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix
from torchjd._linalg import Matrix, compute_gramian, is_matrix

from ._weighting_bases import Weighting
from ._weighting_bases import GramianWeighting, MatrixWeighting


class Aggregator(nn.Module, ABC):
Expand Down Expand Up @@ -46,18 +47,18 @@ def __str__(self) -> str:

class WeightedAggregator(Aggregator):
"""
Aggregator that combines the rows of the input jacobian matrix with weights given by applying a
Weighting to it.
Aggregator that combines the rows of the input Jacobian matrix with weights given by applying a
:class:`~torchjd.aggregation.MatrixWeighting` to it.

:param weighting: The object responsible for extracting the vector of weights from the matrix.
"""

def __init__(self, weighting: Weighting[Matrix]) -> None:
def __init__(self, weighting: MatrixWeighting) -> None:
super().__init__()
self.weighting = weighting

@staticmethod
def combine(matrix: Matrix, weights: Tensor) -> Tensor:
def _combine(matrix: Matrix, weights: Tensor) -> Tensor:
"""
Aggregates a matrix by making a linear combination of its rows, using the provided vector of
weights.
Expand All @@ -68,19 +69,19 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor:

def forward(self, matrix: Matrix, /) -> Tensor:
weights = self.weighting(matrix)
vector = self.combine(matrix, weights)
vector = self._combine(matrix, weights)
return vector


class GramianWeightedAggregator(WeightedAggregator):
"""
WeightedAggregator that computes the gramian of the input jacobian matrix before applying a
Weighting to it.
:class:`~torchjd.aggregation.WeightedAggregator` that computes the gramian of the input
Jacobian matrix before applying a :class:`~torchjd.aggregation.GramianWeighting` to it.

:param gramian_weighting: The object responsible for extracting the vector of weights from the
gramian.
"""

def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None:
super().__init__(gramian_weighting << compute_gramian)
def __init__(self, gramian_weighting: GramianWeighting) -> None:
super().__init__(cast(MatrixWeighting, gramian_weighting << compute_gramian))
self.gramian_weighting = gramian_weighting
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_aligned_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class AlignedMTLWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.AlignedMTL`.

:param pref_vector: The preference vector to use. If not provided, defaults to
Expand Down Expand Up @@ -89,7 +89,7 @@ def _compute_balance_transformation(

class AlignedMTL(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Independent Component Alignment for Multi-Task Learning
<https://openaccess.thecvf.com/content/CVPR2023/papers/Senushkin_Independent_Component_Alignment_for_Multi-Task_Learning_CVPR_2023_paper.pdf>`_.

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class CAGradWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.CAGrad`.

:param c: The scale of the radius of the ball constraint.
Expand Down Expand Up @@ -94,7 +94,7 @@ def norm_eps(self, value: float) -> None:

class CAGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
<https://arxiv.org/pdf/2110.14048.pdf>`_.

Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class ConFIG(Aggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG:
:class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG:
Towards Conflict-free Training of Physics Informed Neural Networks
<https://arxiv.org/pdf/2408.11104>`_.

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class ConstantWeighting(MatrixWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined
:class:`~torchjd.aggregation.MatrixWeighting` that returns constant, pre-determined
weights.

:param weights: The weights to return at each call.
Expand Down Expand Up @@ -37,7 +37,7 @@ def _check_matrix_shape(self, matrix: Tensor) -> None:

class Constant(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of
:class:`~torchjd.aggregation.WeightedAggregator` that makes a linear combination of
the rows of the provided matrix, with constant, pre-determined weights.

:param weights: The weights associated to the rows of the input matrices.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class DualProjWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.DualProj`.

:param pref_vector: The preference vector to use. If not provided, defaults to
Expand Down Expand Up @@ -78,7 +78,7 @@ def reg_eps(self, value: float) -> None:

class DualProj(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
:class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
to the solution to Equation 11 of `Gradient Episodic Memory for Continual Learning
<https://proceedings.neurips.cc/paper/2017/file/f87522788a2be2d171666752f97ddebb-Paper.pdf>`_.
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _identity(P: Tensor) -> Tensor:

class GradDrop(Aggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination
:class:`~torchjd.aggregation.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
Optimizing Deep Multitask Models with Gradient Sign Dropout
<https://arxiv.org/pdf/2010.06808.pdf>`_.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
class GradVacWeighting(GramianWeighting, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.GradVac`.

All required quantities (gradient norms, cosine similarities, and their updates after the
Expand Down Expand Up @@ -131,7 +131,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
class GradVac(GramianWeightedAggregator, Stateful):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task
Optimization in Massively Multilingual Models (ICLR 2021 Spotlight)
<https://openreview.net/forum?id=F1vEjWK-lH_>`_.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class IMTLGWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.IMTLG`.
"""

Expand All @@ -26,7 +26,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:

class IMTLG(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in
:class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in
`Towards Impartial Multi-task Learning <https://discovery.ucl.ac.uk/id/eprint/10120667/>`_.
This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization
<https://arxiv.org/pdf/2406.16232>`_, supports matrices with some linearly dependant rows.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_krum.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

class KrumWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.Krum`.

:param n_byzantine: The number of rows of the input matrix that can come from an adversarial
Expand Down Expand Up @@ -81,7 +81,7 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None:

class Krum(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning,
:class:`~torchjd.aggregation.GramianWeightedAggregator` for adversarial federated learning,
as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent
<https://proceedings.neurips.cc/paper/2017/file/f4b9ec30ad9f68f89b29639786cb62ef-Paper.pdf>`_.

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

class MeanWeighting(MatrixWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights
:class:`~torchjd.aggregation.MatrixWeighting` that gives the weights
:math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in
\mathbb{R}^m`.
"""
Expand All @@ -22,7 +22,7 @@ def forward(self, matrix: Tensor, /) -> Tensor:

class Mean(WeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input
:class:`~torchjd.aggregation.WeightedAggregator` that averages the rows of the input
matrices.
"""

Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_mgda.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class MGDAWeighting(GramianWeighting):
r"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.MGDA`.

:param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization.
Expand Down Expand Up @@ -74,7 +74,7 @@ def max_iters(self, value: int) -> None:

class MGDA(GramianWeightedAggregator):
r"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation
:class:`~torchjd.aggregation.GramianWeightedAggregator` performing the gradient aggregation
step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization
<https://comptes-rendus.academie-sciences.fr/mathematique/articles/10.1016/j.crma.2012.03.014/>`_.
The implementation is based on Algorithm 2 of `Multi-Task Learning as Multi-Objective
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

class _NashMTLWeighting(MatrixWeighting, Stateful):
"""
:class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that
:class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.MatrixWeighting` that
extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining
Game <https://arxiv.org/pdf/2202.01017.pdf>`_.

Expand Down Expand Up @@ -201,7 +201,7 @@ def reset(self) -> None:
class NashMTL(WeightedAggregator, Stateful):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
`Multi-Task Learning as a Bargaining Game <https://arxiv.org/pdf/2202.01017.pdf>`_.

:param n_tasks: The number of tasks, corresponding to the number of rows in the provided
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

class PCGradWeighting(GramianWeighting):
"""
:class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of
:class:`~torchjd.aggregation.GramianWeighting` giving the weights of
:class:`~torchjd.aggregation.PCGrad`.
"""

Expand Down Expand Up @@ -48,7 +48,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:

class PCGrad(GramianWeightedAggregator):
"""
:class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in algorithm 1 of
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
"""

Expand Down
Loading
Loading