aboutsummaryrefslogtreecommitdiffstats
path: root/packages/shared
diff options
context:
space:
mode:
authorMohamed Bassem <me@mbassem.com>2025-12-29 23:35:28 +0000
committerMohamed Bassem <me@mbassem.com>2025-12-29 23:38:21 +0000
commitf00287ede0675521c783c1199675538571f977d6 (patch)
tree2d04b983fa514f4c62a3695c0a521fb50de24eef /packages/shared
parentba8d84a555f9e6cf209c826b97a124f0539739eb (diff)
downloadkarakeep-f00287ede0675521c783c1199675538571f977d6.tar.zst
refactor: reduce duplication in compare-models tool
Diffstat (limited to 'packages/shared')
-rw-r--r--packages/shared/inference.ts107
1 files changed, 79 insertions, 28 deletions
diff --git a/packages/shared/inference.ts b/packages/shared/inference.ts
index fb9fce09..d6a9aa10 100644
--- a/packages/shared/inference.ts
+++ b/packages/shared/inference.ts
@@ -52,34 +52,47 @@ const mapInferenceOutputSchema = <
return opts[type];
};
+export interface OpenAIInferenceConfig {
+ apiKey: string;
+ baseURL?: string;
+ proxyUrl?: string;
+ textModel: string;
+ imageModel: string;
+ contextLength: number;
+ maxOutputTokens: number;
+ useMaxCompletionTokens: boolean;
+ outputSchema: "structured" | "json" | "plain";
+}
+
export class InferenceClientFactory {
static build(): InferenceClient | null {
if (serverConfig.inference.openAIApiKey) {
- return new OpenAIInferenceClient();
+ return OpenAIInferenceClient.fromConfig();
}
if (serverConfig.inference.ollamaBaseUrl) {
- return new OllamaInferenceClient();
+ return OllamaInferenceClient.fromConfig();
}
return null;
}
}
-class OpenAIInferenceClient implements InferenceClient {
+export class OpenAIInferenceClient implements InferenceClient {
openAI: OpenAI;
+ private config: OpenAIInferenceConfig;
- constructor() {
- const fetchOptions = serverConfig.inference.openAIProxyUrl
+ constructor(config: OpenAIInferenceConfig) {
+ this.config = config;
+
+ const fetchOptions = config.proxyUrl
? {
- dispatcher: new undici.ProxyAgent(
- serverConfig.inference.openAIProxyUrl,
- ),
+ dispatcher: new undici.ProxyAgent(config.proxyUrl),
}
: undefined;
this.openAI = new OpenAI({
- apiKey: serverConfig.inference.openAIApiKey,
- baseURL: serverConfig.inference.openAIBaseUrl,
+ apiKey: config.apiKey,
+ baseURL: config.baseURL,
...(fetchOptions ? { fetchOptions } : {}),
defaultHeaders: {
"X-Title": "Karakeep",
@@ -88,6 +101,20 @@ class OpenAIInferenceClient implements InferenceClient {
});
}
+ static fromConfig(): OpenAIInferenceClient {
+ return new OpenAIInferenceClient({
+ apiKey: serverConfig.inference.openAIApiKey!,
+ baseURL: serverConfig.inference.openAIBaseUrl,
+ proxyUrl: serverConfig.inference.openAIProxyUrl,
+ textModel: serverConfig.inference.textModel,
+ imageModel: serverConfig.inference.imageModel,
+ contextLength: serverConfig.inference.contextLength,
+ maxOutputTokens: serverConfig.inference.maxOutputTokens,
+ useMaxCompletionTokens: serverConfig.inference.useMaxCompletionTokens,
+ outputSchema: serverConfig.inference.outputSchema,
+ });
+ }
+
async inferFromText(
prompt: string,
_opts: Partial<InferenceOptions>,
@@ -99,10 +126,10 @@ class OpenAIInferenceClient implements InferenceClient {
const chatCompletion = await this.openAI.chat.completions.create(
{
messages: [{ role: "user", content: prompt }],
- model: serverConfig.inference.textModel,
- ...(serverConfig.inference.useMaxCompletionTokens
- ? { max_completion_tokens: serverConfig.inference.maxOutputTokens }
- : { max_tokens: serverConfig.inference.maxOutputTokens }),
+ model: this.config.textModel,
+ ...(this.config.useMaxCompletionTokens
+ ? { max_completion_tokens: this.config.maxOutputTokens }
+ : { max_tokens: this.config.maxOutputTokens }),
response_format: mapInferenceOutputSchema(
{
structured: optsWithDefaults.schema
@@ -111,7 +138,7 @@ class OpenAIInferenceClient implements InferenceClient {
json: { type: "json_object" },
plain: undefined,
},
- serverConfig.inference.outputSchema,
+ this.config.outputSchema,
),
},
{
@@ -138,10 +165,10 @@ class OpenAIInferenceClient implements InferenceClient {
};
const chatCompletion = await this.openAI.chat.completions.create(
{
- model: serverConfig.inference.imageModel,
- ...(serverConfig.inference.useMaxCompletionTokens
- ? { max_completion_tokens: serverConfig.inference.maxOutputTokens }
- : { max_tokens: serverConfig.inference.maxOutputTokens }),
+ model: this.config.imageModel,
+ ...(this.config.useMaxCompletionTokens
+ ? { max_completion_tokens: this.config.maxOutputTokens }
+ : { max_tokens: this.config.maxOutputTokens }),
response_format: mapInferenceOutputSchema(
{
structured: optsWithDefaults.schema
@@ -150,7 +177,7 @@ class OpenAIInferenceClient implements InferenceClient {
json: { type: "json_object" },
plain: undefined,
},
- serverConfig.inference.outputSchema,
+ this.config.outputSchema,
),
messages: [
{
@@ -195,16 +222,40 @@ class OpenAIInferenceClient implements InferenceClient {
}
}
+export interface OllamaInferenceConfig {
+ baseUrl: string;
+ textModel: string;
+ imageModel: string;
+ contextLength: number;
+ maxOutputTokens: number;
+ keepAlive?: string;
+ outputSchema: "structured" | "json" | "plain";
+}
+
class OllamaInferenceClient implements InferenceClient {
ollama: Ollama;
+ private config: OllamaInferenceConfig;
- constructor() {
+ constructor(config: OllamaInferenceConfig) {
+ this.config = config;
this.ollama = new Ollama({
- host: serverConfig.inference.ollamaBaseUrl,
+ host: config.baseUrl,
fetch: customFetch, // Use the custom fetch with configurable timeout
});
}
+ static fromConfig(): OllamaInferenceClient {
+ return new OllamaInferenceClient({
+ baseUrl: serverConfig.inference.ollamaBaseUrl!,
+ textModel: serverConfig.inference.textModel,
+ imageModel: serverConfig.inference.imageModel,
+ contextLength: serverConfig.inference.contextLength,
+ maxOutputTokens: serverConfig.inference.maxOutputTokens,
+ keepAlive: serverConfig.inference.ollamaKeepAlive,
+ outputSchema: serverConfig.inference.outputSchema,
+ });
+ }
+
async runModel(
model: string,
prompt: string,
@@ -233,13 +284,13 @@ class OllamaInferenceClient implements InferenceClient {
json: "json",
plain: undefined,
},
- serverConfig.inference.outputSchema,
+ this.config.outputSchema,
),
stream: true,
- keep_alive: serverConfig.inference.ollamaKeepAlive,
+ keep_alive: this.config.keepAlive,
options: {
- num_ctx: serverConfig.inference.contextLength,
- num_predict: serverConfig.inference.maxOutputTokens,
+ num_ctx: this.config.contextLength,
+ num_predict: this.config.maxOutputTokens,
},
messages: [
{ role: "user", content: prompt, images: image ? [image] : undefined },
@@ -287,7 +338,7 @@ class OllamaInferenceClient implements InferenceClient {
..._opts,
};
return await this.runModel(
- serverConfig.inference.textModel,
+ this.config.textModel,
prompt,
optsWithDefaults,
undefined,
@@ -305,7 +356,7 @@ class OllamaInferenceClient implements InferenceClient {
..._opts,
};
return await this.runModel(
- serverConfig.inference.imageModel,
+ this.config.imageModel,
prompt,
optsWithDefaults,
image,