diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-12-29 23:35:28 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-12-29 23:38:21 +0000 |
| commit | f00287ede0675521c783c1199675538571f977d6 (patch) | |
| tree | 2d04b983fa514f4c62a3695c0a521fb50de24eef /tools/compare-models/src/index.ts | |
| parent | ba8d84a555f9e6cf209c826b97a124f0539739eb (diff) | |
| download | karakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst | |
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'tools/compare-models/src/index.ts')
| -rw-r--r-- | tools/compare-models/src/index.ts | 110 |
1 files changed, 84 insertions, 26 deletions
diff --git a/tools/compare-models/src/index.ts b/tools/compare-models/src/index.ts index c1a80ab5..88fc9249 100644 --- a/tools/compare-models/src/index.ts +++ b/tools/compare-models/src/index.ts @@ -4,7 +4,7 @@ import type { ComparisonResult } from "./types"; import { KarakeepAPIClient } from "./apiClient"; import { runTaggingForModel } from "./bookmarkProcessor"; import { config } from "./config"; -import { InferenceClient } from "./inferenceClient"; +import { createInferenceClient } from "./inferenceClient"; import { askQuestion, clearProgress, @@ -32,18 +32,58 @@ interface ShuffleResult { async function main() { console.log(chalk.cyan("\nš Karakeep Model Comparison Tool\n")); - const inferenceClient = new InferenceClient( - config.OPENAI_API_KEY, - config.OPENAI_BASE_URL, - ); + const isExistingMode = config.COMPARISON_MODE === "model-vs-existing"; + + if (isExistingMode) { + console.log( + chalk.yellow( + `Mode: Comparing ${config.MODEL1_NAME} against existing AI tags\n`, + ), + ); + } else { + if (!config.MODEL2_NAME) { + console.log( + chalk.red( + "\nā Error: MODEL2_NAME is required for model-vs-model comparison mode\n", + ), + ); + process.exit(1); + } + console.log( + chalk.yellow( + `Mode: Comparing ${config.MODEL1_NAME} vs ${config.MODEL2_NAME}\n`, + ), + ); + } const apiClient = new KarakeepAPIClient(); displayProgress("Fetching bookmarks from Karakeep..."); - const bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT); + let bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT); clearProgress(); - console.log(chalk.green(`ā Fetched ${bookmarks.length} link bookmarks\n`)); + // Filter bookmarks with AI tags if in existing mode + if (isExistingMode) { + bookmarks = bookmarks.filter( + (b) => b.tags.some((t) => t.attachedBy === "ai"), + ); + console.log( + chalk.green( + `ā Fetched ${bookmarks.length} link bookmarks with existing AI tags\n`, + ), + ); + } else { + console.log(chalk.green(`ā Fetched ${bookmarks.length} link bookmarks\n`)); + } + + if (bookmarks.length === 0) { + console.log( + chalk.yellow( + "\nā No bookmarks found with AI tags. Please add some bookmarks with AI tags first.\n", + ), + ); + return; + } const counters: VoteCounters = { model1Votes: 0, @@ -59,17 +99,20 @@ async function main() { const bookmark = bookmarks[i]; displayProgress( - `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || "Untitled"}`, + `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || bookmark.content.title || "Untitled"}`, ); let model1Tags: string[] = []; let model2Tags: string[] = []; + // Get tags for model 1 (new model) try { + const model1Client = createInferenceClient(config.MODEL1_NAME); model1Tags = await runTaggingForModel( bookmark, - config.MODEL1_NAME, - inferenceClient, + model1Client, + "english", + config.INFERENCE_CONTEXT_LENGTH, ); } catch (error) { clearProgress(); @@ -80,31 +123,46 @@ async function main() { continue; } - try { - model2Tags = await runTaggingForModel( - bookmark, - config.MODEL2_NAME, - inferenceClient, - ); - } catch (error) { - clearProgress(); - displayError( - `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, - ); - counters.errors++; - continue; + // Get tags for model 2 or existing AI tags + if (isExistingMode) { + // Use existing AI tags from the bookmark + model2Tags = bookmark.tags + .filter((t) => t.attachedBy === "ai") + .map((t) => t.name); + } else { + // Run inference with model 2 + try { + const model2Client = createInferenceClient(config.MODEL2_NAME!); + model2Tags = await runTaggingForModel( + bookmark, + model2Client, + "english", + config.INFERENCE_CONTEXT_LENGTH, + ); + } catch (error) { + clearProgress(); + displayError( + `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, + ); + counters.errors++; + continue; + } } clearProgress(); + const model2Label = isExistingMode + ? "Existing AI Tags" + : config.MODEL2_NAME!; + const shuffleResult: ShuffleResult = { modelA: config.MODEL1_NAME, - modelB: config.MODEL2_NAME, + modelB: model2Label, modelAIsModel1: Math.random() < 0.5, }; if (!shuffleResult.modelAIsModel1) { - shuffleResult.modelA = config.MODEL2_NAME; + shuffleResult.modelA = model2Label; shuffleResult.modelB = config.MODEL1_NAME; } @@ -156,7 +214,7 @@ async function main() { displayFinalResults({ model1Name: config.MODEL1_NAME, - model2Name: config.MODEL2_NAME, + model2Name: isExistingMode ? "Existing AI Tags" : config.MODEL2_NAME!, model1Votes: counters.model1Votes, model2Votes: counters.model2Votes, skipped: counters.skipped, |
