diff --git a/README.md b/README.md index ee4f51e..f14132e 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ +

@@ -84,7 +85,7 @@ That's it. ShiftAPI reflects your Go types into an OpenAPI 3.1 spec at `/openapi ### Generic type-safe handlers -Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`), body fields (`json:"..."`), and form fields (`form:"..."`). For routes without input, use `_ struct{}`. +Generic free functions capture your request and response types at compile time. Every method uses a single function — struct tags discriminate query params (`query:"..."`), HTTP headers (`header:"..."`), body fields (`json:"..."`), and form fields (`form:"..."`). For routes without input, use `_ struct{}`. ```go // POST with body — input is decoded and passed as *CreateUser @@ -147,6 +148,25 @@ shiftapi.Post(api, "/items", func(r *http.Request, in CreateInput) (*Result, err }) ``` +### Typed HTTP headers + +Define a struct with `header` tags. Headers are parsed, validated, and documented in the OpenAPI spec automatically — just like query params. + +```go +type AuthInput struct { + Token string `header:"Authorization" validate:"required"` + Q string `query:"q"` +} + +shiftapi.Get(api, "/search", func(r *http.Request, in AuthInput) (*Results, error) { + // in.Token parsed from the Authorization header + // in.Q parsed from ?q= query param + return doSearch(in.Token, in.Q), nil +}) +``` + +Supports `string`, `bool`, `int*`, `uint*`, `float*` scalars and `*T` pointers for optional headers. Parse errors return `400`; validation failures return `422`. Header, query, and body fields can be freely combined in one struct. + ### File uploads (`multipart/form-data`) Use `form` tags to declare file upload endpoints. The `form` tag drives OpenAPI spec generation — the generated TypeScript client gets the correct `multipart/form-data` types automatically. At runtime, the request body is parsed via `ParseMultipartForm` and form-tagged fields are populated. @@ -409,6 +429,14 @@ const { data: upload } = await client.POST("/upload", { params: { query: { tags: "important" } }, }); // file uploads are typed as File | Blob | Uint8Array — generated from format: binary in the spec + +const { data: authResults } = await client.GET("/search", { + params: { + query: { q: "hello" }, + header: { Authorization: "Bearer token" }, + }, +}); +// header params are fully typed as well ``` In dev mode the plugins start the Go server, proxy API requests, watch `.go` files, and regenerate types on changes. diff --git a/apps/landing/src/components/Features.tsx b/apps/landing/src/components/Features.tsx index 06cf812..b7d1853 100644 --- a/apps/landing/src/components/Features.tsx +++ b/apps/landing/src/components/Features.tsx @@ -2,7 +2,7 @@ const features = [ { icon: , title: "Type-Safe Handlers", - desc: "Generic Go functions capture request and response types at compile time. No annotations, no magic comments.", + desc: <>Generic Go functions capture request and response types at compile time. json, query, header, and form tags route input automatically., }, { icon: , diff --git a/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx b/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx index d7bb648..604a6b5 100644 --- a/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx +++ b/apps/landing/src/content/docs/docs/core-concepts/error-handling.mdx @@ -96,7 +96,7 @@ Error types merge across all three levels. In the example above, `GET /api/v1/us ## Parse errors -When the framework cannot parse the request (malformed JSON, invalid query parameter types, invalid form data), it returns a `400 Bad Request`. By default the response body is `{"message": "bad request"}`. +When the framework cannot parse the request (malformed JSON, invalid query parameter types, invalid header values, invalid form data), it returns a `400 Bad Request`. By default the response body is `{"message": "bad request"}`. ### Customizing the 400 response diff --git a/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx b/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx index e14fa0f..af164bb 100644 --- a/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx +++ b/apps/landing/src/content/docs/docs/core-concepts/handlers.mdx @@ -28,8 +28,8 @@ func handler(r *http.Request, in InputType) (OutputType, error) { } ``` -- **`r *http.Request`** — the standard HTTP request, useful for headers, cookies, path parameters, etc. -- **`in InputType`** — automatically decoded from JSON body, query parameters, or multipart form data +- **`r *http.Request`** — the standard HTTP request, useful for cookies, path parameters, and other request metadata +- **`in InputType`** — automatically decoded from JSON body, query parameters, HTTP headers, or multipart form data - **`OutputType`** — serialized as JSON in the response (can be a pointer or value type) - **`error`** — return an error to send an error response @@ -50,17 +50,21 @@ Fields are decoded based on struct tags: | `path` | URL path parameters | `path:"id"` | | `json` | Request body (JSON) | `json:"name"` | | `query` | Query parameters | `query:"page"` | +| `header` | HTTP headers | `header:"Authorization"` | | `form` | Multipart form data | `form:"file"` | -You can combine tags in a single struct: +You can combine tags in a single struct — `path`, `query`, `header`, and `json` fields can be freely mixed. The only restriction is that `form` and `json` tags cannot coexist on the same struct. ```go type UpdateUser struct { - ID int `path:"id"` - Name string `json:"name"` + ID int `path:"id"` + Token string `header:"Authorization"` + Name string `json:"name"` } ``` +Header and query fields support `string`, `bool`, `int*`, `uint*`, `float*` scalars and `*T` pointers for optional values. Query fields also support `[]T` slices for repeated params. Slices are not supported for headers. Parse errors return `400`; validation failures return `422`. + ## Path parameters Use `path` tags to declare typed path parameters. The field name in the tag must match a `{param}` in the route pattern: diff --git a/apps/landing/src/content/docs/docs/core-concepts/validation.mdx b/apps/landing/src/content/docs/docs/core-concepts/validation.mdx index fd130bd..8022c39 100644 --- a/apps/landing/src/content/docs/docs/core-concepts/validation.mdx +++ b/apps/landing/src/content/docs/docs/core-concepts/validation.mdx @@ -74,6 +74,18 @@ type UserResponse struct { } ``` +## Validation on query and header fields + +Validation rules work on all input sources — `json`, `query`, `header`, and `form` fields. For example, `validate:"required"` on a `header`-tagged field returns `422` if the header is present but empty, and `validate:"oneof=json xml"` constrains accepted values. Parse errors (e.g. sending `"abc"` for an `int` header) return `400`; validation failures return `422`. + +```go +type AuthSearch struct { + Token string `header:"Authorization" validate:"required"` + Format string `header:"Accept" validate:"oneof=json xml"` + Q string `query:"q" validate:"required"` +} +``` + ## Custom validator You can provide your own validator instance with custom rules: diff --git a/apps/landing/src/content/docs/docs/getting-started/introduction.mdx b/apps/landing/src/content/docs/docs/getting-started/introduction.mdx index 1b98438..5257920 100644 --- a/apps/landing/src/content/docs/docs/getting-started/introduction.mdx +++ b/apps/landing/src/content/docs/docs/getting-started/introduction.mdx @@ -19,6 +19,7 @@ ShiftAPI is a Go framework that generates an OpenAPI 3.1 spec from your handler - **Type-safe errors** — declare custom error types at the API, group, or route level with `WithError[T](status)` — schemas appear in the OpenAPI spec and errors are matched at runtime via `errors.As` - **Middleware** — apply standard `func(http.Handler) http.Handler` middleware at the API, group, or route level with `WithMiddleware` — middleware resolves from API (outermost) → Group → Route (innermost) - **Composable options** — bundle middleware, errors, and other options into reusable `Option` values with `ComposeOptions` +- **Typed HTTP headers** — parse, validate, and document HTTP headers with `header` struct tags - **File uploads** — declare uploads with `form` tags, get correct `multipart/form-data` types - **Interactive docs** — Scalar API reference at `/docs` and OpenAPI spec at `/openapi.json` - **Just `net/http`** — implements `http.Handler`, works with any middleware or router diff --git a/form.go b/form.go index cd99f84..2e6245b 100644 --- a/form.go +++ b/form.go @@ -154,3 +154,5 @@ func (e *formParseError) Error() string { } return fmt.Sprintf("invalid form field %q: %v", e.Field, e.Err) } + +func (e *formParseError) Unwrap() error { return e.Err } diff --git a/handler.go b/handler.go index f93c4fd..537516d 100644 --- a/handler.go +++ b/handler.go @@ -15,17 +15,18 @@ import ( // The In struct's fields are discriminated by struct tags: // - path:"name" — parsed from URL path parameters (e.g. /users/{id}) // - query:"name" — parsed from URL query parameters +// - header:"name" — parsed from HTTP headers // - json:"name" — parsed from the JSON request body (default for POST/PUT/PATCH) // - form:"name" — parsed from multipart/form-data (for file uploads) // // Use struct{} as In for routes that take no input, or as Resp for routes // that return no body (e.g. health checks that only need a status code). // -// The [*http.Request] parameter gives access to headers, cookies, path +// The [*http.Request] parameter gives access to cookies, path // parameters, and other request metadata. type HandlerFunc[In, Resp any] func(r *http.Request, in In) (Resp, error) -func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasPath, hasQuery, hasBody, hasForm bool, maxUploadSize int64, errLookup errorLookup, badRequestFn, internalServerFn func(error) any) http.HandlerFunc { +func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any) error, hasPath, hasQuery, hasHeader, hasBody, hasForm bool, maxUploadSize int64, errLookup errorLookup, badRequestFn, internalServerFn func(error) any) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var in In rv := reflect.ValueOf(&in).Elem() @@ -56,6 +57,12 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any } } + // Reset any header-tagged fields that body decode may have + // inadvertently set, so they only come from HTTP headers. + if hasBody && hasHeader { + resetHeaderFields(rv) + } + // Parse query params if there are query fields if hasQuery { if err := parseQueryInto(rv, r.URL.Query()); err != nil { @@ -72,6 +79,14 @@ func adapt[In, Resp any](fn HandlerFunc[In, Resp], status int, validate func(any } } + // Parse headers if there are header fields + if hasHeader { + if err := parseHeadersInto(rv, r.Header); err != nil { + writeJSON(w, http.StatusBadRequest, badRequestFn(err)) + return + } + } + if err := validate(in); err != nil { handleError(w, internalServerFn, err, errLookup) return diff --git a/handlerFuncs.go b/handlerFuncs.go index be96a29..1ed2b84 100644 --- a/handlerFuncs.go +++ b/handlerFuncs.go @@ -28,7 +28,7 @@ func registerRoute[In, Resp any]( rawInType = rawInType.Elem() } - hasPath, hasQuery, hasBody, hasForm := partitionFields(rawInType) + hasPath, hasQuery, hasHeader, hasBody, hasForm := partitionFields(rawInType) // Validate path-tagged fields match the route pattern. if hasPath { @@ -44,6 +44,10 @@ func registerRoute[In, Resp any]( if hasQuery { queryType = rawInType } + var headerType reflect.Type + if hasHeader { + headerType = rawInType + } // POST/PUT/PATCH conventionally carry a request body, so always attempt // body decode for these methods — even when the input is struct{}. // This means Post(api, path, func(r, _ struct{}) ...) requires at least "{}". @@ -71,14 +75,14 @@ func registerRoute[In, Resp any]( pathType = rawInType } - if err := api.updateSchema(method, fullPath, pathType, queryType, bodyType, outType, hasForm, rawInType, cfg.info, cfg.status, allErrors); err != nil { + if err := api.updateSchema(method, fullPath, pathType, queryType, headerType, bodyType, outType, hasForm, rawInType, cfg.info, cfg.status, allErrors); err != nil { panic(fmt.Sprintf("shiftapi: schema generation failed for %s %s: %v", method, fullPath, err)) } errLookup := buildErrorLookup(allErrors) pattern := fmt.Sprintf("%s %s", method, fullPath) - var h http.Handler = adapt(fn, cfg.status, api.validateBody, hasPath, hasQuery, decodeBody, hasForm, api.maxUploadSize, errLookup, api.badRequestFn, api.internalServerFn) + var h http.Handler = adapt(fn, cfg.status, api.validateBody, hasPath, hasQuery, hasHeader, decodeBody, hasForm, api.maxUploadSize, errLookup, api.badRequestFn, api.internalServerFn) // Apply route-level middleware (innermost), then group-level (outermost). // Reverse order so the first middleware in the slice wraps outermost. for i := len(cfg.middleware) - 1; i >= 0; i-- { diff --git a/header.go b/header.go new file mode 100644 index 0000000..1f6aac3 --- /dev/null +++ b/header.go @@ -0,0 +1,106 @@ +package shiftapi + +import ( + "fmt" + "net/http" + "reflect" + "strings" +) + +// hasHeaderTag returns true if the struct field has a `header` tag. +func hasHeaderTag(f reflect.StructField) bool { + return f.Tag.Get("header") != "" +} + +// headerFieldName returns the header name for a struct field. +func headerFieldName(f reflect.StructField) string { + name, _, _ := strings.Cut(f.Tag.Get("header"), ",") + if name == "" { + return f.Name + } + return name +} + +// resetHeaderFields zeros out any header-tagged fields on a struct value. +// This is called after body decode so that header-tagged fields are only +// populated by parseHeadersInto, not by JSON keys that happen to match. +func resetHeaderFields(rv reflect.Value) { + for rv.Kind() == reflect.Pointer { + rv = rv.Elem() + } + if rv.Kind() != reflect.Struct { + return + } + rt := rv.Type() + for i := range rt.NumField() { + f := rt.Field(i) + if f.IsExported() && hasHeaderTag(f) { + rv.Field(i).SetZero() + } + } +} + +// parseHeadersInto populates header-tagged fields on an existing struct value +// from HTTP headers. Non-header fields are left untouched. +// Only scalar types and pointer-to-scalar types are supported (no slices). +func parseHeadersInto(rv reflect.Value, header http.Header) error { + for rv.Kind() == reflect.Pointer { + if rv.IsNil() { + rv.Set(reflect.New(rv.Type().Elem())) + } + rv = rv.Elem() + } + + rt := rv.Type() + if rt.Kind() != reflect.Struct { + return fmt.Errorf("header type must be a struct, got %s", rt.Kind()) + } + + for i := range rt.NumField() { + field := rt.Field(i) + if !field.IsExported() || !hasHeaderTag(field) { + continue + } + + name := headerFieldName(field) + fv := rv.Field(i) + ft := field.Type + + // Handle pointer fields (optional headers) + if ft.Kind() == reflect.Pointer { + raw := header.Get(name) + if raw == "" { + continue + } + ptr := reflect.New(ft.Elem()) + if err := setScalarValue(ptr.Elem(), raw); err != nil { + return &headerParseError{Field: name, Err: err} + } + fv.Set(ptr) + continue + } + + // Handle scalar fields + raw := header.Get(name) + if raw == "" { + continue + } + if err := setScalarValue(fv, raw); err != nil { + return &headerParseError{Field: name, Err: err} + } + } + + return nil +} + +// headerParseError is returned when a header value cannot be parsed. +type headerParseError struct { + Field string + Err error +} + +func (e *headerParseError) Error() string { + return fmt.Sprintf("invalid header %q: %v", e.Field, e.Err) +} + +func (e *headerParseError) Unwrap() error { return e.Err } diff --git a/path.go b/path.go index c7b72e1..92896ea 100644 --- a/path.go +++ b/path.go @@ -109,3 +109,5 @@ type pathParseError struct { func (e *pathParseError) Error() string { return fmt.Sprintf("invalid path parameter %q: %v", e.Field, e.Err) } + +func (e *pathParseError) Unwrap() error { return e.Err } diff --git a/query.go b/query.go index a6dd3ed..860baf7 100644 --- a/query.go +++ b/query.go @@ -14,14 +14,15 @@ func hasQueryTag(f reflect.StructField) bool { } // partitionFields inspects a struct type and reports whether it contains -// path-tagged, query-tagged, body (json-tagged or untagged non-path/query) fields, -// and/or form-tagged fields. It panics if both body and form fields are present. -func partitionFields(t reflect.Type) (hasPath, hasQuery, hasBody, hasForm bool) { +// path-tagged, query-tagged, header-tagged, body (json-tagged or untagged +// non-path/query/header) fields, and/or form-tagged fields. It panics if both +// body and form fields are present. +func partitionFields(t reflect.Type) (hasPath, hasQuery, hasHeader, hasBody, hasForm bool) { for t.Kind() == reflect.Pointer { t = t.Elem() } if t.Kind() != reflect.Struct { - return false, false, false, false + return false, false, false, false, false } for f := range t.Fields() { if !f.IsExported() { @@ -31,10 +32,12 @@ func partitionFields(t reflect.Type) (hasPath, hasQuery, hasBody, hasForm bool) hasPath = true } else if hasQueryTag(f) { hasQuery = true + } else if hasHeaderTag(f) { + hasHeader = true } else if hasFormTag(f) { hasForm = true } else { - // Any exported field without a path, query, or form tag is a body field + // Any exported field without a path, query, header, or form tag is a body field jsonTag := f.Tag.Get("json") if jsonTag == "-" { continue @@ -178,7 +181,7 @@ func setScalarValue(v reflect.Value, raw string) error { } v.SetFloat(n) default: - return fmt.Errorf("unsupported query parameter type %s", v.Kind()) + return fmt.Errorf("unsupported parameter type %s", v.Kind()) } return nil } @@ -192,3 +195,5 @@ type queryParseError struct { func (e *queryParseError) Error() string { return fmt.Sprintf("invalid query parameter %q: %v", e.Field, e.Err) } + +func (e *queryParseError) Unwrap() error { return e.Err } diff --git a/schema.go b/schema.go index 800aea5..c8a93ad 100644 --- a/schema.go +++ b/schema.go @@ -12,7 +12,7 @@ import ( var pathParamRe = regexp.MustCompile(`\{([^}]+)\}`) -func (a *API) updateSchema(method, path string, pathType, queryType, inType, outType reflect.Type, hasForm bool, formType reflect.Type, info *RouteInfo, status int, errors []errorEntry) error { +func (a *API) updateSchema(method, path string, pathType, queryType, headerType, inType, outType reflect.Type, hasForm bool, formType reflect.Type, info *RouteInfo, status int, errors []errorEntry) error { op := &openapi3.Operation{ OperationID: operationID(method, path), Responses: openapi3.NewResponses(), @@ -67,6 +67,15 @@ func (a *API) updateSchema(method, path string, pathType, queryType, inType, out op.Parameters = append(op.Parameters, queryParams...) } + // Header parameters + if headerType != nil { + headerParams, err := a.generateHeaderParams(headerType) + if err != nil { + return err + } + op.Parameters = append(op.Parameters, headerParams...) + } + // Response schema statusStr := fmt.Sprintf("%d", status) outSchema, err := a.generateSchemaRef(outType) @@ -140,8 +149,9 @@ func (a *API) updateSchema(method, path string, pathType, queryType, inType, out return err } if inSchema != nil { - // Strip query-tagged and path-tagged fields from the body schema + // Strip query-tagged, header-tagged, and path-tagged fields from the body schema stripQueryFields(inType, inSchema.Value) + stripHeaderFields(inType, inSchema.Value) stripPathFields(inType, inSchema.Value) if len(inSchema.Value.Properties) > 0 { @@ -457,6 +467,73 @@ func stripQueryFields(t reflect.Type, schema *openapi3.Schema) { } } +// generateHeaderParams produces OpenAPI parameter definitions for a header struct type. +// Only fields with `header` tags are included. Slices are not supported for headers. +func (a *API) generateHeaderParams(t reflect.Type) ([]*openapi3.ParameterRef, error) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("header type must be a struct, got %s", t.Kind()) + } + + var params []*openapi3.ParameterRef + for field := range t.Fields() { + if !field.IsExported() { + continue + } + if !hasHeaderTag(field) { + continue + } + name := http.CanonicalHeaderKey(headerFieldName(field)) + schema := scalarToOpenAPISchema(field.Type) + + // Apply validation constraints + if err := validateSchemaCustomizer(name, field.Type, field.Tag, schema.Value); err != nil { + return nil, err + } + + required := hasRule(field.Tag.Get("validate"), "required") + + params = append(params, &openapi3.ParameterRef{ + Value: &openapi3.Parameter{ + Name: name, + In: "header", + Required: required, + Schema: schema, + }, + }) + } + return params, nil +} + +// stripHeaderFields removes header-tagged fields from a body schema's Properties and Required. +func stripHeaderFields(t reflect.Type, schema *openapi3.Schema) { + for t.Kind() == reflect.Pointer { + t = t.Elem() + } + if t.Kind() != reflect.Struct || schema == nil { + return + } + for f := range t.Fields() { + if !f.IsExported() || !hasHeaderTag(f) { + continue + } + jname := jsonFieldName(f) + if jname == "" || jname == "-" { + continue + } + delete(schema.Properties, jname) + // Remove from Required slice + for j, req := range schema.Required { + if req == jname { + schema.Required = append(schema.Required[:j], schema.Required[j+1:]...) + break + } + } + } +} + // stripPathFields removes path-tagged fields from a body schema's Properties and Required. func stripPathFields(t reflect.Type, schema *openapi3.Schema) { for t.Kind() == reflect.Pointer { diff --git a/shiftapi_test.go b/shiftapi_test.go index b70748e..ec1bc33 100644 --- a/shiftapi_test.go +++ b/shiftapi_test.go @@ -4111,3 +4111,497 @@ func TestSpecPathParamExcludedFromBody(t *testing.T) { t.Error("path field 'ID' should not appear in body schema") } } + +// --- Header parameter test types --- + +type AuthHeader struct { + Token string `header:"Authorization" validate:"required"` +} + +type AuthResult struct { + Token string `json:"token"` +} + +type OptionalHeader struct { + Debug *bool `header:"X-Debug"` + Limit *int `header:"X-Limit"` +} + +type OptionalHeaderResult struct { + HasDebug bool `json:"has_debug"` + Debug bool `json:"debug"` + HasLimit bool `json:"has_limit"` + Limit int `json:"limit"` +} + +type HeaderAndBody struct { + Token string `header:"Authorization" validate:"required"` + Name string `json:"name" validate:"required"` +} + +type HeaderAndBodyResult struct { + Token string `json:"token"` + Name string `json:"name"` +} + +type HeaderAndQuery struct { + Token string `header:"Authorization" validate:"required"` + Q string `query:"q"` +} + +type HeaderAndQueryResult struct { + Token string `json:"token"` + Q string `json:"q"` +} + +type ScalarHeaders struct { + Flag bool `header:"X-Flag"` + Count uint `header:"X-Count"` + Score float64 `header:"X-Score"` +} + +type ScalarHeaderResult struct { + Flag bool `json:"flag"` + Count uint `json:"count"` + Score float64 `json:"score"` +} + +// --- Header parameter test helpers --- + +func doRequestWithHeaders(t *testing.T, api http.Handler, method, path, body string, headers map[string]string) *http.Response { + t.Helper() + var bodyReader io.Reader + if body != "" { + bodyReader = strings.NewReader(body) + } + req := httptest.NewRequest(method, path, bodyReader) + if body != "" { + req.Header.Set("Content-Type", "application/json") + } + for k, v := range headers { + req.Header.Set(k, v) + } + rec := httptest.NewRecorder() + api.ServeHTTP(rec, req) + return rec.Result() +} + +// --- Header parameter runtime tests --- + +func TestGetWithHeaderBasic(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{Token: in.Token}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/auth", "", map[string]string{ + "Authorization": "Bearer abc123", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[AuthResult](t, resp) + if result.Token != "Bearer abc123" { + t.Errorf("expected Token=%q, got %q", "Bearer abc123", result.Token) + } +} + +func TestGetWithHeaderMissingRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{Token: in.Token}, nil + }) + + // Missing required "Authorization" header + resp := doRequest(t, api, http.MethodGet, "/auth", "") + if resp.StatusCode != http.StatusUnprocessableEntity { + t.Fatalf("expected 422, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderInvalidType(t *testing.T) { + api := newTestAPI(t) + type IntHeader struct { + Count int `header:"X-Count" validate:"required"` + } + shiftapi.Get(api, "/count", func(r *http.Request, in IntHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/count", "", map[string]string{ + "X-Count": "notanumber", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderOptionalPointers(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalHeader) (*OptionalHeaderResult, error) { + result := &OptionalHeaderResult{} + if in.Debug != nil { + result.HasDebug = true + result.Debug = *in.Debug + } + if in.Limit != nil { + result.HasLimit = true + result.Limit = *in.Limit + } + return result, nil + }) + + // With both headers + resp := doRequestWithHeaders(t, api, http.MethodGet, "/optional", "", map[string]string{ + "X-Debug": "true", + "X-Limit": "50", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[OptionalHeaderResult](t, resp) + if !result.HasDebug || !result.Debug { + t.Error("expected Debug to be true") + } + if !result.HasLimit || result.Limit != 50 { + t.Errorf("expected Limit=50, got %d", result.Limit) + } + + // Without optional headers + resp2 := doRequest(t, api, http.MethodGet, "/optional", "") + if resp2.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp2.StatusCode) + } + result2 := decodeJSON[OptionalHeaderResult](t, resp2) + if result2.HasDebug { + t.Error("expected HasDebug=false when header absent") + } + if result2.HasLimit { + t.Error("expected HasLimit=false when header absent") + } +} + +func TestPostWithHeaderAndBody(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) { + return &HeaderAndBodyResult{Token: in.Token, Name: in.Name}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodPost, "/items", `{"name":"widget"}`, map[string]string{ + "Authorization": "Bearer xyz", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[HeaderAndBodyResult](t, resp) + if result.Token != "Bearer xyz" { + t.Errorf("expected Token=%q, got %q", "Bearer xyz", result.Token) + } + if result.Name != "widget" { + t.Errorf("expected Name=%q, got %q", "widget", result.Name) + } +} + +func TestHeaderFieldNotSetByBodyDecode(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) { + return &HeaderAndBodyResult{Token: in.Token, Name: in.Name}, nil + }) + + // Body includes "Token" key that matches the header field name — it should be ignored + resp := doRequestWithHeaders(t, api, http.MethodPost, "/items", `{"name":"widget","Token":"body-token"}`, map[string]string{ + "Authorization": "Bearer real", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[HeaderAndBodyResult](t, resp) + if result.Token != "Bearer real" { + t.Errorf("expected Token=%q from header, got %q", "Bearer real", result.Token) + } +} + +func TestGetWithHeaderAndQuery(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in HeaderAndQuery) (*HeaderAndQueryResult, error) { + return &HeaderAndQueryResult{Token: in.Token, Q: in.Q}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/search?q=hello", "", map[string]string{ + "Authorization": "Bearer abc", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[HeaderAndQueryResult](t, resp) + if result.Token != "Bearer abc" { + t.Errorf("expected Token=%q, got %q", "Bearer abc", result.Token) + } + if result.Q != "hello" { + t.Errorf("expected Q=%q, got %q", "hello", result.Q) + } +} + +func TestGetWithHeaderScalars(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/scalars", func(r *http.Request, in ScalarHeaders) (*ScalarHeaderResult, error) { + return &ScalarHeaderResult{Flag: in.Flag, Count: in.Count, Score: in.Score}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/scalars", "", map[string]string{ + "X-Flag": "true", + "X-Count": "42", + "X-Score": "3.14", + }) + if resp.StatusCode != http.StatusOK { + t.Fatalf("expected 200, got %d", resp.StatusCode) + } + result := decodeJSON[ScalarHeaderResult](t, resp) + if !result.Flag { + t.Error("expected Flag=true") + } + if result.Count != 42 { + t.Errorf("expected Count=42, got %d", result.Count) + } + if result.Score != 3.14 { + t.Errorf("expected Score=3.14, got %f", result.Score) + } +} + +func TestGetWithHeaderInvalidBool(t *testing.T) { + api := newTestAPI(t) + type BoolHeader struct { + Flag bool `header:"X-Flag"` + } + shiftapi.Get(api, "/test", func(r *http.Request, in BoolHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{ + "X-Flag": "notabool", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderInvalidUint(t *testing.T) { + api := newTestAPI(t) + type UintHeader struct { + Count uint `header:"X-Count"` + } + shiftapi.Get(api, "/test", func(r *http.Request, in UintHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{ + "X-Count": "-1", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +func TestGetWithHeaderInvalidFloat(t *testing.T) { + api := newTestAPI(t) + type FloatHeader struct { + Score float64 `header:"X-Score"` + } + shiftapi.Get(api, "/test", func(r *http.Request, in FloatHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + resp := doRequestWithHeaders(t, api, http.MethodGet, "/test", "", map[string]string{ + "X-Score": "abc", + }) + if resp.StatusCode != http.StatusBadRequest { + t.Fatalf("expected 400, got %d", resp.StatusCode) + } +} + +// --- Header parameter OpenAPI spec tests --- + +func TestSpecHeaderParamsDocumented(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/auth").Get + + var found bool + for _, p := range op.Parameters { + if p.Value.Name == "Authorization" && p.Value.In == "header" { + found = true + break + } + } + if !found { + t.Error("expected Authorization header parameter documented in spec") + } +} + +func TestSpecHeaderParamTypes(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/scalars", func(r *http.Request, in ScalarHeaders) (*ScalarHeaderResult, error) { + return &ScalarHeaderResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/scalars").Get + + expected := map[string]string{ + "X-Flag": "boolean", + "X-Count": "integer", + "X-Score": "number", + } + for _, p := range op.Parameters { + if p.Value.In != "header" { + continue + } + want, ok := expected[p.Value.Name] + if !ok { + t.Errorf("unexpected header param %q", p.Value.Name) + continue + } + got := p.Value.Schema.Value.Type.Slice()[0] + if got != want { + t.Errorf("header %q: expected type %q, got %q", p.Value.Name, want, got) + } + } +} + +func TestSpecHeaderParamRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/auth", func(r *http.Request, in AuthHeader) (*AuthResult, error) { + return &AuthResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/auth").Get + + for _, p := range op.Parameters { + if p.Value.Name == "Authorization" && p.Value.In == "header" { + if !p.Value.Required { + t.Error("expected Authorization header to be required") + } + return + } + } + t.Error("Authorization header param not found") +} + +func TestSpecHeaderParamOptionalPointerNotRequired(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/optional", func(r *http.Request, in OptionalHeader) (*OptionalHeaderResult, error) { + return &OptionalHeaderResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/optional").Get + + for _, p := range op.Parameters { + if p.Value.In == "header" && p.Value.Required { + t.Errorf("optional header %q should not be required", p.Value.Name) + } + } +} + +func TestSpecHeaderParamValidationConstraints(t *testing.T) { + api := newTestAPI(t) + type ConstrainedHeader struct { + Count int `header:"X-Count" validate:"min=1,max=100"` + } + shiftapi.Get(api, "/constrained", func(r *http.Request, in ConstrainedHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/constrained").Get + + for _, p := range op.Parameters { + if p.Value.Name == "X-Count" && p.Value.In == "header" { + s := p.Value.Schema.Value + if s.Min == nil || *s.Min != 1 { + t.Error("expected Min=1 on X-Count header param") + } + if s.Max == nil || *s.Max != 100 { + t.Error("expected Max=100 on X-Count header param") + } + return + } + } + t.Error("X-Count header param not found") +} + +func TestSpecBodySchemaExcludesHeaderFields(t *testing.T) { + api := newTestAPI(t) + shiftapi.Post(api, "/items", func(r *http.Request, in HeaderAndBody) (*HeaderAndBodyResult, error) { + return &HeaderAndBodyResult{}, nil + }) + + spec := api.Spec() + // Find the body schema in components + for name, schemaRef := range spec.Components.Schemas { + if name == "HeaderAndBody" { + if _, has := schemaRef.Value.Properties["Token"]; has { + t.Error("body schema should not contain header field 'Token'") + } + if _, has := schemaRef.Value.Properties["name"]; !has { + t.Error("body schema should contain body field 'name'") + } + return + } + } + t.Error("HeaderAndBody schema not found in components") +} + +func TestSpecHeaderParamsCombinedWithQueryParams(t *testing.T) { + api := newTestAPI(t) + shiftapi.Get(api, "/search", func(r *http.Request, in HeaderAndQuery) (*HeaderAndQueryResult, error) { + return &HeaderAndQueryResult{}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/search").Get + + var headerParams, queryParams int + for _, p := range op.Parameters { + switch p.Value.In { + case "header": + headerParams++ + case "query": + queryParams++ + } + } + if headerParams != 1 { + t.Errorf("expected 1 header param, got %d", headerParams) + } + if queryParams != 1 { + t.Errorf("expected 1 query param, got %d", queryParams) + } +} + +func TestSpecHeaderParamEnum(t *testing.T) { + api := newTestAPI(t) + type EnumHeader struct { + Format string `header:"Accept" validate:"oneof=json xml csv"` + } + shiftapi.Get(api, "/data", func(r *http.Request, in EnumHeader) (*Status, error) { + return &Status{OK: true}, nil + }) + + spec := api.Spec() + op := spec.Paths.Find("/data").Get + + for _, p := range op.Parameters { + if p.Value.Name == "Accept" && p.Value.In == "header" { + if len(p.Value.Schema.Value.Enum) != 3 { + t.Errorf("expected 3 enum values, got %d", len(p.Value.Schema.Value.Enum)) + } + return + } + } + t.Error("Accept header param not found") +}