From 1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Sun, 13 Apr 2025 17:03:58 +0000 Subject: fix: Allow using JSON mode for ollama users. Fixes #1160 --- packages/shared/inference.ts | 55 ++++++++++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 15 deletions(-) (limited to 'packages/shared/inference.ts') diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 43a14410..e1f21dae 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -41,6 +41,16 @@ export interface InferenceClient { generateEmbeddingFromText(inputs: string[]): Promise; } +const mapInferenceOutputSchema = < + T, + S extends typeof serverConfig.inference.outputSchema, +>( + opts: Record, + type: S, +): T => { + return opts[type]; +}; + export class InferenceClientFactory { static build(): InferenceClient | null { if (serverConfig.inference.openAIApiKey) { @@ -76,11 +86,16 @@ class OpenAIInferenceClient implements InferenceClient { { messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodResponseFormat(optsWithDefaults.schema, "schema") - : undefined, + response_format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, + json: { type: "json_object" }, + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), }, { signal: optsWithDefaults.abortSignal, @@ -107,11 +122,16 @@ class OpenAIInferenceClient implements InferenceClient { const chatCompletion = await this.openAI.chat.completions.create( { model: serverConfig.inference.imageModel, - response_format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodResponseFormat(optsWithDefaults.schema, "schema") - : undefined, + response_format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, + json: { type: "json_object" }, + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), messages: [ { role: "user", @@ -186,11 +206,16 @@ class OllamaInferenceClient implements InferenceClient { } const chatCompletion = await this.ollama.chat({ model: model, - format: - optsWithDefaults.schema && - serverConfig.inference.supportsStructuredOutput - ? zodToJsonSchema(optsWithDefaults.schema) - : undefined, + format: mapInferenceOutputSchema( + { + structured: optsWithDefaults.schema + ? zodToJsonSchema(optsWithDefaults.schema) + : undefined, + json: "json", + plain: undefined, + }, + serverConfig.inference.outputSchema, + ), stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { -- cgit v1.2.3-70-g09d2