aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared/inference.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/shared/inference.ts')
-rw-r--r--packages/shared/inference.ts47
1 files changed, 39 insertions, 8 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index f34c2880..e09076db 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -9,12 +9,24 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}
+export interface InferenceOptions {
+ json: boolean;
+}
+
+const defaultInferenceOptions: InferenceOptions = {
+ json: true,
+};
+
export interface InferenceClient {
- inferFromText(prompt: string): Promise<InferenceResponse>;
+ inferFromText(
+ prompt: string,
+ opts: InferenceOptions,
+ ): Promise<InferenceResponse>;
inferFromImage(
prompt: string,
contentType: string,
image: string,
+ opts: InferenceOptions,
): Promise<InferenceResponse>;
}
@@ -41,11 +53,14 @@ class OpenAIInferenceClient implements InferenceClient {
});
}
- async inferFromText(prompt: string): Promise<InferenceResponse> {
+ async inferFromText(
+ prompt: string,
+ opts: InferenceOptions = defaultInferenceOptions,
+ ): Promise<InferenceResponse> {
const chatCompletion = await this.openAI.chat.completions.create({
messages: [{ role: "user", content: prompt }],
model: serverConfig.inference.textModel,
- response_format: { type: "json_object" },
+ response_format: opts.json ? { type: "json_object" } : undefined,
});
const response = chatCompletion.choices[0].message.content;
@@ -59,10 +74,11 @@ class OpenAIInferenceClient implements InferenceClient {
prompt: string,
contentType: string,
image: string,
+ opts: InferenceOptions = defaultInferenceOptions,
): Promise<InferenceResponse> {
const chatCompletion = await this.openAI.chat.completions.create({
model: serverConfig.inference.imageModel,
- response_format: { type: "json_object" },
+ response_format: opts.json ? { type: "json_object" } : undefined,
messages: [
{
role: "user",
@@ -98,10 +114,15 @@ class OllamaInferenceClient implements InferenceClient {
});
}
- async runModel(model: string, prompt: string, image?: string) {
+ async runModel(
+ model: string,
+ prompt: string,
+ image?: string,
+ opts: InferenceOptions = defaultInferenceOptions,
+ ) {
const chatCompletion = await this.ollama.chat({
model: model,
- format: "json",
+ format: opts.json ? "json" : undefined,
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {
@@ -137,19 +158,29 @@ class OllamaInferenceClient implements InferenceClient {
return { response, totalTokens };
}
- async inferFromText(prompt: string): Promise<InferenceResponse> {
- return await this.runModel(serverConfig.inference.textModel, prompt);
+ async inferFromText(
+ prompt: string,
+ opts: InferenceOptions = defaultInferenceOptions,
+ ): Promise<InferenceResponse> {
+ return await this.runModel(
+ serverConfig.inference.textModel,
+ prompt,
+ undefined,
+ opts,
+ );
}
async inferFromImage(
prompt: string,
_contentType: string,
image: string,
+ opts: InferenceOptions = defaultInferenceOptions,
): Promise<InferenceResponse> {
return await this.runModel(
serverConfig.inference.imageModel,
prompt,
image,
+ opts,
);
}
}