aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared
diff options
context:
space:
mode:
Diffstat (limited to 'packages/shared')
-rw-r--r--packages/shared/config.ts4
-rw-r--r--packages/shared/inference.ts32
2 files changed, 36 insertions, 0 deletions
diff --git a/packages/shared/config.ts b/packages/shared/config.ts
index aec88096..7b74fc21 100644
--- a/packages/shared/config.ts
+++ b/packages/shared/config.ts
@@ -24,6 +24,7 @@ const allEnv = z.object({
INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30),
INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"),
INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"),
+ EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"),
INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048),
OCR_CACHE_DIR: z.string().optional(),
OCR_LANGS: z
@@ -90,6 +91,9 @@ const serverConfigSchema = allEnv.transform((val) => {
inferredTagLang: val.INFERENCE_LANG,
contextLength: val.INFERENCE_CONTEXT_LENGTH,
},
+ embedding: {
+ textModel: val.EMBEDDING_TEXT_MODEL,
+ },
crawler: {
numWorkers: val.CRAWLER_NUM_WORKERS,
headlessBrowser: val.CRAWLER_HEADLESS_BROWSER,
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index 7cb88819..1573382f 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -9,6 +9,10 @@ export interface InferenceResponse {
totalTokens: number | undefined;
}
+export interface EmbeddingResponse {
+ embeddings: number[][];
+}
+
export interface InferenceOptions {
json: boolean;
}
@@ -28,6 +32,7 @@ export interface InferenceClient {
image: string,
opts: InferenceOptions,
): Promise<InferenceResponse>;
+ generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>;
}
export class InferenceClientFactory {
@@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient {
}
return { response, totalTokens: chatCompletion.usage?.total_tokens };
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const model = serverConfig.embedding.textModel;
+ const embedResponse = await this.openAI.embeddings.create({
+ model: model,
+ input: inputs,
+ });
+ const embedding2D: number[][] = embedResponse.data.map(
+ (embedding: OpenAI.Embedding) => embedding.embedding,
+ );
+ return { embeddings: embedding2D };
+ }
}
class OllamaInferenceClient implements InferenceClient {
@@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient {
opts,
);
}
+
+ async generateEmbeddingFromText(
+ inputs: string[],
+ ): Promise<EmbeddingResponse> {
+ const embedding = await this.ollama.embed({
+ model: serverConfig.embedding.textModel,
+ input: inputs,
+ // Truncate the input to fit into the model's max token limit,
+ // in the future we want to add a way to split the input into multiple parts.
+ truncate: true,
+ });
+ return { embeddings: embedding.embeddings };
+ }
}