aboutsummaryrefslogtreecommitdiffstats
path: root/tools/compare-models/src/inferenceClient.ts
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-12-29 23:35:28 +0000
committerMohamed Bassem <me@mbassem.com>2025-12-29 23:38:21 +0000
commitf00287ede0675521c783c1199675538571f977d6 (patch)
tree2d04b983fa514f4c62a3695c0a521fb50de24eef /tools/compare-models/src/inferenceClient.ts
parentba8d84a555f9e6cf209c826b97a124f0539739eb (diff)
downloadkarakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'tools/compare-models/src/inferenceClient.ts')
-rw-r--r--tools/compare-models/src/inferenceClient.ts157
1 files changed, 37 insertions, 120 deletions
diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts
index 33617318..8649f715 100644
--- a/tools/compare-models/src/inferenceClient.ts
+++ b/tools/compare-models/src/inferenceClient.ts
@@ -1,128 +1,45 @@
-import OpenAI from "openai";
-import { zodResponseFormat } from "openai/helpers/zod";
+import type { InferenceClient } from "@karakeep/shared/inference";
+import {
+ OpenAIInferenceClient,
+ type OpenAIInferenceConfig,
+} from "@karakeep/shared/inference";
import { z } from "zod";
-export interface InferenceOptions {
- schema: z.ZodSchema<any> | null;
+import { config } from "./config";
+
+export function createInferenceClient(modelName: string): InferenceClient {
+ const inferenceConfig: OpenAIInferenceConfig = {
+ apiKey: config.OPENAI_API_KEY,
+ baseURL: config.OPENAI_BASE_URL,
+ textModel: modelName,
+ imageModel: modelName, // Use same model for images if needed
+ contextLength: config.INFERENCE_CONTEXT_LENGTH,
+ maxOutputTokens: config.INFERENCE_MAX_OUTPUT_TOKENS,
+ useMaxCompletionTokens: config.INFERENCE_USE_MAX_COMPLETION_TOKENS,
+ outputSchema: "structured",
+ };
+
+ return new OpenAIInferenceClient(inferenceConfig);
}
-export interface InferenceResponse {
- response: string;
- totalTokens: number | undefined;
-}
-
-export class InferenceClient {
- private client: OpenAI;
-
- constructor(apiKey: string, baseUrl?: string) {
- this.client = new OpenAI({
- apiKey,
- baseURL: baseUrl,
- defaultHeaders: {
- "X-Title": "Karakeep Model Comparison",
- },
- });
- }
-
- async inferTags(
- content: string,
- model: string,
- lang: string = "english",
- customPrompts: string[] = [],
- ): Promise<string[]> {
- const tagsSchema = z.object({
- tags: z.array(z.string()),
- });
-
- const response = await this.inferFromText(
- this.buildPrompt(content, lang, customPrompts),
- model,
- { schema: tagsSchema },
+export async function inferTags(
+ inferenceClient: InferenceClient,
+ prompt: string,
+): Promise<string[]> {
+ const tagsSchema = z.object({
+ tags: z.array(z.string()),
+ });
+
+ const response = await inferenceClient.inferFromText(prompt, {
+ schema: tagsSchema,
+ });
+
+ const parsed = tagsSchema.safeParse(JSON.parse(response.response));
+ if (!parsed.success) {
+ throw new Error(
+ `Failed to parse model response: ${parsed.error.message}`,
);
-
- const parsed = tagsSchema.safeParse(
- this.parseJsonFromResponse(response.response),
- );
- if (!parsed.success) {
- throw new Error(
- `Failed to parse model response: ${parsed.error.message}`,
- );
- }
-
- return parsed.data.tags;
- }
-
- private async inferFromText(
- prompt: string,
- model: string,
- opts: InferenceOptions,
- ): Promise<InferenceResponse> {
- const chatCompletion = await this.client.chat.completions.create({
- messages: [{ role: "user", content: prompt }],
- model: model,
- response_format: opts.schema
- ? zodResponseFormat(opts.schema, "schema")
- : { type: "json_object" },
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error("Got no message content from model");
- }
-
- return {
- response,
- totalTokens: chatCompletion.usage?.total_tokens,
- };
- }
-
- private buildPrompt(
- content: string,
- lang: string,
- customPrompts: string[],
- ): string {
- return `
-You are an expert whose responsibility is to help with automatic tagging for a read-it-later app.
-Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are:
-- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres.
-- The tags must be in ${lang}.
-- If tag is not generic enough, don't include it.
-- The content can include text for cookie consent and privacy policy, ignore those while tagging.
-- Aim for 3-5 tags.
-- If there are no good tags, leave the array empty.
-${customPrompts.map((p) => `- ${p}`).join("\n")}
-
-<TEXT_CONTENT>
-${content}
-</TEXT_CONTENT>
-You must respond in JSON with key "tags" and the value is an array of string tags.`;
}
- private parseJsonFromResponse(response: string): unknown {
- const trimmedResponse = response.trim();
-
- try {
- return JSON.parse(trimmedResponse);
- } catch {
- const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i;
- const match = trimmedResponse.match(jsonBlockRegex);
-
- if (match) {
- try {
- return JSON.parse(match[1]);
- } catch {}
- }
-
- const jsonObjectRegex = /\{[\s\S]*\}/;
- const objectMatch = trimmedResponse.match(jsonObjectRegex);
-
- if (objectMatch) {
- try {
- return JSON.parse(objectMatch[0]);
- } catch {}
- }
-
- return JSON.parse(trimmedResponse);
- }
- }
+ return parsed.data.tags;
}