diff options
| author | MohamedBassem <me@mbassem.com> | 2024-09-29 01:24:04 +0000 |
|---|---|---|
| committer | MohamedBassem <me@mbassem.com> | 2024-09-29 01:24:04 +0000 |
| commit | 36fb5a4c63aada8e8107b8e9d97a6ba128d13494 (patch) | |
| tree | c1bb803b55a751cf97766a0be08691c5589aef55 /packages/trpc | |
| parent | 57f5faa7b5ba7a43bb09555741a207c0113e9d62 (diff) | |
| download | karakeep-36fb5a4c63aada8e8107b8e9d97a6ba128d13494.tar.zst | |
feature(web): Add the ability to customize the inference prompts. Fixes #170
Diffstat (limited to 'packages/trpc')
| -rw-r--r-- | packages/trpc/routers/_app.ts | 2 | ||||
| -rw-r--r-- | packages/trpc/routers/prompts.ts | 114 |
2 files changed, 116 insertions, 0 deletions
diff --git a/packages/trpc/routers/_app.ts b/packages/trpc/routers/_app.ts index 577b523e..01c92e6a 100644 --- a/packages/trpc/routers/_app.ts +++ b/packages/trpc/routers/_app.ts @@ -3,6 +3,7 @@ import { adminAppRouter } from "./admin"; import { apiKeysAppRouter } from "./apiKeys"; import { bookmarksAppRouter } from "./bookmarks"; import { listsAppRouter } from "./lists"; +import { promptsAppRouter } from "./prompts"; import { tagsAppRouter } from "./tags"; import { usersAppRouter } from "./users"; @@ -12,6 +13,7 @@ export const appRouter = router({ users: usersAppRouter, lists: listsAppRouter, tags: tagsAppRouter, + prompts: promptsAppRouter, admin: adminAppRouter, }); // export type definition of API diff --git a/packages/trpc/routers/prompts.ts b/packages/trpc/routers/prompts.ts new file mode 100644 index 00000000..629d5829 --- /dev/null +++ b/packages/trpc/routers/prompts.ts @@ -0,0 +1,114 @@ +import { experimental_trpcMiddleware, TRPCError } from "@trpc/server"; +import { and, eq } from "drizzle-orm"; +import { z } from "zod"; + +import { customPrompts } from "@hoarder/db/schema"; +import { + zNewPromptSchema, + zPromptSchema, + zUpdatePromptSchema, +} from "@hoarder/shared/types/prompts"; + +import { authedProcedure, Context, router } from "../index"; + +export const ensurePromptOwnership = experimental_trpcMiddleware<{ + ctx: Context; + input: { promptId: string }; +}>().create(async (opts) => { + const prompt = await opts.ctx.db.query.customPrompts.findFirst({ + where: eq(customPrompts.id, opts.input.promptId), + columns: { + userId: true, + }, + }); + if (!opts.ctx.user) { + throw new TRPCError({ + code: "UNAUTHORIZED", + message: "User is not authorized", + }); + } + if (!prompt) { + throw new TRPCError({ + code: "NOT_FOUND", + message: "Prompt not found", + }); + } + if (prompt.userId != opts.ctx.user.id) { + throw new TRPCError({ + code: "FORBIDDEN", + message: "User is not allowed to access resource", + }); + } + + return opts.next(); +}); + +export const promptsAppRouter = router({ + create: authedProcedure + .input(zNewPromptSchema) + .output(zPromptSchema) + .mutation(async ({ input, ctx }) => { + const [prompt] = await ctx.db + .insert(customPrompts) + .values({ + text: input.text, + appliesTo: input.appliesTo, + userId: ctx.user.id, + enabled: true, + }) + .returning(); + return prompt; + }), + update: authedProcedure + .input(zUpdatePromptSchema) + .output(zPromptSchema) + .use(ensurePromptOwnership) + .mutation(async ({ input, ctx }) => { + const res = await ctx.db + .update(customPrompts) + .set({ + text: input.text, + appliesTo: input.appliesTo, + enabled: input.enabled, + }) + .where( + and( + eq(customPrompts.userId, ctx.user.id), + eq(customPrompts.id, input.promptId), + ), + ) + .returning(); + if (res.length == 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + return res[0]; + }), + list: authedProcedure + .output(z.array(zPromptSchema)) + .query(async ({ ctx }) => { + const prompts = await ctx.db.query.customPrompts.findMany({ + where: eq(customPrompts.userId, ctx.user.id), + }); + return prompts; + }), + delete: authedProcedure + .input( + z.object({ + promptId: z.string(), + }), + ) + .use(ensurePromptOwnership) + .mutation(async ({ input, ctx }) => { + const res = await ctx.db + .delete(customPrompts) + .where( + and( + eq(customPrompts.userId, ctx.user.id), + eq(customPrompts.id, input.promptId), + ), + ); + if (res.changes == 0) { + throw new TRPCError({ code: "NOT_FOUND" }); + } + }), +}); |
