diff options
Diffstat (limited to 'packages')
| -rw-r--r-- | packages/db/drizzle.ts | 1 | ||||
| -rw-r--r-- | packages/db/index.ts | 1 | ||||
| -rw-r--r-- | packages/trpc/models/lists.ts | 277 | ||||
| -rw-r--r-- | packages/trpc/models/privacy.ts | 5 | ||||
| -rw-r--r-- | packages/trpc/routers/bookmarks.ts | 29 | ||||
| -rw-r--r-- | packages/trpc/routers/lists.ts | 189 |
6 files changed, 308 insertions, 194 deletions
diff --git a/packages/db/drizzle.ts b/packages/db/drizzle.ts index 4763d9d7..5c441bec 100644 --- a/packages/db/drizzle.ts +++ b/packages/db/drizzle.ts @@ -9,6 +9,7 @@ import dbConfig from "./drizzle.config"; const sqlite = new Database(dbConfig.dbCredentials.url); export const db = drizzle(sqlite, { schema }); +export type DB = typeof db; export function getInMemoryDB(runMigrations: boolean) { const mem = new Database(":memory:"); diff --git a/packages/db/index.ts b/packages/db/index.ts index b86665d2..8a33d488 100644 --- a/packages/db/index.ts +++ b/packages/db/index.ts @@ -5,6 +5,7 @@ import { SQLiteTransaction } from "drizzle-orm/sqlite-core"; import * as schema from "./schema"; export { db } from "./drizzle"; +export type { DB } from "./drizzle"; export * as schema from "./schema"; export { SqliteError } from "better-sqlite3"; diff --git a/packages/trpc/models/lists.ts b/packages/trpc/models/lists.ts new file mode 100644 index 00000000..7870bf90 --- /dev/null +++ b/packages/trpc/models/lists.ts @@ -0,0 +1,277 @@ +import { TRPCError } from "@trpc/server"; +import { and, count, eq } from "drizzle-orm"; +import invariant from "tiny-invariant"; +import { z } from "zod"; + +import { SqliteError } from "@hoarder/db"; +import { bookmarkLists, bookmarksInLists } from "@hoarder/db/schema"; +import { parseSearchQuery } from "@hoarder/shared/searchQueryParser"; +import { + ZBookmarkList, + zEditBookmarkListSchemaWithValidation, + zNewBookmarkListSchema, +} from "@hoarder/shared/types/lists"; + +import { AuthedContext } from ".."; +import { getBookmarkIdsFromMatcher } from "../lib/search"; +import { PrivacyAware } from "./privacy"; + +export abstract class List implements PrivacyAware { + protected constructor( + protected ctx: AuthedContext, + public list: ZBookmarkList & { userId: string }, + ) {} + + private static fromData( + ctx: AuthedContext, + data: ZBookmarkList & { userId: string }, + ) { + if (data.type === "smart") { + return new SmartList(ctx, data); + } else { + return new ManualList(ctx, data); + } + } + + static async fromId( + ctx: AuthedContext, + id: string, + ): Promise<ManualList | SmartList> { + const list = await ctx.db.query.bookmarkLists.findFirst({ + where: and( + eq(bookmarkLists.id, id), + eq(bookmarkLists.userId, ctx.user.id), + ), + }); + + if (!list) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "List not found", + }); + } + if (list.type === "smart") { + return new SmartList(ctx, list); + } else { + return new ManualList(ctx, list); + } + } + + static async create( + ctx: AuthedContext, + input: z.infer<typeof zNewBookmarkListSchema>, + ): Promise<ManualList | SmartList> { + const [result] = await ctx.db + .insert(bookmarkLists) + .values({ + name: input.name, + icon: input.icon, + userId: ctx.user.id, + parentId: input.parentId, + type: input.type, + query: input.query, + }) + .returning(); + return this.fromData(ctx, result); + } + + static async getAll(ctx: AuthedContext): Promise<(ManualList | SmartList)[]> { + const lists = await ctx.db.query.bookmarkLists.findMany({ + where: and(eq(bookmarkLists.userId, ctx.user.id)), + }); + return lists.map((l) => this.fromData(ctx, l)); + } + + static async forBookmark(ctx: AuthedContext, bookmarkId: string) { + const lists = await ctx.db.query.bookmarksInLists.findMany({ + where: and(eq(bookmarksInLists.bookmarkId, bookmarkId)), + with: { + list: true, + }, + }); + invariant(lists.map((l) => l.list.userId).every((id) => id == ctx.user.id)); + return lists.map((l) => this.fromData(ctx, l.list)); + } + + ensureCanAccess(ctx: AuthedContext): void { + if (this.list.userId != ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + } + + async delete() { + const res = await this.ctx.db + .delete(bookmarkLists) + .where( + and( + eq(bookmarkLists.id, this.list.id), + eq(bookmarkLists.userId, this.ctx.user.id), + ), + ); + if (res.changes == 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + } + + async update(input: z.infer<typeof zEditBookmarkListSchemaWithValidation>) { + const result = await this.ctx.db + .update(bookmarkLists) + .set({ + name: input.name, + icon: input.icon, + parentId: input.parentId, + query: input.query, + }) + .where( + and( + eq(bookmarkLists.id, this.list.id), + eq(bookmarkLists.userId, this.ctx.user.id), + ), + ) + .returning(); + if (result.length == 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + return result[0]; + } + + abstract get type(): "manual" | "smart"; + abstract getBookmarkIds(ctx: AuthedContext): Promise<string[]>; + abstract getSize(ctx: AuthedContext): Promise<number>; + abstract addBookmark(bookmarkId: string): Promise<void>; + abstract removeBookmark(bookmarkId: string): Promise<void>; +} + +export class SmartList extends List { + parsedQuery: ReturnType<typeof parseSearchQuery> | null = null; + + constructor(ctx: AuthedContext, list: ZBookmarkList & { userId: string }) { + super(ctx, list); + } + + get type(): "smart" { + invariant(this.list.type === "smart"); + return this.list.type; + } + + get query() { + invariant(this.list.query); + return this.list.query; + } + + getParsedQuery() { + if (!this.parsedQuery) { + const result = parseSearchQuery(this.query); + if (result.result !== "full") { + throw new Error("Invalid smart list query"); + } + this.parsedQuery = result; + } + return this.parsedQuery; + } + + async getBookmarkIds(): Promise<string[]> { + const parsedQuery = this.getParsedQuery(); + if (!parsedQuery.matcher) { + return []; + } + return await getBookmarkIdsFromMatcher(this.ctx, parsedQuery.matcher); + } + + async getSize(): Promise<number> { + return await this.getBookmarkIds().then((ids) => ids.length); + } + + addBookmark(_bookmarkId: string): Promise<void> { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Smart lists cannot be added to", + }); + } + + removeBookmark(_bookmarkId: string): Promise<void> { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Smart lists cannot be removed from", + }); + } +} + +export class ManualList extends List { + constructor(ctx: AuthedContext, list: ZBookmarkList & { userId: string }) { + super(ctx, list); + } + + get type(): "manual" { + invariant(this.list.type === "manual"); + return this.list.type; + } + + async getBookmarkIds(): Promise<string[]> { + const results = await this.ctx.db + .select({ id: bookmarksInLists.bookmarkId }) + .from(bookmarksInLists) + .where(eq(bookmarksInLists.listId, this.list.id)); + return results.map((r) => r.id); + } + + async getSize(): Promise<number> { + const results = await this.ctx.db + .select({ count: count() }) + .from(bookmarksInLists) + .where(eq(bookmarksInLists.listId, this.list.id)); + return results[0].count; + } + + async addBookmark(bookmarkId: string): Promise<void> { + try { + await this.ctx.db.insert(bookmarksInLists).values({ + listId: this.list.id, + bookmarkId, + }); + } catch (e) { + if (e instanceof SqliteError) { + if (e.code == "SQLITE_CONSTRAINT_PRIMARYKEY") { + throw new TRPCError({ + code: "BAD_REQUEST", + message: `Bookmark ${bookmarkId} is already in the list ${this.list.id}`, + }); + } + } + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Something went wrong", + }); + } + } + + async removeBookmark(bookmarkId: string): Promise<void> { + const deleted = await this.ctx.db + .delete(bookmarksInLists) + .where( + and( + eq(bookmarksInLists.listId, this.list.id), + eq(bookmarksInLists.bookmarkId, bookmarkId), + ), + ); + if (deleted.changes == 0) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: `Bookmark ${bookmarkId} is already not in list ${this.list.id}`, + }); + } + } + + async update(input: z.infer<typeof zEditBookmarkListSchemaWithValidation>) { + if (input.query) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Manual lists cannot have a query", + }); + } + return super.update(input); + } +} diff --git a/packages/trpc/models/privacy.ts b/packages/trpc/models/privacy.ts new file mode 100644 index 00000000..e2235f44 --- /dev/null +++ b/packages/trpc/models/privacy.ts @@ -0,0 +1,5 @@ +import { AuthedContext } from ".."; + +export interface PrivacyAware { + ensureCanAccess(ctx: AuthedContext): void; +} diff --git a/packages/trpc/routers/bookmarks.ts b/packages/trpc/routers/bookmarks.ts index 3b2d23ce..63d20625 100644 --- a/packages/trpc/routers/bookmarks.ts +++ b/packages/trpc/routers/bookmarks.ts @@ -26,7 +26,6 @@ import { AssetTypes, bookmarkAssets, bookmarkLinks, - bookmarkLists, bookmarks, bookmarksInLists, bookmarkTags, @@ -70,6 +69,7 @@ import type { AuthedContext, Context } from "../index"; import { authedProcedure, router } from "../index"; import { mapDBAssetTypeToUserType } from "../lib/attachments"; import { getBookmarkIdsFromMatcher } from "../lib/search"; +import { List } from "../models/lists"; import { ensureAssetOwnership } from "./assets"; export const ensureBookmarkOwnership = experimental_trpcMiddleware<{ @@ -652,31 +652,10 @@ export const bookmarksAppRouter = router({ input.limit = DEFAULT_NUM_BOOKMARKS_PER_PAGE; } if (input.listId) { - const list = await ctx.db.query.bookmarkLists.findFirst({ - where: and( - eq(bookmarkLists.id, input.listId), - eq(bookmarkLists.userId, ctx.user.id), - ), - }); - if (!list) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "List not found", - }); - } + const list = await List.fromId(ctx, input.listId); if (list.type === "smart") { - invariant(list.query); - const query = parseSearchQuery(list.query); - if (query.result !== "full") { - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Found an invalid smart list query", - }); - } - if (query.matcher) { - input.ids = await getBookmarkIdsFromMatcher(ctx, query.matcher); - delete input.listId; - } + input.ids = await list.getBookmarkIds(); + delete input.listId; } } diff --git a/packages/trpc/routers/lists.ts b/packages/trpc/routers/lists.ts index ec7cb10f..59441879 100644 --- a/packages/trpc/routers/lists.ts +++ b/packages/trpc/routers/lists.ts @@ -1,51 +1,28 @@ -import assert from "node:assert"; -import { experimental_trpcMiddleware, TRPCError } from "@trpc/server"; -import { and, eq } from "drizzle-orm"; -import invariant from "tiny-invariant"; +import { experimental_trpcMiddleware } from "@trpc/server"; import { z } from "zod"; -import { SqliteError } from "@hoarder/db"; -import { bookmarkLists, bookmarksInLists } from "@hoarder/db/schema"; import { zBookmarkListSchema, zEditBookmarkListSchemaWithValidation, zNewBookmarkListSchema, } from "@hoarder/shared/types/lists"; -import type { Context } from "../index"; +import type { AuthedContext } from "../index"; import { authedProcedure, router } from "../index"; +import { List } from "../models/lists"; import { ensureBookmarkOwnership } from "./bookmarks"; export const ensureListOwnership = experimental_trpcMiddleware<{ - ctx: Context; + ctx: AuthedContext; input: { listId: string }; }>().create(async (opts) => { - const list = await opts.ctx.db.query.bookmarkLists.findFirst({ - where: eq(bookmarkLists.id, opts.input.listId), - columns: { - userId: true, + const list = await List.fromId(opts.ctx, opts.input.listId); + return opts.next({ + ctx: { + ...opts.ctx, + list, }, }); - if (!opts.ctx.user) { - throw new TRPCError({ - code: "UNAUTHORIZED", - message: "User is not authorized", - }); - } - if (!list) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "List not found", - }); - } - if (list.userId != opts.ctx.user.id) { - throw new TRPCError({ - code: "FORBIDDEN", - message: "User is not allowed to access resource", - }); - } - - return opts.next(); }); export const listsAppRouter = router({ @@ -53,59 +30,14 @@ export const listsAppRouter = router({ .input(zNewBookmarkListSchema) .output(zBookmarkListSchema) .mutation(async ({ input, ctx }) => { - const [result] = await ctx.db - .insert(bookmarkLists) - .values({ - name: input.name, - icon: input.icon, - userId: ctx.user.id, - parentId: input.parentId, - type: input.type, - query: input.query, - }) - .returning(); - return result; + return await List.create(ctx, input).then((l) => l.list); }), edit: authedProcedure .input(zEditBookmarkListSchemaWithValidation) .output(zBookmarkListSchema) .use(ensureListOwnership) .mutation(async ({ input, ctx }) => { - if (input.query) { - const list = await ctx.db.query.bookmarkLists.findFirst({ - where: and( - eq(bookmarkLists.id, input.listId), - eq(bookmarkLists.userId, ctx.user.id), - ), - }); - // List must exist given that we passed the ownership check - invariant(list); - if (list.type !== "smart") { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Manual lists cannot have a query", - }); - } - } - const result = await ctx.db - .update(bookmarkLists) - .set({ - name: input.name, - icon: input.icon, - parentId: input.parentId, - query: input.query, - }) - .where( - and( - eq(bookmarkLists.id, input.listId), - eq(bookmarkLists.userId, ctx.user.id), - ), - ) - .returning(); - if (result.length == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - return result[0]; + return await ctx.list.update(input); }), delete: authedProcedure .input( @@ -114,18 +46,8 @@ export const listsAppRouter = router({ }), ) .use(ensureListOwnership) - .mutation(async ({ input, ctx }) => { - const res = await ctx.db - .delete(bookmarkLists) - .where( - and( - eq(bookmarkLists.id, input.listId), - eq(bookmarkLists.userId, ctx.user.id), - ), - ); - if (res.changes == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } + .mutation(async ({ ctx }) => { + await ctx.list.delete(); }), addToList: authedProcedure .input( @@ -137,38 +59,7 @@ export const listsAppRouter = router({ .use(ensureListOwnership) .use(ensureBookmarkOwnership) .mutation(async ({ input, ctx }) => { - const list = await ctx.db.query.bookmarkLists.findFirst({ - where: and( - eq(bookmarkLists.id, input.listId), - eq(bookmarkLists.userId, ctx.user.id), - ), - }); - invariant(list); - if (list.type === "smart") { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Smart lists cannot be added to", - }); - } - try { - await ctx.db.insert(bookmarksInLists).values({ - listId: input.listId, - bookmarkId: input.bookmarkId, - }); - } catch (e) { - if (e instanceof SqliteError) { - if (e.code == "SQLITE_CONSTRAINT_PRIMARYKEY") { - throw new TRPCError({ - code: "BAD_REQUEST", - message: `Bookmark ${input.bookmarkId} is already in the list ${input.listId}`, - }); - } - } - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Something went wrong", - }); - } + await ctx.list.addBookmark(input.bookmarkId); }), removeFromList: authedProcedure .input( @@ -180,20 +71,7 @@ export const listsAppRouter = router({ .use(ensureListOwnership) .use(ensureBookmarkOwnership) .mutation(async ({ input, ctx }) => { - const deleted = await ctx.db - .delete(bookmarksInLists) - .where( - and( - eq(bookmarksInLists.listId, input.listId), - eq(bookmarksInLists.bookmarkId, input.bookmarkId), - ), - ); - if (deleted.changes == 0) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: `Bookmark ${input.bookmarkId} is already not in list ${input.listId}`, - }); - } + await ctx.list.removeBookmark(input.bookmarkId); }), get: authedProcedure .input( @@ -203,25 +81,8 @@ export const listsAppRouter = router({ ) .output(zBookmarkListSchema) .use(ensureListOwnership) - .query(async ({ input, ctx }) => { - const res = await ctx.db.query.bookmarkLists.findFirst({ - where: and( - eq(bookmarkLists.id, input.listId), - eq(bookmarkLists.userId, ctx.user.id), - ), - }); - if (!res) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - - return { - id: res.id, - name: res.name, - icon: res.icon, - parentId: res.parentId, - type: res.type, - query: res.query, - }; + .query(({ ctx }) => { + return ctx.list.list; }), list: authedProcedure .output( @@ -230,11 +91,8 @@ export const listsAppRouter = router({ }), ) .query(async ({ ctx }) => { - const lists = await ctx.db.query.bookmarkLists.findMany({ - where: and(eq(bookmarkLists.userId, ctx.user.id)), - }); - - return { lists }; + const results = await List.getAll(ctx); + return { lists: results.map((l) => l.list) }; }), getListsOfBookmark: authedProcedure .input(z.object({ bookmarkId: z.string() })) @@ -245,14 +103,7 @@ export const listsAppRouter = router({ ) .use(ensureBookmarkOwnership) .query(async ({ input, ctx }) => { - const lists = await ctx.db.query.bookmarksInLists.findMany({ - where: and(eq(bookmarksInLists.bookmarkId, input.bookmarkId)), - with: { - list: true, - }, - }); - assert(lists.map((l) => l.list.userId).every((id) => id == ctx.user.id)); - + const lists = await List.forBookmark(ctx, input.bookmarkId); return { lists: lists.map((l) => l.list) }; }), }); |
