Skip to content

Commit 523ef31

Browse files
authored
Merge pull request oneapi-src#1809 from asirvaiy/2023.2_AIKit
Intel PyTorch GPU sample with AMP
2 parents a82dd9a + 118ee65 commit 523ef31

File tree

8 files changed

+1402
-0
lines changed

8 files changed

+1402
-0
lines changed

AI-and-Analytics/Features-and-Functionality/IntelPyTorch_GPU_InferenceOptimization_with_AMP/IntelPyTorch_GPU_InferenceOptimization_with_AMP.ipynb

Lines changed: 645 additions & 0 deletions
Large diffs are not rendered by default.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
import os
2+
from time import time
3+
from tqdm import tqdm
4+
import numpy as np
5+
import torch
6+
import torchvision
7+
import intel_extension_for_pytorch as ipex
8+
9+
import matplotlib.pyplot as plt
10+
11+
12+
# Hyperparameters and constants
13+
LR = 0.01
14+
MOMENTUM = 0.9
15+
DATA = 'datasets/cifar10/'
16+
epochs = 1
17+
batch_size=128
18+
19+
#TO check if IPEX_XPU is correctly installed and can be used for PyTorch model
20+
try:
21+
device = "xpu" if torch.xpu.is_available() else "cpu"
22+
23+
except:
24+
device="cpu"
25+
26+
if device == "xpu": # XPU is for Intel dGPU
27+
print("IPEX_XPU is present and Intel GPU is available to use for PyTorch")
28+
device = "gpu"
29+
else:
30+
print("using CPU device for PyTorch")
31+
32+
33+
"""
34+
Function to run a test case
35+
"""
36+
def trainModel(train_loader, modelName="myModel", device="cpu", dataType="fp32"):
37+
"""
38+
Input parameters
39+
train_loader: a torch DataLoader object containing the training data with images and labels
40+
modelName: a string representing the name of the model
41+
device: the device to use - cpu or gpu
42+
dataType: the data type for model parameters, supported values - fp32, bf16
43+
Return value
44+
training_time: the time in seconds it takes to train the model
45+
"""
46+
47+
# Initialize the model
48+
model = torchvision.models.resnet50(pretrained=True)
49+
model.fc = torch.nn.Linear(2048,10)
50+
lin_layer = model.fc
51+
new_layer = torch.nn.Sequential(
52+
lin_layer,
53+
torch.nn.Softmax(dim=1)
54+
)
55+
model.fc = new_layer
56+
57+
#Define loss function and optimization methodology
58+
criterion = torch.nn.CrossEntropyLoss()
59+
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=MOMENTUM)
60+
model.train()
61+
62+
#export model and criterian to XPU device. GPU specific code
63+
if device == "gpu":
64+
model = model.to("xpu:0") ## GPU
65+
criterion = criterion.to("xpu:0")
66+
67+
#Optimize with BF16 or FP32(default) . BF16 specific code
68+
if "bf16" == dataType:
69+
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.bfloat16)
70+
else:
71+
model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=torch.float32)
72+
73+
#Train the model
74+
num_batches = len(train_loader) * epochs
75+
76+
77+
for i in range(epochs):
78+
running_loss = 0.0
79+
80+
for batch_idx, (data, target) in enumerate(train_loader):
81+
optimizer.zero_grad()
82+
# Export data to XPU device. GPU specific code
83+
if device == "gpu":
84+
data = data.to("xpu:0")
85+
target = target.to("xpu:0")
86+
87+
# Apply Auto-mixed precision(BF16)
88+
if "bf16" == dataType:
89+
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
90+
output = model(data)
91+
loss = criterion(output, target)
92+
loss.backward()
93+
optimizer.step()
94+
running_loss += loss.item()
95+
else:
96+
output = model(data)
97+
loss = criterion(output, target)
98+
loss.backward()
99+
optimizer.step()
100+
running_loss += loss.item()
101+
102+
103+
# Showing Average loss after 50 batches
104+
if 0 == (batch_idx+1) % 50:
105+
print("Batch %d/%d complete" %(batch_idx+1, num_batches))
106+
print(f' average loss: {running_loss / 50:.3f}')
107+
running_loss = 0.0
108+
109+
# Save a checkpoint of the trained model
110+
torch.save({
111+
'model_state_dict': model.state_dict(),
112+
'optimizer_state_dict': optimizer.state_dict(),
113+
}, 'checkpoint_%s.pth' %modelName)
114+
115+
return None
116+
117+
118+
119+
#Dataloader operations
120+
transform = torchvision.transforms.Compose([
121+
torchvision.transforms.Resize((224, 224)),
122+
torchvision.transforms.ToTensor(),
123+
torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
124+
])
125+
train_dataset = torchvision.datasets.CIFAR10(
126+
root=DATA,
127+
train = True,
128+
transform=transform,
129+
download=True,
130+
)
131+
train_loader = torch.utils.data.DataLoader(
132+
dataset=train_dataset,
133+
batch_size=batch_size
134+
)
135+
136+
test_dataset = torchvision.datasets.CIFAR10(root=DATA, train = False,
137+
download=True, transform=transform)
138+
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size )
139+
140+
141+
142+
#Model Training
143+
144+
if device=='gpu':
145+
print("Training model with FP32 on GPU, will be saved as checkpoint_gpu_rn50.pth")
146+
trainModel(train_loader, modelName="gpu_rn50", device="gpu", dataType="fp32")
147+
else:
148+
print("Training model with FP32 on CPU, will be saved as checkpoint_cpu_rn50.pth")
149+
trainModel(train_loader, modelName="cpu_rn50", device="cpu", dataType="fp32")
150+
151+
152+
153+
#Model Evaluation
154+
155+
#Load model from the saved model file
156+
def load_model(cp_file = 'checkpoint_cpu_rn50.pth'):
157+
model = torchvision.models.resnet50()
158+
model.fc = torch.nn.Linear(2048,10)
159+
lin_layer = model.fc
160+
new_layer = torch.nn.Sequential(
161+
lin_layer,
162+
torch.nn.Softmax(dim=1)
163+
)
164+
model.fc = new_layer
165+
166+
checkpoint = torch.load(cp_file)
167+
model.load_state_dict(checkpoint['model_state_dict'])
168+
return model
169+
170+
171+
172+
173+
#Applying torchscript and IPEX optimizations(Optional)
174+
def ipex_jit_optimize(model, dataType = "fp32" , device="cpu"):
175+
model.eval()
176+
177+
if device=="gpu": #export model to xpu device
178+
model = model.to("xpu:0")
179+
180+
if dataType=="bf16": # for bfloat16
181+
model = ipex.optimize(model, dtype=torch.bfloat16)
182+
else:
183+
model = ipex.optimize(model, dtype=torch.float32)
184+
185+
with torch.no_grad():
186+
d = torch.rand(1, 3, 224, 224)
187+
if device=="gpu":
188+
d = d.to("xpu:0")
189+
190+
#export model to Torchscript mode
191+
if dataType=="bf16":
192+
with torch.xpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
193+
jit_model = torch.jit.trace(model, d) # JIT trace the optimized model
194+
jit_model = torch.jit.freeze(jit_model) # JIT freeze the traced model
195+
else:
196+
jit_model = torch.jit.trace(model, d) # JIT trace the optimized model
197+
jit_model = torch.jit.freeze(jit_model) # JIT freeze the traced model
198+
return jit_model
199+
200+
201+
202+
203+
204+
def inferModel(model, test_loader, device="cpu" , dataType='fp32'):
205+
correct = 0
206+
total = 0
207+
if device == "gpu":
208+
model = model.to("xpu:0")
209+
infer_time = 0
210+
211+
with torch.no_grad():
212+
#Warm up rounds of 3 batches
213+
num_batches = len(test_loader)
214+
batches=0
215+
216+
for i, data in tqdm(enumerate(test_loader)):
217+
218+
# Record time for Inference
219+
if device=='gpu':
220+
torch.xpu.synchronize()
221+
start_time = time()
222+
images, labels = data
223+
if device =="gpu":
224+
images = images.to("xpu:0")
225+
226+
outputs = model(images)
227+
outputs = outputs.to("cpu")
228+
_, predicted = torch.max(outputs.data, 1)
229+
230+
total += labels.size(0)
231+
correct += (predicted == labels).sum().item()
232+
233+
# Record time after finishing batch inference
234+
if device=='gpu':
235+
torch.xpu.synchronize()
236+
end_time = time()
237+
238+
if i>=3 and i<=num_batches-3:
239+
infer_time += (end_time-start_time)
240+
batches += 1
241+
#Skip last few batches
242+
if i == num_batches - 3:
243+
break
244+
245+
accuracy = 100 * correct / total
246+
return accuracy, infer_time*1000/(batches*batch_size)
247+
248+
249+
250+
#Evaluation of different models
251+
def Eval_model(cp_file = 'checkpoint_model.pth', dataType = "fp32" , device="gpu" ):
252+
model = load_model(cp_file)
253+
model = ipex_jit_optimize(model, dataType , device)
254+
accuracy, bt = inferModel(model, test_loader, device, dataType )
255+
print(f' Model accuracy: {accuracy} and Average Inference latency: {bt} \n' )
256+
return accuracy, bt
257+
258+
259+
260+
#Accuracy and Inference time check
261+
262+
if device == 'cpu': #For FP32 model on CPU
263+
print("Model evaluation with FP32 on CPU")
264+
Eval_model(cp_file = 'checkpoint_cpu_rn50.pth', dataType = "fp32" , device=device)
265+
else:
266+
#For FP32 model on GPU
267+
print("Model evaluation with FP32 on GPU")
268+
acc_fp32, fp32_avg_latency = Eval_model(cp_file = 'checkpoint_gpu_rn50.pth', dataType = "fp32" , device=device)
269+
270+
#For BF16 model on GPU
271+
print("Model evaluation with BF16 on GPU")
272+
acc_bf16, bf16_avg_latency = Eval_model(cp_file = 'checkpoint_gpu_rn50.pth', dataType = "bf16" , device=device)
273+
274+
#Summary
275+
print("Summary")
276+
print(f'Inference average latecy for FP32 on GPU is: {fp32_avg_latency} ')
277+
print(f'Inference average latency for AMP BF16 on GPU is: {bf16_avg_latency} ')
278+
279+
speedup_from_amp_bf16 = fp32_avg_latency / bf16_avg_latency
280+
print("Inference with BF16 is %.2fX faster than FP32 on GPU" %speedup_from_amp_bf16)
281+
282+
283+
plt.figure()
284+
plt.title("ResNet50 Inference Latency Comparison")
285+
plt.xlabel("Test Case")
286+
plt.ylabel("Inference Latency per sample(ms)")
287+
plt.bar(["FP32 on GPU", "AMP BF16 on GPU"], [fp32_avg_latency, bf16_avg_latency])
288+
plt.savefig('./bf16speedup.png')
289+
290+
plt.figure()
291+
plt.title("Accuracy Comparison")
292+
plt.xlabel("Test Case")
293+
plt.ylabel("Accuracy(%)")
294+
plt.bar(["FP32 on GPU", "AMP BF16 on GPU"], [acc_fp32, acc_bf16])
295+
print(f'Accuracy drop with AMP BF16 is: {acc_fp32-acc_bf16}')
296+
plt.savefig('./accuracy.png')
297+
298+
print('[CODE_SAMPLE_COMPLETED_SUCCESFULLY]')
299+
300+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
Copyright Intel Corporation
2+
3+
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
4+
5+
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
6+
7+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

0 commit comments

Comments
 (0)