aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2024-10-26 20:07:16 +0000
committerMohamed Bassem <me@mbassem.com>2024-10-26 20:07:16 +0000
commit3e727f7ba3ad157ca1ccc6100711266cae1bde23 (patch)
tree767639f897f258886921162eb5bb1c73f318e61e /packages/shared
parentdb45aaf1f61f57287bd2d98e73ec0a071b6caf88 (diff)
downloadkarakeep-3e727f7ba3ad157ca1ccc6100711266cae1bde23.tar.zst
refactor: Move inference to the shared package
Diffstat (limited to 'packages/shared')
-rw-r--r--packages/shared/inference.ts155
-rw-r--r--packages/shared/package.json2
2 files changed, 157 insertions, 0 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
new file mode 100644
index 00000000..f34c2880
--- /dev/null
+++ b/packages/shared/inference.ts
@@ -0,0 +1,155 @@
+import { Ollama } from "ollama";
+import OpenAI from "openai";
+
+import serverConfig from "./config";
+import logger from "./logger";
+
+export interface InferenceResponse {
+ response: string;
+ totalTokens: number | undefined;
+}
+
+export interface InferenceClient {
+ inferFromText(prompt: string): Promise<InferenceResponse>;
+ inferFromImage(
+ prompt: string,
+ contentType: string,
+ image: string,
+ ): Promise<InferenceResponse>;
+}
+
+export class InferenceClientFactory {
+ static build(): InferenceClient | null {
+ if (serverConfig.inference.openAIApiKey) {
+ return new OpenAIInferenceClient();
+ }
+
+ if (serverConfig.inference.ollamaBaseUrl) {
+ return new OllamaInferenceClient();
+ }
+ return null;
+ }
+}
+
+class OpenAIInferenceClient implements InferenceClient {
+ openAI: OpenAI;
+
+ constructor() {
+ this.openAI = new OpenAI({
+ apiKey: serverConfig.inference.openAIApiKey,
+ baseURL: serverConfig.inference.openAIBaseUrl,
+ });
+ }
+
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ const chatCompletion = await this.openAI.chat.completions.create({
+ messages: [{ role: "user", content: prompt }],
+ model: serverConfig.inference.textModel,
+ response_format: { type: "json_object" },
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`Got no message content from OpenAI`);
+ }
+ return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ }
+
+ async inferFromImage(
+ prompt: string,
+ contentType: string,
+ image: string,
+ ): Promise<InferenceResponse> {
+ const chatCompletion = await this.openAI.chat.completions.create({
+ model: serverConfig.inference.imageModel,
+ response_format: { type: "json_object" },
+ messages: [
+ {
+ role: "user",
+ content: [
+ { type: "text", text: prompt },
+ {
+ type: "image_url",
+ image_url: {
+ url: `data:${contentType};base64,${image}`,
+ detail: "low",
+ },
+ },
+ ],
+ },
+ ],
+ max_tokens: 2000,
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error(`Got no message content from OpenAI`);
+ }
+ return { response, totalTokens: chatCompletion.usage?.total_tokens };
+ }
+}
+
+class OllamaInferenceClient implements InferenceClient {
+ ollama: Ollama;
+
+ constructor() {
+ this.ollama = new Ollama({
+ host: serverConfig.inference.ollamaBaseUrl,
+ });
+ }
+
+ async runModel(model: string, prompt: string, image?: string) {
+ const chatCompletion = await this.ollama.chat({
+ model: model,
+ format: "json",
+ stream: true,
+ keep_alive: serverConfig.inference.ollamaKeepAlive,
+ options: {
+ num_ctx: serverConfig.inference.contextLength,
+ },
+ messages: [
+ { role: "user", content: prompt, images: image ? [image] : undefined },
+ ],
+ });
+
+ let totalTokens = 0;
+ let response = "";
+ try {
+ for await (const part of chatCompletion) {
+ response += part.message.content;
+ if (!isNaN(part.eval_count)) {
+ totalTokens += part.eval_count;
+ }
+ if (!isNaN(part.prompt_eval_count)) {
+ totalTokens += part.prompt_eval_count;
+ }
+ }
+ } catch (e) {
+ // There seem to be some bug in ollama where you can get some successfull response, but still throw an error.
+ // Using stream + accumulating the response so far is a workaround.
+ // https://github.com/ollama/ollama-js/issues/72
+ totalTokens = NaN;
+ logger.warn(
+ `Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`,
+ );
+ }
+
+ return { response, totalTokens };
+ }
+
+ async inferFromText(prompt: string): Promise<InferenceResponse> {
+ return await this.runModel(serverConfig.inference.textModel, prompt);
+ }
+
+ async inferFromImage(
+ prompt: string,
+ _contentType: string,
+ image: string,
+ ): Promise<InferenceResponse> {
+ return await this.runModel(
+ serverConfig.inference.imageModel,
+ prompt,
+ image,
+ );
+ }
+}
diff --git a/packages/shared/package.json b/packages/shared/package.json
index 69d93075..f6774263 100644
--- a/packages/shared/package.json
+++ b/packages/shared/package.json
@@ -8,6 +8,8 @@
"@hoarder/queue": "workspace:^0.1.0",
"glob": "^11.0.0",
"meilisearch": "^0.37.0",
+ "ollama": "^0.5.9",
+ "openai": "^4.67.1",
"winston": "^3.11.0",
"zod": "^3.22.4"
},