diff --git a/internal/record/recording_https_proxy.go b/internal/record/recording_https_proxy.go index 796cd84..6ec51c8 100644 --- a/internal/record/recording_https_proxy.go +++ b/internal/record/recording_https_proxy.go @@ -70,16 +70,22 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String()) - reqHash, err := r.recordRequest(req) + recReq, err := r.recordRequest(req) if err != nil { fmt.Printf("Error recording request: %v\n", err) http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError) return } + fileName, err := recReq.GetRecordingFileName() + if err != nil { + fmt.Printf("Invalid recording file name: %v\n", err) + http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError) + return + } if req.Header.Get("Upgrade") == "websocket" { fmt.Printf("Upgrading connection to websocket...\n") - r.proxyWebsocket(w, req, reqHash) + r.proxyWebsocket(w, req, fileName) return } @@ -90,7 +96,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req return } - err = r.recordResponse(resp, reqHash, respBody) + err = r.recordResponse(resp, fileName, respBody) if err != nil { fmt.Printf("Error recording response: %v\n", err) @@ -99,10 +105,10 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } } -func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (string, error) { +func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedRequest, error) { recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config) if err != nil { - return "", err + return recordedRequest, err } // Redact headers by key @@ -112,17 +118,17 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (string, error) { recordedRequest.Request = r.redactor.String(recordedRequest.Request) recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body) - reqHash, err := recordedRequest.ComputeSum() + fileName, err := recordedRequest.GetRecordingFileName() if err != nil { - return "", err + fmt.Printf("Invalid recording file name: %v\n", err) + return recordedRequest, err } - - recordPath := filepath.Join(r.recordingDir, reqHash+".req") + recordPath := filepath.Join(r.recordingDir, fileName+".req") err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644) if err != nil { - return "", err + return recordedRequest, err } - return reqHash, nil + return recordedRequest, nil } func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Request) (*http.Response, []byte, error) { @@ -172,7 +178,7 @@ func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Requ return resp, respBodyBytes, nil } -func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, reqHash string, body []byte) error { +func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, fileName string, body []byte) error { recordedResponse, err := store.NewRecordedResponse(resp, body) if err != nil { return err @@ -180,7 +186,7 @@ func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, reqHash string recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body) - recordPath := filepath.Join(r.recordingDir, reqHash+".resp") + recordPath := filepath.Join(r.recordingDir, fileName+".resp") fmt.Printf("Writing response to: %s\n", recordPath) err = os.WriteFile(recordPath, []byte(recordedResponse.Serialize()), 0644) if err != nil { @@ -209,7 +215,7 @@ func replaceRegex(s, regex, replacement string) string { return re.ReplaceAllString(s, replacement) } -func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Request, reqHash string) { +func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Request, fileName string) { conn, clientConn, err := r.upgradeConnectionToWebsocket(w, req) if err != nil { http.Error(w, fmt.Sprintf("Error proxying websocket: %v", err), http.StatusInternalServerError) @@ -224,7 +230,7 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re go pumpWebsocket(clientConn, conn, c, quit, ">") go pumpWebsocket(conn, clientConn, c, quit, "<") - recordPath := filepath.Join(r.recordingDir, reqHash+".websocket") + recordPath := filepath.Join(r.recordingDir, fileName+".websocket") f, err := os.Create(recordPath) if err != nil { fmt.Printf("Error creating websocket recording file: %v\n", err) @@ -262,7 +268,8 @@ func pumpWebsocket(src, dst *websocket.Conn, c chan []byte, quit chan int, prepe quit <- 1 return } - prefix := fmt.Sprintf("%s%d", prepend, cap(buf)) + buf = append(buf, '\n') + prefix := fmt.Sprintf("%s%d ", prepend, len(buf)) c <- append([]byte(prefix), buf...) err = dst.WriteMessage(msgType, buf) if err != nil { @@ -286,6 +293,7 @@ func (r *RecordingHTTPSProxy) upgradeConnectionToWebsocket(w http.ResponseWriter "Sec-Websocket-Extensions": true, "Connection": true, "Upgrade": true, + "Test-Name": true, } for k, v := range req.Header { if _, ok := excludedHeaders[k]; ok { diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index bf4bedd..a3d735f 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -21,10 +21,14 @@ import ( "net/http" "os" "path/filepath" + "strconv" + "strings" + "unicode" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" "github.com/google/test-server/internal/store" + "github.com/gorilla/websocket" ) type ReplayHTTPServer struct { @@ -68,22 +72,34 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques return } fmt.Printf("Replaying request: %ss\n", redactedReq.Request) - - reqHash, err := redactedReq.ComputeSum() + fileName, err := redactedReq.GetRecordingFileName() if err != nil { - fmt.Printf("Error computing request sum: %v\n", err) - http.Error(w, fmt.Sprintf("Error computing request sum: %v", err), http.StatusInternalServerError) + fmt.Printf("Invalid recording file name: %v\n", err) + http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError) return } + if req.Header.Get("Upgrade") == "websocket" { + fmt.Printf("Upgrading connection to websocket...\n") - resp, err := r.loadResponse(reqHash) + chunks, err := r.loadWebsocketChunks(fileName) + if err != nil { + fmt.Printf("Error loading websocket response: %v\n", err) + http.Error(w, fmt.Sprintf("Error loading websocket response: %v", err), http.StatusInternalServerError) + return + } + fmt.Printf("Replaying websocket: %s\n", fileName) + r.proxyWebsocket(w, req, chunks) + return + } + fmt.Printf("Replaying http request: %s\n", redactedReq.Request) + resp, err := r.loadResponse(fileName) if err != nil { fmt.Printf("Error loading response: %v\n", err) http.Error(w, fmt.Sprintf("Error loading response: %v", err), http.StatusInternalServerError) return } - err = r.writedResponse(w, resp) + err = r.writeResponse(w, resp) if err != nil { fmt.Printf("Error writing response: %v\n", err) panic(err) @@ -106,8 +122,8 @@ func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.Reco return recordedRequest, nil } -func (r *ReplayHTTPServer) loadResponse(sha string) (*store.RecordedResponse, error) { - responseFile := filepath.Join(r.recordingDir, sha+".resp") +func (r *ReplayHTTPServer) loadResponse(fileName string) (*store.RecordedResponse, error) { + responseFile := filepath.Join(r.recordingDir, fileName+".resp") fmt.Printf("loading response from : %s\n", responseFile) responseData, err := os.ReadFile(responseFile) if err != nil { @@ -116,7 +132,7 @@ func (r *ReplayHTTPServer) loadResponse(sha string) (*store.RecordedResponse, er return store.DeserializeResponse(responseData) } -func (r *ReplayHTTPServer) writedResponse(w http.ResponseWriter, resp *store.RecordedResponse) error { +func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.RecordedResponse) error { for key, values := range resp.Header { for _, value := range values { if key == "Content-Length" || key == "Content-Encoding" { @@ -131,3 +147,130 @@ func (r *ReplayHTTPServer) writedResponse(w http.ResponseWriter, resp *store.Rec _, err := w.Write(resp.Body) return err } + +func extractNumber(i *int, content string) (int, error) { + numStart := *i + for *i < len(content) && unicode.IsDigit(rune(content[*i])) { + *i++ + } + numEnd := *i + if numStart == numEnd { + return 0, fmt.Errorf("missing chunk length after prefix at position %d", numStart-1) + } + numStr := content[numStart:numEnd] + num, err := strconv.Atoi(numStr) + if err != nil { + return 0, fmt.Errorf("invalid chunk length '%s': %w", numStr, err) + } + return num, nil +} + +func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Request, chunks []string) { + clientConn, err := r.upgradeConnectionToWebsocket(w, req) + if err != nil { + http.Error(w, fmt.Sprintf("Error proxying websocket: %v", err), http.StatusInternalServerError) + return + } + defer clientConn.Close() + replayWebsocket(clientConn, chunks) +} + +func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) { + responseFile := filepath.Join(r.recordingDir, sha+".websocket") + fmt.Printf("loading websocket response from : %s\n", responseFile) + bytes, err := os.ReadFile(responseFile) + var chunks = make([]string, 0) + if err != nil { + fmt.Printf("Error loading websocket response: %v\n", err) + return chunks, err + } + + i := 0 + response := string(bytes) + for i < len(response) { + // Extracts prefix + prefix := response[i] + if prefix != '>' && prefix != '<' { + return nil, fmt.Errorf("invalid message prefix at position %d: expected '>' or '<', got '%c'", i, prefix) + } + i++ // Move cursor past prefix. + + // Extracts chunk length number + num, err := extractNumber(&i, response) + i++ // Move cursor to skip the whitespace between the number and the actual chunk. + if err != nil { + return nil, fmt.Errorf("failed to extract number %v", err) + } + + // Extracts chunk + chunkStart := i + chunkEnd := chunkStart + num + if chunkEnd > len(response) { + return nil, fmt.Errorf("chunk length %d at position %d exceeds response bounds", chunkEnd, chunkStart) + } + chunk := response[chunkStart : chunkEnd-1] // Remove the \n appended at the end of the chunk + chunks = append(chunks, string(prefix)+chunk) + i = chunkEnd + } + return chunks, nil +} + +func replayWebsocket(conn *websocket.Conn, chunks []string) { + for _, chunk := range chunks { + if strings.HasPrefix(chunk, ">") { + _, buf, err := conn.ReadMessage() + reqChunk := string(buf) + if err != nil { + fmt.Printf("Error reading from websocket: %v\n", err) + return + } + + runes := []rune(chunk) + recChunk := string(runes[1:]) + if reqChunk != recChunk { + fmt.Printf("input chunk mismatch\n Input chunk: %s\n Recorded chunk: %s\n", reqChunk, recChunk) + writeError(conn, "input chunk mismatch") + return + } + } else if strings.HasPrefix(chunk, "<") { + runes := []rune(chunk) + recChunk := string(runes[1:]) + // Write binary message. (messageType=2) + err := conn.WriteMessage(2, []byte(recChunk)) + if err != nil { + fmt.Printf("Error writing to websocket: %v\n", err) + return + } + } else { + fmt.Printf("Unreconginized chunk: %s", chunk) + return + } + } +} + +func writeError(conn *websocket.Conn, errMsg string) { + closeMessage := websocket.FormatCloseMessage( + websocket.CloseInternalServerErr, + errMsg, + ) + err := conn.WriteMessage(websocket.CloseMessage, closeMessage) + if err != nil { + fmt.Printf("Failed to write error: %v\n", err) + } +} + +func (r *ReplayHTTPServer) upgradeConnectionToWebsocket(w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins + }, + } + + clientConn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + return nil, err + } + return clientConn, err +} diff --git a/internal/store/store.go b/internal/store/store.go index 9f58197..25b9090 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -85,11 +85,26 @@ func readBody(req *http.Request) ([]byte, error) { } // ComputeSum computes the SHA256 sum of a RecordedRequest. -func (r *RecordedRequest) ComputeSum() (string, error) { +func (r *RecordedRequest) ComputeSum() string { serialized := r.Serialize() hash := sha256.Sum256([]byte(serialized)) hashHex := hex.EncodeToString(hash[:]) - return hashHex, nil + return hashHex +} + +// GetRecordingFileName returns the recording file name. +// It prefers the value from the TEST_NAME header. +// It returns error when test name contains illegal sequence. +// If the TEST_NAME header is not present, it falls back to computed SHA256 sum. +func (r *RecordedRequest) GetRecordingFileName() (string, error) { + testName := r.Header.Get("Test-Name") + if strings.Contains(testName, "../") { + return "", fmt.Errorf("test name: %s contains illegal sequence '../'", testName) + } + if testName != "" { + return testName, nil + } + return r.ComputeSum(), nil } // Serialize the request. diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 6566dd8..b37ff6e 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -329,6 +329,92 @@ func TestRecordedRequest_Deserialize(t *testing.T) { } } +func TestRecordedRequest_GetRecordFileName(t *testing.T) { + testCases := []struct { + name string + request RecordedRequest + expected string + expectedErr bool + }{ + { + name: "Request with test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Test-Name": []string{"random test name"}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "random test name", + expectedErr: false, + }, + { + name: "Request with empty test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Test-Name": []string{""}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "f824dd099907ed4549822de827b075a7578baadebf08c5bc7303ead90a8f9ff7", + expectedErr: false, + }, + { + name: "Request with invalid test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Test-Name": []string{"../invalid_name"}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "", + expectedErr: true, + }, + { + name: "Request without test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Accept": []string{"application/xml"}, + "Content-Type": []string{"application/json"}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "fc060aea9a2bf35da16ed18c6be577ca64d0f91d681d5db385082df61ecf4ccf", + expectedErr: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual, err := tc.request.GetRecordingFileName() + if tc.expectedErr { + require.Error(t, err) + return + } + require.Equal(t, tc.expected, actual, "GetRecordFileName() result mismatch") + }) + } +} + type errorReader struct{} func (e *errorReader) Read(p []byte) (n int, err error) {