aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared/inference.ts
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-02-01 18:53:43 +0000
committerMohamed Bassem <me@mbassem.com>2025-02-01 18:53:43 +0000
commita698aeaaebbaca14202670dd1efbbf666e360b8a (patch)
tree13749bceb0f8995b06e2b0bdbe52a8cc875a1d5a /packages/shared/inference.ts
parentfd7011aff5dd8ffde0fb10990da238f7baf9a814 (diff)
downloadkarakeep-a698aeaaebbaca14202670dd1efbbf666e360b8a.tar.zst
fix: Fix missing handling for AbortSignal in inference client
Diffstat (limited to 'packages/shared/inference.ts')
-rw-r--r--packages/shared/inference.ts81
1 files changed, 53 insertions, 28 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index f1bea9ff..e5ddf5ca 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -67,13 +67,18 @@ class OpenAIInferenceClient implements InferenceClient {
...defaultInferenceOptions,
..._opts,
};
- const chatCompletion = await this.openAI.chat.completions.create({
- messages: [{ role: "user", content: prompt }],
- model: serverConfig.inference.textModel,
- response_format: optsWithDefaults.json
- ? { type: "json_object" }
- : undefined,
- });
+ const chatCompletion = await this.openAI.chat.completions.create(
+ {
+ messages: [{ role: "user", content: prompt }],
+ model: serverConfig.inference.textModel,
+ response_format: optsWithDefaults.json
+ ? { type: "json_object" }
+ : undefined,
+ },
+ {
+ signal: optsWithDefaults.abortSignal,
+ },
+ );
const response = chatCompletion.choices[0].message.content;
if (!response) {
@@ -92,28 +97,33 @@ class OpenAIInferenceClient implements InferenceClient {
...defaultInferenceOptions,
..._opts,
};
- const chatCompletion = await this.openAI.chat.completions.create({
- model: serverConfig.inference.imageModel,
- response_format: optsWithDefaults.json
- ? { type: "json_object" }
- : undefined,
- messages: [
- {
- role: "user",
- content: [
- { type: "text", text: prompt },
- {
- type: "image_url",
- image_url: {
- url: `data:${contentType};base64,${image}`,
- detail: "low",
+ const chatCompletion = await this.openAI.chat.completions.create(
+ {
+ model: serverConfig.inference.imageModel,
+ response_format: optsWithDefaults.json
+ ? { type: "json_object" }
+ : undefined,
+ messages: [
+ {
+ role: "user",
+ content: [
+ { type: "text", text: prompt },
+ {
+ type: "image_url",
+ image_url: {
+ url: `data:${contentType};base64,${image}`,
+ detail: "low",
+ },
},
- },
- ],
- },
- ],
- max_tokens: 2000,
- });
+ ],
+ },
+ ],
+ max_tokens: 2000,
+ },
+ {
+ signal: optsWithDefaults.abortSignal,
+ },
+ );
const response = chatCompletion.choices[0].message.content;
if (!response) {
@@ -156,6 +166,14 @@ class OllamaInferenceClient implements InferenceClient {
...defaultInferenceOptions,
..._opts,
};
+
+ let newAbortSignal = undefined;
+ if (optsWithDefaults.abortSignal) {
+ newAbortSignal = AbortSignal.any([optsWithDefaults.abortSignal]);
+ newAbortSignal.onabort = () => {
+ this.ollama.abort();
+ };
+ }
const chatCompletion = await this.ollama.chat({
model: model,
format: optsWithDefaults.json ? "json" : undefined,
@@ -182,6 +200,9 @@ class OllamaInferenceClient implements InferenceClient {
}
}
} catch (e) {
+ if (e instanceof Error && e.name === "AbortError") {
+ throw e;
+ }
// There seem to be some bug in ollama where you can get some successful response, but still throw an error.
// Using stream + accumulating the response so far is a workaround.
// https://github.com/ollama/ollama-js/issues/72
@@ -189,6 +210,10 @@ class OllamaInferenceClient implements InferenceClient {
logger.warn(
`Got an exception from ollama, will still attempt to deserialize the response we got so far: ${e}`,
);
+ } finally {
+ if (newAbortSignal) {
+ newAbortSignal.onabort = null;
+ }
}
return { response, totalTokens };