Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 24 additions & 16 deletions internal/record/recording_https_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,22 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
}
fmt.Printf("Recording request: %s %s\n", req.Method, req.URL.String())

reqHash, err := r.recordRequest(req)
recReq, err := r.recordRequest(req)
if err != nil {
fmt.Printf("Error recording request: %v\n", err)
http.Error(w, fmt.Sprintf("Error recording request: %v", err), http.StatusInternalServerError)
return
}
fileName, err := recReq.GetRecordingFileName()
if err != nil {
fmt.Printf("Invalid recording file name: %v\n", err)
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
return
}

if req.Header.Get("Upgrade") == "websocket" {
fmt.Printf("Upgrading connection to websocket...\n")
r.proxyWebsocket(w, req, reqHash)
r.proxyWebsocket(w, req, fileName)
return
}

Expand All @@ -90,7 +96,7 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
return
}

err = r.recordResponse(resp, reqHash, respBody)
err = r.recordResponse(resp, fileName, respBody)

if err != nil {
fmt.Printf("Error recording response: %v\n", err)
Expand All @@ -99,10 +105,10 @@ func (r *RecordingHTTPSProxy) handleRequest(w http.ResponseWriter, req *http.Req
}
}

func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (string, error) {
func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (*store.RecordedRequest, error) {
recordedRequest, err := store.NewRecordedRequest(req, r.prevRequestSHA, *r.config)
if err != nil {
return "", err
return recordedRequest, err
}

// Redact headers by key
Expand All @@ -112,17 +118,17 @@ func (r *RecordingHTTPSProxy) recordRequest(req *http.Request) (string, error) {
recordedRequest.Request = r.redactor.String(recordedRequest.Request)
recordedRequest.Body = r.redactor.Bytes(recordedRequest.Body)

reqHash, err := recordedRequest.ComputeSum()
fileName, err := recordedRequest.GetRecordingFileName()
if err != nil {
return "", err
fmt.Printf("Invalid recording file name: %v\n", err)
return recordedRequest, err
}

recordPath := filepath.Join(r.recordingDir, reqHash+".req")
recordPath := filepath.Join(r.recordingDir, fileName+".req")
err = os.WriteFile(recordPath, []byte(recordedRequest.Serialize()), 0644)
if err != nil {
return "", err
return recordedRequest, err
}
return reqHash, nil
return recordedRequest, nil
}

func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Request) (*http.Response, []byte, error) {
Expand Down Expand Up @@ -172,15 +178,15 @@ func (r *RecordingHTTPSProxy) proxyRequest(w http.ResponseWriter, req *http.Requ
return resp, respBodyBytes, nil
}

func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, reqHash string, body []byte) error {
func (r *RecordingHTTPSProxy) recordResponse(resp *http.Response, fileName string, body []byte) error {
recordedResponse, err := store.NewRecordedResponse(resp, body)
if err != nil {
return err
}

recordedResponse.Body = r.redactor.Bytes(recordedResponse.Body)

recordPath := filepath.Join(r.recordingDir, reqHash+".resp")
recordPath := filepath.Join(r.recordingDir, fileName+".resp")
fmt.Printf("Writing response to: %s\n", recordPath)
err = os.WriteFile(recordPath, []byte(recordedResponse.Serialize()), 0644)
if err != nil {
Expand Down Expand Up @@ -209,7 +215,7 @@ func replaceRegex(s, regex, replacement string) string {
return re.ReplaceAllString(s, replacement)
}

func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Request, reqHash string) {
func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Request, fileName string) {
conn, clientConn, err := r.upgradeConnectionToWebsocket(w, req)
if err != nil {
http.Error(w, fmt.Sprintf("Error proxying websocket: %v", err), http.StatusInternalServerError)
Expand All @@ -224,7 +230,7 @@ func (r *RecordingHTTPSProxy) proxyWebsocket(w http.ResponseWriter, req *http.Re
go pumpWebsocket(clientConn, conn, c, quit, ">")
go pumpWebsocket(conn, clientConn, c, quit, "<")

recordPath := filepath.Join(r.recordingDir, reqHash+".websocket")
recordPath := filepath.Join(r.recordingDir, fileName+".websocket")
f, err := os.Create(recordPath)
if err != nil {
fmt.Printf("Error creating websocket recording file: %v\n", err)
Expand Down Expand Up @@ -262,7 +268,8 @@ func pumpWebsocket(src, dst *websocket.Conn, c chan []byte, quit chan int, prepe
quit <- 1
return
}
prefix := fmt.Sprintf("%s%d", prepend, cap(buf))
buf = append(buf, '\n')
prefix := fmt.Sprintf("%s%d ", prepend, len(buf))
c <- append([]byte(prefix), buf...)
err = dst.WriteMessage(msgType, buf)
if err != nil {
Expand All @@ -286,6 +293,7 @@ func (r *RecordingHTTPSProxy) upgradeConnectionToWebsocket(w http.ResponseWriter
"Sec-Websocket-Extensions": true,
"Connection": true,
"Upgrade": true,
"Test-Name": true,
}
for k, v := range req.Header {
if _, ok := excludedHeaders[k]; ok {
Expand Down
161 changes: 152 additions & 9 deletions internal/replay/replay_http_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,14 @@ import (
"net/http"
"os"
"path/filepath"
"strconv"
"strings"
"unicode"

"github.com/google/test-server/internal/config"
"github.com/google/test-server/internal/redact"
"github.com/google/test-server/internal/store"
"github.com/gorilla/websocket"
)

type ReplayHTTPServer struct {
Expand Down Expand Up @@ -68,22 +72,34 @@ func (r *ReplayHTTPServer) handleRequest(w http.ResponseWriter, req *http.Reques
return
}
fmt.Printf("Replaying request: %ss\n", redactedReq.Request)

reqHash, err := redactedReq.ComputeSum()
fileName, err := redactedReq.GetRecordingFileName()
if err != nil {
fmt.Printf("Error computing request sum: %v\n", err)
http.Error(w, fmt.Sprintf("Error computing request sum: %v", err), http.StatusInternalServerError)
fmt.Printf("Invalid recording file name: %v\n", err)
http.Error(w, fmt.Sprintf("Invalid recording file name: %v", err), http.StatusInternalServerError)
return
}
if req.Header.Get("Upgrade") == "websocket" {
fmt.Printf("Upgrading connection to websocket...\n")

resp, err := r.loadResponse(reqHash)
chunks, err := r.loadWebsocketChunks(fileName)
if err != nil {
fmt.Printf("Error loading websocket response: %v\n", err)
http.Error(w, fmt.Sprintf("Error loading websocket response: %v", err), http.StatusInternalServerError)
return
}
fmt.Printf("Replaying websocket: %s\n", fileName)
r.proxyWebsocket(w, req, chunks)
return
}
fmt.Printf("Replaying http request: %s\n", redactedReq.Request)
resp, err := r.loadResponse(fileName)
if err != nil {
fmt.Printf("Error loading response: %v\n", err)
http.Error(w, fmt.Sprintf("Error loading response: %v", err), http.StatusInternalServerError)
return
}

err = r.writedResponse(w, resp)
err = r.writeResponse(w, resp)
if err != nil {
fmt.Printf("Error writing response: %v\n", err)
panic(err)
Expand All @@ -106,8 +122,8 @@ func (r *ReplayHTTPServer) createRedactedRequest(req *http.Request) (*store.Reco
return recordedRequest, nil
}

func (r *ReplayHTTPServer) loadResponse(sha string) (*store.RecordedResponse, error) {
responseFile := filepath.Join(r.recordingDir, sha+".resp")
func (r *ReplayHTTPServer) loadResponse(fileName string) (*store.RecordedResponse, error) {
responseFile := filepath.Join(r.recordingDir, fileName+".resp")
fmt.Printf("loading response from : %s\n", responseFile)
responseData, err := os.ReadFile(responseFile)
if err != nil {
Expand All @@ -116,7 +132,7 @@ func (r *ReplayHTTPServer) loadResponse(sha string) (*store.RecordedResponse, er
return store.DeserializeResponse(responseData)
}

func (r *ReplayHTTPServer) writedResponse(w http.ResponseWriter, resp *store.RecordedResponse) error {
func (r *ReplayHTTPServer) writeResponse(w http.ResponseWriter, resp *store.RecordedResponse) error {
for key, values := range resp.Header {
for _, value := range values {
if key == "Content-Length" || key == "Content-Encoding" {
Expand All @@ -131,3 +147,130 @@ func (r *ReplayHTTPServer) writedResponse(w http.ResponseWriter, resp *store.Rec
_, err := w.Write(resp.Body)
return err
}

func extractNumber(i *int, content string) (int, error) {
numStart := *i
for *i < len(content) && unicode.IsDigit(rune(content[*i])) {
*i++
}
numEnd := *i
if numStart == numEnd {
return 0, fmt.Errorf("missing chunk length after prefix at position %d", numStart-1)
}
numStr := content[numStart:numEnd]
num, err := strconv.Atoi(numStr)
if err != nil {
return 0, fmt.Errorf("invalid chunk length '%s': %w", numStr, err)
}
return num, nil
}

func (r *ReplayHTTPServer) proxyWebsocket(w http.ResponseWriter, req *http.Request, chunks []string) {
clientConn, err := r.upgradeConnectionToWebsocket(w, req)
if err != nil {
http.Error(w, fmt.Sprintf("Error proxying websocket: %v", err), http.StatusInternalServerError)
return
}
defer clientConn.Close()
replayWebsocket(clientConn, chunks)
}

func (r *ReplayHTTPServer) loadWebsocketChunks(sha string) ([]string, error) {
responseFile := filepath.Join(r.recordingDir, sha+".websocket")
fmt.Printf("loading websocket response from : %s\n", responseFile)
bytes, err := os.ReadFile(responseFile)
var chunks = make([]string, 0)
if err != nil {
fmt.Printf("Error loading websocket response: %v\n", err)
return chunks, err
}

i := 0
response := string(bytes)
for i < len(response) {
// Extracts prefix
prefix := response[i]
if prefix != '>' && prefix != '<' {
return nil, fmt.Errorf("invalid message prefix at position %d: expected '>' or '<', got '%c'", i, prefix)
}
i++ // Move cursor past prefix.

// Extracts chunk length number
num, err := extractNumber(&i, response)
i++ // Move cursor to skip the whitespace between the number and the actual chunk.
if err != nil {
return nil, fmt.Errorf("failed to extract number %v", err)
}

// Extracts chunk
chunkStart := i
chunkEnd := chunkStart + num
if chunkEnd > len(response) {
return nil, fmt.Errorf("chunk length %d at position %d exceeds response bounds", chunkEnd, chunkStart)
}
chunk := response[chunkStart : chunkEnd-1] // Remove the \n appended at the end of the chunk
chunks = append(chunks, string(prefix)+chunk)
i = chunkEnd
}
return chunks, nil
}

func replayWebsocket(conn *websocket.Conn, chunks []string) {
for _, chunk := range chunks {
if strings.HasPrefix(chunk, ">") {
_, buf, err := conn.ReadMessage()
reqChunk := string(buf)
if err != nil {
fmt.Printf("Error reading from websocket: %v\n", err)
return
}

runes := []rune(chunk)
recChunk := string(runes[1:])
if reqChunk != recChunk {
fmt.Printf("input chunk mismatch\n Input chunk: %s\n Recorded chunk: %s\n", reqChunk, recChunk)
writeError(conn, "input chunk mismatch")
return
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to return an error in this case? how would we know to fail the test?

Copy link
Copy Markdown
Collaborator Author

@hkt74 hkt74 Jun 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good call, the current behavior test will fail after timeout (5 seconds), the test is awaiting message from the server, while the request mismatch would skip the subsequent message write

To make it fail fast, we probably could write an error message in the session, but it require the test setup to listen to the error message and fail the test.

what do you think I track this as a follow up action? need more time to explore the possible jasmine setup.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to close the connection with an error?

}
} else if strings.HasPrefix(chunk, "<") {
runes := []rune(chunk)
recChunk := string(runes[1:])
// Write binary message. (messageType=2)
err := conn.WriteMessage(2, []byte(recChunk))
if err != nil {
fmt.Printf("Error writing to websocket: %v\n", err)
return
}
} else {
fmt.Printf("Unreconginized chunk: %s", chunk)
return
}
}
}

func writeError(conn *websocket.Conn, errMsg string) {
closeMessage := websocket.FormatCloseMessage(
websocket.CloseInternalServerErr,
errMsg,
)
err := conn.WriteMessage(websocket.CloseMessage, closeMessage)
if err != nil {
fmt.Printf("Failed to write error: %v\n", err)
}
}

func (r *ReplayHTTPServer) upgradeConnectionToWebsocket(w http.ResponseWriter, req *http.Request) (*websocket.Conn, error) {
upgrader := websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true // Allow all origins
},
}

clientConn, err := upgrader.Upgrade(w, req, nil)
if err != nil {
return nil, err
}
return clientConn, err
}
19 changes: 17 additions & 2 deletions internal/store/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,26 @@ func readBody(req *http.Request) ([]byte, error) {
}

// ComputeSum computes the SHA256 sum of a RecordedRequest.
func (r *RecordedRequest) ComputeSum() (string, error) {
func (r *RecordedRequest) ComputeSum() string {
serialized := r.Serialize()
hash := sha256.Sum256([]byte(serialized))
hashHex := hex.EncodeToString(hash[:])
return hashHex, nil
return hashHex
}

// GetRecordingFileName returns the recording file name.
// It prefers the value from the TEST_NAME header.
// It returns error when test name contains illegal sequence.
// If the TEST_NAME header is not present, it falls back to computed SHA256 sum.
func (r *RecordedRequest) GetRecordingFileName() (string, error) {
testName := r.Header.Get("Test-Name")
if strings.Contains(testName, "../") {
return "", fmt.Errorf("test name: %s contains illegal sequence '../'", testName)
}
if testName != "" {
return testName, nil
}
return r.ComputeSum(), nil
}

// Serialize the request.
Expand Down
Loading
Loading