diff --git a/Cargo.lock b/Cargo.lock index 0c347cf8f7..94c8b76e18 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2253,6 +2253,7 @@ dependencies = [ "image", "interpreted-executor", "log", + "preprocessor", "serde", "serde_json", "tokio", @@ -2385,6 +2386,7 @@ dependencies = [ "log", "num_enum", "once_cell", + "preprocessor", "ron", "serde", "serde_json", @@ -4610,6 +4612,24 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "preprocessor" +version = "0.1.0" +dependencies = [ + "base64 0.22.1", + "dyn-any", + "futures", + "glam", + "graph-craft", + "graphene-core", + "graphene-std", + "interpreted-executor", + "log", + "serde", + "serde_json", + "tokio", +] + [[package]] name = "presser" version = "0.3.1" diff --git a/Cargo.toml b/Cargo.toml index 1c50e83ee1..d291bc9485 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,12 +14,13 @@ members = [ "node-graph/compilation-client", "node-graph/wgpu-executor", "node-graph/gpu-executor", + "node-graph/preprocessor", "node-graph/gpu-compiler/gpu-compiler-bin-wrapper", "libraries/dyn-any", "libraries/path-bool", "libraries/bezier-rs", "libraries/math-parser", - "website/other/bezier-rs-demos/wasm", + "website/other/bezier-rs-demos/wasm", "node-graph/preprocessor", ] exclude = ["node-graph/gpu-compiler"] default-members = [ @@ -38,7 +39,9 @@ resolver = "2" # Local dependencies dyn-any = { path = "libraries/dyn-any", features = ["derive", "glam", "reqwest"] } graphene-core = { path = "node-graph/gcore" } +graphene-std = { path = "node-graph/gstd" } graph-craft = { path = "node-graph/graph-craft", features = ["serde"] } +preprocessor = { path = "node-graph/preprocessor"} wgpu-executor = { path = "node-graph/wgpu-executor" } bezier-rs = { path = "libraries/bezier-rs", features = ["dyn-any"] } path-bool = { path = "libraries/path-bool", default-features = false } diff --git a/editor/Cargo.toml b/editor/Cargo.toml index 2191161bcb..df3fe4826a 100644 --- a/editor/Cargo.toml +++ b/editor/Cargo.toml @@ -30,6 +30,7 @@ ron = ["dep:ron"] # Local dependencies graphite-proc-macros = { path = "../proc-macros" } graph-craft = { path = "../node-graph/graph-craft" } +preprocessor = { path = "../node-graph/preprocessor" } interpreted-executor = { path = "../node-graph/interpreted-executor", features = [ "serde", ] } diff --git a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs index 86ece1bb54..4a6e6cf720 100644 --- a/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs +++ b/editor/src/messages/portfolio/document/node_graph/document_node_definitions.rs @@ -26,6 +26,8 @@ use std::collections::{HashMap, HashSet, VecDeque}; #[cfg(feature = "gpu")] use wgpu_executor::{Bindgroup, CommandBuffer, PipelineLayout, ShaderHandle, ShaderInputFrame, WgpuShaderInput}; +mod document_node_derive; + pub struct NodePropertiesContext<'a> { pub persistent_data: &'a PersistentData, pub responses: &'a mut VecDeque, @@ -93,7 +95,7 @@ static DOCUMENT_NODE_TYPES: once_cell::sync::Lazy> = /// Defines the "signature" or "header file"-like metadata for the document nodes, but not the implementation (which is defined in the node registry). /// The [`DocumentNode`] is the instance while these [`DocumentNodeDefinition`]s are the "classes" or "blueprints" from which the instances are built. fn static_nodes() -> Vec { - let mut custom = vec![ + let custom = vec![ // TODO: Auto-generate this from its proto node macro DocumentNodeDefinition { identifier: "Identity", @@ -237,15 +239,8 @@ fn static_nodes() -> Vec { node_template: NodeTemplate { document_node: DocumentNode { implementation: DocumentNodeImplementation::Network(NodeNetwork { - exports: vec![NodeInput::node(NodeId(3), 0)], + exports: vec![NodeInput::node(NodeId(2), 0)], nodes: [ - // Secondary (left) input type coercion - DocumentNode { - inputs: vec![NodeInput::network(generic!(T), 1)], - implementation: DocumentNodeImplementation::proto("graphene_core::graphic_element::ToElementNode"), - manual_composition: Some(generic!(T)), - ..Default::default() - }, // Primary (bottom) input type coercion DocumentNode { inputs: vec![NodeInput::network(generic!(T), 0)], @@ -255,7 +250,7 @@ fn static_nodes() -> Vec { }, // The monitor node is used to display a thumbnail in the UI DocumentNode { - inputs: vec![NodeInput::node(NodeId(0), 0)], + inputs: vec![NodeInput::network(generic!(T), 1)], implementation: DocumentNodeImplementation::proto("graphene_core::memo::MonitorNode"), manual_composition: Some(generic!(T)), skip_deduplication: true, @@ -264,8 +259,8 @@ fn static_nodes() -> Vec { DocumentNode { manual_composition: Some(generic!(T)), inputs: vec![ + NodeInput::node(NodeId(0), 0), NodeInput::node(NodeId(1), 0), - NodeInput::node(NodeId(2), 0), NodeInput::Reflection(graph_craft::document::DocumentNodeMetadata::DocumentNodePath), ], implementation: DocumentNodeImplementation::proto("graphene_core::graphic_element::LayerNode"), @@ -2533,109 +2528,7 @@ fn static_nodes() -> Vec { }, ]; - // Remove struct generics - for DocumentNodeDefinition { node_template, .. } in custom.iter_mut() { - let NodeTemplate { - document_node: DocumentNode { implementation, .. }, - .. - } = node_template; - if let DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) = implementation { - if let Some((new_name, _suffix)) = name.rsplit_once("<") { - *name = Cow::Owned(new_name.to_string()) - } - }; - } - let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); - 'outer: for (id, metadata) in graphene_core::registry::NODE_METADATA.lock().unwrap().iter() { - use graphene_core::registry::*; - let id = id.clone(); - - for node in custom.iter() { - let DocumentNodeDefinition { - node_template: NodeTemplate { - document_node: DocumentNode { implementation, .. }, - .. - }, - .. - } = node; - match implementation { - DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == &id => continue 'outer, - _ => (), - } - } - - let NodeMetadata { - display_name, - category, - fields, - description, - properties, - } = metadata; - let Some(implementations) = &node_registry.get(&id) else { continue }; - let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); - let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); - let mut input_type = &first_node_io.call_argument; - if valid_inputs.len() > 1 { - input_type = &const { generic!(D) }; - } - let output_type = &first_node_io.return_value; - - let inputs = fields - .iter() - .zip(first_node_io.inputs.iter()) - .enumerate() - .map(|(index, (field, node_io_ty))| { - let ty = field.default_type.as_ref().unwrap_or(node_io_ty); - let exposed = if index == 0 { *ty != fn_type_fut!(Context, ()) } else { field.exposed }; - - match field.value_source { - RegistryValueSource::None => {} - RegistryValueSource::Default(data) => return NodeInput::value(TaggedValue::from_primitive_string(data, ty).unwrap_or(TaggedValue::None), exposed), - RegistryValueSource::Scope(data) => return NodeInput::scope(Cow::Borrowed(data)), - }; - - if let Some(type_default) = TaggedValue::from_type(ty) { - return NodeInput::value(type_default, exposed); - } - NodeInput::value(TaggedValue::None, true) - }) - .collect(); - - let node = DocumentNodeDefinition { - identifier: display_name, - node_template: NodeTemplate { - document_node: DocumentNode { - inputs, - manual_composition: Some(input_type.clone()), - implementation: DocumentNodeImplementation::ProtoNode(id.clone().into()), - visible: true, - skip_deduplication: false, - ..Default::default() - }, - persistent_node_metadata: DocumentNodePersistentMetadata { - // TODO: Store information for input overrides in the node macro - input_properties: fields - .iter() - .map(|f| match f.widget_override { - RegistryWidgetOverride::None => (f.name, f.description).into(), - RegistryWidgetOverride::Hidden => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Hidden), - RegistryWidgetOverride::String(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::String(str.to_string())), - RegistryWidgetOverride::Custom(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Custom(str.to_string())), - }) - .collect(), - output_names: vec![output_type.to_string()], - has_primary_output: true, - locked: false, - ..Default::default() - }, - }, - category: category.unwrap_or("UNCATEGORIZED"), - description: Cow::Borrowed(description), - properties: *properties, - }; - custom.push(node); - } - custom + document_node_derive::post_process_nodes(custom) } // pub static IMAGINATE_NODE: Lazy = Lazy::new(|| DocumentNodeDefinition { diff --git a/editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs b/editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs new file mode 100644 index 0000000000..868ed19bf5 --- /dev/null +++ b/editor/src/messages/portfolio/document/node_graph/document_node_definitions/document_node_derive.rs @@ -0,0 +1,92 @@ +use super::DocumentNodeDefinition; +use crate::messages::portfolio::document::utility_types::network_interface::{DocumentNodePersistentMetadata, NodeTemplate, PropertiesRow, WidgetOverride}; +use graph_craft::ProtoNodeIdentifier; +use graph_craft::document::*; +use graphene_core::*; +use std::collections::HashSet; + +pub(super) fn post_process_nodes(mut custom: Vec) -> Vec { + // Remove struct generics + for DocumentNodeDefinition { node_template, .. } in custom.iter_mut() { + let NodeTemplate { + document_node: DocumentNode { implementation, .. }, + .. + } = node_template; + if let DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) = implementation { + if let Some((new_name, _suffix)) = name.rsplit_once("<") { + *name = Cow::Owned(new_name.to_string()) + } + }; + } + let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); + 'outer: for (id, metadata) in graphene_core::registry::NODE_METADATA.lock().unwrap().iter() { + use graphene_core::registry::*; + let id = id.clone(); + + for node in custom.iter() { + let DocumentNodeDefinition { + node_template: NodeTemplate { + document_node: DocumentNode { implementation, .. }, + .. + }, + .. + } = node; + match implementation { + DocumentNodeImplementation::ProtoNode(ProtoNodeIdentifier { name }) if name == &id => continue 'outer, + _ => (), + } + } + + let NodeMetadata { + display_name, + category, + fields, + description, + properties, + } = metadata; + let Some(implementations) = &node_registry.get(&id) else { continue }; + let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); + let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); + let mut input_type = &first_node_io.call_argument; + if valid_inputs.len() > 1 { + input_type = &const { generic!(D) }; + } + let output_type = &first_node_io.return_value; + + let inputs = preprocessor::node_inputs(fields, first_node_io); + let node = DocumentNodeDefinition { + identifier: display_name, + node_template: NodeTemplate { + document_node: DocumentNode { + inputs, + manual_composition: Some(input_type.clone()), + implementation: DocumentNodeImplementation::ProtoNode(id.clone().into()), + visible: true, + skip_deduplication: false, + ..Default::default() + }, + persistent_node_metadata: DocumentNodePersistentMetadata { + // TODO: Store information for input overrides in the node macro + input_properties: fields + .iter() + .map(|f| match f.widget_override { + RegistryWidgetOverride::None => (f.name, f.description).into(), + RegistryWidgetOverride::Hidden => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Hidden), + RegistryWidgetOverride::String(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::String(str.to_string())), + RegistryWidgetOverride::Custom(str) => PropertiesRow::with_override(f.name, f.description, WidgetOverride::Custom(str.to_string())), + }) + .collect(), + output_names: vec![output_type.to_string()], + has_primary_output: true, + locked: false, + ..Default::default() + }, + }, + category: category.unwrap_or("UNCATEGORIZED"), + description: Cow::Borrowed(description), + properties: *properties, + }; + custom.push(node); + } + custom +} diff --git a/editor/src/messages/portfolio/document/utility_types/network_interface.rs b/editor/src/messages/portfolio/document/utility_types/network_interface.rs index 6d65bf48e8..f58afd49bd 100644 --- a/editor/src/messages/portfolio/document/utility_types/network_interface.rs +++ b/editor/src/messages/portfolio/document/utility_types/network_interface.rs @@ -6488,6 +6488,12 @@ pub struct NodePersistentMetadata { position: NodePosition, } +impl NodePersistentMetadata { + pub fn new(position: NodePosition) -> Self { + Self { position } + } +} + /// A layer can either be position as Absolute or in a Stack #[derive(Debug, Clone, PartialEq, serde::Serialize, serde::Deserialize)] pub enum LayerPosition { diff --git a/editor/src/node_graph_executor/runtime.rs b/editor/src/node_graph_executor/runtime.rs index 703e57a397..9644eca9c3 100644 --- a/editor/src/node_graph_executor/runtime.rs +++ b/editor/src/node_graph_executor/runtime.rs @@ -44,6 +44,8 @@ pub struct NodeRuntime { /// Which node is inspected and which monitor node is used (if any) for the current execution inspect_state: Option, + substitutions: HashMap, + // TODO: Remove, it doesn't need to be persisted anymore /// The current renders of the thumbnails for layer nodes. thumbnail_renders: HashMap>, @@ -119,6 +121,8 @@ impl NodeRuntime { node_graph_errors: Vec::new(), monitor_nodes: Vec::new(), + substitutions: preprocessor::generate_node_substitutions(), + thumbnail_renders: Default::default(), vector_modify: Default::default(), inspect_state: None, @@ -220,11 +224,13 @@ impl NodeRuntime { } } - async fn update_network(&mut self, graph: NodeNetwork) -> Result { + async fn update_network(&mut self, mut graph: NodeNetwork) -> Result { + preprocessor::expand_network(&mut graph, &self.substitutions); let scoped_network = wrap_network_in_scope(graph, self.editor_api.clone()); // We assume only one output assert_eq!(scoped_network.exports.len(), 1, "Graph with multiple outputs not yet handled"); + let c = Compiler {}; let proto_network = match c.compile_single(scoped_network) { Ok(network) => network, diff --git a/node-graph/gcore/src/ops.rs b/node-graph/gcore/src/ops.rs index 14298a5394..15eb8ada08 100644 --- a/node-graph/gcore/src/ops.rs +++ b/node-graph/gcore/src/ops.rs @@ -608,6 +608,76 @@ where Box::pin(async move { input.into() }) } } +pub trait Convert: Sized { + /// Converts this type into the (usually inferred) input type. + #[must_use] + fn convert(self) -> T; +} + +macro_rules! impl_convert { + ($from:ty,$to:ty) => { + impl Convert<$to> for $from { + fn convert(self) -> $to { + self as $to + } + } + }; + ($to:ty) => { + impl_convert!(i8, $to); + impl_convert!(u8, $to); + impl_convert!(u16, $to); + impl_convert!(i16, $to); + impl_convert!(i32, $to); + impl_convert!(u32, $to); + impl_convert!(i64, $to); + impl_convert!(u64, $to); + impl_convert!(isize, $to); + impl_convert!(usize, $to); + impl_convert!(i128, $to); + impl_convert!(u128, $to); + impl_convert!(f32, $to); + impl_convert!(f64, $to); + }; +} +impl_convert!(i8); +impl_convert!(u8); +impl_convert!(u16); +impl_convert!(i16); +impl_convert!(i32); +impl_convert!(u32); +impl_convert!(i64); +impl_convert!(u64); +impl_convert!(isize); +impl_convert!(usize); +impl_convert!(i128); +impl_convert!(u128); +impl_convert!(f32); +impl_convert!(f64); + +// Convert +pub struct ConvertNode(PhantomData); +impl<_O> ConvertNode<_O> { + #[cfg(feature = "alloc")] + pub const fn new() -> Self { + Self(core::marker::PhantomData) + } +} +impl<_O> Default for ConvertNode<_O> { + fn default() -> Self { + Self::new() + } +} +impl<'input, I: 'input, _O: 'input> Node<'input, I> for ConvertNode<_O> +where + I: Convert<_O> + Sync + Send, +{ + type Output = ::dyn_any::DynFuture<'input, _O>; + + #[inline] + fn eval(&'input self, input: I) -> Self::Output { + Box::pin(async move { input.convert() }) + } +} #[cfg(test)] mod test { diff --git a/node-graph/graphene-cli/Cargo.toml b/node-graph/graphene-cli/Cargo.toml index 40f52e7071..d7f0d08c25 100644 --- a/node-graph/graphene-cli/Cargo.toml +++ b/node-graph/graphene-cli/Cargo.toml @@ -24,6 +24,7 @@ gpu = [ # Local dependencies graphene-std = { path = "../gstd", features = ["serde"] } interpreted-executor = { path = "../interpreted-executor" } +preprocessor = { path = "../preprocessor" } # Workspace dependencies log = { workspace = true } diff --git a/node-graph/graphene-cli/src/main.rs b/node-graph/graphene-cli/src/main.rs index 9a932b4162..5a40d5e0c3 100644 --- a/node-graph/graphene-cli/src/main.rs +++ b/node-graph/graphene-cli/src/main.rs @@ -183,8 +183,11 @@ fn fix_nodes(network: &mut NodeNetwork) { fn compile_graph(document_string: String, editor_api: Arc) -> Result> { let mut network = load_network(&document_string); fix_nodes(&mut network); + let substitutions = preprocessor::generate_node_substitutions(); + preprocessor::expand_network(&mut network, &substitutions); let wrapped_network = wrap_network_in_scope(network.clone(), editor_api); + let compiler = Compiler {}; compiler.compile_single(wrapped_network).map_err(|x| x.into()) } diff --git a/node-graph/interpreted-executor/src/node_registry.rs b/node-graph/interpreted-executor/src/node_registry.rs index 61f3e1d84a..85d2bdc8b7 100644 --- a/node-graph/interpreted-executor/src/node_registry.rs +++ b/node-graph/interpreted-executor/src/node_registry.rs @@ -15,7 +15,7 @@ use graphene_std::GraphicElement; use graphene_std::any::{ComposeTypeErased, DowncastBothNode, DynAnyNode, IntoTypeErasedNode}; use graphene_std::application_io::{ImageTexture, TextureFrameTable}; use graphene_std::wasm_application_io::*; -use node_registry_macros::{async_node, into_node}; +use node_registry_macros::{async_node, convert_node, into_node}; use once_cell::sync::Lazy; use std::collections::HashMap; use std::sync::Arc; @@ -25,10 +25,7 @@ use wgpu_executor::{WgpuExecutor, WgpuSurface, WindowHandle}; // TODO: turn into hashmap fn node_registry() -> HashMap> { - let node_types: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ - into_node!(from: f64, to: f64), - into_node!(from: u32, to: f64), - into_node!(from: u8, to: u32), + let mut node_types: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ into_node!(from: VectorDataTable, to: VectorDataTable), into_node!(from: VectorDataTable, to: GraphicElement), into_node!(from: VectorDataTable, to: GraphicGroupTable), @@ -137,6 +134,26 @@ fn node_registry() -> HashMap> = HashMap::new(); @@ -152,7 +169,7 @@ fn node_registry() -> HashMap { ( ProtoNodeIdentifier::new(concat!["graphene_core::ops::IntoNode<", stringify!($to), ">"]), - |mut args| { + |_| { Box::pin(async move { - args.reverse(); let node = graphene_core::ops::IntoNode::<$to>::new(); let any: DynAnyNode<$from, _, _> = graphene_std::any::DynAnyNode::new(node); Box::new(any) as TypeErasedBox @@ -220,7 +236,47 @@ mod node_registry_macros { ) }; } + macro_rules! convert_node { + (from: $from:ty, to: numbers) => {{ + let x: Vec<(ProtoNodeIdentifier, NodeConstructor, NodeIOTypes)> = vec![ + convert_node!(from: $from, to: i8), + convert_node!(from: $from, to: u8), + convert_node!(from: $from, to: u16), + convert_node!(from: $from, to: i16), + convert_node!(from: $from, to: i32), + convert_node!(from: $from, to: u32), + convert_node!(from: $from, to: i64), + convert_node!(from: $from, to: u64), + convert_node!(from: $from, to: isize), + convert_node!(from: $from, to: usize), + convert_node!(from: $from, to: i128), + convert_node!(from: $from, to: u128), + convert_node!(from: $from, to: f32), + convert_node!(from: $from, to: f64), + ]; + x + }}; + (from: $from:ty, to: $to:ty) => { + ( + ProtoNodeIdentifier::new(concat!["graphene_core::ops::ConvertNode<", stringify!($to), ">"]), + |_| { + Box::pin(async move { + let node = graphene_core::ops::ConvertNode::<$to>::new(); + let any: DynAnyNode<$from, _, _> = graphene_std::any::DynAnyNode::new(node); + Box::new(any) as TypeErasedBox + }) + }, + { + let node = graphene_core::ops::ConvertNode::<$to>::new(); + let mut node_io = NodeIO::<'_, $from>::to_async_node_io(&node, vec![]); + node_io.call_argument = future!(<$from as StaticType>::Static); + node_io + }, + ) + }; + } pub(crate) use async_node; + pub(crate) use convert_node; pub(crate) use into_node; } diff --git a/node-graph/preprocessor/Cargo.toml b/node-graph/preprocessor/Cargo.toml new file mode 100644 index 0000000000..2bd0bfca38 --- /dev/null +++ b/node-graph/preprocessor/Cargo.toml @@ -0,0 +1,31 @@ +[package] +name = "preprocessor" +version = "0.1.0" +edition = "2024" +license = "MIT OR Apache-2.0" + +[features] + +[dependencies] +# Local dependencies +dyn-any = { path = "../../libraries/dyn-any", features = [ + "log-bad-types", + "rc", + "glam", +] } + +# Workspace dependencies +graphene-core = { workspace = true, features = ["std"] } +graphene-std = { workspace = true, features = ["gpu"] } +graph-craft = { workspace = true} +interpreted-executor = { path = "../interpreted-executor" } +log = { workspace = true } +futures = { workspace = true } +glam = { workspace = true } +base64 = { workspace = true } + +# Optional workspace dependencies +serde = { workspace = true, optional = true } +tokio = { workspace = true, optional = true } +serde_json = { workspace = true, optional = true } + diff --git a/node-graph/preprocessor/src/lib.rs b/node-graph/preprocessor/src/lib.rs new file mode 100644 index 0000000000..80306e6975 --- /dev/null +++ b/node-graph/preprocessor/src/lib.rs @@ -0,0 +1,154 @@ +use graph_craft::ProtoNodeIdentifier; +use graph_craft::concrete; +use graph_craft::document::value::*; +use graph_craft::document::*; +use graph_craft::proto::RegistryValueSource; +use graphene_core::*; +use std::collections::{HashMap, HashSet}; + +pub fn expand_network(network: &mut NodeNetwork, substitutions: &HashMap) { + for node in network.nodes.values_mut() { + match &mut node.implementation { + DocumentNodeImplementation::Network(node_network) => expand_network(node_network, substitutions), + DocumentNodeImplementation::ProtoNode(proto_node_identifier) => { + if let Some(new_node) = substitutions.get(proto_node_identifier.name.as_ref()) { + node.implementation = new_node.implementation.clone(); + } + } + DocumentNodeImplementation::Extract => (), + } + } +} + +pub fn generate_node_substitutions() -> HashMap { + let mut custom = HashMap::new(); + let node_registry = graphene_core::registry::NODE_REGISTRY.lock().unwrap(); + for (id, metadata) in graphene_core::registry::NODE_METADATA.lock().unwrap().iter() { + use graphene_core::registry::*; + let id = id.clone(); + + let NodeMetadata { fields, .. } = metadata; + let Some(implementations) = &node_registry.get(&id) else { continue }; + let valid_inputs: HashSet<_> = implementations.iter().map(|(_, node_io)| node_io.call_argument.clone()).collect(); + let first_node_io = implementations.first().map(|(_, node_io)| node_io).unwrap_or(const { &NodeIOTypes::empty() }); + let mut node_io_types = vec![HashSet::new(); fields.len()]; + for (_, node_io) in implementations.iter() { + for (i, ty) in node_io.inputs.iter().enumerate() { + node_io_types[i].insert(ty.clone()); + } + } + let mut input_type = &first_node_io.call_argument; + if valid_inputs.len() > 1 { + input_type = &const { generic!(D) }; + } + + let inputs: Vec<_> = node_inputs(fields, first_node_io); + let input_count = inputs.len(); + let network_inputs = (0..input_count).map(|i| NodeInput::node(NodeId(i as u64), 0)).collect(); + + let identity_node = ProtoNodeIdentifier::new("graphene_core::ops::IdentityNode"); + + let into_node_registry = &interpreted_executor::node_registry::NODE_REGISTRY; + + let mut nodes: HashMap<_, _, _> = node_io_types + .iter() + .enumerate() + .map(|(i, inputs)| { + ( + NodeId(i as u64), + match inputs.len() { + 1 => { + let input = inputs.iter().next().unwrap(); + let input_ty = input.nested_type(); + let into_node_identifier = ProtoNodeIdentifier { + name: format!("graphene_core::ops::IntoNode<{}>", input_ty.clone()).into(), + }; + let convert_node_identifier = ProtoNodeIdentifier { + name: format!("graphene_core::ops::ConvertNode<{}>", input_ty.clone()).into(), + }; + let proto_node = if into_node_registry.iter().any(|(ident, _)| { + let ident = ident.name.as_ref(); + ident == into_node_identifier.name.as_ref() + }) { + into_node_identifier + } else if into_node_registry.iter().any(|(ident, _)| { + let ident = ident.name.as_ref(); + ident == convert_node_identifier.name.as_ref() + }) { + convert_node_identifier + } else { + identity_node.clone() + }; + DocumentNode { + inputs: vec![NodeInput::network(input.clone(), i)], + // manual_composition: Some(fn_input.clone()), + implementation: DocumentNodeImplementation::ProtoNode(proto_node), + visible: true, + ..Default::default() + } + } + _ => DocumentNode { + inputs: vec![NodeInput::network(generic!(X), i)], + implementation: DocumentNodeImplementation::ProtoNode(identity_node.clone()), + visible: false, + ..Default::default() + }, + }, + ) + }) + .collect(); + + let document_node = DocumentNode { + inputs: network_inputs, + manual_composition: Some(input_type.clone()), + implementation: DocumentNodeImplementation::ProtoNode(id.clone().into()), + visible: true, + skip_deduplication: false, + ..Default::default() + }; + nodes.insert(NodeId(input_count as u64), document_node); + + let node = DocumentNode { + inputs, + manual_composition: Some(input_type.clone()), + implementation: DocumentNodeImplementation::Network(NodeNetwork { + exports: vec![NodeInput::Node { + node_id: NodeId(input_count as u64), + output_index: 0, + lambda: false, + }], + nodes, + scope_injections: Default::default(), + }), + visible: true, + skip_deduplication: false, + ..Default::default() + }; + + custom.insert(id.clone(), node); + } + custom +} + +pub fn node_inputs(fields: &[registry::FieldMetadata], first_node_io: &NodeIOTypes) -> Vec { + fields + .iter() + .zip(first_node_io.inputs.iter()) + .enumerate() + .map(|(index, (field, node_io_ty))| { + let ty = field.default_type.as_ref().unwrap_or(node_io_ty); + let exposed = if index == 0 { *ty != fn_type_fut!(Context, ()) } else { field.exposed }; + + match field.value_source { + RegistryValueSource::None => {} + RegistryValueSource::Default(data) => return NodeInput::value(TaggedValue::from_primitive_string(data, ty).unwrap_or(TaggedValue::None), exposed), + RegistryValueSource::Scope(data) => return NodeInput::scope(Cow::Borrowed(data)), + }; + + if let Some(type_default) = TaggedValue::from_type(ty) { + return NodeInput::value(type_default, exposed); + } + NodeInput::value(TaggedValue::None, true) + }) + .collect() +}