Skip to content

feat(aggregation): Add CRMOGMWeighting#669

Merged
ValerianRey merged 19 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/cr-mogm-weighting
May 18, 2026
Merged

feat(aggregation): Add CRMOGMWeighting#669
ValerianRey merged 19 commits into
SimplexLab:mainfrom
KhusPatel4450:feat/cr-mogm-weighting

Conversation

@KhusPatel4450
Copy link
Copy Markdown
Contributor

  • Adds CRMOGMWeighting, a stateful Weighting modifier from Conflict-Reduction Multi-Objective Gradient Methods (NeurIPS 2022)
  • Wraps any existing Weighting and smooths its output with an EMA: λk = α·λ{k-1} + (1−α)·λ̂_k
  • Generic over the input type so it composes correctly with both WeightedAggregator and GramianWeightedAggregator
  • Stateful via the Stateful mixin; reset() restores uniform initial weights

Tests:

  • uv run pytest tests/unit/aggregation/test_cr_mogm.py -v, 92 tests covering EMA recurrence, alpha boundaries, reset, structural checks on both aggregator paths
  • uv run pytest tests/unit -q, full regression (2889 passed)
  • uv run ty check src/torchjd/aggregation/_cr_mogm.py, passes
  • Sphinx doctest, 94 tests, 0 failures

@KhusPatel4450 KhusPatel4450 changed the title Add CRMOGMWeighting from NeurIPS 2022 (Aggregation Feature) feat(aggregation): Add CRMOGMWeighting from NeurIPS 2022 May 7, 2026
@ValerianRey ValerianRey changed the title feat(aggregation): Add CRMOGMWeighting from NeurIPS 2022 feat(aggregation): Add CRMOGMWeighting May 7, 2026
@ValerianRey ValerianRey added cc: feat Conventional commit type for new features. package: aggregation labels May 7, 2026
@ValerianRey
Copy link
Copy Markdown
Contributor

Thanks a lot for the PR! I'm gonna review soon! In the meantime, you can try to get the CI the pass

@ValerianRey ValerianRey mentioned this pull request May 7, 2026
@KhusPatel4450
Copy link
Copy Markdown
Contributor Author

Hello,

Happy to say, all checks have been passed! Glad to have got my first PR as well. Looking forward to feedback

Comment thread docs/source/docs/aggregation/cr_mogm.rst
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread tests/unit/aggregation/test_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread tests/unit/aggregation/test_cr_mogm.py
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread CHANGELOG.md Outdated
@KhusPatel4450
Copy link
Copy Markdown
Contributor Author

All the things addressed:

  • Reset propagation: reset() now calls self.weighting.reset() if the wrapped weighting is Stateful.

  • device/dtype/m from weighting output: forward() now calls self.weighting(stat) first and reads everything from lambda_hat, not from stat.

  • Removed repr and the corresponding test_representations test.

  • Removed _state_key: _ensure_state now checks shape/dtype/device directly off _lambda.

  • Added test for reset() propagation using GradVacWeighting as the inner stateful weighting.

Still open:

  • Initial weight strategy (uniform 1/m vs first weighting output)

  • Type checking failure

Comment thread src/torchjd/aggregation/_cr_mogm.py
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread CHANGELOG.md Outdated
@KhusPatel4450
Copy link
Copy Markdown
Contributor Author

Hello I updated the code with the changes that were requested with these two commits, its just that 2nd commit has the similified version and the raise on shape change in CRMOGMWeighting._ensure_state

@ValerianRey

This comment was marked as resolved.

ValerianRey added a commit that referenced this pull request May 9, 2026
- Introduces a new public `torchjd.linalg` package exposing `Matrix` and
`PSDMatrix` (the rest of `_linalg` stays protected)
- Makes `MatrixWeighting` and `GramianWeighting` protected. These
classes are still used to specify the docstring of the `__call__`
methods of the aggregators, but the user only sees those aggregators as
`Weighting[Matrix]` and `Weighting[PSDMatrix]`, respectively. The
`MatrixWeighting` and `GramianWeighting` classes really just bring
updated docstrings, that's all.
- Makes the public type of the gramian_weighting of
GramianWeightedAggregator be Weighting[PSDMatrix] instead of
GramianWeighting, so that #669 can work. Similar with weighting of
WeightedAggregator being Weighting[Matrix].
- Expands docstrings on `Matrix` and `PSDMatrix` with Jacobian and
Gramian examples; adds Sphinx documentation under a new **linalg**
section in the API Reference
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
@ValerianRey
Copy link
Copy Markdown
Contributor

/opencode:Plan Please review this with a lot of precision.

@opencode-agent

This comment was marked as resolved.

Comment thread src/torchjd/aggregation/_cr_mogm.py Outdated
ValerianRey and others added 7 commits May 18, 2026 14:12
Replace the hardcoded uniform λ₀ = 1/m with an optional
`initial_weights` parameter. When `None` (default), λ₀ is set to
λ̂₁ on the first forward call so the first smoothed output always
equals the wrapped weighting's output regardless of α. Users who
want uniform initialisation can still pass the tensor explicitly.

Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
@ValerianRey
Copy link
Copy Markdown
Contributor

Ready to merge IMO @KhusPatel4450 @PierreQuinton

To summarize, I:

  • Changed the default alpha to 0.9
  • Added a initial_weights optional param for lambda_0, and changed the default lambda_0 to be lambda_hat_1 if this parameter is None
  • Improved the docstrings and made them doctested
  • Removed a remaining .detach() (that would prevent differentiation if user wants it). Note that if the wrapped weighting is NonDifferentiable, the lambda_hat will already be detached anyway, and if the user aggregates a matrix that is detached (as in 99.9% of the cases) this will have no effect. So really this change will not make any difference except in the niche case where a user really wants to differentiate through CR-MOGM.

Copy link
Copy Markdown
Contributor

@PierreQuinton PierreQuinton left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@ValerianRey ValerianRey merged commit 3d0d0a5 into SimplexLab:main May 18, 2026
16 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cc: feat Conventional commit type for new features. package: aggregation

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants