Skip to content

Commit 20723e2

Browse files
author
Vijay Vasudevan
committed
TensorFlow: Merge changes from internal
2 parents 4c717c6 + 2d11635 commit 20723e2

30 files changed

+1300
-165
lines changed

RELEASE.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,30 @@
1+
# Release 0.6.0
2+
3+
## Major Features and Improvements
4+
5+
* Python 3.3+ support via changes to python codebase and ability
6+
to specify python version via ./configure.
7+
8+
* Some improvements to GPU performance and memory usage:
9+
[convnet benchmarks](https://github.com/soumith/convnet-benchmarks/issues/66)
10+
roughly equivalent with native cudnn v2 performance. Improvements mostly due
11+
to moving to 32-bit indices, faster shuffling kernels. More improvements to
12+
come in later releases.
13+
14+
15+
## Bug fixes
16+
17+
* Lots of fixes to documentation and tutorials, many contributed
18+
by the public.
19+
20+
* 271 closed issues on github issues.
21+
22+
## Backwards-incompatible changes
23+
24+
* tf.nn.fixed_unigram_candidate_sampler changed its default 'distortion'
25+
attribute from 0.0 to 1.0. This was a bug in the original release
26+
that is now fixed.
27+
128
# Release 0.5.0
229

330
Initial release of TensorFlow.

tensorflow/core/kernels/conv_grad_ops.cc

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ limitations under the License.
3535
#if GOOGLE_CUDA
3636
#include "tensorflow/stream_executor/stream.h"
3737
#include "tensorflow/core/common_runtime/gpu_device_context.h"
38+
#include "tensorflow/core/kernels/conv_ops_gpu.h"
3839
#endif // GOOGLE_CUDA
3940

4041
namespace tensorflow {
@@ -756,17 +757,6 @@ REGISTER_KERNEL_BUILDER(Name("Conv2DBackpropFilter")
756757

757758
// GPU definitions of both ops.
758759
#if GOOGLE_CUDA
759-
namespace {
760-
template <typename T>
761-
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
762-
uint64 size) {
763-
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
764-
size * sizeof(T));
765-
perftools::gputools::DeviceMemory<T> typed(wrapped);
766-
return typed;
767-
}
768-
} // namespace
769-
770760
// The slow version (but compiles for GPU)
771761

772762
// Backprop for input.
@@ -929,10 +919,15 @@ class Conv2DSlowBackpropInputOp : public OpKernel {
929919
AsDeviceMemory(pre_transformed_in_backprop.template flat<T>().data(),
930920
pre_transformed_in_backprop.template flat<T>().size());
931921

922+
static int64 ConvolveBackwardDataScratchSize = GetCudnnWorkspaceLimit(
923+
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
924+
);
925+
CudnnScratchAllocator scratch_allocator(ConvolveBackwardDataScratchSize,
926+
context);
932927
bool cudnn_launch_status =
933-
stream->ThenConvolveBackwardData(filter_desc, filter_ptr, output_desc,
934-
out_backprop_ptr, conv_desc,
935-
input_desc, &in_backprop_ptr)
928+
stream->ThenConvolveBackwardDataWithScratch(
929+
filter_desc, filter_ptr, output_desc, out_backprop_ptr,
930+
conv_desc, input_desc, &in_backprop_ptr, &scratch_allocator)
936931
.ok();
937932

938933
if (!cudnn_launch_status) {
@@ -1185,7 +1180,6 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
11851180
context->eigen_device<Device>(),
11861181
const_cast<const Tensor&>(compatible_input).tensor<T, 4>(),
11871182
transformed_input.tensor<T, 4>());
1188-
11891183
auto out_backprop_ptr =
11901184
AsDeviceMemory(transformed_out_backprop.template flat<T>().data(),
11911185
transformed_out_backprop.template flat<T>().size());
@@ -1196,10 +1190,16 @@ class Conv2DSlowBackpropFilterOp : public OpKernel {
11961190
AsDeviceMemory(transformed_input.template flat<T>().data(),
11971191
transformed_input.template flat<T>().size());
11981192

1193+
static int64 ConvolveBackwardFilterScratchSize = GetCudnnWorkspaceLimit(
1194+
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
1195+
);
1196+
CudnnScratchAllocator scratch_allocator(ConvolveBackwardFilterScratchSize,
1197+
context);
11991198
bool cudnn_launch_status =
1200-
stream->ThenConvolveBackwardFilter(input_desc, input_ptr, output_desc,
1201-
out_backprop_ptr, conv_desc,
1202-
filter_desc, &filter_backprop_ptr)
1199+
stream->ThenConvolveBackwardFilterWithScratch(
1200+
input_desc, input_ptr, output_desc, out_backprop_ptr,
1201+
conv_desc, filter_desc, &filter_backprop_ptr,
1202+
&scratch_allocator)
12031203
.ok();
12041204

12051205
if (!cudnn_launch_status) {

tensorflow/core/kernels/conv_ops.cc

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ limitations under the License.
2525
#include "tensorflow/core/kernels/ops_util.h"
2626
#include "tensorflow/core/lib/core/errors.h"
2727
#include "tensorflow/core/lib/gtl/array_slice.h"
28+
#include "tensorflow/core/lib/strings/numbers.h"
2829
#include "tensorflow/core/platform/logging.h"
2930
#include "tensorflow/core/public/tensor.h"
3031
#include "tensorflow/core/public/tensor_shape.h"
@@ -34,6 +35,7 @@ limitations under the License.
3435
#if GOOGLE_CUDA
3536
#include "tensorflow/stream_executor/stream.h"
3637
#include "tensorflow/core/common_runtime/gpu_device_context.h"
38+
#include "tensorflow/core/kernels/conv_ops_gpu.h"
3739
#endif // GOOGLE_CUDA
3840

3941
namespace tensorflow {
@@ -206,16 +208,22 @@ REGISTER_KERNEL_BUILDER(Name("Conv2D")
206208

207209
#if GOOGLE_CUDA
208210

209-
namespace {
210-
template <typename T>
211-
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
212-
uint64 size) {
213-
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
214-
size * sizeof(T));
215-
perftools::gputools::DeviceMemory<T> typed(wrapped);
216-
return typed;
211+
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
212+
int64 default_value_in_bytes) {
213+
const char* workspace_limit_in_mb_str = getenv(envvar_in_mb.c_str());
214+
if (workspace_limit_in_mb_str != nullptr &&
215+
strcmp(workspace_limit_in_mb_str, "") != 0) {
216+
int64 scratch_limit_in_mb = -1;
217+
if (strings::safe_strto64(workspace_limit_in_mb_str,
218+
&scratch_limit_in_mb)) {
219+
return scratch_limit_in_mb * (1 << 20);
220+
} else {
221+
LOG(WARNING) << "Invalid value for env-var " << envvar_in_mb << ": "
222+
<< workspace_limit_in_mb_str;
223+
}
224+
}
225+
return default_value_in_bytes;
217226
}
218-
} // namespace
219227

220228
template <typename T>
221229
struct LaunchConvOp<GPUDevice, T> {
@@ -287,18 +295,34 @@ struct LaunchConvOp<GPUDevice, T> {
287295
input = transformed_input;
288296
}
289297

298+
{
299+
// Convert the input tensor from NHWC to NCHW.
300+
Tensor transformed_input;
301+
OP_REQUIRES_OK(ctx,
302+
ctx->allocate_temp(
303+
DataTypeToEnum<T>::value,
304+
TensorShape({input.dim_size(0), input.dim_size(3),
305+
input.dim_size(1), input.dim_size(2)}),
306+
&transformed_input));
307+
functor::NHWCToNCHW<GPUDevice, T>()(
308+
ctx->eigen_device<GPUDevice>(),
309+
const_cast<const Tensor&>(input).tensor<T, 4>(),
310+
transformed_input.tensor<T, 4>());
311+
input = transformed_input;
312+
}
313+
290314
perftools::gputools::dnn::BatchDescriptor input_desc;
291315
input_desc.set_count(input.dim_size(0))
292-
.set_height(input.dim_size(1))
293-
.set_width(input.dim_size(2))
294-
.set_feature_map_count(input.dim_size(3))
295-
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
316+
.set_feature_map_count(input.dim_size(1))
317+
.set_height(input.dim_size(2))
318+
.set_width(input.dim_size(3))
319+
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
296320
perftools::gputools::dnn::BatchDescriptor output_desc;
297321
output_desc.set_count(output->dim_size(0))
298322
.set_height(output->dim_size(1))
299323
.set_width(output->dim_size(2))
300324
.set_feature_map_count(output->dim_size(3))
301-
.set_layout(perftools::gputools::dnn::DataLayout::kBatchYXDepth);
325+
.set_layout(perftools::gputools::dnn::DataLayout::kBatchDepthYX);
302326
perftools::gputools::dnn::FilterDescriptor filter_desc;
303327
filter_desc.set_input_filter_height(filter.dim_size(0))
304328
.set_input_filter_width(filter.dim_size(1))
@@ -320,24 +344,44 @@ struct LaunchConvOp<GPUDevice, T> {
320344
ctx->eigen_device<GPUDevice>(), To32Bit(filter.tensor<T, 4>()),
321345
To32Bit(transformed_filter.tensor<T, 4>()));
322346

347+
Tensor transformed_output;
348+
OP_REQUIRES_OK(
349+
ctx, ctx->allocate_temp(
350+
DataTypeToEnum<T>::value,
351+
TensorShape({output->dim_size(0), output->dim_size(3),
352+
output->dim_size(1), output->dim_size(2)}),
353+
&transformed_output));
354+
323355
auto input_ptr = AsDeviceMemory(input.template flat<T>().data(),
324356
input.template flat<T>().size());
325357
auto filter_ptr =
326358
AsDeviceMemory(transformed_filter.template flat<T>().data(),
327359
transformed_filter.template flat<T>().size());
328-
auto output_ptr = AsDeviceMemory(output->template flat<T>().data(),
329-
output->template flat<T>().size());
330-
360+
auto output_ptr =
361+
AsDeviceMemory(transformed_output.template flat<T>().data(),
362+
transformed_output.template flat<T>().size());
363+
364+
static int64 ConvolveScratchSize = GetCudnnWorkspaceLimit(
365+
"TF_CUDNN_WORKSPACE_LIMIT_IN_MB", 1LL << 30 // 1GB by default
366+
);
367+
CudnnScratchAllocator scratch_allocator(ConvolveScratchSize, ctx);
331368
bool cudnn_launch_status =
332-
stream->ThenConvolve(input_desc, input_ptr, filter_desc, filter_ptr,
333-
conv_desc, output_desc, &output_ptr)
369+
stream->ThenConvolveWithScratch(input_desc, input_ptr, filter_desc,
370+
filter_ptr, conv_desc, output_desc,
371+
&output_ptr, &scratch_allocator)
334372
.ok();
335373

336374
if (!cudnn_launch_status) {
337375
ctx->SetStatus(errors::Internal(
338376
"cuDNN launch failure : input shape(", input.shape().DebugString(),
339377
") filter shape(", filter.shape().DebugString(), ")"));
340378
}
379+
380+
// Convert the output tensor back from NHWC to NCHW.
381+
functor::NCHWToNHWC<GPUDevice, T>()(
382+
ctx->eigen_device<GPUDevice>(),
383+
const_cast<const Tensor&>(transformed_output).tensor<T, 4>(),
384+
output->tensor<T, 4>());
341385
} else {
342386
LaunchGeneric<GPUDevice, T>::launch(ctx, input_param, filter, stride,
343387
padding, output);
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
/* Copyright 2015 Google Inc. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#ifndef THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
17+
#define THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_
18+
19+
#if GOOGLE_CUDA
20+
21+
#include "tensorflow/stream_executor/scratch_allocator.h"
22+
#include "tensorflow/core/common_runtime/gpu_device_context.h"
23+
24+
namespace tensorflow {
25+
26+
// TODO(zhengxq): move this to gpu_util.h. The use of such wrapers is wide
27+
// spread.
28+
template <typename T>
29+
perftools::gputools::DeviceMemory<T> AsDeviceMemory(const T* cuda_memory,
30+
uint64 size) {
31+
perftools::gputools::DeviceMemoryBase wrapped(const_cast<T*>(cuda_memory),
32+
size * sizeof(T));
33+
perftools::gputools::DeviceMemory<T> typed(wrapped);
34+
return typed;
35+
}
36+
37+
// Get the Cudnn workspace limit from the environment variable, which is in MB.
38+
// Return the workspace memory limit in bytes. If no value is set, return the
39+
// default value.
40+
int64 GetCudnnWorkspaceLimit(const string& envvar_in_mb,
41+
int64 default_value_in_bytes);
42+
43+
// A class to provide scratch-space allocator for Stream-Executor Cudnn
44+
// callback. TensorFlow is responsible for releasing the temporary buffers after
45+
// the kernel finishes.
46+
class CudnnScratchAllocator : public perftools::gputools::ScratchAllocator {
47+
public:
48+
virtual ~CudnnScratchAllocator() {}
49+
CudnnScratchAllocator(int64 memory_limit, OpKernelContext* context)
50+
: memory_limit_(memory_limit), context_(context) {}
51+
virtual int64 GetMemoryLimitInBytes(
52+
perftools::gputools::Stream* stream) override {
53+
return memory_limit_;
54+
}
55+
virtual perftools::gputools::port::StatusOr<
56+
perftools::gputools::DeviceMemory<uint8>>
57+
AllocateBytes(perftools::gputools::Stream* stream, int64 byte_size) override {
58+
Tensor temporary_memory;
59+
60+
Status allocation_status(context_->allocate_temp(
61+
DT_UINT8, TensorShape({byte_size}), &temporary_memory));
62+
if (!allocation_status.ok()) {
63+
LOG(WARNING) << allocation_status;
64+
context_->SetStatus(allocation_status);
65+
return perftools::gputools::port::StatusOr<
66+
perftools::gputools::DeviceMemory<uint8>>();
67+
}
68+
69+
return perftools::gputools::port::StatusOr<
70+
perftools::gputools::DeviceMemory<uint8>>(
71+
AsDeviceMemory(temporary_memory.flat<uint8>().data(),
72+
temporary_memory.flat<uint8>().size()));
73+
}
74+
75+
private:
76+
int64 memory_limit_;
77+
OpKernelContext* context_;
78+
};
79+
80+
} // namespace tensorflow
81+
82+
#endif // GOOGLE_CUDA
83+
84+
#endif // THIRD_PARTY_TENSORFLOW_CORE_KERNELS_CONV_OPS_GPU_H_

tensorflow/python/client/graph_util.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from __future__ import absolute_import
2020
from __future__ import division
2121
from __future__ import print_function
22+
import copy
2223

2324
import tensorflow.python.platform
2425

@@ -155,3 +156,64 @@ def pin_to_cpu(op):
155156
logging.info("Operation %s has been assigned to a non-CPU (%s), so "
156157
"it will not be pinned to the CPU.", op.name, dev.device_type)
157158
return device
159+
160+
161+
def _node_name(n):
162+
if n.startswith("^"):
163+
return n[1:]
164+
else:
165+
return n.split(":")[0]
166+
167+
168+
def extract_sub_graph(graph_def, dest_nodes):
169+
"""Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
170+
171+
Args:
172+
graph_def: A graph_pb2.GraphDef proto.
173+
dest_nodes: A list of strings specifying the destination node names.
174+
Returns:
175+
The GraphDef of the sub-graph.
176+
177+
Raises:
178+
TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
179+
"""
180+
181+
if not isinstance(graph_def, graph_pb2.GraphDef):
182+
raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
183+
184+
edges = {} # Keyed by the dest node name.
185+
name_to_node_map = {} # Keyed by node name.
186+
187+
# Keeps track of node sequences. It is important to still output the
188+
# operations in the original order.
189+
node_seq = {} # Keyed by node name.
190+
seq = 0
191+
for node in graph_def.node:
192+
n = _node_name(node.name)
193+
name_to_node_map[n] = node
194+
edges[n] = [_node_name(x) for x in node.input]
195+
node_seq[n] = seq
196+
seq += 1
197+
198+
for d in dest_nodes:
199+
assert d in name_to_node_map, "%d is not in graph" % d
200+
201+
nodes_to_keep = set()
202+
# Breadth first search to find all the nodes that we should keep.
203+
next_to_visit = dest_nodes[:]
204+
while next_to_visit:
205+
n = next_to_visit[0]
206+
del next_to_visit[0]
207+
if n in nodes_to_keep:
208+
# Already visited this node.
209+
continue
210+
nodes_to_keep.add(n)
211+
next_to_visit += edges[n]
212+
213+
nodes_to_keep_list = sorted(list(nodes_to_keep), key=lambda n: node_seq[n])
214+
# Now construct the output GraphDef
215+
out = graph_pb2.GraphDef()
216+
for n in nodes_to_keep_list:
217+
out.node.extend([copy.deepcopy(name_to_node_map[n])])
218+
219+
return out

0 commit comments

Comments
 (0)