-
Notifications
You must be signed in to change notification settings - Fork 390
/
Copy pathmain.rs
85 lines (70 loc) · 2.69 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
use std::collections::HashMap;
use anyhow::Result;
use mistralrs::{
Function, IsqType, RequestBuilder, TextMessageRole, TextModelBuilder, Tool, ToolChoice,
ToolType,
};
use serde_json::{json, Value};
#[derive(serde::Deserialize, Debug, Clone)]
struct GetWeatherInput {
place: String,
}
fn get_weather(input: GetWeatherInput) -> String {
format!("Weather in {}: Temperature: 25C. Wind: calm. Dew point: 10C. Precipitiation: 5cm of rain expected.", input.place)
}
#[tokio::main]
async fn main() -> Result<()> {
let model = TextModelBuilder::new("meta-llama/Meta-Llama-3.1-8B-Instruct")
.with_logging()
.with_isq(IsqType::Q8_0)
.build()
.await?;
let parameters: HashMap<String, Value> = serde_json::from_value(json!({
"type": "object",
"properties": {
"place": {
"type": "string",
"description": "The place to get the weather for.",
},
},
"required": ["place"],
}))?;
let tools = vec![Tool {
tp: ToolType::Function,
function: Function {
description: Some("Get the weather for a certain city.".to_string()),
name: "get_weather".to_string(),
parameters: Some(parameters),
},
}];
// We will keep all the messages here
let mut messages = RequestBuilder::new()
.add_message(TextMessageRole::User, "What is the weather in Boston?")
.set_tools(tools)
.set_tool_choice(ToolChoice::Auto);
let response = model.send_chat_request(messages.clone()).await?;
let message = &response.choices[0].message;
if let Some(tool_calls) = &message.tool_calls {
let called = &tool_calls[0];
if called.function.name == "get_weather" {
let input: GetWeatherInput = serde_json::from_str(&called.function.arguments)?;
println!("Called tool `get_weather` with arguments {input:?}");
let result = get_weather(input);
println!("Output of tool call: {result}");
// Add tool call message from assistant so it knows what it called
// Then, add message from the tool
messages = messages
.add_message_with_tool_call(
TextMessageRole::Assistant,
String::new(),
vec![called.clone()],
)
.add_tool_message(result, called.id.clone())
.set_tool_choice(ToolChoice::None);
let response = model.send_chat_request(messages.clone()).await?;
let message = &response.choices[0].message;
println!("Output of model: {:?}", message.content);
}
}
Ok(())
}