-
Notifications
You must be signed in to change notification settings - Fork 1.7k
/
profile_plugin.py
440 lines (382 loc) · 16.3 KB
/
profile_plugin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
# -*- coding: utf-8 -*-
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TensorBoard plugin for performance profiling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import threading
import six
import tensorflow as tf
from werkzeug import wrappers
from tensorboard.backend import http_util
from tensorboard.backend.event_processing import plugin_asset_util
from tensorboard.plugins import base_plugin
from tensorboard.plugins.profile import trace_events_json
from tensorboard.plugins.profile import trace_events_pb2
from tensorboard.util import tb_logging
logger = tb_logging.get_logger()
# The prefix of routes provided by this plugin.
PLUGIN_NAME = 'profile'
# HTTP routes
DATA_ROUTE = '/data'
TOOLS_ROUTE = '/tools'
HOSTS_ROUTE = '/hosts'
# Available profiling tools -> file name of the tool data.
_FILE_NAME = 'TOOL_FILE_NAME'
TOOLS = {
'trace_viewer': 'trace',
'trace_viewer@': 'tracetable', #streaming traceviewer
'op_profile': 'op_profile.json',
'input_pipeline_analyzer': 'input_pipeline.json',
'overview_page': 'overview_page.json',
'memory_viewer': 'memory_viewer.json',
'google_chart_demo': 'google_chart_demo.json',
}
# Tools that consume raw data.
_RAW_DATA_TOOLS = frozenset(['input_pipeline_analyzer',
'op_profile',
'overview_page',
'memory_viewer',
'google_chart_demo',])
def process_raw_trace(raw_trace):
"""Processes raw trace data and returns the UI data."""
trace = trace_events_pb2.Trace()
trace.ParseFromString(raw_trace)
return ''.join(trace_events_json.TraceEventsJsonStream(trace))
class ProfilePluginLoader(base_plugin.TBLoader):
"""Loader for Profile Plugin."""
def define_flags(self, parser):
group = parser.add_argument_group('profile plugin')
group.add_argument(
'--master_tpu_unsecure_channel',
metavar='ADDR',
type=str,
default='',
help='''\
IP address of "master tpu", used for getting streaming trace data
through tpu profiler analysis grpc. The grpc channel is not secured.\
''')
def load(self, context):
return ProfilePlugin(context)
class ProfilePlugin(base_plugin.TBPlugin):
"""Profile Plugin for TensorBoard."""
plugin_name = PLUGIN_NAME
def __init__(self, context):
"""Constructs a profiler plugin for TensorBoard.
This plugin adds handlers for performance-related frontends.
Args:
context: A base_plugin.TBContext instance.
"""
self.logdir = context.logdir
self.multiplexer = context.multiplexer
self.plugin_logdir = plugin_asset_util.PluginDirectory(
self.logdir, PLUGIN_NAME)
self.stub = None
self.master_tpu_unsecure_channel = context.flags.master_tpu_unsecure_channel
# Whether the plugin is active. This is an expensive computation, so we
# compute this asynchronously and cache positive results indefinitely.
self._is_active = False
# Lock to ensure at most one thread computes _is_active at a time.
self._is_active_lock = threading.Lock()
def is_active(self):
"""Whether this plugin is active and has any profile data to show.
Detecting profile data is expensive, so this process runs asynchronously
and the value reported by this method is the cached value and may be stale.
Returns:
Whether any run has profile data.
"""
# If we are already active, we remain active and don't recompute this.
# Otherwise, try to acquire the lock without blocking; if we get it and
# we're still not active, launch a thread to check if we're active and
# release the lock once the computation is finished. Either way, this
# thread returns the current cached value to avoid blocking.
if not self._is_active and self._is_active_lock.acquire(False):
if self._is_active:
self._is_active_lock.release()
else:
def compute_is_active():
self._is_active = any(self.generate_run_to_tools())
self._is_active_lock.release()
new_thread = threading.Thread(
target=compute_is_active,
name='ProfilePluginIsActiveThread')
new_thread.start()
return self._is_active
def start_grpc_stub_if_necessary(self):
# We will enable streaming trace viewer on two conditions:
# 1. user specify the flags master_tpu_unsecure_channel to the ip address of
# as "master" TPU. grpc will be used to fetch streaming trace data.
# 2. the logdir is on google cloud storage.
if self.master_tpu_unsecure_channel and self.logdir.startswith('gs://'):
if self.stub is None:
import grpc
from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2_grpc # pylint: disable=line-too-long
# Workaround the grpc's 4MB message limitation.
gigabyte = 1024 * 1024 * 1024
options = [('grpc.max_message_length', gigabyte),
('grpc.max_send_message_length', gigabyte),
('grpc.max_receive_message_length', gigabyte)]
tpu_profiler_port = self.master_tpu_unsecure_channel + ':8466'
channel = grpc.insecure_channel(tpu_profiler_port, options)
self.stub = tpu_profiler_analysis_pb2_grpc.TPUProfileAnalysisStub(
channel)
def _run_dir(self, run):
"""Helper that maps a frontend run name to a profile "run" directory.
The frontend run name consists of the TensorBoard run name (aka the relative
path from the logdir root to the directory containing the data) path-joined
to the Profile plugin's "run" concept (which is a subdirectory of the
plugins/profile directory representing an individual run of the tool), with
the special case that TensorBoard run is the logdir root (which is the run
named '.') then only the Profile plugin "run" name is used, for backwards
compatibility.
To convert back to the actual run directory, we apply the following
transformation:
- If the run name doesn't contain '/', prepend './'
- Split on the rightmost instance of '/'
- Assume the left side is a TensorBoard run name and map it to a directory
path using EventMultiplexer.RunPaths(), then map that to the profile
plugin directory via PluginDirectory()
- Assume the right side is a Profile plugin "run" and path-join it to
the preceding path to get the final directory
Args:
run: the frontend run name, as described above, e.g. train/run1.
Returns:
The resolved directory path, e.g. /logdir/train/plugins/profile/run1.
"""
run = run.rstrip('/')
if '/' not in run:
run = './' + run
tb_run_name, _, profile_run_name = run.rpartition('/')
tb_run_directory = self.multiplexer.RunPaths().get(tb_run_name)
if tb_run_directory is None:
# Check if logdir is a directory to handle case where it's actually a
# multipart directory spec, which this plugin does not support.
if tb_run_name == '.' and tf.io.gfile.isdir(self.logdir):
tb_run_directory = self.logdir
else:
raise RuntimeError("No matching run directory for run %s" % run)
plugin_directory = plugin_asset_util.PluginDirectory(
tb_run_directory, PLUGIN_NAME)
return os.path.join(plugin_directory, profile_run_name)
def generate_run_to_tools(self):
"""Generator for pairs of "run name" and a list of tools for that run.
The "run name" here is a "frontend run name" - see _run_dir() for the
definition of a "frontend run name" and how it maps to a directory of
profile data for a specific profile "run". The profile plugin concept of
"run" is different from the normal TensorBoard run; each run in this case
represents a single instance of profile data collection, more similar to a
"step" of data in typical TensorBoard semantics. These runs reside in
subdirectories of the plugins/profile directory within any regular
TensorBoard run directory (defined as a subdirectory of the logdir that
contains at least one tfevents file) or within the logdir root directory
itself (even if it contains no tfevents file and would thus not be
considered a normal TensorBoard run, for backwards compatibility).
Within those "profile run directories", there are files in the directory
that correspond to different profiling tools. The file that contains profile
for a specific tool "x" will have a suffix name TOOLS["x"].
Example:
logs/
plugins/
profile/
run1/
hostA.trace
train/
events.out.tfevents.foo
plugins/
profile/
run1/
hostA.trace
hostB.trace
run2/
hostA.trace
validation/
events.out.tfevents.foo
plugins/
profile/
run1/
hostA.trace
Yields:
A sequence of tuples mapping "frontend run names" to lists of tool names
available for those runs. For the above example, this would be:
("run1", ["trace_viewer"])
("train/run1", ["trace_viewer"])
("train/run2", ["trace_viewer"])
("validation/run1", ["trace_viewer"])
"""
self.start_grpc_stub_if_necessary()
plugin_assets = self.multiplexer.PluginAssets(PLUGIN_NAME)
tb_run_names_to_dirs = self.multiplexer.RunPaths()
# Ensure that we also check the root logdir, even if it isn't a recognized
# TensorBoard run (i.e. has no tfevents file directly under it), to remain
# backwards compatible with previously profile plugin behavior. Note that we
# check if logdir is a directory to handle case where it's actually a
# multipart directory spec, which this plugin does not support.
if '.' not in plugin_assets and tf.io.gfile.isdir(self.logdir):
tb_run_names_to_dirs['.'] = self.logdir
plugin_assets['.'] = plugin_asset_util.ListAssets(
self.logdir, PLUGIN_NAME)
for tb_run_name, profile_runs in six.iteritems(plugin_assets):
tb_run_dir = tb_run_names_to_dirs[tb_run_name]
tb_plugin_dir = plugin_asset_util.PluginDirectory(
tb_run_dir, PLUGIN_NAME)
for profile_run in profile_runs:
# Remove trailing slash; some filesystem implementations emit this.
profile_run = profile_run.rstrip('/')
if tb_run_name == '.':
frontend_run = profile_run
else:
frontend_run = '/'.join([tb_run_name, profile_run])
profile_run_dir = os.path.join(tb_plugin_dir, profile_run)
if tf.io.gfile.isdir(profile_run_dir):
yield frontend_run, self._get_active_tools(profile_run_dir)
def _get_active_tools(self, profile_run_dir):
tools = []
for tool in TOOLS:
tool_pattern = '*' + TOOLS[tool]
path = os.path.join(profile_run_dir, tool_pattern)
try:
files = tf.io.gfile.glob(path)
if len(files) >= 1:
tools.append(tool)
except tf.errors.OpError as e:
logger.warn("Cannot read asset directory: %s, OpError %s",
profile_run_dir, e)
if 'trace_viewer@' in tools:
# streaming trace viewer always override normal trace viewer.
# the trailing '@' is to inform tf-profile-dashboard.html and
# tf-trace-viewer.html that stream trace viewer should be used.
removed_tool = 'trace_viewer@' if self.stub is None else 'trace_viewer'
if removed_tool in tools:
tools.remove(removed_tool)
tools.sort()
op = 'overview_page'
if op in tools:
# keep overview page at the top of the list
tools.remove(op)
tools.insert(0, op)
return tools
@wrappers.Request.application
def tools_route(self, request):
run_to_tools = dict(self.generate_run_to_tools())
return http_util.Respond(request, run_to_tools, 'application/json')
def host_impl(self, run, tool):
"""Returns available hosts for the run and tool in the log directory.
In the plugin log directory, each directory contains profile data for a
single run (identified by the directory name), and files in the run
directory contains data for different tools and hosts. The file that
contains profile for a specific tool "x" will have a prefix name TOOLS["x"].
Example:
log/
run1/
plugins/
profile/
host1.trace
host2.trace
run2/
plugins/
profile/
host1.trace
host2.trace
Returns:
A list of host names e.g.
{"host1", "host2", "host3"} for the example.
"""
hosts = {}
run_dir = self._run_dir(run)
if not run_dir:
logger.warn("Cannot find asset directory for: %s", run)
return hosts
tool_pattern = '*' + TOOLS[tool]
try:
files = tf.io.gfile.glob(os.path.join(run_dir, tool_pattern))
hosts = [os.path.basename(f).replace(TOOLS[tool], '') for f in files]
except tf.errors.OpError as e:
logger.warn("Cannot read asset directory: %s, OpError %s",
run_dir, e)
return hosts
@wrappers.Request.application
def hosts_route(self, request):
run = request.args.get('run')
tool = request.args.get('tag')
hosts = self.host_impl(run, tool)
return http_util.Respond(request, hosts, 'application/json')
def data_impl(self, request):
"""Retrieves and processes the tool data for a run and a host.
Args:
request: XMLHttpRequest
Returns:
A string that can be served to the frontend tool or None if tool,
run or host is invalid.
"""
run = request.args.get('run')
tool = request.args.get('tag')
host = request.args.get('host')
run_dir = self._run_dir(run)
# Profile plugin "run" is the last component of run dir.
profile_run = os.path.basename(run_dir)
if tool not in TOOLS:
return None
self.start_grpc_stub_if_necessary()
if tool == 'trace_viewer@' and self.stub is not None:
from tensorflow.contrib.tpu.profiler import tpu_profiler_analysis_pb2
grpc_request = tpu_profiler_analysis_pb2.ProfileSessionDataRequest()
grpc_request.repository_root = run_dir
grpc_request.session_id = profile_run[:-1]
grpc_request.tool_name = 'trace_viewer'
# Remove the trailing dot if present
grpc_request.host_name = host.rstrip('.')
grpc_request.parameters['resolution'] = request.args.get('resolution')
if request.args.get('start_time_ms') is not None:
grpc_request.parameters['start_time_ms'] = request.args.get(
'start_time_ms')
if request.args.get('end_time_ms') is not None:
grpc_request.parameters['end_time_ms'] = request.args.get('end_time_ms')
grpc_response = self.stub.GetSessionToolData(grpc_request)
return grpc_response.output
if tool not in TOOLS:
return None
tool_name = str(host) + TOOLS[tool]
asset_path = os.path.join(run_dir, tool_name)
raw_data = None
try:
with tf.io.gfile.GFile(asset_path, 'rb') as f:
raw_data = f.read()
except tf.errors.NotFoundError:
logger.warn('Asset path %s not found', asset_path)
except tf.errors.OpError as e:
logger.warn("Couldn't read asset path: %s, OpError %s", asset_path, e)
if raw_data is None:
return None
if tool == 'trace_viewer':
return process_raw_trace(raw_data)
if tool in _RAW_DATA_TOOLS:
return raw_data
return None
@wrappers.Request.application
def data_route(self, request):
# params
# request: XMLHTTPRequest.
data = self.data_impl(request)
if data is None:
return http_util.Respond(request, '404 Not Found', 'text/plain', code=404)
return http_util.Respond(request, data, 'application/json')
def get_plugin_apps(self):
return {
TOOLS_ROUTE: self.tools_route,
HOSTS_ROUTE: self.hosts_route,
DATA_ROUTE: self.data_route,
}