Skip to content

Commit 9b855a7

Browse files
Alex Sablayrollesfacebook-github-bot
authored andcommitted
No-Op GradSampleModule (#492)
Summary: TL;DR: Adding a No-Op GradSampleModule in case the grad samples are computed by functorch. The CIFAR10 example has been updated to show a typical use-case for that. The neat thing about functorch is that it directly gives the per-sample gradients with a couple of lines of code. These per-sample gradients are then manually given to `p.grad_sample` by the end-user. Pull Request resolved: #492 Reviewed By: ffuuugor Differential Revision: D39204008 Pulled By: alexandresablayrolles fbshipit-source-id: 22036e6c941522bba7749ef46f97d54f6ee8c551
1 parent 38b24dc commit 9b855a7

File tree

5 files changed

+92
-10
lines changed

5 files changed

+92
-10
lines changed

.circleci/config.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,8 @@ commands:
162162
pip install tensorboard
163163
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <<parameters.device>>
164164
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
165+
python examples/cifar10.py --lr 0.1 --sigma 1.5 -c 10 --batch-size 2000 --epochs 10 --data-root runs/cifar10/data --log-dir runs/cifar10/logs --device <<parameters.device>> --grad_sample_mode no_op
166+
python -c "import torch; model = torch.load('model_best.pth.tar'); exit(0) if (model['best_acc1']>0.4 and model['best_acc1']<0.49) else exit(1)"
165167
when: always
166168
- store_test_results:
167169
path: runs/cifar10/test-reports

examples/cifar10.py

Lines changed: 41 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -138,25 +138,55 @@ def train(args, model, train_loader, optimizer, privacy_engine, epoch, device):
138138
losses = []
139139
top1_acc = []
140140

141+
if args.grad_sample_mode == "no_op":
142+
from functorch import grad_and_value, make_functional, vmap
143+
144+
# Functorch prepare
145+
fmodel, _fparams = make_functional(model)
146+
147+
def compute_loss_stateless_model(params, sample, target):
148+
batch = sample.unsqueeze(0)
149+
targets = target.unsqueeze(0)
150+
151+
predictions = fmodel(params, batch)
152+
loss = criterion(predictions, targets)
153+
return loss
154+
155+
ft_compute_grad = grad_and_value(compute_loss_stateless_model)
156+
ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, 0, 0))
157+
# Using model.parameters() instead of fparams
158+
# as fparams seems to not point to the dynamically updated parameters
159+
params = list(model.parameters())
160+
141161
for i, (images, target) in enumerate(tqdm(train_loader)):
142162

143163
images = images.to(device)
144164
target = target.to(device)
145165

146166
# compute output
147167
output = model(images)
148-
loss = criterion(output, target)
149-
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
150-
labels = target.detach().cpu().numpy()
151168

152-
# measure accuracy and record loss
153-
acc1 = accuracy(preds, labels)
169+
if args.grad_sample_mode == "no_op":
170+
per_sample_grads, per_sample_losses = ft_compute_sample_grad(
171+
params, images, target
172+
)
173+
per_sample_grads = [g.detach() for g in per_sample_grads]
174+
loss = torch.mean(per_sample_losses)
175+
for (p, g) in zip(params, per_sample_grads):
176+
p.grad_sample = g
177+
else:
178+
loss = criterion(output, target)
179+
preds = np.argmax(output.detach().cpu().numpy(), axis=1)
180+
labels = target.detach().cpu().numpy()
154181

155-
losses.append(loss.item())
156-
top1_acc.append(acc1)
182+
# measure accuracy and record loss
183+
acc1 = accuracy(preds, labels)
184+
top1_acc.append(acc1)
185+
186+
# compute gradient and do SGD step
187+
loss.backward()
157188

158-
# compute gradient and do SGD step
159-
loss.backward()
189+
losses.append(loss.item())
160190

161191
# make sure we take a step after processing the last mini-batch in the
162192
# epoch to ensure we start the next epoch with a clean state
@@ -331,6 +361,7 @@ def main():
331361
noise_multiplier=args.sigma,
332362
max_grad_norm=max_grad_norm,
333363
clipping=clipping,
364+
grad_sample_mode=args.grad_sample_mode,
334365
)
335366

336367
# Store some logs
@@ -388,6 +419,7 @@ def main():
388419

389420
def parse_args():
390421
parser = argparse.ArgumentParser(description="PyTorch CIFAR10 DP Training")
422+
parser.add_argument("--grad_sample_mode", type=str, default="hooks")
391423
parser.add_argument(
392424
"-j",
393425
"--workers",

opacus/grad_sample/gsm_no_op.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import torch
17+
import torch.nn as nn
18+
from opacus.grad_sample.gsm_base import AbstractGradSampleModule
19+
20+
21+
class GradSampleModuleNoOp(AbstractGradSampleModule):
22+
"""
23+
NoOp GradSampleModule.
24+
Only wraps the module. The main goal of this class is to provide the same API for all methods.
25+
See README.md for more details
26+
"""
27+
28+
def __init__(
29+
self,
30+
m: nn.Module,
31+
*,
32+
batch_first=True,
33+
loss_reduction="mean",
34+
):
35+
if not batch_first:
36+
raise NotImplementedError
37+
38+
super().__init__(
39+
m,
40+
batch_first=batch_first,
41+
loss_reduction=loss_reduction,
42+
)
43+
44+
def forward(self, x: torch.Tensor, *args, **kwargs):
45+
return self._module.forward(x, *args, **kwargs)

opacus/grad_sample/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .grad_sample_module import GradSampleModule
2121
from .gsm_base import AbstractGradSampleModule
2222
from .gsm_exp_weights import GradSampleModuleExpandedWeights
23+
from .gsm_no_op import GradSampleModuleNoOp
2324

2425

2526
def register_grad_sampler(
@@ -69,6 +70,8 @@ def get_gsm_class(grad_sample_mode: str) -> Type[AbstractGradSampleModule]:
6970
return GradSampleModule
7071
elif grad_sample_mode == "ew":
7172
return GradSampleModuleExpandedWeights
73+
elif grad_sample_mode == "no_op":
74+
return GradSampleModuleNoOp
7275
else:
7376
raise ValueError(
7477
f"Unexpected grad_sample_mode: {grad_sample_mode}. "

opacus/optimizers/optimizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@ def clip_and_accumulate(self):
395395
"""
396396

397397
per_param_norms = [
398-
g.norm(2, dim=tuple(range(1, g.ndim))) for g in self.grad_samples
398+
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
399399
]
400400
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
401401
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(

0 commit comments

Comments
 (0)