aboutsummaryrefslogtreecommitdiffstats
path: root/packages/trpc
diff options
context:
space:
mode:
authorMohamedBassem <me@mbassem.com>2024-09-29 01:24:04 +0000
committerMohamedBassem <me@mbassem.com>2024-09-29 01:24:04 +0000
commit36fb5a4c63aada8e8107b8e9d97a6ba128d13494 (patch)
treec1bb803b55a751cf97766a0be08691c5589aef55 /packages/trpc
parent57f5faa7b5ba7a43bb09555741a207c0113e9d62 (diff)
downloadkarakeep-36fb5a4c63aada8e8107b8e9d97a6ba128d13494.tar.zst
feature(web): Add the ability to customize the inference prompts. Fixes #170
Diffstat (limited to 'packages/trpc')
-rw-r--r--packages/trpc/routers/_app.ts2
-rw-r--r--packages/trpc/routers/prompts.ts114
2 files changed, 116 insertions, 0 deletions
diff --git a/packages/trpc/routers/_app.ts b/packages/trpc/routers/_app.ts
index 577b523e..01c92e6a 100644
--- a/packages/trpc/routers/_app.ts
+++ b/packages/trpc/routers/_app.ts
@@ -3,6 +3,7 @@ import { adminAppRouter } from "./admin";
import { apiKeysAppRouter } from "./apiKeys";
import { bookmarksAppRouter } from "./bookmarks";
import { listsAppRouter } from "./lists";
+import { promptsAppRouter } from "./prompts";
import { tagsAppRouter } from "./tags";
import { usersAppRouter } from "./users";
@@ -12,6 +13,7 @@ export const appRouter = router({
users: usersAppRouter,
lists: listsAppRouter,
tags: tagsAppRouter,
+ prompts: promptsAppRouter,
admin: adminAppRouter,
});
// export type definition of API
diff --git a/packages/trpc/routers/prompts.ts b/packages/trpc/routers/prompts.ts
new file mode 100644
index 00000000..629d5829
--- /dev/null
+++ b/packages/trpc/routers/prompts.ts
@@ -0,0 +1,114 @@
+import { experimental_trpcMiddleware, TRPCError } from "@trpc/server";
+import { and, eq } from "drizzle-orm";
+import { z } from "zod";
+
+import { customPrompts } from "@hoarder/db/schema";
+import {
+ zNewPromptSchema,
+ zPromptSchema,
+ zUpdatePromptSchema,
+} from "@hoarder/shared/types/prompts";
+
+import { authedProcedure, Context, router } from "../index";
+
+export const ensurePromptOwnership = experimental_trpcMiddleware<{
+ ctx: Context;
+ input: { promptId: string };
+}>().create(async (opts) => {
+ const prompt = await opts.ctx.db.query.customPrompts.findFirst({
+ where: eq(customPrompts.id, opts.input.promptId),
+ columns: {
+ userId: true,
+ },
+ });
+ if (!opts.ctx.user) {
+ throw new TRPCError({
+ code: "UNAUTHORIZED",
+ message: "User is not authorized",
+ });
+ }
+ if (!prompt) {
+ throw new TRPCError({
+ code: "NOT_FOUND",
+ message: "Prompt not found",
+ });
+ }
+ if (prompt.userId != opts.ctx.user.id) {
+ throw new TRPCError({
+ code: "FORBIDDEN",
+ message: "User is not allowed to access resource",
+ });
+ }
+
+ return opts.next();
+});
+
+export const promptsAppRouter = router({
+ create: authedProcedure
+ .input(zNewPromptSchema)
+ .output(zPromptSchema)
+ .mutation(async ({ input, ctx }) => {
+ const [prompt] = await ctx.db
+ .insert(customPrompts)
+ .values({
+ text: input.text,
+ appliesTo: input.appliesTo,
+ userId: ctx.user.id,
+ enabled: true,
+ })
+ .returning();
+ return prompt;
+ }),
+ update: authedProcedure
+ .input(zUpdatePromptSchema)
+ .output(zPromptSchema)
+ .use(ensurePromptOwnership)
+ .mutation(async ({ input, ctx }) => {
+ const res = await ctx.db
+ .update(customPrompts)
+ .set({
+ text: input.text,
+ appliesTo: input.appliesTo,
+ enabled: input.enabled,
+ })
+ .where(
+ and(
+ eq(customPrompts.userId, ctx.user.id),
+ eq(customPrompts.id, input.promptId),
+ ),
+ )
+ .returning();
+ if (res.length == 0) {
+ throw new TRPCError({ code: "NOT_FOUND" });
+ }
+ return res[0];
+ }),
+ list: authedProcedure
+ .output(z.array(zPromptSchema))
+ .query(async ({ ctx }) => {
+ const prompts = await ctx.db.query.customPrompts.findMany({
+ where: eq(customPrompts.userId, ctx.user.id),
+ });
+ return prompts;
+ }),
+ delete: authedProcedure
+ .input(
+ z.object({
+ promptId: z.string(),
+ }),
+ )
+ .use(ensurePromptOwnership)
+ .mutation(async ({ input, ctx }) => {
+ const res = await ctx.db
+ .delete(customPrompts)
+ .where(
+ and(
+ eq(customPrompts.userId, ctx.user.id),
+ eq(customPrompts.id, input.promptId),
+ ),
+ );
+ if (res.changes == 0) {
+ throw new TRPCError({ code: "NOT_FOUND" });
+ }
+ }),
+});