Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feature/add cache to reduce disk reading frequency #169

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions demo/vdl_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import random
import subprocess


import numpy as np
from PIL import Image
from scipy.stats import norm
from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log

logdir = './scratch_log'

Expand Down Expand Up @@ -92,3 +94,20 @@
data = np.random.random(shape).flatten()
image0.add_sample(shape, list(data))
image0.finish_sampling()

def download_graph_image():
'''
This is a scratch demo, it do not generate a ONNX proto, but just download an image
that generated before to show how the graph frontend works.
For real cases, just refer to README.
'''
import urllib
image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
log.warning('download graph demo from {}'.format(image_url))
graph_image = urllib.urlopen(image_url).read()
with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
f.write(graph_image)
log.warning('graph ready!')

download_graph_image()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def readlines(name):
VERSION_NUMBER = read('VERSION_NUMBER')
LICENSE = readlines('LICENSE')[0].strip()

# use memcache to reduce disk read frequency.
install_requires = ['Flask', 'numpy', 'Pillow', 'protobuf', 'scipy']
execute_requires = ['npm', 'node', 'bash']

Expand Down
1 change: 1 addition & 0 deletions visualdl/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ function(py_test TARGET_NAME)
endfunction()

py_test(test_summary SRCS test_storage.py)
py_test(test_cache SRCS cache.py)
54 changes: 54 additions & 0 deletions visualdl/python/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import time


class MemCache(object):
class Record:
def __init__(self, value):
self.time = time.time()
self.value = value

def clear(self):
self.value = None

def expired(self, timeout):
return timeout > 0 and time.time() - self.time >= timeout
'''
A global dict to help cache some temporary data.
'''
def __init__(self, timeout=-1):
self._timeout = timeout
self._data = {}

def set(self, key, value):
self._data[key] = MemCache.Record(value)

def get(self, key):
rcd = self._data.get(key, None)
if not rcd: return None
# do not delete the key to accelerate speed
if rcd.expired(self._timeout):
rcd.clear()
return None
return rcd.value

if __name__ == '__main__':
import unittest

class TestMemCacheTest(unittest.TestCase):
def setUp(self):
self.cache = MemCache(timeout=1)

def expire(self):
self.cache.set("message", "hello")
self.assertFalse(self.cache.expired(1))
time.sleep(4)
self.assertTrue(self.cache.expired(1))

def test_have_key(self):
self.cache.set('message', 'hello')
self.assertTrue(self.cache.get('message'))
time.sleep(1.1)
self.assertFalse(self.cache.get('message'))
self.assertTrue(self.cache.get("message") is None)

unittest.main()
2 changes: 1 addition & 1 deletion visualdl/python/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import numpy as np
from PIL import Image
from visualdl import LogReader, LogWriter

pprint.pprint(sys.path)

from visualdl import LogWriter, LogReader


class StorageTest(unittest.TestCase):
Expand Down
14 changes: 12 additions & 2 deletions visualdl/server/lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pprint
import re
import sys
import time
Expand All @@ -7,6 +6,7 @@

import numpy as np
from PIL import Image

from log import logger


Expand Down Expand Up @@ -90,7 +90,6 @@ def get_image_tags(storage):


def get_image_tag_steps(storage, mode, tag):
print 'image_tag_steps,mode,tag:', mode, tag
# remove suffix '/x'
res = re.search(r".*/([0-9]+$)", tag)
sample_index = 0
Expand Down Expand Up @@ -211,3 +210,14 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep)

def cache_get(cache):
def _handler(key, func, *args, **kwargs):
data = cache.get(key)
if data is None:
logger.warning('update cache %s' % key)
data = func(*args, **kwargs)
cache.set(key, data)
return data
return data
return _handler
64 changes: 40 additions & 24 deletions visualdl/server/visualDL
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from visualdl.server import lib
from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data
from visualdl.server.mock import data as mock_tags
from visualdl.python.cache import MemCache
from visualdl.python.storage import (LogWriter, LogReader)

app = Flask(__name__, static_url_path="")
Expand All @@ -33,7 +34,7 @@ def try_call(function, *args, **kwargs):
res = lib.retry(error_retry_times, function, error_sleep_time, *args,
**kwargs)
if not res:
raise exceptions.IOError("server IO error, will retry latter.")
logger.error("server temporary error, will retry latter.")
return res


Expand Down Expand Up @@ -70,6 +71,14 @@ def parse_args():
action="store",
dest="logdir",
help="log file directory")
parser.add_argument(
"--cache_timeout",
action="store",
dest="cache_timeout",
type=float,
default=20,
help="memory cache timeout duration in seconds, default 20",
)
args = parser.parse_args()
if not args.logdir:
parser.print_help()
Expand All @@ -86,8 +95,11 @@ log_reader = LogReader(args.logdir)

# mannully put graph's image on this path also works.
graph_image_path = os.path.join(args.logdir, 'graph.jpg')
# use a memory cache to reduce disk reading frequency.
CACHE = MemCache(timeout=args.cache_timeout)
cache_get = lib.cache_get(CACHE)


# return data
# status, msg, data
def gen_result(status, msg, data):
"""
Expand Down Expand Up @@ -126,52 +138,54 @@ def logdir():

@app.route('/data/runs')
def runs():
result = gen_result(0, "", lib.get_modes(log_reader))
data = cache_get('/data/runs', lib.get_modes, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route("/data/plugin/scalars/tags")
def scalar_tags():
mode = request.args.get('mode')
is_debug = bool(request.args.get('debug'))
result = try_call(lib.get_scalar_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/scalars/tags", try_call,
lib.get_scalar_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route("/data/plugin/images/tags")
def image_tags():
mode = request.args.get('run')
result = try_call(lib.get_image_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route("/data/plugin/histograms/tags")
def histogram_tags():
mode = request.args.get('run')
# hack to avlid IO conflicts
result = try_call(lib.get_histogram_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/histograms/tags", try_call,
lib.get_histogram_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route('/data/plugin/scalars/scalars')
def scalars():
run = request.args.get('run')
tag = request.args.get('tag')
result = try_call(lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", result)
key = os.path.join('/data/plugin/scalars/scalars', run, tag)
data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route('/data/plugin/images/images')
def images():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/images/images', mode, tag)

result = try_call(lib.get_image_tag_steps, log_reader, mode, tag)
result = gen_result(0, "", result)
data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)

return Response(json.dumps(result), mimetype='application/json')

Expand All @@ -181,21 +195,23 @@ def individual_image():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
offset = 0

imagefile = try_call(lib.get_invididual_image, log_reader, mode, tag,
step_index)
key = os.path.join('/data/plugin/images/individualImage', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
tag, step_index)
response = send_file(
imagefile, as_attachment=True, attachment_filename='img.png')
data, as_attachment=True, attachment_filename='img.png')
return response


@app.route('/data/plugin/histograms/histograms')
def histogram():
run = request.args.get('run')
tag = request.args.get('tag')
result = try_call(lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", result)
key = os.path.join('/data/plugin/histograms/histograms', run, tag)
data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


Expand Down