core: convert Model type to Zod schema for better type safety and validation
parent
48cf07d32a
commit
a844eb2429
|
|
@ -1,7 +1,7 @@
|
|||
import z from "zod"
|
||||
import fuzzysort from "fuzzysort"
|
||||
import { Config } from "../config/config"
|
||||
import { mergeDeep, sortBy } from "remeda"
|
||||
import { entries, mapValues, mergeDeep, pipe, sortBy } from "remeda"
|
||||
import { NoSuchModelError, type LanguageModel, type Provider as SDK } from "ai"
|
||||
import { Log } from "../util/log"
|
||||
import { BunProc } from "../bun"
|
||||
|
|
@ -43,7 +43,7 @@ export namespace Provider {
|
|||
"@ai-sdk/github-copilot": createGitHubCopilotOpenAICompatible,
|
||||
}
|
||||
|
||||
type CustomLoader = (provider: ModelsDev.Provider) => Promise<{
|
||||
type CustomLoader = (provider: Info) => Promise<{
|
||||
autoload: boolean
|
||||
getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
|
||||
options?: Record<string, any>
|
||||
|
|
@ -299,18 +299,155 @@ export namespace Provider {
|
|||
},
|
||||
}
|
||||
|
||||
export type Model = {
|
||||
providerID: string
|
||||
modelID: string
|
||||
export const Model = z
|
||||
.object({
|
||||
id: z.string(),
|
||||
providerID: z.string(),
|
||||
api: z.object({
|
||||
id: z.string(),
|
||||
url: z.string(),
|
||||
npm: z.string(),
|
||||
}),
|
||||
name: z.string(),
|
||||
capabilities: z.object({
|
||||
temperature: z.boolean(),
|
||||
reasoning: z.boolean(),
|
||||
attachment: z.boolean(),
|
||||
toolcall: z.boolean(),
|
||||
input: {
|
||||
text: z.boolean(),
|
||||
audio: z.boolean(),
|
||||
image: z.boolean(),
|
||||
video: z.boolean(),
|
||||
pdf: z.boolean(),
|
||||
},
|
||||
output: {
|
||||
text: z.boolean(),
|
||||
audio: z.boolean(),
|
||||
image: z.boolean(),
|
||||
video: z.boolean(),
|
||||
pdf: z.boolean(),
|
||||
},
|
||||
}),
|
||||
cost: z.object({
|
||||
input: z.number(),
|
||||
output: z.number(),
|
||||
cache: z.object({
|
||||
read: z.number(),
|
||||
write: z.number(),
|
||||
}),
|
||||
experimentalOver200K: z
|
||||
.object({
|
||||
input: z.number(),
|
||||
output: z.number(),
|
||||
cache: z.object({
|
||||
read: z.number(),
|
||||
write: z.number(),
|
||||
}),
|
||||
})
|
||||
.optional(),
|
||||
}),
|
||||
limit: z.object({
|
||||
context: z.number(),
|
||||
output: z.number(),
|
||||
}),
|
||||
status: z.enum(["alpha", "beta", "deprecated", "active"]),
|
||||
options: z.record(z.string(), z.any()),
|
||||
headers: z.record(z.string(), z.string()),
|
||||
})
|
||||
.meta({
|
||||
ref: "Model",
|
||||
})
|
||||
export type Model = z.infer<typeof Model>
|
||||
|
||||
export const Info = z.object({
|
||||
id: z.string(),
|
||||
name: z.string(),
|
||||
source: z.enum(["env", "config", "custom", "api"]),
|
||||
env: z.string().array(),
|
||||
key: z.string().optional(),
|
||||
options: z.record(z.string(), z.any()),
|
||||
models: z.record(z.string(), Model),
|
||||
})
|
||||
export type Info = z.infer<typeof Info>
|
||||
|
||||
function fromModelsDevModel(provider: ModelsDev.Provider, model: ModelsDev.Model): Model {
|
||||
return {
|
||||
id: model.id,
|
||||
name: model.name,
|
||||
api: {
|
||||
id: model.id,
|
||||
url: provider.api!,
|
||||
npm: model.provider?.npm ?? provider.npm ?? provider.id,
|
||||
},
|
||||
status: model.status ?? "active",
|
||||
headers: model.headers ?? {},
|
||||
options: model.options ?? {},
|
||||
cost: {
|
||||
input: model.cost.input,
|
||||
output: model.cost.output,
|
||||
cache: {
|
||||
read: model.cost.cache_read ?? 0,
|
||||
write: model.cost.cache_write ?? 0,
|
||||
},
|
||||
experimentalOver200K: model.cost.context_over_200k
|
||||
? {
|
||||
cache: {
|
||||
read: model.cost.context_over_200k.cache_read ?? 0,
|
||||
write: model.cost.context_over_200k.cache_write ?? 0,
|
||||
},
|
||||
input: model.cost.context_over_200k.input,
|
||||
output: model.cost.context_over_200k.output,
|
||||
}
|
||||
: undefined,
|
||||
},
|
||||
limit: {
|
||||
context: model.limit.context,
|
||||
output: model.limit.output,
|
||||
},
|
||||
capabilities: {
|
||||
temperature: model.temperature,
|
||||
reasoning: model.reasoning,
|
||||
attachment: model.attachment,
|
||||
toolcall: model.tool_call,
|
||||
input: {
|
||||
text: model.modalities?.input?.includes("text") ?? false,
|
||||
audio: model.modalities?.input?.includes("audio") ?? false,
|
||||
image: model.modalities?.input?.includes("image") ?? false,
|
||||
video: model.modalities?.input?.includes("video") ?? false,
|
||||
pdf: model.modalities?.input?.includes("pdf") ?? false,
|
||||
},
|
||||
output: {
|
||||
text: model.modalities?.output?.includes("text") ?? false,
|
||||
audio: model.modalities?.output?.includes("audio") ?? false,
|
||||
image: model.modalities?.output?.includes("image") ?? false,
|
||||
video: model.modalities?.output?.includes("video") ?? false,
|
||||
pdf: model.modalities?.output?.includes("pdf") ?? false,
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
function fromModelsDevProvider(provider: ModelsDev.Provider): Info {
|
||||
return {
|
||||
id: provider.id,
|
||||
source: "custom",
|
||||
name: provider.name,
|
||||
env: provider.env ?? [],
|
||||
options: {},
|
||||
models: mapValues(provider.models, (model) => fromModelsDevModel(provider, model)),
|
||||
}
|
||||
}
|
||||
|
||||
export type ModelWithStuff = {
|
||||
language: LanguageModel
|
||||
info: ModelsDev.Model
|
||||
npm: string
|
||||
info: Model
|
||||
}
|
||||
|
||||
const state = Instance.state(async () => {
|
||||
using _ = log.time("state")
|
||||
const config = await Config.get()
|
||||
const database = await ModelsDev.get()
|
||||
const database = mapValues(await ModelsDev.get(), fromModelsDevProvider)
|
||||
|
||||
const disabled = new Set(config.disabled_providers ?? [])
|
||||
const enabled = config.enabled_providers ? new Set(config.enabled_providers) : null
|
||||
|
|
@ -321,43 +458,12 @@ export namespace Provider {
|
|||
return true
|
||||
}
|
||||
|
||||
const providers: {
|
||||
[providerID: string]: {
|
||||
source: Source
|
||||
info: ModelsDev.Provider
|
||||
getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>
|
||||
options: Record<string, any>
|
||||
}
|
||||
} = {}
|
||||
const models = new Map<string, Model>()
|
||||
const providers: { [providerID: string]: Info } = {}
|
||||
const models = new Map<string, ModelWithStuff>()
|
||||
const sdk = new Map<number, SDK>()
|
||||
|
||||
log.info("init")
|
||||
|
||||
function mergeProvider(
|
||||
id: string,
|
||||
options: Record<string, any>,
|
||||
source: Source,
|
||||
getModel?: (sdk: any, modelID: string, options?: Record<string, any>) => Promise<any>,
|
||||
) {
|
||||
const provider = providers[id]
|
||||
if (!provider) {
|
||||
const info = database[id]
|
||||
if (!info) return
|
||||
if (info.api && !options["baseURL"]) options["baseURL"] = info.api
|
||||
providers[id] = {
|
||||
source,
|
||||
info,
|
||||
options,
|
||||
getModel,
|
||||
}
|
||||
return
|
||||
}
|
||||
provider.options = mergeDeep(provider.options, options)
|
||||
provider.source = source
|
||||
provider.getModel = getModel ?? provider.getModel
|
||||
}
|
||||
|
||||
const configProviders = Object.entries(config.provider ?? {})
|
||||
|
||||
// Add GitHub Copilot Enterprise provider that inherits from GitHub Copilot
|
||||
|
|
@ -367,11 +473,17 @@ export namespace Provider {
|
|||
...githubCopilot,
|
||||
id: "github-copilot-enterprise",
|
||||
name: "GitHub Copilot Enterprise",
|
||||
// Enterprise uses a different API endpoint - will be set dynamically based on auth
|
||||
api: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
function mergeProvider(providerID: string, provider: Partial<Info>) {
|
||||
const match = database[providerID]
|
||||
if (!match) return
|
||||
// @ts-expect-error
|
||||
providers[providerID] = mergeDeep(match, provider)
|
||||
}
|
||||
|
||||
// TODO: load config
|
||||
for (const [providerID, provider] of configProviders) {
|
||||
const existing = database[providerID]
|
||||
const parsed: ModelsDev.Provider = {
|
||||
|
|
@ -390,29 +502,27 @@ export namespace Provider {
|
|||
if (model.id && model.id !== modelID) return modelID
|
||||
return existing?.name ?? modelID
|
||||
})
|
||||
const parsedModel: ModelsDev.Model = {
|
||||
const parsedModel: Model = {
|
||||
id: modelID,
|
||||
target: model.target ?? existing?.target ?? modelID,
|
||||
apiID: model.target ?? existing?.target ?? modelID,
|
||||
status: model.status ?? existing?.status ?? "alpha",
|
||||
name,
|
||||
release_date: model.release_date ?? existing?.release_date,
|
||||
attachment: model.attachment ?? existing?.attachment ?? false,
|
||||
reasoning: model.reasoning ?? existing?.reasoning ?? false,
|
||||
temperature: model.temperature ?? existing?.temperature ?? false,
|
||||
tool_call: model.tool_call ?? existing?.tool_call ?? true,
|
||||
cost:
|
||||
!model.cost && !existing?.cost
|
||||
? {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cache_read: 0,
|
||||
cache_write: 0,
|
||||
}
|
||||
: {
|
||||
cache_read: 0,
|
||||
cache_write: 0,
|
||||
...existing?.cost,
|
||||
...model.cost,
|
||||
},
|
||||
providerID,
|
||||
npm: model.provider?.npm ?? existing?.provider?.npm ?? provider.npm ?? providerID,
|
||||
support: {
|
||||
temperature: model.temperature ?? existing?.temperature ?? false,
|
||||
reasoning: model.reasoning ?? existing?.reasoning ?? false,
|
||||
attachment: model.attachment ?? existing?.attachment ?? false,
|
||||
toolcall: model.tool_call ?? existing?.tool_call ?? true,
|
||||
},
|
||||
cost: {
|
||||
input: model?.cost?.input ?? existing?.cost?.input ?? 0,
|
||||
output: model?.cost?.output ?? existing?.cost?.output ?? 0,
|
||||
cache: {
|
||||
read: model?.cost?.cache_read ?? existing?.cost?.cache_read ?? 0,
|
||||
write: model?.cost?.cache_write ?? existing?.cost?.cache_write ?? 0,
|
||||
},
|
||||
},
|
||||
options: {
|
||||
...existing?.options,
|
||||
...model.options,
|
||||
|
|
@ -427,8 +537,7 @@ export namespace Provider {
|
|||
input: ["text"],
|
||||
output: ["text"],
|
||||
},
|
||||
headers: model.headers,
|
||||
provider: model.provider ?? existing?.provider,
|
||||
headers: model.headers ?? {},
|
||||
}
|
||||
parsed.models[modelID] = parsedModel
|
||||
}
|
||||
|
|
@ -442,19 +551,20 @@ export namespace Provider {
|
|||
if (disabled.has(providerID)) continue
|
||||
const apiKey = provider.env.map((item) => env[item]).find(Boolean)
|
||||
if (!apiKey) continue
|
||||
mergeProvider(
|
||||
providerID,
|
||||
// only include apiKey if there's only one potential option
|
||||
provider.env.length === 1 ? { apiKey } : {},
|
||||
"env",
|
||||
)
|
||||
mergeProvider(providerID, {
|
||||
source: "env",
|
||||
key: provider.env.length === 1 ? apiKey : undefined,
|
||||
})
|
||||
}
|
||||
|
||||
// load apikeys
|
||||
for (const [providerID, provider] of Object.entries(await Auth.all())) {
|
||||
if (disabled.has(providerID)) continue
|
||||
if (provider.type === "api") {
|
||||
mergeProvider(providerID, { apiKey: provider.key }, "api")
|
||||
mergeProvider(providerID, {
|
||||
source: "api",
|
||||
key: provider.key,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -480,7 +590,10 @@ export namespace Provider {
|
|||
// Load for the main provider if auth exists
|
||||
if (auth) {
|
||||
const options = await plugin.auth.loader(() => Auth.get(providerID) as any, database[plugin.auth.provider])
|
||||
mergeProvider(plugin.auth.provider, options ?? {}, "custom")
|
||||
mergeProvider(plugin.auth.provider, {
|
||||
source: "custom",
|
||||
options: options,
|
||||
})
|
||||
}
|
||||
|
||||
// If this is github-copilot plugin, also register for github-copilot-enterprise if auth exists
|
||||
|
|
@ -493,7 +606,10 @@ export namespace Provider {
|
|||
() => Auth.get(enterpriseProviderID) as any,
|
||||
database[enterpriseProviderID],
|
||||
)
|
||||
mergeProvider(enterpriseProviderID, enterpriseOptions ?? {}, "custom")
|
||||
mergeProvider(enterpriseProviderID, {
|
||||
source: "custom",
|
||||
options: enterpriseOptions,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -503,13 +619,22 @@ export namespace Provider {
|
|||
if (disabled.has(providerID)) continue
|
||||
const result = await fn(database[providerID])
|
||||
if (result && (result.autoload || providers[providerID])) {
|
||||
mergeProvider(providerID, result.options ?? {}, "custom", result.getModel)
|
||||
mergeProvider(providerID, {
|
||||
source: "custom",
|
||||
options: result.options,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// load config
|
||||
for (const [providerID, provider] of configProviders) {
|
||||
mergeProvider(providerID, provider.options ?? {}, "config")
|
||||
mergeProvider(providerID, {
|
||||
source: "config",
|
||||
env: provider.env,
|
||||
name: provider.name,
|
||||
options: provider.options,
|
||||
// TODO: merge models
|
||||
})
|
||||
}
|
||||
|
||||
for (const [providerID, provider] of Object.entries(providers)) {
|
||||
|
|
@ -519,33 +644,36 @@ export namespace Provider {
|
|||
}
|
||||
|
||||
if (providerID === "github-copilot" || providerID === "github-copilot-enterprise") {
|
||||
provider.info.npm = "@ai-sdk/github-copilot"
|
||||
provider.models = mapValues(provider.models, (model) => ({
|
||||
...model,
|
||||
api: {
|
||||
...model.api,
|
||||
npm: "@ai-sdk/github-copilot",
|
||||
},
|
||||
}))
|
||||
}
|
||||
|
||||
const configProvider = config.provider?.[providerID]
|
||||
|
||||
for (const [modelID, model] of Object.entries(provider.info.models)) {
|
||||
model.target = model.target ?? model.id ?? modelID
|
||||
for (const [modelID, model] of Object.entries(provider.models)) {
|
||||
model.api.id = model.api.id ?? model.id ?? modelID
|
||||
if (modelID === "gpt-5-chat-latest" || (providerID === "openrouter" && modelID === "openai/gpt-5-chat"))
|
||||
delete provider.info.models[modelID]
|
||||
if (
|
||||
((model.status === "alpha" || model.experimental) && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) ||
|
||||
model.status === "deprecated"
|
||||
)
|
||||
delete provider.info.models[modelID]
|
||||
delete provider.models[modelID]
|
||||
if ((model.status === "alpha" && !Flag.OPENCODE_ENABLE_EXPERIMENTAL_MODELS) || model.status === "deprecated")
|
||||
delete provider.models[modelID]
|
||||
if (
|
||||
(configProvider?.blacklist && configProvider.blacklist.includes(modelID)) ||
|
||||
(configProvider?.whitelist && !configProvider.whitelist.includes(modelID))
|
||||
)
|
||||
delete provider.info.models[modelID]
|
||||
delete provider.models[modelID]
|
||||
}
|
||||
|
||||
if (Object.keys(provider.info.models).length === 0) {
|
||||
if (Object.keys(provider.models).length === 0) {
|
||||
delete providers[providerID]
|
||||
continue
|
||||
}
|
||||
|
||||
log.info("found", { providerID, npm: provider.info.npm })
|
||||
log.info("found", { providerID })
|
||||
}
|
||||
|
||||
return {
|
||||
|
|
@ -559,18 +687,28 @@ export namespace Provider {
|
|||
return state().then((state) => state.providers)
|
||||
}
|
||||
|
||||
async function getSDK(npm: string, providerID: string) {
|
||||
async function getSDK(model: Model) {
|
||||
try {
|
||||
using _ = log.time("getSDK", {
|
||||
providerID,
|
||||
providerID: model.providerID,
|
||||
})
|
||||
const s = await state()
|
||||
const options = { ...s.providers[providerID]?.options }
|
||||
if (npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) {
|
||||
const provider = s.providers[model.providerID]
|
||||
const options = { ...provider.options }
|
||||
|
||||
if (model.api.npm.includes("@ai-sdk/openai-compatible") && options["includeUsage"] !== false) {
|
||||
options["includeUsage"] = true
|
||||
}
|
||||
|
||||
const key = Bun.hash.xxHash32(JSON.stringify({ pkg: npm, options }))
|
||||
if (!options["baseURL"]) options["baseURL"] = model.api.url
|
||||
if (!options["apiKey"]) options["apiKey"] = provider.key
|
||||
if (model.headers)
|
||||
options["headers"] = {
|
||||
...options["headers"],
|
||||
...model.headers,
|
||||
}
|
||||
|
||||
const key = Bun.hash.xxHash32(JSON.stringify({ npm: model.api.npm, options }))
|
||||
const existing = s.sdk.get(key)
|
||||
if (existing) return existing
|
||||
|
||||
|
|
@ -599,12 +737,13 @@ export namespace Provider {
|
|||
}
|
||||
|
||||
// Special case: google-vertex-anthropic uses a subpath import
|
||||
const bundledKey = providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : npm
|
||||
const bundledKey =
|
||||
model.providerID === "google-vertex-anthropic" ? "@ai-sdk/google-vertex/anthropic" : model.api.npm
|
||||
const bundledFn = BUNDLED_PROVIDERS[bundledKey]
|
||||
if (bundledFn) {
|
||||
log.info("using bundled provider", { providerID, pkg: bundledKey })
|
||||
log.info("using bundled provider", { providerID: model.providerID, pkg: bundledKey })
|
||||
const loaded = bundledFn({
|
||||
name: providerID,
|
||||
name: model.providerID,
|
||||
...options,
|
||||
})
|
||||
s.sdk.set(key, loaded)
|
||||
|
|
@ -612,24 +751,24 @@ export namespace Provider {
|
|||
}
|
||||
|
||||
let installedPath: string
|
||||
if (!npm.startsWith("file://")) {
|
||||
installedPath = await BunProc.install(npm, "latest")
|
||||
if (!model.api.npm.startsWith("file://")) {
|
||||
installedPath = await BunProc.install(model.api.npm, "latest")
|
||||
} else {
|
||||
log.info("loading local provider", { pkg: npm })
|
||||
installedPath = npm
|
||||
log.info("loading local provider", { pkg: model.api.npm })
|
||||
installedPath = model.api.npm
|
||||
}
|
||||
|
||||
const mod = await import(installedPath)
|
||||
|
||||
const fn = mod[Object.keys(mod).find((key) => key.startsWith("create"))!]
|
||||
const loaded = fn({
|
||||
name: providerID,
|
||||
name: model.providerID,
|
||||
...options,
|
||||
})
|
||||
s.sdk.set(key, loaded)
|
||||
return loaded as SDK
|
||||
} catch (e) {
|
||||
throw new InitError({ providerID }, { cause: e })
|
||||
throw new InitError({ providerID: model.providerID }, { cause: e })
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -655,22 +794,25 @@ export namespace Provider {
|
|||
throw new ModelNotFoundError({ providerID, modelID, suggestions })
|
||||
}
|
||||
|
||||
const info = provider.info.models[modelID]
|
||||
const info = provider.models[modelID]
|
||||
if (!info) {
|
||||
const availableModels = Object.keys(provider.info.models)
|
||||
const availableModels = Object.keys(provider.models)
|
||||
const matches = fuzzysort.go(modelID, availableModels, { limit: 3, threshold: -10000 })
|
||||
const suggestions = matches.map((m) => m.target)
|
||||
throw new ModelNotFoundError({ providerID, modelID, suggestions })
|
||||
}
|
||||
|
||||
const npm = info.provider?.npm ?? provider.info.npm ?? info.id
|
||||
const sdk = await getSDK(npm, providerID)
|
||||
const sdk = await getSDK(info)
|
||||
|
||||
try {
|
||||
const language = provider.getModel
|
||||
? await provider.getModel(sdk, info.target, provider.options)
|
||||
: sdk.languageModel(info.target)
|
||||
? await provider.getModel(sdk, info.api.id, provider.options)
|
||||
: sdk.languageModel(info.api.id)
|
||||
log.info("found", { providerID, modelID })
|
||||
const cached: ModelWithStuff = {
|
||||
info,
|
||||
language,
|
||||
}
|
||||
s.models.set(key, {
|
||||
providerID,
|
||||
modelID,
|
||||
|
|
@ -755,7 +897,7 @@ export namespace Provider {
|
|||
}
|
||||
|
||||
const priority = ["gpt-5", "claude-sonnet-4", "big-pickle", "gemini-3-pro"]
|
||||
export function sort(models: ModelsDev.Model[]) {
|
||||
export function sort(models: Model[]) {
|
||||
return sortBy(
|
||||
models,
|
||||
[(model) => priority.findIndex((filter) => model.id.includes(filter)), "desc"],
|
||||
|
|
|
|||
Loading…
Reference in New Issue