diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-12-26 11:14:17 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-12-26 11:14:17 +0000 |
| commit | 1dfa5d12f6af6ca964bdfa911809a061ffdf36c2 (patch) | |
| tree | 87c734eaa5395051a0a46972ca575f2866c73dd5 /tools/compare-models/src/index.ts | |
| parent | ecb7a710ca7ec22aa3304b8d1f6b603bb60874bc (diff) | |
| download | karakeep-1dfa5d12f6af6ca964bdfa911809a061ffdf36c2.tar.zst | |
chore: add a tool for comparing perf of different models
Diffstat (limited to 'tools/compare-models/src/index.ts')
| -rw-r--r-- | tools/compare-models/src/index.ts | 171 |
1 files changed, 171 insertions, 0 deletions
diff --git a/tools/compare-models/src/index.ts b/tools/compare-models/src/index.ts new file mode 100644 index 00000000..c1a80ab5 --- /dev/null +++ b/tools/compare-models/src/index.ts @@ -0,0 +1,171 @@ +import chalk from "chalk"; + +import type { ComparisonResult } from "./types"; +import { KarakeepAPIClient } from "./apiClient"; +import { runTaggingForModel } from "./bookmarkProcessor"; +import { config } from "./config"; +import { InferenceClient } from "./inferenceClient"; +import { + askQuestion, + clearProgress, + close, + displayComparison, + displayError, + displayFinalResults, + displayProgress, +} from "./interactive"; + +interface VoteCounters { + model1Votes: number; + model2Votes: number; + skipped: number; + errors: number; + total: number; +} + +interface ShuffleResult { + modelA: string; + modelB: string; + modelAIsModel1: boolean; +} + +async function main() { + console.log(chalk.cyan("\nđ Karakeep Model Comparison Tool\n")); + + const inferenceClient = new InferenceClient( + config.OPENAI_API_KEY, + config.OPENAI_BASE_URL, + ); + + const apiClient = new KarakeepAPIClient(); + + displayProgress("Fetching bookmarks from Karakeep..."); + const bookmarks = await apiClient.fetchBookmarks(config.COMPARE_LIMIT); + clearProgress(); + + console.log(chalk.green(`â Fetched ${bookmarks.length} link bookmarks\n`)); + + const counters: VoteCounters = { + model1Votes: 0, + model2Votes: 0, + skipped: 0, + errors: 0, + total: bookmarks.length, + }; + + const detailedResults: ComparisonResult[] = []; + + for (let i = 0; i < bookmarks.length; i++) { + const bookmark = bookmarks[i]; + + displayProgress( + `[${i + 1}/${bookmarks.length}] Running inference on: ${bookmark.title || "Untitled"}`, + ); + + let model1Tags: string[] = []; + let model2Tags: string[] = []; + + try { + model1Tags = await runTaggingForModel( + bookmark, + config.MODEL1_NAME, + inferenceClient, + ); + } catch (error) { + clearProgress(); + displayError( + `${config.MODEL1_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, + ); + counters.errors++; + continue; + } + + try { + model2Tags = await runTaggingForModel( + bookmark, + config.MODEL2_NAME, + inferenceClient, + ); + } catch (error) { + clearProgress(); + displayError( + `${config.MODEL2_NAME} failed: ${error instanceof Error ? error.message : String(error)}`, + ); + counters.errors++; + continue; + } + + clearProgress(); + + const shuffleResult: ShuffleResult = { + modelA: config.MODEL1_NAME, + modelB: config.MODEL2_NAME, + modelAIsModel1: Math.random() < 0.5, + }; + + if (!shuffleResult.modelAIsModel1) { + shuffleResult.modelA = config.MODEL2_NAME; + shuffleResult.modelB = config.MODEL1_NAME; + } + + const comparison: ComparisonResult = { + bookmark, + modelA: shuffleResult.modelA, + modelATags: shuffleResult.modelAIsModel1 ? model1Tags : model2Tags, + modelB: shuffleResult.modelB, + modelBTags: shuffleResult.modelAIsModel1 ? model2Tags : model1Tags, + }; + + displayComparison(i + 1, bookmarks.length, comparison, true); + + const answer = await askQuestion( + "Which tags do you prefer? [1=Model A, 2=Model B, s=skip, q=quit] > ", + ); + + const normalizedAnswer = answer.toLowerCase(); + + if (normalizedAnswer === "q" || normalizedAnswer === "quit") { + console.log(chalk.yellow("\n⸠Quitting early...\n")); + break; + } + + if (normalizedAnswer === "1") { + comparison.winner = "modelA"; + if (shuffleResult.modelAIsModel1) { + counters.model1Votes++; + } else { + counters.model2Votes++; + } + detailedResults.push(comparison); + } else if (normalizedAnswer === "2") { + comparison.winner = "modelB"; + if (shuffleResult.modelAIsModel1) { + counters.model2Votes++; + } else { + counters.model1Votes++; + } + detailedResults.push(comparison); + } else { + comparison.winner = "skip"; + counters.skipped++; + detailedResults.push(comparison); + } + } + + close(); + + displayFinalResults({ + model1Name: config.MODEL1_NAME, + model2Name: config.MODEL2_NAME, + model1Votes: counters.model1Votes, + model2Votes: counters.model2Votes, + skipped: counters.skipped, + errors: counters.errors, + total: counters.total, + }); +} + +main().catch((error) => { + console.error(chalk.red(`\nâ Fatal error: ${error}\n`)); + process.exit(1); +}); |
