diff options
Diffstat (limited to 'packages/shared')
| -rw-r--r-- | packages/shared/config.ts | 2 | ||||
| -rw-r--r-- | packages/shared/inference.ts | 30 | ||||
| -rw-r--r-- | packages/shared/package.json | 7 |
3 files changed, 27 insertions, 12 deletions
diff --git a/packages/shared/config.ts b/packages/shared/config.ts index 6e5a4404..1295fdbf 100644 --- a/packages/shared/config.ts +++ b/packages/shared/config.ts @@ -27,6 +27,7 @@ const allEnv = z.object({ INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"), EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"), INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048), + INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: stringBool("true"), OCR_CACHE_DIR: z.string().optional(), OCR_LANGS: z .string() @@ -94,6 +95,7 @@ const serverConfigSchema = allEnv.transform((val) => { imageModel: val.INFERENCE_IMAGE_MODEL, inferredTagLang: val.INFERENCE_LANG, contextLength: val.INFERENCE_CONTEXT_LENGTH, + supportsStructuredOutput: val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT, }, embedding: { textModel: val.EMBEDDING_TEXT_MODEL, diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 92d9dd94..43a14410 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -1,5 +1,8 @@ import { Ollama } from "ollama"; import OpenAI from "openai"; +import { zodResponseFormat } from "openai/helpers/zod"; +import { z } from "zod"; +import { zodToJsonSchema } from "zod-to-json-schema"; import serverConfig from "./config"; import { customFetch } from "./customFetch"; @@ -15,12 +18,13 @@ export interface EmbeddingResponse { } export interface InferenceOptions { - json: boolean; + // eslint-disable-next-line @typescript-eslint/no-explicit-any + schema: z.ZodSchema<any> | null; abortSignal?: AbortSignal; } const defaultInferenceOptions: InferenceOptions = { - json: true, + schema: null, }; export interface InferenceClient { @@ -72,9 +76,11 @@ class OpenAIInferenceClient implements InferenceClient { { messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: optsWithDefaults.json - ? { type: "json_object" } - : undefined, + response_format: + optsWithDefaults.schema && + serverConfig.inference.supportsStructuredOutput + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, }, { signal: optsWithDefaults.abortSignal, @@ -101,9 +107,11 @@ class OpenAIInferenceClient implements InferenceClient { const chatCompletion = await this.openAI.chat.completions.create( { model: serverConfig.inference.imageModel, - response_format: optsWithDefaults.json - ? { type: "json_object" } - : undefined, + response_format: + optsWithDefaults.schema && + serverConfig.inference.supportsStructuredOutput + ? zodResponseFormat(optsWithDefaults.schema, "schema") + : undefined, messages: [ { role: "user", @@ -178,7 +186,11 @@ class OllamaInferenceClient implements InferenceClient { } const chatCompletion = await this.ollama.chat({ model: model, - format: optsWithDefaults.json ? "json" : undefined, + format: + optsWithDefaults.schema && + serverConfig.inference.supportsStructuredOutput + ? zodToJsonSchema(optsWithDefaults.schema) + : undefined, stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { diff --git a/packages/shared/package.json b/packages/shared/package.json index ecb16013..b868f9e3 100644 --- a/packages/shared/package.json +++ b/packages/shared/package.json @@ -8,11 +8,12 @@ "glob": "^11.0.0", "liteque": "^0.3.2", "meilisearch": "^0.37.0", - "ollama": "^0.5.9", - "openai": "^4.67.1", + "ollama": "^0.5.14", + "openai": "^4.86.1", "typescript-parsec": "^0.3.4", "winston": "^3.11.0", - "zod": "^3.22.4" + "zod": "^3.22.4", + "zod-to-json-schema": "^3.24.3" }, "devDependencies": { "@hoarder/eslint-config": "workspace:^0.2.0", |
