Skip to content

Commit bdff933

Browse files
docs: Add Grouping example (#648)
* Add Grouping example * Add link to grouping example in examples/index.html and in GradVac's docstring * Add changelog entry
1 parent 477cfbd commit bdff933

File tree

4 files changed

+186
-0
lines changed

4 files changed

+186
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ changelog does not include internal changes that do not affect the user.
1212

1313
- Added `GradVac` and `GradVacWeighting` from
1414
[Gradient Vaccine: Investigating and Improving Multi-task Optimization in Massively Multilingual Models](https://arxiv.org/pdf/2010.05874).
15+
- Documented per-parameter-group aggregation (GradVac-style grouping) in a new Grouping example.
1516

1617
### Fixed
1718

docs/source/examples/grouping.rst

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
Grouping
2+
========
3+
4+
The aggregation can be made independently on groups of parameters, at different granularities. The
5+
`Gradient Vaccine paper <https://arxiv.org/pdf/2010.05874>`_ introduces four strategies to partition
6+
the parameters:
7+
8+
1. **Together** (baseline): one group covering all parameters. Corresponds to the `whole_model`
9+
stategy in the paper.
10+
11+
2. **Per network**: one group per top-level sub-network (e.g. encoder and decoder separately).
12+
Corresponds to the `enc_dec` stategy in the paper.
13+
14+
3. **Per layer**: one group per leaf module of the network. Corresponds to the `all_layer` stategy
15+
in the paper.
16+
17+
4. **Per tensor**: one group per individual parameter tensor. Corresponds to the `all_matrix`
18+
stategy in the paper.
19+
20+
In TorchJD, grouping is achieved by calling :func:`~torchjd.autojac.jac_to_grad` once per group
21+
after :func:`~torchjd.autojac.backward` or :func:`~torchjd.autojac.mtl_backward`, with a dedicated
22+
aggregator instance per group. For :class:`~torchjd.aggregation.Stateful` aggregators, each instance
23+
should independently maintains its own state (e.g. the EMA :math:`\hat{\phi}` state in
24+
:class:`~torchjd.aggregation.GradVac`, matching the per-block targets from the original paper).
25+
26+
.. note::
27+
The grouping is orthogonal to the choice between
28+
:func:`~torchjd.autojac.backward` vs :func:`~torchjd.autojac.mtl_backward`. Those functions
29+
determine *which* parameters receive Jacobians; grouping then determines *how* those Jacobians
30+
are partitioned for aggregation.
31+
32+
.. note::
33+
The examples below use :class:`~torchjd.aggregation.GradVac`, but the same pattern applies to
34+
any :class:`~torchjd.aggregation.Aggregator`.
35+
36+
1. Together
37+
-----------
38+
39+
A single :class:`~torchjd.aggregation.Aggregator` instance aggregates all shared parameters
40+
together. Cosine similarities are computed between the full task gradient vectors.
41+
42+
.. testcode::
43+
:emphasize-lines: 14, 21
44+
45+
import torch
46+
from torch.nn import Linear, MSELoss, ReLU, Sequential
47+
from torch.optim import SGD
48+
49+
from torchjd.aggregation import GradVac
50+
from torchjd.autojac import jac_to_grad, mtl_backward
51+
52+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
53+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
54+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
55+
loss_fn = MSELoss()
56+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
57+
58+
aggregator = GradVac()
59+
60+
for x, y1, y2 in zip(inputs, t1, t2):
61+
features = encoder(x)
62+
loss1 = loss_fn(task1_head(features), y1)
63+
loss2 = loss_fn(task2_head(features), y2)
64+
mtl_backward([loss1, loss2], features=features)
65+
jac_to_grad(encoder.parameters(), aggregator)
66+
optimizer.step()
67+
optimizer.zero_grad()
68+
69+
2. Per network
70+
--------------
71+
72+
One :class:`~torchjd.aggregation.Aggregator` instance per top-level sub-network. Here the model
73+
is split into an encoder and a decoder; cosine similarities are computed separately within each.
74+
Passing ``features=dec_out`` to :func:`~torchjd.autojac.mtl_backward` causes both sub-networks
75+
to receive Jacobians, which are then aggregated independently.
76+
77+
.. testcode::
78+
:emphasize-lines: 8-9, 15-16, 24-25
79+
80+
import torch
81+
from torch.nn import Linear, MSELoss, ReLU, Sequential
82+
from torch.optim import SGD
83+
84+
from torchjd.aggregation import GradVac
85+
from torchjd.autojac import jac_to_grad, mtl_backward
86+
87+
encoder = Sequential(Linear(10, 5), ReLU())
88+
decoder = Sequential(Linear(5, 3), ReLU())
89+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
90+
optimizer = SGD([*encoder.parameters(), *decoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
91+
loss_fn = MSELoss()
92+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
93+
94+
encoder_aggregator = GradVac()
95+
decoder_aggregator = GradVac()
96+
97+
for x, y1, y2 in zip(inputs, t1, t2):
98+
enc_out = encoder(x)
99+
dec_out = decoder(enc_out)
100+
loss1 = loss_fn(task1_head(dec_out), y1)
101+
loss2 = loss_fn(task2_head(dec_out), y2)
102+
mtl_backward([loss1, loss2], features=dec_out)
103+
jac_to_grad(encoder.parameters(), encoder_aggregator)
104+
jac_to_grad(decoder.parameters(), decoder_aggregator)
105+
optimizer.step()
106+
optimizer.zero_grad()
107+
108+
3. Per layer
109+
------------
110+
111+
One :class:`~torchjd.aggregation.Aggregator` instance per leaf module. Cosine similarities are
112+
computed per-layer between the task gradients.
113+
114+
.. testcode::
115+
:emphasize-lines: 14-15, 22-23
116+
117+
import torch
118+
from torch.nn import Linear, MSELoss, ReLU, Sequential
119+
from torch.optim import SGD
120+
121+
from torchjd.aggregation import GradVac
122+
from torchjd.autojac import jac_to_grad, mtl_backward
123+
124+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
125+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
126+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
127+
loss_fn = MSELoss()
128+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
129+
130+
leaf_layers = [m for m in encoder.modules() if list(m.parameters()) and not list(m.children())]
131+
aggregators = [GradVac() for _ in leaf_layers]
132+
133+
for x, y1, y2 in zip(inputs, t1, t2):
134+
features = encoder(x)
135+
loss1 = loss_fn(task1_head(features), y1)
136+
loss2 = loss_fn(task2_head(features), y2)
137+
mtl_backward([loss1, loss2], features=features)
138+
for layer, aggregator in zip(leaf_layers, aggregators):
139+
jac_to_grad(layer.parameters(), aggregator)
140+
optimizer.step()
141+
optimizer.zero_grad()
142+
143+
4. Per parameter
144+
----------------
145+
146+
One :class:`~torchjd.aggregation.Aggregator` instance per individual parameter tensor. Cosine
147+
similarities are computed per-tensor between the task gradients (e.g. weights and biases of each
148+
layer are treated as separate groups).
149+
150+
.. testcode::
151+
:emphasize-lines: 14-15, 22-23
152+
153+
import torch
154+
from torch.nn import Linear, MSELoss, ReLU, Sequential
155+
from torch.optim import SGD
156+
157+
from torchjd.aggregation import GradVac
158+
from torchjd.autojac import jac_to_grad, mtl_backward
159+
160+
encoder = Sequential(Linear(10, 5), ReLU(), Linear(5, 3), ReLU())
161+
task1_head, task2_head = Linear(3, 1), Linear(3, 1)
162+
optimizer = SGD([*encoder.parameters(), *task1_head.parameters(), *task2_head.parameters()], lr=0.1)
163+
loss_fn = MSELoss()
164+
inputs, t1, t2 = torch.randn(8, 16, 10), torch.randn(8, 16, 1), torch.randn(8, 16, 1)
165+
166+
shared_params = list(encoder.parameters())
167+
aggregators = [GradVac() for _ in shared_params]
168+
169+
for x, y1, y2 in zip(inputs, t1, t2):
170+
features = encoder(x)
171+
loss1 = loss_fn(task1_head(features), y1)
172+
loss2 = loss_fn(task2_head(features), y2)
173+
mtl_backward([loss1, loss2], features=features)
174+
for param, aggregator in zip(shared_params, aggregators):
175+
jac_to_grad([param], aggregator)
176+
optimizer.step()
177+
optimizer.zero_grad()

docs/source/examples/index.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ This section contains some usage examples for TorchJD.
2929
- :doc:`PyTorch Lightning Integration <lightning_integration>` showcases how to combine
3030
TorchJD with PyTorch Lightning, by providing an example implementation of a multi-task
3131
``LightningModule`` optimized by Jacobian descent.
32+
- :doc:`Grouping <grouping>` shows how to apply an aggregator independently per parameter group
33+
(e.g. per layer), so that conflict resolution happens at a finer granularity than the full
34+
parameter vector.
3235
- :doc:`Automatic Mixed Precision <amp>` shows how to combine mixed precision training with TorchJD.
3336

3437
.. toctree::
@@ -43,3 +46,4 @@ This section contains some usage examples for TorchJD.
4346
monitoring.rst
4447
lightning_integration.rst
4548
amp.rst
49+
grouping.rst

src/torchjd/aggregation/_gradvac.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,10 @@ class GradVac(GramianWeightedAggregator, Stateful):
4040
For each task :math:`i`, the order of other tasks :math:`j` is shuffled independently
4141
using the global PyTorch RNG (``torch.randperm``). Seed it with ``torch.manual_seed`` if
4242
you need reproducibility.
43+
44+
.. note::
45+
To apply GradVac with the `whole_model`, `enc_dec`, `all_layer` or `all_matrix` grouping
46+
strategy, please refer to the :doc:`Grouping </examples/grouping>` examples.
4347
"""
4448

4549
def __init__(self, beta: float = 0.5, eps: float = 1e-8) -> None:

0 commit comments

Comments
 (0)