-
Notifications
You must be signed in to change notification settings - Fork 281
Add private model catalog SDK support (AddCatalog, SelectCatalog, GetCatalogNames) #601
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
282f0a9
ba4b050
f2725a4
b3ed6db
6c4c8da
62c21fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -190,10 +190,10 @@ private async Task<IModel> GetLatestVersionImplAsync(IModel modelOrModelVariant, | |
| return latest.Id == modelOrModelVariant.Id ? modelOrModelVariant : latest; | ||
| } | ||
|
|
||
| private async Task UpdateModels(CancellationToken? ct) | ||
| private async Task UpdateModels(CancellationToken? ct, bool forceRefresh = false) | ||
| { | ||
| // TODO: make this configurable | ||
| if (DateTime.Now - _lastFetch < TimeSpan.FromHours(6)) | ||
| if (!forceRefresh && DateTime.Now - _lastFetch < TimeSpan.FromHours(6)) | ||
| { | ||
| return; | ||
| } | ||
|
|
@@ -249,4 +249,104 @@ public void Dispose() | |
| { | ||
| _lock.Dispose(); | ||
| } | ||
|
|
||
| public async Task AddCatalogAsync(string name, Uri uri, string? clientId = null, | ||
| string? clientSecret = null, string? bearerToken = null, | ||
| string? tokenEndpoint = null, string? audience = null, | ||
|
Comment on lines
+254
to
+255
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we keep only the essential arguments in the function and move all the optional arguments in a map argument. |
||
| CancellationToken? ct = null) | ||
| { | ||
| ArgumentException.ThrowIfNullOrWhiteSpace(name); | ||
| ArgumentNullException.ThrowIfNull(uri); | ||
|
|
||
| if (uri.Scheme != "https" && uri.Scheme != "http") | ||
| { | ||
| throw new ArgumentException($"Catalog URI must use http or https scheme, got '{uri.Scheme}'.", nameof(uri)); | ||
| } | ||
|
|
||
| if (tokenEndpoint != null) | ||
| { | ||
| if (!Uri.TryCreate(tokenEndpoint, UriKind.Absolute, out var parsedEndpoint)) | ||
| { | ||
| throw new ArgumentException($"Token endpoint is not a valid URL: '{tokenEndpoint}'.", nameof(tokenEndpoint)); | ||
| } | ||
| if (parsedEndpoint.Scheme != "https" && parsedEndpoint.Scheme != "http") | ||
| { | ||
| throw new ArgumentException($"Token endpoint must use http or https scheme, got '{parsedEndpoint.Scheme}'.", nameof(tokenEndpoint)); | ||
| } | ||
| } | ||
|
|
||
| await Utils.CallWithExceptionHandling(async () => | ||
| { | ||
| var request = new CoreInteropRequest | ||
| { | ||
| Params = new Dictionary<string, string> | ||
| { | ||
| ["Name"] = name, | ||
| ["Uri"] = uri.ToString(), | ||
| ["ClientId"] = clientId ?? "", | ||
| ["ClientSecret"] = clientSecret ?? "", | ||
| ["BearerToken"] = bearerToken ?? "", | ||
| ["TokenEndpoint"] = tokenEndpoint ?? "", | ||
| ["Audience"] = audience ?? "" | ||
| } | ||
| }; | ||
|
Comment on lines
+261
to
+292
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can all of this logic be moved to Core so each sdk need not have this logic? |
||
|
|
||
| var result = await _coreInterop.ExecuteCommandAsync("add_catalog", request, ct) | ||
| .ConfigureAwait(false); | ||
| if (result.Error != null) | ||
| { | ||
| throw new FoundryLocalException($"Error adding catalog '{name}': {result.Error}", _logger); | ||
| } | ||
|
|
||
| // Force model list refresh to pick up new catalog's models | ||
| await UpdateModels(ct, forceRefresh: true).ConfigureAwait(false); | ||
| }, $"Error adding catalog '{name}'.", _logger).ConfigureAwait(false); | ||
| } | ||
|
|
||
| public async Task SelectCatalogAsync(string? catalogName, CancellationToken? ct = null) | ||
| { | ||
| if (catalogName != null) | ||
| { | ||
| ArgumentException.ThrowIfNullOrWhiteSpace(catalogName); | ||
| } | ||
|
|
||
| await Utils.CallWithExceptionHandling(async () => | ||
| { | ||
| var request = new CoreInteropRequest | ||
| { | ||
| Params = new Dictionary<string, string> | ||
| { | ||
| ["Name"] = catalogName ?? "" | ||
| } | ||
| }; | ||
|
|
||
| var result = await _coreInterop.ExecuteCommandAsync("select_catalog", request, ct) | ||
| .ConfigureAwait(false); | ||
| if (result.Error != null) | ||
| { | ||
| throw new FoundryLocalException($"Error selecting catalog: {result.Error}", _logger); | ||
| } | ||
|
|
||
| // Force model list refresh so the managed-side maps reflect the filter. | ||
| // The native core already has models cached; this just re-fetches the | ||
| // (now-filtered) list into _modelAliasToModel / _modelIdToModelVariant. | ||
| await UpdateModels(ct, forceRefresh: true).ConfigureAwait(false); | ||
| }, "Error selecting catalog.", _logger).ConfigureAwait(false); | ||
| } | ||
|
|
||
| public async Task<List<string>> GetCatalogNamesAsync(CancellationToken? ct = null) | ||
| { | ||
| return await Utils.CallWithExceptionHandling(async () => | ||
| { | ||
| CoreInteropRequest? input = null; | ||
| var result = await _coreInterop.ExecuteCommandAsync("get_catalog_names", input, ct) | ||
| .ConfigureAwait(false); | ||
| if (result.Error != null) | ||
| { | ||
| throw new FoundryLocalException($"Error getting catalog names: {result.Error}", _logger); | ||
| } | ||
|
|
||
| return JsonSerializer.Deserialize(result.Data ?? "[]", JsonSerializationContext.Default.ListString) ?? []; | ||
| }, "Error getting catalog names.", _logger).ConfigureAwait(false); | ||
kobby-kobbs marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,61 @@ | ||
| // -------------------------------------------------------------------------------------------------------------------- | ||
| // <copyright company="Microsoft"> | ||
| // Copyright (c) Microsoft. All rights reserved. | ||
| // </copyright> | ||
| // -------------------------------------------------------------------------------------------------------------------- | ||
|
|
||
| namespace Microsoft.AI.Foundry.Local.Tests; | ||
|
|
||
| using System.Text.Json; | ||
| using Microsoft.AI.Foundry.Local.Detail; | ||
| using Moq; | ||
|
|
||
| public class CatalogManagementTests | ||
| { | ||
| private static async Task<Catalog> CreateCatalogWithIntercepts( | ||
| List<Utils.InteropCommandInterceptInfo> extra) | ||
| { | ||
| var logger = Utils.CreateCapturingLoggerMock([]); | ||
| var lm = new Mock<IModelLoadManager>(); | ||
| lm.Setup(m => m.ListLoadedModelsAsync(It.IsAny<CancellationToken?>())).ReturnsAsync(Array.Empty<string>()); | ||
|
|
||
| List<Utils.InteropCommandInterceptInfo> intercepts = | ||
| [ | ||
| new() { CommandName = "get_catalog_name", ResponseData = "Test" }, | ||
| new() { CommandName = "get_model_list", | ||
| ResponseData = JsonSerializer.Serialize(Utils.TestCatalog.TestCatalog, | ||
| JsonSerializationContext.Default.ListModelInfo) }, | ||
| new() { CommandName = "get_cached_models", ResponseData = "[]" }, | ||
| .. extra | ||
| ]; | ||
|
|
||
| var ci = Utils.CreateCoreInteropWithIntercept(Utils.CoreInterop, intercepts); | ||
| return await Catalog.CreateAsync(lm.Object, ci.Object, logger.Object); | ||
| } | ||
|
|
||
| [Test] | ||
| public async Task Test_AddAndSelectCatalog() | ||
| { | ||
| using var catalog = await CreateCatalogWithIntercepts( | ||
| [ | ||
| new() { CommandName = "add_catalog", ResponseData = "OK" }, | ||
| new() { CommandName = "select_catalog", ResponseData = "OK" } | ||
| ]); | ||
|
|
||
| await catalog.AddCatalogAsync("priv", new Uri("https://mds.example.com"), "id", "secret"); | ||
| await catalog.SelectCatalogAsync("priv"); | ||
| await catalog.SelectCatalogAsync(null); | ||
| await Assert.That(catalog).IsNotNull(); | ||
| } | ||
|
|
||
| [Test] | ||
| public async Task Test_GetCatalogNames() | ||
| { | ||
| using var catalog = await CreateCatalogWithIntercepts( | ||
| [new() { CommandName = "get_catalog_names", ResponseData = "[\"public\",\"private\"]" }]); | ||
|
|
||
| var names = await catalog.GetCatalogNamesAsync(); | ||
| await Assert.That(names.Count).IsEqualTo(2); | ||
| await Assert.That(names).Contains("private"); | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There is a function for invalidating the cache:
InvalidateCache. Could you use that instead of adding forceRefresh?