aboutsummaryrefslogtreecommitdiffstats
path: root/tools/compare-models/src/inferenceClient.ts
blob: 33617318a6b4961a8d1ffb9b8012baedab38f4bb (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import OpenAI from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import { z } from "zod";

export interface InferenceOptions {
  schema: z.ZodSchema<any> | null;
}

export interface InferenceResponse {
  response: string;
  totalTokens: number | undefined;
}

export class InferenceClient {
  private client: OpenAI;

  constructor(apiKey: string, baseUrl?: string) {
    this.client = new OpenAI({
      apiKey,
      baseURL: baseUrl,
      defaultHeaders: {
        "X-Title": "Karakeep Model Comparison",
      },
    });
  }

  async inferTags(
    content: string,
    model: string,
    lang: string = "english",
    customPrompts: string[] = [],
  ): Promise<string[]> {
    const tagsSchema = z.object({
      tags: z.array(z.string()),
    });

    const response = await this.inferFromText(
      this.buildPrompt(content, lang, customPrompts),
      model,
      { schema: tagsSchema },
    );

    const parsed = tagsSchema.safeParse(
      this.parseJsonFromResponse(response.response),
    );
    if (!parsed.success) {
      throw new Error(
        `Failed to parse model response: ${parsed.error.message}`,
      );
    }

    return parsed.data.tags;
  }

  private async inferFromText(
    prompt: string,
    model: string,
    opts: InferenceOptions,
  ): Promise<InferenceResponse> {
    const chatCompletion = await this.client.chat.completions.create({
      messages: [{ role: "user", content: prompt }],
      model: model,
      response_format: opts.schema
        ? zodResponseFormat(opts.schema, "schema")
        : { type: "json_object" },
    });

    const response = chatCompletion.choices[0].message.content;
    if (!response) {
      throw new Error("Got no message content from model");
    }

    return {
      response,
      totalTokens: chatCompletion.usage?.total_tokens,
    };
  }

  private buildPrompt(
    content: string,
    lang: string,
    customPrompts: string[],
  ): string {
    return `
You are an expert whose responsibility is to help with automatic tagging for a read-it-later app.
Please analyze the TEXT_CONTENT below and suggest relevant tags that describe its key themes, topics, and main ideas. The rules are:
- Aim for a variety of tags, including broad categories, specific keywords, and potential sub-genres.
- The tags must be in ${lang}.
- If tag is not generic enough, don't include it.
- The content can include text for cookie consent and privacy policy, ignore those while tagging.
- Aim for 3-5 tags.
- If there are no good tags, leave the array empty.
${customPrompts.map((p) => `- ${p}`).join("\n")}

<TEXT_CONTENT>
${content}
</TEXT_CONTENT>
You must respond in JSON with key "tags" and the value is an array of string tags.`;
  }

  private parseJsonFromResponse(response: string): unknown {
    const trimmedResponse = response.trim();

    try {
      return JSON.parse(trimmedResponse);
    } catch {
      const jsonBlockRegex = /```(?:json)?\s*(\{[\s\S]*?\})\s*```/i;
      const match = trimmedResponse.match(jsonBlockRegex);

      if (match) {
        try {
          return JSON.parse(match[1]);
        } catch {}
      }

      const jsonObjectRegex = /\{[\s\S]*\}/;
      const objectMatch = trimmedResponse.match(jsonObjectRegex);

      if (objectMatch) {
        try {
          return JSON.parse(objectMatch[0]);
        } catch {}
      }

      return JSON.parse(trimmedResponse);
    }
  }
}