diff options
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 30 |
1 files changed, 21 insertions, 9 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 92d9dd94..43a14410 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -1,5 +1,8 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; +import { zodResponseFormat } from "openai/helpers/zod"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import serverConfig from "./config"; import { customFetch } from "./customFetch"; @@ -15,12 +18,13 @@ export interface EmbeddingResponse { } export interface InferenceOptions { - json: boolean; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + schema: z.ZodSchema<any> | null; abortSignal?: AbortSignal; } const defaultInferenceOptions: InferenceOptions = { - json: true, + schema: null, }; export interface InferenceClient { @@ -72,9 +76,11 @@ class OpenAIInferenceClient implements InferenceClient { { messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: optsWithDefaults.json - ? { type: "json_object" } - : undefined, + response_format: + optsWithDefaults.schema && + serverConfig.inference.supportsStructuredOutput + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, }, { signal: optsWithDefaults.abortSignal, @@ -101,9 +107,11 @@ class OpenAIInferenceClient implements InferenceClient { const chatCompletion = await this.openAI.chat.completions.create( { model: serverConfig.inference.imageModel, - response_format: optsWithDefaults.json - ? { type: "json_object" } - : undefined, + response_format: + optsWithDefaults.schema && + serverConfig.inference.supportsStructuredOutput + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, messages: [ { role: "user", @@ -178,7 +186,11 @@ class OllamaInferenceClient implements InferenceClient { } const chatCompletion = await this.ollama.chat({ model: model, - format: optsWithDefaults.json ? "json" : undefined, + format: + optsWithDefaults.schema && + serverConfig.inference.supportsStructuredOutput + ? zodToJsonSchema(optsWithDefaults.schema) + : undefined, stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { |
