@@ -92,21 +92,21 @@ class definitions.
92
92
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
93
93
94
94
We can implement part one as a pure python function as below. Notice, to
95
- trace this function we add the ``@ torch.jit.trace`` decorator. Since the
96
- trace requires a dummy input of the expected runtime type and shape, we
97
- also include the ``torch.rand`` to generate a single valued torch
98
- tensor.
95
+ trace this function we call `` torch.jit.trace`` and pass in the function
96
+ to be traced. Since the trace requires a dummy input of the expected
97
+ runtime type and shape, we also include the ``torch.rand`` to generate a
98
+ single valued torch tensor.
99
99
100
100
"""
101
101
102
102
import torch
103
103
104
- # This is how you define a traced function
105
- # Pass in an example input to this decorator and then apply it to the function
106
- @torch .jit .trace (torch .rand (()))
107
- def traced_fn (x ):
104
+ def fn (x ):
108
105
return torch .abs (2 * x )
109
106
107
+ # This is how you define a traced function
108
+ # Pass in both the function to be traced and an example input to ``torch.jit.trace``
109
+ traced_fn = torch .jit .trace (fn , torch .rand (()))
110
110
111
111
######################################################################
112
112
# Part 2 - Scripting a pure python function
@@ -124,7 +124,7 @@ def traced_fn(x):
124
124
@torch .jit .script
125
125
def script_fn (x ):
126
126
z = torch .ones ([1 ], dtype = torch .int64 )
127
- for i in range (x ):
127
+ for i in range (int ( x ) ):
128
128
z = z * (i + 1 )
129
129
return z
130
130
@@ -163,7 +163,7 @@ class ScriptModule(torch.jit.ScriptModule):
163
163
@torch .jit .script_method
164
164
def forward (self , x ):
165
165
r = - x
166
- if torch .fmod (x , 2.0 ) == 0.0 :
166
+ if int ( torch .fmod (x , 2.0 ) ) == 0.0 :
167
167
r = x / 2.0
168
168
return r
169
169
@@ -201,7 +201,7 @@ def __init__(self):
201
201
# Modules must be attributes on the Module because if you want to trace
202
202
# or script this Module, we must be able to inherit the submodules'
203
203
# params.
204
- self .traced_module = torch .jit .trace (torch .rand (()))( TracedModule ( ))
204
+ self .traced_module = torch .jit .trace (TracedModule (), torch .rand (()))
205
205
self .script_module = ScriptModule ()
206
206
207
207
print ('traced_fn graph' , traced_fn .graph )
@@ -244,8 +244,6 @@ def forward(self, x):
244
244
# Tracing the Top-Level Model
245
245
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
246
246
#
247
- # **NOTE:** Open issue https://github.com/pytorch/pytorch/issues/8755
248
- #
249
247
# The last part of the example is to trace the top-level module, ``Net``.
250
248
# As mentioned previously, since the traced/scripted modules are
251
249
# attributes of Net, we are able to trace ``Net`` as it inherits the
@@ -254,11 +252,9 @@ def forward(self, x):
254
252
# Also, check out the graph that is created.
255
253
#
256
254
257
- # TODO: this fails with some weird bug https://github.com/pytorch/pytorch/issues/8755
258
- #n_traced = torch.jit.trace(torch.tensor([5]))(n)
259
- #print(n_traced(torch.tensor([5])))
260
-
261
- # TODO: print the graph of the traced module
255
+ n_traced = torch .jit .trace (n , torch .tensor ([5 ]))
256
+ print (n_traced (torch .tensor ([5 ])))
257
+ print ('n_traced graph' , n_traced .graph )
262
258
263
259
264
260
######################################################################
0 commit comments