diff --git a/Makefile b/Makefile
index 68b72008..92c1ca9b 100644
--- a/Makefile
+++ b/Makefile
@@ -13,7 +13,7 @@ dev:
cp ./target/debug/mosec mosec/bin/
pip install -e .
-test:
+test: dev
pytest tests -vv -s
RUST_BACKTRACE=1 cargo test -vv
@@ -44,5 +44,6 @@ lint:
flake8 ${PY_SOURCE_FILES} --count --show-source --statistics
mypy --install-types --non-interactive ${PY_SOURCE_FILES}
cargo +nightly fmt -- --check
+ cargo clippy
.PHONY: test doc
diff --git a/examples/echo.py b/examples/echo.py
index bc4feb55..f83f6b36 100644
--- a/examples/echo.py
+++ b/examples/echo.py
@@ -21,7 +21,7 @@ def forward(self, data: dict) -> float:
try:
time = float(data["time"])
except KeyError as err:
- raise ValidationError(err)
+ raise ValidationError(f"cannot find key {err}")
return time
diff --git a/examples/resnet50_server_pytorch.py b/examples/resnet50_server_pytorch.py
index e299b3fe..d7f272f9 100644
--- a/examples/resnet50_server_pytorch.py
+++ b/examples/resnet50_server_pytorch.py
@@ -32,7 +32,7 @@ def forward(self, req: dict) -> np.ndarray:
im = np.frombuffer(base64.b64decode(image), np.uint8)
im = cv2.imdecode(im, cv2.IMREAD_COLOR)[:, :, ::-1] # bgr -> rgb
except KeyError as err:
- raise ValidationError(f"bad request: {err}")
+ raise ValidationError(f"cannot find key {err}")
except Exception as err:
raise ValidationError(f"cannot decode as image data: {err}")
diff --git a/mosec/coordinator.py b/mosec/coordinator.py
index 757e12dd..03aef9af 100644
--- a/mosec/coordinator.py
+++ b/mosec/coordinator.py
@@ -8,7 +8,7 @@
from multiprocessing.synchronize import Event
from typing import Callable, Type
-from .errors import ValidationError
+from .errors import DecodingError, ValidationError
from .protocol import Protocol
from .worker import Worker
@@ -160,15 +160,24 @@ def coordinate(self):
"returned data doesn't match the input data:"
f"input({len(data)})!=output({len(payloads)})"
)
- except ValidationError as err:
+ except DecodingError as err:
err_msg = str(err).replace("\n", " - ")
+ err_msg = (
+ err_msg if len(err_msg) else "cannot deserialize request bytes"
+ )
+ logger.info(f"{self.name} decoding error: {err_msg}")
+ status = self.protocol.FLAG_BAD_REQUEST
+ payloads = (f"decoding error: {err_msg}".encode(),)
+ except ValidationError as err:
+ err_msg = str(err)
+ err_msg = err_msg if len(err_msg) else "invalid data format"
logger.info(f"{self.name} validation error: {err_msg}")
status = self.protocol.FLAG_VALIDATION_ERROR
- payloads = (f"Validation Error: {err_msg}".encode(),)
+ payloads = (f"validation error: {err_msg}".encode(),)
except Exception:
logger.warning(traceback.format_exc().replace("\n", " "))
status = self.protocol.FLAG_INTERNAL_ERROR
- payloads = ("Internal Error".encode(),)
+ payloads = ("inference internal error".encode(),)
try:
self.protocol.send(status, ids, payloads)
diff --git a/mosec/errors.py b/mosec/errors.py
index 77e2557c..a9c71c15 100644
--- a/mosec/errors.py
+++ b/mosec/errors.py
@@ -1,7 +1,31 @@
+"""
+Suppose the input dataflow of our model server is as follows:
+
+**bytes** --- *deserialize*(decoding) ---> **data**
+--- *parse*(validation) ---> **valid data**
+
+If the raw bytes cannot be successfully deserialized, the `DecodingError`
+is raised; if the decoded data cannot pass the validation check (usually
+implemented by users), the `ValidationError` should be raised.
+"""
+
+
+class DecodingError(Exception):
+ """
+ The `DecodingError` should be raised in user-implemented codes
+ when the de-serialization for the request bytes fails. This error
+ will set the status code to
+ [HTTP 400]("https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400)
+ in the response.
+ """
+
+
class ValidationError(Exception):
"""
The `ValidationError` should be raised in user-implemented codes,
- where the validation for the input data fails. Usually it can be put
- after the data deserialization, which converts the raw bytes into
- structured data.
+ where the validation for the input data fails. Usually, it should be
+ put after the data de-serialization, which converts the raw bytes
+ into structured data. This error will set the status code to
+ [HTTP 422](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/422)
+ in the response.
"""
diff --git a/mosec/worker.py b/mosec/worker.py
index 664d84b1..cd7b4c03 100644
--- a/mosec/worker.py
+++ b/mosec/worker.py
@@ -3,7 +3,7 @@
import pickle
from typing import Any
-from .errors import ValidationError
+from .errors import DecodingError
logger = logging.getLogger(__name__)
@@ -89,7 +89,7 @@ def deserialize(self, data: bytes) -> Any:
try:
data_json = json.loads(data) if data else {}
except Exception as err:
- raise ValidationError(err)
+ raise DecodingError(err)
return data_json
def forward(self, data: Any) -> Any:
diff --git a/src/coordinator.rs b/src/coordinator.rs
index a404fcdd..68604329 100644
--- a/src/coordinator.rs
+++ b/src/coordinator.rs
@@ -27,7 +27,7 @@ impl Coordinator {
let (sender, receiver) = bounded(opts.capacity);
let timeout = Duration::from_millis(opts.timeout);
let wait_time = Duration::from_millis(opts.wait);
- let path = if opts.path.len() > 0 {
+ let path = if !opts.path.is_empty() {
opts.path.to_string()
} else {
// default IPC path
@@ -44,7 +44,7 @@ impl Coordinator {
Self {
capacity: opts.capacity,
- path: path,
+ path,
batches: opts.batches.clone(),
wait_time,
timeout,
diff --git a/src/errors.rs b/src/errors.rs
index 2de109b5..7ee8f6c5 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -5,21 +5,9 @@ pub(crate) enum ServiceError {
#[display(fmt = "inference timeout")]
Timeout,
- #[display(fmt = "bad request")]
- BadRequestError,
-
- #[display(fmt = "bad request: validation error")]
- ValidationError,
-
- #[display(fmt = "inference internal error")]
- InternalError,
-
- #[display(fmt = "too many request: channel is full")]
+ #[display(fmt = "too many request: task queue is full")]
TooManyRequests,
- #[display(fmt = "cannot accept new request during the graceful shutdown")]
- GracefulShutdown,
-
#[display(fmt = "mosec unknown error")]
UnknownError,
}
diff --git a/src/main.rs b/src/main.rs
index 08d91ce2..b279d0ab 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -7,6 +7,7 @@ mod tasks;
use std::net::SocketAddr;
+use bytes::Bytes;
use clap::Clap;
use hyper::service::{make_service_fn, service_fn};
use hyper::{body::to_bytes, header::HeaderValue, Body, Method, Request, Response, StatusCode};
@@ -22,97 +23,102 @@ use crate::metrics::Metrics;
use crate::tasks::{TaskCode, TaskManager};
const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
-const NOT_FOUND: &[u8] = b"Not Found";
+const RESPONSE_DEFAULT: &[u8] = b"MOSEC service";
+const RESPONSE_NOT_FOUND: &[u8] = b"not found";
+const RESPONSE_EMPTY: &[u8] = b"no data provided";
+const RESPONSE_SHUTDOWN: &[u8] = b"gracefully shutting down";
-async fn index(_: Request
) -> Result, ServiceError> {
+async fn index(_: Request) -> Response {
let task_manager = TaskManager::global();
if task_manager.is_shutdown() {
- return Err(ServiceError::GracefulShutdown);
+ build_response(
+ StatusCode::SERVICE_UNAVAILABLE,
+ Bytes::from_static(RESPONSE_SHUTDOWN),
+ )
+ } else {
+ build_response(StatusCode::OK, Bytes::from_static(RESPONSE_DEFAULT))
}
- Ok(Response::new(Body::from("MOSEC service")))
}
-async fn metrics(_: Request) -> Result, ServiceError> {
+async fn metrics(_: Request) -> Response {
let encoder = TextEncoder::new();
let metrics = prometheus::gather();
let mut buffer = vec![];
encoder.encode(&metrics, &mut buffer).unwrap();
- Ok(Response::new(Body::from(buffer)))
+ build_response(StatusCode::OK, Bytes::from(buffer))
}
-async fn inference(req: Request) -> Result, ServiceError> {
+async fn inference(req: Request) -> Response {
let task_manager = TaskManager::global();
let data = to_bytes(req.into_body()).await.unwrap();
let metrics = Metrics::global();
+ if task_manager.is_shutdown() {
+ return build_response(
+ StatusCode::SERVICE_UNAVAILABLE,
+ Bytes::from_static(RESPONSE_SHUTDOWN),
+ );
+ }
+
if data.is_empty() {
- return Ok(Response::new(Body::from("No data provided")));
+ return build_response(StatusCode::OK, Bytes::from_static(RESPONSE_EMPTY));
}
+ let (status, content);
metrics.remaining_task.inc();
- let task = task_manager.submit_task(data).await?;
- match task.code {
- TaskCode::Normal => {
- metrics.remaining_task.dec();
- metrics
- .duration
- .with_label_values(&["total", "total"])
- .observe(task.create_at.elapsed().as_secs_f64());
- metrics
- .throughput
- .with_label_values(&[StatusCode::OK.as_str()])
- .inc();
- Ok(Response::new(Body::from(task.data)))
+ match task_manager.submit_task(data).await {
+ Ok(task) => {
+ content = task.data;
+ status = match task.code {
+ TaskCode::Normal => {
+ // Record latency only for successful tasks
+ metrics
+ .duration
+ .with_label_values(&["total", "total"])
+ .observe(task.create_at.elapsed().as_secs_f64());
+ StatusCode::OK
+ }
+ TaskCode::BadRequestError => StatusCode::BAD_REQUEST,
+ TaskCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
+ TaskCode::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
+ }
+ }
+ Err(err) => {
+ // Handle errors for which tasks cannot be retrieved
+ content = Bytes::from(err.to_string());
+ status = match err {
+ ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
+ ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
+ ServiceError::UnknownError => StatusCode::INTERNAL_SERVER_ERROR,
+ };
}
- TaskCode::BadRequestError => Err(ServiceError::BadRequestError),
- TaskCode::ValidationError => Err(ServiceError::ValidationError),
- TaskCode::InternalError => Err(ServiceError::InternalError),
- TaskCode::UnknownError => Err(ServiceError::UnknownError),
}
-}
-
-fn error_handler(err: ServiceError) -> Response {
- let status = match err {
- ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
- ServiceError::BadRequestError => StatusCode::BAD_REQUEST,
- ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
- ServiceError::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
- ServiceError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
- ServiceError::GracefulShutdown => StatusCode::SERVICE_UNAVAILABLE,
- ServiceError::UnknownError => StatusCode::NOT_IMPLEMENTED,
- };
- let metrics = Metrics::global();
-
metrics.remaining_task.dec();
metrics
.throughput
.with_label_values(&[status.as_str()])
.inc();
+ build_response(status, content)
+}
+
+fn build_response(status: StatusCode, content: Bytes) -> Response {
Response::builder()
.status(status)
.header("server", HeaderValue::from_static(SERVER_INFO))
- .body(Body::from(err.to_string()))
+ .body(Body::from(content))
.unwrap()
}
async fn service_func(req: Request) -> Result, hyper::Error> {
- let res = match (req.method(), req.uri().path()) {
- (&Method::GET, "/") => index(req).await,
- (&Method::GET, "/metrics") => metrics(req).await,
- (&Method::POST, "/inference") => inference(req).await,
- _ => Ok(Response::builder()
- .status(StatusCode::NOT_FOUND)
- .body(NOT_FOUND.into())
- .unwrap()),
- };
- match res {
- Ok(mut resp) => {
- resp.headers_mut()
- .insert("server", HeaderValue::from_static(SERVER_INFO));
- Ok(resp)
- }
- Err(err) => Ok(error_handler(err)),
+ match (req.method(), req.uri().path()) {
+ (&Method::GET, "/") => Ok(index(req).await),
+ (&Method::GET, "/metrics") => Ok(metrics(req).await),
+ (&Method::POST, "/inference") => Ok(inference(req).await),
+ _ => Ok(build_response(
+ StatusCode::NOT_FOUND,
+ Bytes::from(RESPONSE_NOT_FOUND),
+ )),
}
}
diff --git a/src/protocol.rs b/src/protocol.rs
index c71ac35e..49847eda 100644
--- a/src/protocol.rs
+++ b/src/protocol.rs
@@ -18,7 +18,7 @@ const LENGTH_U8_SIZE: usize = 4;
const BIT_STATUS_OK: u16 = 0b1;
const BIT_STATUS_BAD_REQ: u16 = 0b10;
const BIT_STATUS_VALIDATION_ERR: u16 = 0b100;
-const BIT_STATUS_INTERNAL_ERR: u16 = 0b1000;
+// Others are treated as Internal Error
pub(crate) async fn communicate(
path: PathBuf,
@@ -42,7 +42,7 @@ pub(crate) async fn communicate(
Ok((mut stream, addr)) => {
info!(?addr, "accepted connection from");
tokio::spawn(async move {
- let mut code: TaskCode = TaskCode::UnknownError;
+ let mut code: TaskCode = TaskCode::InternalError;
let mut ids: Vec = Vec::with_capacity(batch_size);
let mut data: Vec = Vec::with_capacity(batch_size);
let task_manager = TaskManager::global();
@@ -134,10 +134,8 @@ async fn read_message(
TaskCode::BadRequestError
} else if flag & BIT_STATUS_VALIDATION_ERR > 0 {
TaskCode::ValidationError
- } else if flag & BIT_STATUS_INTERNAL_ERR > 0 {
- TaskCode::InternalError
} else {
- TaskCode::UnknownError
+ TaskCode::InternalError
};
let mut id_buf = [0u8; TASK_ID_U8_SIZE];
@@ -271,7 +269,7 @@ mod tests {
let mut stream = UnixStream::connect(&path).await.unwrap();
let mut recv_ids = Vec::new();
let mut recv_data = Vec::new();
- let mut code = TaskCode::UnknownError;
+ let mut code = TaskCode::InternalError;
read_message(&mut stream, &mut code, &mut recv_ids, &mut recv_data)
.await
.expect("read message error");
diff --git a/src/tasks.rs b/src/tasks.rs
index 35e74545..de552a98 100644
--- a/src/tasks.rs
+++ b/src/tasks.rs
@@ -13,7 +13,6 @@ use crate::errors::ServiceError;
#[derive(Debug, Clone, Copy)]
pub(crate) enum TaskCode {
- UnknownError,
Normal,
BadRequestError,
ValidationError,
@@ -30,7 +29,7 @@ pub(crate) struct Task {
impl Task {
fn new(data: Bytes) -> Self {
Self {
- code: TaskCode::UnknownError,
+ code: TaskCode::InternalError,
data,
create_at: Instant::now(),
}
@@ -87,7 +86,7 @@ impl TaskManager {
}
pub(crate) async fn submit_task(&self, data: Bytes) -> Result {
- let (id, rx) = self.add_new_task(data).await?;
+ let (id, rx) = self.add_new_task(data)?;
if let Err(err) = time::timeout(self.timeout, rx).await {
error!(%id, %err, "task timeout");
let mut table = self.table.write();
@@ -110,13 +109,7 @@ impl TaskManager {
self.shutdown.load(Ordering::Acquire)
}
- async fn add_new_task(
- &self,
- data: Bytes,
- ) -> Result<(u32, oneshot::Receiver<()>), ServiceError> {
- if self.is_shutdown() {
- return Err(ServiceError::GracefulShutdown);
- }
+ fn add_new_task(&self, data: Bytes) -> Result<(u32, oneshot::Receiver<()>), ServiceError> {
let (tx, rx) = oneshot::channel();
let id: u32;
{
@@ -196,7 +189,7 @@ mod tests {
let mut task = Task::new(Bytes::from_static(b"hello"));
assert!(task.create_at > now);
assert!(task.create_at < Instant::now());
- assert!(matches!(task.code, TaskCode::UnknownError));
+ assert!(matches!(task.code, TaskCode::InternalError));
assert_eq!(task.data, Bytes::from_static(b"hello"));
task.update(TaskCode::Normal, &Bytes::from_static(b"world"));
@@ -210,7 +203,6 @@ mod tests {
let task_manager = TaskManager::new(Duration::from_secs(1), tx);
let (id, _rx) = task_manager
.add_new_task(Bytes::from_static(b"hello"))
- .await
.unwrap();
assert_eq!(id, 0);
{
@@ -224,7 +216,6 @@ mod tests {
// add a new task
let (id, _rx) = task_manager
.add_new_task(Bytes::from_static(b"world"))
- .await
.unwrap();
assert_eq!(id, 1);
{
diff --git a/tests/test_service.py b/tests/test_service.py
index 1bab983e..3baf327c 100644
--- a/tests/test_service.py
+++ b/tests/test_service.py
@@ -45,7 +45,7 @@ def test_square_service(mosec_service, http_client):
assert resp.status_code == 422
resp = http_client.post(f"{URI}/inference", content=b"bad-binary-request")
- assert resp.status_code == 422
+ assert resp.status_code == 400
validate_square_service(http_client, 2)
@@ -62,6 +62,7 @@ def test_square_service_mp(mosec_service, http_client):
for t in threads:
t.join()
assert_batch_larger_than_one(http_client)
+ assert_empty_queue(http_client)
def validate_square_service(http_client, x):
@@ -74,3 +75,9 @@ def assert_batch_larger_than_one(http_client):
bs = re.findall(r"batch_size_bucket.+", metrics)
get_bs_int = lambda x: int(x.split(" ")[-1]) # noqa
assert get_bs_int(bs[-1]) > get_bs_int(bs[0])
+
+
+def assert_empty_queue(http_client):
+ metrics = http_client.get(f"{URI}/metrics").content.decode()
+ remain = re.findall(r"mosec_service_remaining_task \d+", metrics)[0]
+ assert int(remain.split(" ")[-1]) == 0