diff --git a/packages/opencode/src/config/config.ts b/packages/opencode/src/config/config.ts index 47afdfd7d0f..bfab45eeee6 100644 --- a/packages/opencode/src/config/config.ts +++ b/packages/opencode/src/config/config.ts @@ -979,6 +979,10 @@ export namespace Config { .extend({ whitelist: z.array(z.string()).optional(), blacklist: z.array(z.string()).optional(), + auth_provider: z + .string() + .optional() + .describe("Provider to inherit SDK and model loading behavior from"), models: z .record( z.string(), diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 6ab45d028b9..12334b76a68 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -55,6 +55,22 @@ import { ModelID, ProviderID } from "./schema" export namespace Provider { const log = Log.create({ service: "provider" }) + function driver(cfg: Config.Info, id: string): string { + const seen = new Set() + let current = id + while (true) { + if (seen.has(current)) return current + seen.add(current) + const next = cfg.provider?.[current]?.auth_provider + if (!next || next === current) return current + current = next + } + } + + function aliases(cfg: Config.Info, source: string): string[] { + return Object.keys(cfg.provider ?? {}).filter((item) => item !== source && driver(cfg, item) === source) + } + function shouldUseCopilotResponsesApi(modelID: string): boolean { const match = /^gpt-(\d+)/.exec(modelID) if (!match) return false @@ -1069,19 +1085,27 @@ export namespace Provider { for (const plugin of await Plugin.list()) { if (!plugin.auth) continue - const providerID = ProviderID.make(plugin.auth.provider) - if (disabled.has(providerID)) continue - - const auth = await Auth.get(providerID) - if (!auth) continue - if (!plugin.auth.loader) continue - - if (auth) { - const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider]) + const loader = plugin.auth.loader + const source = plugin.auth.provider + + const load = async (id: string) => { + const providerID = ProviderID.make(id) + if (disabled.has(providerID)) return + const auth = await Auth.get(providerID) + if (!auth) return + if (!loader) return + const info = database[id] + if (!info) return + const options = await loader(() => Auth.get(providerID) as any, info) const opts = options ?? {} const patch: Partial = providers[providerID] ? { options: opts } : { source: "custom", options: opts } mergeProvider(providerID, patch) } + + await load(source) + for (const id of aliases(config, source)) { + await load(id) + } } for (const [id, fn] of Object.entries(CUSTOM_LOADERS)) { @@ -1103,6 +1127,15 @@ export namespace Provider { } } + for (const [id] of configProviders) { + const providerID = ProviderID.make(id) + if (disabled.has(providerID)) continue + const sourceID = ProviderID.make(driver(config, id)) + if (sourceID === providerID) continue + if (modelLoaders[sourceID]) modelLoaders[providerID] = modelLoaders[sourceID] + if (varsLoaders[sourceID]) varsLoaders[providerID] = varsLoaders[sourceID] + } + // load config for (const [id, provider] of configProviders) { const providerID = ProviderID.make(id)