From 9986746aa890f2490ff18fd4fc79be4de0e4dbe2 Mon Sep 17 00:00:00 2001 From: MohamedBassem Date: Wed, 27 Mar 2024 16:30:27 +0000 Subject: fix: Attempt to increase the reliability of the ollama inference --- apps/workers/inference.ts | 45 +++++++++++++++++++++++++++++++------------- apps/workers/openaiWorker.ts | 11 ++++++++--- 2 files changed, 40 insertions(+), 16 deletions(-) (limited to 'apps') diff --git a/apps/workers/inference.ts b/apps/workers/inference.ts index c622dd54..3b0b5943 100644 --- a/apps/workers/inference.ts +++ b/apps/workers/inference.ts @@ -2,6 +2,7 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; import serverConfig from "@hoarder/shared/config"; +import logger from "@hoarder/shared/logger"; export interface InferenceResponse { response: string; @@ -96,16 +97,41 @@ class OllamaInferenceClient implements InferenceClient { }); } - async inferFromText(prompt: string): Promise { + async runModel(model: string, prompt: string, image?: string) { const chatCompletion = await this.ollama.chat({ - model: serverConfig.inference.textModel, + model: model, format: "json", - messages: [{ role: "system", content: prompt }], + stream: true, + messages: [ + { role: "user", content: prompt, images: image ? [image] : undefined }, + ], }); - const response = chatCompletion.message.content; + let totalTokens = 0; + let response = ""; + try { + for await (const part of chatCompletion) { + response += part.message.content; + if (!isNaN(part.eval_count)) { + totalTokens += part.eval_count; + } + if (!isNaN(part.prompt_eval_count)) { + totalTokens += part.prompt_eval_count; + } + } + } catch (e) { + // There seem to be some bug in ollama where you can get some successfull response, but still throw an error. + // Using stream + accumulating the response so far is a workaround. + // https://github.com/ollama/ollama-js/issues/72 + totalTokens = NaN; + logger.warn(`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`) + } + + return { response, totalTokens }; + } - return { response, totalTokens: chatCompletion.eval_count }; + async inferFromText(prompt: string): Promise { + return await this.runModel(serverConfig.inference.textModel, prompt); } async inferFromImage( @@ -113,13 +139,6 @@ class OllamaInferenceClient implements InferenceClient { _contentType: string, image: string, ): Promise { - const chatCompletion = await this.ollama.chat({ - model: serverConfig.inference.imageModel, - format: "json", - messages: [{ role: "user", content: prompt, images: [`${image}`] }], - }); - - const response = chatCompletion.message.content; - return { response, totalTokens: chatCompletion.eval_count }; + return await this.runModel(serverConfig.inference.imageModel, prompt, image); } } diff --git a/apps/workers/openaiWorker.ts b/apps/workers/openaiWorker.ts index b706fb90..9b2934e3 100644 --- a/apps/workers/openaiWorker.ts +++ b/apps/workers/openaiWorker.ts @@ -14,7 +14,7 @@ import { zOpenAIRequestSchema, } from "@hoarder/shared/queues"; -import { InferenceClientFactory, InferenceClient } from "./inference"; +import { InferenceClient, InferenceClientFactory } from "./inference"; const openAIResponseSchema = z.object({ tags: z.array(z.string()), @@ -36,7 +36,7 @@ async function attemptMarkTaggingStatus( }) .where(eq(bookmarks.id, request.bookmarkId)); } catch (e) { - console.log(`Something went wrong when marking the tagging status: ${e}`); + logger.error(`Something went wrong when marking the tagging status: ${e}`); } } @@ -196,8 +196,9 @@ async function inferTags( return tags; } catch (e) { + const responseSneak = response.response.substr(0, 20); throw new Error( - `[inference][${jobId}] Failed to parse JSON response from inference client: ${e}`, + `[inference][${jobId}] The model ignored our prompt and didn't respond with the expected JSON: ${JSON.stringify(e)}. Here's a sneak peak from the response: ${responseSneak}`, ); } } @@ -285,6 +286,10 @@ async function runOpenAI(job: Job) { ); } + logger.info( + `[inference][${jobId}] Starting an inference job for bookmark with id "${bookmark.id}"`, + ); + const tags = await inferTags(jobId, bookmark, inferenceClient); await connectTags(bookmarkId, tags, bookmark.userId); -- cgit v1.2.3-70-g09d2