aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMohammed Farghal <mohamed@farghal.com>2024-12-29 19:20:59 +0100
committerMohamed Bassem <me@mbassem.com>2024-12-29 18:27:17 +0000
commitc89b0c54418d6b739c765162ca180c8d154a6af8 (patch)
treedf2d52535b9f146bda73aa45e09c6e245909136c
parent225d855e6c239249b7e6ea3131d704642699142f (diff)
downloadkarakeep-c89b0c54418d6b739c765162ca180c8d154a6af8.tar.zst
feat: Add support for embeddings in the inference interface (#403)
* support embeddings generation in inference.ts (cherry picked from commit 9ae8773ad13ed87af8f72f167bdd56e02ea66f15) * make AI worker generate embeddings for text bookmark * make AI worker generate embeddings for text bookmark * fix unintentional change -- inference image model * support embeddings for PDF bookmarks * Upgrade drizzle-kit Existing version is not working with the upgraded version of drizzle-orm. I removed the "driver" to the match the new schema of the Config. Quoting from their Config: * `driver` - optional param that is responsible for explicitly providing a driver to use when accessing a database * *Possible values*: `aws-data-api`, `d1-http`, `expo`, `turso`, `pglite` * If you don't use AWS Data API, D1, Turso or Expo - ypu don't need this driver. You can check a driver strategy choice here: https://orm. * fix formatting and lint * add comments about truncate content * Revert "Upgrade drizzle-kit" This reverts commit 08a02c8df4ea403de65986ed1265940c6c994a20. * revert keep alive field in Ollama * change the interface to accept multiple inputs * docs --------- Co-authored-by: Mohamed Bassem <me@mbassem.com>
-rw-r--r--docs/docs/03-configuration.md23
-rw-r--r--packages/shared/config.ts4
-rw-r--r--packages/shared/inference.ts32
3 files changed, 48 insertions, 11 deletions
diff --git a/docs/docs/03-configuration.md b/docs/docs/03-configuration.md
index 47c3227f..82438dbf 100644
--- a/docs/docs/03-configuration.md
+++ b/docs/docs/03-configuration.md
@@ -48,17 +48,18 @@ Either `OPENAI_API_KEY` or `OLLAMA_BASE_URL` need to be set for automatic taggin
- You might want to tune the `INFERENCE_CONTEXT_LENGTH` as the default is quite small. The larger the value, the better the quality of the tags, but the more expensive the inference will be (money-wise on OpenAI and resource-wise on ollama).
:::
-| Name | Required | Default | Description |
-| ------------------------- | -------- | ----------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
-| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
-| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
-| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
-| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). |
-| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
-| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
-| INFERENCE_CONTEXT_LENGTH | No | 2048 | The max number of tokens that we'll pass to the inference model. If your content is larger than this size, it'll be truncated to fit. The larger this value, the more of the content will be used in tag inference, but the more expensive the inference will be (money-wise on openAI and resource-wise on ollama). Check the model you're using for its max supported content size. |
-| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
-| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |
+| Name | Required | Default | Description |
+| ------------------------- | -------- | ---------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| OPENAI_API_KEY | No | Not set | The OpenAI key used for automatic tagging. More on that in [here](/openai). |
+| OPENAI_BASE_URL | No | Not set | If you just want to use OpenAI you don't need to pass this variable. If, however, you want to use some other openai compatible API (e.g. azure openai service), set this to the url of the API. |
+| OLLAMA_BASE_URL | No | Not set | If you want to use ollama for local inference, set the address of ollama API here. |
+| OLLAMA_KEEP_ALIVE | No | Not set | Controls how long the model will stay loaded into memory following the request (example value: "5m"). |
+| INFERENCE_TEXT_MODEL | No | gpt-4o-mini | The model to use for text inference. You'll need to change this to some other model if you're using ollama. |
+| INFERENCE_IMAGE_MODEL | No | gpt-4o-mini | The model to use for image inference. You'll need to change this to some other model if you're using ollama and that model needs to support vision APIs (e.g. llava). |
+| EMBEDDING_TEXT_MODEL | No | text-embedding-3-small | The model to be used for generating embeddings for the text. |
+| INFERENCE_CONTEXT_LENGTH | No | 2048 | The max number of tokens that we'll pass to the inference model. If your content is larger than this size, it'll be truncated to fit. The larger this value, the more of the content will be used in tag inference, but the more expensive the inference will be (money-wise on openAI and resource-wise on ollama). Check the model you're using for its max supported content size. |
+| INFERENCE_LANG | No | english | The language in which the tags will be generated. |
+| INFERENCE_JOB_TIMEOUT_SEC | No | 30 | How long to wait for the inference job to finish before timing out. If you're running ollama without powerful GPUs, you might want to increase the timeout a bit. |
:::info
diff --git a/packages/shared/config.ts b/packages/shared/config.ts
index aec88096..7b74fc21 100644
--- a/packages/shared/config.ts
+++ b/packages/shared/config.ts
@@ -24,6 +24,7 @@ const allEnv = z.object({
INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30),
INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"),
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
+ EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"),
INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048),
OCR_CACHE_DIR: z.string().optional(),
OCR_LANGS: z
@@ -90,6 +91,9 @@ const serverConfigSchema = allEnv.transform((val) => {
inferredTagLang: val.INFERENCE_LANG,
contextLength: val.INFERENCE_CONTEXT_LENGTH,
},
+ embedding: {
+ textModel: val.EMBEDDING_TEXT_MODEL,
+ },
crawler: {
numWorkers: val.CRAWLER_NUM_WORKERS,
headlessBrowser: val.CRAWLER_HEADLESS_BROWSER,
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 7cb88819..1573382f 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -9,6 +9,10 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}
+export interface EmbeddingResponse {
+ embeddings: number[][];
+}
+
export interface InferenceOptions {
json: boolean;
}
@@ -28,6 +32,7 @@ export interface InferenceClient {
image: string,
opts: InferenceOptions,
): Promise<InferenceResponse>;
+ generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
export class InferenceClientFactory {
@@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient {
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const model = serverConfig.embedding.textModel;
+ const embedResponse = await this.openAI.embeddings.create({
+ model: model,
+ input: inputs,
+ });
+ const embedding2D: number[][] = embedResponse.data.map(
+ (embedding: OpenAI.Embedding) => embedding.embedding,
+ );
+ return { embeddings: embedding2D };
+ }
}
class OllamaInferenceClient implements InferenceClient {
@@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient {
opts,
);
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const embedding = await this.ollama.embed({
+ model: serverConfig.embedding.textModel,
+ input: inputs,
+ // Truncate the input to fit into the model's max token limit,
+ // in the future we want to add a way to split the input into multiple parts.
+ truncate: true,
+ });
+ return { embeddings: embedding.embeddings };
+ }
}