@@ -710,6 +710,22 @@ void extra_files_to_python(const ExtraFilesMap& m, const py::dict& pydict) {
710710 }
711711}
712712
713+ void pyCompilationUnitDefine (
714+ CompilationUnit& cu,
715+ const std::string& src,
716+ const ResolutionCallback* rcb,
717+ const uint32_t _frames_up) {
718+ if (rcb && *rcb) {
719+ cu.define (c10::nullopt , src, pythonResolver (*rcb), nullptr );
720+ } else {
721+ py::object py_default_rcb =
722+ py::module::import (" torch._jit_internal" )
723+ .attr (" createResolutionCallbackFromFrame" )(_frames_up);
724+ auto default_rcb = py_default_rcb.cast <ResolutionCallback>();
725+ cu.define (c10::nullopt , src, pythonResolver (default_rcb), nullptr );
726+ }
727+ }
728+
713729void initJitScriptBindings (PyObject* module ) {
714730 auto m = py::handle (module ).cast <py::module >();
715731
@@ -1114,21 +1130,72 @@ void initJitScriptBindings(PyObject* module) {
11141130
11151131 py::class_<CompilationUnit, std::shared_ptr<CompilationUnit>>(
11161132 m, " CompilationUnit" )
1117- .def (py::init<>())
1133+ .def (
1134+ py::init ([](const std::string& lang, const uint32_t _frames_up) {
1135+ auto cu = std::make_shared<CompilationUnit>();
1136+ if (lang.size () > 0 ) {
1137+ pyCompilationUnitDefine (*cu, lang, nullptr , _frames_up);
1138+ }
1139+ return cu;
1140+ }),
1141+ py::arg (" lang" ) = " " ,
1142+ py::arg (" _frames_up" ) = 0 )
1143+
11181144 .def (
11191145 " find_function" ,
11201146 [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1121- auto & fn = self->get_function (QualifiedName (name));
1122- return StrongFunctionPtr (std::move (self), &fn);
1147+ auto fn = self->find_function (QualifiedName (name));
1148+ if (fn) {
1149+ return c10::optional<StrongFunctionPtr>(
1150+ StrongFunctionPtr (std::move (self), fn));
1151+ } else {
1152+ return c10::optional<StrongFunctionPtr>(c10::nullopt );
1153+ }
1154+ })
1155+ .def (
1156+ " __getattr__" ,
1157+ [](std::shared_ptr<CompilationUnit> self, const std::string& name) {
1158+ auto fn = self->find_function (QualifiedName (name));
1159+ if (fn) {
1160+ return StrongFunctionPtr (std::move (self), fn);
1161+ } else {
1162+ throw AttributeError (
1163+ " 'CompilationUnit' has no attribute '%s'" , name.c_str ());
1164+ }
1165+ })
1166+ .def (
1167+ " get_functions" ,
1168+ [](const std::shared_ptr<CompilationUnit>& self) {
1169+ auto raw_functions = self->get_functions ();
1170+ std::vector<StrongFunctionPtr> functions;
1171+ functions.reserve (raw_functions.size ());
1172+ for (auto fn : raw_functions) {
1173+ if (fn) {
1174+ functions.emplace_back (self, fn);
1175+ }
1176+ }
1177+ return functions;
11231178 })
11241179 .def (" set_optimized" , &CompilationUnit::set_optimized)
11251180 .def (
11261181 " define" ,
1127- [](CompilationUnit& cu,
1128- const std::string& src,
1129- const ResolutionCallback& rcb) {
1130- cu.define (c10::nullopt , src, pythonResolver (rcb), nullptr );
1131- })
1182+ pyCompilationUnitDefine,
1183+ py::arg (" src" ),
1184+ py::arg (" rcb" ) = nullptr ,
1185+ py::arg (" _frames_up" ) = 0 )
1186+ .def (
1187+ " create_function" ,
1188+ [](std::shared_ptr<CompilationUnit>& self,
1189+ const std::string& qualified_name,
1190+ std::shared_ptr<Graph> graph,
1191+ bool should_mangle) {
1192+ Function* fn = self->create_function (
1193+ qualified_name, std::move (graph), should_mangle);
1194+ return StrongFunctionPtr (std::move (self), fn);
1195+ },
1196+ py::arg (" qualified_name" ),
1197+ py::arg (" graph" ),
1198+ py::arg (" should_mangle" ) = false )
11321199 .def (
11331200 " get_interface" ,
11341201 [](const std::shared_ptr<CompilationUnit>& self,
0 commit comments