diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-04-13 17:03:58 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-04-13 17:03:58 +0000 |
| commit | 1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0 (patch) | |
| tree | eb88bb3c6f04d8d4dea1be889cb8a8e552ca91ba /packages/shared/inference.ts | |
| parent | f3c525b7f7dd360f654d8621bbf64e31ad5ff48e (diff) | |
| download | karakeep-1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0.tar.zst | |
fix: Allow using JSON mode for ollama users. Fixes #1160
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 55 |
1 files changed, 40 insertions, 15 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 43a14410..e1f21dae 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -41,6 +41,16 @@ export interface InferenceClient { generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>; } +const mapInferenceOutputSchema = < + T, + S extends typeof serverConfig.inference.outputSchema, +>( + opts: Record<S, T>, + type: S, +): T => { + return opts[type]; +}; + export class InferenceClientFactory { static build(): InferenceClient | null { if (serverConfig.inference.openAIApiKey) { @@ -76,11 +86,16 @@ class OpenAIInferenceClient implements InferenceClient { { messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodResponseFormat(optsWithDefaults.schema, "schema") - : undefined, + response_format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, + json: { type: "json_object" }, + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), }, { signal: optsWithDefaults.abortSignal, @@ -107,11 +122,16 @@ class OpenAIInferenceClient implements InferenceClient { const chatCompletion = await this.openAI.chat.completions.create( { model: serverConfig.inference.imageModel, - response_format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodResponseFormat(optsWithDefaults.schema, "schema") - : undefined, + response_format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, + json: { type: "json_object" }, + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), messages: [ { role: "user", @@ -186,11 +206,16 @@ class OllamaInferenceClient implements InferenceClient { } const chatCompletion = await this.ollama.chat({ model: model, - format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodToJsonSchema(optsWithDefaults.schema) - : undefined, + format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodToJsonSchema(optsWithDefaults.schema) + : undefined, + json: "json", + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { |
