Skip to content

Commit a0cb722

Browse files
committed
test(openapi): add OpenAPI v3 compatibility and test for nullable field schema workaround (modelcontextprotocol#135)
1 parent 18346b9 commit a0cb722

File tree

2 files changed

+220
-3
lines changed

2 files changed

+220
-3
lines changed

crates/rmcp/src/handler/server/tool.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,9 @@ use crate::{
1414
};
1515
/// A shortcut for generating a JSON schema for a type.
1616
pub fn schema_for_type<T: JsonSchema>() -> JsonObject {
17-
let schema = schemars::r#gen::SchemaGenerator::default().into_root_schema_for::<T>();
17+
let settings = schemars::r#gen::SchemaSettings::openapi3();
18+
let generator = settings.into_generator();
19+
let schema = generator.into_root_schema_for::<T>();
1820
let object = serde_json::to_value(schema).expect("failed to serialize schema");
1921
match object {
2022
serde_json::Value::Object(object) => object,

crates/rmcp/tests/test_tool_macros.rs

Lines changed: 217 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,21 @@
1-
use std::sync::Arc;
1+
//cargo test --test test_tool_macros --features "client server"
22

3-
use rmcp::{ServerHandler, handler::server::tool::ToolCallContext, tool};
3+
use std::sync::Arc;
4+
use rmcp::{
5+
ClientHandler,
6+
Peer,
7+
RoleClient,
8+
ServerHandler,
9+
ServiceExt,
10+
model::{
11+
CallToolRequestParam,
12+
ClientInfo,
13+
},
14+
};
15+
use rmcp::{handler::server::tool::ToolCallContext, tool};
416
use schemars::JsonSchema;
517
use serde::{Deserialize, Serialize};
18+
use serde_json;
619

720
#[derive(Serialize, Deserialize, JsonSchema)]
821
pub struct GetWeatherRequest {
@@ -100,3 +113,205 @@ async fn test_tool_macros_with_generics() {
100113
}
101114

102115
impl GetWeatherRequest {}
116+
117+
// Struct defined for testing optional field schema generation
118+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
119+
pub struct OptionalFieldTestSchema {
120+
#[schemars(description = "An optional description field")]
121+
pub description: Option<String>,
122+
}
123+
124+
// Struct defined for testing optional i64 field schema generation and null handling
125+
#[derive(Debug, Deserialize, Serialize, JsonSchema)]
126+
pub struct OptionalI64TestSchema {
127+
#[schemars(description = "An optional i64 field")]
128+
pub count: Option<i64>,
129+
pub mandatory_field: String, // Added to ensure non-empty object schema
130+
}
131+
132+
// Dummy struct to host the test tool method
133+
#[derive(Debug, Clone, Default)]
134+
pub struct OptionalSchemaTester {}
135+
136+
impl OptionalSchemaTester {
137+
// Dummy tool function using the test schema as an aggregated parameter
138+
#[tool(description = "A tool to test optional schema generation")]
139+
async fn test_optional_aggr(&self, #[tool(aggr)] _req: OptionalFieldTestSchema) {
140+
// Implementation doesn't matter for schema testing
141+
// Return type changed to () to satisfy IntoCallToolResult
142+
}
143+
144+
// Tool function to test optional i64 handling
145+
#[tool(description = "A tool to test optional i64 schema generation")]
146+
async fn test_optional_i64_aggr(
147+
&self,
148+
#[tool(aggr)] req: OptionalI64TestSchema,
149+
) -> String {
150+
match req.count {
151+
Some(c) => format!("Received count: {}", c),
152+
None => "Received null count".to_string(),
153+
}
154+
}
155+
}
156+
157+
// Implement ServerHandler to route tool calls for OptionalSchemaTester
158+
impl ServerHandler for OptionalSchemaTester {
159+
async fn call_tool(
160+
&self,
161+
request: rmcp::model::CallToolRequestParam,
162+
context: rmcp::service::RequestContext<rmcp::RoleServer>,
163+
) -> Result<rmcp::model::CallToolResult, rmcp::Error> {
164+
let tcc = ToolCallContext::new(self, request, context);
165+
match tcc.name() {
166+
"test_optional_aggr" => Self::test_optional_aggr_tool_call(tcc).await,
167+
"test_optional_i64_aggr" => Self::test_optional_i64_aggr_tool_call(tcc).await,
168+
_ => Err(rmcp::Error::invalid_params("method not found", None)),
169+
}
170+
}
171+
}
172+
173+
#[test]
174+
fn test_optional_field_schema_generation_via_macro() {
175+
// tests https://github.com/modelcontextprotocol/rust-sdk/issues/135
176+
177+
// Get the attributes generated by the #[tool] macro helper
178+
let tool_attr = OptionalSchemaTester::test_optional_aggr_tool_attr();
179+
180+
// Print the actual generated schema for debugging
181+
println!(
182+
"Actual input schema generated by macro: {:#?}",
183+
tool_attr.input_schema
184+
);
185+
186+
// Verify the schema generated for the aggregated OptionalFieldTestSchema
187+
// by the macro infrastructure (which should now use OpenAPI 3 settings)
188+
let input_schema_map = &*tool_attr.input_schema; // Dereference Arc<JsonObject>
189+
190+
// Check the schema for the 'description' property within the input schema
191+
let properties = input_schema_map
192+
.get("properties")
193+
.expect("Schema should have properties")
194+
.as_object()
195+
.unwrap();
196+
let description_schema = properties
197+
.get("description")
198+
.expect("Properties should include description")
199+
.as_object()
200+
.unwrap();
201+
202+
// Assert that the format is now `type: "string", nullable: true`
203+
assert_eq!(
204+
description_schema.get("type").map(|v| v.as_str().unwrap()),
205+
Some("string"),
206+
"Schema for Option<String> generated by macro should be type: \"string\""
207+
);
208+
assert_eq!(
209+
description_schema
210+
.get("nullable")
211+
.map(|v| v.as_bool().unwrap()),
212+
Some(true),
213+
"Schema for Option<String> generated by macro should have nullable: true"
214+
);
215+
// We still check the description is correct
216+
assert_eq!(
217+
description_schema
218+
.get("description")
219+
.map(|v| v.as_str().unwrap()),
220+
Some("An optional description field")
221+
);
222+
223+
// Ensure the old 'type: [T, null]' format is NOT used
224+
let type_value = description_schema.get("type").unwrap();
225+
assert!(
226+
!type_value.is_array(),
227+
"Schema type should not be an array [T, null]"
228+
);
229+
}
230+
231+
// Define a dummy client handler
232+
#[derive(Debug, Clone, Default)]
233+
struct DummyClientHandler {
234+
peer: Option<Peer<RoleClient>>,
235+
}
236+
237+
impl ClientHandler for DummyClientHandler {
238+
fn get_info(&self) -> ClientInfo {
239+
ClientInfo::default()
240+
}
241+
242+
fn set_peer(&mut self, peer: Peer<RoleClient>) {
243+
self.peer = Some(peer);
244+
}
245+
246+
fn get_peer(&self) -> Option<Peer<RoleClient>> {
247+
self.peer.clone()
248+
}
249+
}
250+
251+
#[tokio::test]
252+
async fn test_optional_i64_field_with_null_input() -> anyhow::Result<()> {
253+
let (server_transport, client_transport) = tokio::io::duplex(4096);
254+
255+
// Server setup
256+
let server = OptionalSchemaTester::default();
257+
let server_handle = tokio::spawn(async move {
258+
server.serve(server_transport).await?
259+
.waiting()
260+
.await?;
261+
anyhow::Ok(())
262+
});
263+
264+
// Create a simple client handler that just forwards tool calls
265+
let client_handler = DummyClientHandler::default();
266+
let client = client_handler.serve(client_transport).await?;
267+
268+
// Test null case
269+
let result = client.call_tool(
270+
CallToolRequestParam {
271+
name: "test_optional_i64_aggr".into(),
272+
arguments: Some(serde_json::json!({
273+
"count": null,
274+
"mandatory_field": "test_null"
275+
}).as_object().unwrap().clone()),
276+
}
277+
).await?;
278+
279+
let result_text = result.content
280+
.first()
281+
.and_then(|content| content.raw.as_text())
282+
.map(|text| text.text.as_str())
283+
.expect("Expected text content");
284+
285+
assert_eq!(
286+
result_text,
287+
"Received null count",
288+
"Null case should return expected message"
289+
);
290+
291+
// Test Some case
292+
let some_result = client.call_tool(
293+
CallToolRequestParam {
294+
name: "test_optional_i64_aggr".into(),
295+
arguments: Some(serde_json::json!({
296+
"count": 42,
297+
"mandatory_field": "test_some"
298+
}).as_object().unwrap().clone()),
299+
}
300+
).await?;
301+
302+
let some_result_text = some_result.content
303+
.first()
304+
.and_then(|content| content.raw.as_text())
305+
.map(|text| text.text.as_str())
306+
.expect("Expected text content");
307+
308+
assert_eq!(
309+
some_result_text,
310+
"Received count: 42",
311+
"Some case should return expected message"
312+
);
313+
314+
client.cancel().await?;
315+
server_handle.await??;
316+
Ok(())
317+
}

0 commit comments

Comments
 (0)