|
1 |
| -use std::sync::Arc; |
| 1 | +//cargo test --test test_tool_macros --features "client server" |
2 | 2 |
|
3 |
| -use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool}; |
| 3 | +use rmcp::{ |
| 4 | + ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt, |
| 5 | + model::{CallToolRequestParam, ClientInfo}, |
| 6 | +}; |
| 7 | +use rmcp::{handler::server::tool::ToolCallContext, tool}; |
4 | 8 | use schemars::JsonSchema;
|
5 | 9 | use serde::{Deserialize, Serialize};
|
| 10 | +use serde_json; |
| 11 | +use std::sync::Arc; |
6 | 12 |
|
7 | 13 | #[derive(Serialize, Deserialize, JsonSchema)]
|
8 | 14 | pub struct GetWeatherRequest {
|
@@ -100,3 +106,210 @@ async fn test_tool_macros_with_generics() {
|
100 | 106 | }
|
101 | 107 |
|
102 | 108 | impl GetWeatherRequest {}
|
| 109 | + |
| 110 | +// Struct defined for testing optional field schema generation |
| 111 | +#[derive(Debug, Deserialize, Serialize, JsonSchema)] |
| 112 | +pub struct OptionalFieldTestSchema { |
| 113 | + #[schemars(description = "An optional description field")] |
| 114 | + pub description: Option<String>, |
| 115 | +} |
| 116 | + |
| 117 | +// Struct defined for testing optional i64 field schema generation and null handling |
| 118 | +#[derive(Debug, Deserialize, Serialize, JsonSchema)] |
| 119 | +pub struct OptionalI64TestSchema { |
| 120 | + #[schemars(description = "An optional i64 field")] |
| 121 | + pub count: Option<i64>, |
| 122 | + pub mandatory_field: String, // Added to ensure non-empty object schema |
| 123 | +} |
| 124 | + |
| 125 | +// Dummy struct to host the test tool method |
| 126 | +#[derive(Debug, Clone, Default)] |
| 127 | +pub struct OptionalSchemaTester {} |
| 128 | + |
| 129 | +impl OptionalSchemaTester { |
| 130 | + // Dummy tool function using the test schema as an aggregated parameter |
| 131 | + #[tool(description = "A tool to test optional schema generation")] |
| 132 | + async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) { |
| 133 | + // Implementation doesn't matter for schema testing |
| 134 | + // Return type changed to () to satisfy IntoCallToolResult |
| 135 | + } |
| 136 | + |
| 137 | + // Tool function to test optional i64 handling |
| 138 | + #[tool(description = "A tool to test optional i64 schema generation")] |
| 139 | + async fn test_optional_i64_aggr(&self, #[tool(aggr)] req: OptionalI64TestSchema) -> String { |
| 140 | + match req.count { |
| 141 | + Some(c) => format!("Received count: {}", c), |
| 142 | + None => "Received null count".to_string(), |
| 143 | + } |
| 144 | + } |
| 145 | +} |
| 146 | + |
| 147 | +// Implement ServerHandler to route tool calls for OptionalSchemaTester |
| 148 | +impl ServerHandler for OptionalSchemaTester { |
| 149 | + async fn call_tool( |
| 150 | + &self, |
| 151 | + request: rmcp::model::CallToolRequestParam, |
| 152 | + context: rmcp::service::RequestContext<rmcp::RoleServer>, |
| 153 | + ) -> Result<rmcp::model::CallToolResult, rmcp::Error> { |
| 154 | + let tcc = ToolCallContext::new(self, request, context); |
| 155 | + match tcc.name() { |
| 156 | + "test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await, |
| 157 | + "test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await, |
| 158 | + _ => Err(rmcp::Error::invalid_params("method not found", None)), |
| 159 | + } |
| 160 | + } |
| 161 | +} |
| 162 | + |
| 163 | +#[test] |
| 164 | +fn test_optional_field_schema_generation_via_macro() { |
| 165 | + // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135 |
| 166 | + |
| 167 | + // Get the attributes generated by the #[tool] macro helper |
| 168 | + let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr(); |
| 169 | + |
| 170 | + // Print the actual generated schema for debugging |
| 171 | + println!( |
| 172 | + "Actual input schema generated by macro: {:#?}", |
| 173 | + tool_attr.input_schema |
| 174 | + ); |
| 175 | + |
| 176 | + // Verify the schema generated for the aggregated OptionalFieldTestSchema |
| 177 | + // by the macro infrastructure (which should now use OpenAPI 3 settings) |
| 178 | + let input_schema_map = &*tool_attr.input_schema; // Dereference Arc<JsonObject> |
| 179 | + |
| 180 | + // Check the schema for the 'description' property within the input schema |
| 181 | + let properties = input_schema_map |
| 182 | + .get("properties") |
| 183 | + .expect("Schema should have properties") |
| 184 | + .as_object() |
| 185 | + .unwrap(); |
| 186 | + let description_schema = properties |
| 187 | + .get("description") |
| 188 | + .expect("Properties should include description") |
| 189 | + .as_object() |
| 190 | + .unwrap(); |
| 191 | + |
| 192 | + // Assert that the format is now `type: "string", nullable: true` |
| 193 | + assert_eq!( |
| 194 | + description_schema.get("type").map(|v| v.as_str().unwrap()), |
| 195 | + Some("string"), |
| 196 | + "Schema for Option<String> generated by macro should be type: \"string\"" |
| 197 | + ); |
| 198 | + assert_eq!( |
| 199 | + description_schema |
| 200 | + .get("nullable") |
| 201 | + .map(|v| v.as_bool().unwrap()), |
| 202 | + Some(true), |
| 203 | + "Schema for Option<String> generated by macro should have nullable: true" |
| 204 | + ); |
| 205 | + // We still check the description is correct |
| 206 | + assert_eq!( |
| 207 | + description_schema |
| 208 | + .get("description") |
| 209 | + .map(|v| v.as_str().unwrap()), |
| 210 | + Some("An optional description field") |
| 211 | + ); |
| 212 | + |
| 213 | + // Ensure the old 'type: [T, null]' format is NOT used |
| 214 | + let type_value = description_schema.get("type").unwrap(); |
| 215 | + assert!( |
| 216 | + !type_value.is_array(), |
| 217 | + "Schema type should not be an array [T, null]" |
| 218 | + ); |
| 219 | +} |
| 220 | + |
| 221 | +// Define a dummy client handler |
| 222 | +#[derive(Debug, Clone, Default)] |
| 223 | +struct DummyClientHandler { |
| 224 | + peer: Option<Peer<RoleClient>>, |
| 225 | +} |
| 226 | + |
| 227 | +impl ClientHandler for DummyClientHandler { |
| 228 | + fn get_info(&self) -> ClientInfo { |
| 229 | + ClientInfo::default() |
| 230 | + } |
| 231 | + |
| 232 | + fn set_peer(&mut self, peer: Peer<RoleClient>) { |
| 233 | + self.peer = Some(peer); |
| 234 | + } |
| 235 | + |
| 236 | + fn get_peer(&self) -> Option<Peer<RoleClient>> { |
| 237 | + self.peer.clone() |
| 238 | + } |
| 239 | +} |
| 240 | + |
| 241 | +#[tokio::test] |
| 242 | +async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> { |
| 243 | + let (server_transport, client_transport) = tokio::io::duplex(4096); |
| 244 | + |
| 245 | + // Server setup |
| 246 | + let server = OptionalSchemaTester::default(); |
| 247 | + let server_handle = tokio::spawn(async move { |
| 248 | + server.serve(server_transport).await?.waiting().await?; |
| 249 | + anyhow::Ok(()) |
| 250 | + }); |
| 251 | + |
| 252 | + // Create a simple client handler that just forwards tool calls |
| 253 | + let client_handler = DummyClientHandler::default(); |
| 254 | + let client = client_handler.serve(client_transport).await?; |
| 255 | + |
| 256 | + // Test null case |
| 257 | + let result = client |
| 258 | + .call_tool(CallToolRequestParam { |
| 259 | + name: "test_optional_i64_aggr".into(), |
| 260 | + arguments: Some( |
| 261 | + serde_json::json!({ |
| 262 | + "count": null, |
| 263 | + "mandatory_field": "test_null" |
| 264 | + }) |
| 265 | + .as_object() |
| 266 | + .unwrap() |
| 267 | + .clone(), |
| 268 | + ), |
| 269 | + }) |
| 270 | + .await?; |
| 271 | + |
| 272 | + let result_text = result |
| 273 | + .content |
| 274 | + .first() |
| 275 | + .and_then(|content| content.raw.as_text()) |
| 276 | + .map(|text| text.text.as_str()) |
| 277 | + .expect("Expected text content"); |
| 278 | + |
| 279 | + assert_eq!( |
| 280 | + result_text, "Received null count", |
| 281 | + "Null case should return expected message" |
| 282 | + ); |
| 283 | + |
| 284 | + // Test Some case |
| 285 | + let some_result = client |
| 286 | + .call_tool(CallToolRequestParam { |
| 287 | + name: "test_optional_i64_aggr".into(), |
| 288 | + arguments: Some( |
| 289 | + serde_json::json!({ |
| 290 | + "count": 42, |
| 291 | + "mandatory_field": "test_some" |
| 292 | + }) |
| 293 | + .as_object() |
| 294 | + .unwrap() |
| 295 | + .clone(), |
| 296 | + ), |
| 297 | + }) |
| 298 | + .await?; |
| 299 | + |
| 300 | + let some_result_text = some_result |
| 301 | + .content |
| 302 | + .first() |
| 303 | + .and_then(|content| content.raw.as_text()) |
| 304 | + .map(|text| text.text.as_str()) |
| 305 | + .expect("Expected text content"); |
| 306 | + |
| 307 | + assert_eq!( |
| 308 | + some_result_text, "Received count: 42", |
| 309 | + "Some case should return expected message" |
| 310 | + ); |
| 311 | + |
| 312 | + client.cancel().await?; |
| 313 | + server_handle.await??; |
| 314 | + Ok(()) |
| 315 | +} |
0 commit comments