aboutsummaryrefslogtreecommitdiffstats
path: root/tools/compare-models/src/index.ts
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-12-29 23:35:28 +0000
committerMohamed Bassem <me@mbassem.com>2025-12-29 23:38:21 +0000
commitf00287ede0675521c783c1199675538571f977d6 (patch)
tree2d04b983fa514f4c62a3695c0a521fb50de24eef /tools/compare-models/src/index.ts
parentba8d84a555f9e6cf209c826b97a124f0539739eb (diff)
downloadkarakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'tools/compare-models/src/index.ts')
-rw-r--r--tools/compare-models/src/index.ts110
1 files changed, 84 insertions, 26 deletions
diff --git a/tools/compare-models/src/index.ts b/tools/compare-models/src/index.ts
index c1a80ab5..88fc9249 100644
--- a/tools/compare-models/src/index.ts
+++ b/tools/compare-models/src/index.ts
@@ -4,7 +4,7 @@ import type { ComparisonResult } from "./types";
import { KarakeepAPIClient } from "./apiClient";
import { runTaggingForModel } from "./bookmarkProcessor";
import { config } from "./config";
-import { InferenceClient } from "./inferenceClient";
+import { createInferenceClient } from "./inferenceClient";
import {
askQuestion,
clearProgress,
@@ -32,18 +32,58 @@ interface ShuffleResult {
async function main() {
console.log(chalk.cyan("\nšŸš€ Karakeep Model Comparison Tool\n"));
- const inferenceClient = new InferenceClient(
- config.OPENAI_API_KEY,
- config.OPENAI_BASE_URL,
- );
+ const isExistingMode = config.COMPARISON_MODE === "model-vs-existing";
+
+ if (isExistingMode) {
+ console.log(
+ chalk.yellow(
+ `Mode: Comparing ${config.MODEL1_NAME} against existing AI tags\n`,
+ ),
+ );
+ } else {
+ if (!config.MODEL2_NAME) {
+ console.log(
+ chalk.red(
+ "\nāœ— Error: MODEL2_NAME is required for model-vs-model comparison mode\n",
+ ),
+ );
+ process.exit(1);
+ }
+ console.log(
+ chalk.yellow(
+ `Mode: Comparing ${config.MODEL1_NAME} vs ${config.MODEL2_NAME}\n`,
+ ),
+ );
+ }
const apiClient = new KarakeepAPIClient();
displayProgress("Fetching bookmarks from Karakeep...");
- const bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT);
+ let bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT);
clearProgress();
- console.log(chalk.green(`āœ“ Fetched ${bookmarks.length} link bookmarks\n`));
+ // Filter bookmarks with AI tags if in existing mode
+ if (isExistingMode) {
+ bookmarks = bookmarks.filter(
+ (b) => b.tags.some((t) => t.attachedBy === "ai"),
+ );
+ console.log(
+ chalk.green(
+ `āœ“ Fetched ${bookmarks.length} link bookmarks with existing AI tags\n`,
+ ),
+ );
+ } else {
+ console.log(chalk.green(`āœ“ Fetched ${bookmarks.length} link bookmarks\n`));
+ }
+
+ if (bookmarks.length === 0) {
+ console.log(
+ chalk.yellow(
+ "\n⚠ No bookmarks found with AI tags. Please add some bookmarks with AI tags first.\n",
+ ),
+ );
+ return;
+ }
const counters: VoteCounters = {
model1Votes: 0,
@@ -59,17 +99,20 @@ async function main() {
const bookmark = bookmarks[i];
displayProgress(
- `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || "Untitled"}`,
+ `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || bookmark.content.title || "Untitled"}`,
);
let model1Tags: string[] = [];
let model2Tags: string[] = [];
+ // Get tags for model 1 (new model)
try {
+ const model1Client = createInferenceClient(config.MODEL1_NAME);
model1Tags = await runTaggingForModel(
bookmark,
- config.MODEL1_NAME,
- inferenceClient,
+ model1Client,
+ "english",
+ config.INFERENCE_CONTEXT_LENGTH,
);
} catch (error) {
clearProgress();
@@ -80,31 +123,46 @@ async function main() {
continue;
}
- try {
- model2Tags = await runTaggingForModel(
- bookmark,
- config.MODEL2_NAME,
- inferenceClient,
- );
- } catch (error) {
- clearProgress();
- displayError(
- `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`,
- );
- counters.errors++;
- continue;
+ // Get tags for model 2 or existing AI tags
+ if (isExistingMode) {
+ // Use existing AI tags from the bookmark
+ model2Tags = bookmark.tags
+ .filter((t) => t.attachedBy === "ai")
+ .map((t) => t.name);
+ } else {
+ // Run inference with model 2
+ try {
+ const model2Client = createInferenceClient(config.MODEL2_NAME!);
+ model2Tags = await runTaggingForModel(
+ bookmark,
+ model2Client,
+ "english",
+ config.INFERENCE_CONTEXT_LENGTH,
+ );
+ } catch (error) {
+ clearProgress();
+ displayError(
+ `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`,
+ );
+ counters.errors++;
+ continue;
+ }
}
clearProgress();
+ const model2Label = isExistingMode
+ ? "Existing AI Tags"
+ : config.MODEL2_NAME!;
+
const shuffleResult: ShuffleResult = {
modelA: config.MODEL1_NAME,
- modelB: config.MODEL2_NAME,
+ modelB: model2Label,
modelAIsModel1: Math.random() < 0.5,
};
if (!shuffleResult.modelAIsModel1) {
- shuffleResult.modelA = config.MODEL2_NAME;
+ shuffleResult.modelA = model2Label;
shuffleResult.modelB = config.MODEL1_NAME;
}
@@ -156,7 +214,7 @@ async function main() {
displayFinalResults({
model1Name: config.MODEL1_NAME,
- model2Name: config.MODEL2_NAME,
+ model2Name: isExistingMode ? "Existing AI Tags" : config.MODEL2_NAME!,
model1Votes: counters.model1Votes,
model2Votes: counters.model2Votes,
skipped: counters.skipped,