Skip to content

Commit fa93653

Browse files
apaszkesoumith
authored andcommitted
Improve handling of graph roots in autograd engine (pytorch#1635)
1 parent ba56de1 commit fa93653

File tree

4 files changed

+43
-50
lines changed

4 files changed

+43
-50
lines changed

test/test_autograd.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -362,7 +362,7 @@ def hook(grad_input, grad_output):
362362
sum(fn(x, y)).sum().backward()
363363
self.assertTrue(was_called[0])
364364

365-
def _test_backward(self):
365+
def test_backward(self):
366366
v_t = torch.randn(5, 5)
367367
x_t = torch.randn(5, 5)
368368
y_t = torch.rand(5, 5) + 0.1
@@ -385,9 +385,6 @@ def _test_backward(self):
385385
self.assertEqual(y.grad.data, y_grad * grad_output)
386386
self.assertEqual(z.grad.data, z_grad * grad_output)
387387

388-
def test_backward(self):
389-
self._test_backward()
390-
391388
def test_sparse_backward(self):
392389
class FixedGradientFunction(Function):
393390

@@ -431,11 +428,6 @@ def backward(self, grad_x):
431428
(sparse_fn1(x) + sparse_fn2(x)).sum().backward()
432429
self.assertEqual(x.grad.data, sparse_grad1 + sparse_grad2)
433430

434-
@unittest.skip("BasicEngine is out of date")
435-
def test_backward_basic_engine(self):
436-
with backward_engine(torch.autograd.engine.BasicEngine):
437-
self._test_backward()
438-
439431
def test_multi_backward(self):
440432
x = Variable(torch.randn(5, 5), requires_grad=True)
441433
y = Variable(torch.randn(5, 5), requires_grad=True)
@@ -478,6 +470,18 @@ def test_multi_backward_no_grad(self):
478470
torch.autograd.backward([z, q], [torch.ones(5, 5), torch.ones(5, 5)])
479471
self.assertEqual(x.grad.data, torch.ones(5, 5))
480472

473+
def test_dependent_backward(self):
474+
x = Variable(torch.randn(10), requires_grad=True)
475+
y = x ** 2
476+
z = y ** 3
477+
478+
go_y = torch.randn(10)
479+
go_z = torch.randn(10)
480+
torch.autograd.backward([y, z], [go_y, go_z])
481+
482+
xd = x.data
483+
self.assertEqual(x.grad.data, 2 * xd * go_y + 6 * xd.pow(5) * go_z)
484+
481485
def test_volatile(self):
482486
x = Variable(torch.ones(5, 5), requires_grad=True)
483487
y = Variable(torch.ones(5, 5) * 4, volatile=True)

torch/csrc/autograd/engine.cpp

Lines changed: 15 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "torch/csrc/autograd/engine.h"
2+
#include "torch/csrc/autograd/functions/basic_ops.h"
23

34
#include <atomic>
45
#include <condition_variable>
@@ -226,9 +227,9 @@ auto Engine::evaluate_function(FunctionTask& task) -> void {
226227
}
227228

228229
/** Finds all stochastic functions and appends them to the queue */
229-
auto Engine::find_stochastic_functions(function_queue& queue, GraphTask& task) -> void {
230-
std::unordered_set<Function*> seen;
231-
function_queue search_queue(queue);
230+
auto Engine::find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task) -> void {
231+
std::unordered_set<Function*> seen {graph_root};
232+
function_queue search_queue {graph_root};
232233
while (search_queue.size() > 0) {
233234
auto fn = search_queue.back(); search_queue.pop_back();
234235
for (auto& next_fn_pair : fn->next_functions) {
@@ -258,8 +259,6 @@ auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void
258259
auto& dependencies = task.dependencies;
259260
while (queue.size() > 0) {
260261
auto fn = std::move(queue.back()); queue.pop_back();
261-
// This is needed only to filter out roots that aren't executable
262-
if (!fn->is_executable) continue;
263262
for (auto& next_fn_pair : fn->next_functions) {
264263
Function* next_ptr = next_fn_pair.first.get();
265264
if (!next_ptr) continue;
@@ -274,38 +273,6 @@ auto Engine::compute_dependencies(function_queue queue, GraphTask& task) -> void
274273
}
275274
}
276275

277-
auto Engine::find_roots(const function_list& input_roots,
278-
variable_list& inputs,
279-
GraphTask& task) -> function_queue {
280-
std::unordered_map<std::shared_ptr<Function>, std::unique_ptr<InputBuffer>> root_value;
281-
int num_inputs = input_roots.size();
282-
for (int i = 0; i < num_inputs; ++i) {
283-
auto& input = inputs[i];
284-
auto& root_info = input_roots[i];
285-
auto root = root_info.first;
286-
int input_nr = root_info.second;
287-
auto& buf = root_value[root];
288-
if (root->is_executable) {
289-
if (!buf) buf.reset(new InputBuffer(root->num_inputs));
290-
buf->add(input_nr, std::shared_ptr<Variable>(input));
291-
}
292-
}
293-
294-
function_queue roots;
295-
for (auto& entry: root_value) {
296-
const auto& root = entry.first;
297-
roots.push_back(root.get());
298-
// no need to enqueue tasks for non-executable functions
299-
if (!root->is_executable) continue;
300-
auto& input_buf = entry.second;
301-
auto& queue = ready_queue(input_buf->device());
302-
queue.push_front(FunctionTask(&task, root, std::move(*input_buf)));
303-
task.has_any_work = true;
304-
}
305-
306-
return roots;
307-
}
308-
309276
auto Engine::execute(const function_list& input_roots,
310277
variable_list& inputs,
311278
bool keep_graph,
@@ -315,11 +282,19 @@ auto Engine::execute(const function_list& input_roots,
315282
GraphTask graph_task(keep_graph, callbacks);
316283
std::unique_lock<std::mutex> lock(graph_task.mutex);
317284

318-
// Find the unique roots and backprop into variables.
319-
function_queue roots = find_roots(input_roots, inputs, graph_task);
285+
auto graph_root = std::make_shared<GraphRoot>(input_roots, inputs);
286+
function_queue roots;
287+
for (auto entry : input_roots) {
288+
if (entry.first->is_executable) {
289+
graph_task.has_any_work = true;
290+
roots.push_back(graph_root.get());
291+
ready_queue(-1).push_front(FunctionTask(&graph_task, graph_root, InputBuffer(0)));
292+
break;
293+
}
294+
}
320295

321296
// Search the graph and find all stochastic functions. Append them to the queue.
322-
find_stochastic_functions(roots, graph_task);
297+
find_stochastic_functions(roots, graph_root.get(), graph_task);
323298

324299
if (!graph_task.has_any_work) {
325300
throw std::runtime_error(

torch/csrc/autograd/engine.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ struct Engine {
4444
const function_list& roots,
4545
variable_list& inputs,
4646
GraphTask& task);
47-
void find_stochastic_functions(function_queue& queue, GraphTask& task);
47+
void find_stochastic_functions(function_queue& queue, Function* graph_root, GraphTask& task);
4848
void compute_dependencies(function_queue queue, GraphTask& task);
4949
void evaluate_function(FunctionTask& task);
5050
ReadyQueue& ready_queue(int device);

torch/csrc/autograd/functions/basic_ops.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,20 @@ struct DelayedError : public Function {
3232
std::string msg;
3333
};
3434

35+
struct GraphRoot : public Function {
36+
GraphRoot(function_list functions, variable_list inputs)
37+
: outputs(std::move(inputs)) {
38+
next_functions = std::move(functions);
39+
is_executable = true;
40+
};
41+
42+
virtual variable_list apply(const variable_list& inputs) {
43+
return outputs;
44+
}
45+
46+
variable_list outputs;
47+
};
48+
3549
struct Add : public Function {
3650
Add() {}
3751

0 commit comments

Comments
 (0)