aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared/inference.ts
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-04-13 17:03:58 +0000
committerMohamed Bassem <me@mbassem.com>2025-04-13 17:03:58 +0000
commit1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0 (patch)
treeeb88bb3c6f04d8d4dea1be889cb8a8e552ca91ba /packages/shared/inference.ts
parentf3c525b7f7dd360f654d8621bbf64e31ad5ff48e (diff)
downloadkarakeep-1373a7b21d7b04f0fe5ea2a008c88b6a85665fe0.tar.zst
fix: Allow using JSON mode for ollama users. Fixes #1160
Diffstat (limited to 'packages/shared/inference.ts')
-rw-r--r--packages/shared/inference.ts55
1 files changed, 40 insertions, 15 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 43a14410..e1f21dae 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -41,6 +41,16 @@ export interface InferenceClient {
generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
+const mapInferenceOutputSchema = <
+ T,
+ S extends typeof serverConfig.inference.outputSchema,
+>(
+ opts: Record<S, T>,
+ type: S,
+): T => {
+ return opts[type];
+};
+
export class InferenceClientFactory {
static build(): InferenceClient | null {
if (serverConfig.inference.openAIApiKey) {
@@ -76,11 +86,16 @@ class OpenAIInferenceClient implements InferenceClient {
{
messages: [{ role: "user", content: prompt }],
model: serverConfig.inference.textModel,
- response_format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodResponseFormat(optsWithDefaults.schema, "schema")
- : undefined,
+ response_format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodResponseFormat(optsWithDefaults.schema, "schema")
+ : undefined,
+ json: { type: "json_object" },
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
},
{
signal: optsWithDefaults.abortSignal,
@@ -107,11 +122,16 @@ class OpenAIInferenceClient implements InferenceClient {
const chatCompletion = await this.openAI.chat.completions.create(
{
model: serverConfig.inference.imageModel,
- response_format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodResponseFormat(optsWithDefaults.schema, "schema")
- : undefined,
+ response_format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodResponseFormat(optsWithDefaults.schema, "schema")
+ : undefined,
+ json: { type: "json_object" },
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
messages: [
{
role: "user",
@@ -186,11 +206,16 @@ class OllamaInferenceClient implements InferenceClient {
}
const chatCompletion = await this.ollama.chat({
model: model,
- format:
- optsWithDefaults.schema &&
- serverConfig.inference.supportsStructuredOutput
- ? zodToJsonSchema(optsWithDefaults.schema)
- : undefined,
+ format: mapInferenceOutputSchema(
+ {
+ structured: optsWithDefaults.schema
+ ? zodToJsonSchema(optsWithDefaults.schema)
+ : undefined,
+ json: "json",
+ plain: undefined,
+ },
+ serverConfig.inference.outputSchema,
+ ),
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {