diff --git a/internal/record/recording_https_proxy.go b/internal/record/recording_https_proxy.go index 6ec51c8..b0c180a 100644 --- a/internal/record/recording_https_proxy.go +++ b/internal/record/recording_https_proxy.go @@ -33,6 +33,7 @@ import ( type RecordingHTTPSProxy struct { prevRequestSHA string + seenFiles map[string]struct{} config *config.EndpointConfig recordingDir string redactor *redact.Redact @@ -41,6 +42,7 @@ type RecordingHTTPSProxy struct { func NewRecordingHTTPSProxy(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *RecordingHTTPSProxy { return &RecordingHTTPSProxy{ prevRequestSHA: store.HeadSHA, + seenFiles: make(map[string]struct{}), config: cfg, recordingDir: recordingDir, redactor: redactor, @@ -70,7 +72,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req } fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String()) - recReq, err := r.recordRequest(req) + recReq, err := r.redactRequest(req) if err != nil { fmt.Printf("Error recording request: %v\n", err) http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError) @@ -82,6 +84,10 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError) return } + if _, ok := r.seenFiles[fileName]; !ok { + // Reset to HeadSHA when first time seen a request from the given file. + recReq.PreviousRequest=store.HeadSHA + } if req.Header.Get("Upgrade") == "websocket" { fmt.Printf("Upgrading connection to websocket...\n") @@ -95,17 +101,20 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req http.Error(w, fmt.Sprintf("Error proxying request: %v", err), http.StatusInternalServerError) return } - - err = r.recordResponse(resp, fileName, respBody) - + shaSum := recReq.ComputeSum() + err = r.recordResponse(recReq, resp, fileName, shaSum, respBody) if err != nil { fmt.Printf("Error recording response: %v\n", err) http.Error(w, fmt.Sprintf("Error recording response: %v", err), http.StatusInternalServerError) return } + if (fileName != shaSum) { + r.prevRequestSHA = shaSum + } + r.seenFiles[fileName] = struct{}{} } -func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedRequest, error) { +func (r *RecordingHTTPSProxy) redactRequest(req *http.Request) (*store.RecordedRequest, error) { recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config) if err != nil { return recordedRequest, err @@ -117,17 +126,6 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedR r.redactor.Headers(recordedRequest.Header) recordedRequest.Request = r.redactor.String(recordedRequest.Request) recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body) - - fileName, err := recordedRequest.GetRecordingFileName() - if err != nil { - fmt.Printf("Invalid recording file name: %v\n", err) - return recordedRequest, err - } - recordPath := filepath.Join(r.recordingDir, fileName+".req") - err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644) - if err != nil { - return recordedRequest, err - } return recordedRequest, nil } @@ -178,21 +176,47 @@ func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Requ return resp, respBodyBytes, nil } -func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, fileName string, body []byte) error { +func (r *RecordingHTTPSProxy) recordResponse(recReq *store.RecordedRequest, resp *http.Response, fileName string, shaSum string, body []byte) error { recordedResponse, err := store.NewRecordedResponse(resp, body) if err != nil { return err } + recordPath := filepath.Join(r.recordingDir, fileName+".http.log") - recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body) + // Default to overwriting the file, assuming it's a new file to record. + fileMode := os.O_TRUNC + // If we've seen requests with the same file name before, change the mode to append. + if _, ok := r.seenFiles[fileName]; ok { + fileMode = os.O_APPEND + } + file, err := os.OpenFile(recordPath, fileMode|os.O_CREATE|os.O_WRONLY , 0644) + if err != nil { + return err + } + defer file.Close() - recordPath := filepath.Join(r.recordingDir, fileName+".resp") - fmt.Printf("Writing response to: %s\n", recordPath) - err = os.WriteFile(recordPath, []byte(recordedResponse.Serialize()), 0644) + fmt.Printf("Writing request to: %s\n", recordPath) + serializedReq := recReq.Serialize() + _, err = file.WriteString(fmt.Sprintf("%s.req %d\n", shaSum, len(serializedReq))) + if err != nil { + return err + } + _, err = file.WriteString(serializedReq) if err != nil { return err } + fmt.Printf("Writing response to: %s\n", recordPath) + recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body) + serializedResp := recordedResponse.Serialize() + _, err = file.WriteString(fmt.Sprintf("\n%s.resp %d\n", shaSum, len(serializedResp))) + if err != nil { + return err + } + _, err = file.WriteString(serializedResp) + if err != nil { + return err + } return nil } @@ -230,7 +254,7 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re go pumpWebsocket(clientConn, conn, c, quit, ">") go pumpWebsocket(conn, clientConn, c, quit, "<") - recordPath := filepath.Join(r.recordingDir, fileName+".websocket") + recordPath := filepath.Join(r.recordingDir, fileName+".websocket.log") f, err := os.Create(recordPath) if err != nil { fmt.Printf("Error creating websocket recording file: %v\n", err) diff --git a/internal/replay/replay_http_server.go b/internal/replay/replay_http_server.go index a3d735f..f799915 100644 --- a/internal/replay/replay_http_server.go +++ b/internal/replay/replay_http_server.go @@ -24,6 +24,8 @@ import ( "strconv" "strings" "unicode" + "bufio" + "io" "github.com/google/test-server/internal/config" "github.com/google/test-server/internal/redact" @@ -33,6 +35,7 @@ import ( type ReplayHTTPServer struct { prevRequestSHA string + seenFiles map[string]struct{} config *config.EndpointConfig recordingDir string redactor *redact.Redact @@ -41,6 +44,7 @@ type ReplayHTTPServer struct { func NewReplayHTTPServer(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *ReplayHTTPServer { return &ReplayHTTPServer{ prevRequestSHA: store.HeadSHA, + seenFiles: make(map[string]struct{}), config: cfg, recordingDir: recordingDir, redactor: redactor, @@ -78,6 +82,10 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError) return } + if _, ok := r.seenFiles[fileName]; !ok { + // Reset to HeadSHA when first time seen request from the given file. + redactedReq.PreviousRequest=store.HeadSHA + } if req.Header.Get("Upgrade") == "websocket" { fmt.Printf("Upgrading connection to websocket...\n") @@ -92,7 +100,8 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques return } fmt.Printf("Replaying http request: %s\n", redactedReq.Request) - resp, err := r.loadResponse(fileName) + shaSum := redactedReq.ComputeSum() + resp, err := r.loadResponse(fileName, shaSum) if err != nil { fmt.Printf("Error loading response: %v\n", err) http.Error(w, fmt.Sprintf("Error loading response: %v", err), http.StatusInternalServerError) @@ -104,6 +113,10 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques fmt.Printf("Error writing response: %v\n", err) panic(err) } + if (fileName != shaSum) { + r.prevRequestSHA = shaSum + } + r.seenFiles[fileName] = struct{}{} } func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.RecordedRequest, error) { @@ -122,14 +135,59 @@ func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.Reco return recordedRequest, nil } -func (r *ReplayHTTPServer) loadResponse(fileName string) (*store.RecordedResponse, error) { - responseFile := filepath.Join(r.recordingDir, fileName+".resp") - fmt.Printf("loading response from : %s\n", responseFile) - responseData, err := os.ReadFile(responseFile) +func (r *ReplayHTTPServer) loadResponse(fileName string, shaSum string) (*store.RecordedResponse, error) { + // 1. Open the replay log file for reading. + filePath := filepath.Join(r.recordingDir, fileName+".http.log") + fmt.Printf("loading response from : %s with shaSum: %s\n", filePath, shaSum) + file, err := os.Open(filePath) if err != nil { - return nil, err + return nil, fmt.Errorf("could not open file %s: %w", filePath, err) + } + defer file.Close() + + reader := bufio.NewReader(file) + expectedKey := shaSum + ".resp" + // 2. Scan the file line by line using the reader directly. + for { + // Read one line, including the newline character. + line, err := reader.ReadString('\n') + if err != nil { + if err == io.EOF { + return nil, fmt.Errorf("response with shaSum %s not found in file", shaSum) + } + return nil, fmt.Errorf("error while reading file: %w", err) + } + trimmedLine := strings.TrimSpace(line) + parts := strings.Fields(trimmedLine) + if len(parts) != 2 { + continue + } + + fileKey := parts[0] + sizeStr := parts[1] + + size, err := strconv.Atoi(sizeStr) + if err != nil { + return nil, fmt.Errorf("invalid size format on delimiter line: '%s'", trimmedLine) + } + fmt.Printf("Bytes to load: %d\n", size) + if size < 0 { + return nil, fmt.Errorf("invalid negative size on delimiter line: '%s'", trimmedLine) + } + + // 3. Read the exact number of bytes for the payload. + data := make([]byte, size) + if _, err := io.ReadFull(reader, data); err != nil { + return nil, fmt.Errorf("failed to read %d bytes after delimiter: %w", size, err) + } + + // 4. Return the response when it matches our target shaSum. + if fileKey == expectedKey { + return store.DeserializeResponse(data) + } else { + continue + } } - return store.DeserializeResponse(responseData) } func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.RecordedResponse) error { @@ -175,8 +233,8 @@ func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Reque replayWebsocket(clientConn, chunks) } -func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) { - responseFile := filepath.Join(r.recordingDir, sha+".websocket") +func (r *ReplayHTTPServer) loadWebsocketChunks(fileName string) ([]string, error) { + responseFile := filepath.Join(r.recordingDir, fileName+".websocket.log") fmt.Printf("loading websocket response from : %s\n", responseFile) bytes, err := os.ReadFile(responseFile) var chunks = make([]string, 0) diff --git a/internal/store/store.go b/internal/store/store.go index 25b9090..60e5997 100644 --- a/internal/store/store.go +++ b/internal/store/store.go @@ -30,7 +30,6 @@ import ( "github.com/google/test-server/internal/config" ) -// A sha of an invalid RecordRequest to be used as the head of all chains. const HeadSHA = "b4d6e60a9b97e7b98c63df9308728c5c88c0b40c398046772c63447b94608b4d" type RecordedRequest struct { @@ -102,7 +101,8 @@ func (r *RecordedRequest) GetRecordingFileName() (string, error) { return "", fmt.Errorf("test name: %s contains illegal sequence '../'", testName) } if testName != "" { - return testName, nil + fileName := strings.ReplaceAll(testName, " ", "_") + return fileName, nil } return r.ComputeSum(), nil } diff --git a/internal/store/store_test.go b/internal/store/store_test.go index b37ff6e..0552807 100644 --- a/internal/store/store_test.go +++ b/internal/store/store_test.go @@ -349,7 +349,7 @@ func TestRecordedRequest_GetRecordFileName(t *testing.T) { Port: 0, Protocol: "", }, - expected: "random test name", + expected: "random_test_name", expectedErr: false, }, {