aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared
diff options
context:
space:
mode:
Diffstat (limited to 'packages/shared')
-rw-r--r--packages/shared/inference.ts49
-rw-r--r--packages/shared/package.json2
2 files changed, 38 insertions, 13 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 1573382f..f1bea9ff 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -15,6 +15,7 @@ export interface EmbeddingResponse {
export interface InferenceOptions {
json: boolean;
+ abortSignal?: AbortSignal;
}
const defaultInferenceOptions: InferenceOptions = {
@@ -24,13 +25,13 @@ const defaultInferenceOptions: InferenceOptions = {
export interface InferenceClient {
inferFromText(
prompt: string,
- opts: InferenceOptions,
+ opts: Partial<InferenceOptions>,
): Promise<InferenceResponse>;
inferFromImage(
prompt: string,
contentType: string,
image: string,
- opts: InferenceOptions,
+ opts: Partial<InferenceOptions>,
): Promise<InferenceResponse>;
generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
@@ -60,12 +61,18 @@ class OpenAIInferenceClient implements InferenceClient {
async inferFromText(
prompt: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
const chatCompletion = await this.openAI.chat.completions.create({
messages: [{ role: "user", content: prompt }],
model: serverConfig.inference.textModel,
- response_format: opts.json ? { type: "json_object" } : undefined,
+ response_format: optsWithDefaults.json
+ ? { type: "json_object" }
+ : undefined,
});
const response = chatCompletion.choices[0].message.content;
@@ -79,11 +86,17 @@ class OpenAIInferenceClient implements InferenceClient {
prompt: string,
contentType: string,
image: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
const chatCompletion = await this.openAI.chat.completions.create({
model: serverConfig.inference.imageModel,
- response_format: opts.json ? { type: "json_object" } : undefined,
+ response_format: optsWithDefaults.json
+ ? { type: "json_object" }
+ : undefined,
messages: [
{
role: "user",
@@ -136,12 +149,16 @@ class OllamaInferenceClient implements InferenceClient {
async runModel(
model: string,
prompt: string,
+ _opts: InferenceOptions,
image?: string,
- opts: InferenceOptions = defaultInferenceOptions,
) {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
const chatCompletion = await this.ollama.chat({
model: model,
- format: opts.json ? "json" : undefined,
+ format: optsWithDefaults.json ? "json" : undefined,
stream: true,
keep_alive: serverConfig.inference.ollamaKeepAlive,
options: {
@@ -179,13 +196,17 @@ class OllamaInferenceClient implements InferenceClient {
async inferFromText(
prompt: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
return await this.runModel(
serverConfig.inference.textModel,
prompt,
+ optsWithDefaults,
undefined,
- opts,
);
}
@@ -193,13 +214,17 @@ class OllamaInferenceClient implements InferenceClient {
prompt: string,
_contentType: string,
image: string,
- opts: InferenceOptions = defaultInferenceOptions,
+ _opts: Partial<InferenceOptions>,
): Promise<InferenceResponse> {
+ const optsWithDefaults: InferenceOptions = {
+ ...defaultInferenceOptions,
+ ..._opts,
+ };
return await this.runModel(
serverConfig.inference.imageModel,
prompt,
+ optsWithDefaults,
image,
- opts,
);
}
diff --git a/packages/shared/package.json b/packages/shared/package.json
index 93d5495a..ecb16013 100644
--- a/packages/shared/package.json
+++ b/packages/shared/package.json
@@ -6,7 +6,7 @@
"type": "module",
"dependencies": {
"glob": "^11.0.0",
- "liteque": "^0.3.0",
+ "liteque": "^0.3.2",
"meilisearch": "^0.37.0",
"ollama": "^0.5.9",
"openai": "^4.67.1",