aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--docs/docs/03-configuration.md3
-rw-r--r--packages/shared/config.ts19
-rw-r--r--packages/shared/inference.ts55
3 files changed, 59 insertions, 18 deletions
diff --git a/docs/docs/03-configuration.md b/docs/docs/03-configuration.md
index 51ee23a5..3790d289 100644
--- a/docs/docs/03-configuration.md
+++ b/docs/docs/03-configuration.md
@@ -63,7 +63,8 @@ Either `OPENAI_API_KEY` or `OLLAMA_BASE_URL` need to be set for automatic taggin
| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |
| INFERENCE_FETCH_TIMEOUT_SEC | No | 300 | \[Ollama Only\] The timeout of the fetch request to the ollama server. If your inference requests take longer than the default 5mins, you might want to increase this timeout. |
-| INFERENCE_SUPPORTS_STRUCTURED_OUTPUT | No | true | Whether the inference model supports structured output or not. |
+| INFERENCE_SUPPORTS_STRUCTURED_OUTPUT | No | Not set | \[DEPRECATED\] Whether the inference model supports structured output or not. Use INFERENCE_OUTPUT_SCHEMA instead. Setting this to true translates to INFERENCE_OUTPUT_SCHEMA=structured, and to false translates to INFERENCE_OUTPUT_SCHEMA=plain. |
+| INFERENCE_OUTPUT_SCHEMA | No | structured | Possible values are "structured", "json", "plain". Structured is the preferred option, but if your model doesn't support it, you can use "json" if your model supports JSON mode, otherwise "plain" which should be supported by all the models but the model might not output the data in the correct format. |
:::info
diff --git a/packages/shared/config.ts b/packages/shared/config.ts
index fd224ea7..8abb5902 100644
--- a/packages/shared/config.ts
+++ b/packages/shared/config.ts
@@ -8,6 +8,13 @@ const stringBool = (defaultValue: string) =>
.refine((s) => s === "true" || s === "false")
.transform((s) => s === "true");
+const optionalStringBool = () =>
+ z
+ .string()
+ .refine((s) => s === "true" || s === "false")
+ .transform((s) => s === "true")
+ .optional();
+
const allEnv = z.object({
API_URL: z.string().url().default("http://localhost:3000"),
DISABLE_SIGNUPS: stringBool("false"),
@@ -29,7 +36,10 @@ const allEnv = z.object({
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"),
INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048),
- INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: stringBool("true"),
+ INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: optionalStringBool(),
+ INFERENCE_OUTPUT_SCHEMA: z
+ .enum(["structured", "json", "plain"])
+ .default("structured"),
OCR_CACHE_DIR: z.string().optional(),
OCR_LANGS: z
.string()
@@ -104,7 +114,12 @@ const serverConfigSchema = allEnv.transform((val) => {
imageModel: val.INFERENCE_IMAGE_MODEL,
inferredTagLang: val.INFERENCE_LANG,
contextLength: val.INFERENCE_CONTEXT_LENGTH,
- supportsStructuredOutput: val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT,
+ outputSchema:
+ val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT !== undefined
+ ? val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT
+ ? ("structured" as const)
+ : ("plain" as const)
+ : val.INFERENCE_OUTPUT_SCHEMA,
},
embedding: {
textModel: val.EMBEDDING_TEXT_MODEL,
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 43a14410..e1f21dae 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -41,6 +41,16 @@ export interface InferenceClient {
generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
+const mapInferenceOutputSchema = <
+ T,
+ S extends typeof serverConfig.inference.outputSchema,
+>(
+ opts: Record<S, T>,
+ type: S,
+): T => {
+ return opts[type];
+};
+
export class InferenceClientFactory {
static build(): InferenceClient | null {
if (serverConfig.inference.openAIApiKey) {
@@ -76,11 +86,16 @@ class OpenAIInferenceClient implements InferenceClient {
{
messages: [{ role: "user", content: prompt }],
model: serverConfig.inference.textModel,
- response_format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodResponseFormat(optsWithDefaults.schema, "schema")
- : undefined,
+ response_format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodResponseFormat(optsWithDefaults.schema, "schema")
+ : undefined,
+ json: { type: "json_object" },
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
},
{
signal: optsWithDefaults.abortSignal,
@@ -107,11 +122,16 @@ class OpenAIInferenceClient implements InferenceClient {
const chatCompletion = await this.openAI.chat.completions.create(
{
model: serverConfig.inference.imageModel,
- response_format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodResponseFormat(optsWithDefaults.schema, "schema")
- : undefined,
+ response_format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodResponseFormat(optsWithDefaults.schema, "schema")
+ : undefined,
+ json: { type: "json_object" },
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
messages: [
{
role: "user",
@@ -186,11 +206,16 @@ class OllamaInferenceClient implements InferenceClient {
}
const chatCompletion = await this.ollama.chat({
model: model,
- format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodToJsonSchema(optsWithDefaults.schema)
- : undefined,
+ format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodToJsonSchema(optsWithDefaults.schema)
+ : undefined,
+ json: "json",
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {