Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: '1.22'
go-version: '1.26'

- name: Build
run: go build ./...
Expand Down
51 changes: 43 additions & 8 deletions auth.go
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
package fsa

import (
"fmt"
"math/rand"
"strconv"
"time"

"bytes"
"context"
"github.com/google/uuid"
"fmt"
"html/template"
"math/rand"
"net/http"
"net/url"
"path/filepath"
"runtime"
"strconv"
"time"

"github.com/golang-jwt/jwt/v4"
"github.com/google/uuid"
)

type IAuthDb interface {
Expand Down Expand Up @@ -44,6 +45,18 @@ const ClaimsKey Key = "claims"
const UserEmailKey Key = "userEmail"
const UserIdKey Key = "userId"

type CookieConfig struct {
Domain string // For cross-subdomain auth
Secure bool // HTTPS only (default: true)
SameSite http.SameSite // Default: SameSiteStrictMode
}

type CSRFConfig struct {
CookieName string // Default: "csrf_token"
HeaderName string // Default: "X-CSRF-Token"
TokenLength int // Default: 32
}

type Config struct {
AppName string
Logo string
Expand All @@ -62,6 +75,9 @@ type Config struct {
ReturnUrls []string

UseIdentity bool

CookieConfig *CookieConfig // Optional, uses secure defaults if nil
CSRFConfig *CSRFConfig // Optional, uses defaults if nil
}

type Token struct {
Expand Down Expand Up @@ -116,6 +132,25 @@ func NewWithMemDbAndDefaultTemplate(sender ICodeSender, uc IUserCreator, cfg *Co
}

func (a *Auth) LoginStep1SendVerificationCode(ctx context.Context, email, returnUrl string) error {
// Validate returnUrl
if returnUrl == "" {
if len(a.Cfg.ReturnUrls) > 0 {
returnUrl = a.Cfg.ReturnUrls[0]
} else {
return fmt.Errorf("returnUrl required and no defaults configured")
}
}

validReturnUrl := false
for _, u := range a.Cfg.ReturnUrls {
if u == returnUrl {
validReturnUrl = true
break
}
}
if !validReturnUrl {
return fmt.Errorf("invalid return url: %s", returnUrl)
}

validEmail := a.Ev.Validate(email)
if !validEmail {
Expand All @@ -137,8 +172,8 @@ func (a *Auth) LoginStep1SendVerificationCode(ctx context.Context, email, return
return err
}

// send the code
link := fmt.Sprintf("%s?code=%s&email=%s", returnUrl, code, email)
// send the code with URL-encoded parameters
link := fmt.Sprintf("%s?code=%s&email=%s", returnUrl, url.QueryEscape(code), url.QueryEscape(email))
body := a.ParseTemplate(link, code)
err = a.Sender.Send(email, "Login Verification Code", body)
if err != nil {
Expand Down
164 changes: 164 additions & 0 deletions auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
package fsa

import (
"context"
"strings"
"testing"
"time"
)

// Phase 1: URL Encoding Tests

func TestLoginStep1_URLEncodesEmailWithPlusSign(t *testing.T) {
mockSender := &MockSender{}
auth := createTestAuthWithMockSender(mockSender)

err := auth.LoginStep1SendVerificationCode(context.Background(), "test+user@example.com", "https://app.com/login")

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(mockSender.LastBody, "email=test%2Buser%40example.com") {
t.Errorf("expected URL-encoded email with plus sign, got body: %s", mockSender.LastBody)
}
if strings.Contains(mockSender.LastBody, "email=test+user@") {
t.Error("email should be URL-encoded, not raw")
}
}

func TestLoginStep1_URLEncodesEmailWithAmpersand(t *testing.T) {
mockSender := &MockSender{}
auth := createTestAuthWithMockSender(mockSender)

err := auth.LoginStep1SendVerificationCode(context.Background(), "test&user@example.com", "https://app.com/login")

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(mockSender.LastBody, "email=test%26user%40example.com") {
t.Errorf("expected URL-encoded email with ampersand, got body: %s", mockSender.LastBody)
}
}

func TestLoginStep1_URLEncodesSpecialCharacters(t *testing.T) {
testCases := []struct {
email string
expected string
}{
{"test+user@example.com", "test%2Buser%40example.com"},
{"test&user@example.com", "test%26user%40example.com"},
{"test=user@example.com", "test%3Duser%40example.com"},
}

for _, tc := range testCases {
t.Run(tc.email, func(t *testing.T) {
mockSender := &MockSender{}
auth := createTestAuthWithMockSender(mockSender)

err := auth.LoginStep1SendVerificationCode(context.Background(), tc.email, "https://app.com/login")

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(mockSender.LastBody, "email="+tc.expected) {
t.Errorf("expected email=%s in body, got: %s", tc.expected, mockSender.LastBody)
}
})
}
}

// Phase 1: Input Validation Tests

func TestLoginStep1_DefaultsEmptyReturnUrl(t *testing.T) {
mockSender := &MockSender{}
auth := New(NewMemDb(), mockSender, &MockUserCreator{}, NewEmailValidator(), nil, nil, &Config{
AppName: "TestApp",
ReturnUrls: []string{"https://default.com/login"},
AccessTokenSecret: "secret",
RefreshTokenSecret: "secret",
CodeValidityPeriod: 5 * time.Minute,
AccessTokenValidityPeriod: 1 * time.Hour,
RefreshTokenValidityPeriod: 24 * time.Hour,
})

err := auth.LoginStep1SendVerificationCode(context.Background(), "test@example.com", "")

if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if !strings.Contains(mockSender.LastBody, "https://default.com/login") {
t.Errorf("expected default return URL in body, got: %s", mockSender.LastBody)
}
}

func TestLoginStep1_RejectsInvalidReturnUrl(t *testing.T) {
mockSender := &MockSender{}
auth := New(NewMemDb(), mockSender, &MockUserCreator{}, NewEmailValidator(), nil, nil, &Config{
AppName: "TestApp",
ReturnUrls: []string{"https://allowed.com/login"},
AccessTokenSecret: "secret",
RefreshTokenSecret: "secret",
CodeValidityPeriod: 5 * time.Minute,
AccessTokenValidityPeriod: 1 * time.Hour,
RefreshTokenValidityPeriod: 24 * time.Hour,
})

err := auth.LoginStep1SendVerificationCode(context.Background(), "test@example.com", "https://evil.com/login")

if err == nil {
t.Fatal("expected error for invalid return URL")
}
if !strings.Contains(err.Error(), "invalid return url") {
t.Errorf("expected 'invalid return url' error, got: %v", err)
}
if mockSender.LastBody != "" {
t.Error("email should not be sent for invalid return URL")
}
}

func TestLoginStep1_ErrorsWhenNoReturnUrlsConfigured(t *testing.T) {
mockSender := &MockSender{}
auth := New(NewMemDb(), mockSender, &MockUserCreator{}, NewEmailValidator(), nil, nil, &Config{
AppName: "TestApp",
ReturnUrls: []string{},
AccessTokenSecret: "secret",
RefreshTokenSecret: "secret",
CodeValidityPeriod: 5 * time.Minute,
AccessTokenValidityPeriod: 1 * time.Hour,
RefreshTokenValidityPeriod: 24 * time.Hour,
})

err := auth.LoginStep1SendVerificationCode(context.Background(), "test@example.com", "")

if err == nil {
t.Fatal("expected error when no return URLs configured")
}
if !strings.Contains(err.Error(), "no defaults configured") {
t.Errorf("expected 'no defaults configured' error, got: %v", err)
}
}

func TestLoginStep1_AcceptsValidReturnUrl(t *testing.T) {
mockSender := &MockSender{}
auth := New(NewMemDb(), mockSender, &MockUserCreator{}, NewEmailValidator(), nil, nil, &Config{
AppName: "TestApp",
ReturnUrls: []string{"https://app1.com/login", "https://app2.com/login"},
AccessTokenSecret: "secret",
RefreshTokenSecret: "secret",
CodeValidityPeriod: 5 * time.Minute,
AccessTokenValidityPeriod: 1 * time.Hour,
RefreshTokenValidityPeriod: 24 * time.Hour,
})

// Test first allowed URL
err := auth.LoginStep1SendVerificationCode(context.Background(), "test@example.com", "https://app1.com/login")
if err != nil {
t.Fatalf("unexpected error for first allowed URL: %v", err)
}

// Test second allowed URL
err = auth.LoginStep1SendVerificationCode(context.Background(), "test2@example.com", "https://app2.com/login")
if err != nil {
t.Fatalf("unexpected error for second allowed URL: %v", err)
}
}
24 changes: 17 additions & 7 deletions chi_middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,29 @@ func NewChiMiddleware(cfg *Config) *AuthMiddleware {

func (am *AuthMiddleware) VerifyAuthenticationToken(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
var jwtToken string

// Try Authorization header first
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
http.Error(w, ErrorAuthHeaderMissing.Error(), http.StatusUnauthorized)
return
if authHeader != "" {
splitToken := strings.Split(authHeader, "Bearer ")
if len(splitToken) == 2 {
jwtToken = splitToken[1]
}
}

splitToken := strings.Split(authHeader, "Bearer ")
if len(splitToken) != 2 {
http.Error(w, ErrorInvalidAuthHeader.Error(), http.StatusUnauthorized)
// Fall back to cookie if header not present
if jwtToken == "" {
if cookie, err := r.Cookie("access_token"); err == nil {
jwtToken = cookie.Value
}
}

if jwtToken == "" {
http.Error(w, ErrorAuthHeaderMissing.Error(), http.StatusUnauthorized)
return
}

jwtToken := splitToken[1]
claims, err := parseTokenString(jwtToken, am.Cfg.AccessTokenSecret)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
Expand Down
Loading