Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support graceful shutdown on "auto conn" #66

Merged
merged 5 commits into from
Dec 19, 2023
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 160 additions & 35 deletions src/server/conn/auto.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
//! Http1 or Http2 connection.

use futures_util::ready;
use hyper::service::HttpService;
use std::future::Future;
use std::io::{Error as IoError, ErrorKind, Result as IoResult};
use std::marker::PhantomPinned;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{error::Error as StdError, marker::Unpin, time::Duration};
Expand Down Expand Up @@ -65,7 +67,7 @@ impl<E> Builder<E> {
}

/// Bind a connection together with a [`Service`].
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
Expand All @@ -75,13 +77,13 @@ impl<E> Builder<E> {
I: Read + Write + Unpin + 'static,
E: Http2ServerConnExec<S::Future, B>,
{
let (version, io) = read_version(io).await?;
match version {
Version::H1 => self.http1.serve_connection(io, service).await?,
Version::H2 => self.http2.serve_connection(io, service).await?,
Connection {
state: ConnFutureState::ReadVersion {
read_version: read_version(io),
builder: self,
service: Some(service),
},
}

Ok(())
}

/// Bind a connection together with a [`Service`], with the ability to
Expand Down Expand Up @@ -116,57 +118,180 @@ enum Version {
H1,
H2,
}
async fn read_version<'a, A>(mut reader: A) -> IoResult<(Version, Rewind<A>)>

fn read_version<I>(io: I) -> ReadVersion<I>
where
A: Read + Unpin,
I: Read + Unpin,
{
use std::mem::MaybeUninit;

let mut buf = [MaybeUninit::uninit(); 24];
let (version, buf) = ReadVersion {
reader: &mut reader,
buf: ReadBuf::uninit(&mut buf),
ReadVersion {
io: Some(io),
buf: [MaybeUninit::uninit(); 24],
filled: 0,
version: Version::H1,
_pin: PhantomPinned,
}
.await?;
Ok((version, Rewind::new_buffered(reader, Bytes::from(buf))))
}

pin_project! {
struct ReadVersion<'a, A: ?Sized> {
reader: &'a mut A,
buf: ReadBuf<'a>,
struct ReadVersion<I> {
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had to remove all references from this type since we can't use async/await anymore because we need a nameable Connection type to add the graceful_shutdown method to.

io: Option<I>,
buf: [MaybeUninit<u8>; 24],
// the amount of `buf` thats been filled
filled: usize,
version: Version,
// Make this future `!Unpin` for compatibility with async trait methods.
#[pin]
_pin: PhantomPinned,
}
}

impl<A> Future for ReadVersion<'_, A>
impl<I> Future for ReadVersion<I>
where
A: Read + Unpin + ?Sized,
I: Read + Unpin,
{
type Output = IoResult<(Version, Vec<u8>)>;
type Output = IoResult<(Version, Rewind<I>)>;

fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<IoResult<(Version, Vec<u8>)>> {
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();

while this.buf.filled().len() < H2_PREFACE.len() {
if this.buf.filled() != &H2_PREFACE[0..this.buf.filled().len()] {
return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec())));
}
// if our buffer is empty, then we need to read some data to continue.
let len = this.buf.filled().len();
ready!(Pin::new(&mut *this.reader).poll_read(cx, this.buf.unfilled()))?;
if this.buf.filled().len() == len {
return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into();
let mut buf = ReadBuf::uninit(&mut *this.buf);
// SAFETY: `this.filled` tracks how many bytes have been read (and thus initialized) and
// we're only advancing by that many.
unsafe {
buf.unfilled().advance(*this.filled);
};

while buf.filled().len() < H2_PREFACE.len() {
if buf.filled() != &H2_PREFACE[0..buf.filled().len()] {
let io = this.io.take().unwrap();
let buf = buf.filled().to_vec();
return Poll::Ready(Ok((
*this.version,
Rewind::new_buffered(io, Bytes::from(buf)),
)));
} else {
// if our buffer is empty, then we need to read some data to continue.
let len = buf.filled().len();
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?;
*this.filled = buf.filled().len();
if buf.filled().len() == len {
return Err(IoError::new(ErrorKind::UnexpectedEof, "early eof")).into();
}
}
}
if this.buf.filled() == H2_PREFACE {
if buf.filled() == H2_PREFACE {
*this.version = Version::H2;
}
return Poll::Ready(Ok((*this.version, this.buf.filled().to_vec())));
let io = this.io.take().unwrap();
let buf = buf.filled().to_vec();
Poll::Ready(Ok((
*this.version,
Rewind::new_buffered(io, Bytes::from(buf)),
)))
}
}

pin_project! {
/// TODO
pub struct Connection<'a, I, S, E>
where
S: HttpService<Incoming>,
{
#[pin]
state: ConnFutureState<'a, I, S, E>,
}
}

pin_project! {
#[project = ConnFutureStateProj]
enum ConnFutureState<'a, I, S, E>
where
S: HttpService<Incoming>,
{
ReadVersion {
#[pin]
read_version: ReadVersion<I>,
builder: &'a Builder<E>,
service: Option<S>,
},
H1 {
#[pin]
conn: hyper::server::conn::http1::Connection<Rewind<I>, S>,
},
H1WithUpgrades {
// can't name this type :(
#[pin]
conn: UpgradeableConnection<Rewind<I>, S>,
},
H2 {
#[pin]
conn: hyper::server::conn::http2::Connection<Rewind<I>, S, E>,
},
}
}

impl<I, S, E, B> Connection<'_, I, S, E>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: Http2ServerConnExec<S::Future, B>,
{
/// TODO
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
ConnFutureStateProj::ReadVersion { .. } => {}
ConnFutureStateProj::H1 { conn } => conn.graceful_shutdown(),
ConnFutureStateProj::H2 { conn } => conn.graceful_shutdown(),
}
}
}

impl<I, S, E, B> Future for Connection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: Http2ServerConnExec<S::Future, B>,
{
type Output = Result<()>;

fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();

match this.state.as_mut().project() {
ConnFutureStateProj::ReadVersion {
read_version,
builder,
service,
} => {
let (version, io) = ready!(read_version.poll(cx))?;
let service = service.take().unwrap();
match version {
Version::H1 => {
let conn = builder.http1.serve_connection(io, service);
this.state.set(ConnFutureState::H1 { conn });
}
Version::H2 => {
let conn = builder.http2.serve_connection(io, service);
this.state.set(ConnFutureState::H2 { conn });
}
}
}
ConnFutureStateProj::H1 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
ConnFutureStateProj::H2 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
}
}
}
}

Expand Down
Loading