Skip to content

Commit 6527e83

Browse files
committed
test(openapi): add OpenAPI v3 compatibility and test for nullable field schema workaround (#135)
1 parent 18346b9 commit 6527e83

File tree

2 files changed

+218
-3
lines changed

2 files changed

+218
-3
lines changed

crates/rmcp/src/handler/server/tool.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use crate::{
1414
};
1515
/// A shortcut for generating a JSON schema for a type.
1616
pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
17-
let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<T>();
17+
let settings = schemars::r#gen::SchemaSettings::openapi3();
18+
let generator = settings.into_generator();
19+
let schema = generator.into_root_schema_for::<T>();
1820
let object = serde_json::to_value(schema).expect("failed to serialize schema");
1921
match object {
2022
serde_json::Value::Object(object) => object,

crates/rmcp/tests/test_tool_macros.rs

+215-2
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,14 @@
1-
use std::sync::Arc;
1+
//cargo test --test test_tool_macros --features "client server"
22

3-
use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool};
3+
use rmcp::{
4+
ClientHandler, Peer, RoleClient, ServerHandler, ServiceExt,
5+
model::{CallToolRequestParam, ClientInfo},
6+
};
7+
use rmcp::{handler::server::tool::ToolCallContext, tool};
48
use schemars::JsonSchema;
59
use serde::{Deserialize, Serialize};
10+
use serde_json;
11+
use std::sync::Arc;
612

713
#[derive(Serialize, Deserialize, JsonSchema)]
814
pub struct GetWeatherRequest {
@@ -100,3 +106,210 @@ async fn test_tool_macros_with_generics() {
100106
}
101107

102108
impl GetWeatherRequest {}
109+
110+
// Struct defined for testing optional field schema generation
111+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
112+
pub struct OptionalFieldTestSchema {
113+
#[schemars(description = "An optional description field")]
114+
pub description: Option<String>,
115+
}
116+
117+
// Struct defined for testing optional i64 field schema generation and null handling
118+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
119+
pub struct OptionalI64TestSchema {
120+
#[schemars(description = "An optional i64 field")]
121+
pub count: Option<i64>,
122+
pub mandatory_field: String, // Added to ensure non-empty object schema
123+
}
124+
125+
// Dummy struct to host the test tool method
126+
#[derive(Debug, Clone, Default)]
127+
pub struct OptionalSchemaTester {}
128+
129+
impl OptionalSchemaTester {
130+
// Dummy tool function using the test schema as an aggregated parameter
131+
#[tool(description = "A tool to test optional schema generation")]
132+
async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) {
133+
// Implementation doesn't matter for schema testing
134+
// Return type changed to () to satisfy IntoCallToolResult
135+
}
136+
137+
// Tool function to test optional i64 handling
138+
#[tool(description = "A tool to test optional i64 schema generation")]
139+
async fn test_optional_i64_aggr(&self, #[tool(aggr)] req: OptionalI64TestSchema) -> String {
140+
match req.count {
141+
Some(c) => format!("Received count: {}", c),
142+
None => "Received null count".to_string(),
143+
}
144+
}
145+
}
146+
147+
// Implement ServerHandler to route tool calls for OptionalSchemaTester
148+
impl ServerHandler for OptionalSchemaTester {
149+
async fn call_tool(
150+
&self,
151+
request: rmcp::model::CallToolRequestParam,
152+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
153+
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
154+
let tcc = ToolCallContext::new(self, request, context);
155+
match tcc.name() {
156+
"test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await,
157+
"test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await,
158+
_ => Err(rmcp::Error::invalid_params("method not found", None)),
159+
}
160+
}
161+
}
162+
163+
#[test]
164+
fn test_optional_field_schema_generation_via_macro() {
165+
// tests https://github.com/modelcontextprotocol/rust-sdk/issues/135
166+
167+
// Get the attributes generated by the #[tool] macro helper
168+
let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr();
169+
170+
// Print the actual generated schema for debugging
171+
println!(
172+
"Actual input schema generated by macro: {:#?}",
173+
tool_attr.input_schema
174+
);
175+
176+
// Verify the schema generated for the aggregated OptionalFieldTestSchema
177+
// by the macro infrastructure (which should now use OpenAPI 3 settings)
178+
let input_schema_map = &*tool_attr.input_schema; // Dereference Arc<JsonObject>
179+
180+
// Check the schema for the 'description' property within the input schema
181+
let properties = input_schema_map
182+
.get("properties")
183+
.expect("Schema should have properties")
184+
.as_object()
185+
.unwrap();
186+
let description_schema = properties
187+
.get("description")
188+
.expect("Properties should include description")
189+
.as_object()
190+
.unwrap();
191+
192+
// Assert that the format is now `type: "string", nullable: true`
193+
assert_eq!(
194+
description_schema.get("type").map(|v| v.as_str().unwrap()),
195+
Some("string"),
196+
"Schema for Option<String> generated by macro should be type: \"string\""
197+
);
198+
assert_eq!(
199+
description_schema
200+
.get("nullable")
201+
.map(|v| v.as_bool().unwrap()),
202+
Some(true),
203+
"Schema for Option<String> generated by macro should have nullable: true"
204+
);
205+
// We still check the description is correct
206+
assert_eq!(
207+
description_schema
208+
.get("description")
209+
.map(|v| v.as_str().unwrap()),
210+
Some("An optional description field")
211+
);
212+
213+
// Ensure the old 'type: [T, null]' format is NOT used
214+
let type_value = description_schema.get("type").unwrap();
215+
assert!(
216+
!type_value.is_array(),
217+
"Schema type should not be an array [T, null]"
218+
);
219+
}
220+
221+
// Define a dummy client handler
222+
#[derive(Debug, Clone, Default)]
223+
struct DummyClientHandler {
224+
peer: Option<Peer<RoleClient>>,
225+
}
226+
227+
impl ClientHandler for DummyClientHandler {
228+
fn get_info(&self) -> ClientInfo {
229+
ClientInfo::default()
230+
}
231+
232+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
233+
self.peer = Some(peer);
234+
}
235+
236+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
237+
self.peer.clone()
238+
}
239+
}
240+
241+
#[tokio::test]
242+
async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> {
243+
let (server_transport, client_transport) = tokio::io::duplex(4096);
244+
245+
// Server setup
246+
let server = OptionalSchemaTester::default();
247+
let server_handle = tokio::spawn(async move {
248+
server.serve(server_transport).await?.waiting().await?;
249+
anyhow::Ok(())
250+
});
251+
252+
// Create a simple client handler that just forwards tool calls
253+
let client_handler = DummyClientHandler::default();
254+
let client = client_handler.serve(client_transport).await?;
255+
256+
// Test null case
257+
let result = client
258+
.call_tool(CallToolRequestParam {
259+
name: "test_optional_i64_aggr".into(),
260+
arguments: Some(
261+
serde_json::json!({
262+
"count": null,
263+
"mandatory_field": "test_null"
264+
})
265+
.as_object()
266+
.unwrap()
267+
.clone(),
268+
),
269+
})
270+
.await?;
271+
272+
let result_text = result
273+
.content
274+
.first()
275+
.and_then(|content| content.raw.as_text())
276+
.map(|text| text.text.as_str())
277+
.expect("Expected text content");
278+
279+
assert_eq!(
280+
result_text, "Received null count",
281+
"Null case should return expected message"
282+
);
283+
284+
// Test Some case
285+
let some_result = client
286+
.call_tool(CallToolRequestParam {
287+
name: "test_optional_i64_aggr".into(),
288+
arguments: Some(
289+
serde_json::json!({
290+
"count": 42,
291+
"mandatory_field": "test_some"
292+
})
293+
.as_object()
294+
.unwrap()
295+
.clone(),
296+
),
297+
})
298+
.await?;
299+
300+
let some_result_text = some_result
301+
.content
302+
.first()
303+
.and_then(|content| content.raw.as_text())
304+
.map(|text| text.text.as_str())
305+
.expect("Expected text content");
306+
307+
assert_eq!(
308+
some_result_text, "Received count: 42",
309+
"Some case should return expected message"
310+
);
311+
312+
client.cancel().await?;
313+
server_handle.await??;
314+
Ok(())
315+
}

0 commit comments

Comments
 (0)