diff --git a/README.md b/README.md index ee4f51e..f14132e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ +
json, query, header, and form tags route input automatically.>,
},
{
icon: ,
diff --git a/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx b/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx
index d7bb648..604a6b5 100644
--- a/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx
+++ b/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx
@@ -96,7 +96,7 @@ Error types merge across all three levels. In the example above, `GET /api/v1/us
## Parse errors
-When the framework cannot parse the request (malformed JSON, invalid query parameter types, invalid form data), it returns a `400 Bad Request`. By default the response body is `{"message": "bad request"}`.
+When the framework cannot parse the request (malformed JSON, invalid query parameter types, invalid header values, invalid form data), it returns a `400 Bad Request`. By default the response body is `{"message": "bad request"}`.
### Customizing the 400 response
diff --git a/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx b/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx
index e14fa0f..af164bb 100644
--- a/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx
+++ b/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx
@@ -28,8 +28,8 @@ func handler(r *http.Request, in InputType) (OutputType, error) {
}
```
-- **`r *http.Request`** — the standard HTTP request, useful for headers, cookies, path parameters, etc.
-- **`in InputType`** — automatically decoded from JSON body, query parameters, or multipart form data
+- **`r *http.Request`** — the standard HTTP request, useful for cookies, path parameters, and other request metadata
+- **`in InputType`** — automatically decoded from JSON body, query parameters, HTTP headers, or multipart form data
- **`OutputType`** — serialized as JSON in the response (can be a pointer or value type)
- **`error`** — return an error to send an error response
@@ -50,17 +50,21 @@ Fields are decoded based on struct tags:
| `path` | URL path parameters | `path:"id"` |
| `json` | Request body (JSON) | `json:"name"` |
| `query` | Query parameters | `query:"page"` |
+| `header` | HTTP headers | `header:"Authorization"` |
| `form` | Multipart form data | `form:"file"` |
-You can combine tags in a single struct:
+You can combine tags in a single struct — `path`, `query`, `header`, and `json` fields can be freely mixed. The only restriction is that `form` and `json` tags cannot coexist on the same struct.
```go
type UpdateUser struct {
- ID int `path:"id"`
- Name string `json:"name"`
+ ID int `path:"id"`
+ Token string `header:"Authorization"`
+ Name string `json:"name"`
}
```
+Header and query fields support `string`, `bool`, `int*`, `uint*`, `float*` scalars and `*T` pointers for optional values. Query fields also support `[]T` slices for repeated params. Slices are not supported for headers. Parse errors return `400`; validation failures return `422`.
+
## Path parameters
Use `path` tags to declare typed path parameters. The field name in the tag must match a `{param}` in the route pattern:
diff --git a/apps/landing/src/content/docs/docs/core-concepts/validation.mdx b/apps/landing/src/content/docs/docs/core-concepts/validation.mdx
index fd130bd..8022c39 100644
--- a/apps/landing/src/content/docs/docs/core-concepts/validation.mdx
+++ b/apps/landing/src/content/docs/docs/core-concepts/validation.mdx
@@ -74,6 +74,18 @@ type UserResponse struct {
}
```
+## Validation on query and header fields
+
+Validation rules work on all input sources — `json`, `query`, `header`, and `form` fields. For example, `validate:"required"` on a `header`-tagged field returns `422` if the header is present but empty, and `validate:"oneof=json xml"` constrains accepted values. Parse errors (e.g. sending `"abc"` for an `int` header) return `400`; validation failures return `422`.
+
+```go
+type AuthSearch struct {
+ Token string `header:"Authorization" validate:"required"`
+ Format string `header:"Accept" validate:"oneof=json xml"`
+ Q string `query:"q" validate:"required"`
+}
+```
+
## Custom validator
You can provide your own validator instance with custom rules:
diff --git a/apps/landing/src/content/docs/docs/getting-started/introduction.mdx b/apps/landing/src/content/docs/docs/getting-started/introduction.mdx
index 1b98438..5257920 100644
--- a/apps/landing/src/content/docs/docs/getting-started/introduction.mdx
+++ b/apps/landing/src/content/docs/docs/getting-started/introduction.mdx
@@ -19,6 +19,7 @@ ShiftAPI is a Go framework that generates an OpenAPI 3.1 spec from your handler
- **Type-safe errors** — declare custom error types at the API, group, or route level with `WithError[T](status)` — schemas appear in the OpenAPI spec and errors are matched at runtime via `errors.As`
- **Middleware** — apply standard `func(http.Handler) http.Handler` middleware at the API, group, or route level with `WithMiddleware` — middleware resolves from API (outermost) → Group → Route (innermost)
- **Composable options** — bundle middleware, errors, and other options into reusable `Option` values with `ComposeOptions`
+- **Typed HTTP headers** — parse, validate, and document HTTP headers with `header` struct tags
- **File uploads** — declare uploads with `form` tags, get correct `multipart/form-data` types
- **Interactive docs** — Scalar API reference at `/docs` and OpenAPI spec at `/openapi.json`
- **Just `net/http`** — implements `http.Handler`, works with any middleware or router
diff --git a/form.go b/form.go
index cd99f84..2e6245b 100644
--- a/form.go
+++ b/form.go
@@ -154,3 +154,5 @@ func (e *formParseError) Error() string {
}
return fmt.Sprintf("invalid form field %q: %v", e.Field, e.Err)
}
+
+func (e *formParseError) Unwrap() error { return e.Err }
diff --git a/handler.go b/handler.go
index f93c4fd..537516d 100644
--- a/handler.go
+++ b/handler.go
@@ -15,17 +15,18 @@ import (
// The In struct's fields are discriminated by struct tags:
// - path:"name" — parsed from URL path parameters (e.g. /users/{id})
// - query:"name" — parsed from URL query parameters
+// - header:"name" — parsed from HTTP headers
// - json:"name" — parsed from the JSON request body (default for POST/PUT/PATCH)
// - form:"name" — parsed from multipart/form-data (for file uploads)
//
// Use struct{} as In for routes that take no input, or as Resp for routes
// that return no body (e.g. health checks that only need a status code).
//
-// The [*http.Request] parameter gives access to headers, cookies, path
+// The [*http.Request] parameter gives access to cookies, path
// parameters, and other request metadata.
type HandlerFunc[In, Resp any] func(r *http.Request, in In) (Resp, error)
-func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasPath, hasQuery, hasBody, hasForm bool, maxUploadSize int64, errLookup errorLookup, badRequestFn, internalServerFn func(error) any) http.HandlerFunc {
+func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasPath, hasQuery, hasHeader, hasBody, hasForm bool, maxUploadSize int64, errLookup errorLookup, badRequestFn, internalServerFn func(error) any) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var in In
rv := reflect.ValueOf(&in).Elem()
@@ -56,6 +57,12 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any
}
}
+ // Reset any header-tagged fields that body decode may have
+ // inadvertently set, so they only come from HTTP headers.
+ if hasBody && hasHeader {
+ resetHeaderFields(rv)
+ }
+
// Parse query params if there are query fields
if hasQuery {
if err := parseQueryInto(rv, r.URL.Query()); err != nil {
@@ -72,6 +79,14 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any
}
}
+ // Parse headers if there are header fields
+ if hasHeader {
+ if err := parseHeadersInto(rv, r.Header); err != nil {
+ writeJSON(w, http.StatusBadRequest, badRequestFn(err))
+ return
+ }
+ }
+
if err := validate(in); err != nil {
handleError(w, internalServerFn, err, errLookup)
return
diff --git a/handlerFuncs.go b/handlerFuncs.go
index be96a29..1ed2b84 100644
--- a/handlerFuncs.go
+++ b/handlerFuncs.go
@@ -28,7 +28,7 @@ func registerRoute[In, Resp any](
rawInType = rawInType.Elem()
}
- hasPath, hasQuery, hasBody, hasForm := partitionFields(rawInType)
+ hasPath, hasQuery, hasHeader, hasBody, hasForm := partitionFields(rawInType)
// Validate path-tagged fields match the route pattern.
if hasPath {
@@ -44,6 +44,10 @@ func registerRoute[In, Resp any](
if hasQuery {
queryType = rawInType
}
+ var headerType reflect.Type
+ if hasHeader {
+ headerType = rawInType
+ }
// POST/PUT/PATCH conventionally carry a request body, so always attempt
// body decode for these methods — even when the input is struct{}.
// This means Post(api, path, func(r, _ struct{}) ...) requires at least "{}".
@@ -71,14 +75,14 @@ func registerRoute[In, Resp any](
pathType = rawInType
}
- if err := api.updateSchema(method, fullPath, pathType, queryType, bodyType, outType, hasForm, rawInType, cfg.info, cfg.status, allErrors); err != nil {
+ if err := api.updateSchema(method, fullPath, pathType, queryType, headerType, bodyType, outType, hasForm, rawInType, cfg.info, cfg.status, allErrors); err != nil {
panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, fullPath, err))
}
errLookup := buildErrorLookup(allErrors)
pattern := fmt.Sprintf("%s %s", method, fullPath)
- var h http.Handler = adapt(fn, cfg.status, api.validateBody, hasPath, hasQuery, decodeBody, hasForm, api.maxUploadSize, errLookup, api.badRequestFn, api.internalServerFn)
+ var h http.Handler = adapt(fn, cfg.status, api.validateBody, hasPath, hasQuery, hasHeader, decodeBody, hasForm, api.maxUploadSize, errLookup, api.badRequestFn, api.internalServerFn)
// Apply route-level middleware (innermost), then group-level (outermost).
// Reverse order so the first middleware in the slice wraps outermost.
for i := len(cfg.middleware) - 1; i >= 0; i-- {
diff --git a/header.go b/header.go
new file mode 100644
index 0000000..1f6aac3
--- /dev/null
+++ b/header.go
@@ -0,0 +1,106 @@
+package shiftapi
+
+import (
+ "fmt"
+ "net/http"
+ "reflect"
+ "strings"
+)
+
+// hasHeaderTag returns true if the struct field has a `header` tag.
+func hasHeaderTag(f reflect.StructField) bool {
+ return f.Tag.Get("header") != ""
+}
+
+// headerFieldName returns the header name for a struct field.
+func headerFieldName(f reflect.StructField) string {
+ name, _, _ := strings.Cut(f.Tag.Get("header"), ",")
+ if name == "" {
+ return f.Name
+ }
+ return name
+}
+
+// resetHeaderFields zeros out any header-tagged fields on a struct value.
+// This is called after body decode so that header-tagged fields are only
+// populated by parseHeadersInto, not by JSON keys that happen to match.
+func resetHeaderFields(rv reflect.Value) {
+ for rv.Kind() == reflect.Pointer {
+ rv = rv.Elem()
+ }
+ if rv.Kind() != reflect.Struct {
+ return
+ }
+ rt := rv.Type()
+ for i := range rt.NumField() {
+ f := rt.Field(i)
+ if f.IsExported() && hasHeaderTag(f) {
+ rv.Field(i).SetZero()
+ }
+ }
+}
+
+// parseHeadersInto populates header-tagged fields on an existing struct value
+// from HTTP headers. Non-header fields are left untouched.
+// Only scalar types and pointer-to-scalar types are supported (no slices).
+func parseHeadersInto(rv reflect.Value, header http.Header) error {
+ for rv.Kind() == reflect.Pointer {
+ if rv.IsNil() {
+ rv.Set(reflect.New(rv.Type().Elem()))
+ }
+ rv = rv.Elem()
+ }
+
+ rt := rv.Type()
+ if rt.Kind() != reflect.Struct {
+ return fmt.Errorf("header type must be a struct, got %s", rt.Kind())
+ }
+
+ for i := range rt.NumField() {
+ field := rt.Field(i)
+ if !field.IsExported() || !hasHeaderTag(field) {
+ continue
+ }
+
+ name := headerFieldName(field)
+ fv := rv.Field(i)
+ ft := field.Type
+
+ // Handle pointer fields (optional headers)
+ if ft.Kind() == reflect.Pointer {
+ raw := header.Get(name)
+ if raw == "" {
+ continue
+ }
+ ptr := reflect.New(ft.Elem())
+ if err := setScalarValue(ptr.Elem(), raw); err != nil {
+ return &headerParseError{Field: name, Err: err}
+ }
+ fv.Set(ptr)
+ continue
+ }
+
+ // Handle scalar fields
+ raw := header.Get(name)
+ if raw == "" {
+ continue
+ }
+ if err := setScalarValue(fv, raw); err != nil {
+ return &headerParseError{Field: name, Err: err}
+ }
+ }
+
+ return nil
+}
+
+// headerParseError is returned when a header value cannot be parsed.
+type headerParseError struct {
+ Field string
+ Err error
+}
+
+func (e *headerParseError) Error() string {
+ return fmt.Sprintf("invalid header %q: %v", e.Field, e.Err)
+}
+
+func (e *headerParseError) Unwrap() error { return e.Err }
diff --git a/path.go b/path.go
index c7b72e1..92896ea 100644
--- a/path.go
+++ b/path.go
@@ -109,3 +109,5 @@ type pathParseError struct {
func (e *pathParseError) Error() string {
return fmt.Sprintf("invalid path parameter %q: %v", e.Field, e.Err)
}
+
+func (e *pathParseError) Unwrap() error { return e.Err }
diff --git a/query.go b/query.go
index a6dd3ed..860baf7 100644
--- a/query.go
+++ b/query.go
@@ -14,14 +14,15 @@ func hasQueryTag(f reflect.StructField) bool {
}
// partitionFields inspects a struct type and reports whether it contains
-// path-tagged, query-tagged, body (json-tagged or untagged non-path/query) fields,
-// and/or form-tagged fields. It panics if both body and form fields are present.
-func partitionFields(t reflect.Type) (hasPath, hasQuery, hasBody, hasForm bool) {
+// path-tagged, query-tagged, header-tagged, body (json-tagged or untagged
+// non-path/query/header) fields, and/or form-tagged fields. It panics if both
+// body and form fields are present.
+func partitionFields(t reflect.Type) (hasPath, hasQuery, hasHeader, hasBody, hasForm bool) {
for t.Kind() == reflect.Pointer {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
- return false, false, false, false
+ return false, false, false, false, false
}
for f := range t.Fields() {
if !f.IsExported() {
@@ -31,10 +32,12 @@ func partitionFields(t reflect.Type) (hasPath, hasQuery, hasBody, hasForm bool)
hasPath = true
} else if hasQueryTag(f) {
hasQuery = true
+ } else if hasHeaderTag(f) {
+ hasHeader = true
} else if hasFormTag(f) {
hasForm = true
} else {
- // Any exported field without a path, query, or form tag is a body field
+ // Any exported field without a path, query, header, or form tag is a body field
jsonTag := f.Tag.Get("json")
if jsonTag == "-" {
continue
@@ -178,7 +181,7 @@ func setScalarValue(v reflect.Value, raw string) error {
}
v.SetFloat(n)
default:
- return fmt.Errorf("unsupported query parameter type %s", v.Kind())
+ return fmt.Errorf("unsupported parameter type %s", v.Kind())
}
return nil
}
@@ -192,3 +195,5 @@ type queryParseError struct {
func (e *queryParseError) Error() string {
return fmt.Sprintf("invalid query parameter %q: %v", e.Field, e.Err)
}
+
+func (e *queryParseError) Unwrap() error { return e.Err }
diff --git a/schema.go b/schema.go
index 800aea5..c8a93ad 100644
--- a/schema.go
+++ b/schema.go
@@ -12,7 +12,7 @@ import (
var pathParamRe = regexp.MustCompile(`\{([^}]+)\}`)
-func (a *API) updateSchema(method, path string, pathType, queryType, inType, outType reflect.Type, hasForm bool, formType reflect.Type, info *RouteInfo, status int, errors []errorEntry) error {
+func (a *API) updateSchema(method, path string, pathType, queryType, headerType, inType, outType reflect.Type, hasForm bool, formType reflect.Type, info *RouteInfo, status int, errors []errorEntry) error {
op := &openapi3.Operation{
OperationID: operationID(method, path),
Responses: openapi3.NewResponses(),
@@ -67,6 +67,15 @@ func (a *API) updateSchema(method, path string, pathType, queryType, inType, out
op.Parameters = append(op.Parameters, queryParams...)
}
+ // Header parameters
+ if headerType != nil {
+ headerParams, err := a.generateHeaderParams(headerType)
+ if err != nil {
+ return err
+ }
+ op.Parameters = append(op.Parameters, headerParams...)
+ }
+
// Response schema
statusStr := fmt.Sprintf("%d", status)
outSchema, err := a.generateSchemaRef(outType)
@@ -140,8 +149,9 @@ func (a *API) updateSchema(method, path string, pathType, queryType, inType, out
return err
}
if inSchema != nil {
- // Strip query-tagged and path-tagged fields from the body schema
+ // Strip query-tagged, header-tagged, and path-tagged fields from the body schema
stripQueryFields(inType, inSchema.Value)
+ stripHeaderFields(inType, inSchema.Value)
stripPathFields(inType, inSchema.Value)
if len(inSchema.Value.Properties) > 0 {
@@ -457,6 +467,73 @@ func stripQueryFields(t reflect.Type, schema *openapi3.Schema) {
}
}
+// generateHeaderParams produces OpenAPI parameter definitions for a header struct type.
+// Only fields with `header` tags are included. Slices are not supported for headers.
+func (a *API) generateHeaderParams(t reflect.Type) ([]*openapi3.ParameterRef, error) {
+ for t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ }
+ if t.Kind() != reflect.Struct {
+ return nil, fmt.Errorf("header type must be a struct, got %s", t.Kind())
+ }
+
+ var params []*openapi3.ParameterRef
+ for field := range t.Fields() {
+ if !field.IsExported() {
+ continue
+ }
+ if !hasHeaderTag(field) {
+ continue
+ }
+ name := http.CanonicalHeaderKey(headerFieldName(field))
+ schema := scalarToOpenAPISchema(field.Type)
+
+ // Apply validation constraints
+ if err := validateSchemaCustomizer(name, field.Type, field.Tag, schema.Value); err != nil {
+ return nil, err
+ }
+
+ required := hasRule(field.Tag.Get("validate"), "required")
+
+ params = append(params, &openapi3.ParameterRef{
+ Value: &openapi3.Parameter{
+ Name: name,
+ In: "header",
+ Required: required,
+ Schema: schema,
+ },
+ })
+ }
+ return params, nil
+}
+
+// stripHeaderFields removes header-tagged fields from a body schema's Properties and Required.
+func stripHeaderFields(t reflect.Type, schema *openapi3.Schema) {
+ for t.Kind() == reflect.Pointer {
+ t = t.Elem()
+ }
+ if t.Kind() != reflect.Struct || schema == nil {
+ return
+ }
+ for f := range t.Fields() {
+ if !f.IsExported() || !hasHeaderTag(f) {
+ continue
+ }
+ jname := jsonFieldName(f)
+ if jname == "" || jname == "-" {
+ continue
+ }
+ delete(schema.Properties, jname)
+ // Remove from Required slice
+ for j, req := range schema.Required {
+ if req == jname {
+ schema.Required = append(schema.Required[:j], schema.Required[j+1:]...)
+ break
+ }
+ }
+ }
+}
+
// stripPathFields removes path-tagged fields from a body schema's Properties and Required.
func stripPathFields(t reflect.Type, schema *openapi3.Schema) {
for t.Kind() == reflect.Pointer {
diff --git a/shiftapi_test.go b/shiftapi_test.go
index b70748e..ec1bc33 100644
--- a/shiftapi_test.go
+++ b/shiftapi_test.go
@@ -4111,3 +4111,497 @@ func TestSpecPathParamExcludedFromBody(t *testing.T) {
t.Error("path field 'ID' should not appear in body schema")
}
}
+
+// --- Header parameter test types ---
+
+type AuthHeader struct {
+ Token string `header:"Authorization" validate:"required"`
+}
+
+type AuthResult struct {
+ Token string `json:"token"`
+}
+
+type OptionalHeader struct {
+ Debug *bool `header:"X-Debug"`
+ Limit *int `header:"X-Limit"`
+}
+
+type OptionalHeaderResult struct {
+ HasDebug bool `json:"has_debug"`
+ Debug bool `json:"debug"`
+ HasLimit bool `json:"has_limit"`
+ Limit int `json:"limit"`
+}
+
+type HeaderAndBody struct {
+ Token string `header:"Authorization" validate:"required"`
+ Name string `json:"name" validate:"required"`
+}
+
+type HeaderAndBodyResult struct {
+ Token string `json:"token"`
+ Name string `json:"name"`
+}
+
+type HeaderAndQuery struct {
+ Token string `header:"Authorization" validate:"required"`
+ Q string `query:"q"`
+}
+
+type HeaderAndQueryResult struct {
+ Token string `json:"token"`
+ Q string `json:"q"`
+}
+
+type ScalarHeaders struct {
+ Flag bool `header:"X-Flag"`
+ Count uint `header:"X-Count"`
+ Score float64 `header:"X-Score"`
+}
+
+type ScalarHeaderResult struct {
+ Flag bool `json:"flag"`
+ Count uint `json:"count"`
+ Score float64 `json:"score"`
+}
+
+// --- Header parameter test helpers ---
+
+func doRequestWithHeaders(t *testing.T, api http.Handler, method, path, body string, headers map[string]string) *http.Response {
+ t.Helper()
+ var bodyReader io.Reader
+ if body != "" {
+ bodyReader = strings.NewReader(body)
+ }
+ req := httptest.NewRequest(method, path, bodyReader)
+ if body != "" {
+ req.Header.Set("Content-Type", "application/json")
+ }
+ for k, v := range headers {
+ req.Header.Set(k, v)
+ }
+ rec := httptest.NewRecorder()
+ api.ServeHTTP(rec, req)
+ return rec.Result()
+}
+
+// --- Header parameter runtime tests ---
+
+func TestGetWithHeaderBasic(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) {
+ return &AuthResult{Token: in.Token}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/auth", "", map[string]string{
+ "Authorization": "Bearer abc123",
+ })
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+ result := decodeJSON[AuthResult](t, resp)
+ if result.Token != "Bearer abc123" {
+ t.Errorf("expected Token=%q, got %q", "Bearer abc123", result.Token)
+ }
+}
+
+func TestGetWithHeaderMissingRequired(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) {
+ return &AuthResult{Token: in.Token}, nil
+ })
+
+ // Missing required "Authorization" header
+ resp := doRequest(t, api, http.MethodGet, "/auth", "")
+ if resp.StatusCode != http.StatusUnprocessableEntity {
+ t.Fatalf("expected 422, got %d", resp.StatusCode)
+ }
+}
+
+func TestGetWithHeaderInvalidType(t *testing.T) {
+ api := newTestAPI(t)
+ type IntHeader struct {
+ Count int `header:"X-Count" validate:"required"`
+ }
+ shiftapi.Get(api, "/count", func(r *http.Request, in IntHeader) (*Status, error) {
+ return &Status{OK: true}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/count", "", map[string]string{
+ "X-Count": "notanumber",
+ })
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Fatalf("expected 400, got %d", resp.StatusCode)
+ }
+}
+
+func TestGetWithHeaderOptionalPointers(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalHeader) (*OptionalHeaderResult, error) {
+ result := &OptionalHeaderResult{}
+ if in.Debug != nil {
+ result.HasDebug = true
+ result.Debug = *in.Debug
+ }
+ if in.Limit != nil {
+ result.HasLimit = true
+ result.Limit = *in.Limit
+ }
+ return result, nil
+ })
+
+ // With both headers
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/optional", "", map[string]string{
+ "X-Debug": "true",
+ "X-Limit": "50",
+ })
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+ result := decodeJSON[OptionalHeaderResult](t, resp)
+ if !result.HasDebug || !result.Debug {
+ t.Error("expected Debug to be true")
+ }
+ if !result.HasLimit || result.Limit != 50 {
+ t.Errorf("expected Limit=50, got %d", result.Limit)
+ }
+
+ // Without optional headers
+ resp2 := doRequest(t, api, http.MethodGet, "/optional", "")
+ if resp2.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp2.StatusCode)
+ }
+ result2 := decodeJSON[OptionalHeaderResult](t, resp2)
+ if result2.HasDebug {
+ t.Error("expected HasDebug=false when header absent")
+ }
+ if result2.HasLimit {
+ t.Error("expected HasLimit=false when header absent")
+ }
+}
+
+func TestPostWithHeaderAndBody(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) {
+ return &HeaderAndBodyResult{Token: in.Token, Name: in.Name}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodPost, "/items", `{"name":"widget"}`, map[string]string{
+ "Authorization": "Bearer xyz",
+ })
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+ result := decodeJSON[HeaderAndBodyResult](t, resp)
+ if result.Token != "Bearer xyz" {
+ t.Errorf("expected Token=%q, got %q", "Bearer xyz", result.Token)
+ }
+ if result.Name != "widget" {
+ t.Errorf("expected Name=%q, got %q", "widget", result.Name)
+ }
+}
+
+func TestHeaderFieldNotSetByBodyDecode(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) {
+ return &HeaderAndBodyResult{Token: in.Token, Name: in.Name}, nil
+ })
+
+ // Body includes "Token" key that matches the header field name — it should be ignored
+ resp := doRequestWithHeaders(t, api, http.MethodPost, "/items", `{"name":"widget","Token":"body-token"}`, map[string]string{
+ "Authorization": "Bearer real",
+ })
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+ result := decodeJSON[HeaderAndBodyResult](t, resp)
+ if result.Token != "Bearer real" {
+ t.Errorf("expected Token=%q from header, got %q", "Bearer real", result.Token)
+ }
+}
+
+func TestGetWithHeaderAndQuery(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/search", func(r *http.Request, in HeaderAndQuery) (*HeaderAndQueryResult, error) {
+ return &HeaderAndQueryResult{Token: in.Token, Q: in.Q}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/search?q=hello", "", map[string]string{
+ "Authorization": "Bearer abc",
+ })
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+ result := decodeJSON[HeaderAndQueryResult](t, resp)
+ if result.Token != "Bearer abc" {
+ t.Errorf("expected Token=%q, got %q", "Bearer abc", result.Token)
+ }
+ if result.Q != "hello" {
+ t.Errorf("expected Q=%q, got %q", "hello", result.Q)
+ }
+}
+
+func TestGetWithHeaderScalars(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/scalars", func(r *http.Request, in ScalarHeaders) (*ScalarHeaderResult, error) {
+ return &ScalarHeaderResult{Flag: in.Flag, Count: in.Count, Score: in.Score}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/scalars", "", map[string]string{
+ "X-Flag": "true",
+ "X-Count": "42",
+ "X-Score": "3.14",
+ })
+ if resp.StatusCode != http.StatusOK {
+ t.Fatalf("expected 200, got %d", resp.StatusCode)
+ }
+ result := decodeJSON[ScalarHeaderResult](t, resp)
+ if !result.Flag {
+ t.Error("expected Flag=true")
+ }
+ if result.Count != 42 {
+ t.Errorf("expected Count=42, got %d", result.Count)
+ }
+ if result.Score != 3.14 {
+ t.Errorf("expected Score=3.14, got %f", result.Score)
+ }
+}
+
+func TestGetWithHeaderInvalidBool(t *testing.T) {
+ api := newTestAPI(t)
+ type BoolHeader struct {
+ Flag bool `header:"X-Flag"`
+ }
+ shiftapi.Get(api, "/test", func(r *http.Request, in BoolHeader) (*Status, error) {
+ return &Status{OK: true}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{
+ "X-Flag": "notabool",
+ })
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Fatalf("expected 400, got %d", resp.StatusCode)
+ }
+}
+
+func TestGetWithHeaderInvalidUint(t *testing.T) {
+ api := newTestAPI(t)
+ type UintHeader struct {
+ Count uint `header:"X-Count"`
+ }
+ shiftapi.Get(api, "/test", func(r *http.Request, in UintHeader) (*Status, error) {
+ return &Status{OK: true}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{
+ "X-Count": "-1",
+ })
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Fatalf("expected 400, got %d", resp.StatusCode)
+ }
+}
+
+func TestGetWithHeaderInvalidFloat(t *testing.T) {
+ api := newTestAPI(t)
+ type FloatHeader struct {
+ Score float64 `header:"X-Score"`
+ }
+ shiftapi.Get(api, "/test", func(r *http.Request, in FloatHeader) (*Status, error) {
+ return &Status{OK: true}, nil
+ })
+
+ resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{
+ "X-Score": "abc",
+ })
+ if resp.StatusCode != http.StatusBadRequest {
+ t.Fatalf("expected 400, got %d", resp.StatusCode)
+ }
+}
+
+// --- Header parameter OpenAPI spec tests ---
+
+func TestSpecHeaderParamsDocumented(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) {
+ return &AuthResult{}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/auth").Get
+
+ var found bool
+ for _, p := range op.Parameters {
+ if p.Value.Name == "Authorization" && p.Value.In == "header" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Error("expected Authorization header parameter documented in spec")
+ }
+}
+
+func TestSpecHeaderParamTypes(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/scalars", func(r *http.Request, in ScalarHeaders) (*ScalarHeaderResult, error) {
+ return &ScalarHeaderResult{}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/scalars").Get
+
+ expected := map[string]string{
+ "X-Flag": "boolean",
+ "X-Count": "integer",
+ "X-Score": "number",
+ }
+ for _, p := range op.Parameters {
+ if p.Value.In != "header" {
+ continue
+ }
+ want, ok := expected[p.Value.Name]
+ if !ok {
+ t.Errorf("unexpected header param %q", p.Value.Name)
+ continue
+ }
+ got := p.Value.Schema.Value.Type.Slice()[0]
+ if got != want {
+ t.Errorf("header %q: expected type %q, got %q", p.Value.Name, want, got)
+ }
+ }
+}
+
+func TestSpecHeaderParamRequired(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) {
+ return &AuthResult{}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/auth").Get
+
+ for _, p := range op.Parameters {
+ if p.Value.Name == "Authorization" && p.Value.In == "header" {
+ if !p.Value.Required {
+ t.Error("expected Authorization header to be required")
+ }
+ return
+ }
+ }
+ t.Error("Authorization header param not found")
+}
+
+func TestSpecHeaderParamOptionalPointerNotRequired(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalHeader) (*OptionalHeaderResult, error) {
+ return &OptionalHeaderResult{}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/optional").Get
+
+ for _, p := range op.Parameters {
+ if p.Value.In == "header" && p.Value.Required {
+ t.Errorf("optional header %q should not be required", p.Value.Name)
+ }
+ }
+}
+
+func TestSpecHeaderParamValidationConstraints(t *testing.T) {
+ api := newTestAPI(t)
+ type ConstrainedHeader struct {
+ Count int `header:"X-Count" validate:"min=1,max=100"`
+ }
+ shiftapi.Get(api, "/constrained", func(r *http.Request, in ConstrainedHeader) (*Status, error) {
+ return &Status{OK: true}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/constrained").Get
+
+ for _, p := range op.Parameters {
+ if p.Value.Name == "X-Count" && p.Value.In == "header" {
+ s := p.Value.Schema.Value
+ if s.Min == nil || *s.Min != 1 {
+ t.Error("expected Min=1 on X-Count header param")
+ }
+ if s.Max == nil || *s.Max != 100 {
+ t.Error("expected Max=100 on X-Count header param")
+ }
+ return
+ }
+ }
+ t.Error("X-Count header param not found")
+}
+
+func TestSpecBodySchemaExcludesHeaderFields(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) {
+ return &HeaderAndBodyResult{}, nil
+ })
+
+ spec := api.Spec()
+ // Find the body schema in components
+ for name, schemaRef := range spec.Components.Schemas {
+ if name == "HeaderAndBody" {
+ if _, has := schemaRef.Value.Properties["Token"]; has {
+ t.Error("body schema should not contain header field 'Token'")
+ }
+ if _, has := schemaRef.Value.Properties["name"]; !has {
+ t.Error("body schema should contain body field 'name'")
+ }
+ return
+ }
+ }
+ t.Error("HeaderAndBody schema not found in components")
+}
+
+func TestSpecHeaderParamsCombinedWithQueryParams(t *testing.T) {
+ api := newTestAPI(t)
+ shiftapi.Get(api, "/search", func(r *http.Request, in HeaderAndQuery) (*HeaderAndQueryResult, error) {
+ return &HeaderAndQueryResult{}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/search").Get
+
+ var headerParams, queryParams int
+ for _, p := range op.Parameters {
+ switch p.Value.In {
+ case "header":
+ headerParams++
+ case "query":
+ queryParams++
+ }
+ }
+ if headerParams != 1 {
+ t.Errorf("expected 1 header param, got %d", headerParams)
+ }
+ if queryParams != 1 {
+ t.Errorf("expected 1 query param, got %d", queryParams)
+ }
+}
+
+func TestSpecHeaderParamEnum(t *testing.T) {
+ api := newTestAPI(t)
+ type EnumHeader struct {
+ Format string `header:"Accept" validate:"oneof=json xml csv"`
+ }
+ shiftapi.Get(api, "/data", func(r *http.Request, in EnumHeader) (*Status, error) {
+ return &Status{OK: true}, nil
+ })
+
+ spec := api.Spec()
+ op := spec.Paths.Find("/data").Get
+
+ for _, p := range op.Parameters {
+ if p.Value.Name == "Accept" && p.Value.In == "header" {
+ if len(p.Value.Schema.Value.Enum) != 3 {
+ t.Errorf("expected 3 enum values, got %d", len(p.Value.Schema.Value.Enum))
+ }
+ return
+ }
+ }
+ t.Error("Accept header param not found")
+}