Skip to content

Commit

Permalink
feature/add cache to reduce disk reading frequency (#169)
Browse files Browse the repository at this point in the history
  • Loading branch information
Superjomn authored Jan 16, 2018
1 parent 4f41b19 commit 6dffdeb
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 27 deletions.
19 changes: 19 additions & 0 deletions demo/vdl_scratch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
import random
import subprocess


import numpy as np
from PIL import Image
from scipy.stats import norm
from visualdl import ROOT, LogWriter
from visualdl.server.log import logger as log

logdir = './scratch_log'

Expand Down Expand Up @@ -92,3 +94,20 @@
data = np.random.random(shape).flatten()
image0.add_sample(shape, list(data))
image0.finish_sampling()

def download_graph_image():
'''
This is a scratch demo, it do not generate a ONNX proto, but just download an image
that generated before to show how the graph frontend works.
For real cases, just refer to README.
'''
import urllib
image_url = "https://github.com/PaddlePaddle/VisualDL/blob/develop/demo/mxnet/super_resolution_graph.png?raw=true"
log.warning('download graph demo from {}'.format(image_url))
graph_image = urllib.urlopen(image_url).read()
with open(os.path.join(logdir, 'graph.jpg'), 'wb') as f:
f.write(graph_image)
log.warning('graph ready!')

download_graph_image()
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def readlines(name):
VERSION_NUMBER = read('VERSION_NUMBER')
LICENSE = readlines('LICENSE')[0].strip()

# use memcache to reduce disk read frequency.
install_requires = ['Flask', 'numpy', 'Pillow', 'protobuf', 'scipy']
execute_requires = ['npm', 'node', 'bash']

Expand Down
1 change: 1 addition & 0 deletions visualdl/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ function(py_test TARGET_NAME)
endfunction()

py_test(test_summary SRCS test_storage.py)
py_test(test_cache SRCS cache.py)
54 changes: 54 additions & 0 deletions visualdl/python/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import time


class MemCache(object):
class Record:
def __init__(self, value):
self.time = time.time()
self.value = value

def clear(self):
self.value = None

def expired(self, timeout):
return timeout > 0 and time.time() - self.time >= timeout
'''
A global dict to help cache some temporary data.
'''
def __init__(self, timeout=-1):
self._timeout = timeout
self._data = {}

def set(self, key, value):
self._data[key] = MemCache.Record(value)

def get(self, key):
rcd = self._data.get(key, None)
if not rcd: return None
# do not delete the key to accelerate speed
if rcd.expired(self._timeout):
rcd.clear()
return None
return rcd.value

if __name__ == '__main__':
import unittest

class TestMemCacheTest(unittest.TestCase):
def setUp(self):
self.cache = MemCache(timeout=1)

def expire(self):
self.cache.set("message", "hello")
self.assertFalse(self.cache.expired(1))
time.sleep(4)
self.assertTrue(self.cache.expired(1))

def test_have_key(self):
self.cache.set('message', 'hello')
self.assertTrue(self.cache.get('message'))
time.sleep(1.1)
self.assertFalse(self.cache.get('message'))
self.assertTrue(self.cache.get("message") is None)

unittest.main()
2 changes: 1 addition & 1 deletion visualdl/python/test_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@

import numpy as np
from PIL import Image
from visualdl import LogReader, LogWriter

pprint.pprint(sys.path)

from visualdl import LogWriter, LogReader


class StorageTest(unittest.TestCase):
Expand Down
14 changes: 12 additions & 2 deletions visualdl/server/lib.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pprint
import re
import sys
import time
Expand All @@ -7,6 +6,7 @@

import numpy as np
from PIL import Image

from log import logger


Expand Down Expand Up @@ -90,7 +90,6 @@ def get_image_tags(storage):


def get_image_tag_steps(storage, mode, tag):
print 'image_tag_steps,mode,tag:', mode, tag
# remove suffix '/x'
res = re.search(r".*/([0-9]+$)", tag)
sample_index = 0
Expand Down Expand Up @@ -211,3 +210,14 @@ def retry(ntimes, function, time2sleep, *args, **kwargs):
error_info = '\n'.join(map(str, sys.exc_info()))
logger.error("Unexpected error: %s" % error_info)
time.sleep(time2sleep)

def cache_get(cache):
def _handler(key, func, *args, **kwargs):
data = cache.get(key)
if data is None:
logger.warning('update cache %s' % key)
data = func(*args, **kwargs)
cache.set(key, data)
return data
return data
return _handler
64 changes: 40 additions & 24 deletions visualdl/server/visualDL
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ from visualdl.server import lib
from visualdl.server.log import logger
from visualdl.server.mock import data as mock_data
from visualdl.server.mock import data as mock_tags
from visualdl.python.cache import MemCache
from visualdl.python.storage import (LogWriter, LogReader)

app = Flask(__name__, static_url_path="")
Expand All @@ -33,7 +34,7 @@ def try_call(function, *args, **kwargs):
res = lib.retry(error_retry_times, function, error_sleep_time, *args,
**kwargs)
if not res:
raise exceptions.IOError("server IO error, will retry latter.")
logger.error("server temporary error, will retry latter.")
return res


Expand Down Expand Up @@ -70,6 +71,14 @@ def parse_args():
action="store",
dest="logdir",
help="log file directory")
parser.add_argument(
"--cache_timeout",
action="store",
dest="cache_timeout",
type=float,
default=20,
help="memory cache timeout duration in seconds, default 20",
)
args = parser.parse_args()
if not args.logdir:
parser.print_help()
Expand All @@ -86,8 +95,11 @@ log_reader = LogReader(args.logdir)

# mannully put graph's image on this path also works.
graph_image_path = os.path.join(args.logdir, 'graph.jpg')
# use a memory cache to reduce disk reading frequency.
CACHE = MemCache(timeout=args.cache_timeout)
cache_get = lib.cache_get(CACHE)


# return data
# status, msg, data
def gen_result(status, msg, data):
"""
Expand Down Expand Up @@ -126,52 +138,54 @@ def logdir():

@app.route('/data/runs')
def runs():
result = gen_result(0, "", lib.get_modes(log_reader))
data = cache_get('/data/runs', lib.get_modes, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route("/data/plugin/scalars/tags")
def scalar_tags():
mode = request.args.get('mode')
is_debug = bool(request.args.get('debug'))
result = try_call(lib.get_scalar_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/scalars/tags", try_call,
lib.get_scalar_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route("/data/plugin/images/tags")
def image_tags():
mode = request.args.get('run')
result = try_call(lib.get_image_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/images/tags", try_call, lib.get_image_tags,
log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route("/data/plugin/histograms/tags")
def histogram_tags():
mode = request.args.get('run')
# hack to avlid IO conflicts
result = try_call(lib.get_histogram_tags, log_reader)
result = gen_result(0, "", result)
data = cache_get("/data/plugin/histograms/tags", try_call,
lib.get_histogram_tags, log_reader)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route('/data/plugin/scalars/scalars')
def scalars():
run = request.args.get('run')
tag = request.args.get('tag')
result = try_call(lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", result)
key = os.path.join('/data/plugin/scalars/scalars', run, tag)
data = cache_get(key, try_call, lib.get_scalar, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


@app.route('/data/plugin/images/images')
def images():
mode = request.args.get('run')
tag = request.args.get('tag')
key = os.path.join('/data/plugin/images/images', mode, tag)

result = try_call(lib.get_image_tag_steps, log_reader, mode, tag)
result = gen_result(0, "", result)
data = cache_get(key, try_call, lib.get_image_tag_steps, log_reader, mode,
tag)
result = gen_result(0, "", data)

return Response(json.dumps(result), mimetype='application/json')

Expand All @@ -181,21 +195,23 @@ def individual_image():
mode = request.args.get('run')
tag = request.args.get('tag') # include a index
step_index = int(request.args.get('index')) # index of step
offset = 0

imagefile = try_call(lib.get_invididual_image, log_reader, mode, tag,
step_index)
key = os.path.join('/data/plugin/images/individualImage', mode, tag,
str(step_index))
data = cache_get(key, try_call, lib.get_invididual_image, log_reader, mode,
tag, step_index)
response = send_file(
imagefile, as_attachment=True, attachment_filename='img.png')
data, as_attachment=True, attachment_filename='img.png')
return response


@app.route('/data/plugin/histograms/histograms')
def histogram():
run = request.args.get('run')
tag = request.args.get('tag')
result = try_call(lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", result)
key = os.path.join('/data/plugin/histograms/histograms', run, tag)
data = cache_get(key, try_call, lib.get_histogram, log_reader, run, tag)
result = gen_result(0, "", data)
return Response(json.dumps(result), mimetype='application/json')


Expand Down

0 comments on commit 6dffdeb

Please sign in to comment.