From 075fadcfa28a9aa92bf4e56fc0b399c3e8dd6e82 Mon Sep 17 00:00:00 2001 From: Kaituo Huang Date: Wed, 25 Jun 2025 09:48:48 -0700 Subject: [PATCH 1/6] feat: Support record and replay websocket request 1. Add a new GetRecordFileName method, prefers the test name from header if present 2. Add magic words used to split websocket client and server message chunks See example replay file: https://gist.github.com/hkt74/2430f265644dc5d5b62a7fd7ad97f1a6 --- internal/record/recording_https_proxy.go | 40 ++++----- internal/replay/replay_http_server.go | 109 ++++++++++++++++++++--- internal/store/store.go | 15 +++- 3 files changed, 130 insertions(+), 34 deletions(-) diff --git a/internal/record/recording_https_proxy.go b/internal/record/recording_https_proxy.go index 796cd84..eca80e4 100644 --- a/internal/record/recording_https_proxy.go +++ b/internal/record/recording_https_proxy.go @@ -70,7 +70,8 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String()) - reqHash, err := r.recordRequest(req) + recReq, err := r.recordRequest(req) + fileName := recReq.GetRecordFileName() if err != nil { fmt.Printf("Error recording request: %v\n", err) http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError) @@ -79,7 +80,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req if req.Header.Get("Upgrade") == "websocket" { fmt.Printf("Upgrading connection to websocket...\n") - r.proxyWebsocket(w, req, reqHash) + r.proxyWebsocket(w, req, fileName) return } @@ -90,7 +91,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req return } - err = r.recordResponse(resp, reqHash, respBody) + err = r.recordResponse(resp, fileName, respBody) if err != nil { fmt.Printf("Error recording response: %v\n", err) @@ -99,10 +100,10 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } } -func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (string, error) { +func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedRequest, error) { recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config) if err != nil { - return "", err + return recordedRequest, err } // Redact headers by key @@ -112,17 +113,13 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (string, error) { recordedRequest.Request = r.redactor.String(recordedRequest.Request) recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body) - reqHash, err := recordedRequest.ComputeSum() - if err != nil { - return "", err - } - - recordPath := filepath.Join(r.recordingDir, reqHash+".req") + fileName := recordedRequest.GetRecordFileName() + recordPath := filepath.Join(r.recordingDir, fileName+".req") err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644) if err != nil { - return "", err + return recordedRequest, err } - return reqHash, nil + return recordedRequest, nil } func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Request) (*http.Response, []byte, error) { @@ -172,7 +169,7 @@ func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Requ return resp, respBodyBytes, nil } -func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, reqHash string, body []byte) error { +func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, fileName string, body []byte) error { recordedResponse, err := store.NewRecordedResponse(resp, body) if err != nil { return err @@ -180,7 +177,7 @@ func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, reqHash string recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body) - recordPath := filepath.Join(r.recordingDir, reqHash+".resp") + recordPath := filepath.Join(r.recordingDir, fileName+".resp") fmt.Printf("Writing response to: %s\n", recordPath) err = os.WriteFile(recordPath, []byte(recordedResponse.Serialize()), 0644) if err != nil { @@ -209,7 +206,7 @@ func replaceRegex(s, regex, replacement string) string { return re.ReplaceAllString(s, replacement) } -func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Request, reqHash string) { +func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Request, fileName string) { conn, clientConn, err := r.upgradeConnectionToWebsocket(w, req) if err != nil { http.Error(w, fmt.Sprintf("Error proxying websocket: %v", err), http.StatusInternalServerError) @@ -221,10 +218,10 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re c := make(chan []byte) quit := make(chan int) - go pumpWebsocket(clientConn, conn, c, quit, ">") - go pumpWebsocket(conn, clientConn, c, quit, "<") + go pumpWebsocket(clientConn, conn, c, quit, "[WS_MSG][C2S]") + go pumpWebsocket(conn, clientConn, c, quit, "[WS_MSG][S2C]") - recordPath := filepath.Join(r.recordingDir, reqHash+".websocket") + recordPath := filepath.Join(r.recordingDir, fileName+".websocket") f, err := os.Create(recordPath) if err != nil { fmt.Printf("Error creating websocket recording file: %v\n", err) @@ -262,8 +259,8 @@ func pumpWebsocket(src, dst *websocket.Conn, c chan []byte, quit chan int, prepe quit <- 1 return } - prefix := fmt.Sprintf("%s%d", prepend, cap(buf)) - c <- append([]byte(prefix), buf...) + buf = append(buf, '\n') + c <- append([]byte(prepend), buf...) err = dst.WriteMessage(msgType, buf) if err != nil { fmt.Printf("Error writing to websocket: %v\n", err) @@ -286,6 +283,7 @@ func (r *RecordingHTTPSProxy) upgradeConnectionToWebsocket(w http.ResponseWriter "Sec-Websocket-Extensions": true, "Connection": true, "Upgrade": true, + "Test-Name": true, } for k, v := range req.Header { if _, ok := excludedHeaders[k]; ok { diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index bf4bedd..5131d02 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -21,10 +21,12 @@ import ( "net/http" "os" "path/filepath" + "strings" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" "github.com/google/test-server/internal/store" + "github.com/gorilla/websocket" ) type ReplayHTTPServer struct { @@ -68,22 +70,29 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques return } fmt.Printf("Replaying request: %ss\n", redactedReq.Request) - - reqHash, err := redactedReq.ComputeSum() - if err != nil { - fmt.Printf("Error computing request sum: %v\n", err) - http.Error(w, fmt.Sprintf("Error computing request sum: %v", err), http.StatusInternalServerError) + fileName := redactedReq.GetRecordFileName() + if req.Header.Get("Upgrade") == "websocket" { + fmt.Printf("Upgrading connection to websocket...\n") + + chunks, err := r.loadWebsocketChunks(fileName) + if err != nil { + fmt.Printf("Error loading websocket response: %v\n", err) + http.Error(w, fmt.Sprintf("Error loading websocket response: %v", err), http.StatusInternalServerError) + return + } + fmt.Printf("Replaying websocket: %s\n", fileName) + r.proxyWebsocket(w, req, chunks) return } - - resp, err := r.loadResponse(reqHash) + fmt.Printf("Replaying http request: %s\n", redactedReq.Request) + resp, err := r.loadResponse(fileName) if err != nil { fmt.Printf("Error loading response: %v\n", err) http.Error(w, fmt.Sprintf("Error loading response: %v", err), http.StatusInternalServerError) return } - err = r.writedResponse(w, resp) + err = r.writeResponse(w, resp) if err != nil { fmt.Printf("Error writing response: %v\n", err) panic(err) @@ -106,8 +115,8 @@ func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.Reco return recordedRequest, nil } -func (r *ReplayHTTPServer) loadResponse(sha string) (*store.RecordedResponse, error) { - responseFile := filepath.Join(r.recordingDir, sha+".resp") +func (r *ReplayHTTPServer) loadResponse(fileName string) (*store.RecordedResponse, error) { + responseFile := filepath.Join(r.recordingDir, fileName+".resp") fmt.Printf("loading response from : %s\n", responseFile) responseData, err := os.ReadFile(responseFile) if err != nil { @@ -116,7 +125,7 @@ func (r *ReplayHTTPServer) loadResponse(sha string) (*store.RecordedResponse, er return store.DeserializeResponse(responseData) } -func (r *ReplayHTTPServer) writedResponse(w http.ResponseWriter, resp *store.RecordedResponse) error { +func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.RecordedResponse) error { for key, values := range resp.Header { for _, value := range values { if key == "Content-Length" || key == "Content-Encoding" { @@ -131,3 +140,81 @@ func (r *ReplayHTTPServer) writedResponse(w http.ResponseWriter, resp *store.Rec _, err := w.Write(resp.Body) return err } + +func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Request, chunks []string) { + clientConn, err := r.upgradeConnectionToWebsocket(w, req) + if err != nil { + http.Error(w, fmt.Sprintf("Error proxying websocket: %v", err), http.StatusInternalServerError) + return + } + defer clientConn.Close() + replayWebsocket(clientConn, chunks) +} + +func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) { + responseFile := filepath.Join(r.recordingDir, sha+".websocket") + fmt.Printf("loading websocket response from : %s\n", responseFile) + responseData, err := os.ReadFile(responseFile) + var chunks []string + if err != nil { + fmt.Printf("Error loading websocket response: %v\n", err) + return chunks, err + } + chunks = strings.Split(string(responseData), "[WS_MSG]") + var cleanChunks []string + for _, chunk := range chunks { + trimmedChunk := strings.TrimSpace(chunk) + if trimmedChunk != "" { + cleanChunks = append(cleanChunks, trimmedChunk) + } + } + return cleanChunks, nil +} + +func replayWebsocket(conn *websocket.Conn, chunks []string) { + for _, chunk := range chunks { + if strings.HasPrefix(chunk, "[C2S]") { + _, buf, err := conn.ReadMessage() + reqChunk := string(buf) + if err != nil { + fmt.Printf("Error reading from websocket: %v\n", err) + return + } + + runes := []rune(chunk) + recChunk := string(runes[5:]) + if reqChunk != recChunk { + fmt.Printf("input chunk mismatch\n Input chunk: %s\n Recorded chunk: %s\n", reqChunk, recChunk) + return + } + } else if strings.HasPrefix(chunk, "[S2C]") { + runes := []rune(chunk) + recChunk := string(runes[5:]) + // Write binary message. (messageType=2) + err := conn.WriteMessage(2, []byte(recChunk)) + if err != nil { + fmt.Printf("Error writing to websocket: %v\n", err) + return + } + } else { + fmt.Printf("Unreconginized chunk: %s", chunk) + return + } + } +} + +func (r *ReplayHTTPServer) upgradeConnectionToWebsocket(w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins + }, + } + + clientConn, err := upgrader.Upgrade(w, req, nil) + if err != nil { + return nil, err + } + return clientConn, err +} diff --git a/internal/store/store.go b/internal/store/store.go index 9f58197..8fe41cd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -85,11 +85,22 @@ func readBody(req *http.Request) ([]byte, error) { } // ComputeSum computes the SHA256 sum of a RecordedRequest. -func (r *RecordedRequest) ComputeSum() (string, error) { +func (r *RecordedRequest) ComputeSum() string { serialized := r.Serialize() hash := sha256.Sum256([]byte(serialized)) hashHex := hex.EncodeToString(hash[:]) - return hashHex, nil + return hashHex +} + +// GetRecordFileName returns the record file name. +// It prefers the value from the TEST_NAME header. +// If the TEST_NAME header is not present or its value is empty, it falls back to computed SHA256 sum. +func (r *RecordedRequest) GetRecordFileName() string { + testName := r.Header.Get("Test-Name") + if testName != "" { + return testName + } + return r.ComputeSum() } // Serialize the request. From a9d51b5422d1f478d4ce90f3b27193a5bf2f21a5 Mon Sep 17 00:00:00 2001 From: Kaituo Huang Date: Wed, 25 Jun 2025 09:48:48 -0700 Subject: [PATCH 2/6] feat: Support record and replay websocket request 1. Add a new GetRecordFileName method, prefers the test name from header if present 2. Add magic words used to split websocket client and server message chunks See example replay file: https://gist.github.com/hkt74/2430f265644dc5d5b62a7fd7ad97f1a6 --- internal/store/store.go | 11 ------- internal/store/store_test.go | 62 ++++++++++++++++++++++++++++++++++++ 2 files changed, 62 insertions(+), 11 deletions(-) diff --git a/internal/store/store.go b/internal/store/store.go index 8fe41cd..9a53915 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -92,17 +92,6 @@ func (r *RecordedRequest) ComputeSum() string { return hashHex } -// GetRecordFileName returns the record file name. -// It prefers the value from the TEST_NAME header. -// If the TEST_NAME header is not present or its value is empty, it falls back to computed SHA256 sum. -func (r *RecordedRequest) GetRecordFileName() string { - testName := r.Header.Get("Test-Name") - if testName != "" { - return testName - } - return r.ComputeSum() -} - // Serialize the request. // // The serialization format is as follows: diff --git a/internal/store/store_test.go b/internal/store/store_test.go index 6566dd8..f378506 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -329,6 +329,68 @@ func TestRecordedRequest_Deserialize(t *testing.T) { } } +func TestRecordedRequest_GetRecordFileName(t *testing.T) { + testCases := []struct { + name string + request RecordedRequest + expected string + }{ + { + name: "Request with test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Test-Name": []string{"random test name"}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "random test name", + }, + { + name: "Request with empty test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Test-Name": []string{""}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "f824dd099907ed4549822de827b075a7578baadebf08c5bc7303ead90a8f9ff7", + }, + { + name: "Request without test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Accept": []string{"application/xml"}, + "Content-Type": []string{"application/json"}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "fc060aea9a2bf35da16ed18c6be577ca64d0f91d681d5db385082df61ecf4ccf", + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + actual := tc.request.GetRecordFileName() + require.Equal(t, tc.expected, actual, "GetRecordFileName() result mismatch") + }) + } +} + type errorReader struct{} func (e *errorReader) Read(p []byte) (n int, err error) { From 0bfb5b4f8660f6461f05368c0fa01b56cfc4dea2 Mon Sep 17 00:00:00 2001 From: Kaituo Huang Date: Wed, 25 Jun 2025 10:22:49 -0700 Subject: [PATCH 3/6] add GetRecordFileName --- internal/store/store.go | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/internal/store/store.go b/internal/store/store.go index 9a53915..8fe41cd 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -92,6 +92,17 @@ func (r *RecordedRequest) ComputeSum() string { return hashHex } +// GetRecordFileName returns the record file name. +// It prefers the value from the TEST_NAME header. +// If the TEST_NAME header is not present or its value is empty, it falls back to computed SHA256 sum. +func (r *RecordedRequest) GetRecordFileName() string { + testName := r.Header.Get("Test-Name") + if testName != "" { + return testName + } + return r.ComputeSum() +} + // Serialize the request. // // The serialization format is as follows: From fa51811815dfbf5d1c1e24ca9f4dd1c6e24f4258 Mon Sep 17 00:00:00 2001 From: Kaituo Huang Date: Wed, 25 Jun 2025 14:30:40 -0700 Subject: [PATCH 4/6] resolve review comments --- internal/record/recording_https_proxy.go | 28 +++++++---- internal/replay/replay_http_server.go | 63 ++++++++++++++++++------ internal/store/store.go | 14 ++++-- internal/store/store_test.go | 38 +++++++++++--- 4 files changed, 108 insertions(+), 35 deletions(-) diff --git a/internal/record/recording_https_proxy.go b/internal/record/recording_https_proxy.go index eca80e4..aeec50b 100644 --- a/internal/record/recording_https_proxy.go +++ b/internal/record/recording_https_proxy.go @@ -70,11 +70,16 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String()) - recReq, err := r.recordRequest(req) - fileName := recReq.GetRecordFileName() - if err != nil { - fmt.Printf("Error recording request: %v\n", err) - http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError) + recReq, recErr := r.recordRequest(req) + if recErr != nil { + fmt.Printf("Error recording request: %v\n", recErr) + http.Error(w, fmt.Sprintf("Error recording request: %v", recErr), http.StatusInternalServerError) + return + } + fileName, fileNameErr := recReq.GetRecordingFileName() + if fileNameErr != nil { + fmt.Printf("Invalid recording file name: %v\n", recErr) + http.Error(w, fmt.Sprintf("Invalid recording file name: %v", recErr), http.StatusInternalServerError) return } @@ -113,7 +118,11 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedR recordedRequest.Request = r.redactor.String(recordedRequest.Request) recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body) - fileName := recordedRequest.GetRecordFileName() + fileName, err := recordedRequest.GetRecordingFileName() + if err != nil { + fmt.Printf("Invalid recording file name: %v\n", err) + return recordedRequest, err + } recordPath := filepath.Join(r.recordingDir, fileName+".req") err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644) if err != nil { @@ -218,8 +227,8 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re c := make(chan []byte) quit := make(chan int) - go pumpWebsocket(clientConn, conn, c, quit, "[WS_MSG][C2S]") - go pumpWebsocket(conn, clientConn, c, quit, "[WS_MSG][S2C]") + go pumpWebsocket(clientConn, conn, c, quit, ">") + go pumpWebsocket(conn, clientConn, c, quit, "<") recordPath := filepath.Join(r.recordingDir, fileName+".websocket") f, err := os.Create(recordPath) @@ -260,7 +269,8 @@ func pumpWebsocket(src, dst *websocket.Conn, c chan []byte, quit chan int, prepe return } buf = append(buf, '\n') - c <- append([]byte(prepend), buf...) + prefix := fmt.Sprintf("%s%d", prepend, len(buf)) + c <- append([]byte(prefix), buf...) err = dst.WriteMessage(msgType, buf) if err != nil { fmt.Printf("Error writing to websocket: %v\n", err) diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index 5131d02..0cca9bb 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -21,7 +21,9 @@ import ( "net/http" "os" "path/filepath" + "strconv" "strings" + "unicode" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" @@ -70,7 +72,12 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques return } fmt.Printf("Replaying request: %ss\n", redactedReq.Request) - fileName := redactedReq.GetRecordFileName() + fileName, err := redactedReq.GetRecordingFileName() + if err != nil { + fmt.Printf("Invalid recording file name: %v\n", err) + http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError) + return + } if req.Header.Get("Upgrade") == "websocket" { fmt.Printf("Upgrading connection to websocket...\n") @@ -154,26 +161,54 @@ func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Reque func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) { responseFile := filepath.Join(r.recordingDir, sha+".websocket") fmt.Printf("loading websocket response from : %s\n", responseFile) - responseData, err := os.ReadFile(responseFile) - var chunks []string + bytes, err := os.ReadFile(responseFile) + var chunks = make([]string, 0) if err != nil { fmt.Printf("Error loading websocket response: %v\n", err) return chunks, err } - chunks = strings.Split(string(responseData), "[WS_MSG]") - var cleanChunks []string - for _, chunk := range chunks { - trimmedChunk := strings.TrimSpace(chunk) - if trimmedChunk != "" { - cleanChunks = append(cleanChunks, trimmedChunk) + + i := 0 + response := string(bytes) + for i < len(response) { + // Extracts prefix + prefix := response[i] + if prefix != '>' && prefix != '<' { + return nil, fmt.Errorf("invalid message prefix at position %d: expected '>' or '<', got '%c'", i, prefix) + } + i++ // Move cursor past prefix. + + // Extracts chunk length + numStart := i + for i < len(response) && unicode.IsDigit(rune(response[i])) { + i++ + } + numEnd := i + if numStart == numEnd { + return nil, fmt.Errorf("missing chunk length after prefix at position %d", numStart-1) + } + numStr := response[numStart:numEnd] + num, err := strconv.Atoi(numStr) + if err != nil { + return nil, fmt.Errorf("invalid chunk length '%s': %w", numStr, err) + } + + // Extracts chunk + chunkStart := numEnd + chunkEnd := chunkStart + num + if chunkEnd > len(response) { + return nil, fmt.Errorf("chunk length %d at position %d exceeds response bounds", chunkEnd, chunkStart) } + chunk := response[chunkStart : chunkEnd-1] // Remove the \n appended at the end of the chunk + chunks = append(chunks, string(prefix)+chunk) + i = chunkEnd } - return cleanChunks, nil + return chunks, nil } func replayWebsocket(conn *websocket.Conn, chunks []string) { for _, chunk := range chunks { - if strings.HasPrefix(chunk, "[C2S]") { + if strings.HasPrefix(chunk, ">") { _, buf, err := conn.ReadMessage() reqChunk := string(buf) if err != nil { @@ -182,14 +217,14 @@ func replayWebsocket(conn *websocket.Conn, chunks []string) { } runes := []rune(chunk) - recChunk := string(runes[5:]) + recChunk := string(runes[1:]) if reqChunk != recChunk { fmt.Printf("input chunk mismatch\n Input chunk: %s\n Recorded chunk: %s\n", reqChunk, recChunk) return } - } else if strings.HasPrefix(chunk, "[S2C]") { + } else if strings.HasPrefix(chunk, "<") { runes := []rune(chunk) - recChunk := string(runes[5:]) + recChunk := string(runes[1:]) // Write binary message. (messageType=2) err := conn.WriteMessage(2, []byte(recChunk)) if err != nil { diff --git a/internal/store/store.go b/internal/store/store.go index 8fe41cd..25b9090 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -92,15 +92,19 @@ func (r *RecordedRequest) ComputeSum() string { return hashHex } -// GetRecordFileName returns the record file name. +// GetRecordingFileName returns the recording file name. // It prefers the value from the TEST_NAME header. -// If the TEST_NAME header is not present or its value is empty, it falls back to computed SHA256 sum. -func (r *RecordedRequest) GetRecordFileName() string { +// It returns error when test name contains illegal sequence. +// If the TEST_NAME header is not present, it falls back to computed SHA256 sum. +func (r *RecordedRequest) GetRecordingFileName() (string, error) { testName := r.Header.Get("Test-Name") + if strings.Contains(testName, "../") { + return "", fmt.Errorf("test name: %s contains illegal sequence '../'", testName) + } if testName != "" { - return testName + return testName, nil } - return r.ComputeSum() + return r.ComputeSum(), nil } // Serialize the request. diff --git a/internal/store/store_test.go b/internal/store/store_test.go index f378506..b37ff6e 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -331,9 +331,10 @@ func TestRecordedRequest_Deserialize(t *testing.T) { func TestRecordedRequest_GetRecordFileName(t *testing.T) { testCases := []struct { - name string - request RecordedRequest - expected string + name string + request RecordedRequest + expected string + expectedErr bool }{ { name: "Request with test name header", @@ -348,7 +349,8 @@ func TestRecordedRequest_GetRecordFileName(t *testing.T) { Port: 0, Protocol: "", }, - expected: "random test name", + expected: "random test name", + expectedErr: false, }, { name: "Request with empty test name header", @@ -363,7 +365,24 @@ func TestRecordedRequest_GetRecordFileName(t *testing.T) { Port: 0, Protocol: "", }, - expected: "f824dd099907ed4549822de827b075a7578baadebf08c5bc7303ead90a8f9ff7", + expected: "f824dd099907ed4549822de827b075a7578baadebf08c5bc7303ead90a8f9ff7", + expectedErr: false, + }, + { + name: "Request with invalid test name header", + request: RecordedRequest{ + Request: "GET / HTTP/1.1", + Header: http.Header{ + "Test-Name": []string{"../invalid_name"}, + }, + Body: []byte{}, + PreviousRequest: HeadSHA, + ServerAddress: "", + Port: 0, + Protocol: "", + }, + expected: "", + expectedErr: true, }, { name: "Request without test name header", @@ -379,13 +398,18 @@ func TestRecordedRequest_GetRecordFileName(t *testing.T) { Port: 0, Protocol: "", }, - expected: "fc060aea9a2bf35da16ed18c6be577ca64d0f91d681d5db385082df61ecf4ccf", + expected: "fc060aea9a2bf35da16ed18c6be577ca64d0f91d681d5db385082df61ecf4ccf", + expectedErr: false, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { - actual := tc.request.GetRecordFileName() + actual, err := tc.request.GetRecordingFileName() + if tc.expectedErr { + require.Error(t, err) + return + } require.Equal(t, tc.expected, actual, "GetRecordFileName() result mismatch") }) } From 67c33e756f99cfcbe9752c03002f1842032f4b5d Mon Sep 17 00:00:00 2001 From: Kaituo Huang Date: Wed, 25 Jun 2025 20:33:07 -0700 Subject: [PATCH 5/6] resolve review comments --- internal/record/recording_https_proxy.go | 18 ++++++------ internal/replay/replay_http_server.go | 35 +++++++++++++++--------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/internal/record/recording_https_proxy.go b/internal/record/recording_https_proxy.go index aeec50b..6ec51c8 100644 --- a/internal/record/recording_https_proxy.go +++ b/internal/record/recording_https_proxy.go @@ -70,16 +70,16 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String()) - recReq, recErr := r.recordRequest(req) - if recErr != nil { - fmt.Printf("Error recording request: %v\n", recErr) - http.Error(w, fmt.Sprintf("Error recording request: %v", recErr), http.StatusInternalServerError) + recReq, err := r.recordRequest(req) + if err != nil { + fmt.Printf("Error recording request: %v\n", err) + http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError) return } - fileName, fileNameErr := recReq.GetRecordingFileName() - if fileNameErr != nil { - fmt.Printf("Invalid recording file name: %v\n", recErr) - http.Error(w, fmt.Sprintf("Invalid recording file name: %v", recErr), http.StatusInternalServerError) + fileName, err := recReq.GetRecordingFileName() + if err != nil { + fmt.Printf("Invalid recording file name: %v\n", err) + http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError) return } @@ -269,7 +269,7 @@ func pumpWebsocket(src, dst *websocket.Conn, c chan []byte, quit chan int, prepe return } buf = append(buf, '\n') - prefix := fmt.Sprintf("%s%d", prepend, len(buf)) + prefix := fmt.Sprintf("%s%d ", prepend, len(buf)) c <- append([]byte(prefix), buf...) err = dst.WriteMessage(msgType, buf) if err != nil { diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index 0cca9bb..8edf7be 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -148,6 +148,23 @@ func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.Reco return err } +func extractNumber(i *int, content string) (int, error) { + numStart := *i + for *i < len(content) && unicode.IsDigit(rune(content[*i])) { + *i++ + } + numEnd := *i + if numStart == numEnd { + return 0, fmt.Errorf("missing chunk length after prefix at position %d", numStart-1) + } + numStr := content[numStart:numEnd] + num, err := strconv.Atoi(numStr) + if err != nil { + return 0, fmt.Errorf("invalid chunk length '%s': %w", numStr, err) + } + return num, nil +} + func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Request, chunks []string) { clientConn, err := r.upgradeConnectionToWebsocket(w, req) if err != nil { @@ -178,23 +195,15 @@ func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) { } i++ // Move cursor past prefix. - // Extracts chunk length - numStart := i - for i < len(response) && unicode.IsDigit(rune(response[i])) { - i++ - } - numEnd := i - if numStart == numEnd { - return nil, fmt.Errorf("missing chunk length after prefix at position %d", numStart-1) - } - numStr := response[numStart:numEnd] - num, err := strconv.Atoi(numStr) + // Extracts chunk length number + num, err := extractNumber(&i, response) + i++ // Move cursor to skip the whitespace between the number and the actual chunk. if err != nil { - return nil, fmt.Errorf("invalid chunk length '%s': %w", numStr, err) + return nil, fmt.Errorf("failed to extract number %v", err) } // Extracts chunk - chunkStart := numEnd + chunkStart := i chunkEnd := chunkStart + num if chunkEnd > len(response) { return nil, fmt.Errorf("chunk length %d at position %d exceeds response bounds", chunkEnd, chunkStart) From 00b0ac4d0d3bb3ad66d169e68a00d095ee58fef5 Mon Sep 17 00:00:00 2001 From: Kaituo Huang Date: Thu, 26 Jun 2025 15:07:23 -0700 Subject: [PATCH 6/6] write connection close --- internal/replay/replay_http_server.go | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index 8edf7be..a3d735f 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -229,6 +229,7 @@ func replayWebsocket(conn *websocket.Conn, chunks []string) { recChunk := string(runes[1:]) if reqChunk != recChunk { fmt.Printf("input chunk mismatch\n Input chunk: %s\n Recorded chunk: %s\n", reqChunk, recChunk) + writeError(conn, "input chunk mismatch") return } } else if strings.HasPrefix(chunk, "<") { @@ -247,6 +248,17 @@ func replayWebsocket(conn *websocket.Conn, chunks []string) { } } +func writeError(conn *websocket.Conn, errMsg string) { + closeMessage := websocket.FormatCloseMessage( + websocket.CloseInternalServerErr, + errMsg, + ) + err := conn.WriteMessage(websocket.CloseMessage, closeMessage) + if err != nil { + fmt.Printf("Failed to write error: %v\n", err) + } +} + func (r *ReplayHTTPServer) upgradeConnectionToWebsocket(w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) { upgrader := websocket.Upgrader{ ReadBufferSize: 1024,