aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared/inference.ts
diff options
context:
space:
mode:
authorMohammed Farghal <mohamed@farghal.com>2024-12-29 19:20:59 +0100
committerMohamed Bassem <me@mbassem.com>2024-12-29 18:27:17 +0000
commitc89b0c54418d6b739c765162ca180c8d154a6af8 (patch)
treedf2d52535b9f146bda73aa45e09c6e245909136c /packages/shared/inference.ts
parent225d855e6c239249b7e6ea3131d704642699142f (diff)
downloadkarakeep-c89b0c54418d6b739c765162ca180c8d154a6af8.tar.zst
feat: Add support for embeddings in the inference interface (#403)
* support embeddings generation in inference.ts (cherry picked from commit 9ae8773ad13ed87af8f72f167bdd56e02ea66f15) * make AI worker generate embeddings for text bookmark * make AI worker generate embeddings for text bookmark * fix unintentional change -- inference image model * support embeddings for PDF bookmarks * Upgrade drizzle-kit Existing version is not working with the upgraded version of drizzle-orm. I removed the "driver" to the match the new schema of the Config. Quoting from their Config: * `driver` - optional param that is responsible for explicitly providing a driver to use when accessing a database * *Possible values*: `aws-data-api`, `d1-http`, `expo`, `turso`, `pglite` * If you don't use AWS Data API, D1, Turso or Expo - ypu don't need this driver. You can check a driver strategy choice here: https://orm. * fix formatting and lint * add comments about truncate content * Revert "Upgrade drizzle-kit" This reverts commit 08a02c8df4ea403de65986ed1265940c6c994a20. * revert keep alive field in Ollama * change the interface to accept multiple inputs * docs --------- Co-authored-by: Mohamed Bassem <me@mbassem.com>
Diffstat (limited to 'packages/shared/inference.ts')
-rw-r--r--packages/shared/inference.ts32
1 files changed, 32 insertions, 0 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 7cb88819..1573382f 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -9,6 +9,10 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}
+export interface EmbeddingResponse {
+ embeddings: number[][];
+}
+
export interface InferenceOptions {
json: boolean;
}
@@ -28,6 +32,7 @@ export interface InferenceClient {
image: string,
opts: InferenceOptions,
): Promise<InferenceResponse>;
+ generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
export class InferenceClientFactory {
@@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient {
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const model = serverConfig.embedding.textModel;
+ const embedResponse = await this.openAI.embeddings.create({
+ model: model,
+ input: inputs,
+ });
+ const embedding2D: number[][] = embedResponse.data.map(
+ (embedding: OpenAI.Embedding) => embedding.embedding,
+ );
+ return { embeddings: embedding2D };
+ }
}
class OllamaInferenceClient implements InferenceClient {
@@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient {
opts,
);
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const embedding = await this.ollama.embed({
+ model: serverConfig.embedding.textModel,
+ input: inputs,
+ // Truncate the input to fit into the model's max token limit,
+ // in the future we want to add a way to split the input into multiple parts.
+ truncate: true,
+ });
+ return { embeddings: embedding.embeddings };
+ }
}