diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..0f5e4a6 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,37 @@ + +name: Go tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + workflow_dispatch: {} + +jobs: + test: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v6 + with: + go-version: '1.24' + + - name: Cache Go modules + uses: actions/cache@v4 + with: + path: | + ~/.cache/go-build + ~/go/pkg/mod + key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }} + restore-keys: | + ${{ runner.os }}-go- + + - name: Install dependencies + run: go mod download + + - name: Run tests + run: go test ./... -v diff --git a/cmd/server/.env.example b/cmd/server/.env.example index 47e3160..36159b7 100644 --- a/cmd/server/.env.example +++ b/cmd/server/.env.example @@ -2,4 +2,7 @@ DB_HOST= DB_USER= DB_PASSWORD= DB_NAME= -DB_PORT= \ No newline at end of file +DB_PORT= +PORT= +GOOGLE_CLIENT_ID= +GOOGLE_CLIENT_SECRET= \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go index 4f3e831..9d6af8d 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,28 +2,54 @@ package main import ( "fmt" + "log" "os" "github.com/gin-gonic/gin" - "github.com/jasutiin/envlink/internal/server/auth" + "github.com/jasutiin/envlink/internal/server/api/auth" + "github.com/jasutiin/envlink/internal/server/api/projects" + "github.com/jasutiin/envlink/internal/server/api/pull" + "github.com/jasutiin/envlink/internal/server/api/push" "github.com/jasutiin/envlink/internal/server/database" - "github.com/jasutiin/envlink/internal/server/projects" - "github.com/jasutiin/envlink/internal/server/pull" - "github.com/jasutiin/envlink/internal/server/push" + "github.com/joho/godotenv" ) func main() { + err := godotenv.Load() + if err != nil { + log.Println("No .env file found, using environment variables directly") + } + + port := os.Getenv("PORT") + if port == "" { + log.Fatalf("Port was not provided!") + } + server := gin.Default() api := server.Group("/api/v1") db := database.CreateDB() database.AutoMigrate(db) // creates tables if they don't exist + // empty RAILWAY_ENVIRONMENT_NAME means dev environment, otherwise production + isProd := os.Getenv("RAILWAY_ENVIRONMENT_NAME") != "" + + key := os.Getenv("COOKIE_SESSION_KEY") + if key == "" { + log.Fatalf("COOKIE_SESSION_KEY is required") + } + + domain := os.Getenv("RAILWAY_PUBLIC_DOMAIN") + + err = auth.NewAuth(port, domain, key, isProd) + if err != nil { + log.Fatalf("Failed to initialize auth: %s", err) + } + auth.AuthRouter(api, db) push.PushRouter(api) pull.PullRouter(api) projects.ProjectsRouter(api) - port := os.Getenv("PORT") fmt.Printf("listening on port %s", port) server.Run("0.0.0.0:" + port) -}; \ No newline at end of file +} diff --git a/go.mod b/go.mod index 2e520a8..98ecc57 100644 --- a/go.mod +++ b/go.mod @@ -4,7 +4,9 @@ go 1.24.4 require ( github.com/gin-gonic/gin v1.11.0 + github.com/gorilla/sessions v1.4.0 github.com/joho/godotenv v1.5.1 + github.com/markbates/goth v1.82.0 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 gorm.io/driver/postgres v1.6.0 @@ -12,18 +14,23 @@ require ( ) require ( + cloud.google.com/go/compute/metadata v0.3.0 // indirect github.com/bytedance/sonic v1.14.0 // indirect github.com/bytedance/sonic/loader v0.3.0 // indirect github.com/cloudwego/base64x v0.1.6 // indirect github.com/fsnotify/fsnotify v1.9.0 // indirect github.com/gabriel-vasile/mimetype v1.4.8 // indirect github.com/gin-contrib/sse v1.1.0 // indirect + github.com/go-chi/chi/v5 v5.2.2 // indirect github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/go-playground/validator/v10 v10.27.0 // indirect github.com/go-viper/mapstructure/v2 v2.4.0 // indirect github.com/goccy/go-json v0.10.2 // indirect github.com/goccy/go-yaml v1.18.0 // indirect + github.com/gorilla/context v1.1.1 // indirect + github.com/gorilla/mux v1.6.2 // indirect + github.com/gorilla/securecookie v1.1.2 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect @@ -54,6 +61,7 @@ require ( golang.org/x/crypto v0.40.0 // indirect golang.org/x/mod v0.26.0 // indirect golang.org/x/net v0.42.0 // indirect + golang.org/x/oauth2 v0.27.0 // indirect golang.org/x/sync v0.16.0 // indirect golang.org/x/sys v0.35.0 // indirect golang.org/x/text v0.28.0 // indirect diff --git a/go.sum b/go.sum index 02ad1ea..c4f80a0 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +cloud.google.com/go/compute/metadata v0.3.0 h1:Tz+eQXMEqDIKRsmY3cHTL6FVaynIjX2QxYC4trgAKZc= +cloud.google.com/go/compute/metadata v0.3.0/go.mod h1:zFmK7XCadkQkj6TtorcaGlCW1hT1fIilQDwofLpJ20k= github.com/bytedance/sonic v1.14.0 h1:/OfKt8HFw0kh2rj8N0F6C/qPGRESq0BbaNZgcNXXzQQ= github.com/bytedance/sonic v1.14.0/go.mod h1:WoEbx8WTcFJfzCe0hbmyTGrfjt8PzNEBdxlNUO24NhA= github.com/bytedance/sonic/loader v0.3.0 h1:dskwH8edlzNMctoruo8FPTJDF3vLtDT0sXZwvZJyqeA= @@ -18,6 +20,8 @@ github.com/gin-contrib/sse v1.1.0 h1:n0w2GMuUpWDVp7qSpvze6fAu9iRxJY4Hmj6AmBOU05w github.com/gin-contrib/sse v1.1.0/go.mod h1:hxRZ5gVpWMT7Z0B0gSNYqqsSCNIJMjzvm6fqCz9vjwM= github.com/gin-gonic/gin v1.11.0 h1:OW/6PLjyusp2PPXtyxKHU0RbX6I/l28FTdDlae5ueWk= github.com/gin-gonic/gin v1.11.0/go.mod h1:+iq/FyxlGzII0KHiBGjuNn4UNENUlKbGlNmc+W50Dls= +github.com/go-chi/chi/v5 v5.2.2 h1:CMwsvRVTbXVytCk1Wd72Zy1LAsAh9GxMmSNWLHCG618= +github.com/go-chi/chi/v5 v5.2.2/go.mod h1:L2yAIGWB3H+phAw1NxKwWM+7eUH/lU8pOMm5hHcoops= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= @@ -35,6 +39,16 @@ github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7Lk github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU= github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/gofuzz v1.2.0 h1:xRy4A+RhZaiKjJ1bPfwQ8sedCA+YS2YcCHW6ec7JMi0= +github.com/google/gofuzz v1.2.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= +github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= +github.com/gorilla/mux v1.6.2 h1:Pgr17XVTNXAk3q/r4CpKzC5xBM/qW1uVLV+IhRZpIIk= +github.com/gorilla/mux v1.6.2/go.mod h1:1lud6UwP+6orDFRuTfBEV8e9/aOM/c4fVVCaMa2zaAs= +github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kXD8ePA= +github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo= +github.com/gorilla/sessions v1.4.0 h1:kpIYOp/oi6MG/p5PgxApU8srsSw9tuFbt46Lt7auzqQ= +github.com/gorilla/sessions v1.4.0/go.mod h1:FLWm50oby91+hl7p/wRxDth9bWSuk0qVL2emc7lT5ik= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= @@ -61,6 +75,8 @@ github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.4.0 h1:WT9HwE9SGECu3lg4d/dIA+jxlljEa1/ffXKmRjqdmIQ= github.com/leodido/go-urn v1.4.0/go.mod h1:bvxc+MVxLKB4z00jd1z+Dvzr47oO32F/QSNjSBOlFxI= +github.com/markbates/goth v1.82.0 h1:8j/c34AjBSTNzO7zTsOyP5IYCQCMBTRBHAbBt/PI0bQ= +github.com/markbates/goth v1.82.0/go.mod h1:/DRlcq0pyqkKToyZjsL2KgiA1zbF1HIjE7u2uC79rUk= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= @@ -75,8 +91,8 @@ github.com/quic-go/qpack v0.5.1 h1:giqksBPnT/HDtZ6VhtFKgoLOWmlyo9Ei6u9PqzIMbhI= github.com/quic-go/qpack v0.5.1/go.mod h1:+PC4XFrEskIVkcLzpEkbLqq1uCoxPhQuvK5rH1ZgaEg= github.com/quic-go/quic-go v0.54.0 h1:6s1YB9QotYI6Ospeiguknbp2Znb/jZYjZLRXn9kMQBg= github.com/quic-go/quic-go v0.54.0/go.mod h1:e68ZEaCdyviluZmy44P6Iey98v/Wfz6HCjQEm+l8zTY= -github.com/rogpeppe/go-internal v1.9.0 h1:73kH8U+JUqXU8lRuOHeVHaa/SZPifC7BkcraZVejAe8= -github.com/rogpeppe/go-internal v1.9.0/go.mod h1:WtVeX8xhTBvf0smdhujwtBcq4Qrzq/fJaraNFVN+nFs= +github.com/rogpeppe/go-internal v1.12.0 h1:exVL4IDcn6na9z1rAb56Vxr+CgyK3nn3O+epU5NdKM8= +github.com/rogpeppe/go-internal v1.12.0/go.mod h1:E+RYuTGaKKdloAfM02xzb0FW3Paa99yedzYV+kq4uf4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sagikazarmark/locafero v0.11.0 h1:1iurJgmM9G3PA/I+wWYIOw/5SyBtxapeHDcg+AAIFXc= github.com/sagikazarmark/locafero v0.11.0/go.mod h1:nVIGvgyzw595SUSUE6tvCp3YYTeHs15MvlmU87WwIik= @@ -121,6 +137,8 @@ golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= golang.org/x/net v0.42.0 h1:jzkYrhi3YQWD6MLBJcsklgQsoAcw89EcZbJw8Z614hs= golang.org/x/net v0.42.0/go.mod h1:FF1RA5d3u7nAYA4z2TkclSCKh68eSXtiFwcWQpPXdt8= +golang.org/x/oauth2 v0.27.0 h1:da9Vo7/tDv5RH/7nZDz1eMGS/q1Vv1N/7FCrBhI9I3M= +golang.org/x/oauth2 v0.27.0/go.mod h1:onh5ek6nERTohokkhCD/y2cV4Do3fxFHFuAejCkRWT8= golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= diff --git a/internal/cli/login.go b/internal/cli/commands/login.go similarity index 93% rename from internal/cli/login.go rename to internal/cli/commands/login.go index 3cce7e4..71d0fde 100644 --- a/internal/cli/login.go +++ b/internal/cli/commands/login.go @@ -1,4 +1,4 @@ -package cli +package commands import ( "bytes" @@ -10,11 +10,7 @@ import ( "github.com/spf13/cobra" ) -func init() { - rootCmd.AddCommand(loginCmd) -} - -var loginCmd = &cobra.Command{ +var LoginCmd = &cobra.Command{ Use: "login", Short: "Login to envlink.", Long: `Login to envlink.`, diff --git a/internal/cli/projects.go b/internal/cli/commands/projects.go similarity index 72% rename from internal/cli/projects.go rename to internal/cli/commands/projects.go index b1eeadb..ec508e5 100644 --- a/internal/cli/projects.go +++ b/internal/cli/commands/projects.go @@ -1,4 +1,4 @@ -package cli +package commands import ( "fmt" @@ -6,11 +6,7 @@ import ( "github.com/spf13/cobra" ) -func init() { - rootCmd.AddCommand(projectsCmd) -} - -var projectsCmd = &cobra.Command{ +var ProjectsCmd = &cobra.Command{ Use: "projects", Short: "Lists all the .envs you have stored.", Long: `Lists all the .envs you have stored.`, diff --git a/internal/cli/pull.go b/internal/cli/commands/pull.go similarity index 79% rename from internal/cli/pull.go rename to internal/cli/commands/pull.go index e7dff37..25991f7 100644 --- a/internal/cli/pull.go +++ b/internal/cli/commands/pull.go @@ -1,4 +1,4 @@ -package cli +package commands import ( "fmt" @@ -6,11 +6,7 @@ import ( "github.com/spf13/cobra" ) -func init() { - rootCmd.AddCommand(pullCmd) -} - -var pullCmd = &cobra.Command{ +var PullCmd = &cobra.Command{ Use: "pull", Short: "Pulls the project's latest changes to the .env file.", Long: `Pulls the project's latest changes to the .env file. It will update your local .env whether it is new or not.`, diff --git a/internal/cli/push.go b/internal/cli/commands/push.go similarity index 78% rename from internal/cli/push.go rename to internal/cli/commands/push.go index 346eb45..7fe8770 100644 --- a/internal/cli/push.go +++ b/internal/cli/commands/push.go @@ -1,4 +1,4 @@ -package cli +package commands import ( "fmt" @@ -6,11 +6,7 @@ import ( "github.com/spf13/cobra" ) -func init() { - rootCmd.AddCommand(pushCmd) -} - -var pushCmd = &cobra.Command{ +var PushCmd = &cobra.Command{ Use: "push", Short: "Pushes your project's .env to the database.", Long: `Pushes your project's .env to the database. It will update the entry whether there are new changes or not.`, diff --git a/internal/cli/commands/register.go b/internal/cli/commands/register.go new file mode 100644 index 0000000..20ac4e4 --- /dev/null +++ b/internal/cli/commands/register.go @@ -0,0 +1,150 @@ +package commands + +import ( + "bytes" + "fmt" + "log" + "net" + "net/http" + "time" + "encoding/json" + + cliutils "github.com/jasutiin/envlink/internal/cli/utils" + "github.com/spf13/cobra" +) + +var RegisterCmd = &cobra.Command{ + Use: "register", + Short: "Register to envlink.", + Long: `Register to envlink.`, + Run: func(cmd *cobra.Command, args []string) { + register() + }, +} + +func register() { + var choice string + fmt.Println("1) Email/Password") + fmt.Println("2) Google") + + fmt.Print("Which auth provider would you like to use? ") + fmt.Scanln(&choice) + + switch choice { + case "1": + registerUsingEmailPassword() + case "2": + registerUsingGoogle() + default: + fmt.Println("Cancelled.") + } +} + +func registerUsingEmailPassword() { + var email string + var password string + + fmt.Printf("Email: ") + fmt.Scanln(&email) + + fmt.Printf("Password: ") + fmt.Scanln(&password) + + data := map[string]string{"email": email, "password": password} + jsonBytes, err := json.Marshal(data) // escapes special chars to prevent injection + if err != nil { + fmt.Println("error encoding request body") + return + } + payload := bytes.NewBuffer(jsonBytes) + client := &http.Client{Timeout: 10 * time.Second} + req, err := http.NewRequest("POST", "http://localhost:8080/api/v1/auth/register", payload) + + if err != nil { + fmt.Println("error on creating new POST req for register") + return + } + + req.Header.Set("Content-Type", "application/json") + + resp, err := client.Do(req) + if err != nil { + log.Fatalf("Error performing request: %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + fmt.Printf("request successful with status code: %d\n", resp.StatusCode) + } else { + fmt.Printf("request failed with status code: %d\n", resp.StatusCode) + } +} + +func registerUsingGoogle() { + state, err := cliutils.NewCLISessionID() + if err != nil { + fmt.Println("failed to create auth state") + return + } + + // listen on localhost with a port assigned by OS + listener, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + fmt.Println("failed to start local callback listener") + return + } + defer listener.Close() // this will run when registerUsingGoogle() function returns + + // callback path on the CLI's temporary HTTP server + callbackPath := "/oauth/google/callback" + callbackURL := fmt.Sprintf("http://%s%s", listener.Addr().String(), callbackPath) + + // make the actual server url that the CLI will call to + authURL := cliutils.BuildServerGoogleAuthURL(callbackURL, state) + resultChan := make(chan cliutils.CallbackResult, 1) + server := cliutils.CreateLocalServer(callbackPath, state, resultChan) + + // run a local server for the CLI + go func() { + _ = server.Serve(listener) + }() + + fmt.Println("Opening browser to:", authURL) + if err := cliutils.OpenInBrowser(authURL); err != nil { + fmt.Println("failed to open browser automatically. Open this URL manually:") + fmt.Println(authURL) + _ = server.Close() + return + } + + waitTimer := time.NewTimer(2 * time.Minute) + defer waitTimer.Stop() + + var callback cliutils.CallbackResult + + // either wait until the waitTimer runs out, or if a result was returned from the local server + select { + case callback = <-resultChan: + if callback.Err != nil { + fmt.Printf("google auth did not complete: %v\n", callback.Err) + _ = server.Close() + return + } + case <-waitTimer.C: + fmt.Println("timed out waiting for google authentication") + _ = server.Close() + return + } + + token, err := cliutils.ExchangeServerCode(callback.ExchangeCode, callback.State) + if err != nil { + fmt.Printf("token exchange failed: %v\n", err) + _ = server.Close() + return + } + + _ = server.Close() + + fmt.Println("Google authentication successful.") + fmt.Println("Token:", token.AccessToken) +} \ No newline at end of file diff --git a/internal/cli/store.go b/internal/cli/commands/store.go similarity index 75% rename from internal/cli/store.go rename to internal/cli/commands/store.go index 63ba115..bc11513 100644 --- a/internal/cli/store.go +++ b/internal/cli/commands/store.go @@ -1,4 +1,4 @@ -package cli +package commands import ( "fmt" @@ -6,11 +6,7 @@ import ( "github.com/spf13/cobra" ) -func init() { - rootCmd.AddCommand(storeCmd) -} - -var storeCmd = &cobra.Command{ +var StoreCmd = &cobra.Command{ Use: "store", Short: "Store your secret key.", Long: `Store your secret key that was generated when you first registered.`, diff --git a/internal/cli/register.go b/internal/cli/register.go deleted file mode 100644 index 163032b..0000000 --- a/internal/cli/register.go +++ /dev/null @@ -1,70 +0,0 @@ -package cli - -import ( - "bytes" - "fmt" - "log" - "net/http" - "time" - - "github.com/spf13/cobra" -) - -func init() { - rootCmd.AddCommand(registerCmd) -} - -var registerCmd = &cobra.Command{ - Use: "register", - Short: "Register to envlink.", - Long: `Register to envlink.`, - Run: func(cmd *cobra.Command, args []string) { - register() - }, -} - -func register() { - var email string; - var password string; - - fmt.Printf("Email: ") - fmt.Scanln(&email) - - fmt.Printf("Password: ") - fmt.Scanln(&password) - - if email != "" { - fmt.Println("email provided") - } else { - fmt.Println("email not provided") - } - - if password != "" { - fmt.Println("password provided") - } else { - fmt.Println("password not provided") - } - - jsonStr := []byte(fmt.Sprintf(`{"email":"%s","password":"%s"}`, email, password)) - payload := bytes.NewBuffer(jsonStr) - client := &http.Client{Timeout: 10 * time.Second} - req, err := http.NewRequest("POST", "http://localhost:8080/api/v1/auth/register", payload) - - if err != nil { - fmt.Println("error on creating new POST req for register") - } - - req.Header.Set("Content-Type", "application/json") - - resp, err := client.Do(req) - if err != nil { - log.Fatalf("Error performing request: %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode >= 200 && resp.StatusCode < 300 { - fmt.Printf("request successful with status code: %d\n", resp.StatusCode) - } else { - fmt.Printf("request failed with status code: %d\n", resp.StatusCode) - } -} \ No newline at end of file diff --git a/internal/cli/root.go b/internal/cli/root.go index b4b135c..b0ac327 100644 --- a/internal/cli/root.go +++ b/internal/cli/root.go @@ -4,6 +4,7 @@ import ( "fmt" "os" + "github.com/jasutiin/envlink/internal/cli/commands" "github.com/spf13/cobra" "github.com/spf13/viper" ) @@ -38,12 +39,12 @@ func init() { viper.SetDefault("author", "NAME HERE ") viper.SetDefault("license", "apache") - rootCmd.AddCommand(loginCmd) - rootCmd.AddCommand(registerCmd) - rootCmd.AddCommand(pushCmd) - rootCmd.AddCommand(pullCmd) - rootCmd.AddCommand(projectsCmd) - rootCmd.AddCommand(storeCmd) + rootCmd.AddCommand(commands.LoginCmd) + rootCmd.AddCommand(commands.RegisterCmd) + rootCmd.AddCommand(commands.PushCmd) + rootCmd.AddCommand(commands.PullCmd) + rootCmd.AddCommand(commands.ProjectsCmd) + rootCmd.AddCommand(commands.StoreCmd) } func initConfig() { diff --git a/internal/cli/utils/exchange.go b/internal/cli/utils/exchange.go new file mode 100644 index 0000000..0b72fe5 --- /dev/null +++ b/internal/cli/utils/exchange.go @@ -0,0 +1,171 @@ +package cliutils + +import ( + "bytes" + "crypto/rand" + "encoding/hex" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "os/exec" + "strings" + "time" +) + +const googleAuthStateBytes = 24 + +type tokenExchangeRequest struct { + ExchangeCode string `json:"exchange_code"` + State string `json:"state"` +} + +type tokenExchangeResponse struct { + Token string `json:"token"` +} + +type CallbackResult struct { + ExchangeCode string + State string + Err error +} + +type GoogleTokenResult struct { + AccessToken string +} + +// NewCLISessionID creates a random state value for a CLI OAuth session. +// +// The CLI includes this state in the browser auth request and validates that +// the same value is returned on callback. This binds the callback to the +// original login attempt and helps prevent CSRF/callback-injection attacks. +func NewCLISessionID() (string, error) { + b := make([]byte, googleAuthStateBytes) + if _, err := rand.Read(b); err != nil { + return "", err + } + + return hex.EncodeToString(b), nil +} + +// BuildServerGoogleAuthURL builds the API auth endpoint URL for CLI login. +// It attaches the local callback URL and CLI state so the server can redirect +// back to the CLI listener and preserve request integrity across the flow. +func BuildServerGoogleAuthURL(callbackURL, state string) string { + baseURL := "http://localhost:8080/api/v1/auth/google" + values := url.Values{} + values.Set("cli_callback", callbackURL) + values.Set("cli_state", state) + + return baseURL + "?" + values.Encode() +} + +/* +This function creates a local server on the machine. This is used for listening to the browser's callback +function, which will be called if the server returns successfully. +*/ +func CreateLocalServer(callbackPath, expectedState string, resultChan chan<- CallbackResult) *http.Server { + mux := http.NewServeMux() + server := &http.Server{Handler: mux} + + mux.HandleFunc(callbackPath, func(w http.ResponseWriter, r *http.Request) { + query := r.URL.Query() + + // if there is an OAuth error, display "Authentication failed" in the browser + if oauthErr := strings.TrimSpace(query.Get("error")); oauthErr != "" { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("

Authentication failed

Return to your CLI and try again.

")) + resultChan <- CallbackResult{Err: fmt.Errorf("oauth error: %s", oauthErr)} + return + } + + // if invalid state, display "State validation failed." in the browser + returnedState := strings.TrimSpace(query.Get("state")) + exchangeCode := strings.TrimSpace(query.Get("exchange_code")) + if returnedState == "" || exchangeCode == "" || returnedState != expectedState { + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusBadRequest) + _, _ = w.Write([]byte("

Invalid callback

State validation failed.

")) + resultChan <- CallbackResult{Err: fmt.Errorf("invalid oauth callback state")} + return + } + + // if all checks pass, then authentication was successful. show it in browser + w.Header().Set("Content-Type", "text/html; charset=utf-8") + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("

Authentication complete

You can close this window and return to the CLI.

")) + resultChan <- CallbackResult{ExchangeCode: exchangeCode, State: returnedState} + }) + + return server +} + +/* +ExchangeServerCode sends the one-time exchange code from the local OAuth callback to the API. + +The API validates that: +1) the exchange code exists and is still valid, +2) the state value matches what was originally issued, +3) the code has not already been consumed. + +If validation succeeds, the server returns an auth token for the CLI session. +If validation fails (expired/invalid code, state mismatch, or server rejection), +this function returns an error so the user can retry login. +*/ +func ExchangeServerCode(exchangeCode, state string) (*GoogleTokenResult, error) { + // build the payload the API expects for code/state validation + payload := tokenExchangeRequest{ExchangeCode: exchangeCode, State: state} + payloadBytes, err := json.Marshal(payload) + if err != nil { + return nil, err + } + + // create a bounded-time HTTP client to avoid hanging CLI auth + client := &http.Client{Timeout: 15 * time.Second} + req, err := http.NewRequest("POST", "http://localhost:8080/api/v1/auth/cli/exchange", bytes.NewBuffer(payloadBytes)) + if err != nil { + return nil, err + } + + req.Header.Set("Content-Type", "application/json") + // send the exchange request to the local API + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + // read response body for both success and error branches + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + // non-200 means the server rejected or could not validate the exchange + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("server exchange failed with status %d: %s", resp.StatusCode, string(body)) + } + + // parse successful response and ensure token is present + var exchangeResp tokenExchangeResponse + if err := json.Unmarshal(body, &exchangeResp); err != nil { + return nil, err + } + + if strings.TrimSpace(exchangeResp.Token) == "" { + return nil, fmt.Errorf("empty token response") + } + + return &GoogleTokenResult{AccessToken: exchangeResp.Token}, nil +} + +// OpenInBrowser opens a targetURL in a browser. +func OpenInBrowser(targetURL string) error { + if err := exec.Command("rundll32", "url.dll,FileProtocolHandler", targetURL).Start(); err == nil { + return nil + } + + return exec.Command("cmd", "/c", "start", "", fmt.Sprintf("\"%s\"", targetURL)).Start() +} diff --git a/internal/cli/utils/exchange_test.go b/internal/cli/utils/exchange_test.go new file mode 100644 index 0000000..329b6b5 --- /dev/null +++ b/internal/cli/utils/exchange_test.go @@ -0,0 +1,303 @@ +package cliutils + +import ( + "bytes" + "fmt" + "io" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + "time" +) + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// TestNewCLISessionID verifies the generated session ID is a 24-byte hex string. +func TestNewCLISessionID(t *testing.T) { + // generate a new random CLI session id + sessionID, err := NewCLISessionID() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + // 24 random bytes should be hex-encoded into 48 characters + if len(sessionID) != 48 { + t.Fatalf("expected session id length 48, got %d", len(sessionID)) + } + + // ensure every character is valid lowercase hex + for _, char := range sessionID { + if !strings.ContainsRune("0123456789abcdef", char) { + t.Fatalf("expected hex string, got %q", sessionID) + } + } +} + +// TestBuildServerGoogleAuthURL verifies callback and state are encoded into the auth URL. +func TestBuildServerGoogleAuthURL(t *testing.T) { + callbackURL := "http://127.0.0.1:54001/oauth/google/callback" + state := "test-state" + + // build URL and parse it back for query validation + out := BuildServerGoogleAuthURL(callbackURL, state) + parsed, err := url.Parse(out) + if err != nil { + t.Fatalf("failed to parse url: %v", err) + } + + if parsed.Scheme != "http" || parsed.Host != "localhost:8080" { + t.Fatalf("unexpected base URL: %s", out) + } + + query := parsed.Query() + if query.Get("cli_callback") != callbackURL { + t.Fatalf("expected cli_callback %q, got %q", callbackURL, query.Get("cli_callback")) + } + + if query.Get("cli_state") != state { + t.Fatalf("expected cli_state %q, got %q", state, query.Get("cli_state")) + } +} + +// TestCreateLocalServer_OAuthError verifies OAuth error responses are surfaced to the CLI. +func TestCreateLocalServer_OAuthError(t *testing.T) { + // create handler and simulate callback with OAuth error + resultChan := make(chan CallbackResult, 1) + server := CreateLocalServer("/oauth/google/callback", "expected-state", resultChan) + + req := httptest.NewRequest("GET", "http://localhost/oauth/google/callback?error=access_denied", nil) + recorder := httptest.NewRecorder() + server.Handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, recorder.Code) + } + + // handler should report the error to result channel + select { + case result := <-resultChan: + if result.Err == nil { + t.Fatalf("expected callback error") + } + if !strings.Contains(result.Err.Error(), "oauth error") { + t.Fatalf("expected oauth error message, got %v", result.Err) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("timed out waiting for callback result") + } +} + +// TestCreateLocalServer_InvalidState verifies state mismatches are rejected. +func TestCreateLocalServer_InvalidState(t *testing.T) { + // create handler and simulate callback with mismatched state + resultChan := make(chan CallbackResult, 1) + server := CreateLocalServer("/oauth/google/callback", "expected-state", resultChan) + + req := httptest.NewRequest("GET", "http://localhost/oauth/google/callback?exchange_code=abc&state=wrong-state", nil) + recorder := httptest.NewRecorder() + server.Handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusBadRequest { + t.Fatalf("expected status %d, got %d", http.StatusBadRequest, recorder.Code) + } + + // handler should return state validation error through channel + select { + case result := <-resultChan: + if result.Err == nil { + t.Fatalf("expected callback error") + } + if !strings.Contains(result.Err.Error(), "invalid oauth callback state") { + t.Fatalf("expected invalid state error, got %v", result.Err) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("timed out waiting for callback result") + } +} + +// TestCreateLocalServer_Success verifies a valid callback returns exchange code and state. +func TestCreateLocalServer_Success(t *testing.T) { + // create handler and simulate a valid callback payload + resultChan := make(chan CallbackResult, 1) + expectedState := "expected-state" + expectedCode := "exchange-code" + server := CreateLocalServer("/oauth/google/callback", expectedState, resultChan) + + req := httptest.NewRequest("GET", "http://localhost/oauth/google/callback?exchange_code="+expectedCode+"&state="+expectedState, nil) + recorder := httptest.NewRecorder() + server.Handler.ServeHTTP(recorder, req) + + if recorder.Code != http.StatusOK { + t.Fatalf("expected status %d, got %d", http.StatusOK, recorder.Code) + } + + // handler should pass code/state back to the CLI via channel + select { + case result := <-resultChan: + if result.Err != nil { + t.Fatalf("expected no error, got %v", result.Err) + } + if result.ExchangeCode != expectedCode { + t.Fatalf("expected exchange code %q, got %q", expectedCode, result.ExchangeCode) + } + if result.State != expectedState { + t.Fatalf("expected state %q, got %q", expectedState, result.State) + } + case <-time.After(200 * time.Millisecond): + t.Fatalf("timed out waiting for callback result") + } +} + +// TestExchangeServerCode_Success verifies a valid server response returns an access token. +func TestExchangeServerCode_Success(t *testing.T) { + // replace default transport so no real HTTP call is made + originalTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + // validate outgoing request shape before returning mocked response + if req.URL.String() != "http://localhost:8080/api/v1/auth/cli/exchange" { + t.Fatalf("unexpected request URL: %s", req.URL.String()) + } + + if req.Method != http.MethodPost { + t.Fatalf("expected method POST, got %s", req.Method) + } + + if req.Header.Get("Content-Type") != "application/json" { + t.Fatalf("expected application/json content type, got %s", req.Header.Get("Content-Type")) + } + + body, err := io.ReadAll(req.Body) + if err != nil { + t.Fatalf("failed to read request body: %v", err) + } + + if !bytes.Contains(body, []byte(`"exchange_code":"code123"`)) || !bytes.Contains(body, []byte(`"state":"state123"`)) { + t.Fatalf("unexpected request body: %s", string(body)) + } + + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"token":"jwt-token"}`)), + Header: make(http.Header), + }, nil + }) + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + + // function should parse token from mocked 200 response + result, err := ExchangeServerCode("code123", "state123") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + if result == nil { + t.Fatalf("expected token result") + } + + if result.AccessToken != "jwt-token" { + t.Fatalf("expected token jwt-token, got %s", result.AccessToken) + } +} + +// TestExchangeServerCode_NonOKResponse verifies non-200 server responses return an error. +func TestExchangeServerCode_NonOKResponse(t *testing.T) { + // mock unauthorized response from API + originalTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusUnauthorized, + Body: io.NopCloser(strings.NewReader(`{"error":"invalid"}`)), + Header: make(http.Header), + }, nil + }) + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + + // function should surface status code in returned error + _, err := ExchangeServerCode("bad-code", "state123") + if err == nil { + t.Fatalf("expected error for non-200 response") + } + + if !strings.Contains(err.Error(), "server exchange failed with status 401") { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestExchangeServerCode_EmptyToken verifies empty/blank token payloads are rejected. +func TestExchangeServerCode_EmptyToken(t *testing.T) { + // mock success status with blank token payload + originalTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`{"token":" "}`)), + Header: make(http.Header), + }, nil + }) + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + + // function should reject empty token responses + _, err := ExchangeServerCode("code123", "state123") + if err == nil { + t.Fatalf("expected error for empty token") + } + + if !strings.Contains(err.Error(), "empty token response") { + t.Fatalf("unexpected error: %v", err) + } +} + +// TestExchangeServerCode_InvalidJSON verifies malformed JSON responses return an error. +func TestExchangeServerCode_InvalidJSON(t *testing.T) { + // mock malformed JSON body from API + originalTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return &http.Response{ + StatusCode: http.StatusOK, + Body: io.NopCloser(strings.NewReader(`not-json`)), + Header: make(http.Header), + }, nil + }) + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + + // function should return JSON parse error + _, err := ExchangeServerCode("code123", "state123") + if err == nil { + t.Fatalf("expected JSON unmarshal error") + } +} + +// TestExchangeServerCode_DoError verifies transport-level HTTP failures are propagated. +func TestExchangeServerCode_DoError(t *testing.T) { + // mock transport-level failure before any response is received + originalTransport := http.DefaultTransport + http.DefaultTransport = roundTripperFunc(func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("transport failure") + }) + t.Cleanup(func() { + http.DefaultTransport = originalTransport + }) + + // function should propagate the transport error + _, err := ExchangeServerCode("code123", "state123") + if err == nil { + t.Fatalf("expected transport error") + } + + if !strings.Contains(err.Error(), "transport failure") { + t.Fatalf("unexpected error: %v", err) + } +} \ No newline at end of file diff --git a/internal/server/api/auth/auth.go b/internal/server/api/auth/auth.go new file mode 100644 index 0000000..db3bb1c --- /dev/null +++ b/internal/server/api/auth/auth.go @@ -0,0 +1,58 @@ +package auth + +import ( + "errors" + "fmt" + "net/http" + "os" + + "github.com/gorilla/sessions" + "github.com/markbates/goth" + "github.com/markbates/goth/gothic" + "github.com/markbates/goth/providers/google" +) + +/* +This function is called to initialize the gothic package with the external +providers we will be using for OAuth. +*/ +func NewAuth(port string, domain string, key string, isProd bool) error { + // key needs to be 32-bytes as per NewCookieStore + if len(key) < 32 { + return errors.New("session key must be at least 32 bytes") + } + + googleClientId := os.Getenv("GOOGLE_CLIENT_ID") + if googleClientId == "" { + return errors.New("Google Client Id was not provided!") + } + + googleClientSecret := os.Getenv("GOOGLE_CLIENT_SECRET") + if googleClientSecret == "" { + return errors.New("Google Client Secret was not provided!") + } + + store := sessions.NewCookieStore([]byte(key)) + store.Options = &sessions.Options{ + Path: "/", // cookie is valid for all paths on the host + MaxAge: 86400 * 30, + HttpOnly: true, + Secure: isProd, + SameSite: http.SameSiteLaxMode, + } + + var url string + + if domain == "" { + url = fmt.Sprintf("http://localhost:%s/api/v1/auth/google/callback", port) + } else { + url = fmt.Sprintf("https://%s/api/v1/auth/google/callback", domain) + } + + gothic.Store = store + goth.UseProviders( + google.New(googleClientId, googleClientSecret, url), + ) + + return nil +} \ No newline at end of file diff --git a/internal/server/api/auth/controllers.go b/internal/server/api/auth/controllers.go new file mode 100644 index 0000000..8c36e45 --- /dev/null +++ b/internal/server/api/auth/controllers.go @@ -0,0 +1,174 @@ +package auth + +import ( + "context" + "fmt" + "html" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + serverutils "github.com/jasutiin/envlink/internal/server/utils" + "github.com/markbates/goth/gothic" +) + +type authRequestBody struct { + Email string + Password string +} + +type cliTokenExchangeRequest struct { + ExchangeCode string `json:"exchange_code"` + State string `json:"state"` +} + +type AuthController struct { + repo AuthRepository +} + +func NewController(repo AuthRepository) *AuthController { + return &AuthController{repo: repo} +} + +func (controller *AuthController) postLogin(c *gin.Context) { + var requestBody authRequestBody + if err := c.BindJSON(&requestBody); err != nil { + c.IndentedJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + + if _, err := controller.repo.Login(requestBody.Email, requestBody.Password); err != nil { + c.IndentedJSON(http.StatusInternalServerError, gin.H{"error": "failed to login"}) + return + } + + c.IndentedJSON(http.StatusOK, requestBody) +} + +func (controller *AuthController) postRegister(c *gin.Context) { + var requestBody authRequestBody + if err := c.BindJSON(&requestBody); err != nil { + c.IndentedJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + + if _, err := controller.repo.Register(requestBody.Email, requestBody.Password); err != nil { + c.IndentedJSON(http.StatusInternalServerError, gin.H{"error": "failed to register"}) + return + } + + c.IndentedJSON(http.StatusOK, requestBody) +} + +/* +This function is called when we navigate to the '/api/v1/auth/:provider' endpoint on the server, initiating the auth process +for a provider. It takes a state and callback url passed in as query parameters, and stores them on the server so that we +can use it later to navigate the caller browser back to this url. Finally, it takes the user to the provider's login screen. +*/ +func (controller *AuthController) getAuthProvider(c *gin.Context) { + provider := c.Param("provider") + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "provider", provider)) + + cliCallback := strings.TrimSpace(c.Query("cli_callback")) + cliState := strings.TrimSpace(c.Query("cli_state")) + + if cliCallback != "" || cliState != "" { + if cliCallback == "" || cliState == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "cli_callback and cli_state are required together"}) + return + } + + if !serverutils.IsAllowedCLICallback(cliCallback) { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid cli_callback"}) + return + } + + // create httpOnly cookies to use on subsequent requests + serverutils.WriteCLIAuthContext(c, cliCallback, cliState) + } + + gothic.BeginAuthHandler(c.Writer, c.Request) +} + +/* +This function is called when we navigate to the '/api/v1/auth/:provider/callback' endpoint on the server. It is called when the user +successfully authenticates. It takes the user's credentials and validates it. It generates an exchange code, then it takes the +callbackURL and state that we stored on the server and builds the url with the exchange code. It saves all of this information on the +server we can do another validation. +*/ +func (controller *AuthController) getAuthCallbackFunction(c *gin.Context) { + provider := c.Param("provider") + c.Request = c.Request.WithContext(context.WithValue(c.Request.Context(), "provider", provider)) + + user, err := gothic.CompleteUserAuth(c.Writer, c.Request) + if err != nil { + c.Data(http.StatusBadRequest, "text/html; charset=utf-8", []byte("

Authentication failed

Please return to your CLI and try again.

")) + return + } + + if callbackURL, callbackState, found := serverutils.ReadCLIAuthContext(c); found { + exchangeCode, codeErr := serverutils.NewExchangeCode() + if codeErr == nil { + // save all of this information on the server so we can refer back to it later + serverutils.PendingCLIExchanges.Save(exchangeCode, callbackState, user.AccessToken, serverutils.CLIExchangeTTL) + if redirectURL, redirectErr := serverutils.BuildCLIRedirectURL(callbackURL, exchangeCode, callbackState); redirectErr == nil { + serverutils.ClearCLIAuthContext(c) + c.Redirect(http.StatusFound, redirectURL) + return + } + } + } + + serverutils.ClearCLIAuthContext(c) + + html := fmt.Sprintf(` + + + Auth successful + +

Authentication successful

+

Name: %s

+

Email: %s

+

Provider: %s

+

You can close this window and return to the CLI.

+ + + `, html.EscapeString(user.Name), html.EscapeString(user.Email), html.EscapeString(user.Provider)) + + c.Data(http.StatusOK, "text/html; charset=utf-8", []byte(html)) +} + +/* +This function is called when we navigate to the '/api/v1/auth/cli/exchange' endpoint on the server. It is called when the browser +received the exchange code from the server, and sends it back to the server for one last validation check. If the exchange code +that the browser sent matches the one on the server, then we return an authentication token. +*/ +func (controller *AuthController) postCLIExchange(c *gin.Context) { + var requestBody cliTokenExchangeRequest + if err := c.BindJSON(&requestBody); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) + return + } + + exchangeCode := strings.TrimSpace(requestBody.ExchangeCode) + state := strings.TrimSpace(requestBody.State) + if exchangeCode == "" || state == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "exchange_code and state are required"}) + return + } + + // check if the exchange code and state is stored in the server + token, found := serverutils.PendingCLIExchanges.Consume(exchangeCode, state) + if !found { + c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid or expired exchange_code"}) + return + } + + c.JSON(http.StatusOK, gin.H{"token": token}) +} + +func (controller *AuthController) getLogoutProvider(c *gin.Context) { + gothic.Logout(c.Writer, c.Request) + c.Writer.Header().Set("Location", "/") + c.Writer.WriteHeader(http.StatusTemporaryRedirect) +} diff --git a/internal/server/auth/repository.go b/internal/server/api/auth/repository.go similarity index 100% rename from internal/server/auth/repository.go rename to internal/server/api/auth/repository.go diff --git a/internal/server/auth/router.go b/internal/server/api/auth/router.go similarity index 57% rename from internal/server/auth/router.go rename to internal/server/api/auth/router.go index 0194026..94aca5c 100644 --- a/internal/server/auth/router.go +++ b/internal/server/api/auth/router.go @@ -13,5 +13,9 @@ func AuthRouter(router *gin.RouterGroup, db *gorm.DB) { { auth.POST("/login", controller.postLogin) auth.POST("/register", controller.postRegister) + auth.GET("/:provider", controller.getAuthProvider) + auth.GET("/:provider/callback", controller.getAuthCallbackFunction) + auth.POST("/cli/exchange", controller.postCLIExchange) + auth.GET("/:provider/logout", controller.getLogoutProvider) } } \ No newline at end of file diff --git a/internal/server/auth/types.go b/internal/server/api/auth/types.go similarity index 100% rename from internal/server/auth/types.go rename to internal/server/api/auth/types.go diff --git a/internal/server/projects/controllers.go b/internal/server/api/projects/controllers.go similarity index 100% rename from internal/server/projects/controllers.go rename to internal/server/api/projects/controllers.go diff --git a/internal/server/projects/router.go b/internal/server/api/projects/router.go similarity index 100% rename from internal/server/projects/router.go rename to internal/server/api/projects/router.go diff --git a/internal/server/pull/controllers.go b/internal/server/api/pull/controllers.go similarity index 100% rename from internal/server/pull/controllers.go rename to internal/server/api/pull/controllers.go diff --git a/internal/server/pull/router.go b/internal/server/api/pull/router.go similarity index 100% rename from internal/server/pull/router.go rename to internal/server/api/pull/router.go diff --git a/internal/server/push/controllers.go b/internal/server/api/push/controllers.go similarity index 100% rename from internal/server/push/controllers.go rename to internal/server/api/push/controllers.go diff --git a/internal/server/push/router.go b/internal/server/api/push/router.go similarity index 100% rename from internal/server/push/router.go rename to internal/server/api/push/router.go diff --git a/internal/server/auth/controllers.go b/internal/server/auth/controllers.go deleted file mode 100644 index a65e85b..0000000 --- a/internal/server/auth/controllers.go +++ /dev/null @@ -1,55 +0,0 @@ -package auth - -import ( - "fmt" - "net/http" - - "github.com/gin-gonic/gin" -) - -type authRequestBody struct { - Email string - Password string -} - -type AuthController struct { - repo AuthRepository -} - -func NewController(repo AuthRepository) *AuthController { - return &AuthController{repo: repo} -} - -func (controller *AuthController) postLogin(c *gin.Context) { - var requestBody authRequestBody - if err := c.BindJSON(&requestBody); err != nil { - c.IndentedJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) - return - } - - if _, err := controller.repo.Login(requestBody.Email, requestBody.Password); err != nil { - c.IndentedJSON(http.StatusInternalServerError, gin.H{"error": "failed to login"}) - return - } - - fmt.Println(requestBody.Email) - fmt.Println(requestBody.Password) - c.IndentedJSON(http.StatusOK, requestBody) -} - -func (controller *AuthController) postRegister(c *gin.Context) { - var requestBody authRequestBody - if err := c.BindJSON(&requestBody); err != nil { - c.IndentedJSON(http.StatusBadRequest, gin.H{"error": "invalid request body"}) - return - } - - if _, err := controller.repo.Register(requestBody.Email, requestBody.Password); err != nil { - c.IndentedJSON(http.StatusInternalServerError, gin.H{"error": "failed to register"}) - return - } - - fmt.Println(requestBody.Email) - fmt.Println(requestBody.Password) - c.IndentedJSON(http.StatusOK, requestBody) -} \ No newline at end of file diff --git a/internal/server/database/database.go b/internal/server/database/database.go index b15b3a9..cd05903 100644 --- a/internal/server/database/database.go +++ b/internal/server/database/database.go @@ -4,7 +4,7 @@ import ( "fmt" "os" - "github.com/jasutiin/envlink/internal/server/auth" + "github.com/jasutiin/envlink/internal/server/api/auth" "github.com/joho/godotenv" "gorm.io/driver/postgres" "gorm.io/gorm" @@ -28,12 +28,12 @@ func CreateDB() *gorm.DB { if err != nil { fmt.Println("failed to open db") } - + return db } func AutoMigrate(db *gorm.DB) { if err := db.AutoMigrate(&auth.User{}); err != nil { - fmt.Println("migrate failed:", err) + fmt.Println("migrate failed:", err) } -} \ No newline at end of file +} diff --git a/internal/server/utils/exchange.go b/internal/server/utils/exchange.go new file mode 100644 index 0000000..86073e4 --- /dev/null +++ b/internal/server/utils/exchange.go @@ -0,0 +1,173 @@ +package serverutils + +import ( + "crypto/rand" + "encoding/hex" + "net/url" + "strings" + "sync" + "time" + + "github.com/gin-gonic/gin" +) + +const CLIExchangeTTL = 2 * time.Minute + +type cliExchangeStore struct { + mu sync.Mutex + entries map[string]cliExchangeEntry +} + +type cliExchangeEntry struct { + token string + state string + expiresAt time.Time +} + +const ( + cliCallbackCookieName = "envlink_cli_callback" + cliStateCookieName = "envlink_cli_state" + cliCookieTTLSeconds = 300 +) + +/* +newCLIExchangeStore initializes a new store and creates a new cliExchangeEntry map +so it is allocated. It doesn't create a mutex because the mutex's default value of 0 means +that it is unlocked. +*/ +func newCLIExchangeStore() *cliExchangeStore { + return &cliExchangeStore{entries: make(map[string]cliExchangeEntry)} +} + +/* +Save saves a new cliExchangeEntry to the cliExchangeStore. +*/ +func (store *cliExchangeStore) Save(exchangeCode, state, token string, ttl time.Duration) { + if exchangeCode == "" || state == "" || token == "" { + return + } + + store.mu.Lock() + defer store.mu.Unlock() // store.mu.Unlock() is called before Save() returns + + store.entries[exchangeCode] = cliExchangeEntry{ + token: token, + state: state, + expiresAt: time.Now().Add(ttl), + } +} + +/* +Consume consumes a cliExchangeEntry given an exchange code +*/ +func (store *cliExchangeStore) Consume(exchangeCode, state string) (string, bool) { + store.mu.Lock() + defer store.mu.Unlock() // store.mu.Unlock() is called before Consume() returns + + entry, found := store.entries[exchangeCode] + if !found { + return "", false + } + + delete(store.entries, exchangeCode) // deletes the entry with exchangeCode as its key from the store.entries map + + if time.Now().After(entry.expiresAt) { + return "", false + } + + if entry.state != state { + return "", false + } + + return entry.token, true +} + +var PendingCLIExchanges = newCLIExchangeStore() + +/* +isAllowedCLICallback checks if the callback url is something valid that a user +initiated themselves. This prevents the server from returning a different +callback url that the user expects. If we did not check this, the user may +be taken to a malicious website. +*/ +func IsAllowedCLICallback(rawCallbackURL string) bool { + parsedURL, err := url.Parse(rawCallbackURL) + if err != nil { + return false + } + + if parsedURL.Scheme != "http" { + return false + } + + hostName := strings.ToLower(parsedURL.Hostname()) + return hostName == "localhost" || hostName == "127.0.0.1" || hostName == "::1" +} + +/* +writeCLIAuthContext sets httpOnly cookies for the callback url and state separately, both with an expiration time. +It adds it to the Gin context object which adds it to the response the server sends back. From there, +the browser would be sending these cookies to the server upon each subsequent request. +*/ +func WriteCLIAuthContext(c *gin.Context, callbackURL, state string) { + c.SetCookie(cliCallbackCookieName, url.QueryEscape(callbackURL), cliCookieTTLSeconds, "/", "", false, true) + c.SetCookie(cliStateCookieName, state, cliCookieTTLSeconds, "/", "", false, true) +} + +/* +readCLIAuthContext checks if the caller has cookies storing the callback url and state. +*/ +func ReadCLIAuthContext(c *gin.Context) (string, string, bool) { + callbackCookie, callbackErr := c.Cookie(cliCallbackCookieName) + stateCookie, stateErr := c.Cookie(cliStateCookieName) + if callbackErr != nil || stateErr != nil { + return "", "", false + } + + decodedCallback, decodeErr := url.QueryUnescape(callbackCookie) + if decodeErr != nil || !IsAllowedCLICallback(decodedCallback) { + return "", "", false + } + + return decodedCallback, stateCookie, true +} + +/* +clearCLIAuthContext clears cookies from the response, signalling that we have successfully +received the user's credentials. +*/ +func ClearCLIAuthContext(c *gin.Context) { + c.SetCookie(cliCallbackCookieName, "", -1, "/", "", false, true) + c.SetCookie(cliStateCookieName, "", -1, "/", "", false, true) +} + +/* +newExchangeCode creates a new exchange code that the browser will use to verify +against the server. +*/ +func NewExchangeCode() (string, error) { + b := make([]byte, 24) + if _, err := rand.Read(b); err != nil { + return "", err + } + + return hex.EncodeToString(b), nil +} + +/* +buildCLIRedirectURL creates a redirect URL with the callback URL and exchange code that we +received after logging in with an auth provider. +*/ +func BuildCLIRedirectURL(callbackURL, exchangeCode, state string) (string, error) { + parsedURL, err := url.Parse(callbackURL) + if err != nil { + return "", err + } + + queryValues := parsedURL.Query() + queryValues.Set("exchange_code", exchangeCode) + queryValues.Set("state", state) + parsedURL.RawQuery = queryValues.Encode() + + return parsedURL.String(), nil +} diff --git a/internal/server/utils/exchange_test.go b/internal/server/utils/exchange_test.go new file mode 100644 index 0000000..25c76ed --- /dev/null +++ b/internal/server/utils/exchange_test.go @@ -0,0 +1,106 @@ +package serverutils + +import ( + "net/url" + "testing" + "time" +) + +// TestIsAllowedCLICallback tests multiple URLs to see if they are allowed as a callback url. +func TestIsAllowedCLICallback(t *testing.T) { + allowed := []string{ + "http://localhost:8080/cb", + "http://127.0.0.1/cb", + "http://[::1]/cb", + } + + // all of the urls in allowed slice should return true + for _, u := range allowed { + if !IsAllowedCLICallback(u) { + t.Errorf("expected allowed for %s", u) + } + } + + disallowed := []string{ + "https://localhost/cb", + "http://example.com/cb", + "notaurl", + } + + // all of the urls in disallowed slice should return false + for _, u := range disallowed { + if IsAllowedCLICallback(u) { + t.Errorf("expected disallowed for %s", u) + } + } +} + +// TestBuildCLIRedirectURL tests if BuildCLIRedirectURL() correctly appends +// exchange_code and state to the callback's query string. +func TestBuildCLIRedirectURL(t *testing.T) { + callback := "http://localhost:3000/cb?foo=bar" + code := "abc123" + state := "s1" + + out, err := BuildCLIRedirectURL(callback, code, state) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + + parsed, err := url.Parse(out) + if err != nil { + t.Fatalf("failed to parse: %v", err) + } + + q := parsed.Query() + if q.Get("exchange_code") != code { + t.Fatalf("expected exchange_code %s got %s", code, q.Get("exchange_code")) + } + if q.Get("state") != state { + t.Fatalf("expected state %s got %s", state, q.Get("state")) + } +} + +// TestNewExchangeCode tests if NewExchangeCode() returns a non-error hex string +// of 48 characters or 24 encoded bytes. +func TestNewExchangeCode(t *testing.T) { + code, err := NewExchangeCode() + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(code) != 48 { + t.Fatalf("expected code length 48 got %d", len(code)) + } +} + +// TestCLIExchangeStore_SaveConsume tests the cliExchangeStore's Save() and Consume() +// functions to see if they are working properly. +func TestCLIExchangeStore_SaveConsume(t *testing.T) { + store := newCLIExchangeStore() + + store.Save("code1", "state1", "token1", time.Minute) // save a new entry + token, found := store.Consume("code1", "state1") + if !found || token != "token1" { + t.Fatalf("expected to find token, got found=%v tok=%s", found, token) + } + + // second consume should fail because we've already consumed it in the prev call + _, found = store.Consume("code1", "state1") + if found { + t.Fatalf("expected second consume to fail") + } + + // wrong state + store.Save("code2", "state2", "token2", time.Minute) + _, found = store.Consume("code2", "wrong") + if found { + t.Fatalf("expected consume with wrong state to fail") + } + + // expired + store.Save("code3", "state3", "token3", -time.Second) + _, found = store.Consume("code3", "state3") + if found { + t.Fatalf("expected expired consume to fail") + } +}