|
29 | 29 | import torch._C
|
30 | 30 | from torch._guards import Guard
|
31 | 31 |
|
32 |
| -from .. import variables |
| 32 | +from .. import graph_break_hints, variables |
33 | 33 | from ..bytecode_transformation import (
|
34 | 34 | create_call_function,
|
35 | 35 | create_instruction,
|
36 | 36 | create_setup_with,
|
37 | 37 | )
|
38 | 38 | from ..device_interface import get_interface_for_device
|
39 |
| -from ..exc import unimplemented, Unsupported |
| 39 | +from ..exc import unimplemented_v2 |
40 | 40 | from ..guards import GuardBuilder, install_guard
|
41 | 41 | from ..source import AttrSource, GlobalStateSource
|
42 | 42 | from .base import VariableTracker
|
@@ -173,40 +173,27 @@ def fn_name(self):
|
173 | 173 |
|
174 | 174 | def enter(self, tx):
|
175 | 175 | source = None if self.source is None else AttrSource(self.source, "__enter__")
|
176 |
| - try: |
177 |
| - return variables.UserMethodVariable( |
178 |
| - self.cm_obj.__enter__.__func__, |
179 |
| - self, |
180 |
| - source=source, |
181 |
| - ).call_function(tx, [], {}) |
182 |
| - except Unsupported as e: |
183 |
| - unimplemented( |
184 |
| - f"Unsupported context manager {self.cm_obj}'s __enter__ function", |
185 |
| - from_exc=e, |
186 |
| - ) |
| 176 | + return variables.UserMethodVariable( |
| 177 | + self.cm_obj.__enter__.__func__, |
| 178 | + self, |
| 179 | + source=source, |
| 180 | + ).call_function(tx, [], {}) |
187 | 181 |
|
188 | 182 | def exit(self, tx: "InstructionTranslator", *args):
|
189 | 183 | source = None if self.source is None else AttrSource(self.source, "__exit__")
|
190 |
| - try: |
191 |
| - x = variables.UserMethodVariable( |
192 |
| - self.cm_obj.__exit__.__func__, |
193 |
| - self, |
194 |
| - source=source, |
195 |
| - ).call_function( |
196 |
| - tx, |
197 |
| - [ |
198 |
| - variables.ConstantVariable.create(None), |
199 |
| - variables.ConstantVariable.create(None), |
200 |
| - variables.ConstantVariable.create(None), |
201 |
| - ], |
202 |
| - {}, |
203 |
| - ) |
204 |
| - except Unsupported as e: |
205 |
| - unimplemented( |
206 |
| - f"Unsupported context manager {self.cm_obj}'s __exit__ function", |
207 |
| - from_exc=e, |
208 |
| - ) |
209 |
| - |
| 184 | + x = variables.UserMethodVariable( |
| 185 | + self.cm_obj.__exit__.__func__, |
| 186 | + self, |
| 187 | + source=source, |
| 188 | + ).call_function( |
| 189 | + tx, |
| 190 | + [ |
| 191 | + variables.ConstantVariable.create(None), |
| 192 | + variables.ConstantVariable.create(None), |
| 193 | + variables.ConstantVariable.create(None), |
| 194 | + ], |
| 195 | + {}, |
| 196 | + ) |
210 | 197 | tx.active_generic_context_managers.pop()
|
211 | 198 | return x
|
212 | 199 |
|
@@ -921,11 +908,13 @@ def fn_name(self):
|
921 | 908 | return "nullcontext"
|
922 | 909 |
|
923 | 910 | def reconstruct(self, cg):
|
924 |
| - unimplemented( |
925 |
| - """ |
926 |
| -Dynamo doesn't support compiling a region that leaks torch profiler context |
927 |
| -objects which will be used outside the region |
928 |
| -""" |
| 911 | + unimplemented_v2( |
| 912 | + gb_type="torch.profiler object escaped from compiled region", |
| 913 | + context=str(self), |
| 914 | + explanation="Dynamo doesn't support compiling a region that returns a torch.profiler context manager.", |
| 915 | + hints=[ |
| 916 | + *graph_break_hints.SUPPORTABLE, |
| 917 | + ], |
929 | 918 | )
|
930 | 919 |
|
931 | 920 |
|
@@ -1043,8 +1032,16 @@ def exit(self, tx: "InstructionTranslator", *args):
|
1043 | 1032 | ).call_function(tx, [self.tensors, self.prev_versions], {})
|
1044 | 1033 |
|
1045 | 1034 | def reconstruct(self, codegen):
|
1046 |
| - unimplemented( |
1047 |
| - "torch.autograd._unsafe_preserve_version_counter with graph break" |
| 1035 | + unimplemented_v2( |
| 1036 | + gb_type="torch.autograd._unsafe_preserve_version_counter escaped from compiled region", |
| 1037 | + context=str(self), |
| 1038 | + explanation=( |
| 1039 | + "Dynamo doesn't support compiling a region that returns " |
| 1040 | + "a torch.autograd._unsafe_preserve_version_counter context manager." |
| 1041 | + ), |
| 1042 | + hints=[ |
| 1043 | + *graph_break_hints.SUPPORTABLE, |
| 1044 | + ], |
1048 | 1045 | )
|
1049 | 1046 |
|
1050 | 1047 |
|
@@ -1292,7 +1289,17 @@ def call_method(
|
1292 | 1289 | ),
|
1293 | 1290 | )
|
1294 | 1291 | else:
|
1295 |
| - unimplemented(f"event method {name} unsupported") |
| 1292 | + unimplemented_v2( |
| 1293 | + gb_type="Unsupported torch.cuda.Event method", |
| 1294 | + context=str(name), |
| 1295 | + explanation=( |
| 1296 | + f"Dynamo doesn't support tracing the torch.cuda.Event.{name} method. " |
| 1297 | + f"We currently support wait, record, synchronize, and query.", |
| 1298 | + ), |
| 1299 | + hints=[ |
| 1300 | + *graph_break_hints.SUPPORTABLE, |
| 1301 | + ], |
| 1302 | + ) |
1296 | 1303 |
|
1297 | 1304 | def as_proxy(self):
|
1298 | 1305 | return self.proxy
|
|
0 commit comments