diff options
Diffstat (limited to 'tools/compare-models/src')
| -rw-r--r-- | tools/compare-models/src/apiClient.ts | 65 | ||||
| -rw-r--r-- | tools/compare-models/src/bookmarkProcessor.ts | 56 | ||||
| -rw-r--r-- | tools/compare-models/src/config.ts | 16 | ||||
| -rw-r--r-- | tools/compare-models/src/index.ts | 171 | ||||
| -rw-r--r-- | tools/compare-models/src/inferenceClient.ts | 128 | ||||
| -rw-r--r-- | tools/compare-models/src/interactive.ts | 128 | ||||
| -rw-r--r-- | tools/compare-models/src/types.ts | 37 |
7 files changed, 601 insertions, 0 deletions
diff --git a/tools/compare-models/src/apiClient.ts b/tools/compare-models/src/apiClient.ts new file mode 100644 index 00000000..f3a960cb --- /dev/null +++ b/tools/compare-models/src/apiClient.ts @@ -0,0 +1,65 @@ +import { createKarakeepClient } from "@karakeep/sdk"; + +import type { Bookmark } from "./types"; +import { config } from "./config"; + +export class KarakeepAPIClient { + private readonly client: ReturnType<typeof createKarakeepClient>; + + constructor() { + this.client = createKarakeepClient({ + baseUrl: `${config.KARAKEEP_SERVER_ADDR}/api/v1/`, + headers: { + "Content-Type": "application/json", + authorization: `Bearer ${config.KARAKEEP_API_KEY}`, + }, + }); + } + + async fetchBookmarks(limit: number): Promise<Bookmark[]> { + const bookmarks: Bookmark[] = []; + let cursor: string | null = null; + let hasMore = true; + + while (hasMore && bookmarks.length < limit) { + const params: { + limit: number; + includeContent: true; + archived?: boolean; + cursor?: string; + } = { + limit: Math.min(limit - bookmarks.length, 50), + includeContent: true, + archived: false, + }; + + if (cursor) { + params.cursor = cursor; + } + + const { data, response, error } = await this.client.GET("/bookmarks", { + params: { + query: params, + }, + }); + + if (error) { + throw new Error(`Failed to fetch bookmarks: ${String(error)}`); + } + + if (!response.ok) { + throw new Error(`Failed to fetch bookmarks: ${response.status}`); + } + + const batchBookmarks = (data?.bookmarks || []) + .filter((b) => b.content?.type === "link") + .map((b) => b as Bookmark); + + bookmarks.push(...batchBookmarks); + cursor = data?.nextCursor || null; + hasMore = !!cursor; + } + + return bookmarks.slice(0, limit); + } +} diff --git a/tools/compare-models/src/bookmarkProcessor.ts b/tools/compare-models/src/bookmarkProcessor.ts new file mode 100644 index 00000000..910957fe --- /dev/null +++ b/tools/compare-models/src/bookmarkProcessor.ts @@ -0,0 +1,56 @@ +import type { InferenceClient } from "./inferenceClient"; +import type { Bookmark } from "./types"; + +export async function extractBookmarkContent( + bookmark: Bookmark, +): Promise<string> { + if (bookmark.content.type === "link") { + const parts = []; + + if (bookmark.content.url) { + parts.push(`URL: ${bookmark.content.url}`); + } + + if (bookmark.title) { + parts.push(`Title: ${bookmark.title}`); + } + + if (bookmark.content.description) { + parts.push(`Description: ${bookmark.content.description}`); + } + + if (bookmark.content.htmlContent) { + parts.push(`Content: ${bookmark.content.htmlContent}`); + } + + return parts.join("\n"); + } + + if (bookmark.content.type === "text" && bookmark.content.text) { + return bookmark.content.text; + } + + return ""; +} + +export async function runTaggingForModel( + bookmark: Bookmark, + model: string, + inferenceClient: InferenceClient, + lang: string = "english", +): Promise<string[]> { + const content = await extractBookmarkContent(bookmark); + + if (!content) { + return []; + } + + try { + const tags = await inferenceClient.inferTags(content, model, lang, []); + return tags; + } catch (error) { + throw new Error( + `Failed to generate tags with ${model}: ${error instanceof Error ? error.message : String(error)}`, + ); + } +} diff --git a/tools/compare-models/src/config.ts b/tools/compare-models/src/config.ts new file mode 100644 index 00000000..9c32610d --- /dev/null +++ b/tools/compare-models/src/config.ts @@ -0,0 +1,16 @@ +import { z } from "zod"; + +const envSchema = z.object({ + KARAKEEP_API_KEY: z.string().min(1), + KARAKEEP_SERVER_ADDR: z.string().url(), + MODEL1_NAME: z.string().min(1), + MODEL2_NAME: z.string().min(1), + OPENAI_API_KEY: z.string().min(1), + OPENAI_BASE_URL: z.string().url().optional(), + COMPARE_LIMIT: z + .string() + .optional() + .transform((val) => (val ? parseInt(val, 10) : 10)), +}); + +export const config = envSchema.parse(process.env); diff --git a/tools/compare-models/src/index.ts b/tools/compare-models/src/index.ts new file mode 100644 index 00000000..c1a80ab5 --- /dev/null +++ b/tools/compare-models/src/index.ts @@ -0,0 +1,171 @@ +import chalk from "chalk"; + +import type { ComparisonResult } from "./types"; +import { KarakeepAPIClient } from "./apiClient"; +import { runTaggingForModel } from "./bookmarkProcessor"; +import { config } from "./config"; +import { InferenceClient } from "./inferenceClient"; +import { + askQuestion, + clearProgress, + close, + displayComparison, + displayError, + displayFinalResults, + displayProgress, +} from "./interactive"; + +interface VoteCounters { + model1Votes: number; + model2Votes: number; + skipped: number; + errors: number; + total: number; +} + +interface ShuffleResult { + modelA: string; + modelB: string; + modelAIsModel1: boolean; +} + +async function main() { + console.log(chalk.cyan("\nš Karakeep Model Comparison Tool\n")); + + const inferenceClient = new InferenceClient( + config.OPENAI_API_KEY, + config.OPENAI_BASE_URL, + ); + + const apiClient = new KarakeepAPIClient(); + + displayProgress("Fetching bookmarks from Karakeep..."); + const bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT); + clearProgress(); + + console.log(chalk.green(`ā Fetched ${bookmarks.length} link bookmarks\n`)); + + const counters: VoteCounters = { + model1Votes: 0, + model2Votes: 0, + skipped: 0, + errors: 0, + total: bookmarks.length, + }; + + const detailedResults: ComparisonResult[] = []; + + for (let i = 0; i < bookmarks.length; i++) { + const bookmark = bookmarks[i]; + + displayProgress( + `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || "Untitled"}`, + ); + + let model1Tags: string[] = []; + let model2Tags: string[] = []; + + try { + model1Tags = await runTaggingForModel( + bookmark, + config.MODEL1_NAME, + inferenceClient, + ); + } catch (error) { + clearProgress(); + displayError( + `${config.MODEL1_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, + ); + counters.errors++; + continue; + } + + try { + model2Tags = await runTaggingForModel( + bookmark, + config.MODEL2_NAME, + inferenceClient, + ); + } catch (error) { + clearProgress(); + displayError( + `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, + ); + counters.errors++; + continue; + } + + clearProgress(); + + const shuffleResult: ShuffleResult = { + modelA: config.MODEL1_NAME, + modelB: config.MODEL2_NAME, + modelAIsModel1: Math.random() < 0.5, + }; + + if (!shuffleResult.modelAIsModel1) { + shuffleResult.modelA = config.MODEL2_NAME; + shuffleResult.modelB = config.MODEL1_NAME; + } + + const comparison: ComparisonResult = { + bookmark, + modelA: shuffleResult.modelA, + modelATags: shuffleResult.modelAIsModel1 ? model1Tags : model2Tags, + modelB: shuffleResult.modelB, + modelBTags: shuffleResult.modelAIsModel1 ? model2Tags : model1Tags, + }; + + displayComparison(i + 1, bookmarks.length, comparison, true); + + const answer = await askQuestion( + "Which tags do you prefer? [1=Model A, 2=Model B, s=skip, q=quit] > ", + ); + + const normalizedAnswer = answer.toLowerCase(); + + if (normalizedAnswer === "q" || normalizedAnswer === "quit") { + console.log(chalk.yellow("\nāø Quitting early...\n")); + break; + } + + if (normalizedAnswer === "1") { + comparison.winner = "modelA"; + if (shuffleResult.modelAIsModel1) { + counters.model1Votes++; + } else { + counters.model2Votes++; + } + detailedResults.push(comparison); + } else if (normalizedAnswer === "2") { + comparison.winner = "modelB"; + if (shuffleResult.modelAIsModel1) { + counters.model2Votes++; + } else { + counters.model1Votes++; + } + detailedResults.push(comparison); + } else { + comparison.winner = "skip"; + counters.skipped++; + detailedResults.push(comparison); + } + } + + close(); + + displayFinalResults({ + model1Name: config.MODEL1_NAME, + model2Name: config.MODEL2_NAME, + model1Votes: counters.model1Votes, + model2Votes: counters.model2Votes, + skipped: counters.skipped, + errors: counters.errors, + total: counters.total, + }); +} + +main().catch((error) => { + console.error(chalk.red(`\nā Fatal error: ${error}\n`)); + process.exit(1); +}); diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts new file mode 100644 index 00000000..33617318 --- /dev/null +++ b/tools/compare-models/src/inferenceClient.ts @@ -0,0 +1,128 @@ +import OpenAI from "openai"; +import { zodResponseFormat } from "openai/helpers/zod"; +import { z } from "zod"; + +export interface InferenceOptions { + schema: z.ZodSchema<any> | null; +} + +export interface InferenceResponse { + response: string; + totalTokens: number | undefined; +} + +export class InferenceClient { + private client: OpenAI; + + constructor(apiKey: string, baseUrl?: string) { + this.client = new OpenAI({ + apiKey, + baseURL: baseUrl, + defaultHeaders: { + "X-Title": "Karakeep Model Comparison", + }, + }); + } + + async inferTags( + content: string, + model: string, + lang: string = "english", + customPrompts: string[] = [], + ): Promise<string[]> { + const tagsSchema = z.object({ + tags: z.array(z.string()), + }); + + const response = await this.inferFromText( + this.buildPrompt(content, lang, customPrompts), + model, + { schema: tagsSchema }, + ); + + const parsed = tagsSchema.safeParse( + this.parseJsonFromResponse(response.response), + ); + if (!parsed.success) { + throw new Error( + `Failed to parse model response: ${parsed.error.message}`, + ); + } + + return parsed.data.tags; + } + + private async inferFromText( + prompt: string, + model: string, + opts: InferenceOptions, + ): Promise<InferenceResponse> { + const chatCompletion = await this.client.chat.completions.create({ + messages: [{ role: "user", content: prompt }], + model: model, + response_format: opts.schema + ? zodResponseFormat(opts.schema, "schema") + : { type: "json_object" }, + }); + + const response = chatCompletion.choices[0].message.content; + if (!response) { + throw new Error("Got no message content from model"); + } + + return { + response, + totalTokens: chatCompletion.usage?.total_tokens, + }; + } + + private buildPrompt( + content: string, + lang: string, + customPrompts: string[], + ): string { + return ` +You are an expert whose responsibility is to help with automatic tagging for a read-it-later app. +Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are: +- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres. +- The tags must be in ${lang}. +- If tag is not generic enough, don't include it. +- The content can include text for cookie consent and privacy policy, ignore those while tagging. +- Aim for 3-5 tags. +- If there are no good tags, leave the array empty. +${customPrompts.map((p) => `- ${p}`).join("\n")} + +<TEXT_CONTENT> +${content} +</TEXT_CONTENT> +You must respond in JSON with key "tags" and the value is an array of string tags.`; + } + + private parseJsonFromResponse(response: string): unknown { + const trimmedResponse = response.trim(); + + try { + return JSON.parse(trimmedResponse); + } catch { + const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i; + const match = trimmedResponse.match(jsonBlockRegex); + + if (match) { + try { + return JSON.parse(match[1]); + } catch {} + } + + const jsonObjectRegex = /\{[\s\S]*\}/; + const objectMatch = trimmedResponse.match(jsonObjectRegex); + + if (objectMatch) { + try { + return JSON.parse(objectMatch[0]); + } catch {} + } + + return JSON.parse(trimmedResponse); + } + } +} diff --git a/tools/compare-models/src/interactive.ts b/tools/compare-models/src/interactive.ts new file mode 100644 index 00000000..b93fc1d7 --- /dev/null +++ b/tools/compare-models/src/interactive.ts @@ -0,0 +1,128 @@ +import * as readline from "node:readline"; +import chalk from "chalk"; + +import type { ComparisonResult } from "./types"; + +const rl = readline.createInterface({ + input: process.stdin, + output: process.stdout, +}); + +export async function askQuestion(question: string): Promise<string> { + return new Promise((resolve) => { + rl.question(question, (answer) => { + resolve(answer.trim()); + }); + }); +} + +export function displayComparison( + index: number, + total: number, + result: ComparisonResult, + blind: boolean = true, +): void { + const divider = chalk.gray("ā".repeat(80)); + const header = chalk.bold.cyan(`\n=== Bookmark ${index}/${total} ===`); + const title = chalk.bold.white(result.bookmark.title || "Untitled"); + const url = result.bookmark.content.url + ? chalk.gray(result.bookmark.content.url) + : ""; + const content = chalk.gray( + result.bookmark.content.description + ? result.bookmark.content.description.substring(0, 200) + "..." + : "", + ); + + const modelAName = blind ? "Model A" : result.modelA; + const modelBName = blind ? "Model B" : result.modelB; + + const modelATags = result.modelATags + .map((tag) => chalk.green(` ⢠${tag}`)) + .join("\n"); + const modelBTags = result.modelBTags + .map((tag) => chalk.yellow(` ⢠${tag}`)) + .join("\n"); + + console.log(header); + console.log(title); + if (url) console.log(url); + if (content) console.log(content); + console.log(divider); + console.log(); + console.log(chalk.green(`${modelAName}:`)); + if (modelATags) { + console.log(modelATags); + } else { + console.log(chalk.gray(" (no tags)")); + } + console.log(); + console.log(chalk.yellow(`${modelBName}:`)); + if (modelBTags) { + console.log(modelBTags); + } else { + console.log(chalk.gray(" (no tags)")); + } + console.log(); +} + +export function displayError(message: string): void { + console.log(chalk.red(`\nā Error: ${message}\n`)); +} + +export function displayProgress(message: string): void { + process.stdout.write(chalk.gray(message)); +} + +export function clearProgress(): void { + process.stdout.write("\r\x1b[K"); +} + +export function close(): void { + rl.close(); +} + +export function displayFinalResults(results: { + model1Name: string; + model2Name: string; + model1Votes: number; + model2Votes: number; + skipped: number; + errors: number; + total: number; +}): void { + const winner = + results.model1Votes > results.model2Votes + ? results.model1Name + : results.model2Votes > results.model1Votes + ? results.model2Name + : "TIE"; + + const divider = chalk.gray("ā".repeat(80)); + const header = chalk.bold.cyan("\n=== FINAL RESULTS ==="); + const model1Line = chalk.green( + `${results.model1Name}: ${results.model1Votes} votes`, + ); + const model2Line = chalk.yellow( + `${results.model2Name}: ${results.model2Votes} votes`, + ); + const skippedLine = chalk.gray(`Skipped: ${results.skipped}`); + const errorsLine = chalk.red(`Errors: ${results.errors}`); + const totalLine = chalk.bold(`Total bookmarks tested: ${results.total}`); + const winnerLine = + winner === "TIE" + ? chalk.bold.cyan(`\nš RESULT: TIE`) + : chalk.bold.green(`\nš WINNER: ${winner}`); + + console.log(divider); + console.log(header); + console.log(divider); + console.log(model1Line); + console.log(model2Line); + console.log(skippedLine); + console.log(errorsLine); + console.log(divider); + console.log(totalLine); + console.log(winnerLine); + console.log(divider); +} diff --git a/tools/compare-models/src/types.ts b/tools/compare-models/src/types.ts new file mode 100644 index 00000000..b8bdc024 --- /dev/null +++ b/tools/compare-models/src/types.ts @@ -0,0 +1,37 @@ +export interface Bookmark { + id: string; + title: string | null; + content: { + type: string; + url?: string; + text?: string; + htmlContent?: string; + description?: string; + }; + tags: Array<{ name: string }>; +} + +export interface ModelConfig { + name: string; + apiKey: string; + baseUrl?: string; +} + +export interface ComparisonResult { + bookmark: Bookmark; + modelA: string; + modelATags: string[]; + modelB: string; + modelBTags: string[]; + winner?: "modelA" | "modelB" | "skip"; +} + +export interface FinalResults { + model1Name: string; + model2Name: string; + model1Votes: number; + model2Votes: number; + skipped: number; + errors: number; + total: number; +} |
