diff --git a/cmd/authctl/group/group_test.go b/cmd/authctl/group/group_test.go index 34d3093521..8fb4d7559a 100644 --- a/cmd/authctl/group/group_test.go +++ b/cmd/authctl/group/group_test.go @@ -32,6 +32,7 @@ func TestGroupCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"group"}, tc.args...)...) + cmd.Env = []string{testutils.CoverDirEnv()} testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/group/set-gid_test.go b/cmd/authctl/group/set-gid_test.go index d26382f1df..aa9533d758 100644 --- a/cmd/authctl/group/set-gid_test.go +++ b/cmd/authctl/group/set-gid_test.go @@ -2,12 +2,12 @@ package group_test import ( "math" - "os" "os/exec" "path/filepath" "strconv" "testing" + "github.com/canonical/authd/internal/envutils" "github.com/canonical/authd/internal/testutils" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -26,8 +26,10 @@ func TestSetGIDCommand(t *testing.T) { testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -69,18 +71,17 @@ func TestSetGIDCommand(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // Copy authctlEnv to avoid modifying the original slice. + authctlEnv := append([]string{}, authctlEnv...) if tc.authdUnavailable { - origValue := os.Getenv("AUTHD_SOCKET") - err := os.Setenv("AUTHD_SOCKET", "/non-existent") + var err error + authctlEnv, err = envutils.Setenv(authctlEnv, "AUTHD_SOCKET", "/non-existent") require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") - t.Cleanup(func() { - err := os.Setenv("AUTHD_SOCKET", origValue) - require.NoError(t, err, "Failed to restore AUTHD_SOCKET environment variable") - }) } //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"group"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/internal/client/client.go b/cmd/authctl/internal/client/client.go index fff6efe7ff..1c388e72e0 100644 --- a/cmd/authctl/internal/client/client.go +++ b/cmd/authctl/internal/client/client.go @@ -27,7 +27,7 @@ func NewUserServiceClient() (authd.UserServiceClient, error) { conn, err := grpc.NewClient(authdSocket, grpc.WithTransportCredentials(insecure.NewCredentials())) if err != nil { - return nil, fmt.Errorf("failed to connect to authd: %w", err) + return nil, fmt.Errorf("failed to create gRPC client: %w", err) } client := authd.NewUserServiceClient(conn) diff --git a/cmd/authctl/main_test.go b/cmd/authctl/main_test.go index 52c2a0ee62..278286fac7 100644 --- a/cmd/authctl/main_test.go +++ b/cmd/authctl/main_test.go @@ -33,6 +33,7 @@ func TestRootCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, tc.args...) + cmd.Env = []string{testutils.CoverDirEnv()} testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/lock_test.go b/cmd/authctl/user/lock_test.go index eb89ac349a..0828bf9165 100644 --- a/cmd/authctl/user/lock_test.go +++ b/cmd/authctl/user/lock_test.go @@ -1,13 +1,11 @@ package user_test import ( - "os" "os/exec" "path/filepath" "testing" "github.com/canonical/authd/internal/testutils" - "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" ) @@ -20,8 +18,10 @@ func TestUserLockCommand(t *testing.T) { testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -38,6 +38,7 @@ func TestUserLockCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/set-shell.go b/cmd/authctl/user/set-shell.go new file mode 100644 index 0000000000..0e8e3e8778 --- /dev/null +++ b/cmd/authctl/user/set-shell.go @@ -0,0 +1,52 @@ +package user + +import ( + "context" + + "github.com/canonical/authd/cmd/authctl/internal/client" + "github.com/canonical/authd/cmd/authctl/internal/completion" + "github.com/canonical/authd/cmd/authctl/internal/log" + "github.com/canonical/authd/internal/proto/authd" + "github.com/spf13/cobra" +) + +var setShellCmd = &cobra.Command{ + Use: "set-shell ", + Short: "Set the login shell for a user", + Args: cobra.ExactArgs(2), + ValidArgsFunction: setShellCompletionFunc, + RunE: runSetShell, +} + +func runSetShell(cmd *cobra.Command, args []string) error { + name := args[0] + shell := args[1] + + svc, err := client.NewUserServiceClient() + if err != nil { + return err + } + + resp, err := svc.SetShell(context.Background(), &authd.SetShellRequest{ + Name: name, + Shell: shell, + }) + if resp == nil { + return err + } + + // Print any warnings returned by the server. + for _, warning := range resp.Warnings { + log.Warning(warning) + } + + return err +} + +func setShellCompletionFunc(cmd *cobra.Command, args []string, toComplete string) ([]string, cobra.ShellCompDirective) { + if len(args) == 0 { + return completion.Users(cmd, args, toComplete) + } + + return nil, cobra.ShellCompDirectiveNoFileComp +} diff --git a/cmd/authctl/user/set-shell_test.go b/cmd/authctl/user/set-shell_test.go new file mode 100644 index 0000000000..d6110dd64a --- /dev/null +++ b/cmd/authctl/user/set-shell_test.go @@ -0,0 +1,49 @@ +package user_test + +import ( + "os/exec" + "path/filepath" + "testing" + + "github.com/canonical/authd/internal/testutils" + "google.golang.org/grpc/codes" +) + +func TestSetShellCommand(t *testing.T) { + t.Parallel() + + daemonSocket := testutils.StartAuthd(t, daemonPath, + testutils.WithGroupFile(filepath.Join("testdata", "empty.group")), + testutils.WithPreviousDBState("one_user_and_group"), + testutils.WithCurrentUserAsRoot, + ) + + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } + + tests := map[string]struct { + args []string + + expectedExitCode int + }{ + "Set_shell_success": {args: []string{"set-shell", "user1", "/bin/bash"}, expectedExitCode: 0}, + + "Error_when_user_does_not_exist": { + args: []string{"set-shell", "invaliduser", "/bin/bash"}, + expectedExitCode: int(codes.NotFound), + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + //nolint:gosec // G204 it's safe to use exec.Command with a variable here + cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = authctlEnv + testutils.CheckCommand(t, cmd, tc.expectedExitCode) + }) + } +} diff --git a/cmd/authctl/user/set-uid_test.go b/cmd/authctl/user/set-uid_test.go index e4a2ed0765..6284330d1f 100644 --- a/cmd/authctl/user/set-uid_test.go +++ b/cmd/authctl/user/set-uid_test.go @@ -2,12 +2,12 @@ package user_test import ( "math" - "os" "os/exec" "path/filepath" "strconv" "testing" + "github.com/canonical/authd/internal/envutils" "github.com/canonical/authd/internal/testutils" "github.com/stretchr/testify/require" "google.golang.org/grpc/codes" @@ -19,15 +19,16 @@ func TestSetUIDCommand(t *testing.T) { // which makes userslocking.WriteLock() return an error immediately when the lock // is already held - unlike the normal behavior which tries to acquire the lock // for 15 seconds before returning an error. - daemonSocket := testutils.StartAuthd(t, daemonPath, testutils.WithGroupFile(filepath.Join("testdata", "empty.group")), testutils.WithPreviousDBState("one_user_and_group"), testutils.WithCurrentUserAsRoot, ) - err := os.Setenv("AUTHD_SOCKET", daemonSocket) - require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") + authctlEnv := []string{ + "AUTHD_SOCKET=" + daemonSocket, + testutils.CoverDirEnv(), + } tests := map[string]struct { args []string @@ -69,18 +70,17 @@ func TestSetUIDCommand(t *testing.T) { for name, tc := range tests { t.Run(name, func(t *testing.T) { + // Copy authctlEnv to avoid modifying the original slice. + authctlEnv := append([]string{}, authctlEnv...) if tc.authdUnavailable { - origValue := os.Getenv("AUTHD_SOCKET") - err := os.Setenv("AUTHD_SOCKET", "/non-existent") + var err error + authctlEnv, err = envutils.Setenv(authctlEnv, "AUTHD_SOCKET", "/non-existent") require.NoError(t, err, "Failed to set AUTHD_SOCKET environment variable") - t.Cleanup(func() { - err := os.Setenv("AUTHD_SOCKET", origValue) - require.NoError(t, err, "Failed to restore AUTHD_SOCKET environment variable") - }) } //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = authctlEnv testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist b/cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist new file mode 100644 index 0000000000..93dd7dd5ff --- /dev/null +++ b/cmd/authctl/user/testdata/golden/TestSetShellCommand/Error_when_user_does_not_exist @@ -0,0 +1 @@ +Error: user "invaliduser" not found diff --git a/cmd/authctl/user/testdata/golden/TestSetShellCommand/Set_shell_success b/cmd/authctl/user/testdata/golden/TestSetShellCommand/Set_shell_success new file mode 100644 index 0000000000..e69de29bb2 diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command index a66b03e2c9..b6b2bbed64 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_command @@ -6,6 +6,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag index 3408cec6d8..beeb0304bd 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Error_on_invalid_flag @@ -6,6 +6,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag b/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag index ee4765c1cc..7a303d9166 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Help_flag @@ -8,6 +8,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args b/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args index d83ced5684..c90581f34f 100644 --- a/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args +++ b/cmd/authctl/user/testdata/golden/TestUserCommand/Usage_message_when_no_args @@ -6,6 +6,7 @@ Available Commands: lock Lock (disable) a user managed by authd unlock Unlock (enable) a user managed by authd set-uid Set the UID of a user managed by authd + set-shell Set the login shell for a user Flags: -h, --help help for user diff --git a/cmd/authctl/user/user.go b/cmd/authctl/user/user.go index 109743aa17..e4e218b230 100644 --- a/cmd/authctl/user/user.go +++ b/cmd/authctl/user/user.go @@ -17,4 +17,5 @@ func init() { UserCmd.AddCommand(lockCmd) UserCmd.AddCommand(unlockCmd) UserCmd.AddCommand(setUIDCmd) + UserCmd.AddCommand(setShellCmd) } diff --git a/cmd/authctl/user/user_test.go b/cmd/authctl/user/user_test.go index 5010e8cc74..1523bd2d67 100644 --- a/cmd/authctl/user/user_test.go +++ b/cmd/authctl/user/user_test.go @@ -32,6 +32,7 @@ func TestUserCommand(t *testing.T) { //nolint:gosec // G204 it's safe to use exec.Command with a variable here cmd := exec.Command(authctlPath, append([]string{"user"}, tc.args...)...) + cmd.Env = []string{testutils.CoverDirEnv()} testutils.CheckCommand(t, cmd, tc.expectedExitCode) }) } diff --git a/docs/reference/cli/authctl_user.md b/docs/reference/cli/authctl_user.md index 3a6e5606ae..e6163ff3b2 100644 --- a/docs/reference/cli/authctl_user.md +++ b/docs/reference/cli/authctl_user.md @@ -16,6 +16,7 @@ authctl user [flags] * [authctl](authctl.md) - CLI tool to interact with authd * [authctl user lock](authctl_user_lock.md) - Lock (disable) a user managed by authd +* [authctl user set-shell](authctl_user_set-shell.md) - Set the login shell for a user * [authctl user set-uid](authctl_user_set-uid.md) - Set the UID of a user managed by authd * [authctl user unlock](authctl_user_unlock.md) - Unlock (enable) a user managed by authd diff --git a/docs/reference/cli/authctl_user_set-shell.md b/docs/reference/cli/authctl_user_set-shell.md new file mode 100644 index 0000000000..2668479223 --- /dev/null +++ b/docs/reference/cli/authctl_user_set-shell.md @@ -0,0 +1,18 @@ +## authctl user set-shell + +Set the login shell for a user + +``` +authctl user set-shell [flags] +``` + +### Options + +``` + -h, --help help for set-shell +``` + +### SEE ALSO + +* [authctl user](authctl_user.md) - Commands related to users + diff --git a/docs/reference/cli/index.md b/docs/reference/cli/index.md index c9f8a20962..ff0686ff63 100644 --- a/docs/reference/cli/index.md +++ b/docs/reference/cli/index.md @@ -21,6 +21,7 @@ authctl_user authctl_user_lock authctl_user_unlock authctl_user_set-uid +authctl_user_set-shell ``` ```{toctree} diff --git a/internal/envutils/envutils.go b/internal/envutils/envutils.go new file mode 100644 index 0000000000..9f24c64129 --- /dev/null +++ b/internal/envutils/envutils.go @@ -0,0 +1,46 @@ +// Package envutils provides utilities for manipulating string slices representing environment variables. +package envutils + +import ( + "errors" + "fmt" + "strings" +) + +// Getenv retrieves the value of an environment variable from a slice of strings. +func Getenv(env []string, key string) string { + for _, kv := range env { + if strings.HasPrefix(kv, key+"=") { + return strings.TrimPrefix(kv, key+"=") + } + } + return "" +} + +// Setenv sets an environment variable in a slice of strings. +func Setenv(env []string, key, value string) ([]string, error) { + if len(key) == 0 { + return nil, errors.New("empty key") + } + if strings.ContainsAny(key, "="+"\x00") { + return nil, fmt.Errorf("invalid key: %q", key) + } + if strings.ContainsRune(value, '\x00') { + return nil, fmt.Errorf("invalid value: %q", value) + } + + kv := fmt.Sprintf("%s=%s", key, value) + + // Check if the key is already set + for i, kvPair := range env { + if strings.HasPrefix(kvPair, key+"=") { + // Key exists, update the value + env[i] = kv + return env, nil + } + } + + // Key is not set yet, append it + env = append(env, kv) + return env, nil +} diff --git a/internal/envutils/envutils_test.go b/internal/envutils/envutils_test.go new file mode 100644 index 0000000000..25c7108972 --- /dev/null +++ b/internal/envutils/envutils_test.go @@ -0,0 +1,241 @@ +package envutils_test + +import ( + "testing" + + "github.com/canonical/authd/internal/envutils" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestGetenv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + env []string + key string + want string + }{ + "Get_existing_environment_variable": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "FOO", + want: "bar", + }, + "Get_environment_variable_with_empty_value": { + env: []string{"FOO=bar", "EMPTY=", "BAZ=qux"}, + key: "EMPTY", + want: "", + }, + "Get_environment_variable_with_special_characters": { + env: []string{"PATH=/usr/bin:/usr/local/bin"}, + key: "PATH", + want: "/usr/bin:/usr/local/bin", + }, + "Get_environment_variable_with_spaces": { + env: []string{"MESSAGE=hello world"}, + key: "MESSAGE", + want: "hello world", + }, + "Get_environment_variable_with_equals_sign_in_value": { + env: []string{"EQUATION=x=y+z"}, + key: "EQUATION", + want: "x=y+z", + }, + "Get_first_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "FIRST", + want: "1", + }, + "Get_middle_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "SECOND", + want: "2", + }, + "Get_last_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "THIRD", + want: "3", + }, + "Return_empty_string_when_key_not_found": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "MISSING", + want: "", + }, + "Return_empty_string_when_key_not_found_in_empty_environment": { + env: []string{}, + key: "VAR", + want: "", + }, + "Return_empty_string_when_looking_for_partial_key_match": { + env: []string{"FOOBAR=baz"}, + key: "FOO", + want: "", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got := envutils.Getenv(tc.env, tc.key) + assert.Equal(t, tc.want, got, "Value should match expected") + }) + } +} + +func TestSetenv(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + env []string + key string + value string + want []string + wantErr bool + errContains string + }{ + "Set_new_environment_variable": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "NEW_VAR", + value: "new_value", + want: []string{"FOO=bar", "BAZ=qux", "NEW_VAR=new_value"}, + }, + "Update_existing_environment_variable": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "FOO", + value: "updated", + want: []string{"FOO=updated", "BAZ=qux"}, + }, + "Set_variable_in_empty_environment": { + env: []string{}, + key: "VAR", + value: "value", + want: []string{"VAR=value"}, + }, + "Set_variable_with_empty_value": { + env: []string{"FOO=bar"}, + key: "EMPTY", + value: "", + want: []string{"FOO=bar", "EMPTY="}, + }, + "Update_variable_to_empty_value": { + env: []string{"FOO=bar", "BAZ=qux"}, + key: "FOO", + value: "", + want: []string{"FOO=", "BAZ=qux"}, + }, + "Set_variable_with_special_characters_in_value": { + env: []string{}, + key: "PATH", + value: "/usr/bin:/usr/local/bin", + want: []string{"PATH=/usr/bin:/usr/local/bin"}, + }, + "Set_variable_with_spaces_in_value": { + env: []string{}, + key: "MESSAGE", + value: "hello world", + want: []string{"MESSAGE=hello world"}, + }, + "Update_first_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "FIRST", + value: "updated", + want: []string{"FIRST=updated", "SECOND=2", "THIRD=3"}, + }, + "Update_middle_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "SECOND", + value: "updated", + want: []string{"FIRST=1", "SECOND=updated", "THIRD=3"}, + }, + "Update_last_variable_in_list": { + env: []string{"FIRST=1", "SECOND=2", "THIRD=3"}, + key: "THIRD", + value: "updated", + want: []string{"FIRST=1", "SECOND=2", "THIRD=updated"}, + }, + + // Error cases + "Error_on_empty_key": { + env: []string{"FOO=bar"}, + key: "", + value: "value", + wantErr: true, + errContains: "empty key", + }, + "Error_on_key_with_equals_sign": { + env: []string{"FOO=bar"}, + key: "KEY=VALUE", + value: "value", + wantErr: true, + errContains: "invalid key", + }, + "Error_on_key_with_null_byte": { + env: []string{"FOO=bar"}, + key: "KEY\x00", + value: "value", + wantErr: true, + errContains: "invalid key", + }, + "Error_on_value_with_null_byte": { + env: []string{"FOO=bar"}, + key: "KEY", + value: "value\x00", + wantErr: true, + errContains: "invalid value", + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + got, err := envutils.Setenv(tc.env, tc.key, tc.value) + + if tc.wantErr { + require.Error(t, err, "Setenv should return an error") + assert.Contains(t, err.Error(), tc.errContains, "Error message should contain expected text") + return + } + + require.NoError(t, err, "Setenv should not return an error") + assert.Equal(t, tc.want, got, "Environment slice should match expected") + }) + } +} + +func TestSetenvDoesNotModifyOriginal(t *testing.T) { + t.Parallel() + + original := []string{"FOO=bar", "BAZ=qux"} + originalCopy := make([]string, len(original)) + copy(originalCopy, original) + + result, err := envutils.Setenv(original, "NEW", "value") + require.NoError(t, err) + + // Verify original slice content is unchanged (but may have increased capacity) + assert.Equal(t, originalCopy, original[:len(originalCopy)], "Original slice content should not be modified") + // Verify result contains the new variable + assert.Contains(t, result, "NEW=value", "Result should contain new variable") +} + +func TestSetenvPreservesOrder(t *testing.T) { + t.Parallel() + + // Update a middle variable + env1 := []string{"A=1", "B=2", "C=3", "D=4", "E=5"} + result, err := envutils.Setenv(env1, "C", "updated") + require.NoError(t, err) + + expected := []string{"A=1", "B=2", "C=updated", "D=4", "E=5"} + assert.Equal(t, expected, result, "Order should be preserved when updating") + + // Add a new variable + env2 := []string{"A=1", "B=2", "C=3", "D=4", "E=5"} + result2, err := envutils.Setenv(env2, "F", "6") + require.NoError(t, err) + + expected2 := []string{"A=1", "B=2", "C=3", "D=4", "E=5", "F=6"} + assert.Equal(t, expected2, result2, "New variable should be appended at the end") +} diff --git a/internal/proto/authd/authd.pb.go b/internal/proto/authd/authd.pb.go index 556a17b4ac..5351978afb 100644 --- a/internal/proto/authd/authd.pb.go +++ b/internal/proto/authd/authd.pb.go @@ -1413,6 +1413,102 @@ func (x *SetGroupIDResponse) GetWarnings() []string { return nil } +type SetShellRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` + Shell string `protobuf:"bytes,2,opt,name=shell,proto3" json:"shell,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetShellRequest) Reset() { + *x = SetShellRequest{} + mi := &file_authd_proto_msgTypes[26] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetShellRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetShellRequest) ProtoMessage() {} + +func (x *SetShellRequest) ProtoReflect() protoreflect.Message { + mi := &file_authd_proto_msgTypes[26] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetShellRequest.ProtoReflect.Descriptor instead. +func (*SetShellRequest) Descriptor() ([]byte, []int) { + return file_authd_proto_rawDescGZIP(), []int{26} +} + +func (x *SetShellRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *SetShellRequest) GetShell() string { + if x != nil { + return x.Shell + } + return "" +} + +type SetShellResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Warnings []string `protobuf:"bytes,1,rep,name=warnings,proto3" json:"warnings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *SetShellResponse) Reset() { + *x = SetShellResponse{} + mi := &file_authd_proto_msgTypes[27] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *SetShellResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*SetShellResponse) ProtoMessage() {} + +func (x *SetShellResponse) ProtoReflect() protoreflect.Message { + mi := &file_authd_proto_msgTypes[27] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use SetShellResponse.ProtoReflect.Descriptor instead. +func (*SetShellResponse) Descriptor() ([]byte, []int) { + return file_authd_proto_rawDescGZIP(), []int{27} +} + +func (x *SetShellResponse) GetWarnings() []string { + if x != nil { + return x.Warnings + } + return nil +} + type User struct { state protoimpl.MessageState `protogen:"open.v1"` Name string `protobuf:"bytes,1,opt,name=name,proto3" json:"name,omitempty"` @@ -1427,7 +1523,7 @@ type User struct { func (x *User) Reset() { *x = User{} - mi := &file_authd_proto_msgTypes[26] + mi := &file_authd_proto_msgTypes[28] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1439,7 +1535,7 @@ func (x *User) String() string { func (*User) ProtoMessage() {} func (x *User) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[26] + mi := &file_authd_proto_msgTypes[28] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1452,7 +1548,7 @@ func (x *User) ProtoReflect() protoreflect.Message { // Deprecated: Use User.ProtoReflect.Descriptor instead. func (*User) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{26} + return file_authd_proto_rawDescGZIP(), []int{28} } func (x *User) GetName() string { @@ -1506,7 +1602,7 @@ type Users struct { func (x *Users) Reset() { *x = Users{} - mi := &file_authd_proto_msgTypes[27] + mi := &file_authd_proto_msgTypes[29] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1518,7 +1614,7 @@ func (x *Users) String() string { func (*Users) ProtoMessage() {} func (x *Users) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[27] + mi := &file_authd_proto_msgTypes[29] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1531,7 +1627,7 @@ func (x *Users) ProtoReflect() protoreflect.Message { // Deprecated: Use Users.ProtoReflect.Descriptor instead. func (*Users) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{27} + return file_authd_proto_rawDescGZIP(), []int{29} } func (x *Users) GetUsers() []*User { @@ -1554,7 +1650,7 @@ type Group struct { func (x *Group) Reset() { *x = Group{} - mi := &file_authd_proto_msgTypes[28] + mi := &file_authd_proto_msgTypes[30] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1566,7 +1662,7 @@ func (x *Group) String() string { func (*Group) ProtoMessage() {} func (x *Group) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[28] + mi := &file_authd_proto_msgTypes[30] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1579,7 +1675,7 @@ func (x *Group) ProtoReflect() protoreflect.Message { // Deprecated: Use Group.ProtoReflect.Descriptor instead. func (*Group) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{28} + return file_authd_proto_rawDescGZIP(), []int{30} } func (x *Group) GetName() string { @@ -1619,7 +1715,7 @@ type Groups struct { func (x *Groups) Reset() { *x = Groups{} - mi := &file_authd_proto_msgTypes[29] + mi := &file_authd_proto_msgTypes[31] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1631,7 +1727,7 @@ func (x *Groups) String() string { func (*Groups) ProtoMessage() {} func (x *Groups) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[29] + mi := &file_authd_proto_msgTypes[31] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1644,7 +1740,7 @@ func (x *Groups) ProtoReflect() protoreflect.Message { // Deprecated: Use Groups.ProtoReflect.Descriptor instead. func (*Groups) Descriptor() ([]byte, []int) { - return file_authd_proto_rawDescGZIP(), []int{29} + return file_authd_proto_rawDescGZIP(), []int{31} } func (x *Groups) GetGroups() []*Group { @@ -1665,7 +1761,7 @@ type ABResponse_BrokerInfo struct { func (x *ABResponse_BrokerInfo) Reset() { *x = ABResponse_BrokerInfo{} - mi := &file_authd_proto_msgTypes[30] + mi := &file_authd_proto_msgTypes[32] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1677,7 +1773,7 @@ func (x *ABResponse_BrokerInfo) String() string { func (*ABResponse_BrokerInfo) ProtoMessage() {} func (x *ABResponse_BrokerInfo) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[30] + mi := &file_authd_proto_msgTypes[32] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1724,7 +1820,7 @@ type GAMResponse_AuthenticationMode struct { func (x *GAMResponse_AuthenticationMode) Reset() { *x = GAMResponse_AuthenticationMode{} - mi := &file_authd_proto_msgTypes[31] + mi := &file_authd_proto_msgTypes[33] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1736,7 +1832,7 @@ func (x *GAMResponse_AuthenticationMode) String() string { func (*GAMResponse_AuthenticationMode) ProtoMessage() {} func (x *GAMResponse_AuthenticationMode) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[31] + mi := &file_authd_proto_msgTypes[33] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1781,7 +1877,7 @@ type IARequest_AuthenticationData struct { func (x *IARequest_AuthenticationData) Reset() { *x = IARequest_AuthenticationData{} - mi := &file_authd_proto_msgTypes[32] + mi := &file_authd_proto_msgTypes[34] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -1793,7 +1889,7 @@ func (x *IARequest_AuthenticationData) String() string { func (*IARequest_AuthenticationData) ProtoMessage() {} func (x *IARequest_AuthenticationData) ProtoReflect() protoreflect.Message { - mi := &file_authd_proto_msgTypes[32] + mi := &file_authd_proto_msgTypes[34] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -1998,7 +2094,12 @@ const file_authd_proto_rawDesc = "" + "\n" + "id_changed\x18\x01 \x01(\bR\tidChanged\x123\n" + "\x16home_dir_owner_changed\x18\x02 \x01(\bR\x13homeDirOwnerChanged\x12\x1a\n" + - "\bwarnings\x18\x03 \x03(\tR\bwarnings\"\x84\x01\n" + + "\bwarnings\x18\x03 \x03(\tR\bwarnings\";\n" + + "\x0fSetShellRequest\x12\x12\n" + + "\x04name\x18\x01 \x01(\tR\x04name\x12\x14\n" + + "\x05shell\x18\x02 \x01(\tR\x05shell\".\n" + + "\x10SetShellResponse\x12\x1a\n" + + "\bwarnings\x18\x01 \x03(\tR\bwarnings\"\x84\x01\n" + "\x04User\x12\x12\n" + "\x04name\x18\x01 \x01(\tR\x04name\x12\x10\n" + "\x03uid\x18\x02 \x01(\rR\x03uid\x12\x10\n" + @@ -2028,7 +2129,7 @@ const file_authd_proto_rawDesc = "" + "\x0fIsAuthenticated\x12\x10.authd.IARequest\x1a\x11.authd.IAResponse\x12,\n" + "\n" + "EndSession\x12\x10.authd.ESRequest\x1a\f.authd.Empty\x12<\n" + - "\x17SetDefaultBrokerForUser\x12\x13.authd.SDBFURequest\x1a\f.authd.Empty2\xb6\x04\n" + + "\x17SetDefaultBrokerForUser\x12\x13.authd.SDBFURequest\x1a\f.authd.Empty2\xf3\x04\n" + "\vUserService\x129\n" + "\rGetUserByName\x12\x1b.authd.GetUserByNameRequest\x1a\v.authd.User\x125\n" + "\vGetUserByID\x12\x19.authd.GetUserByIDRequest\x1a\v.authd.User\x12'\n" + @@ -2038,7 +2139,8 @@ const file_authd_proto_rawDesc = "" + "UnlockUser\x12\x18.authd.UnlockUserRequest\x1a\f.authd.Empty\x12>\n" + "\tSetUserID\x12\x17.authd.SetUserIDRequest\x1a\x18.authd.SetUserIDResponse\x12A\n" + "\n" + - "SetGroupID\x12\x18.authd.SetGroupIDRequest\x1a\x19.authd.SetGroupIDResponse\x12<\n" + + "SetGroupID\x12\x18.authd.SetGroupIDRequest\x1a\x19.authd.SetGroupIDResponse\x12;\n" + + "\bSetShell\x12\x16.authd.SetShellRequest\x1a\x17.authd.SetShellResponse\x12<\n" + "\x0eGetGroupByName\x12\x1c.authd.GetGroupByNameRequest\x1a\f.authd.Group\x128\n" + "\fGetGroupByID\x12\x1a.authd.GetGroupByIDRequest\x1a\f.authd.Group\x12)\n" + "\n" + @@ -2057,7 +2159,7 @@ func file_authd_proto_rawDescGZIP() []byte { } var file_authd_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_authd_proto_msgTypes = make([]protoimpl.MessageInfo, 33) +var file_authd_proto_msgTypes = make([]protoimpl.MessageInfo, 35) var file_authd_proto_goTypes = []any{ (SessionMode)(0), // 0: authd.SessionMode (*Empty)(nil), // 1: authd.Empty @@ -2086,23 +2188,25 @@ var file_authd_proto_goTypes = []any{ (*SetUserIDResponse)(nil), // 24: authd.SetUserIDResponse (*SetGroupIDRequest)(nil), // 25: authd.SetGroupIDRequest (*SetGroupIDResponse)(nil), // 26: authd.SetGroupIDResponse - (*User)(nil), // 27: authd.User - (*Users)(nil), // 28: authd.Users - (*Group)(nil), // 29: authd.Group - (*Groups)(nil), // 30: authd.Groups - (*ABResponse_BrokerInfo)(nil), // 31: authd.ABResponse.BrokerInfo - (*GAMResponse_AuthenticationMode)(nil), // 32: authd.GAMResponse.AuthenticationMode - (*IARequest_AuthenticationData)(nil), // 33: authd.IARequest.AuthenticationData + (*SetShellRequest)(nil), // 27: authd.SetShellRequest + (*SetShellResponse)(nil), // 28: authd.SetShellResponse + (*User)(nil), // 29: authd.User + (*Users)(nil), // 30: authd.Users + (*Group)(nil), // 31: authd.Group + (*Groups)(nil), // 32: authd.Groups + (*ABResponse_BrokerInfo)(nil), // 33: authd.ABResponse.BrokerInfo + (*GAMResponse_AuthenticationMode)(nil), // 34: authd.GAMResponse.AuthenticationMode + (*IARequest_AuthenticationData)(nil), // 35: authd.IARequest.AuthenticationData } var file_authd_proto_depIdxs = []int32{ - 31, // 0: authd.ABResponse.brokers_infos:type_name -> authd.ABResponse.BrokerInfo + 33, // 0: authd.ABResponse.brokers_infos:type_name -> authd.ABResponse.BrokerInfo 0, // 1: authd.SBRequest.mode:type_name -> authd.SessionMode 9, // 2: authd.GAMRequest.supported_ui_layouts:type_name -> authd.UILayout - 32, // 3: authd.GAMResponse.authentication_modes:type_name -> authd.GAMResponse.AuthenticationMode + 34, // 3: authd.GAMResponse.authentication_modes:type_name -> authd.GAMResponse.AuthenticationMode 9, // 4: authd.SAMResponse.ui_layout_info:type_name -> authd.UILayout - 33, // 5: authd.IARequest.authentication_data:type_name -> authd.IARequest.AuthenticationData - 27, // 6: authd.Users.users:type_name -> authd.User - 29, // 7: authd.Groups.groups:type_name -> authd.Group + 35, // 5: authd.IARequest.authentication_data:type_name -> authd.IARequest.AuthenticationData + 29, // 6: authd.Users.users:type_name -> authd.User + 31, // 7: authd.Groups.groups:type_name -> authd.Group 1, // 8: authd.PAM.AvailableBrokers:input_type -> authd.Empty 2, // 9: authd.PAM.GetPreviousBroker:input_type -> authd.GPBRequest 6, // 10: authd.PAM.SelectBroker:input_type -> authd.SBRequest @@ -2118,29 +2222,31 @@ var file_authd_proto_depIdxs = []int32{ 20, // 20: authd.UserService.UnlockUser:input_type -> authd.UnlockUserRequest 23, // 21: authd.UserService.SetUserID:input_type -> authd.SetUserIDRequest 25, // 22: authd.UserService.SetGroupID:input_type -> authd.SetGroupIDRequest - 21, // 23: authd.UserService.GetGroupByName:input_type -> authd.GetGroupByNameRequest - 22, // 24: authd.UserService.GetGroupByID:input_type -> authd.GetGroupByIDRequest - 1, // 25: authd.UserService.ListGroups:input_type -> authd.Empty - 4, // 26: authd.PAM.AvailableBrokers:output_type -> authd.ABResponse - 3, // 27: authd.PAM.GetPreviousBroker:output_type -> authd.GPBResponse - 7, // 28: authd.PAM.SelectBroker:output_type -> authd.SBResponse - 10, // 29: authd.PAM.GetAuthenticationModes:output_type -> authd.GAMResponse - 12, // 30: authd.PAM.SelectAuthenticationMode:output_type -> authd.SAMResponse - 14, // 31: authd.PAM.IsAuthenticated:output_type -> authd.IAResponse - 1, // 32: authd.PAM.EndSession:output_type -> authd.Empty - 1, // 33: authd.PAM.SetDefaultBrokerForUser:output_type -> authd.Empty - 27, // 34: authd.UserService.GetUserByName:output_type -> authd.User - 27, // 35: authd.UserService.GetUserByID:output_type -> authd.User - 28, // 36: authd.UserService.ListUsers:output_type -> authd.Users - 1, // 37: authd.UserService.LockUser:output_type -> authd.Empty - 1, // 38: authd.UserService.UnlockUser:output_type -> authd.Empty - 24, // 39: authd.UserService.SetUserID:output_type -> authd.SetUserIDResponse - 26, // 40: authd.UserService.SetGroupID:output_type -> authd.SetGroupIDResponse - 29, // 41: authd.UserService.GetGroupByName:output_type -> authd.Group - 29, // 42: authd.UserService.GetGroupByID:output_type -> authd.Group - 30, // 43: authd.UserService.ListGroups:output_type -> authd.Groups - 26, // [26:44] is the sub-list for method output_type - 8, // [8:26] is the sub-list for method input_type + 27, // 23: authd.UserService.SetShell:input_type -> authd.SetShellRequest + 21, // 24: authd.UserService.GetGroupByName:input_type -> authd.GetGroupByNameRequest + 22, // 25: authd.UserService.GetGroupByID:input_type -> authd.GetGroupByIDRequest + 1, // 26: authd.UserService.ListGroups:input_type -> authd.Empty + 4, // 27: authd.PAM.AvailableBrokers:output_type -> authd.ABResponse + 3, // 28: authd.PAM.GetPreviousBroker:output_type -> authd.GPBResponse + 7, // 29: authd.PAM.SelectBroker:output_type -> authd.SBResponse + 10, // 30: authd.PAM.GetAuthenticationModes:output_type -> authd.GAMResponse + 12, // 31: authd.PAM.SelectAuthenticationMode:output_type -> authd.SAMResponse + 14, // 32: authd.PAM.IsAuthenticated:output_type -> authd.IAResponse + 1, // 33: authd.PAM.EndSession:output_type -> authd.Empty + 1, // 34: authd.PAM.SetDefaultBrokerForUser:output_type -> authd.Empty + 29, // 35: authd.UserService.GetUserByName:output_type -> authd.User + 29, // 36: authd.UserService.GetUserByID:output_type -> authd.User + 30, // 37: authd.UserService.ListUsers:output_type -> authd.Users + 1, // 38: authd.UserService.LockUser:output_type -> authd.Empty + 1, // 39: authd.UserService.UnlockUser:output_type -> authd.Empty + 24, // 40: authd.UserService.SetUserID:output_type -> authd.SetUserIDResponse + 26, // 41: authd.UserService.SetGroupID:output_type -> authd.SetGroupIDResponse + 28, // 42: authd.UserService.SetShell:output_type -> authd.SetShellResponse + 31, // 43: authd.UserService.GetGroupByName:output_type -> authd.Group + 31, // 44: authd.UserService.GetGroupByID:output_type -> authd.Group + 32, // 45: authd.UserService.ListGroups:output_type -> authd.Groups + 27, // [27:46] is the sub-list for method output_type + 8, // [8:27] is the sub-list for method input_type 8, // [8:8] is the sub-list for extension type_name 8, // [8:8] is the sub-list for extension extendee 0, // [0:8] is the sub-list for field type_name @@ -2152,8 +2258,8 @@ func file_authd_proto_init() { return } file_authd_proto_msgTypes[8].OneofWrappers = []any{} - file_authd_proto_msgTypes[30].OneofWrappers = []any{} - file_authd_proto_msgTypes[32].OneofWrappers = []any{ + file_authd_proto_msgTypes[32].OneofWrappers = []any{} + file_authd_proto_msgTypes[34].OneofWrappers = []any{ (*IARequest_AuthenticationData_Secret)(nil), (*IARequest_AuthenticationData_Wait)(nil), (*IARequest_AuthenticationData_Skip)(nil), @@ -2165,7 +2271,7 @@ func file_authd_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_authd_proto_rawDesc), len(file_authd_proto_rawDesc)), NumEnums: 1, - NumMessages: 33, + NumMessages: 35, NumExtensions: 0, NumServices: 2, }, diff --git a/internal/proto/authd/authd.proto b/internal/proto/authd/authd.proto index df108d0f51..9aae541f5f 100644 --- a/internal/proto/authd/authd.proto +++ b/internal/proto/authd/authd.proto @@ -137,6 +137,7 @@ service UserService { rpc UnlockUser(UnlockUserRequest) returns (Empty); rpc SetUserID(SetUserIDRequest) returns (SetUserIDResponse); rpc SetGroupID(SetGroupIDRequest) returns (SetGroupIDResponse); + rpc SetShell(SetShellRequest) returns (SetShellResponse); rpc GetGroupByName(GetGroupByNameRequest) returns (Group); rpc GetGroupByID(GetGroupByIDRequest) returns (Group); @@ -196,6 +197,15 @@ message SetGroupIDResponse { repeated string warnings = 3; } +message SetShellRequest { + string name = 1; + string shell = 2; +} + +message SetShellResponse { + repeated string warnings = 1; +} + message User { string name = 1; uint32 uid = 2; diff --git a/internal/proto/authd/authd_grpc.pb.go b/internal/proto/authd/authd_grpc.pb.go index 8685f6fdca..d6146889df 100644 --- a/internal/proto/authd/authd_grpc.pb.go +++ b/internal/proto/authd/authd_grpc.pb.go @@ -394,6 +394,7 @@ const ( UserService_UnlockUser_FullMethodName = "/authd.UserService/UnlockUser" UserService_SetUserID_FullMethodName = "/authd.UserService/SetUserID" UserService_SetGroupID_FullMethodName = "/authd.UserService/SetGroupID" + UserService_SetShell_FullMethodName = "/authd.UserService/SetShell" UserService_GetGroupByName_FullMethodName = "/authd.UserService/GetGroupByName" UserService_GetGroupByID_FullMethodName = "/authd.UserService/GetGroupByID" UserService_ListGroups_FullMethodName = "/authd.UserService/ListGroups" @@ -410,6 +411,7 @@ type UserServiceClient interface { UnlockUser(ctx context.Context, in *UnlockUserRequest, opts ...grpc.CallOption) (*Empty, error) SetUserID(ctx context.Context, in *SetUserIDRequest, opts ...grpc.CallOption) (*SetUserIDResponse, error) SetGroupID(ctx context.Context, in *SetGroupIDRequest, opts ...grpc.CallOption) (*SetGroupIDResponse, error) + SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*SetShellResponse, error) GetGroupByName(ctx context.Context, in *GetGroupByNameRequest, opts ...grpc.CallOption) (*Group, error) GetGroupByID(ctx context.Context, in *GetGroupByIDRequest, opts ...grpc.CallOption) (*Group, error) ListGroups(ctx context.Context, in *Empty, opts ...grpc.CallOption) (*Groups, error) @@ -493,6 +495,16 @@ func (c *userServiceClient) SetGroupID(ctx context.Context, in *SetGroupIDReques return out, nil } +func (c *userServiceClient) SetShell(ctx context.Context, in *SetShellRequest, opts ...grpc.CallOption) (*SetShellResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(SetShellResponse) + err := c.cc.Invoke(ctx, UserService_SetShell_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *userServiceClient) GetGroupByName(ctx context.Context, in *GetGroupByNameRequest, opts ...grpc.CallOption) (*Group, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(Group) @@ -534,6 +546,7 @@ type UserServiceServer interface { UnlockUser(context.Context, *UnlockUserRequest) (*Empty, error) SetUserID(context.Context, *SetUserIDRequest) (*SetUserIDResponse, error) SetGroupID(context.Context, *SetGroupIDRequest) (*SetGroupIDResponse, error) + SetShell(context.Context, *SetShellRequest) (*SetShellResponse, error) GetGroupByName(context.Context, *GetGroupByNameRequest) (*Group, error) GetGroupByID(context.Context, *GetGroupByIDRequest) (*Group, error) ListGroups(context.Context, *Empty) (*Groups, error) @@ -568,6 +581,9 @@ func (UnimplementedUserServiceServer) SetUserID(context.Context, *SetUserIDReque func (UnimplementedUserServiceServer) SetGroupID(context.Context, *SetGroupIDRequest) (*SetGroupIDResponse, error) { return nil, status.Error(codes.Unimplemented, "method SetGroupID not implemented") } +func (UnimplementedUserServiceServer) SetShell(context.Context, *SetShellRequest) (*SetShellResponse, error) { + return nil, status.Error(codes.Unimplemented, "method SetShell not implemented") +} func (UnimplementedUserServiceServer) GetGroupByName(context.Context, *GetGroupByNameRequest) (*Group, error) { return nil, status.Error(codes.Unimplemented, "method GetGroupByName not implemented") } @@ -724,6 +740,24 @@ func _UserService_SetGroupID_Handler(srv interface{}, ctx context.Context, dec f return interceptor(ctx, in, info, handler) } +func _UserService_SetShell_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(SetShellRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(UserServiceServer).SetShell(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: UserService_SetShell_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(UserServiceServer).SetShell(ctx, req.(*SetShellRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _UserService_GetGroupByName_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(GetGroupByNameRequest) if err := dec(in); err != nil { @@ -813,6 +847,10 @@ var UserService_ServiceDesc = grpc.ServiceDesc{ MethodName: "SetGroupID", Handler: _UserService_SetGroupID_Handler, }, + { + MethodName: "SetShell", + Handler: _UserService_SetShell_Handler, + }, { MethodName: "GetGroupByName", Handler: _UserService_GetGroupByName_Handler, diff --git a/internal/services/permissions/export_test.go b/internal/services/permissions/export_test.go index 728aaa902b..7c43b1aad6 100644 --- a/internal/services/permissions/export_test.go +++ b/internal/services/permissions/export_test.go @@ -1,9 +1,9 @@ package permissions -type PeerCredsInfo = peerCredsInfo +type PeerAuthInfo = peerAuthInfo -func NewTestPeerCredsInfo(uid uint32, pid int32) PeerCredsInfo { - return PeerCredsInfo{uid: uid, pid: pid} +func NewTestPeerAuthInfo(uid uint32, pid int32) PeerAuthInfo { + return PeerAuthInfo{uid: uid, pid: pid} } var ( diff --git a/internal/services/permissions/internal_test.go b/internal/services/permissions/internal_test.go index 8bbdf6752b..ea662527f4 100644 --- a/internal/services/permissions/internal_test.go +++ b/internal/services/permissions/internal_test.go @@ -16,7 +16,7 @@ import ( func TestPeerCredsInfoAuthType(t *testing.T) { t.Parallel() - p := peerCredsInfo{ + p := peerAuthInfo{ uid: 11111, pid: 22222, } diff --git a/internal/services/permissions/permissions.go b/internal/services/permissions/permissions.go index a2c00eb6d7..27c6bb7b0c 100644 --- a/internal/services/permissions/permissions.go +++ b/internal/services/permissions/permissions.go @@ -38,20 +38,52 @@ func New(args ...Option) Manager { } // CheckRequestIsFromRoot checks if the current gRPC request is from a root user and returns an error if not. -// The pid and uid are extracted from peerCredsInfo in the gRPC context. +// The pid and uid are extracted from peerAuthInfo in the gRPC context. func (m Manager) CheckRequestIsFromRoot(ctx context.Context) (err error) { - p, ok := peer.FromContext(ctx) - if !ok { - return errors.New("context request doesn't have gRPC peer information") + isRoot, err := m.isRequestFromRoot(ctx) + if err != nil { + return err } - pci, ok := p.AuthInfo.(peerCredsInfo) - if !ok { - return errors.New("context request doesn't have valid gRPC peer credential information") + if !isRoot { + return errors.New("only root can perform this operation") } + return nil +} - if pci.uid != m.rootUID { - return errors.New("only root can perform this operation") +// CheckRequestIsFromRootOrUID checks if the current gRPC request is from a root +// user or a specified user and returns an error if not. +func (m Manager) CheckRequestIsFromRootOrUID(ctx context.Context, uid uint32) (err error) { + isRoot, err := m.isRequestFromRoot(ctx) + if err != nil { + return err + } + if isRoot { + return nil } + isFromUID, err := m.isRequestFromUID(ctx, uid) + if err != nil { + return err + } + if !isFromUID { + return errors.New("only root or the specified user can perform this operation") + } return nil } + +func (m Manager) isRequestFromRoot(ctx context.Context) (bool, error) { + return m.isRequestFromUID(ctx, m.rootUID) +} + +func (m Manager) isRequestFromUID(ctx context.Context, uid uint32) (bool, error) { + p, ok := peer.FromContext(ctx) + if !ok { + return false, errors.New("context request doesn't have gRPC peer information") + } + pci, ok := p.AuthInfo.(peerAuthInfo) + if !ok { + return false, errors.New("context request doesn't have valid gRPC peer credential information") + } + + return pci.uid == uid, nil +} diff --git a/internal/services/permissions/permissions_test.go b/internal/services/permissions/permissions_test.go index 04605cabb4..fd231a880a 100644 --- a/internal/services/permissions/permissions_test.go +++ b/internal/services/permissions/permissions_test.go @@ -21,44 +21,27 @@ func TestNew(t *testing.T) { require.NotNil(t, pm, "New permission manager is created") } -func TestIsRequestFromRoot(t *testing.T) { +func TestCheckRequestIsFromRoot(t *testing.T) { t.Parallel() tests := map[string]struct { currentUserNotRoot bool - noPeerCredsInfo bool - noAuthInfo bool + noPeerInfo bool + noPeerAuthInfo bool wantErr bool }{ "Granted_if_current_user_considered_as_root": {}, - "Error_as_deny_when_current_user_is_not_root": {currentUserNotRoot: true, wantErr: true}, - "Error_as_deny_when_missing_peer_creds_Info": {noPeerCredsInfo: true, wantErr: true}, - "Error_as_deny_when_missing_auth_info_creds": {noAuthInfo: true, wantErr: true}, + "Error_if_current_user_is_not_root": {currentUserNotRoot: true, wantErr: true}, + "Error_if_missing_peer_info": {noPeerInfo: true, wantErr: true}, + "Error_if_missing_peer_auth_info": {noPeerAuthInfo: true, wantErr: true}, } for name, tc := range tests { t.Run(name, func(t *testing.T) { t.Parallel() - // Setup peer creds info - ctx := context.Background() - if !tc.noPeerCredsInfo { - var authInfo credentials.AuthInfo - if !tc.noAuthInfo { - uid := permissions.CurrentUserUID() - pid := os.Getpid() - if pid > math.MaxInt32 { - t.Fatalf("Setup: pid is too large to be converted to int32: %d", pid) - } - //nolint:gosec // we did check the conversion check beforehand. - authInfo = permissions.NewTestPeerCredsInfo(uid, int32(os.Getpid())) - } - p := peer.Peer{ - AuthInfo: authInfo, - } - ctx = peer.NewContext(ctx, &p) - } + ctx := setupPermissionTestContext(t, tc.noPeerInfo, tc.noPeerAuthInfo) var opts []permissions.Option if !tc.currentUserNotRoot { @@ -77,6 +60,60 @@ func TestIsRequestFromRoot(t *testing.T) { } } +func TestCheckRequestIsFromRootOrUID(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + currentUserNotRoot bool + useCurrentUID bool + useDifferentUID bool + noPeerInfo bool + noPeerAuthInfo bool + + wantErr bool + }{ + "Granted_if_current_user_considered_as_root": {}, + "Granted_if_current_user_matches_target_uid": {currentUserNotRoot: true, useCurrentUID: true}, + + "Error_if_current_user_is_neither_root_nor_target_uid": { + currentUserNotRoot: true, + useDifferentUID: true, + wantErr: true, + }, + "Error_if_missing_peer_info": {noPeerInfo: true, wantErr: true}, + "Error_if_missing_peer_auth_info": {noPeerAuthInfo: true, wantErr: true}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + ctx := setupPermissionTestContext(t, tc.noPeerInfo, tc.noPeerAuthInfo) + + currentUID := permissions.CurrentUserUID() + targetUID := currentUID + + // If we want a different UID, use a different value + if tc.useDifferentUID { + targetUID = currentUID + 1000 + } + + var opts []permissions.Option + if !tc.currentUserNotRoot { + opts = append(opts, permissions.Z_ForTests_WithCurrentUserAsRoot()) + } + pm := permissions.New(opts...) + + err := pm.CheckRequestIsFromRootOrUID(ctx, targetUID) + + if tc.wantErr { + require.Error(t, err, "CheckRequestIsFromRootOrUID should deny access but didn't") + return + } + require.NoError(t, err, "CheckRequestIsFromRootOrUID should allow access but didn't") + }) + } +} + func TestWithUnixPeerCreds(t *testing.T) { t.Parallel() @@ -84,3 +121,28 @@ func TestWithUnixPeerCreds(t *testing.T) { require.NotNil(t, g, "New gRPC with Unix Peer Creds is created") } + +// setupPermissionTestContext creates a context with peer credentials for testing. +func setupPermissionTestContext(t *testing.T, noPeerInfo, noAuthInfo bool) context.Context { + t.Helper() + + ctx := context.Background() + if noPeerInfo { + return ctx + } + + var authInfo credentials.AuthInfo + if !noAuthInfo { + uid := permissions.CurrentUserUID() + pid := os.Getpid() + if pid > math.MaxInt32 { + require.Fail(t, "Setup: pid is too large to be converted to int32: %d", pid) + } + //nolint:gosec // we checked for an integer overflow above. + authInfo = permissions.NewTestPeerAuthInfo(uid, int32(pid)) + } + p := peer.Peer{ + AuthInfo: authInfo, + } + return peer.NewContext(ctx, &p) +} diff --git a/internal/services/permissions/servercreds.go b/internal/services/permissions/servercreds.go index 72b3d71d3a..85f0383fb7 100644 --- a/internal/services/permissions/servercreds.go +++ b/internal/services/permissions/servercreds.go @@ -13,7 +13,7 @@ import ( "google.golang.org/grpc/credentials" ) -// WithUnixPeerCreds returns the credentials of the caller. +// WithUnixPeerCreds returns a ServerOption that sets credentials for server connections. func WithUnixPeerCreds() grpc.ServerOption { return grpc.Creds(serverPeerCreds{}) } @@ -57,7 +57,7 @@ func (serverPeerCreds) ServerHandshake(conn net.Conn) (n net.Conn, c credentials return nil, nil, fmt.Errorf("Control() error: %v", err) } - return conn, peerCredsInfo{uid: cred.Uid, pid: cred.Pid}, nil + return conn, peerAuthInfo{uid: cred.Uid, pid: cred.Pid}, nil } func (serverPeerCreds) ClientHandshake(_ context.Context, _ string, conn net.Conn) (net.Conn, credentials.AuthInfo, error) { return conn, nil, nil @@ -66,12 +66,12 @@ func (serverPeerCreds) Info() credentials.ProtocolInfo { return credent func (serverPeerCreds) Clone() credentials.TransportCredentials { return nil } func (serverPeerCreds) OverrideServerName(_ string) error { return nil } -type peerCredsInfo struct { +type peerAuthInfo struct { uid uint32 pid int32 } // AuthType returns a string containing the uid and pid of caller. -func (p peerCredsInfo) AuthType() string { +func (p peerAuthInfo) AuthType() string { return fmt.Sprintf("uid: %d, pid: %d", p.uid, p.pid) } diff --git a/internal/services/testdata/golden/TestRegisterGRPCServices b/internal/services/testdata/golden/TestRegisterGRPCServices index 871d7f2387..c043749b06 100644 --- a/internal/services/testdata/golden/TestRegisterGRPCServices +++ b/internal/services/testdata/golden/TestRegisterGRPCServices @@ -51,6 +51,9 @@ authd.UserService: - name: SetGroupID isclientstream: false isserverstream: false + - name: SetShell + isclientstream: false + isserverstream: false - name: SetUserID isclientstream: false isserverstream: false diff --git a/internal/services/user/user.go b/internal/services/user/user.go index 4d976d8297..4fe4d1fe3f 100644 --- a/internal/services/user/user.go +++ b/internal/services/user/user.go @@ -270,6 +270,26 @@ func (s Service) SetGroupID(ctx context.Context, req *authd.SetGroupIDRequest) ( }, nil } +// SetShell sets the shell of a user. +func (s Service) SetShell(ctx context.Context, req *authd.SetShellRequest) (*authd.SetShellResponse, error) { + // authd uses lowercase group names. + name := strings.ToLower(req.GetName()) + + if err := s.permissionManager.CheckRequestIsFromRoot(ctx); err != nil { + return nil, status.Error(codes.PermissionDenied, err.Error()) + } + + warnings, err := s.userManager.SetShell(name, req.GetShell()) + if err != nil { + log.Errorf(ctx, "SetShell: %v", err) + return nil, grpcError(err) + } + + return &authd.SetShellResponse{ + Warnings: warnings, + }, nil +} + // userToProtobuf converts a types.UserEntry to authd.User. func userToProtobuf(u types.UserEntry) *authd.User { return &authd.User{ diff --git a/internal/services/user/user_test.go b/internal/services/user/user_test.go index 91afdf1e95..78c7dc1b27 100644 --- a/internal/services/user/user_test.go +++ b/internal/services/user/user_test.go @@ -423,6 +423,43 @@ func TestSetGroupID(t *testing.T) { } } +func TestSetShell(t *testing.T) { + tests := map[string]struct { + sourceDB string + + username string + newShell string + closeDB bool + currentUserNotRoot bool + + wantErr bool + }{ + "Successfully_set_shell": {username: "user1", newShell: "/bin/sh"}, + "Successfully_set_shell_when_username_has_uppercase_char": {username: "USER1", newShell: "/bin/sh"}, + + "Error_when_not_root": {username: "user1", newShell: "/bin/sh", currentUserNotRoot: true, wantErr: true}, + "Error_when_users_manager_fails_to_set_shell": {username: "doesnotexist", newShell: "/bin/sh", wantErr: true}, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + client, m := newUserServiceClient(t, tc.sourceDB, tc.currentUserNotRoot) + + if tc.closeDB { + // Close the database to trigger a database error + err := userstestutils.DBManager(m).Close() + require.NoError(t, err, "Setup: failed to close database") + } + + _, err := client.SetShell(context.Background(), &authd.SetShellRequest{Name: tc.username, Shell: tc.newShell}) + if tc.wantErr { + require.Error(t, err, "SetShell should return an error, but did not") + return + } + require.NoError(t, err, "SetShell should not return an error, but did") + }) + } +} + // newUserServiceClient returns a new gRPC client for the CLI service. func newUserServiceClient(t *testing.T, dbFile string, currentUserNotRoot ...bool) (client authd.UserServiceClient, userManager *users.Manager) { t.Helper() diff --git a/internal/users/db/db_test.go b/internal/users/db/db_test.go index 47cce2060f..47926fc7cc 100644 --- a/internal/users/db/db_test.go +++ b/internal/users/db/db_test.go @@ -1013,6 +1013,45 @@ func TestSetGroupID(t *testing.T) { } } +func TestSetShell(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + nonExistentUser bool + + wantErr bool + }{ + "Update_existing_user_shell": {}, + + "Error_on_nonexistent_user": {nonExistentUser: true, wantErr: true}, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + m := initDB(t, "one_user_and_group") + + username := "user1" + if tc.nonExistentUser { + username = "nonexistent" + } + + err := m.SetShell(username, "/bin/new-shell") + if tc.wantErr { + require.Error(t, err, "SetShell should return an error for case %q", name) + return + } + require.NoError(t, err, "SetShell should not return an error for case %q", name) + + dbContent, err := db.Z_ForTests_DumpNormalizedYAML(m) + require.NoError(t, err) + + golden.CheckOrUpdate(t, dbContent) + }) + } +} + func TestRemoveDb(t *testing.T) { t.Parallel() diff --git a/internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell b/internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell new file mode 100644 index 0000000000..c3b17bbdd4 --- /dev/null +++ b/internal/users/db/testdata/golden/TestSetShell/Update_existing_user_shell @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /bin/new-shell + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/db/update.go b/internal/users/db/update.go index 7c420d09f9..f60d10737e 100644 --- a/internal/users/db/update.go +++ b/internal/users/db/update.go @@ -373,3 +373,20 @@ func (m *Manager) SetGroupID(groupName string, newGID uint32) ([]UserRow, error) return users, nil } + +// SetShell updates the shell of a user. +func (m *Manager) SetShell(username, shell string) error { + query := `UPDATE users SET shell = ? WHERE name = ?` + res, err := m.db.Exec(query, shell, username) + if err != nil { + return fmt.Errorf("failed to update shell for user: %w", err) + } + rowsAffected, err := res.RowsAffected() + if err != nil { + return fmt.Errorf("failed to get rows affected: %w", err) + } + if rowsAffected == 0 { + return NewUserNotFoundError(username) + } + return nil +} diff --git a/internal/users/manager.go b/internal/users/manager.go index 3b978de19b..340da97e20 100644 --- a/internal/users/manager.go +++ b/internal/users/manager.go @@ -439,7 +439,7 @@ func (m *Manager) SetUserID(name string, uid uint32) (resp *SetUserIDResp, err e // Check if the home directory is currently owned by the user. homeUID, _, err := getHomeDirOwner(oldUser.Dir) if err != nil && !errors.Is(err, os.ErrNotExist) { - warning := fmt.Sprintf("Could not get owner of home directory '%s'.", oldUser.Dir) + warning := fmt.Sprintf("Warning: Could not get owner of home directory '%s', not updating ownership.", oldUser.Dir) log.Warningf(context.Background(), "%s: %v", warning, err) resp.Warnings = append(resp.Warnings, warning) return resp, nil @@ -451,7 +451,7 @@ func (m *Manager) SetUserID(name string, uid uint32) (resp *SetUserIDResp, err e } if homeUID != oldUser.UID { - warning := fmt.Sprintf("Not updating ownership of home directory '%s' because it is not owned by UID %d (current owner: %d).", oldUser.Dir, oldUser.UID, homeUID) + warning := fmt.Sprintf("Warning: Not updating ownership of home directory '%s' because it is not owned by UID %d (current owner: %d).", oldUser.Dir, oldUser.UID, homeUID) log.Warning(context.Background(), warning) resp.Warnings = append(resp.Warnings, warning) return resp, nil @@ -549,18 +549,18 @@ func (m *Manager) updateUserHomeDirOwnership(userRow db.UserRow, oldGID uint32, // Check if the home directory is currently owned by the group _, homeGID, err := getHomeDirOwner(userRow.Dir) if err != nil && !errors.Is(err, os.ErrNotExist) { - warning := fmt.Sprintf("Could not get owner of home directory '%s' for user '%s'.", userRow.Dir, userRow.Name) + warning := fmt.Sprintf("Warning: Could not get owner of home directory '%s', not updating ownership.", userRow.Dir) log.Warningf(context.Background(), "%s: %v", warning, err) return false, warning, nil } if errors.Is(err, os.ErrNotExist) { // The home directory does not exist, so we don't need to change the owner. - log.Debugf(context.Background(), "Home directory %q for user %q does not exist, skipping ownership change", userRow.Dir, userRow.Name) + log.Debugf(context.Background(), "Not updating ownership of home directory %q for user %q because it does not exist", userRow.Dir, userRow.Name) return false, "", nil } if homeGID != oldGID { - warning := fmt.Sprintf("Not updating ownership of home directory '%s' because it is not owned by GID %d (current owner: %d).", userRow.Dir, oldGID, homeGID) + warning := fmt.Sprintf("Warning: Not updating ownership of home directory '%s' because it is not owned by GID %d (current owner: %d).", userRow.Dir, oldGID, homeGID) log.Warning(context.Background(), warning) return false, warning, nil } @@ -665,6 +665,36 @@ func checkHomeDirOwner(home string, uid, gid uint32) error { return nil } +// SetShell sets the shell for the given user. +func (m *Manager) SetShell(username, shell string) (warnings []string, err error) { + if username == "" { + return nil, errors.New("empty username") + } + + // Check if the user exists + _, err = m.db.UserByName(username) + if err != nil { + return nil, err + } + + err = checkValidShellPath(shell) + if err != nil { + return nil, err + } + + err = checkValidShell(shell) + if err != nil { + // We allow root to set an invalid shell but print a warning + warnings = append(warnings, fmt.Sprintf("Warning: %s", err.Error())) + } + + if err = m.db.SetShell(username, shell); err != nil { + return warnings, err + } + + return warnings, nil +} + // BrokerForUser returns the broker ID for the given user. func (m *Manager) BrokerForUser(username string) (string, error) { u, err := m.db.UserByName(username) diff --git a/internal/users/manager_bwrap_test.go b/internal/users/manager_bwrap_test.go index c2a4cd3bd8..b05f1ace96 100644 --- a/internal/users/manager_bwrap_test.go +++ b/internal/users/manager_bwrap_test.go @@ -155,12 +155,11 @@ func TestSetUserID(t *testing.T) { // To make the tests deterministic, we replace the temporary home directory path with a placeholder for i, w := range resp.Warnings { - if regexp.MustCompile(`Could not get owner of home directory '([^"]+)'`).MatchString(w) { - resp.Warnings[i] = `Could not get owner of home directory '{{HOME}}'` - } - if regexp.MustCompile(`Not updating ownership of home directory '([^"]+)' because it is not owned by UID \d+ \(current owner: \d+\)`).MatchString(w) { - resp.Warnings[i] = `Not updating ownership of home directory '{{HOME}}' because it is not owned by UID {{UID}} (current owner: {{CURR_UID}})` - } + // Replace home directory path with placeholder + w = regexp.MustCompile(`home directory '([^']+)'`).ReplaceAllString(w, `home directory '{{HOME}}'`) + // Replace UID and current owner UID with placeholders + w = regexp.MustCompile(`UID (\d+) \(current owner: (\d+)\)`).ReplaceAllString(w, `UID {{UID}} (current owner: {{CURR_UID}})`) + resp.Warnings[i] = w } golden.CheckOrUpdateYAML(t, resp, golden.WithPath("response")) diff --git a/internal/users/manager_test.go b/internal/users/manager_test.go index 033dffdea5..379e2ac2ec 100644 --- a/internal/users/manager_test.go +++ b/internal/users/manager_test.go @@ -1301,6 +1301,138 @@ func TestUpdateUserAfterUnlock(t *testing.T) { require.NoError(t, err, "UpdateUser should not fail") } +func TestSetShell(t *testing.T) { + t.Parallel() + + tests := map[string]struct { + nonExistentUser bool + emptyUsername bool + emptyShell bool + shell string + + wantWarnings int + wantErr bool + }{ + "Successfully_set_shell": {}, + + "Warning_if_shell_is_not_in_etc_shells": { + shell: "/bin/ls", + wantWarnings: 1, + }, + "Warning_if_shell_does_not_exist": { + shell: "/doesnotexist", + wantWarnings: 1, + }, + "Warning_if_shell_is_directory": { + shell: "/etc", + wantWarnings: 1, + }, + "Warning_if_shell_is_not_executable": { + shell: "/etc/passwd", + wantWarnings: 1, + }, + + // checkValidPasswdField error cases + "Error_if_shell_is_empty": { + emptyShell: true, + wantErr: true, + }, + "Error_if_shell_contains_invalid_utf8": { + shell: "/bin/\xff\xfeinvalid", + wantErr: true, + }, + "Error_if_shell_contains_colon": { + shell: "/bin/sh:bash", + wantErr: true, + }, + "Error_if_shell_contains_control_characters": { + shell: "/bin/sh\x00", + wantErr: true, + }, + "Error_if_shell_contains_control_character_tab": { + shell: "/bin/sh\t", + wantErr: true, + }, + "Error_if_shell_contains_control_character_newline": { + shell: "/bin/sh\n", + wantErr: true, + }, + "Error_if_shell_contains_control_character_del": { + shell: "/bin/sh\x7f", + wantErr: true, + }, + + // checkValidShellPath error cases + "Error_if_shell_is_not_absolute_path": { + shell: "bin/sh", + wantErr: true, + }, + "Error_if_shell_path_is_not_normalized": { + shell: "/bin/../bin/sh", + wantErr: true, + }, + "Error_if_shell_path_is_not_normalized_with_dot": { + shell: "/bin/./sh", + wantErr: true, + }, + "Error_if_shell_path_is_too_long": { + shell: "/" + strings.Repeat("a", 4096), + wantErr: true, + }, + + // other error cases + "Error_if_user_does_not_exist": { + nonExistentUser: true, + wantErr: true, + }, + "Error_if_username_is_empty": { + emptyUsername: true, + wantErr: true, + }, + } + + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + t.Parallel() + + dbDir := t.TempDir() + err := db.Z_ForTests_CreateDBFromYAML(filepath.Join("testdata", "db", "one_user_and_group.db.yaml"), dbDir) + require.NoError(t, err, "Setup: could not create database from testdata") + + m := newManagerForTests(t, dbDir) + + username := "user1" + if tc.nonExistentUser { + username = "nonexistent" + } else if tc.emptyUsername { + username = "" + } + + shell := "/bin/sh" + if tc.emptyShell { + shell = "" + } else if tc.shell != "" { + shell = tc.shell + } + + warnings, err := m.SetShell(username, shell) + requireErrorAssertions(t, err, nil, tc.wantErr) + + require.Len(t, warnings, tc.wantWarnings, "Number of warnings mismatch") + + if tc.wantErr { + return + } + + yamlData, err := db.Z_ForTests_DumpNormalizedYAML(m.DB()) + require.NoError(t, err) + golden.CheckOrUpdate(t, yamlData, golden.WithPath("db")) + + golden.CheckOrUpdateYAML(t, warnings, golden.WithPath("warnings")) + }) + } +} + func requireErrorAssertions(t *testing.T, gotErr, wantErrType error, wantErr bool) { t.Helper() diff --git a/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/db b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/db new file mode 100644 index 0000000000..7fa3788902 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /bin/sh + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings new file mode 100644 index 0000000000..fe51488c70 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Successfully_set_shell/warnings @@ -0,0 +1 @@ +[] diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db new file mode 100644 index 0000000000..8bea3e2319 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /doesnotexist + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings new file mode 100644 index 0000000000..d8f388a81a --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_does_not_exist/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/doesnotexist'' does not exist' diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db new file mode 100644 index 0000000000..4b8eee1498 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /etc + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings new file mode 100644 index 0000000000..8063733301 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_directory/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/etc'' is not an executable file' diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db new file mode 100644 index 0000000000..4e506fb35a --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /etc/passwd + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings new file mode 100644 index 0000000000..0baecc86c8 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_executable/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/etc/passwd'' is not an executable file' diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db new file mode 100644 index 0000000000..f5dd819874 --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/db @@ -0,0 +1,18 @@ +users: + - name: user1 + uid: 1111 + gid: 11111 + gecos: |- + User1 gecos + On multiple lines + dir: /home/user1 + shell: /bin/ls + broker_id: broker-id +groups: + - name: group1 + gid: 11111 + ugid: "12345678" +users_to_groups: + - uid: 1111 + gid: 11111 +schema_version: 2 diff --git a/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings new file mode 100644 index 0000000000..d3bff1320b --- /dev/null +++ b/internal/users/testdata/golden/TestSetShell/Warning_if_shell_is_not_in_etc_shells/warnings @@ -0,0 +1 @@ +- 'Warning: shell ''/bin/ls'' is not allowed in /etc/shells' diff --git a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response index 5967e5be65..6b8aa56782 100644 --- a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response +++ b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_cannot_be_accessed/response @@ -1,4 +1,4 @@ idchanged: true homedirownerchanged: false warnings: - - Could not get owner of home directory '{{HOME}}' + - 'Warning: Could not get owner of home directory ''{{HOME}}'', not updating ownership.' diff --git a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response index 1927dd6513..0f7edbb746 100644 --- a/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response +++ b/internal/users/testdata/golden/TestSetUserID/Warning_if_home_directory_is_owned_by_other_user/response @@ -1,4 +1,4 @@ idchanged: true homedirownerchanged: false warnings: - - 'Not updating ownership of home directory ''{{HOME}}'' because it is not owned by UID {{UID}} (current owner: {{CURR_UID}})' + - 'Warning: Not updating ownership of home directory ''{{HOME}}'' because it is not owned by UID {{UID}} (current owner: {{CURR_UID}}).' diff --git a/internal/users/userutils.go b/internal/users/userutils.go new file mode 100644 index 0000000000..554187e0d9 --- /dev/null +++ b/internal/users/userutils.go @@ -0,0 +1,87 @@ +package users + +import ( + "errors" + "fmt" + "os" + "path" + "strings" + "unicode/utf8" + + "golang.org/x/sys/unix" +) + +func checkValidPasswdField(value string) (err error) { + if value == "" { + return errors.New("value cannot be empty") + } + + if !utf8.ValidString(value) { + return errors.New("value must be valid UTF-8") + } + + if strings.ContainsRune(value, ':') { + return errors.New("value cannot contain ':' character") + } + + for _, r := range value { + if r < 32 || r == 127 { + return errors.New("value cannot contain control characters") + } + } + + return nil +} + +func checkValidShellPath(shell string) (err error) { + // Do the same checks as systemd-homed in shell_is_ok: + // https://github.com/systemd/systemd/blob/ba67af7efb7b743ba1974ef9ceb53fba0e3f9e21/src/home/homectl.c#L2812 + if err = checkValidPasswdField(shell); err != nil { + return err + } + + if !path.IsAbs(shell) { + return errors.New("shell must be an absolute path") + } + + if shell != path.Clean(shell) { + return errors.New("shell path must be normalized") + } + + // PATH_MAX is counted with the terminating null byte + if unix.PathMax-1 < len(shell) { + return errors.New("shell path is too long") + } + + return nil +} + +func checkValidShell(shell string) (err error) { + // Check if the shell exists and is executable + stat, err := os.Stat(shell) + if errors.Is(err, os.ErrNotExist) { + return fmt.Errorf("shell '%s' does not exist", shell) + } + + if stat.IsDir() || stat.Mode()&0111 == 0 { + return fmt.Errorf("shell '%s' is not an executable file", shell) + } + + // Check if the shell is in the list of allowed shells in /etc/shells + shells, err := os.ReadFile("/etc/shells") + if err != nil { + return err + } + + for _, allowedShell := range strings.Split(string(shells), "\n") { + if len(allowedShell) > 0 && allowedShell[0] == '#' { + // Skip comments + continue + } + if allowedShell == shell { + return nil + } + } + + return fmt.Errorf("shell '%s' is not allowed in /etc/shells", shell) +}