Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

factored out make_rest_server_debug/prod #4268

Merged
merged 2 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ examples/models/autoscaling/autoscaling_example.py

# vscode local history
.history/
.vscode/

# python venv
venv/
168 changes: 107 additions & 61 deletions python/seldon_core/microservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import time
from distutils.util import strtobool
from functools import partial
from typing import Callable, Dict
from typing import Any, Callable, Dict, List

from seldon_core import __version__
from seldon_core import wrapper as seldon_microservice
Expand Down Expand Up @@ -353,6 +353,102 @@ def parse_args():
return parser.parse_known_args()


def _make_rest_server_debug(
user_object: Any,
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
*,
zyxue marked this conversation as resolved.
Show resolved Hide resolved
jaeger_extra_tags: List[str],
) -> Callable[[], None]:
"""Makes a function that creates a REST debugging server.
Args:
user_object: an instance of user-defined class, inherited from user_model.SeldonComponent.
seldon_metrics: a SeldonMetrics instance.
args: parsed args from commandline.
jaeger_extra_tags:
"""

def server():
app = seldon_microservice.get_rest_microservice(user_object, seldon_metrics)
try:
user_object.load()
except (NotImplementedError, AttributeError):
pass
if args.tracing:
logger.info("Tracing branch is active")
from flask_opentracing import FlaskTracing

tracer = setup_tracing(args.interface_name)

logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
FlaskTracing(tracer, True, app, jaeger_extra_tags)

# Timeout not supported in flask development server
app.run(
host="0.0.0.0",
port=args.http_port,
threaded=False if args.single_threaded else True,
)

return server


def _make_rest_server_prod(
user_object: Any,
seldon_metrics: SeldonMetrics,
args: argparse.Namespace,
*,
zyxue marked this conversation as resolved.
Show resolved Hide resolved
jaeger_extra_tags: List[str],
annotations: Dict[str, str],
) -> Callable[[], None]:
"""Makes a function that creates a REST production server.
Args:
user_object: an instance of user-defined class, inherited from user_model.SeldonComponent.
seldon_metrics: a SeldonMetrics instance.
args: parsed args from commandline.
jaeger_extra_tags:
annotations:
"""

def server():
rest_timeout = DEFAULT_ANNOTATION_REST_TIMEOUT
if ANNOTATION_REST_TIMEOUT in annotations:
# Gunicorn timeout is in seconds so convert as annotation is in miliseconds
rest_timeout = int(annotations[ANNOTATION_REST_TIMEOUT]) / 1000
# Converting timeout from float to int and set to 1 if is 0
rest_timeout = int(rest_timeout) or 1

options = {
"bind": "%s:%s" % ("0.0.0.0", args.http_port),
"accesslog": accesslog(args.access_log),
"loglevel": args.log_level.lower(),
"timeout": rest_timeout,
"threads": threads(args.threads, args.single_threaded),
"workers": args.workers,
"max_requests": args.max_requests,
"max_requests_jitter": args.max_requests_jitter,
"post_worker_init": post_worker_init,
"worker_exit": partial(worker_exit, seldon_metrics=seldon_metrics),
"keepalive": args.keepalive,
}
logger.info(f"Gunicorn Config: {options}")

if args.pidfile is not None:
options["pidfile"] = args.pidfile
app = seldon_microservice.get_rest_microservice(user_object, seldon_metrics)

UserModelApplication(
app,
user_object,
args.tracing,
jaeger_extra_tags,
args.interface_name,
options=options,
).run()

return server


def main():
LOG_FORMAT = (
"%(asctime)s - %(name)s:%(funcName)s:%(lineno)s - %(levelname)s: %(message)s"
Expand Down Expand Up @@ -412,74 +508,24 @@ def main():
# )
if args.debug:
# Start Flask debug server
def rest_prediction_server():
app = seldon_microservice.get_rest_microservice(user_object, seldon_metrics)
try:
user_object.load()
except (NotImplementedError, AttributeError):
pass
if args.tracing:
logger.info("Tracing branch is active")
from flask_opentracing import FlaskTracing

tracer = setup_tracing(args.interface_name)

logger.info("Set JAEGER_EXTRA_TAGS %s", jaeger_extra_tags)
FlaskTracing(tracer, True, app, jaeger_extra_tags)

# Timeout not supported in flask development server
app.run(
host="0.0.0.0",
port=http_port,
threaded=False if args.single_threaded else True,
)

logger.info(
"REST microservice running on port %i single-threaded=%s",
http_port,
args.single_threaded,
)
server_rest_func = rest_prediction_server
server_rest_func = _make_rest_server_debug(
user_object, seldon_metrics, args, jaeger_extra_tags=jaeger_extra_tags
)
else:
# Start production server
def rest_prediction_server():
rest_timeout = DEFAULT_ANNOTATION_REST_TIMEOUT
if ANNOTATION_REST_TIMEOUT in annotations:
# Gunicorn timeout is in seconds so convert as annotation is in miliseconds
rest_timeout = int(annotations[ANNOTATION_REST_TIMEOUT]) / 1000
# Converting timeout from float to int and set to 1 if is 0
rest_timeout = int(rest_timeout) or 1

options = {
"bind": "%s:%s" % ("0.0.0.0", http_port),
"accesslog": accesslog(args.access_log),
"loglevel": args.log_level.lower(),
"timeout": rest_timeout,
"threads": threads(args.threads, args.single_threaded),
"workers": args.workers,
"max_requests": args.max_requests,
"max_requests_jitter": args.max_requests_jitter,
"post_worker_init": post_worker_init,
"worker_exit": partial(worker_exit, seldon_metrics=seldon_metrics),
"keepalive": args.keepalive,
}
logger.info(f"Gunicorn Config: {options}")

if args.pidfile is not None:
options["pidfile"] = args.pidfile
app = seldon_microservice.get_rest_microservice(user_object, seldon_metrics)

UserModelApplication(
app,
user_object,
args.tracing,
jaeger_extra_tags,
args.interface_name,
options=options,
).run()

logger.info("REST gunicorn microservice running on port %i", http_port)
server_rest_func = rest_prediction_server
server_rest_func = _make_rest_server_prod(
user_object,
seldon_metrics,
args,
jaeger_extra_tags=jaeger_extra_tags,
annotations=annotations,
)

def _wait_forever(server):
try:
Expand Down