diff --git a/cot-macros/src/lib.rs b/cot-macros/src/lib.rs index a9cd552b..62febdcb 100644 --- a/cot-macros/src/lib.rs +++ b/cot-macros/src/lib.rs @@ -10,7 +10,7 @@ use darling::ast::NestedMeta; use proc_macro::TokenStream; use proc_macro_crate::crate_name; use quote::quote; -use syn::{ItemFn, parse_macro_input}; +use syn::{Data, Field, Fields, ItemFn, parse_macro_input, punctuated::Punctuated}; use crate::admin::impl_admin_model_for_struct; use crate::dbtest::fn_to_dbtest; @@ -192,3 +192,58 @@ pub(crate) fn cot_ident() -> proc_macro2::TokenStream { } } } + +#[proc_macro_derive(FromRequestParts)] +pub fn derive_from_request_parts(input: TokenStream) -> TokenStream { + let ast = parse_macro_input!(input as syn::DeriveInput); + impl_from_request_parts_for_struct(&ast).into() +} + +fn impl_from_request_parts_for_struct(ast: &syn::DeriveInput) -> proc_macro2::TokenStream { + let struct_name = &ast.ident; + let cot = cot_ident(); + + let fields = match &ast.data { + Data::Struct(data_struct) => match &data_struct.fields { + Fields::Named(fields_named) => &fields_named.named, + Fields::Unnamed(_) => { + let err = Error::custom( + "Structs with unnamed fields are not supported for `FromRequestParts`", + ); + return err.write_errors().into(); + } + Fields::Unit => &Punctuated::new(), + }, + _ => { + let err = Error::custom("Only structs can derive `FromRequestParts`"); + return err.write_errors().into(); + } + }; + + let field_initializers = fields.iter().map(|field: &Field| { + let field_name = &field.ident; + let field_type = &field.ty; + + quote! { + #field_name: #cot::axum::extract::Extension::<#field_type>::from_request_parts(parts) + .await + .map(|ext| ext.0) + .map_err(|e| #cot::anyhow::anyhow!(e))?, + } + }); + + let expanded = quote! { + #[::core::automatically_derived] + impl #cot::axum::extract::FromRequestParts<#cot::http::Request, #cot::anyhow::Error> for #struct_name { + async fn from_request_parts( + parts: &mut #cot::axum::extract::RequestParts, + ) -> ::std::result::Result { + Ok(Self { + #(#field_initializers,)* + }) + } + } + }; + + expanded +} diff --git a/cot-macros/tests/ui/from_request_parts_test.rs b/cot-macros/tests/ui/from_request_parts_test.rs new file mode 100644 index 00000000..58b1fc10 --- /dev/null +++ b/cot-macros/tests/ui/from_request_parts_test.rs @@ -0,0 +1,75 @@ +use cot::axum::extract::RequestParts; +use cot::http::Request; +use cot_macros::FromRequestParts; + +#[tokio::test] +async fn test_derive_from_request_parts_success() { + #[derive(FromRequestParts, Debug, PartialEq)] + struct TestContext { + user_id: i32, + username: String, + } + + let mut parts = RequestParts::new(Request::default()); + parts.extensions.insert(10_i32); + parts.extensions.insert("test_user".to_string()); + + let context = TestContext::from_request_parts(&mut parts).await.unwrap(); + + assert_eq!( + context, + TestContext { + user_id: 10, + username: "test_user".to_string() + } + ); +} + +#[tokio::test] +async fn test_derive_from_request_parts_different_types() { + #[derive(FromRequestParts, Debug, PartialEq)] + struct TestContext { + value_i64: i64, + value_bool: bool, + value_string: String, + } + + let mut parts = RequestParts::new(Request::default()); + parts.extensions.insert(12345_i64); + parts.extensions.insert(true); + parts.extensions.insert("another_test".to_string()); + + let context = TestContext::from_request_parts(&mut parts).await.unwrap(); + + assert_eq!( + context, + TestContext { + value_i64: 12345, + value_bool: true, + value_string: "another_test".to_string() + } + ); +} + +#[tokio::test] +async fn test_derive_from_request_parts_missing_extension() { + #[derive(FromRequestParts, Debug)] + struct TestContext { + missing_value: i32, + } + + let mut parts = RequestParts::new(Request::default()); + + let result = TestContext::from_request_parts(&mut parts).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn test_derive_from_request_parts_empty_struct() { + #[derive(FromRequestParts, Debug, PartialEq)] + struct EmptyContext {} + + let mut parts = RequestParts::new(Request::default()); + let context = EmptyContext::from_request_parts(&mut parts).await.unwrap(); + assert_eq!(EmptyContext {}, context); +}