aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared/inference.ts
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-02-01 18:16:25 +0000
committerMohamed Bassem <me@mbassem.com>2025-02-01 18:16:25 +0000
commitfd7011aff5dd8ffde0fb10990da238f7baf9a814 (patch)
tree99df3086a838ee33c40722d803c05c45a3a22ae3 /packages/shared/inference.ts
parent0893446bed6cca753549ee8e3cf090f2fcf11d9d (diff)
downloadkarakeep-fd7011aff5dd8ffde0fb10990da238f7baf9a814.tar.zst
fix: Abort all IO when workers timeout instead of detaching. Fixes #742
Diffstat (limited to 'packages/shared/inference.ts')
-rw-r--r--packages/shared/inference.ts49
1 files changed, 37 insertions, 12 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 1573382f..f1bea9ff 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -15,6 +15,7 @@ export interface EmbeddingResponse {
export interface InferenceOptions {
json: boolean;
+ abortSignal?: AbortSignal;
}
const defaultInferenceOptions: InferenceOptions = {
@@ -24,13 +25,13 @@ const defaultInferenceOptions: InferenceOptions = {
export interface InferenceClient {
inferFromText(
prompt: string,
- opts: InferenceOptions,
+ opts: Partial<InferenceOptions>,
): Promise<InferenceResponse>;
inferFromImage(
prompt: string,
contentType: string,
image: string,
- opts: InferenceOptions,
+ opts: Partial<InferenceOptions>,
): Promise<InferenceResponse>;
generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
@@ -60,12 +61,18 @@ class OpenAIInferenceClient implements InferenceClient {
async inferFromText(
prompt: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
const chatCompletion = await this.openAI.chat.completions.create({
messages: [{ role: "user", content: prompt }],
model: serverConfig.inference.textModel,
- response_format: opts.json ? { type: "json_object" } : undefined,
+ response_format: optsWithDefaults.json
+ ? { type: "json_object" }
+ : undefined,
});
const response = chatCompletion.choices[0].message.content;
@@ -79,11 +86,17 @@ class OpenAIInferenceClient implements InferenceClient {
prompt: string,
contentType: string,
image: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
const chatCompletion = await this.openAI.chat.completions.create({
model: serverConfig.inference.imageModel,
- response_format: opts.json ? { type: "json_object" } : undefined,
+ response_format: optsWithDefaults.json
+ ? { type: "json_object" }
+ : undefined,
messages: [
{
role: "user",
@@ -136,12 +149,16 @@ class OllamaInferenceClient implements InferenceClient {
async runModel(
model: string,
prompt: string,
+ _opts: InferenceOptions,
image?: string,
- opts: InferenceOptions = defaultInferenceOptions,
) {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
const chatCompletion = await this.ollama.chat({
model: model,
- format: opts.json ? "json" : undefined,
+ format: optsWithDefaults.json ? "json" : undefined,
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {
@@ -179,13 +196,17 @@ class OllamaInferenceClient implements InferenceClient {
async inferFromText(
prompt: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
return await this.runModel(
serverConfig.inference.textModel,
prompt,
+ optsWithDefaults,
undefined,
- opts,
);
}
@@ -193,13 +214,17 @@ class OllamaInferenceClient implements InferenceClient {
prompt: string,
_contentType: string,
image: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
return await this.runModel(
serverConfig.inference.imageModel,
prompt,
+ optsWithDefaults,
image,
- opts,
);
}