Skip to content

Commit cbdc98d

Browse files
feat(backend): adding Azure AI Foundry OpenAPI support (#3683)
* feat(chat): add support for Azure API with versioning and refactor model handling * feat(embedding): add AzureEmbeddingEngine for Azure API integration * fix(azure): update default API version for Azure integration * [autofix.ci] apply automated fixes * feat(embedding): enhance AzureEmbeddingEngine with detailed documentation and API version support * feat(embedding): standardize Azure API version usage across embedding and chat modules --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
1 parent 58717da commit cbdc98d

File tree

5 files changed

+190
-26
lines changed

5 files changed

+190
-26
lines changed

Diff for: crates/http-api-bindings/src/chat/mod.rs

+38-26
Original file line numberDiff line numberDiff line change
@@ -5,38 +5,50 @@ use tabby_common::config::HttpModelConfig;
55
use tabby_inference::{ChatCompletionStream, ExtendedOpenAIConfig};
66

77
use super::rate_limit;
8-
use crate::create_reqwest_client;
8+
use crate::{create_reqwest_client, AZURE_API_VERSION};
99

1010
pub async fn create(model: &HttpModelConfig) -> Arc<dyn ChatCompletionStream> {
1111
let api_endpoint = model
1212
.api_endpoint
1313
.as_deref()
1414
.expect("api_endpoint is required");
15-
let config = OpenAIConfig::default()
16-
.with_api_base(api_endpoint)
17-
.with_api_key(model.api_key.clone().unwrap_or_default());
18-
19-
let mut builder = ExtendedOpenAIConfig::builder();
20-
21-
builder
22-
.base(config)
23-
.supported_models(model.supported_models.clone())
24-
.model_name(model.model_name.as_deref().expect("Model name is required"));
25-
26-
if model.kind == "openai/chat" {
27-
// Do nothing
28-
} else if model.kind == "mistral/chat" {
29-
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
30-
} else {
31-
panic!("Unsupported model kind: {}", model.kind);
32-
}
33-
34-
let config = builder.build().expect("Failed to build config");
35-
36-
let engine = Box::new(
37-
async_openai_alt::Client::with_config(config)
38-
.with_http_client(create_reqwest_client(api_endpoint)),
39-
);
15+
16+
let engine: Box<dyn ChatCompletionStream> = match model.kind.as_str() {
17+
"azure/chat" => {
18+
let config = async_openai_alt::config::AzureConfig::new()
19+
.with_api_base(api_endpoint)
20+
.with_api_key(model.api_key.clone().unwrap_or_default())
21+
.with_api_version(AZURE_API_VERSION.to_string())
22+
.with_deployment_id(model.model_name.as_deref().expect("Model name is required"));
23+
Box::new(
24+
async_openai_alt::Client::with_config(config)
25+
.with_http_client(create_reqwest_client(api_endpoint)),
26+
)
27+
}
28+
"openai/chat" | "mistral/chat" => {
29+
let config = OpenAIConfig::default()
30+
.with_api_base(api_endpoint)
31+
.with_api_key(model.api_key.clone().unwrap_or_default());
32+
33+
let mut builder = ExtendedOpenAIConfig::builder();
34+
builder
35+
.base(config)
36+
.supported_models(model.supported_models.clone())
37+
.model_name(model.model_name.as_deref().expect("Model name is required"));
38+
39+
if model.kind == "mistral/chat" {
40+
builder.fields_to_remove(ExtendedOpenAIConfig::mistral_fields_to_remove());
41+
}
42+
43+
Box::new(
44+
async_openai_alt::Client::with_config(
45+
builder.build().expect("Failed to build config"),
46+
)
47+
.with_http_client(create_reqwest_client(api_endpoint)),
48+
)
49+
}
50+
_ => panic!("Unsupported model kind: {}", model.kind),
51+
};
4052

4153
Arc::new(rate_limit::new_chat(
4254
engine,

Diff for: crates/http-api-bindings/src/embedding/azure.rs

+123
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
use std::sync::Arc;
2+
3+
use anyhow::Result;
4+
use async_trait::async_trait;
5+
use reqwest::Client;
6+
use serde::{Deserialize, Serialize};
7+
use tabby_inference::Embedding;
8+
9+
use crate::AZURE_API_VERSION;
10+
11+
/// `AzureEmbeddingEngine` is responsible for interacting with Azure's Embedding API.
12+
///
13+
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
14+
#[derive(Clone)]
15+
pub struct AzureEmbeddingEngine {
16+
client: Arc<Client>,
17+
api_endpoint: String,
18+
api_key: String,
19+
}
20+
21+
/// Structure representing the request body for embedding.
22+
#[derive(Debug, Serialize)]
23+
struct EmbeddingRequest {
24+
input: String,
25+
}
26+
27+
/// Structure representing the response from the embedding API.
28+
#[derive(Debug, Deserialize)]
29+
struct EmbeddingResponse {
30+
data: Vec<Data>,
31+
}
32+
33+
/// Structure representing individual embedding data.
34+
#[derive(Debug, Deserialize)]
35+
struct Data {
36+
embedding: Vec<f32>,
37+
}
38+
39+
impl AzureEmbeddingEngine {
40+
/// Creates a new instance of `AzureEmbeddingEngine`.
41+
///
42+
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
43+
///
44+
/// # Parameters
45+
///
46+
/// - `api_endpoint`: The base URL of the Azure Embedding API.
47+
/// - `model_name`: The name of the deployed model, used to construct the deployment ID.
48+
/// - `api_key`: Optional API key for authentication.
49+
/// - `api_version`: Optional API version, defaults to "2023-05-15".
50+
///
51+
/// # Returns
52+
///
53+
/// A boxed instance that implements the `Embedding` trait.
54+
pub fn create(
55+
api_endpoint: &str,
56+
model_name: &str,
57+
api_key: Option<&str>,
58+
) -> Box<dyn Embedding> {
59+
let client = Client::new();
60+
let deployment_id = model_name;
61+
// Construct the full endpoint URL for the Azure Embedding API
62+
let azure_endpoint = format!(
63+
"{}/openai/deployments/{}/embeddings",
64+
api_endpoint.trim_end_matches('/'),
65+
deployment_id
66+
);
67+
68+
Box::new(Self {
69+
client: Arc::new(client),
70+
api_endpoint: azure_endpoint,
71+
api_key: api_key.unwrap_or_default().to_owned(),
72+
})
73+
}
74+
}
75+
76+
#[async_trait]
77+
impl Embedding for AzureEmbeddingEngine {
78+
/// Generates an embedding vector for the given prompt.
79+
///
80+
/// **Note**: Currently, this implementation only supports the OpenAI API and specific API versions.
81+
///
82+
/// # Parameters
83+
///
84+
/// - `prompt`: The input text to generate embeddings for.
85+
///
86+
/// # Returns
87+
///
88+
/// A `Result` containing the embedding vector or an error.
89+
async fn embed(&self, prompt: &str) -> Result<Vec<f32>> {
90+
// Clone all necessary fields to ensure thread safety across await points
91+
let api_endpoint = self.api_endpoint.clone();
92+
let api_key = self.api_key.clone();
93+
let api_version = AZURE_API_VERSION.to_string();
94+
let request = EmbeddingRequest {
95+
input: prompt.to_owned(),
96+
};
97+
98+
// Send a POST request to the Azure Embedding API
99+
let response = self
100+
.client
101+
.post(&api_endpoint)
102+
.query(&[("api-version", &api_version)])
103+
.header("api-key", &api_key)
104+
.header("Content-Type", "application/json")
105+
.json(&request)
106+
.send()
107+
.await?;
108+
109+
// Check if the response status indicates success
110+
if !response.status().is_success() {
111+
let error_text = response.text().await?;
112+
anyhow::bail!("Azure API error: {}", error_text);
113+
}
114+
115+
// Deserialize the response body into `EmbeddingResponse`
116+
let embedding_response: EmbeddingResponse = response.json().await?;
117+
embedding_response
118+
.data
119+
.first()
120+
.map(|data| data.embedding.clone())
121+
.ok_or_else(|| anyhow::anyhow!("No embedding data received"))
122+
}
123+
}

Diff for: crates/http-api-bindings/src/embedding/mod.rs

+10
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
mod azure;
12
mod llama;
23
mod openai;
34

45
use core::panic;
56
use std::sync::Arc;
67

8+
use azure::AzureEmbeddingEngine;
79
use llama::LlamaCppEngine;
810
use openai::OpenAIEmbeddingEngine;
911
use tabby_common::config::HttpModelConfig;
@@ -40,6 +42,14 @@ pub async fn create(config: &HttpModelConfig) -> Arc<dyn Embedding> {
4042
.expect("model_name must be set for voyage/embedding"),
4143
config.api_key.as_deref(),
4244
),
45+
"azure/embedding" => AzureEmbeddingEngine::create(
46+
config
47+
.api_endpoint
48+
.as_deref()
49+
.expect("api_endpoint is required for azure/embedding"),
50+
config.model_name.as_deref().unwrap_or_default(), // Provide a default if model_name is optional
51+
config.api_key.as_deref(),
52+
),
4353
unsupported_kind => panic!(
4454
"Unsupported kind for http embedding model: {}",
4555
unsupported_kind

Diff for: crates/http-api-bindings/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ fn create_reqwest_client(api_endpoint: &str) -> reqwest::Client {
2020

2121
builder.build().unwrap()
2222
}
23+
24+
static AZURE_API_VERSION: &str = "2024-02-01";

Diff for: crates/tabby-inference/src/chat.rs

+17
Original file line numberDiff line numberDiff line change
@@ -125,3 +125,20 @@ impl ChatCompletionStream for async_openai_alt::Client<ExtendedOpenAIConfig> {
125125
self.chat().create_stream(request).await
126126
}
127127
}
128+
129+
#[async_trait]
130+
impl ChatCompletionStream for async_openai_alt::Client<async_openai_alt::config::AzureConfig> {
131+
async fn chat(
132+
&self,
133+
request: CreateChatCompletionRequest,
134+
) -> Result<CreateChatCompletionResponse, OpenAIError> {
135+
self.chat().create(request).await
136+
}
137+
138+
async fn chat_stream(
139+
&self,
140+
request: CreateChatCompletionRequest,
141+
) -> Result<ChatCompletionResponseStream, OpenAIError> {
142+
self.chat().create_stream(request).await
143+
}
144+
}

0 commit comments

Comments
 (0)