Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_cagrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class CAGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
class CAGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable):
_REQUIRED_DEPS = ["numpy", "cvxpy", "clarabel"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[cagrad]"'
"""
Expand Down Expand Up @@ -94,7 +94,7 @@ def norm_eps(self, value: float) -> None:
self._norm_eps = value


class CAGrad(_NonDifferentiable, GramianWeightedAggregator):
class CAGrad(GramianWeightedAggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Conflict-Averse Gradient Descent for Multi-task Learning
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


# Non-differentiable: the pseudoinverse and the normalization are not differentiable in this context.
class ConFIG(_NonDifferentiable, Aggregator):
class ConFIG(Aggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Aggregator` as defined in Equation 2 of `ConFIG:
Towards Conflict-free Training of Physics Informed Neural Networks
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_dualproj.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
class DualProjWeighting(_NonDifferentiable, _GramianWeighting):
class DualProjWeighting(_GramianWeighting, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.DualProj`.
Expand Down Expand Up @@ -53,7 +53,7 @@ def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)


class DualProj(_NonDifferentiable, GramianWeightedAggregator):
class DualProj(GramianWeightedAggregator, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that averages the rows of the input
matrix, and projects the result onto the dual cone of the rows of the matrix. This corresponds
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_fairgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# Non-differentiable: the scipy solver operates on numpy arrays, breaking the autograd graph.
class FairGradWeighting(_WithOptionalDeps, _NonDifferentiable, _GramianWeighting):
class FairGradWeighting(_WithOptionalDeps, _GramianWeighting, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`] giving the
weights of :class:`~torchjd.aggregation.FairGrad`, as defined in Equation 4 of `Fair Resource
Expand Down Expand Up @@ -78,7 +78,7 @@ def alpha(self, value: float) -> None:
self._alpha = value


class FairGrad(_NonDifferentiable, GramianWeightedAggregator):
class FairGrad(GramianWeightedAggregator, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` using the step decision of Algorithm 1
of `Fair Resource Allocation in Multi-Task Learning
Expand Down
2 changes: 1 addition & 1 deletion src/torchjd/aggregation/_graddrop.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def _identity(P: Tensor) -> Tensor:


# Non-differentiable: the sign-based random masking is not differentiable.
class GradDrop(_NonDifferentiable, Aggregator):
class GradDrop(Aggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Aggregator` that applies the gradient combination
steps from GradDrop, as defined in lines 10 to 15 of Algorithm 1 of `Just Pick a Sign:
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_gradvac.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@


# Non-differentiable: weights are modified in-place during the gradient correction loop.
class GradVacWeighting(_NonDifferentiable, Stateful, _GramianWeighting):
class GradVacWeighting(_GramianWeighting, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
Expand Down Expand Up @@ -128,7 +128,7 @@ def _ensure_state(self, m: int, dtype: torch.dtype) -> None:
self._state_key = key


class GradVac(_NonDifferentiable, Stateful, GramianWeightedAggregator):
class GradVac(GramianWeightedAggregator, Stateful, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.GramianWeightedAggregator` implementing the aggregation step of
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_imtl_g.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


# Non-differentiable: differentiating through pinv(gramian) would give incorrect gradients.
class IMTLGWeighting(_NonDifferentiable, _GramianWeighting):
class IMTLGWeighting(_GramianWeighting, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.IMTLG`.
Expand All @@ -25,7 +25,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights


class IMTLG(_NonDifferentiable, GramianWeightedAggregator):
class IMTLG(GramianWeightedAggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` generalizing the method described in
`Towards Impartial Multi-task Learning <https://discovery.ucl.ac.uk/id/eprint/10120667/>`_.
Expand Down
5 changes: 2 additions & 3 deletions src/torchjd/aggregation/_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@ class _NonDifferentiable(nn.Module):
the call in :func:`torch.no_grad`.

.. warning::
This mixin must appear **before** any :class:`torch.nn.Module` base class in the inheritance
list. Placing it after will silently have no effect, because :meth:`__call__` would be
resolved to :class:`torch.nn.Module` before reaching this mixin.
Placing this mixin *before* the primary base will cause it to shadow the primary class's
:meth:`__call__` signature in generated documentation.
"""

def __call__(self, *args: Any, **kwargs: Any) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_nash_mtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@


# Non-differentiable: the cvxpy solver operates on numpy arrays, breaking the autograd graph.
class _NashMTLWeighting(_WithOptionalDeps, _NonDifferentiable, Stateful, _MatrixWeighting):
class _NashMTLWeighting(_WithOptionalDeps, _MatrixWeighting, Stateful, _NonDifferentiable):
_REQUIRED_DEPS = ["numpy", "cvxpy", "ecos"]
_INSTALL_HINT = 'Install them with: pip install "torchjd[nash_mtl]"'
"""
Expand Down Expand Up @@ -204,7 +204,7 @@ def reset(self) -> None:
self.prvs_alpha = np.ones(self.n_tasks, dtype=np.float32)


class NashMTL(_NonDifferentiable, Stateful, WeightedAggregator):
class NashMTL(WeightedAggregator, Stateful, _NonDifferentiable):
"""
:class:`~torchjd.aggregation._mixins.Stateful`
:class:`~torchjd.aggregation.WeightedAggregator` as proposed in Algorithm 1 of
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_pcgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@


# Non-differentiable: weights are modified in-place during the gradient projection loop.
class PCGradWeighting(_NonDifferentiable, _GramianWeighting):
class PCGradWeighting(_GramianWeighting, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.PCGrad`.
Expand Down Expand Up @@ -47,7 +47,7 @@ def forward(self, gramian: PSDMatrix, /) -> Tensor:
return weights.to(device)


class PCGrad(_NonDifferentiable, GramianWeightedAggregator):
class PCGrad(GramianWeightedAggregator, _NonDifferentiable):
"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` as defined in Algorithm 1 of
`Gradient Surgery for Multi-Task Learning <https://arxiv.org/pdf/2001.06782.pdf>`_.
Expand Down
4 changes: 2 additions & 2 deletions src/torchjd/aggregation/_upgrad.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@


# Non-differentiable: the QP solver operates on numpy arrays, breaking the autograd graph.
class UPGradWeighting(_NonDifferentiable, _GramianWeighting):
class UPGradWeighting(_GramianWeighting, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.Weighting` [:class:`~torchjd.linalg.PSDMatrix`]
giving the weights of :class:`~torchjd.aggregation.UPGrad`.
Expand Down Expand Up @@ -54,7 +54,7 @@ def projector(self, value: DualConeProjector | None) -> None:
self._projector = projector_or_default(value)


class UPGrad(_NonDifferentiable, GramianWeightedAggregator):
class UPGrad(GramianWeightedAggregator, _NonDifferentiable):
r"""
:class:`~torchjd.aggregation.GramianWeightedAggregator` that projects each row of the input
matrix onto the dual cone of all rows of this matrix, and that combines the result, as proposed
Expand Down
Loading