diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-12-29 23:35:28 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-12-29 23:38:21 +0000 |
| commit | f00287ede0675521c783c1199675538571f977d6 (patch) | |
| tree | 2d04b983fa514f4c62a3695c0a521fb50de24eef /tools/compare-models/src | |
| parent | ba8d84a555f9e6cf209c826b97a124f0539739eb (diff) | |
| download | karakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst | |
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'tools/compare-models/src')
| -rw-r--r-- | tools/compare-models/src/apiClient.ts | 8 | ||||
| -rw-r--r-- | tools/compare-models/src/bookmarkProcessor.ts | 20 | ||||
| -rw-r--r-- | tools/compare-models/src/config.ts | 19 | ||||
| -rw-r--r-- | tools/compare-models/src/index.ts | 110 | ||||
| -rw-r--r-- | tools/compare-models/src/inferenceClient.ts | 157 | ||||
| -rw-r--r-- | tools/compare-models/src/types.ts | 3 |
6 files changed, 164 insertions, 153 deletions
diff --git a/tools/compare-models/src/apiClient.ts b/tools/compare-models/src/apiClient.ts index f3a960cb..1d9f799d 100644 --- a/tools/compare-models/src/apiClient.ts +++ b/tools/compare-models/src/apiClient.ts @@ -53,7 +53,13 @@ export class KarakeepAPIClient { const batchBookmarks = (data?.bookmarks || []) .filter((b) => b.content?.type === "link") - .map((b) => b as Bookmark); + .map((b) => ({ + ...b, + tags: (b.tags || []).map((tag) => ({ + name: tag.name, + attachedBy: tag.attachedBy, + })), + })) as Bookmark[]; bookmarks.push(...batchBookmarks); cursor = data?.nextCursor || null; diff --git a/tools/compare-models/src/bookmarkProcessor.ts b/tools/compare-models/src/bookmarkProcessor.ts index 910957fe..21280b97 100644 --- a/tools/compare-models/src/bookmarkProcessor.ts +++ b/tools/compare-models/src/bookmarkProcessor.ts @@ -1,4 +1,7 @@ -import type { InferenceClient } from "./inferenceClient"; +import type { InferenceClient } from "@karakeep/shared/inference"; +import { buildTextPrompt } from "@karakeep/shared/prompts"; + +import { inferTags } from "./inferenceClient"; import type { Bookmark } from "./types"; export async function extractBookmarkContent( @@ -35,9 +38,9 @@ export async function extractBookmarkContent( export async function runTaggingForModel( bookmark: Bookmark, - model: string, inferenceClient: InferenceClient, lang: string = "english", + contextLength: number = 8000, ): Promise<string[]> { const content = await extractBookmarkContent(bookmark); @@ -46,11 +49,20 @@ export async function runTaggingForModel( } try { - const tags = await inferenceClient.inferTags(content, model, lang, []); + // Use the shared prompt builder with empty custom prompts and default tag style + const prompt = await buildTextPrompt( + lang, + [], // No custom prompts for comparison tool + content, + contextLength, + "as-generated", // Use tags as generated by the model + ); + + const tags = await inferTags(inferenceClient, prompt); return tags; } catch (error) { throw new Error( - `Failed to generate tags with ${model}: ${error instanceof Error ? error.message : String(error)}`, + `Failed to generate tags: ${error instanceof Error ? error.message : String(error)}`, ); } } diff --git a/tools/compare-models/src/config.ts b/tools/compare-models/src/config.ts index 9c32610d..0b5d217f 100644 --- a/tools/compare-models/src/config.ts +++ b/tools/compare-models/src/config.ts @@ -1,16 +1,33 @@ import { z } from "zod"; +// Local config schema for compare-models tool const envSchema = z.object({ KARAKEEP_API_KEY: z.string().min(1), KARAKEEP_SERVER_ADDR: z.string().url(), MODEL1_NAME: z.string().min(1), - MODEL2_NAME: z.string().min(1), + MODEL2_NAME: z.string().min(1).optional(), OPENAI_API_KEY: z.string().min(1), OPENAI_BASE_URL: z.string().url().optional(), + COMPARISON_MODE: z + .enum(["model-vs-model", "model-vs-existing"]) + .default("model-vs-model"), COMPARE_LIMIT: z .string() .optional() .transform((val) => (val ? parseInt(val, 10) : 10)), + INFERENCE_CONTEXT_LENGTH: z + .string() + .optional() + .transform((val) => (val ? parseInt(val, 10) : 8000)), + INFERENCE_MAX_OUTPUT_TOKENS: z + .string() + .optional() + .transform((val) => (val ? parseInt(val, 10) : 2048)), + INFERENCE_USE_MAX_COMPLETION_TOKENS: z + .string() + .optional() + .transform((val) => val === "true") + .default("false"), }); export const config = envSchema.parse(process.env); diff --git a/tools/compare-models/src/index.ts b/tools/compare-models/src/index.ts index c1a80ab5..88fc9249 100644 --- a/tools/compare-models/src/index.ts +++ b/tools/compare-models/src/index.ts @@ -4,7 +4,7 @@ import type { ComparisonResult } from "./types"; import { KarakeepAPIClient } from "./apiClient"; import { runTaggingForModel } from "./bookmarkProcessor"; import { config } from "./config"; -import { InferenceClient } from "./inferenceClient"; +import { createInferenceClient } from "./inferenceClient"; import { askQuestion, clearProgress, @@ -32,18 +32,58 @@ interface ShuffleResult { async function main() { console.log(chalk.cyan("\nš Karakeep Model Comparison Tool\n")); - const inferenceClient = new InferenceClient( - config.OPENAI_API_KEY, - config.OPENAI_BASE_URL, - ); + const isExistingMode = config.COMPARISON_MODE === "model-vs-existing"; + + if (isExistingMode) { + console.log( + chalk.yellow( + `Mode: Comparing ${config.MODEL1_NAME} against existing AI tags\n`, + ), + ); + } else { + if (!config.MODEL2_NAME) { + console.log( + chalk.red( + "\nā Error: MODEL2_NAME is required for model-vs-model comparison mode\n", + ), + ); + process.exit(1); + } + console.log( + chalk.yellow( + `Mode: Comparing ${config.MODEL1_NAME} vs ${config.MODEL2_NAME}\n`, + ), + ); + } const apiClient = new KarakeepAPIClient(); displayProgress("Fetching bookmarks from Karakeep..."); - const bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT); + let bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT); clearProgress(); - console.log(chalk.green(`ā Fetched ${bookmarks.length} link bookmarks\n`)); + // Filter bookmarks with AI tags if in existing mode + if (isExistingMode) { + bookmarks = bookmarks.filter( + (b) => b.tags.some((t) => t.attachedBy === "ai"), + ); + console.log( + chalk.green( + `ā Fetched ${bookmarks.length} link bookmarks with existing AI tags\n`, + ), + ); + } else { + console.log(chalk.green(`ā Fetched ${bookmarks.length} link bookmarks\n`)); + } + + if (bookmarks.length === 0) { + console.log( + chalk.yellow( + "\nā No bookmarks found with AI tags. Please add some bookmarks with AI tags first.\n", + ), + ); + return; + } const counters: VoteCounters = { model1Votes: 0, @@ -59,17 +99,20 @@ async function main() { const bookmark = bookmarks[i]; displayProgress( - `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || "Untitled"}`, + `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || bookmark.content.title || "Untitled"}`, ); let model1Tags: string[] = []; let model2Tags: string[] = []; + // Get tags for model 1 (new model) try { + const model1Client = createInferenceClient(config.MODEL1_NAME); model1Tags = await runTaggingForModel( bookmark, - config.MODEL1_NAME, - inferenceClient, + model1Client, + "english", + config.INFERENCE_CONTEXT_LENGTH, ); } catch (error) { clearProgress(); @@ -80,31 +123,46 @@ async function main() { continue; } - try { - model2Tags = await runTaggingForModel( - bookmark, - config.MODEL2_NAME, - inferenceClient, - ); - } catch (error) { - clearProgress(); - displayError( - `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, - ); - counters.errors++; - continue; + // Get tags for model 2 or existing AI tags + if (isExistingMode) { + // Use existing AI tags from the bookmark + model2Tags = bookmark.tags + .filter((t) => t.attachedBy === "ai") + .map((t) => t.name); + } else { + // Run inference with model 2 + try { + const model2Client = createInferenceClient(config.MODEL2_NAME!); + model2Tags = await runTaggingForModel( + bookmark, + model2Client, + "english", + config.INFERENCE_CONTEXT_LENGTH, + ); + } catch (error) { + clearProgress(); + displayError( + `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, + ); + counters.errors++; + continue; + } } clearProgress(); + const model2Label = isExistingMode + ? "Existing AI Tags" + : config.MODEL2_NAME!; + const shuffleResult: ShuffleResult = { modelA: config.MODEL1_NAME, - modelB: config.MODEL2_NAME, + modelB: model2Label, modelAIsModel1: Math.random() < 0.5, }; if (!shuffleResult.modelAIsModel1) { - shuffleResult.modelA = config.MODEL2_NAME; + shuffleResult.modelA = model2Label; shuffleResult.modelB = config.MODEL1_NAME; } @@ -156,7 +214,7 @@ async function main() { displayFinalResults({ model1Name: config.MODEL1_NAME, - model2Name: config.MODEL2_NAME, + model2Name: isExistingMode ? "Existing AI Tags" : config.MODEL2_NAME!, model1Votes: counters.model1Votes, model2Votes: counters.model2Votes, skipped: counters.skipped, diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts index 33617318..8649f715 100644 --- a/tools/compare-models/src/inferenceClient.ts +++ b/tools/compare-models/src/inferenceClient.ts @@ -1,128 +1,45 @@ -import OpenAI from "openai"; -import { zodResponseFormat } from "openai/helpers/zod"; +import type { InferenceClient } from "@karakeep/shared/inference"; +import { + OpenAIInferenceClient, + type OpenAIInferenceConfig, +} from "@karakeep/shared/inference"; import { z } from "zod"; -export interface InferenceOptions { - schema: z.ZodSchema<any> | null; +import { config } from "./config"; + +export function createInferenceClient(modelName: string): InferenceClient { + const inferenceConfig: OpenAIInferenceConfig = { + apiKey: config.OPENAI_API_KEY, + baseURL: config.OPENAI_BASE_URL, + textModel: modelName, + imageModel: modelName, // Use same model for images if needed + contextLength: config.INFERENCE_CONTEXT_LENGTH, + maxOutputTokens: config.INFERENCE_MAX_OUTPUT_TOKENS, + useMaxCompletionTokens: config.INFERENCE_USE_MAX_COMPLETION_TOKENS, + outputSchema: "structured", + }; + + return new OpenAIInferenceClient(inferenceConfig); } -export interface InferenceResponse { - response: string; - totalTokens: number | undefined; -} - -export class InferenceClient { - private client: OpenAI; - - constructor(apiKey: string, baseUrl?: string) { - this.client = new OpenAI({ - apiKey, - baseURL: baseUrl, - defaultHeaders: { - "X-Title": "Karakeep Model Comparison", - }, - }); - } - - async inferTags( - content: string, - model: string, - lang: string = "english", - customPrompts: string[] = [], - ): Promise<string[]> { - const tagsSchema = z.object({ - tags: z.array(z.string()), - }); - - const response = await this.inferFromText( - this.buildPrompt(content, lang, customPrompts), - model, - { schema: tagsSchema }, +export async function inferTags( + inferenceClient: InferenceClient, + prompt: string, +): Promise<string[]> { + const tagsSchema = z.object({ + tags: z.array(z.string()), + }); + + const response = await inferenceClient.inferFromText(prompt, { + schema: tagsSchema, + }); + + const parsed = tagsSchema.safeParse(JSON.parse(response.response)); + if (!parsed.success) { + throw new Error( + `Failed to parse model response: ${parsed.error.message}`, ); - - const parsed = tagsSchema.safeParse( - this.parseJsonFromResponse(response.response), - ); - if (!parsed.success) { - throw new Error( - `Failed to parse model response: ${parsed.error.message}`, - ); - } - - return parsed.data.tags; - } - - private async inferFromText( - prompt: string, - model: string, - opts: InferenceOptions, - ): Promise<InferenceResponse> { - const chatCompletion = await this.client.chat.completions.create({ - messages: [{ role: "user", content: prompt }], - model: model, - response_format: opts.schema - ? zodResponseFormat(opts.schema, "schema") - : { type: "json_object" }, - }); - - const response = chatCompletion.choices[0].message.content; - if (!response) { - throw new Error("Got no message content from model"); - } - - return { - response, - totalTokens: chatCompletion.usage?.total_tokens, - }; - } - - private buildPrompt( - content: string, - lang: string, - customPrompts: string[], - ): string { - return ` -You are an expert whose responsibility is to help with automatic tagging for a read-it-later app. -Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are: -- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. -- The tags must be in ${lang}. -- If tag is not generic enough, don't include it. -- The content can include text for cookie consent and privacy policy, ignore those while tagging. -- Aim for 3-5 tags. -- If there are no good tags, leave the array empty. -${customPrompts.map((p) => `- ${p}`).join("\n")} - -<TEXT_CONTENT> -${content} -</TEXT_CONTENT> -You must respond in JSON with key "tags" and the value is an array of string tags.`; } - private parseJsonFromResponse(response: string): unknown { - const trimmedResponse = response.trim(); - - try { - return JSON.parse(trimmedResponse); - } catch { - const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i; - const match = trimmedResponse.match(jsonBlockRegex); - - if (match) { - try { - return JSON.parse(match[1]); - } catch {} - } - - const jsonObjectRegex = /\{[\s\S]*\}/; - const objectMatch = trimmedResponse.match(jsonObjectRegex); - - if (objectMatch) { - try { - return JSON.parse(objectMatch[0]); - } catch {} - } - - return JSON.parse(trimmedResponse); - } - } + return parsed.data.tags; } diff --git a/tools/compare-models/src/types.ts b/tools/compare-models/src/types.ts index b8bdc024..35a677ae 100644 --- a/tools/compare-models/src/types.ts +++ b/tools/compare-models/src/types.ts @@ -3,12 +3,13 @@ export interface Bookmark { title: string | null; content: { type: string; + title: string; url?: string; text?: string; htmlContent?: string; description?: string; }; - tags: Array<{ name: string }>; + tags: Array<{ name: string; attachedBy?: "ai" | "human" }>; } export interface ModelConfig { |
