diff --git a/Makefile b/Makefile index 68b72008..92c1ca9b 100644 --- a/Makefile +++ b/Makefile @@ -13,7 +13,7 @@ dev: cp ./target/debug/mosec mosec/bin/ pip install -e . -test: +test: dev pytest tests -vv -s RUST_BACKTRACE=1 cargo test -vv @@ -44,5 +44,6 @@ lint: flake8 ${PY_SOURCE_FILES} --count --show-source --statistics mypy --install-types --non-interactive ${PY_SOURCE_FILES} cargo +nightly fmt -- --check + cargo clippy .PHONY: test doc diff --git a/examples/echo.py b/examples/echo.py index bc4feb55..f83f6b36 100644 --- a/examples/echo.py +++ b/examples/echo.py @@ -21,7 +21,7 @@ def forward(self, data: dict) -> float: try: time = float(data["time"]) except KeyError as err: - raise ValidationError(err) + raise ValidationError(f"cannot find key {err}") return time diff --git a/examples/resnet50_server_pytorch.py b/examples/resnet50_server_pytorch.py index e299b3fe..d7f272f9 100644 --- a/examples/resnet50_server_pytorch.py +++ b/examples/resnet50_server_pytorch.py @@ -32,7 +32,7 @@ def forward(self, req: dict) -> np.ndarray: im = np.frombuffer(base64.b64decode(image), np.uint8) im = cv2.imdecode(im, cv2.IMREAD_COLOR)[:, :, ::-1] # bgr -> rgb except KeyError as err: - raise ValidationError(f"bad request: {err}") + raise ValidationError(f"cannot find key {err}") except Exception as err: raise ValidationError(f"cannot decode as image data: {err}") diff --git a/mosec/coordinator.py b/mosec/coordinator.py index 757e12dd..03aef9af 100644 --- a/mosec/coordinator.py +++ b/mosec/coordinator.py @@ -8,7 +8,7 @@ from multiprocessing.synchronize import Event from typing import Callable, Type -from .errors import ValidationError +from .errors import DecodingError, ValidationError from .protocol import Protocol from .worker import Worker @@ -160,15 +160,24 @@ def coordinate(self): "returned data doesn't match the input data:" f"input({len(data)})!=output({len(payloads)})" ) - except ValidationError as err: + except DecodingError as err: err_msg = str(err).replace("\n", " - ") + err_msg = ( + err_msg if len(err_msg) else "cannot deserialize request bytes" + ) + logger.info(f"{self.name} decoding error: {err_msg}") + status = self.protocol.FLAG_BAD_REQUEST + payloads = (f"decoding error: {err_msg}".encode(),) + except ValidationError as err: + err_msg = str(err) + err_msg = err_msg if len(err_msg) else "invalid data format" logger.info(f"{self.name} validation error: {err_msg}") status = self.protocol.FLAG_VALIDATION_ERROR - payloads = (f"Validation Error: {err_msg}".encode(),) + payloads = (f"validation error: {err_msg}".encode(),) except Exception: logger.warning(traceback.format_exc().replace("\n", " ")) status = self.protocol.FLAG_INTERNAL_ERROR - payloads = ("Internal Error".encode(),) + payloads = ("inference internal error".encode(),) try: self.protocol.send(status, ids, payloads) diff --git a/mosec/errors.py b/mosec/errors.py index 77e2557c..a9c71c15 100644 --- a/mosec/errors.py +++ b/mosec/errors.py @@ -1,7 +1,31 @@ +""" +Suppose the input dataflow of our model server is as follows: + +**bytes** --- *deserialize*(decoding) ---> **data** +--- *parse*(validation) ---> **valid data** + +If the raw bytes cannot be successfully deserialized, the `DecodingError` +is raised; if the decoded data cannot pass the validation check (usually +implemented by users), the `ValidationError` should be raised. +""" + + +class DecodingError(Exception): + """ + The `DecodingError` should be raised in user-implemented codes + when the de-serialization for the request bytes fails. This error + will set the status code to + [HTTP 400]("https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400) + in the response. + """ + + class ValidationError(Exception): """ The `ValidationError` should be raised in user-implemented codes, - where the validation for the input data fails. Usually it can be put - after the data deserialization, which converts the raw bytes into - structured data. + where the validation for the input data fails. Usually, it should be + put after the data de-serialization, which converts the raw bytes + into structured data. This error will set the status code to + [HTTP 422](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/422) + in the response. """ diff --git a/mosec/worker.py b/mosec/worker.py index 664d84b1..cd7b4c03 100644 --- a/mosec/worker.py +++ b/mosec/worker.py @@ -3,7 +3,7 @@ import pickle from typing import Any -from .errors import ValidationError +from .errors import DecodingError logger = logging.getLogger(__name__) @@ -89,7 +89,7 @@ def deserialize(self, data: bytes) -> Any: try: data_json = json.loads(data) if data else {} except Exception as err: - raise ValidationError(err) + raise DecodingError(err) return data_json def forward(self, data: Any) -> Any: diff --git a/src/coordinator.rs b/src/coordinator.rs index a404fcdd..68604329 100644 --- a/src/coordinator.rs +++ b/src/coordinator.rs @@ -27,7 +27,7 @@ impl Coordinator { let (sender, receiver) = bounded(opts.capacity); let timeout = Duration::from_millis(opts.timeout); let wait_time = Duration::from_millis(opts.wait); - let path = if opts.path.len() > 0 { + let path = if !opts.path.is_empty() { opts.path.to_string() } else { // default IPC path @@ -44,7 +44,7 @@ impl Coordinator { Self { capacity: opts.capacity, - path: path, + path, batches: opts.batches.clone(), wait_time, timeout, diff --git a/src/errors.rs b/src/errors.rs index 2de109b5..7ee8f6c5 100644 --- a/src/errors.rs +++ b/src/errors.rs @@ -5,21 +5,9 @@ pub(crate) enum ServiceError { #[display(fmt = "inference timeout")] Timeout, - #[display(fmt = "bad request")] - BadRequestError, - - #[display(fmt = "bad request: validation error")] - ValidationError, - - #[display(fmt = "inference internal error")] - InternalError, - - #[display(fmt = "too many request: channel is full")] + #[display(fmt = "too many request: task queue is full")] TooManyRequests, - #[display(fmt = "cannot accept new request during the graceful shutdown")] - GracefulShutdown, - #[display(fmt = "mosec unknown error")] UnknownError, } diff --git a/src/main.rs b/src/main.rs index 08d91ce2..b279d0ab 100644 --- a/src/main.rs +++ b/src/main.rs @@ -7,6 +7,7 @@ mod tasks; use std::net::SocketAddr; +use bytes::Bytes; use clap::Clap; use hyper::service::{make_service_fn, service_fn}; use hyper::{body::to_bytes, header::HeaderValue, Body, Method, Request, Response, StatusCode}; @@ -22,97 +23,102 @@ use crate::metrics::Metrics; use crate::tasks::{TaskCode, TaskManager}; const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")); -const NOT_FOUND: &[u8] = b"Not Found"; +const RESPONSE_DEFAULT: &[u8] = b"MOSEC service"; +const RESPONSE_NOT_FOUND: &[u8] = b"not found"; +const RESPONSE_EMPTY: &[u8] = b"no data provided"; +const RESPONSE_SHUTDOWN: &[u8] = b"gracefully shutting down"; -async fn index(_: Request) -> Result, ServiceError> { +async fn index(_: Request) -> Response { let task_manager = TaskManager::global(); if task_manager.is_shutdown() { - return Err(ServiceError::GracefulShutdown); + build_response( + StatusCode::SERVICE_UNAVAILABLE, + Bytes::from_static(RESPONSE_SHUTDOWN), + ) + } else { + build_response(StatusCode::OK, Bytes::from_static(RESPONSE_DEFAULT)) } - Ok(Response::new(Body::from("MOSEC service"))) } -async fn metrics(_: Request) -> Result, ServiceError> { +async fn metrics(_: Request) -> Response { let encoder = TextEncoder::new(); let metrics = prometheus::gather(); let mut buffer = vec![]; encoder.encode(&metrics, &mut buffer).unwrap(); - Ok(Response::new(Body::from(buffer))) + build_response(StatusCode::OK, Bytes::from(buffer)) } -async fn inference(req: Request) -> Result, ServiceError> { +async fn inference(req: Request) -> Response { let task_manager = TaskManager::global(); let data = to_bytes(req.into_body()).await.unwrap(); let metrics = Metrics::global(); + if task_manager.is_shutdown() { + return build_response( + StatusCode::SERVICE_UNAVAILABLE, + Bytes::from_static(RESPONSE_SHUTDOWN), + ); + } + if data.is_empty() { - return Ok(Response::new(Body::from("No data provided"))); + return build_response(StatusCode::OK, Bytes::from_static(RESPONSE_EMPTY)); } + let (status, content); metrics.remaining_task.inc(); - let task = task_manager.submit_task(data).await?; - match task.code { - TaskCode::Normal => { - metrics.remaining_task.dec(); - metrics - .duration - .with_label_values(&["total", "total"]) - .observe(task.create_at.elapsed().as_secs_f64()); - metrics - .throughput - .with_label_values(&[StatusCode::OK.as_str()]) - .inc(); - Ok(Response::new(Body::from(task.data))) + match task_manager.submit_task(data).await { + Ok(task) => { + content = task.data; + status = match task.code { + TaskCode::Normal => { + // Record latency only for successful tasks + metrics + .duration + .with_label_values(&["total", "total"]) + .observe(task.create_at.elapsed().as_secs_f64()); + StatusCode::OK + } + TaskCode::BadRequestError => StatusCode::BAD_REQUEST, + TaskCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY, + TaskCode::InternalError => StatusCode::INTERNAL_SERVER_ERROR, + } + } + Err(err) => { + // Handle errors for which tasks cannot be retrieved + content = Bytes::from(err.to_string()); + status = match err { + ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS, + ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT, + ServiceError::UnknownError => StatusCode::INTERNAL_SERVER_ERROR, + }; } - TaskCode::BadRequestError => Err(ServiceError::BadRequestError), - TaskCode::ValidationError => Err(ServiceError::ValidationError), - TaskCode::InternalError => Err(ServiceError::InternalError), - TaskCode::UnknownError => Err(ServiceError::UnknownError), } -} - -fn error_handler(err: ServiceError) -> Response { - let status = match err { - ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT, - ServiceError::BadRequestError => StatusCode::BAD_REQUEST, - ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS, - ServiceError::ValidationError => StatusCode::UNPROCESSABLE_ENTITY, - ServiceError::InternalError => StatusCode::INTERNAL_SERVER_ERROR, - ServiceError::GracefulShutdown => StatusCode::SERVICE_UNAVAILABLE, - ServiceError::UnknownError => StatusCode::NOT_IMPLEMENTED, - }; - let metrics = Metrics::global(); - metrics.remaining_task.dec(); metrics .throughput .with_label_values(&[status.as_str()]) .inc(); + build_response(status, content) +} + +fn build_response(status: StatusCode, content: Bytes) -> Response { Response::builder() .status(status) .header("server", HeaderValue::from_static(SERVER_INFO)) - .body(Body::from(err.to_string())) + .body(Body::from(content)) .unwrap() } async fn service_func(req: Request) -> Result, hyper::Error> { - let res = match (req.method(), req.uri().path()) { - (&Method::GET, "/") => index(req).await, - (&Method::GET, "/metrics") => metrics(req).await, - (&Method::POST, "/inference") => inference(req).await, - _ => Ok(Response::builder() - .status(StatusCode::NOT_FOUND) - .body(NOT_FOUND.into()) - .unwrap()), - }; - match res { - Ok(mut resp) => { - resp.headers_mut() - .insert("server", HeaderValue::from_static(SERVER_INFO)); - Ok(resp) - } - Err(err) => Ok(error_handler(err)), + match (req.method(), req.uri().path()) { + (&Method::GET, "/") => Ok(index(req).await), + (&Method::GET, "/metrics") => Ok(metrics(req).await), + (&Method::POST, "/inference") => Ok(inference(req).await), + _ => Ok(build_response( + StatusCode::NOT_FOUND, + Bytes::from(RESPONSE_NOT_FOUND), + )), } } diff --git a/src/protocol.rs b/src/protocol.rs index c71ac35e..49847eda 100644 --- a/src/protocol.rs +++ b/src/protocol.rs @@ -18,7 +18,7 @@ const LENGTH_U8_SIZE: usize = 4; const BIT_STATUS_OK: u16 = 0b1; const BIT_STATUS_BAD_REQ: u16 = 0b10; const BIT_STATUS_VALIDATION_ERR: u16 = 0b100; -const BIT_STATUS_INTERNAL_ERR: u16 = 0b1000; +// Others are treated as Internal Error pub(crate) async fn communicate( path: PathBuf, @@ -42,7 +42,7 @@ pub(crate) async fn communicate( Ok((mut stream, addr)) => { info!(?addr, "accepted connection from"); tokio::spawn(async move { - let mut code: TaskCode = TaskCode::UnknownError; + let mut code: TaskCode = TaskCode::InternalError; let mut ids: Vec = Vec::with_capacity(batch_size); let mut data: Vec = Vec::with_capacity(batch_size); let task_manager = TaskManager::global(); @@ -134,10 +134,8 @@ async fn read_message( TaskCode::BadRequestError } else if flag & BIT_STATUS_VALIDATION_ERR > 0 { TaskCode::ValidationError - } else if flag & BIT_STATUS_INTERNAL_ERR > 0 { - TaskCode::InternalError } else { - TaskCode::UnknownError + TaskCode::InternalError }; let mut id_buf = [0u8; TASK_ID_U8_SIZE]; @@ -271,7 +269,7 @@ mod tests { let mut stream = UnixStream::connect(&path).await.unwrap(); let mut recv_ids = Vec::new(); let mut recv_data = Vec::new(); - let mut code = TaskCode::UnknownError; + let mut code = TaskCode::InternalError; read_message(&mut stream, &mut code, &mut recv_ids, &mut recv_data) .await .expect("read message error"); diff --git a/src/tasks.rs b/src/tasks.rs index 35e74545..de552a98 100644 --- a/src/tasks.rs +++ b/src/tasks.rs @@ -13,7 +13,6 @@ use crate::errors::ServiceError; #[derive(Debug, Clone, Copy)] pub(crate) enum TaskCode { - UnknownError, Normal, BadRequestError, ValidationError, @@ -30,7 +29,7 @@ pub(crate) struct Task { impl Task { fn new(data: Bytes) -> Self { Self { - code: TaskCode::UnknownError, + code: TaskCode::InternalError, data, create_at: Instant::now(), } @@ -87,7 +86,7 @@ impl TaskManager { } pub(crate) async fn submit_task(&self, data: Bytes) -> Result { - let (id, rx) = self.add_new_task(data).await?; + let (id, rx) = self.add_new_task(data)?; if let Err(err) = time::timeout(self.timeout, rx).await { error!(%id, %err, "task timeout"); let mut table = self.table.write(); @@ -110,13 +109,7 @@ impl TaskManager { self.shutdown.load(Ordering::Acquire) } - async fn add_new_task( - &self, - data: Bytes, - ) -> Result<(u32, oneshot::Receiver<()>), ServiceError> { - if self.is_shutdown() { - return Err(ServiceError::GracefulShutdown); - } + fn add_new_task(&self, data: Bytes) -> Result<(u32, oneshot::Receiver<()>), ServiceError> { let (tx, rx) = oneshot::channel(); let id: u32; { @@ -196,7 +189,7 @@ mod tests { let mut task = Task::new(Bytes::from_static(b"hello")); assert!(task.create_at > now); assert!(task.create_at < Instant::now()); - assert!(matches!(task.code, TaskCode::UnknownError)); + assert!(matches!(task.code, TaskCode::InternalError)); assert_eq!(task.data, Bytes::from_static(b"hello")); task.update(TaskCode::Normal, &Bytes::from_static(b"world")); @@ -210,7 +203,6 @@ mod tests { let task_manager = TaskManager::new(Duration::from_secs(1), tx); let (id, _rx) = task_manager .add_new_task(Bytes::from_static(b"hello")) - .await .unwrap(); assert_eq!(id, 0); { @@ -224,7 +216,6 @@ mod tests { // add a new task let (id, _rx) = task_manager .add_new_task(Bytes::from_static(b"world")) - .await .unwrap(); assert_eq!(id, 1); { diff --git a/tests/test_service.py b/tests/test_service.py index 1bab983e..3baf327c 100644 --- a/tests/test_service.py +++ b/tests/test_service.py @@ -45,7 +45,7 @@ def test_square_service(mosec_service, http_client): assert resp.status_code == 422 resp = http_client.post(f"{URI}/inference", content=b"bad-binary-request") - assert resp.status_code == 422 + assert resp.status_code == 400 validate_square_service(http_client, 2) @@ -62,6 +62,7 @@ def test_square_service_mp(mosec_service, http_client): for t in threads: t.join() assert_batch_larger_than_one(http_client) + assert_empty_queue(http_client) def validate_square_service(http_client, x): @@ -74,3 +75,9 @@ def assert_batch_larger_than_one(http_client): bs = re.findall(r"batch_size_bucket.+", metrics) get_bs_int = lambda x: int(x.split(" ")[-1]) # noqa assert get_bs_int(bs[-1]) > get_bs_int(bs[0]) + + +def assert_empty_queue(http_client): + metrics = http_client.get(f"{URI}/metrics").content.decode() + remain = re.findall(r"mosec_service_remaining_task \d+", metrics)[0] + assert int(remain.split(" ")[-1]) == 0