diff options
| author | Mohamed Bassem <me@mbassem.com> | 2025-02-01 18:16:25 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2025-02-01 18:16:25 +0000 |
| commit | fd7011aff5dd8ffde0fb10990da238f7baf9a814 (patch) | |
| tree | 99df3086a838ee33c40722d803c05c45a3a22ae3 /packages/shared/inference.ts | |
| parent | 0893446bed6cca753549ee8e3cf090f2fcf11d9d (diff) | |
| download | karakeep-fd7011aff5dd8ffde0fb10990da238f7baf9a814.tar.zst | |
fix: Abort all IO when workers timeout instead of detaching. Fixes #742
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 49 |
1 files changed, 37 insertions, 12 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 1573382f..f1bea9ff 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -15,6 +15,7 @@ export interface EmbeddingResponse { export interface InferenceOptions { json: boolean; + abortSignal?: AbortSignal; } const defaultInferenceOptions: InferenceOptions = { @@ -24,13 +25,13 @@ const defaultInferenceOptions: InferenceOptions = { export interface InferenceClient { inferFromText( prompt: string, - opts: InferenceOptions, + opts: Partial<InferenceOptions>, ): Promise<InferenceResponse>; inferFromImage( prompt: string, contentType: string, image: string, - opts: InferenceOptions, + opts: Partial<InferenceOptions>, ): Promise<InferenceResponse>; generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>; } @@ -60,12 +61,18 @@ class OpenAIInferenceClient implements InferenceClient { async inferFromText( prompt: string, - opts: InferenceOptions = defaultInferenceOptions, + _opts: Partial<InferenceOptions>, ): Promise<InferenceResponse> { + const optsWithDefaults: InferenceOptions = { + ...defaultInferenceOptions, + ..._opts, + }; const chatCompletion = await this.openAI.chat.completions.create({ messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: opts.json ? { type: "json_object" } : undefined, + response_format: optsWithDefaults.json + ? { type: "json_object" } + : undefined, }); const response = chatCompletion.choices[0].message.content; @@ -79,11 +86,17 @@ class OpenAIInferenceClient implements InferenceClient { prompt: string, contentType: string, image: string, - opts: InferenceOptions = defaultInferenceOptions, + _opts: Partial<InferenceOptions>, ): Promise<InferenceResponse> { + const optsWithDefaults: InferenceOptions = { + ...defaultInferenceOptions, + ..._opts, + }; const chatCompletion = await this.openAI.chat.completions.create({ model: serverConfig.inference.imageModel, - response_format: opts.json ? { type: "json_object" } : undefined, + response_format: optsWithDefaults.json + ? { type: "json_object" } + : undefined, messages: [ { role: "user", @@ -136,12 +149,16 @@ class OllamaInferenceClient implements InferenceClient { async runModel( model: string, prompt: string, + _opts: InferenceOptions, image?: string, - opts: InferenceOptions = defaultInferenceOptions, ) { + const optsWithDefaults: InferenceOptions = { + ...defaultInferenceOptions, + ..._opts, + }; const chatCompletion = await this.ollama.chat({ model: model, - format: opts.json ? "json" : undefined, + format: optsWithDefaults.json ? "json" : undefined, stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { @@ -179,13 +196,17 @@ class OllamaInferenceClient implements InferenceClient { async inferFromText( prompt: string, - opts: InferenceOptions = defaultInferenceOptions, + _opts: Partial<InferenceOptions>, ): Promise<InferenceResponse> { + const optsWithDefaults: InferenceOptions = { + ...defaultInferenceOptions, + ..._opts, + }; return await this.runModel( serverConfig.inference.textModel, prompt, + optsWithDefaults, undefined, - opts, ); } @@ -193,13 +214,17 @@ class OllamaInferenceClient implements InferenceClient { prompt: string, _contentType: string, image: string, - opts: InferenceOptions = defaultInferenceOptions, + _opts: Partial<InferenceOptions>, ): Promise<InferenceResponse> { + const optsWithDefaults: InferenceOptions = { + ...defaultInferenceOptions, + ..._opts, + }; return await this.runModel( serverConfig.inference.imageModel, prompt, + optsWithDefaults, image, - opts, ); } |
