Skip to content
Snippets Groups Projects
Commit c1e51ac7 authored by Julia Andrews's avatar Julia Andrews
Browse files

Add more sophisticated "axum-like" routing

Lots of futzing with generics to allow handler functions of many
different signatures, whose arguments cause the API to automatically
extract and/or parse parts of the request, and whose return values cause
the API to automatically encode responses.

All of this still needs documentation and cleaning up/refactoring into
different modules.
parent 401c8be7
No related branches found
No related tags found
1 merge request!9Rust RPC module implementation
...@@ -566,6 +566,12 @@ dependencies = [ ...@@ -566,6 +566,12 @@ dependencies = [
"either", "either",
] ]
[[package]]
name = "itoa"
version = "1.0.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b"
[[package]] [[package]]
name = "js-sys" name = "js-sys"
version = "0.3.69" version = "0.3.69"
...@@ -898,6 +904,12 @@ dependencies = [ ...@@ -898,6 +904,12 @@ dependencies = [
"windows-sys 0.52.0", "windows-sys 0.52.0",
] ]
[[package]]
name = "ryu"
version = "1.0.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.202" version = "1.0.202"
...@@ -918,6 +930,17 @@ dependencies = [ ...@@ -918,6 +930,17 @@ dependencies = [
"syn 2.0.53", "syn 2.0.53",
] ]
[[package]]
name = "serde_json"
version = "1.0.120"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]] [[package]]
name = "sharded-slab" name = "sharded-slab"
version = "0.1.7" version = "0.1.7"
...@@ -1431,6 +1454,7 @@ dependencies = [ ...@@ -1431,6 +1454,7 @@ dependencies = [
"lazy_static", "lazy_static",
"libc", "libc",
"serde", "serde",
"serde_json",
"tokio", "tokio",
"tower", "tower",
"tracing", "tracing",
......
...@@ -5,7 +5,7 @@ use std::sync::Arc; ...@@ -5,7 +5,7 @@ use std::sync::Arc;
use base64::prelude::*; use base64::prelude::*;
use structopt::StructOpt; use structopt::StructOpt;
use xxdk::base::*; use xxdk::base::*;
use xxdk::service::*; use xxdk::service::{CMixServerConfig, SenderId, Utf8Lossy};
const SECRET: &str = "Hello"; const SECRET: &str = "Hello";
const REGISTRATION_CODE: &str = ""; const REGISTRATION_CODE: &str = "";
...@@ -56,17 +56,16 @@ pub async fn run() -> Result<(), String> { ...@@ -56,17 +56,16 @@ pub async fn run() -> Result<(), String> {
private_key: String::from(""), private_key: String::from(""),
}; };
let xx_router = xxdk::service::Router::new(Arc::new(cmix)).route("demo", xx_rpc_handler); let xx_router = xxdk::service::Router::with_state(Arc::new(cmix)).route("demo", xx_rpc_handler);
CMixServer::serve(xx_router, cmix_config).await xxdk::service::serve(xx_router, cmix_config).await
} }
pub async fn xx_rpc_handler(_: Arc<CMix>, request: IncomingRequest) -> Result<Vec<u8>, String> { pub async fn xx_rpc_handler(id: SenderId, req: Utf8Lossy) -> String {
let sender: String = request.sender_id.iter().fold(String::new(), |mut s, b| { let sender: String = id.0.iter().fold(String::new(), |mut s, b| {
write!(s, "{b:02x}").unwrap(); write!(s, "{b:02x}").unwrap();
s s
}); });
tracing::info!(sender, "Received message via cMix",); tracing::info!(sender, "Received message via cMix",);
let text = String::from_utf8_lossy(&request.request); let text = req.0;
format!("Hi from rust rpc example! Echoed message: {text}")
Ok(format!("Hi from rust rpc example! Echoed message: {text}").into_bytes())
} }
...@@ -10,6 +10,7 @@ base64 = "0.22.1" ...@@ -10,6 +10,7 @@ base64 = "0.22.1"
lazy_static = "1.4.0" lazy_static = "1.4.0"
libc = "0.2.153" libc = "0.2.153"
serde = { version = "1.0.202", features = ["derive"] } serde = { version = "1.0.202", features = ["derive"] }
serde_json = "1.0.120"
tokio = { version = "1.37.0", features = ["rt", "fs", "sync", "time"] } tokio = { version = "1.37.0", features = ["rt", "fs", "sync", "time"] }
tower = "0.4.13" tower = "0.4.13"
tracing = "0.1.40" tracing = "0.1.40"
......
use base64::prelude::*; use std::borrow::Cow;
use serde::Deserialize;
use std::collections::HashMap; use std::collections::HashMap;
use std::future::{self, poll_fn, Future}; use std::future::Future;
use std::pin::Pin; use std::pin::Pin;
use std::sync::Arc; use std::sync::Arc;
use std::task::Poll; use std::task::{Context, Poll};
use std::time::Duration; use std::time::Duration;
use tokio::sync::mpsc;
use base64::prelude::*;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use serde_json as json;
use tower::Service; use tower::Service;
use crate::base; use crate::{base, rpc};
use crate::rpc;
pub type PinnedFuture<T> = Pin<Box<dyn Future<Output = T> + Send + 'static>>;
#[derive(Debug, Clone)] #[derive(Debug, Clone)]
pub struct IncomingRequest { pub struct IncomingRequest {
pub sender_id: Vec<u8>, sender_id: Vec<u8>,
pub request: Vec<u8>, request: Vec<u8>,
separator_idx: usize,
} }
#[derive(Debug, Clone)]
struct Response { impl IncomingRequest {
pub text: Vec<u8>, fn new(sender_id: Vec<u8>, request: Vec<u8>) -> Result<Self, String> {
let separator_idx = request
.iter()
.position(|b| *b == b',')
.ok_or_else(|| "no endpoint in request".to_string())?;
std::str::from_utf8(&request[..separator_idx])
.map_err(|e| format!("non-UTF-8 endpoint: {e}"))?;
Ok(Self {
sender_id,
request,
separator_idx,
})
}
pub fn sender_id(&self) -> &[u8] {
&self.sender_id
}
pub fn endpoint(&self) -> &str {
unsafe { std::str::from_utf8_unchecked(&self.request[..self.separator_idx]) }
} }
type HandlerFnInner<T> = dyn Fn( pub fn request(&self) -> &[u8] {
Arc<T>, &self.request[self.separator_idx + 1..]
IncomingRequest, }
) -> Pin<Box<dyn Future<Output = Result<Vec<u8>, String>> + Send + 'static>> }
+ Send
+ Sync
+ 'static;
pub trait HandlerFn<T>: // TODO If we're a bit more careful about it, we can probably get rid of the Sync bound here
Fn(Arc<T>, IncomingRequest) -> Self::Future + Send + Sync + 'static pub trait Handler<T, S, Res>: Clone + Send + Sync + Sized + 'static {
fn call(self, req: IncomingRequest, state: S) -> PinnedFuture<Result<Vec<u8>, String>>;
}
macro_rules! impl_handler {
($($ty:ident),*) => {
impl<F, Fut, S, Res, $($ty),*> Handler<($($ty,)*), S, Res> for F
where
F: FnOnce($($ty),*) -> Fut + Clone + Send + Sync + Sized + 'static,
Fut: Future<Output = Res> + Send + 'static,
S: Send + 'static,
Res: IntoResponse,
$(
$ty: FromRequest<S>,
)*
{ {
type Future: Future<Output = Result<Vec<u8>, String>> + Send + 'static; #[allow(non_snake_case, unused_variables)]
fn call(self, req: IncomingRequest, state: S) -> PinnedFuture<Result<Vec<u8>, String>> {
Box::pin(async move {
$(
let $ty = $ty::extract(&req, &state)?;
)*
self($($ty),*).await.into_response()
})
}
}
};
}
macro_rules! tuples {
($name:ident) => {
$name!();
$name!(T1);
$name!(T1, T2);
$name!(T1, T2, T3);
$name!(T1, T2, T3, T4);
$name!(T1, T2, T3, T4, T5);
$name!(T1, T2, T3, T4, T5, T6);
$name!(T1, T2, T3, T4, T5, T6, T7);
$name!(T1, T2, T3, T4, T5, T6, T7, T8);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15);
$name!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16);
};
}
tuples!(impl_handler);
trait ErasedHandler<S>: Send + Sync + 'static {
fn call(&self, req: IncomingRequest, state: S) -> PinnedFuture<Result<Vec<u8>, String>>;
} }
impl<T, F, Fut> HandlerFn<T> for F struct MakeErasedHandler<H, S> {
handler: H,
#[allow(clippy::type_complexity)]
call: fn(H, IncomingRequest, S) -> PinnedFuture<Result<Vec<u8>, String>>,
}
impl<H, S> ErasedHandler<S> for MakeErasedHandler<H, S>
where where
F: Fn(Arc<T>, IncomingRequest) -> Fut + Send + Sync + 'static, H: Clone + Send + Sync + 'static,
Fut: Future<Output = Result<Vec<u8>, String>> + Send + 'static, S: 'static,
{ {
type Future = Fut; fn call(&self, req: IncomingRequest, state: S) -> PinnedFuture<Result<Vec<u8>, String>> {
let h = self.handler.clone();
(self.call)(h, req, state)
}
} }
pub struct Router<T> { impl<H, S> MakeErasedHandler<H, S> {
handlers: HashMap<String, Arc<HandlerFnInner<T>>>, fn make<T, Res>(handler: H) -> Self
state: Arc<T>, where
H: Handler<T, S, Res>,
{
let call = |h: H, req, state| h.call(req, state);
Self { handler, call }
}
} }
// Manual implementation to avoid derive putting a Clone bound on T type BoxedErasedHandler<S> = Arc<dyn ErasedHandler<S>>;
impl<T> Clone for Router<T> {
fn clone(&self) -> Self { #[derive(Clone)]
Self { pub struct Router<S> {
handlers: self.handlers.clone(), inner: Arc<RouterInner<S>>,
state: self.state.clone(),
} }
#[derive(Clone)]
struct RouterInner<S> {
handlers: HashMap<String, BoxedErasedHandler<S>>,
state: S,
}
impl Router<()> {
pub fn without_state() -> Self {
Self::with_state(())
} }
} }
impl<T> Router<T> { impl<S> Router<S>
pub fn new(state: Arc<T>) -> Self { where
Self { S: Send + Clone + 'static,
handlers: HashMap::new(), {
state, pub fn with_state(state: S) -> Self {
let handlers = HashMap::new();
let inner = Arc::new(RouterInner { handlers, state });
Self { inner }
} }
pub fn route<H, T, Res>(self, endpoint: &str, handler: H) -> Self
where
H: Handler<T, S, Res>,
{
let handler = Arc::new(MakeErasedHandler::make(handler));
self.with_inner(|inner| {
inner.handlers.insert(String::from(endpoint), handler);
})
} }
pub fn route<F>(mut self, endpoint: &str, handler: F) -> Self fn with_inner<F>(self, f: F) -> Self
where where
F: HandlerFn<T>, F: FnOnce(&mut RouterInner<S>),
{ {
self.handlers.insert( let mut inner = self.into_inner();
endpoint.to_string(), f(&mut inner);
Arc::new(move |state, req| Box::pin(handler(state, req))), Self {
); inner: Arc::new(inner),
self
} }
} }
impl<T> Service<IncomingRequest> for Router<T> fn into_inner(self) -> RouterInner<S> {
match Arc::try_unwrap(self.inner) {
Ok(inner) => inner,
Err(arc) => RouterInner::clone(&*arc),
}
}
}
impl<S> Service<IncomingRequest> for Router<S>
where where
T: Send + Sync + 'static, S: Clone + Send + 'static,
{ {
type Response = Vec<u8>; type Response = Vec<u8>;
type Error = String; type Error = String;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>; type Future = PinnedFuture<Result<Vec<u8>, String>>;
fn poll_ready(&mut self, _cx: &mut std::task::Context<'_>) -> Poll<Result<(), Self::Error>> { fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(())) Poll::Ready(Ok(()))
} }
fn call(&mut self, req: IncomingRequest) -> Self::Future { fn call(&mut self, req: IncomingRequest) -> Self::Future {
// TODO All this manual matching on Results is gross, see if we can find a cleaner way to let endpoint = req.endpoint();
// do this let handler = match self.inner.handlers.get(req.endpoint()) {
let separator_idx = match req Some(h) => h,
.request None => {
.iter() return Box::pin(std::future::ready(Err(format!(
.position(|b| *b == b',') "unrecognized endpoint `{endpoint}`"
.ok_or_else(|| "no endpoint in request".to_string()) ))))
{ }
Ok(idx) => idx,
Err(e) => return Box::pin(future::ready(Err(e))),
}; };
let (endpoint, request) = req.request.split_at(separator_idx);
let endpoint = let state = self.inner.state.clone();
match std::str::from_utf8(endpoint).map_err(|e| format!("non-UTF-8 endpoint ({e})")) { handler.call(req, state)
Ok(endpoint) => endpoint, }
Err(e) => return Box::pin(future::ready(Err(e))), }
};
if let Some(handler) = self.handlers.get(endpoint) { // TODO We can put a lifetime parameter on this to allow borrowing directly from the request buffer
let (_, request) = request.split_first().unwrap(); pub trait FromRequest<S>: Sized {
let req = IncomingRequest { fn extract(req: &IncomingRequest, state: &S) -> Result<Self, String>;
request: request.into(), }
sender_id: req.sender_id,
};
handler(self.state.clone(), req) pub trait IntoResponse {
} else { fn into_response(self) -> Result<Vec<u8>, String>;
Box::pin(future::ready(Err(format!( }
"unrecognized endpoint `{endpoint}`"
)))) impl<R> IntoResponse for Result<R, String>
where
R: IntoResponse,
{
fn into_response(self) -> Result<Vec<u8>, String> {
self.and_then(|r| r.into_response())
}
}
impl IntoResponse for () {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(Vec::new())
}
}
impl IntoResponse for Vec<u8> {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(self)
}
} }
impl IntoResponse for &[u8] {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(Vec::from(self))
}
}
impl<const N: usize> IntoResponse for [u8; N] {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(Vec::from(&self))
}
}
impl IntoResponse for Cow<'_, [u8]> {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(self.into_owned())
}
}
impl IntoResponse for String {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(self.into_bytes())
}
}
impl IntoResponse for &str {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(Vec::from(self.as_bytes()))
}
}
impl IntoResponse for Cow<'_, str> {
fn into_response(self) -> Result<Vec<u8>, String> {
Ok(self.into_owned().into_bytes())
}
}
#[derive(Debug, Clone, Copy)]
pub struct Json<T>(pub T);
impl<T, S> FromRequest<S> for Json<T>
where
T: DeserializeOwned,
{
fn extract(req: &IncomingRequest, _state: &S) -> Result<Self, String> {
Ok(Self(
json::from_slice(req.request()).map_err(|e| e.to_string())?,
))
}
}
impl<T> IntoResponse for Json<T>
where
T: Serialize,
{
fn into_response(self) -> Result<Vec<u8>, String> {
json::to_vec(&self.0).map_err(|e| e.to_string())
}
}
#[derive(Debug, Clone)]
pub struct SenderId(pub Vec<u8>);
impl<S> FromRequest<S> for SenderId {
fn extract(req: &IncomingRequest, _state: &S) -> Result<Self, String> {
Ok(Self(req.sender_id.clone()))
}
}
#[derive(Debug, Clone)]
pub struct RawRequest(pub Vec<u8>);
impl<S> FromRequest<S> for RawRequest {
fn extract(req: &IncomingRequest, _state: &S) -> Result<Self, String> {
Ok(Self(Vec::from(req.request())))
}
}
#[derive(Debug, Clone)]
pub struct Utf8(pub String);
impl<S> FromRequest<S> for Utf8 {
fn extract(req: &IncomingRequest, _state: &S) -> Result<Self, String> {
Ok(Self(String::from(
std::str::from_utf8(req.request()).map_err(|e| e.to_string())?,
)))
}
}
#[derive(Debug, Clone)]
pub struct Utf8Lossy(pub String);
impl<S> FromRequest<S> for Utf8Lossy {
fn extract(req: &IncomingRequest, _state: &S) -> Result<Self, String> {
Ok(Self(String::from_utf8_lossy(req.request()).into_owned()))
}
}
#[derive(Debug, Clone)]
pub struct State<S>(pub S);
impl<S> FromRequest<S> for State<S>
where
S: Clone,
{
fn extract(_req: &IncomingRequest, state: &S) -> Result<Self, String> {
Ok(Self(state.clone()))
} }
} }
...@@ -138,13 +374,9 @@ pub struct CMixServerConfig { ...@@ -138,13 +374,9 @@ pub struct CMixServerConfig {
pub private_key: String, pub private_key: String,
} }
#[derive(Debug)] pub async fn serve<S>(service: S, config: CMixServerConfig) -> Result<(), String>
pub struct CMixServer;
impl CMixServer {
pub async fn serve<T>(router: Router<T>, config: CMixServerConfig) -> Result<(), String>
where where
T: Send + Sync + 'static, S: Service<IncomingRequest, Response = Vec<u8>, Error = String> + Clone + Send + 'static,
{ {
tracing::info!("Starting cMix server"); tracing::info!("Starting cMix server");
let ndf_contents = tokio::fs::read_to_string(&config.ndf_path) let ndf_contents = tokio::fs::read_to_string(&config.ndf_path)
...@@ -231,16 +463,7 @@ impl CMixServer { ...@@ -231,16 +463,7 @@ impl CMixServer {
cmix.ekv_set("rpc_server_private_key", &private_key)?; cmix.ekv_set("rpc_server_private_key", &private_key)?;
let runtime = tokio::runtime::Handle::current(); let runtime = tokio::runtime::Handle::current();
let (sender, mut response_queue) = mpsc::channel(256); let cbs = CMixServerCallback { service, runtime };
let cbs = CMixServerCallbacks {
router,
runtime,
response_queue: sender,
};
tracing::info!("Spawning RPC server");
base::callbacks::set_rpc_callbacks();
let rpc_server = rpc::new_server(&cmix, cbs, reception_id, private_key)?;
let cmix = Arc::new(cmix); let cmix = Arc::new(cmix);
tokio::task::spawn_blocking({ tokio::task::spawn_blocking({
...@@ -262,8 +485,10 @@ impl CMixServer { ...@@ -262,8 +485,10 @@ impl CMixServer {
tokio::time::sleep(Duration::from_secs(1)).await; tokio::time::sleep(Duration::from_secs(1)).await;
} }
tracing::info!("Spawning RPC server");
base::callbacks::set_rpc_callbacks();
let rpc_server = rpc::new_server(&cmix, cbs, reception_id, private_key)?;
rpc_server.start(); rpc_server.start();
tracing::info!( tracing::info!(
"RPC Server CB PTR: {:#x}", "RPC Server CB PTR: {:#x}",
rpc_server.cb as *const _ as *const libc::c_void as usize rpc_server.cb as *const _ as *const libc::c_void as usize
...@@ -272,61 +497,50 @@ impl CMixServer { ...@@ -272,61 +497,50 @@ impl CMixServer {
tracing::info!("RPC Public Key: {public_key_b64}"); tracing::info!("RPC Public Key: {public_key_b64}");
tracing::info!("RPC Reception ID: {reception_id_b64}"); tracing::info!("RPC Reception ID: {reception_id_b64}");
while let Some(resp) = response_queue.recv().await { // TODO We need a better way to shut down the server. This never actually completes or gets
tokio::spawn(async move { // past this line, it just runs until the process gets a kill signal.
tracing::debug!("request received, sending response"); std::future::pending::<()>().await;
tracing::debug!("{}", String::from_utf8_lossy(&resp.text));
});
}
// rpc_server.stop(); rpc_server.stop();
cmix.stop_network_follower() cmix.stop_network_follower()
} }
}
unsafe impl Send for CMixServer {}
struct CMixServerCallbacks<T> { struct CMixServerCallback<S> {
router: Router<T>, service: S,
runtime: tokio::runtime::Handle, runtime: tokio::runtime::Handle,
response_queue: mpsc::Sender<Response>,
} }
impl<T> rpc::ServerCallback for CMixServerCallbacks<T> impl<S> rpc::ServerCallback for CMixServerCallback<S>
where where
T: Send + Sync + 'static, S: Service<IncomingRequest, Response = Vec<u8>, Error = String> + Clone + Send + 'static,
{ {
#[tracing::instrument(level = "debug", skip(self))]
fn serve_req(&self, sender_id: Vec<u8>, request: Vec<u8>) -> Vec<u8> { fn serve_req(&self, sender_id: Vec<u8>, request: Vec<u8>) -> Vec<u8> {
let mut router = self.router.clone(); let mut service = self.service.clone();
let response_queue = self.response_queue.clone(); let res: Result<Vec<u8>, String> = self.runtime.block_on(async move {
let req = IncomingRequest { sender_id, request }; tracing::debug!("evaluating service on request");
let r = self.runtime.block_on(async { if std::future::poll_fn(|cx| service.poll_ready(cx))
let ret: Vec<u8>;
tracing::debug!("Evaluating router on request");
if poll_fn(|cx| router.poll_ready(cx)).await.is_ok() {
match router.call(req).await {
Err(e) => {
tracing::warn!(error = e, "Error in servicing request");
ret = e.into_bytes();
}
Ok(resp) => {
if response_queue
.send(Response { text: resp.clone() })
.await .await
.is_err() .is_ok()
{ {
tracing::warn!("couldn't send to queue"); let req = IncomingRequest::new(sender_id, request)?;
}; service.call(req).await
ret = resp;
}
}
} else { } else {
ret = String::from("error unable to service request").into_bytes(); Err("unable to service request".to_string())
} }
ret
}); });
tracing::warn!("sending response: {}", String::from_utf8_lossy(&r));
return r; let res = match res {
Ok(bytes) => bytes,
Err(text) => {
tracing::warn!(error = text, "error servicing request");
text.into_bytes()
}
};
tracing::info!(
res = String::from_utf8_lossy(&res).as_ref(),
"sending response"
);
res
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment