aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared/inference.ts
diff options
context:
space:
mode:
Diffstat (limited to 'packages/shared/inference.ts')
-rw-r--r--packages/shared/inference.ts32
1 files changed, 32 insertions, 0 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 7cb88819..1573382f 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -9,6 +9,10 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}
+export interface EmbeddingResponse {
+ embeddings: number[][];
+}
+
export interface InferenceOptions {
json: boolean;
}
@@ -28,6 +32,7 @@ export interface InferenceClient {
image: string,
opts: InferenceOptions,
): Promise<InferenceResponse>;
+ generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
export class InferenceClientFactory {
@@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient {
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const model = serverConfig.embedding.textModel;
+ const embedResponse = await this.openAI.embeddings.create({
+ model: model,
+ input: inputs,
+ });
+ const embedding2D: number[][] = embedResponse.data.map(
+ (embedding: OpenAI.Embedding) => embedding.embedding,
+ );
+ return { embeddings: embedding2D };
+ }
}
class OllamaInferenceClient implements InferenceClient {
@@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient {
opts,
);
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const embedding = await this.ollama.embed({
+ model: serverConfig.embedding.textModel,
+ input: inputs,
+ // Truncate the input to fit into the model's max token limit,
+ // in the future we want to add a way to split the input into multiple parts.
+ truncate: true,
+ });
+ return { embeddings: embedding.embeddings };
+ }
}