Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 46 additions & 22 deletions internal/record/recording_https_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ import (

type RecordingHTTPSProxy struct {
prevRequestSHA string
seenFiles map[string]struct{}
config *config.EndpointConfig
recordingDir string
redactor *redact.Redact
Expand All @@ -41,6 +42,7 @@ type RecordingHTTPSProxy struct {
func NewRecordingHTTPSProxy(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *RecordingHTTPSProxy {
return &RecordingHTTPSProxy{
prevRequestSHA: store.HeadSHA,
seenFiles: make(map[string]struct{}),
config: cfg,
recordingDir: recordingDir,
redactor: redactor,
Expand Down Expand Up @@ -70,7 +72,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
}
fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String())

recReq, err := r.recordRequest(req)
recReq, err := r.redactRequest(req)
if err != nil {
fmt.Printf("Error recording request: %v\n", err)
http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError)
Expand All @@ -82,6 +84,10 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
return
}
if _, ok := r.seenFiles[fileName]; !ok {
// Reset to HeadSHA when first time seen a request from the given file.
recReq.PreviousRequest=store.HeadSHA
}

if req.Header.Get("Upgrade") == "websocket" {
fmt.Printf("Upgrading connection to websocket...\n")
Expand All @@ -95,17 +101,20 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
http.Error(w, fmt.Sprintf("Error proxying request: %v", err), http.StatusInternalServerError)
return
}

err = r.recordResponse(resp, fileName, respBody)

shaSum := recReq.ComputeSum()
err = r.recordResponse(recReq, resp, fileName, shaSum, respBody)
if err != nil {
fmt.Printf("Error recording response: %v\n", err)
http.Error(w, fmt.Sprintf("Error recording response: %v", err), http.StatusInternalServerError)
return
}
if (fileName != shaSum) {
r.prevRequestSHA = shaSum
}
r.seenFiles[fileName] = struct{}{}
}

func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedRequest, error) {
func (r *RecordingHTTPSProxy) redactRequest(req *http.Request) (*store.RecordedRequest, error) {
recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config)
if err != nil {
return recordedRequest, err
Expand All @@ -117,17 +126,6 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedR
r.redactor.Headers(recordedRequest.Header)
recordedRequest.Request = r.redactor.String(recordedRequest.Request)
recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body)

fileName, err := recordedRequest.GetRecordingFileName()
if err != nil {
fmt.Printf("Invalid recording file name: %v\n", err)
return recordedRequest, err
}
recordPath := filepath.Join(r.recordingDir, fileName+".req")
err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644)
if err != nil {
return recordedRequest, err
}
return recordedRequest, nil
}

Expand Down Expand Up @@ -178,21 +176,47 @@ func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Requ
return resp, respBodyBytes, nil
}

func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, fileName string, body []byte) error {
func (r *RecordingHTTPSProxy) recordResponse(recReq *store.RecordedRequest, resp *http.Response, fileName string, shaSum string, body []byte) error {
recordedResponse, err := store.NewRecordedResponse(resp, body)
if err != nil {
return err
}
recordPath := filepath.Join(r.recordingDir, fileName+".http.log")

recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body)
// Default to overwriting the file, assuming it's a new file to record.
fileMode := os.O_TRUNC
// If we've seen requests with the same file name before, change the mode to append.
if _, ok := r.seenFiles[fileName]; ok {
fileMode = os.O_APPEND
}
file, err := os.OpenFile(recordPath, fileMode|os.O_CREATE|os.O_WRONLY , 0644)
if err != nil {
return err
}
defer file.Close()

recordPath := filepath.Join(r.recordingDir, fileName+".resp")
fmt.Printf("Writing response to: %s\n", recordPath)
err = os.WriteFile(recordPath, []byte(recordedResponse.Serialize()), 0644)
fmt.Printf("Writing request to: %s\n", recordPath)
serializedReq := recReq.Serialize()
_, err = file.WriteString(fmt.Sprintf("%s.req %d\n", shaSum, len(serializedReq)))
if err != nil {
return err
}
_, err = file.WriteString(serializedReq)
if err != nil {
return err
}

fmt.Printf("Writing response to: %s\n", recordPath)
recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body)
serializedResp := recordedResponse.Serialize()
_, err = file.WriteString(fmt.Sprintf("\n%s.resp %d\n", shaSum, len(serializedResp)))
if err != nil {
return err
}
_, err = file.WriteString(serializedResp)
if err != nil {
return err
}
return nil
}

Expand Down Expand Up @@ -230,7 +254,7 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re
go pumpWebsocket(clientConn, conn, c, quit, ">")
go pumpWebsocket(conn, clientConn, c, quit, "<")

recordPath := filepath.Join(r.recordingDir, fileName+".websocket")
recordPath := filepath.Join(r.recordingDir, fileName+".websocket.log")
f, err := os.Create(recordPath)
if err != nil {
fmt.Printf("Error creating websocket recording file: %v\n", err)
Expand Down
76 changes: 67 additions & 9 deletions internal/replay/replay_http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ import (
"strconv"
"strings"
"unicode"
"bufio"
"io"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
Expand All @@ -33,6 +35,7 @@ import (

type ReplayHTTPServer struct {
prevRequestSHA string
seenFiles map[string]struct{}
config *config.EndpointConfig
recordingDir string
redactor *redact.Redact
Expand All @@ -41,6 +44,7 @@ type ReplayHTTPServer struct {
func NewReplayHTTPServer(cfg *config.EndpointConfig, recordingDir string, redactor *redact.Redact) *ReplayHTTPServer {
return &ReplayHTTPServer{
prevRequestSHA: store.HeadSHA,
seenFiles: make(map[string]struct{}),
config: cfg,
recordingDir: recordingDir,
redactor: redactor,
Expand Down Expand Up @@ -78,6 +82,10 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
return
}
if _, ok := r.seenFiles[fileName]; !ok {
// Reset to HeadSHA when first time seen request from the given file.
redactedReq.PreviousRequest=store.HeadSHA
}
if req.Header.Get("Upgrade") == "websocket" {
fmt.Printf("Upgrading connection to websocket...\n")

Expand All @@ -92,7 +100,8 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
return
}
fmt.Printf("Replaying http request: %s\n", redactedReq.Request)
resp, err := r.loadResponse(fileName)
shaSum := redactedReq.ComputeSum()
resp, err := r.loadResponse(fileName, shaSum)
if err != nil {
fmt.Printf("Error loading response: %v\n", err)
http.Error(w, fmt.Sprintf("Error loading response: %v", err), http.StatusInternalServerError)
Expand All @@ -104,6 +113,10 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
fmt.Printf("Error writing response: %v\n", err)
panic(err)
}
if (fileName != shaSum) {
r.prevRequestSHA = shaSum
}
r.seenFiles[fileName] = struct{}{}
}

func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.RecordedRequest, error) {
Expand All @@ -122,14 +135,59 @@ func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.Reco
return recordedRequest, nil
}

func (r *ReplayHTTPServer) loadResponse(fileName string) (*store.RecordedResponse, error) {
responseFile := filepath.Join(r.recordingDir, fileName+".resp")
fmt.Printf("loading response from : %s\n", responseFile)
responseData, err := os.ReadFile(responseFile)
func (r *ReplayHTTPServer) loadResponse(fileName string, shaSum string) (*store.RecordedResponse, error) {
// 1. Open the replay log file for reading.
filePath := filepath.Join(r.recordingDir, fileName+".http.log")
fmt.Printf("loading response from : %s with shaSum: %s\n", filePath, shaSum)
file, err := os.Open(filePath)
if err != nil {
return nil, err
return nil, fmt.Errorf("could not open file %s: %w", filePath, err)
}
defer file.Close()

reader := bufio.NewReader(file)
expectedKey := shaSum + ".resp"
// 2. Scan the file line by line using the reader directly.
for {
// Read one line, including the newline character.
line, err := reader.ReadString('\n')
if err != nil {
if err == io.EOF {
return nil, fmt.Errorf("response with shaSum %s not found in file", shaSum)
}
return nil, fmt.Errorf("error while reading file: %w", err)
}
trimmedLine := strings.TrimSpace(line)
parts := strings.Fields(trimmedLine)
if len(parts) != 2 {
continue
}

fileKey := parts[0]
sizeStr := parts[1]

size, err := strconv.Atoi(sizeStr)
if err != nil {
return nil, fmt.Errorf("invalid size format on delimiter line: '%s'", trimmedLine)
}
fmt.Printf("Bytes to load: %d\n", size)
if size < 0 {
return nil, fmt.Errorf("invalid negative size on delimiter line: '%s'", trimmedLine)
}

// 3. Read the exact number of bytes for the payload.
data := make([]byte, size)
if _, err := io.ReadFull(reader, data); err != nil {
return nil, fmt.Errorf("failed to read %d bytes after delimiter: %w", size, err)
}

// 4. Return the response when it matches our target shaSum.
if fileKey == expectedKey {
return store.DeserializeResponse(data)
} else {
continue
}
}
return store.DeserializeResponse(responseData)
}

func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.RecordedResponse) error {
Expand Down Expand Up @@ -175,8 +233,8 @@ func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Reque
replayWebsocket(clientConn, chunks)
}

func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) {
responseFile := filepath.Join(r.recordingDir, sha+".websocket")
func (r *ReplayHTTPServer) loadWebsocketChunks(fileName string) ([]string, error) {
responseFile := filepath.Join(r.recordingDir, fileName+".websocket.log")
fmt.Printf("loading websocket response from : %s\n", responseFile)
bytes, err := os.ReadFile(responseFile)
var chunks = make([]string, 0)
Expand Down
4 changes: 2 additions & 2 deletions internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@ import (
"github.com/google/test-server/internal/config"
)

// A sha of an invalid RecordRequest to be used as the head of all chains.
const HeadSHA = "b4d6e60a9b97e7b98c63df9308728c5c88c0b40c398046772c63447b94608b4d"

type RecordedRequest struct {
Expand Down Expand Up @@ -102,7 +101,8 @@ func (r *RecordedRequest) GetRecordingFileName() (string, error) {
return "", fmt.Errorf("test name: %s contains illegal sequence '../'", testName)
}
if testName != "" {
return testName, nil
fileName := strings.ReplaceAll(testName, " ", "_")
return fileName, nil
}
return r.ComputeSum(), nil
}
Expand Down
2 changes: 1 addition & 1 deletion internal/store/store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ func TestRecordedRequest_GetRecordFileName(t *testing.T) {
Port: 0,
Protocol: "",
},
expected: "random test name",
expected: "random_test_name",
expectedErr: false,
},
{
Expand Down