diff options
| author | Mohamed Bassem <me@mbassem.com> | 2024-10-26 20:07:16 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2024-10-26 20:07:16 +0000 |
| commit | 3e727f7ba3ad157ca1ccc6100711266cae1bde23 (patch) | |
| tree | 767639f897f258886921162eb5bb1c73f318e61e /packages/shared/inference.ts | |
| parent | db45aaf1f61f57287bd2d98e73ec0a071b6caf88 (diff) | |
| download | karakeep-3e727f7ba3ad157ca1ccc6100711266cae1bde23.tar.zst | |
refactor: Move inference to the shared package
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts new file mode 100644 index 00000000..f34c2880 --- /dev/null +++ b/packages/shared/inference.ts @@ -0,0 +1,155 @@ +import { Ollama } from "ollama"; +import OpenAI from "openai"; + +import serverConfig from "./config"; +import logger from "./logger"; + +export interface InferenceResponse { + response: string; + totalTokens: number | undefined; +} + +export interface InferenceClient { + inferFromText(prompt: string): Promise<InferenceResponse>; + inferFromImage( + prompt: string, + contentType: string, + image: string, + ): Promise<InferenceResponse>; +} + +export class InferenceClientFactory { + static build(): InferenceClient | null { + if (serverConfig.inference.openAIApiKey) { + return new OpenAIInferenceClient(); + } + + if (serverConfig.inference.ollamaBaseUrl) { + return new OllamaInferenceClient(); + } + return null; + } +} + +class OpenAIInferenceClient implements InferenceClient { + openAI: OpenAI; + + constructor() { + this.openAI = new OpenAI({ + apiKey: serverConfig.inference.openAIApiKey, + baseURL: serverConfig.inference.openAIBaseUrl, + }); + } + + async inferFromText(prompt: string): Promise<InferenceResponse> { + const chatCompletion = await this.openAI.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + model: serverConfig.inference.textModel, + response_format: { type: "json_object" }, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`Got no message content from OpenAI`); + } + return { response, totalTokens: chatCompletion.usage?.total_tokens }; + } + + async inferFromImage( + prompt: string, + contentType: string, + image: string, + ): Promise<InferenceResponse> { + const chatCompletion = await this.openAI.chat.completions.create({ + model: serverConfig.inference.imageModel, + response_format: { type: "json_object" }, + messages: [ + { + role: "user", + content: [ + { type: "text", text: prompt }, + { + type: "image_url", + image_url: { + url: `data:${contentType};base64,${image}`, + detail: "low", + }, + }, + ], + }, + ], + max_tokens: 2000, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error(`Got no message content from OpenAI`); + } + return { response, totalTokens: chatCompletion.usage?.total_tokens }; + } +} + +class OllamaInferenceClient implements InferenceClient { + ollama: Ollama; + + constructor() { + this.ollama = new Ollama({ + host: serverConfig.inference.ollamaBaseUrl, + }); + } + + async runModel(model: string, prompt: string, image?: string) { + const chatCompletion = await this.ollama.chat({ + model: model, + format: "json", + stream: true, + keep_alive: serverConfig.inference.ollamaKeepAlive, + options: { + num_ctx: serverConfig.inference.contextLength, + }, + messages: [ + { role: "user", content: prompt, images: image ? [image] : undefined }, + ], + }); + + let totalTokens = 0; + let response = ""; + try { + for await (const part of chatCompletion) { + response += part.message.content; + if (!isNaN(part.eval_count)) { + totalTokens += part.eval_count; + } + if (!isNaN(part.prompt_eval_count)) { + totalTokens += part.prompt_eval_count; + } + } + } catch (e) { + // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. + // Using stream + accumulating the response so far is a workaround. + // https://github.com/ollama/ollama-js/issues/72 + totalTokens = NaN; + logger.warn( + `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`, + ); + } + + return { response, totalTokens }; + } + + async inferFromText(prompt: string): Promise<InferenceResponse> { + return await this.runModel(serverConfig.inference.textModel, prompt); + } + + async inferFromImage( + prompt: string, + _contentType: string, + image: string, + ): Promise<InferenceResponse> { + return await this.runModel( + serverConfig.inference.imageModel, + prompt, + image, + ); + } +} |
