Skip to content

Commit dc17f76

Browse files
Add slow path to fetch health pills at individual steps.
Tensorboard samples steps, yet users desire health pills at specific steps. This change makes the debugger plugin read directly from disk when the user specifies a specific step. This is much slower (It could take minutes.) than the alternative path of querying the multiplexer for sampled health pills. Change: 150041439
1 parent 8ba27e3 commit dc17f76

File tree

5 files changed

+296
-41
lines changed

5 files changed

+296
-41
lines changed

tensorflow/tensorboard/backend/event_processing/event_accumulator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from tensorflow.core.protobuf import meta_graph_pb2
2929
from tensorflow.core.protobuf.config_pb2 import RunMetadata
3030
from tensorflow.core.util.event_pb2 import SessionLog
31+
from tensorflow.python.framework import tensor_util
3132
from tensorflow.python.platform import tf_logging as logging
3233
from tensorflow.python.util import compat
3334
from tensorflow.tensorboard.backend.event_processing import directory_watcher
@@ -116,7 +117,7 @@
116117
# The tag that values containing health pills have. Health pill data is stored
117118
# in tensors. In order to distinguish health pill values from scalar values, we
118119
# rely on how health pill values have this special tag value.
119-
_HEALTH_PILL_EVENT_TAG = '__health_pill__'
120+
HEALTH_PILL_EVENT_TAG = '__health_pill__'
120121

121122

122123
def IsTensorFlowEventsFile(path):
@@ -318,7 +319,7 @@ def _ProcessEvent(self, event):
318319
self._tagged_metadata[tag] = event.tagged_run_metadata.run_metadata
319320
elif event.HasField('summary'):
320321
for value in event.summary.value:
321-
if value.HasField('tensor') and value.tag == _HEALTH_PILL_EVENT_TAG:
322+
if value.HasField('tensor') and value.tag == HEALTH_PILL_EVENT_TAG:
322323
self._ProcessHealthPillSummary(value, event)
323324
else:
324325
for summary_type, summary_func in SUMMARY_TYPES.items():
@@ -341,7 +342,7 @@ def _ProcessHealthPillSummary(self, value, event):
341342
value: A summary_pb2.Summary.Value with a Tensor field.
342343
event: The event_pb2.Event containing that value.
343344
"""
344-
elements = np.fromstring(value.tensor.tensor_content, dtype=np.float64)
345+
elements = tensor_util.MakeNdarray(value.tensor)
345346

346347
# The node_name property of the value object is actually a watch key: a
347348
# combination of node name, output slot, and a suffix. We capture the

tensorflow/tensorboard/backend/event_processing/event_accumulator_test.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
from tensorflow.core.framework import graph_pb2
2727
from tensorflow.core.framework import summary_pb2
28+
from tensorflow.core.framework import types_pb2
2829
from tensorflow.core.protobuf import config_pb2
2930
from tensorflow.core.util import event_pb2
3031
from tensorflow.python.framework import constant_op
@@ -70,15 +71,13 @@ def AddScalar(self, tag, wall_time=0, step=0, value=0):
7071
tag=tag, simple_value=value)]))
7172
self.AddEvent(event)
7273

73-
def AddHealthPill(self, wall_time, step, node_name, output_slot, elements):
74-
event = event_pb2.Event()
75-
event.wall_time = wall_time
76-
event.step = step
77-
value = event.summary.value.add()
78-
# The node_name property is actually a watch key.
79-
value.node_name = '%s:%d:DebugNumericSummary' % (node_name, output_slot)
80-
value.tag = '__health_pill__'
81-
value.tensor.tensor_shape.dim.add().size = len(elements)
74+
def AddHealthPill(self, wall_time, step, op_name, output_slot, elements):
75+
event = event_pb2.Event(step=step, wall_time=wall_time)
76+
value = event.summary.value.add(
77+
tag='__health_pill__',
78+
node_name='%s:%d:DebugNumericSummary' % (op_name, output_slot))
79+
value.tensor.tensor_shape.dim.add(size=len(elements))
80+
value.tensor.dtype = types_pb2.DT_DOUBLE
8281
value.tensor.tensor_content = np.array(elements, dtype=np.float64).tobytes()
8382
self.AddEvent(event)
8483

tensorflow/tensorboard/plugins/debugger/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,10 @@ py_library(
1515
srcs = ["debugger_plugin.py"],
1616
srcs_version = "PY2AND3",
1717
deps = [
18+
"//tensorflow/python:framework",
1819
"//tensorflow/python:platform",
20+
"//tensorflow/tensorboard/backend/event_processing:event_accumulator",
21+
"//tensorflow/tensorboard/backend/event_processing:event_file_loader",
1922
"//tensorflow/tensorboard/lib/python:http_util",
2023
"//tensorflow/tensorboard/plugins:base_plugin",
2124
],

tensorflow/tensorboard/plugins/debugger/debugger_plugin.py

Lines changed: 221 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,17 @@
1919
from __future__ import print_function
2020

2121
import collections
22+
import glob
2223
import json
24+
import os
25+
import re
2326

2427
from werkzeug import wrappers
2528

29+
from tensorflow.python.framework import tensor_util
2630
from tensorflow.python.platform import tf_logging as logging
31+
from tensorflow.tensorboard.backend.event_processing import event_accumulator
32+
from tensorflow.tensorboard.backend.event_processing import event_file_loader
2733
from tensorflow.tensorboard.lib.python import http_util
2834
from tensorflow.tensorboard.plugins import base_plugin
2935

@@ -42,6 +48,13 @@
4248
# The default run to retrieve health pills for.
4349
_DEFAULT_RUN = '.'
4450

51+
# The POST key of HEALTH_PILLS_ROUTE for the specific step to retrieve health
52+
# pills for.
53+
_STEP_POST_KEY = 'step'
54+
55+
# A glob pattern for files containing debugger-related events.
56+
_DEBUGGER_EVENTS_GLOB_PATTERN = 'events.debugger*'
57+
4558

4659
class DebuggerPlugin(base_plugin.TBPlugin):
4760
"""TensorFlow Debugger plugin. Receives requests for debugger-related data.
@@ -58,17 +71,18 @@ def __init__(self, event_multiplexer):
5871
"""
5972
self._event_multiplexer = event_multiplexer
6073

61-
def get_plugin_apps(self, unused_run_paths, unused_logdir):
62-
"""Obtains a mapping between routes and handlers.
74+
def get_plugin_apps(self, unused_run_paths, logdir):
75+
"""Obtains a mapping between routes and handlers. Stores the logdir.
6376
6477
Args:
6578
unused_run_paths: A mapping between run paths and handlers.
66-
unused_logdir: The logdir string - the directory of events files.
79+
logdir: The logdir string - the directory of events files.
6780
6881
Returns:
6982
A mapping between routes and handlers (functions that respond to
7083
requests).
7184
"""
85+
self._logdir = logdir
7286
return {
7387
_HEALTH_PILLS_ROUTE: self._serve_health_pills_handler,
7488
}
@@ -77,15 +91,27 @@ def get_plugin_apps(self, unused_run_paths, unused_logdir):
7791
def _serve_health_pills_handler(self, request):
7892
"""A (wrapped) werkzeug handler for serving health pills.
7993
80-
Accepts POST requests and responds with health pills. Specifically, the
81-
handler expects a required "node_names" and an optional "run" POST data key.
82-
The value of the "node_names" key should be a JSON-ified list of node names
83-
for which the client would like to request health pills. The value of the
84-
"run" key (which defaults to ".") should be the run to retrieve health pills
85-
for. This data is sent via POST (not GET) because URL length is limited.
94+
Accepts POST requests and responds with health pills. The request accepts
95+
several POST parameters:
96+
97+
node_names: (required string) A JSON-ified list of node names for which
98+
the client would like to request health pills.
99+
run: (optional string) The run to retrieve health pills for. Defaults to
100+
'.'. This data is sent via POST (not GET) since URL length is limited.
101+
step: (optional integer): The session run step for which to
102+
retrieve health pills. If provided, the handler reads the health pills
103+
of that step from disk (which is slow) and produces a response with
104+
only health pills at that step. If not provided, the handler returns a
105+
response with health pills at all steps sampled by the event
106+
multiplexer (the fast path). The motivation here is that, sometimes,
107+
one desires to examine health pills at a specific step (to say find
108+
the first step that causes a model to blow up with NaNs).
109+
get_plugin_apps must be called before this slower feature is used
110+
because that method passes the logdir (directory path) to this plugin.
86111
87112
This handler responds with a JSON-ified object mapping from node names to a
88-
list of health pill event objects, each of which has these properties.
113+
list (of size 1) of health pill event objects, each of which has these
114+
properties.
89115
90116
{
91117
'wall_time': float,
@@ -112,7 +138,7 @@ def _serve_health_pills_handler(self, request):
112138

113139
if _NODE_NAMES_POST_KEY not in request.form:
114140
logging.error(
115-
'The %s POST key was not found in the request for health pills.',
141+
'The %r POST key was not found in the request for health pills.',
116142
_NODE_NAMES_POST_KEY)
117143
return wrappers.Response(status=400)
118144

@@ -123,30 +149,197 @@ def _serve_health_pills_handler(self, request):
123149
# Different JSON libs raise different exceptions, so we just do a
124150
# catch-all here. This problem is complicated by how Tensorboard might be
125151
# run in many different environments, as it is open-source.
126-
logging.error(
127-
'Could not decode node name JSON string %s: %s',
128-
jsonified_node_names, e)
152+
logging.error('Could not decode node name JSON string %r: %s',
153+
jsonified_node_names, e)
129154
return wrappers.Response(status=400)
130155

131156
if not isinstance(node_names, list):
132-
logging.error(
133-
'%s is not a JSON list of node names:', jsonified_node_names)
157+
logging.error('%r is not a JSON list of node names:',
158+
jsonified_node_names)
134159
return wrappers.Response(status=400)
135160

136-
mapping = collections.defaultdict(list)
137161
run = request.form.get(_RUN_POST_KEY, _DEFAULT_RUN)
162+
step_string = request.form.get(_STEP_POST_KEY, None)
163+
if step_string is None:
164+
# Use all steps sampled by the event multiplexer (Relatively fast).
165+
mapping = self._obtain_sampled_health_pills(run, node_names)
166+
else:
167+
# Read disk to obtain the health pills for that step (Relatively slow).
168+
# Make sure that the directory for the run exists.
169+
# Determine the directory of events file to read.
170+
events_directory = self._logdir
171+
if run != _DEFAULT_RUN:
172+
# Use the directory for the specific run.
173+
events_directory = os.path.join(events_directory, run)
174+
175+
step = int(step_string)
176+
try:
177+
mapping = self._obtain_health_pills_at_step(
178+
events_directory, node_names, step)
179+
except IOError as error:
180+
logging.error(
181+
'Error retrieving health pills for step %d: %s', step, error)
182+
return wrappers.Response(status=404)
183+
184+
# Convert event_accumulator.HealthPillEvents to JSON-able dicts.
185+
jsonable_mapping = {}
186+
for node_name, events in mapping.items():
187+
jsonable_mapping[node_name] = [e._asdict() for e in events]
188+
return http_util.Respond(request, jsonable_mapping, 'application/json')
189+
190+
def _obtain_sampled_health_pills(self, run, node_names):
191+
"""Obtains the health pills for a run sampled by the event multiplexer.
192+
193+
This is much faster than the alternative path of reading health pills from
194+
disk.
195+
196+
Args:
197+
run: The run to fetch health pills for.
198+
node_names: A list of node names for which to retrieve health pills.
199+
200+
Returns:
201+
A dictionary mapping from node name to a list of
202+
event_accumulator.HealthPillEvents.
203+
"""
204+
mapping = {}
138205
for node_name in node_names:
139206
try:
140-
pill_events = self._event_multiplexer.HealthPills(run, node_name)
141-
for pill_event in pill_events:
142-
mapping[node_name].append({
143-
'wall_time': pill_event[0],
144-
'step': pill_event[1],
145-
'node_name': pill_event[2],
146-
'output_slot': pill_event[3],
147-
'value': pill_event[4],
148-
})
207+
mapping[node_name] = self._event_multiplexer.HealthPills(run, node_name)
149208
except KeyError:
150-
logging.info('No health pills found for node %s.', node_name)
209+
logging.info('No health pills found for node %r.', node_name)
210+
continue
211+
212+
return mapping
213+
214+
def _obtain_health_pills_at_step(self, events_directory, node_names, step):
215+
"""Reads disk to obtain the health pills for a run at a specific step.
216+
217+
This could be much slower than the alternative path of just returning all
218+
health pills sampled by the event multiplexer. It could take tens of minutes
219+
to complete this call for large graphs for big step values (in the
220+
thousands).
221+
222+
Args:
223+
events_directory: The directory containing events for the desired run.
224+
node_names: A list of node names for which to retrieve health pills.
225+
step: The step to obtain health pills for.
226+
227+
Returns:
228+
A dictionary mapping from node name to a list of health pill objects (see
229+
docs for _serve_health_pills_handler for properties of those objects).
230+
231+
Raises:
232+
IOError: If no files with health pill events could be found.
233+
"""
234+
# Obtain all files with debugger-related events.
235+
pattern = os.path.join(events_directory, _DEBUGGER_EVENTS_GLOB_PATTERN)
236+
file_paths = glob.glob(pattern)
237+
238+
if not file_paths:
239+
raise IOError(
240+
'No events files found that matches the pattern %r.', pattern)
241+
242+
# Sort by name (and thus by timestamp).
243+
file_paths.sort()
244+
245+
mapping = collections.defaultdict(list)
246+
node_name_set = frozenset(node_names)
247+
248+
for file_path in file_paths:
249+
should_stop = self._process_health_pill_event(
250+
node_name_set, mapping, step, file_path)
251+
if should_stop:
252+
break
253+
254+
return mapping
255+
256+
def _process_health_pill_event(self, node_name_set, mapping, target_step,
257+
file_path):
258+
"""Creates health pills out of data in an event.
259+
260+
Creates health pills out of the event and adds them to the mapping.
261+
262+
Args:
263+
node_name_set: A set of node names that are relevant.
264+
mapping: The mapping from node name to event_accumulator.HealthPillEvents.
265+
This object may be destructively modified.
266+
target_step: The target step at which to obtain health pills.
267+
file_path: The path to the file with health pill events.
268+
269+
Returns:
270+
Whether we should stop reading events because future events are no longer
271+
relevant.
272+
"""
273+
events_loader = event_file_loader.EventFileLoader(file_path)
274+
for event in events_loader.Load():
275+
if not event.HasField('summary'):
276+
logging.warning('An event in a debugger events file lacks a summary.')
277+
continue
278+
279+
if event.step < target_step:
280+
# This event is not of the relevant step. We perform this check
281+
# first because the majority of events will be eliminated from
282+
# consideration by this check.
283+
continue
284+
285+
if event.step > target_step:
286+
# We have passed the relevant step. No need to read more events.
287+
return True
288+
289+
for value in event.summary.value:
290+
# Since we seek health pills for a specific step, this function
291+
# returns 1 health pill per node per step. The wall time is the
292+
# seconds since the epoch.
293+
health_pill = self._process_health_pill_value(
294+
node_name_set, event.wall_time, event.step, value)
295+
if not health_pill:
296+
continue
297+
mapping[health_pill.node_name].append(health_pill)
298+
299+
# Keep reading events.
300+
return False
301+
302+
def _process_health_pill_value(self, node_name_set, wall_time, step, value):
303+
"""Creates a dict containing various properties of a health pill.
304+
305+
Args:
306+
node_name_set: A set of node names that are relevant.
307+
wall_time: The wall time in seconds.
308+
step: The session run step of the event.
309+
value: The health pill value.
310+
311+
Returns:
312+
An event_accumulator.HealthPillEvent. Or None if one could not be created.
313+
"""
314+
if not value.HasField('tensor'):
315+
logging.warning(
316+
'An event in a debugger events file lacks a tensor value.')
317+
return None
318+
319+
if value.tag != event_accumulator.HEALTH_PILL_EVENT_TAG:
320+
logging.warning(
321+
('A debugger-related event lacks the %r tag. It instead has '
322+
'the %r tag.'), event_accumulator.HEALTH_PILL_EVENT_TAG, value.tag)
323+
return None
324+
325+
match = re.match(r'^(.*):(\d+):DebugNumericSummary$', value.node_name)
326+
if not match:
327+
logging.warning(
328+
('A event with a health pill has an invalid watch, (i.e., an '
329+
'unexpected debug op): %r'), value.node_name)
330+
return None
331+
332+
node_name = match.group(1)
333+
if node_name not in node_name_set:
334+
# This event is not relevant.
335+
return None
151336

152-
return http_util.Respond(request, mapping, 'application/json')
337+
# Since we seek health pills for a specific step, this function
338+
# returns 1 health pill per node per step. The wall time is the
339+
# seconds since the epoch.
340+
return event_accumulator.HealthPillEvent(
341+
wall_time=wall_time,
342+
step=step,
343+
node_name=node_name,
344+
output_slot=int(match.group(2)),
345+
value=list(tensor_util.MakeNdarray(value.tensor)))

0 commit comments

Comments
 (0)