Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,29 @@ changelog does not include internal changes that do not affect the user.
(installed with `pip install torchjd`) much lighter, but it means that users of `UPGrad` and
`DualProj` now have to install the new optional dependency group `quadprog_projector` explicitly
(with e.g. `pip install "torchjd[quadprog_projector]"`).
- **BREAKING**: Removed entirely the concept of generalized Gramians. The `Engine.compute_gramian`
method now always returns a square matrix of shape `[m, m]`, where `m` is the total number of
elements of the ``output`` tensor (treating all dimensions uniformly). Previously, an output of
shape `[m_1, m_2]` would return a 4D generalized Gramian of shape `[m_1, m_2, m_2, m_1]`; it now
returns a `[m_1 * m_2, m_1 * m_2]` matrix.
This also removes `GeneralizedWeighting` and `Flattening`.
To update, replace `Flattening(weighting)` with a standard `Weighting` and reshape the resulting
weight vector yourself:
```python
# Before
from torchjd.aggregation import Flattening, UPGradWeighting
weighting = Flattening(UPGradWeighting())
gramian = engine.compute_gramian(losses) # shape: [m1, m2, m2, m1]
weights = weighting(gramian) # shape: [m1, m2]
losses.backward(weights)

# After
from torchjd.aggregation import UPGradWeighting
weighting = UPGradWeighting()
gramian = engine.compute_gramian(losses) # shape: [m1 * m2, m1 * m2]
weights = weighting(gramian).reshape(losses.shape) # shape: [m1, m2]
losses.backward(weights)
```

## [0.11.0] - 2026-05-18

Expand Down
7 changes: 0 additions & 7 deletions docs/source/docs/aggregation/flattening.rst

This file was deleted.

4 changes: 0 additions & 4 deletions docs/source/docs/aggregation/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ Abstract base classes
.. autoclass:: torchjd.aggregation.Weighting
:members: __call__

.. autoclass:: torchjd.aggregation.GeneralizedWeighting
:members: __call__

.. autoclass:: torchjd.aggregation.Stateful
:members: reset

Expand All @@ -38,7 +35,6 @@ Abstract base classes
cr_mogm.rst
dualproj.rst
fairgrad.rst
flattening.rst
graddrop.rst
gradvac.rst
imtl_g.rst
Expand Down
21 changes: 10 additions & 11 deletions docs/source/examples/iwmtl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ The following example shows how to do that.
from torch.nn import Linear, MSELoss, ReLU, Sequential
from torch.optim import SGD

from torchjd.aggregation import Flattening, UPGradWeighting
from torchjd.aggregation import UPGradWeighting
from torchjd.autogram import Engine

shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
Expand All @@ -30,7 +30,7 @@ The following example shows how to do that.

optimizer = SGD(params, lr=0.1)
mse = MSELoss(reduction="none")
weighting = Flattening(UPGradWeighting())
weighting = UPGradWeighting()
engine = Engine(shared_module, batch_dim=0)

inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10
Expand All @@ -46,20 +46,19 @@ The following example shows how to do that.
losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2]

# Compute the gramian (inner products between pairs of gradients of the losses)
gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16]
gramian = engine.compute_gramian(losses) # shape: [32, 32]

# Obtain the weights that lead to no conflict between reweighted gradients
weights = weighting(gramian) # shape: [16, 2]
weights = weighting(gramian) # shape: [32]

# Do the standard backward pass, but weighted using the obtained weights
losses.backward(weights)
losses.backward(weights.reshape(losses.shape))
optimizer.step()
optimizer.zero_grad()

.. note::
In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a
4D tensor rather than a matrix, and a
:class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as
:class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of
weights from it. More information about ``GeneralizedWeighting`` can be found in the
:doc:`../../docs/aggregation/index` page.
In this example, the tensor of losses is a matrix of shape ``[16, 2]`` (16 samples, 2 tasks).
The autogram engine flattens this into a vector of ``m = 16 × 2 = 32`` objectives, so the
Gramian has shape ``[32, 32]``. A standard :class:`~torchjd.aggregation.Weighting` is then used
to extract a vector of 32 weights, which is reshaped back to ``[16, 2]`` before being passed to
:meth:`~torch.Tensor.backward`.
8 changes: 1 addition & 7 deletions src/torchjd/_linalg/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
from ._dual_cone import DualConeProjector, QuadprogProjector, projector_or_default
from ._generalized_gramian import flatten, movedim, reshape
from ._gramian import compute_gramian, normalize, regularize
from ._matrix import Matrix, PSDMatrix, PSDTensor, is_matrix, is_psd_matrix, is_psd_tensor
from ._matrix import Matrix, PSDMatrix, is_matrix, is_psd_matrix

__all__ = [
"compute_gramian",
"normalize",
"regularize",
"Matrix",
"PSDMatrix",
"PSDTensor",
"is_matrix",
"is_psd_matrix",
"is_psd_tensor",
"flatten",
"reshape",
"movedim",
"DualConeProjector",
"QuadprogProjector",
"projector_or_default",
Expand Down
87 changes: 0 additions & 87 deletions src/torchjd/_linalg/_generalized_gramian.py

This file was deleted.

18 changes: 7 additions & 11 deletions src/torchjd/_linalg/_gramian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,26 @@
import torch
from torch import Tensor

from ._matrix import Matrix, PSDMatrix, PSDTensor
from ._matrix import Matrix, PSDMatrix


@overload
def compute_gramian(t: Tensor) -> PSDMatrix:
pass
def compute_gramian(t: Tensor) -> PSDMatrix: ...


@overload
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix:
pass
def compute_gramian(t: Tensor, contracted_dims: Literal[-1]) -> PSDMatrix: ...


@overload
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix:
pass
def compute_gramian(t: Matrix, contracted_dims: Literal[1]) -> PSDMatrix: ...

Comment on lines +10 to 22
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It removed all the pass here.


@overload
def compute_gramian(t: Tensor, contracted_dims: int) -> PSDTensor:
pass
def compute_gramian(t: Tensor, contracted_dims: int) -> Tensor: ...


def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
def compute_gramian(t: Tensor, contracted_dims: int = -1) -> Tensor:
"""
Computes the `Gramian matrix <https://en.wikipedia.org/wiki/Gram_matrix>`_ of the input.

Expand All @@ -49,7 +45,7 @@ def compute_gramian(t: Tensor, contracted_dims: int = -1) -> PSDTensor:
transposed = t.movedim(indices_source, indices_dest)
gramian = torch.tensordot(t, transposed, dims=contracted_dims)

return cast(PSDTensor, gramian)
return gramian


def normalize(gramian: PSDMatrix, eps: float) -> PSDMatrix:
Expand Down
17 changes: 1 addition & 16 deletions src/torchjd/_linalg/_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,7 @@ class Matrix(Tensor):
"""


class PSDTensor(Tensor):
"""
Tensor representing a quadratic form. The first half of its dimensions matches the reversed
second half of its dimensions (e.g. shape=[4, 3, 3, 4]), and its reshaping into a matrix should
be positive semi-definite.
"""


class PSDMatrix(PSDTensor, Matrix):
class PSDMatrix(Matrix):
"""
Positive semi-definite matrix.

Expand All @@ -44,13 +36,6 @@ def is_matrix(t: Tensor) -> TypeGuard[Matrix]:
return t.ndim == 2


def is_psd_tensor(t: Tensor) -> TypeGuard[PSDTensor]:
half_dim = t.ndim // 2
return t.ndim % 2 == 0 and t.shape[:half_dim] == t.shape[: half_dim - 1 : -1]
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
# every function that uses this TypeGuard by using `assert_is_psd_tensor`.


def is_psd_matrix(t: Tensor) -> TypeGuard[PSDMatrix]:
return t.ndim == 2 and t.shape[0] == t.shape[1]
# We do not check that t is PSD as it is expensive, but this must be checked in the tests of
Expand Down
27 changes: 1 addition & 26 deletions src/torchjd/aggregation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,28 +36,6 @@
>>> weights = weighting(gramian)
>>> weights
tensor([1.1109, 0.7894])

When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not
necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian
will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a
:class:`GeneralizedWeighting<torchjd.aggregation.GeneralizedWeighting>` to extract
a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest
:class:`GeneralizedWeighting<torchjd.aggregation.GeneralizedWeighting>` is
:class:`Flattening<torchjd.aggregation.Flattening>`: it simply "flattens" the
generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``),
applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of
weights.

>>> from torch import ones
>>> from torchjd.aggregation import Flattening, UPGradWeighting
>>>
>>> weighting = Flattening(UPGradWeighting())
>>> # Generate a generalized Gramian filled with ones, for the sake of the example
>>> generalized_gramian = ones((2, 3, 3, 2))
>>> weights = weighting(generalized_gramian)
>>> weights
tensor([[0.1667, 0.1667, 0.1667],
[0.1667, 0.1667, 0.1667]])
"""

from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator
Expand All @@ -68,7 +46,6 @@
from ._cr_mogm import CRMOGMWeighting
from ._dualproj import DualProj, DualProjWeighting
from ._fairgrad import FairGrad, FairGradWeighting
from ._flattening import Flattening
from ._graddrop import GradDrop
from ._gradvac import GradVac, GradVacWeighting
from ._imtl_g import IMTLG, IMTLGWeighting
Expand All @@ -82,7 +59,7 @@
from ._sum import Sum, SumWeighting
from ._trimmed_mean import TrimmedMean
from ._upgrad import UPGrad, UPGradWeighting
from ._weighting_bases import GeneralizedWeighting, Weighting
from ._weighting_bases import Weighting

__all__ = [
"Aggregator",
Expand All @@ -98,8 +75,6 @@
"DualProjWeighting",
"FairGrad",
"FairGradWeighting",
"Flattening",
"GeneralizedWeighting",
"GradDrop",
"GradVac",
"GradVacWeighting",
Expand Down
32 changes: 0 additions & 32 deletions src/torchjd/aggregation/_flattening.py

This file was deleted.

Loading
Loading