From 731d2dfbea39aa140ccb6d2d2cabd49186320299 Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Sun, 27 Oct 2024 00:12:11 +0000 Subject: feature: Add a summarize with AI button for links --- packages/shared/inference.ts | 47 ++++++++++++++++++++++++++++++++++++-------- 1 file changed, 39 insertions(+), 8 deletions(-) (limited to 'packages/shared/inference.ts') diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index f34c2880..e09076db 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -9,12 +9,24 @@ export interface InferenceResponse { totalTokens: number | undefined; } +export interface InferenceOptions { + json: boolean; +} + +const defaultInferenceOptions: InferenceOptions = { + json: true, +}; + export interface InferenceClient { - inferFromText(prompt: string): Promise; + inferFromText( + prompt: string, + opts: InferenceOptions, + ): Promise; inferFromImage( prompt: string, contentType: string, image: string, + opts: InferenceOptions, ): Promise; } @@ -41,11 +53,14 @@ class OpenAIInferenceClient implements InferenceClient { }); } - async inferFromText(prompt: string): Promise { + async inferFromText( + prompt: string, + opts: InferenceOptions = defaultInferenceOptions, + ): Promise { const chatCompletion = await this.openAI.chat.completions.create({ messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, + response_format: opts.json ? { type: "json_object" } : undefined, }); const response = chatCompletion.choices[0].message.content; @@ -59,10 +74,11 @@ class OpenAIInferenceClient implements InferenceClient { prompt: string, contentType: string, image: string, + opts: InferenceOptions = defaultInferenceOptions, ): Promise { const chatCompletion = await this.openAI.chat.completions.create({ model: serverConfig.inference.imageModel, - response_format: { type: "json_object" }, + response_format: opts.json ? { type: "json_object" } : undefined, messages: [ { role: "user", @@ -98,10 +114,15 @@ class OllamaInferenceClient implements InferenceClient { }); } - async runModel(model: string, prompt: string, image?: string) { + async runModel( + model: string, + prompt: string, + image?: string, + opts: InferenceOptions = defaultInferenceOptions, + ) { const chatCompletion = await this.ollama.chat({ model: model, - format: "json", + format: opts.json ? "json" : undefined, stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { @@ -137,19 +158,29 @@ class OllamaInferenceClient implements InferenceClient { return { response, totalTokens }; } - async inferFromText(prompt: string): Promise { - return await this.runModel(serverConfig.inference.textModel, prompt); + async inferFromText( + prompt: string, + opts: InferenceOptions = defaultInferenceOptions, + ): Promise { + return await this.runModel( + serverConfig.inference.textModel, + prompt, + undefined, + opts, + ); } async inferFromImage( prompt: string, _contentType: string, image: string, + opts: InferenceOptions = defaultInferenceOptions, ): Promise { return await this.runModel( serverConfig.inference.imageModel, prompt, image, + opts, ); } } -- cgit v1.2.3-70-g09d2