fix(aggregation): Fix __call__ docs#693
Conversation
…t in MRO Before this change, every non-differentiable aggregator/weighting had _NonDifferentiable listed first in its base-class tuple, e.g. `class PCGradWeighting(_NonDifferentiable, _GramianWeighting)`. Because Python's MRO resolves `__call__` to the first class that defines it, Sphinx documented the method with _NonDifferentiable.__call__'s generic `(*args, **kwargs)` signature instead of the more specific one from _GramianWeighting or Aggregator. The fix is to list _NonDifferentiable after the primary base class. The cooperative super().__call__() chain then becomes: _GramianWeighting.__call__(gramian) → Weighting.__call__(gramian) → _NonDifferentiable.__call__(*args) [applies no_grad] → nn.Module.__call__(...) The no_grad wrapping is fully preserved because every class in the chain calls super().__call__(), so _NonDifferentiable is still reached — just later in the chain. The old warning in _NonDifferentiable said it must appear "before any nn.Module base class", which was imprecise: what actually matters is that it appears before nn.Module *itself* in the resolved MRO, which C3 linearization guarantees as long as super() chains are cooperative. All 2982 unit tests pass. Generated docs now show the correct parameter names (gramian / matrix) for every affected __call__. Co-Authored-By: Claude Sonnet 4.6 <[email protected]>
|
The above is claude-generated. Not completely wrong but a bit off. The reality is that we only need class A(nn.Module, _NonDifferentiable):then Python would raise: So in fact, there's no way this can even fail. |
|
Even worse actually, before this, the chain was CAGradWeighting.call => NonDifferentiable.call => nn.Module.call => CAGradWeighting.forward, i.e. we never called GramianWeighting.call or Weighting.call (which just do a super.call, but still it's dangerous to skip calling them). |
|
@PierreQuinton Please take a look at this. I'll merge because the current state has a broken documentation and I'd like it to be fixed at least on the |


Problem
Every non-differentiable aggregator and weighting had
_NonDifferentiablelisted first in its base-class tuple, e.g.:
Python's MRO resolves
__call__to the first class in the MRO thatdefines it. Since
_NonDifferentiable.__call__has the generic signature(*args, **kwargs), Sphinx documented every affected method with thatunhelpful signature instead of the more specific one from
_GramianWeighting.__call__(gramian: Tensor, /)orAggregator.__call__(matrix: Tensor, /).Fix
Swap the order so
_NonDifferentiablecomes after the primary baseclass. Example:
The
no_gradwrapping is fully preserved. The cooperativesuper().__call__()chain now runs:_NonDifferentiableis still reached — just later in the chain.Misunderstanding about the MRO requirement
The old docstring warning in
_NonDifferentiablesaid:This was imprecise. What actually matters is that
_NonDifferentiableappears before
nn.Moduleitself in the resolved MRO, notnecessarily before every
nn.Modulesubclass in the inheritance list.C3 linearization guarantees the former as long as every class in the chain
calls
super().__call__(). The warning has been updated to reflect this.Scope
All 11 affected files in
src/torchjd/aggregation/were updated:_cagrad,_config,_dualproj,_fairgrad,_graddrop,_gradvac,_imtl_g,_mixins,_nash_mtl,_pcgrad,_upgrad.No other mixins (
_WithOptionalDeps,Stateful) define__call__, sothey are unaffected.
Verification
__call__(gramian, /)/__call__(matrix, /)instead of
__call__(*args, **kwargs)for every affected class.