diff --git a/packages/opencode/src/provider/model-detection.ts b/packages/opencode/src/provider/model-detection.ts index 27c63ff5f2..0dc65fd3fc 100644 --- a/packages/opencode/src/provider/model-detection.ts +++ b/packages/opencode/src/provider/model-detection.ts @@ -1,117 +1,31 @@ import z from "zod" import { iife } from "@/util/iife" import { Log } from "@/util/log" -import { Config } from "../config/config" -import { ModelsDev } from "./models" import { Provider } from "./provider" export namespace ProviderModelDetection { - function mergeModel( - detectedModel: Partial, - providerModel: Provider.Model | undefined, - modelID: string, - providerID: string, - providerBaseURL: string, - ): Provider.Model { - return { - id: modelID, - providerID: detectedModel.providerID ?? providerModel?.providerID ?? providerID, - api: { - id: modelID, - url: detectedModel.api?.url ?? providerModel?.api?.url ?? providerBaseURL, - npm: detectedModel.api?.npm ?? providerModel?.api?.npm ?? "@ai-sdk/openai-compatible", - }, - name: detectedModel.name ?? providerModel?.name ?? modelID, - family: detectedModel.family ?? providerModel?.family ?? "", - capabilities: { - temperature: detectedModel.capabilities?.temperature ?? providerModel?.capabilities?.temperature ?? false, - reasoning: detectedModel.capabilities?.reasoning ?? providerModel?.capabilities?.reasoning ?? false, - attachment: detectedModel.capabilities?.attachment ?? providerModel?.capabilities?.attachment ?? false, - toolcall: detectedModel.capabilities?.toolcall ?? providerModel?.capabilities?.toolcall ?? true, - input: { - text: detectedModel.capabilities?.input?.text ?? providerModel?.capabilities?.input?.text ?? true, - audio: detectedModel.capabilities?.input?.audio ?? providerModel?.capabilities?.input?.audio ?? false, - image: detectedModel.capabilities?.input?.image ?? providerModel?.capabilities?.input?.image ?? false, - video: detectedModel.capabilities?.input?.video ?? providerModel?.capabilities?.input?.video ?? false, - pdf: detectedModel.capabilities?.input?.pdf ?? providerModel?.capabilities?.input?.pdf ?? false, - }, - output: { - text: detectedModel.capabilities?.output?.text ?? providerModel?.capabilities?.output?.text ?? true, - audio: detectedModel.capabilities?.output?.audio ?? providerModel?.capabilities?.output?.audio ?? false, - image: detectedModel.capabilities?.output?.image ?? providerModel?.capabilities?.output?.image ?? false, - video: detectedModel.capabilities?.output?.video ?? providerModel?.capabilities?.output?.video ?? false, - pdf: detectedModel.capabilities?.output?.pdf ?? providerModel?.capabilities?.output?.pdf ?? false, - }, - interleaved: detectedModel.capabilities?.interleaved ?? providerModel?.capabilities?.interleaved ?? false, - }, - cost: { - input: detectedModel.cost?.input ?? providerModel?.cost?.input ?? 0, - output: detectedModel.cost?.output ?? providerModel?.cost?.output ?? 0, - cache: { - read: detectedModel.cost?.cache?.read ?? providerModel?.cost?.cache?.read ?? 0, - write: detectedModel.cost?.cache?.write ?? providerModel?.cost?.cache?.write ?? 0, - }, - experimentalOver200K: detectedModel.cost?.experimentalOver200K ?? providerModel?.cost?.experimentalOver200K, - }, - limit: { - context: detectedModel.limit?.context ?? providerModel?.limit?.context ?? 0, - input: detectedModel.limit?.input ?? providerModel?.limit?.input ?? 0, - output: detectedModel.limit?.output ?? providerModel?.limit?.output ?? 0, - }, - status: detectedModel.status ?? providerModel?.status ?? "active", - options: detectedModel.options ?? providerModel?.options ?? {}, - headers: detectedModel.headers ?? providerModel?.headers ?? {}, - release_date: detectedModel.release_date ?? providerModel?.release_date ?? "", - variants: detectedModel.variants ?? providerModel?.variants ?? {}, - } - } - - export async function populateModels( - provider: Provider.Info, - configProvider?: Config.Provider, - modelsDevProvider?: ModelsDev.Provider, - ): Promise { + export async function detect(provider: Provider.Info): Promise { const log = Log.create({ service: "provider.model-detection" }) - const providerNPM = configProvider?.npm ?? modelsDevProvider?.npm ?? "@ai-sdk/openai-compatible" - const providerBaseURL = configProvider?.options?.baseURL ?? configProvider?.api ?? modelsDevProvider?.api ?? "" + const model = Object.values(provider.models)[0] + const providerNPM = model?.api?.npm ?? "@ai-sdk/openai-compatible" + const providerBaseURL = provider.options["baseURL"] ?? model?.api?.url ?? "" const detectedModels = await iife(async () => { - if (provider.id === "opencode") return - try { if (providerNPM === "@ai-sdk/openai-compatible" && providerBaseURL) { log.info("using OpenAI-compatible method", { providerID: provider.id }) return await ProviderModelDetection.OpenAICompatible.listModels(providerBaseURL, provider) } } catch (error) { - log.warn(`failed to populate models\n${error}`, { providerID: provider.id }) + log.warn(`failed to detect models\n${error}`, { providerID: provider.id }) } }) - if (!detectedModels || Object.entries(detectedModels).length === 0) return - // Only keep models detected and models specified in config - const modelIDs = Array.from(new Set([ - ...Object.keys(detectedModels), - ...Object.keys(configProvider?.models ?? {}), - ])) - // Provider models are merged from config and Models.dev, delete models only from Models.dev - for (const [modelID] of Object.entries(provider.models)) { - if (!modelIDs.includes(modelID)) delete provider.models[modelID] - } - // Add detected models, and take precedence over provider models (which are from config and Models.dev) - for (const modelID of modelIDs) { - if (!(modelID in detectedModels)) continue - provider.models[modelID] = mergeModel( - detectedModels[modelID], - provider.models[modelID], - modelID, - provider.id, - providerBaseURL, - ) - } + if (!detectedModels || detectedModels.length === 0) return - log.info("populated models", { providerID: provider.id }) + log.info("detected models", { providerID: provider.id, count: detectedModels.length }) + return detectedModels } } @@ -129,7 +43,7 @@ export namespace ProviderModelDetection.OpenAICompatible { }) type OpenAICompatibleResponse = z.infer - export async function listModels(baseURL: string, provider: Provider.Info): Promise>> { + export async function listModels(baseURL: string, provider: Provider.Info): Promise { const fetchFn = provider.options["fetch"] ?? fetch const apiKey = provider.options["apiKey"] ?? provider.key ?? "" const headers = new Headers() @@ -142,10 +56,8 @@ export namespace ProviderModelDetection.OpenAICompatible { if (!res.ok) throw new Error(`bad http status ${res.status}`) const parsed = OpenAICompatibleResponse.parse(await res.json()) - return Object.fromEntries( - parsed.data - .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed")) - .map((model) => [model.id, {}]) - ) + return parsed.data + .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed")) + .map((model) => model.id) } } diff --git a/packages/opencode/src/provider/provider.ts b/packages/opencode/src/provider/provider.ts index 998b54a45f..476b2f652f 100644 --- a/packages/opencode/src/provider/provider.ts +++ b/packages/opencode/src/provider/provider.ts @@ -673,6 +673,45 @@ export namespace Provider { } } + const ModelsList = z.object({ + object: z.string(), + data: z.array( + z + .object({ + id: z.string(), + object: z.string().optional(), + created: z.number().optional(), + owned_by: z.string().optional(), + }) + .catchall(z.any()), + ), + }) + type ModelsList = z.infer + + async function listModels(provider: Info) { + const baseURL = provider.options["baseURL"] + const fetchFn = (provider.options["fetch"] as typeof fetch) ?? fetch + const apiKey = provider.options["apiKey"] ?? provider.key ?? "" + const headers = new Headers() + if (apiKey) headers.append("Authorization", `Bearer ${apiKey}`) + const models = await fetchFn(`${baseURL}/models`, { + headers, + signal: AbortSignal.timeout(3 * 1000), + }) + .then(async (resp) => { + if (!resp.ok) return + return ModelsList.parse(await resp.json()) + }) + .catch((err) => { + log.error(`Failed to fetch models from: ${baseURL}/models`, { error: err }) + }) + if (!models) return + + return models.data + .filter((model) => model.id && !model.id.includes("embedding") && !model.id.includes("embed")) + .map((model) => model.id) + } + const state = Instance.state(async () => { using _ = log.time("state") const config = await Config.get() @@ -904,11 +943,18 @@ export namespace Provider { mergeProvider(providerID, partial) } - // detect and populate models + // detect models and prune invalid ones await Promise.all( - Object.entries(providers).map(async ([providerID, provider]) => { - await ProviderModelDetection.populateModels(provider, config.provider?.[providerID], modelsDev[providerID]) - }) + Object.values(providers).map(async (provider) => { + const detected = await listModels(provider) + if (!detected) return + const detectedSet = new Set(detected) + for (const modelID of Object.keys(provider.models)) { + if (!detectedSet.has(modelID)) delete provider.models[modelID] + } + // TODO: add detected models not present in config/models.dev + // for (const modelID of detected) {} + }), ) for (const [providerID, provider] of Object.entries(providers)) {