1
- use std :: sync :: Arc ;
1
+ //cargo test --test test_tool_macros --features "client server"
2
2
3
- use rmcp:: { ServerHandler , handler:: server:: tool:: ToolCallContext , tool} ;
3
+ use std:: sync:: Arc ;
4
+ use rmcp:: {
5
+ ClientHandler ,
6
+ Peer ,
7
+ RoleClient ,
8
+ ServerHandler ,
9
+ ServiceExt ,
10
+ model:: {
11
+ CallToolRequestParam ,
12
+ ClientInfo ,
13
+ } ,
14
+ } ;
15
+ use rmcp:: { handler:: server:: tool:: ToolCallContext , tool} ;
4
16
use schemars:: JsonSchema ;
5
17
use serde:: { Deserialize , Serialize } ;
18
+ use serde_json;
6
19
7
20
#[ derive( Serialize , Deserialize , JsonSchema ) ]
8
21
pub struct GetWeatherRequest {
@@ -100,3 +113,205 @@ async fn test_tool_macros_with_generics() {
100
113
}
101
114
102
115
impl GetWeatherRequest { }
116
+
117
+ // Struct defined for testing optional field schema generation
118
+ #[ derive( Debug , Deserialize , Serialize , JsonSchema ) ]
119
+ pub struct OptionalFieldTestSchema {
120
+ #[ schemars( description = "An optional description field" ) ]
121
+ pub description : Option < String > ,
122
+ }
123
+
124
+ // Struct defined for testing optional i64 field schema generation and null handling
125
+ #[ derive( Debug , Deserialize , Serialize , JsonSchema ) ]
126
+ pub struct OptionalI64TestSchema {
127
+ #[ schemars( description = "An optional i64 field" ) ]
128
+ pub count : Option < i64 > ,
129
+ pub mandatory_field : String , // Added to ensure non-empty object schema
130
+ }
131
+
132
+ // Dummy struct to host the test tool method
133
+ #[ derive( Debug , Clone , Default ) ]
134
+ pub struct OptionalSchemaTester { }
135
+
136
+ impl OptionalSchemaTester {
137
+ // Dummy tool function using the test schema as an aggregated parameter
138
+ #[ tool( description = "A tool to test optional schema generation" ) ]
139
+ async fn test_optional_aggr ( & self , #[ tool( aggr) ] _req : OptionalFieldTestSchema ) {
140
+ // Implementation doesn't matter for schema testing
141
+ // Return type changed to () to satisfy IntoCallToolResult
142
+ }
143
+
144
+ // Tool function to test optional i64 handling
145
+ #[ tool( description = "A tool to test optional i64 schema generation" ) ]
146
+ async fn test_optional_i64_aggr (
147
+ & self ,
148
+ #[ tool( aggr) ] req : OptionalI64TestSchema ,
149
+ ) -> String {
150
+ match req. count {
151
+ Some ( c) => format ! ( "Received count: {}" , c) ,
152
+ None => "Received null count" . to_string ( ) ,
153
+ }
154
+ }
155
+ }
156
+
157
+ // Implement ServerHandler to route tool calls for OptionalSchemaTester
158
+ impl ServerHandler for OptionalSchemaTester {
159
+ async fn call_tool (
160
+ & self ,
161
+ request : rmcp:: model:: CallToolRequestParam ,
162
+ context : rmcp:: service:: RequestContext < rmcp:: RoleServer > ,
163
+ ) -> Result < rmcp:: model:: CallToolResult , rmcp:: Error > {
164
+ let tcc = ToolCallContext :: new ( self , request, context) ;
165
+ match tcc. name ( ) {
166
+ "test_optional_aggr" => Self :: test_optional_aggr_tool_call ( tcc) . await ,
167
+ "test_optional_i64_aggr" => Self :: test_optional_i64_aggr_tool_call ( tcc) . await ,
168
+ _ => Err ( rmcp:: Error :: invalid_params ( "method not found" , None ) ) ,
169
+ }
170
+ }
171
+ }
172
+
173
+ #[ test]
174
+ fn test_optional_field_schema_generation_via_macro ( ) {
175
+ // tests https://github.com/modelcontextprotocol/rust-sdk/issues/135
176
+
177
+ // Get the attributes generated by the #[tool] macro helper
178
+ let tool_attr = OptionalSchemaTester :: test_optional_aggr_tool_attr ( ) ;
179
+
180
+ // Print the actual generated schema for debugging
181
+ println ! (
182
+ "Actual input schema generated by macro: {:#?}" ,
183
+ tool_attr. input_schema
184
+ ) ;
185
+
186
+ // Verify the schema generated for the aggregated OptionalFieldTestSchema
187
+ // by the macro infrastructure (which should now use OpenAPI 3 settings)
188
+ let input_schema_map = & * tool_attr. input_schema ; // Dereference Arc<JsonObject>
189
+
190
+ // Check the schema for the 'description' property within the input schema
191
+ let properties = input_schema_map
192
+ . get ( "properties" )
193
+ . expect ( "Schema should have properties" )
194
+ . as_object ( )
195
+ . unwrap ( ) ;
196
+ let description_schema = properties
197
+ . get ( "description" )
198
+ . expect ( "Properties should include description" )
199
+ . as_object ( )
200
+ . unwrap ( ) ;
201
+
202
+ // Assert that the format is now `type: "string", nullable: true`
203
+ assert_eq ! (
204
+ description_schema. get( "type" ) . map( |v| v. as_str( ) . unwrap( ) ) ,
205
+ Some ( "string" ) ,
206
+ "Schema for Option<String> generated by macro should be type: \" string\" "
207
+ ) ;
208
+ assert_eq ! (
209
+ description_schema
210
+ . get( "nullable" )
211
+ . map( |v| v. as_bool( ) . unwrap( ) ) ,
212
+ Some ( true ) ,
213
+ "Schema for Option<String> generated by macro should have nullable: true"
214
+ ) ;
215
+ // We still check the description is correct
216
+ assert_eq ! (
217
+ description_schema
218
+ . get( "description" )
219
+ . map( |v| v. as_str( ) . unwrap( ) ) ,
220
+ Some ( "An optional description field" )
221
+ ) ;
222
+
223
+ // Ensure the old 'type: [T, null]' format is NOT used
224
+ let type_value = description_schema. get ( "type" ) . unwrap ( ) ;
225
+ assert ! (
226
+ !type_value. is_array( ) ,
227
+ "Schema type should not be an array [T, null]"
228
+ ) ;
229
+ }
230
+
231
+ // Define a dummy client handler
232
+ #[ derive( Debug , Clone , Default ) ]
233
+ struct DummyClientHandler {
234
+ peer : Option < Peer < RoleClient > > ,
235
+ }
236
+
237
+ impl ClientHandler for DummyClientHandler {
238
+ fn get_info ( & self ) -> ClientInfo {
239
+ ClientInfo :: default ( )
240
+ }
241
+
242
+ fn set_peer ( & mut self , peer : Peer < RoleClient > ) {
243
+ self . peer = Some ( peer) ;
244
+ }
245
+
246
+ fn get_peer ( & self ) -> Option < Peer < RoleClient > > {
247
+ self . peer . clone ( )
248
+ }
249
+ }
250
+
251
+ #[ tokio:: test]
252
+ async fn test_optional_i64_field_with_null_input ( ) -> anyhow:: Result < ( ) > {
253
+ let ( server_transport, client_transport) = tokio:: io:: duplex ( 4096 ) ;
254
+
255
+ // Server setup
256
+ let server = OptionalSchemaTester :: default ( ) ;
257
+ let server_handle = tokio:: spawn ( async move {
258
+ server. serve ( server_transport) . await ?
259
+ . waiting ( )
260
+ . await ?;
261
+ anyhow:: Ok ( ( ) )
262
+ } ) ;
263
+
264
+ // Create a simple client handler that just forwards tool calls
265
+ let client_handler = DummyClientHandler :: default ( ) ;
266
+ let client = client_handler. serve ( client_transport) . await ?;
267
+
268
+ // Test null case
269
+ let result = client. call_tool (
270
+ CallToolRequestParam {
271
+ name : "test_optional_i64_aggr" . into ( ) ,
272
+ arguments : Some ( serde_json:: json!( {
273
+ "count" : null,
274
+ "mandatory_field" : "test_null"
275
+ } ) . as_object ( ) . unwrap ( ) . clone ( ) ) ,
276
+ }
277
+ ) . await ?;
278
+
279
+ let result_text = result. content
280
+ . first ( )
281
+ . and_then ( |content| content. raw . as_text ( ) )
282
+ . map ( |text| text. text . as_str ( ) )
283
+ . expect ( "Expected text content" ) ;
284
+
285
+ assert_eq ! (
286
+ result_text,
287
+ "Received null count" ,
288
+ "Null case should return expected message"
289
+ ) ;
290
+
291
+ // Test Some case
292
+ let some_result = client. call_tool (
293
+ CallToolRequestParam {
294
+ name : "test_optional_i64_aggr" . into ( ) ,
295
+ arguments : Some ( serde_json:: json!( {
296
+ "count" : 42 ,
297
+ "mandatory_field" : "test_some"
298
+ } ) . as_object ( ) . unwrap ( ) . clone ( ) ) ,
299
+ }
300
+ ) . await ?;
301
+
302
+ let some_result_text = some_result. content
303
+ . first ( )
304
+ . and_then ( |content| content. raw . as_text ( ) )
305
+ . map ( |text| text. text . as_str ( ) )
306
+ . expect ( "Expected text content" ) ;
307
+
308
+ assert_eq ! (
309
+ some_result_text,
310
+ "Received count: 42" ,
311
+ "Some case should return expected message"
312
+ ) ;
313
+
314
+ client. cancel ( ) . await ?;
315
+ server_handle. await ??;
316
+ Ok ( ( ) )
317
+ }
0 commit comments