Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 102 additions & 2 deletions sdk/cs/src/Catalog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ private async Task<IModel> GetLatestVersionImplAsync(IModel modelOrModelVariant,
return latest.Id == modelOrModelVariant.Id ? modelOrModelVariant : latest;
}

private async Task UpdateModels(CancellationToken? ct)
private async Task UpdateModels(CancellationToken? ct, bool forceRefresh = false)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is a function for invalidating the cache: InvalidateCache. Could you use that instead of adding forceRefresh?

{
// TODO: make this configurable
if (DateTime.Now - _lastFetch < TimeSpan.FromHours(6))
if (!forceRefresh && DateTime.Now - _lastFetch < TimeSpan.FromHours(6))
{
return;
}
Expand Down Expand Up @@ -249,4 +249,104 @@ public void Dispose()
{
_lock.Dispose();
}

public async Task AddCatalogAsync(string name, Uri uri, string? clientId = null,
string? clientSecret = null, string? bearerToken = null,
string? tokenEndpoint = null, string? audience = null,
Comment on lines +254 to +255
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we keep only the essential arguments in the function and move all the optional arguments in a map argument.
Different catalogs may have different arguments that may be needed. So, the basic ones can be explicit arguments and the optional ones can be in a map?

CancellationToken? ct = null)
{
ArgumentException.ThrowIfNullOrWhiteSpace(name);
ArgumentNullException.ThrowIfNull(uri);

if (uri.Scheme != "https" && uri.Scheme != "http")
{
throw new ArgumentException($"Catalog URI must use http or https scheme, got '{uri.Scheme}'.", nameof(uri));
}

if (tokenEndpoint != null)
{
if (!Uri.TryCreate(tokenEndpoint, UriKind.Absolute, out var parsedEndpoint))
{
throw new ArgumentException($"Token endpoint is not a valid URL: '{tokenEndpoint}'.", nameof(tokenEndpoint));
}
if (parsedEndpoint.Scheme != "https" && parsedEndpoint.Scheme != "http")
{
throw new ArgumentException($"Token endpoint must use http or https scheme, got '{parsedEndpoint.Scheme}'.", nameof(tokenEndpoint));
}
}

await Utils.CallWithExceptionHandling(async () =>
{
var request = new CoreInteropRequest
{
Params = new Dictionary<string, string>
{
["Name"] = name,
["Uri"] = uri.ToString(),
["ClientId"] = clientId ?? "",
["ClientSecret"] = clientSecret ?? "",
["BearerToken"] = bearerToken ?? "",
["TokenEndpoint"] = tokenEndpoint ?? "",
["Audience"] = audience ?? ""
}
};
Comment on lines +261 to +292
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can all of this logic be moved to Core so each sdk need not have this logic?


var result = await _coreInterop.ExecuteCommandAsync("add_catalog", request, ct)
.ConfigureAwait(false);
if (result.Error != null)
{
throw new FoundryLocalException($"Error adding catalog '{name}': {result.Error}", _logger);
}

// Force model list refresh to pick up new catalog's models
await UpdateModels(ct, forceRefresh: true).ConfigureAwait(false);
}, $"Error adding catalog '{name}'.", _logger).ConfigureAwait(false);
}

public async Task SelectCatalogAsync(string? catalogName, CancellationToken? ct = null)
{
if (catalogName != null)
{
ArgumentException.ThrowIfNullOrWhiteSpace(catalogName);
}

await Utils.CallWithExceptionHandling(async () =>
{
var request = new CoreInteropRequest
{
Params = new Dictionary<string, string>
{
["Name"] = catalogName ?? ""
}
};

var result = await _coreInterop.ExecuteCommandAsync("select_catalog", request, ct)
.ConfigureAwait(false);
if (result.Error != null)
{
throw new FoundryLocalException($"Error selecting catalog: {result.Error}", _logger);
}

// Force model list refresh so the managed-side maps reflect the filter.
// The native core already has models cached; this just re-fetches the
// (now-filtered) list into _modelAliasToModel / _modelIdToModelVariant.
await UpdateModels(ct, forceRefresh: true).ConfigureAwait(false);
}, "Error selecting catalog.", _logger).ConfigureAwait(false);
}

public async Task<List<string>> GetCatalogNamesAsync(CancellationToken? ct = null)
{
return await Utils.CallWithExceptionHandling(async () =>
{
CoreInteropRequest? input = null;
var result = await _coreInterop.ExecuteCommandAsync("get_catalog_names", input, ct)
.ConfigureAwait(false);
if (result.Error != null)
{
throw new FoundryLocalException($"Error getting catalog names: {result.Error}", _logger);
}

return JsonSerializer.Deserialize(result.Data ?? "[]", JsonSerializationContext.Default.ListString) ?? [];
}, "Error getting catalog names.", _logger).ConfigureAwait(false);
}
}
3 changes: 1 addition & 2 deletions sdk/cs/src/Detail/CoreInterop.cs
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput,
if (response.Error != IntPtr.Zero && response.ErrorLength > 0)
{
result.Error = Marshal.PtrToStringUTF8(response.Error, response.ErrorLength)!;
_logger.LogDebug($"Input:{commandInput ?? "null"}");
_logger.LogDebug($"Command: {commandName} Error: {result.Error}");
}

Expand All @@ -342,7 +341,7 @@ public Response ExecuteCommandImpl(string commandName, string? commandInput,
}
catch (Exception ex) when (ex is not OperationCanceledException)
{
var msg = $"Error executing command '{commandName}' with input {commandInput ?? "null"}";
var msg = $"Error executing command '{commandName}'";
throw new FoundryLocalException(msg, ex, _logger);
}
}
Expand Down
3 changes: 2 additions & 1 deletion sdk/cs/src/Detail/JsonSerializationContext.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// --------------------------------------------------------------------------------------------------------------------
// --------------------------------------------------------------------------------------------------------------------
// <copyright company="Microsoft">
// Copyright (c) Microsoft. All rights reserved.
// </copyright>
Expand Down Expand Up @@ -39,6 +39,7 @@ namespace Microsoft.AI.Foundry.Local.Detail;
// which has AOT-incompatible JsonConverters, so we only register the raw deserialization type) ---
[JsonSerializable(typeof(LiveAudioTranscriptionRaw))]
[JsonSerializable(typeof(CoreErrorResponse))]
[JsonSerializable(typeof(List<string>))] // catalog names
[JsonSourceGenerationOptions(DefaultIgnoreCondition = JsonIgnoreCondition.WhenWritingNull,
WriteIndented = false)]
internal partial class JsonSerializationContext : JsonSerializerContext
Expand Down
33 changes: 32 additions & 1 deletion sdk/cs/src/ICatalog.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// --------------------------------------------------------------------------------------------------------------------
// --------------------------------------------------------------------------------------------------------------------
// <copyright company="Microsoft">
// Copyright (c) Microsoft. All rights reserved.
// </copyright>
Expand Down Expand Up @@ -61,4 +61,35 @@ public interface ICatalog
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>The latest version of the model. Will match the input if it is the latest version.</returns>
Task<IModel> GetLatestVersionAsync(IModel model, CancellationToken? ct = null);

/// <summary>
/// Add a private model catalog. The model list is refreshed automatically,
/// so models from the new catalog are available as soon as this call returns.
/// </summary>
/// <param name="name">Display name for the catalog (e.g. "my-private-catalog").</param>
/// <param name="uri">Base URL of the private catalog service.</param>
/// <param name="clientId">Optional OAuth2 client credentials ID.</param>
/// <param name="clientSecret">Optional OAuth2 client credentials secret, or API key for legacy auth.</param>
/// <param name="bearerToken">Optional pre-obtained bearer token (for testing/self-service auth).</param>
/// <param name="tokenEndpoint">Optional OAuth2 token endpoint URL (e.g. "https://idp.example.com/oauth/token").</param>
/// <param name="audience">Optional OAuth2 audience parameter (e.g. "model-distribution-service").</param>
/// <param name="ct">Optional CancellationToken.</param>
Task AddCatalogAsync(string name, Uri uri, string? clientId = null, string? clientSecret = null,
string? bearerToken = null, string? tokenEndpoint = null, string? audience = null,
CancellationToken? ct = null);

/// <summary>
/// Filter the catalog to only return models from the named catalog.
/// Pass null to reset and show models from all catalogs.
/// </summary>
/// <param name="catalogName">Catalog name to filter to, or null to show all.</param>
/// <param name="ct">Optional CancellationToken.</param>
Task SelectCatalogAsync(string? catalogName, CancellationToken? ct = null);

/// <summary>
/// Get the names of all registered catalogs.
/// </summary>
/// <param name="ct">Optional CancellationToken.</param>
/// <returns>List of catalog name strings.</returns>
Task<List<string>> GetCatalogNamesAsync(CancellationToken? ct = null);
}
61 changes: 61 additions & 0 deletions sdk/cs/test/FoundryLocal.Tests/CatalogManagementTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
// --------------------------------------------------------------------------------------------------------------------
// <copyright company="Microsoft">
// Copyright (c) Microsoft. All rights reserved.
// </copyright>
// --------------------------------------------------------------------------------------------------------------------

namespace Microsoft.AI.Foundry.Local.Tests;

using System.Text.Json;
using Microsoft.AI.Foundry.Local.Detail;
using Moq;

public class CatalogManagementTests
{
private static async Task<Catalog> CreateCatalogWithIntercepts(
List<Utils.InteropCommandInterceptInfo> extra)
{
var logger = Utils.CreateCapturingLoggerMock([]);
var lm = new Mock<IModelLoadManager>();
lm.Setup(m => m.ListLoadedModelsAsync(It.IsAny<CancellationToken?>())).ReturnsAsync(Array.Empty<string>());

List<Utils.InteropCommandInterceptInfo> intercepts =
[
new() { CommandName = "get_catalog_name", ResponseData = "Test" },
new() { CommandName = "get_model_list",
ResponseData = JsonSerializer.Serialize(Utils.TestCatalog.TestCatalog,
JsonSerializationContext.Default.ListModelInfo) },
new() { CommandName = "get_cached_models", ResponseData = "[]" },
.. extra
];

var ci = Utils.CreateCoreInteropWithIntercept(Utils.CoreInterop, intercepts);
return await Catalog.CreateAsync(lm.Object, ci.Object, logger.Object);
}

[Test]
public async Task Test_AddAndSelectCatalog()
{
using var catalog = await CreateCatalogWithIntercepts(
[
new() { CommandName = "add_catalog", ResponseData = "OK" },
new() { CommandName = "select_catalog", ResponseData = "OK" }
]);

await catalog.AddCatalogAsync("priv", new Uri("https://mds.example.com"), "id", "secret");
await catalog.SelectCatalogAsync("priv");
await catalog.SelectCatalogAsync(null);
await Assert.That(catalog).IsNotNull();
}

[Test]
public async Task Test_GetCatalogNames()
{
using var catalog = await CreateCatalogWithIntercepts(
[new() { CommandName = "get_catalog_names", ResponseData = "[\"public\",\"private\"]" }]);

var names = await catalog.GetCatalogNamesAsync();
await Assert.That(names.Count).IsEqualTo(2);
await Assert.That(names).Contains("private");
}
}
Loading