Skip to content

Commit ac0a3cc

Browse files
t-vifacebook-github-bot
authored andcommitted
Merge CompilationUnit from torch._C and torch.jit (pytorch#50614)
Summary: This simplifies our handling and allows passing CompilationUnits from Python to C++ defined functions via PyBind easily. Discussed on Slack with SplitInfinity Pull Request resolved: pytorch#50614 Reviewed By: anjali411 Differential Revision: D25938005 Pulled By: SplitInfinity fbshipit-source-id: 94aadf0c063ddfef7ca9ea17bfa998d8e7b367ad
1 parent 5e79b8e commit ac0a3cc

File tree

5 files changed

+121
-28
lines changed

5 files changed

+121
-28
lines changed

test/jit/test_python_bindings.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import torch
2+
from torch.testing._internal.jit_utils import JitTestCase
3+
4+
if __name__ == "__main__":
5+
raise RuntimeError(
6+
"This test file is not meant to be run directly, use:\n\n"
7+
"\tpython test/test_jit.py TestPythonBindings\n\n"
8+
"instead."
9+
)
10+
11+
12+
class TestPythonBindings(JitTestCase):
13+
def test_cu_get_functions(self):
14+
@torch.jit.script
15+
def test_get_python_cu_fn(x: torch.Tensor):
16+
return 2 * x
17+
18+
cu = torch.jit._state._python_cu
19+
self.assertTrue(
20+
"test_get_python_cu_fn" in (str(fn.name) for fn in cu.get_functions())
21+
)
22+
23+
def test_cu_create_function(self):
24+
@torch.jit.script
25+
def fn(x: torch.Tensor):
26+
return 2 * x
27+
28+
cu = torch._C.CompilationUnit()
29+
cu.create_function("test_fn", fn.graph)
30+
31+
inp = torch.randn(5)
32+
33+
self.assertEqual(inp * 2, cu.find_function("test_fn")(inp))
34+
self.assertEqual(cu.find_function("doesnt_exist"), None)
35+
self.assertEqual(inp * 2, cu.test_fn(inp))
36+
with self.assertRaises(AttributeError):
37+
cu.doesnt_exist(inp)

test/test_jit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from jit.test_peephole import TestPeephole # noqa: F401
2424
from jit.test_save_load import TestSaveLoad # noqa: F401
2525
from jit.test_module_containers import TestModuleContainers # noqa: F401
26+
from jit.test_python_bindings import TestPythonBindings # noqa: F401
2627
from jit.test_python_ir import TestPythonIr # noqa: F401
2728
from jit.test_functional_blocks import TestFunctionalBlocks # noqa: F401
2829
from jit.test_remove_mutation import TestRemoveMutation # noqa: F401

torch/_C/__init__.pyi.in

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,13 @@ class ErrorReport:
415415
def call_stack() -> str: ...
416416

417417
class CompilationUnit:
418-
def __init__(self) -> None: ...
418+
def __init__(self, lang: str=..., _frames_up: _int=...) -> None: ...
419419
def find_function(self, name: str) -> ScriptFunction: ...
420-
def define(self, script: str, rcb: ResolutionCallback): ...
420+
def __getattr__(self, name: str) -> ScriptFunction: ...
421+
def define(self, script: str, rcb: ResolutionCallback=..., _frames_up: _int=...): ...
421422
def get_interface(self, name: str) -> InterfaceType: ...
423+
def get_functions(self) -> List[ScriptFunction]: ...
424+
def create_function(self, name: str, graph: Graph, shouldMangle: _bool=...) -> ScriptFunction: ...
422425

423426
class ScriptModule:
424427
def setattr(self, name: str, value: Any): ...
@@ -429,6 +432,7 @@ class ScriptFunction:
429432
def __call__(self, *args, **kwargs) -> Tensor: ...
430433
def save(self, filename: str, _extra_files: Dict[str, bytes]) -> None: ...
431434
def save_to_buffer(self, _extra_files: Dict[str, bytes]) -> bytes: ...
435+
@property
432436
def graph(self) -> Graph: ...
433437
def inlined_graph(self) -> Graph: ...
434438
def schema(self) -> FunctionSchema: ...

torch/csrc/jit/python/script_init.cpp

Lines changed: 75 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,22 @@ void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
710710
}
711711
}
712712

713+
void pyCompilationUnitDefine(
714+
CompilationUnit& cu,
715+
const std::string& src,
716+
const ResolutionCallback* rcb,
717+
const uint32_t _frames_up) {
718+
if (rcb && *rcb) {
719+
cu.define(c10::nullopt, src, pythonResolver(*rcb), nullptr);
720+
} else {
721+
py::object py_default_rcb =
722+
py::module::import("torch._jit_internal")
723+
.attr("createResolutionCallbackFromFrame")(_frames_up);
724+
auto default_rcb = py_default_rcb.cast<ResolutionCallback>();
725+
cu.define(c10::nullopt, src, pythonResolver(default_rcb), nullptr);
726+
}
727+
}
728+
713729
void initJitScriptBindings(PyObject* module) {
714730
auto m = py::handle(module).cast<py::module>();
715731

@@ -1114,21 +1130,72 @@ void initJitScriptBindings(PyObject* module) {
11141130

11151131
py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
11161132
m, "CompilationUnit")
1117-
.def(py::init<>())
1133+
.def(
1134+
py::init([](const std::string& lang, const uint32_t _frames_up) {
1135+
auto cu = std::make_shared<CompilationUnit>();
1136+
if (lang.size() > 0) {
1137+
pyCompilationUnitDefine(*cu, lang, nullptr, _frames_up);
1138+
}
1139+
return cu;
1140+
}),
1141+
py::arg("lang") = "",
1142+
py::arg("_frames_up") = 0)
1143+
11181144
.def(
11191145
"find_function",
11201146
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1121-
auto& fn = self->get_function(QualifiedName(name));
1122-
return StrongFunctionPtr(std::move(self), &fn);
1147+
auto fn = self->find_function(QualifiedName(name));
1148+
if (fn) {
1149+
return c10::optional<StrongFunctionPtr>(
1150+
StrongFunctionPtr(std::move(self), fn));
1151+
} else {
1152+
return c10::optional<StrongFunctionPtr>(c10::nullopt);
1153+
}
1154+
})
1155+
.def(
1156+
"__getattr__",
1157+
[](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1158+
auto fn = self->find_function(QualifiedName(name));
1159+
if (fn) {
1160+
return StrongFunctionPtr(std::move(self), fn);
1161+
} else {
1162+
throw AttributeError(
1163+
"'CompilationUnit' has no attribute '%s'", name.c_str());
1164+
}
1165+
})
1166+
.def(
1167+
"get_functions",
1168+
[](const std::shared_ptr<CompilationUnit>& self) {
1169+
auto raw_functions = self->get_functions();
1170+
std::vector<StrongFunctionPtr> functions;
1171+
functions.reserve(raw_functions.size());
1172+
for (auto fn : raw_functions) {
1173+
if (fn) {
1174+
functions.emplace_back(self, fn);
1175+
}
1176+
}
1177+
return functions;
11231178
})
11241179
.def("set_optimized", &CompilationUnit::set_optimized)
11251180
.def(
11261181
"define",
1127-
[](CompilationUnit& cu,
1128-
const std::string& src,
1129-
const ResolutionCallback& rcb) {
1130-
cu.define(c10::nullopt, src, pythonResolver(rcb), nullptr);
1131-
})
1182+
pyCompilationUnitDefine,
1183+
py::arg("src"),
1184+
py::arg("rcb") = nullptr,
1185+
py::arg("_frames_up") = 0)
1186+
.def(
1187+
"create_function",
1188+
[](std::shared_ptr<CompilationUnit>& self,
1189+
const std::string& qualified_name,
1190+
std::shared_ptr<Graph> graph,
1191+
bool should_mangle) {
1192+
Function* fn = self->create_function(
1193+
qualified_name, std::move(graph), should_mangle);
1194+
return StrongFunctionPtr(std::move(self), fn);
1195+
},
1196+
py::arg("qualified_name"),
1197+
py::arg("graph"),
1198+
py::arg("should_mangle") = false)
11321199
.def(
11331200
"get_interface",
11341201
[](const std::shared_ptr<CompilationUnit>& self,

torch/jit/_script.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,24 +1095,8 @@ def _recursive_compile_class(obj, loc):
10951095
rcb = _jit_internal.createResolutionCallbackForClassMethods(obj)
10961096
_compile_and_register_class(obj, rcb, _qual_name)
10971097

1098-
1099-
class CompilationUnit(object):
1100-
def __init__(self, lang=None, _frames_up=0):
1101-
self._c = torch._C.CompilationUnit()
1102-
if lang is not None:
1103-
self.define(lang, _frames_up=_frames_up + 1)
1104-
1105-
def define(self, lang, rcb=None, _frames_up=0):
1106-
if not rcb:
1107-
rcb = _jit_internal.createResolutionCallbackFromFrame(_frames_up + 1)
1108-
self._c.define(lang, rcb)
1109-
1110-
def __getattr__(self, attr):
1111-
r = self._c.find_function(attr)
1112-
if r is None:
1113-
raise AttributeError("'CompilationUnit' has no attribute '{}'".format(attr))
1114-
return r
1115-
1098+
CompilationUnit = torch._C.CompilationUnit
1099+
set_module(CompilationUnit, "torch.jit")
11161100

11171101
def _unwrap_optional(x):
11181102
assert x is not None, "Unwrapping null optional"

0 commit comments

Comments
 (0)