From 57171d8ae60bff59cf11ec9dc8249e9a1b437a29 Mon Sep 17 00:00:00 2001 From: thde Date: Fri, 13 Mar 2026 14:29:05 +0100 Subject: [PATCH] feat(exec): add exec support for on-demand services --- exec/cmd.go | 398 ++++++++++++++++++++++++++++++++ exec/cmd_test.go | 75 ++++++ exec/exec.go | 10 +- exec/keyvaluestore.go | 94 ++++++++ exec/keyvaluestore_test.go | 144 ++++++++++++ exec/mysql.go | 101 ++++++++ exec/mysql_test.go | 134 +++++++++++ exec/mysqldatabase.go | 64 +++++ exec/mysqldatabase_test.go | 100 ++++++++ exec/postgres.go | 106 +++++++++ exec/postgres_test.go | 146 ++++++++++++ exec/postgresdatabase.go | 60 +++++ exec/postgresdatabase_test.go | 100 ++++++++ get/apiserviceaccount.go | 2 +- get/bucketuser.go | 2 +- get/database.go | 4 +- get/get.go | 12 +- get/keyvaluestore.go | 2 +- get/opensearch.go | 2 +- get/postgres.go | 2 +- get/postgresdatabase.go | 20 +- internal/ipcheck/client.go | 152 ++++++++++++ internal/ipcheck/client_test.go | 72 ++++++ main.go | 3 +- 24 files changed, 1780 insertions(+), 25 deletions(-) create mode 100644 exec/cmd.go create mode 100644 exec/cmd_test.go create mode 100644 exec/keyvaluestore.go create mode 100644 exec/keyvaluestore_test.go create mode 100644 exec/mysql.go create mode 100644 exec/mysql_test.go create mode 100644 exec/mysqldatabase.go create mode 100644 exec/mysqldatabase_test.go create mode 100644 exec/postgres.go create mode 100644 exec/postgres_test.go create mode 100644 exec/postgresdatabase.go create mode 100644 exec/postgresdatabase_test.go create mode 100644 internal/ipcheck/client.go create mode 100644 internal/ipcheck/client_test.go diff --git a/exec/cmd.go b/exec/cmd.go new file mode 100644 index 00000000..bd4dc50f --- /dev/null +++ b/exec/cmd.go @@ -0,0 +1,398 @@ +package exec + +import ( + "context" + "errors" + "fmt" + "io" + "net" + "net/netip" + "os" + "os/exec" + "time" + + "github.com/crossplane/crossplane-runtime/pkg/resource" + "github.com/mattn/go-isatty" + meta "github.com/ninech/apis/meta/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/get" + "github.com/ninech/nctl/internal/cli" + "github.com/ninech/nctl/internal/format" + "github.com/ninech/nctl/internal/ipcheck" +) + +// cmdExecutor encapsulates resource-specific logic for connecting via an external CLI. +type cmdExecutor[T resource.Managed] interface { + // Command returns the CLI binary name (e.g. "psql", "mysql", "redis-cli"). + Command() string + + // Endpoint returns "host:port" for the TCP connectivity check. + Endpoint(res T) string + + // Args builds CLI arguments from the resource and credentials. + // The returned cleanup func removes any temp files created (e.g. CA cert). + Args(res T, user, pw string) (args []string, cleanup func(), err error) +} + +// accessManager extends cmdExecutor for resources that have access restrictions. +type accessManager[T resource.Managed] interface { + // AllowedCIDRs returns the current list of allowed CIDRs for the resource. + AllowedCIDRs(res T) []meta.IPv4CIDR + + // Update patches the resource to allow the given CIDRs. + Update(ctx context.Context, client *api.Client, res T, cidrs []meta.IPv4CIDR) error +} + +// serviceCmd is the shared base for all database exec sub-commands. +type serviceCmd struct { + resourceCmd + format.Writer `kong:"-"` + format.Reader `kong:"-"` + AllowedCidrs *[]meta.IPv4CIDR `placeholder:"203.0.113.1/32" help:"Specifies the IP addresses allowed to connect to the instance. Overrides auto-detected public IP."` + WaitTimeout time.Duration `default:"3m" help:"Timeout waiting for connectivity."` + ExtraArgs []string `arg:"" optional:"" passthrough:"" help:"Additional flags passed to the CLI (after --)."` + + // Internal dependencies — nil means use production default. + runCommand func(ctx context.Context, name string, args []string) error `kong:"-"` + lookPath func(file string) (string, error) `kong:"-"` + waitForConnectivity func(ctx context.Context, writer format.Writer, endpoint string, timeout time.Duration) error `kong:"-"` + openTTYForConfirm func() (io.ReadCloser, error) `kong:"-"` +} + +// BeforeApply initializes Writer and Reader from Kong's bound io.Writer and io.Reader. +func (cmd *serviceCmd) BeforeApply(writer io.Writer, reader io.Reader) error { + return errors.Join( + cmd.Writer.BeforeApply(writer), + cmd.Reader.BeforeApply(reader), + ) +} + +func (cmd serviceCmd) getRunCommand() func(context.Context, string, []string) error { + if cmd.runCommand != nil { + return cmd.runCommand + } + + return func(ctx context.Context, name string, args []string) error { + cmd := exec.CommandContext(ctx, name, args...) + cmd.Stdin = os.Stdin + cmd.Stdout = os.Stdout + cmd.Stderr = os.Stderr + return cmd.Run() + } +} + +func (cmd serviceCmd) getLookPath() func(string) (string, error) { + if cmd.lookPath != nil { + return cmd.lookPath + } + + return exec.LookPath +} + +func (cmd serviceCmd) connectivityCheck() func(context.Context, format.Writer, string, time.Duration) error { + if cmd.waitForConnectivity != nil { + return cmd.waitForConnectivity + } + + return waitForConnectivity +} + +// openTTY returns the openTTY function to use for confirming prompts. +func (cmd serviceCmd) openTTY() func() (io.ReadCloser, error) { + if cmd.openTTYForConfirm != nil { + return cmd.openTTYForConfirm + } + + return func() (io.ReadCloser, error) { + return os.Open("/dev/tty") + } +} + +// connectAndExec is the main orchestration function for exec commands. +// It handles path checking, connectivity waiting, and credential retrieval. +func connectAndExec[T resource.Managed]( + ctx context.Context, + client *api.Client, + res T, + connector cmdExecutor[T], + opts serviceCmd, +) error { + if err := opts.checkPath(connector.Command()); err != nil { + return err + } + + endpoint := connector.Endpoint(res) + if endpoint == "" { + return fmt.Errorf("resource %q is not ready yet (no endpoint available)", res.GetName()) + } + + if !quickDial(ctx, endpoint) { + if am, ok := connector.(accessManager[T]); ok { + if err := ensureAccess(ctx, client, am, res, opts); err != nil { + return err + } + } + + if err := opts.connectivityCheck()(ctx, opts.Writer, endpoint, opts.WaitTimeout); err != nil { + return err + } + } + + user, pw, err := getCredentials(ctx, client, res) + if err != nil { + return err + } + + args, cleanup, err := connector.Args(res, user, pw) + if err != nil { + return fmt.Errorf("building CLI arguments: %w", err) + } + defer cleanup() + + args = append(args, opts.ExtraArgs...) + + if err := opts.getRunCommand()(ctx, connector.Command(), args); err != nil { + if exitErr, ok := errors.AsType[*exec.ExitError](err); ok { + return cli.ErrorWithContext(err).WithExitCode(exitErr.ExitCode()) + } + return err + } + + return nil +} + +// ensureAccess detects the caller's public IP (or uses the overridden list), +// checks whether it is already permitted, and if not prompts the user before +// calling connector.Update. +func ensureAccess[T resource.Managed]( + ctx context.Context, + client *api.Client, + connector accessManager[T], + res T, + cmd serviceCmd, +) error { + var toAdd []meta.IPv4CIDR + + if cmd.AllowedCidrs != nil { + toAdd = *cmd.AllowedCidrs + + if cidrsPresent(connector.AllowedCIDRs(res), toAdd) { + cmd.Infof("✅", "specified CIDRs are already allowed") + return nil + } + } else { + ip, err := ipcheck.New(ipcheck.WithUserAgent(cli.Name)).PublicIP(ctx) + if err != nil { + return cli.ErrorWithContext(fmt.Errorf("detecting public IP address: %w", err)). + WithSuggestions("Are you connected to the internet?") + } + if ip.Blocked { + return cli.ErrorWithContext(fmt.Errorf("public IP seems to be blocked")). + WithContext("IP", ip.RemoteAddr.String()). + WithSuggestions("Reach out to support@nine.ch.") + } + cmd.Infof("🌐", "detected public IP: %s", ip.RemoteAddr) + + if cidr := ipCoveredByCIDRs(ip.RemoteAddr, connector.AllowedCIDRs(res)); cidr != nil { + cmd.Infof("✅", "IP %s is already allowedby %s", ip.RemoteAddr, cidr.String()) + return nil + } + + toAdd = []meta.IPv4CIDR{meta.IPv4CIDR(netip.PrefixFrom(ip.RemoteAddr, 32).String())} + } + + msg := fmt.Sprintf("Add %v to the allowed CIDRs of %q?", toAdd, res.GetName()) + ok, err := cmd.confirm(msg) + if err != nil { + return err + } + if !ok { + return fmt.Errorf("CIDR addition canceled") + } + + // Merge with existing CIDRs. + merged := appendMissing(connector.AllowedCIDRs(res), toAdd) + if err := connector.Update(ctx, client, res, merged); err != nil { + return fmt.Errorf("updating allowed CIDRs: %w", err) + } + + cmd.Infof("ℹ️", "to remove this CIDR later: nctl update %s %s", res.GetObjectKind().GroupVersionKind().Kind, res.GetName()) + + return nil +} + +// waitForConnectivity dials endpoint in a retry loop until it succeeds or timeout expires. +func waitForConnectivity(ctx context.Context, writer format.Writer, endpoint string, timeout time.Duration) error { + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + spinner, err := writer.Spinner( + format.Progressf("⏳", "waiting for connectivity to %s", endpoint), + format.Progressf("✅", "connected to %s", endpoint), + ) + if err != nil { + return err + } + + _ = spinner.Start() + defer func() { _ = spinner.Stop() }() + + ticker := time.NewTicker(2 * time.Second) + defer ticker.Stop() + + for { + attemptCtx, attemptCancel := context.WithTimeout(ctx, 3*time.Second) + dialErr := dialTCP(attemptCtx, endpoint) + attemptCancel() + if dialErr == nil { + _ = spinner.Stop() + return nil + } + + select { + case <-ctx.Done(): + switch ctx.Err() { + case context.DeadlineExceeded: + msg := "timeout waiting for connectivity to %s" + spinner.StopFailMessage(format.Progressf("", msg, endpoint)) + _ = spinner.StopFail() + return fmt.Errorf(msg, endpoint) + default: + _ = spinner.StopFail() + return nil + } + case <-ticker.C: + } + } +} + +// checkPath verifies that the named CLI binary is installed and on PATH. +func (cmd serviceCmd) checkPath(name string) error { + if _, err := cmd.getLookPath()(name); err != nil { + return cli.ErrorWithContext(fmt.Errorf("%q CLI not found", name)). + WithSuggestions( + fmt.Sprintf("Install %q and ensure it is available in your PATH.", name), + ) + } + return nil +} + +// confirm prints a confirmation prompt. When stdin is not a TTY it opens /dev/tty +// so that piped input (e.g. SQL dumps) does not consume the prompt, mirroring +// the pattern used by git and ssh. +func (cmd serviceCmd) confirm(msg string) (bool, error) { + if !isatty.IsTerminal(os.Stdin.Fd()) { + tty, err := cmd.openTTY()() + if err == nil { + defer tty.Close() + return cmd.Confirm(format.NewReader(tty), msg) + } + } + return cmd.Confirm(cmd.Reader, msg) +} + +// getCredentials fetches the connection secret for the given resource and +// returns the first username/password pair found. +func getCredentials(ctx context.Context, client *api.Client, mg resource.Managed) (string, string, error) { + secret, err := get.ConnectionSecretMap(ctx, client, mg) + if err != nil { + return "", "", fmt.Errorf("getting connection secret: %w", err) + } + + for user, pw := range secret { + return user, string(pw), nil + } + + return "", "", fmt.Errorf("connection secret %q contains no credentials", mg.GetWriteConnectionSecretToReference().Name) +} + +// ipCoveredByCIDRs reports whether ip is contained in any of the given CIDRs. +func ipCoveredByCIDRs(ip netip.Addr, cidrs []meta.IPv4CIDR) *netip.Prefix { + for _, cidr := range cidrs { + p, err := netip.ParsePrefix(string(cidr)) + if err != nil { + continue + } + if p.Contains(ip) { + return &p + } + } + + return nil +} + +// cidrsPresent reports whether all of want are present in current. +func cidrsPresent(current []meta.IPv4CIDR, want []meta.IPv4CIDR) bool { + set := make(map[meta.IPv4CIDR]struct{}, len(current)) + for _, c := range current { + set[c] = struct{}{} + } + for _, w := range want { + if _, ok := set[w]; !ok { + return false + } + } + return true +} + +// appendMissing appends any CIDRs from add that are not already in current. +func appendMissing(current []meta.IPv4CIDR, add []meta.IPv4CIDR) []meta.IPv4CIDR { + set := make(map[meta.IPv4CIDR]struct{}, len(current)) + for _, c := range current { + set[c] = struct{}{} + } + result := append([]meta.IPv4CIDR(nil), current...) + for _, a := range add { + if _, ok := set[a]; !ok { + result = append(result, a) + } + } + return result +} + +// dialTCP opens a single TCP connection to endpoint, respecting ctx for +// cancellation and deadline. +func dialTCP(ctx context.Context, endpoint string) error { + var d net.Dialer + conn, err := d.DialContext(ctx, "tcp", endpoint) + if err != nil { + return err + } + _ = conn.Close() + return nil +} + +// quickDial attempts a single TCP connection with a short timeout. +// Returns true when the endpoint is immediately reachable. +func quickDial(ctx context.Context, endpoint string) bool { + ctx, cancel := context.WithTimeout(ctx, 2*time.Second) + defer cancel() + return dialTCP(ctx, endpoint) == nil +} + +// writeCACert decodes a base64-encoded PEM CA certificate and writes it to a +// temporary file, returning the file path along with a cleanup function. +func writeCACert(caCert string) (path string, cleanup func(), err error) { + if caCert == "" { + return "", func() {}, nil + } + + f, err := os.CreateTemp("", "nctl-ca-*.pem") + if err != nil { + return "", func() {}, fmt.Errorf("creating CA cert temp file: %w", err) + } + + if err := get.WriteBase64(f, caCert); err != nil { + _ = f.Close() + _ = os.Remove(f.Name()) + return "", func() {}, fmt.Errorf("writing CA cert: %w", err) + } + + if err := f.Close(); err != nil { + _ = os.Remove(f.Name()) + return "", func() {}, fmt.Errorf("closing CA cert temp file: %w", err) + } + + path = f.Name() + return path, func() { _ = os.Remove(path) }, nil +} diff --git a/exec/cmd_test.go b/exec/cmd_test.go new file mode 100644 index 00000000..ced9d7d6 --- /dev/null +++ b/exec/cmd_test.go @@ -0,0 +1,75 @@ +package exec + +import ( + "bytes" + "context" + "fmt" + "io" + "strings" + "time" + + meta "github.com/ninech/apis/meta/v1alpha1" + "github.com/ninech/nctl/internal/format" + corev1 "k8s.io/api/core/v1" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +// capturingCmd records the CLI name and args passed to runCommand. +type capturingCmd struct { + name string + args []string +} + +// testSecret creates a corev1.Secret with a single username→password entry. +func testSecret(name, namespace, user, password string) *corev1.Secret { + return &corev1.Secret{ + ObjectMeta: metav1.ObjectMeta{ + Name: name, + Namespace: namespace, + }, + Data: map[string][]byte{ + user: []byte(password), + }, + } +} + +// testDatabaseCmd returns a capturingCmd and a databaseCmd wired with no-op +// writer/reader and test-friendly function fields. +// When cidrs is non-nil those CIDRs are used; when nil the IP detection is +// triggered only for instance resources (which is safe to use in tests if the +// connector returns nil from AllowedCIDRs). +func testDatabaseCmd(name string, cidrs *[]meta.IPv4CIDR) (*capturingCmd, serviceCmd) { + return testDatabaseCmdConfirmed(name, cidrs, false) +} + +// testDatabaseCmdConfirmed is like testDatabaseCmd but pre-seeds the reader +// with "y\n" so that confirmation prompts are auto-accepted. +func testDatabaseCmdConfirmed(name string, cidrs *[]meta.IPv4CIDR, confirmed bool) (*capturingCmd, serviceCmd) { + var reader io.Reader = &bytes.Buffer{} + if confirmed { + reader = strings.NewReader("y\n") + } + cap := &capturingCmd{} + cmd := serviceCmd{ + resourceCmd: resourceCmd{Name: name}, + Writer: format.NewWriter(&bytes.Buffer{}), + Reader: format.NewReader(reader), + AllowedCidrs: cidrs, + WaitTimeout: 0, + runCommand: func(_ context.Context, n string, args []string) error { + cap.name = n + cap.args = args + return nil + }, + lookPath: func(file string) (string, error) { + return "/usr/bin/" + file, nil + }, + waitForConnectivity: func(_ context.Context, _ format.Writer, _ string, _ time.Duration) error { + return nil + }, + openTTYForConfirm: func() (io.ReadCloser, error) { + return nil, fmt.Errorf("no tty in tests") + }, + } + return cap, cmd +} diff --git a/exec/exec.go b/exec/exec.go index f37d6a41..b1d82df2 100644 --- a/exec/exec.go +++ b/exec/exec.go @@ -1,10 +1,16 @@ // Package exec provides the implementation for the exec command. package exec +// Cmd holds all exec sub-commands. type Cmd struct { - Application applicationCmd `cmd:"" group:"deplo.io" aliases:"app,application" name:"application" help:"Execute a command or shell in a deplo.io application."` + Application applicationCmd `cmd:"" group:"deplo.io" aliases:"app,application" name:"application" help:"Execute a command or shell in a deplo.io application."` + Postgres postgresCmd `cmd:"" group:"storage.nine.ch" name:"postgres" help:"Connect to a PostgreSQL instance."` + PostgresDatabase postgresDatabaseCmd `cmd:"" group:"storage.nine.ch" name:"postgresdatabase" help:"Connect to a PostgreSQL database."` + MySQL mysqlCmd `cmd:"" group:"storage.nine.ch" name:"mysql" help:"Connect to a MySQL instance."` + MySQLDatabase mysqlDatabaseCmd `cmd:"" group:"storage.nine.ch" name:"mysqldatabase" help:"Connect to a MySQL database."` + KeyValueStore kvsCmd `cmd:"" group:"storage.nine.ch" name:"keyvaluestore" aliases:"kvs" help:"Connect to a KeyValueStore instance."` } type resourceCmd struct { - Name string `arg:"" completion-predictor:"resource_name" help:"Name of the application to exec command/shell in." required:""` + Name string `arg:"" completion-predictor:"resource_name" help:"Name of the resource." required:""` } diff --git a/exec/keyvaluestore.go b/exec/keyvaluestore.go new file mode 100644 index 00000000..c3b0dbfa --- /dev/null +++ b/exec/keyvaluestore.go @@ -0,0 +1,94 @@ +package exec + +import ( + "context" + "fmt" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/cli" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const kvsPort = "6380" + +type kvsCmd struct { + serviceCmd +} + +// Help displays usage examples for the keyvaluestore exec command. +func (cmd kvsCmd) Help() string { + return `Examples: + # Connect to a KeyValueStore instance interactively + nctl exec keyvaluestore mykvs + + # Pass extra flags to redis-cli (after --) + nctl exec keyvaluestore mykvs -- --no-auth-warning +` +} + +func (cmd *kvsCmd) Run(ctx context.Context, client *api.Client) error { + kvs := &storage.KeyValueStore{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), kvs); err != nil { + return fmt.Errorf("getting keyvaluestore %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, kvs, kvsConnector{}, cmd.serviceCmd) +} + +// kvsConnector implements ServiceConnector for storage.KeyValueStore instances. +type kvsConnector struct{} + +func (kvsConnector) Command() string { return "redis-cli" } + +func (kvsConnector) Endpoint(kvs *storage.KeyValueStore) string { + if kvs.Status.AtProvider.FQDN == "" { + return "" + } + return kvs.Status.AtProvider.FQDN + ":" + kvsPort +} + +func (kvsConnector) AllowedCIDRs(kvs *storage.KeyValueStore) []meta.IPv4CIDR { + return kvs.Spec.ForProvider.AllowedCIDRs +} + +func (kvsConnector) Update(ctx context.Context, client *api.Client, kvs *storage.KeyValueStore, cidrs []meta.IPv4CIDR) error { + current := &storage.KeyValueStore{} + if err := client.Get(ctx, api.ObjectName(kvs), current); err != nil { + return err + } + + if current.Spec.ForProvider.PublicNetworkingEnabled != nil && !*current.Spec.ForProvider.PublicNetworkingEnabled { + return cli.ErrorWithContext(fmt.Errorf("public networking is disabled for keyvaluestore %q", kvs.GetName())). + WithSuggestions( + fmt.Sprintf("Enable it with: nctl update keyvaluestore %s --public-networking", kvs.GetName()), + ) + } + + current.Spec.ForProvider.AllowedCIDRs = cidrs + return client.Update(ctx, current) +} + +func (kvsConnector) Args(kvs *storage.KeyValueStore, user, pw string) ([]string, func(), error) { + caPath, cleanup, err := writeCACert(kvs.Status.AtProvider.CACert) + if err != nil { + return nil, func() {}, err + } + + args := []string{ + "-h", kvs.Status.AtProvider.FQDN, + "-p", kvsPort, + "--tls", + "-a", pw, + } + if caPath != "" { + args = append(args, "--cacert", caPath) + } + + return args, cleanup, nil +} diff --git a/exec/keyvaluestore_test.go b/exec/keyvaluestore_test.go new file mode 100644 index 00000000..ce64762a --- /dev/null +++ b/exec/keyvaluestore_test.go @@ -0,0 +1,144 @@ +package exec + +import ( + "context" + "strings" + "testing" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestKVSCmd(t *testing.T) { + t.Parallel() + + const ( + kvsName = "mykvs" + kvsFQDN = "mykvs.example.com" + kvsToken = "supersecrettoken" + ) + + cidr := []meta.IPv4CIDR{"203.0.113.5/32"} + pubNet := true + + ready := test.KeyValueStore(kvsName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = kvsFQDN + ready.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + ready.Spec.ForProvider.PublicNetworkingEnabled = &pubNet + + pubNetFalse := false + pubNetDisabled := test.KeyValueStore("no-public", test.DefaultProject, "nine-es34") + pubNetDisabled.Status.AtProvider.FQDN = "no-public.example.com" + pubNetDisabled.Spec.ForProvider.PublicNetworkingEnabled = &pubNetFalse + pubNetDisabled.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{} + + notReady := test.KeyValueStore("notready", test.DefaultProject, "nine-es34") + + // KVS secret: single key with auth token as value. + secret := testSecret(kvsName, test.DefaultProject, "token", kvsToken) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", &cidr) + _, notReadyCmd := testDatabaseCmd("notready", &cidr) + alreadyCap, alreadyPresentCmd := testDatabaseCmd(kvsName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + _, newCidrCmd := testDatabaseCmdConfirmed(kvsName, &cidr, true) + _, pubNetDisabledCmd := testDatabaseCmdConfirmed("no-public", &cidr, true) + tokenCap, tokenCmd := testDatabaseCmd(kvsName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + + tests := []struct { + name string + cmd kvsCmd + cap *capturingCmd + wantErr bool + errContains string + wantUpdate bool + checkArgs func(t *testing.T, args []string) + }{ + { + name: "resource not found", + cmd: kvsCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: kvsCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "cidr already present skips update", + cmd: kvsCmd{serviceCmd: alreadyPresentCmd}, + cap: alreadyCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + if !strings.Contains(strings.Join(args, " "), kvsFQDN) { + t.Errorf("expected FQDN %q in args %v", kvsFQDN, args) + } + }, + }, + { + name: "new cidr triggers update", + cmd: kvsCmd{serviceCmd: newCidrCmd}, + wantUpdate: true, + }, + { + name: "public networking disabled returns error", + cmd: kvsCmd{serviceCmd: pubNetDisabledCmd}, + wantErr: true, + errContains: "networking is disabled", + }, + { + name: "token appears in args", + cmd: kvsCmd{serviceCmd: tokenCmd}, + cap: tokenCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + if !strings.Contains(strings.Join(args, " "), kvsToken) { + t.Errorf("expected token in args %v", args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, pubNetDisabled, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if tc.wantUpdate && !updateCalled { + t.Error("expected Update to be called for CIDR addition") + } + if !tc.wantErr && tc.checkArgs != nil { + tc.checkArgs(t, tc.cap.args) + } + if tc.wantUpdate { + kvs := &storage.KeyValueStore{} + if err := apiClient.Get(t.Context(), api.ObjectName(ready), kvs); err != nil { + t.Fatalf("getting kvs: %v", err) + } + if !cidrsPresent(kvs.Spec.ForProvider.AllowedCIDRs, cidr) { + t.Errorf("expected CIDR %v to be added, got %v", cidr, kvs.Spec.ForProvider.AllowedCIDRs) + } + } + }) + } +} diff --git a/exec/mysql.go b/exec/mysql.go new file mode 100644 index 00000000..42ab1482 --- /dev/null +++ b/exec/mysql.go @@ -0,0 +1,101 @@ +package exec + +import ( + "context" + "fmt" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + mysqlPort = "3306" + mysqlCommand = "mysql" +) + +type mysqlCmd struct { + serviceCmd +} + +// Help displays usage examples for the mysql exec command. +func (cmd mysqlCmd) Help() string { + return `Examples: + # Connect to a MySQL instance interactively + nctl exec mysql myinstance + + # Import a SQL dump via pipe + cat dump.sql | nctl exec mysql myinstance + + # Pass extra flags to mysql (after --) + nctl exec mysql myinstance -- --batch +` +} + +func (cmd *mysqlCmd) Run(ctx context.Context, client *api.Client) error { + my := &storage.MySQL{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), my); err != nil { + return fmt.Errorf("getting mysql %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, my, mysqlConnector{}, cmd.serviceCmd) +} + +// mysqlConnector implements cmdExecutor for storage.MySQL instances. +type mysqlConnector struct{} + +func (mysqlConnector) Command() string { return mysqlCommand } + +func (mysqlConnector) Endpoint(my *storage.MySQL) string { + if my.Status.AtProvider.FQDN == "" { + return "" + } + return my.Status.AtProvider.FQDN + ":" + mysqlPort +} + +func (mysqlConnector) AllowedCIDRs(my *storage.MySQL) []meta.IPv4CIDR { + return my.Spec.ForProvider.AllowedCIDRs +} + +func (mysqlConnector) Update(ctx context.Context, client *api.Client, my *storage.MySQL, cidrs []meta.IPv4CIDR) error { + current := &storage.MySQL{} + if err := client.Get(ctx, api.ObjectName(my), current); err != nil { + return err + } + current.Spec.ForProvider.AllowedCIDRs = cidrs + return client.Update(ctx, current) +} + +func (mysqlConnector) Args(my *storage.MySQL, user, pw string) ([]string, func(), error) { + return mysqlArgs(my.Status.AtProvider.FQDN, "", my.Status.AtProvider.CACert, user, pw) +} + +// mysqlArgs returns the mysql CLI arguments for connecting to a MySQL instance. +// dbName is appended as a positional argument when non-empty. +func mysqlArgs(fqdn, dbName, caCertBase64, user, pw string) ([]string, func(), error) { + caPath, cleanup, err := writeCACert(caCertBase64) + if err != nil { + return nil, func() {}, err + } + + args := []string{ + "-h", fqdn, + "-P", mysqlPort, + "-u", user, + "-p" + pw, + "--ssl-mode=REQUIRED", + } + if caPath != "" { + args = append(args, "--ssl-ca="+caPath) + } + if dbName != "" { + args = append(args, dbName) + } + + return args, cleanup, nil +} diff --git a/exec/mysql_test.go b/exec/mysql_test.go new file mode 100644 index 00000000..38ca3cfb --- /dev/null +++ b/exec/mysql_test.go @@ -0,0 +1,134 @@ +package exec + +import ( + "context" + "strings" + "testing" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestMySQLCmd(t *testing.T) { + t.Parallel() + + const ( + myName = "mymy" + myFQDN = "mymy.example.com" + myUser = "root" + myPass = "rootpass" + ) + + cidr := []meta.IPv4CIDR{"203.0.113.5/32"} + + ready := test.MySQL(myName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = myFQDN + ready.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + + notReady := test.MySQL("notready", test.DefaultProject, "nine-es34") + + secret := testSecret(myName, test.DefaultProject, myUser, myPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", &cidr) + _, notReadyCmd := testDatabaseCmd("notready", &cidr) + alreadyCap, alreadyPresentCmd := testDatabaseCmd(myName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + _, newCidrCmd := testDatabaseCmdConfirmed(myName, &cidr, true) + credsCap, credsCmd := testDatabaseCmd(myName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + + tests := []struct { + name string + cmd mysqlCmd + cap *capturingCmd + wantErr bool + errContains string + wantUpdate bool + checkArgs func(t *testing.T, args []string) + }{ + { + name: "resource not found", + cmd: mysqlCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: mysqlCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "cidr already present skips update", + cmd: mysqlCmd{serviceCmd: alreadyPresentCmd}, + cap: alreadyCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + if !strings.Contains(strings.Join(args, " "), myFQDN) { + t.Errorf("expected FQDN %q in args %v", myFQDN, args) + } + }, + }, + { + name: "new cidr triggers update", + cmd: mysqlCmd{serviceCmd: newCidrCmd}, + wantUpdate: true, + }, + { + name: "credentials appear in args", + cmd: mysqlCmd{serviceCmd: credsCmd}, + cap: credsCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + joined := strings.Join(args, " ") + if !strings.Contains(joined, myUser) { + t.Errorf("expected user %q in args %v", myUser, args) + } + if !strings.Contains(joined, myPass) { + t.Errorf("expected password in args %v", args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if tc.wantUpdate && !updateCalled { + t.Error("expected Update to be called for CIDR addition") + } + if !tc.wantErr && tc.checkArgs != nil { + tc.checkArgs(t, tc.cap.args) + } + if tc.wantUpdate { + my := &storage.MySQL{} + if err := apiClient.Get(t.Context(), api.ObjectName(ready), my); err != nil { + t.Fatalf("getting mysql: %v", err) + } + if !cidrsPresent(my.Spec.ForProvider.AllowedCIDRs, cidr) { + t.Errorf("expected CIDR %v to be added, got %v", cidr, my.Spec.ForProvider.AllowedCIDRs) + } + } + }) + } +} + diff --git a/exec/mysqldatabase.go b/exec/mysqldatabase.go new file mode 100644 index 00000000..76e84e3d --- /dev/null +++ b/exec/mysqldatabase.go @@ -0,0 +1,64 @@ +package exec + +import ( + "context" + "fmt" + + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type mysqlDatabaseCmd struct { + serviceCmd +} + +// Help displays usage examples for the mysqldatabase exec command. +func (cmd mysqlDatabaseCmd) Help() string { + return `Examples: + # Connect to a MySQL database interactively + nctl exec mysqldatabase mydb + + # Import a SQL dump via pipe + cat dump.sql | nctl exec mysqldatabase mydb +` +} + +// Run connects to the named MySQLDatabase resource. +func (cmd *mysqlDatabaseCmd) Run(ctx context.Context, client *api.Client) error { + db := &storage.MySQLDatabase{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), db); err != nil { + return fmt.Errorf("getting mysqldatabase %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, db, mysqlDatabaseConnector{}, cmd.serviceCmd) +} + +// mysqlDatabaseConnector implements cmdExecutor for storage.MySQLDatabase resources. +// It does not implement accessManager because the parent MySQL instance manages CIDRs. +type mysqlDatabaseConnector struct{} + +// Command returns the CLI binary name for connecting to a MySQL database. +func (mysqlDatabaseConnector) Command() string { return mysqlCommand } + +// Endpoint returns the host:port for the TCP connectivity check. +func (mysqlDatabaseConnector) Endpoint(db *storage.MySQLDatabase) string { + if db.Status.AtProvider.FQDN == "" { + return "" + } + return db.Status.AtProvider.FQDN + ":" + mysqlPort +} + +// Args returns the mysql CLI arguments for connecting to a MySQLDatabase. +// dbName is appended as a positional argument when non-empty. +func (mysqlDatabaseConnector) Args(db *storage.MySQLDatabase, user, pw string) ([]string, func(), error) { + dbName := db.Status.AtProvider.Name + if dbName == "" { + dbName = user + } + return mysqlArgs(db.Status.AtProvider.FQDN, dbName, db.Status.AtProvider.CACert, user, pw) +} diff --git a/exec/mysqldatabase_test.go b/exec/mysqldatabase_test.go new file mode 100644 index 00000000..ede065ed --- /dev/null +++ b/exec/mysqldatabase_test.go @@ -0,0 +1,100 @@ +package exec + +import ( + "context" + "strings" + "testing" + + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestMySQLDatabaseCmd(t *testing.T) { + t.Parallel() + + const ( + myDBName = "mydb" + myDBFQDN = "mydb.example.com" + myDBUser = "mydb" + myDBPass = "mydbpass" + ) + + ready := test.MySQLDatabase(myDBName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = myDBFQDN + ready.Status.AtProvider.Name = myDBName + + notReady := test.MySQLDatabase("notready", test.DefaultProject, "nine-es34") + + secret := testSecret(myDBName, test.DefaultProject, myDBUser, myDBPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", nil) + _, notReadyCmd := testDatabaseCmd("notready", nil) + connectCap, connectCmd := testDatabaseCmd(myDBName, nil) + + tests := []struct { + name string + cmd mysqlDatabaseCmd + cap *capturingCmd + wantErr bool + errContains string + checkArgs func(t *testing.T, args []string) + }{ + { + name: "resource not found", + cmd: mysqlDatabaseCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: mysqlDatabaseCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "connects without cidr management", + cmd: mysqlDatabaseCmd{serviceCmd: connectCmd}, + cap: connectCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + joined := strings.Join(args, " ") + if !strings.Contains(joined, myDBFQDN) { + t.Errorf("expected FQDN %q in args %v", myDBFQDN, args) + } + if !strings.Contains(joined, myDBName) { + t.Errorf("expected dbname %q in args %v", myDBName, args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if updateCalled { + t.Error("Update must not be called for child database resources") + } + if !tc.wantErr && tc.checkArgs != nil { + tc.checkArgs(t, tc.cap.args) + } + }) + } +} diff --git a/exec/postgres.go b/exec/postgres.go new file mode 100644 index 00000000..883e3139 --- /dev/null +++ b/exec/postgres.go @@ -0,0 +1,106 @@ +package exec + +import ( + "context" + "fmt" + "net" + "net/url" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/get" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +const ( + postgresPort = "5432" + postgresCommand = "psql" +) + +type postgresCmd struct { + serviceCmd +} + +// Help displays usage examples for the postgres exec command. +func (cmd postgresCmd) Help() string { + return `Examples: + # Connect to a PostgreSQL instance interactively + nctl exec postgres myinstance + + # Import a SQL dump via pipe + cat dump.sql | nctl exec postgres myinstance + + # Pass extra flags to psql (after --) + nctl exec postgres myinstance -- --no-pager +` +} + +func (cmd *postgresCmd) Run(ctx context.Context, client *api.Client) error { + pg := &storage.Postgres{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), pg); err != nil { + return fmt.Errorf("getting postgres %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, pg, postgresConnector{}, cmd.serviceCmd) +} + +// postgresConnector implements cmdExecutor for storage.Postgres instances. +type postgresConnector struct{} + +func (postgresConnector) Command() string { return postgresCommand } + +func (postgresConnector) Endpoint(pg *storage.Postgres) string { + if pg.Status.AtProvider.FQDN == "" { + return "" + } + return net.JoinHostPort(pg.Status.AtProvider.FQDN, postgresPort) +} + +func (postgresConnector) AllowedCIDRs(pg *storage.Postgres) []meta.IPv4CIDR { + return pg.Spec.ForProvider.AllowedCIDRs +} + +func (postgresConnector) Update(ctx context.Context, client *api.Client, pg *storage.Postgres, cidrs []meta.IPv4CIDR) error { + current := &storage.Postgres{} + if err := client.Get(ctx, api.ObjectName(pg), current); err != nil { + return err + } + current.Spec.ForProvider.AllowedCIDRs = cidrs + return client.Update(ctx, current) +} + +func (postgresConnector) Args(pg *storage.Postgres, user, pw string) ([]string, func(), error) { + return psqlArgs(pg.Status.AtProvider.FQDN, "postgres", pg.Status.AtProvider.CACert, user, pw) +} + +// psqlArgs returns the psql arguments for a given database. +func psqlArgs(fqdn, name, caCertBase64, user, pw string) ([]string, func(), error) { + caPath, cleanup, err := writeCACert(caCertBase64) + if err != nil { + return nil, func() {}, err + } + + dbName := name + if dbName == "" { + dbName = user + } + + conn := postgresConnectionStringCA(fqdn, user, dbName, []byte(pw), caPath) + return []string{conn.String()}, cleanup, nil +} + +// postgresConnectionStringCA returns a PostgreSQL connection string with CA certificate verification enabled. +func postgresConnectionStringCA(fqdn string, user string, db string, pw []byte, caPath string) *url.URL { + conn := get.PostgresConnectionString(fqdn, user, db, pw) + q := conn.Query() + q.Set("sslrootcert", caPath) + q.Set("sslmode", "verify-ca") + conn.RawQuery = q.Encode() + + return conn +} diff --git a/exec/postgres_test.go b/exec/postgres_test.go new file mode 100644 index 00000000..6629da3b --- /dev/null +++ b/exec/postgres_test.go @@ -0,0 +1,146 @@ +package exec + +import ( + "context" + "strings" + "testing" + + meta "github.com/ninech/apis/meta/v1alpha1" + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestPostgresCmd(t *testing.T) { + t.Parallel() + + const ( + pgName = "mypg" + location = "nine-es34" + fqdn = "mypg.example.com" + pgUser = "admin" + pgPass = "secret" + ) + + cidr := []meta.IPv4CIDR{"203.0.113.5/32"} + + ready := test.Postgres(pgName, test.DefaultProject, location) + ready.Status.AtProvider.FQDN = fqdn + ready.Spec.ForProvider.AllowedCIDRs = []meta.IPv4CIDR{"10.0.0.1/32"} + + notReady := test.Postgres("notready", test.DefaultProject, location) + + secret := testSecret(pgName, test.DefaultProject, pgUser, pgPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", &cidr) + _, notReadyCmd := testDatabaseCmd("notready", &cidr) + alreadyCap, alreadyPresentCmd := testDatabaseCmd(pgName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + newCidrCap, newCidrCmd := testDatabaseCmdConfirmed(pgName, &cidr, true) + credsCap, credsCmd := testDatabaseCmd(pgName, &[]meta.IPv4CIDR{"10.0.0.1/32"}) + + tests := []struct { + name string + cmd postgresCmd + cap *capturingCmd + wantErr bool + errContains string + wantUpdate bool + checkArgs func(t *testing.T, args []string) + }{ + { + name: "resource not found", + cmd: postgresCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: postgresCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "cidr already present skips update", + cmd: postgresCmd{serviceCmd: alreadyPresentCmd}, + cap: alreadyCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + if !strings.Contains(strings.Join(args, " "), fqdn) { + t.Errorf("expected FQDN %q in args %v", fqdn, args) + } + }, + }, + { + name: "new cidr triggers update", + cmd: postgresCmd{serviceCmd: newCidrCmd}, + cap: newCidrCap, + wantUpdate: true, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + if !strings.Contains(strings.Join(args, " "), fqdn) { + t.Errorf("expected FQDN %q in args %v", fqdn, args) + } + }, + }, + { + name: "credentials appear in args", + cmd: postgresCmd{serviceCmd: credsCmd}, + cap: credsCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + joined := strings.Join(args, " ") + if !strings.Contains(joined, pgUser) { + t.Errorf("expected user %q in args %v", pgUser, args) + } + if !strings.Contains(joined, pgPass) { + t.Errorf("expected password in args %v", args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if tc.wantUpdate && !updateCalled { + t.Error("expected Update to be called for CIDR addition") + } + if !tc.wantUpdate && !tc.wantErr && updateCalled { + t.Error("unexpected Update call when CIDR already present") + } + if !tc.wantErr && tc.checkArgs != nil { + tc.checkArgs(t, tc.cap.args) + } + + if tc.wantUpdate { + pg := &storage.Postgres{} + if err := apiClient.Get(t.Context(), api.ObjectName(ready), pg); err != nil { + t.Fatalf("getting postgres: %v", err) + } + if !cidrsPresent(pg.Spec.ForProvider.AllowedCIDRs, cidr) { + t.Errorf("expected CIDR %v to be added, got %v", cidr, pg.Spec.ForProvider.AllowedCIDRs) + } + } + }) + } +} + diff --git a/exec/postgresdatabase.go b/exec/postgresdatabase.go new file mode 100644 index 00000000..272df306 --- /dev/null +++ b/exec/postgresdatabase.go @@ -0,0 +1,60 @@ +package exec + +import ( + "context" + "fmt" + "net" + + storage "github.com/ninech/apis/storage/v1alpha1" + "github.com/ninech/nctl/api" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" +) + +type postgresDatabaseCmd struct { + serviceCmd +} + +// Help displays usage examples for the postgresdatabase exec command. +func (cmd postgresDatabaseCmd) Help() string { + return `Examples: + # Connect to a PostgreSQL database interactively + nctl exec postgresdatabase mydb + + # Import a SQL dump via pipe + cat dump.sql | nctl exec postgresdatabase mydb +` +} + +// Run connects to the named PostgresDatabase resource. +func (cmd *postgresDatabaseCmd) Run(ctx context.Context, client *api.Client) error { + db := &storage.PostgresDatabase{ + ObjectMeta: metav1.ObjectMeta{ + Name: cmd.Name, + Namespace: client.Project, + }, + } + if err := client.Get(ctx, client.Name(cmd.Name), db); err != nil { + return fmt.Errorf("getting postgresdatabase %q: %w", cmd.Name, err) + } + return connectAndExec(ctx, client, db, postgresDatabaseConnector{}, cmd.serviceCmd) +} + +// postgresDatabaseConnector implements cmdExecutor for storage.PostgresDatabase resources. +// It does not implement accessManager because the parent Postgres instance manages CIDRs. +type postgresDatabaseConnector struct{} + +// Command returns the CLI binary name for connecting to a PostgreSQL database. +func (postgresDatabaseConnector) Command() string { return postgresCommand } + +// Endpoint returns the host:port for the TCP connectivity check. +func (postgresDatabaseConnector) Endpoint(db *storage.PostgresDatabase) string { + if db.Status.AtProvider.FQDN == "" { + return "" + } + return net.JoinHostPort(db.Status.AtProvider.FQDN, postgresPort) +} + +// Args returns the psql CLI arguments for connecting to a PostgresDatabase. +func (postgresDatabaseConnector) Args(db *storage.PostgresDatabase, user, pw string) ([]string, func(), error) { + return psqlArgs(db.Status.AtProvider.FQDN, db.Status.AtProvider.Name, db.Status.AtProvider.CACert, user, pw) +} diff --git a/exec/postgresdatabase_test.go b/exec/postgresdatabase_test.go new file mode 100644 index 00000000..c6e352ad --- /dev/null +++ b/exec/postgresdatabase_test.go @@ -0,0 +1,100 @@ +package exec + +import ( + "context" + "strings" + "testing" + + "github.com/ninech/nctl/internal/test" + runtimeclient "sigs.k8s.io/controller-runtime/pkg/client" + "sigs.k8s.io/controller-runtime/pkg/client/interceptor" +) + +func TestPostgresDatabaseCmd(t *testing.T) { + t.Parallel() + + const ( + dbName = "mydb" + dbFQDN = "mydb.example.com" + dbUser = "mydb" + dbPass = "dbsecret" + ) + + ready := test.PostgresDatabase(dbName, test.DefaultProject, "nine-es34") + ready.Status.AtProvider.FQDN = dbFQDN + ready.Status.AtProvider.Name = dbName + + notReady := test.PostgresDatabase("notready", test.DefaultProject, "nine-es34") + + secret := testSecret(dbName, test.DefaultProject, dbUser, dbPass) + + _, notFoundCmd := testDatabaseCmd("doesnotexist", nil) + _, notReadyCmd := testDatabaseCmd("notready", nil) + connectCap, connectCmd := testDatabaseCmd(dbName, nil) + + tests := []struct { + name string + cmd postgresDatabaseCmd + cap *capturingCmd + wantErr bool + errContains string + checkArgs func(t *testing.T, args []string) + }{ + { + name: "resource not found", + cmd: postgresDatabaseCmd{serviceCmd: notFoundCmd}, + wantErr: true, + }, + { + name: "resource not ready", + cmd: postgresDatabaseCmd{serviceCmd: notReadyCmd}, + wantErr: true, + errContains: "not ready", + }, + { + name: "connects without cidr management", + cmd: postgresDatabaseCmd{serviceCmd: connectCmd}, + cap: connectCap, + checkArgs: func(t *testing.T, args []string) { + t.Helper() + joined := strings.Join(args, " ") + if !strings.Contains(joined, dbFQDN) { + t.Errorf("expected FQDN %q in args %v", dbFQDN, args) + } + if !strings.Contains(joined, dbName) { + t.Errorf("expected dbname %q in args %v", dbName, args) + } + }, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + updateCalled := false + apiClient := test.SetupClient(t, + test.WithObjects(ready, notReady, secret), + test.WithInterceptorFuncs(interceptor.Funcs{ + Update: func(ctx context.Context, c runtimeclient.WithWatch, obj runtimeclient.Object, opts ...runtimeclient.UpdateOption) error { + updateCalled = true + return c.Update(ctx, obj, opts...) + }, + }), + ) + + err := tc.cmd.Run(t.Context(), apiClient) + + if (err != nil) != tc.wantErr { + t.Fatalf("Run() error = %v, wantErr %v", err, tc.wantErr) + } + if tc.errContains != "" && (err == nil || !strings.Contains(err.Error(), tc.errContains)) { + t.Errorf("expected error containing %q, got %v", tc.errContains, err) + } + if updateCalled { + t.Error("Update must not be called for child database resources") + } + if !tc.wantErr && tc.checkArgs != nil { + tc.checkArgs(t, tc.cap.args) + } + }) + } +} diff --git a/get/apiserviceaccount.go b/get/apiserviceaccount.go index e107c287..7673a0df 100644 --- a/get/apiserviceaccount.go +++ b/get/apiserviceaccount.go @@ -144,7 +144,7 @@ func (cmd *apiServiceAccountsCmd) printSecret( key string, out *output, ) error { - data, err := getConnectionSecret(ctx, client, key, sa) + data, err := connectionSecret(ctx, client, key, sa) if err != nil { return err } diff --git a/get/bucketuser.go b/get/bucketuser.go index 681fbda7..79b49975 100644 --- a/get/bucketuser.go +++ b/get/bucketuser.go @@ -107,7 +107,7 @@ func (cmd *bucketUserCmd) printSecret( key string, out *output, ) error { - data, err := getConnectionSecret(ctx, client, key, user) + data, err := connectionSecret(ctx, client, key, user) if err != nil { return err } diff --git a/get/database.go b/get/database.go index 67d0fd6c..30cb35a6 100644 --- a/get/database.go +++ b/get/database.go @@ -47,7 +47,7 @@ func (cmd *databaseCmd) run(ctx context.Context, client *api.Client, get *Cmd, } if cmd.Name != "" && cmd.PrintConnectionString { - secrets, err := getConnectionSecretMap(ctx, client, databaseResources.GetItems()[0]) + secrets, err := ConnectionSecretMap(ctx, client, databaseResources.GetItems()[0]) if err != nil { return err } @@ -66,7 +66,7 @@ func (cmd *databaseCmd) run(ctx context.Context, client *api.Client, get *Cmd, if err != nil { return err } - return printBase64(&get.Writer, ca) + return WriteBase64(&get.Writer, ca) } switch get.Format { diff --git a/get/get.go b/get/get.go index 2bf8d37e..ee3edc95 100644 --- a/get/get.go +++ b/get/get.go @@ -178,7 +178,7 @@ func (out *output) notFound(kind, project string) error { return err } -func getConnectionSecretMap(ctx context.Context, client *api.Client, mg resource.Managed) (map[string][]byte, error) { +func ConnectionSecretMap(ctx context.Context, client *api.Client, mg resource.Managed) (map[string][]byte, error) { secret, err := client.GetConnectionSecret(ctx, mg) if err != nil { return nil, err @@ -187,8 +187,8 @@ func getConnectionSecretMap(ctx context.Context, client *api.Client, mg resource return secret.Data, nil } -func getConnectionSecret(ctx context.Context, client *api.Client, key string, mg resource.Managed) (string, error) { - secrets, err := getConnectionSecretMap(ctx, client, mg) +func connectionSecret(ctx context.Context, client *api.Client, key string, mg resource.Managed) (string, error) { + secrets, err := ConnectionSecretMap(ctx, client, mg) if err != nil { return "", fmt.Errorf("unable to get connection secret: %w", err) } @@ -208,7 +208,7 @@ func (cmd *resourceCmd) printSecret( out *output, field func(string, string) string, ) error { - secrets, err := getConnectionSecretMap(ctx, client, mg) + secrets, err := ConnectionSecretMap(ctx, client, mg) if err != nil { return err } @@ -227,7 +227,7 @@ func (cmd *resourceCmd) printCredentials( out *output, filter func(key string) bool, ) error { - data, err := getConnectionSecretMap(ctx, client, mg) + data, err := ConnectionSecretMap(ctx, client, mg) if err != nil { return err } @@ -264,7 +264,7 @@ func (cmd *resourceCmd) printCredentials( return nil } -func printBase64(out io.Writer, s string) error { +func WriteBase64(out io.Writer, s string) error { s = strings.TrimSpace(s) if s == "" { return nil diff --git a/get/keyvaluestore.go b/get/keyvaluestore.go index a541a154..59bd1f22 100644 --- a/get/keyvaluestore.go +++ b/get/keyvaluestore.go @@ -37,7 +37,7 @@ func (cmd *keyValueStoreCmd) print(ctx context.Context, client *api.Client, list return cmd.printSecret(ctx, client, &keyValueStoreList.Items[0], out, func(_, pw string) string { return pw }) } if cmd.Name != "" && cmd.PrintCACert { - return printBase64(&out.Writer, keyValueStoreList.Items[0].Status.AtProvider.CACert) + return WriteBase64(&out.Writer, keyValueStoreList.Items[0].Status.AtProvider.CACert) } switch out.Format { diff --git a/get/opensearch.go b/get/opensearch.go index 273a3aa1..1faa6772 100644 --- a/get/opensearch.go +++ b/get/opensearch.go @@ -63,7 +63,7 @@ func (cmd *openSearchCmd) print( } if cmd.Name != "" && cmd.PrintCACert { - return printBase64(&out.Writer, openSearchList.Items[0].Status.AtProvider.CACert) + return WriteBase64(&out.Writer, openSearchList.Items[0].Status.AtProvider.CACert) } if cmd.Name != "" && cmd.PrintSnapshotBucket { diff --git a/get/postgres.go b/get/postgres.go index 7ebe0885..bea1dbab 100644 --- a/get/postgres.go +++ b/get/postgres.go @@ -65,7 +65,7 @@ func (cmd *postgresCmd) connectionString(mg resource.Managed, secrets map[string } for user, pw := range secrets { - return postgresConnectionString(my.Status.AtProvider.FQDN, user, "postgres", pw), nil + return PostgresConnectionString(my.Status.AtProvider.FQDN, user, "postgres", pw).String(), nil } return "", nil diff --git a/get/postgresdatabase.go b/get/postgresdatabase.go index 04c22099..cd0dbc9e 100644 --- a/get/postgresdatabase.go +++ b/get/postgresdatabase.go @@ -66,21 +66,25 @@ func (cmd *postgresDatabaseCmd) connectionString(mg resource.Managed, secrets ma } for user, pw := range secrets { - return postgresConnectionString(my.Status.AtProvider.FQDN, user, user, pw), nil + return PostgresConnectionString(my.Status.AtProvider.FQDN, user, user, pw).String(), nil } return "", nil } -// postgresConnectionString according to the PostgreSQL documentation: +// PostgresConnectionString according to the PostgreSQL documentation: // https://www.postgresql.org/docs/current/libpq-connect.html#LIBPQ-CONNSTRING -func postgresConnectionString(fqdn, user, db string, pw []byte) string { +func PostgresConnectionString(fqdn, user, db string, pw []byte) *url.URL { + q := url.Values{} + q.Set("sslmode", "require") + u := &url.URL{ - Scheme: "postgres", - Host: fqdn, - User: url.UserPassword(user, string(pw)), - Path: db, + Scheme: "postgres", + Host: fqdn, + User: url.UserPassword(user, string(pw)), + Path: db, + RawQuery: q.Encode(), } - return u.String() + return u } diff --git a/internal/ipcheck/client.go b/internal/ipcheck/client.go new file mode 100644 index 00000000..6f92cd17 --- /dev/null +++ b/internal/ipcheck/client.go @@ -0,0 +1,152 @@ +// Package ipcheck provides a client for detecting the caller's public IP address. +package ipcheck + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "net/http" + "net/netip" + "net/url" + "sync" + "time" +) + +// ErrStatus represents a non success status code error. +var ErrStatus = errors.New("status code error") + +const ( + // defaultTimeout is the default HTTP timeout. + defaultTimeout = 5 * time.Second + // defaultURL is the default endpoint to query. + defaultURL = "https://ip-ban-check.nine.ch/" +) + +// defaultClient returns the default Client instance. +var defaultClient = sync.OnceValue(func() *Client { + return New() +}) + +// PublicIP returns the caller's public IP address as reported by the endpoint. +func PublicIP(ctx context.Context) (*Response, error) { + return defaultClient().PublicIP(ctx) +} + +// Client fetches the caller's public IP address from Nine's IP check endpoint. +type Client struct { + // httpClient is the HTTP client to use. If nil, a default client with a 5s timeout is used. + httpClient *http.Client + // userAgent is the value to set in the User-Agent header. + userAgent string + // url is the endpoint to query. Defaults to https://ip-ban-check.nine.ch/. + url *url.URL +} + +// Response is the JSON response from the IP check endpoint. +type Response struct { + Blocked bool `json:"blocked"` + RemoteAddr netip.Addr `json:"remoteAddr"` +} + +// Option is a function that configures a Client. +type Option func(*Client) + +// WithHTTPClient configures the HTTP client to use. +func WithHTTPClient(client *http.Client) Option { + return func(c *Client) { + c.httpClient = client + } +} + +// WithUserAgent configures the User-Agent header to use. +func WithUserAgent(userAgent string) Option { + return func(c *Client) { + c.userAgent = userAgent + } +} + +// WithURL configures the endpoint URL to query. +func WithURL(url *url.URL) Option { + return func(c *Client) { + c.url = url + } +} + +// New creates a new Client with the given options. +func New(options ...Option) *Client { + u, _ := url.Parse(defaultURL) + c := &Client{ + url: u, + httpClient: &http.Client{Timeout: defaultTimeout}, + } + + for _, opt := range options { + opt(c) + } + + return c +} + +// PublicIP returns the caller's public IP address as reported by the endpoint. +func (c *Client) PublicIP(ctx context.Context) (*Response, error) { + req, err := c.newRequest(ctx, http.MethodGet) + if err != nil { + return nil, fmt.Errorf("creating request: %w", err) + } + + result := Response{} + if _, err := c.doJSON(req, &result); err != nil { + return nil, fmt.Errorf("decoding IP check response: %w", err) + } + + return &result, nil +} + +// newRequest creates a new HTTP request with the given method and URL. +func (c *Client) newRequest(ctx context.Context, method string) (*http.Request, error) { + req, err := http.NewRequestWithContext(ctx, method, c.url.String(), nil) + if err != nil { + return nil, err + } + + req.Header.Set("Accept", "application/json") + if c.userAgent != "" { + req.Header.Set("User-Agent", c.userAgent) + } + return req, nil +} + +// doJSON sends the given request and decodes the response into v. +func (c *Client) doJSON(req *http.Request, v any) (*http.Response, error) { + resp, err := c.do(req) + if err != nil { + return resp, err + } + if resp != nil && resp.Body != nil { + defer resp.Body.Close() + } + + err = json.NewDecoder(resp.Body).Decode(&v) + + return resp, err +} + +// do sends the given request and returns the response. +func (c *Client) do(req *http.Request) (*http.Response, error) { + resp, err := c.httpClient.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode != http.StatusOK { + return resp, fmt.Errorf( + "%s: %d, %w", + http.StatusText(resp.StatusCode), + resp.StatusCode, + ErrStatus, + ) + } + + return resp, err +} diff --git a/internal/ipcheck/client_test.go b/internal/ipcheck/client_test.go new file mode 100644 index 00000000..aeac645a --- /dev/null +++ b/internal/ipcheck/client_test.go @@ -0,0 +1,72 @@ +package ipcheck_test + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "net/netip" + "net/url" + "testing" + + "github.com/ninech/nctl/internal/ipcheck" +) + +func TestClient_PublicIP(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + response ipcheck.Response + statusCode int + wantIP netip.Addr + wantErr bool + }{ + { + name: "returns remote addr", + response: ipcheck.Response{Blocked: false, RemoteAddr: netip.MustParseAddr("203.0.113.1")}, + statusCode: http.StatusOK, + wantIP: netip.MustParseAddr("203.0.113.1"), + }, + { + name: "server error", + statusCode: http.StatusInternalServerError, + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Accept") != "application/json" { + t.Errorf("expected Accept: application/json, got %q", r.Header.Get("Accept")) + } + w.WriteHeader(tc.statusCode) + if tc.statusCode == http.StatusOK { + _ = json.NewEncoder(w).Encode(tc.response) + } + })) + defer srv.Close() + + srvURL, err := url.Parse(srv.URL) + if err != nil { + t.Fatalf("parsing URL %q: %v", srv.URL, err) + } + + c := ipcheck.New( + ipcheck.WithURL(srvURL), + ipcheck.WithHTTPClient(srv.Client()), + ipcheck.WithUserAgent("nctl-test"), + ) + + got, err := c.PublicIP(t.Context()) + if (err != nil) != tc.wantErr { + t.Fatalf("PublicIP() error = %v, wantErr %v", err, tc.wantErr) + } + if !tc.wantErr && got.RemoteAddr.Compare(tc.wantIP) != 0 { + t.Errorf("PublicIP() = %q, want %q", got.RemoteAddr.String(), tc.wantIP) + } + }) + } +} diff --git a/main.go b/main.go index a9c63b42..2d9362bd 100644 --- a/main.go +++ b/main.go @@ -162,8 +162,7 @@ func main() { } } - var cliErr *cli.Error - if errors.As(err, &cliErr) { + if cliErr, ok := errors.AsType[*cli.Error](err); ok { fmt.Fprintln(writer, err.Error()) kongCtx.Exit(cliErr.ExitCode()) return