diff --git a/macros/Cargo.toml b/macros/Cargo.toml index 12fa2bb3..c8627251 100644 --- a/macros/Cargo.toml +++ b/macros/Cargo.toml @@ -16,3 +16,4 @@ proc-macro = true [dependencies] syn = "2.0" quote = "1.0" +proc-macro2 = "1.0.93" diff --git a/macros/src/buffer.rs b/macros/src/buffer.rs new file mode 100644 index 00000000..6ab201dc --- /dev/null +++ b/macros/src/buffer.rs @@ -0,0 +1,266 @@ +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::{parse_quote, Field, Generics, Ident, ItemStruct, Type, TypePath}; + +use crate::Result; + +pub(crate) fn impl_joined_value(input_struct: &ItemStruct) -> Result { + let struct_ident = &input_struct.ident; + let (impl_generics, ty_generics, where_clause) = input_struct.generics.split_for_impl(); + let StructConfig { + buffer_struct_name: buffer_struct_ident, + } = StructConfig::from_data_struct(&input_struct); + let buffer_struct_vis = &input_struct.vis; + + let (field_ident, _, field_config) = get_fields_map(&input_struct.fields)?; + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let noncopy = field_config.iter().any(|config| config.noncopy); + + let buffer_struct: ItemStruct = parse_quote! { + #[allow(non_camel_case_types, unused)] + #buffer_struct_vis struct #buffer_struct_ident #impl_generics #where_clause { + #( + #buffer_struct_vis #field_ident: #buffer, + )* + } + }; + + let buffer_clone_impl = if noncopy { + // Clone impl for structs with a buffer that is not copyable + quote! { + impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause { + fn clone(&self) -> Self { + Self { + #( + #field_ident: self.#field_ident.clone(), + )* + } + } + } + } + } else { + // Clone and copy impl for structs with buffers that are all copyable + quote! { + impl #impl_generics ::std::clone::Clone for #buffer_struct_ident #ty_generics #where_clause { + fn clone(&self) -> Self { + *self + } + } + + impl #impl_generics ::std::marker::Copy for #buffer_struct_ident #ty_generics #where_clause {} + } + }; + + let impl_buffer_map_layout = impl_buffer_map_layout(&buffer_struct, &input_struct)?; + let impl_joined = impl_joined(&buffer_struct, &input_struct)?; + + let gen = quote! { + impl #impl_generics ::bevy_impulse::JoinedValue for #struct_ident #ty_generics #where_clause { + type Buffers = #buffer_struct_ident #ty_generics; + } + + #buffer_struct + + #buffer_clone_impl + + impl #impl_generics #struct_ident #ty_generics #where_clause { + fn select_buffers( + #( + #field_ident: #buffer, + )* + ) -> #buffer_struct_ident #ty_generics { + #buffer_struct_ident { + #( + #field_ident, + )* + } + } + } + + #impl_buffer_map_layout + + #impl_joined + }; + + Ok(gen.into()) +} + +/// Code that are currently unused but could be used in the future, move them out of this mod if +/// they are ever used. +#[allow(unused)] +mod _unused { + use super::*; + + /// Converts a list of generics to a [`PhantomData`] TypePath. + /// e.g. `::std::marker::PhantomData` + fn to_phantom_data(generics: &Generics) -> TypePath { + let lifetimes: Vec = generics + .lifetimes() + .map(|lt| { + let lt = <.lifetime; + let ty: Type = parse_quote! { & #lt () }; + ty + }) + .collect(); + let ty_params: Vec<&Ident> = generics.type_params().map(|ty| &ty.ident).collect(); + parse_quote! { ::std::marker::PhantomData } + } +} + +struct StructConfig { + buffer_struct_name: Ident, +} + +impl StructConfig { + fn from_data_struct(data_struct: &ItemStruct) -> Self { + let mut config = Self { + buffer_struct_name: format_ident!("__bevy_impulse_{}_Buffers", data_struct.ident), + }; + + let attr = data_struct + .attrs + .iter() + .find(|attr| attr.path().is_ident("joined")); + + if let Some(attr) = attr { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("buffers_struct_name") { + config.buffer_struct_name = meta.value()?.parse()?; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +struct FieldConfig { + buffer: Type, + noncopy: bool, +} + +impl FieldConfig { + fn from_field(field: &Field) -> Self { + let ty = &field.ty; + let mut config = Self { + buffer: parse_quote! { ::bevy_impulse::Buffer<#ty> }, + noncopy: false, + }; + + for attr in field + .attrs + .iter() + .filter(|attr| attr.path().is_ident("joined")) + { + attr.parse_nested_meta(|meta| { + if meta.path.is_ident("buffer") { + config.buffer = meta.value()?.parse()?; + } + if meta.path.is_ident("noncopy_buffer") { + config.noncopy = true; + } + Ok(()) + }) + // panic if attribute is malformed, this will result in a compile error which is intended. + .unwrap(); + } + + config + } +} + +fn get_fields_map(fields: &syn::Fields) -> Result<(Vec<&Ident>, Vec<&Type>, Vec)> { + match fields { + syn::Fields::Named(data) => { + let mut idents = Vec::new(); + let mut types = Vec::new(); + let mut configs = Vec::new(); + for field in &data.named { + let ident = field + .ident + .as_ref() + .ok_or("expected named fields".to_string())?; + idents.push(ident); + types.push(&field.ty); + configs.push(FieldConfig::from_field(field)); + } + Ok((idents, types, configs)) + } + _ => return Err("expected named fields".to_string()), + } +} + +/// Params: +/// buffer_struct: The struct to implement `BufferMapLayout`. +/// item_struct: The struct which `buffer_struct` is derived from. +fn impl_buffer_map_layout( + buffer_struct: &ItemStruct, + item_struct: &ItemStruct, +) -> Result { + let struct_ident = &buffer_struct.ident; + let (impl_generics, ty_generics, where_clause) = buffer_struct.generics.split_for_impl(); + let (field_ident, _, field_config) = get_fields_map(&item_struct.fields)?; + let buffer: Vec<&Type> = field_config.iter().map(|config| &config.buffer).collect(); + let map_key: Vec = field_ident.iter().map(|v| v.to_string()).collect(); + + Ok(quote! { + impl #impl_generics ::bevy_impulse::BufferMapLayout for #struct_ident #ty_generics #where_clause { + fn buffer_list(&self) -> ::smallvec::SmallVec<[AnyBuffer; 8]> { + use smallvec::smallvec; + smallvec![#( + self.#field_ident.as_any_buffer(), + )*] + } + + fn try_from_buffer_map(buffers: &::bevy_impulse::BufferMap) -> Result { + let mut compatibility = ::bevy_impulse::IncompatibleLayout::default(); + #( + let #field_ident = if let Ok(buffer) = compatibility.require_buffer_type::<#buffer>(#map_key, buffers) { + buffer + } else { + return Err(compatibility); + }; + )* + + Ok(Self { + #( + #field_ident, + )* + }) + } + } + } + .into()) +} + +/// Params: +/// joined_struct: The struct to implement `Joined`. +/// item_struct: The associated `Item` type to use for the `Joined` implementation. +fn impl_joined( + joined_struct: &ItemStruct, + item_struct: &ItemStruct, +) -> Result { + let struct_ident = &joined_struct.ident; + let item_struct_ident = &item_struct.ident; + let (impl_generics, ty_generics, where_clause) = item_struct.generics.split_for_impl(); + let (field_ident, _, _) = get_fields_map(&item_struct.fields)?; + + Ok(quote! { + impl #impl_generics ::bevy_impulse::Joined for #struct_ident #ty_generics #where_clause { + type Item = #item_struct_ident #ty_generics; + + fn pull(&self, session: ::bevy_ecs::prelude::Entity, world: &mut ::bevy_ecs::prelude::World) -> Result { + #( + let #field_ident = self.#field_ident.pull(session, world)?; + )* + + Ok(Self::Item {#( + #field_ident, + )*}) + } + } + }.into()) +} diff --git a/macros/src/lib.rs b/macros/src/lib.rs index d40c9309..df58fdc6 100644 --- a/macros/src/lib.rs +++ b/macros/src/lib.rs @@ -15,9 +15,12 @@ * */ +mod buffer; +use buffer::impl_joined_value; + use proc_macro::TokenStream; use quote::quote; -use syn::DeriveInput; +use syn::{parse_macro_input, DeriveInput, ItemStruct}; #[proc_macro_derive(Stream)] pub fn simple_stream_macro(item: TokenStream) -> TokenStream { @@ -58,3 +61,18 @@ pub fn delivery_label_macro(item: TokenStream) -> TokenStream { } .into() } + +/// The result error is the compiler error message to be displayed. +type Result = std::result::Result; + +#[proc_macro_derive(JoinedValue, attributes(joined))] +pub fn derive_joined_value(input: TokenStream) -> TokenStream { + let input = parse_macro_input!(input as ItemStruct); + match impl_joined_value(&input) { + Ok(tokens) => tokens.into(), + Err(msg) => quote! { + compile_error!(#msg); + } + .into(), + } +} diff --git a/src/buffer/any_buffer.rs b/src/buffer/any_buffer.rs index 32529079..b4315fe8 100644 --- a/src/buffer/any_buffer.rs +++ b/src/buffer/any_buffer.rs @@ -121,6 +121,10 @@ impl AnyBuffer { .ok() .map(|x| *x) } + + pub fn as_any_buffer(&self) -> Self { + self.clone().into() + } } impl From> for AnyBuffer { @@ -857,6 +861,17 @@ impl AnyBufferAccessImpl { })), ); + // Allow downcasting back to the original Buffer + buffer_downcasts.insert( + TypeId::of::>(), + Box::leak(Box::new(|location| -> Box { + Box::new(Buffer:: { + location, + _ignore: Default::default(), + }) + })), + ); + let mut key_downcasts: HashMap<_, KeyDowncastRef> = HashMap::new(); // Automatically register a downcast to AnyBufferKey diff --git a/src/buffer/buffer_map.rs b/src/buffer/buffer_map.rs index 24569d44..c522dbfe 100644 --- a/src/buffer/buffer_map.rs +++ b/src/buffer/buffer_map.rs @@ -29,6 +29,8 @@ use crate::{ OperationError, OperationResult, OperationRoster, Output, UnusedTarget, }; +pub use bevy_impulse_derive::JoinedValue; + #[derive(Clone, Default)] pub struct BufferMap { inner: HashMap, AnyBuffer>, @@ -311,76 +313,16 @@ impl Accessed for BufferMap { #[cfg(test)] mod tests { - use crate::{prelude::*, testing::*, BufferMap, OperationError}; - - use bevy_ecs::prelude::World; + use crate::{prelude::*, testing::*, BufferMap}; - #[derive(Clone)] - struct TestJoinedValue { + #[derive(JoinedValue)] + struct TestJoinedValue { integer: i64, float: f64, string: String, - } - - impl JoinedValue for TestJoinedValue { - type Buffers = TestJoinedValueBuffers; - } - - #[derive(Clone)] - struct TestJoinedValueBuffers { - integer: Buffer, - float: Buffer, - string: Buffer, - } - - impl BufferMapLayout for TestJoinedValueBuffers { - fn buffer_list(&self) -> smallvec::SmallVec<[AnyBuffer; 8]> { - use smallvec::smallvec; - smallvec![ - self.integer.as_any_buffer(), - self.float.as_any_buffer(), - self.string.as_any_buffer(), - ] - } - - fn try_from_buffer_map(buffers: &BufferMap) -> Result { - let mut compatibility = IncompatibleLayout::default(); - let integer = compatibility.require_message_type::("integer", buffers); - let float = compatibility.require_message_type::("float", buffers); - let string = compatibility.require_message_type::("string", buffers); - - let Ok(integer) = integer else { - return Err(compatibility); - }; - let Ok(float) = float else { - return Err(compatibility); - }; - let Ok(string) = string else { - return Err(compatibility); - }; - - Ok(Self { - integer, - float, - string, - }) - } - } - - impl crate::Joined for TestJoinedValueBuffers { - type Item = TestJoinedValue; - - fn pull(&self, session: Entity, world: &mut World) -> Result { - let integer = self.integer.pull(session, world)?; - let float = self.float.pull(session, world)?; - let string = self.string.pull(session, world)?; - - Ok(TestJoinedValue { - integer, - float, - string, - }) - } + generic: T, + #[joined(buffer = AnyBuffer)] + any: AnyMessageBox, } #[test] @@ -391,16 +333,22 @@ mod tests { let buffer_i64 = builder.create_buffer(BufferSettings::default()); let buffer_f64 = builder.create_buffer(BufferSettings::default()); let buffer_string = builder.create_buffer(BufferSettings::default()); + let buffer_generic = builder.create_buffer(BufferSettings::default()); + let buffer_any = builder.create_buffer(BufferSettings::default()); let mut buffers = BufferMap::default(); buffers.insert("integer", buffer_i64); buffers.insert("float", buffer_f64); buffers.insert("string", buffer_string); + buffers.insert("generic", buffer_generic); + buffers.insert("any", buffer_any); scope.input.chain(builder).fork_unzip(( |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), )); builder.try_join(&buffers).unwrap().connect(scope.terminate); @@ -408,15 +356,20 @@ mod tests { let mut promise = context.command(|commands| { commands - .request((5_i64, 3.14_f64, "hello".to_string()), workflow) + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) .take_response() }); context.run_with_conditions(&mut promise, Duration::from_secs(2)); - let value: TestJoinedValue = promise.take().available().unwrap(); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); assert_eq!(value.integer, 5); assert_eq!(value.float, 3.14); assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); assert!(context.no_unhandled_errors()); } @@ -425,32 +378,129 @@ mod tests { let mut context = TestingContext::minimal_plugins(); let workflow = context.spawn_io_workflow(|scope, builder| { - let buffers = TestJoinedValueBuffers { - integer: builder.create_buffer(BufferSettings::default()), - float: builder.create_buffer(BufferSettings::default()), - string: builder.create_buffer(BufferSettings::default()), - }; + let buffer_i64 = builder.create_buffer(BufferSettings::default()); + let buffer_f64 = builder.create_buffer(BufferSettings::default()); + let buffer_string = builder.create_buffer(BufferSettings::default()); + let buffer_generic = builder.create_buffer(BufferSettings::default()); + let buffer_any = builder.create_buffer::(BufferSettings::default()); scope.input.chain(builder).fork_unzip(( - |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), - |chain: Chain<_>| chain.connect(buffers.float.input_slot()), - |chain: Chain<_>| chain.connect(buffers.string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_i64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_f64.input_slot()), + |chain: Chain<_>| chain.connect(buffer_string.input_slot()), + |chain: Chain<_>| chain.connect(buffer_generic.input_slot()), + |chain: Chain<_>| chain.connect(buffer_any.input_slot()), )); + let buffers = TestJoinedValue::select_buffers( + buffer_i64, + buffer_f64, + buffer_string, + buffer_generic, + buffer_any.into(), + ); + builder.join(buffers).connect(scope.terminate); }); let mut promise = context.command(|commands| { commands - .request((5_i64, 3.14_f64, "hello".to_string()), workflow) + .request( + (5_i64, 3.14_f64, "hello".to_string(), "world", 42_i64), + workflow, + ) .take_response() }); context.run_with_conditions(&mut promise, Duration::from_secs(2)); - let value: TestJoinedValue = promise.take().available().unwrap(); + let value: TestJoinedValue<&'static str> = promise.take().available().unwrap(); assert_eq!(value.integer, 5); assert_eq!(value.float, 3.14); assert_eq!(value.string, "hello"); + assert_eq!(value.generic, "world"); + assert_eq!(*value.any.downcast::().unwrap(), 42); + assert!(context.no_unhandled_errors()); + } + + #[derive(Clone, JoinedValue)] + #[joined(buffers_struct_name = FooBuffers)] + struct TestDeriveWithConfig {} + + #[test] + fn test_derive_with_config() { + // a compile test to check that the name of the generated struct is correct + fn _check_buffer_struct_name(_: FooBuffers) {} + } + + struct MultiGenericValue { + t: T, + u: U, + } + + #[derive(JoinedValue)] + #[joined(buffers_struct_name = MultiGenericBuffers)] + struct JoinedMultiGenericValue { + #[joined(buffer = Buffer>)] + a: MultiGenericValue, + b: String, + } + + #[test] + fn test_multi_generic_joined_value() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow( + |scope: Scope<(i32, String), JoinedMultiGenericValue>, builder| { + let multi_generic_buffers = MultiGenericBuffers:: { + a: builder.create_buffer(BufferSettings::default()), + b: builder.create_buffer(BufferSettings::default()), + }; + + let copy = multi_generic_buffers; + + scope + .input + .chain(builder) + .map_block(|(integer, string)| { + ( + MultiGenericValue { + t: integer, + u: string.clone(), + }, + string, + ) + }) + .fork_unzip(( + |a: Chain<_>| a.connect(multi_generic_buffers.a.input_slot()), + |b: Chain<_>| b.connect(multi_generic_buffers.b.input_slot()), + )); + + multi_generic_buffers.join(builder).connect(scope.terminate); + copy.join(builder).connect(scope.terminate); + }, + ); + + let mut promise = context.command(|commands| { + commands + .request((5, "hello".to_string()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value = promise.take().available().unwrap(); + assert_eq!(value.a.t, 5); + assert_eq!(value.a.u, "hello"); + assert_eq!(value.b, "hello"); assert!(context.no_unhandled_errors()); } + + /// We create this struct just to verify that it is able to compile despite + /// NonCopyBuffer not being copyable. + #[derive(JoinedValue)] + #[allow(unused)] + struct JoinedValueForNonCopyBuffer { + #[joined(buffer = NonCopyBuffer, noncopy_buffer)] + _a: String, + _b: u32, + } } diff --git a/src/buffer/json_buffer.rs b/src/buffer/json_buffer.rs index 181fb726..d470da01 100644 --- a/src/buffer/json_buffer.rs +++ b/src/buffer/json_buffer.rs @@ -1353,78 +1353,15 @@ mod tests { assert!(context.no_unhandled_errors()); } - #[derive(Clone)] + #[derive(Clone, JoinedValue)] + #[joined(buffers_struct_name = TestJoinedValueJsonBuffers)] struct TestJoinedValueJson { integer: i64, float: f64, + #[joined(buffer = JsonBuffer)] json: JsonMessage, } - #[derive(Clone)] - struct TestJoinedValueJsonBuffers { - integer: Buffer, - float: Buffer, - json: JsonBuffer, - } - - impl JoinedValue for TestJoinedValueJson { - type Buffers = TestJoinedValueJsonBuffers; - } - - impl BufferMapLayout for TestJoinedValueJsonBuffers { - fn buffer_list(&self) -> smallvec::SmallVec<[AnyBuffer; 8]> { - use smallvec::smallvec; - smallvec![ - self.integer.as_any_buffer(), - self.float.as_any_buffer(), - self.json.as_any_buffer(), - ] - } - - fn try_from_buffer_map(buffers: &BufferMap) -> Result { - let mut compatibility = IncompatibleLayout::default(); - let integer = compatibility.require_message_type::("integer", buffers); - let float = compatibility.require_message_type::("float", buffers); - let json = compatibility.require_buffer_type::("json", buffers); - - let Ok(integer) = integer else { - return Err(compatibility); - }; - let Ok(float) = float else { - return Err(compatibility); - }; - let Ok(json) = json else { - return Err(compatibility); - }; - - Ok(Self { - integer, - float, - json, - }) - } - } - - impl crate::Joined for TestJoinedValueJsonBuffers { - type Item = TestJoinedValueJson; - - fn pull( - &self, - session: Entity, - world: &mut World, - ) -> Result { - let integer = self.integer.pull(session, world)?; - let float = self.float.pull(session, world)?; - let json = self.json.pull(session, world)?; - - Ok(TestJoinedValueJson { - integer, - float, - json, - }) - } - } - #[test] fn test_try_join_json() { let mut context = TestingContext::minimal_plugins(); @@ -1502,4 +1439,43 @@ mod tests { let expected_json = TestMessage::new(); assert_eq!(deserialized_json, expected_json); } + + #[test] + fn test_select_buffers_json() { + let mut context = TestingContext::minimal_plugins(); + + let workflow = context.spawn_io_workflow(|scope, builder| { + let buffer_integer = builder.create_buffer::(BufferSettings::default()); + let buffer_float = builder.create_buffer::(BufferSettings::default()); + let buffer_json = + JsonBuffer::from(builder.create_buffer::(BufferSettings::default())); + + let buffers = + TestJoinedValueJson::select_buffers(buffer_integer, buffer_float, buffer_json); + + scope.input.chain(builder).fork_unzip(( + |chain: Chain<_>| chain.connect(buffers.integer.input_slot()), + |chain: Chain<_>| chain.connect(buffers.float.input_slot()), + |chain: Chain<_>| { + chain.connect(buffers.json.downcast_for_message().unwrap().input_slot()) + }, + )); + + builder.join(buffers).connect(scope.terminate); + }); + + let mut promise = context.command(|commands| { + commands + .request((5_i64, 3.14_f64, TestMessage::new()), workflow) + .take_response() + }); + + context.run_with_conditions(&mut promise, Duration::from_secs(2)); + let value: TestJoinedValueJson = promise.take().available().unwrap(); + assert_eq!(value.integer, 5); + assert_eq!(value.float, 3.14); + let deserialized_json: TestMessage = serde_json::from_value(value.json).unwrap(); + let expected_json = TestMessage::new(); + assert_eq!(deserialized_json, expected_json); + } } diff --git a/src/lib.rs b/src/lib.rs index 417bd59b..4853bd93 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -148,6 +148,8 @@ pub use trim::*; use bevy_app::prelude::{App, Plugin, Update}; use bevy_ecs::prelude::{Entity, In}; +extern crate self as bevy_impulse; + /// Use `BlockingService` to indicate that your system is a blocking [`Service`]. /// /// A blocking service will have exclusive world access while it runs, which diff --git a/src/testing.rs b/src/testing.rs index 83d21dd9..a5ad4ec1 100644 --- a/src/testing.rs +++ b/src/testing.rs @@ -19,7 +19,7 @@ use bevy_app::ScheduleRunnerPlugin; pub use bevy_app::{App, Update}; use bevy_core::{FrameCountPlugin, TaskPoolPlugin, TypeRegistrationPlugin}; pub use bevy_ecs::{ - prelude::{Commands, Component, Entity, In, Local, Query, ResMut, Resource}, + prelude::{Commands, Component, Entity, In, Local, Query, ResMut, Resource, World}, system::{CommandQueue, IntoSystem}, }; use bevy_time::TimePlugin; @@ -32,10 +32,11 @@ pub use std::time::{Duration, Instant}; use smallvec::SmallVec; use crate::{ - flush_impulses, AddContinuousServicesExt, AsyncServiceInput, BlockingMap, BlockingServiceInput, - Builder, ContinuousQuery, ContinuousQueueView, ContinuousService, FlushParameters, - GetBufferedSessionsFn, Promise, RunCommandsOnWorldExt, Scope, Service, SpawnWorkflowExt, - StreamOf, StreamPack, UnhandledErrors, WorkflowSettings, + flush_impulses, Accessed, AddContinuousServicesExt, AnyBuffer, AsyncServiceInput, BlockingMap, + BlockingServiceInput, Buffer, BufferKey, Bufferable, Buffered, Builder, ContinuousQuery, + ContinuousQueueView, ContinuousService, FlushParameters, GetBufferedSessionsFn, Joined, + OperationError, OperationResult, OperationRoster, Promise, RunCommandsOnWorldExt, Scope, + Service, SpawnWorkflowExt, StreamOf, StreamPack, UnhandledErrors, WorkflowSettings, }; pub struct TestingContext { @@ -478,3 +479,104 @@ pub struct TestComponent; pub struct Integer { pub value: i32, } + +/// This is an ordinary buffer newtype whose only purpose is to test the +/// #[joined(noncopy_buffer)] feature. We intentionally do not implement +/// the Copy trait for it. +pub struct NonCopyBuffer { + inner: Buffer, +} + +impl NonCopyBuffer { + pub fn register_downcast() { + let any_interface = AnyBuffer::interface_for::(); + any_interface.register_buffer_downcast( + std::any::TypeId::of::>(), + Box::new(|location| { + Box::new(NonCopyBuffer:: { + inner: Buffer { + location, + _ignore: Default::default(), + }, + }) + }), + ); + } +} + +impl NonCopyBuffer { + pub fn as_any_buffer(&self) -> AnyBuffer { + self.inner.as_any_buffer() + } +} + +impl Clone for NonCopyBuffer { + fn clone(&self) -> Self { + Self { inner: self.inner } + } +} + +impl Bufferable for NonCopyBuffer { + type BufferType = Self; + fn into_buffer(self, _builder: &mut Builder) -> Self::BufferType { + self + } +} + +impl Buffered for NonCopyBuffer { + fn add_listener(&self, listener: Entity, world: &mut World) -> OperationResult { + self.inner.add_listener(listener, world) + } + + fn as_input(&self) -> smallvec::SmallVec<[Entity; 8]> { + self.inner.as_input() + } + + fn buffered_count(&self, session: Entity, world: &World) -> Result { + self.inner.buffered_count(session, world) + } + + fn ensure_active_session(&self, session: Entity, world: &mut World) -> OperationResult { + self.inner.ensure_active_session(session, world) + } + + fn gate_action( + &self, + session: Entity, + action: crate::Gate, + world: &mut World, + roster: &mut OperationRoster, + ) -> OperationResult { + self.inner.gate_action(session, action, world, roster) + } + + fn verify_scope(&self, scope: Entity) { + self.inner.verify_scope(scope); + } +} + +impl Joined for NonCopyBuffer { + type Item = T; + fn pull(&self, session: Entity, world: &mut World) -> Result { + self.inner.pull(session, world) + } +} + +impl Accessed for NonCopyBuffer { + type Key = BufferKey; + fn add_accessor(&self, accessor: Entity, world: &mut World) -> OperationResult { + self.inner.add_accessor(accessor, world) + } + + fn create_key(&self, builder: &crate::BufferKeyBuilder) -> Self::Key { + self.inner.create_key(builder) + } + + fn deep_clone_key(key: &Self::Key) -> Self::Key { + key.deep_clone() + } + + fn is_key_in_use(key: &Self::Key) -> bool { + key.is_in_use() + } +}