aboutsummaryrefslogtreecommitdiffstats
path: root/tools/compare-models/src
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-12-29 23:35:28 +0000
committerMohamed Bassem <me@mbassem.com>2025-12-29 23:38:21 +0000
commitf00287ede0675521c783c1199675538571f977d6 (patch)
tree2d04b983fa514f4c62a3695c0a521fb50de24eef /tools/compare-models/src
parentba8d84a555f9e6cf209c826b97a124f0539739eb (diff)
downloadkarakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'tools/compare-models/src')
-rw-r--r--tools/compare-models/src/apiClient.ts8
-rw-r--r--tools/compare-models/src/bookmarkProcessor.ts20
-rw-r--r--tools/compare-models/src/config.ts19
-rw-r--r--tools/compare-models/src/index.ts110
-rw-r--r--tools/compare-models/src/inferenceClient.ts157
-rw-r--r--tools/compare-models/src/types.ts3
6 files changed, 164 insertions, 153 deletions
diff --git a/tools/compare-models/src/apiClient.ts b/tools/compare-models/src/apiClient.ts
index f3a960cb..1d9f799d 100644
--- a/tools/compare-models/src/apiClient.ts
+++ b/tools/compare-models/src/apiClient.ts
@@ -53,7 +53,13 @@ export class KarakeepAPIClient {
const batchBookmarks = (data?.bookmarks || [])
.filter((b) => b.content?.type === "link")
- .map((b) => b as Bookmark);
+ .map((b) => ({
+ ...b,
+ tags: (b.tags || []).map((tag) => ({
+ name: tag.name,
+ attachedBy: tag.attachedBy,
+ })),
+ })) as Bookmark[];
bookmarks.push(...batchBookmarks);
cursor = data?.nextCursor || null;
diff --git a/tools/compare-models/src/bookmarkProcessor.ts b/tools/compare-models/src/bookmarkProcessor.ts
index 910957fe..21280b97 100644
--- a/tools/compare-models/src/bookmarkProcessor.ts
+++ b/tools/compare-models/src/bookmarkProcessor.ts
@@ -1,4 +1,7 @@
-import type { InferenceClient } from "./inferenceClient";
+import type { InferenceClient } from "@karakeep/shared/inference";
+import { buildTextPrompt } from "@karakeep/shared/prompts";
+
+import { inferTags } from "./inferenceClient";
import type { Bookmark } from "./types";
export async function extractBookmarkContent(
@@ -35,9 +38,9 @@ export async function extractBookmarkContent(
export async function runTaggingForModel(
bookmark: Bookmark,
- model: string,
inferenceClient: InferenceClient,
lang: string = "english",
+ contextLength: number = 8000,
): Promise<string[]> {
const content = await extractBookmarkContent(bookmark);
@@ -46,11 +49,20 @@ export async function runTaggingForModel(
}
try {
- const tags = await inferenceClient.inferTags(content, model, lang, []);
+ // Use the shared prompt builder with empty custom prompts and default tag style
+ const prompt = await buildTextPrompt(
+ lang,
+ [], // No custom prompts for comparison tool
+ content,
+ contextLength,
+ "as-generated", // Use tags as generated by the model
+ );
+
+ const tags = await inferTags(inferenceClient, prompt);
return tags;
} catch (error) {
throw new Error(
- `Failed to generate tags with ${model}: ${error instanceof Error ? error.message : String(error)}`,
+ `Failed to generate tags: ${error instanceof Error ? error.message : String(error)}`,
);
}
}
diff --git a/tools/compare-models/src/config.ts b/tools/compare-models/src/config.ts
index 9c32610d..0b5d217f 100644
--- a/tools/compare-models/src/config.ts
+++ b/tools/compare-models/src/config.ts
@@ -1,16 +1,33 @@
import { z } from "zod";
+// Local config schema for compare-models tool
const envSchema = z.object({
KARAKEEP_API_KEY: z.string().min(1),
KARAKEEP_SERVER_ADDR: z.string().url(),
MODEL1_NAME: z.string().min(1),
- MODEL2_NAME: z.string().min(1),
+ MODEL2_NAME: z.string().min(1).optional(),
OPENAI_API_KEY: z.string().min(1),
OPENAI_BASE_URL: z.string().url().optional(),
+ COMPARISON_MODE: z
+ .enum(["model-vs-model", "model-vs-existing"])
+ .default("model-vs-model"),
COMPARE_LIMIT: z
.string()
.optional()
.transform((val) => (val ? parseInt(val, 10) : 10)),
+ INFERENCE_CONTEXT_LENGTH: z
+ .string()
+ .optional()
+ .transform((val) => (val ? parseInt(val, 10) : 8000)),
+ INFERENCE_MAX_OUTPUT_TOKENS: z
+ .string()
+ .optional()
+ .transform((val) => (val ? parseInt(val, 10) : 2048)),
+ INFERENCE_USE_MAX_COMPLETION_TOKENS: z
+ .string()
+ .optional()
+ .transform((val) => val === "true")
+ .default("false"),
});
export const config = envSchema.parse(process.env);
diff --git a/tools/compare-models/src/index.ts b/tools/compare-models/src/index.ts
index c1a80ab5..88fc9249 100644
--- a/tools/compare-models/src/index.ts
+++ b/tools/compare-models/src/index.ts
@@ -4,7 +4,7 @@ import type { ComparisonResult } from "./types";
import { KarakeepAPIClient } from "./apiClient";
import { runTaggingForModel } from "./bookmarkProcessor";
import { config } from "./config";
-import { InferenceClient } from "./inferenceClient";
+import { createInferenceClient } from "./inferenceClient";
import {
askQuestion,
clearProgress,
@@ -32,18 +32,58 @@ interface ShuffleResult {
async function main() {
console.log(chalk.cyan("\nšŸš€ Karakeep Model Comparison Tool\n"));
- const inferenceClient = new InferenceClient(
- config.OPENAI_API_KEY,
- config.OPENAI_BASE_URL,
- );
+ const isExistingMode = config.COMPARISON_MODE === "model-vs-existing";
+
+ if (isExistingMode) {
+ console.log(
+ chalk.yellow(
+ `Mode: Comparing ${config.MODEL1_NAME} against existing AI tags\n`,
+ ),
+ );
+ } else {
+ if (!config.MODEL2_NAME) {
+ console.log(
+ chalk.red(
+ "\nāœ— Error: MODEL2_NAME is required for model-vs-model comparison mode\n",
+ ),
+ );
+ process.exit(1);
+ }
+ console.log(
+ chalk.yellow(
+ `Mode: Comparing ${config.MODEL1_NAME} vs ${config.MODEL2_NAME}\n`,
+ ),
+ );
+ }
const apiClient = new KarakeepAPIClient();
displayProgress("Fetching bookmarks from Karakeep...");
- const bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT);
+ let bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT);
clearProgress();
- console.log(chalk.green(`āœ“ Fetched ${bookmarks.length} link bookmarks\n`));
+ // Filter bookmarks with AI tags if in existing mode
+ if (isExistingMode) {
+ bookmarks = bookmarks.filter(
+ (b) => b.tags.some((t) => t.attachedBy === "ai"),
+ );
+ console.log(
+ chalk.green(
+ `āœ“ Fetched ${bookmarks.length} link bookmarks with existing AI tags\n`,
+ ),
+ );
+ } else {
+ console.log(chalk.green(`āœ“ Fetched ${bookmarks.length} link bookmarks\n`));
+ }
+
+ if (bookmarks.length === 0) {
+ console.log(
+ chalk.yellow(
+ "\n⚠ No bookmarks found with AI tags. Please add some bookmarks with AI tags first.\n",
+ ),
+ );
+ return;
+ }
const counters: VoteCounters = {
model1Votes: 0,
@@ -59,17 +99,20 @@ async function main() {
const bookmark = bookmarks[i];
displayProgress(
- `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || "Untitled"}`,
+ `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || bookmark.content.title || "Untitled"}`,
);
let model1Tags: string[] = [];
let model2Tags: string[] = [];
+ // Get tags for model 1 (new model)
try {
+ const model1Client = createInferenceClient(config.MODEL1_NAME);
model1Tags = await runTaggingForModel(
bookmark,
- config.MODEL1_NAME,
- inferenceClient,
+ model1Client,
+ "english",
+ config.INFERENCE_CONTEXT_LENGTH,
);
} catch (error) {
clearProgress();
@@ -80,31 +123,46 @@ async function main() {
continue;
}
- try {
- model2Tags = await runTaggingForModel(
- bookmark,
- config.MODEL2_NAME,
- inferenceClient,
- );
- } catch (error) {
- clearProgress();
- displayError(
- `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`,
- );
- counters.errors++;
- continue;
+ // Get tags for model 2 or existing AI tags
+ if (isExistingMode) {
+ // Use existing AI tags from the bookmark
+ model2Tags = bookmark.tags
+ .filter((t) => t.attachedBy === "ai")
+ .map((t) => t.name);
+ } else {
+ // Run inference with model 2
+ try {
+ const model2Client = createInferenceClient(config.MODEL2_NAME!);
+ model2Tags = await runTaggingForModel(
+ bookmark,
+ model2Client,
+ "english",
+ config.INFERENCE_CONTEXT_LENGTH,
+ );
+ } catch (error) {
+ clearProgress();
+ displayError(
+ `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`,
+ );
+ counters.errors++;
+ continue;
+ }
}
clearProgress();
+ const model2Label = isExistingMode
+ ? "Existing AI Tags"
+ : config.MODEL2_NAME!;
+
const shuffleResult: ShuffleResult = {
modelA: config.MODEL1_NAME,
- modelB: config.MODEL2_NAME,
+ modelB: model2Label,
modelAIsModel1: Math.random() < 0.5,
};
if (!shuffleResult.modelAIsModel1) {
- shuffleResult.modelA = config.MODEL2_NAME;
+ shuffleResult.modelA = model2Label;
shuffleResult.modelB = config.MODEL1_NAME;
}
@@ -156,7 +214,7 @@ async function main() {
displayFinalResults({
model1Name: config.MODEL1_NAME,
- model2Name: config.MODEL2_NAME,
+ model2Name: isExistingMode ? "Existing AI Tags" : config.MODEL2_NAME!,
model1Votes: counters.model1Votes,
model2Votes: counters.model2Votes,
skipped: counters.skipped,
diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts
index 33617318..8649f715 100644
--- a/tools/compare-models/src/inferenceClient.ts
+++ b/tools/compare-models/src/inferenceClient.ts
@@ -1,128 +1,45 @@
-import OpenAI from "openai";
-import { zodResponseFormat } from "openai/helpers/zod";
+import type { InferenceClient } from "@karakeep/shared/inference";
+import {
+ OpenAIInferenceClient,
+ type OpenAIInferenceConfig,
+} from "@karakeep/shared/inference";
import { z } from "zod";
-export interface InferenceOptions {
- schema: z.ZodSchema<any> | null;
+import { config } from "./config";
+
+export function createInferenceClient(modelName: string): InferenceClient {
+ const inferenceConfig: OpenAIInferenceConfig = {
+ apiKey: config.OPENAI_API_KEY,
+ baseURL: config.OPENAI_BASE_URL,
+ textModel: modelName,
+ imageModel: modelName, // Use same model for images if needed
+ contextLength: config.INFERENCE_CONTEXT_LENGTH,
+ maxOutputTokens: config.INFERENCE_MAX_OUTPUT_TOKENS,
+ useMaxCompletionTokens: config.INFERENCE_USE_MAX_COMPLETION_TOKENS,
+ outputSchema: "structured",
+ };
+
+ return new OpenAIInferenceClient(inferenceConfig);
}
-export interface InferenceResponse {
- response: string;
- totalTokens: number | undefined;
-}
-
-export class InferenceClient {
- private client: OpenAI;
-
- constructor(apiKey: string, baseUrl?: string) {
- this.client = new OpenAI({
- apiKey,
- baseURL: baseUrl,
- defaultHeaders: {
- "X-Title": "Karakeep Model Comparison",
- },
- });
- }
-
- async inferTags(
- content: string,
- model: string,
- lang: string = "english",
- customPrompts: string[] = [],
- ): Promise<string[]> {
- const tagsSchema = z.object({
- tags: z.array(z.string()),
- });
-
- const response = await this.inferFromText(
- this.buildPrompt(content, lang, customPrompts),
- model,
- { schema: tagsSchema },
+export async function inferTags(
+ inferenceClient: InferenceClient,
+ prompt: string,
+): Promise<string[]> {
+ const tagsSchema = z.object({
+ tags: z.array(z.string()),
+ });
+
+ const response = await inferenceClient.inferFromText(prompt, {
+ schema: tagsSchema,
+ });
+
+ const parsed = tagsSchema.safeParse(JSON.parse(response.response));
+ if (!parsed.success) {
+ throw new Error(
+ `Failed to parse model response: ${parsed.error.message}`,
);
-
- const parsed = tagsSchema.safeParse(
- this.parseJsonFromResponse(response.response),
- );
- if (!parsed.success) {
- throw new Error(
- `Failed to parse model response: ${parsed.error.message}`,
- );
- }
-
- return parsed.data.tags;
- }
-
- private async inferFromText(
- prompt: string,
- model: string,
- opts: InferenceOptions,
- ): Promise<InferenceResponse> {
- const chatCompletion = await this.client.chat.completions.create({
- messages: [{ role: "user", content: prompt }],
- model: model,
- response_format: opts.schema
- ? zodResponseFormat(opts.schema, "schema")
- : { type: "json_object" },
- });
-
- const response = chatCompletion.choices[0].message.content;
- if (!response) {
- throw new Error("Got no message content from model");
- }
-
- return {
- response,
- totalTokens: chatCompletion.usage?.total_tokens,
- };
- }
-
- private buildPrompt(
- content: string,
- lang: string,
- customPrompts: string[],
- ): string {
- return `
-You are an expert whose responsibility is to help with automatic tagging for a read-it-later app.
-Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are:
-- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres.
-- The tags must be in ${lang}.
-- If tag is not generic enough, don't include it.
-- The content can include text for cookie consent and privacy policy, ignore those while tagging.
-- Aim for 3-5 tags.
-- If there are no good tags, leave the array empty.
-${customPrompts.map((p) => `- ${p}`).join("\n")}
-
-<TEXT_CONTENT>
-${content}
-</TEXT_CONTENT>
-You must respond in JSON with key "tags" and the value is an array of string tags.`;
}
- private parseJsonFromResponse(response: string): unknown {
- const trimmedResponse = response.trim();
-
- try {
- return JSON.parse(trimmedResponse);
- } catch {
- const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i;
- const match = trimmedResponse.match(jsonBlockRegex);
-
- if (match) {
- try {
- return JSON.parse(match[1]);
- } catch {}
- }
-
- const jsonObjectRegex = /\{[\s\S]*\}/;
- const objectMatch = trimmedResponse.match(jsonObjectRegex);
-
- if (objectMatch) {
- try {
- return JSON.parse(objectMatch[0]);
- } catch {}
- }
-
- return JSON.parse(trimmedResponse);
- }
- }
+ return parsed.data.tags;
}
diff --git a/tools/compare-models/src/types.ts b/tools/compare-models/src/types.ts
index b8bdc024..35a677ae 100644
--- a/tools/compare-models/src/types.ts
+++ b/tools/compare-models/src/types.ts
@@ -3,12 +3,13 @@ export interface Bookmark {
title: string | null;
content: {
type: string;
+ title: string;
url?: string;
text?: string;
htmlContent?: string;
description?: string;
};
- tags: Array<{ name: string }>;
+ tags: Array<{ name: string; attachedBy?: "ai" | "human" }>;
}
export interface ModelConfig {