Skip to content

Commit 79aa174

Browse files
zou3519pytorchmergebot
authored andcommitted
[dynamo] ctx_manager.py: replace unimplemented with unimplemented_v2 (#148570)
Pull Request resolved: #148570 Approved by: https://github.com/williamwen42 ghstack dependencies: #148454
1 parent e7bc1d1 commit 79aa174

File tree

1 file changed

+48
-41
lines changed

1 file changed

+48
-41
lines changed

torch/_dynamo/variables/ctx_manager.py

Lines changed: 48 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,14 @@
2929
import torch._C
3030
from torch._guards import Guard
3131

32-
from .. import variables
32+
from .. import graph_break_hints, variables
3333
from ..bytecode_transformation import (
3434
create_call_function,
3535
create_instruction,
3636
create_setup_with,
3737
)
3838
from ..device_interface import get_interface_for_device
39-
from ..exc import unimplemented, Unsupported
39+
from ..exc import unimplemented_v2
4040
from ..guards import GuardBuilder, install_guard
4141
from ..source import AttrSource, GlobalStateSource
4242
from .base import VariableTracker
@@ -173,40 +173,27 @@ def fn_name(self):
173173

174174
def enter(self, tx):
175175
source = None if self.source is None else AttrSource(self.source, "__enter__")
176-
try:
177-
return variables.UserMethodVariable(
178-
self.cm_obj.__enter__.__func__,
179-
self,
180-
source=source,
181-
).call_function(tx, [], {})
182-
except Unsupported as e:
183-
unimplemented(
184-
f"Unsupported context manager {self.cm_obj}'s __enter__ function",
185-
from_exc=e,
186-
)
176+
return variables.UserMethodVariable(
177+
self.cm_obj.__enter__.__func__,
178+
self,
179+
source=source,
180+
).call_function(tx, [], {})
187181

188182
def exit(self, tx: "InstructionTranslator", *args):
189183
source = None if self.source is None else AttrSource(self.source, "__exit__")
190-
try:
191-
x = variables.UserMethodVariable(
192-
self.cm_obj.__exit__.__func__,
193-
self,
194-
source=source,
195-
).call_function(
196-
tx,
197-
[
198-
variables.ConstantVariable.create(None),
199-
variables.ConstantVariable.create(None),
200-
variables.ConstantVariable.create(None),
201-
],
202-
{},
203-
)
204-
except Unsupported as e:
205-
unimplemented(
206-
f"Unsupported context manager {self.cm_obj}'s __exit__ function",
207-
from_exc=e,
208-
)
209-
184+
x = variables.UserMethodVariable(
185+
self.cm_obj.__exit__.__func__,
186+
self,
187+
source=source,
188+
).call_function(
189+
tx,
190+
[
191+
variables.ConstantVariable.create(None),
192+
variables.ConstantVariable.create(None),
193+
variables.ConstantVariable.create(None),
194+
],
195+
{},
196+
)
210197
tx.active_generic_context_managers.pop()
211198
return x
212199

@@ -921,11 +908,13 @@ def fn_name(self):
921908
return "nullcontext"
922909

923910
def reconstruct(self, cg):
924-
unimplemented(
925-
"""
926-
Dynamo doesn't support compiling a region that leaks torch profiler context
927-
objects which will be used outside the region
928-
"""
911+
unimplemented_v2(
912+
gb_type="torch.profiler object escaped from compiled region",
913+
context=str(self),
914+
explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.",
915+
hints=[
916+
*graph_break_hints.SUPPORTABLE,
917+
],
929918
)
930919

931920

@@ -1043,8 +1032,16 @@ def exit(self, tx: "InstructionTranslator", *args):
10431032
).call_function(tx, [self.tensors, self.prev_versions], {})
10441033

10451034
def reconstruct(self, codegen):
1046-
unimplemented(
1047-
"torch.autograd._unsafe_preserve_version_counter with graph break"
1035+
unimplemented_v2(
1036+
gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region",
1037+
context=str(self),
1038+
explanation=(
1039+
"Dynamo doesn't support compiling a region that returns "
1040+
"a torch.autograd._unsafe_preserve_version_counter context manager."
1041+
),
1042+
hints=[
1043+
*graph_break_hints.SUPPORTABLE,
1044+
],
10481045
)
10491046

10501047

@@ -1292,7 +1289,17 @@ def call_method(
12921289
),
12931290
)
12941291
else:
1295-
unimplemented(f"event method {name} unsupported")
1292+
unimplemented_v2(
1293+
gb_type="Unsupported torch.cuda.Event method",
1294+
context=str(name),
1295+
explanation=(
1296+
f"Dynamo doesn't support tracing the torch.cuda.Event.{name} method. "
1297+
f"We currently support wait, record, synchronize, and query.",
1298+
),
1299+
hints=[
1300+
*graph_break_hints.SUPPORTABLE,
1301+
],
1302+
)
12961303

12971304
def as_proxy(self):
12981305
return self.proxy

0 commit comments

Comments
 (0)