-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
66 lines (48 loc) · 1.87 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import signal
import algorithm
import argparse
import flask
import cherrypy
parser = argparse.ArgumentParser()
parser.add_argument("mode",
choices=['train', 'serve'],
help="train: The algorithm will be run in train mode. Trained models will be stored for later use. serve: a web server will start on port 8080 to listen for prediction requests.")
def main(args):
print(args.mode)
if args.mode == 'train':
algorithm.Init('train')
algorithm.Train()
elif args.mode == 'serve':
algorithm.Init('predict')
app = flask.Flask(__name__)
# required by AWS SageMaker, must return 200 and empty body
@app.route('/ping')
def ping():
return 'OK'
@app.route('/invocations', methods=['POST'])
def prediction():
try:
features = flask.request.get_json()
if not features:
return 'must post json payload with content-type header', 400
except Exception as e:
print('exception', e)
return 'internal server error', 500
try:
result = algorithm.Predict(features)
except Exception as e:
print('exception', e)
return 'internal server error', 500
return flask.jsonify(result)
cherrypy.tree.graft(app, '/')
cherrypy.config.update({'server.socket_host': '0.0.0.0',
'server.socket_port': 8080,
'engine.autoreload.on': False,
})
def shutdown(signum, frame):
cherrypy.engine.stop()
cherrypy.engine.start()
signal.signal(signal.SIGINT, shutdown)
signal.signal(signal.SIGTERM, shutdown)
if __name__ == "__main__":
main(parser.parse_args())