From 136f126296af65f50da598d084d1485c0e40437a Mon Sep 17 00:00:00 2001 From: Mohamed Bassem Date: Sun, 27 Apr 2025 00:02:20 +0100 Subject: feat: Implement generic rule engine (#1318) * Add schema for the new rule engine * Add rule engine backend logic * Implement the worker logic and event firing * Implement the UI changesfor the rule engine * Ensure that when a referenced list or tag are deleted, the corresponding event/action is * Dont show smart lists in rule engine events * Add privacy validations for attached tag and list ids * Move the rules logic into a models --- packages/trpc/lib/__tests__/ruleEngine.test.ts | 664 +++++++++++++++++++++++++ packages/trpc/lib/ruleEngine.ts | 231 +++++++++ packages/trpc/models/lists.ts | 19 +- packages/trpc/models/rules.ts | 233 +++++++++ packages/trpc/package.json | 2 + packages/trpc/routers/_app.ts | 2 + packages/trpc/routers/bookmarks.ts | 27 + packages/trpc/routers/lists.ts | 3 +- packages/trpc/routers/rules.test.ts | 379 ++++++++++++++ packages/trpc/routers/rules.ts | 120 +++++ packages/trpc/routers/tags.ts | 2 +- 11 files changed, 1678 insertions(+), 4 deletions(-) create mode 100644 packages/trpc/lib/__tests__/ruleEngine.test.ts create mode 100644 packages/trpc/lib/ruleEngine.ts create mode 100644 packages/trpc/models/rules.ts create mode 100644 packages/trpc/routers/rules.test.ts create mode 100644 packages/trpc/routers/rules.ts (limited to 'packages/trpc') diff --git a/packages/trpc/lib/__tests__/ruleEngine.test.ts b/packages/trpc/lib/__tests__/ruleEngine.test.ts new file mode 100644 index 00000000..cbb4b978 --- /dev/null +++ b/packages/trpc/lib/__tests__/ruleEngine.test.ts @@ -0,0 +1,664 @@ +import { and, eq } from "drizzle-orm"; +import { afterEach, beforeEach, describe, expect, it, vi } from "vitest"; + +import { getInMemoryDB } from "@karakeep/db/drizzle"; +import { + bookmarkLinks, + bookmarkLists, + bookmarks, + bookmarksInLists, + bookmarkTags, + rssFeedImportsTable, + rssFeedsTable, + ruleEngineActionsTable as ruleActions, + ruleEngineRulesTable as rules, + tagsOnBookmarks, + users, +} from "@karakeep/db/schema"; +import { LinkCrawlerQueue } from "@karakeep/shared/queues"; +import { BookmarkTypes } from "@karakeep/shared/types/bookmarks"; +import { + RuleEngineAction, + RuleEngineCondition, + RuleEngineEvent, + RuleEngineRule, +} from "@karakeep/shared/types/rules"; + +import { AuthedContext } from "../.."; +import { TestDB } from "../../testUtils"; +import { RuleEngine } from "../ruleEngine"; + +// Mock the queue +vi.mock("@karakeep/shared/queues", () => ({ + LinkCrawlerQueue: { + enqueue: vi.fn(), + }, + triggerRuleEngineOnEvent: vi.fn(), +})); + +describe("RuleEngine", () => { + let db: TestDB; + let ctx: AuthedContext; + let userId: string; + let bookmarkId: string; + let linkBookmarkId: string; + let _textBookmarkId: string; + let tagId1: string; + let tagId2: string; + let feedId1: string; + let listId1: string; + + // Helper to seed a rule + const seedRule = async ( + ruleData: Omit & { userId: string }, + ): Promise => { + const [insertedRule] = await db + .insert(rules) + .values({ + userId: ruleData.userId, + name: ruleData.name, + description: ruleData.description, + enabled: ruleData.enabled, + event: JSON.stringify(ruleData.event), + condition: JSON.stringify(ruleData.condition), + }) + .returning({ id: rules.id }); + + await db.insert(ruleActions).values( + ruleData.actions.map((action) => ({ + ruleId: insertedRule.id, + action: JSON.stringify(action), + userId: ruleData.userId, + })), + ); + return insertedRule.id; + }; + + beforeEach(async () => { + vi.resetAllMocks(); + db = getInMemoryDB(/* runMigrations */ true); + + // Seed User + [userId] = ( + await db + .insert(users) + .values({ name: "Test User", email: "test@test.com" }) + .returning({ id: users.id }) + ).map((u) => u.id); + + ctx = { + user: { id: userId, role: "user" }, + db: db, // Cast needed because TestDB might have extra test methods + req: { ip: null }, + }; + + // Seed Tags + [tagId1, tagId2] = ( + await db + .insert(bookmarkTags) + .values([ + { name: "Tag1", userId }, + { name: "Tag2", userId }, + ]) + .returning({ id: bookmarkTags.id }) + ).map((t) => t.id); + + // Seed Feed + [feedId1] = ( + await db + .insert(rssFeedsTable) + .values({ name: "Feed1", userId, url: "https://example.com/feed1" }) + .returning({ id: rssFeedsTable.id }) + ).map((f) => f.id); + + // Seed List + [listId1] = ( + await db + .insert(bookmarkLists) + .values({ name: "List1", userId, type: "manual", icon: "📚" }) + .returning({ id: bookmarkLists.id }) + ).map((l) => l.id); + + // Seed Bookmarks + [linkBookmarkId] = ( + await db + .insert(bookmarks) + .values({ + userId, + type: BookmarkTypes.LINK, + favourited: false, + archived: false, + }) + .returning({ id: bookmarks.id }) + ).map((b) => b.id); + await db.insert(bookmarkLinks).values({ + id: linkBookmarkId, + url: "https://example.com/test", + }); + await db.insert(tagsOnBookmarks).values({ + bookmarkId: linkBookmarkId, + tagId: tagId1, + attachedBy: "human", + }); + await db.insert(rssFeedImportsTable).values({ + bookmarkId: linkBookmarkId, + rssFeedId: feedId1, + entryId: "entry-id", + }); + + [_textBookmarkId] = ( + await db + .insert(bookmarks) + .values({ + userId, + type: BookmarkTypes.TEXT, + favourited: true, + archived: false, + }) + .returning({ id: bookmarks.id }) + ).map((b) => b.id); + + bookmarkId = linkBookmarkId; // Default bookmark for most tests + }); + + afterEach(() => { + vi.restoreAllMocks(); + }); + + describe("RuleEngine.forBookmark static method", () => { + it("should initialize RuleEngine successfully for an existing bookmark", async () => { + const engine = await RuleEngine.forBookmark(ctx, bookmarkId); + expect(engine).toBeInstanceOf(RuleEngine); + }); + + it("should throw an error if bookmark is not found", async () => { + await expect( + RuleEngine.forBookmark(ctx, "nonexistent-bookmark"), + ).rejects.toThrow("Bookmark nonexistent-bookmark not found"); + }); + + it("should load rules associated with the bookmark's user", async () => { + const ruleId = await seedRule({ + userId, + name: "Test Rule", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "urlContains", str: "example" }, + actions: [{ type: "addTag", tagId: tagId2 }], + }); + + const engine = await RuleEngine.forBookmark(ctx, bookmarkId); + // @ts-expect-error Accessing private property for test verification + expect(engine.rules).toHaveLength(1); + // @ts-expect-error Accessing private property for test verification + expect(engine.rules[0].id).toBe(ruleId); + }); + }); + + describe("doesBookmarkMatchConditions", () => { + let engine: RuleEngine; + + beforeEach(async () => { + engine = await RuleEngine.forBookmark(ctx, bookmarkId); + }); + + it("should return true for urlContains condition", () => { + const condition: RuleEngineCondition = { + type: "urlContains", + str: "example.com", + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should return false for urlContains condition mismatch", () => { + const condition: RuleEngineCondition = { + type: "urlContains", + str: "nonexistent", + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should return true for importedFromFeed condition", () => { + const condition: RuleEngineCondition = { + type: "importedFromFeed", + feedId: feedId1, + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should return false for importedFromFeed condition mismatch", () => { + const condition: RuleEngineCondition = { + type: "importedFromFeed", + feedId: "other-feed", + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should return true for bookmarkTypeIs condition (link)", () => { + const condition: RuleEngineCondition = { + type: "bookmarkTypeIs", + bookmarkType: BookmarkTypes.LINK, + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should return false for bookmarkTypeIs condition mismatch", () => { + const condition: RuleEngineCondition = { + type: "bookmarkTypeIs", + bookmarkType: BookmarkTypes.TEXT, + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should return true for hasTag condition", () => { + const condition: RuleEngineCondition = { type: "hasTag", tagId: tagId1 }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should return false for hasTag condition mismatch", () => { + const condition: RuleEngineCondition = { type: "hasTag", tagId: tagId2 }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should return false for isFavourited condition (default)", () => { + const condition: RuleEngineCondition = { type: "isFavourited" }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should return true for isFavourited condition when favourited", async () => { + await db + .update(bookmarks) + .set({ favourited: true }) + .where(eq(bookmarks.id, bookmarkId)); + const updatedEngine = await RuleEngine.forBookmark(ctx, bookmarkId); + const condition: RuleEngineCondition = { type: "isFavourited" }; + expect(updatedEngine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should return false for isArchived condition (default)", () => { + const condition: RuleEngineCondition = { type: "isArchived" }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should return true for isArchived condition when archived", async () => { + await db + .update(bookmarks) + .set({ archived: true }) + .where(eq(bookmarks.id, bookmarkId)); + const updatedEngine = await RuleEngine.forBookmark(ctx, bookmarkId); + const condition: RuleEngineCondition = { type: "isArchived" }; + expect(updatedEngine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should handle and condition (true)", () => { + const condition: RuleEngineCondition = { + type: "and", + conditions: [ + { type: "urlContains", str: "example" }, + { type: "hasTag", tagId: tagId1 }, + ], + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should handle and condition (false)", () => { + const condition: RuleEngineCondition = { + type: "and", + conditions: [ + { type: "urlContains", str: "example" }, + { type: "hasTag", tagId: tagId2 }, // This one is false + ], + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + + it("should handle or condition (true)", () => { + const condition: RuleEngineCondition = { + type: "or", + conditions: [ + { type: "urlContains", str: "nonexistent" }, // false + { type: "hasTag", tagId: tagId1 }, // true + ], + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(true); + }); + + it("should handle or condition (false)", () => { + const condition: RuleEngineCondition = { + type: "or", + conditions: [ + { type: "urlContains", str: "nonexistent" }, // false + { type: "hasTag", tagId: tagId2 }, // false + ], + }; + expect(engine.doesBookmarkMatchConditions(condition)).toBe(false); + }); + }); + + describe("evaluateRule", () => { + let ruleId: string; + let engine: RuleEngine; + let testRule: RuleEngineRule; + + beforeEach(async () => { + const tmp = { + id: "", // Will be set after seeding + userId, + name: "Evaluate Rule Test", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "urlContains", str: "example" }, + actions: [{ type: "addTag", tagId: tagId2 }], + } as Omit & { userId: string }; + ruleId = await seedRule(tmp); + testRule = { ...tmp, id: ruleId }; + engine = await RuleEngine.forBookmark(ctx, bookmarkId); + }); + + it("should evaluate rule successfully when event and conditions match", async () => { + const event: RuleEngineEvent = { type: "bookmarkAdded" }; + const results = await engine.evaluateRule(testRule, event); + expect(results).toEqual([ + { type: "success", ruleId: ruleId, message: `Added tag ${tagId2}` }, + ]); + // Verify action was performed + const tags = await db.query.tagsOnBookmarks.findMany({ + where: eq(tagsOnBookmarks.bookmarkId, bookmarkId), + }); + expect(tags.map((t) => t.tagId)).toContain(tagId2); + }); + + it("should return empty array if rule is disabled", async () => { + await db + .update(rules) + .set({ enabled: false }) + .where(eq(rules.id, ruleId)); + const disabledRule = { ...testRule, enabled: false }; + const event: RuleEngineEvent = { type: "bookmarkAdded" }; + const results = await engine.evaluateRule(disabledRule, event); + expect(results).toEqual([]); + }); + + it("should return empty array if event does not match", async () => { + const event: RuleEngineEvent = { type: "favourited" }; + const results = await engine.evaluateRule(testRule, event); + expect(results).toEqual([]); + }); + + it("should return empty array if condition does not match", async () => { + const nonMatchingRule: RuleEngineRule = { + ...testRule, + condition: { type: "urlContains", str: "nonexistent" }, + }; + await db + .update(rules) + .set({ condition: JSON.stringify(nonMatchingRule.condition) }) + .where(eq(rules.id, ruleId)); + + const event: RuleEngineEvent = { type: "bookmarkAdded" }; + const results = await engine.evaluateRule(nonMatchingRule, event); + expect(results).toEqual([]); + }); + + it("should return failure result if action fails", async () => { + // Mock addBookmark to throw an error + const listAddBookmarkSpy = vi + .spyOn(RuleEngine.prototype, "executeAction") + .mockImplementation(async (action: RuleEngineAction) => { + if (action.type === "addToList") { + throw new Error("Failed to add to list"); + } + // Call original for other actions if needed, though not strictly necessary here + return Promise.resolve(`Action ${action.type} executed`); + }); + + const ruleWithFailingAction = { + ...testRule, + actions: [{ type: "addToList", listId: "invalid-list" } as const], + }; + await db.delete(ruleActions).where(eq(ruleActions.ruleId, ruleId)); // Clear old actions + await db.insert(ruleActions).values({ + ruleId: ruleId, + action: JSON.stringify(ruleWithFailingAction.actions[0]), + userId, + }); + + const event: RuleEngineEvent = { type: "bookmarkAdded" }; + const results = await engine.evaluateRule(ruleWithFailingAction, event); + + expect(results).toEqual([ + { + type: "failure", + ruleId: ruleId, + message: "Failed to add to list", + }, + ]); + listAddBookmarkSpy.mockRestore(); + }); + }); + + describe("executeAction", () => { + let engine: RuleEngine; + + beforeEach(async () => { + engine = await RuleEngine.forBookmark(ctx, bookmarkId); + }); + + it("should execute addTag action", async () => { + const action: RuleEngineAction = { type: "addTag", tagId: tagId2 }; + const result = await engine.executeAction(action); + expect(result).toBe(`Added tag ${tagId2}`); + const tagLink = await db.query.tagsOnBookmarks.findFirst({ + where: and( + eq(tagsOnBookmarks.bookmarkId, bookmarkId), + eq(tagsOnBookmarks.tagId, tagId2), + ), + }); + expect(tagLink).toBeDefined(); + }); + + it("should execute removeTag action", async () => { + // Ensure tag exists first + expect( + await db.query.tagsOnBookmarks.findFirst({ + where: and( + eq(tagsOnBookmarks.bookmarkId, bookmarkId), + eq(tagsOnBookmarks.tagId, tagId1), + ), + }), + ).toBeDefined(); + + const action: RuleEngineAction = { type: "removeTag", tagId: tagId1 }; + const result = await engine.executeAction(action); + expect(result).toBe(`Removed tag ${tagId1}`); + const tagLink = await db.query.tagsOnBookmarks.findFirst({ + where: and( + eq(tagsOnBookmarks.bookmarkId, bookmarkId), + eq(tagsOnBookmarks.tagId, tagId1), + ), + }); + expect(tagLink).toBeUndefined(); + }); + + it("should execute addToList action", async () => { + const action: RuleEngineAction = { type: "addToList", listId: listId1 }; + const result = await engine.executeAction(action); + expect(result).toBe(`Added to list ${listId1}`); + const listLink = await db.query.bookmarksInLists.findFirst({ + where: and( + eq(bookmarksInLists.bookmarkId, bookmarkId), + eq(bookmarksInLists.listId, listId1), + ), + }); + expect(listLink).toBeDefined(); + }); + + it("should execute removeFromList action", async () => { + // Add to list first + await db + .insert(bookmarksInLists) + .values({ bookmarkId: bookmarkId, listId: listId1 }); + expect( + await db.query.bookmarksInLists.findFirst({ + where: and( + eq(bookmarksInLists.bookmarkId, bookmarkId), + eq(bookmarksInLists.listId, listId1), + ), + }), + ).toBeDefined(); + + const action: RuleEngineAction = { + type: "removeFromList", + listId: listId1, + }; + const result = await engine.executeAction(action); + expect(result).toBe(`Removed from list ${listId1}`); + const listLink = await db.query.bookmarksInLists.findFirst({ + where: and( + eq(bookmarksInLists.bookmarkId, bookmarkId), + eq(bookmarksInLists.listId, listId1), + ), + }); + expect(listLink).toBeUndefined(); + }); + + it("should execute downloadFullPageArchive action", async () => { + const action: RuleEngineAction = { type: "downloadFullPageArchive" }; + const result = await engine.executeAction(action); + expect(result).toBe(`Enqueued full page archive`); + expect(LinkCrawlerQueue.enqueue).toHaveBeenCalledWith({ + bookmarkId: bookmarkId, + archiveFullPage: true, + runInference: false, + }); + }); + + it("should execute favouriteBookmark action", async () => { + let bm = await db.query.bookmarks.findFirst({ + where: eq(bookmarks.id, bookmarkId), + }); + expect(bm?.favourited).toBe(false); + + const action: RuleEngineAction = { type: "favouriteBookmark" }; + const result = await engine.executeAction(action); + expect(result).toBe(`Marked as favourited`); + + bm = await db.query.bookmarks.findFirst({ + where: eq(bookmarks.id, bookmarkId), + }); + expect(bm?.favourited).toBe(true); + }); + + it("should execute archiveBookmark action", async () => { + let bm = await db.query.bookmarks.findFirst({ + where: eq(bookmarks.id, bookmarkId), + }); + expect(bm?.archived).toBe(false); + + const action: RuleEngineAction = { type: "archiveBookmark" }; + const result = await engine.executeAction(action); + expect(result).toBe(`Marked as archived`); + + bm = await db.query.bookmarks.findFirst({ + where: eq(bookmarks.id, bookmarkId), + }); + expect(bm?.archived).toBe(true); + }); + }); + + describe("onEvent", () => { + let ruleMatchId: string; + let _ruleNoMatchConditionId: string; + let _ruleNoMatchEventId: string; + let _ruleDisabledId: string; + let engine: RuleEngine; + + beforeEach(async () => { + // Rule that should match and execute + ruleMatchId = await seedRule({ + userId, + name: "Match Rule", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "urlContains", str: "example" }, + actions: [{ type: "addTag", tagId: tagId2 }], + }); + + // Rule with non-matching condition + _ruleNoMatchConditionId = await seedRule({ + userId, + name: "No Match Condition Rule", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "urlContains", str: "nonexistent" }, + actions: [{ type: "favouriteBookmark" }], + }); + + // Rule with non-matching event + _ruleNoMatchEventId = await seedRule({ + userId, + name: "No Match Event Rule", + description: "", + enabled: true, + event: { type: "favourited" }, // Must match rule event type + condition: { type: "urlContains", str: "example" }, + actions: [{ type: "archiveBookmark" }], + }); + + // Disabled rule + _ruleDisabledId = await seedRule({ + userId, + name: "Disabled Rule", + description: "", + enabled: false, // Disabled + event: { type: "bookmarkAdded" }, + condition: { type: "urlContains", str: "example" }, + actions: [{ type: "addToList", listId: listId1 }], + }); + + engine = await RuleEngine.forBookmark(ctx, bookmarkId); + }); + + it("should process event and return only results for matching, enabled rules", async () => { + const event: RuleEngineEvent = { type: "bookmarkAdded" }; + const results = await engine.onEvent(event); + + expect(results).toHaveLength(1); // Only ruleMatchId should produce a result + expect(results[0]).toEqual({ + type: "success", + ruleId: ruleMatchId, + message: `Added tag ${tagId2}`, + }); + + // Verify only the action from the matching rule was executed + const tags = await db.query.tagsOnBookmarks.findMany({ + where: eq(tagsOnBookmarks.bookmarkId, bookmarkId), + }); + expect(tags.map((t) => t.tagId)).toContain(tagId2); // Tag added by ruleMatchId + + const bm = await db.query.bookmarks.findFirst({ + where: eq(bookmarks.id, bookmarkId), + }); + expect(bm?.favourited).toBe(false); // Action from ruleNoMatchConditionId not executed + expect(bm?.archived).toBe(false); // Action from ruleNoMatchEventId not executed + + const listLink = await db.query.bookmarksInLists.findFirst({ + where: and( + eq(bookmarksInLists.bookmarkId, bookmarkId), + eq(bookmarksInLists.listId, listId1), + ), + }); + expect(listLink).toBeUndefined(); // Action from ruleDisabledId not executed + }); + + it("should return empty array if no rules match the event", async () => { + const event: RuleEngineEvent = { type: "tagAdded", tagId: "some-tag" }; // Event that matches no rules + const results = await engine.onEvent(event); + expect(results).toEqual([]); + }); + }); +}); diff --git a/packages/trpc/lib/ruleEngine.ts b/packages/trpc/lib/ruleEngine.ts new file mode 100644 index 00000000..0bef8cdc --- /dev/null +++ b/packages/trpc/lib/ruleEngine.ts @@ -0,0 +1,231 @@ +import deepEql from "deep-equal"; +import { and, eq } from "drizzle-orm"; + +import { bookmarks, tagsOnBookmarks } from "@karakeep/db/schema"; +import { LinkCrawlerQueue } from "@karakeep/shared/queues"; +import { + RuleEngineAction, + RuleEngineCondition, + RuleEngineEvent, + RuleEngineRule, +} from "@karakeep/shared/types/rules"; + +import { AuthedContext } from ".."; +import { List } from "../models/lists"; +import { RuleEngineRuleModel } from "../models/rules"; + +async function fetchBookmark(db: AuthedContext["db"], bookmarkId: string) { + return await db.query.bookmarks.findFirst({ + where: eq(bookmarks.id, bookmarkId), + with: { + link: { + columns: { + url: true, + }, + }, + text: true, + asset: true, + tagsOnBookmarks: true, + rssFeeds: { + columns: { + rssFeedId: true, + }, + }, + user: { + columns: {}, + with: { + rules: { + with: { + actions: true, + }, + }, + }, + }, + }, + }); +} + +type ReturnedBookmark = NonNullable>>; + +export interface RuleEngineEvaluationResult { + type: "success" | "failure"; + ruleId: string; + message: string; +} + +export class RuleEngine { + private constructor( + private ctx: AuthedContext, + private bookmark: Omit, + private rules: RuleEngineRule[], + ) {} + + static async forBookmark(ctx: AuthedContext, bookmarkId: string) { + const [bookmark, rules] = await Promise.all([ + fetchBookmark(ctx.db, bookmarkId), + RuleEngineRuleModel.getAll(ctx), + ]); + if (!bookmark) { + throw new Error(`Bookmark ${bookmarkId} not found`); + } + return new RuleEngine( + ctx, + bookmark, + rules.map((r) => r.rule), + ); + } + + doesBookmarkMatchConditions(condition: RuleEngineCondition): boolean { + switch (condition.type) { + case "alwaysTrue": { + return true; + } + case "urlContains": { + return (this.bookmark.link?.url ?? "").includes(condition.str); + } + case "importedFromFeed": { + return this.bookmark.rssFeeds.some( + (f) => f.rssFeedId === condition.feedId, + ); + } + case "bookmarkTypeIs": { + return this.bookmark.type === condition.bookmarkType; + } + case "hasTag": { + return this.bookmark.tagsOnBookmarks.some( + (t) => t.tagId === condition.tagId, + ); + } + case "isFavourited": { + return this.bookmark.favourited; + } + case "isArchived": { + return this.bookmark.archived; + } + case "and": { + return condition.conditions.every((c) => + this.doesBookmarkMatchConditions(c), + ); + } + case "or": { + return condition.conditions.some((c) => + this.doesBookmarkMatchConditions(c), + ); + } + default: { + const _exhaustiveCheck: never = condition; + return false; + } + } + } + + async evaluateRule( + rule: RuleEngineRule, + event: RuleEngineEvent, + ): Promise { + if (!rule.enabled) { + return []; + } + if (!deepEql(rule.event, event, { strict: true })) { + return []; + } + if (!this.doesBookmarkMatchConditions(rule.condition)) { + return []; + } + const results = await Promise.allSettled( + rule.actions.map((action) => this.executeAction(action)), + ); + return results.map((result) => { + if (result.status === "fulfilled") { + return { + type: "success", + ruleId: rule.id, + message: result.value, + }; + } else { + return { + type: "failure", + ruleId: rule.id, + message: (result.reason as Error).message, + }; + } + }); + } + + async executeAction(action: RuleEngineAction): Promise { + switch (action.type) { + case "addTag": { + await this.ctx.db + .insert(tagsOnBookmarks) + .values([ + { + attachedBy: "human", + bookmarkId: this.bookmark.id, + tagId: action.tagId, + }, + ]) + .onConflictDoNothing(); + return `Added tag ${action.tagId}`; + } + case "removeTag": { + await this.ctx.db + .delete(tagsOnBookmarks) + .where( + and( + eq(tagsOnBookmarks.tagId, action.tagId), + eq(tagsOnBookmarks.bookmarkId, this.bookmark.id), + ), + ); + return `Removed tag ${action.tagId}`; + } + case "addToList": { + const list = await List.fromId(this.ctx, action.listId); + await list.addBookmark(this.bookmark.id); + return `Added to list ${action.listId}`; + } + case "removeFromList": { + const list = await List.fromId(this.ctx, action.listId); + await list.removeBookmark(this.bookmark.id); + return `Removed from list ${action.listId}`; + } + case "downloadFullPageArchive": { + await LinkCrawlerQueue.enqueue({ + bookmarkId: this.bookmark.id, + archiveFullPage: true, + runInference: false, + }); + return `Enqueued full page archive`; + } + case "favouriteBookmark": { + await this.ctx.db + .update(bookmarks) + .set({ + favourited: true, + }) + .where(eq(bookmarks.id, this.bookmark.id)); + return `Marked as favourited`; + } + case "archiveBookmark": { + await this.ctx.db + .update(bookmarks) + .set({ + archived: true, + }) + .where(eq(bookmarks.id, this.bookmark.id)); + return `Marked as archived`; + } + default: { + const _exhaustiveCheck: never = action; + return ""; + } + } + } + + async onEvent(event: RuleEngineEvent): Promise { + const results = await Promise.all( + this.rules.map((rule) => this.evaluateRule(rule, event)), + ); + + return results.flat(); + } +} diff --git a/packages/trpc/models/lists.ts b/packages/trpc/models/lists.ts index 8072060f..4da127d2 100644 --- a/packages/trpc/models/lists.ts +++ b/packages/trpc/models/lists.ts @@ -5,6 +5,7 @@ import { z } from "zod"; import { SqliteError } from "@karakeep/db"; import { bookmarkLists, bookmarksInLists } from "@karakeep/db/schema"; +import { triggerRuleEngineOnEvent } from "@karakeep/shared/queues"; import { parseSearchQuery } from "@karakeep/shared/searchQueryParser"; import { ZBookmarkList, @@ -117,7 +118,9 @@ export abstract class List implements PrivacyAware { } } - async update(input: z.infer) { + async update( + input: z.infer, + ): Promise { const result = await this.ctx.db .update(bookmarkLists) .set({ @@ -137,7 +140,7 @@ export abstract class List implements PrivacyAware { if (result.length == 0) { throw new TRPCError({ code: "NOT_FOUND" }); } - return result[0]; + this.list = result[0]; } abstract get type(): "manual" | "smart"; @@ -248,6 +251,12 @@ export class ManualList extends List { listId: this.list.id, bookmarkId, }); + await triggerRuleEngineOnEvent(bookmarkId, [ + { + type: "addedToList", + listId: this.list.id, + }, + ]); } catch (e) { if (e instanceof SqliteError) { if (e.code == "SQLITE_CONSTRAINT_PRIMARYKEY") { @@ -279,6 +288,12 @@ export class ManualList extends List { message: `Bookmark ${bookmarkId} is already not in list ${this.list.id}`, }); } + await triggerRuleEngineOnEvent(bookmarkId, [ + { + type: "removedFromList", + listId: this.list.id, + }, + ]); } async update(input: z.infer) { diff --git a/packages/trpc/models/rules.ts b/packages/trpc/models/rules.ts new file mode 100644 index 00000000..7b17fd8a --- /dev/null +++ b/packages/trpc/models/rules.ts @@ -0,0 +1,233 @@ +import { TRPCError } from "@trpc/server"; +import { and, eq } from "drizzle-orm"; +import { z } from "zod"; + +import { db as DONT_USE_DB } from "@karakeep/db"; +import { + ruleEngineActionsTable, + ruleEngineRulesTable, +} from "@karakeep/db/schema"; +import { + RuleEngineRule, + zNewRuleEngineRuleSchema, + zRuleEngineActionSchema, + zRuleEngineConditionSchema, + zRuleEngineEventSchema, + zUpdateRuleEngineRuleSchema, +} from "@karakeep/shared/types/rules"; + +import { AuthedContext } from ".."; +import { PrivacyAware } from "./privacy"; + +function dummy_fetchRule(ctx: AuthedContext, id: string) { + return DONT_USE_DB.query.ruleEngineRulesTable.findFirst({ + where: and( + eq(ruleEngineRulesTable.id, id), + eq(ruleEngineRulesTable.userId, ctx.user.id), + ), + with: { + actions: true, // Assuming actions are related; adjust if needed + }, + }); +} + +type FetchedRuleType = NonNullable>>; + +export class RuleEngineRuleModel implements PrivacyAware { + protected constructor( + protected ctx: AuthedContext, + public rule: RuleEngineRule & { userId: string }, + ) {} + + private static fromData( + ctx: AuthedContext, + ruleData: FetchedRuleType, + ): RuleEngineRuleModel { + return new RuleEngineRuleModel(ctx, { + id: ruleData.id, + userId: ruleData.userId, + name: ruleData.name, + description: ruleData.description, + enabled: ruleData.enabled, + event: zRuleEngineEventSchema.parse(JSON.parse(ruleData.event)), + condition: zRuleEngineConditionSchema.parse( + JSON.parse(ruleData.condition), + ), + actions: ruleData.actions.map((a) => + zRuleEngineActionSchema.parse(JSON.parse(a.action)), + ), + }); + } + + static async fromId( + ctx: AuthedContext, + id: string, + ): Promise { + const ruleData = await ctx.db.query.ruleEngineRulesTable.findFirst({ + where: and( + eq(ruleEngineRulesTable.id, id), + eq(ruleEngineRulesTable.userId, ctx.user.id), + ), + with: { + actions: true, // Assuming actions are related; adjust if needed + }, + }); + + if (!ruleData) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "Rule not found", + }); + } + + return this.fromData(ctx, ruleData); + } + + ensureCanAccess(ctx: AuthedContext): void { + if (this.rule.userId != ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + } + + static async create( + ctx: AuthedContext, + input: z.infer, + ): Promise { + // Similar to lists create, but for rules + const insertedRule = await ctx.db.transaction(async (tx) => { + const [newRule] = await tx + .insert(ruleEngineRulesTable) + .values({ + name: input.name, + description: input.description, + enabled: input.enabled, + event: JSON.stringify(input.event), + condition: JSON.stringify(input.condition), + userId: ctx.user.id, + listId: + input.event.type === "addedToList" || + input.event.type === "removedFromList" + ? input.event.listId + : null, + tagId: + input.event.type === "tagAdded" || input.event.type === "tagRemoved" + ? input.event.tagId + : null, + }) + .returning(); + + if (input.actions.length > 0) { + await tx.insert(ruleEngineActionsTable).values( + input.actions.map((action) => ({ + ruleId: newRule.id, + userId: ctx.user.id, + action: JSON.stringify(action), + listId: + action.type === "addToList" || action.type === "removeFromList" + ? action.listId + : null, + tagId: + action.type === "addTag" || action.type === "removeTag" + ? action.tagId + : null, + })), + ); + } + return newRule; + }); + + // Fetch the full rule after insertion + return await RuleEngineRuleModel.fromId(ctx, insertedRule.id); + } + + async update( + input: z.infer, + ): Promise { + if (this.rule.id !== input.id) { + throw new TRPCError({ code: "BAD_REQUEST", message: "ID mismatch" }); + } + + await this.ctx.db.transaction(async (tx) => { + const result = await tx + .update(ruleEngineRulesTable) + .set({ + name: input.name, + description: input.description, + enabled: input.enabled, + event: JSON.stringify(input.event), + condition: JSON.stringify(input.condition), + listId: + input.event.type === "addedToList" || + input.event.type === "removedFromList" + ? input.event.listId + : null, + tagId: + input.event.type === "tagAdded" || input.event.type === "tagRemoved" + ? input.event.tagId + : null, + }) + .where( + and( + eq(ruleEngineRulesTable.id, input.id), + eq(ruleEngineRulesTable.userId, this.ctx.user.id), + ), + ); + + if (result.changes === 0) { + throw new TRPCError({ code: "NOT_FOUND", message: "Rule not found" }); + } + + if (input.actions.length > 0) { + await tx + .delete(ruleEngineActionsTable) + .where(eq(ruleEngineActionsTable.ruleId, input.id)); + await tx.insert(ruleEngineActionsTable).values( + input.actions.map((action) => ({ + ruleId: input.id, + userId: this.ctx.user.id, + action: JSON.stringify(action), + listId: + action.type === "addToList" || action.type === "removeFromList" + ? action.listId + : null, + tagId: + action.type === "addTag" || action.type === "removeTag" + ? action.tagId + : null, + })), + ); + } + }); + + this.rule = await RuleEngineRuleModel.fromId(this.ctx, this.rule.id).then( + (r) => r.rule, + ); + } + + async delete(): Promise { + const result = await this.ctx.db + .delete(ruleEngineRulesTable) + .where( + and( + eq(ruleEngineRulesTable.id, this.rule.id), + eq(ruleEngineRulesTable.userId, this.ctx.user.id), + ), + ); + + if (result.changes === 0) { + throw new TRPCError({ code: "NOT_FOUND", message: "Rule not found" }); + } + } + + static async getAll(ctx: AuthedContext): Promise { + const rulesData = await ctx.db.query.ruleEngineRulesTable.findMany({ + where: eq(ruleEngineRulesTable.userId, ctx.user.id), + with: { actions: true }, + }); + + return rulesData.map((r) => this.fromData(ctx, r)); + } +} diff --git a/packages/trpc/package.json b/packages/trpc/package.json index 94fdee1b..5b5bad86 100644 --- a/packages/trpc/package.json +++ b/packages/trpc/package.json @@ -17,6 +17,7 @@ "@karakeep/shared": "workspace:*", "@trpc/server": "11.0.0", "bcryptjs": "^2.4.3", + "deep-equal": "^2.2.3", "drizzle-orm": "^0.38.3", "superjson": "^2.2.1", "tiny-invariant": "^1.3.3", @@ -27,6 +28,7 @@ "@karakeep/prettier-config": "workspace:^0.1.0", "@karakeep/tsconfig": "workspace:^0.1.0", "@types/bcryptjs": "^2.4.6", + "@types/deep-equal": "^1.0.4", "vite-tsconfig-paths": "^4.3.1", "vitest": "^1.6.1" }, diff --git a/packages/trpc/routers/_app.ts b/packages/trpc/routers/_app.ts index 7af19884..394e95e7 100644 --- a/packages/trpc/routers/_app.ts +++ b/packages/trpc/routers/_app.ts @@ -7,6 +7,7 @@ import { feedsAppRouter } from "./feeds"; import { highlightsAppRouter } from "./highlights"; import { listsAppRouter } from "./lists"; import { promptsAppRouter } from "./prompts"; +import { rulesAppRouter } from "./rules"; import { tagsAppRouter } from "./tags"; import { usersAppRouter } from "./users"; import { webhooksAppRouter } from "./webhooks"; @@ -23,6 +24,7 @@ export const appRouter = router({ highlights: highlightsAppRouter, webhooks: webhooksAppRouter, assets: assetsAppRouter, + rules: rulesAppRouter, }); // export type definition of API export type AppRouter = typeof appRouter; diff --git a/packages/trpc/routers/bookmarks.ts b/packages/trpc/routers/bookmarks.ts index 9a1b6b0b..b9a21400 100644 --- a/packages/trpc/routers/bookmarks.ts +++ b/packages/trpc/routers/bookmarks.ts @@ -45,6 +45,7 @@ import { AssetPreprocessingQueue, LinkCrawlerQueue, OpenAIQueue, + triggerRuleEngineOnEvent, triggerSearchDeletion, triggerSearchReindex, triggerWebhook, @@ -430,6 +431,11 @@ export const bookmarksAppRouter = router({ break; } } + await triggerRuleEngineOnEvent(bookmark.id, [ + { + type: "bookmarkAdded", + }, + ]); await triggerSearchReindex(bookmark.id); await triggerWebhook(bookmark.id, "created"); return bookmark; @@ -573,6 +579,17 @@ export const bookmarksAppRouter = router({ /* includeContent: */ false, ); + if (input.favourited === true || input.archived === true) { + await triggerRuleEngineOnEvent( + input.bookmarkId, + [ + ...(input.favourited === true ? ["favourited" as const] : []), + ...(input.archived === true ? ["archived" as const] : []), + ].map((t) => ({ + type: t, + })), + ); + } // Trigger re-indexing and webhooks await triggerSearchReindex(input.bookmarkId); await triggerWebhook(input.bookmarkId, "edited"); @@ -1141,6 +1158,16 @@ export const bookmarksAppRouter = router({ ), ); + await triggerRuleEngineOnEvent(input.bookmarkId, [ + ...idsToRemove.map((t) => ({ + type: "tagRemoved" as const, + tagId: t, + })), + ...allIds.map((t) => ({ + type: "tagAdded" as const, + tagId: t, + })), + ]); await triggerSearchReindex(input.bookmarkId); await triggerWebhook(input.bookmarkId, "edited"); return { diff --git a/packages/trpc/routers/lists.ts b/packages/trpc/routers/lists.ts index 12960316..65cffd2d 100644 --- a/packages/trpc/routers/lists.ts +++ b/packages/trpc/routers/lists.ts @@ -38,7 +38,8 @@ export const listsAppRouter = router({ .output(zBookmarkListSchema) .use(ensureListOwnership) .mutation(async ({ input, ctx }) => { - return await ctx.list.update(input); + await ctx.list.update(input); + return ctx.list.list; }), merge: authedProcedure .input(zMergeListSchema) diff --git a/packages/trpc/routers/rules.test.ts b/packages/trpc/routers/rules.test.ts new file mode 100644 index 00000000..6bbbcd84 --- /dev/null +++ b/packages/trpc/routers/rules.test.ts @@ -0,0 +1,379 @@ +import { beforeEach, describe, expect, test } from "vitest"; + +import { RuleEngineRule } from "@karakeep/shared/types/rules"; + +import type { CustomTestContext } from "../testUtils"; +import { defaultBeforeEach } from "../testUtils"; + +describe("Rules Routes", () => { + let tagId1: string; + let tagId2: string; + let otherUserTagId: string; + + let listId: string; + let otherUserListId: string; + + beforeEach(async (ctx) => { + await defaultBeforeEach(true)(ctx); + + tagId1 = ( + await ctx.apiCallers[0].tags.create({ + name: "Tag 1", + }) + ).id; + + tagId2 = ( + await ctx.apiCallers[0].tags.create({ + name: "Tag 2", + }) + ).id; + + otherUserTagId = ( + await ctx.apiCallers[1].tags.create({ + name: "Tag 1", + }) + ).id; + + listId = ( + await ctx.apiCallers[0].lists.create({ + name: "List 1", + icon: "😘", + }) + ).id; + + otherUserListId = ( + await ctx.apiCallers[1].lists.create({ + name: "List 1", + icon: "😘", + }) + ).id; + }); + + test("create rule with valid data", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; + + const validRuleInput: Omit = { + name: "Valid Rule", + description: "A test rule", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [ + { type: "addTag", tagId: tagId1 }, + { type: "addToList", listId: listId }, + ], + }; + + const createdRule = await api.create(validRuleInput); + expect(createdRule).toMatchObject({ + name: "Valid Rule", + description: "A test rule", + enabled: true, + event: validRuleInput.event, + condition: validRuleInput.condition, + actions: validRuleInput.actions, + }); + }); + + test("create rule fails with invalid data (no actions)", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Missing actions", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [], // Empty actions array - should fail validation + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /You must specify at least one action/, + ); + }); + + test("create rule fails with invalid event (empty tagId)", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Invalid event", + enabled: true, + event: { type: "tagAdded", tagId: "" }, // Empty tagId - should fail + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /You must specify a tag for this event type/, + ); + }); + + test("create rule fails with invalid condition (empty tagId in hasTag)", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Invalid condition", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "hasTag", tagId: "" }, // Empty tagId - should fail + actions: [{ type: "addTag", tagId: tagId1 }], + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /You must specify a tag for this condition type/, + ); + }); + + test("update rule with valid data", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; + + // First, create a rule + const createdRule = await api.create({ + name: "Original Rule", + description: "Original desc", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }); + + const validUpdateInput: RuleEngineRule = { + id: createdRule.id, + name: "Updated Rule", + description: "Updated desc", + enabled: false, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "removeTag", tagId: tagId2 }], + }; + + const updatedRule = await api.update(validUpdateInput); + expect(updatedRule).toMatchObject({ + id: createdRule.id, + name: "Updated Rule", + description: "Updated desc", + enabled: false, + event: validUpdateInput.event, + condition: validUpdateInput.condition, + actions: validUpdateInput.actions, + }); + }); + + test("update rule fails with invalid data (empty action tagId)", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; + + // First, create a rule + const createdRule = await api.create({ + name: "Original Rule", + description: "Original desc", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }); + + const invalidUpdateInput: RuleEngineRule = { + id: createdRule.id, + name: "Updated Rule", + description: "Updated desc", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "removeTag", tagId: "" }], // Empty tagId - should fail + }; + + await expect(() => api.update(invalidUpdateInput)).rejects.toThrow( + /You must specify a tag for this action type/, + ); + }); + + test("delete rule", async ({ apiCallers }) => { + const api = apiCallers[0].rules; + + const createdRule = await api.create({ + name: "Rule to Delete", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }); + + await api.delete({ id: createdRule.id }); + + // Attempt to fetch the rule should fail + await expect(() => + api.update({ ...createdRule, name: "Updated" }), + ).rejects.toThrow(/Rule not found/); + }); + + test("list rules", async ({ apiCallers }) => { + const api = apiCallers[0].rules; + + await api.create({ + name: "Rule 1", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }); + + await api.create({ + name: "Rule 2", + description: "", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId2 }], + }); + + const rulesList = await api.list(); + expect(rulesList.rules.length).toBeGreaterThanOrEqual(2); + expect(rulesList.rules.some((rule) => rule.name === "Rule 1")).toBeTruthy(); + expect(rulesList.rules.some((rule) => rule.name === "Rule 2")).toBeTruthy(); + }); + + describe("privacy checks", () => { + test("cannot access or manipulate another user's rule", async ({ + apiCallers, + }) => { + const apiUserA = apiCallers[0].rules; // First user + const apiUserB = apiCallers[1].rules; // Second user + + // User A creates a rule + const createdRule = await apiUserA.create({ + name: "User A's Rule", + description: "A rule for User A", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }); + + // User B tries to update User A's rule + const updateInput: RuleEngineRule = { + id: createdRule.id, + name: "Trying to Update", + description: "Unauthorized update", + enabled: true, + event: createdRule.event, + condition: createdRule.condition, + actions: createdRule.actions, + }; + + await expect(() => apiUserB.update(updateInput)).rejects.toThrow( + /Rule not found/, + ); + }); + + test("cannot create rule with event on another user's tag", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; // First user trying to use second user's tag + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Event with other user's tag", + enabled: true, + event: { type: "tagAdded", tagId: otherUserTagId }, // Other user's tag + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /Tag not found/, // Expect an error indicating lack of ownership + ); + }); + + test("cannot create rule with condition on another user's tag", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; // First user trying to use second user's tag + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Condition with other user's tag", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "hasTag", tagId: otherUserTagId }, // Other user's tag + actions: [{ type: "addTag", tagId: tagId1 }], + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /Tag not found/, + ); + }); + + test("cannot create rule with action on another user's tag", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; // First user trying to use second user's tag + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Action with other user's tag", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: otherUserTagId }], // Other user's tag + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /Tag not found/, + ); + }); + + test("cannot create rule with event on another user's list", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; // First user trying to use second user's list + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Event with other user's list", + enabled: true, + event: { type: "addedToList", listId: otherUserListId }, // Other user's list + condition: { type: "alwaysTrue" }, + actions: [{ type: "addTag", tagId: tagId1 }], + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /List not found/, + ); + }); + + test("cannot create rule with action on another user's list", async ({ + apiCallers, + }) => { + const api = apiCallers[0].rules; // First user trying to use second user's list + + const invalidRuleInput: Omit = { + name: "Invalid Rule", + description: "Action with other user's list", + enabled: true, + event: { type: "bookmarkAdded" }, + condition: { type: "alwaysTrue" }, + actions: [{ type: "addToList", listId: otherUserListId }], // Other user's list + }; + + await expect(() => api.create(invalidRuleInput)).rejects.toThrow( + /List not found/, + ); + }); + }); +}); diff --git a/packages/trpc/routers/rules.ts b/packages/trpc/routers/rules.ts new file mode 100644 index 00000000..5def8003 --- /dev/null +++ b/packages/trpc/routers/rules.ts @@ -0,0 +1,120 @@ +import { experimental_trpcMiddleware, TRPCError } from "@trpc/server"; +import { and, eq, inArray } from "drizzle-orm"; +import { z } from "zod"; + +import { bookmarkTags } from "@karakeep/db/schema"; +import { + RuleEngineRule, + zNewRuleEngineRuleSchema, + zRuleEngineRuleSchema, + zUpdateRuleEngineRuleSchema, +} from "@karakeep/shared/types/rules"; + +import { AuthedContext, authedProcedure, router } from "../index"; +import { List } from "../models/lists"; +import { RuleEngineRuleModel } from "../models/rules"; + +const ensureRuleOwnership = experimental_trpcMiddleware<{ + ctx: AuthedContext; + input: { id: string }; +}>().create(async (opts) => { + const rule = await RuleEngineRuleModel.fromId(opts.ctx, opts.input.id); + return opts.next({ + ctx: { + ...opts.ctx, + rule, + }, + }); +}); + +const ensureTagListOwnership = experimental_trpcMiddleware<{ + ctx: AuthedContext; + input: Omit; +}>().create(async (opts) => { + const tagIds = [ + ...(opts.input.event.type === "tagAdded" || + opts.input.event.type === "tagRemoved" + ? [opts.input.event.tagId] + : []), + ...(opts.input.condition.type === "hasTag" + ? [opts.input.condition.tagId] + : []), + ...opts.input.actions.flatMap((a) => + a.type == "addTag" || a.type == "removeTag" ? [a.tagId] : [], + ), + ]; + + const validateTags = async () => { + if (tagIds.length == 0) { + return; + } + const userTags = await opts.ctx.db.query.bookmarkTags.findMany({ + where: and( + eq(bookmarkTags.userId, opts.ctx.user.id), + inArray(bookmarkTags.id, tagIds), + ), + columns: { + id: true, + }, + }); + if (tagIds.some((t) => userTags.find((u) => u.id == t) == null)) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "Tag not found", + }); + } + }; + + const listIds = [ + ...(opts.input.event.type === "addedToList" || + opts.input.event.type === "removedFromList" + ? [opts.input.event.listId] + : []), + ...opts.input.actions.flatMap((a) => + a.type == "addToList" || a.type == "removeFromList" ? [a.listId] : [], + ), + ]; + + const [_tags, _lists] = await Promise.all([ + validateTags(), + Promise.all(listIds.map((l) => List.fromId(opts.ctx, l))), + ]); + return opts.next(); +}); + +export const rulesAppRouter = router({ + create: authedProcedure + .input(zNewRuleEngineRuleSchema) + .output(zRuleEngineRuleSchema) + .use(ensureTagListOwnership) + .mutation(async ({ input, ctx }) => { + const newRule = await RuleEngineRuleModel.create(ctx, input); + return newRule.rule; + }), + update: authedProcedure + .input(zUpdateRuleEngineRuleSchema) + .output(zRuleEngineRuleSchema) + .use(ensureRuleOwnership) + .use(ensureTagListOwnership) + .mutation(async ({ ctx, input }) => { + await ctx.rule.update(input); + return ctx.rule.rule; + }), + delete: authedProcedure + .input(z.object({ id: z.string() })) + .use(ensureRuleOwnership) + .mutation(async ({ ctx }) => { + await ctx.rule.delete(); + }), + list: authedProcedure + .output( + z.object({ + rules: z.array(zRuleEngineRuleSchema), + }), + ) + .query(async ({ ctx }) => { + return { + rules: (await RuleEngineRuleModel.getAll(ctx)).map((r) => r.rule), + }; + }), +}); diff --git a/packages/trpc/routers/tags.ts b/packages/trpc/routers/tags.ts index cdf47f4f..7f75c16e 100644 --- a/packages/trpc/routers/tags.ts +++ b/packages/trpc/routers/tags.ts @@ -18,7 +18,7 @@ function conditionFromInput(input: { tagId: string }, userId: string) { return and(eq(bookmarkTags.id, input.tagId), eq(bookmarkTags.userId, userId)); } -const ensureTagOwnership = experimental_trpcMiddleware<{ +export const ensureTagOwnership = experimental_trpcMiddleware<{ ctx: Context; input: { tagId: string }; }>().create(async (opts) => { -- cgit v1.2.3-70-g09d2