diff options
Diffstat (limited to 'apps/workers')
| -rw-r--r-- | apps/workers/network.ts | 419 | ||||
| -rw-r--r-- | apps/workers/package.json | 2 | ||||
| -rw-r--r-- | apps/workers/utils.ts | 61 | ||||
| -rw-r--r-- | apps/workers/workers/crawlerWorker.ts | 81 | ||||
| -rw-r--r-- | apps/workers/workers/feedWorker.ts | 2 | ||||
| -rw-r--r-- | apps/workers/workers/videoWorker.ts | 36 | ||||
| -rw-r--r-- | apps/workers/workers/webhookWorker.ts | 4 |
7 files changed, 507 insertions, 98 deletions
diff --git a/apps/workers/network.ts b/apps/workers/network.ts new file mode 100644 index 00000000..acfd2439 --- /dev/null +++ b/apps/workers/network.ts @@ -0,0 +1,419 @@ +import dns from "node:dns/promises"; +import type { HeadersInit, RequestInit, Response } from "node-fetch"; +import { HttpProxyAgent } from "http-proxy-agent"; +import { HttpsProxyAgent } from "https-proxy-agent"; +import ipaddr from "ipaddr.js"; +import { LRUCache } from "lru-cache"; +import fetch, { Headers } from "node-fetch"; + +import serverConfig from "@karakeep/shared/config"; + +const DISALLOWED_IP_RANGES = new Set([ + // IPv4 ranges + "unspecified", + "broadcast", + "multicast", + "linkLocal", + "loopback", + "private", + "reserved", + "carrierGradeNat", + // IPv6 ranges + "uniqueLocal", + "6to4", // RFC 3056 - IPv6 transition mechanism + "teredo", // RFC 4380 - IPv6 tunneling + "benchmarking", // RFC 5180 - benchmarking addresses + "deprecated", // RFC 3879 - deprecated IPv6 addresses + "discard", // RFC 6666 - discard-only prefix +]); + +// DNS cache with 5 minute TTL and max 1000 entries +const dnsCache = new LRUCache<string, string[]>({ + max: 1000, + ttl: 5 * 60 * 1000, // 5 minutes in milliseconds +}); + +async function resolveHostAddresses(hostname: string): Promise<string[]> { + const resolver = new dns.Resolver({ + timeout: serverConfig.crawler.ipValidation.dnsResolverTimeoutSec * 1000, + }); + + const results = await Promise.allSettled([ + resolver.resolve4(hostname), + resolver.resolve6(hostname), + ]); + + const addresses: string[] = []; + const errors: string[] = []; + + for (const result of results) { + if (result.status === "fulfilled") { + addresses.push(...result.value); + } else { + const reason = result.reason; + if (reason instanceof Error) { + errors.push(reason.message); + } else { + errors.push(String(reason)); + } + } + } + + if (addresses.length > 0) { + return addresses; + } + + const errorMessage = + errors.length > 0 + ? errors.join("; ") + : "DNS lookup did not return any A or AAAA records"; + throw new Error(errorMessage); +} + +function isAddressForbidden(address: string): boolean { + if (!ipaddr.isValid(address)) { + return true; + } + const parsed = ipaddr.parse(address); + if ( + parsed.kind() === "ipv6" && + (parsed as ipaddr.IPv6).isIPv4MappedAddress() + ) { + const mapped = (parsed as ipaddr.IPv6).toIPv4Address(); + return DISALLOWED_IP_RANGES.has(mapped.range()); + } + return DISALLOWED_IP_RANGES.has(parsed.range()); +} + +export type UrlValidationResult = + | { ok: true; url: URL } + | { ok: false; reason: string }; + +function hostnameMatchesAnyPattern( + hostname: string, + patterns: string[], +): boolean { + function hostnameMatchesPattern(hostname: string, pattern: string): boolean { + return ( + pattern === hostname || + (pattern.startsWith(".") && hostname.endsWith(pattern)) || + hostname.endsWith("." + pattern) + ); + } + + for (const pattern of patterns) { + if (hostnameMatchesPattern(hostname, pattern)) { + return true; + } + } + return false; +} + +function isHostnameAllowedForInternalAccess(hostname: string): boolean { + if (!serverConfig.allowedInternalHostnames) { + return false; + } + return hostnameMatchesAnyPattern( + hostname, + serverConfig.allowedInternalHostnames, + ); +} + +export async function validateUrl( + urlCandidate: string, + runningInProxyContext: boolean, +): Promise<UrlValidationResult> { + let parsedUrl: URL; + try { + parsedUrl = new URL(urlCandidate); + } catch (error) { + return { + ok: false, + reason: `Invalid URL "${urlCandidate}": ${ + error instanceof Error ? error.message : String(error) + }`, + } as const; + } + + if (parsedUrl.protocol !== "http:" && parsedUrl.protocol !== "https:") { + return { + ok: false, + reason: `Unsupported protocol for URL: ${parsedUrl.toString()}`, + } as const; + } + + const hostname = parsedUrl.hostname; + if (!hostname) { + return { + ok: false, + reason: `URL ${parsedUrl.toString()} must include a hostname`, + } as const; + } + + if (isHostnameAllowedForInternalAccess(hostname)) { + return { ok: true, url: parsedUrl } as const; + } + + if (ipaddr.isValid(hostname)) { + if (isAddressForbidden(hostname)) { + return { + ok: false, + reason: `Refusing to access disallowed IP address ${hostname} (requested via ${parsedUrl.toString()})`, + } as const; + } + return { ok: true, url: parsedUrl } as const; + } + + if (runningInProxyContext) { + // If we're running in a proxy context, we must skip DNS resolution + // as the DNS resolution will be handled by the proxy + return { ok: true, url: parsedUrl } as const; + } + + // Check cache first + let records = dnsCache.get(hostname); + + if (!records) { + // Cache miss or expired - perform DNS resolution + try { + records = await resolveHostAddresses(hostname); + dnsCache.set(hostname, records); + } catch (error) { + return { + ok: false, + reason: `Failed to resolve hostname ${hostname}: ${ + error instanceof Error ? error.message : String(error) + }`, + } as const; + } + } + + if (!records || records.length === 0) { + return { + ok: false, + reason: `DNS lookup for ${hostname} did not return any addresses (requested via ${parsedUrl.toString()})`, + } as const; + } + + for (const record of records) { + if (isAddressForbidden(record)) { + return { + ok: false, + reason: `Refusing to access disallowed resolved address ${record} for host ${hostname}`, + } as const; + } + } + + return { ok: true, url: parsedUrl } as const; +} + +export function getRandomProxy(proxyList: string[]): string { + return proxyList[Math.floor(Math.random() * proxyList.length)].trim(); +} + +export function matchesNoProxy(url: string, noProxy: string[]) { + const urlObj = new URL(url); + const hostname = urlObj.hostname; + return hostnameMatchesAnyPattern(hostname, noProxy); +} + +export function getProxyAgent(url: string) { + const { proxy } = serverConfig; + + if (!proxy.httpProxy && !proxy.httpsProxy) { + return undefined; + } + + const urlObj = new URL(url); + const protocol = urlObj.protocol; + + // Check if URL should bypass proxy + if (proxy.noProxy && matchesNoProxy(url, proxy.noProxy)) { + return undefined; + } + + if (protocol === "https:" && proxy.httpsProxy) { + const selectedProxy = getRandomProxy(proxy.httpsProxy); + return new HttpsProxyAgent(selectedProxy); + } else if (protocol === "http:" && proxy.httpProxy) { + const selectedProxy = getRandomProxy(proxy.httpProxy); + return new HttpProxyAgent(selectedProxy); + } else if (proxy.httpProxy) { + const selectedProxy = getRandomProxy(proxy.httpProxy); + return new HttpProxyAgent(selectedProxy); + } + + return undefined; +} + +function cloneHeaders(init?: HeadersInit): Headers { + const headers = new Headers(); + if (!init) { + return headers; + } + if (init instanceof Headers) { + init.forEach((value, key) => { + headers.set(key, value); + }); + return headers; + } + + if (Array.isArray(init)) { + for (const [key, value] of init) { + headers.append(key, value); + } + return headers; + } + + for (const [key, value] of Object.entries(init)) { + if (Array.isArray(value)) { + headers.set(key, value.join(", ")); + } else if (value !== undefined) { + headers.set(key, value); + } + } + + return headers; +} + +function isRedirectResponse(response: Response): boolean { + return ( + response.status === 301 || + response.status === 302 || + response.status === 303 || + response.status === 307 || + response.status === 308 + ); +} + +export type FetchWithProxyOptions = Omit< + RequestInit & { + maxRedirects?: number; + }, + "agent" +>; + +interface PreparedFetchOptions { + maxRedirects: number; + baseHeaders: Headers; + method: string; + body?: RequestInit["body"]; + baseOptions: RequestInit; +} + +export function prepareFetchOptions( + options: FetchWithProxyOptions = {}, +): PreparedFetchOptions { + const { + maxRedirects = 5, + headers: initHeaders, + method: initMethod, + body: initBody, + redirect: _ignoredRedirect, + ...restOptions + } = options; + + const baseOptions = restOptions as RequestInit; + + return { + maxRedirects, + baseHeaders: cloneHeaders(initHeaders), + method: initMethod?.toUpperCase?.() ?? "GET", + body: initBody, + baseOptions, + }; +} + +interface BuildFetchOptionsInput { + method: string; + body?: RequestInit["body"]; + headers: Headers; + agent?: RequestInit["agent"]; + baseOptions: RequestInit; +} + +export function buildFetchOptions({ + method, + body, + headers, + agent, + baseOptions, +}: BuildFetchOptionsInput): RequestInit { + return { + ...baseOptions, + method, + body, + headers, + agent, + redirect: "manual", + }; +} + +export const fetchWithProxy = async ( + url: string, + options: FetchWithProxyOptions = {}, +) => { + const { + maxRedirects, + baseHeaders, + method: preparedMethod, + body: preparedBody, + baseOptions, + } = prepareFetchOptions(options); + + let redirectsRemaining = maxRedirects; + let currentUrl = url; + let currentMethod = preparedMethod; + let currentBody = preparedBody; + + while (true) { + const agent = getProxyAgent(currentUrl); + + const validation = await validateUrl(currentUrl, !!agent); + if (!validation.ok) { + throw new Error(validation.reason); + } + const requestUrl = validation.url; + currentUrl = requestUrl.toString(); + + const response = await fetch( + currentUrl, + buildFetchOptions({ + method: currentMethod, + body: currentBody, + headers: baseHeaders, + agent, + baseOptions, + }), + ); + + if (!isRedirectResponse(response)) { + return response; + } + + const locationHeader = response.headers.get("location"); + if (!locationHeader) { + return response; + } + + if (redirectsRemaining <= 0) { + throw new Error(`Too many redirects while fetching ${url}`); + } + + const nextUrl = new URL(locationHeader, currentUrl); + + if ( + response.status === 303 || + ((response.status === 301 || response.status === 302) && + currentMethod !== "GET" && + currentMethod !== "HEAD") + ) { + currentMethod = "GET"; + currentBody = undefined; + baseHeaders.delete("content-length"); + } + + currentUrl = nextUrl.toString(); + redirectsRemaining -= 1; + } +}; diff --git a/apps/workers/package.json b/apps/workers/package.json index b02c3bc9..f35a52f4 100644 --- a/apps/workers/package.json +++ b/apps/workers/package.json @@ -23,8 +23,10 @@ "hono": "^4.7.10", "http-proxy-agent": "^7.0.2", "https-proxy-agent": "^7.0.6", + "ipaddr.js": "^2.2.0", "jsdom": "^24.0.0", "liteque": "^0.6.2", + "lru-cache": "^11.2.2", "metascraper": "^5.49.5", "metascraper-amazon": "^5.49.5", "metascraper-author": "^5.49.5", diff --git a/apps/workers/utils.ts b/apps/workers/utils.ts index a82dd12d..2f56d3f0 100644 --- a/apps/workers/utils.ts +++ b/apps/workers/utils.ts @@ -1,9 +1,3 @@ -import { HttpProxyAgent } from "http-proxy-agent"; -import { HttpsProxyAgent } from "https-proxy-agent"; -import fetch from "node-fetch"; - -import serverConfig from "@karakeep/shared/config"; - export function withTimeout<T, Ret>( func: (param: T) => Promise<Ret>, timeoutSec: number, @@ -20,58 +14,3 @@ export function withTimeout<T, Ret>( ]); }; } - -export function getRandomProxy(proxyList: string[]): string { - return proxyList[Math.floor(Math.random() * proxyList.length)].trim(); -} - -function getProxyAgent(url: string) { - const { proxy } = serverConfig; - - if (!proxy.httpProxy && !proxy.httpsProxy) { - return undefined; - } - - const urlObj = new URL(url); - const protocol = urlObj.protocol; - - // Check if URL should bypass proxy - if (proxy.noProxy) { - const noProxyList = proxy.noProxy.split(",").map((host) => host.trim()); - const hostname = urlObj.hostname; - - for (const noProxyHost of noProxyList) { - if ( - noProxyHost === hostname || - (noProxyHost.startsWith(".") && hostname.endsWith(noProxyHost)) || - hostname.endsWith("." + noProxyHost) - ) { - return undefined; - } - } - } - - if (protocol === "https:" && proxy.httpsProxy) { - const selectedProxy = getRandomProxy(proxy.httpsProxy); - return new HttpsProxyAgent(selectedProxy); - } else if (protocol === "http:" && proxy.httpProxy) { - const selectedProxy = getRandomProxy(proxy.httpProxy); - return new HttpProxyAgent(selectedProxy); - } else if (proxy.httpProxy) { - const selectedProxy = getRandomProxy(proxy.httpProxy); - return new HttpProxyAgent(selectedProxy); - } - - return undefined; -} - -export const fetchWithProxy = ( - url: string, - options: Record<string, unknown> = {}, -) => { - const agent = getProxyAgent(url); - if (agent) { - options.agent = agent; - } - return fetch(url, options); -}; diff --git a/apps/workers/workers/crawlerWorker.ts b/apps/workers/workers/crawlerWorker.ts index 33ff2851..70b2e644 100644 --- a/apps/workers/workers/crawlerWorker.ts +++ b/apps/workers/workers/crawlerWorker.ts @@ -25,10 +25,15 @@ import metascraperTitle from "metascraper-title"; import metascraperTwitter from "metascraper-twitter"; import metascraperUrl from "metascraper-url"; import { workerStatsCounter } from "metrics"; +import { + fetchWithProxy, + getRandomProxy, + matchesNoProxy, + validateUrl, +} from "network"; import { Browser, BrowserContextOptions } from "playwright"; import { chromium } from "playwright-extra"; import StealthPlugin from "puppeteer-extra-plugin-stealth"; -import { fetchWithProxy, getRandomProxy } from "utils"; import { getBookmarkDetails, updateAsset } from "workerUtils"; import { z } from "zod"; @@ -173,7 +178,7 @@ function getPlaywrightProxyConfig(): BrowserContextOptions["proxy"] { server: proxyUrl, username: parsed.username, password: parsed.password, - bypass: proxy.noProxy, + bypass: proxy.noProxy?.join(","), }; } @@ -355,22 +360,6 @@ async function changeBookmarkStatus( .where(eq(bookmarkLinks.id, bookmarkId)); } -/** - * This provides some "basic" protection from malicious URLs. However, all of those - * can be easily circumvented by pointing dns of origin to localhost, or with - * redirects. - */ -function validateUrl(url: string) { - const urlParsed = new URL(url); - if (urlParsed.protocol != "http:" && urlParsed.protocol != "https:") { - throw new Error(`Unsupported URL protocol: ${urlParsed.protocol}`); - } - - if (["localhost", "127.0.0.1", "0.0.0.0"].includes(urlParsed.hostname)) { - throw new Error(`Link hostname rejected: ${urlParsed.hostname}`); - } -} - async function browserlessCrawlPage( jobId: string, url: string, @@ -430,11 +419,15 @@ async function crawlPage( return browserlessCrawlPage(jobId, url, abortSignal); } + const proxyConfig = getPlaywrightProxyConfig(); + const isRunningInProxyContext = + proxyConfig !== undefined && + !matchesNoProxy(url, proxyConfig.bypass?.split(",") ?? []); const context = await browser.newContext({ viewport: { width: 1440, height: 900 }, userAgent: "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36", - proxy: getPlaywrightProxyConfig(), + proxy: proxyConfig, }); try { @@ -453,8 +446,12 @@ async function crawlPage( await globalBlocker.enableBlockingInPage(page); } - // Block audio/video resources - await page.route("**/*", (route) => { + // Block audio/video resources and disallowed sub-requests + await page.route("**/*", async (route) => { + if (abortSignal.aborted) { + await route.abort("aborted"); + return; + } const request = route.request(); const resourceType = request.resourceType(); @@ -464,18 +461,49 @@ async function crawlPage( request.headers()["content-type"]?.includes("video/") || request.headers()["content-type"]?.includes("audio/") ) { - route.abort(); + await route.abort("aborted"); return; } + const requestUrl = request.url(); + const requestIsRunningInProxyContext = + proxyConfig !== undefined && + !matchesNoProxy(requestUrl, proxyConfig.bypass?.split(",") ?? []); + if ( + requestUrl.startsWith("http://") || + requestUrl.startsWith("https://") + ) { + const validation = await validateUrl( + requestUrl, + requestIsRunningInProxyContext, + ); + if (!validation.ok) { + logger.warn( + `[Crawler][${jobId}] Blocking sub-request to disallowed URL "${requestUrl}": ${validation.reason}`, + ); + await route.abort("blockedbyclient"); + return; + } + } + // Continue with other requests - route.continue(); + await route.continue(); }); // Navigate to the target URL - logger.info(`[Crawler][${jobId}] Navigating to "${url}"`); + const navigationValidation = await validateUrl( + url, + isRunningInProxyContext, + ); + if (!navigationValidation.ok) { + throw new Error( + `Disallowed navigation target "${url}": ${navigationValidation.reason}`, + ); + } + const targetUrl = navigationValidation.url.toString(); + logger.info(`[Crawler][${jobId}] Navigating to "${targetUrl}"`); const response = await Promise.race([ - page.goto(url, { + page.goto(targetUrl, { timeout: serverConfig.crawler.navigateTimeoutSec * 1000, waitUntil: "domcontentloaded", }), @@ -483,7 +511,7 @@ async function crawlPage( ]); logger.info( - `[Crawler][${jobId}] Successfully navigated to "${url}". Waiting for the page to load ...`, + `[Crawler][${jobId}] Successfully navigated to "${targetUrl}". Waiting for the page to load ...`, ); // Wait until network is relatively idle or timeout after 5 seconds @@ -1231,7 +1259,6 @@ async function runCrawler(job: DequeuedJob<ZCrawlLinkRequest>) { logger.info( `[Crawler][${jobId}] Will crawl "${url}" for link with id "${bookmarkId}"`, ); - validateUrl(url); const contentType = await getContentType(url, jobId, job.abortSignal); job.abortSignal.throwIfAborted(); diff --git a/apps/workers/workers/feedWorker.ts b/apps/workers/workers/feedWorker.ts index 38b06c47..f86e7424 100644 --- a/apps/workers/workers/feedWorker.ts +++ b/apps/workers/workers/feedWorker.ts @@ -1,9 +1,9 @@ import { and, eq, inArray } from "drizzle-orm"; import { workerStatsCounter } from "metrics"; +import { fetchWithProxy } from "network"; import cron from "node-cron"; import Parser from "rss-parser"; import { buildImpersonatingTRPCClient } from "trpc"; -import { fetchWithProxy } from "utils"; import { z } from "zod"; import type { ZFeedRequestSchema } from "@karakeep/shared-server"; diff --git a/apps/workers/workers/videoWorker.ts b/apps/workers/workers/videoWorker.ts index a41eb069..8d3ac666 100644 --- a/apps/workers/workers/videoWorker.ts +++ b/apps/workers/workers/videoWorker.ts @@ -3,6 +3,7 @@ import * as os from "os"; import path from "path";
import { execa } from "execa";
import { workerStatsCounter } from "metrics";
+import { getProxyAgent, validateUrl } from "network";
import { db } from "@karakeep/db";
import { AssetTypes } from "@karakeep/db/schema";
@@ -62,7 +63,11 @@ export class VideoWorker { }
}
-function prepareYtDlpArguments(url: string, assetPath: string) {
+function prepareYtDlpArguments(
+ url: string,
+ proxy: string | undefined,
+ assetPath: string,
+) {
const ytDlpArguments = [url];
if (serverConfig.crawler.maxVideoDownloadSize > 0) {
ytDlpArguments.push(
@@ -74,6 +79,9 @@ function prepareYtDlpArguments(url: string, assetPath: string) { ytDlpArguments.push(...serverConfig.crawler.ytDlpArguments);
ytDlpArguments.push("-o", assetPath);
ytDlpArguments.push("--no-playlist");
+ if (proxy) {
+ ytDlpArguments.push("--proxy", proxy);
+ }
return ytDlpArguments;
}
@@ -94,15 +102,29 @@ async function runWorker(job: DequeuedJob<ZVideoRequest>) { return;
}
+ const proxy = getProxyAgent(url);
+ const validation = await validateUrl(url, !!proxy);
+ if (!validation.ok) {
+ logger.warn(
+ `[VideoCrawler][${jobId}] Skipping video download to disallowed URL "${url}": ${validation.reason}`,
+ );
+ return;
+ }
+ const normalizedUrl = validation.url.toString();
+
const videoAssetId = newAssetId();
let assetPath = `${TMP_FOLDER}/${videoAssetId}`;
await fs.promises.mkdir(TMP_FOLDER, { recursive: true });
- const ytDlpArguments = prepareYtDlpArguments(url, assetPath);
+ const ytDlpArguments = prepareYtDlpArguments(
+ normalizedUrl,
+ proxy?.proxy.toString(),
+ assetPath,
+ );
try {
logger.info(
- `[VideoCrawler][${jobId}] Attempting to download a file from "${url}" to "${assetPath}" using the following arguments: "${ytDlpArguments}"`,
+ `[VideoCrawler][${jobId}] Attempting to download a file from "${normalizedUrl}" to "${assetPath}" using the following arguments: "${ytDlpArguments}"`,
);
await execa("yt-dlp", ytDlpArguments, {
@@ -123,11 +145,11 @@ async function runWorker(job: DequeuedJob<ZVideoRequest>) { err.message.includes("No media found")
) {
logger.info(
- `[VideoCrawler][${jobId}] Skipping video download from "${url}", because it's not one of the supported yt-dlp URLs`,
+ `[VideoCrawler][${jobId}] Skipping video download from "${normalizedUrl}", because it's not one of the supported yt-dlp URLs`,
);
return;
}
- const genericError = `[VideoCrawler][${jobId}] Failed to download a file from "${url}" to "${assetPath}"`;
+ const genericError = `[VideoCrawler][${jobId}] Failed to download a file from "${normalizedUrl}" to "${assetPath}"`;
if ("stderr" in err) {
logger.error(`${genericError}: ${err.stderr}`);
} else {
@@ -138,7 +160,7 @@ async function runWorker(job: DequeuedJob<ZVideoRequest>) { }
logger.info(
- `[VideoCrawler][${jobId}] Finished downloading a file from "${url}" to "${assetPath}"`,
+ `[VideoCrawler][${jobId}] Finished downloading a file from "${normalizedUrl}" to "${assetPath}"`,
);
// Get file size and check quota before saving
@@ -177,7 +199,7 @@ async function runWorker(job: DequeuedJob<ZVideoRequest>) { await silentDeleteAsset(userId, oldVideoAssetId);
logger.info(
- `[VideoCrawler][${jobId}] Finished downloading video from "${url}" and adding it to the database`,
+ `[VideoCrawler][${jobId}] Finished downloading video from "${normalizedUrl}" and adding it to the database`,
);
} catch (error) {
if (error instanceof StorageQuotaError) {
diff --git a/apps/workers/workers/webhookWorker.ts b/apps/workers/workers/webhookWorker.ts index 2bbef160..472a27ed 100644 --- a/apps/workers/workers/webhookWorker.ts +++ b/apps/workers/workers/webhookWorker.ts @@ -1,6 +1,6 @@ import { eq } from "drizzle-orm"; import { workerStatsCounter } from "metrics"; -import fetch from "node-fetch"; +import { fetchWithProxy } from "network"; import { db } from "@karakeep/db"; import { bookmarks, webhooksTable } from "@karakeep/db/schema"; @@ -102,7 +102,7 @@ async function runWebhook(job: DequeuedJob<ZWebhookRequest>) { while (attempt < maxRetries && !success) { try { - const response = await fetch(url, { + const response = await fetchWithProxy(url, { method: "POST", headers: { "Content-Type": "application/json", |
