Skip to content

v0.11.0

Latest

Choose a tag to compare

@ValerianRey ValerianRey released this 18 May 13:58
· 16 commits to main since this release
af321be

🍓 CR-MOGM, DualConeProjector, setters

This PR introduces CRMOGMWeighting, a wrapper that can be applied to any weighting to stabilize its weights by computing the exponential moving average of the weights returned by the wrapped weighting. This was originally introduced in On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond. Many thanks to @KhusPatel4450 for contributing this!

The interface is also improved, with:

  • A new abstraction DualConeProjector and its concrete QuadprogProjector implementation, to do the projection of the gradients onto the dual cone, as required in UPGrad, and DualProj. We plan to add more projectors in the future. Many thanks to @PierreQuinton for this!
  • Setters for the attributes of all existing aggregators and weighting. Many thanks to @mattbuot for implementing them!
  • Some classes of torchjd becoming public: Matrix, PSDMatrix, WeightedAggregator and GramianWeightedAggregator.

Contributors

Changelog

Changed

  • BREAKING: Removed norm_eps, rep_eps and solver parameters from the __init__ of
    UPGrad, UPGradWeighting, DualProj and DualProjWeighting in favor of a projector
    parameter of type DualConeProjector. To update:
    # Before
    from torchjd.aggregation import UPGrad
    aggregator = UPGrad(norm_eps=1e-6, reg_eps=1e-6, solver="quadprog")
    
    # After
    from torchjd.aggregation import UPGrad
    from torchjd.linalg import QuadprogProjector
    aggregator = UPGrad(projector=QuadprogProjector(norm_eps=1e-6, reg_eps=1e-6))
    If you used the default norm_eps, reg_eps and solver, you don't have to change anything and
    you will get the same results.
  • CAGrad, CAGradWeighting, and NashMTL are now always importable from torchjd.aggregation,
    even when their optional dependencies are not installed. Attempting to instantiate them without the
    required dependencies now raises an ImportError with installation instructions, instead of
    raising an ImportError at import time.
  • Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG,
    GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors
    that require gradients. Their forward pass is now wrapped in torch.no_grad(), so attempting to
    differentiate through them is not possible anymore (while before, it raised a NonDifferentiableError).

Added

  • Added CRMOGMWeighting from On the Convergence of Stochastic Multi-Objective Gradient
    Manipulation and Beyond

    (NeurIPS 2022). It wraps an existing Weighting and stabilises its weights with an exponential
    moving average across calls.
  • Added a new abstraction: the DualConeProjector abstract base class and its concrete
    QuadprogProjector implementation, to do the projection of the gradients onto the dual cone, as
    required in UPGrad, and DualProj. These classes can be found in torchjd.linalg.
  • Made WeightedAggregator and GramianWeightedAggregator public. These abstract base classes are
    now importable from torchjd.aggregation and documented. They can be extended to easily implement
    custom Aggregators.
  • Made Matrix and PSDMatrix public. These type annotation classes are now importable from
    torchjd.linalg and documented. Users can now subclass Weighting[Matrix] or
    Weighting[PSDMatrix] to implement custom Weightings.
  • Added getters and setters for the constructor parameters of all aggregators and weightings, so
    that they can be changed after initialization. This includes: pref_vector,
    norm_eps and reg_eps in UPGrad, UPGradWeighting, DualProj and DualProjWeighting;
    pref_vector and scale_mode in AlignedMTL and AlignedMTLWeighting; c and norm_eps in
    CAGrad and CAGradWeighting; pref_vector in ConFIG; leak in GradDrop, n_byzantine and
    n_selected in Krum and KrumWeighting; epsilon and max_iters in MGDA and
    MGDAWeighting; n_tasks, max_norm, update_weights_every and optim_niter in NashMTL;
    trim_number in TrimmedMean. Setters validate their inputs matching the existing constructor
    checks. Note that setters for GradVac and GradVacWeighting already existed.