Skip to content

feat(irpc-iroh): make it easy to do a manual connection loop with authentication #14

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 282 additions & 0 deletions irpc-iroh/examples/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
//! This example demonstrates a few things:
//! * Using irpc with a cloneable server struct instead of with an actor loop
//! * Manually implementing the connection loop
//! * Authenticating peers

use anyhow::Result;
use iroh::{protocol::Router, Endpoint};

use self::storage::{StorageClient, StorageServer};

#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
println!("Remote use");
remote().await?;
Ok(())
}

async fn remote() -> Result<()> {
let (server_router, server_addr) = {
let endpoint = Endpoint::builder().discovery_n0().bind().await?;
let server = StorageServer::new("secret".to_string());
let router = Router::builder(endpoint.clone())
.accept(StorageServer::ALPN, server.clone())
.spawn()
.await?;
let addr = endpoint.node_addr().await?;
(router, addr)
};

let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone());
api.auth("secret").await?;
api.set("hello".to_string(), "world".to_string()).await?;
api.set("goodbye".to_string(), "world".to_string()).await?;
let value = api.get("hello".to_string()).await?;
println!("value = {:?}", value);
let mut list = api.list().await?;
while let Some(value) = list.recv().await? {
println!("list value = {:?}", value);
}

let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr.clone());
assert!(api.auth("bad").await.is_err());
assert!(api.get("hello".to_string()).await.is_err());

let client_endpoint = Endpoint::builder().bind().await?;
let api = StorageClient::connect(client_endpoint, server_addr);
assert!(api.get("hello".to_string()).await.is_err());

drop(server_router);
Ok(())
}

mod storage {
//! Implementation of our storage service.
//!
//! The only `pub` item is [`StorageApi`], everything else is private.

use std::{
collections::BTreeMap,
sync::{Arc, Mutex},
};

use anyhow::Result;
use iroh::{endpoint::Connection, protocol::ProtocolHandler, Endpoint};
use irpc::{
channel::{oneshot, spsc},
Client, Service, WithChannels,
};
// Import the macro
use irpc_derive::rpc_requests;
use irpc_iroh::{read_request, IrohRemoteConnection};
use serde::{Deserialize, Serialize};
use tracing::info;

const ALPN: &[u8] = b"storage-api/0";

/// A simple storage service, just to try it out
#[derive(Debug, Clone, Copy)]
struct StorageService;

impl Service for StorageService {}

#[derive(Debug, Serialize, Deserialize)]
struct Auth {
token: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct Get {
key: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct List;

#[derive(Debug, Serialize, Deserialize)]
struct Set {
key: String,
value: String,
}

#[derive(Debug, Serialize, Deserialize)]
struct SetMany;

// Use the macro to generate both the StorageProtocol and StorageMessage enums
// plus implement Channels for each type
#[rpc_requests(StorageService, message = StorageMessage)]
#[derive(Serialize, Deserialize)]
enum StorageProtocol {
#[rpc(tx=oneshot::Sender<Result<(), String>>)]
Auth(Auth),
#[rpc(tx=oneshot::Sender<Option<String>>)]
Get(Get),
#[rpc(tx=oneshot::Sender<()>)]
Set(Set),
#[rpc(tx=oneshot::Sender<u64>, rx=spsc::Receiver<(String, String)>)]
SetMany(SetMany),
#[rpc(tx=spsc::Sender<String>)]
List(List),
}

#[derive(Debug, Clone)]
pub struct StorageServer {
state: Arc<Mutex<BTreeMap<String, String>>>,
auth_token: String,
}

#[derive(Default)]
struct PeerState {
authed: bool,
}

impl ProtocolHandler for StorageServer {
fn accept(&self, conn: Connection) -> n0_future::future::Boxed<Result<()>> {
let this = self.clone();
Box::pin(async move {
let mut peer_state = PeerState::default();
while let Some((msg, rx, tx)) = read_request(&conn).await? {
// Upcast the send/receive streams to the channel types each message needs.
let msg: StorageMessage = match msg {
StorageProtocol::Auth(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::Get(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::Set(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::SetMany(msg) => WithChannels::from((msg, tx, rx)).into(),
StorageProtocol::List(msg) => WithChannels::from((msg, tx, rx)).into(),
};

// Handle the message
if let Err(err) = this.handle(&mut peer_state, msg).await {
match err {
Error::Unauthorized => conn.close(401u32.into(), b"unauthorized"),
Error::InvalidMessage => conn.close(400u32.into(), b"invalid message"),
}
break;
}
}
conn.closed().await;
Ok(())
})
}
}

enum Error {
Unauthorized,
InvalidMessage,
}

impl StorageServer {
pub const ALPN: &[u8] = ALPN;

pub fn new(auth_token: String) -> Self {
Self {
state: Default::default(),
auth_token,
}
}

async fn handle(
&self,
peer_state: &mut PeerState,
msg: StorageMessage,
) -> Result<(), Error> {
if !peer_state.authed && !matches!(msg, StorageMessage::Auth(_)) {
return Err(Error::InvalidMessage);
}
match msg {
StorageMessage::Auth(auth) => {
let WithChannels { tx, inner, .. } = auth;
if peer_state.authed {
return Err(Error::InvalidMessage);
} else if inner.token != self.auth_token {
return Err(Error::Unauthorized);
} else {
peer_state.authed = true;
tx.send(Ok(())).await.ok();
}
}
StorageMessage::Get(get) => {
info!("get {:?}", get);
let WithChannels { tx, inner, .. } = get;
let res = self.state.lock().unwrap().get(&inner.key).cloned();
tx.send(res).await.ok();
}
StorageMessage::Set(set) => {
info!("set {:?}", set);
let WithChannels { tx, inner, .. } = set;
self.state.lock().unwrap().insert(inner.key, inner.value);
tx.send(()).await.ok();
}
StorageMessage::SetMany(list) => {
let WithChannels { tx, mut rx, .. } = list;
let mut i = 0;
while let Ok(Some((key, value))) = rx.recv().await {
let mut state = self.state.lock().unwrap();
state.insert(key, value);
i += 1;
}
tx.send(i).await.ok();
}
StorageMessage::List(list) => {
info!("list {:?}", list);
let WithChannels { mut tx, .. } = list;
let values = {
let state = self.state.lock().unwrap();
// TODO: use async lock to not clone here.
let values: Vec<_> = state
.iter()
.map(|(key, value)| format!("{key}={value}"))
.collect();
values
};
for value in values {
if tx.send(value).await.is_err() {
break;
}
}
}
}
Ok(())
}
}

pub struct StorageClient {
inner: Client<StorageMessage, StorageProtocol, StorageService>,
}

impl StorageClient {
pub const ALPN: &[u8] = ALPN;

pub fn connect(endpoint: Endpoint, addr: impl Into<iroh::NodeAddr>) -> StorageClient {
let conn = IrohRemoteConnection::new(endpoint, addr.into(), Self::ALPN.to_vec());
StorageClient {
inner: Client::boxed(conn),
}
}

pub async fn auth(&self, token: &str) -> Result<(), anyhow::Error> {
self.inner
.rpc(Auth {
token: token.to_string(),
})
.await?
.map_err(|err| anyhow::anyhow!(err))
}

pub async fn get(&self, key: String) -> Result<Option<String>, irpc::Error> {
self.inner.rpc(Get { key }).await
}

pub async fn list(&self) -> Result<spsc::Receiver<String>, irpc::Error> {
self.inner.server_streaming(List, 10).await
}

pub async fn set(&self, key: String, value: String) -> Result<(), irpc::Error> {
let msg = Set { key, value };
self.inner.rpc(msg).await
}
}
}
55 changes: 31 additions & 24 deletions irpc-iroh/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -128,35 +128,42 @@ pub async fn handle_connection<R: DeserializeOwned + 'static>(
handler: Handler<R>,
) -> io::Result<()> {
loop {
let (send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause))
if cause.error_code.into_inner() == 0 =>
{
trace!("remote side closed connection {cause:?}");
return Ok(());
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
let Some((msg, rx, tx)) = read_request(&connection).await? else {
return Ok(());
};
let size = recv
.read_varint_u64()
.await?
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
let msg: R = postcard::from_bytes(&buf)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let rx = recv;
let tx = send;
handler(msg, rx, tx).await?;
}
}

pub async fn read_request<R: DeserializeOwned + 'static>(
connection: &Connection,
) -> std::io::Result<Option<(R, RecvStream, SendStream)>> {
let (send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause)) if cause.error_code.into_inner() == 0 => {
trace!("remote side closed connection {cause:?}");
return Ok(None);
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
};
let size = recv
.read_varint_u64()
.await?
.ok_or_else(|| io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size"))?;
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
let msg: R =
postcard::from_bytes(&buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let rx = recv;
let tx = send;
Ok(Some((msg, rx, tx)))
}

/// Utility function to listen for incoming connections and handle them with the provided handler
pub async fn listen<R: DeserializeOwned + 'static>(endpoint: iroh::Endpoint, handler: Handler<R>) {
let mut request_id = 0u64;
Expand Down