diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-12-29 23:35:28 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-12-29 23:38:21 +0000 |
| commit | f00287ede0675521c783c1199675538571f977d6 (patch) | |
| tree | 2d04b983fa514f4c62a3695c0a521fb50de24eef /packages/shared | |
| parent | ba8d84a555f9e6cf209c826b97a124f0539739eb (diff) | |
| download | karakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst | |
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'packages/shared')
| -rw-r--r-- | packages/shared/inference.ts | 107 |
1 files changed, 79 insertions, 28 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index fb9fce09..d6a9aa10 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -52,34 +52,47 @@ const mapInferenceOutputSchema = < return opts[type]; }; +export interface OpenAIInferenceConfig { + apiKey: string; + baseURL?: string; + proxyUrl?: string; + textModel: string; + imageModel: string; + contextLength: number; + maxOutputTokens: number; + useMaxCompletionTokens: boolean; + outputSchema: "structured" | "json" | "plain"; +} + export class InferenceClientFactory { static build(): InferenceClient | null { if (serverConfig.inference.openAIApiKey) { - return new OpenAIInferenceClient(); + return OpenAIInferenceClient.fromConfig(); } if (serverConfig.inference.ollamaBaseUrl) { - return new OllamaInferenceClient(); + return OllamaInferenceClient.fromConfig(); } return null; } } -class OpenAIInferenceClient implements InferenceClient { +export class OpenAIInferenceClient implements InferenceClient { openAI: OpenAI; + private config: OpenAIInferenceConfig; - constructor() { - const fetchOptions = serverConfig.inference.openAIProxyUrl + constructor(config: OpenAIInferenceConfig) { + this.config = config; + + const fetchOptions = config.proxyUrl ? { - dispatcher: new undici.ProxyAgent( - serverConfig.inference.openAIProxyUrl, - ), + dispatcher: new undici.ProxyAgent(config.proxyUrl), } : undefined; this.openAI = new OpenAI({ - apiKey: serverConfig.inference.openAIApiKey, - baseURL: serverConfig.inference.openAIBaseUrl, + apiKey: config.apiKey, + baseURL: config.baseURL, ...(fetchOptions ? { fetchOptions } : {}), defaultHeaders: { "X-Title": "Karakeep", @@ -88,6 +101,20 @@ class OpenAIInferenceClient implements InferenceClient { }); } + static fromConfig(): OpenAIInferenceClient { + return new OpenAIInferenceClient({ + apiKey: serverConfig.inference.openAIApiKey!, + baseURL: serverConfig.inference.openAIBaseUrl, + proxyUrl: serverConfig.inference.openAIProxyUrl, + textModel: serverConfig.inference.textModel, + imageModel: serverConfig.inference.imageModel, + contextLength: serverConfig.inference.contextLength, + maxOutputTokens: serverConfig.inference.maxOutputTokens, + useMaxCompletionTokens: serverConfig.inference.useMaxCompletionTokens, + outputSchema: serverConfig.inference.outputSchema, + }); + } + async inferFromText( prompt: string, _opts: Partial<InferenceOptions>, @@ -99,10 +126,10 @@ class OpenAIInferenceClient implements InferenceClient { const chatCompletion = await this.openAI.chat.completions.create( { messages: [{ role: "user", content: prompt }], - model: serverConfig.inference.textModel, - ...(serverConfig.inference.useMaxCompletionTokens - ? { max_completion_tokens: serverConfig.inference.maxOutputTokens } - : { max_tokens: serverConfig.inference.maxOutputTokens }), + model: this.config.textModel, + ...(this.config.useMaxCompletionTokens + ? { max_completion_tokens: this.config.maxOutputTokens } + : { max_tokens: this.config.maxOutputTokens }), response_format: mapInferenceOutputSchema( { structured: optsWithDefaults.schema @@ -111,7 +138,7 @@ class OpenAIInferenceClient implements InferenceClient { json: { type: "json_object" }, plain: undefined, }, - serverConfig.inference.outputSchema, + this.config.outputSchema, ), }, { @@ -138,10 +165,10 @@ class OpenAIInferenceClient implements InferenceClient { }; const chatCompletion = await this.openAI.chat.completions.create( { - model: serverConfig.inference.imageModel, - ...(serverConfig.inference.useMaxCompletionTokens - ? { max_completion_tokens: serverConfig.inference.maxOutputTokens } - : { max_tokens: serverConfig.inference.maxOutputTokens }), + model: this.config.imageModel, + ...(this.config.useMaxCompletionTokens + ? { max_completion_tokens: this.config.maxOutputTokens } + : { max_tokens: this.config.maxOutputTokens }), response_format: mapInferenceOutputSchema( { structured: optsWithDefaults.schema @@ -150,7 +177,7 @@ class OpenAIInferenceClient implements InferenceClient { json: { type: "json_object" }, plain: undefined, }, - serverConfig.inference.outputSchema, + this.config.outputSchema, ), messages: [ { @@ -195,16 +222,40 @@ class OpenAIInferenceClient implements InferenceClient { } } +export interface OllamaInferenceConfig { + baseUrl: string; + textModel: string; + imageModel: string; + contextLength: number; + maxOutputTokens: number; + keepAlive?: string; + outputSchema: "structured" | "json" | "plain"; +} + class OllamaInferenceClient implements InferenceClient { ollama: Ollama; + private config: OllamaInferenceConfig; - constructor() { + constructor(config: OllamaInferenceConfig) { + this.config = config; this.ollama = new Ollama({ - host: serverConfig.inference.ollamaBaseUrl, + host: config.baseUrl, fetch: customFetch, // Use the custom fetch with configurable timeout }); } + static fromConfig(): OllamaInferenceClient { + return new OllamaInferenceClient({ + baseUrl: serverConfig.inference.ollamaBaseUrl!, + textModel: serverConfig.inference.textModel, + imageModel: serverConfig.inference.imageModel, + contextLength: serverConfig.inference.contextLength, + maxOutputTokens: serverConfig.inference.maxOutputTokens, + keepAlive: serverConfig.inference.ollamaKeepAlive, + outputSchema: serverConfig.inference.outputSchema, + }); + } + async runModel( model: string, prompt: string, @@ -233,13 +284,13 @@ class OllamaInferenceClient implements InferenceClient { json: "json", plain: undefined, }, - serverConfig.inference.outputSchema, + this.config.outputSchema, ), stream: true, - keep_alive: serverConfig.inference.ollamaKeepAlive, + keep_alive: this.config.keepAlive, options: { - num_ctx: serverConfig.inference.contextLength, - num_predict: serverConfig.inference.maxOutputTokens, + num_ctx: this.config.contextLength, + num_predict: this.config.maxOutputTokens, }, messages: [ { role: "user", content: prompt, images: image ? [image] : undefined }, @@ -287,7 +338,7 @@ class OllamaInferenceClient implements InferenceClient { ..._opts, }; return await this.runModel( - serverConfig.inference.textModel, + this.config.textModel, prompt, optsWithDefaults, undefined, @@ -305,7 +356,7 @@ class OllamaInferenceClient implements InferenceClient { ..._opts, }; return await this.runModel( - serverConfig.inference.imageModel, + this.config.imageModel, prompt, optsWithDefaults, image, |
