From f00287ede0675521c783c1199675538571f977d6 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Mon, 29 Dec 2025 23:35:28 +0000 Subject: refactor: reduce duplication in compare-models tool --- tools/compare-models/src/inferenceClient.ts | 157 +++++++--------------------- 1 file changed, 37 insertions(+), 120 deletions(-) (limited to 'tools/compare-models/src/inferenceClient.ts') diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts index 33617318..8649f715 100644 --- a/tools/compare-models/src/inferenceClient.ts +++ b/tools/compare-models/src/inferenceClient.ts @@ -1,128 +1,45 @@ -import OpenAI from "openai"; -import { zodResponseFormat } from "openai/helpers/zod"; +import type { InferenceClient } from "@karakeep/shared/inference"; +import { + OpenAIInferenceClient, + type OpenAIInferenceConfig, +} from "@karakeep/shared/inference"; import { z } from "zod"; -export interface InferenceOptions { - schema: z.ZodSchema | null; +import { config } from "./config"; + +export function createInferenceClient(modelName: string): InferenceClient { + const inferenceConfig: OpenAIInferenceConfig = { + apiKey: config.OPENAI_API_KEY, + baseURL: config.OPENAI_BASE_URL, + textModel: modelName, + imageModel: modelName, // Use same model for images if needed + contextLength: config.INFERENCE_CONTEXT_LENGTH, + maxOutputTokens: config.INFERENCE_MAX_OUTPUT_TOKENS, + useMaxCompletionTokens: config.INFERENCE_USE_MAX_COMPLETION_TOKENS, + outputSchema: "structured", + }; + + return new OpenAIInferenceClient(inferenceConfig); } -export interface InferenceResponse { - response: string; - totalTokens: number | undefined; -} - -export class InferenceClient { - private client: OpenAI; - - constructor(apiKey: string, baseUrl?: string) { - this.client = new OpenAI({ - apiKey, - baseURL: baseUrl, - defaultHeaders: { - "X-Title": "Karakeep Model Comparison", - }, - }); - } - - async inferTags( - content: string, - model: string, - lang: string = "english", - customPrompts: string[] = [], - ): Promise { - const tagsSchema = z.object({ - tags: z.array(z.string()), - }); - - const response = await this.inferFromText( - this.buildPrompt(content, lang, customPrompts), - model, - { schema: tagsSchema }, +export async function inferTags( + inferenceClient: InferenceClient, + prompt: string, +): Promise { + const tagsSchema = z.object({ + tags: z.array(z.string()), + }); + + const response = await inferenceClient.inferFromText(prompt, { + schema: tagsSchema, + }); + + const parsed = tagsSchema.safeParse(JSON.parse(response.response)); + if (!parsed.success) { + throw new Error( + `Failed to parse model response: ${parsed.error.message}`, ); - - const parsed = tagsSchema.safeParse( - this.parseJsonFromResponse(response.response), - ); - if (!parsed.success) { - throw new Error( - `Failed to parse model response: ${parsed.error.message}`, - ); - } - - return parsed.data.tags; - } - - private async inferFromText( - prompt: string, - model: string, - opts: InferenceOptions, - ): Promise { - const chatCompletion = await this.client.chat.completions.create({ - messages: [{ role: "user", content: prompt }], - model: model, - response_format: opts.schema - ? zodResponseFormat(opts.schema, "schema") - : { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error("Got no message content from model"); - } - - return { - response, - totalTokens: chatCompletion.usage?.total_tokens, - }; - } - - private buildPrompt( - content: string, - lang: string, - customPrompts: string[], - ): string { - return ` -You are an expert whose responsibility is to help with automatic tagging for a read-it-later app. -Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are: -- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. -- The tags must be in ${lang}. -- If tag is not generic enough, don't include it. -- The content can include text for cookie consent and privacy policy, ignore those while tagging. -- Aim for 3-5 tags. -- If there are no good tags, leave the array empty. -${customPrompts.map((p) => `- ${p}`).join("\n")} - - -${content} - -You must respond in JSON with key "tags" and the value is an array of string tags.`; } - private parseJsonFromResponse(response: string): unknown { - const trimmedResponse = response.trim(); - - try { - return JSON.parse(trimmedResponse); - } catch { - const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i; - const match = trimmedResponse.match(jsonBlockRegex); - - if (match) { - try { - return JSON.parse(match[1]); - } catch {} - } - - const jsonObjectRegex = /\{[\s\S]*\}/; - const objectMatch = trimmedResponse.match(jsonObjectRegex); - - if (objectMatch) { - try { - return JSON.parse(objectMatch[0]); - } catch {} - } - - return JSON.parse(trimmedResponse); - } - } + return parsed.data.tags; } -- cgit v1.2.3-70-g09d2