diff --git a/workspaces/lightspeed/.changeset/many-toys-sing.md b/workspaces/lightspeed/.changeset/many-toys-sing.md new file mode 100644 index 0000000000..f9c953c8fd --- /dev/null +++ b/workspaces/lightspeed/.changeset/many-toys-sing.md @@ -0,0 +1,6 @@ +--- +'@red-hat-developer-hub/backstage-plugin-lightspeed-backend': patch +'@red-hat-developer-hub/backstage-plugin-lightspeed': patch +--- + +Add stop button to interrupt a streaming conversation diff --git a/workspaces/lightspeed/plugins/lightspeed-backend/__fixtures__/lcsHandlers.ts b/workspaces/lightspeed/plugins/lightspeed-backend/__fixtures__/lcsHandlers.ts index e07f70273b..2bec5a88a3 100644 --- a/workspaces/lightspeed/plugins/lightspeed-backend/__fixtures__/lcsHandlers.ts +++ b/workspaces/lightspeed/plugins/lightspeed-backend/__fixtures__/lcsHandlers.ts @@ -126,6 +126,10 @@ export const lcsHandlers: HttpHandler[] = [ return HttpResponse.json(response); }), + http.post(`${LOCAL_LCS_ADDR}/v1/streaming_query/interrupt`, () => { + return HttpResponse.json({ success: true }); + }), + http.get( `${LOCAL_LCS_ADDR}/v2/conversations/:conversation_id`, ({ params }) => { diff --git a/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.test.ts b/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.test.ts index 522e1cdee8..6df5eff2a3 100644 --- a/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.test.ts +++ b/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.test.ts @@ -705,4 +705,30 @@ describe('lightspeed router tests', () => { expect(response.statusCode).toEqual(500); }); }); + + describe('POST /v1/query/interrupt', () => { + it('returns success when interrupt succeeds', async () => { + const backendServer = await startBackendServer(); + + const response = await request(backendServer) + .post('/api/lightspeed/v1/query/interrupt') + .send({ request_id: 'req-123' }); + + expect(response.statusCode).toEqual(200); + expect(response.body).toEqual({ success: true }); + }); + + it('returns 403 when user lacks permission', async () => { + const backendServer = await startBackendServer( + undefined, + AuthorizeResult.DENY, + ); + + const response = await request(backendServer) + .post('/api/lightspeed/v1/query/interrupt') + .send({ request_id: 'req-123' }); + + expect(response.statusCode).toEqual(403); + }); + }); }); diff --git a/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.ts b/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.ts index 1e8e9746df..6c4e51770c 100644 --- a/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.ts +++ b/workspaces/lightspeed/plugins/lightspeed-backend/src/service/router.ts @@ -390,7 +390,11 @@ export async function createRouter( // ─── Proxy Middleware (existing) ──────────────────────────────────── router.use('/', async (req, res, next) => { - const passthroughPaths = ['/v1/query', '/v1/feedback']; + const passthroughPaths = [ + '/v1/query', + '/v1/query/interrupt', + '/v1/feedback', + ]; // Skip middleware for ai-notebooks routes and specific paths if ( req.path.startsWith('/ai-notebooks') || @@ -512,6 +516,47 @@ export async function createRouter( } } }); + + router.post('/v1/query/interrupt', async (request, response) => { + try { + const credentials = await httpAuth.credentials(request); + const userEntity = await userInfo.getUserInfo(credentials); + const user_id = userEntity.userEntityRef; + await authorizer.authorizeUser( + lightspeedChatCreatePermission, + credentials, + ); + const userQueryParam = `user_id=${encodeURIComponent(user_id)}`; + const requestBody = JSON.stringify(request.body); + const fetchResponse = await fetch( + `http://0.0.0.0:${port}/v1/streaming_query/interrupt?${userQueryParam}`, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: requestBody, + }, + ); + if (!fetchResponse.ok) { + const errorBody = await fetchResponse.json(); + const errormsg = `Error from lightspeed-core server: ${errorBody.error?.message || errorBody?.detail?.cause || 'Unknown error'}`; + logger.error(errormsg); + response.status(500).json({ error: errormsg }); + return; + } + response.status(fetchResponse.status).json(await fetchResponse.json()); + } catch (error) { + const errormsg = `Error while interrupting query: ${error}`; + logger.error(errormsg); + if (error instanceof NotAllowedError) { + response.status(403).json({ error: error.message }); + } else { + response.status(500).json({ error: error }); + } + } + }); + router.post( '/v1/query', validateCompletionsRequest, diff --git a/workspaces/lightspeed/plugins/lightspeed/report-alpha.api.md b/workspaces/lightspeed/plugins/lightspeed/report-alpha.api.md index 0f4d2401e4..e396dd9c80 100644 --- a/workspaces/lightspeed/plugins/lightspeed/report-alpha.api.md +++ b/workspaces/lightspeed/plugins/lightspeed/report-alpha.api.md @@ -108,6 +108,7 @@ readonly "conversation.rename": string; readonly "conversation.addToPinnedChats": string; readonly "conversation.removeFromPinnedChats": string; readonly "conversation.announcement.userMessage": string; +readonly "conversation.announcement.responseStopped": string; readonly "user.guest": string; readonly "user.loading": string; readonly "tooltip.attach": string; diff --git a/workspaces/lightspeed/plugins/lightspeed/src/api/LightspeedApiClient.ts b/workspaces/lightspeed/plugins/lightspeed/src/api/LightspeedApiClient.ts index 06080af4fe..a07229f302 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/api/LightspeedApiClient.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/api/LightspeedApiClient.ts @@ -183,6 +183,25 @@ export class LightspeedApiClient implements LightspeedAPI { return response.conversations ?? []; } + async stopMessage(requestId: string): Promise<{ success: boolean }> { + const baseUrl = await this.getBaseUrl(); + const response = await this.fetchApi.fetch( + `${baseUrl}/v1/query/interrupt`, + { + method: 'POST', + headers: { + 'Content-Type': 'application/json', + }, + body: JSON.stringify({ request_id: requestId }), + }, + ); + if (!response.ok) { + throw new Error( + `failed to stop message, status ${response.status}: ${response.statusText}`, + ); + } + return await response.json(); + } async deleteConversation(conversation_id: string) { const baseUrl = await this.getBaseUrl(); diff --git a/workspaces/lightspeed/plugins/lightspeed/src/api/__tests__/LightspeedApiClient.test.ts b/workspaces/lightspeed/plugins/lightspeed/src/api/__tests__/LightspeedApiClient.test.ts index d84a04d6ef..26567e6789 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/api/__tests__/LightspeedApiClient.test.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/api/__tests__/LightspeedApiClient.test.ts @@ -399,6 +399,38 @@ describe('LightspeedApiClient', () => { }); }); + describe('stopMessage', () => { + it('should return success when stop succeeds', async () => { + mockFetchApi.fetch.mockResolvedValue({ + ok: true, + json: jest.fn().mockResolvedValue({ success: true }), + } as unknown as Response); + + const result = await client.stopMessage('req-123'); + + expect(result).toEqual({ success: true }); + expect(mockFetchApi.fetch).toHaveBeenCalledWith( + 'http://localhost:7007/api/lightspeed/v1/query/interrupt', + expect.objectContaining({ + method: 'POST', + body: JSON.stringify({ request_id: 'req-123' }), + }), + ); + }); + + it('should throw error when stop fails', async () => { + mockFetchApi.fetch.mockResolvedValue({ + ok: false, + status: 500, + statusText: 'Internal Server Error', + } as unknown as Response); + + await expect(client.stopMessage('req-123')).rejects.toThrow( + 'failed to stop message, status 500: Internal Server Error', + ); + }); + }); + describe('createMessage', () => { it('should return readable stream reader when message is created', async () => { const mockReader = { diff --git a/workspaces/lightspeed/plugins/lightspeed/src/api/api.ts b/workspaces/lightspeed/plugins/lightspeed/src/api/api.ts index bf88232182..f1f659ad34 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/api/api.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/api/api.ts @@ -50,6 +50,7 @@ export type LightspeedAPI = { getFeedbackStatus: () => Promise; captureFeedback: (payload: CaptureFeedback) => Promise<{ response: string }>; isTopicRestrictionEnabled: () => Promise; + stopMessage: (requestId: string) => Promise<{ success: boolean }>; }; /** diff --git a/workspaces/lightspeed/plugins/lightspeed/src/components/LightSpeedChat.tsx b/workspaces/lightspeed/plugins/lightspeed/src/components/LightSpeedChat.tsx index c6564a85e5..b63210a772 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/components/LightSpeedChat.tsx +++ b/workspaces/lightspeed/plugins/lightspeed/src/components/LightSpeedChat.tsx @@ -83,6 +83,7 @@ import { useNotebookSessions, usePinnedChatsSettings, useSortSettings, + useStopConversation, } from '../hooks'; import { useLightspeedDrawerContext } from '../hooks/useLightspeedDrawerContext'; import { useLightspeedUpdatePermission } from '../hooks/useLightspeedUpdatePermission'; @@ -370,6 +371,7 @@ export const LightspeedChat = ({ [], ); const [conversationId, setConversationId] = useState(''); + const [requestId, setRequestId] = useState(''); const [newChatCreated, setNewChatCreated] = useState(false); const [isSendButtonDisabled, setIsSendButtonDisabled] = useState(false); @@ -379,6 +381,8 @@ export const LightspeedChat = ({ const [isSortSelectOpen, setIsSortSelectOpen] = useState(false); const contentScrollRef = useRef(null); const bottomSentinelRef = useRef(null); + const [messageBarKey, setMessageBarKey] = useState(0); + const wasStoppedByUserRef = useRef(false); const { isReady, lastOpenedId, setLastOpenedId, clearLastOpenedId } = useLastOpenedConversation(user); const { @@ -528,9 +532,16 @@ export const LightspeedChat = ({ setCurrentConversationId(conv_id); }; + const onRequestIdReady = (request_id: string) => { + setRequestId(request_id); + }; + const onComplete = (message: string) => { setIsSendButtonDisabled(false); - setAnnouncement(`Message from Bot: ${message}`); + if (!wasStoppedByUserRef.current) { + setAnnouncement(`Message from Bot: ${message}`); + } + wasStoppedByUserRef.current = false; queryClient.invalidateQueries({ queryKey: ['conversations'], }); @@ -549,12 +560,16 @@ export const LightspeedChat = ({ avatar, onComplete, onStart, + onRequestIdReady, ); const [messages, setMessages] = useState(conversationMessages); const sendMessage = (message: string | number) => { + if (!message.toString().trim()) return; + + wasStoppedByUserRef.current = false; if (conversationId !== TEMP_CONVERSATION_ID) { setNewChatCreated(false); } @@ -693,7 +708,7 @@ export const LightspeedChat = ({ const filteredConversations = Object.entries(categorizedMessages).reduce( (acc, [key, items]) => { const filteredItems = items.filter(item => - item.text + (item.text ?? '') .toLocaleLowerCase('en-US') .includes(targetValue.toLocaleLowerCase('en-US')), ); @@ -952,6 +967,26 @@ export const LightspeedChat = ({ handleFileUpload(data); }; + const { mutate: stopConversation } = useStopConversation(); + + const handleStopButton = () => { + wasStoppedByUserRef.current = true; + if (requestId) { + stopConversation(requestId); + setRequestId(''); + } + setIsSendButtonDisabled(false); + setAnnouncement(t('conversation.announcement.responseStopped')); + const lastUserMessage = [...conversationMessages] + .reverse() + .find((m: { role?: string }) => m.role === 'user'); + const restoredPrompt = (lastUserMessage?.content as string) ?? ''; + setDraftMessage(restoredPrompt.trim()); + if (restoredPrompt) setMessageBarKey(k => k + 1); + setFileContents([]); + setUploadError({ message: null }); + }; + const handleDraftMessage = ( _e: ChangeEvent, value: string | number, @@ -1186,6 +1221,7 @@ export const LightspeedChat = ({ c.conversation_id === conversationId, ); if (conversation) { - setChatName(conversation.topic_summary); - setOriginalChatName(conversation.topic_summary); + setChatName(conversation.topic_summary ?? ''); + setOriginalChatName(conversation.topic_summary ?? ''); } else { setChatName(''); setOriginalChatName(''); diff --git a/workspaces/lightspeed/plugins/lightspeed/src/components/__tests__/LightspeedChat.test.tsx b/workspaces/lightspeed/plugins/lightspeed/src/components/__tests__/LightspeedChat.test.tsx index b2716bf231..1af218e157 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/components/__tests__/LightspeedChat.test.tsx +++ b/workspaces/lightspeed/plugins/lightspeed/src/components/__tests__/LightspeedChat.test.tsx @@ -148,6 +148,7 @@ const mockLightspeedApi = { getFeedbackStatus: jest.fn().mockResolvedValue(false), captureFeedback: jest.fn().mockResolvedValue({ response: 'success' }), isTopicRestrictionEnabled: jest.fn().mockResolvedValue(false), + stopMessage: jest.fn().mockResolvedValue({ success: true }), }; const setupLightspeedChat = () => ( diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversationMessages.test.tsx b/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversationMessages.test.tsx index 6211584b4b..68c5ad6a93 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversationMessages.test.tsx +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversationMessages.test.tsx @@ -19,6 +19,7 @@ import { useApi } from '@backstage/core-plugin-api'; import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; import { act, renderHook, waitFor } from '@testing-library/react'; +import { TEMP_CONVERSATION_ID } from '../../const'; import { getTimestamp } from '../../utils/lightspeed-chatbox-utils'; import { useConversationMessages, @@ -636,4 +637,68 @@ data: {"event": "token", "data": {"id": 2, "token": ""}}\n ); }); }); + + it('should handle interrupted event and migrate temp conversation to new id', async () => { + const newConversationId = 'interrupted-conv-123'; + const onComplete = jest.fn(); + const onStart = jest.fn(); + + mockLightspeedApi.getConversationMessages.mockResolvedValue([]); + + const interruptedStream = createSSEStream([ + { event: 'start', data: { conversation_id: newConversationId } }, + { event: 'token', data: { id: 0, token: 'Partial ', role: 'inference' } }, + { + event: 'interrupted', + data: { conversation_id: newConversationId }, + }, + ]); + + const lightSpeedApi = { + ...mockLightspeedApi, + createMessage: jest.fn().mockResolvedValue({ + read: jest + .fn() + .mockResolvedValueOnce({ + done: false, + value: new TextEncoder().encode(interruptedStream), + }) + .mockResolvedValueOnce({ done: true, value: null }), + }), + }; + + (useApi as jest.Mock).mockReturnValue(lightSpeedApi); + (getTimestamp as jest.Mock).mockReturnValue('01/01/2024, 10:00:00'); + + const { result } = renderHook( + () => + useConversationMessages( + TEMP_CONVERSATION_ID, + 'test-user', + 'gpt-3', + 'openai', + 'user.png', + onComplete, + onStart, + ), + { wrapper }, + ); + + await act(async () => { + await result.current.handleInputPrompt('Hello'); + }); + + await waitFor(() => { + expect(onStart).toHaveBeenCalledWith(newConversationId); + }); + expect(onComplete).toHaveBeenCalledWith('Partial '); + expect(result.current.conversations[newConversationId]).toBeDefined(); + expect(result.current.conversations[newConversationId]).toHaveLength(2); + // Temp removal is deferred via setTimeout(0); wait for it + await waitFor(() => { + expect( + result.current.conversations[TEMP_CONVERSATION_ID], + ).toBeUndefined(); + }); + }); }); diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversations.test.tsx b/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversations.test.tsx new file mode 100644 index 0000000000..5e31ebc433 --- /dev/null +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useConversations.test.tsx @@ -0,0 +1,152 @@ +/* + * Copyright Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { useApi } from '@backstage/core-plugin-api'; + +import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; +import { act, renderHook, waitFor } from '@testing-library/react'; + +import { useConversations } from '../useConversations'; + +jest.mock('@backstage/core-plugin-api', () => ({ + ...jest.requireActual('@backstage/core-plugin-api'), + useApi: jest.fn(), +})); + +const mockGetConversations = jest.fn(); + +const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + }, +}); + +const wrapper = ({ children }: { children?: React.ReactNode }): any => ( + {children} +); + +describe('useConversations', () => { + beforeEach(() => { + jest.clearAllMocks(); + queryClient.clear(); + }); + + it('should fetch conversations successfully', async () => { + const mockData = [ + { + conversation_id: 'conv-1', + last_message_timestamp: 1234567890, + topic_summary: 'Test conversation', + }, + ]; + mockGetConversations.mockResolvedValue(mockData); + + (useApi as jest.Mock).mockReturnValue({ + getConversations: mockGetConversations, + }); + + const { result } = renderHook(() => useConversations(), { wrapper }); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + expect(result.current.data).toEqual(mockData); + expect(mockGetConversations).toHaveBeenCalledTimes(1); + }); + + it('should refetch when topic_summary is null', async () => { + const conversationsWithNullSummary = [ + { + conversation_id: 'conv-1', + last_message_timestamp: 1234567890, + topic_summary: null, + }, + ]; + const conversationsWithSummary = [ + { + conversation_id: 'conv-1', + last_message_timestamp: 1234567890, + topic_summary: 'Generated summary', + }, + ]; + + mockGetConversations + .mockResolvedValueOnce(conversationsWithNullSummary) + .mockResolvedValueOnce(conversationsWithSummary); + + (useApi as jest.Mock).mockReturnValue({ + getConversations: mockGetConversations, + }); + + jest.useFakeTimers(); + + const { result } = renderHook(() => useConversations(), { wrapper }); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + expect(result.current.data).toEqual(conversationsWithNullSummary); + expect(mockGetConversations).toHaveBeenCalledTimes(1); + + await act(async () => { + jest.advanceTimersByTime(2000); + }); + + await waitFor( + () => { + expect(mockGetConversations).toHaveBeenCalledTimes(2); + }, + { timeout: 3000 }, + ); + + await waitFor( + () => { + expect(result.current.data?.[0]?.topic_summary).toBe( + 'Generated summary', + ); + }, + { timeout: 3000 }, + ); + + jest.useRealTimers(); + }); + + it('should not refetch when all topic_summary are set', async () => { + const mockData = [ + { + conversation_id: 'conv-1', + last_message_timestamp: 1234567890, + topic_summary: 'Has summary', + }, + ]; + mockGetConversations.mockResolvedValue(mockData); + + (useApi as jest.Mock).mockReturnValue({ + getConversations: mockGetConversations, + }); + + jest.useFakeTimers(); + + const { result } = renderHook(() => useConversations(), { wrapper }); + + await waitFor(() => expect(result.current.isSuccess).toBe(true)); + expect(mockGetConversations).toHaveBeenCalledTimes(1); + + act(() => { + jest.advanceTimersByTime(5000); + }); + + expect(mockGetConversations).toHaveBeenCalledTimes(1); + + jest.useRealTimers(); + }); +}); diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useStopConversation.test.tsx b/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useStopConversation.test.tsx new file mode 100644 index 0000000000..55502fd54e --- /dev/null +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/__tests__/useStopConversation.test.tsx @@ -0,0 +1,92 @@ +/* + * Copyright Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { useApi } from '@backstage/core-plugin-api'; + +import { QueryClient, QueryClientProvider } from '@tanstack/react-query'; +import { act, renderHook } from '@testing-library/react'; + +import { useStopConversation } from '../useStopConversation'; + +jest.mock('@backstage/core-plugin-api', () => ({ + ...jest.requireActual('@backstage/core-plugin-api'), + useApi: jest.fn(), +})); + +const mockStopMessage = jest.fn(); + +const queryClient = new QueryClient({ + defaultOptions: { + queries: { retry: false }, + }, +}); + +const wrapper = ({ children }: { children?: React.ReactNode }): any => ( + {children} +); + +describe('useStopConversation', () => { + beforeEach(() => { + mockStopMessage.mockResolvedValue({ success: true }); + (useApi as jest.Mock).mockReturnValue({ + stopMessage: mockStopMessage, + }); + }); + + afterEach(() => { + jest.clearAllMocks(); + }); + + it('calls stopMessage API with requestId', async () => { + const requestId = 'req-123'; + + const { result } = renderHook(() => useStopConversation(), { wrapper }); + + await act(async () => { + result.current.mutate(requestId); + }); + + expect(mockStopMessage).toHaveBeenCalledWith(requestId); + }); + + it('returns success when stop succeeds', async () => { + mockStopMessage.mockResolvedValue({ success: true }); + + const { result } = renderHook(() => useStopConversation(), { wrapper }); + + let mutateResult: { success: boolean } | undefined; + await act(async () => { + mutateResult = await result.current.mutateAsync('req-456'); + }); + + expect(mutateResult).toEqual({ success: true }); + }); + + it('handles API errors', async () => { + mockStopMessage.mockRejectedValue(new Error('Stop failed')); + + const consoleSpy = jest.spyOn(console, 'warn').mockImplementation(); + + const { result } = renderHook(() => useStopConversation(), { wrapper }); + + await act(async () => { + result.current.mutate('req-789'); + }); + + expect(consoleSpy).toHaveBeenCalled(); + consoleSpy.mockRestore(); + }); +}); diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/index.ts b/workspaces/lightspeed/plugins/lightspeed/src/hooks/index.ts index f129c777a7..474e11d3d8 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/hooks/index.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/index.ts @@ -29,4 +29,5 @@ export * from './useNotebookSessions'; export * from './usePinnedChatsSettings'; export * from './useRenameNotebook'; export * from './useSortSettings'; +export * from './useStopConversation'; export * from './useTranslation'; diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversationMessages.ts b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversationMessages.ts index ca60832fa0..36e403e5f0 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversationMessages.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversationMessages.ts @@ -103,6 +103,7 @@ export const useConversationMessages = ( avatar: string = userAvatar, onComplete?: (message: string) => void, onStart?: (conversation_id: string) => void, + onRequestIdReady?: (request_id: string) => void, ): UseConversationMessagesReturn => { const { mutateAsync: createMessage } = useCreateConversationMessage(); const scrollToBottomRef = useRef(null); @@ -144,11 +145,7 @@ export const useConversationMessages = ( useFetchConversationMessages(currentConversation); useEffect(() => { - if ( - !Array.isArray(conversationsData) || - (conversationsData.length === 0 && - conversationId !== TEMP_CONVERSATION_ID) - ) + if (!Array.isArray(conversationsData) || conversationsData.length === 0) return; const newConvoIndex: number[] = []; @@ -218,6 +215,7 @@ export const useConversationMessages = ( const handleInputPrompt = useCallback( async (prompt: string, attachments: Attachment[] = []) => { let newConversationId = ''; + let requestId = ''; const conversationTuple = [ createUserMessage({ @@ -266,11 +264,14 @@ export const useConversationMessages = ( }); const decoder = new TextDecoder('utf-8'); - const keepGoing = true; + let streamEnded = false; - while (keepGoing) { + while (!streamEnded) { const { value, done } = await reader.read(); - if (done) break; + if (done) { + streamEnded = true; + break; + } buffer += decoder.decode(value, { stream: true }); @@ -289,6 +290,9 @@ export const useConversationMessages = ( try { const { event, data } = JSON.parse(jsonString); if (event === 'start') { + requestId = data?.request_id; + onRequestIdReady?.(requestId); + if (currentConversation === TEMP_CONVERSATION_ID) { // If the conversation is temp, we need to set the new conversation id newConversationId = data?.conversation_id; @@ -493,6 +497,38 @@ export const useConversationMessages = ( }); } + if (event === 'interrupted') { + if ( + currentConversation === TEMP_CONVERSATION_ID && + data?.conversation_id + ) { + newConversationId = data.conversation_id; + } + setConversations(prevConversations => { + const conversation = + prevConversations[currentConversation] ?? []; + const lastMessageIndex = conversation.length - 1; + const lastMessage = + conversation.length === 0 + ? createBotMessage({ + content: '', + isLoading: false, + timestamp: getTimestamp(Date.now()), + }) + : { ...conversation[lastMessageIndex], isLoading: false }; + const updatedConversation = [ + ...conversation.slice(0, lastMessageIndex), + lastMessage, + ]; + return { + ...prevConversations, + [currentConversation]: updatedConversation, + }; + }); + streamEnded = true; + break; + } + if (event === 'end') { const documents = data?.referenced_documents || []; @@ -539,6 +575,7 @@ export const useConversationMessages = ( } } } + if (streamEnded) break; } } catch (e) { setConversations(prevConversations => { @@ -602,10 +639,13 @@ export const useConversationMessages = ( onStart?.(newConversationId); - setConversations(prev => { - const { temp, ...rest } = prev; - return rest; - }); + // Defer removal so it runs after the sync useEffect updates currentConversation. + setTimeout(() => { + setConversations(prev => { + const { [TEMP_CONVERSATION_ID]: _, ...rest } = prev; + return rest; + }); + }, 0); } }, @@ -614,6 +654,7 @@ export const useConversationMessages = ( userName, onComplete, onStart, + onRequestIdReady, selectedModel, selectedProvider, createMessage, diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversations.ts b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversations.ts index a9df40352b..eeace3584d 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversations.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useConversations.ts @@ -30,6 +30,12 @@ export const useConversations = (): UseQueryResult => { const response = await lightspeedApi.getConversations(); return response; }, + refetchInterval: query => { + const data = query.state.data; + if (!data?.length) return false; + const hasNullSummary = data.some(c => !c.topic_summary); + return hasNullSummary ? 2000 : false; + }, staleTime: 1000 * 60 * 5, // 5 minutes }); }; diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useCreateCoversationMessage.ts b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useCreateCoversationMessage.ts index f73590a8f6..12a3147383 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useCreateCoversationMessage.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useCreateCoversationMessage.ts @@ -43,13 +43,7 @@ export const useCreateConversationMessage = (): UseMutationResult< selectedProvider, currentConversation, attachments, - }: { - prompt: string; - selectedModel: string; - selectedProvider: string; - currentConversation: string; - attachments: Attachment[]; - }) => { + }: CreateMessageVariables) => { if (!currentConversation) { throw new Error('Failed to generate AI response'); } diff --git a/workspaces/lightspeed/plugins/lightspeed/src/hooks/useStopConversation.ts b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useStopConversation.ts new file mode 100644 index 0000000000..98ae2a7797 --- /dev/null +++ b/workspaces/lightspeed/plugins/lightspeed/src/hooks/useStopConversation.ts @@ -0,0 +1,39 @@ +/* + * Copyright Red Hat, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import { useApi } from '@backstage/core-plugin-api'; + +import { useMutation, type UseMutationResult } from '@tanstack/react-query'; + +import { lightspeedApiRef } from '../api/api'; + +export const useStopConversation = (): UseMutationResult< + { success: boolean }, + Error, + string +> => { + const lightspeedApi = useApi(lightspeedApiRef); + + return useMutation({ + mutationFn: async (requestId: string) => { + return await lightspeedApi.stopMessage(requestId); + }, + onError: error => { + // eslint-disable-next-line no-console + console.warn(error); + }, + }); +}; diff --git a/workspaces/lightspeed/plugins/lightspeed/src/translations/de.ts b/workspaces/lightspeed/plugins/lightspeed/src/translations/de.ts index 0b4534c0d8..1ad0074f5e 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/translations/de.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/translations/de.ts @@ -151,6 +151,7 @@ const lightspeedTranslationDe = createTranslationMessages({ 'conversation.announcement.userMessage': 'Nachricht vom Benutzer: {{prompt}}. Nachricht vom Bot wird geladen.', 'user.guest': 'Gast', + 'conversation.announcement.responseStopped': 'Antwort angehalten.', 'user.loading': '...', 'tooltip.attach': 'Anhängen', 'tooltip.send': 'Senden', diff --git a/workspaces/lightspeed/plugins/lightspeed/src/translations/es.ts b/workspaces/lightspeed/plugins/lightspeed/src/translations/es.ts index 68da8ee8c7..14a541b667 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/translations/es.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/translations/es.ts @@ -151,6 +151,7 @@ const lightspeedTranslationEs = createTranslationMessages({ 'conversation.announcement.userMessage': 'Mensaje del usuario: {{prompt}}. El mensaje del bot se está cargando.', 'user.guest': 'Invitado', + 'conversation.announcement.responseStopped': 'Respuesta detenida.', 'user.loading': '...', 'tooltip.attach': 'Adjuntar', 'tooltip.send': 'Enviar', diff --git a/workspaces/lightspeed/plugins/lightspeed/src/translations/fr.ts b/workspaces/lightspeed/plugins/lightspeed/src/translations/fr.ts index 41c56381ec..99c555bfc4 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/translations/fr.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/translations/fr.ts @@ -149,6 +149,7 @@ const lightspeedTranslationFr = createTranslationMessages({ 'conversation.removeFromPinnedChats': 'Détacher', 'conversation.announcement.userMessage': 'Message en provenance de l’utilisateur: {{prompt}}. Message en provenance du Bot en cours de chargement.', + 'conversation.announcement.responseStopped': 'Réponse arrêtée.', 'user.guest': 'Invité', 'user.loading': '...', 'tooltip.attach': 'Attacher', diff --git a/workspaces/lightspeed/plugins/lightspeed/src/translations/it.ts b/workspaces/lightspeed/plugins/lightspeed/src/translations/it.ts index 62701ca383..0b01a2fe3f 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/translations/it.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/translations/it.ts @@ -150,6 +150,7 @@ const lightspeedTranslationIt = createTranslationMessages({ 'conversation.removeFromPinnedChats': 'Sblocca', 'conversation.announcement.userMessage': "Messaggio dall'utente: {{prompt}}. Caricamento in corso del messaggio del bot.", + 'conversation.announcement.responseStopped': 'Risposta interrotta.', 'user.guest': 'Ospite', 'user.loading': '...', 'tooltip.attach': 'Allega', diff --git a/workspaces/lightspeed/plugins/lightspeed/src/translations/ja.ts b/workspaces/lightspeed/plugins/lightspeed/src/translations/ja.ts index e8e866cc0b..448841a4c0 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/translations/ja.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/translations/ja.ts @@ -148,6 +148,7 @@ const lightspeedTranslationJa = createTranslationMessages({ 'conversation.removeFromPinnedChats': '固定解除', 'conversation.announcement.userMessage': 'ユーザーからのメッセージ: {{prompt}}。ボットからのメッセージを読み込んでいます。', + 'conversation.announcement.responseStopped': '応答を停止しました。', 'user.guest': 'ゲスト', 'user.loading': '...', 'tooltip.attach': '割り当て', diff --git a/workspaces/lightspeed/plugins/lightspeed/src/translations/ref.ts b/workspaces/lightspeed/plugins/lightspeed/src/translations/ref.ts index 393e55fd95..e18336c134 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/translations/ref.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/translations/ref.ts @@ -168,6 +168,7 @@ export const lightspeedMessages = { 'conversation.removeFromPinnedChats': 'Unpin', 'conversation.announcement.userMessage': 'Message from User: {{prompt}}. Message from Bot is loading.', + 'conversation.announcement.responseStopped': 'Response stopped.', // User states 'user.guest': 'Guest', diff --git a/workspaces/lightspeed/plugins/lightspeed/src/types.ts b/workspaces/lightspeed/plugins/lightspeed/src/types.ts index aacd36dc03..84ffbea287 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/types.ts +++ b/workspaces/lightspeed/plugins/lightspeed/src/types.ts @@ -112,7 +112,7 @@ export interface BaseMessage { export type ConversationSummary = { conversation_id: string; last_message_timestamp: number; - topic_summary: string; + topic_summary: string | null; }; export enum SupportedFileType { diff --git a/workspaces/lightspeed/plugins/lightspeed/src/utils/lightspeed-chatbox-utils.tsx b/workspaces/lightspeed/plugins/lightspeed/src/utils/lightspeed-chatbox-utils.tsx index aeef450a89..8b5629d24f 100644 --- a/workspaces/lightspeed/plugins/lightspeed/src/utils/lightspeed-chatbox-utils.tsx +++ b/workspaces/lightspeed/plugins/lightspeed/src/utils/lightspeed-chatbox-utils.tsx @@ -15,6 +15,7 @@ */ import PushPinIcon from '@mui/icons-material/PushPin'; import { Conversation, SourcesCardProps } from '@patternfly/chatbot'; +import { Spinner } from '@patternfly/react-core'; import { BaseMessage, @@ -179,15 +180,17 @@ const sortConversations = ( sortOption: SortOption, ): ConversationList => { return [...messages].sort((a, b) => { + const aTopicSummary = a.topic_summary || ''; + const bTopicSummary = b.topic_summary || ''; switch (sortOption) { case 'oldest': return a.last_message_timestamp - b.last_message_timestamp; case 'alphabeticalAsc': - return a.topic_summary.localeCompare(b.topic_summary, undefined, { + return aTopicSummary.localeCompare(bTopicSummary, undefined, { sensitivity: 'base', }); case 'alphabeticalDesc': - return b.topic_summary.localeCompare(a.topic_summary, undefined, { + return bTopicSummary.localeCompare(aTopicSummary, undefined, { sensitivity: 'base', }); case 'newest': @@ -214,7 +217,8 @@ export const getCategorizeMessages = ( sortedMessages.forEach(c => { const message: Conversation = { id: c.conversation_id, - text: c.topic_summary, + text: c.topic_summary ?? '', + icon: c.topic_summary ? undefined : , label: t?.('message.options.label') || 'Options', additionalProps: { 'aria-label': t?.('aria.options.label') || 'Options',