aboutsummaryrefslogtreecommitdiffstats
path: root/tools/compare-models/src/inferenceClient.ts
diff options
context:
space:
mode:
Diffstat (limited to 'tools/compare-models/src/inferenceClient.ts')
-rw-r--r--tools/compare-models/src/inferenceClient.ts46
1 files changed, 46 insertions, 0 deletions
diff --git a/tools/compare-models/src/inferenceClient.ts b/tools/compare-models/src/inferenceClient.ts
new file mode 100644
index 00000000..0a5ed8b5
--- /dev/null
+++ b/tools/compare-models/src/inferenceClient.ts
@@ -0,0 +1,46 @@
+import type { InferenceClient } from "@karakeep/shared/inference";
+import {
+ OpenAIInferenceClient,
+ type OpenAIInferenceConfig,
+} from "@karakeep/shared/inference";
+import { z } from "zod";
+
+import { config } from "./config";
+
+export function createInferenceClient(modelName: string): InferenceClient {
+ const inferenceConfig: OpenAIInferenceConfig = {
+ apiKey: config.OPENAI_API_KEY,
+ baseURL: config.OPENAI_BASE_URL,
+ serviceTier: config.OPENAI_SERVICE_TIER,
+ textModel: modelName,
+ imageModel: modelName, // Use same model for images if needed
+ contextLength: config.INFERENCE_CONTEXT_LENGTH,
+ maxOutputTokens: config.INFERENCE_MAX_OUTPUT_TOKENS,
+ useMaxCompletionTokens: config.INFERENCE_USE_MAX_COMPLETION_TOKENS,
+ outputSchema: "structured",
+ };
+
+ return new OpenAIInferenceClient(inferenceConfig);
+}
+
+export async function inferTags(
+ inferenceClient: InferenceClient,
+ prompt: string,
+): Promise<string[]> {
+ const tagsSchema = z.object({
+ tags: z.array(z.string()),
+ });
+
+ const response = await inferenceClient.inferFromText(prompt, {
+ schema: tagsSchema,
+ });
+
+ const parsed = tagsSchema.safeParse(JSON.parse(response.response));
+ if (!parsed.success) {
+ throw new Error(
+ `Failed to parse model response: ${parsed.error.message}`,
+ );
+ }
+
+ return parsed.data.tags;
+}