Skip to content

Commit c47c4d9

Browse files
authored
Re-enable learning_hybrid_frontend_through_example_tutorial.py (pytorch#377)
1 parent c64413c commit c47c4d9

File tree

2 files changed

+14
-21
lines changed

2 files changed

+14
-21
lines changed

.jenkins/build.sh

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,9 +51,6 @@ pushd audio
5151
python setup.py install
5252
popd
5353

54-
# We will fix the hybrid frontend tutorials when the API is stable
55-
rm beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py || true
56-
5754
aws configure set default.s3.multipart_threshold 5120MB
5855

5956
# Decide whether to parallelize tutorial builds, based on $JOB_BASE_NAME

beginner_source/hybrid_frontend/learning_hybrid_frontend_through_example_tutorial.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,21 +92,21 @@ class definitions.
9292
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
9393
9494
We can implement part one as a pure python function as below. Notice, to
95-
trace this function we add the ``@torch.jit.trace`` decorator. Since the
96-
trace requires a dummy input of the expected runtime type and shape, we
97-
also include the ``torch.rand`` to generate a single valued torch
98-
tensor.
95+
trace this function we call ``torch.jit.trace`` and pass in the function
96+
to be traced. Since the trace requires a dummy input of the expected
97+
runtime type and shape, we also include the ``torch.rand`` to generate a
98+
single valued torch tensor.
9999
100100
"""
101101

102102
import torch
103103

104-
# This is how you define a traced function
105-
# Pass in an example input to this decorator and then apply it to the function
106-
@torch.jit.trace(torch.rand(()))
107-
def traced_fn(x):
104+
def fn(x):
108105
return torch.abs(2*x)
109106

107+
# This is how you define a traced function
108+
# Pass in both the function to be traced and an example input to ``torch.jit.trace``
109+
traced_fn = torch.jit.trace(fn, torch.rand(()))
110110

111111
######################################################################
112112
# Part 2 - Scripting a pure python function
@@ -124,7 +124,7 @@ def traced_fn(x):
124124
@torch.jit.script
125125
def script_fn(x):
126126
z = torch.ones([1], dtype=torch.int64)
127-
for i in range(x):
127+
for i in range(int(x)):
128128
z = z * (i + 1)
129129
return z
130130

@@ -163,7 +163,7 @@ class ScriptModule(torch.jit.ScriptModule):
163163
@torch.jit.script_method
164164
def forward(self, x):
165165
r = -x
166-
if torch.fmod(x, 2.0) == 0.0:
166+
if int(torch.fmod(x, 2.0)) == 0.0:
167167
r = x / 2.0
168168
return r
169169

@@ -201,7 +201,7 @@ def __init__(self):
201201
# Modules must be attributes on the Module because if you want to trace
202202
# or script this Module, we must be able to inherit the submodules'
203203
# params.
204-
self.traced_module = torch.jit.trace(torch.rand(()))(TracedModule())
204+
self.traced_module = torch.jit.trace(TracedModule(), torch.rand(()))
205205
self.script_module = ScriptModule()
206206

207207
print('traced_fn graph', traced_fn.graph)
@@ -244,8 +244,6 @@ def forward(self, x):
244244
# Tracing the Top-Level Model
245245
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~
246246
#
247-
# **NOTE:** Open issue https://github.com/pytorch/pytorch/issues/8755
248-
#
249247
# The last part of the example is to trace the top-level module, ``Net``.
250248
# As mentioned previously, since the traced/scripted modules are
251249
# attributes of Net, we are able to trace ``Net`` as it inherits the
@@ -254,11 +252,9 @@ def forward(self, x):
254252
# Also, check out the graph that is created.
255253
#
256254

257-
# TODO: this fails with some weird bug https://github.com/pytorch/pytorch/issues/8755
258-
#n_traced = torch.jit.trace(torch.tensor([5]))(n)
259-
#print(n_traced(torch.tensor([5])))
260-
261-
# TODO: print the graph of the traced module
255+
n_traced = torch.jit.trace(n, torch.tensor([5]))
256+
print(n_traced(torch.tensor([5])))
257+
print('n_traced graph', n_traced.graph)
262258

263259

264260
######################################################################

0 commit comments

Comments
 (0)