💉 Stateful Aggregators, Gradient Vaccine and Grouping
TorchJD has always been limited to stateless aggregators, but we're now opening up the implementation of stateful methods too.
This release introduces the Stateful mixin for aggregators that retain some state from their previous aggregations. As an example, it also adds the GradVac aggregator, from Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models. Note that because of its implementation, NashMTL is also considered as stateful.
This also introduces a new Grouping usage example, explaining how to partition the Jacobian into several groups of parameters before aggregating each group independently.
Lastly, we have updated the contribution guidelines to be able to add aggregators that also depend on the loss values or any other information than the Jacobian. We hope that this will enable us to support a few more aggregators soon. Feel free to open issues or pull requests if you have ideas in mind!
Many thanks to @rkhosrowshahi @PierreQuinton and @ValerianRey for the contributions!
Changelog
Added
- Added
GradVacandGradVacWeightingfrom
Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models. - Documented per-parameter-group aggregation (GradVac-style grouping) in a new Grouping example.
Fixed
- Added a fallback for when the inner optimization of
NashMTLfails (which can happen for example
on the matrix [[0., 0.], [0., 1.]]).