diff options
Diffstat (limited to 'packages/shared')
| -rw-r--r-- | packages/shared/config.ts | 4 | ||||
| -rw-r--r-- | packages/shared/inference.ts | 32 |
2 files changed, 36 insertions, 0 deletions
diff --git a/packages/shared/config.ts b/packages/shared/config.ts index aec88096..7b74fc21 100644 --- a/packages/shared/config.ts +++ b/packages/shared/config.ts @@ -24,6 +24,7 @@ const allEnv = z.object({ INFERENCE_JOB_TIMEOUT_SEC: z.coerce.number().default(30), INFERENCE_TEXT_MODEL: z.string().default("gpt-4o-mini"), INFERENCE_IMAGE_MODEL: z.string().default("gpt-4o-mini"), + EMBEDDING_TEXT_MODEL: z.string().default("text-embedding-3-small"), INFERENCE_CONTEXT_LENGTH: z.coerce.number().default(2048), OCR_CACHE_DIR: z.string().optional(), OCR_LANGS: z @@ -90,6 +91,9 @@ const serverConfigSchema = allEnv.transform((val) => { inferredTagLang: val.INFERENCE_LANG, contextLength: val.INFERENCE_CONTEXT_LENGTH, }, + embedding: { + textModel: val.EMBEDDING_TEXT_MODEL, + }, crawler: { numWorkers: val.CRAWLER_NUM_WORKERS, headlessBrowser: val.CRAWLER_HEADLESS_BROWSER, diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts index 7cb88819..1573382f 100644 --- a/packages/shared/inference.ts +++ b/packages/shared/inference.ts @@ -9,6 +9,10 @@ export interface InferenceResponse { totalTokens: number | undefined; } +export interface EmbeddingResponse { + embeddings: number[][]; +} + export interface InferenceOptions { json: boolean; } @@ -28,6 +32,7 @@ export interface InferenceClient { image: string, opts: InferenceOptions, ): Promise<InferenceResponse>; + generateEmbeddingFromText(inputs: string[]): Promise<EmbeddingResponse>; } export class InferenceClientFactory { @@ -103,6 +108,20 @@ class OpenAIInferenceClient implements InferenceClient { } return { response, totalTokens: chatCompletion.usage?.total_tokens }; } + + async generateEmbeddingFromText( + inputs: string[], + ): Promise<EmbeddingResponse> { + const model = serverConfig.embedding.textModel; + const embedResponse = await this.openAI.embeddings.create({ + model: model, + input: inputs, + }); + const embedding2D: number[][] = embedResponse.data.map( + (embedding: OpenAI.Embedding) => embedding.embedding, + ); + return { embeddings: embedding2D }; + } } class OllamaInferenceClient implements InferenceClient { @@ -183,4 +202,17 @@ class OllamaInferenceClient implements InferenceClient { opts, ); } + + async generateEmbeddingFromText( + inputs: string[], + ): Promise<EmbeddingResponse> { + const embedding = await this.ollama.embed({ + model: serverConfig.embedding.textModel, + input: inputs, + // Truncate the input to fit into the model's max token limit, + // in the future we want to add a way to split the input into multiple parts. + truncate: true, + }); + return { embeddings: embedding.embeddings }; + } } |
