1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
import { Ollama } from "ollama";
import OpenAI from "openai";
import serverConfig from "@hoarder/shared/config";
export interface InferenceResponse {
response: string;
totalTokens: number | undefined;
}
export interface InferenceClient {
inferFromText(prompt: string): Promise<InferenceResponse>;
inferFromImage(
prompt: string,
contentType: string,
image: string,
): Promise<InferenceResponse>;
}
export class InferenceClientFactory {
static build(): InferenceClient | null {
if (serverConfig.inference.openAIApiKey) {
return new OpenAIInferenceClient();
}
if (serverConfig.inference.ollamaBaseUrl) {
return new OllamaInferenceClient();
}
return null;
}
}
class OpenAIInferenceClient implements InferenceClient {
openAI: OpenAI;
constructor() {
this.openAI = new OpenAI({
apiKey: serverConfig.inference.openAIApiKey,
baseURL: serverConfig.inference.openAIBaseUrl,
});
}
async inferFromText(prompt: string): Promise<InferenceResponse> {
const chatCompletion = await this.openAI.chat.completions.create({
messages: [{ role: "system", content: prompt }],
model: serverConfig.inference.textModel,
response_format: { type: "json_object" },
});
const response = chatCompletion.choices[0].message.content;
if (!response) {
throw new Error(`Got no message content from OpenAI`);
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}
async inferFromImage(
prompt: string,
contentType: string,
image: string,
): Promise<InferenceResponse> {
const chatCompletion = await this.openAI.chat.completions.create({
model: serverConfig.inference.imageModel,
messages: [
{
role: "user",
content: [
{ type: "text", text: prompt },
{
type: "image_url",
image_url: {
url: `data:${contentType};base64,${image}`,
detail: "low",
},
},
],
},
],
max_tokens: 2000,
});
const response = chatCompletion.choices[0].message.content;
if (!response) {
throw new Error(`Got no message content from OpenAI`);
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}
}
class OllamaInferenceClient implements InferenceClient {
ollama: Ollama;
constructor() {
this.ollama = new Ollama({
host: serverConfig.inference.ollamaBaseUrl,
});
}
async inferFromText(prompt: string): Promise<InferenceResponse> {
const chatCompletion = await this.ollama.chat({
model: serverConfig.inference.textModel,
format: "json",
messages: [{ role: "system", content: prompt }],
});
const response = chatCompletion.message.content;
return { response, totalTokens: chatCompletion.eval_count };
}
async inferFromImage(
prompt: string,
_contentType: string,
image: string,
): Promise<InferenceResponse> {
const chatCompletion = await this.ollama.chat({
model: serverConfig.inference.imageModel,
format: "json",
messages: [{ role: "user", content: prompt, images: [`${image}`] }],
});
const response = chatCompletion.message.content;
return { response, totalTokens: chatCompletion.eval_count };
}
}
|