refactor(aggregation)!: Remove generalized gramians#692
Conversation
Simplify `Engine.compute_gramian` to always return a flat `[m, m]` PSD matrix where `m = output.numel()`, removing the concepts of generalized Gramians, `PSDTensor`, `GeneralizedWeighting`, and `Flattening`. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
ce15075 to
5ad48d2
Compare
PierreQuinton
left a comment
There was a problem hiding this comment.
I think I like it but let me think about it. Also I don't think we want to erase all private generalized_gramian utilities and put their implementation in autogram.
|
@PierreQuinton this is now ready for review. I reverted to (almost) the original implementation of autogram: compute square gramian, reshape, movedim. The extra flatten is because we now want to flatten. In a future PR we should be able to let |
That would be nice. But I think that this is not trivial with the batched dims (I guess all could have one, and possibly a different? I think in principle that would work but not trivial). |
Co-authored-by: Pierre Quinton <pierre.quinton@epfl.ch>
Closes #690
Engine.compute_gramiannow always returns a flat[m, m]gramian, regardless of the output shapeGeneralizedWeighting, andFlattening— they are no longer neededUPGradWeightingdirectly and reshape the weights before callingbackwardCHANGELOG.md