diff options
| author | Mohammed Farghal <mohamed@farghal.com> | 2024-12-29 19:20:59 +0100 |
|---|---|---|
| committer | Mohamed Bassem <me@mbassem.com> | 2024-12-29 18:27:17 +0000 |
| commit | c89b0c54418d6b739c765162ca180c8d154a6af8 (patch) | |
| tree | df2d52535b9f146bda73aa45e09c6e245909136c /packages/shared/inference.ts | |
| parent | 225d855e6c239249b7e6ea3131d704642699142f (diff) | |
| download | karakeep-c89b0c54418d6b739c765162ca180c8d154a6af8.tar.zst | |
feat: Add support for embeddings in the inference interface (#403)
* support embeddings generation in inference.ts
(cherry picked from commit 9ae8773ad13ed87af8f72f167bdd56e02ea66f15)
* make AI worker generate embeddings for text bookmark
* make AI worker generate embeddings for text bookmark
* fix unintentional change -- inference image model
* support embeddings for PDF bookmarks
* Upgrade drizzle-kit
Existing version is not working with the upgraded version of drizzle-orm.
I removed the "driver" to the match the new schema of the Config.
Quoting from their Config:
* `driver` - optional param that is responsible for explicitly providing a driver to use when accessing a database
* *Possible values*: `aws-data-api`, `d1-http`, `expo`, `turso`, `pglite`
* If you don't use AWS Data API, D1, Turso or Expo - ypu don't need this driver. You can check a driver strategy choice here: https://orm.
* fix formatting and lint
* add comments about truncate content
* Revert "Upgrade drizzle-kit"
This reverts commit 08a02c8df4ea403de65986ed1265940c6c994a20.
* revert keep alive field in Ollama
* change the interface to accept multiple inputs
* docs
---------
Co-authored-by: Mohamed Bassem <me@mbassem.com>
Diffstat (limited to 'packages/shared/inference.ts')
| -rw-r--r-- | packages/shared/inference.ts | 32 |
1 files changed, 32 insertions, 0 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 7cb88819..1573382f 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -9,6 +9,10 @@ export interface InferenceResponse { totalTokens: number | undefined; } +export interface EmbeddingResponse { + embeddings: number[][]; +} + export interface InferenceOptions { json: boolean; } @@ -28,6 +32,7 @@ export interface InferenceClient { image: string, opts: InferenceOptions, ): Promise<InferenceResponse>; + generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>; } export class InferenceClientFactory { @@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient { } return { response, totalTokens: chatCompletion.usage?.total_tokens }; } + + async generateEmbeddingFromText( + inputs: string[], + ): Promise<EmbeddingResponse> { + const model = serverConfig.embedding.textModel; + const embedResponse = await this.openAI.embeddings.create({ + model: model, + input: inputs, + }); + const embedding2D: number[][] = embedResponse.data.map( + (embedding: OpenAI.Embedding) => embedding.embedding, + ); + return { embeddings: embedding2D }; + } } class OllamaInferenceClient implements InferenceClient { @@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient { opts, ); } + + async generateEmbeddingFromText( + inputs: string[], + ): Promise<EmbeddingResponse> { + const embedding = await this.ollama.embed({ + model: serverConfig.embedding.textModel, + input: inputs, + // Truncate the input to fit into the model's max token limit, + // in the future we want to add a way to split the input into multiple parts. + truncate: true, + }); + return { embeddings: embedding.embeddings }; + } } |
