From ffcb2e16664543df34d6b9f159747bb7a63dd53b Mon Sep 17 00:00:00 2001 From: Abdul Rehman-z Date: Sun, 31 Aug 2025 22:48:06 +0500 Subject: [PATCH] feat: improve request handling by encapsulating body management and by smart buffer management plus updating header retrieval --- internal/headers/headers.go | 28 ++++++++++++-------- internal/request/request.go | 44 ++++++++++++-------------------- internal/request/request_test.go | 26 +++++++++---------- 3 files changed, 48 insertions(+), 50 deletions(-) diff --git a/internal/headers/headers.go b/internal/headers/headers.go index 2b47a64..dc595e2 100644 --- a/internal/headers/headers.go +++ b/internal/headers/headers.go @@ -3,8 +3,8 @@ package headers import ( "bytes" "fmt" - "log/slog" "sort" + "strconv" "strings" ) @@ -40,6 +40,12 @@ type Headers struct { headers map[string]string } +func NewHeaders() *Headers { + return &Headers{ + headers: map[string]string{}, + } +} + func (r *Headers) ForEach(cb func(n, v string)) { keys := make([]string, 0, len(r.headers)) for k := range r.headers { @@ -52,10 +58,18 @@ func (r *Headers) ForEach(cb func(n, v string)) { } } -func NewHeaders() *Headers { - return &Headers{ - headers: map[string]string{}, +func (h *Headers) GetInt(name string, defaultValue int) int { + value, exists := h.Get(name) + if !exists { + return defaultValue + } + + v, err := strconv.Atoi(value) + if err != nil { + return defaultValue } + + return v } func (h *Headers) Set(name, value string) { @@ -73,13 +87,10 @@ func (h *Headers) Get(name string) (string, bool) { return "", false } - // slog.Info("getHeader", "name", name, "value", v) - return v, ok } func (h *Headers) parseHeader(fieldLine []byte) (string, string, error) { - slog.Info("parseHeader", "fieldLine", string(fieldLine)) parts := bytes.SplitN(fieldLine, []byte(":"), 2) if len(parts) != 2 { return "", "", ERR_BAD_HEADER @@ -91,7 +102,6 @@ func (h *Headers) parseHeader(fieldLine []byte) (string, string, error) { return "", "", ERR_BAD_HEADER } - // slog.Info("header", "name", string(fieldName), "value", string(fieldValue)) return string(fieldName), string(fieldValue), nil } @@ -105,8 +115,6 @@ func (h *Headers) Parse(data []byte) (int, bool, error) { break } - slog.Info("parse header", "read", read) - // Empty header if idx == 0 { done = true diff --git a/internal/request/request.go b/internal/request/request.go index 8940451..675cd8d 100644 --- a/internal/request/request.go +++ b/internal/request/request.go @@ -4,8 +4,6 @@ import ( "bytes" "fmt" "io" - "log/slog" - "strconv" "github.com/merge/handly/internal/headers" ) @@ -33,7 +31,9 @@ type RequestLine struct { type Request struct { RequestLine RequestLine Headers *headers.Headers - Body string + body string + bodyBuffer []byte + bodyPos int state parsetState } @@ -41,27 +41,13 @@ func newRequest() *Request { return &Request{ Headers: headers.NewHeaders(), state: StateInit, - Body: "", } } -func getInt(h *headers.Headers, name string, defaultValue int) int { - value, exists := h.Get(name) - if !exists { - return defaultValue - } - - v, err := strconv.Atoi(value) - if err != nil { - return defaultValue - } - - return v -} - func (r *Request) hasBody() bool { - length := getInt(r.Headers, "content-length", 0) + length := r.Headers.GetInt("content-length", 0) return length > 0 + } func (r *Request) Parse(data []byte) (int, error) { @@ -105,28 +91,34 @@ outer: } read += n - slog.Info("StateHeader", "read", read) if done { if r.hasBody() { + length := r.Headers.GetInt("content-length", 0) + + r.bodyBuffer = make([]byte, length) + r.bodyPos = 0 r.state = StateBody } else { r.state = StateDone } } + case StateBody: - length := getInt(r.Headers, "content-length", 0) + length := r.Headers.GetInt("content-length", 0) if length == 0 { r.state = StateDone break outer } - slog.Info("StateBody", "length-leb(r.body)", length-len(r.Body), "length currentData", len(currentData), "currentData", currentData, "read", read) - remaining := min(length-len(r.Body), len(currentData)) - r.Body += string(currentData[:remaining]) + remaining := min(length-r.bodyPos, len(currentData)) + copy(r.bodyBuffer[r.bodyPos:], currentData[:remaining]) + + r.bodyPos += remaining read += remaining - if len(r.Body) == length { + if r.bodyPos == length { + r.body = string(r.bodyBuffer) // Single conversion at end r.state = StateDone } case StateDone: @@ -147,7 +139,6 @@ func parseRequestLine(data []byte) (*RequestLine, int, error) { return nil, 0, nil } - slog.Info("parseRequestLine", "data", string(data)) startOfLine := data[:idx] read := idx + len(SEPARATOR) @@ -181,7 +172,6 @@ func RequestFromReader(r io.Reader) (*Request, error) { bufLen += n readN, err := request.Parse(buf[:bufLen]) - slog.Info("RequestFromHeader", "readN", readN, "bufLen", bufLen) if err != nil { return nil, err } diff --git a/internal/request/request_test.go b/internal/request/request_test.go index 82b6842..7bb3e87 100644 --- a/internal/request/request_test.go +++ b/internal/request/request_test.go @@ -111,17 +111,17 @@ func TestParseBody(t *testing.T) { require.NoError(t, err) require.NotNil(t, r) - assert.Equal(t, "hello world!\n", string(r.Body)) - - // // Test: Body shorter than reported content length - // reader = &ChunkReader{ - // data: "POST /submit HTTP/1.1\r\n" + - // "Host: localhost:3000\r\n" + - // "Content-Length: 20\r\n" + - // "\r\n" + - // "partial content", - // numOfBytesPerRead: 3, - // } - // r, err = RequestFromReader(reader) - // require.Error(t, err) + assert.Equal(t, "hello world!\n", string(r.body)) + + // Test: Body shorter than reported content length + reader = &ChunkReader{ + data: "POST /submit HTTP/1.1\r\n" + + "Host: localhost:3000\r\n" + + "Content-Length: 20\r\n" + + "\r\n" + + "partial content", + numOfBytesPerRead: 3, + } + r, err = RequestFromReader(reader) + require.Error(t, err) }