Skip to content

Commit

Permalink
Custom err msg (#76)
Browse files Browse the repository at this point in the history
* support custom 422 msg by raising ValidationError

* fix grammar

* clean rs err handling logic

* minor naming change

Co-authored-by: zclzc <[email protected]>
  • Loading branch information
lkevinzc and zclzc authored Oct 11, 2021
1 parent e01fcaf commit 1627a9b
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 103 deletions.
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ dev:
cp ./target/debug/mosec mosec/bin/
pip install -e .

test:
test: dev
pytest tests -vv -s
RUST_BACKTRACE=1 cargo test -vv

Expand Down Expand Up @@ -44,5 +44,6 @@ lint:
flake8 ${PY_SOURCE_FILES} --count --show-source --statistics
mypy --install-types --non-interactive ${PY_SOURCE_FILES}
cargo +nightly fmt -- --check
cargo clippy

.PHONY: test doc
2 changes: 1 addition & 1 deletion examples/echo.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def forward(self, data: dict) -> float:
try:
time = float(data["time"])
except KeyError as err:
raise ValidationError(err)
raise ValidationError(f"cannot find key {err}")
return time


Expand Down
2 changes: 1 addition & 1 deletion examples/resnet50_server_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def forward(self, req: dict) -> np.ndarray:
im = np.frombuffer(base64.b64decode(image), np.uint8)
im = cv2.imdecode(im, cv2.IMREAD_COLOR)[:, :, ::-1] # bgr -> rgb
except KeyError as err:
raise ValidationError(f"bad request: {err}")
raise ValidationError(f"cannot find key {err}")
except Exception as err:
raise ValidationError(f"cannot decode as image data: {err}")

Expand Down
17 changes: 13 additions & 4 deletions mosec/coordinator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from multiprocessing.synchronize import Event
from typing import Callable, Type

from .errors import ValidationError
from .errors import DecodingError, ValidationError
from .protocol import Protocol
from .worker import Worker

Expand Down Expand Up @@ -160,15 +160,24 @@ def coordinate(self):
"returned data doesn't match the input data:"
f"input({len(data)})!=output({len(payloads)})"
)
except ValidationError as err:
except DecodingError as err:
err_msg = str(err).replace("\n", " - ")
err_msg = (
err_msg if len(err_msg) else "cannot deserialize request bytes"
)
logger.info(f"{self.name} decoding error: {err_msg}")
status = self.protocol.FLAG_BAD_REQUEST
payloads = (f"decoding error: {err_msg}".encode(),)
except ValidationError as err:
err_msg = str(err)
err_msg = err_msg if len(err_msg) else "invalid data format"
logger.info(f"{self.name} validation error: {err_msg}")
status = self.protocol.FLAG_VALIDATION_ERROR
payloads = (f"Validation Error: {err_msg}".encode(),)
payloads = (f"validation error: {err_msg}".encode(),)
except Exception:
logger.warning(traceback.format_exc().replace("\n", " "))
status = self.protocol.FLAG_INTERNAL_ERROR
payloads = ("Internal Error".encode(),)
payloads = ("inference internal error".encode(),)

try:
self.protocol.send(status, ids, payloads)
Expand Down
30 changes: 27 additions & 3 deletions mosec/errors.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,31 @@
"""
Suppose the input dataflow of our model server is as follows:
**bytes** --- *deserialize*<sup>(decoding)</sup> ---> **data**
--- *parse*<sup>(validation)</sup> ---> **valid data**
If the raw bytes cannot be successfully deserialized, the `DecodingError`
is raised; if the decoded data cannot pass the validation check (usually
implemented by users), the `ValidationError` should be raised.
"""


class DecodingError(Exception):
"""
The `DecodingError` should be raised in user-implemented codes
when the de-serialization for the request bytes fails. This error
will set the status code to
[HTTP 400]("https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/400)
in the response.
"""


class ValidationError(Exception):
"""
The `ValidationError` should be raised in user-implemented codes,
where the validation for the input data fails. Usually it can be put
after the data deserialization, which converts the raw bytes into
structured data.
where the validation for the input data fails. Usually, it should be
put after the data de-serialization, which converts the raw bytes
into structured data. This error will set the status code to
[HTTP 422](https://developer.mozilla.org/en-US/docs/Web/HTTP/Status/422)
in the response.
"""
4 changes: 2 additions & 2 deletions mosec/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pickle
from typing import Any

from .errors import ValidationError
from .errors import DecodingError

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -89,7 +89,7 @@ def deserialize(self, data: bytes) -> Any:
try:
data_json = json.loads(data) if data else {}
except Exception as err:
raise ValidationError(err)
raise DecodingError(err)
return data_json

def forward(self, data: Any) -> Any:
Expand Down
4 changes: 2 additions & 2 deletions src/coordinator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ impl Coordinator {
let (sender, receiver) = bounded(opts.capacity);
let timeout = Duration::from_millis(opts.timeout);
let wait_time = Duration::from_millis(opts.wait);
let path = if opts.path.len() > 0 {
let path = if !opts.path.is_empty() {
opts.path.to_string()
} else {
// default IPC path
Expand All @@ -44,7 +44,7 @@ impl Coordinator {

Self {
capacity: opts.capacity,
path: path,
path,
batches: opts.batches.clone(),
wait_time,
timeout,
Expand Down
14 changes: 1 addition & 13 deletions src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,9 @@ pub(crate) enum ServiceError {
#[display(fmt = "inference timeout")]
Timeout,

#[display(fmt = "bad request")]
BadRequestError,

#[display(fmt = "bad request: validation error")]
ValidationError,

#[display(fmt = "inference internal error")]
InternalError,

#[display(fmt = "too many request: channel is full")]
#[display(fmt = "too many request: task queue is full")]
TooManyRequests,

#[display(fmt = "cannot accept new request during the graceful shutdown")]
GracefulShutdown,

#[display(fmt = "mosec unknown error")]
UnknownError,
}
118 changes: 62 additions & 56 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod tasks;

use std::net::SocketAddr;

use bytes::Bytes;
use clap::Clap;
use hyper::service::{make_service_fn, service_fn};
use hyper::{body::to_bytes, header::HeaderValue, Body, Method, Request, Response, StatusCode};
Expand All @@ -22,97 +23,102 @@ use crate::metrics::Metrics;
use crate::tasks::{TaskCode, TaskManager};

const SERVER_INFO: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
const NOT_FOUND: &[u8] = b"Not Found";
const RESPONSE_DEFAULT: &[u8] = b"MOSEC service";
const RESPONSE_NOT_FOUND: &[u8] = b"not found";
const RESPONSE_EMPTY: &[u8] = b"no data provided";
const RESPONSE_SHUTDOWN: &[u8] = b"gracefully shutting down";

async fn index(_: Request<Body>) -> Result<Response<Body>, ServiceError> {
async fn index(_: Request<Body>) -> Response<Body> {
let task_manager = TaskManager::global();
if task_manager.is_shutdown() {
return Err(ServiceError::GracefulShutdown);
build_response(
StatusCode::SERVICE_UNAVAILABLE,
Bytes::from_static(RESPONSE_SHUTDOWN),
)
} else {
build_response(StatusCode::OK, Bytes::from_static(RESPONSE_DEFAULT))
}
Ok(Response::new(Body::from("MOSEC service")))
}

async fn metrics(_: Request<Body>) -> Result<Response<Body>, ServiceError> {
async fn metrics(_: Request<Body>) -> Response<Body> {
let encoder = TextEncoder::new();
let metrics = prometheus::gather();
let mut buffer = vec![];
encoder.encode(&metrics, &mut buffer).unwrap();
Ok(Response::new(Body::from(buffer)))
build_response(StatusCode::OK, Bytes::from(buffer))
}

async fn inference(req: Request<Body>) -> Result<Response<Body>, ServiceError> {
async fn inference(req: Request<Body>) -> Response<Body> {
let task_manager = TaskManager::global();
let data = to_bytes(req.into_body()).await.unwrap();
let metrics = Metrics::global();

if task_manager.is_shutdown() {
return build_response(
StatusCode::SERVICE_UNAVAILABLE,
Bytes::from_static(RESPONSE_SHUTDOWN),
);
}

if data.is_empty() {
return Ok(Response::new(Body::from("No data provided")));
return build_response(StatusCode::OK, Bytes::from_static(RESPONSE_EMPTY));
}

let (status, content);
metrics.remaining_task.inc();
let task = task_manager.submit_task(data).await?;
match task.code {
TaskCode::Normal => {
metrics.remaining_task.dec();
metrics
.duration
.with_label_values(&["total", "total"])
.observe(task.create_at.elapsed().as_secs_f64());
metrics
.throughput
.with_label_values(&[StatusCode::OK.as_str()])
.inc();
Ok(Response::new(Body::from(task.data)))
match task_manager.submit_task(data).await {
Ok(task) => {
content = task.data;
status = match task.code {
TaskCode::Normal => {
// Record latency only for successful tasks
metrics
.duration
.with_label_values(&["total", "total"])
.observe(task.create_at.elapsed().as_secs_f64());
StatusCode::OK
}
TaskCode::BadRequestError => StatusCode::BAD_REQUEST,
TaskCode::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
TaskCode::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
}
}
Err(err) => {
// Handle errors for which tasks cannot be retrieved
content = Bytes::from(err.to_string());
status = match err {
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
ServiceError::UnknownError => StatusCode::INTERNAL_SERVER_ERROR,
};
}
TaskCode::BadRequestError => Err(ServiceError::BadRequestError),
TaskCode::ValidationError => Err(ServiceError::ValidationError),
TaskCode::InternalError => Err(ServiceError::InternalError),
TaskCode::UnknownError => Err(ServiceError::UnknownError),
}
}

fn error_handler(err: ServiceError) -> Response<Body> {
let status = match err {
ServiceError::Timeout => StatusCode::REQUEST_TIMEOUT,
ServiceError::BadRequestError => StatusCode::BAD_REQUEST,
ServiceError::TooManyRequests => StatusCode::TOO_MANY_REQUESTS,
ServiceError::ValidationError => StatusCode::UNPROCESSABLE_ENTITY,
ServiceError::InternalError => StatusCode::INTERNAL_SERVER_ERROR,
ServiceError::GracefulShutdown => StatusCode::SERVICE_UNAVAILABLE,
ServiceError::UnknownError => StatusCode::NOT_IMPLEMENTED,
};
let metrics = Metrics::global();

metrics.remaining_task.dec();
metrics
.throughput
.with_label_values(&[status.as_str()])
.inc();

build_response(status, content)
}

fn build_response(status: StatusCode, content: Bytes) -> Response<Body> {
Response::builder()
.status(status)
.header("server", HeaderValue::from_static(SERVER_INFO))
.body(Body::from(err.to_string()))
.body(Body::from(content))
.unwrap()
}

async fn service_func(req: Request<Body>) -> Result<Response<Body>, hyper::Error> {
let res = match (req.method(), req.uri().path()) {
(&Method::GET, "/") => index(req).await,
(&Method::GET, "/metrics") => metrics(req).await,
(&Method::POST, "/inference") => inference(req).await,
_ => Ok(Response::builder()
.status(StatusCode::NOT_FOUND)
.body(NOT_FOUND.into())
.unwrap()),
};
match res {
Ok(mut resp) => {
resp.headers_mut()
.insert("server", HeaderValue::from_static(SERVER_INFO));
Ok(resp)
}
Err(err) => Ok(error_handler(err)),
match (req.method(), req.uri().path()) {
(&Method::GET, "/") => Ok(index(req).await),
(&Method::GET, "/metrics") => Ok(metrics(req).await),
(&Method::POST, "/inference") => Ok(inference(req).await),
_ => Ok(build_response(
StatusCode::NOT_FOUND,
Bytes::from(RESPONSE_NOT_FOUND),
)),
}
}

Expand Down
10 changes: 4 additions & 6 deletions src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ const LENGTH_U8_SIZE: usize = 4;
const BIT_STATUS_OK: u16 = 0b1;
const BIT_STATUS_BAD_REQ: u16 = 0b10;
const BIT_STATUS_VALIDATION_ERR: u16 = 0b100;
const BIT_STATUS_INTERNAL_ERR: u16 = 0b1000;
// Others are treated as Internal Error

pub(crate) async fn communicate(
path: PathBuf,
Expand All @@ -42,7 +42,7 @@ pub(crate) async fn communicate(
Ok((mut stream, addr)) => {
info!(?addr, "accepted connection from");
tokio::spawn(async move {
let mut code: TaskCode = TaskCode::UnknownError;
let mut code: TaskCode = TaskCode::InternalError;
let mut ids: Vec<u32> = Vec::with_capacity(batch_size);
let mut data: Vec<Bytes> = Vec::with_capacity(batch_size);
let task_manager = TaskManager::global();
Expand Down Expand Up @@ -134,10 +134,8 @@ async fn read_message(
TaskCode::BadRequestError
} else if flag & BIT_STATUS_VALIDATION_ERR > 0 {
TaskCode::ValidationError
} else if flag & BIT_STATUS_INTERNAL_ERR > 0 {
TaskCode::InternalError
} else {
TaskCode::UnknownError
TaskCode::InternalError
};

let mut id_buf = [0u8; TASK_ID_U8_SIZE];
Expand Down Expand Up @@ -271,7 +269,7 @@ mod tests {
let mut stream = UnixStream::connect(&path).await.unwrap();
let mut recv_ids = Vec::new();
let mut recv_data = Vec::new();
let mut code = TaskCode::UnknownError;
let mut code = TaskCode::InternalError;
read_message(&mut stream, &mut code, &mut recv_ids, &mut recv_data)
.await
.expect("read message error");
Expand Down
Loading

0 comments on commit 1627a9b

Please sign in to comment.