diff --git a/CHANGELOG.md b/CHANGELOG.md index fe52e1fec..fe6f140bf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,8 @@ changelog does not include internal changes that do not affect the user. ### Added +- Made `WeightedAggregator`, `GramianWeightedAggregator`, `MatrixWeighting`, and `GramianWeighting` + public. These abstract base classes are now importable from `torchjd.aggregation` and documented. - Added getters and setters for the constructor parameters of all aggregators and weightings, so that they can be changed after initialization. This includes: `pref_vector`, `norm_eps` and `reg_eps` in `UPGrad`, `UPGradWeighting`, `DualProj` and `DualProjWeighting`; diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 4d62f820c..04a8de666 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -10,9 +10,21 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Aggregator :members: __call__ +.. autoclass:: torchjd.aggregation.WeightedAggregator + :members: __call__ + +.. autoclass:: torchjd.aggregation.GramianWeightedAggregator + :members: __call__ + .. autoclass:: torchjd.aggregation.Weighting :members: __call__ +.. autoclass:: torchjd.aggregation.MatrixWeighting + :members: __call__ + +.. autoclass:: torchjd.aggregation.GramianWeighting + :members: __call__ + .. autoclass:: torchjd.aggregation.GeneralizedWeighting :members: __call__ diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index 400cfe270..ec871e899 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -1,23 +1,23 @@ r""" When doing Jacobian descent, the Jacobian matrix has to be aggregated into a vector to store in the ``.grad`` fields of the model parameters. The -:class:`~torchjd.aggregation._aggregator_bases.Aggregator` is responsible for these aggregations. +:class:`~torchjd.aggregation.Aggregator` is responsible for these aggregations. When using the :doc:`autogram <../autogram/index>` engine, we rather need to extract a vector of weights from the Gramian of the Jacobian. The -:class:`~torchjd.aggregation._weighting_bases.Weighting` is responsible for this. +:class:`~torchjd.aggregation.Weighting` is responsible for this. .. note:: Most aggregators rely on computing the Gramian of the Jacobian, extracting a vector of weights - from this Gramian using a :class:`~torchjd.aggregation._weighting_bases.Weighting`, and then - combining the rows of the Jacobian using these weights. For all of them, we provide both the - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` interface (to be used in autojac) and - the :class:`~torchjd.aggregation._weighting_bases.Weighting` interface (to be used in autogram). - For the rest, we only provide the :class:`~torchjd.aggregation._aggregator_bases.Aggregator` + from this Gramian using a :class:`~torchjd.aggregation.GramianWeighting`, and then combining the + rows of the Jacobian using these weights. For all of them, we provide both the + :class:`~torchjd.aggregation.Aggregator` interface (to be used in autojac) and the + :class:`~torchjd.aggregation.Weighting` interface (to be used in autogram). + For the rest, we only provide the :class:`~torchjd.aggregation.Aggregator` interface -- they are not compatible with autogram. -:class:`Aggregators ` and :class:`Weightings -` are callables that take a Jacobian matrix or a +:class:`Aggregators ` and +:class:`Weightings ` are callables that take a Jacobian matrix or a Gramian matrix as inputs, respectively. The following example shows how to use UPGrad to either aggregate a Jacobian (of shape ``[m, n]``, where ``m`` is the number of objectives and ``n`` is the number of parameters), or obtain the weights from the Gramian of the Jacobian (of shape ``[m, m]``). @@ -39,10 +39,10 @@ When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a -:class:`GeneralizedWeighting` to extract +:class:`GeneralizedWeighting` to extract a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest -:class:`GeneralizedWeighting` is -:class:`Flattening`: it simply "flattens" the +:class:`GeneralizedWeighting` is +:class:`Flattening`: it simply "flattens" the generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``), applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of weights. @@ -59,7 +59,7 @@ [0.1667, 0.1667, 0.1667]]) """ -from ._aggregator_bases import Aggregator +from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator from ._aligned_mtl import AlignedMTL, AlignedMTLWeighting from ._config import ConFIG from ._constant import Constant, ConstantWeighting @@ -80,7 +80,7 @@ from ._utils.check_dependencies import ( OptionalDepsNotInstalledError as _OptionalDepsNotInstalledError, ) -from ._weighting_bases import GeneralizedWeighting, Weighting +from ._weighting_bases import GeneralizedWeighting, GramianWeighting, MatrixWeighting, Weighting __all__ = [ "Aggregator", @@ -96,10 +96,13 @@ "GradDrop", "GradVac", "GradVacWeighting", + "GramianWeightedAggregator", + "GramianWeighting", "IMTLG", "IMTLGWeighting", "Krum", "KrumWeighting", + "MatrixWeighting", "Mean", "MeanWeighting", "MGDA", @@ -114,6 +117,7 @@ "TrimmedMean", "UPGrad", "UPGradWeighting", + "WeightedAggregator", "Weighting", ] diff --git a/src/torchjd/aggregation/_aggregator_bases.py b/src/torchjd/aggregation/_aggregator_bases.py index 2ed1505ef..7372b0027 100644 --- a/src/torchjd/aggregation/_aggregator_bases.py +++ b/src/torchjd/aggregation/_aggregator_bases.py @@ -1,10 +1,11 @@ from abc import ABC, abstractmethod +from typing import cast from torch import Tensor, nn -from torchjd._linalg import Matrix, PSDMatrix, compute_gramian, is_matrix +from torchjd._linalg import Matrix, compute_gramian, is_matrix -from ._weighting_bases import Weighting +from ._weighting_bases import GramianWeighting, MatrixWeighting class Aggregator(nn.Module, ABC): @@ -46,18 +47,18 @@ def __str__(self) -> str: class WeightedAggregator(Aggregator): """ - Aggregator that combines the rows of the input jacobian matrix with weights given by applying a - Weighting to it. + Aggregator that combines the rows of the input Jacobian matrix with weights given by applying a + :class:`~torchjd.aggregation.MatrixWeighting` to it. :param weighting: The object responsible for extracting the vector of weights from the matrix. """ - def __init__(self, weighting: Weighting[Matrix]) -> None: + def __init__(self, weighting: MatrixWeighting) -> None: super().__init__() self.weighting = weighting @staticmethod - def combine(matrix: Matrix, weights: Tensor) -> Tensor: + def _combine(matrix: Matrix, weights: Tensor) -> Tensor: """ Aggregates a matrix by making a linear combination of its rows, using the provided vector of weights. @@ -68,19 +69,19 @@ def combine(matrix: Matrix, weights: Tensor) -> Tensor: def forward(self, matrix: Matrix, /) -> Tensor: weights = self.weighting(matrix) - vector = self.combine(matrix, weights) + vector = self._combine(matrix, weights) return vector class GramianWeightedAggregator(WeightedAggregator): """ - WeightedAggregator that computes the gramian of the input jacobian matrix before applying a - Weighting to it. + :class:`~torchjd.aggregation.WeightedAggregator` that computes the gramian of the input + Jacobian matrix before applying a :class:`~torchjd.aggregation.GramianWeighting` to it. :param gramian_weighting: The object responsible for extracting the vector of weights from the gramian. """ - def __init__(self, gramian_weighting: Weighting[PSDMatrix]) -> None: - super().__init__(gramian_weighting << compute_gramian) + def __init__(self, gramian_weighting: GramianWeighting) -> None: + super().__init__(cast(MatrixWeighting, gramian_weighting << compute_gramian)) self.gramian_weighting = gramian_weighting diff --git a/src/torchjd/aggregation/_aligned_mtl.py b/src/torchjd/aggregation/_aligned_mtl.py index b0fbe3b42..5ff9e5ed5 100644 --- a/src/torchjd/aggregation/_aligned_mtl.py +++ b/src/torchjd/aggregation/_aligned_mtl.py @@ -19,7 +19,7 @@ class AlignedMTLWeighting(GramianWeighting): r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.AlignedMTL`. :param pref_vector: The preference vector to use. If not provided, defaults to @@ -89,7 +89,7 @@ def _compute_balance_transformation( class AlignedMTL(GramianWeightedAggregator): r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of + :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Independent Component Alignment for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/_cagrad.py b/src/torchjd/aggregation/_cagrad.py index 40ce85157..a0b09cb1c 100644 --- a/src/torchjd/aggregation/_cagrad.py +++ b/src/torchjd/aggregation/_cagrad.py @@ -20,7 +20,7 @@ class CAGradWeighting(GramianWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.CAGrad`. :param c: The scale of the radius of the ball constraint. @@ -94,7 +94,7 @@ def norm_eps(self, value: float) -> None: class CAGrad(GramianWeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Algorithm 1 of + :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of `Conflict-Averse Gradient Descent for Multi-task Learning `_. diff --git a/src/torchjd/aggregation/_config.py b/src/torchjd/aggregation/_config.py index f19c4023a..10fbd9986 100644 --- a/src/torchjd/aggregation/_config.py +++ b/src/torchjd/aggregation/_config.py @@ -15,7 +15,7 @@ class ConFIG(Aggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in Equation 2 of `ConFIG: + :class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG: Towards Conflict-free Training of Physics Informed Neural Networks `_. diff --git a/src/torchjd/aggregation/_constant.py b/src/torchjd/aggregation/_constant.py index 8b0f73075..91a639ba6 100644 --- a/src/torchjd/aggregation/_constant.py +++ b/src/torchjd/aggregation/_constant.py @@ -7,7 +7,7 @@ class ConstantWeighting(MatrixWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` that returns constant, pre-determined + :class:`~torchjd.aggregation.MatrixWeighting` that returns constant, pre-determined weights. :param weights: The weights to return at each call. @@ -37,7 +37,7 @@ def _check_matrix_shape(self, matrix: Tensor) -> None: class Constant(WeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that makes a linear combination of + :class:`~torchjd.aggregation.WeightedAggregator` that makes a linear combination of the rows of the provided matrix, with constant, pre-determined weights. :param weights: The weights associated to the rows of the input matrices. diff --git a/src/torchjd/aggregation/_dualproj.py b/src/torchjd/aggregation/_dualproj.py index acb87d2fb..a9167648c 100644 --- a/src/torchjd/aggregation/_dualproj.py +++ b/src/torchjd/aggregation/_dualproj.py @@ -12,7 +12,7 @@ class DualProjWeighting(GramianWeighting): r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.DualProj`. :param pref_vector: The preference vector to use. If not provided, defaults to @@ -78,7 +78,7 @@ def reg_eps(self, value: float) -> None: class DualProj(GramianWeightedAggregator): r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input + :class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds to the solution to Equation 11 of `Gradient Episodic Memory for Continual Learning `_. diff --git a/src/torchjd/aggregation/_graddrop.py b/src/torchjd/aggregation/_graddrop.py index fc67810da..81ebf8176 100644 --- a/src/torchjd/aggregation/_graddrop.py +++ b/src/torchjd/aggregation/_graddrop.py @@ -15,7 +15,7 @@ def _identity(P: Tensor) -> Tensor: class GradDrop(Aggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that applies the gradient combination + :class:`~torchjd.aggregation.Aggregator` that applies the gradient combination steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign: Optimizing Deep Multitask Models with Gradient Sign Dropout `_. diff --git a/src/torchjd/aggregation/_gradvac.py b/src/torchjd/aggregation/_gradvac.py index 26a075b90..b791f6808 100644 --- a/src/torchjd/aggregation/_gradvac.py +++ b/src/torchjd/aggregation/_gradvac.py @@ -16,7 +16,7 @@ class GradVacWeighting(GramianWeighting, Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.GradVac`. All required quantities (gradient norms, cosine similarities, and their updates after the @@ -131,7 +131,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None: class GradVac(GramianWeightedAggregator, Stateful): r""" :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` implementing the aggregation step of + :class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of Gradient Vaccine (GradVac) from `Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models (ICLR 2021 Spotlight) `_. diff --git a/src/torchjd/aggregation/_imtl_g.py b/src/torchjd/aggregation/_imtl_g.py index 672f4d515..a53085be7 100644 --- a/src/torchjd/aggregation/_imtl_g.py +++ b/src/torchjd/aggregation/_imtl_g.py @@ -10,7 +10,7 @@ class IMTLGWeighting(GramianWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.IMTLG`. """ @@ -26,7 +26,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: class IMTLG(GramianWeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` generalizing the method described in + :class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in `Towards Impartial Multi-task Learning `_. This generalization, defined formally in `Jacobian Descent For Multi-Objective Optimization `_, supports matrices with some linearly dependant rows. diff --git a/src/torchjd/aggregation/_krum.py b/src/torchjd/aggregation/_krum.py index 046c0a461..db268e0c7 100644 --- a/src/torchjd/aggregation/_krum.py +++ b/src/torchjd/aggregation/_krum.py @@ -10,7 +10,7 @@ class KrumWeighting(GramianWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.Krum`. :param n_byzantine: The number of rows of the input matrix that can come from an adversarial @@ -81,7 +81,7 @@ def _check_matrix_shape(self, gramian: PSDMatrix) -> None: class Krum(GramianWeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, + :class:`~torchjd.aggregation.GramianWeightedAggregator` for adversarial federated learning, as defined in `Machine Learning with Adversaries: Byzantine Tolerant Gradient Descent `_. diff --git a/src/torchjd/aggregation/_mean.py b/src/torchjd/aggregation/_mean.py index 13a72649a..e8b75e7cc 100644 --- a/src/torchjd/aggregation/_mean.py +++ b/src/torchjd/aggregation/_mean.py @@ -7,7 +7,7 @@ class MeanWeighting(MatrixWeighting): r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights + :class:`~torchjd.aggregation.MatrixWeighting` that gives the weights :math:`\begin{bmatrix} \frac{1}{m} & \dots & \frac{1}{m} \end{bmatrix}^T \in \mathbb{R}^m`. """ @@ -22,7 +22,7 @@ def forward(self, matrix: Tensor, /) -> Tensor: class Mean(WeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that averages the rows of the input + :class:`~torchjd.aggregation.WeightedAggregator` that averages the rows of the input matrices. """ diff --git a/src/torchjd/aggregation/_mgda.py b/src/torchjd/aggregation/_mgda.py index 33c727aa1..ec9d0afc3 100644 --- a/src/torchjd/aggregation/_mgda.py +++ b/src/torchjd/aggregation/_mgda.py @@ -9,7 +9,7 @@ class MGDAWeighting(GramianWeighting): r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.MGDA`. :param epsilon: The value of :math:`\hat{\gamma}` below which we stop the optimization. @@ -74,7 +74,7 @@ def max_iters(self, value: int) -> None: class MGDA(GramianWeightedAggregator): r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` performing the gradient aggregation + :class:`~torchjd.aggregation.GramianWeightedAggregator` performing the gradient aggregation step of `Multiple-gradient descent algorithm (MGDA) for multiobjective optimization `_. The implementation is based on Algorithm 2 of `Multi-Task Learning as Multi-Objective diff --git a/src/torchjd/aggregation/_nash_mtl.py b/src/torchjd/aggregation/_nash_mtl.py index 65604d358..5be55afd5 100644 --- a/src/torchjd/aggregation/_nash_mtl.py +++ b/src/torchjd/aggregation/_nash_mtl.py @@ -20,7 +20,7 @@ class _NashMTLWeighting(MatrixWeighting, Stateful): """ - :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.Weighting` that + :class:`~torchjd.aggregation._mixins.Stateful` :class:`~torchjd.aggregation.MatrixWeighting` that extracts weights using the step decision of Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. @@ -201,7 +201,7 @@ def reset(self) -> None: class NashMTL(WeightedAggregator, Stateful): """ :class:`~torchjd.aggregation._mixins.Stateful` - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as proposed in Algorithm 1 of + :class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of `Multi-Task Learning as a Bargaining Game `_. :param n_tasks: The number of tasks, corresponding to the number of rows in the provided diff --git a/src/torchjd/aggregation/_pcgrad.py b/src/torchjd/aggregation/_pcgrad.py index dd965ff77..25e244522 100644 --- a/src/torchjd/aggregation/_pcgrad.py +++ b/src/torchjd/aggregation/_pcgrad.py @@ -12,7 +12,7 @@ class PCGradWeighting(GramianWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.PCGrad`. """ @@ -48,7 +48,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor: class PCGrad(GramianWeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` as defined in algorithm 1 of + :class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in algorithm 1 of `Gradient Surgery for Multi-Task Learning `_. """ diff --git a/src/torchjd/aggregation/_random.py b/src/torchjd/aggregation/_random.py index ca32d601a..d20d54db1 100644 --- a/src/torchjd/aggregation/_random.py +++ b/src/torchjd/aggregation/_random.py @@ -8,7 +8,7 @@ class RandomWeighting(MatrixWeighting): """ - :class:`~torchjd.aggregation._weighting_bases.Weighting` that generates positive random weights + :class:`~torchjd.aggregation.MatrixWeighting` that generates positive random weights at each call. """ @@ -20,7 +20,7 @@ def forward(self, matrix: Tensor, /) -> Tensor: class Random(WeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that computes a random combination of + :class:`~torchjd.aggregation.WeightedAggregator` that computes a random combination of the rows of the provided matrices, as defined in algorithm 2 of `Reasonable Effectiveness of Random Weighting: A Litmus Test for Multi-Task Learning `_. diff --git a/src/torchjd/aggregation/_sum.py b/src/torchjd/aggregation/_sum.py index 7e6beb55c..1a48ef41f 100644 --- a/src/torchjd/aggregation/_sum.py +++ b/src/torchjd/aggregation/_sum.py @@ -7,7 +7,7 @@ class SumWeighting(MatrixWeighting): r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` that gives the weights + :class:`~torchjd.aggregation.MatrixWeighting` that gives the weights :math:`\begin{bmatrix} 1 & \dots & 1 \end{bmatrix}^T \in \mathbb{R}^m`. """ @@ -20,7 +20,7 @@ def forward(self, matrix: Tensor, /) -> Tensor: class Sum(WeightedAggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that sums of the rows of the input + :class:`~torchjd.aggregation.WeightedAggregator` that sums of the rows of the input matrices. """ diff --git a/src/torchjd/aggregation/_trimmed_mean.py b/src/torchjd/aggregation/_trimmed_mean.py index 013e8e579..6a3edb848 100644 --- a/src/torchjd/aggregation/_trimmed_mean.py +++ b/src/torchjd/aggregation/_trimmed_mean.py @@ -6,7 +6,7 @@ class TrimmedMean(Aggregator): """ - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` for adversarial federated learning, + :class:`~torchjd.aggregation.Aggregator` for adversarial federated learning, that trims the most extreme values of the input matrix, before averaging its rows, as defined in `Byzantine-Robust Distributed Learning: Towards Optimal Statistical Rates `_. diff --git a/src/torchjd/aggregation/_upgrad.py b/src/torchjd/aggregation/_upgrad.py index 686898297..731986843 100644 --- a/src/torchjd/aggregation/_upgrad.py +++ b/src/torchjd/aggregation/_upgrad.py @@ -13,7 +13,7 @@ class UPGradWeighting(GramianWeighting): r""" - :class:`~torchjd.aggregation._weighting_bases.Weighting` giving the weights of + :class:`~torchjd.aggregation.GramianWeighting` giving the weights of :class:`~torchjd.aggregation.UPGrad`. :param pref_vector: The preference vector to use. If not provided, defaults to @@ -81,7 +81,7 @@ def reg_eps(self, value: float) -> None: class UPGrad(GramianWeightedAggregator): r""" - :class:`~torchjd.aggregation._aggregator_bases.Aggregator` that projects each row of the input + :class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed in `Jacobian Descent For Multi-Objective Optimization `_. diff --git a/src/torchjd/autojac/_jac_to_grad.py b/src/torchjd/autojac/_jac_to_grad.py index 03ab3b606..bcebcb39d 100644 --- a/src/torchjd/autojac/_jac_to_grad.py +++ b/src/torchjd/autojac/_jac_to_grad.py @@ -6,8 +6,12 @@ from torch import Tensor, nn from torchjd._linalg import Matrix, PSDMatrix, compute_gramian -from torchjd.aggregation import Aggregator, Weighting -from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator +from torchjd.aggregation import ( + Aggregator, + GramianWeightedAggregator, + WeightedAggregator, + Weighting, +) from ._accumulation import TensorWithJac, accumulate_grads, is_tensor_with_jac from ._utils import check_consistent_first_dimension @@ -59,8 +63,8 @@ def jac_to_grad( :param tensors: The tensors whose ``.jac`` fields should be aggregated. All Jacobians must have the same first dimension (e.g. number of losses). :param aggregator: The aggregator used to reduce the Jacobians into gradients. If it uses a - :class:`Weighting ` to combine the rows of - the Jacobians, ``jac_to_grad`` will also return the computed weights. + :class:`~torchjd.aggregation.Weighting` to combine the rows of the Jacobians, + ``jac_to_grad`` will also return the computed weights. :param retain_jac: Whether to preserve the ``.jac`` fields of the tensors after they have been used. Defaults to ``False``. :param optimize_gramian_computation: When the ``aggregator`` computes weights based on the diff --git a/tests/unit/autojac/test_jac_to_grad.py b/tests/unit/autojac/test_jac_to_grad.py index 91278f07d..b4ca0e0be 100644 --- a/tests/unit/autojac/test_jac_to_grad.py +++ b/tests/unit/autojac/test_jac_to_grad.py @@ -15,6 +15,7 @@ Constant, DualProj, GradDrop, + GramianWeightedAggregator, Krum, Mean, PCGrad, @@ -22,8 +23,8 @@ Sum, TrimmedMean, UPGrad, + WeightedAggregator, ) -from torchjd.aggregation._aggregator_bases import GramianWeightedAggregator, WeightedAggregator from torchjd.autojac._jac_to_grad import ( _can_skip_jacobian_combination, _has_forward_hook,