🍓 CR-MOGM, DualConeProjector, setters
This PR introduces CRMOGMWeighting, a wrapper that can be applied to any weighting to stabilize its weights by computing the exponential moving average of the weights returned by the wrapped weighting. This was originally introduced in On the Convergence of Stochastic Multi-Objective Gradient Manipulation and Beyond. Many thanks to @KhusPatel4450 for contributing this!
The interface is also improved, with:
- A new abstraction
DualConeProjectorand its concreteQuadprogProjectorimplementation, to do the projection of the gradients onto the dual cone, as required inUPGrad, andDualProj. We plan to add more projectors in the future. Many thanks to @PierreQuinton for this! - Setters for the attributes of all existing aggregators and weighting. Many thanks to @mattbuot for implementing them!
- Some classes of torchjd becoming public:
Matrix,PSDMatrix,WeightedAggregatorandGramianWeightedAggregator.
Contributors
Changelog
Changed
- BREAKING: Removed
norm_eps,rep_epsandsolverparameters from the__init__of
UPGrad,UPGradWeighting,DualProjandDualProjWeightingin favor of aprojector
parameter of typeDualConeProjector. To update:If you used the default# Before from torchjd.aggregation import UPGrad aggregator = UPGrad(norm_eps=1e-6, reg_eps=1e-6, solver="quadprog") # After from torchjd.aggregation import UPGrad from torchjd.linalg import QuadprogProjector aggregator = UPGrad(projector=QuadprogProjector(norm_eps=1e-6, reg_eps=1e-6))
norm_eps,reg_epsandsolver, you don't have to change anything and
you will get the same results. CAGrad,CAGradWeighting, andNashMTLare now always importable fromtorchjd.aggregation,
even when their optional dependencies are not installed. Attempting to instantiate them without the
required dependencies now raises anImportErrorwith installation instructions, instead of
raising anImportErrorat import time.- Non-differentiable aggregators and weightings (UPGrad, DualProj, PCGrad, GradVac, IMTLG,
GradDrop, ConFIG, CAGrad, NashMTL) no longer build a computation graph when called on tensors
that require gradients. Their forward pass is now wrapped intorch.no_grad(), so attempting to
differentiate through them is not possible anymore (while before, it raised aNonDifferentiableError).
Added
- Added
CRMOGMWeightingfrom On the Convergence of Stochastic Multi-Objective Gradient
Manipulation and Beyond
(NeurIPS 2022). It wraps an existingWeightingand stabilises its weights with an exponential
moving average across calls. - Added a new abstraction: the
DualConeProjectorabstract base class and its concrete
QuadprogProjectorimplementation, to do the projection of the gradients onto the dual cone, as
required inUPGrad, andDualProj. These classes can be found intorchjd.linalg. - Made
WeightedAggregatorandGramianWeightedAggregatorpublic. These abstract base classes are
now importable fromtorchjd.aggregationand documented. They can be extended to easily implement
customAggregators. - Made
MatrixandPSDMatrixpublic. These type annotation classes are now importable from
torchjd.linalgand documented. Users can now subclassWeighting[Matrix]or
Weighting[PSDMatrix]to implement customWeightings. - Added getters and setters for the constructor parameters of all aggregators and weightings, so
that they can be changed after initialization. This includes:pref_vector,
norm_epsandreg_epsinUPGrad,UPGradWeighting,DualProjandDualProjWeighting;
pref_vectorandscale_modeinAlignedMTLandAlignedMTLWeighting;candnorm_epsin
CAGradandCAGradWeighting;pref_vectorinConFIG;leakinGradDrop,n_byzantineand
n_selectedinKrumandKrumWeighting;epsilonandmax_itersinMGDAand
MGDAWeighting;n_tasks,max_norm,update_weights_everyandoptim_niterinNashMTL;
trim_numberinTrimmedMean. Setters validate their inputs matching the existing constructor
checks. Note that setters forGradVacandGradVacWeightingalready existed.