diff --git a/CHANGELOG.md b/CHANGELOG.md index 2e160ad1..a41691d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -20,6 +20,29 @@ changelog does not include internal changes that do not affect the user. (installed with `pip install torchjd`) much lighter, but it means that users of `UPGrad` and `DualProj` now have to install the new optional dependency group `quadprog_projector` explicitly (with e.g. `pip install "torchjd[quadprog_projector]"`). +- **BREAKING**: Removed entirely the concept of generalized Gramians. The `Engine.compute_gramian` + method now always returns a square matrix of shape `[m, m]`, where `m` is the total number of + elements of the ``output`` tensor (treating all dimensions uniformly). Previously, an output of + shape `[m1, m2]` would return a 4D generalized Gramian of shape `[m1, m2, m2, m1]`; it now + returns a `[m1 * m2, m1 * m2]` matrix. + This also removes `GeneralizedWeighting` and `Flattening`. + To update, replace `Flattening(weighting)` with a standard `Weighting` and reshape the resulting + weight vector yourself: + ```python + # Before + from torchjd.aggregation import Flattening, UPGradWeighting + weighting = Flattening(UPGradWeighting()) + gramian = engine.compute_gramian(losses) # shape: [m1, m2, m2, m1] + weights = weighting(gramian) # shape: [m1, m2] + losses.backward(weights) + + # After + from torchjd.aggregation import UPGradWeighting + weighting = UPGradWeighting() + gramian = engine.compute_gramian(losses) # shape: [m1 * m2, m1 * m2] + weights = weighting(gramian).reshape(losses.shape) # shape: [m1, m2] + losses.backward(weights) + ``` ## [0.11.0] - 2026-05-18 diff --git a/docs/source/docs/aggregation/flattening.rst b/docs/source/docs/aggregation/flattening.rst deleted file mode 100644 index b6d7f698..00000000 --- a/docs/source/docs/aggregation/flattening.rst +++ /dev/null @@ -1,7 +0,0 @@ -:hide-toc: - -Flattening -========== - -.. autoclass:: torchjd.aggregation.Flattening - :members: __call__ diff --git a/docs/source/docs/aggregation/index.rst b/docs/source/docs/aggregation/index.rst index 0b149e5d..13e405cb 100644 --- a/docs/source/docs/aggregation/index.rst +++ b/docs/source/docs/aggregation/index.rst @@ -19,9 +19,6 @@ Abstract base classes .. autoclass:: torchjd.aggregation.Weighting :members: __call__ -.. autoclass:: torchjd.aggregation.GeneralizedWeighting - :members: __call__ - .. autoclass:: torchjd.aggregation.Stateful :members: reset @@ -38,7 +35,6 @@ Abstract base classes cr_mogm.rst dualproj.rst fairgrad.rst - flattening.rst graddrop.rst gradvac.rst imtl_g.rst diff --git a/docs/source/examples/iwmtl.rst b/docs/source/examples/iwmtl.rst index 3235e0ee..dd76f3db 100644 --- a/docs/source/examples/iwmtl.rst +++ b/docs/source/examples/iwmtl.rst @@ -16,7 +16,7 @@ The following example shows how to do that. from torch.nn import Linear, MSELoss, ReLU, Sequential from torch.optim import SGD - from torchjd.aggregation import Flattening, UPGradWeighting + from torchjd.aggregation import UPGradWeighting from torchjd.autogram import Engine shared_module = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU()) @@ -30,7 +30,7 @@ The following example shows how to do that. optimizer = SGD(params, lr=0.1) mse = MSELoss(reduction="none") - weighting = Flattening(UPGradWeighting()) + weighting = UPGradWeighting() engine = Engine(shared_module, batch_dim=0) inputs = torch.randn(8, 16, 10) # 8 batches of 16 random input vectors of length 10 @@ -46,20 +46,19 @@ The following example shows how to do that. losses = torch.stack([mse(out1, target1), mse(out2, target2)], dim=1) # shape: [16, 2] # Compute the gramian (inner products between pairs of gradients of the losses) - gramian = engine.compute_gramian(losses) # shape: [16, 2, 2, 16] + gramian = engine.compute_gramian(losses) # shape: [32, 32] # Obtain the weights that lead to no conflict between reweighted gradients - weights = weighting(gramian) # shape: [16, 2] + weights = weighting(gramian) # shape: [32] # Do the standard backward pass, but weighted using the obtained weights - losses.backward(weights) + losses.backward(weights.reshape(losses.shape)) optimizer.step() optimizer.zero_grad() .. note:: - In this example, the tensor of losses is a matrix rather than a vector. The gramian is thus a - 4D tensor rather than a matrix, and a - :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting`, such as - :class:`~torchjd.aggregation._flattening.Flattening`, has to be used to extract a matrix of - weights from it. More information about ``GeneralizedWeighting`` can be found in the - :doc:`../../docs/aggregation/index` page. + In this example, the tensor of losses is a matrix of shape ``[16, 2]`` (16 samples, 2 tasks). + The autogram engine flattens this into a vector of ``m = 16 × 2 = 32`` objectives, so the + Gramian has shape ``[32, 32]``. A standard :class:`~torchjd.aggregation.Weighting` is then used + to extract a vector of 32 weights, which is reshaped back to ``[16, 2]`` before being passed to + :meth:`~torch.Tensor.backward`. diff --git a/src/torchjd/aggregation/__init__.py b/src/torchjd/aggregation/__init__.py index a1b67e39..1814d320 100644 --- a/src/torchjd/aggregation/__init__.py +++ b/src/torchjd/aggregation/__init__.py @@ -36,28 +36,6 @@ >>> weights = weighting(gramian) >>> weights tensor([1.1109, 0.7894]) - -When dealing with a more general tensor of objectives, of shape ``[m_1, ..., m_k]`` (i.e. not -necessarily a simple vector), the Jacobian will be of shape ``[m_1, ..., m_k, n]``, and its Gramian -will be called a `generalized Gramian`, of shape ``[m_1, ..., m_k, m_k, ..., m_1]``. One can use a -:class:`GeneralizedWeighting` to extract -a tensor of weights (of shape ``[m_1, ..., m_k]``) from such a generalized Gramian. The simplest -:class:`GeneralizedWeighting` is -:class:`Flattening`: it simply "flattens" the -generalized Gramian into a square Gramian matrix (of shape ``[m_1 * ... * m_k, m_1 * ... * m_k]``), -applies a normal weighting to it to obtain a vector of weights, and returns the reshaped tensor of -weights. - ->>> from torch import ones ->>> from torchjd.aggregation import Flattening, UPGradWeighting ->>> ->>> weighting = Flattening(UPGradWeighting()) ->>> # Generate a generalized Gramian filled with ones, for the sake of the example ->>> generalized_gramian = ones((2, 3, 3, 2)) ->>> weights = weighting(generalized_gramian) ->>> weights -tensor([[0.1667, 0.1667, 0.1667], - [0.1667, 0.1667, 0.1667]]) """ from ._aggregator_bases import Aggregator, GramianWeightedAggregator, WeightedAggregator @@ -68,7 +46,6 @@ from ._cr_mogm import CRMOGMWeighting from ._dualproj import DualProj, DualProjWeighting from ._fairgrad import FairGrad, FairGradWeighting -from ._flattening import Flattening from ._graddrop import GradDrop from ._gradvac import GradVac, GradVacWeighting from ._imtl_g import IMTLG, IMTLGWeighting @@ -82,7 +59,7 @@ from ._sum import Sum, SumWeighting from ._trimmed_mean import TrimmedMean from ._upgrad import UPGrad, UPGradWeighting -from ._weighting_bases import GeneralizedWeighting, Weighting +from ._weighting_bases import Weighting __all__ = [ "Aggregator", @@ -98,8 +75,6 @@ "DualProjWeighting", "FairGrad", "FairGradWeighting", - "Flattening", - "GeneralizedWeighting", "GradDrop", "GradVac", "GradVacWeighting", diff --git a/src/torchjd/aggregation/_flattening.py b/src/torchjd/aggregation/_flattening.py deleted file mode 100644 index 4d4a02af..00000000 --- a/src/torchjd/aggregation/_flattening.py +++ /dev/null @@ -1,32 +0,0 @@ -from torch import Tensor - -from torchjd._linalg import PSDTensor, flatten -from torchjd.aggregation._weighting_bases import GeneralizedWeighting, Weighting - - -class Flattening(GeneralizedWeighting): - """ - :class:`~torchjd.aggregation._weighting_bases.GeneralizedWeighting` flattening the generalized - Gramian into a square matrix, extracting a vector of weights from it using a - :class:`~torchjd.aggregation._weighting_bases.Weighting`, and returning the reshaped tensor of - weights. - - For instance, when applied to a generalized Gramian of shape ``[2, 3, 3, 2]``, it would flatten - it into a square Gramian matrix of shape ``[6, 6]``, apply the weighting on it to get a vector - of weights of shape ``[6]``, and then return this vector reshaped into a matrix of shape - ``[2, 3]``. - - :param weighting: The weighting to apply to the Gramian matrix. - """ - - def __init__(self, weighting: Weighting) -> None: - super().__init__() - self.weighting = weighting - - def forward(self, generalized_gramian: PSDTensor, /) -> Tensor: - k = generalized_gramian.ndim // 2 - shape = generalized_gramian.shape[:k] - square_gramian = flatten(generalized_gramian) - weights_vector = self.weighting(square_gramian) - weights = weights_vector.reshape(shape) - return weights diff --git a/src/torchjd/aggregation/_weighting_bases.py b/src/torchjd/aggregation/_weighting_bases.py index 00eea54a..cdffdd3d 100644 --- a/src/torchjd/aggregation/_weighting_bases.py +++ b/src/torchjd/aggregation/_weighting_bases.py @@ -6,7 +6,6 @@ from torch import Tensor, nn -from torchjd._linalg import PSDTensor, is_psd_tensor from torchjd.linalg import Matrix, PSDMatrix _T = TypeVar("_T", contravariant=True, bound=Tensor) @@ -76,30 +75,3 @@ def __call__(self, gramian: Tensor, /) -> Tensor: :param gramian: The Gramian from which the weights must be extracted. """ return super().__call__(gramian) - - -class GeneralizedWeighting(nn.Module, ABC): - r""" - Abstract base class for all weightings that operate on generalized Gramians. It has the role of - extracting a tensor of weights of dimension :math:`m_1 \times \dots \times m_k` from a - generalized Gramian of dimension - :math:`m_1 \times \dots \times m_k \times m_k \times \dots \times m_1`. - """ - - def __init__(self) -> None: - super().__init__() - - @abstractmethod - def forward(self, generalized_gramian: PSDTensor, /) -> Tensor: - """Computes the vector of weights from the input generalized Gramian.""" - - def __call__(self, generalized_gramian: Tensor, /) -> Tensor: - """ - Computes the tensor of weights from the input generalized Gramian and applies all registered - hooks. - - :param generalized_gramian: The tensor from which the weights must be extracted. - """ - - assert is_psd_tensor(generalized_gramian) - return super().__call__(generalized_gramian) diff --git a/src/torchjd/autogram/_engine.py b/src/torchjd/autogram/_engine.py index 72112e80..8a8633df 100644 --- a/src/torchjd/autogram/_engine.py +++ b/src/torchjd/autogram/_engine.py @@ -4,7 +4,7 @@ from torch import Tensor, nn, vmap from torch.autograd.graph import get_gradient_edge -from torchjd._linalg import movedim, reshape +from torchjd._linalg import flatten, movedim, reshape from torchjd.linalg import PSDMatrix from ._edge_registry import EdgeRegistry @@ -246,28 +246,18 @@ def compute_gramian(self, output: Tensor, /) -> Tensor: Computes the Gramian of the Jacobian of ``output`` with respect to the direct parameters of all ``modules``. - :param output: The tensor of arbitrary shape to differentiate. The shape of the returned - Gramian depends on the shape of this output. - - .. note:: - This function doesn't require ``output`` to be a vector. For example, if ``output`` is - a matrix of shape :math:`[m_1, m_2]`, its Jacobian :math:`J` with respect to the - parameters will be of shape :math:`[m_1, m_2, n]`, where :math:`n` is the number of - parameters in the model. This is what we call a `generalized Jacobian`. The - corresponding Gramian :math:`G = J J^\top` will be of shape - :math:`[m_1, m_2, m_2, m_1]`. This is what we call a `generalized Gramian`. The number - of dimensions of the returned generalized Gramian will always be twice that of the - ``output``. + :param output: The tensor to differentiate. Its elements are treated as a flat vector of + :math:`m` objectives (where :math:`m` is the total number of elements of ``output``), + so the returned Gramian always has shape :math:`[m, m]`. A few examples: - - 0D (scalar) ``output``: 0D Gramian (this can be used to efficiently compute the - squared norm of the gradient of ``output``). - - 1D (vector) ``output``: 2D Gramian (this is the standard setting of Jacobian - descent). - - 2D (matrix) ``output``: 4D Gramian (this can be used for :doc:`Instance-Wise - Multi-Task Learning (IWMTL) <../../examples/iwmtl>`, as each sample in the batch - has one loss per task). - - etc. + - Scalar ``output``: :math:`1\times 1` Gramian (this can be used to efficiently + compute the squared norm of the gradient of ``output``). + - Vector ``output`` of dimension :math:`m`: :math:`m \times m` Gramian (this is the + standard setting of Jacobian descent). + - Matrix ``output`` of dimension :math:`m_1\times m_2`: :math:`m_1 m_2 \times m_1 m_2` + Gramian (this can be used for :doc:`Instance-Wise Multi-Task Learning (IWMTL) + <../../examples/iwmtl>`, as each sample in the batch has one loss per task). """ if self._batch_dim is not None: @@ -305,12 +295,11 @@ def compute_gramian(self, output: Tensor, /) -> Tensor: for gramian_computer in self._gramian_computers.values(): gramian_computer.reset() - unordered_gramian = reshape(square_gramian, ordered_shape) - if self._batch_dim is not None: - gramian = movedim(unordered_gramian, [-1], [self._batch_dim]) + unordered_gramian = reshape(square_gramian, ordered_shape) + gramian = flatten(movedim(unordered_gramian, [-1], [self._batch_dim])) else: - gramian = unordered_gramian + gramian = square_gramian return gramian diff --git a/tests/unit/aggregation/test_flattening.py b/tests/unit/aggregation/test_flattening.py deleted file mode 100644 index e1f7fc26..00000000 --- a/tests/unit/aggregation/test_flattening.py +++ /dev/null @@ -1,36 +0,0 @@ -from pytest import mark -from torch.testing import assert_close -from utils.optional_deps import base_weighting -from utils.tensors import randn_ - -from torchjd._linalg import PSDMatrix, compute_gramian, flatten -from torchjd.aggregation import Flattening, MeanWeighting, SumWeighting, Weighting - - -@mark.parametrize( - "half_shape", - [ - [1], - [12], - [4, 3], - [2, 3, 2], - ], -) -@mark.parametrize( - "weighting", - [ - SumWeighting(), - MeanWeighting(), - base_weighting(), - ], -) -def test_flattening(half_shape: list[int], weighting: Weighting[PSDMatrix]) -> None: - matrix = randn_([*half_shape, 2]) - generalized_gramian = compute_gramian(matrix, 1) - gramian = flatten(generalized_gramian) - - flattening = Flattening(weighting) - weights = flattening(generalized_gramian) - - expected_weights = weighting(gramian).reshape(half_shape) - assert_close(weights, expected_weights) diff --git a/tests/unit/autogram/test_engine.py b/tests/unit/autogram/test_engine.py index 0ca97f9c..84fb1f5f 100644 --- a/tests/unit/autogram/test_engine.py +++ b/tests/unit/autogram/test_engine.py @@ -1,7 +1,6 @@ from collections.abc import Callable from itertools import combinations from math import prod -from typing import cast import pytest import torch @@ -80,7 +79,7 @@ from utils.optional_deps import base_weighting from utils.tensors import make_inputs_and_targets, ones_, randn_, zeros_ -from torchjd._linalg import PSDMatrix, compute_gramian, movedim, reshape +from torchjd._linalg import compute_gramian from torchjd.autogram._engine import Engine PARAMETRIZATIONS = [ @@ -287,10 +286,8 @@ def test_compute_gramian_various_output_shapes( losses, params = _get_losses_and_params(model_autograd, inputs, loss_fn, reduction) reshaped_losses = torch.movedim(losses, movedim_source, movedim_destination) - # Go back to a vector so that compute_gramian_with_autograd works loss_vector = reshaped_losses.reshape([-1]) - autograd_gramian = compute_gramian_with_autograd(loss_vector, params) - expected_gramian = reshape(autograd_gramian, list(reshaped_losses.shape)) + expected_gramian = compute_gramian_with_autograd(loss_vector, params) engine = Engine(model_autogram, batch_dim=batch_dim) losses = forward_pass(model_autogram, inputs, loss_fn, reduction) @@ -455,11 +452,10 @@ def test_compute_gramian_manual() -> None: [1], ], ) -def test_reshape_equivariance(shape: list[int]) -> None: +def test_reshape_invariance(shape: list[int]) -> None: """ - Test equivariance of `compute_gramian` under reshape operation. More precisely, if we reshape - the `output` to some `shape`, then the result is the same as reshaping the Gramian to the - corresponding shape. + Test that compute_gramian returns the same flat [m, m] gramian regardless of how the output is + shaped. """ input_size = shape[0] @@ -470,52 +466,45 @@ def test_reshape_equivariance(shape: list[int]) -> None: engine1 = Engine(model1, batch_dim=None) output = model1(input) - gramian = cast(PSDMatrix, engine1.compute_gramian(output)) - expected_reshaped_gramian = reshape(gramian, shape[1:]) + gramian = engine1.compute_gramian(output) engine2 = Engine(model2, batch_dim=None) reshaped_output = model2(input).reshape(shape[1:]) reshaped_gramian = engine2.compute_gramian(reshaped_output) - assert_close(reshaped_gramian, expected_reshaped_gramian) + assert_close(reshaped_gramian, gramian) @mark.parametrize( - ["shape", "source", "destination"], + "shape", [ - ([50, 2, 2, 3], [0, 2], [1, 0]), - ([60, 3, 2, 5], [1], [2]), - ([30, 6, 7], [0, 1], [1, 0]), - ([3, 2], [0], [0]), - ([3], [], []), - ([3, 2, 1], [1, 0], [0, 1]), - ([4, 3, 2], [], []), - ([1, 1, 1], [1, 0], [0, 1]), + [50, 2, 2, 3], + [60, 3, 2, 5], + [30, 6, 7], + [3, 2], + [3], + [3, 2, 1], + [4, 3, 2], + [1, 1, 1], ], ) -def test_movedim_equivariance(shape: list[int], source: list[int], destination: list[int]) -> None: +def test_gramian_has_correct_shape(shape: list[int]) -> None: """ - Test equivariance of `compute_gramian` under movedim operation. More precisely, if we movedim - the `output` on some dimensions, then the result is the same as movedim on the Gramian with the - corresponding dimensions. + Test that compute_gramian always returns a [m, m] matrix where m is the total number of + elements of the output tensor, regardless of how the output is shaped. """ input_size = shape[0] output_size = prod(shape[1:]) factory = ModuleFactory(Linear, input_size, output_size) - model1, model2 = factory(), factory() + model = factory() input = randn_([input_size]) - engine1 = Engine(model1, batch_dim=None) - output = model1(input).reshape(shape[1:]) - gramian = cast(PSDMatrix, engine1.compute_gramian(output)) - expected_moved_gramian = movedim(gramian, source, destination) - - engine2 = Engine(model2, batch_dim=None) - moved_output = model2(input).reshape(shape[1:]).movedim(source, destination) - moved_gramian = engine2.compute_gramian(moved_output) + engine = Engine(model, batch_dim=None) + output = model(input).reshape(shape[1:]) + gramian = engine.compute_gramian(output) - assert_close(moved_gramian, expected_moved_gramian) + assert gramian.shape == (output_size, output_size) @mark.parametrize(