aboutsummaryrefslogtreecommitdiffstats
path: root/tools/compare-models/src/inferenceClient.ts
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-12-26 11:14:17 +0000
committerMohamed Bassem <me@mbassem.com>2025-12-26 11:14:17 +0000
commit1dfa5d12f6af6ca964bdfa911809a061ffdf36c2 (patch)
tree87c734eaa5395051a0a46972ca575f2866c73dd5 /tools/compare-models/src/inferenceClient.ts
parentecb7a710ca7ec22aa3304b8d1f6b603bb60874bc (diff)
downloadkarakeep-1dfa5d12f6af6ca964bdfa911809a061ffdf36c2.tar.zst
chore: add a tool for comparing perf of different models
Diffstat (limited to 'tools/compare-models/src/inferenceClient.ts')
-rw-r--r--tools/compare-models/src/inferenceClient.ts128
1 files changed, 128 insertions, 0 deletions
diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts
new file mode 100644
index 00000000..33617318
--- /dev/null
+++ b/tools/compare-models/src/inferenceClient.ts
@@ -0,0 +1,128 @@
+import OpenAI from "openai";
+import { zodResponseFormat } from "openai/helpers/zod";
+import { z } from "zod";
+
+export interface InferenceOptions {
+ schema: z.ZodSchema<any> | null;
+}
+
+export interface InferenceResponse {
+ response: string;
+ totalTokens: number | undefined;
+}
+
+export class InferenceClient {
+ private client: OpenAI;
+
+ constructor(apiKey: string, baseUrl?: string) {
+ this.client = new OpenAI({
+ apiKey,
+ baseURL: baseUrl,
+ defaultHeaders: {
+ "X-Title": "Karakeep Model Comparison",
+ },
+ });
+ }
+
+ async inferTags(
+ content: string,
+ model: string,
+ lang: string = "english",
+ customPrompts: string[] = [],
+ ): Promise<string[]> {
+ const tagsSchema = z.object({
+ tags: z.array(z.string()),
+ });
+
+ const response = await this.inferFromText(
+ this.buildPrompt(content, lang, customPrompts),
+ model,
+ { schema: tagsSchema },
+ );
+
+ const parsed = tagsSchema.safeParse(
+ this.parseJsonFromResponse(response.response),
+ );
+ if (!parsed.success) {
+ throw new Error(
+ `Failed to parse model response: ${parsed.error.message}`,
+ );
+ }
+
+ return parsed.data.tags;
+ }
+
+ private async inferFromText(
+ prompt: string,
+ model: string,
+ opts: InferenceOptions,
+ ): Promise<InferenceResponse> {
+ const chatCompletion = await this.client.chat.completions.create({
+ messages: [{ role: "user", content: prompt }],
+ model: model,
+ response_format: opts.schema
+ ? zodResponseFormat(opts.schema, "schema")
+ : { type: "json_object" },
+ });
+
+ const response = chatCompletion.choices[0].message.content;
+ if (!response) {
+ throw new Error("Got no message content from model");
+ }
+
+ return {
+ response,
+ totalTokens: chatCompletion.usage?.total_tokens,
+ };
+ }
+
+ private buildPrompt(
+ content: string,
+ lang: string,
+ customPrompts: string[],
+ ): string {
+ return `
+You are an expert whose responsibility is to help with automatic tagging for a read-it-later app.
+Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are:
+- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres.
+- The tags must be in ${lang}.
+- If tag is not generic enough, don't include it.
+- The content can include text for cookie consent and privacy policy, ignore those while tagging.
+- Aim for 3-5 tags.
+- If there are no good tags, leave the array empty.
+${customPrompts.map((p) => `- ${p}`).join("\n")}
+
+<TEXT_CONTENT>
+${content}
+</TEXT_CONTENT>
+You must respond in JSON with key "tags" and the value is an array of string tags.`;
+ }
+
+ private parseJsonFromResponse(response: string): unknown {
+ const trimmedResponse = response.trim();
+
+ try {
+ return JSON.parse(trimmedResponse);
+ } catch {
+ const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i;
+ const match = trimmedResponse.match(jsonBlockRegex);
+
+ if (match) {
+ try {
+ return JSON.parse(match[1]);
+ } catch {}
+ }
+
+ const jsonObjectRegex = /\{[\s\S]*\}/;
+ const objectMatch = trimmedResponse.match(jsonObjectRegex);
+
+ if (objectMatch) {
+ try {
+ return JSON.parse(objectMatch[0]);
+ } catch {}
+ }
+
+ return JSON.parse(trimmedResponse);
+ }
+ }
+}