diff --git a/AGENTS.md b/AGENTS.md index 09a1b45..544261c 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -72,6 +72,8 @@ prettier -w . 5. HTTP client executes request 6. Response formatted based on Content-Type and output to stdout (optionally via pager) +Retryable requests replay bodies by calling `req.GetBody` when available, reopening file-backed bodies directly when possible, and only spooling the original body to a temp file as a final fallback for one-shot streams. This avoids holding large uploads in memory and keeps retries working for closable bodies like `*os.File`. + ### Content Type Detection `internal/fetch/fetch.go:getContentType()` maps MIME types to formatters. Supported types include JSON, XML, YAML, HTML, CSS, CSV, msgpack, protobuf, gRPC, SSE, NDJSON, and images. diff --git a/internal/fetch/fetch.go b/internal/fetch/fetch.go index 4e440ee..f86c52f 100644 --- a/internal/fetch/fetch.go +++ b/internal/fetch/fetch.go @@ -32,6 +32,14 @@ import ( // formatting a response body or copying it to the clipboard. const maxBodyBytes = 1 << 20 // 1MiB +func setReplayableBody(req *http.Request, data []byte) { + req.Body = io.NopCloser(bytes.NewReader(data)) + req.ContentLength = int64(len(data)) + req.GetBody = func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil + } +} + type Request struct { AWSSigv4 *aws.Config Basic *core.KeyVal[string] @@ -213,13 +221,13 @@ func fetch(ctx context.Context, r *Request) (int, error) { if err != nil { return 0, err } - req.Body = io.NopCloser(converted) + setReplayableBody(req, converted) } framed, err := frameGRPCRequest(req.Body) if err != nil { return 0, err } - req.Body = io.NopCloser(framed) + setReplayableBody(req, framed) } } diff --git a/internal/fetch/proto.go b/internal/fetch/proto.go index b5666dd..5db3d25 100644 --- a/internal/fetch/proto.go +++ b/internal/fetch/proto.go @@ -1,7 +1,6 @@ package fetch import ( - "bytes" "encoding/json" "fmt" "io" @@ -107,7 +106,7 @@ func setupGRPC(r *Request, schema *proto.Schema) (protoreflect.MessageDescriptor } // convertJSONToProtobuf converts JSON body to protobuf. -func convertJSONToProtobuf(data io.Reader, desc protoreflect.MessageDescriptor) (io.Reader, error) { +func convertJSONToProtobuf(data io.Reader, desc protoreflect.MessageDescriptor) ([]byte, error) { // Read all the JSON data. jsonData, err := io.ReadAll(data) if err != nil { @@ -120,12 +119,12 @@ func convertJSONToProtobuf(data io.Reader, desc protoreflect.MessageDescriptor) return nil, fmt.Errorf("failed to convert JSON to protobuf: %w", err) } - return bytes.NewReader(protoData), nil + return protoData, nil } // frameGRPCRequest wraps data in gRPC framing. // Handles nil/empty body by sending an empty framed message. -func frameGRPCRequest(data io.Reader) (io.Reader, error) { +func frameGRPCRequest(data io.Reader) ([]byte, error) { var rawData []byte if data != nil && data != http.NoBody { var err error @@ -137,7 +136,7 @@ func frameGRPCRequest(data io.Reader) (io.Reader, error) { // Frame with gRPC format (works for empty data too). framedData := fetchgrpc.Frame(rawData, false) - return bytes.NewReader(framedData), nil + return framedData, nil } // streamGRPCRequest reads JSON objects from data, converts each to protobuf, diff --git a/internal/fetch/retry.go b/internal/fetch/retry.go index 1c16f55..32dda99 100644 --- a/internal/fetch/retry.go +++ b/internal/fetch/retry.go @@ -1,7 +1,6 @@ package fetch import ( - "bytes" "context" "errors" "fmt" @@ -11,6 +10,7 @@ import ( "net/http" "net/http/httptrace" "net/url" + "os" "strconv" "time" @@ -31,6 +31,9 @@ func retryableRequest(ctx context.Context, r *Request, c *client.Client, req *ht if err != nil { return 0, err } + if replayer != nil { + defer replayer.close() + } } var hadRedirects bool @@ -246,10 +249,11 @@ func sleepWithContext(ctx context.Context, d time.Duration) error { } } -// replayableBody allows a request body to be replayed across retry attempts. +// replayableBody reopens a request body for each retry attempt. type replayableBody struct { - seeker io.ReadSeeker - data []byte + open func() (io.ReadCloser, error) + cleanup func() error + tempPath string } // newReplayableBody creates a replayableBody from the request's current body. @@ -259,31 +263,113 @@ func newReplayableBody(req *http.Request) (*replayableBody, error) { return nil, nil } - // If the body supports seeking, use it directly. - if rs, ok := req.Body.(io.ReadSeeker); ok { - return &replayableBody{seeker: rs}, nil + if req.GetBody != nil { + if err := req.Body.Close(); err != nil { + return nil, err + } + return &replayableBody{open: req.GetBody}, nil + } + + if f, ok := req.Body.(*os.File); ok && f != os.Stdin { + offset, err := f.Seek(0, io.SeekCurrent) + if err != nil { + return nil, err + } + path := f.Name() + if err := f.Close(); err != nil { + return nil, err + } + return &replayableBody{ + open: func() (io.ReadCloser, error) { + reopened, err := os.Open(path) + if err != nil { + return nil, err + } + if offset != 0 { + if _, err := reopened.Seek(offset, io.SeekStart); err != nil { + reopened.Close() + return nil, err + } + } + return reopened, nil + }, + }, nil + } + + if rs, ok := req.Body.(io.ReadSeeker); ok && req.Body != os.Stdin { + var cleanup func() error + if closer, ok := req.Body.(io.Closer); ok { + cleanup = closer.Close + } + return &replayableBody{ + open: func() (io.ReadCloser, error) { + if _, err := rs.Seek(0, io.SeekStart); err != nil { + return nil, err + } + return nopReadCloser{Reader: rs}, nil + }, + cleanup: cleanup, + }, nil } - // Otherwise, read the entire body into memory. - data, err := io.ReadAll(req.Body) + tmp, err := os.CreateTemp("", "fetch-retry-body-*") if err != nil { return nil, err } - req.Body.Close() - return &replayableBody{data: data}, nil + tmpPath := tmp.Name() + cleanup := func() error { + return os.Remove(tmpPath) + } + + _, copyErr := io.Copy(tmp, req.Body) + closeErr := req.Body.Close() + if copyErr != nil { + tmp.Close() + cleanup() + return nil, copyErr + } + if closeErr != nil { + tmp.Close() + cleanup() + return nil, closeErr + } + if err := tmp.Close(); err != nil { + cleanup() + return nil, err + } + + return &replayableBody{ + open: func() (io.ReadCloser, error) { + return os.Open(tmpPath) + }, + cleanup: cleanup, + tempPath: tmpPath, + }, nil } // reset returns a fresh io.ReadCloser for the next attempt. func (rb *replayableBody) reset() (io.ReadCloser, error) { - if rb.seeker != nil { - if _, err := rb.seeker.Seek(0, io.SeekStart); err != nil { - return nil, err - } - return io.NopCloser(rb.seeker), nil + if rb == nil { + return nil, nil } - return io.NopCloser(bytes.NewReader(rb.data)), nil + return rb.open() } +func (rb *replayableBody) close() error { + if rb == nil || rb.cleanup == nil { + return nil + } + err := rb.cleanup() + rb.cleanup = nil + return err +} + +type nopReadCloser struct { + io.Reader +} + +func (nopReadCloser) Close() error { return nil } + // retryReason returns a human-readable reason for the retry. func retryReason(resp *http.Response, err error) string { if err != nil { diff --git a/internal/fetch/retry_test.go b/internal/fetch/retry_test.go index b5aed1b..3a9eca6 100644 --- a/internal/fetch/retry_test.go +++ b/internal/fetch/retry_test.go @@ -9,6 +9,7 @@ import ( "net" "net/http" "net/url" + "os" "strings" "testing" "time" @@ -288,15 +289,19 @@ func TestSleepWithContext(t *testing.T) { } func TestReplayableBody(t *testing.T) { - t.Run("seekable body", func(t *testing.T) { - body := &readSeekCloser{Reader: bytes.NewReader([]byte("hello"))} - req := &http.Request{Body: body} + t.Run("getbody body", func(t *testing.T) { + req := &http.Request{ + Body: io.NopCloser(bytes.NewReader([]byte("hello"))), + GetBody: func() (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader([]byte("hello"))), nil + }, + } rb, err := newReplayableBody(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if rb.seeker == nil { - t.Fatal("expected seeker path to be used for ReadSeeker body") + if rb.tempPath != "" { + t.Fatal("expected GetBody path to avoid temp spool") } for range 3 { @@ -311,20 +316,33 @@ func TestReplayableBody(t *testing.T) { if string(data) != "hello" { t.Errorf("expected 'hello', got '%s'", data) } + rc.Close() } }) - t.Run("buffered body", func(t *testing.T) { - body := bytes.NewReader([]byte("hello")) - req := &http.Request{Body: io.NopCloser(body)} - // bytes.Reader wrapped in NopCloser is not a ReadSeeker, - // so it will be read into memory. + t.Run("closable seekable body", func(t *testing.T) { + f, err := os.CreateTemp(t.TempDir(), "body-*") + if err != nil { + t.Fatalf("create temp file: %v", err) + } + if _, err := f.WriteString("hello"); err != nil { + t.Fatalf("write temp file: %v", err) + } + if _, err := f.Seek(0, io.SeekStart); err != nil { + t.Fatalf("seek temp file: %v", err) + } + + req := &http.Request{Body: f} rb, err := newReplayableBody(req) if err != nil { t.Fatalf("unexpected error: %v", err) } - if rb.seeker != nil { - t.Fatal("expected buffered path to be used for non-ReadSeeker body") + defer rb.close() + if rb.tempPath != "" { + t.Fatal("expected file-backed body to replay without temp spool") + } + if _, err := f.Read(make([]byte, 1)); !isClosedFileErr(err) { + t.Fatalf("expected original file to be closed, got %v", err) } for range 3 { @@ -339,28 +357,46 @@ func TestReplayableBody(t *testing.T) { if string(data) != "hello" { t.Errorf("expected 'hello', got '%s'", data) } + if err := rc.Close(); err != nil { + t.Fatalf("close error: %v", err) + } } }) - t.Run("non-seekable body", func(t *testing.T) { - body := io.NopCloser(strings.NewReader("world")) + t.Run("large streamed body", func(t *testing.T) { + const size = 8 << 20 + body := &streamingReadCloser{remaining: size, fill: 'x'} req := &http.Request{Body: body} rb, err := newReplayableBody(req) if err != nil { t.Fatalf("unexpected error: %v", err) } + defer rb.close() + if rb.tempPath == "" { + t.Fatal("expected temp spool for streamed body") + } + info, err := os.Stat(rb.tempPath) + if err != nil { + t.Fatalf("stat temp spool: %v", err) + } + if info.Size() != size { + t.Fatalf("expected temp spool size %d, got %d", size, info.Size()) + } for range 3 { rc, err := rb.reset() if err != nil { t.Fatalf("reset error: %v", err) } - data, err := io.ReadAll(rc) + n, err := io.Copy(io.Discard, rc) if err != nil { t.Fatalf("read error: %v", err) } - if string(data) != "world" { - t.Errorf("expected 'world', got '%s'", data) + if n != size { + t.Fatalf("expected %d bytes, got %d", size, n) + } + if err := rc.Close(); err != nil { + t.Fatalf("close error: %v", err) } } }) @@ -388,9 +424,34 @@ func TestReplayableBody(t *testing.T) { }) } -// readSeekCloser wraps a bytes.Reader to implement io.ReadSeeker and io.ReadCloser. -type readSeekCloser struct { - *bytes.Reader +func isClosedFileErr(err error) bool { + return err != nil && strings.Contains(err.Error(), "file already closed") } -func (r *readSeekCloser) Close() error { return nil } +type streamingReadCloser struct { + remaining int64 + fill byte + closed bool +} + +func (r *streamingReadCloser) Read(p []byte) (int, error) { + if r.closed { + return 0, os.ErrClosed + } + if r.remaining == 0 { + return 0, io.EOF + } + if int64(len(p)) > r.remaining { + p = p[:r.remaining] + } + for i := range p { + p[i] = r.fill + } + r.remaining -= int64(len(p)) + return len(p), nil +} + +func (r *streamingReadCloser) Close() error { + r.closed = true + return nil +}