diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-04-13 17:03:58 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-04-13 17:03:58 +0000 |
| commit | 1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0 (patch) | |
| tree | eb88bb3c6f04d8d4dea1be889cb8a8e552ca91ba /packages | |
| parent | f3c525b7f7dd360f654d8621bbf64e31ad5ff48e (diff) | |
| download | karakeep-1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0.tar.zst | |
fix: Allow using JSON mode for ollama users. Fixes #1160
Diffstat (limited to 'packages')
| -rw-r--r-- | packages/shared/config.ts | 19 | ||||
| -rw-r--r-- | packages/shared/inference.ts | 55 |
2 files changed, 57 insertions, 17 deletions
diff --git a/packages/shared/config.ts b/packages/shared/config.ts index fd224ea7..8abb5902 100644 --- a/packages/shared/config.ts +++ b/packages/shared/config.ts @@ -8,6 +8,13 @@ const stringBool = (defaultValue: string) => .refine((s) => s === "true" || s === "false") .transform((s) => s === "true"); +const optionalStringBool = () => + z + .string() + .refine((s) => s === "true" || s === "false") + .transform((s) => s === "true") + .optional(); + const allEnv = z.object({ API_URL: z.string().url().default("http://localhost:3000"), DISABLE_SIGNUPS: stringBool("false"), @@ -29,7 +36,10 @@ const allEnv = z.object({ INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"), EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"), INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048), - INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: stringBool("true"), + INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: optionalStringBool(), + INFERENCE_OUTPUT_SCHEMA: z + .enum(["structured", "json", "plain"]) + .default("structured"), OCR_CACHE_DIR: z.string().optional(), OCR_LANGS: z .string() @@ -104,7 +114,12 @@ const serverConfigSchema = allEnv.transform((val) => { imageModel: val.INFERENCE_IMAGE_MODEL, inferredTagLang: val.INFERENCE_LANG, contextLength: val.INFERENCE_CONTEXT_LENGTH, - supportsStructuredOutput: val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT, + outputSchema: + val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT !== undefined + ? val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT + ? ("structured" as const) + : ("plain" as const) + : val.INFERENCE_OUTPUT_SCHEMA, }, embedding: { textModel: val.EMBEDDING_TEXT_MODEL, diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 43a14410..e1f21dae 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -41,6 +41,16 @@ export interface InferenceClient { generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>; } +const mapInferenceOutputSchema = < + T, + S extends typeof serverConfig.inference.outputSchema, +>( + opts: Record<S, T>, + type: S, +): T => { + return opts[type]; +}; + export class InferenceClientFactory { static build(): InferenceClient | null { if (serverConfig.inference.openAIApiKey) { @@ -76,11 +86,16 @@ class OpenAIInferenceClient implements InferenceClient { { messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodResponseFormat(optsWithDefaults.schema, "schema") - : undefined, + response_format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, + json: { type: "json_object" }, + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), }, { signal: optsWithDefaults.abortSignal, @@ -107,11 +122,16 @@ class OpenAIInferenceClient implements InferenceClient { const chatCompletion = await this.openAI.chat.completions.create( { model: serverConfig.inference.imageModel, - response_format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodResponseFormat(optsWithDefaults.schema, "schema") - : undefined, + response_format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, + json: { type: "json_object" }, + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), messages: [ { role: "user", @@ -186,11 +206,16 @@ class OllamaInferenceClient implements InferenceClient { } const chatCompletion = await this.ollama.chat({ model: model, - format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodToJsonSchema(optsWithDefaults.schema) - : undefined, + format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodToJsonSchema(optsWithDefaults.schema) + : undefined, + json: "json", + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { |
