aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-04-13 17:03:58 +0000
committerMohamed Bassem <me@mbassem.com>2025-04-13 17:03:58 +0000
commit1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0 (patch)
treeeb88bb3c6f04d8d4dea1be889cb8a8e552ca91ba /packages/shared
parentf3c525b7f7dd360f654d8621bbf64e31ad5ff48e (diff)
downloadkarakeep-1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0.tar.zst
fix: Allow using JSON mode for ollama users. Fixes #1160
Diffstat (limited to 'packages/shared')
-rw-r--r--packages/shared/config.ts19
-rw-r--r--packages/shared/inference.ts55
2 files changed, 57 insertions, 17 deletions
diff --git a/packages/shared/config.ts b/packages/shared/config.ts
index fd224ea7..8abb5902 100644
--- a/packages/shared/config.ts
+++ b/packages/shared/config.ts
@@ -8,6 +8,13 @@ const stringBool = (defaultValue: string) =>
.refine((s) => s === "true" || s === "false")
.transform((s) => s === "true");
+const optionalStringBool = () =>
+ z
+ .string()
+ .refine((s) => s === "true" || s === "false")
+ .transform((s) => s === "true")
+ .optional();
+
const allEnv = z.object({
API_URL: z.string().url().default("http://localhost:3000"),
DISABLE_SIGNUPS: stringBool("false"),
@@ -29,7 +36,10 @@ const allEnv = z.object({
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"),
INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048),
- INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: stringBool("true"),
+ INFERENCE_SUPPORTS_STRUCTURED_OUTPUT: optionalStringBool(),
+ INFERENCE_OUTPUT_SCHEMA: z
+ .enum(["structured", "json", "plain"])
+ .default("structured"),
OCR_CACHE_DIR: z.string().optional(),
OCR_LANGS: z
.string()
@@ -104,7 +114,12 @@ const serverConfigSchema = allEnv.transform((val) => {
imageModel: val.INFERENCE_IMAGE_MODEL,
inferredTagLang: val.INFERENCE_LANG,
contextLength: val.INFERENCE_CONTEXT_LENGTH,
- supportsStructuredOutput: val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT,
+ outputSchema:
+ val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT !== undefined
+ ? val.INFERENCE_SUPPORTS_STRUCTURED_OUTPUT
+ ? ("structured" as const)
+ : ("plain" as const)
+ : val.INFERENCE_OUTPUT_SCHEMA,
},
embedding: {
textModel: val.EMBEDDING_TEXT_MODEL,
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 43a14410..e1f21dae 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -41,6 +41,16 @@ export interface InferenceClient {
generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
+const mapInferenceOutputSchema = <
+ T,
+ S extends typeof serverConfig.inference.outputSchema,
+>(
+ opts: Record<S, T>,
+ type: S,
+): T => {
+ return opts[type];
+};
+
export class InferenceClientFactory {
static build(): InferenceClient | null {
if (serverConfig.inference.openAIApiKey) {
@@ -76,11 +86,16 @@ class OpenAIInferenceClient implements InferenceClient {
{
messages: [{ role: "user", content: prompt }],
model: serverConfig.inference.textModel,
- response_format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodResponseFormat(optsWithDefaults.schema, "schema")
- : undefined,
+ response_format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodResponseFormat(optsWithDefaults.schema, "schema")
+ : undefined,
+ json: { type: "json_object" },
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
},
{
signal: optsWithDefaults.abortSignal,
@@ -107,11 +122,16 @@ class OpenAIInferenceClient implements InferenceClient {
const chatCompletion = await this.openAI.chat.completions.create(
{
model: serverConfig.inference.imageModel,
- response_format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodResponseFormat(optsWithDefaults.schema, "schema")
- : undefined,
+ response_format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodResponseFormat(optsWithDefaults.schema, "schema")
+ : undefined,
+ json: { type: "json_object" },
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
messages: [
{
role: "user",
@@ -186,11 +206,16 @@ class OllamaInferenceClient implements InferenceClient {
}
const chatCompletion = await this.ollama.chat({
model: model,
- format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodToJsonSchema(optsWithDefaults.schema)
- : undefined,
+ format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodToJsonSchema(optsWithDefaults.schema)
+ : undefined,
+ json: "json",
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {