From 04e1eff3dbce2425cb4ed2787587d37935d71c16 Mon Sep 17 00:00:00 2001 From: Dietrich Beck Date: Fri, 15 Dec 2023 17:21:53 +0100 Subject: [PATCH] added async server example and constructor with rx_remainder for framer_async --- examples/server_async.rs | 274 +++++++++++++++++++++++++++++++++++++++ src/framer_async.rs | 103 ++++++++++----- 2 files changed, 344 insertions(+), 33 deletions(-) create mode 100644 examples/server_async.rs diff --git a/examples/server_async.rs b/examples/server_async.rs new file mode 100644 index 0000000..62048fd --- /dev/null +++ b/examples/server_async.rs @@ -0,0 +1,274 @@ +use core::ops::Deref; +use futures::{Sink, SinkExt, Stream, StreamExt}; +use std::io; + +use bytes::{BufMut, BytesMut}; +use tokio::net::{TcpListener, TcpStream}; +use tokio_util::codec::{Decoder, Encoder, Framed}; + +use embedded_websocket::{ + framer_async::{Framer, FramerError, ReadResult}, + read_http_header, WebSocketContext, WebSocketSendMessageType, WebSocketServer, +}; + +struct MyCodec {} + +impl MyCodec { + fn new() -> Self { + MyCodec {} + } +} + +impl Decoder for MyCodec { + type Item = BytesMut; + type Error = io::Error; + + fn decode(&mut self, buf: &mut BytesMut) -> Result, io::Error> { + if !buf.is_empty() { + let len = buf.len(); + Ok(Some(buf.split_to(len))) + } else { + Ok(None) + } + } +} + +impl Encoder<&[u8]> for MyCodec { + type Error = io::Error; + + fn encode(&mut self, data: &[u8], buf: &mut BytesMut) -> Result<(), io::Error> { + buf.reserve(data.len()); + buf.put(data); + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), FramerError> { + let addr = "127.0.0.1:1337"; + let listener = TcpListener::bind(addr).await.map_err(FramerError::Io)?; + println!("Listening on: {}", addr); + + // accept connections and process them in parallel + loop { + match listener.accept().await { + Ok((stream, _)) => { + tokio::spawn(async move { + match handle_client(stream).await { + Ok(()) => println!("Connection closed"), + Err(e) => println!("Error: {:?}", e), + }; + }); + } + Err(e) => println!("Failed to establish a connection: {}", e), + } + } +} + +async fn handle_client(stream: TcpStream) -> Result<(), FramerError> { + println!( + "Client connected {}", + stream.peer_addr().map_err(FramerError::Io)? + ); + + let mut buffer = [0u8; 4000]; + let codec = MyCodec::new(); + let mut stream = Framed::new(stream, codec); + + if let Some((websocket_context, rx_remainder_len)) = + read_header(&mut stream, &mut buffer).await? + { + // this is a websocket upgrade HTTP request + let websocket = WebSocketServer::new_server(); + let mut framer = Framer::new_with_rx(websocket, rx_remainder_len); + + // complete the opening handshake with the client + framer + .accept(&mut stream, &mut buffer, &websocket_context) + .await?; + println!("Websocket connection opened"); + + // read websocket frames + while let Some(read_result) = framer.read(&mut stream, &mut buffer).await { + if let ReadResult::Text(text) = read_result? { + println!("Received: {}", text); + + // copy text to satisfy borrow checker + let text = Vec::from(text.as_bytes()); + + // send the text back to the client + framer + .write( + &mut stream, + &mut buffer, + WebSocketSendMessageType::Text, + true, + &text, + ) + .await? + } + } + + println!("Closing websocket connection"); + + Ok(()) + } else { + Ok(()) + } +} + +async fn read_header<'a, B: Deref, E>( + stream: &mut (impl Stream> + Sink<&'a [u8], Error = E> + Unpin), + buffer: &'a mut [u8], +) -> Result, FramerError> { + let mut read_cursor = 0usize; + + loop { + let mut headers = [httparse::EMPTY_HEADER; 16]; + let mut request = httparse::Request::new(&mut headers); + + match stream.next().await { + Some(Ok(input)) => { + if buffer.len() < read_cursor + input.len() { + return Err(FramerError::RxBufferTooSmall(read_cursor + input.len())); + } + + // copy to start of buffer (unlike Framer::read()) + buffer[read_cursor..read_cursor + input.len()].copy_from_slice(&input); + read_cursor += input.len(); + + if let httparse::Status::Complete(len) = request + .parse(&buffer[0..read_cursor]) + .map_err(FramerError::HttpHeader)? + { + // if we read exactly the right amount of bytes for the HTTP header then read_cursor would be 0 + let headers = request.headers.iter().map(|f| (f.name, f.value)); + match read_http_header(headers).map_err(FramerError::WebSocket)? { + Some(websocket_context) => match request.path { + Some("/chat") => { + let remaining_len = read_cursor - len; + for i in 0..remaining_len { + buffer[buffer.len() - remaining_len + i] = buffer[len + i] + } + return Ok(Some((websocket_context, remaining_len))); + } + _ => return_404_not_found(stream, request.path).await?, + }, + None => { + handle_non_websocket_http_request(stream, request.path).await?; + } + } + return Ok(None); + } + } + Some(Err(e)) => { + return Err(FramerError::Io(e)); + } + None => return Ok(None), + } + } +} + +async fn handle_non_websocket_http_request<'a, B, E>( + stream: &mut (impl Stream> + Sink<&'a [u8], Error = E> + Unpin), + path: Option<&str>, +) -> Result<(), FramerError> { + println!("Received file request: {:?}", path); + + match path { + Some("/") => { + stream + .send(ROOT_HTML.as_bytes()) + .await + .map_err(FramerError::Io)?; + stream.flush().await.map_err(FramerError::Io)?; + } + unknown_path => { + return_404_not_found(stream, unknown_path).await?; + } + }; + + Ok(()) +} + +async fn return_404_not_found<'a, B, E>( + stream: &mut (impl Stream> + Sink<&'a [u8], Error = E> + Unpin), + unknown_path: Option<&str>, +) -> Result<(), FramerError> { + println!("Unknown path: {:?}", unknown_path); + let html = "HTTP/1.1 404 Not Found\r\nContent-Length: 0\r\nConnection: close\r\n\r\n"; + stream + .send(html.as_bytes()) + .await + .map_err(FramerError::Io)?; + stream.flush().await.map_err(FramerError::Io)?; + Ok(()) +} + +const ROOT_HTML : &str = "HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=UTF-8\r\nContent-Length: 2590\r\nConnection: close\r\n\r\n + + + + + + + + Web Socket Demo + + + +
    +
    + +
    + + + +"; diff --git a/src/framer_async.rs b/src/framer_async.rs index 20a0063..14f0b3b 100644 --- a/src/framer_async.rs +++ b/src/framer_async.rs @@ -4,8 +4,8 @@ use futures::{Sink, SinkExt, Stream, StreamExt}; use rand_core::RngCore; use crate::{ - WebSocket, WebSocketCloseStatusCode, WebSocketOptions, WebSocketReceiveMessageType, - WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType, + WebSocket, WebSocketCloseStatusCode, WebSocketContext, WebSocketOptions, + WebSocketReceiveMessageType, WebSocketSendMessageType, WebSocketSubProtocol, WebSocketType, }; pub struct CloseMessage<'a> { @@ -40,6 +40,10 @@ pub enum FramerError { RxBufferTooSmall(usize), } +/// NOTE: expected buffer layout for read and connect +/// [0 .. frame_cursor]: already decoded by websocket +/// [frame_cursor .. len - rx_remainder_len]: free +/// [len - rx_remainder_len .. len]: raw received bytes (not decoded by websocket yet) pub struct Framer where TRng: RngCore, @@ -72,45 +76,67 @@ where stream.send(tx_buf).await.map_err(FramerError::Io)?; stream.flush().await.map_err(FramerError::Io)?; - loop { - match stream.next().await { - Some(buf) => { - let buf = buf.map_err(FramerError::Io)?; - let buf = buf.as_ref(); - - match self.websocket.client_accept(&web_socket_key, buf) { - Ok((len, sub_protocol)) => { - // "consume" the HTTP header that we have read from the stream - // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else - - // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s) - let from = len; - let to = buf.len(); - let remaining_len = to - from; - - if remaining_len > 0 { - let rx_start = rx_buf.len() - remaining_len; - rx_buf[rx_start..].copy_from_slice(&buf[from..to]); - self.rx_remainder_len = remaining_len; - } - - return Ok(sub_protocol); - } - Err(crate::Error::HttpHeaderIncomplete) => { - // TODO: continue reading HTTP header in loop - panic!("oh no"); - } - Err(e) => { - return Err(FramerError::WebSocket(e)); + match stream.next().await { + Some(buf) => { + let buf = buf.map_err(FramerError::Io)?; + let buf = buf.as_ref(); + + match self.websocket.client_accept(&web_socket_key, buf) { + Ok((len, sub_protocol)) => { + // "consume" the HTTP header that we have read from the stream + // read_cursor would be 0 if we exactly read the HTTP header from the stream and nothing else + + // copy the remaining bytes to the end of the rx_buf (which is also the end of the buffer) because they are the contents of the next websocket frame(s) + let from = len; + let to = buf.len(); + let remaining_len = to - from; + + if remaining_len > 0 { + let rx_start = rx_buf.len() - remaining_len; + rx_buf[rx_start..].copy_from_slice(&buf[from..to]); + self.rx_remainder_len = remaining_len; } + + Ok(sub_protocol) } + Err(crate::Error::HttpHeaderIncomplete) => { + // TODO: continue reading HTTP header in loop + panic!("oh no"); + } + Err(e) => Err(FramerError::WebSocket(e)), } - None => return Err(FramerError::Disconnected), } + None => Err(FramerError::Disconnected), } } } +impl Framer +where + TRng: RngCore, +{ + pub async fn accept<'a, B, E>( + &mut self, + stream: &mut (impl Stream> + Sink<&'a [u8], Error = E> + Unpin), + buffer: &'a mut [u8], + websocket_context: &WebSocketContext, + ) -> Result, FramerError> { + let sec_websocket_protocol = None; + let len = self + .websocket + .server_accept( + &websocket_context.sec_websocket_key, + sec_websocket_protocol.as_ref(), + buffer, + ) + .map_err(FramerError::WebSocket)?; + + stream.send(&buffer[..len]).await.map_err(FramerError::Io)?; + stream.flush().await.map_err(FramerError::Io)?; + Ok(sec_websocket_protocol) + } +} + impl Framer where TRng: RngCore, @@ -124,6 +150,17 @@ where } } + pub fn new_with_rx( + websocket: WebSocket, + rx_remainder_len: usize, + ) -> Self { + Self { + websocket, + frame_cursor: 0, + rx_remainder_len, + } + } + pub fn encode( &mut self, message_type: WebSocketSendMessageType,