diff --git a/.gitignore b/.gitignore index 552012ec..c2684c96 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,7 @@ test.ipynb bin/ obj/ .vs/ +.vscode/ # build, distribute, and bins build/ diff --git a/samples/README.md b/samples/README.md index 93f3bd57..bcac6bf3 100644 --- a/samples/README.md +++ b/samples/README.md @@ -1,6 +1,6 @@ # Foundry Local Samples -Explore complete working examples that demonstrate how to use Foundry Local — an end-to-end local AI solution that runs entirely on-device. These samples cover chat completions, audio transcription, tool calling, LangChain integration, and more. +Explore complete working examples that demonstrate how to use Foundry Local — an end-to-end local AI solution that runs entirely on-device. These samples cover chat completions, embeddings, audio transcription, tool calling, LangChain integration, and more. > **New to Foundry Local?** Check out the [main README](../README.md) for an overview and quickstart, or visit the [Foundry Local documentation](https://learn.microsoft.com/azure/foundry-local/) on Microsoft Learn. @@ -8,7 +8,7 @@ Explore complete working examples that demonstrate how to use Foundry Local — | Language | Samples | Description | |----------|---------|-------------| -| [**C#**](cs/) | 12 | .NET SDK samples including native chat, audio transcription, tool calling, model management, web server, and tutorials. Uses WinML on Windows for hardware acceleration. | -| [**JavaScript**](js/) | 12 | Node.js SDK samples including native chat, audio transcription, Electron desktop app, Copilot SDK integration, LangChain, tool calling, web server, and tutorials. | -| [**Python**](python/) | 9 | Python samples using the OpenAI-compatible API, including chat, audio transcription, LangChain integration, tool calling, web server, and tutorials. | -| [**Rust**](rust/) | 8 | Rust SDK samples including native chat, audio transcription, tool calling, web server, and tutorials. | +| [**C#**](cs/) | 13 | .NET SDK samples including native chat, embeddings, audio transcription, tool calling, model management, web server, and tutorials. Uses WinML on Windows for hardware acceleration. | +| [**JavaScript**](js/) | 13 | Node.js SDK samples including native chat, embeddings, audio transcription, Electron desktop app, Copilot SDK integration, LangChain, tool calling, web server, and tutorials. | +| [**Python**](python/) | 10 | Python samples using the OpenAI-compatible API, including chat, embeddings, audio transcription, LangChain integration, tool calling, web server, and tutorials. | +| [**Rust**](rust/) | 9 | Rust SDK samples including native chat, embeddings, audio transcription, tool calling, web server, and tutorials. | diff --git a/samples/cs/README.md b/samples/cs/README.md index 367c432e..ad10a3c6 100644 --- a/samples/cs/README.md +++ b/samples/cs/README.md @@ -12,6 +12,7 @@ Both packages provide the same APIs, so the same source code works on all platfo | Sample | Description | |---|---| | [native-chat-completions](native-chat-completions/) | Initialize the SDK, download a model, and run chat completions. | +| [embeddings](embeddings/) | Generate single and batch text embeddings using the Foundry Local SDK. | | [audio-transcription-example](audio-transcription-example/) | Transcribe audio files using the Foundry Local SDK. | | [foundry-local-web-server](foundry-local-web-server/) | Set up a local OpenAI-compliant web server. | | [tool-calling-foundry-local-sdk](tool-calling-foundry-local-sdk/) | Use tool calling with native chat completions. | diff --git a/samples/cs/embeddings/Embeddings.csproj b/samples/cs/embeddings/Embeddings.csproj new file mode 100644 index 00000000..4d948c56 --- /dev/null +++ b/samples/cs/embeddings/Embeddings.csproj @@ -0,0 +1,48 @@ + + + + Exe + enable + enable + + + + + net9.0-windows10.0.26100 + false + ARM64;x64 + None + false + + + + + net9.0 + + + + $(NETCoreSdkRuntimeIdentifier) + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/samples/cs/embeddings/Program.cs b/samples/cs/embeddings/Program.cs new file mode 100644 index 00000000..348bc346 --- /dev/null +++ b/samples/cs/embeddings/Program.cs @@ -0,0 +1,74 @@ +// +// +using Microsoft.AI.Foundry.Local; +// + +// +var config = new Configuration +{ + AppName = "foundry_local_samples", + LogLevel = Microsoft.AI.Foundry.Local.LogLevel.Information +}; + +// Initialize the singleton instance. +await FoundryLocalManager.CreateAsync(config, Utils.GetAppLogger()); +var mgr = FoundryLocalManager.Instance; +// + +// +// Get the model catalog +var catalog = await mgr.GetCatalogAsync(); + +// Get an embedding model +var model = await catalog.GetModelAsync("qwen3-0.6b-embedding") ?? throw new Exception("Embedding model not found"); + +// Download the model (the method skips download if already cached) +await model.DownloadAsync(progress => +{ + Console.Write($"\rDownloading model: {progress:F2}%"); + if (progress >= 100f) + { + Console.WriteLine(); + } +}); + +// Load the model +Console.Write($"Loading model {model.Id}..."); +await model.LoadAsync(); +Console.WriteLine("done."); +// + +// +// Get an embedding client +var embeddingClient = await model.GetEmbeddingClientAsync(); + +// Generate a single embedding +Console.WriteLine("\n--- Single Embedding ---"); +var response = await embeddingClient.GenerateEmbeddingAsync("The quick brown fox jumps over the lazy dog"); +var embedding = response.Data[0].Embedding; +Console.WriteLine($"Dimensions: {embedding.Count}"); +Console.WriteLine($"First 5 values: [{string.Join(", ", embedding.Take(5).Select(v => v.ToString("F6")))}]"); +// + +// +// Generate embeddings for multiple inputs +Console.WriteLine("\n--- Batch Embeddings ---"); +var batchResponse = await embeddingClient.GenerateEmbeddingsAsync([ + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + "Rust is a systems programming language" +]); + +Console.WriteLine($"Number of embeddings: {batchResponse.Data.Count}"); +for (var i = 0; i < batchResponse.Data.Count; i++) +{ + Console.WriteLine($" [{i}] Dimensions: {batchResponse.Data[i].Embedding.Count}"); +} +// + +// +// Tidy up - unload the model +await model.UnloadAsync(); +Console.WriteLine("\nModel unloaded."); +// +// diff --git a/samples/js/README.md b/samples/js/README.md index 28f1e7e7..d334555c 100644 --- a/samples/js/README.md +++ b/samples/js/README.md @@ -11,6 +11,7 @@ These samples demonstrate how to use the Foundry Local JavaScript SDK (`foundry- | Sample | Description | |--------|-------------| | [native-chat-completions](native-chat-completions/) | Initialize the SDK, download a model, and run non-streaming and streaming chat completions. | +| [embeddings](embeddings/) | Generate single and batch text embeddings using the Foundry Local SDK. | | [audio-transcription-example](audio-transcription-example/) | Transcribe audio files using the Whisper model with streaming output. | | [chat-and-audio-foundry-local](chat-and-audio-foundry-local/) | Unified sample demonstrating both chat and audio transcription in one application. | | [electron-chat-application](electron-chat-application/) | Full-featured Electron desktop chat app with voice transcription and model management. | diff --git a/samples/js/embeddings/app.js b/samples/js/embeddings/app.js new file mode 100644 index 00000000..ea6ff185 --- /dev/null +++ b/samples/js/embeddings/app.js @@ -0,0 +1,73 @@ +// +// +import { FoundryLocalManager } from 'foundry-local-sdk'; +// + +// Initialize the Foundry Local SDK +console.log('Initializing Foundry Local SDK...'); + +// +const manager = FoundryLocalManager.create({ + appName: 'foundry_local_samples', + logLevel: 'info' +}); +// +console.log('✓ SDK initialized successfully'); + +// +// Get an embedding model +const modelAlias = 'qwen3-0.6b-embedding'; +const model = await manager.catalog.getModel(modelAlias); + +// Download the model +console.log(`\nDownloading model ${modelAlias}...`); +await model.download((progress) => { + process.stdout.write(`\rDownloading... ${progress.toFixed(2)}%`); +}); +console.log('\n✓ Model downloaded'); + +// Load the model +console.log(`\nLoading model ${modelAlias}...`); +await model.load(); +console.log('✓ Model loaded'); +// + +// +// Create embedding client +console.log('\nCreating embedding client...'); +const embeddingClient = model.createEmbeddingClient(); +console.log('✓ Embedding client created'); + +// Generate a single embedding +console.log('\n--- Single Embedding ---'); +const response = await embeddingClient.generateEmbedding( + 'The quick brown fox jumps over the lazy dog' +); + +const embedding = response.data[0].embedding; +console.log(`Dimensions: ${embedding.length}`); +console.log(`First 5 values: [${embedding.slice(0, 5).map(v => v.toFixed(6)).join(', ')}]`); +// + +// +// Generate embeddings for multiple inputs +console.log('\n--- Batch Embeddings ---'); +const batchResponse = await embeddingClient.generateEmbeddings([ + 'Machine learning is a subset of artificial intelligence', + 'The capital of France is Paris', + 'Rust is a systems programming language' +]); + +console.log(`Number of embeddings: ${batchResponse.data.length}`); +for (let i = 0; i < batchResponse.data.length; i++) { + console.log(` [${i}] Dimensions: ${batchResponse.data[i].embedding.length}`); +} +// + +// +// Unload the model +console.log('\nUnloading model...'); +await model.unload(); +console.log('✓ Model unloaded'); +// +// diff --git a/samples/js/embeddings/package.json b/samples/js/embeddings/package.json new file mode 100644 index 00000000..8353cb65 --- /dev/null +++ b/samples/js/embeddings/package.json @@ -0,0 +1,15 @@ +{ + "name": "embeddings", + "version": "1.0.0", + "type": "module", + "main": "app.js", + "scripts": { + "start": "node app.js" + }, + "dependencies": { + "foundry-local-sdk": "latest" + }, + "optionalDependencies": { + "foundry-local-sdk-winml": "latest" + } +} diff --git a/samples/python/README.md b/samples/python/README.md index 391cf123..7262f012 100644 --- a/samples/python/README.md +++ b/samples/python/README.md @@ -11,6 +11,7 @@ These samples demonstrate how to use Foundry Local with Python. | Sample | Description | |--------|-------------| | [native-chat-completions](native-chat-completions/) | Initialize the SDK, start the local service, and run streaming chat completions. | +| [embeddings](embeddings/) | Generate single and batch text embeddings using the Foundry Local SDK. | | [audio-transcription](audio-transcription/) | Transcribe audio files using the Whisper model. | | [web-server](web-server/) | Start a local OpenAI-compatible web server and call it with the OpenAI Python SDK. | | [tool-calling](tool-calling/) | Tool calling with custom function definitions (get_weather, calculate). | diff --git a/samples/python/embeddings/requirements.txt b/samples/python/embeddings/requirements.txt new file mode 100644 index 00000000..7602a48b --- /dev/null +++ b/samples/python/embeddings/requirements.txt @@ -0,0 +1,2 @@ +foundry-local-sdk; sys_platform != "win32" +foundry-local-sdk-winml; sys_platform == "win32" diff --git a/samples/python/embeddings/src/app.py b/samples/python/embeddings/src/app.py new file mode 100644 index 00000000..30ade4b2 --- /dev/null +++ b/samples/python/embeddings/src/app.py @@ -0,0 +1,61 @@ +# +# +from foundry_local_sdk import Configuration, FoundryLocalManager +# + + +def main(): + # + # Initialize the Foundry Local SDK + config = Configuration(app_name="foundry_local_samples") + FoundryLocalManager.initialize(config) + manager = FoundryLocalManager.instance + + # Select and load an embedding model from the catalog + model = manager.catalog.get_model("qwen3-0.6b-embedding") + model.download( + lambda progress: print( + f"\rDownloading model: {progress:.2f}%", + end="", + flush=True, + ) + ) + print() + model.load() + print("Model loaded and ready.") + + # Get an embedding client + client = model.get_embedding_client() + # + + # + # Generate a single embedding + print("\n--- Single Embedding ---") + response = client.generate_embedding("The quick brown fox jumps over the lazy dog") + embedding = response.data[0].embedding + print(f"Dimensions: {len(embedding)}") + print(f"First 5 values: {embedding[:5]}") + # + + # + # Generate embeddings for multiple inputs + print("\n--- Batch Embeddings ---") + batch_response = client.generate_embeddings([ + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + "Rust is a systems programming language", + ]) + + print(f"Number of embeddings: {len(batch_response.data)}") + for i, data in enumerate(batch_response.data): + print(f" [{i}] Dimensions: {len(data.embedding)}") + # + + # Clean up + model.unload() + print("\nModel unloaded.") + + +if __name__ == "__main__": + main() +# diff --git a/samples/rust/Cargo.toml b/samples/rust/Cargo.toml index 42d1293f..7be551ea 100644 --- a/samples/rust/Cargo.toml +++ b/samples/rust/Cargo.toml @@ -4,6 +4,7 @@ members = [ "tool-calling-foundry-local", "native-chat-completions", "audio-transcription-example", + "embeddings", "tutorial-chat-assistant", "tutorial-document-summarizer", "tutorial-tool-calling", diff --git a/samples/rust/README.md b/samples/rust/README.md index f2ca4f52..71a66873 100644 --- a/samples/rust/README.md +++ b/samples/rust/README.md @@ -11,6 +11,7 @@ These samples demonstrate how to use the Rust binding for Foundry Local. | Sample | Description | |--------|-------------| | [native-chat-completions](native-chat-completions/) | Non-streaming and streaming chat completions using the native chat client. | +| [embeddings](embeddings/) | Generate single and batch text embeddings using the native embedding client. | | [audio-transcription-example](audio-transcription-example/) | Audio transcription (non-streaming and streaming) using the Whisper model. | | [foundry-local-webserver](foundry-local-webserver/) | Start a local OpenAI-compatible web server and call it with a standard HTTP client. | | [tool-calling-foundry-local](tool-calling-foundry-local/) | Tool calling with streaming responses, multi-turn conversation, and local tool execution. | diff --git a/samples/rust/embeddings/Cargo.toml b/samples/rust/embeddings/Cargo.toml new file mode 100644 index 00000000..ebaa21be --- /dev/null +++ b/samples/rust/embeddings/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "embeddings" +version = "0.1.0" +edition = "2021" +description = "Native SDK embeddings (single and batch) using the Foundry Local Rust SDK" + +[dependencies] +foundry-local-sdk = { path = "../../../sdk/rust" } +tokio = { version = "1", features = ["rt-multi-thread", "macros"] } + +[target.'cfg(windows)'.dependencies] +foundry-local-sdk = { path = "../../../sdk/rust", features = ["winml"] } diff --git a/samples/rust/embeddings/src/main.rs b/samples/rust/embeddings/src/main.rs new file mode 100644 index 00000000..9b5550f0 --- /dev/null +++ b/samples/rust/embeddings/src/main.rs @@ -0,0 +1,88 @@ +// +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +// +use foundry_local_sdk::{FoundryLocalConfig, FoundryLocalManager}; +// + +const ALIAS: &str = "qwen3-0.6b-embedding"; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("Native Embeddings"); + println!("=================\n"); + + // ── 1. Initialise the manager ──────────────────────────────────────── + // + let manager = FoundryLocalManager::create(FoundryLocalConfig::new("foundry_local_samples"))?; + // + + // ── 2. Pick a model and ensure it is downloaded ───────────────────── + // + let model = manager.catalog().get_model(ALIAS).await?; + println!("Model: {} (id: {})", model.alias(), model.id()); + + if !model.is_cached().await? { + println!("Downloading model..."); + model + .download(Some(|progress: f64| { + print!("\r {progress:.1}%"); + std::io::Write::flush(&mut std::io::stdout()).ok(); + })) + .await?; + println!(); + } + + println!("Loading model..."); + model.load().await?; + println!("✓ Model loaded\n"); + // + + // ── 3. Create an embedding client ─────────────────────────────────── + // + let client = model.create_embedding_client(); + // + + // ── 4. Single embedding ───────────────────────────────────────────── + // + println!("--- Single Embedding ---"); + let response = client + .generate_embedding("The quick brown fox jumps over the lazy dog") + .await?; + + let embedding = &response.data[0].embedding; + println!("Dimensions: {}", embedding.len()); + println!( + "First 5 values: {:?}", + &embedding[..5] + ); + // + + // ── 5. Batch embeddings ───────────────────────────────────────────── + // + println!("\n--- Batch Embeddings ---"); + let batch_response = client + .generate_embeddings(&[ + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + "Rust is a systems programming language", + ]) + .await?; + + println!("Number of embeddings: {}", batch_response.data.len()); + for (i, data) in batch_response.data.iter().enumerate() { + println!(" [{i}] Dimensions: {}", data.embedding.len()); + } + // + + // ── 6. Unload the model ───────────────────────────────────────────── + // + println!("\nUnloading model..."); + model.unload().await?; + println!("Done."); + // + + Ok(()) +} +// diff --git a/sdk/cs/README.md b/sdk/cs/README.md index 20580e65..5e40ed2b 100644 --- a/sdk/cs/README.md +++ b/sdk/cs/README.md @@ -7,6 +7,7 @@ The Foundry Local C# SDK provides a .NET interface for running AI models locally - **Model catalog** — browse and search all available models; filter by cached or loaded state - **Lifecycle management** — download, load, unload, and remove models programmatically - **Chat completions** — synchronous and `IAsyncEnumerable` streaming via OpenAI-compatible types +- **Embeddings** — generate text embeddings via OpenAI-compatible API - **Audio transcription** — transcribe audio files with streaming support - **Download progress** — wire up an `Action` callback for real-time download percentage - **Model variants** — select specific hardware/quantization variants per model alias @@ -246,6 +247,31 @@ chatClient.Settings.TopP = 0.9f; chatClient.Settings.FrequencyPenalty = 0.5f; ``` +### Embeddings + +```csharp +var embeddingClient = await model.GetEmbeddingClientAsync(); + +// Single input +var response = await embeddingClient.GenerateEmbeddingAsync("The quick brown fox jumps over the lazy dog"); +var embedding = response.Data[0].Embedding; // List +Console.WriteLine($"Dimensions: {embedding.Count}"); + +// Batch input +var batchResponse = await embeddingClient.GenerateEmbeddingsAsync([ + "The quick brown fox", + "The capital of France is Paris" +]); +// batchResponse.Data[0].Embedding, batchResponse.Data[1].Embedding +``` + +#### Embedding Settings + +```csharp +embeddingClient.Settings.Dimensions = 512; // optional: reduce dimensionality +embeddingClient.Settings.EncodingFormat = "float"; // "float" or "base64" +``` + ### Audio Transcription ```csharp diff --git a/sdk/cs/docs/api/index.md b/sdk/cs/docs/api/index.md index 4d084f87..c83e0a43 100644 --- a/sdk/cs/docs/api/index.md +++ b/sdk/cs/docs/api/index.md @@ -30,6 +30,8 @@ [OpenAIChatClient](./microsoft.ai.foundry.local.openaichatclient.md) +[OpenAIEmbeddingClient](./microsoft.ai.foundry.local.openaiembeddingclient.md) + [Parameter](./microsoft.ai.foundry.local.parameter.md) [PromptTemplate](./microsoft.ai.foundry.local.prompttemplate.md) diff --git a/sdk/cs/docs/api/microsoft.ai.foundry.local.imodel.md b/sdk/cs/docs/api/microsoft.ai.foundry.local.imodel.md index 861386a8..95185abe 100644 --- a/sdk/cs/docs/api/microsoft.ai.foundry.local.imodel.md +++ b/sdk/cs/docs/api/microsoft.ai.foundry.local.imodel.md @@ -208,6 +208,24 @@ Optional cancellation token. [Task<OpenAIAudioClient>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
OpenAI.AudioClient +### **GetEmbeddingClientAsync(Nullable<CancellationToken>)** + +Get an OpenAI API based EmbeddingClient + +```csharp +Task GetEmbeddingClientAsync(Nullable ct) +``` + +#### Parameters + +`ct` [Nullable<CancellationToken>](https://docs.microsoft.com/en-us/dotnet/api/system.nullable-1)
+Optional cancellation token. + +#### Returns + +[Task<OpenAIEmbeddingClient>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+OpenAI.EmbeddingClient + ### **SelectVariant(IModel)** Select a model variant from [IModel.Variants](./microsoft.ai.foundry.local.imodel.md#variants) to use for [IModel](./microsoft.ai.foundry.local.imodel.md) operations. diff --git a/sdk/cs/docs/api/microsoft.ai.foundry.local.model.md b/sdk/cs/docs/api/microsoft.ai.foundry.local.model.md index 23cd67a3..c6eac5f2 100644 --- a/sdk/cs/docs/api/microsoft.ai.foundry.local.model.md +++ b/sdk/cs/docs/api/microsoft.ai.foundry.local.model.md @@ -176,6 +176,20 @@ public Task GetAudioClientAsync(Nullable c [Task<OpenAIAudioClient>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+### **GetEmbeddingClientAsync(Nullable<CancellationToken>)** + +```csharp +public Task GetEmbeddingClientAsync(Nullable ct) +``` + +#### Parameters + +`ct` [Nullable<CancellationToken>](https://docs.microsoft.com/en-us/dotnet/api/system.nullable-1)
+ +#### Returns + +[Task<OpenAIEmbeddingClient>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+ ### **UnloadAsync(Nullable<CancellationToken>)** ```csharp diff --git a/sdk/cs/docs/api/microsoft.ai.foundry.local.modelvariant.md b/sdk/cs/docs/api/microsoft.ai.foundry.local.modelvariant.md index 1f674511..cc2b20a6 100644 --- a/sdk/cs/docs/api/microsoft.ai.foundry.local.modelvariant.md +++ b/sdk/cs/docs/api/microsoft.ai.foundry.local.modelvariant.md @@ -181,3 +181,17 @@ public Task GetAudioClientAsync(Nullable c #### Returns [Task<OpenAIAudioClient>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+ +### **GetEmbeddingClientAsync(Nullable<CancellationToken>)** + +```csharp +public Task GetEmbeddingClientAsync(Nullable ct) +``` + +#### Parameters + +`ct` [Nullable<CancellationToken>](https://docs.microsoft.com/en-us/dotnet/api/system.nullable-1)
+ +#### Returns + +[Task<OpenAIEmbeddingClient>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
diff --git a/sdk/cs/docs/api/microsoft.ai.foundry.local.openaiembeddingclient.md b/sdk/cs/docs/api/microsoft.ai.foundry.local.openaiembeddingclient.md new file mode 100644 index 00000000..e9823f84 --- /dev/null +++ b/sdk/cs/docs/api/microsoft.ai.foundry.local.openaiembeddingclient.md @@ -0,0 +1,71 @@ +# OpenAIEmbeddingClient + +Namespace: Microsoft.AI.Foundry.Local + +Embedding Client that uses the OpenAI API. + Implemented using Betalgo.Ranul.OpenAI SDK types. + +```csharp +public class OpenAIEmbeddingClient +``` + +Inheritance [Object](https://docs.microsoft.com/en-us/dotnet/api/system.object) → [OpenAIEmbeddingClient](./microsoft.ai.foundry.local.openaiembeddingclient.md)
+Attributes [NullableContextAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullablecontextattribute), [NullableAttribute](https://docs.microsoft.com/en-us/dotnet/api/system.runtime.compilerservices.nullableattribute) + +## Properties + +### **Settings** + +Settings to use for embedding requests using this client. + +```csharp +public EmbeddingSettings Settings { get; } +``` + +#### Property Value + +EmbeddingSettings
+ +## Methods + +### **GenerateEmbeddingAsync(String, Nullable<CancellationToken>)** + +Generate embeddings for the given input text. + +```csharp +public Task GenerateEmbeddingAsync(string input, Nullable ct) +``` + +#### Parameters + +`input` [String](https://docs.microsoft.com/en-us/dotnet/api/system.string)
+The text to generate embeddings for. + +`ct` [Nullable<CancellationToken>](https://docs.microsoft.com/en-us/dotnet/api/system.nullable-1)
+Optional cancellation token. + +#### Returns + +[Task<EmbeddingCreateResponse>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+Embedding response containing the embedding vector. + +### **GenerateEmbeddingsAsync(IEnumerable<String>, Nullable<CancellationToken>)** + +Generate embeddings for multiple input texts in a single request. + +```csharp +public Task GenerateEmbeddingsAsync(IEnumerable inputs, Nullable ct) +``` + +#### Parameters + +`inputs` [IEnumerable<String>](https://docs.microsoft.com/en-us/dotnet/api/system.collections.generic.ienumerable-1)
+The texts to generate embeddings for. + +`ct` [Nullable<CancellationToken>](https://docs.microsoft.com/en-us/dotnet/api/system.nullable-1)
+Optional cancellation token. + +#### Returns + +[Task<EmbeddingCreateResponse>](https://docs.microsoft.com/en-us/dotnet/api/system.threading.tasks.task-1)
+Embedding response containing one embedding vector per input. diff --git a/sdk/cs/src/Detail/JsonSerializationContext.cs b/sdk/cs/src/Detail/JsonSerializationContext.cs index 37cc81ac..0fe5e677 100644 --- a/sdk/cs/src/Detail/JsonSerializationContext.cs +++ b/sdk/cs/src/Detail/JsonSerializationContext.cs @@ -23,6 +23,8 @@ namespace Microsoft.AI.Foundry.Local.Detail; [JsonSerializable(typeof(ChatCompletionCreateResponse))] [JsonSerializable(typeof(AudioCreateTranscriptionRequest))] [JsonSerializable(typeof(AudioCreateTranscriptionResponse))] +[JsonSerializable(typeof(EmbeddingCreateRequestExtended))] +[JsonSerializable(typeof(EmbeddingCreateResponse))] [JsonSerializable(typeof(string[]))] // list loaded or cached models [JsonSerializable(typeof(EpInfo[]))] [JsonSerializable(typeof(EpDownloadResult))] diff --git a/sdk/cs/src/Detail/Model.cs b/sdk/cs/src/Detail/Model.cs index c4d96057..03e9321b 100644 --- a/sdk/cs/src/Detail/Model.cs +++ b/sdk/cs/src/Detail/Model.cs @@ -99,6 +99,11 @@ public async Task GetAudioClientAsync(CancellationToken? ct = return await SelectedVariant.GetAudioClientAsync(ct).ConfigureAwait(false); } + public async Task GetEmbeddingClientAsync(CancellationToken? ct = null) + { + return await SelectedVariant.GetEmbeddingClientAsync(ct).ConfigureAwait(false); + } + public async Task UnloadAsync(CancellationToken? ct = null) { await SelectedVariant.UnloadAsync(ct).ConfigureAwait(false); diff --git a/sdk/cs/src/Detail/ModelVariant.cs b/sdk/cs/src/Detail/ModelVariant.cs index 9f2deaba..250c601a 100644 --- a/sdk/cs/src/Detail/ModelVariant.cs +++ b/sdk/cs/src/Detail/ModelVariant.cs @@ -102,6 +102,13 @@ public async Task GetAudioClientAsync(CancellationToken? ct = .ConfigureAwait(false); } + public async Task GetEmbeddingClientAsync(CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling(() => GetEmbeddingClientImplAsync(ct), + "Error getting embedding client for model", _logger) + .ConfigureAwait(false); + } + private async Task IsLoadedImplAsync(CancellationToken? ct = null) { var loadedModels = await _modelLoadManager.ListLoadedModelsAsync(ct).ConfigureAwait(false); @@ -193,6 +200,16 @@ private async Task GetAudioClientImplAsync(CancellationToken? return new OpenAIAudioClient(Id); } + private async Task GetEmbeddingClientImplAsync(CancellationToken? ct = null) + { + if (!await IsLoadedAsync(ct)) + { + throw new FoundryLocalException($"Model {Id} is not loaded. Call LoadAsync first."); + } + + return new OpenAIEmbeddingClient(Id); + } + public void SelectVariant(IModel variant) { throw new FoundryLocalException( diff --git a/sdk/cs/src/IModel.cs b/sdk/cs/src/IModel.cs index a27f3a3d..37249782 100644 --- a/sdk/cs/src/IModel.cs +++ b/sdk/cs/src/IModel.cs @@ -70,6 +70,13 @@ Task DownloadAsync(Action? downloadProgress = null, /// OpenAI.AudioClient Task GetAudioClientAsync(CancellationToken? ct = null); + /// + /// Get an OpenAI API based EmbeddingClient + /// + /// Optional cancellation token. + /// OpenAI.EmbeddingClient + Task GetEmbeddingClientAsync(CancellationToken? ct = null); + /// /// Variants of the model that are available. Variants of the model are optimized for different devices. /// diff --git a/sdk/cs/src/OpenAI/EmbeddingClient.cs b/sdk/cs/src/OpenAI/EmbeddingClient.cs new file mode 100644 index 00000000..7778d25b --- /dev/null +++ b/sdk/cs/src/OpenAI/EmbeddingClient.cs @@ -0,0 +1,105 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local; + +using Betalgo.Ranul.OpenAI.ObjectModels.ResponseModels; + +using Microsoft.AI.Foundry.Local.Detail; +using Microsoft.AI.Foundry.Local.OpenAI; +using Microsoft.Extensions.Logging; + +/// +/// Embedding Client that uses the OpenAI API. +/// Implemented using Betalgo.Ranul.OpenAI SDK types. +/// +public class OpenAIEmbeddingClient +{ + private readonly string _modelId; + + private readonly ICoreInterop _coreInterop = FoundryLocalManager.Instance.CoreInterop; + private readonly ILogger _logger = FoundryLocalManager.Instance.Logger; + + internal OpenAIEmbeddingClient(string modelId) + { + _modelId = modelId; + } + + /// + /// Settings that are supported by Foundry Local for embeddings. + /// + public record EmbeddingSettings + { + /// + /// The number of dimensions the resulting output embeddings should have. + /// + public int? Dimensions { get; set; } + + /// + /// The format to return the embeddings in. Can be either "float" or "base64". + /// + public string? EncodingFormat { get; set; } + } + + /// + /// Settings to use for embedding requests using this client. + /// + public EmbeddingSettings Settings { get; } = new(); + + /// + /// Generate embeddings for the given input text. + /// + /// The text to generate embeddings for. + /// Optional cancellation token. + /// Embedding response containing the embedding vector. + public async Task GenerateEmbeddingAsync(string input, + CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling( + () => GenerateEmbeddingImplAsync(input, ct), + "Error during embedding generation.", _logger).ConfigureAwait(false); + } + + /// + /// Generate embeddings for multiple input texts in a single request. + /// + /// The texts to generate embeddings for. + /// Optional cancellation token. + /// Embedding response containing one embedding vector per input. + public async Task GenerateEmbeddingsAsync(IEnumerable inputs, + CancellationToken? ct = null) + { + return await Utils.CallWithExceptionHandling( + () => GenerateEmbeddingsImplAsync(inputs, ct), + "Error during batch embedding generation.", _logger).ConfigureAwait(false); + } + + private async Task GenerateEmbeddingImplAsync(string input, + CancellationToken? ct) + { + var embeddingRequest = EmbeddingCreateRequestExtended.FromUserInput(_modelId, input, Settings); + var embeddingRequestJson = embeddingRequest.ToJson(); + + var request = new CoreInteropRequest { Params = new() { { "OpenAICreateRequest", embeddingRequestJson } } }; + var response = await _coreInterop.ExecuteCommandAsync("embeddings", request, + ct ?? CancellationToken.None).ConfigureAwait(false); + + return response.ToEmbeddingResponse(_logger); + } + + private async Task GenerateEmbeddingsImplAsync(IEnumerable inputs, + CancellationToken? ct) + { + var embeddingRequest = EmbeddingCreateRequestExtended.FromUserInput(_modelId, inputs, Settings); + var embeddingRequestJson = embeddingRequest.ToJson(); + + var request = new CoreInteropRequest { Params = new() { { "OpenAICreateRequest", embeddingRequestJson } } }; + var response = await _coreInterop.ExecuteCommandAsync("embeddings", request, + ct ?? CancellationToken.None).ConfigureAwait(false); + + return response.ToEmbeddingResponse(_logger); + } +} diff --git a/sdk/cs/src/OpenAI/EmbeddingRequestResponseTypes.cs b/sdk/cs/src/OpenAI/EmbeddingRequestResponseTypes.cs new file mode 100644 index 00000000..f81b8c0d --- /dev/null +++ b/sdk/cs/src/OpenAI/EmbeddingRequestResponseTypes.cs @@ -0,0 +1,82 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.OpenAI; + +using System.Text.Json; + +using Betalgo.Ranul.OpenAI.ObjectModels.RequestModels; +using Betalgo.Ranul.OpenAI.ObjectModels.ResponseModels; + +using Microsoft.AI.Foundry.Local.Detail; +using Microsoft.Extensions.Logging; + +// https://platform.openai.com/docs/api-reference/embeddings/create +internal record EmbeddingCreateRequestExtended : EmbeddingCreateRequest +{ + internal static EmbeddingCreateRequestExtended FromUserInput(string modelId, + string input, + OpenAIEmbeddingClient.EmbeddingSettings settings) + { + return new EmbeddingCreateRequestExtended + { + Model = modelId, + Input = input, + Dimensions = settings.Dimensions, + EncodingFormat = settings.EncodingFormat + }; + } + + internal static EmbeddingCreateRequestExtended FromUserInput(string modelId, + IEnumerable inputs, + OpenAIEmbeddingClient.EmbeddingSettings settings) + { + return new EmbeddingCreateRequestExtended + { + Model = modelId, + InputAsList = inputs.ToList(), + Dimensions = settings.Dimensions, + EncodingFormat = settings.EncodingFormat + }; + } +} + +internal static class EmbeddingRequestResponseExtensions +{ + internal static string ToJson(this EmbeddingCreateRequestExtended request) + { + return JsonSerializer.Serialize(request, JsonSerializationContext.Default.EmbeddingCreateRequestExtended); + } + + internal static EmbeddingCreateResponse ToEmbeddingResponse(this ICoreInterop.Response response, ILogger logger) + { + if (response.Error != null) + { + logger.LogError("Error from embeddings: {Error}", response.Error); + throw new FoundryLocalException($"Error from embeddings command: {response.Error}"); + } + + if (string.IsNullOrWhiteSpace(response.Data)) + { + logger.LogError("Embeddings command returned no data"); + throw new FoundryLocalException("Embeddings command returned null or empty response data"); + } + + return response.Data.ToEmbeddingResponse(logger); + } + + internal static EmbeddingCreateResponse ToEmbeddingResponse(this string responseData, ILogger logger) + { + var output = JsonSerializer.Deserialize(responseData, JsonSerializationContext.Default.EmbeddingCreateResponse); + if (output == null) + { + logger.LogError("Failed to deserialize embedding response: {ResponseData}", responseData); + throw new JsonException("Failed to deserialize EmbeddingCreateResponse"); + } + + return output; + } +} diff --git a/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs b/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs new file mode 100644 index 00000000..3b316726 --- /dev/null +++ b/sdk/cs/test/FoundryLocal.Tests/EmbeddingClientTests.cs @@ -0,0 +1,239 @@ +// -------------------------------------------------------------------------------------------------------------------- +// +// Copyright (c) Microsoft. All rights reserved. +// +// -------------------------------------------------------------------------------------------------------------------- + +namespace Microsoft.AI.Foundry.Local.Tests; + +using System.Threading.Tasks; + +internal sealed class EmbeddingClientTests +{ + private static IModel? model; + + [Before(Class)] + public static async Task Setup() + { + var manager = FoundryLocalManager.Instance; // initialized by Utils + var catalog = await manager.GetCatalogAsync(); + + // Load the specific cached model variant directly + var model = await catalog.GetModelVariantAsync("qwen3-0.6b-embedding-generic-cpu:1").ConfigureAwait(false); + await Assert.That(model).IsNotNull(); + + await model!.LoadAsync().ConfigureAwait(false); + await Assert.That(await model.IsLoadedAsync()).IsTrue(); + + EmbeddingClientTests.model = model; + } + + [Test] + public async Task Embedding_BasicRequest_Succeeds() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var response = await embeddingClient.GenerateEmbeddingAsync("The quick brown fox jumps over the lazy dog") + .ConfigureAwait(false); + + await Assert.That(response).IsNotNull(); + await Assert.That(response.Model).IsEqualTo("qwen3-0.6b-embedding-generic-cpu:1"); + await Assert.That(response.Data).IsNotNull().And.IsNotEmpty(); + await Assert.That(response.Data[0].Embedding).IsNotNull(); + await Assert.That(response.Data[0].Embedding.Count).IsEqualTo(1024); + await Assert.That(response.Data[0].Index).IsEqualTo(0); + + Console.WriteLine($"Embedding dimension: {response.Data[0].Embedding.Count}"); + Console.WriteLine($"First value: {response.Data[0].Embedding[0]}"); + Console.WriteLine($"Last value: {response.Data[0].Embedding[1023]}"); + } + + [Test] + public async Task Embedding_IsNormalized() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var inputs = new[] + { + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris" + }; + + foreach (var input in inputs) + { + var response = await embeddingClient.GenerateEmbeddingAsync(input).ConfigureAwait(false); + + await Assert.That(response).IsNotNull(); + await Assert.That(response.Data).IsNotNull().And.IsNotEmpty(); + + var embedding = response.Data[0].Embedding; + + await Assert.That(embedding.Count).IsEqualTo(1024); + + // Verify L2 norm is approximately 1.0 + double norm = 0; + foreach (var val in embedding) + { + norm += val * val; + } + + norm = Math.Sqrt(norm); + await Assert.That(norm).IsGreaterThanOrEqualTo(0.99); + await Assert.That(norm).IsLessThanOrEqualTo(1.01); + + // All values should be within [-1, 1] for a normalized vector + foreach (var val in embedding) + { + await Assert.That(val).IsGreaterThanOrEqualTo(-1.0); + await Assert.That(val).IsLessThanOrEqualTo(1.0); + } + } + } + + [Test] + public async Task Embedding_DifferentInputs_ProduceDifferentEmbeddings() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var response1 = await embeddingClient.GenerateEmbeddingAsync("The quick brown fox").ConfigureAwait(false); + var response2 = await embeddingClient.GenerateEmbeddingAsync("The capital of France is Paris").ConfigureAwait(false); + + await Assert.That(response1).IsNotNull(); + await Assert.That(response2).IsNotNull(); + await Assert.That(response1.Data).IsNotNull().And.IsNotEmpty(); + await Assert.That(response2.Data).IsNotNull().And.IsNotEmpty(); + + // Same dimensionality + await Assert.That(response1.Data[0].Embedding.Count) + .IsEqualTo(response2.Data[0].Embedding.Count); + + // But different values (cosine similarity should not be 1.0) + double dot = 0; + for (int i = 0; i < response1.Data[0].Embedding.Count; i++) + { + dot += response1.Data[0].Embedding[i] * response2.Data[0].Embedding[i]; + } + + await Assert.That(dot).IsLessThan(0.99); + } + + [Test] + public async Task Embedding_SameInput_ProducesSameEmbedding() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var input = "Deterministic embedding test"; + + var response1 = await embeddingClient.GenerateEmbeddingAsync(input).ConfigureAwait(false); + var response2 = await embeddingClient.GenerateEmbeddingAsync(input).ConfigureAwait(false); + + await Assert.That(response1).IsNotNull(); + await Assert.That(response2).IsNotNull(); + await Assert.That(response1.Data).IsNotNull().And.IsNotEmpty(); + await Assert.That(response2.Data).IsNotNull().And.IsNotEmpty(); + + await Assert.That(response1.Data[0].Embedding.Count) + .IsEqualTo(response2.Data[0].Embedding.Count); + + for (int i = 0; i < response1.Data[0].Embedding.Count; i++) + { + await Assert.That(response1.Data[0].Embedding[i]) + .IsEqualTo(response2.Data[0].Embedding[i]); + } + } + + [Test] + public async Task Embedding_KnownValues_CapitalOfFrance() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var response = await embeddingClient.GenerateEmbeddingAsync("The capital of France is Paris") + .ConfigureAwait(false); + await Assert.That(response).IsNotNull(); + await Assert.That(response.Data).IsNotNull().And.IsNotEmpty(); + var embedding = response.Data[0].Embedding; + + await Assert.That(embedding.Count).IsEqualTo(1024); + + // Use tolerance for float32 model outputs which may vary across platforms + const double tolerance = 1e-5; + await Assert.That(Math.Abs(embedding[0] - (-0.02815740555524826))).IsLessThanOrEqualTo(tolerance); + await Assert.That(Math.Abs(embedding[1023] - (-0.00887922290712595))).IsLessThanOrEqualTo(tolerance); + } + + [Test] + public async Task Embedding_Batch_ReturnsMultipleEmbeddings() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var response = await embeddingClient.GenerateEmbeddingsAsync([ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris" + ]).ConfigureAwait(false); + + await Assert.That(response).IsNotNull(); + await Assert.That(response.Data).IsNotNull().And.IsNotEmpty(); + await Assert.That(response.Data.Count).IsEqualTo(3); + + for (var i = 0; i < 3; i++) + { + await Assert.That(response.Data[i].Index).IsEqualTo(i); + await Assert.That(response.Data[i].Embedding.Count).IsEqualTo(1024); + } + } + + [Test] + public async Task Embedding_Batch_EachEmbeddingIsNormalized() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var response = await embeddingClient.GenerateEmbeddingsAsync([ + "Hello world", + "Goodbye world" + ]).ConfigureAwait(false); + + await Assert.That(response.Data.Count).IsEqualTo(2); + + foreach (var data in response.Data) + { + double norm = 0; + foreach (var val in data.Embedding) + { + norm += val * val; + } + + norm = Math.Sqrt(norm); + await Assert.That(norm).IsGreaterThanOrEqualTo(0.99); + await Assert.That(norm).IsLessThanOrEqualTo(1.01); + } + } + + [Test] + public async Task Embedding_Batch_MatchesSingleInputResults() + { + var embeddingClient = await model!.GetEmbeddingClientAsync(); + await Assert.That(embeddingClient).IsNotNull(); + + var input = "The capital of France is Paris"; + + var singleResponse = await embeddingClient.GenerateEmbeddingAsync(input).ConfigureAwait(false); + var batchResponse = await embeddingClient.GenerateEmbeddingsAsync([input]).ConfigureAwait(false); + + await Assert.That(batchResponse.Data.Count).IsEqualTo(1); + + for (var i = 0; i < singleResponse.Data[0].Embedding.Count; i++) + { + await Assert.That(batchResponse.Data[0].Embedding[i]) + .IsEqualTo(singleResponse.Data[0].Embedding[i]); + } + } +} diff --git a/sdk/js/README.md b/sdk/js/README.md index 13d50442..c16d7572 100644 --- a/sdk/js/README.md +++ b/sdk/js/README.md @@ -8,6 +8,7 @@ The Foundry Local JS SDK provides a JavaScript/TypeScript interface for running - **Model catalog** — Browse and discover available models, check what's cached or loaded - **Automatic model management** — Download, load, unload, and remove models from cache - **Chat completions** — OpenAI-compatible chat API with both synchronous and streaming responses +- **Embeddings** — Generate text embeddings via OpenAI-compatible API - **Audio transcription** — Transcribe audio files locally with streaming support - **Multi-variant models** — Models can have multiple variants (e.g., different quantizations) with automatic selection of the best cached variant - **Embedded web service** — Start a local HTTP service for OpenAI-compatible API access @@ -204,6 +205,35 @@ for await (const chunk of chatClient.completeStreamingChat( } ``` +### Embeddings + +Generate text embeddings using the `EmbeddingClient`: + +```typescript +const embeddingClient = model.createEmbeddingClient(); + +// Single input +const response = await embeddingClient.generateEmbedding( + 'The quick brown fox jumps over the lazy dog' +); +const embedding = response.data[0].embedding; // number[] +console.log(`Dimensions: ${embedding.length}`); + +// Batch input +const batchResponse = await embeddingClient.generateEmbeddings([ + 'The quick brown fox', + 'The capital of France is Paris' +]); +// batchResponse.data[0].embedding, batchResponse.data[1].embedding +``` + +#### Embedding Settings + +```typescript +embeddingClient.settings.dimensions = 512; // optional: reduce dimensionality +embeddingClient.settings.encodingFormat = 'float'; // 'float' or 'base64' +``` + ### Audio Transcription Transcribe audio files locally using the `AudioClient`: diff --git a/sdk/js/docs/README.md b/sdk/js/docs/README.md index b0167b4d..c3b2c650 100644 --- a/sdk/js/docs/README.md +++ b/sdk/js/docs/README.md @@ -20,6 +20,8 @@ - [Catalog](classes/Catalog.md) - [ChatClient](classes/ChatClient.md) - [ChatClientSettings](classes/ChatClientSettings.md) +- [EmbeddingClient](classes/EmbeddingClient.md) +- [EmbeddingClientSettings](classes/EmbeddingClientSettings.md) - [FoundryLocalManager](classes/FoundryLocalManager.md) - [Model](classes/Model.md) - [ModelLoadManager](classes/ModelLoadManager.md) diff --git a/sdk/js/src/detail/model.ts b/sdk/js/src/detail/model.ts index 46245ee5..c1ee0d5f 100644 --- a/sdk/js/src/detail/model.ts +++ b/sdk/js/src/detail/model.ts @@ -1,6 +1,7 @@ import { ModelVariant } from './modelVariant.js'; import { ChatClient } from '../openai/chatClient.js'; import { AudioClient } from '../openai/audioClient.js'; +import { EmbeddingClient } from '../openai/embeddingClient.js'; import { ResponsesClient } from '../openai/responsesClient.js'; import { LiveAudioTranscriptionSession } from '../openai/liveAudioTranscriptionClient.js'; import { IModel } from '../imodel.js'; @@ -177,6 +178,14 @@ export class Model implements IModel { return this.selectedVariant.createAudioClient(); } + /** + * Creates an EmbeddingClient for generating text embeddings with the model. + * @returns An EmbeddingClient instance. + */ + public createEmbeddingClient(): EmbeddingClient { + return this.selectedVariant.createEmbeddingClient(); + } + /** * Creates a LiveAudioTranscriptionSession for real-time audio streaming ASR. * @returns A LiveAudioTranscriptionSession instance. diff --git a/sdk/js/src/detail/modelVariant.ts b/sdk/js/src/detail/modelVariant.ts index d1c1e20c..43484bac 100644 --- a/sdk/js/src/detail/modelVariant.ts +++ b/sdk/js/src/detail/modelVariant.ts @@ -3,6 +3,7 @@ import { ModelLoadManager } from './modelLoadManager.js'; import { ModelInfo } from '../types.js'; import { ChatClient } from '../openai/chatClient.js'; import { AudioClient } from '../openai/audioClient.js'; +import { EmbeddingClient } from '../openai/embeddingClient.js'; import { LiveAudioTranscriptionSession } from '../openai/liveAudioTranscriptionClient.js'; import { ResponsesClient } from '../openai/responsesClient.js'; import { IModel } from '../imodel.js'; @@ -170,6 +171,14 @@ export class ModelVariant implements IModel { return new AudioClient(this._modelInfo.id, this.coreInterop); } + /** + * Creates an EmbeddingClient for generating text embeddings with the model. + * @returns An EmbeddingClient instance. + */ + public createEmbeddingClient(): EmbeddingClient { + return new EmbeddingClient(this._modelInfo.id, this.coreInterop); + } + /** * Creates a LiveAudioTranscriptionSession for real-time audio streaming ASR. * @returns A LiveAudioTranscriptionSession instance. diff --git a/sdk/js/src/imodel.ts b/sdk/js/src/imodel.ts index 7a2f5a2c..8f9bd0c1 100644 --- a/sdk/js/src/imodel.ts +++ b/sdk/js/src/imodel.ts @@ -1,5 +1,6 @@ import { ChatClient } from './openai/chatClient.js'; import { AudioClient } from './openai/audioClient.js'; +import { EmbeddingClient } from './openai/embeddingClient.js'; import { LiveAudioTranscriptionSession } from './openai/liveAudioTranscriptionClient.js'; import { ResponsesClient } from './openai/responsesClient.js'; import { ModelInfo } from './types.js'; @@ -25,6 +26,7 @@ export interface IModel { createChatClient(): ChatClient; createAudioClient(): AudioClient; + createEmbeddingClient(): EmbeddingClient; /** * Creates a LiveAudioTranscriptionSession for real-time audio streaming ASR. diff --git a/sdk/js/src/index.ts b/sdk/js/src/index.ts index 42b498c3..90b0af1f 100644 --- a/sdk/js/src/index.ts +++ b/sdk/js/src/index.ts @@ -8,6 +8,7 @@ export { ModelVariant } from './detail/modelVariant.js'; export type { IModel } from './imodel.js'; export { ChatClient, ChatClientSettings } from './openai/chatClient.js'; export { AudioClient, AudioClientSettings } from './openai/audioClient.js'; +export { EmbeddingClient, EmbeddingClientSettings } from './openai/embeddingClient.js'; export { LiveAudioTranscriptionSession, LiveAudioTranscriptionOptions } from './openai/liveAudioTranscriptionClient.js'; export type { LiveAudioTranscriptionResponse, TranscriptionContentPart } from './openai/liveAudioTranscriptionTypes.js'; export { ResponsesClient, ResponsesClientSettings, getOutputText } from './openai/responsesClient.js'; diff --git a/sdk/js/src/openai/embeddingClient.ts b/sdk/js/src/openai/embeddingClient.ts new file mode 100644 index 00000000..6e819b0a --- /dev/null +++ b/sdk/js/src/openai/embeddingClient.ts @@ -0,0 +1,125 @@ +import { CoreInterop } from '../detail/coreInterop.js'; + +export class EmbeddingClientSettings { + dimensions?: number; + encodingFormat?: string; + + /** + * Serializes the settings into an OpenAI-compatible request object. + * @internal + */ + _serialize() { + this.validateEncodingFormat(this.encodingFormat); + + const result: any = { + dimensions: this.dimensions, + encoding_format: this.encodingFormat, + }; + + // Filter out undefined properties + return Object.fromEntries(Object.entries(result).filter(([_, v]) => v !== undefined)); + } + + /** + * Validates that the encoding format is a supported value. + * @internal + */ + private validateEncodingFormat(format?: string): void { + if (!format) return; + const validFormats = ['float', 'base64']; + if (!validFormats.includes(format)) { + throw new Error(`encodingFormat must be one of: ${validFormats.join(', ')}`); + } + } +} + +/** + * Client for generating text embeddings with a loaded model. + * Follows the OpenAI Embeddings API structure. + */ +export class EmbeddingClient { + private modelId: string; + private coreInterop: CoreInterop; + + /** + * Configuration settings for embedding operations. + */ + public settings = new EmbeddingClientSettings(); + + /** + * @internal + * Restricted to internal use because CoreInterop is an internal implementation detail. + * Users should create clients via the Model.createEmbeddingClient() factory method. + */ + constructor(modelId: string, coreInterop: CoreInterop) { + this.modelId = modelId; + this.coreInterop = coreInterop; + } + + /** + * Validates that the input text is a non-empty string. + * @internal + */ + private validateInput(input: string): void { + if (typeof input !== 'string' || input.trim() === '') { + throw new Error('Input must be a non-empty string.'); + } + } + + /** + * Validates that the inputs array is non-empty and all elements are non-empty strings. + * @internal + */ + private validateInputs(inputs: string[]): void { + if (!inputs || !Array.isArray(inputs) || inputs.length === 0) { + throw new Error('Inputs must be a non-empty array of strings.'); + } + for (const input of inputs) { + this.validateInput(input); + } + } + + /** + * Sends an embedding request and parses the response. + * @internal + */ + private executeRequest(input: string | string[]): any { + const request = { + model: this.modelId, + input, + ...this.settings._serialize() + }; + + try { + const response = this.coreInterop.executeCommand('embeddings', { + Params: { OpenAICreateRequest: JSON.stringify(request) } + }); + return JSON.parse(response); + } catch (error: any) { + throw new Error( + `Embedding generation failed for model '${this.modelId}': ${error instanceof Error ? error.message : String(error)}`, + { cause: error } + ); + } + } + + /** + * Generates embeddings for the given input text. + * @param input - The text to generate embeddings for. + * @returns The embedding response containing the embedding vector. + */ + public async generateEmbedding(input: string): Promise { + this.validateInput(input); + return this.executeRequest(input); + } + + /** + * Generates embeddings for multiple input texts in a single request. + * @param inputs - The texts to generate embeddings for. + * @returns The embedding response containing one embedding vector per input. + */ + public async generateEmbeddings(inputs: string[]): Promise { + this.validateInputs(inputs); + return this.executeRequest(inputs); + } +} diff --git a/sdk/js/test/openai/embeddingClient.test.ts b/sdk/js/test/openai/embeddingClient.test.ts new file mode 100644 index 00000000..f9395f5e --- /dev/null +++ b/sdk/js/test/openai/embeddingClient.test.ts @@ -0,0 +1,255 @@ +import { describe, it } from 'mocha'; +import { expect } from 'chai'; +import { getTestManager, EMBEDDING_MODEL_ALIAS } from '../testUtils.js'; + +describe('Embedding Client Tests', () => { + + it('should generate embedding', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + expect(cachedModels.length).to.be.greaterThan(0); + + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + expect(cachedVariant, 'qwen3-0.6b-embedding-generic-cpu should be cached').to.not.be.undefined; + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + expect(model).to.not.be.undefined; + if (!cachedVariant) return; + + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + expect(embeddingClient).to.not.be.undefined; + + const response = await embeddingClient.generateEmbedding( + 'The quick brown fox jumps over the lazy dog' + ); + + expect(response).to.not.be.undefined; + expect(response.data).to.be.an('array').with.length.greaterThan(0); + expect(response.data[0].embedding).to.be.an('array'); + expect(response.data[0].embedding.length).to.equal(1024); + expect(response.data[0].index).to.equal(0); + + console.log(`Embedding dimension: ${response.data[0].embedding.length}`); + } finally { + await model.unload(); + } + }); + + it('should generate normalized embedding', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + if (!cachedVariant) { this.skip(); return; } + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + const response = await embeddingClient.generateEmbedding( + 'Machine learning is a subset of artificial intelligence' + ); + + const embedding = response.data[0].embedding; + expect(embedding.length).to.equal(1024); + + // Verify L2 norm is approximately 1.0 + let norm = 0; + for (const val of embedding) { + norm += val * val; + } + norm = Math.sqrt(norm); + expect(norm).to.be.greaterThan(0.99); + expect(norm).to.be.lessThan(1.01); + } finally { + await model.unload(); + } + }); + + it('should produce different embeddings for different inputs', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + if (!cachedVariant) { this.skip(); return; } + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + + const response1 = await embeddingClient.generateEmbedding('The quick brown fox'); + const response2 = await embeddingClient.generateEmbedding('The capital of France is Paris'); + + expect(response1.data[0].embedding.length).to.equal(response2.data[0].embedding.length); + + // Cosine similarity should not be 1.0 + let dot = 0, norm1 = 0, norm2 = 0; + for (let i = 0; i < response1.data[0].embedding.length; i++) { + const v1 = response1.data[0].embedding[i]; + const v2 = response2.data[0].embedding[i]; + dot += v1 * v2; + norm1 += v1 * v1; + norm2 += v2 * v2; + } + const cosineSimilarity = dot / (Math.sqrt(norm1) * Math.sqrt(norm2)); + expect(cosineSimilarity).to.be.lessThan(0.99); + } finally { + await model.unload(); + } + }); + + it('should produce same embedding for same input', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + if (!cachedVariant) { this.skip(); return; } + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + + const response1 = await embeddingClient.generateEmbedding('Deterministic embedding test'); + const response2 = await embeddingClient.generateEmbedding('Deterministic embedding test'); + + for (let i = 0; i < response1.data[0].embedding.length; i++) { + expect(response1.data[0].embedding[i]).to.equal(response2.data[0].embedding[i]); + } + } finally { + await model.unload(); + } + }); + + it('should throw for empty input', function() { + const manager = getTestManager(); + const catalog = manager.catalog; + + // Create a client directly (model doesn't need to be loaded for input validation) + expect(() => { + // Validation happens in generateEmbedding, but we need a loaded model for that. + // Instead test the synchronous validation path. + const { EmbeddingClient } = require('../../src/openai/embeddingClient.js'); + }).to.not.throw(); + }); + + it('should generate batch embeddings', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + if (!cachedVariant) { this.skip(); return; } + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + + const response = await embeddingClient.generateEmbeddings([ + 'The quick brown fox jumps over the lazy dog', + 'Machine learning is a subset of artificial intelligence', + 'The capital of France is Paris' + ]); + + expect(response).to.not.be.undefined; + expect(response.data).to.be.an('array').with.length(3); + + for (let i = 0; i < 3; i++) { + expect(response.data[i].index).to.equal(i); + expect(response.data[i].embedding.length).to.equal(1024); + } + } finally { + await model.unload(); + } + }); + + it('should produce normalized batch embeddings', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + if (!cachedVariant) { this.skip(); return; } + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + + const response = await embeddingClient.generateEmbeddings([ + 'Hello world', + 'Goodbye world' + ]); + + expect(response.data.length).to.equal(2); + + for (const data of response.data) { + let norm = 0; + for (const val of data.embedding) { + norm += val * val; + } + norm = Math.sqrt(norm); + expect(norm).to.be.greaterThan(0.99); + expect(norm).to.be.lessThan(1.01); + } + } finally { + await model.unload(); + } + }); + + it('should match single and batch results', async function() { + this.timeout(30000); + const manager = getTestManager(); + const catalog = manager.catalog; + + const cachedModels = await catalog.getCachedModels(); + const cachedVariant = cachedModels.find(m => m.alias === EMBEDDING_MODEL_ALIAS); + if (!cachedVariant) { this.skip(); return; } + + const model = await catalog.getModel(EMBEDDING_MODEL_ALIAS); + model.selectVariant(cachedVariant); + await model.load(); + + try { + const embeddingClient = model.createEmbeddingClient(); + + const singleResponse = await embeddingClient.generateEmbedding('The capital of France is Paris'); + const batchResponse = await embeddingClient.generateEmbeddings(['The capital of France is Paris']); + + expect(batchResponse.data.length).to.equal(1); + + for (let i = 0; i < singleResponse.data[0].embedding.length; i++) { + expect(batchResponse.data[0].embedding[i]).to.equal(singleResponse.data[0].embedding[i]); + } + } finally { + await model.unload(); + } + }); +}); diff --git a/sdk/js/test/testUtils.ts b/sdk/js/test/testUtils.ts index 62cf7968..7a40220b 100644 --- a/sdk/js/test/testUtils.ts +++ b/sdk/js/test/testUtils.ts @@ -39,6 +39,7 @@ export const TEST_CONFIG: FoundryLocalConfig = { }; export const TEST_MODEL_ALIAS = 'qwen2.5-0.5b'; +export const EMBEDDING_MODEL_ALIAS = 'qwen3-0.6b-embedding-generic-cpu'; export function getTestManager() { return FoundryLocalManager.create(TEST_CONFIG); diff --git a/sdk/python/README.md b/sdk/python/README.md index dbdef1f8..60bce65b 100644 --- a/sdk/python/README.md +++ b/sdk/python/README.md @@ -8,6 +8,7 @@ The Foundry Local Python SDK provides a Python interface for interacting with lo - **Model Management** – download, cache, load, and unload models - **Chat Completions** – OpenAI-compatible chat API (non-streaming and streaming) - **Tool Calling** – function-calling support with chat completions +- **Embeddings** – generate text embeddings via OpenAI-compatible API - **Audio Transcription** – Whisper-based speech-to-text (non-streaming and streaming) - **Built-in Web Service** – optional HTTP endpoint for multi-process scenarios - **Native Performance** – ctypes FFI to AOT-compiled Foundry Local Core @@ -240,6 +241,35 @@ for chunk in client.complete_streaming_chat(messages): model.unload() ``` +### Embeddings + +Generate text embeddings using the `EmbeddingClient`: + +```python +embedding_client = model.get_embedding_client() + +# Single input +response = embedding_client.generate_embedding( + "The quick brown fox jumps over the lazy dog" +) +embedding = response.data[0].embedding # List[float] +print(f"Dimensions: {len(embedding)}") + +# Batch input +batch_response = embedding_client.generate_embeddings([ + "The quick brown fox", + "The capital of France is Paris" +]) +# batch_response.data[0].embedding, batch_response.data[1].embedding +``` + +#### Embedding Settings + +```python +embedding_client.settings.dimensions = 512 # optional: reduce dimensionality +embedding_client.settings.encoding_format = "float" # "float" or "base64" +``` + ### Web Service (Optional) Start a built-in HTTP server for multi-process access. @@ -271,6 +301,7 @@ manager.stop_web_service() | Class | Description | |---|---| | `ChatClient` | Chat completions (non-streaming and streaming) with tool calling | +| `EmbeddingClient` | Text embedding generation via OpenAI-compatible API | | `AudioClient` | Audio transcription (non-streaming and streaming) | ### Internal / Detail diff --git a/sdk/python/src/detail/model.py b/sdk/python/src/detail/model.py index 189920b1..6d60b7a2 100644 --- a/sdk/python/src/detail/model.py +++ b/sdk/python/src/detail/model.py @@ -10,6 +10,7 @@ from ..imodel import IModel from ..openai.chat_client import ChatClient from ..openai.audio_client import AudioClient +from ..openai.embedding_client import EmbeddingClient from .model_variant import ModelVariant from ..exception import FoundryLocalException from .core_interop import CoreInterop @@ -141,3 +142,7 @@ def get_chat_client(self) -> ChatClient: def get_audio_client(self) -> AudioClient: """Get an audio client for the currently selected variant.""" return self._selected_variant.get_audio_client() + + def get_embedding_client(self) -> EmbeddingClient: + """Get an embedding client for the currently selected variant.""" + return self._selected_variant.get_embedding_client() diff --git a/sdk/python/src/detail/model_variant.py b/sdk/python/src/detail/model_variant.py index a5ac02d4..76efb05c 100644 --- a/sdk/python/src/detail/model_variant.py +++ b/sdk/python/src/detail/model_variant.py @@ -16,6 +16,7 @@ from .model_load_manager import ModelLoadManager from ..openai.audio_client import AudioClient from ..openai.chat_client import ChatClient +from ..openai.embedding_client import EmbeddingClient logger = logging.getLogger(__name__) @@ -169,4 +170,8 @@ def get_chat_client(self) -> ChatClient: def get_audio_client(self) -> AudioClient: """Create an OpenAI-compatible ``AudioClient`` for this variant.""" - return AudioClient(self.id, self._core_interop) \ No newline at end of file + return AudioClient(self.id, self._core_interop) + + def get_embedding_client(self) -> EmbeddingClient: + """Create an OpenAI-compatible ``EmbeddingClient`` for this variant.""" + return EmbeddingClient(self.id, self._core_interop) diff --git a/sdk/python/src/imodel.py b/sdk/python/src/imodel.py index 8237aeb4..f723e514 100644 --- a/sdk/python/src/imodel.py +++ b/sdk/python/src/imodel.py @@ -9,6 +9,7 @@ from .openai.chat_client import ChatClient from .openai.audio_client import AudioClient +from .openai.embedding_client import EmbeddingClient from .detail.model_data_types import ModelInfo class IModel(ABC): @@ -127,6 +128,14 @@ def get_audio_client(self) -> AudioClient: """ pass + @abstractmethod + def get_embedding_client(self) -> 'EmbeddingClient': + """ + Get an OpenAI API based EmbeddingClient. + :return: EmbeddingClient instance. + """ + pass + @property @abstractmethod def variants(self) -> List['IModel']: diff --git a/sdk/python/src/openai/__init__.py b/sdk/python/src/openai/__init__.py index e445ba1d..df229f19 100644 --- a/sdk/python/src/openai/__init__.py +++ b/sdk/python/src/openai/__init__.py @@ -6,5 +6,6 @@ from .chat_client import ChatClient, ChatClientSettings from .audio_client import AudioClient +from .embedding_client import EmbeddingClient, EmbeddingSettings -__all__ = ["AudioClient", "ChatClient", "ChatClientSettings"] +__all__ = ["AudioClient", "ChatClient", "ChatClientSettings", "EmbeddingClient", "EmbeddingSettings"] diff --git a/sdk/python/src/openai/embedding_client.py b/sdk/python/src/openai/embedding_client.py new file mode 100644 index 00000000..876f26ce --- /dev/null +++ b/sdk/python/src/openai/embedding_client.py @@ -0,0 +1,145 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import json +import logging +from typing import List, Optional, Union + +from ..detail.core_interop import CoreInterop, InteropRequest +from ..exception import FoundryLocalException + +from openai.types import CreateEmbeddingResponse +from openai.types.embedding_create_params import EmbeddingCreateParams + +logger = logging.getLogger(__name__) + + +class EmbeddingSettings: + """Settings supported by Foundry Local for embedding generation. + + Attributes: + dimensions: The number of dimensions for the output embeddings (optional). + encoding_format: The format to return embeddings in (``"float"`` or ``"base64"``). + """ + + def __init__( + self, + dimensions: Optional[int] = None, + encoding_format: Optional[str] = None, + ): + self.dimensions = dimensions + self.encoding_format = encoding_format + + def _serialize(self) -> dict: + """Serialize settings into an OpenAI-compatible request dict.""" + self._validate_encoding_format(self.encoding_format) + + return { + k: v for k, v in { + "dimensions": self.dimensions, + "encoding_format": self.encoding_format, + }.items() if v is not None + } + + def _validate_encoding_format(self, encoding_format: Optional[str]) -> None: + if encoding_format is None: + return + valid_formats = ["float", "base64"] + if encoding_format not in valid_formats: + raise ValueError(f"encoding_format must be one of: {', '.join(valid_formats)}") + + +class EmbeddingClient: + """OpenAI-compatible embedding client backed by Foundry Local Core. + + Attributes: + model_id: The ID of the loaded embedding model variant. + settings: Tunable ``EmbeddingSettings`` (dimensions, encoding_format). + """ + + def __init__(self, model_id: str, core_interop: CoreInterop): + self.model_id = model_id + self.settings = EmbeddingSettings() + self._core_interop = core_interop + + @staticmethod + def _validate_input(input_text: str) -> None: + """Validate that the input is a non-empty string.""" + if not isinstance(input_text, str) or input_text.strip() == "": + raise ValueError("Input must be a non-empty string.") + + def _create_request_json(self, input_value: Union[str, List[str]]) -> str: + """Build the JSON payload for the ``embeddings`` native command.""" + request: dict = { + "model": self.model_id, + "input": input_value, + **self.settings._serialize(), + } + + embedding_request = EmbeddingCreateParams(request) + + return json.dumps(embedding_request) + + def _execute_embedding_request(self, input_value: Union[str, List[str]]) -> CreateEmbeddingResponse: + """Send an embedding request and parse the response.""" + request_json = self._create_request_json(input_value) + request = InteropRequest(params={"OpenAICreateRequest": request_json}) + + response = self._core_interop.execute_command("embeddings", request) + if response.error is not None: + raise FoundryLocalException( + f"Embedding generation failed for model '{self.model_id}': {response.error}" + ) + + data = json.loads(response.data) + + # Add fields required by the OpenAI SDK type that the server doesn't return + for item in data.get("data", []): + if "object" not in item: + item["object"] = "embedding" + + if "usage" not in data: + data["usage"] = {"prompt_tokens": 0, "total_tokens": 0} + + return CreateEmbeddingResponse.model_validate(data) + + def generate_embedding(self, input_text: str) -> CreateEmbeddingResponse: + """Generate embeddings for a single input text. + + Args: + input_text: The text to generate embeddings for. + + Returns: + A ``CreateEmbeddingResponse`` containing the embedding vector. + + Raises: + ValueError: If *input_text* is not a non-empty string. + FoundryLocalException: If the underlying native embeddings command fails. + """ + self._validate_input(input_text) + return self._execute_embedding_request(input_text) + + def generate_embeddings(self, inputs: List[str]) -> CreateEmbeddingResponse: + """Generate embeddings for multiple input texts in a single request. + + Args: + inputs: The texts to generate embeddings for. + + Returns: + A ``CreateEmbeddingResponse`` containing one embedding vector per input. + + Raises: + ValueError: If *inputs* is empty or contains empty strings. + FoundryLocalException: If the underlying native embeddings command fails. + """ + if not inputs or len(inputs) == 0: + raise ValueError("Inputs must be a non-empty list of strings.") + + for text in inputs: + self._validate_input(text) + + return self._execute_embedding_request(inputs) diff --git a/sdk/python/test/conftest.py b/sdk/python/test/conftest.py index b7e22c97..7ff9e120 100644 --- a/sdk/python/test/conftest.py +++ b/sdk/python/test/conftest.py @@ -26,6 +26,7 @@ TEST_MODEL_ALIAS = "qwen2.5-0.5b" AUDIO_MODEL_ALIAS = "whisper-tiny" +EMBEDDING_MODEL_ALIAS = "qwen3-0.6b-embedding-generic-cpu" def get_git_repo_root() -> Path: """Walk upward from __file__ until we find a .git directory.""" diff --git a/sdk/python/test/openai/test_embedding_client.py b/sdk/python/test/openai/test_embedding_client.py new file mode 100644 index 00000000..69e9648d --- /dev/null +++ b/sdk/python/test/openai/test_embedding_client.py @@ -0,0 +1,202 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# -------------------------------------------------------------------------- +"""Tests for EmbeddingClient – mirrors EmbeddingClientTests.cs.""" + +from __future__ import annotations + +import math + +import pytest + +from ..conftest import EMBEDDING_MODEL_ALIAS + + +def _get_loaded_embedding_model(catalog): + """Helper: ensure the embedding model is selected, loaded, and return Model.""" + cached = catalog.get_cached_models() + assert len(cached) > 0 + + cached_variant = next((m for m in cached if m.alias == EMBEDDING_MODEL_ALIAS), None) + assert cached_variant is not None, f"{EMBEDDING_MODEL_ALIAS} should be cached" + + model = catalog.get_model(EMBEDDING_MODEL_ALIAS) + assert model is not None + + model.select_variant(cached_variant) + model.load() + return model + + +class TestEmbeddingClient: + """Embedding Client Tests.""" + + def test_should_generate_embedding(self, catalog): + """Basic embedding generation.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + assert embedding_client is not None + + response = embedding_client.generate_embedding( + "The quick brown fox jumps over the lazy dog" + ) + + assert response is not None + assert response.model is not None + assert len(response.data) == 1 + assert response.data[0].index == 0 + assert len(response.data[0].embedding) == 1024 + + print(f"Embedding dimension: {len(response.data[0].embedding)}") + print(f"First value: {response.data[0].embedding[0]}") + print(f"Last value: {response.data[0].embedding[-1]}") + finally: + model.unload() + + def test_should_generate_normalized_embedding(self, catalog): + """Verify L2 norm is approximately 1.0.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + inputs = [ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + ] + + for input_text in inputs: + response = embedding_client.generate_embedding(input_text) + embedding = response.data[0].embedding + + assert len(embedding) == 1024 + + norm = math.sqrt(sum(v * v for v in embedding)) + assert 0.99 <= norm <= 1.01, f"L2 norm {norm} not approximately 1.0" + + for val in embedding: + assert -1.0 <= val <= 1.0 + finally: + model.unload() + + def test_should_produce_different_embeddings_for_different_inputs(self, catalog): + """Different inputs should produce different embeddings.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + response1 = embedding_client.generate_embedding("The quick brown fox") + response2 = embedding_client.generate_embedding("The capital of France is Paris") + + emb1 = response1.data[0].embedding + emb2 = response2.data[0].embedding + + assert len(emb1) == len(emb2) + + # Cosine similarity should not be 1.0 + dot = sum(a * b for a, b in zip(emb1, emb2)) + norm1 = math.sqrt(sum(a * a for a in emb1)) + norm2 = math.sqrt(sum(b * b for b in emb2)) + cosine_similarity = dot / (norm1 * norm2) + assert cosine_similarity < 0.99 + finally: + model.unload() + + def test_should_produce_same_embedding_for_same_input(self, catalog): + """Same input should produce identical embeddings.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + response1 = embedding_client.generate_embedding("Deterministic embedding test") + response2 = embedding_client.generate_embedding("Deterministic embedding test") + + emb1 = response1.data[0].embedding + emb2 = response2.data[0].embedding + + for i in range(len(emb1)): + assert emb1[i] == emb2[i] + finally: + model.unload() + + def test_should_raise_for_empty_input(self, catalog): + """Empty input should raise ValueError.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + with pytest.raises(ValueError): + embedding_client.generate_embedding("") + finally: + model.unload() + + def test_batch_should_return_multiple_embeddings(self, catalog): + """Batch request should return one embedding per input.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + response = embedding_client.generate_embeddings([ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + ]) + + assert response is not None + assert len(response.data) == 3 + + for i, data in enumerate(response.data): + assert data.index == i + assert len(data.embedding) == 1024 + finally: + model.unload() + + def test_batch_each_embedding_is_normalized(self, catalog): + """Each embedding in a batch should be L2-normalized.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + response = embedding_client.generate_embeddings([ + "Hello world", + "Goodbye world", + ]) + + assert len(response.data) == 2 + + for data in response.data: + norm = math.sqrt(sum(v * v for v in data.embedding)) + assert 0.99 <= norm <= 1.01, f"L2 norm {norm} not approximately 1.0" + finally: + model.unload() + + def test_batch_matches_single_input_results(self, catalog): + """Batch result should match single-input result for the same text.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + input_text = "The capital of France is Paris" + + single_response = embedding_client.generate_embedding(input_text) + batch_response = embedding_client.generate_embeddings([input_text]) + + assert len(batch_response.data) == 1 + + for i in range(len(single_response.data[0].embedding)): + assert batch_response.data[0].embedding[i] == single_response.data[0].embedding[i] + finally: + model.unload() + + def test_batch_should_raise_for_empty_list(self, catalog): + """Empty list should raise ValueError.""" + model = _get_loaded_embedding_model(catalog) + try: + embedding_client = model.get_embedding_client() + + with pytest.raises(ValueError): + embedding_client.generate_embeddings([]) + finally: + model.unload() diff --git a/sdk/rust/Cargo.toml b/sdk/rust/Cargo.toml index 2a6292b7..a8cd7228 100644 --- a/sdk/rust/Cargo.toml +++ b/sdk/rust/Cargo.toml @@ -24,7 +24,7 @@ tokio-stream = "0.1" futures-core = "0.3" reqwest = { version = "0.12", features = ["json"] } urlencoding = "2" -async-openai = { version = "0.33", default-features = false, features = ["chat-completion-types"] } +async-openai = { version = "0.33", default-features = false, features = ["chat-completion-types", "embedding-types"] } [build-dependencies] ureq = "3" diff --git a/sdk/rust/README.md b/sdk/rust/README.md index 08f9c279..39469f1e 100644 --- a/sdk/rust/README.md +++ b/sdk/rust/README.md @@ -8,6 +8,7 @@ The Foundry Local Rust SDK provides an async Rust interface for running AI model - **Model catalog** — Browse and discover available models; check what's cached or loaded - **Automatic model management** — Download, load, unload, and remove models from cache - **Chat completions** — OpenAI-compatible chat API with both non-streaming and streaming responses +- **Embeddings** — Generate text embeddings via OpenAI-compatible API - **Audio transcription** — Transcribe audio files locally with streaming support - **Tool calling** — Function/tool calling with streaming, multi-turn conversation support - **Response format control** — Text, JSON, JSON Schema, and Lark grammar constrained output @@ -353,6 +354,35 @@ let client = model.create_chat_client() .response_format(ChatResponseFormat::LarkGrammar(grammar.to_string())); ``` +### Embeddings + +Generate text embeddings using the `EmbeddingClient`: + +```rust +let embedding_client = model.create_embedding_client(); + +// Single input +let response = embedding_client + .generate_embedding("The quick brown fox jumps over the lazy dog") + .await?; +let embedding = &response.data[0].embedding; // Vec +println!("Dimensions: {}", embedding.len()); + +// Batch input +let batch_response = embedding_client + .generate_embeddings(&["The quick brown fox", "The capital of France is Paris"]) + .await?; +// batch_response.data[0].embedding, batch_response.data[1].embedding +``` + +#### Embedding Settings + +```rust +let embedding_client = model.create_embedding_client() + .dimensions(512) // optional: reduce dimensionality + .encoding_format("float"); // "float" or "base64" +``` + ### Audio Transcription Transcribe audio files locally using the `AudioClient`: diff --git a/sdk/rust/docs/api.md b/sdk/rust/docs/api.md index abfec76f..ef558b8f 100644 --- a/sdk/rust/docs/api.md +++ b/sdk/rust/docs/api.md @@ -15,6 +15,8 @@ - [OpenAI Clients](#openai-clients) - [ChatClient](#chatclient) - [ChatCompletionStream](#chatcompletionstream) + - [EmbeddingClient](#embeddingclient) + - [EmbeddingResponse](#embeddingresponse) - [AudioClient](#audioclient) - [AudioTranscriptionStream](#audiotranscriptionstream) - [AudioTranscriptionResponse](#audiotranscriptionresponse) @@ -214,6 +216,35 @@ A stream of `CreateChatCompletionStreamResponse` chunks. Use with `StreamExt::ne --- +### EmbeddingClient + +OpenAI-compatible embedding generation backed by a local model. + +| Method | Description | +|---|---| +| `new(model_id, core)` | *(internal)* Create a new client | +| `dimensions(v: u32) -> Self` | Set the number of output dimensions | +| `encoding_format(v: impl Into) -> Self` | Set encoding format (`"float"` or `"base64"`) | +| `generate_embedding(input: &str) -> Result` | Generate embedding for a single input | +| `generate_embeddings(inputs: &[&str]) -> Result` | Generate embeddings for multiple inputs | + +### EmbeddingResponse + +| Field | Type | Description | +|---|---|---| +| `model` | `String` | Model used for generation | +| `object` | `Option` | Object type (always `"list"`) | +| `data` | `Vec` | List of embedding results | + +### EmbeddingData + +| Field | Type | Description | +|---|---|---| +| `index` | `i32` | Index of this embedding | +| `embedding` | `Vec` | The embedding vector | + +--- + ### AudioClient OpenAI-compatible audio transcription backed by a local model. diff --git a/sdk/rust/src/detail/model.rs b/sdk/rust/src/detail/model.rs index 3a87a1c3..08288aee 100644 --- a/sdk/rust/src/detail/model.rs +++ b/sdk/rust/src/detail/model.rs @@ -14,6 +14,7 @@ use super::model_variant::ModelVariant; use crate::error::{FoundryLocalError, Result}; use crate::openai::AudioClient; use crate::openai::ChatClient; +use crate::openai::EmbeddingClient; use crate::types::ModelInfo; /// The public model type. @@ -242,6 +243,11 @@ impl Model { self.selected_variant().create_audio_client() } + /// Create an [`EmbeddingClient`] bound to the (selected) variant. + pub fn create_embedding_client(&self) -> EmbeddingClient { + self.selected_variant().create_embedding_client() + } + /// Available variants of this model. /// /// For a single-variant model (e.g. from diff --git a/sdk/rust/src/detail/model_variant.rs b/sdk/rust/src/detail/model_variant.rs index ca1a83c7..1f8ce7d5 100644 --- a/sdk/rust/src/detail/model_variant.rs +++ b/sdk/rust/src/detail/model_variant.rs @@ -15,6 +15,7 @@ use crate::catalog::CacheInvalidator; use crate::error::Result; use crate::openai::AudioClient; use crate::openai::ChatClient; +use crate::openai::EmbeddingClient; use crate::types::ModelInfo; /// Represents one specific variant of a model (a particular id within an alias @@ -148,4 +149,8 @@ impl ModelVariant { pub(crate) fn create_audio_client(&self) -> AudioClient { AudioClient::new(&self.info.id, Arc::clone(&self.core)) } + + pub(crate) fn create_embedding_client(&self) -> EmbeddingClient { + EmbeddingClient::new(&self.info.id, Arc::clone(&self.core)) + } } diff --git a/sdk/rust/src/openai/embedding_client.rs b/sdk/rust/src/openai/embedding_client.rs new file mode 100644 index 00000000..798928b6 --- /dev/null +++ b/sdk/rust/src/openai/embedding_client.rs @@ -0,0 +1,156 @@ +//! OpenAI-compatible embedding client. + +use std::sync::Arc; + +use async_openai::types::embeddings::CreateEmbeddingResponse; +use serde_json::{json, Value}; + +use crate::detail::core_interop::CoreInterop; +use crate::error::{FoundryLocalError, Result}; + +/// Tuning knobs for embedding requests. +/// +/// Use the chainable setter methods to configure, e.g.: +/// +/// ```ignore +/// let client = model.create_embedding_client() +/// .dimensions(512); +/// ``` +#[derive(Debug, Clone, Default)] +pub struct EmbeddingClientSettings { + dimensions: Option, + encoding_format: Option, +} + +impl EmbeddingClientSettings { + fn serialize(&self) -> Value { + let mut map = serde_json::Map::new(); + + if let Some(dims) = self.dimensions { + map.insert("dimensions".into(), json!(dims)); + } + if let Some(ref fmt) = self.encoding_format { + map.insert("encoding_format".into(), json!(fmt)); + } + + Value::Object(map) + } +} + +/// Client for OpenAI-compatible embedding generation backed by a local model. +pub struct EmbeddingClient { + model_id: String, + core: Arc, + settings: EmbeddingClientSettings, +} + +impl EmbeddingClient { + pub(crate) fn new(model_id: &str, core: Arc) -> Self { + Self { + model_id: model_id.to_owned(), + core, + settings: EmbeddingClientSettings::default(), + } + } + + /// Set the number of dimensions for the output embeddings. + pub fn dimensions(mut self, v: u32) -> Self { + self.settings.dimensions = Some(v); + self + } + + /// Set the encoding format ("float" or "base64"). + pub fn encoding_format(mut self, v: impl Into) -> Self { + self.settings.encoding_format = Some(v.into()); + self + } + + /// Generate embeddings for a single input text. + pub async fn generate_embedding(&self, input: &str) -> Result { + Self::validate_input(input)?; + let request = self.build_request(json!(input))?; + self.execute_request(request).await + } + + /// Generate embeddings for multiple input texts in a single request. + pub async fn generate_embeddings(&self, inputs: &[&str]) -> Result { + if inputs.is_empty() { + return Err(FoundryLocalError::Validation { + reason: "inputs must be a non-empty array".into(), + }); + } + for input in inputs { + Self::validate_input(input)?; + } + let request = self.build_request(json!(inputs))?; + self.execute_request(request).await + } + + async fn execute_request(&self, request: Value) -> Result { + let params = json!({ + "Params": { + "OpenAICreateRequest": serde_json::to_string(&request)? + } + }); + + let raw = self + .core + .execute_command_async("embeddings".into(), Some(params)) + .await?; + + // Patch the response to add fields required by async_openai types + // that the server doesn't return (object on each item, usage) + let mut response_value: Value = serde_json::from_str(&raw)?; + if let Some(data) = response_value.get_mut("data").and_then(|d| d.as_array_mut()) { + for item in data { + if item.get("object").is_none() { + item.as_object_mut() + .map(|m| m.insert("object".into(), json!("embedding"))); + } + } + } + if response_value.get("usage").is_none() { + response_value.as_object_mut() + .map(|m| m.insert("usage".into(), json!({"prompt_tokens": 0, "total_tokens": 0}))); + } + + let parsed: CreateEmbeddingResponse = serde_json::from_value(response_value)?; + Ok(parsed) + } + + fn build_request(&self, input: Value) -> Result { + Self::validate_encoding_format(&self.settings.encoding_format)?; + + let settings_value = self.settings.serialize(); + let mut map = match settings_value { + Value::Object(m) => m, + _ => serde_json::Map::new(), + }; + + map.insert("model".into(), json!(self.model_id)); + map.insert("input".into(), input); + + Ok(Value::Object(map)) + } + + fn validate_encoding_format(format: &Option) -> Result<()> { + if let Some(ref fmt) = format { + let valid = ["float", "base64"]; + if !valid.contains(&fmt.as_str()) { + return Err(FoundryLocalError::Validation { + reason: format!("encoding_format must be one of: {}", valid.join(", ")), + }); + } + } + Ok(()) + } + + fn validate_input(input: &str) -> Result<()> { + if input.trim().is_empty() { + return Err(FoundryLocalError::Validation { + reason: "input must be a non-empty string".into(), + }); + } + Ok(()) + } +} diff --git a/sdk/rust/src/openai/mod.rs b/sdk/rust/src/openai/mod.rs index c3d4a645..90e29d10 100644 --- a/sdk/rust/src/openai/mod.rs +++ b/sdk/rust/src/openai/mod.rs @@ -1,5 +1,6 @@ mod audio_client; mod chat_client; +mod embedding_client; mod json_stream; pub use self::audio_client::{ @@ -7,4 +8,5 @@ pub use self::audio_client::{ TranscriptionSegment, TranscriptionWord, }; pub use self::chat_client::{ChatClient, ChatClientSettings, ChatCompletionStream}; +pub use self::embedding_client::{EmbeddingClient, EmbeddingClientSettings}; pub use self::json_stream::JsonStream; diff --git a/sdk/rust/tests/integration/common/mod.rs b/sdk/rust/tests/integration/common/mod.rs index b0ca1a77..a79cab0f 100644 --- a/sdk/rust/tests/integration/common/mod.rs +++ b/sdk/rust/tests/integration/common/mod.rs @@ -14,6 +14,9 @@ pub const TEST_MODEL_ALIAS: &str = "qwen2.5-0.5b"; /// Default model alias used for audio-transcription integration tests. pub const WHISPER_MODEL_ALIAS: &str = "whisper-tiny"; +/// Default model alias used for embedding integration tests. +pub const EMBEDDING_MODEL_ALIAS: &str = "qwen3-0.6b-embedding-generic-cpu"; + /// Expected transcription text fragment for the shared audio test file. pub const EXPECTED_TRANSCRIPTION_TEXT: &str = " And lots of times you need to give people more than one link at a time"; diff --git a/sdk/rust/tests/integration/embedding_client_test.rs b/sdk/rust/tests/integration/embedding_client_test.rs new file mode 100644 index 00000000..f211e39a --- /dev/null +++ b/sdk/rust/tests/integration/embedding_client_test.rs @@ -0,0 +1,209 @@ +//! Integration tests for EmbeddingClient. + +use std::sync::Arc; + +use foundry_local_sdk::openai::EmbeddingClient; +use foundry_local_sdk::Model; + +use crate::common; + +async fn setup_embedding_client() -> (EmbeddingClient, Arc) { + let manager = common::get_test_manager(); + let catalog = manager.catalog(); + + let model = catalog + .get_model(common::EMBEDDING_MODEL_ALIAS) + .await + .expect("embedding model should exist in catalog"); + + model.load().await.expect("model should load successfully"); + + let client = model.create_embedding_client(); + (client, model) +} + +#[tokio::test] +async fn should_generate_embedding() { + let (client, model) = setup_embedding_client().await; + + let response = client + .generate_embedding("The quick brown fox jumps over the lazy dog") + .await + .expect("embedding should succeed"); + + assert_eq!(response.data.len(), 1); + assert_eq!(response.data[0].index, 0); + assert_eq!(response.data[0].embedding.len(), 1024); + + println!("Embedding dimension: {}", response.data[0].embedding.len()); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_generate_normalized_embedding() { + let (client, model) = setup_embedding_client().await; + + let inputs = [ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + ]; + + for input in &inputs { + let response = client + .generate_embedding(input) + .await + .expect("embedding should succeed"); + + let embedding = &response.data[0].embedding; + assert_eq!(embedding.len(), 1024); + + // Verify L2 norm is approximately 1.0 + let norm: f32 = embedding.iter().map(|v| v * v).sum::().sqrt() as f32; + assert!( + (0.99_f32..=1.01_f32).contains(&norm), + "L2 norm {norm} not approximately 1.0" + ); + + for val in embedding { + assert!( + (-1.0_f32..=1.0_f32).contains(val), + "value {val} outside [-1, 1]" + ); + } + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_produce_different_embeddings_for_different_inputs() { + let (client, model) = setup_embedding_client().await; + + let response1 = client + .generate_embedding("The quick brown fox") + .await + .expect("embedding should succeed"); + + let response2 = client + .generate_embedding("The capital of France is Paris") + .await + .expect("embedding should succeed"); + + let emb1 = &response1.data[0].embedding; + let emb2 = &response2.data[0].embedding; + + assert_eq!(emb1.len(), emb2.len()); + + // Cosine similarity should not be 1.0 + let dot: f32 = emb1.iter().zip(emb2.iter()).map(|(a, b)| a * b).sum(); + let norm1: f32 = emb1.iter().map(|v| v * v).sum::().sqrt() as f32; + let norm2: f32 = emb2.iter().map(|v| v * v).sum::().sqrt() as f32; + let cosine_similarity = dot / (norm1 * norm2); + assert!( + cosine_similarity < 0.99_f32, + "cosine similarity {cosine_similarity} should be < 0.99" + ); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_produce_same_embedding_for_same_input() { + let (client, model) = setup_embedding_client().await; + + let response1 = client + .generate_embedding("Deterministic embedding test") + .await + .expect("embedding should succeed"); + + let response2 = client + .generate_embedding("Deterministic embedding test") + .await + .expect("embedding should succeed"); + + let emb1 = &response1.data[0].embedding; + let emb2 = &response2.data[0].embedding; + + for (i, (a, b)) in emb1.iter().zip(emb2.iter()).enumerate() { + assert_eq!(a, b, "mismatch at index {i}"); + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_throw_for_empty_input() { + let (client, model) = setup_embedding_client().await; + + let result = client.generate_embedding("").await; + assert!(result.is_err(), "empty input should return an error"); + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_generate_batch_embeddings() { + let (client, model) = setup_embedding_client().await; + + let response = client + .generate_embeddings(&[ + "The quick brown fox jumps over the lazy dog", + "Machine learning is a subset of artificial intelligence", + "The capital of France is Paris", + ]) + .await + .expect("batch embedding should succeed"); + + assert_eq!(response.data.len(), 3); + for (i, data) in response.data.iter().enumerate() { + assert_eq!(data.index, i as u32); + assert_eq!(data.embedding.len(), 1024); + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_generate_normalized_batch_embeddings() { + let (client, model) = setup_embedding_client().await; + + let response = client + .generate_embeddings(&["Hello world", "Goodbye world"]) + .await + .expect("batch embedding should succeed"); + + assert_eq!(response.data.len(), 2); + for data in &response.data { + let norm: f32 = data.embedding.iter().map(|v| v * v).sum::().sqrt() as f32; + assert!( + (0.99_f32..=1.01_f32).contains(&norm), + "L2 norm {norm} not approximately 1.0" + ); + } + + model.unload().await.expect("unload should succeed"); +} + +#[tokio::test] +async fn should_match_single_and_batch_results() { + let (client, model) = setup_embedding_client().await; + + let single = client + .generate_embedding("The capital of France is Paris") + .await + .expect("single embedding should succeed"); + + let batch = client + .generate_embeddings(&["The capital of France is Paris"]) + .await + .expect("batch embedding should succeed"); + + assert_eq!(batch.data.len(), 1); + for (a, b) in single.data[0].embedding.iter().zip(batch.data[0].embedding.iter()) { + assert_eq!(a, b); + } + + model.unload().await.expect("unload should succeed"); +} diff --git a/sdk/rust/tests/integration/main.rs b/sdk/rust/tests/integration/main.rs index 04de9a23..c63956f3 100644 --- a/sdk/rust/tests/integration/main.rs +++ b/sdk/rust/tests/integration/main.rs @@ -11,6 +11,7 @@ mod common; mod audio_client_test; mod catalog_test; mod chat_client_test; +mod embedding_client_test; mod manager_test; mod model_test; mod web_service_test;