aboutsummaryrefslogtreecommitdiffstats
path: root/packages/trpc/rateLimit.ts
blob: b9aa4aa17b322137e6e5ac095d9c92496408cfd4 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import { TRPCError } from "@trpc/server";

import serverConfig from "@karakeep/shared/config";

import { Context } from ".";

interface RateLimitConfig {
  name: string;
  windowMs: number;
  maxRequests: number;
}

interface RateLimitEntry {
  count: number;
  resetTime: number;
}

const rateLimitStore = new Map<string, RateLimitEntry>();

function cleanupExpiredEntries() {
  const now = Date.now();
  for (const [key, entry] of rateLimitStore.entries()) {
    if (now > entry.resetTime) {
      rateLimitStore.delete(key);
    }
  }
}

setInterval(cleanupExpiredEntries, 60000);

export function createRateLimitMiddleware<T>(config: RateLimitConfig) {
  return function rateLimitMiddleware(opts: {
    path: string;
    ctx: Context;
    next: () => Promise<T>;
  }) {
    if (!serverConfig.rateLimiting.enabled) {
      return opts.next();
    }
    const ip = opts.ctx.req.ip;

    if (!ip) {
      return opts.next();
    }

    // TODO: Better fingerprinting
    const key = `${config.name}:${ip}:${opts.path}`;
    const now = Date.now();

    let entry = rateLimitStore.get(key);

    if (!entry || now > entry.resetTime) {
      entry = {
        count: 1,
        resetTime: now + config.windowMs,
      };
      rateLimitStore.set(key, entry);
      return opts.next();
    }

    if (entry.count >= config.maxRequests) {
      const resetInSeconds = Math.ceil((entry.resetTime - now) / 1000);
      throw new TRPCError({
        code: "TOO_MANY_REQUESTS",
        message: `Rate limit exceeded. Try again in ${resetInSeconds} seconds.`,
      });
    }

    entry.count++;
    return opts.next();
  };
}