11#include " torch/csrc/autograd/engine.h"
2+ #include " torch/csrc/autograd/functions/basic_ops.h"
23
34#include < atomic>
45#include < condition_variable>
@@ -226,9 +227,9 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
226227}
227228
228229/* * Finds all stochastic functions and appends them to the queue */
229- auto Engine::find_stochastic_functions (function_queue& queue, GraphTask& task) -> void {
230- std::unordered_set<Function*> seen;
231- function_queue search_queue (queue) ;
230+ auto Engine::find_stochastic_functions (function_queue& queue, Function* graph_root, GraphTask& task) -> void {
231+ std::unordered_set<Function*> seen {graph_root} ;
232+ function_queue search_queue {graph_root} ;
232233 while (search_queue.size () > 0 ) {
233234 auto fn = search_queue.back (); search_queue.pop_back ();
234235 for (auto & next_fn_pair : fn->next_functions ) {
@@ -258,8 +259,6 @@ auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void
258259 auto & dependencies = task.dependencies ;
259260 while (queue.size () > 0 ) {
260261 auto fn = std::move (queue.back ()); queue.pop_back ();
261- // This is needed only to filter out roots that aren't executable
262- if (!fn->is_executable ) continue ;
263262 for (auto & next_fn_pair : fn->next_functions ) {
264263 Function* next_ptr = next_fn_pair.first .get ();
265264 if (!next_ptr) continue ;
@@ -274,38 +273,6 @@ auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void
274273 }
275274}
276275
277- auto Engine::find_roots (const function_list& input_roots,
278- variable_list& inputs,
279- GraphTask& task) -> function_queue {
280- std::unordered_map<std::shared_ptr<Function>, std::unique_ptr<InputBuffer>> root_value;
281- int num_inputs = input_roots.size ();
282- for (int i = 0 ; i < num_inputs; ++i) {
283- auto & input = inputs[i];
284- auto & root_info = input_roots[i];
285- auto root = root_info.first ;
286- int input_nr = root_info.second ;
287- auto & buf = root_value[root];
288- if (root->is_executable ) {
289- if (!buf) buf.reset (new InputBuffer (root->num_inputs ));
290- buf->add (input_nr, std::shared_ptr<Variable>(input));
291- }
292- }
293-
294- function_queue roots;
295- for (auto & entry: root_value) {
296- const auto & root = entry.first ;
297- roots.push_back (root.get ());
298- // no need to enqueue tasks for non-executable functions
299- if (!root->is_executable ) continue ;
300- auto & input_buf = entry.second ;
301- auto & queue = ready_queue (input_buf->device ());
302- queue.push_front (FunctionTask (&task, root, std::move (*input_buf)));
303- task.has_any_work = true ;
304- }
305-
306- return roots;
307- }
308-
309276auto Engine::execute (const function_list& input_roots,
310277 variable_list& inputs,
311278 bool keep_graph,
@@ -315,11 +282,19 @@ auto Engine::execute(const function_list& input_roots,
315282 GraphTask graph_task (keep_graph, callbacks);
316283 std::unique_lock<std::mutex> lock (graph_task.mutex );
317284
318- // Find the unique roots and backprop into variables.
319- function_queue roots = find_roots (input_roots, inputs, graph_task);
285+ auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs);
286+ function_queue roots;
287+ for (auto entry : input_roots) {
288+ if (entry.first ->is_executable ) {
289+ graph_task.has_any_work = true ;
290+ roots.push_back (graph_root.get ());
291+ ready_queue (-1 ).push_front (FunctionTask (&graph_task, graph_root, InputBuffer (0 )));
292+ break ;
293+ }
294+ }
320295
321296 // Search the graph and find all stochastic functions. Append them to the queue.
322- find_stochastic_functions (roots, graph_task);
297+ find_stochastic_functions (roots, graph_root. get (), graph_task);
323298
324299 if (!graph_task.has_any_work ) {
325300 throw std::runtime_error (
0 commit comments