Skip to content

Commit cb75add

Browse files
zdevitofacebook-github-bot
authored andcommitted
torch.package - a way to package models and code (pytorch#45015)
Summary: Pull Request resolved: pytorch#45015 torch.package allows you to write packages of code, pickled python data, and arbitrary binary and text resources into a self-contained package. torch.package.PackageExporter writes the packages and torch.package.PackageImporter reads them. The importers can load this code in a hermetic way, such that code is loaded from the package rather than the normal python import system. This allows for the packaging of PyTorch model code and data so that it can be run on a server or used in the future for transfer learning. The code contained in packages is copied file-by-file from the original source when it is created, and the file format is a specially organized zip file. Future users of the package can unzip the package, and edit the code in order to perform custom modifications to it. The importer for packages ensures that code in the module can only be loaded from within the package, except for modules explicitly listed as external using :method:`extern_module`. The file `extern_modules` in the zip archive lists all the modules that a package externally depends on. This prevents "implicit" dependencies where the package runs locally because it is importing a locally-installed package, but then fails when the package is copied to another machine. Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D23824337 Pulled By: zdevito fbshipit-source-id: 1247c34ba9b656f9db68a83e31f2a0fbe3bea6bd
1 parent d4a634c commit cb75add

File tree

15 files changed

+1439
-3
lines changed

15 files changed

+1439
-3
lines changed

test/module_a.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
result = 'module_a'

test/namespace_b/subpackage.py

Whitespace-only changes.

test/package_a/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
result = 'package_a'
2+
3+
class PackageAObject:
4+
__slots__ = ['obj']
5+
6+
def __init__(self, obj):
7+
self.obj = obj

test/package_a/subpackage.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
result = 'package_a.subpackage'
2+
class PackageASubpackageObject:
3+
pass

test/run_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,8 @@
8989
'test_determination',
9090
'test_futures',
9191
'test_fx',
92-
'test_functional_autograd_benchmark'
92+
'test_functional_autograd_benchmark',
93+
'test_package',
9394
]
9495

9596
WINDOWS_BLOCKLIST = [

test/test_package.py

Lines changed: 309 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
from unittest import main, skipIf
2+
from torch.testing._internal.common_utils import TestCase, IS_WINDOWS
3+
from tempfile import NamedTemporaryFile
4+
from torch.package import PackageExporter, PackageImporter
5+
from pathlib import Path
6+
from tempfile import TemporaryDirectory
7+
import torch
8+
from sys import version_info
9+
10+
try:
11+
from torchvision.models import resnet18
12+
HAS_TORCHVISION = True
13+
except ImportError:
14+
HAS_TORCHVISION = False
15+
skipIfNoTorchVision = skipIf(not HAS_TORCHVISION, "no torchvision")
16+
17+
18+
19+
packaging_directory = Path(__file__).parent
20+
21+
class PackagingTest(TestCase):
22+
def __init__(self, *args, **kwargs):
23+
super().__init__(*args, **kwargs)
24+
self._temporary_files = []
25+
26+
def temp(self):
27+
t = NamedTemporaryFile()
28+
name = t.name
29+
if IS_WINDOWS:
30+
t.close() # can't read an open file in windows
31+
else:
32+
self._temporary_files.append(t)
33+
return name
34+
35+
def tearDown(self):
36+
for t in self._temporary_files:
37+
t.close()
38+
self._temporary_files = []
39+
40+
def test_saving_source(self):
41+
filename = self.temp()
42+
with PackageExporter(filename, verbose=False) as he:
43+
he.save_source_file('foo', str(packaging_directory / 'module_a.py'))
44+
he.save_source_file('foodir', str(packaging_directory / 'package_a'))
45+
hi = PackageImporter(filename)
46+
foo = hi.import_module('foo')
47+
s = hi.import_module('foodir.subpackage')
48+
self.assertEqual(foo.result, 'module_a')
49+
self.assertEqual(s.result, 'package_a.subpackage')
50+
51+
def test_saving_string(self):
52+
filename = self.temp()
53+
with PackageExporter(filename, verbose=False) as he:
54+
src = """\
55+
import math
56+
the_math = math
57+
"""
58+
he.save_source_string('my_mod', src)
59+
hi = PackageImporter(filename)
60+
m = hi.import_module('math')
61+
import math
62+
self.assertIs(m, math)
63+
my_mod = hi.import_module('my_mod')
64+
self.assertIs(my_mod.math, math)
65+
66+
def test_save_module(self):
67+
filename = self.temp()
68+
with PackageExporter(filename, verbose=False) as he:
69+
import module_a
70+
import package_a
71+
he.save_module(module_a.__name__)
72+
he.save_module(package_a.__name__)
73+
hi = PackageImporter(filename)
74+
module_a_i = hi.import_module('module_a')
75+
self.assertEqual(module_a_i.result, 'module_a')
76+
self.assertIsNot(module_a, module_a_i)
77+
package_a_i = hi.import_module('package_a')
78+
self.assertEqual(package_a_i.result, 'package_a')
79+
self.assertIsNot(package_a_i, package_a)
80+
81+
def test_pickle(self):
82+
import package_a.subpackage
83+
obj = package_a.subpackage.PackageASubpackageObject()
84+
obj2 = package_a.PackageAObject(obj)
85+
86+
filename = self.temp()
87+
with PackageExporter(filename, verbose=False) as he:
88+
he.save_pickle('obj', 'obj.pkl', obj2)
89+
hi = PackageImporter(filename)
90+
91+
# check we got dependencies
92+
sp = hi.import_module('package_a.subpackage')
93+
# check we didn't get other stuff
94+
with self.assertRaises(ImportError):
95+
hi.import_module('module_a')
96+
97+
obj_loaded = hi.load_pickle('obj', 'obj.pkl')
98+
self.assertIsNot(obj2, obj_loaded)
99+
self.assertIsInstance(obj_loaded.obj, sp.PackageASubpackageObject)
100+
self.assertIsNot(package_a.subpackage.PackageASubpackageObject, sp.PackageASubpackageObject)
101+
102+
def test_resources(self):
103+
filename = self.temp()
104+
with PackageExporter(filename, verbose=False) as he:
105+
he.save_text('main', 'main', "my string")
106+
he.save_binary('main', 'main_binary', "my string".encode('utf-8'))
107+
src = """\
108+
import resources
109+
t = resources.load_text('main', 'main')
110+
b = resources.load_binary('main', 'main_binary')
111+
"""
112+
he.save_source_string('main', src, is_package=True)
113+
hi = PackageImporter(filename)
114+
m = hi.import_module('main')
115+
self.assertEqual(m.t, "my string")
116+
self.assertEqual(m.b, "my string".encode('utf-8'))
117+
118+
def test_extern(self):
119+
filename = self.temp()
120+
with PackageExporter(filename, verbose=False) as he:
121+
he.extern_modules(['package_a.subpackage', 'module_a'])
122+
he.save_module('package_a')
123+
hi = PackageImporter(filename)
124+
import package_a.subpackage
125+
import module_a
126+
127+
module_a_im = hi.import_module('module_a')
128+
hi.import_module('package_a.subpackage')
129+
package_a_im = hi.import_module('package_a')
130+
131+
self.assertIs(module_a, module_a_im)
132+
self.assertIsNot(package_a, package_a_im)
133+
self.assertIs(package_a.subpackage, package_a_im.subpackage)
134+
135+
@skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature')
136+
def test_mock(self):
137+
filename = self.temp()
138+
with PackageExporter(filename, verbose=False) as he:
139+
he.mock_modules(['package_a.subpackage', 'module_a'])
140+
he.save_module('package_a')
141+
hi = PackageImporter(filename)
142+
import package_a.subpackage
143+
_ = package_a.subpackage
144+
import module_a
145+
_ = module_a
146+
147+
m = hi.import_module('package_a.subpackage')
148+
r = m.result
149+
with self.assertRaisesRegex(NotImplementedError, 'was mocked out'):
150+
r()
151+
152+
@skipIf(version_info.major < 3 or version_info.minor < 7, 'mock uses __getattr__ a 3.7 feature')
153+
def test_custom_requires(self):
154+
filename = self.temp()
155+
156+
class Custom(PackageExporter):
157+
def require_module(self, name, dependencies):
158+
if name == 'module_a':
159+
self.mock_module('module_a')
160+
elif name == 'package_a':
161+
self.save_source_string('package_a', 'import module_a\nresult = 5\n')
162+
else:
163+
raise NotImplementedError('wat')
164+
165+
with Custom(filename, verbose=False) as he:
166+
he.save_source_string('main', 'import package_a\n')
167+
168+
hi = PackageImporter(filename)
169+
hi.import_module('module_a').should_be_mocked
170+
bar = hi.import_module('package_a')
171+
self.assertEqual(bar.result, 5)
172+
173+
@skipIfNoTorchVision
174+
def test_resnet(self):
175+
resnet = resnet18()
176+
177+
f1 = self.temp()
178+
179+
# create a package that will save it along with its code
180+
with PackageExporter(f1, verbose=False) as e:
181+
# put the pickled resnet in the package, by default
182+
# this will also save all the code files references by
183+
# the objects in the pickle
184+
e.save_pickle('model', 'model.pkl', resnet)
185+
186+
# we can now load the saved model
187+
i = PackageImporter(f1)
188+
r2 = i.load_pickle('model', 'model.pkl')
189+
190+
# test that it works
191+
input = torch.rand(1, 3, 224, 224)
192+
ref = resnet(input)
193+
self.assertTrue(torch.allclose(r2(input), ref))
194+
195+
# functions exist also to get at the private modules in each package
196+
torchvision = i.import_module('torchvision')
197+
198+
f2 = self.temp()
199+
# if we are doing transfer learning we might want to re-save
200+
# things that were loaded from a package
201+
with PackageExporter(f2, verbose=False) as e:
202+
# We need to tell the exporter about any modules that
203+
# came from imported packages so that it can resolve
204+
# class names like torchvision.models.resnet.ResNet
205+
# to their source code.
206+
207+
e.importers.insert(0, i.import_module)
208+
209+
# e.importers is a list of module importing functions
210+
# that by default contains importlib.import_module.
211+
# it is searched in order until the first success and
212+
# that module is taken to be what torchvision.models.resnet
213+
# should be in this code package. In the case of name collisions,
214+
# such as trying to save a ResNet from two different packages,
215+
# we take the first thing found in the path, so only ResNet objects from
216+
# one importer will work. This avoids a bunch of name mangling in
217+
# the source code. If you need to actually mix ResNet objects,
218+
# we suggest reconstructing the model objects using code from a single package
219+
# using functions like save_state_dict and load_state_dict to transfer state
220+
# to the correct code objects.
221+
e.save_pickle('model', 'model.pkl', r2)
222+
223+
i2 = PackageImporter(f2)
224+
r3 = i2.load_pickle('model', 'model.pkl')
225+
self.assertTrue(torch.allclose(r3(input), ref))
226+
227+
# test we can load from a directory
228+
import zipfile
229+
zf = zipfile.ZipFile(f1, 'r')
230+
231+
with TemporaryDirectory() as td:
232+
zf.extractall(path=td)
233+
iz = PackageImporter(str(Path(td) / Path(f1).name))
234+
r4 = iz.load_pickle('model', 'model.pkl')
235+
self.assertTrue(torch.allclose(r4(input), ref))
236+
237+
@skipIfNoTorchVision
238+
def test_model_save(self):
239+
240+
# This example shows how you might package a model
241+
# so that the creator of the model has flexibility about
242+
# how they want to save it but the 'server' can always
243+
# use the same API to load the package.
244+
245+
# The convension is for each model to provide a
246+
# 'model' package with a 'load' function that actual
247+
# reads the model out of the archive.
248+
249+
# How the load function is implemented is up to the
250+
# the packager.
251+
252+
# get our normal torchvision resnet
253+
resnet = resnet18()
254+
255+
256+
f1 = self.temp()
257+
# Option 1: save by pickling the whole model
258+
# + single-line, similar to torch.jit.save
259+
# - more difficult to edit the code after the model is created
260+
with PackageExporter(f1, verbose=False) as e:
261+
e.save_pickle('model', 'pickled', resnet)
262+
# note that this source is the same for all models in this approach
263+
# so it can be made part of an API that just takes the model and
264+
# packages it with this source.
265+
src = """\
266+
import resources # gives you access to the importer from within the package
267+
268+
# server knows to call model.load() to get the model,
269+
# maybe in the future it passes options as arguments by convension
270+
def load():
271+
return resources.load_pickle('model', 'pickled')
272+
"""
273+
e.save_source_string('model', src, is_package=True)
274+
275+
f2 = self.temp()
276+
# Option 2: save with state dict
277+
# - more code to write to save/load the model
278+
# + but this code can be edited later to adjust adapt the model later
279+
with PackageExporter(f2, verbose=False) as e:
280+
e.save_pickle('model', 'state_dict', resnet.state_dict())
281+
src = """\
282+
import resources # gives you access to the importer from within the package
283+
from torchvision.models.resnet import resnet18
284+
def load():
285+
# if you want, you can later edit how resnet is constructed here
286+
# to edit the model in the package, while still loading the original
287+
# state dict weights
288+
r = resnet18()
289+
state_dict = resources.load_pickle('model', 'state_dict')
290+
r.load_state_dict(state_dict)
291+
return r
292+
"""
293+
e.save_source_string('model', src, is_package=True)
294+
295+
296+
297+
# regardless of how we chose to package, we can now use the model in a server in the same way
298+
input = torch.rand(1, 3, 224, 224)
299+
results = []
300+
for m in [f1, f2]:
301+
importer = PackageImporter(m)
302+
the_model = importer.import_module('model').load()
303+
r = the_model(input)
304+
results.append(r)
305+
306+
self.assertTrue(torch.allclose(*results))
307+
308+
if __name__ == '__main__':
309+
main()

torch/package/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .importer import PackageImporter
2+
from .exporter import PackageExporter

0 commit comments

Comments
 (0)