diff options
| author | Mohamed Bassem <me@mbassem.com> | 2024-10-27 00:12:11 +0000 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2024-10-27 00:12:11 +0000 |
| commit | 731d2dfbea39aa140ccb6d2d2cabd49186320299 (patch) | |
| tree | 2311d04b5dc61102c63d4e4ec9c7c97b359faad6 /packages/shared/inference.ts | |
| parent | 3e727f7ba3ad157ca1ccc6100711266cae1bde23 (diff) | |
| download | karakeep-731d2dfbea39aa140ccb6d2d2cabd49186320299.tar.zst | |
feature: Add a summarize with AI button for links
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 47 |
1 files changed, 39 insertions, 8 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index f34c2880..e09076db 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -9,12 +9,24 @@ export interface InferenceResponse { totalTokens: number | undefined; } +export interface InferenceOptions { + json: boolean; +} + +const defaultInferenceOptions: InferenceOptions = { + json: true, +}; + export interface InferenceClient { - inferFromText(prompt: string): Promise<InferenceResponse>; + inferFromText( + prompt: string, + opts: InferenceOptions, + ): Promise<InferenceResponse>; inferFromImage( prompt: string, contentType: string, image: string, + opts: InferenceOptions, ): Promise<InferenceResponse>; } @@ -41,11 +53,14 @@ class OpenAIInferenceClient implements InferenceClient { }); } - async inferFromText(prompt: string): Promise<InferenceResponse> { + async inferFromText( + prompt: string, + opts: InferenceOptions = defaultInferenceOptions, + ): Promise<InferenceResponse> { const chatCompletion = await this.openAI.chat.completions.create({ messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, + response_format: opts.json ? { type: "json_object" } : undefined, }); const response = chatCompletion.choices[0].message.content; @@ -59,10 +74,11 @@ class OpenAIInferenceClient implements InferenceClient { prompt: string, contentType: string, image: string, + opts: InferenceOptions = defaultInferenceOptions, ): Promise<InferenceResponse> { const chatCompletion = await this.openAI.chat.completions.create({ model: serverConfig.inference.imageModel, - response_format: { type: "json_object" }, + response_format: opts.json ? { type: "json_object" } : undefined, messages: [ { role: "user", @@ -98,10 +114,15 @@ class OllamaInferenceClient implements InferenceClient { }); } - async runModel(model: string, prompt: string, image?: string) { + async runModel( + model: string, + prompt: string, + image?: string, + opts: InferenceOptions = defaultInferenceOptions, + ) { const chatCompletion = await this.ollama.chat({ model: model, - format: "json", + format: opts.json ? "json" : undefined, stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { @@ -137,19 +158,29 @@ class OllamaInferenceClient implements InferenceClient { return { response, totalTokens }; } - async inferFromText(prompt: string): Promise<InferenceResponse> { - return await this.runModel(serverConfig.inference.textModel, prompt); + async inferFromText( + prompt: string, + opts: InferenceOptions = defaultInferenceOptions, + ): Promise<InferenceResponse> { + return await this.runModel( + serverConfig.inference.textModel, + prompt, + undefined, + opts, + ); } async inferFromImage( prompt: string, _contentType: string, image: string, + opts: InferenceOptions = defaultInferenceOptions, ): Promise<InferenceResponse> { return await this.runModel( serverConfig.inference.imageModel, prompt, image, + opts, ); } } |
