diff options
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 47 |
1 files changed, 39 insertions, 8 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index f34c2880..e09076db 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -9,12 +9,24 @@ export interface InferenceResponse { totalTokens: number | undefined; } +export interface InferenceOptions { + json: boolean; +} + +const defaultInferenceOptions: InferenceOptions = { + json: true, +}; + export interface InferenceClient { - inferFromText(prompt: string): Promise<InferenceResponse>; + inferFromText( + prompt: string, + opts: InferenceOptions, + ): Promise<InferenceResponse>; inferFromImage( prompt: string, contentType: string, image: string, + opts: InferenceOptions, ): Promise<InferenceResponse>; } @@ -41,11 +53,14 @@ class OpenAIInferenceClient implements InferenceClient { }); } - async inferFromText(prompt: string): Promise<InferenceResponse> { + async inferFromText( + prompt: string, + opts: InferenceOptions = defaultInferenceOptions, + ): Promise<InferenceResponse> { const chatCompletion = await this.openAI.chat.completions.create({ messages: [{ role: "user", content: prompt }], model: serverConfig.inference.textModel, - response_format: { type: "json_object" }, + response_format: opts.json ? { type: "json_object" } : undefined, }); const response = chatCompletion.choices[0].message.content; @@ -59,10 +74,11 @@ class OpenAIInferenceClient implements InferenceClient { prompt: string, contentType: string, image: string, + opts: InferenceOptions = defaultInferenceOptions, ): Promise<InferenceResponse> { const chatCompletion = await this.openAI.chat.completions.create({ model: serverConfig.inference.imageModel, - response_format: { type: "json_object" }, + response_format: opts.json ? { type: "json_object" } : undefined, messages: [ { role: "user", @@ -98,10 +114,15 @@ class OllamaInferenceClient implements InferenceClient { }); } - async runModel(model: string, prompt: string, image?: string) { + async runModel( + model: string, + prompt: string, + image?: string, + opts: InferenceOptions = defaultInferenceOptions, + ) { const chatCompletion = await this.ollama.chat({ model: model, - format: "json", + format: opts.json ? "json" : undefined, stream: true, keep_alive: serverConfig.inference.ollamaKeepAlive, options: { @@ -137,19 +158,29 @@ class OllamaInferenceClient implements InferenceClient { return { response, totalTokens }; } - async inferFromText(prompt: string): Promise<InferenceResponse> { - return await this.runModel(serverConfig.inference.textModel, prompt); + async inferFromText( + prompt: string, + opts: InferenceOptions = defaultInferenceOptions, + ): Promise<InferenceResponse> { + return await this.runModel( + serverConfig.inference.textModel, + prompt, + undefined, + opts, + ); } async inferFromImage( prompt: string, _contentType: string, image: string, + opts: InferenceOptions = defaultInferenceOptions, ): Promise<InferenceResponse> { return await this.runModel( serverConfig.inference.imageModel, prompt, image, + opts, ); } } |
