From 99653566f73187631d30cb52a66a982c455c1f9a Mon Sep 17 00:00:00 2001 From: MohamedBassem Date: Sat, 2 Aug 2025 21:41:59 -0700 Subject: refactor: Move webhook, users and tags into models --- packages/trpc/models/tags.ts | 362 +++++++++++++++++ packages/trpc/models/users.ts | 768 ++++++++++++++++++++++++++++++++++++ packages/trpc/models/webhooks.ts | 123 ++++++ packages/trpc/routers/admin.ts | 4 +- packages/trpc/routers/invites.ts | 4 +- packages/trpc/routers/lists.ts | 18 +- packages/trpc/routers/tags.test.ts | 2 +- packages/trpc/routers/tags.ts | 335 ++-------------- packages/trpc/routers/users.test.ts | 262 ++++++++++++ packages/trpc/routers/users.ts | 739 ++-------------------------------- packages/trpc/routers/webhooks.ts | 101 +---- 11 files changed, 1600 insertions(+), 1118 deletions(-) create mode 100644 packages/trpc/models/tags.ts create mode 100644 packages/trpc/models/users.ts create mode 100644 packages/trpc/models/webhooks.ts (limited to 'packages') diff --git a/packages/trpc/models/tags.ts b/packages/trpc/models/tags.ts new file mode 100644 index 00000000..79cd855b --- /dev/null +++ b/packages/trpc/models/tags.ts @@ -0,0 +1,362 @@ +import { TRPCError } from "@trpc/server"; +import { and, eq, inArray, notExists } from "drizzle-orm"; +import { z } from "zod"; + +import type { ZAttachedByEnum } from "@karakeep/shared/types/tags"; +import { SqliteError } from "@karakeep/db"; +import { bookmarkTags, tagsOnBookmarks } from "@karakeep/db/schema"; +import { triggerSearchReindex } from "@karakeep/shared/queues"; +import { + zCreateTagRequestSchema, + zGetTagResponseSchema, + zTagBasicSchema, + zUpdateTagRequestSchema, +} from "@karakeep/shared/types/tags"; + +import { AuthedContext } from ".."; +import { PrivacyAware } from "./privacy"; + +export class Tag implements PrivacyAware { + constructor( + protected ctx: AuthedContext, + public tag: typeof bookmarkTags.$inferSelect, + ) {} + + static async fromId(ctx: AuthedContext, id: string): Promise { + const tag = await ctx.db.query.bookmarkTags.findFirst({ + where: eq(bookmarkTags.id, id), + }); + + if (!tag) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "Tag not found", + }); + } + + // If it exists but belongs to another user, throw forbidden error + if (tag.userId !== ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + + return new Tag(ctx, tag); + } + + static async create( + ctx: AuthedContext, + input: z.infer, + ): Promise { + try { + const [result] = await ctx.db + .insert(bookmarkTags) + .values({ + name: input.name, + userId: ctx.user.id, + }) + .returning(); + + return new Tag(ctx, result); + } catch (e) { + if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Tag name already exists for this user.", + }); + } + throw e; + } + } + + static async getAll(ctx: AuthedContext): Promise { + const tags = await ctx.db.query.bookmarkTags.findMany({ + where: eq(bookmarkTags.userId, ctx.user.id), + }); + + return tags.map((t) => new Tag(ctx, t)); + } + + static async getAllWithStats(ctx: AuthedContext) { + const tags = await ctx.db.query.bookmarkTags.findMany({ + where: eq(bookmarkTags.userId, ctx.user.id), + with: { + tagsOnBookmarks: { + columns: { + attachedBy: true, + }, + }, + }, + }); + + return tags.map(({ tagsOnBookmarks, ...rest }) => ({ + ...rest, + numBookmarks: tagsOnBookmarks.length, + numBookmarksByAttachedType: tagsOnBookmarks.reduce< + Record + >( + (acc, curr) => { + if (curr.attachedBy) { + acc[curr.attachedBy]++; + } + return acc; + }, + { ai: 0, human: 0 }, + ), + })); + } + + static async deleteUnused(ctx: AuthedContext): Promise { + const res = await ctx.db + .delete(bookmarkTags) + .where( + and( + eq(bookmarkTags.userId, ctx.user.id), + notExists( + ctx.db + .select({ id: tagsOnBookmarks.tagId }) + .from(tagsOnBookmarks) + .where(eq(tagsOnBookmarks.tagId, bookmarkTags.id)), + ), + ), + ); + return res.changes; + } + + static async merge( + ctx: AuthedContext, + input: { + intoTagId: string; + fromTagIds: string[]; + }, + ): Promise<{ + mergedIntoTagId: string; + deletedTags: string[]; + }> { + const requestedTags = new Set([input.intoTagId, ...input.fromTagIds]); + if (requestedTags.size === 0) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "No tags provided", + }); + } + if (input.fromTagIds.includes(input.intoTagId)) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Cannot merge tag into itself", + }); + } + + const affectedTags = await ctx.db.query.bookmarkTags.findMany({ + where: and( + eq(bookmarkTags.userId, ctx.user.id), + inArray(bookmarkTags.id, [...requestedTags]), + ), + columns: { + id: true, + userId: true, + }, + }); + + if (affectedTags.some((t) => t.userId !== ctx.user.id)) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + if (affectedTags.length !== requestedTags.size) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "One or more tags not found", + }); + } + + const { deletedTags, affectedBookmarks } = await ctx.db.transaction( + async (trx) => { + const unlinked = await trx + .delete(tagsOnBookmarks) + .where(and(inArray(tagsOnBookmarks.tagId, input.fromTagIds))) + .returning(); + + if (unlinked.length > 0) { + await trx + .insert(tagsOnBookmarks) + .values( + unlinked.map((u) => ({ + ...u, + tagId: input.intoTagId, + })), + ) + .onConflictDoNothing(); + } + + const deletedTags = await trx + .delete(bookmarkTags) + .where( + and( + inArray(bookmarkTags.id, input.fromTagIds), + eq(bookmarkTags.userId, ctx.user.id), + ), + ) + .returning({ id: bookmarkTags.id }); + + return { + deletedTags, + affectedBookmarks: unlinked.map((u) => u.bookmarkId), + }; + }, + ); + + try { + await Promise.all( + affectedBookmarks.map((id) => triggerSearchReindex(id)), + ); + } catch (e) { + console.error("Failed to reindex affected bookmarks", e); + } + + return { + deletedTags: deletedTags.map((t) => t.id), + mergedIntoTagId: input.intoTagId, + }; + } + + ensureCanAccess(ctx: AuthedContext): void { + if (this.tag.userId !== ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + } + + async delete(): Promise { + const affectedBookmarks = await this.ctx.db + .select({ + bookmarkId: tagsOnBookmarks.bookmarkId, + }) + .from(tagsOnBookmarks) + .where(eq(tagsOnBookmarks.tagId, this.tag.id)); + + const res = await this.ctx.db + .delete(bookmarkTags) + .where( + and( + eq(bookmarkTags.id, this.tag.id), + eq(bookmarkTags.userId, this.ctx.user.id), + ), + ); + + if (res.changes === 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + + await Promise.all( + affectedBookmarks.map(({ bookmarkId }) => + triggerSearchReindex(bookmarkId), + ), + ); + } + + async update(input: z.infer): Promise { + try { + const result = await this.ctx.db + .update(bookmarkTags) + .set({ + name: input.name, + }) + .where( + and( + eq(bookmarkTags.id, this.tag.id), + eq(bookmarkTags.userId, this.ctx.user.id), + ), + ) + .returning(); + + if (result.length === 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + + this.tag = result[0]; + + try { + const affectedBookmarks = + await this.ctx.db.query.tagsOnBookmarks.findMany({ + where: eq(tagsOnBookmarks.tagId, this.tag.id), + columns: { + bookmarkId: true, + }, + }); + await Promise.all( + affectedBookmarks + .map((b) => b.bookmarkId) + .map((id) => triggerSearchReindex(id)), + ); + } catch (e) { + console.error("Failed to reindex affected bookmarks", e); + } + } catch (e) { + if (e instanceof SqliteError) { + if (e.code === "SQLITE_CONSTRAINT_UNIQUE") { + throw new TRPCError({ + code: "BAD_REQUEST", + message: + "Tag name already exists. You might want to consider a merge instead.", + }); + } + } + throw e; + } + } + + async getStats(): Promise> { + const res = await this.ctx.db + .select({ + id: bookmarkTags.id, + name: bookmarkTags.name, + attachedBy: tagsOnBookmarks.attachedBy, + }) + .from(bookmarkTags) + .leftJoin(tagsOnBookmarks, eq(bookmarkTags.id, tagsOnBookmarks.tagId)) + .where( + and( + eq(bookmarkTags.id, this.tag.id), + eq(bookmarkTags.userId, this.ctx.user.id), + ), + ); + + if (res.length === 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + + const numBookmarksByAttachedType = res.reduce< + Record + >( + (acc, curr) => { + if (curr.attachedBy) { + acc[curr.attachedBy]++; + } + return acc; + }, + { ai: 0, human: 0 }, + ); + + return { + id: res[0].id, + name: res[0].name, + numBookmarks: Object.values(numBookmarksByAttachedType).reduce( + (s, a) => s + a, + 0, + ), + numBookmarksByAttachedType, + }; + } + + asBasicTag(): z.infer { + return { + id: this.tag.id, + name: this.tag.name, + }; + } +} diff --git a/packages/trpc/models/users.ts b/packages/trpc/models/users.ts new file mode 100644 index 00000000..e6d443a7 --- /dev/null +++ b/packages/trpc/models/users.ts @@ -0,0 +1,768 @@ +import { randomBytes } from "crypto"; +import { TRPCError } from "@trpc/server"; +import { and, count, desc, eq, gte, sql } from "drizzle-orm"; +import invariant from "tiny-invariant"; +import { z } from "zod"; + +import { SqliteError } from "@karakeep/db"; +import { + assets, + bookmarkLinks, + bookmarkLists, + bookmarks, + bookmarkTags, + highlights, + passwordResetTokens, + tagsOnBookmarks, + users, + userSettings, + verificationTokens, +} from "@karakeep/db/schema"; +import { deleteUserAssets } from "@karakeep/shared/assetdb"; +import serverConfig from "@karakeep/shared/config"; +import { + zResetPasswordSchema, + zSignUpSchema, + zUpdateUserSettingsSchema, + zUserSettingsSchema, + zUserStatsResponseSchema, + zWhoAmIResponseSchema, +} from "@karakeep/shared/types/users"; + +import { AuthedContext, Context } from ".."; +import { generatePasswordSalt, hashPassword, validatePassword } from "../auth"; +import { sendPasswordResetEmail, sendVerificationEmail } from "../email"; +import { PrivacyAware } from "./privacy"; + +export class User implements PrivacyAware { + constructor( + protected ctx: AuthedContext, + public user: typeof users.$inferSelect, + ) {} + + static async fromId_DANGEROUS(ctx: AuthedContext, id: string): Promise { + const user = await ctx.db.query.users.findFirst({ + where: eq(users.id, id), + }); + + if (!user) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "User not found", + }); + } + + return new User(ctx, user); + } + + static async fromCtx(ctx: AuthedContext): Promise { + return this.fromId_DANGEROUS(ctx, ctx.user.id); + } + + static async create( + ctx: Context, + input: z.infer, + role?: "user" | "admin", + ) { + const salt = generatePasswordSalt(); + const user = await User.createRaw(ctx.db, { + name: input.name, + email: input.email, + password: await hashPassword(input.password, salt), + salt, + role, + }); + + if (serverConfig.auth.emailVerificationRequired) { + const token = await User.genEmailVerificationToken(ctx.db, input.email); + try { + await sendVerificationEmail(input.email, input.name, token); + } catch (error) { + console.error("Failed to send verification email:", error); + } + } + + return user; + } + + static async createRaw( + db: Context["db"], + input: { + name: string; + email: string; + password?: string; + salt?: string; + role?: "user" | "admin"; + emailVerified?: Date | null; + }, + ) { + return await db.transaction(async (trx) => { + let userRole = input.role; + if (!userRole) { + const [{ count: userCount }] = await trx + .select({ count: count() }) + .from(users); + userRole = userCount === 0 ? "admin" : "user"; + } + + try { + const [result] = await trx + .insert(users) + .values({ + name: input.name, + email: input.email, + password: input.password, + salt: input.salt, + role: userRole, + emailVerified: input.emailVerified, + bookmarkQuota: serverConfig.quotas.free.bookmarkLimit, + storageQuota: serverConfig.quotas.free.assetSizeBytes, + }) + .returning(); + + await trx.insert(userSettings).values({ + userId: result.id, + }); + + return result; + } catch (e) { + if (e instanceof SqliteError) { + if (e.code === "SQLITE_CONSTRAINT_UNIQUE") { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Email is already taken", + }); + } + } + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Something went wrong", + }); + } + }); + } + + static async getAll(ctx: AuthedContext): Promise { + const dbUsers = await ctx.db + .select({ + id: users.id, + name: users.name, + email: users.email, + role: users.role, + password: users.password, + bookmarkQuota: users.bookmarkQuota, + storageQuota: users.storageQuota, + emailVerified: users.emailVerified, + image: users.image, + salt: users.salt, + browserCrawlingEnabled: users.browserCrawlingEnabled, + }) + .from(users); + + return dbUsers.map((u) => new User(ctx, u)); + } + + static async genEmailVerificationToken( + db: Context["db"], + email: string, + ): Promise { + const token = randomBytes(10).toString("hex"); + const expires = new Date(Date.now() + 24 * 60 * 60 * 1000); // 24 hours + + await db.insert(verificationTokens).values({ + identifier: email, + token, + expires, + }); + + return token; + } + + static async verifyEmailToken( + db: Context["db"], + email: string, + token: string, + ): Promise { + const verificationToken = await db.query.verificationTokens.findFirst({ + where: (vt, { and, eq }) => + and(eq(vt.identifier, email), eq(vt.token, token)), + }); + + if (!verificationToken) { + return false; + } + + if (verificationToken.expires < new Date()) { + await db + .delete(verificationTokens) + .where( + and( + eq(verificationTokens.identifier, email), + eq(verificationTokens.token, token), + ), + ); + return false; + } + + await db + .delete(verificationTokens) + .where( + and( + eq(verificationTokens.identifier, email), + eq(verificationTokens.token, token), + ), + ); + + return true; + } + + static async verifyEmail( + ctx: Context, + email: string, + token: string, + ): Promise { + const isValid = await User.verifyEmailToken(ctx.db, email, token); + if (!isValid) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Invalid or expired verification token", + }); + } + + const result = await ctx.db + .update(users) + .set({ emailVerified: new Date() }) + .where(eq(users.email, email)); + + if (result.changes === 0) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "User not found", + }); + } + } + + static async resendVerificationEmail( + ctx: Context, + email: string, + ): Promise { + if ( + !serverConfig.auth.emailVerificationRequired || + !serverConfig.email.smtp + ) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Email verification is not enabled", + }); + } + + const user = await ctx.db.query.users.findFirst({ + where: eq(users.email, email), + }); + + if (!user) { + return; // Don't reveal if user exists or not for security + } + + if (user.emailVerified) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Email is already verified", + }); + } + + const token = await User.genEmailVerificationToken(ctx.db, email); + try { + await sendVerificationEmail(email, user.name, token); + } catch (error) { + console.error("Failed to send verification email:", error); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to send verification email", + }); + } + } + + static async forgotPassword(ctx: Context, email: string): Promise { + if (!serverConfig.email.smtp) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Email service is not configured", + }); + } + + const user = await ctx.db.query.users.findFirst({ + where: eq(users.email, email), + }); + + if (!user || !user.password) { + return; // Don't reveal if user exists or not for security + } + + try { + const token = randomBytes(32).toString("hex"); + const expires = new Date(Date.now() + 60 * 60 * 1000); // 1 hour + + await ctx.db.insert(passwordResetTokens).values({ + userId: user.id, + token, + expires, + }); + + await sendPasswordResetEmail(email, user.name, token); + } catch (error) { + console.error("Failed to send password reset email:", error); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: "Failed to send password reset email", + }); + } + } + + static async resetPassword( + ctx: Context, + input: z.infer, + ): Promise { + const resetToken = await ctx.db.query.passwordResetTokens.findFirst({ + where: eq(passwordResetTokens.token, input.token), + with: { + user: { + columns: { + id: true, + }, + }, + }, + }); + + if (!resetToken) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Invalid or expired reset token", + }); + } + + if (resetToken.expires < new Date()) { + await ctx.db + .delete(passwordResetTokens) + .where(eq(passwordResetTokens.token, input.token)); + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Invalid or expired reset token", + }); + } + + if (!resetToken.user) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "User not found", + }); + } + + const newSalt = generatePasswordSalt(); + const hashedPassword = await hashPassword(input.newPassword, newSalt); + + await ctx.db + .update(users) + .set({ + password: hashedPassword, + salt: newSalt, + }) + .where(eq(users.id, resetToken.user.id)); + + await ctx.db + .delete(passwordResetTokens) + .where(eq(passwordResetTokens.token, input.token)); + } + + ensureCanAccess(ctx: AuthedContext): void { + if (this.user.id !== ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + } + + private static async deleteInternal(db: Context["db"], userId: string) { + const res = await db.delete(users).where(eq(users.id, userId)); + + if (res.changes === 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + + await deleteUserAssets({ userId: userId }); + } + + static async deleteAsAdmin( + adminCtx: AuthedContext, + userId: string, + ): Promise { + invariant(adminCtx.user.role === "admin", "Only admins can delete users"); + await this.deleteInternal(adminCtx.db, userId); + } + + async deleteAccount(password?: string): Promise { + invariant(this.ctx.user.email, "A user always has an email specified"); + + if (this.user.password) { + if (!password) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "Password is required for local accounts", + }); + } + + try { + await validatePassword(this.ctx.user.email, password, this.ctx.db); + } catch { + throw new TRPCError({ + code: "UNAUTHORIZED", + message: "Invalid password", + }); + } + } + + await User.deleteInternal(this.ctx.db, this.user.id); + } + + async changePassword( + currentPassword: string, + newPassword: string, + ): Promise { + invariant(this.ctx.user.email, "A user always has an email specified"); + + try { + const user = await validatePassword( + this.ctx.user.email, + currentPassword, + this.ctx.db, + ); + invariant(user.id === this.ctx.user.id); + } catch { + throw new TRPCError({ code: "UNAUTHORIZED" }); + } + + const newSalt = generatePasswordSalt(); + await this.ctx.db + .update(users) + .set({ + password: await hashPassword(newPassword, newSalt), + salt: newSalt, + }) + .where(eq(users.id, this.user.id)); + } + + async getSettings(): Promise> { + const settings = await this.ctx.db.query.userSettings.findFirst({ + where: eq(userSettings.userId, this.user.id), + }); + + if (!settings) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "User settings not found", + }); + } + + return { + bookmarkClickAction: settings.bookmarkClickAction, + archiveDisplayBehaviour: settings.archiveDisplayBehaviour, + timezone: settings.timezone || "UTC", + }; + } + + async updateSettings( + input: z.infer, + ): Promise { + if (Object.keys(input).length === 0) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: "No settings provided", + }); + } + + await this.ctx.db + .update(userSettings) + .set({ + bookmarkClickAction: input.bookmarkClickAction, + archiveDisplayBehaviour: input.archiveDisplayBehaviour, + timezone: input.timezone, + }) + .where(eq(userSettings.userId, this.user.id)); + } + + async getStats(): Promise> { + const userSet = await this.ctx.db.query.userSettings.findFirst({ + where: eq(userSettings.userId, this.user.id), + }); + const userTimezone = userSet?.timezone || "UTC"; + const now = new Date(); + const weekAgo = new Date(now.getTime() - 7 * 24 * 60 * 60 * 1000); + const monthAgo = new Date(now.getTime() - 30 * 24 * 60 * 60 * 1000); + const yearAgo = new Date(now.getTime() - 365 * 24 * 60 * 60 * 1000); + + const [ + [{ numBookmarks }], + [{ numFavorites }], + [{ numArchived }], + [{ numTags }], + [{ numLists }], + [{ numHighlights }], + bookmarksByType, + topDomains, + [{ totalAssetSize }], + assetsByType, + [{ thisWeek }], + [{ thisMonth }], + [{ thisYear }], + bookmarkTimestamps, + tagUsage, + ] = await Promise.all([ + // Basic counts + this.ctx.db + .select({ numBookmarks: count() }) + .from(bookmarks) + .where(eq(bookmarks.userId, this.user.id)), + this.ctx.db + .select({ numFavorites: count() }) + .from(bookmarks) + .where( + and( + eq(bookmarks.userId, this.user.id), + eq(bookmarks.favourited, true), + ), + ), + this.ctx.db + .select({ numArchived: count() }) + .from(bookmarks) + .where( + and(eq(bookmarks.userId, this.user.id), eq(bookmarks.archived, true)), + ), + this.ctx.db + .select({ numTags: count() }) + .from(bookmarkTags) + .where(eq(bookmarkTags.userId, this.user.id)), + this.ctx.db + .select({ numLists: count() }) + .from(bookmarkLists) + .where(eq(bookmarkLists.userId, this.user.id)), + this.ctx.db + .select({ numHighlights: count() }) + .from(highlights) + .where(eq(highlights.userId, this.user.id)), + + // Bookmarks by type + this.ctx.db + .select({ + type: bookmarks.type, + count: count(), + }) + .from(bookmarks) + .where(eq(bookmarks.userId, this.user.id)) + .groupBy(bookmarks.type), + + // Top domains + this.ctx.db + .select({ + domain: sql`CASE + WHEN ${bookmarkLinks.url} LIKE 'https://%' THEN + CASE + WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') > 0 THEN + SUBSTR(${bookmarkLinks.url}, 9, INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') - 1) + ELSE + SUBSTR(${bookmarkLinks.url}, 9) + END + WHEN ${bookmarkLinks.url} LIKE 'http://%' THEN + CASE + WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') > 0 THEN + SUBSTR(${bookmarkLinks.url}, 8, INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') - 1) + ELSE + SUBSTR(${bookmarkLinks.url}, 8) + END + ELSE + CASE + WHEN INSTR(${bookmarkLinks.url}, '/') > 0 THEN + SUBSTR(${bookmarkLinks.url}, 1, INSTR(${bookmarkLinks.url}, '/') - 1) + ELSE + ${bookmarkLinks.url} + END + END`, + count: count(), + }) + .from(bookmarkLinks) + .innerJoin(bookmarks, eq(bookmarks.id, bookmarkLinks.id)) + .where(eq(bookmarks.userId, this.user.id)) + .groupBy( + sql`CASE + WHEN ${bookmarkLinks.url} LIKE 'https://%' THEN + CASE + WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') > 0 THEN + SUBSTR(${bookmarkLinks.url}, 9, INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') - 1) + ELSE + SUBSTR(${bookmarkLinks.url}, 9) + END + WHEN ${bookmarkLinks.url} LIKE 'http://%' THEN + CASE + WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') > 0 THEN + SUBSTR(${bookmarkLinks.url}, 8, INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') - 1) + ELSE + SUBSTR(${bookmarkLinks.url}, 8) + END + ELSE + CASE + WHEN INSTR(${bookmarkLinks.url}, '/') > 0 THEN + SUBSTR(${bookmarkLinks.url}, 1, INSTR(${bookmarkLinks.url}, '/') - 1) + ELSE + ${bookmarkLinks.url} + END + END`, + ) + .orderBy(desc(count())) + .limit(10), + + // Total asset size + this.ctx.db + .select({ + totalAssetSize: sql`COALESCE(SUM(${assets.size}), 0)`, + }) + .from(assets) + .where(eq(assets.userId, this.user.id)), + + // Assets by type + this.ctx.db + .select({ + type: assets.assetType, + count: count(), + totalSize: sql`COALESCE(SUM(${assets.size}), 0)`, + }) + .from(assets) + .where(eq(assets.userId, this.user.id)) + .groupBy(assets.assetType), + + // Activity stats + this.ctx.db + .select({ thisWeek: count() }) + .from(bookmarks) + .where( + and( + eq(bookmarks.userId, this.user.id), + gte(bookmarks.createdAt, weekAgo), + ), + ), + this.ctx.db + .select({ thisMonth: count() }) + .from(bookmarks) + .where( + and( + eq(bookmarks.userId, this.user.id), + gte(bookmarks.createdAt, monthAgo), + ), + ), + this.ctx.db + .select({ thisYear: count() }) + .from(bookmarks) + .where( + and( + eq(bookmarks.userId, this.user.id), + gte(bookmarks.createdAt, yearAgo), + ), + ), + + // Get all bookmark timestamps for timezone conversion + this.ctx.db + .select({ + createdAt: bookmarks.createdAt, + }) + .from(bookmarks) + .where(eq(bookmarks.userId, this.user.id)), + + // Tag usage + this.ctx.db + .select({ + name: bookmarkTags.name, + count: count(), + }) + .from(bookmarkTags) + .innerJoin(tagsOnBookmarks, eq(tagsOnBookmarks.tagId, bookmarkTags.id)) + .where(eq(bookmarkTags.userId, this.user.id)) + .groupBy(bookmarkTags.name) + .orderBy(desc(count())) + .limit(10), + ]); + + // Process bookmarks by type + const bookmarkTypeMap = { link: 0, text: 0, asset: 0 }; + bookmarksByType.forEach((item) => { + if (item.type in bookmarkTypeMap) { + bookmarkTypeMap[item.type as keyof typeof bookmarkTypeMap] = item.count; + } + }); + + // Process timestamps with user timezone + const hourCounts = Array.from({ length: 24 }, () => 0); + const dayCounts = Array.from({ length: 7 }, () => 0); + + bookmarkTimestamps.forEach(({ createdAt }) => { + if (createdAt) { + const date = new Date(createdAt); + const userDate = new Date( + date.toLocaleString("en-US", { timeZone: userTimezone }), + ); + + const hour = userDate.getHours(); + const day = userDate.getDay(); + + hourCounts[hour]++; + dayCounts[day]++; + } + }); + + const hourlyActivity = Array.from({ length: 24 }, (_, i) => ({ + hour: i, + count: hourCounts[i], + })); + + const dailyActivity = Array.from({ length: 7 }, (_, i) => ({ + day: i, + count: dayCounts[i], + })); + + return { + numBookmarks, + numFavorites, + numArchived, + numTags, + numLists, + numHighlights, + bookmarksByType: bookmarkTypeMap, + topDomains: topDomains.filter((d) => d.domain && d.domain.length > 0), + totalAssetSize: totalAssetSize || 0, + assetsByType, + bookmarkingActivity: { + thisWeek: thisWeek || 0, + thisMonth: thisMonth || 0, + thisYear: thisYear || 0, + byHour: hourlyActivity, + byDayOfWeek: dailyActivity, + }, + tagUsage, + }; + } + + asWhoAmI(): z.infer { + return { + id: this.user.id, + name: this.user.name, + email: this.user.email, + localUser: this.user.password !== null, + }; + } + + asPublicUser() { + const { password, salt: _salt, ...rest } = this.user; + return { + ...rest, + localUser: password !== null, + }; + } +} diff --git a/packages/trpc/models/webhooks.ts b/packages/trpc/models/webhooks.ts new file mode 100644 index 00000000..3a8c7bab --- /dev/null +++ b/packages/trpc/models/webhooks.ts @@ -0,0 +1,123 @@ +import { TRPCError } from "@trpc/server"; +import { and, eq } from "drizzle-orm"; +import { z } from "zod"; + +import { webhooksTable } from "@karakeep/db/schema"; +import { + zNewWebhookSchema, + zUpdateWebhookSchema, + zWebhookSchema, +} from "@karakeep/shared/types/webhooks"; + +import { AuthedContext } from ".."; +import { PrivacyAware } from "./privacy"; + +export class Webhook implements PrivacyAware { + constructor( + protected ctx: AuthedContext, + public webhook: typeof webhooksTable.$inferSelect, + ) {} + + static async fromId(ctx: AuthedContext, id: string): Promise { + const webhook = await ctx.db.query.webhooksTable.findFirst({ + where: eq(webhooksTable.id, id), + }); + + if (!webhook) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "Webhook not found", + }); + } + + // If it exists but belongs to another user, throw forbidden error + if (webhook.userId !== ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + + return new Webhook(ctx, webhook); + } + + static async create( + ctx: AuthedContext, + input: z.infer, + ): Promise { + const [result] = await ctx.db + .insert(webhooksTable) + .values({ + url: input.url, + events: input.events, + token: input.token ?? null, + userId: ctx.user.id, + }) + .returning(); + + return new Webhook(ctx, result); + } + + static async getAll(ctx: AuthedContext): Promise { + const webhooks = await ctx.db.query.webhooksTable.findMany({ + where: eq(webhooksTable.userId, ctx.user.id), + }); + + return webhooks.map((w) => new Webhook(ctx, w)); + } + + ensureCanAccess(ctx: AuthedContext): void { + if (this.webhook.userId !== ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + } + + async delete(): Promise { + const res = await this.ctx.db + .delete(webhooksTable) + .where( + and( + eq(webhooksTable.id, this.webhook.id), + eq(webhooksTable.userId, this.ctx.user.id), + ), + ); + + if (res.changes === 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + } + + async update(input: z.infer): Promise { + const result = await this.ctx.db + .update(webhooksTable) + .set({ + url: input.url, + events: input.events, + token: input.token, + }) + .where( + and( + eq(webhooksTable.id, this.webhook.id), + eq(webhooksTable.userId, this.ctx.user.id), + ), + ) + .returning(); + + if (result.length === 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + + this.webhook = result[0]; + } + + asPublicWebhook(): z.infer { + const { token, ...rest } = this.webhook; + return { + ...rest, + hasToken: token !== null, + }; + } +} diff --git a/packages/trpc/routers/admin.ts b/packages/trpc/routers/admin.ts index 1b069b9e..e005c3dd 100644 --- a/packages/trpc/routers/admin.ts +++ b/packages/trpc/routers/admin.ts @@ -23,7 +23,7 @@ import { import { generatePasswordSalt, hashPassword } from "../auth"; import { adminProcedure, router } from "../index"; -import { createUser } from "./users"; +import { User } from "../models/users"; export const adminAppRouter = router({ stats: adminProcedure @@ -334,7 +334,7 @@ export const adminAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - return createUser(input, ctx, input.role); + return await User.create(ctx, input, input.role); }), updateUser: adminProcedure .input(updateUserSchema) diff --git a/packages/trpc/routers/invites.ts b/packages/trpc/routers/invites.ts index e010fd73..026ee4a2 100644 --- a/packages/trpc/routers/invites.ts +++ b/packages/trpc/routers/invites.ts @@ -13,7 +13,7 @@ import { publicProcedure, router, } from "../index"; -import { createUserRaw } from "./users"; +import { User } from "../models/users"; export const invitesAppRouter = router({ create: adminProcedure @@ -186,7 +186,7 @@ export const invitesAppRouter = router({ } const salt = generatePasswordSalt(); - const user = await createUserRaw(ctx.db, { + const user = await User.createRaw(ctx.db, { name: input.name, email: invite.email, password: await hashPassword(input.password, salt), diff --git a/packages/trpc/routers/lists.ts b/packages/trpc/routers/lists.ts index bb949962..92392448 100644 --- a/packages/trpc/routers/lists.ts +++ b/packages/trpc/routers/lists.ts @@ -144,9 +144,9 @@ export const listsAppRouter = router({ token: z.string(), }), ) - .mutation(async ({ input, ctx }) => { - const list = await List.fromId(ctx, input.listId); - const token = await list.regenRssToken(); + .use(ensureListOwnership) + .mutation(async ({ ctx }) => { + const token = await ctx.list.regenRssToken(); return { token: token! }; }), clearRssToken: authedProcedure @@ -155,9 +155,9 @@ export const listsAppRouter = router({ listId: z.string(), }), ) - .mutation(async ({ input, ctx }) => { - const list = await List.fromId(ctx, input.listId); - await list.clearRssToken(); + .use(ensureListOwnership) + .mutation(async ({ ctx }) => { + await ctx.list.clearRssToken(); }), getRssToken: authedProcedure .input( @@ -170,8 +170,8 @@ export const listsAppRouter = router({ token: z.string().nullable(), }), ) - .query(async ({ input, ctx }) => { - const list = await List.fromId(ctx, input.listId); - return { token: await list.getRssToken() }; + .use(ensureListOwnership) + .query(async ({ ctx }) => { + return { token: await ctx.list.getRssToken() }; }), }); diff --git a/packages/trpc/routers/tags.test.ts b/packages/trpc/routers/tags.test.ts index 1e7118d2..a4d690ee 100644 --- a/packages/trpc/routers/tags.test.ts +++ b/packages/trpc/routers/tags.test.ts @@ -47,7 +47,7 @@ describe("Tags Routes", () => { const api = apiCallers[1].tags; await expect(() => api.delete({ tagId: createdTag.id })).rejects.toThrow( - /Tag not found/, + /User is not allowed to access resource/, ); }); diff --git a/packages/trpc/routers/tags.ts b/packages/trpc/routers/tags.ts index cade4b45..c1217cf9 100644 --- a/packages/trpc/routers/tags.ts +++ b/packages/trpc/routers/tags.ts @@ -1,11 +1,6 @@ -import { experimental_trpcMiddleware, TRPCError } from "@trpc/server"; -import { and, eq, inArray, notExists } from "drizzle-orm"; +import { experimental_trpcMiddleware } from "@trpc/server"; import { z } from "zod"; -import type { ZAttachedByEnum } from "@karakeep/shared/types/tags"; -import { SqliteError } from "@karakeep/db"; -import { bookmarkTags, tagsOnBookmarks } from "@karakeep/db/schema"; -import { triggerSearchReindex } from "@karakeep/shared/queues"; import { zCreateTagRequestSchema, zGetTagResponseSchema, @@ -13,44 +8,21 @@ import { zUpdateTagRequestSchema, } from "@karakeep/shared/types/tags"; -import type { Context } from "../index"; +import type { AuthedContext } from "../index"; import { authedProcedure, router } from "../index"; - -function conditionFromInput(input: { tagId: string }, userId: string) { - return and(eq(bookmarkTags.id, input.tagId), eq(bookmarkTags.userId, userId)); -} +import { Tag } from "../models/tags"; export const ensureTagOwnership = experimental_trpcMiddleware<{ - ctx: Context; + ctx: AuthedContext; input: { tagId: string }; }>().create(async (opts) => { - if (!opts.ctx.user) { - throw new TRPCError({ - code: "UNAUTHORIZED", - message: "User is not authorized", - }); - } - const tag = await opts.ctx.db.query.bookmarkTags.findFirst({ - where: conditionFromInput(opts.input, opts.ctx.user.id), - columns: { - userId: true, + const tag = await Tag.fromId(opts.ctx, opts.input.tagId); + return opts.next({ + ctx: { + ...opts.ctx, + tag, }, }); - - if (!tag) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Tag not found", - }); - } - if (tag.userId != opts.ctx.user.id) { - throw new TRPCError({ - code: "FORBIDDEN", - message: "User is not allowed to access resource", - }); - } - - return opts.next(); }); export const tagsAppRouter = router({ @@ -58,28 +30,8 @@ export const tagsAppRouter = router({ .input(zCreateTagRequestSchema) .output(zTagBasicSchema) .mutation(async ({ input, ctx }) => { - try { - const [newTag] = await ctx.db - .insert(bookmarkTags) - .values({ - name: input.name, - userId: ctx.user.id, - }) - .returning(); - - return { - id: newTag.id, - name: newTag.name, - }; - } catch (e) { - if (e instanceof SqliteError && e.code === "SQLITE_CONSTRAINT_UNIQUE") { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Tag name already exists for this user.", - }); - } - throw e; - } + const tag = await Tag.create(ctx, input); + return tag.asBasicTag(); }), get: authedProcedure @@ -90,47 +42,8 @@ export const tagsAppRouter = router({ ) .output(zGetTagResponseSchema) .use(ensureTagOwnership) - .query(async ({ input, ctx }) => { - const res = await ctx.db - .select({ - id: bookmarkTags.id, - name: bookmarkTags.name, - attachedBy: tagsOnBookmarks.attachedBy, - }) - .from(bookmarkTags) - .leftJoin(tagsOnBookmarks, eq(bookmarkTags.id, tagsOnBookmarks.tagId)) - .where( - and( - conditionFromInput(input, ctx.user.id), - eq(bookmarkTags.userId, ctx.user.id), - ), - ); - - if (res.length == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - - const numBookmarksByAttachedType = res.reduce< - Record - >( - (acc, curr) => { - if (curr.attachedBy) { - acc[curr.attachedBy]++; - } - return acc; - }, - { ai: 0, human: 0 }, - ); - - return { - id: res[0].id, - name: res[0].name, - numBookmarks: Object.values(numBookmarksByAttachedType).reduce( - (s, a) => s + a, - 0, - ), - numBookmarksByAttachedType, - }; + .query(async ({ ctx }) => { + return await ctx.tag.getStats(); }), delete: authedProcedure .input( @@ -139,31 +52,8 @@ export const tagsAppRouter = router({ }), ) .use(ensureTagOwnership) - .mutation(async ({ input, ctx }) => { - const affectedBookmarks = await ctx.db - .select({ - bookmarkId: tagsOnBookmarks.bookmarkId, - }) - .from(tagsOnBookmarks) - .where( - and( - eq(tagsOnBookmarks.tagId, input.tagId), - // Tag ownership is checked in the ensureTagOwnership middleware - // eq(bookmarkTags.userId, ctx.user.id), - ), - ); - - const res = await ctx.db - .delete(bookmarkTags) - .where(conditionFromInput(input, ctx.user.id)); - if (res.changes == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - await Promise.all( - affectedBookmarks.map(({ bookmarkId }) => - triggerSearchReindex(bookmarkId), - ), - ); + .mutation(async ({ ctx }) => { + await ctx.tag.delete(); }), deleteUnused: authedProcedure .output( @@ -172,79 +62,16 @@ export const tagsAppRouter = router({ }), ) .mutation(async ({ ctx }) => { - const res = await ctx.db - .delete(bookmarkTags) - .where( - and( - eq(bookmarkTags.userId, ctx.user.id), - notExists( - ctx.db - .select({ id: tagsOnBookmarks.tagId }) - .from(tagsOnBookmarks) - .where(eq(tagsOnBookmarks.tagId, bookmarkTags.id)), - ), - ), - ); - return { deletedTags: res.changes }; + const deletedCount = await Tag.deleteUnused(ctx); + return { deletedTags: deletedCount }; }), update: authedProcedure .input(zUpdateTagRequestSchema) .output(zTagBasicSchema) .use(ensureTagOwnership) .mutation(async ({ input, ctx }) => { - try { - const res = await ctx.db - .update(bookmarkTags) - .set({ - name: input.name, - }) - .where( - and( - eq(bookmarkTags.id, input.tagId), - eq(bookmarkTags.userId, ctx.user.id), - ), - ) - .returning(); - - if (res.length == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - - try { - const affectedBookmarks = await ctx.db.query.tagsOnBookmarks.findMany( - { - where: eq(tagsOnBookmarks.tagId, input.tagId), - columns: { - bookmarkId: true, - }, - }, - ); - await Promise.all( - affectedBookmarks - .map((b) => b.bookmarkId) - .map((id) => triggerSearchReindex(id)), - ); - } catch (e) { - // Best Effort attempt to reindex affected bookmarks - console.error("Failed to reindex affected bookmarks", e); - } - - return { - id: res[0].id, - name: res[0].name, - }; - } catch (e) { - if (e instanceof SqliteError) { - if (e.code == "SQLITE_CONSTRAINT_UNIQUE") { - throw new TRPCError({ - code: "BAD_REQUEST", - message: - "Tag name already exists. You might want to consider a merge instead.", - }); - } - } - throw e; - } + await ctx.tag.update(input); + return ctx.tag.asBasicTag(); }), merge: authedProcedure .input( @@ -260,99 +87,7 @@ export const tagsAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - const requestedTags = new Set([input.intoTagId, ...input.fromTagIds]); - if (requestedTags.size == 0) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "No tags provided", - }); - } - if (input.fromTagIds.includes(input.intoTagId)) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Cannot merge tag into itself", - }); - } - const affectedTags = await ctx.db.query.bookmarkTags.findMany({ - where: and( - eq(bookmarkTags.userId, ctx.user.id), - inArray(bookmarkTags.id, [...requestedTags]), - ), - columns: { - id: true, - userId: true, - }, - }); - if (affectedTags.some((t) => t.userId != ctx.user.id)) { - throw new TRPCError({ - code: "FORBIDDEN", - message: "User is not allowed to access resource", - }); - } - if (affectedTags.length != requestedTags.size) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "One or more tags not found", - }); - } - - const { deletedTags, affectedBookmarks } = await ctx.db.transaction( - async (trx) => { - // Not entirely sure what happens with a racing transaction that adds a to-be-deleted tag on a bookmark. But it's fine for now. - - // NOTE: You can't really do an update here as you might violate the uniquness constraint if the info tag is already attached to the bookmark. - // There's no OnConflict handling for updates in drizzle. - - // Unlink old tags - const unlinked = await trx - .delete(tagsOnBookmarks) - .where(and(inArray(tagsOnBookmarks.tagId, input.fromTagIds))) - .returning(); - - // Re-attach them to the new tag - if (unlinked.length > 0) { - await trx - .insert(tagsOnBookmarks) - .values( - unlinked.map((u) => ({ - ...u, - tagId: input.intoTagId, - })), - ) - .onConflictDoNothing(); - } - - // Delete the old tags - const deletedTags = await trx - .delete(bookmarkTags) - .where( - and( - inArray(bookmarkTags.id, input.fromTagIds), - eq(bookmarkTags.userId, ctx.user.id), - ), - ) - .returning({ id: bookmarkTags.id }); - - return { - deletedTags, - affectedBookmarks: unlinked.map((u) => u.bookmarkId), - }; - }, - ); - - try { - await Promise.all( - affectedBookmarks.map((id) => triggerSearchReindex(id)), - ); - } catch (e) { - // Best Effort attempt to reindex affected bookmarks - console.error("Failed to reindex affected bookmarks", e); - } - - return { - deletedTags: deletedTags.map((t) => t.id), - mergedIntoTagId: input.intoTagId, - }; + return await Tag.merge(ctx, input); }), list: authedProcedure .output( @@ -361,33 +96,7 @@ export const tagsAppRouter = router({ }), ) .query(async ({ ctx }) => { - const tags = await ctx.db.query.bookmarkTags.findMany({ - where: eq(bookmarkTags.userId, ctx.user.id), - with: { - tagsOnBookmarks: { - columns: { - attachedBy: true, - }, - }, - }, - }); - - const resp = tags.map(({ tagsOnBookmarks, ...rest }) => ({ - ...rest, - numBookmarks: tagsOnBookmarks.length, - numBookmarksByAttachedType: tagsOnBookmarks.reduce< - Record - >( - (acc, curr) => { - if (curr.attachedBy) { - acc[curr.attachedBy]++; - } - return acc; - }, - { ai: 0, human: 0 }, - ), - })); - - return { tags: resp }; + const tags = await Tag.getAllWithStats(ctx); + return { tags }; }), }); diff --git a/packages/trpc/routers/users.test.ts b/packages/trpc/routers/users.test.ts index 1c03f47a..3b16e1a4 100644 --- a/packages/trpc/routers/users.test.ts +++ b/packages/trpc/routers/users.test.ts @@ -21,6 +21,10 @@ vi.mock("@karakeep/shared/config", async (original) => { ...mod, default: { ...mod.default, + auth: { + ...mod.default.auth, + emailVerificationRequired: true, + }, email: { smtp: { host: "test-smtp.example.com", @@ -760,4 +764,262 @@ describe("User Routes", () => { ).rejects.toThrow(/Invalid or expired reset token/); }); }); + + describe("Change Password", () => { + test("changePassword - successful change", async ({ + db, + unauthedAPICaller, + }) => { + const user = await unauthedAPICaller.users.create({ + name: "Test User", + email: "changepass@test.com", + password: "oldpass123", + confirmPassword: "oldpass123", + }); + const caller = getApiCaller(db, user.id, user.email, user.role || "user"); + + await caller.users.changePassword({ + currentPassword: "oldpass123", + newPassword: "newpass456", + }); + + // Password change should succeed without throwing + }); + + test("changePassword - wrong current password", async ({ + db, + unauthedAPICaller, + }) => { + const user = await unauthedAPICaller.users.create({ + name: "Test User", + email: "wrongpass@test.com", + password: "oldpass123", + confirmPassword: "oldpass123", + }); + const caller = getApiCaller(db, user.id, user.email, user.role || "user"); + + await expect(() => + caller.users.changePassword({ + currentPassword: "wrongpassword", + newPassword: "newpass456", + }), + ).rejects.toThrow(); + }); + + test("changePassword - OAuth user (no password)", async ({ + db, + }) => { + // Create OAuth user without password + await db.insert(users).values({ + name: "OAuth User", + email: "oauth@test.com", + password: null, + }); + + const oauthUser = await db + .select() + .from(users) + .where(eq(users.email, "oauth@test.com")) + .then((rows) => rows[0]); + + const caller = getApiCaller(db, oauthUser.id, oauthUser.email, "user"); + + await expect(() => + caller.users.changePassword({ + currentPassword: "anypassword", + newPassword: "newpass456", + }), + ).rejects.toThrow(); + }); + }); + + describe("Delete Account", () => { + test("deleteAccount - with password", async ({ + db, + unauthedAPICaller, + }) => { + const user = await unauthedAPICaller.users.create({ + name: "Test User", + email: "deleteaccount@test.com", + password: "pass1234", + confirmPassword: "pass1234", + }); + const caller = getApiCaller(db, user.id, user.email, user.role || "user"); + + await caller.users.deleteAccount({ + password: "pass1234", + }); + + // Verify user is deleted + const deletedUser = await db + .select() + .from(users) + .where(eq(users.id, user.id)); + expect(deletedUser).toHaveLength(0); + }); + + test("deleteAccount - wrong password", async ({ + db, + unauthedAPICaller, + }) => { + const user = await unauthedAPICaller.users.create({ + name: "Test User", + email: "wrongdeletepass@test.com", + password: "pass1234", + confirmPassword: "pass1234", + }); + const caller = getApiCaller(db, user.id, user.email, user.role || "user"); + + await expect(() => + caller.users.deleteAccount({ + password: "wrongpassword", + }), + ).rejects.toThrow(); + }); + + test("deleteAccount - OAuth user (no password)", async ({ + db, + }) => { + // Create OAuth user without password + await db.insert(users).values({ + name: "OAuth User", + email: "oauthdelete@test.com", + password: null, + }); + + const oauthUser = await db + .select() + .from(users) + .where(eq(users.email, "oauthdelete@test.com")) + .then((rows) => rows[0]); + + const caller = getApiCaller(db, oauthUser.id, oauthUser.email, "user"); + + await caller.users.deleteAccount({}); + + // Verify user is deleted + const deletedUser = await db + .select() + .from(users) + .where(eq(users.id, oauthUser.id)); + expect(deletedUser).toHaveLength(0); + }); + }); + + describe("Who Am I", () => { + test("whoami - returns user info", async ({ + db, + unauthedAPICaller, + }) => { + const user = await unauthedAPICaller.users.create({ + name: "Test User", + email: "whoami@test.com", + password: "pass1234", + confirmPassword: "pass1234", + }); + const caller = getApiCaller(db, user.id, user.email, user.role || "user"); + + const whoami = await caller.users.whoami(); + + expect(whoami.id).toBe(user.id); + expect(whoami.name).toBe("Test User"); + expect(whoami.email).toBe("whoami@test.com"); + expect(whoami.localUser).toBe(true); + }); + + test("whoami - OAuth user", async ({ db }) => { + // Create OAuth user + await db.insert(users).values({ + name: "OAuth User", + email: "oauthwhoami@test.com", + password: null, + }); + + const oauthUser = await db + .select() + .from(users) + .where(eq(users.email, "oauthwhoami@test.com")) + .then((rows) => rows[0]); + + const caller = getApiCaller(db, oauthUser.id, oauthUser.email, "user"); + + const whoami = await caller.users.whoami(); + + expect(whoami.id).toBe(oauthUser.id); + expect(whoami.name).toBe("OAuth User"); + expect(whoami.email).toBe("oauthwhoami@test.com"); + expect(whoami.localUser).toBe(false); + }); + }); + + describe("Email Verification", () => { + test("verifyEmail - invalid token", async ({ + unauthedAPICaller, + }) => { + await expect(() => + unauthedAPICaller.users.verifyEmail({ + email: "nonexistent@test.com", + token: "invalid-token", + }), + ).rejects.toThrow(); + }); + + test("verifyEmail - invalid email format", async ({ + unauthedAPICaller, + }) => { + await expect(() => + unauthedAPICaller.users.verifyEmail({ + email: "invalid-email", + token: "some-token", + }), + ).rejects.toThrow(); + }); + }); + + describe("Resend Verification Email", () => { + test("resendVerificationEmail - existing user", async ({ + unauthedAPICaller, + }) => { + // Create user first + await unauthedAPICaller.users.create({ + name: "Test User", + email: "resend@test.com", + password: "pass1234", + confirmPassword: "pass1234", + }); + + const result = await unauthedAPICaller.users.resendVerificationEmail({ + email: "resend@test.com", + }); + + expect(result.success).toBe(true); + + // Verify that the email function was called + expect(emailModule.sendVerificationEmail).toHaveBeenCalledWith( + "resend@test.com", + "Test User", + expect.any(String), // token + ); + }); + + test("resendVerificationEmail - non-existing user", async ({ + unauthedAPICaller, + }) => { + // Should not reveal if user exists or not + const result = await unauthedAPICaller.users.resendVerificationEmail({ + email: "nonexistent@test.com", + }); + expect(result.success).toBe(true); + }); + + test("resendVerificationEmail - invalid email format", async ({ + unauthedAPICaller, + }) => { + await expect(() => + unauthedAPICaller.users.resendVerificationEmail({ + email: "invalid-email", + }), + ).rejects.toThrow(); + }); + }); }); diff --git a/packages/trpc/routers/users.ts b/packages/trpc/routers/users.ts index 6aa12454..5ce9c67e 100644 --- a/packages/trpc/routers/users.ts +++ b/packages/trpc/routers/users.ts @@ -1,24 +1,6 @@ -import { randomBytes } from "crypto"; import { TRPCError } from "@trpc/server"; -import { and, count, desc, eq, gte, sql } from "drizzle-orm"; -import invariant from "tiny-invariant"; import { z } from "zod"; -import { SqliteError } from "@karakeep/db"; -import { - assets, - bookmarkLinks, - bookmarkLists, - bookmarks, - bookmarkTags, - highlights, - passwordResetTokens, - tagsOnBookmarks, - users, - userSettings, - verificationTokens, -} from "@karakeep/db/schema"; -import { deleteUserAssets } from "@karakeep/shared/assetdb"; import serverConfig from "@karakeep/shared/config"; import { zResetPasswordSchema, @@ -29,160 +11,14 @@ import { zWhoAmIResponseSchema, } from "@karakeep/shared/types/users"; -import { generatePasswordSalt, hashPassword, validatePassword } from "../auth"; -import { sendPasswordResetEmail, sendVerificationEmail } from "../email"; import { adminProcedure, authedProcedure, - Context, createRateLimitMiddleware, publicProcedure, router, } from "../index"; - -async function genEmailVerificationToken(db: Context["db"], email: string) { - const token = randomBytes(10).toString("hex"); - const expires = new Date(Date.now() + 24 * 60 * 60 * 1000); // 24 hours - - // Store verification token - await db.insert(verificationTokens).values({ - identifier: email, - token, - expires, - }); - - return token; -} - -async function verifyEmailToken( - db: Context["db"], - email: string, - token: string, -): Promise { - const verificationToken = await db.query.verificationTokens.findFirst({ - where: (vt, { and, eq }) => - and(eq(vt.identifier, email), eq(vt.token, token)), - }); - - if (!verificationToken) { - return false; - } - - if (verificationToken.expires < new Date()) { - // Clean up expired token - await db - .delete(verificationTokens) - .where( - and( - eq(verificationTokens.identifier, email), - eq(verificationTokens.token, token), - ), - ); - return false; - } - - // Clean up used token - await db - .delete(verificationTokens) - .where( - and( - eq(verificationTokens.identifier, email), - eq(verificationTokens.token, token), - ), - ); - - return true; -} - -export async function createUserRaw( - db: Context["db"], - input: { - name: string; - email: string; - password?: string; - salt?: string; - role?: "user" | "admin"; - emailVerified?: Date | null; - }, -) { - return await db.transaction(async (trx) => { - let userRole = input.role; - if (!userRole) { - const [{ count: userCount }] = await trx - .select({ count: count() }) - .from(users); - userRole = userCount == 0 ? "admin" : "user"; - } - - try { - const [result] = await trx - .insert(users) - .values({ - name: input.name, - email: input.email, - password: input.password, - salt: input.salt, - role: userRole, - emailVerified: input.emailVerified, - bookmarkQuota: serverConfig.quotas.free.bookmarkLimit, - storageQuota: serverConfig.quotas.free.assetSizeBytes, - }) - .returning({ - id: users.id, - name: users.name, - email: users.email, - role: users.role, - emailVerified: users.emailVerified, - }); - - // Insert user settings for the new user - await trx.insert(userSettings).values({ - userId: result.id, - }); - - return result; - } catch (e) { - if (e instanceof SqliteError) { - if (e.code == "SQLITE_CONSTRAINT_UNIQUE") { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Email is already taken", - }); - } - } - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Something went wrong", - }); - } - }); -} - -export async function createUser( - input: z.infer, - ctx: Context, - role?: "user" | "admin", -) { - const salt = generatePasswordSalt(); - let user = await createUserRaw(ctx.db, { - name: input.name, - email: input.email, - password: await hashPassword(input.password, salt), - salt, - role, - }); - // Send verification email if required - if (serverConfig.auth.emailVerificationRequired) { - const token = await genEmailVerificationToken(ctx.db, input.email); - try { - await sendVerificationEmail(input.email, input.name, token); - } catch (error) { - console.error("Failed to send verification email:", error); - // Don't fail user creation if email sending fails - } - } - return user; -} +import { User } from "../models/users"; export const usersAppRouter = router({ create: publicProcedure @@ -215,7 +51,13 @@ export const usersAppRouter = router({ message: errorMessage, }); } - return createUser(input, ctx); + const user = await User.create(ctx, input); + return { + id: user.id, + name: user.name, + email: user.email, + role: user.role, + }; }), list: adminProcedure .output( @@ -234,23 +76,9 @@ export const usersAppRouter = router({ }), ) .query(async ({ ctx }) => { - const dbUsers = await ctx.db - .select({ - id: users.id, - name: users.name, - email: users.email, - role: users.role, - password: users.password, - bookmarkQuota: users.bookmarkQuota, - storageQuota: users.storageQuota, - }) - .from(users); - + const users = await User.getAll(ctx); return { - users: dbUsers.map(({ password, ...user }) => ({ - ...user, - localUser: password !== null, - })), + users: users.map((u) => u.asPublicUser()), }; }), changePassword: authedProcedure @@ -261,26 +89,8 @@ export const usersAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - invariant(ctx.user.email, "A user always has an email specified"); - let user; - try { - user = await validatePassword( - ctx.user.email, - input.currentPassword, - ctx.db, - ); - } catch { - throw new TRPCError({ code: "UNAUTHORIZED" }); - } - invariant(user.id, ctx.user.id); - const newSalt = generatePasswordSalt(); - await ctx.db - .update(users) - .set({ - password: await hashPassword(input.newPassword, newSalt), - salt: newSalt, - }) - .where(eq(users.id, ctx.user.id)); + const user = await User.fromCtx(ctx); + await user.changePassword(input.currentPassword, input.newPassword); }), delete: adminProcedure .input( @@ -289,11 +99,7 @@ export const usersAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - const res = await ctx.db.delete(users).where(eq(users.id, input.userId)); - if (res.changes == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - await deleteUserAssets({ userId: input.userId }); + await User.deleteAsAdmin(ctx, input.userId); }), deleteAccount: authedProcedure .input( @@ -302,367 +108,32 @@ export const usersAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - invariant(ctx.user.email, "A user always has an email specified"); - - // Check if user has a password (local account) - const user = await ctx.db.query.users.findFirst({ - where: eq(users.id, ctx.user.id), - }); - - if (!user) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - - // If user has a password, verify it before allowing account deletion - if (user.password) { - if (!input.password) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Password is required for local accounts", - }); - } - - try { - await validatePassword(ctx.user.email, input.password, ctx.db); - } catch { - throw new TRPCError({ - code: "UNAUTHORIZED", - message: "Invalid password", - }); - } - } - - // Delete the user account - const res = await ctx.db.delete(users).where(eq(users.id, ctx.user.id)); - if (res.changes == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - - // Delete user assets - await deleteUserAssets({ userId: ctx.user.id }); + const user = await User.fromCtx(ctx); + await user.deleteAccount(input.password); }), whoami: authedProcedure .output(zWhoAmIResponseSchema) .query(async ({ ctx }) => { - if (!ctx.user.email) { - throw new TRPCError({ code: "UNAUTHORIZED" }); - } - const userDb = await ctx.db.query.users.findFirst({ - where: and(eq(users.id, ctx.user.id), eq(users.email, ctx.user.email)), - }); - if (!userDb) { - throw new TRPCError({ code: "UNAUTHORIZED" }); - } - return { - id: ctx.user.id, - name: ctx.user.name, - email: ctx.user.email, - localUser: userDb.password !== null, - }; + const user = await User.fromCtx(ctx); + return user.asWhoAmI(); }), stats: authedProcedure .output(zUserStatsResponseSchema) .query(async ({ ctx }) => { - // Get user's timezone - const userSet = await ctx.db.query.userSettings.findFirst({ - where: eq(userSettings.userId, ctx.user.id), - }); - const userTimezone = userSet?.timezone || "UTC"; - const now = new Date(); - const weekAgo = new Date(now.getTime() - 7 * 24 * 60 * 60 * 1000); - const monthAgo = new Date(now.getTime() - 30 * 24 * 60 * 60 * 1000); - const yearAgo = new Date(now.getTime() - 365 * 24 * 60 * 60 * 1000); - - const [ - [{ numBookmarks }], - [{ numFavorites }], - [{ numArchived }], - [{ numTags }], - [{ numLists }], - [{ numHighlights }], - bookmarksByType, - topDomains, - [{ totalAssetSize }], - assetsByType, - [{ thisWeek }], - [{ thisMonth }], - [{ thisYear }], - bookmarkTimestamps, - tagUsage, - ] = await Promise.all([ - // Basic counts - ctx.db - .select({ numBookmarks: count() }) - .from(bookmarks) - .where(eq(bookmarks.userId, ctx.user.id)), - ctx.db - .select({ numFavorites: count() }) - .from(bookmarks) - .where( - and( - eq(bookmarks.userId, ctx.user.id), - eq(bookmarks.favourited, true), - ), - ), - ctx.db - .select({ numArchived: count() }) - .from(bookmarks) - .where( - and( - eq(bookmarks.userId, ctx.user.id), - eq(bookmarks.archived, true), - ), - ), - ctx.db - .select({ numTags: count() }) - .from(bookmarkTags) - .where(eq(bookmarkTags.userId, ctx.user.id)), - ctx.db - .select({ numLists: count() }) - .from(bookmarkLists) - .where(eq(bookmarkLists.userId, ctx.user.id)), - ctx.db - .select({ numHighlights: count() }) - .from(highlights) - .where(eq(highlights.userId, ctx.user.id)), - - // Bookmarks by type - ctx.db - .select({ - type: bookmarks.type, - count: count(), - }) - .from(bookmarks) - .where(eq(bookmarks.userId, ctx.user.id)) - .groupBy(bookmarks.type), - - // Top domains - ctx.db - .select({ - domain: sql`CASE - WHEN ${bookmarkLinks.url} LIKE 'https://%' THEN - CASE - WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') > 0 THEN - SUBSTR(${bookmarkLinks.url}, 9, INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') - 1) - ELSE - SUBSTR(${bookmarkLinks.url}, 9) - END - WHEN ${bookmarkLinks.url} LIKE 'http://%' THEN - CASE - WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') > 0 THEN - SUBSTR(${bookmarkLinks.url}, 8, INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') - 1) - ELSE - SUBSTR(${bookmarkLinks.url}, 8) - END - ELSE - CASE - WHEN INSTR(${bookmarkLinks.url}, '/') > 0 THEN - SUBSTR(${bookmarkLinks.url}, 1, INSTR(${bookmarkLinks.url}, '/') - 1) - ELSE - ${bookmarkLinks.url} - END - END`, - count: count(), - }) - .from(bookmarkLinks) - .innerJoin(bookmarks, eq(bookmarks.id, bookmarkLinks.id)) - .where(eq(bookmarks.userId, ctx.user.id)) - .groupBy( - sql`CASE - WHEN ${bookmarkLinks.url} LIKE 'https://%' THEN - CASE - WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') > 0 THEN - SUBSTR(${bookmarkLinks.url}, 9, INSTR(SUBSTR(${bookmarkLinks.url}, 9), '/') - 1) - ELSE - SUBSTR(${bookmarkLinks.url}, 9) - END - WHEN ${bookmarkLinks.url} LIKE 'http://%' THEN - CASE - WHEN INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') > 0 THEN - SUBSTR(${bookmarkLinks.url}, 8, INSTR(SUBSTR(${bookmarkLinks.url}, 8), '/') - 1) - ELSE - SUBSTR(${bookmarkLinks.url}, 8) - END - ELSE - CASE - WHEN INSTR(${bookmarkLinks.url}, '/') > 0 THEN - SUBSTR(${bookmarkLinks.url}, 1, INSTR(${bookmarkLinks.url}, '/') - 1) - ELSE - ${bookmarkLinks.url} - END - END`, - ) - .orderBy(desc(count())) - .limit(10), - - // Total asset size - ctx.db - .select({ - totalAssetSize: sql`COALESCE(SUM(${assets.size}), 0)`, - }) - .from(assets) - .where(eq(assets.userId, ctx.user.id)), - - // Assets by type - ctx.db - .select({ - type: assets.assetType, - count: count(), - totalSize: sql`COALESCE(SUM(${assets.size}), 0)`, - }) - .from(assets) - .where(eq(assets.userId, ctx.user.id)) - .groupBy(assets.assetType), - - // Activity stats - ctx.db - .select({ thisWeek: count() }) - .from(bookmarks) - .where( - and( - eq(bookmarks.userId, ctx.user.id), - gte(bookmarks.createdAt, weekAgo), - ), - ), - ctx.db - .select({ thisMonth: count() }) - .from(bookmarks) - .where( - and( - eq(bookmarks.userId, ctx.user.id), - gte(bookmarks.createdAt, monthAgo), - ), - ), - ctx.db - .select({ thisYear: count() }) - .from(bookmarks) - .where( - and( - eq(bookmarks.userId, ctx.user.id), - gte(bookmarks.createdAt, yearAgo), - ), - ), - - // Get all bookmark timestamps for timezone conversion - ctx.db - .select({ - createdAt: bookmarks.createdAt, - }) - .from(bookmarks) - .where(eq(bookmarks.userId, ctx.user.id)), - - // Tag usage - ctx.db - .select({ - name: bookmarkTags.name, - count: count(), - }) - .from(bookmarkTags) - .innerJoin( - tagsOnBookmarks, - eq(tagsOnBookmarks.tagId, bookmarkTags.id), - ) - .where(eq(bookmarkTags.userId, ctx.user.id)) - .groupBy(bookmarkTags.name) - .orderBy(desc(count())) - .limit(10), - ]); - - // Process bookmarks by type - const bookmarkTypeMap = { link: 0, text: 0, asset: 0 }; - bookmarksByType.forEach((item) => { - if (item.type in bookmarkTypeMap) { - bookmarkTypeMap[item.type as keyof typeof bookmarkTypeMap] = - item.count; - } - }); - - // Process timestamps with user timezone - const hourCounts = Array.from({ length: 24 }, () => 0); - const dayCounts = Array.from({ length: 7 }, () => 0); - - bookmarkTimestamps.forEach(({ createdAt }) => { - if (createdAt) { - // Convert timestamp to user timezone - const date = new Date(createdAt); - const userDate = new Date( - date.toLocaleString("en-US", { timeZone: userTimezone }), - ); - - const hour = userDate.getHours(); - const day = userDate.getDay(); - - hourCounts[hour]++; - dayCounts[day]++; - } - }); - - const hourlyActivity = Array.from({ length: 24 }, (_, i) => ({ - hour: i, - count: hourCounts[i], - })); - - const dailyActivity = Array.from({ length: 7 }, (_, i) => ({ - day: i, - count: dayCounts[i], - })); - - return { - numBookmarks, - numFavorites, - numArchived, - numTags, - numLists, - numHighlights, - bookmarksByType: bookmarkTypeMap, - topDomains: topDomains.filter((d) => d.domain && d.domain.length > 0), - totalAssetSize: totalAssetSize || 0, - assetsByType, - bookmarkingActivity: { - thisWeek: thisWeek || 0, - thisMonth: thisMonth || 0, - thisYear: thisYear || 0, - byHour: hourlyActivity, - byDayOfWeek: dailyActivity, - }, - tagUsage, - }; + const user = await User.fromCtx(ctx); + return await user.getStats(); }), settings: authedProcedure .output(zUserSettingsSchema) .query(async ({ ctx }) => { - const settings = await ctx.db.query.userSettings.findFirst({ - where: eq(userSettings.userId, ctx.user.id), - }); - if (!settings) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "User settings not found", - }); - } - return { - bookmarkClickAction: settings.bookmarkClickAction, - archiveDisplayBehaviour: settings.archiveDisplayBehaviour, - timezone: settings.timezone || "UTC", - }; + const user = await User.fromCtx(ctx); + return await user.getSettings(); }), updateSettings: authedProcedure .input(zUpdateUserSettingsSchema) .mutation(async ({ input, ctx }) => { - if (Object.keys(input).length === 0) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "No settings provided", - }); - } - await ctx.db - .update(userSettings) - .set({ - bookmarkClickAction: input.bookmarkClickAction, - archiveDisplayBehaviour: input.archiveDisplayBehaviour, - timezone: input.timezone, - }) - .where(eq(userSettings.userId, ctx.user.id)); + const user = await User.fromCtx(ctx); + await user.updateSettings(input); }), verifyEmail: publicProcedure .use( @@ -671,7 +142,7 @@ export const usersAppRouter = router({ windowMs: 5 * 60 * 1000, maxRequests: 10, }), - ) // 10 requests per 5 minutes + ) .input( z.object({ email: z.string().email(), @@ -679,27 +150,7 @@ export const usersAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - const isValid = await verifyEmailToken(ctx.db, input.email, input.token); - if (!isValid) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Invalid or expired verification token", - }); - } - - // Update user's emailVerified status - const result = await ctx.db - .update(users) - .set({ emailVerified: new Date() }) - .where(eq(users.email, input.email)); - - if (result.changes === 0) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "User not found", - }); - } - + await User.verifyEmail(ctx, input.email, input.token); return { success: true }; }), resendVerificationEmail: publicProcedure @@ -709,52 +160,15 @@ export const usersAppRouter = router({ windowMs: 5 * 60 * 1000, maxRequests: 3, }), - ) // 3 requests per 5 minutes + ) .input( z.object({ email: z.string().email(), }), ) .mutation(async ({ input, ctx }) => { - if ( - !serverConfig.auth.emailVerificationRequired || - !serverConfig.email.smtp - ) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Email verification is not enabled", - }); - } - - const user = await ctx.db.query.users.findFirst({ - where: eq(users.email, input.email), - }); - - if (!user) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "User not found", - }); - } - - if (user.emailVerified) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Email is already verified", - }); - } - - const token = await genEmailVerificationToken(ctx.db, input.email); - try { - await sendVerificationEmail(input.email, user.name, token); - return { success: true }; - } catch (error) { - console.error("Failed to send verification email:", error); - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Failed to send verification email", - }); - } + await User.resendVerificationEmail(ctx, input.email); + return { success: true }; }), forgotPassword: publicProcedure .use( @@ -770,47 +184,8 @@ export const usersAppRouter = router({ }), ) .mutation(async ({ input, ctx }) => { - if (!serverConfig.email.smtp) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Email service is not configured", - }); - } - - const user = await ctx.db.query.users.findFirst({ - where: eq(users.email, input.email), - }); - - if (!user) { - // Don't reveal if user exists or not for security - return { success: true }; - } - - // Only send reset email for users with passwords (local accounts) - if (!user.password) { - return { success: true }; - } - - try { - const token = randomBytes(32).toString("hex"); - const expires = new Date(Date.now() + 60 * 60 * 1000); // 1 hour - - // Store password reset token - await ctx.db.insert(passwordResetTokens).values({ - userId: user.id, - token, - expires, - }); - - await sendPasswordResetEmail(input.email, user.name, token); - return { success: true }; - } catch (error) { - console.error("Failed to send password reset email:", error); - throw new TRPCError({ - code: "INTERNAL_SERVER_ERROR", - message: "Failed to send password reset email", - }); - } + await User.forgotPassword(ctx, input.email); + return { success: true }; }), resetPassword: publicProcedure .use( @@ -822,59 +197,7 @@ export const usersAppRouter = router({ ) .input(zResetPasswordSchema) .mutation(async ({ input, ctx }) => { - const token = input.token; - const resetToken = await ctx.db.query.passwordResetTokens.findFirst({ - where: eq(passwordResetTokens.token, token), - with: { - user: { - columns: { - id: true, - }, - }, - }, - }); - - if (!resetToken) { - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Invalid or expired reset token", - }); - } - - if (resetToken.expires < new Date()) { - // Clean up expired token - await ctx.db - .delete(passwordResetTokens) - .where(eq(passwordResetTokens.token, token)); - throw new TRPCError({ - code: "BAD_REQUEST", - message: "Invalid or expired reset token", - }); - } - - if (!resetToken.user) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "User not found", - }); - } - - // Generate new password hash - const newSalt = generatePasswordSalt(); - const hashedPassword = await hashPassword(input.newPassword, newSalt); - - // Update user password - await ctx.db - .update(users) - .set({ - password: hashedPassword, - salt: newSalt, - }) - .where(eq(users.id, resetToken.user.id)); - - await ctx.db - .delete(passwordResetTokens) - .where(eq(passwordResetTokens.token, token)); + await User.resetPassword(ctx, input); return { success: true }; }), }); diff --git a/packages/trpc/routers/webhooks.ts b/packages/trpc/routers/webhooks.ts index ab2a6908..5d30969b 100644 --- a/packages/trpc/routers/webhooks.ts +++ b/packages/trpc/routers/webhooks.ts @@ -1,54 +1,27 @@ -import { experimental_trpcMiddleware, TRPCError } from "@trpc/server"; -import { and, eq } from "drizzle-orm"; +import { experimental_trpcMiddleware } from "@trpc/server"; import { z } from "zod"; -import { webhooksTable } from "@karakeep/db/schema"; import { zNewWebhookSchema, zUpdateWebhookSchema, zWebhookSchema, } from "@karakeep/shared/types/webhooks"; -import { authedProcedure, Context, router } from "../index"; - -function adaptWebhook(webhook: typeof webhooksTable.$inferSelect) { - const { token, ...rest } = webhook; - return { - ...rest, - hasToken: token !== null, - }; -} +import type { AuthedContext } from "../index"; +import { authedProcedure, router } from "../index"; +import { Webhook } from "../models/webhooks"; export const ensureWebhookOwnership = experimental_trpcMiddleware<{ - ctx: Context; + ctx: AuthedContext; input: { webhookId: string }; }>().create(async (opts) => { - const webhook = await opts.ctx.db.query.webhooksTable.findFirst({ - where: eq(webhooksTable.id, opts.input.webhookId), - columns: { - userId: true, + const webhook = await Webhook.fromId(opts.ctx, opts.input.webhookId); + return opts.next({ + ctx: { + ...opts.ctx, + webhook, }, }); - if (!opts.ctx.user) { - throw new TRPCError({ - code: "UNAUTHORIZED", - message: "User is not authorized", - }); - } - if (!webhook) { - throw new TRPCError({ - code: "NOT_FOUND", - message: "Webhook not found", - }); - } - if (webhook.userId != opts.ctx.user.id) { - throw new TRPCError({ - code: "FORBIDDEN", - message: "User is not allowed to access resource", - }); - } - - return opts.next(); }); export const webhooksAppRouter = router({ @@ -56,50 +29,22 @@ export const webhooksAppRouter = router({ .input(zNewWebhookSchema) .output(zWebhookSchema) .mutation(async ({ input, ctx }) => { - const [webhook] = await ctx.db - .insert(webhooksTable) - .values({ - url: input.url, - events: input.events, - token: input.token ?? null, - userId: ctx.user.id, - }) - .returning(); - - return adaptWebhook(webhook); + const webhook = await Webhook.create(ctx, input); + return webhook.asPublicWebhook(); }), update: authedProcedure .input(zUpdateWebhookSchema) .output(zWebhookSchema) .use(ensureWebhookOwnership) .mutation(async ({ input, ctx }) => { - const webhook = await ctx.db - .update(webhooksTable) - .set({ - url: input.url, - events: input.events, - token: input.token, - }) - .where( - and( - eq(webhooksTable.userId, ctx.user.id), - eq(webhooksTable.id, input.webhookId), - ), - ) - .returning(); - if (webhook.length == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } - - return adaptWebhook(webhook[0]); + await ctx.webhook.update(input); + return ctx.webhook.asPublicWebhook(); }), list: authedProcedure .output(z.object({ webhooks: z.array(zWebhookSchema) })) .query(async ({ ctx }) => { - const webhooks = await ctx.db.query.webhooksTable.findMany({ - where: eq(webhooksTable.userId, ctx.user.id), - }); - return { webhooks: webhooks.map(adaptWebhook) }; + const webhooks = await Webhook.getAll(ctx); + return { webhooks: webhooks.map((w) => w.asPublicWebhook()) }; }), delete: authedProcedure .input( @@ -108,17 +53,7 @@ export const webhooksAppRouter = router({ }), ) .use(ensureWebhookOwnership) - .mutation(async ({ input, ctx }) => { - const res = await ctx.db - .delete(webhooksTable) - .where( - and( - eq(webhooksTable.userId, ctx.user.id), - eq(webhooksTable.id, input.webhookId), - ), - ); - if (res.changes == 0) { - throw new TRPCError({ code: "NOT_FOUND" }); - } + .mutation(async ({ ctx }) => { + await ctx.webhook.delete(); }), }); -- cgit v1.2.3-70-g09d2