From f84db0a94d78e595332b792446d305717f83b47a Mon Sep 17 00:00:00 2001 From: pradipthaadhi Date: Tue, 31 Mar 2026 11:36:35 +0700 Subject: [PATCH] feat: enhance memory retrieval API with authentication and room validation - Added header validation and room ownership checks to the GET /api/memories/get endpoint. - Implemented error handling for missing room ID, unauthorized access, room not found, and forbidden access. - Created unit tests for various scenarios including authentication, room existence, and memory retrieval success. - Updated useMessageLoader and related hooks to include access token for fetching messages. --- app/api/memories/get/__tests__/route.test.ts | 135 +++++++++++++++++++ app/api/memories/get/route.ts | 20 ++- hooks/useMessageLoader.ts | 13 +- hooks/useVercelChat.ts | 1 + lib/supabase/getClientMessages.tsx | 8 +- 5 files changed, 171 insertions(+), 6 deletions(-) create mode 100644 app/api/memories/get/__tests__/route.test.ts diff --git a/app/api/memories/get/__tests__/route.test.ts b/app/api/memories/get/__tests__/route.test.ts new file mode 100644 index 000000000..bf4cb9780 --- /dev/null +++ b/app/api/memories/get/__tests__/route.test.ts @@ -0,0 +1,135 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { NextRequest } from "next/server"; +import { GET } from "../route"; + +const mockValidateHeaders = vi.fn(); +const mockGetRoom = vi.fn(); +const mockQueryMemories = vi.fn(); + +vi.mock("@/lib/chat/validateHeaders", () => ({ + validateHeaders: (...args: unknown[]) => mockValidateHeaders(...args), +})); + +vi.mock("@/lib/supabase/getRoom", () => ({ + default: (...args: unknown[]) => mockGetRoom(...args), +})); + +vi.mock("@/lib/supabase/queryMemories", () => ({ + default: (...args: unknown[]) => mockQueryMemories(...args), +})); + +describe("GET /api/memories/get", () => { + const roomId = "11111111-1111-1111-1111-111111111111"; + const accountId = "aaaaaaaa-aaaa-aaaa-aaaa-aaaaaaaaaaaa"; + + beforeEach(() => { + vi.clearAllMocks(); + }); + + it("returns 400 when roomId is missing", async () => { + const req = new NextRequest("https://example.com/api/memories/get"); + const res = await GET(req); + expect(res.status).toBe(400); + const body = await res.json(); + expect(body.error).toBe("Room ID is required"); + }); + + it("returns 401 when not authenticated", async () => { + mockValidateHeaders.mockResolvedValueOnce({}); + + const req = new NextRequest( + `https://example.com/api/memories/get?roomId=${roomId}`, + ); + const res = await GET(req); + expect(res.status).toBe(401); + const body = await res.json(); + expect(body.error).toBe("Unauthorized"); + }); + + it("forwards validateHeaders error Response", async () => { + const errorRes = new Response(JSON.stringify({ status: "error" }), { + status: 401, + }); + mockValidateHeaders.mockResolvedValueOnce(errorRes); + + const req = new NextRequest( + `https://example.com/api/memories/get?roomId=${roomId}`, + { headers: { Authorization: "Bearer bad" } }, + ); + const res = await GET(req); + expect(res.status).toBe(401); + }); + + it("returns 404 when room does not exist", async () => { + mockValidateHeaders.mockResolvedValueOnce({ accountId }); + mockGetRoom.mockResolvedValueOnce(null); + + const req = new NextRequest( + `https://example.com/api/memories/get?roomId=${roomId}`, + { headers: { Authorization: "Bearer token" } }, + ); + const res = await GET(req); + expect(res.status).toBe(404); + const body = await res.json(); + expect(body.error).toBe("Room not found"); + }); + + it("returns 403 when room belongs to another account", async () => { + mockValidateHeaders.mockResolvedValueOnce({ accountId }); + mockGetRoom.mockResolvedValueOnce({ + id: roomId, + account_id: "bbbbbbbb-bbbb-bbbb-bbbb-bbbbbbbbbbbb", + }); + + const req = new NextRequest( + `https://example.com/api/memories/get?roomId=${roomId}`, + { headers: { Authorization: "Bearer token" } }, + ); + const res = await GET(req); + expect(res.status).toBe(403); + const body = await res.json(); + expect(body.error).toBe("Forbidden"); + }); + + it("returns 200 with memories when caller owns the room", async () => { + const memories = [{ id: "m1", room_id: roomId, content: {}, updated_at: "" }]; + mockValidateHeaders.mockResolvedValueOnce({ accountId }); + mockGetRoom.mockResolvedValueOnce({ + id: roomId, + account_id: accountId, + }); + mockQueryMemories.mockResolvedValueOnce({ + data: memories, + error: null, + }); + + const req = new NextRequest( + `https://example.com/api/memories/get?roomId=${roomId}`, + { headers: { Authorization: "Bearer token" } }, + ); + const res = await GET(req); + expect(res.status).toBe(200); + const body = await res.json(); + expect(body.data).toEqual(memories); + expect(mockQueryMemories).toHaveBeenCalledWith(roomId, { ascending: true }); + }); + + it("returns 400 when queryMemories fails", async () => { + mockValidateHeaders.mockResolvedValueOnce({ accountId }); + mockGetRoom.mockResolvedValueOnce({ + id: roomId, + account_id: accountId, + }); + mockQueryMemories.mockResolvedValueOnce({ + data: null, + error: { message: "db error" }, + }); + + const req = new NextRequest( + `https://example.com/api/memories/get?roomId=${roomId}`, + { headers: { Authorization: "Bearer token" } }, + ); + const res = await GET(req); + expect(res.status).toBe(400); + }); +}); diff --git a/app/api/memories/get/route.ts b/app/api/memories/get/route.ts index b3cb3cbeb..96fed6454 100644 --- a/app/api/memories/get/route.ts +++ b/app/api/memories/get/route.ts @@ -1,5 +1,7 @@ import { NextRequest } from "next/server"; import queryMemories from "@/lib/supabase/queryMemories"; +import { validateHeaders } from "@/lib/chat/validateHeaders"; +import getRoom from "@/lib/supabase/getRoom"; export async function GET(req: NextRequest) { const roomId = req.nextUrl.searchParams.get("roomId"); @@ -8,9 +10,25 @@ export async function GET(req: NextRequest) { return Response.json({ error: "Room ID is required" }, { status: 400 }); } + const authResult = await validateHeaders(req); + if (authResult instanceof Response) { + return authResult; + } + if (!authResult.accountId) { + return Response.json({ error: "Unauthorized" }, { status: 401 }); + } + + const room = await getRoom(roomId); + if (!room) { + return Response.json({ error: "Room not found" }, { status: 404 }); + } + if (room.account_id !== authResult.accountId) { + return Response.json({ error: "Forbidden" }, { status: 403 }); + } + try { const { data, error } = await queryMemories(roomId, { ascending: true }); - + if (error) { throw error; } diff --git a/hooks/useMessageLoader.ts b/hooks/useMessageLoader.ts index fd98abc01..6d2e8a3e9 100644 --- a/hooks/useMessageLoader.ts +++ b/hooks/useMessageLoader.ts @@ -7,12 +7,14 @@ import getClientMessages from "@/lib/supabase/getClientMessages"; * @param roomId - The room ID to load messages from (undefined to skip loading) * @param userId - The current user ID (messages won't load if user is not authenticated) * @param setMessages - Callback function to set the loaded messages + * @param accessToken - Privy access token for /api/memories/get (required for auth) * @returns Loading state and error information */ export function useMessageLoader( roomId: string | undefined, userId: string | undefined, - setMessages: (messages: UIMessage[]) => void + setMessages: (messages: UIMessage[]) => void, + accessToken: string | null, ) { const [isLoading, setIsLoading] = useState(!!roomId); const [error, setError] = useState(null); @@ -28,12 +30,17 @@ export function useMessageLoader( return; } + if (!accessToken) { + setIsLoading(true); + return; + } + const loadMessages = async () => { setIsLoading(true); setError(null); try { - const initialMessages = await getClientMessages(roomId); + const initialMessages = await getClientMessages(roomId, accessToken); if (initialMessages.length > 0) { setMessages(initialMessages as UIMessage[]); } @@ -48,7 +55,7 @@ export function useMessageLoader( }; loadMessages(); - }, [userId, roomId]); + }, [userId, roomId, accessToken]); return { isLoading, diff --git a/hooks/useVercelChat.ts b/hooks/useVercelChat.ts index e50ad1297..c4b7c502c 100644 --- a/hooks/useVercelChat.ts +++ b/hooks/useVercelChat.ts @@ -249,6 +249,7 @@ export function useVercelChat({ messages.length === 0 ? id : undefined, userId, setMessages, + accessToken, ); // Only show loading state if: diff --git a/lib/supabase/getClientMessages.tsx b/lib/supabase/getClientMessages.tsx index da6d95684..cd6ebfb6b 100644 --- a/lib/supabase/getClientMessages.tsx +++ b/lib/supabase/getClientMessages.tsx @@ -1,6 +1,10 @@ -const getClientMessages = async (chatId: string) => { +const getClientMessages = async (chatId: string, accessToken: string) => { try { - const response = await fetch(`/api/memories/get?roomId=${chatId}`); + const response = await fetch(`/api/memories/get?roomId=${chatId}`, { + headers: { + Authorization: `Bearer ${accessToken}`, + }, + }); const data = await response.json(); const memories = data?.data || [];