Skip to content

Commit 4cc8a76

Browse files
author
min-jean-cho
committed
Merge branch 'launcher' of https://github.com/min-jean-cho/serve into launcher
2 parents 506fc08 + d1e7104 commit 4cc8a76

File tree

3 files changed

+110
-57
lines changed

3 files changed

+110
-57
lines changed

binaries/conda/build_packages.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@
88
MINICONDA_DOWNLOAD_URL = "https://repo.anaconda.com/miniconda/Miniconda3-py39_4.9.2-Linux-x86_64.sh"
99
CONDA_BINARY = os.popen("which conda").read().strip() if os.system(f"conda --version") == 0 else f"$HOME/miniconda/condabin/conda"
1010

11+
if os.name == "nt":
12+
#Assumes miniconda is installed in windows
13+
CONDA_BINARY = "conda"
14+
1115
def install_conda_build():
1216
"""
1317
Install conda-build, required to create conda packages
@@ -24,6 +28,9 @@ def install_miniconda():
2428
if exit_code == 0:
2529
print(f"'conda' already present on the system. Proceeding without a fresh minconda installation.")
2630
return
31+
if os.name == "nt":
32+
print("Identified as Windows system. Please install miniconda using this URL: https://repo.anaconda.com/miniconda/Miniconda3-latest-Windows-x86_64.exe")
33+
return
2734

2835
os.system(f"rm -rf $HOME/miniconda")
2936
exit_code = os.system(f"wget {MINICONDA_DOWNLOAD_URL} -O ~/miniconda.sh")

binaries/upload.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def upload_conda_packages():
4646
"""
4747

4848
# Identify *.tar.bz2 files to upload
49-
anaconda_token = os.environ[CONDA_TOKEN_ENV_VARIABLE]
49+
anaconda_token = os.getenv(CONDA_TOKEN_ENV_VARIABLE)
5050

5151
for root, _, files in os.walk(CONDA_PACKAGES_PATH):
5252
for name in files:

examples/intel_extension_for_pytorch/README.md

Lines changed: 102 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,15 @@
11
# TorchServe with Intel® Extension for PyTorch*
22

3-
TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware.
3+
TorchServe can be used with Intel® Extension for PyTorch* (IPEX) to give performance boost on Intel hardware<sup>1</sup>.
44
Here we show how to use TorchServe with IPEX.
55

6+
<sup>1. While IPEX benefits all platforms, plaforms with AVX512 benefit the most. </sup>
7+
68
## Contents of this Document
79
* [Install Intel Extension for PyTorch](#install-intel-extension-for-pytorch)
810
* [Serving model with Intel Extension for PyTorch](#serving-model-with-intel-extension-for-pytorch)
11+
* [TorchServe with Launcher](#torchserve-with-launcher)
912
* [Creating and Exporting INT8 model for IPEX](#creating-and-exporting-int8-model-for-ipex)
10-
* [Torchserve with Launcher](#torchserve-with-launcher)
1113
* [Benchmarking with Launcher](#benchmarking-with-launcher)
1214

1315

@@ -19,7 +21,50 @@ After installation, all it needs to be done to use TorchServe with IPEX is to en
1921
```
2022
ipex_enable=true
2123
```
22-
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). Torchserve with IPEX can deploy any model and do inference.
24+
Once IPEX is enabled, deploying PyTorch model follows the same procedure shown [here](https://pytorch.org/serve/use_cases.html). TorchServe with IPEX can deploy any model and do inference.
25+
26+
## TorchServe with Launcher
27+
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for details on performance tuning with launcher.
28+
29+
All it needs to be done to use TorchServe with launcher is to set its configuration in `config.properties`.
30+
31+
Add the following lines in `config.properties` to use launcher with its default configuration.
32+
```
33+
ipex_enable=true
34+
cpu_launcher_enable=true
35+
```
36+
37+
Launcher by default uses `numactl` if its installed to ensure socket is pinned and thus memory is allocated from local numa node. To use launcher without numactl, add the following lines in `config.properties`.
38+
```
39+
ipex_enable=true
40+
cpu_launcher_enable=true
41+
cpu_launcher_args=--disable_numactl
42+
```
43+
44+
Launcher by default uses only non-hyperthreaded cores if hyperthreading is present to avoid core compute resource sharing. To use launcher with all cores, both physical and logical, add the following lines in `config.properties`.
45+
```
46+
ipex_enable=true
47+
cpu_launcher_enable=true
48+
cpu_launcher_args=--use_logical_core
49+
```
50+
51+
Below is an example of passing multiple args to `cpu_launcher_args`.
52+
```
53+
ipex_enable=true
54+
cpu_launcher_enable=true
55+
cpu_launcher_args=--use_logical_core --disable_numactl
56+
```
57+
58+
Some useful `cpu_launcher_args` to note are:
59+
1. Memory Allocator: [ PTMalloc `--use_default_allocator` | *TCMalloc `--enable_tcmalloc`* | JeMalloc `--enable_jemalloc`]
60+
* PyTorch by defualt uses PTMalloc. TCMalloc/JeMalloc generally gives better performance.
61+
2. OpenMP library: [GNU OpenMP `--disable_iomp` | *Intel OpenMP*]
62+
* PyTorch by default uses GNU OpenMP. Launcher by default uses Intel OpenMP. Intel OpenMP library generally gives better performance.
63+
3. Socket id: [`--socket_id`]
64+
* Launcher by default uses all physical cores. Limit memory access to local memories on the Nth socket to avoid Non-Uniform Memory Access (NUMA).
65+
66+
Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher.
67+
2368

2469
## Creating and Exporting INT8 model for IPEX
2570
Intel Extension for PyTorch supports both eager and torchscript mode. In this section, we show how to deploy INT8 model for IPEX.
@@ -30,11 +75,10 @@ First create `.pt` serialized file using IPEX INT8 inference. Here we show two e
3075
#### BERT
3176

3277
```
78+
import torch
3379
import intel_extension_for_pytorch as ipex
34-
from transformers import AutoModelForSequenceClassification, AutoConfig
3580
import transformers
36-
from datasets import load_dataset
37-
import torch
81+
from transformers import AutoModelForSequenceClassification, AutoConfig
3882
3983
# load the model
4084
config = AutoConfig.from_pretrained(
@@ -43,99 +87,101 @@ model = AutoModelForSequenceClassification.from_pretrained(
4387
"bert-base-uncased", config=config)
4488
model = model.eval()
4589
46-
max_length = 384
47-
dummy_tensor = torch.ones((1, max_length), dtype=torch.long)
48-
jit_inputs = (dummy_tensor, dummy_tensor, dummy_tensor)
49-
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
50-
90+
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing
91+
N, max_length = 1, 384
92+
dummy_tensor = torch.ones((N, max_length), dtype=torch.long)
5193
5294
# calibration
53-
n_iter = 100
95+
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric
96+
# default qscheme is torch.per_tensor_affine
97+
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_affine)
98+
n_iter = 100
5499
with torch.no_grad():
55100
for i in range(n_iter):
56101
with ipex.quantization.calibrate(conf):
57102
model(dummy_tensor, dummy_tensor, dummy_tensor)
58103
59-
# optionally save the configuraiton for later use
60-
conf.save(‘model_conf.json’, default_recipe=True)
104+
# optionally save the configuraiton for later use
105+
# save:
106+
# conf.save("model_conf.json")
107+
# load:
108+
# conf = ipex.quantization.QuantConf("model_conf.json")
61109
62110
# conversion
111+
jit_inputs = (dummy_tensor, dummy_tensor, dummy_tensor)
63112
model = ipex.quantization.convert(model, conf, jit_inputs)
64113
114+
# enable fusion path work(need to run forward propagation twice)
115+
with torch.no_grad():
116+
y = model(dummy_tensor,dummy_tensor,dummy_tensor)
117+
y = model(dummy_tensor,dummy_tensor,dummy_tensor)
118+
65119
# save to .pt
66120
torch.jit.save(model, 'bert_int8_jit.pt')
67121
```
68122

69123
#### ResNet50
70124

71125
```
72-
import intel_extension_for_pytorch as ipex
73-
import torchvision.models as models
74126
import torch
75127
import torch.fx.experimental.optimization as optimization
76-
from copy import deepcopy
77-
128+
import intel_extension_for_pytorch as ipex
129+
import torchvision.models as models
78130
131+
# load the model
79132
model = models.resnet50(pretrained=True)
80133
model = model.eval()
134+
model = optimization.fuse(model)
81135
82-
C, H, W = 3, 224, 224
83-
dummy_tensor = torch.randn(1, C, H, W).contiguous(memory_format=torch.channels_last)
84-
jit_inputs = (dummy_tensor)
85-
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_symmetric)
136+
# define dummy input tensor to use for the model's forward call to record operations in the model for tracing
137+
N, C, H, W = 1, 3, 224, 224
138+
dummy_tensor = torch.randn(N, C, H, W).contiguous(memory_format=torch.channels_last)
86139
87-
n_iter = 100
140+
# calibration
141+
# ipex supports two quantization schemes to be used for activation: torch.per_tensor_affine and torch.per_tensor_symmetric
142+
# default qscheme is torch.per_tensor_affine
143+
conf = ipex.quantization.QuantConf(qscheme=torch.per_tensor_symmetric)
144+
n_iter = 100
88145
with torch.no_grad():
89-
for i in range(n_iter):
90-
with ipex.quantization.calibrate(conf):
91-
model(dummy_tensor)
146+
for i in range(n_iter):
147+
with ipex.quantization.calibrate(conf):
148+
model(dummy_tensor)
149+
150+
# optionally save the configuraiton for later use
151+
# save:
152+
# conf.save("model_conf.json")
153+
# load:
154+
# conf = ipex.quantization.QuantConf("model_conf.json")
92155
156+
# conversion
157+
jit_inputs = (dummy_tensor)
93158
model = ipex.quantization.convert(model, conf, jit_inputs)
159+
160+
# enable fusion path work(need to run two iterations)
161+
with torch.no_grad():
162+
y = model(dummy_tensor)
163+
y = model(dummy_tensor)
164+
165+
# save to .pt
94166
torch.jit.save(model, 'rn50_int8_jit.pt')
95167
```
168+
96169
### 2. Creating a Model Archive
97170
Once the serialized file ( `.pt`) is created, it can be used with `torch-model-archiver` as ususal. Use the following command to package the model.
98171
```
99172
torch-model-archiver --model-name rn50_ipex_int8 --version 1.0 --serialized-file rn50_int8_jit.pt --handler image_classifier
100173
```
101-
### 3. Start Torchserve to serve the model
102-
Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start Torchserve with IPEX.
174+
### 3. Start TorchServe to serve the model
175+
Make sure to set `ipex_enable=true` in `config.properties`. Use the following command to start TorchServe with IPEX.
103176
```
104177
torchserve --start --ncs --model-store model_store --ts-config config.properties
105178
```
106179

107180
### 4. Registering and Deploying model
108181
Registering and deploying the model follows the same steps shown [here](https://pytorch.org/serve/use_cases.html).
109182

110-
## Torchserve with Launcher
111-
Launcher is a script to automate the process of tunining configuration setting on intel hardware to boost performance. Tuning configurations such as OMP_NUM_THREADS, thread affininty, memory allocator can have a dramatic effect on performance. Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) and [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/tuning_guide.md) for details on performance tuning with launcher.
112-
113-
All it needs to be done to use Torchserve with launcher is to set its configuration in `config.properties`.
114-
115-
116-
Add the following lines in `config.properties` to use launcher with its default configuration.
117-
```
118-
ipex_enable=true
119-
cpu_launcher_enable=true
120-
```
121-
122-
Launcher by default uses `numactl` if its installed to ensure socket is pinned and thus memory is allocated from local numa node. To use launcher without numactl, add the following lines in `config.properties`.
123-
```
124-
ipex_enable=true
125-
cpu_launcher_enable=true
126-
cpu_launcher_args=--disable_numactl
127-
```
128-
129-
Launcher by default uses only non-hyperthreaded cores if hyperthreading is present to avoid core compute resource sharing. To use launcher with all cores, both physical and logical, add the following lines in `config.properties`.
130-
```
131-
ipex_enable=true
132-
cpu_launcher_enable=true
133-
cpu_launcher_args=--use_logical_core
134-
```
135-
Please refer to [here](https://github.com/intel/intel-extension-for-pytorch/blob/master/docs/tutorials/performance_tuning/launch_script.md) for a full list of tunable configuration of launcher.
136-
137183
## Benchmarking with Launcher
138-
Launcher can be used with Torchserve official [benchmark](https://github.com/pytorch/serve/tree/master/benchmarks) to launch server and benchmark requests with optimal configuration on Intel hardware.
184+
Launcher can be used with TorchServe official [benchmark](https://github.com/pytorch/serve/tree/master/benchmarks) to launch server and benchmark requests with optimal configuration on Intel hardware.
139185

140186
In this section we provide examples of benchmarking with launcher with its default configuration.
141187

0 commit comments

Comments
 (0)