use futures_util::ready;
use hyper::service::HttpService;
use std::future::Future;
use std::marker::PhantomPinned;
use std::mem::MaybeUninit;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{error::Error as StdError, io, marker::Unpin, time::Duration};
use bytes::Bytes;
use http::{Request, Response};
use http_body::Body;
use hyper::{
body::Incoming,
rt::{Read, ReadBuf, Timer, Write},
service::Service,
};
#[cfg(feature = "http1")]
use hyper::server::conn::http1;
#[cfg(feature = "http2")]
use hyper::{rt::bounds::Http2ServerConnExec, server::conn::http2};
#[cfg(any(not(feature = "http2"), not(feature = "http1")))]
use std::marker::PhantomData;
use pin_project_lite::pin_project;
use crate::common::rewind::Rewind;
type Error = Box<dyn std::error::Error + Send + Sync>;
type Result<T> = std::result::Result<T, Error>;
const H2_PREFACE: &[u8] = b"PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n";
#[cfg(feature = "http2")]
pub trait HttpServerConnExec<A, B: Body>: Http2ServerConnExec<A, B> {}
#[cfg(feature = "http2")]
impl<A, B: Body, T: Http2ServerConnExec<A, B>> HttpServerConnExec<A, B> for T {}
#[cfg(not(feature = "http2"))]
pub trait HttpServerConnExec<A, B: Body> {}
#[cfg(not(feature = "http2"))]
impl<A, B: Body, T> HttpServerConnExec<A, B> for T {}
#[derive(Clone, Debug)]
pub struct Builder<E> {
#[cfg(feature = "http1")]
http1: http1::Builder,
#[cfg(feature = "http2")]
http2: http2::Builder<E>,
#[cfg(not(feature = "http2"))]
_executor: E,
}
impl<E> Builder<E> {
pub fn new(executor: E) -> Self {
Self {
#[cfg(feature = "http1")]
http1: http1::Builder::new(),
#[cfg(feature = "http2")]
http2: http2::Builder::new(executor),
#[cfg(not(feature = "http2"))]
_executor: executor,
}
}
#[cfg(feature = "http1")]
pub fn http1(&mut self) -> Http1Builder<'_, E> {
Http1Builder { inner: self }
}
#[cfg(feature = "http2")]
pub fn http2(&mut self) -> Http2Builder<'_, E> {
Http2Builder { inner: self }
}
pub fn serve_connection<I, S, B>(&self, io: I, service: S) -> Connection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
Connection {
state: ConnState::ReadVersion {
read_version: read_version(io),
builder: self,
service: Some(service),
},
}
}
pub fn serve_connection_with_upgrades<I, S, B>(
&self,
io: I,
service: S,
) -> UpgradeableConnection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
UpgradeableConnection {
state: UpgradeableConnState::ReadVersion {
read_version: read_version(io),
builder: self,
service: Some(service),
},
}
}
}
#[derive(Copy, Clone)]
enum Version {
H1,
H2,
}
impl Version {
#[must_use]
#[cfg(any(not(feature = "http2"), not(feature = "http1")))]
pub fn unsupported(self) -> Error {
match self {
Version::H1 => Error::from("HTTP/1 is not supported"),
Version::H2 => Error::from("HTTP/2 is not supported"),
}
}
}
fn read_version<I>(io: I) -> ReadVersion<I>
where
I: Read + Unpin,
{
ReadVersion {
io: Some(io),
buf: [MaybeUninit::uninit(); 24],
filled: 0,
version: Version::H2,
_pin: PhantomPinned,
}
}
pin_project! {
struct ReadVersion<I> {
io: Option<I>,
buf: [MaybeUninit<u8>; 24],
filled: usize,
version: Version,
#[pin]
_pin: PhantomPinned,
}
}
impl<I> Future for ReadVersion<I>
where
I: Read + Unpin,
{
type Output = io::Result<(Version, Rewind<I>)>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut buf = ReadBuf::uninit(&mut *this.buf);
unsafe {
buf.unfilled().advance(*this.filled);
};
while buf.filled().len() < H2_PREFACE.len() {
let len = buf.filled().len();
ready!(Pin::new(this.io.as_mut().unwrap()).poll_read(cx, buf.unfilled()))?;
*this.filled = buf.filled().len();
if buf.filled().len() == len
|| &buf.filled()[len..] != &H2_PREFACE[len..buf.filled().len()]
{
*this.version = Version::H1;
break;
}
}
let io = this.io.take().unwrap();
let buf = buf.filled().to_vec();
Poll::Ready(Ok((
*this.version,
Rewind::new_buffered(io, Bytes::from(buf)),
)))
}
}
pin_project! {
pub struct Connection<'a, I, S, E>
where
S: HttpService<Incoming>,
{
#[pin]
state: ConnState<'a, I, S, E>,
}
}
#[cfg(feature = "http1")]
type Http1Connection<I, S> = hyper::server::conn::http1::Connection<Rewind<I>, S>;
#[cfg(not(feature = "http1"))]
type Http1Connection<I, S> = (PhantomData<I>, PhantomData<S>);
#[cfg(feature = "http2")]
type Http2Connection<I, S, E> = hyper::server::conn::http2::Connection<Rewind<I>, S, E>;
#[cfg(not(feature = "http2"))]
type Http2Connection<I, S, E> = (PhantomData<I>, PhantomData<S>, PhantomData<E>);
pin_project! {
#[project = ConnStateProj]
enum ConnState<'a, I, S, E>
where
S: HttpService<Incoming>,
{
ReadVersion {
#[pin]
read_version: ReadVersion<I>,
builder: &'a Builder<E>,
service: Option<S>,
},
H1 {
#[pin]
conn: Http1Connection<I, S>,
},
H2 {
#[pin]
conn: Http2Connection<I, S, E>,
},
}
}
impl<I, S, E, B> Connection<'_, I, S, E>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: HttpServerConnExec<S::Future, B>,
{
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
ConnStateProj::ReadVersion { .. } => {}
#[cfg(feature = "http1")]
ConnStateProj::H1 { conn } => conn.graceful_shutdown(),
#[cfg(feature = "http2")]
ConnStateProj::H2 { conn } => conn.graceful_shutdown(),
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
}
impl<I, S, E, B> Future for Connection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
ConnStateProj::ReadVersion {
read_version,
builder,
service,
} => {
let (version, io) = ready!(read_version.poll(cx))?;
let service = service.take().unwrap();
match version {
#[cfg(feature = "http1")]
Version::H1 => {
let conn = builder.http1.serve_connection(io, service);
this.state.set(ConnState::H1 { conn });
}
#[cfg(feature = "http2")]
Version::H2 => {
let conn = builder.http2.serve_connection(io, service);
this.state.set(ConnState::H2 { conn });
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => return Poll::Ready(Err(version.unsupported())),
}
}
#[cfg(feature = "http1")]
ConnStateProj::H1 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(feature = "http2")]
ConnStateProj::H2 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
}
}
pin_project! {
pub struct UpgradeableConnection<'a, I, S, E>
where
S: HttpService<Incoming>,
{
#[pin]
state: UpgradeableConnState<'a, I, S, E>,
}
}
#[cfg(feature = "http1")]
type Http1UpgradeableConnection<I, S> = hyper::server::conn::http1::UpgradeableConnection<I, S>;
#[cfg(not(feature = "http1"))]
type Http1UpgradeableConnection<I, S> = (PhantomData<I>, PhantomData<S>);
pin_project! {
#[project = UpgradeableConnStateProj]
enum UpgradeableConnState<'a, I, S, E>
where
S: HttpService<Incoming>,
{
ReadVersion {
#[pin]
read_version: ReadVersion<I>,
builder: &'a Builder<E>,
service: Option<S>,
},
H1 {
#[pin]
conn: Http1UpgradeableConnection<Rewind<I>, S>,
},
H2 {
#[pin]
conn: Http2Connection<I, S, E>,
},
}
}
impl<I, S, E, B> UpgradeableConnection<'_, I, S, E>
where
S: HttpService<Incoming, ResBody = B>,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
E: HttpServerConnExec<S::Future, B>,
{
pub fn graceful_shutdown(self: Pin<&mut Self>) {
match self.project().state.project() {
UpgradeableConnStateProj::ReadVersion { .. } => {}
#[cfg(feature = "http1")]
UpgradeableConnStateProj::H1 { conn } => conn.graceful_shutdown(),
#[cfg(feature = "http2")]
UpgradeableConnStateProj::H2 { conn } => conn.graceful_shutdown(),
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
}
impl<I, S, E, B> Future for UpgradeableConnection<'_, I, S, E>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + Send + 'static,
E: HttpServerConnExec<S::Future, B>,
{
type Output = Result<()>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
loop {
let mut this = self.as_mut().project();
match this.state.as_mut().project() {
UpgradeableConnStateProj::ReadVersion {
read_version,
builder,
service,
} => {
let (version, io) = ready!(read_version.poll(cx))?;
let service = service.take().unwrap();
match version {
#[cfg(feature = "http1")]
Version::H1 => {
let conn = builder.http1.serve_connection(io, service).with_upgrades();
this.state.set(UpgradeableConnState::H1 { conn });
}
#[cfg(feature = "http2")]
Version::H2 => {
let conn = builder.http2.serve_connection(io, service);
this.state.set(UpgradeableConnState::H2 { conn });
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => return Poll::Ready(Err(version.unsupported())),
}
}
#[cfg(feature = "http1")]
UpgradeableConnStateProj::H1 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(feature = "http2")]
UpgradeableConnStateProj::H2 { conn } => {
return conn.poll(cx).map_err(Into::into);
}
#[cfg(any(not(feature = "http1"), not(feature = "http2")))]
_ => unreachable!(),
}
}
}
}
#[cfg(feature = "http1")]
pub struct Http1Builder<'a, E> {
inner: &'a mut Builder<E>,
}
#[cfg(feature = "http1")]
impl<E> Http1Builder<'_, E> {
#[cfg(feature = "http2")]
pub fn http2(&mut self) -> Http2Builder<'_, E> {
Http2Builder { inner: self.inner }
}
pub fn half_close(&mut self, val: bool) -> &mut Self {
self.inner.http1.half_close(val);
self
}
pub fn keep_alive(&mut self, val: bool) -> &mut Self {
self.inner.http1.keep_alive(val);
self
}
pub fn title_case_headers(&mut self, enabled: bool) -> &mut Self {
self.inner.http1.title_case_headers(enabled);
self
}
pub fn preserve_header_case(&mut self, enabled: bool) -> &mut Self {
self.inner.http1.preserve_header_case(enabled);
self
}
pub fn header_read_timeout(&mut self, read_timeout: Duration) -> &mut Self {
self.inner.http1.header_read_timeout(read_timeout);
self
}
pub fn writev(&mut self, val: bool) -> &mut Self {
self.inner.http1.writev(val);
self
}
pub fn max_buf_size(&mut self, max: usize) -> &mut Self {
self.inner.http1.max_buf_size(max);
self
}
pub fn pipeline_flush(&mut self, enabled: bool) -> &mut Self {
self.inner.http1.pipeline_flush(enabled);
self
}
pub fn timer<M>(&mut self, timer: M) -> &mut Self
where
M: Timer + Send + Sync + 'static,
{
self.inner.http1.timer(timer);
self
}
#[cfg(feature = "http2")]
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
self.inner.serve_connection(io, service).await
}
#[cfg(not(feature = "http2"))]
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
{
self.inner.serve_connection(io, service).await
}
}
#[cfg(feature = "http2")]
pub struct Http2Builder<'a, E> {
inner: &'a mut Builder<E>,
}
#[cfg(feature = "http2")]
impl<E> Http2Builder<'_, E> {
#[cfg(feature = "http1")]
pub fn http1(&mut self) -> Http1Builder<'_, E> {
Http1Builder { inner: self.inner }
}
pub fn initial_stream_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.initial_stream_window_size(sz);
self
}
pub fn initial_connection_window_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.initial_connection_window_size(sz);
self
}
pub fn adaptive_window(&mut self, enabled: bool) -> &mut Self {
self.inner.http2.adaptive_window(enabled);
self
}
pub fn max_frame_size(&mut self, sz: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.max_frame_size(sz);
self
}
pub fn max_concurrent_streams(&mut self, max: impl Into<Option<u32>>) -> &mut Self {
self.inner.http2.max_concurrent_streams(max);
self
}
pub fn keep_alive_interval(&mut self, interval: impl Into<Option<Duration>>) -> &mut Self {
self.inner.http2.keep_alive_interval(interval);
self
}
pub fn keep_alive_timeout(&mut self, timeout: Duration) -> &mut Self {
self.inner.http2.keep_alive_timeout(timeout);
self
}
pub fn max_send_buf_size(&mut self, max: usize) -> &mut Self {
self.inner.http2.max_send_buf_size(max);
self
}
pub fn enable_connect_protocol(&mut self) -> &mut Self {
self.inner.http2.enable_connect_protocol();
self
}
pub fn max_header_list_size(&mut self, max: u32) -> &mut Self {
self.inner.http2.max_header_list_size(max);
self
}
pub fn timer<M>(&mut self, timer: M) -> &mut Self
where
M: Timer + Send + Sync + 'static,
{
self.inner.http2.timer(timer);
self
}
pub async fn serve_connection<I, S, B>(&self, io: I, service: S) -> Result<()>
where
S: Service<Request<Incoming>, Response = Response<B>>,
S::Future: 'static,
S::Error: Into<Box<dyn StdError + Send + Sync>>,
B: Body + 'static,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
I: Read + Write + Unpin + 'static,
E: HttpServerConnExec<S::Future, B>,
{
self.inner.serve_connection(io, service).await
}
}
#[cfg(test)]
mod tests {
use crate::{
rt::{TokioExecutor, TokioIo},
server::conn::auto,
};
use http::{Request, Response};
use http_body::Body;
use http_body_util::{BodyExt, Empty, Full};
use hyper::{body, body::Bytes, client, service::service_fn};
use std::{convert::Infallible, error::Error as StdError, net::SocketAddr};
use tokio::net::{TcpListener, TcpStream};
const BODY: &[u8] = b"Hello, world!";
#[test]
fn configuration() {
auto::Builder::new(TokioExecutor::new())
.http1()
.keep_alive(true)
.http2()
.keep_alive_interval(None);
let mut builder = auto::Builder::new(TokioExecutor::new());
builder.http1().keep_alive(true);
builder.http2().keep_alive_interval(None);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http1() {
let addr = start_server().await;
let mut sender = connect_h1(addr).await;
let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, BODY);
}
#[cfg(not(miri))]
#[tokio::test]
async fn http2() {
let addr = start_server().await;
let mut sender = connect_h2(addr).await;
let response = sender
.send_request(Request::new(Empty::<Bytes>::new()))
.await
.unwrap();
let body = response.into_body().collect().await.unwrap().to_bytes();
assert_eq!(body, BODY);
}
async fn connect_h1<B>(addr: SocketAddr) -> client::conn::http1::SendRequest<B>
where
B: Body + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (sender, connection) = client::conn::http1::handshake(stream).await.unwrap();
tokio::spawn(connection);
sender
}
async fn connect_h2<B>(addr: SocketAddr) -> client::conn::http2::SendRequest<B>
where
B: Body + Unpin + Send + 'static,
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
let stream = TokioIo::new(TcpStream::connect(addr).await.unwrap());
let (sender, connection) = client::conn::http2::Builder::new(TokioExecutor::new())
.handshake(stream)
.await
.unwrap();
tokio::spawn(connection);
sender
}
async fn start_server() -> SocketAddr {
let addr: SocketAddr = ([127, 0, 0, 1], 0).into();
let listener = TcpListener::bind(addr).await.unwrap();
let local_addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
let stream = TokioIo::new(stream);
tokio::task::spawn(async move {
let _ = auto::Builder::new(TokioExecutor::new())
.serve_connection(stream, service_fn(hello))
.await;
});
}
});
local_addr
}
async fn hello(_req: Request<body::Incoming>) -> Result<Response<Full<Bytes>>, Infallible> {
Ok(Response::new(Full::new(Bytes::from(BODY))))
}
}