Skip to content

Commit aebaf31

Browse files
This CL removes the Graph.edge_set_ field. This field stores a set of the Edge* that are in a Graph. However, Graph already stores this information, in Graph.edges_. There's really no good reason to keep both of these collections. To convert everything to use Graph.edges_ instead of Graph.edge_set_, I defined a class which handled excluding nullptr from iteration of the edges_ vector.
This caused changes to non-contractual behavior of the runtime (enumeration order), so the unit tests are updated to reflect this. On a real-world graph used by our team, which contains 13190 nodes and 20796 edges, this change reduced heap allocation from 39.1 MB to 38.0 MB, for a drop of about 3%. Change: 154781831
1 parent 0135602 commit aebaf31

File tree

4 files changed

+75
-15
lines changed

4 files changed

+75
-15
lines changed

tensorflow/core/common_runtime/function.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ Status FunctionLibraryRuntimeImpl::Instantiate(
456456
void DumpGraph(StringPiece label, const Graph* g) {
457457
// TODO(zhifengc): Change Graph to record #nodes.
458458
VLOG(1) << "Graph " << label << " #nodes " << g->num_nodes() << " #edges "
459-
<< g->edges().size();
459+
<< g->num_edges();
460460
if (VLOG_IS_ON(2)) {
461461
for (const auto& line : str_util::Split(DebugString(g), '\n')) {
462462
VLOG(2) << "|| " << line;

tensorflow/core/common_runtime/function_test.cc

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ TEST_F(FunctionLibraryRuntimeTest, ControlDeps) {
424424
n8 = NoOp() @ n4
425425
n9 = Identity[T=float](n3) @ n8
426426
n10 = Identity[T=float](n2) @ n8
427-
n11 = NoOp() @ n10, n9
427+
n11 = NoOp() @ n9, n10
428428
n5 = Mul[T=float](n2, n2) @ n11
429429
n6 = Add[T=float](n4, n5)
430430
}
@@ -500,8 +500,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_XTimesTwo) {
500500
OptimizeGraph(lib_, &g);
501501
const char* e2 = R"P(
502502
(n2:float, n3:float) -> (n9:float) {
503-
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
504503
n10 = Const[dtype=float, value=Tensor<type: float shape: [] values: 2>]()
504+
n11 = Const[dtype=int32, value=Tensor<type: int32 shape: [0] values: >]()
505505
n6 = Shape[T=float, out_type=int32](n2)
506506
n5 = Mul[T=float](n3, n10)
507507
n7 = BroadcastGradientArgs[T=int32](n6, n11)
@@ -614,10 +614,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
614614
n17 = Sum[T=float, Tidx=int32, keep_dims=false](n14, n16)
615615
n19 = SymbolicGradient[Tin={float, int32, float}, Tout={float, int32}, f=Sum[T=float, Tidx=int32, keep_dims=false]](n14, n16, n26)
616616
n21 = SymbolicGradient[Tin={float, float, float}, Tout={float, float}, f=Add[T=float]](n24, n25, n19)
617-
n28 = Identity[T=float](n21:1)
618617
n27 = Identity[T=float](n21)
619-
n6 = Identity[T=float](n28)
618+
n28 = Identity[T=float](n21:1)
620619
n8 = Identity[T=float](n27)
620+
n6 = Identity[T=float](n28)
621621
}
622622
)P";
623623
EXPECT_EQ(e1, DebugString(g.get()));
@@ -626,8 +626,8 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
626626
const char* e2 = R"P(
627627
(n4:float, n3:float) -> (n25:float, n23:float) {
628628
n2 = Const[dtype=float, value=Tensor<type: float shape: [] values: 1>]()
629-
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
630629
n7 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 1>]()
630+
n8 = Const[dtype=int32, value=Tensor<type: int32 shape: [] values: 0>]()
631631
n19 = Shape[T=float, out_type=int32](n3)
632632
n9 = Add[T=float](n4, n3)
633633
n20 = Shape[T=float, out_type=int32](n4)
@@ -641,10 +641,10 @@ TEST_F(FunctionLibraryRuntimeTest, Gradient_AddSum) {
641641
n16 = Reshape[T=float, Tshape=int32](n2, n15)
642642
n17 = Div[T=int32](n14, n15)
643643
n18 = Tile[T=float, Tmultiples=int32](n16, n17)
644-
n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
645644
n22 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21:1)
646-
n25 = Reshape[T=float, Tshape=int32](n24, n20)
645+
n24 = Sum[T=float, Tidx=int32, keep_dims=false](n18, n21)
647646
n23 = Reshape[T=float, Tshape=int32](n22, n19)
647+
n25 = Reshape[T=float, Tshape=int32](n24, n20)
648648
}
649649
)P";
650650
EXPECT_EQ(e2, DebugString(g.get()));

tensorflow/core/graph/graph.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ const Edge* Graph::AddEdge(Node* source, int x, Node* dest, int y) {
344344
CHECK(source->out_edges_.insert(e).second);
345345
CHECK(dest->in_edges_.insert(e).second);
346346
edges_.push_back(e);
347-
edge_set_.insert(e);
347+
++num_edges_;
348348
return e;
349349
}
350350

@@ -354,8 +354,8 @@ void Graph::RemoveEdge(const Edge* e) {
354354
CHECK_EQ(e->src_->out_edges_.erase(e), size_t{1});
355355
CHECK_EQ(e->dst_->in_edges_.erase(e), size_t{1});
356356
CHECK_EQ(e, edges_[e->id_]);
357+
CHECK_GT(num_edges_, 0);
357358

358-
CHECK_EQ(edge_set_.erase(e), size_t{1});
359359
edges_[e->id_] = nullptr;
360360

361361
Edge* del = const_cast<Edge*>(e);
@@ -365,6 +365,7 @@ void Graph::RemoveEdge(const Edge* e) {
365365
del->src_output_ = kControlSlot - 1;
366366
del->dst_input_ = kControlSlot - 1;
367367
free_edges_.push_back(del);
368+
--num_edges_;
368369
}
369370

370371
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {

tensorflow/core/graph/graph.h

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,66 @@ class Edge {
268268
int dst_input_;
269269
};
270270

271+
// Allows for iteration of the edges of a Graph, by iterating the underlying
272+
// Graph.edges_ vector while skipping over null entries.
273+
class GraphEdgesIterable {
274+
private:
275+
const std::vector<Edge*>& edges_;
276+
277+
public:
278+
explicit GraphEdgesIterable(const std::vector<Edge*>& edges)
279+
: edges_(edges) {}
280+
281+
typedef Edge* value_type;
282+
283+
class const_iterator {
284+
private:
285+
// The underlying iterator.
286+
std::vector<value_type>::const_iterator iter_;
287+
288+
// The end of the underlying iterator.
289+
std::vector<value_type>::const_iterator end_;
290+
291+
// Advances iter_ until it reaches a non-null item, or reaches the end.
292+
void apply_filter() {
293+
while (iter_ != end_ && *iter_ == nullptr) {
294+
++iter_;
295+
}
296+
}
297+
298+
public:
299+
const_iterator(std::vector<value_type>::const_iterator iter,
300+
std::vector<value_type>::const_iterator end)
301+
: iter_(iter), end_(end) {
302+
apply_filter();
303+
}
304+
305+
bool operator==(const const_iterator& other) const {
306+
return iter_ == other.iter_;
307+
}
308+
309+
bool operator!=(const const_iterator& other) const {
310+
return iter_ != other.iter_;
311+
}
312+
313+
// This is the prefix increment operator (++x), which is the operator
314+
// used by C++ range iteration (for (x : y) ...). We intentionally do not
315+
// provide a postfix increment operator.
316+
const_iterator& operator++() {
317+
++iter_;
318+
apply_filter();
319+
return *this;
320+
}
321+
322+
value_type operator*() { return *iter_; }
323+
};
324+
325+
const_iterator begin() {
326+
return const_iterator(edges_.begin(), edges_.end());
327+
}
328+
const_iterator end() { return const_iterator(edges_.end(), edges_.end()); }
329+
};
330+
271331
// Thread compatible but not thread safe.
272332
class Graph {
273333
public:
@@ -345,7 +405,7 @@ class Graph {
345405
// smaller than num_edge_ids(). If one needs to create an array of
346406
// edges indexed by edge ids, num_edge_ids() should be used as the
347407
// array's size.
348-
int num_edges() const { return edges().size(); }
408+
int num_edges() const { return num_edges_; }
349409

350410
// Serialize the nodes starting at `from_node_id` to a GraphDef.
351411
void ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const;
@@ -381,7 +441,7 @@ class Graph {
381441

382442
// Access to the set of all edges. Example usage:
383443
// for (const Edge* e : graph.edges()) { ... }
384-
const EdgeSet& edges() const { return edge_set_; }
444+
GraphEdgesIterable edges() const { return GraphEdgesIterable(edges_); }
385445

386446
// The pre-defined nodes.
387447
enum { kSourceId = 0, kSinkId = 1 };
@@ -421,9 +481,8 @@ class Graph {
421481
// the edge with that id was removed from the graph.
422482
std::vector<Edge*> edges_;
423483

424-
// For ease of iteration, we currently just keep a set of all live
425-
// edges. May want to optimize by removing this copy.
426-
EdgeSet edge_set_;
484+
// The number of entries in edges_ that are not nullptr.
485+
int num_edges_ = 0;
427486

428487
// Allocated but free nodes and edges.
429488
std::vector<Node*> free_nodes_;

0 commit comments

Comments
 (0)