From b2f015c418edae3ed12fb3737ee24beb92ec02d4 Mon Sep 17 00:00:00 2001 From: Ilya Tsupryk Date: Thu, 19 Mar 2026 17:46:14 +0000 Subject: [PATCH 1/5] Cover exec_resource.go by unit tests --- config/exec_resource_test.go | 575 ++++++++++++++++++++++++++++++----- 1 file changed, 491 insertions(+), 84 deletions(-) diff --git a/config/exec_resource_test.go b/config/exec_resource_test.go index 51538effb..10cea8d5a 100644 --- a/config/exec_resource_test.go +++ b/config/exec_resource_test.go @@ -16,19 +16,32 @@ package config import ( "context" - "io/ioutil" - "math/rand" + "crypto/rand" + "errors" + "fmt" "os" - "path" + "os/exec" "path/filepath" "reflect" + "runtime" "testing" "cloud.google.com/go/osconfig/agentendpoint/apiv1/agentendpointpb" + + utilmocks "github.com/GoogleCloudPlatform/osconfig/util/mocks" + "github.com/golang/mock/gomock" ) +// TestExecResourceDownload verifies downloading and temp file creation for exec resources. func TestExecResourceDownload(t *testing.T) { ctx := context.Background() + preserveGlobalState(t) + + tmpDir := t.TempDir() + localScriptPath := filepath.Join(tmpDir, "my_local_script") + if err := os.WriteFile(localScriptPath, []byte("local validate"), 0755); err != nil { + t.Fatal(err) + } var tests = []struct { name string @@ -38,10 +51,11 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath string wantEnforceContents string goos string + wantErr string }{ { - "Script NONE Linux", - &agentendpointpb.OSPolicy_Resource_ExecResource{ + name: "Script NONE Linux", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Script{Script: "validate"}, Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, @@ -51,15 +65,16 @@ func TestExecResourceDownload(t *testing.T) { Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, }, }, - "script", - "validate", - "script", - "enforce", - "linux", + wantValidatePath: "script", + wantValidateContents: "validate", + wantEnforcePath: "script", + wantEnforceContents: "enforce", + goos: "linux", + wantErr: "", }, { - "Script NONE Windows", - &agentendpointpb.OSPolicy_Resource_ExecResource{ + name: "Script NONE Windows", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Script{Script: "validate"}, Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, @@ -69,15 +84,16 @@ func TestExecResourceDownload(t *testing.T) { Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, }, }, - "script.cmd", - "validate", - "script.cmd", - "enforce", - "windows", + wantValidatePath: "script.cmd", + wantValidateContents: "validate", + wantEnforcePath: "script.cmd", + wantEnforceContents: "enforce", + goos: "windows", + wantErr: "", }, { - "Script SHELL Linux", - &agentendpointpb.OSPolicy_Resource_ExecResource{ + name: "Script SHELL Linux", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Script{Script: "validate"}, Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL, @@ -87,15 +103,16 @@ func TestExecResourceDownload(t *testing.T) { Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL, }, }, - "script.sh", - "validate", - "script.sh", - "enforce", - "linux", + wantValidatePath: "script.sh", + wantValidateContents: "validate", + wantEnforcePath: "script.sh", + wantEnforceContents: "enforce", + goos: "linux", + wantErr: "", }, { - "Script SHELL Windows", - &agentendpointpb.OSPolicy_Resource_ExecResource{ + name: "Script SHELL Windows", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Script{Script: "validate"}, Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL, @@ -105,15 +122,16 @@ func TestExecResourceDownload(t *testing.T) { Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL, }, }, - "script.cmd", - "validate", - "script.cmd", - "enforce", - "windows", + wantValidatePath: "script.cmd", + wantValidateContents: "validate", + wantEnforcePath: "script.cmd", + wantEnforceContents: "enforce", + goos: "windows", + wantErr: "", }, { - "Script POWERSHELL Windows", - &agentendpointpb.OSPolicy_Resource_ExecResource{ + name: "Script POWERSHELL Windows", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Script{Script: "validate"}, Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_POWERSHELL, @@ -123,13 +141,80 @@ func TestExecResourceDownload(t *testing.T) { Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_POWERSHELL, }, }, - "script.ps1", - "validate", - "script.ps1", - "enforce", - "windows", + wantValidatePath: "script.ps1", + wantValidateContents: "validate", + wantEnforcePath: "script.ps1", + wantEnforceContents: "enforce", + goos: "windows", + wantErr: "", + }, + { + name: "Unsupported Interpreter", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Script{Script: "validate"}, + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Interpreter(99), + }, + }, + wantValidatePath: "", + wantValidateContents: "", + wantEnforcePath: "", + wantEnforceContents: "", + goos: "linux", + wantErr: `unsupported interpreter "99"`, + }, + { + name: "Unrecognized Source Type", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, + }, + }, + wantValidatePath: "", + wantValidateContents: "", + wantEnforcePath: "", + wantEnforceContents: "", + goos: "linux", + wantErr: `unrecognized Source type for ExecResource: %!q()`, + }, + { + name: "Unsupported File", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_File{ + File: &agentendpointpb.OSPolicy_Resource_File{}, + }, + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, + }, + }, + wantValidatePath: "", + wantValidateContents: "", + wantEnforcePath: "", + wantEnforceContents: "", + goos: "linux", + wantErr: `unsupported File `, + }, + { + name: "LocalPath File", + erpb: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Source: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec_File{ + File: &agentendpointpb.OSPolicy_Resource_File{ + Type: &agentendpointpb.OSPolicy_Resource_File_LocalPath{LocalPath: localScriptPath}, + }, + }, + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, + }, + }, + wantValidatePath: "my_local_script", + wantValidateContents: "local validate", + wantEnforcePath: "", + wantEnforceContents: "", + goos: "linux", + wantErr: "", }, } + for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { goos = tt.goos @@ -142,47 +227,217 @@ func TestExecResourceDownload(t *testing.T) { } defer pr.Cleanup(ctx) - if err := pr.Validate(ctx); err != nil { - t.Fatalf("Unexpected error: %v", err) + err := pr.Validate(ctx) + if !matchError(t, err, tt.wantErr) || tt.wantErr != "" { + return } + resource := pr.resource.(*execResource) - if tt.wantValidatePath != path.Base(resource.validatePath) { - t.Errorf("unexpected validate path: %q", resource.validatePath) + if tt.wantValidatePath != "" { + if tt.wantValidatePath != filepath.Base(resource.validatePath) { + t.Errorf("unexpected validate path: got %q, want %q", filepath.Base(resource.validatePath), tt.wantValidatePath) + } + assertFileContents(t, resource.validatePath, tt.wantValidateContents) + } + + if tt.wantEnforcePath != "" { + if tt.wantEnforcePath != filepath.Base(resource.enforcePath) { + t.Errorf("unexpected enforce path: got %q, want %q", filepath.Base(resource.enforcePath), tt.wantEnforcePath) + } + assertFileContents(t, resource.enforcePath, tt.wantEnforceContents) + } + }) + } +} + +// TestExecResourceRun verifies command construction and execution. +func TestExecResourceRun(t *testing.T) { + ctx := context.Background() + preserveGlobalState(t) + + var tests = []struct { + name string + goos string + execR *agentendpointpb.OSPolicy_Resource_ExecResource_Exec + expectedCmd *exec.Cmd + wantErr string + }{ + { + name: "NONE interpreter", + goos: "linux", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE}, + expectedCmd: exec.Command("test_script"), + wantErr: "", + }, + { + name: "SHELL Linux", + goos: "linux", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL}, + expectedCmd: exec.Command("/bin/sh", "test_script"), + wantErr: "", + }, + { + name: "SHELL with args", + goos: "linux", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL, Args: []string{"arg1", "arg2"}}, + expectedCmd: exec.Command("/bin/sh", "test_script", "arg1", "arg2"), + wantErr: "", + }, + { + name: "SHELL Windows", + goos: "windows", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL}, + expectedCmd: exec.Command("test_script"), + wantErr: "", + }, + { + name: "POWERSHELL Windows", + goos: "windows", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_POWERSHELL}, + expectedCmd: exec.Command("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\PowerShell.exe", "-File", "test_script"), + wantErr: "", + }, + { + name: "POWERSHELL Linux error", + goos: "linux", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_POWERSHELL}, + expectedCmd: nil, + wantErr: `interpreter "POWERSHELL" can only be used on Windows systems`, + }, + { + name: "Unsupported interpreter", + goos: "linux", + execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Interpreter(99)}, + expectedCmd: nil, + wantErr: `unsupported interpreter "99"`, + }, + { + name: "Nil Exec", + goos: "linux", + execR: nil, + expectedCmd: nil, + wantErr: `ExecResource Exec cannot be nil`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + goos = tt.goos + e := &execResource{} + mockCommandRunner := setupMockRunner(t) + if tt.expectedCmd != nil { + mockCommandRunner.EXPECT().Run(ctx, utilmocks.EqCmd(tt.expectedCmd)).Return([]byte("stdout"), []byte("stderr"), nil).Times(1) } - data, err := ioutil.ReadFile(resource.validatePath) - if err != nil { - t.Fatal(err) + + _, _, _, err := e.run(ctx, "test_script", tt.execR) + matchError(t, err, tt.wantErr) + }) + } +} + +// TestExecResourceCheckState verifies validation phase exit code mapping. +func TestExecResourceCheckState(t *testing.T) { + ctx := context.Background() + preserveGlobalState(t) + + var tests = []struct { + name string + exitCode int + wantInDesiredState bool + wantErr string + }{ + {name: "Code 100", exitCode: 100, wantInDesiredState: true, wantErr: ""}, + {name: "Code 101", exitCode: 101, wantInDesiredState: false, wantErr: ""}, + {name: "Code 0", exitCode: 0, wantInDesiredState: false, wantErr: "unexpected return code from validate: 0, stdout: stdout, stderr: stderr"}, + {name: "Code -1", exitCode: -1, wantInDesiredState: false, wantErr: "some error"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockCommandRunner := setupMockRunner(t) + mockRunnerExpectation(ctx, mockCommandRunner, tt.exitCode) + e := &execResource{ + validatePath: "test_script", + OSPolicy_Resource_ExecResource: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, + }, + }, } - if tt.wantValidateContents != string(data) { - t.Errorf("unexpected validate contents: %q", data) + + inDesiredState, retErr := e.checkState(ctx) + if matchError(t, retErr, tt.wantErr) { + if inDesiredState != tt.wantInDesiredState { + t.Errorf("checkState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) + } } + }) + } +} + +// TestExecResourceEnforceState verifies enforcement execution and output capturing. +func TestExecResourceEnforceState(t *testing.T) { + ctx := context.Background() + preserveGlobalState(t) + + tmpDir := t.TempDir() + outputFile := filepath.Join(tmpDir, "output.txt") - if tt.wantEnforcePath != path.Base(resource.enforcePath) { - t.Errorf("unexpected enforce path: %q", resource.enforcePath) + var tests = []struct { + name string + exitCode int + outputFilePath string + wantInDesiredState bool + wantErr string + wantOutput string + }{ + {name: "Code 100 without output", exitCode: 100, outputFilePath: "", wantInDesiredState: true, wantErr: "", wantOutput: ""}, + {name: "Code 100 with output", exitCode: 100, outputFilePath: outputFile, wantInDesiredState: true, wantErr: "", wantOutput: "my enforce output"}, + {name: "Code 0", exitCode: 0, outputFilePath: "", wantInDesiredState: false, wantErr: "unexpected return code from enforce: 0, stdout: stdout, stderr: stderr", wantOutput: ""}, + {name: "Code -1", exitCode: -1, outputFilePath: "", wantInDesiredState: false, wantErr: "some error", wantOutput: ""}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + mockCommandRunner := setupMockRunner(t) + mockRunnerExpectation(ctx, mockCommandRunner, tt.exitCode) + if tt.outputFilePath != "" { + if err := os.WriteFile(tt.outputFilePath, []byte(tt.wantOutput), 0644); err != nil { + t.Fatal(err) + } } - data, err = ioutil.ReadFile(resource.enforcePath) - if err != nil { - t.Fatal(err) + e := &execResource{ + enforcePath: "test_script", + OSPolicy_Resource_ExecResource: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Enforce: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, + OutputFilePath: tt.outputFilePath, + }, + }, } - if tt.wantEnforceContents != string(data) { - t.Errorf("unexpected enforce contents: %q", data) + + inDesiredState, retErr := e.enforceState(ctx) + if matchError(t, retErr, tt.wantErr) { + if inDesiredState != tt.wantInDesiredState { + t.Errorf("enforceState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) + } + if tt.wantOutput != "" && string(e.enforceOutput) != tt.wantOutput { + t.Errorf("enforceState() output = %q, want %q", string(e.enforceOutput), tt.wantOutput) + } } }) } } +// TestExecOutput verifies file reading and truncation for enforce output. func TestExecOutput(t *testing.T) { ctx := context.Background() - tmpDir, err := ioutil.TempDir("", "") - if err != nil { - t.Fatal(err) - } - defer os.RemoveAll(tmpDir) + tmpDir := t.TempDir() fileA := filepath.Join(tmpDir, "fileA") contentsA := []byte("here is some text\nand some more\n") - if err := ioutil.WriteFile(fileA, contentsA, 0600); err != nil { + if err := os.WriteFile(fileA, contentsA, 0600); err != nil { t.Fatal(err) } @@ -191,54 +446,206 @@ func TestExecOutput(t *testing.T) { if _, err := rand.Read(contentsB); err != nil { t.Fatal(err) } - if err := ioutil.WriteFile(fileB, contentsB, 0600); err != nil { + if err := os.WriteFile(fileB, contentsB, 0600); err != nil { t.Fatal(err) } + _, errDNE := os.Open("DNE") + expectedDNEErr := fmt.Sprintf("error opening OutputFilePath: %v", errDNE) + var tests = []struct { name string filePath string want []byte - wantErr bool + wantErr string }{ { - "empty path", - "", - nil, - false, + name: "empty path", + filePath: "", + want: nil, + wantErr: "", }, { - "path DNE", - "DNE", - nil, - true, + name: "path DNE", + filePath: "DNE", + want: nil, + wantErr: expectedDNEErr, }, { - "normal case", - fileA, - contentsA, - false, + name: "normal case", + filePath: fileA, + want: contentsA, + wantErr: "", }, { - "file to large case", - fileB, - contentsB[:maxExecOutputSize], - true, + name: "file to large case", + filePath: fileB, + want: contentsB[:maxExecOutputSize], + wantErr: fmt.Sprintf("contents of OutputFilePath greater than %dK", maxExecOutputSize/1024), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := execOutput(ctx, tt.filePath) - if err != nil && !tt.wantErr { - t.Errorf("Unexpected error from execOutput: %v", err) + if matchError(t, err, tt.wantErr) { + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got != want, string(got) = %q string(want) = %q", got, tt.want) + } + } + }) + } +} + +// TestExecResourcePopulateOutput verifies protobuf output assignment. +func TestExecResourcePopulateOutput(t *testing.T) { + tests := []struct { + name string + outputData []byte + wantOutput string + }{ + { + name: "With output data", + outputData: []byte("test output data"), + wantOutput: "test output data", + }, + { + name: "Nil output data", + outputData: nil, + wantOutput: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &execResource{ + enforceOutput: tt.outputData, + } + rCompliance := &agentendpointpb.OSPolicyResourceCompliance{} + e.populateOutput(rCompliance) + + var got string + if rCompliance.GetExecResourceOutput() != nil { + got = string(rCompliance.GetExecResourceOutput().GetEnforcementOutput()) } - if err == nil && tt.wantErr { - t.Error("Did not get expected error from execOutput") + + if got != tt.wantOutput { + t.Errorf("populateOutput() output = %q, want %q", got, tt.wantOutput) } + }) + } +} + +// TestExecResourceCleanup verifies cleanup of temporary directories. +func TestExecResourceCleanup(t *testing.T) { + ctx := context.Background() + + tests := []struct { + name string + setupTempDir func() string + wantErr string + }{ + { + name: "Empty temp directory", + setupTempDir: func() string { return "" }, + wantErr: "", + }, + { + name: "Valid temp directory", + setupTempDir: func() string { + return t.TempDir() + }, + wantErr: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir := tt.setupTempDir() + e := &execResource{tempDir: tmpDir} + + err := e.cleanup(ctx) + matchError(t, err, tt.wantErr) - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("got != want, string(got) = %q string(want) = %q", got, tt.want) + if tmpDir != "" { + if _, err := os.Stat(tmpDir); !os.IsNotExist(err) { + t.Errorf("cleanup() failed to remove temp directory %q", tmpDir) + } } }) } } + +// mockRunnerExpectation configures a mock command runner for a single execution. +func mockRunnerExpectation(ctx context.Context, mockCommandRunner *utilmocks.MockCommandRunner, exitCode int) { + var err error + if exitCode == -1 { + err = errors.New("some error") + } else if exitCode != 0 { + var cmd *exec.Cmd + if runtime.GOOS == "windows" { + cmd = exec.Command("cmd", "/c", fmt.Sprintf("exit %d", exitCode)) + } else { + cmd = exec.Command("sh", "-c", fmt.Sprintf("exit %d", exitCode)) + } + err = cmd.Run() + } + mockCommandRunner.EXPECT().Run(ctx, gomock.Any()).Return([]byte("stdout"), []byte("stderr"), err).Times(1) +} + +// setupMockRunner initializes a gomock controller, creates a mock command runner, +// and injects it into the global runner variable. It also registers a cleanup +// function to finish the mock controller when the test ends. +func setupMockRunner(t *testing.T) *utilmocks.MockCommandRunner { + t.Helper() + mockCtrl := gomock.NewController(t) + t.Cleanup(func() { mockCtrl.Finish() }) + + mockCommandRunner := utilmocks.NewMockCommandRunner(mockCtrl) + runner = mockCommandRunner + return mockCommandRunner +} + +// preserveGlobalState saves the current values of global variables (goos, runner) +// and registers a cleanup function to restore them after the test completes, +// preventing state pollution between tests. +func preserveGlobalState(t *testing.T) { + t.Helper() + origGoos := goos + origRunner := runner + + t.Cleanup(func() { + goos = origGoos + runner = origRunner + }) +} + +// matchError asserts if the error matches the expected error message. +// It returns true if we should continue testing (i.e. no error occurred and none was expected). +func matchError(t *testing.T, err error, wantErr string) bool { + t.Helper() + if err != nil { + if wantErr == "" { + t.Errorf("Unexpected error: %v", err) + } else if err.Error() != wantErr { + t.Errorf("error = %q, wantErr %q", err.Error(), wantErr) + } + return false + } + if wantErr != "" { + t.Errorf("Expected error %q but got nil", wantErr) + return false + } + return true +} + +// assertFileContents verifies that the file at filePath matches the expected contents. +func assertFileContents(t *testing.T, filePath string, wantContents string) { + t.Helper() + data, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read file %q: %v", filePath, err) + } + if string(data) != wantContents { + t.Errorf("File contents = %q, want %q", string(data), wantContents) + } +} From faf920a2c04c3f1d4fa8cc8b1916fca6911566d3 Mon Sep 17 00:00:00 2001 From: Ilya Tsupryk Date: Tue, 24 Mar 2026 10:29:53 +0000 Subject: [PATCH 2/5] Make changes according to review commits --- config/exec_resource_test.go | 197 +++++++++++++---------------------- testutil/testutil.go | 75 +++++++++++++ 2 files changed, 150 insertions(+), 122 deletions(-) create mode 100644 testutil/testutil.go diff --git a/config/exec_resource_test.go b/config/exec_resource_test.go index 10cea8d5a..7e7396064 100644 --- a/config/exec_resource_test.go +++ b/config/exec_resource_test.go @@ -22,11 +22,11 @@ import ( "os" "os/exec" "path/filepath" - "reflect" "runtime" "testing" "cloud.google.com/go/osconfig/agentendpoint/apiv1/agentendpointpb" + "github.com/GoogleCloudPlatform/osconfig/testutil" utilmocks "github.com/GoogleCloudPlatform/osconfig/util/mocks" "github.com/golang/mock/gomock" @@ -51,7 +51,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath string wantEnforceContents string goos string - wantErr string + wantErr error }{ { name: "Script NONE Linux", @@ -70,7 +70,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "script", wantEnforceContents: "enforce", goos: "linux", - wantErr: "", + wantErr: nil, }, { name: "Script NONE Windows", @@ -89,7 +89,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "script.cmd", wantEnforceContents: "enforce", goos: "windows", - wantErr: "", + wantErr: nil, }, { name: "Script SHELL Linux", @@ -108,7 +108,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "script.sh", wantEnforceContents: "enforce", goos: "linux", - wantErr: "", + wantErr: nil, }, { name: "Script SHELL Windows", @@ -127,7 +127,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "script.cmd", wantEnforceContents: "enforce", goos: "windows", - wantErr: "", + wantErr: nil, }, { name: "Script POWERSHELL Windows", @@ -146,7 +146,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "script.ps1", wantEnforceContents: "enforce", goos: "windows", - wantErr: "", + wantErr: nil, }, { name: "Unsupported Interpreter", @@ -161,7 +161,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "", wantEnforceContents: "", goos: "linux", - wantErr: `unsupported interpreter "99"`, + wantErr: errors.New(`unsupported interpreter "99"`), }, { name: "Unrecognized Source Type", @@ -175,7 +175,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "", wantEnforceContents: "", goos: "linux", - wantErr: `unrecognized Source type for ExecResource: %!q()`, + wantErr: errors.New(`unrecognized Source type for ExecResource: %!q()`), }, { name: "Unsupported File", @@ -192,7 +192,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "", wantEnforceContents: "", goos: "linux", - wantErr: `unsupported File `, + wantErr: errors.New(`unsupported File `), }, { name: "LocalPath File", @@ -211,7 +211,7 @@ func TestExecResourceDownload(t *testing.T) { wantEnforcePath: "", wantEnforceContents: "", goos: "linux", - wantErr: "", + wantErr: nil, }, } @@ -228,25 +228,15 @@ func TestExecResourceDownload(t *testing.T) { defer pr.Cleanup(ctx) err := pr.Validate(ctx) - if !matchError(t, err, tt.wantErr) || tt.wantErr != "" { - return - } + testutil.AssertErrorMatch(t, err, tt.wantErr) resource := pr.resource.(*execResource) - if tt.wantValidatePath != "" { - if tt.wantValidatePath != filepath.Base(resource.validatePath) { - t.Errorf("unexpected validate path: got %q, want %q", filepath.Base(resource.validatePath), tt.wantValidatePath) - } - assertFileContents(t, resource.validatePath, tt.wantValidateContents) - } + testutil.AssertFilePath(t, "validate", resource.validatePath, tt.wantValidatePath) + testutil.AssertFileContents(t, resource.validatePath, tt.wantValidateContents) - if tt.wantEnforcePath != "" { - if tt.wantEnforcePath != filepath.Base(resource.enforcePath) { - t.Errorf("unexpected enforce path: got %q, want %q", filepath.Base(resource.enforcePath), tt.wantEnforcePath) - } - assertFileContents(t, resource.enforcePath, tt.wantEnforceContents) - } + testutil.AssertFilePath(t, "enforce", resource.enforcePath, tt.wantEnforcePath) + testutil.AssertFileContents(t, resource.enforcePath, tt.wantEnforceContents) }) } } @@ -261,77 +251,77 @@ func TestExecResourceRun(t *testing.T) { goos string execR *agentendpointpb.OSPolicy_Resource_ExecResource_Exec expectedCmd *exec.Cmd - wantErr string + wantErr error }{ { name: "NONE interpreter", goos: "linux", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE}, expectedCmd: exec.Command("test_script"), - wantErr: "", + wantErr: nil, }, { name: "SHELL Linux", goos: "linux", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL}, expectedCmd: exec.Command("/bin/sh", "test_script"), - wantErr: "", + wantErr: nil, }, { name: "SHELL with args", goos: "linux", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL, Args: []string{"arg1", "arg2"}}, expectedCmd: exec.Command("/bin/sh", "test_script", "arg1", "arg2"), - wantErr: "", + wantErr: nil, }, { name: "SHELL Windows", goos: "windows", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_SHELL}, expectedCmd: exec.Command("test_script"), - wantErr: "", + wantErr: nil, }, { name: "POWERSHELL Windows", goos: "windows", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_POWERSHELL}, expectedCmd: exec.Command("C:\\Windows\\System32\\WindowsPowerShell\\v1.0\\PowerShell.exe", "-File", "test_script"), - wantErr: "", + wantErr: nil, }, { name: "POWERSHELL Linux error", goos: "linux", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_POWERSHELL}, expectedCmd: nil, - wantErr: `interpreter "POWERSHELL" can only be used on Windows systems`, + wantErr: errors.New(`interpreter "POWERSHELL" can only be used on Windows systems`), }, { name: "Unsupported interpreter", goos: "linux", execR: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_Interpreter(99)}, expectedCmd: nil, - wantErr: `unsupported interpreter "99"`, + wantErr: errors.New(`unsupported interpreter "99"`), }, { name: "Nil Exec", goos: "linux", execR: nil, expectedCmd: nil, - wantErr: `ExecResource Exec cannot be nil`, + wantErr: errors.New(`ExecResource Exec cannot be nil`), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { goos = tt.goos - e := &execResource{} + execRes := &execResource{} mockCommandRunner := setupMockRunner(t) if tt.expectedCmd != nil { mockCommandRunner.EXPECT().Run(ctx, utilmocks.EqCmd(tt.expectedCmd)).Return([]byte("stdout"), []byte("stderr"), nil).Times(1) } - _, _, _, err := e.run(ctx, "test_script", tt.execR) - matchError(t, err, tt.wantErr) + _, _, _, err := execRes.run(ctx, "test_script", tt.execR) + testutil.AssertErrorMatch(t, err, tt.wantErr) }) } } @@ -340,37 +330,36 @@ func TestExecResourceRun(t *testing.T) { func TestExecResourceCheckState(t *testing.T) { ctx := context.Background() preserveGlobalState(t) + execRes := &execResource{ + validatePath: "test_script", + OSPolicy_Resource_ExecResource: &agentendpointpb.OSPolicy_Resource_ExecResource{ + Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ + Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, + }, + }, + } var tests = []struct { name string exitCode int wantInDesiredState bool - wantErr string + wantErr error }{ - {name: "Code 100", exitCode: 100, wantInDesiredState: true, wantErr: ""}, - {name: "Code 101", exitCode: 101, wantInDesiredState: false, wantErr: ""}, - {name: "Code 0", exitCode: 0, wantInDesiredState: false, wantErr: "unexpected return code from validate: 0, stdout: stdout, stderr: stderr"}, - {name: "Code -1", exitCode: -1, wantInDesiredState: false, wantErr: "some error"}, + {name: "Code 100", exitCode: 100, wantInDesiredState: true, wantErr: nil}, + {name: "Code 101", exitCode: 101, wantInDesiredState: false, wantErr: nil}, + {name: "Code 0", exitCode: 0, wantInDesiredState: false, wantErr: errors.New("unexpected return code from validate: 0, stdout: stdout, stderr: stderr")}, + {name: "Code -1", exitCode: -1, wantInDesiredState: false, wantErr: errors.New("some error")}, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { mockCommandRunner := setupMockRunner(t) mockRunnerExpectation(ctx, mockCommandRunner, tt.exitCode) - e := &execResource{ - validatePath: "test_script", - OSPolicy_Resource_ExecResource: &agentendpointpb.OSPolicy_Resource_ExecResource{ - Validate: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ - Interpreter: agentendpointpb.OSPolicy_Resource_ExecResource_Exec_NONE, - }, - }, - } - inDesiredState, retErr := e.checkState(ctx) - if matchError(t, retErr, tt.wantErr) { - if inDesiredState != tt.wantInDesiredState { - t.Errorf("checkState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) - } + inDesiredState, err := execRes.checkState(ctx) + testutil.AssertErrorMatch(t, err, tt.wantErr) + if inDesiredState != tt.wantInDesiredState { + t.Errorf("checkState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) } }) } @@ -389,13 +378,13 @@ func TestExecResourceEnforceState(t *testing.T) { exitCode int outputFilePath string wantInDesiredState bool - wantErr string + wantErr error wantOutput string }{ - {name: "Code 100 without output", exitCode: 100, outputFilePath: "", wantInDesiredState: true, wantErr: "", wantOutput: ""}, - {name: "Code 100 with output", exitCode: 100, outputFilePath: outputFile, wantInDesiredState: true, wantErr: "", wantOutput: "my enforce output"}, - {name: "Code 0", exitCode: 0, outputFilePath: "", wantInDesiredState: false, wantErr: "unexpected return code from enforce: 0, stdout: stdout, stderr: stderr", wantOutput: ""}, - {name: "Code -1", exitCode: -1, outputFilePath: "", wantInDesiredState: false, wantErr: "some error", wantOutput: ""}, + {name: "Code 100 without output", exitCode: 100, outputFilePath: "", wantInDesiredState: true, wantErr: nil, wantOutput: ""}, + {name: "Code 100 with output", exitCode: 100, outputFilePath: outputFile, wantInDesiredState: true, wantErr: nil, wantOutput: "my enforce output"}, + {name: "Code 0", exitCode: 0, outputFilePath: "", wantInDesiredState: false, wantErr: errors.New("unexpected return code from enforce: 0, stdout: stdout, stderr: stderr"), wantOutput: ""}, + {name: "Code -1", exitCode: -1, outputFilePath: "", wantInDesiredState: false, wantErr: errors.New("some error"), wantOutput: ""}, } for _, tt := range tests { @@ -407,7 +396,7 @@ func TestExecResourceEnforceState(t *testing.T) { t.Fatal(err) } } - e := &execResource{ + execRes := &execResource{ enforcePath: "test_script", OSPolicy_Resource_ExecResource: &agentendpointpb.OSPolicy_Resource_ExecResource{ Enforce: &agentendpointpb.OSPolicy_Resource_ExecResource_Exec{ @@ -417,14 +406,13 @@ func TestExecResourceEnforceState(t *testing.T) { }, } - inDesiredState, retErr := e.enforceState(ctx) - if matchError(t, retErr, tt.wantErr) { - if inDesiredState != tt.wantInDesiredState { - t.Errorf("enforceState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) - } - if tt.wantOutput != "" && string(e.enforceOutput) != tt.wantOutput { - t.Errorf("enforceState() output = %q, want %q", string(e.enforceOutput), tt.wantOutput) - } + inDesiredState, err := execRes.enforceState(ctx) + testutil.AssertErrorMatch(t, err, tt.wantErr) + if inDesiredState != tt.wantInDesiredState { + t.Errorf("enforceState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) + } + if tt.wantOutput != "" && string(execRes.enforceOutput) != tt.wantOutput { + t.Errorf("enforceState() output = %q, want %q", string(execRes.enforceOutput), tt.wantOutput) } }) } @@ -450,48 +438,45 @@ func TestExecOutput(t *testing.T) { t.Fatal(err) } - _, errDNE := os.Open("DNE") - expectedDNEErr := fmt.Sprintf("error opening OutputFilePath: %v", errDNE) + _, err := os.Open("DNE") + wantedDoNotExistErrorMessage := fmt.Sprintf("error opening OutputFilePath: %v", err) var tests = []struct { name string filePath string want []byte - wantErr string + wantErr error }{ { name: "empty path", filePath: "", want: nil, - wantErr: "", + wantErr: nil, }, { name: "path DNE", filePath: "DNE", want: nil, - wantErr: expectedDNEErr, + wantErr: errors.New(wantedDoNotExistErrorMessage), }, { name: "normal case", filePath: fileA, want: contentsA, - wantErr: "", + wantErr: nil, }, { name: "file to large case", filePath: fileB, want: contentsB[:maxExecOutputSize], - wantErr: fmt.Sprintf("contents of OutputFilePath greater than %dK", maxExecOutputSize/1024), + wantErr: fmt.Errorf("contents of OutputFilePath greater than %dK", maxExecOutputSize/1024), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := execOutput(ctx, tt.filePath) - if matchError(t, err, tt.wantErr) { - if !reflect.DeepEqual(got, tt.want) { - t.Errorf("got != want, string(got) = %q string(want) = %q", got, tt.want) - } - } + testutil.AssertErrorMatch(t, err, tt.wantErr) + testutil.EnsureEquals(t, got, tt.want) }) } } @@ -517,11 +502,11 @@ func TestExecResourcePopulateOutput(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - e := &execResource{ + execRes := &execResource{ enforceOutput: tt.outputData, } rCompliance := &agentendpointpb.OSPolicyResourceCompliance{} - e.populateOutput(rCompliance) + execRes.populateOutput(rCompliance) var got string if rCompliance.GetExecResourceOutput() != nil { @@ -542,30 +527,29 @@ func TestExecResourceCleanup(t *testing.T) { tests := []struct { name string setupTempDir func() string - wantErr string + wantErr error }{ { name: "Empty temp directory", setupTempDir: func() string { return "" }, - wantErr: "", + wantErr: nil, }, { name: "Valid temp directory", setupTempDir: func() string { return t.TempDir() }, - wantErr: "", + wantErr: nil, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { tmpDir := tt.setupTempDir() - e := &execResource{tempDir: tmpDir} - - err := e.cleanup(ctx) - matchError(t, err, tt.wantErr) + execRes := &execResource{tempDir: tmpDir} + err := execRes.cleanup(ctx) + testutil.AssertErrorMatch(t, err, tt.wantErr) if tmpDir != "" { if _, err := os.Stat(tmpDir); !os.IsNotExist(err) { t.Errorf("cleanup() failed to remove temp directory %q", tmpDir) @@ -618,34 +602,3 @@ func preserveGlobalState(t *testing.T) { runner = origRunner }) } - -// matchError asserts if the error matches the expected error message. -// It returns true if we should continue testing (i.e. no error occurred and none was expected). -func matchError(t *testing.T, err error, wantErr string) bool { - t.Helper() - if err != nil { - if wantErr == "" { - t.Errorf("Unexpected error: %v", err) - } else if err.Error() != wantErr { - t.Errorf("error = %q, wantErr %q", err.Error(), wantErr) - } - return false - } - if wantErr != "" { - t.Errorf("Expected error %q but got nil", wantErr) - return false - } - return true -} - -// assertFileContents verifies that the file at filePath matches the expected contents. -func assertFileContents(t *testing.T, filePath string, wantContents string) { - t.Helper() - data, err := os.ReadFile(filePath) - if err != nil { - t.Fatalf("Failed to read file %q: %v", filePath, err) - } - if string(data) != wantContents { - t.Errorf("File contents = %q, want %q", string(data), wantContents) - } -} diff --git a/testutil/testutil.go b/testutil/testutil.go new file mode 100644 index 000000000..ecddff564 --- /dev/null +++ b/testutil/testutil.go @@ -0,0 +1,75 @@ +// Copyright 2026 Google Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package testutil provides common testing utility functions for the osconfig agent. +package testutil + +import ( + "os" + "path/filepath" + "reflect" + "testing" +) + +// EnsureEquals checks if got and want are deeply equal. If not, it fails the test. +func EnsureEquals(t *testing.T, got interface{}, want interface{}) { + t.Helper() + if !reflect.DeepEqual(got, want) { + t.Errorf("got != want, got = %q want = %q", got, want) + } +} + +// AssertErrorMatch verifies that the gotErr matches the wantErr type and message. +func AssertErrorMatch(t *testing.T, gotErr, wantErr error) { + t.Helper() + if gotErr == nil && wantErr == nil { + return + } + if gotErr == nil || wantErr == nil { + t.Errorf("Errors mismatch, want %v, got %v", wantErr, gotErr) + return + } + if reflect.TypeOf(gotErr) != reflect.TypeOf(wantErr) || gotErr.Error() != wantErr.Error() { + t.Errorf("Unexpected error, want %v, got %v", wantErr, gotErr) + } +} + +// AssertFilePath verifies that the file path base matches the expected path base. +func AssertFilePath(t *testing.T, pathType string, gotPath string, wantPath string) { + t.Helper() + if wantPath == "" { + if gotPath != "" { + t.Errorf("unexpected %s path: got %q, want empty", pathType, gotPath) + } + return + } + if wantPath != filepath.Base(gotPath) { + t.Errorf("unexpected %s path: got %q, want %q", pathType, filepath.Base(gotPath), wantPath) + } +} + +// AssertFileContents verifies that the file at filePath matches the expected contents. +func AssertFileContents(t *testing.T, filePath string, wantContents string) { + t.Helper() + if filePath == "" { + return + } + data, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read file %q: %v", filePath, err) + } + if string(data) != wantContents { + t.Errorf("File contents = %q, want %q", string(data), wantContents) + } +} From c22fde34497aaa3625049d6b9e7eb16e15af1279 Mon Sep 17 00:00:00 2001 From: Ilya Tsupryk Date: Tue, 24 Mar 2026 14:37:02 +0000 Subject: [PATCH 3/5] Fix review comments --- config/exec_resource_test.go | 31 +++++++-------- testutil/testutil.go | 75 ------------------------------------ util/utiltest/utiltest.go | 54 ++++++++++++++++++++++++++ 3 files changed, 68 insertions(+), 92 deletions(-) delete mode 100644 testutil/testutil.go diff --git a/config/exec_resource_test.go b/config/exec_resource_test.go index 7e7396064..5a74d73ce 100644 --- a/config/exec_resource_test.go +++ b/config/exec_resource_test.go @@ -26,7 +26,7 @@ import ( "testing" "cloud.google.com/go/osconfig/agentendpoint/apiv1/agentendpointpb" - "github.com/GoogleCloudPlatform/osconfig/testutil" + "github.com/GoogleCloudPlatform/osconfig/util/utiltest" utilmocks "github.com/GoogleCloudPlatform/osconfig/util/mocks" "github.com/golang/mock/gomock" @@ -228,15 +228,15 @@ func TestExecResourceDownload(t *testing.T) { defer pr.Cleanup(ctx) err := pr.Validate(ctx) - testutil.AssertErrorMatch(t, err, tt.wantErr) + utiltest.AssertErrorMatch(t, err, tt.wantErr) resource := pr.resource.(*execResource) - testutil.AssertFilePath(t, "validate", resource.validatePath, tt.wantValidatePath) - testutil.AssertFileContents(t, resource.validatePath, tt.wantValidateContents) + utiltest.AssertFilePath(t, resource.validatePath, tt.wantValidatePath) + utiltest.AssertFileContents(t, resource.validatePath, tt.wantValidateContents) - testutil.AssertFilePath(t, "enforce", resource.enforcePath, tt.wantEnforcePath) - testutil.AssertFileContents(t, resource.enforcePath, tt.wantEnforceContents) + utiltest.AssertFilePath(t, resource.enforcePath, tt.wantEnforcePath) + utiltest.AssertFileContents(t, resource.enforcePath, tt.wantEnforceContents) }) } } @@ -321,7 +321,7 @@ func TestExecResourceRun(t *testing.T) { } _, _, _, err := execRes.run(ctx, "test_script", tt.execR) - testutil.AssertErrorMatch(t, err, tt.wantErr) + utiltest.AssertErrorMatch(t, err, tt.wantErr) }) } } @@ -357,7 +357,7 @@ func TestExecResourceCheckState(t *testing.T) { mockRunnerExpectation(ctx, mockCommandRunner, tt.exitCode) inDesiredState, err := execRes.checkState(ctx) - testutil.AssertErrorMatch(t, err, tt.wantErr) + utiltest.AssertErrorMatch(t, err, tt.wantErr) if inDesiredState != tt.wantInDesiredState { t.Errorf("checkState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) } @@ -407,11 +407,11 @@ func TestExecResourceEnforceState(t *testing.T) { } inDesiredState, err := execRes.enforceState(ctx) - testutil.AssertErrorMatch(t, err, tt.wantErr) + utiltest.AssertErrorMatch(t, err, tt.wantErr) if inDesiredState != tt.wantInDesiredState { t.Errorf("enforceState() inDesiredState = %v, want %v", inDesiredState, tt.wantInDesiredState) } - if tt.wantOutput != "" && string(execRes.enforceOutput) != tt.wantOutput { + if string(execRes.enforceOutput) != tt.wantOutput { t.Errorf("enforceState() output = %q, want %q", string(execRes.enforceOutput), tt.wantOutput) } }) @@ -475,8 +475,8 @@ func TestExecOutput(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := execOutput(ctx, tt.filePath) - testutil.AssertErrorMatch(t, err, tt.wantErr) - testutil.EnsureEquals(t, got, tt.want) + utiltest.AssertErrorMatch(t, err, tt.wantErr) + utiltest.EnsureEquals(t, got, tt.want) }) } } @@ -512,10 +512,7 @@ func TestExecResourcePopulateOutput(t *testing.T) { if rCompliance.GetExecResourceOutput() != nil { got = string(rCompliance.GetExecResourceOutput().GetEnforcementOutput()) } - - if got != tt.wantOutput { - t.Errorf("populateOutput() output = %q, want %q", got, tt.wantOutput) - } + utiltest.EnsureEquals(t, got, tt.wantOutput) }) } } @@ -549,7 +546,7 @@ func TestExecResourceCleanup(t *testing.T) { execRes := &execResource{tempDir: tmpDir} err := execRes.cleanup(ctx) - testutil.AssertErrorMatch(t, err, tt.wantErr) + utiltest.AssertErrorMatch(t, err, tt.wantErr) if tmpDir != "" { if _, err := os.Stat(tmpDir); !os.IsNotExist(err) { t.Errorf("cleanup() failed to remove temp directory %q", tmpDir) diff --git a/testutil/testutil.go b/testutil/testutil.go deleted file mode 100644 index ecddff564..000000000 --- a/testutil/testutil.go +++ /dev/null @@ -1,75 +0,0 @@ -// Copyright 2026 Google Inc. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package testutil provides common testing utility functions for the osconfig agent. -package testutil - -import ( - "os" - "path/filepath" - "reflect" - "testing" -) - -// EnsureEquals checks if got and want are deeply equal. If not, it fails the test. -func EnsureEquals(t *testing.T, got interface{}, want interface{}) { - t.Helper() - if !reflect.DeepEqual(got, want) { - t.Errorf("got != want, got = %q want = %q", got, want) - } -} - -// AssertErrorMatch verifies that the gotErr matches the wantErr type and message. -func AssertErrorMatch(t *testing.T, gotErr, wantErr error) { - t.Helper() - if gotErr == nil && wantErr == nil { - return - } - if gotErr == nil || wantErr == nil { - t.Errorf("Errors mismatch, want %v, got %v", wantErr, gotErr) - return - } - if reflect.TypeOf(gotErr) != reflect.TypeOf(wantErr) || gotErr.Error() != wantErr.Error() { - t.Errorf("Unexpected error, want %v, got %v", wantErr, gotErr) - } -} - -// AssertFilePath verifies that the file path base matches the expected path base. -func AssertFilePath(t *testing.T, pathType string, gotPath string, wantPath string) { - t.Helper() - if wantPath == "" { - if gotPath != "" { - t.Errorf("unexpected %s path: got %q, want empty", pathType, gotPath) - } - return - } - if wantPath != filepath.Base(gotPath) { - t.Errorf("unexpected %s path: got %q, want %q", pathType, filepath.Base(gotPath), wantPath) - } -} - -// AssertFileContents verifies that the file at filePath matches the expected contents. -func AssertFileContents(t *testing.T, filePath string, wantContents string) { - t.Helper() - if filePath == "" { - return - } - data, err := os.ReadFile(filePath) - if err != nil { - t.Fatalf("Failed to read file %q: %v", filePath, err) - } - if string(data) != wantContents { - t.Errorf("File contents = %q, want %q", string(data), wantContents) - } -} diff --git a/util/utiltest/utiltest.go b/util/utiltest/utiltest.go index 930f8018a..78a9e6b2b 100644 --- a/util/utiltest/utiltest.go +++ b/util/utiltest/utiltest.go @@ -3,6 +3,8 @@ package utiltest import ( "errors" "os" + "path/filepath" + "reflect" "testing" "github.com/google/go-cmp/cmp" @@ -81,3 +83,55 @@ func MatchSnapshot(t testReporter, actual any, snapshotFilepath string) { removeSnapshotDraft(snapshotFilepath) } } + +// EnsureEquals checks if got and want are deeply equal. If not, it fails the test. +func EnsureEquals(t *testing.T, got interface{}, want interface{}) { + t.Helper() + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("got != want (-want +got):\n%s", diff) + } +} + +// AssertErrorMatch verifies that the gotErr matches the wantErr type and message. +func AssertErrorMatch(t *testing.T, gotErr, wantErr error) { + t.Helper() + if gotErr == nil && wantErr == nil { + return + } + if gotErr == nil || wantErr == nil { + t.Errorf("Errors mismatch, want %v, got %v", wantErr, gotErr) + return + } + if reflect.TypeOf(gotErr) != reflect.TypeOf(wantErr) || gotErr.Error() != wantErr.Error() { + t.Errorf("Unexpected error, want %v, got %v", wantErr, gotErr) + } +} + +// AssertFilePath verifies that the file path base matches the expected path base. +func AssertFilePath(t *testing.T, gotPath string, wantPath string) { + t.Helper() + if wantPath == "" { + if gotPath != "" { + t.Errorf("unexpected path: got %q, want empty", gotPath) + } + return + } + if diff := cmp.Diff(wantPath, filepath.Base(gotPath)); diff != "" { + t.Errorf("unexpected path (-want +got):\n%s", diff) + } +} + +// AssertFileContents verifies that the file at filePath matches the expected contents. +func AssertFileContents(t *testing.T, filePath string, wantContents string) { + t.Helper() + if filePath == "" { + return + } + data, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read file %q: %v", filePath, err) + } + if diff := cmp.Diff(wantContents, string(data)); diff != "" { + t.Errorf("File contents mismatch (-want +got):\n%s", diff) + } +} From 66da7de9ea9803be13bf1236ef3e7f12ccdc4203 Mon Sep 17 00:00:00 2001 From: Ilya Tsupryk Date: Thu, 12 Mar 2026 18:32:12 +0000 Subject: [PATCH 4/5] Cover agentconfig functionality by unit tests --- agentconfig/agentconfig.go | 15 +- agentconfig/agentconfig_test.go | 1236 +++++++++++++++++++++++++++++-- 2 files changed, 1169 insertions(+), 82 deletions(-) diff --git a/agentconfig/agentconfig.go b/agentconfig/agentconfig.go index 99c4adfb1..a2fbaf42a 100644 --- a/agentconfig/agentconfig.go +++ b/agentconfig/agentconfig.go @@ -104,6 +104,7 @@ var ( capabilities = []string{"PATCH_GA", "GUEST_POLICY_BETA", "CONFIG_V1"} osConfigWatchConfigTimeout = 10 * time.Minute + watchConfigRetryInterval = 5 * time.Second defaultClient = &http.Client{ Transport: &http.Transport{ @@ -116,6 +117,8 @@ var ( freeOSMemory = strings.ToLower(os.Getenv("OSCONFIG_FREE_OS_MEMORY")) disableInventoryWrite = strings.ToLower(os.Getenv("OSCONFIG_DISABLE_INVENTORY_WRITE")) + + goos = runtime.GOOS ) type config struct { @@ -487,7 +490,7 @@ func WatchConfig(ctx context.Context) error { // Max watch time, after this WatchConfig will return. timeout := time.After(osConfigWatchConfigTimeout) // Min watch loop time. - loopTicker := time.NewTicker(5 * time.Second) + loopTicker := time.NewTicker(watchConfigRetryInterval) defer loopTicker.Stop() eTag := lEtag.get() webErrorCount := 0 @@ -561,7 +564,7 @@ func SvcPollInterval() time.Duration { // SerialLogPort is the serial port to log to. func SerialLogPort() string { - if runtime.GOOS == "windows" { + if goos == "windows" { return "COM1" } // Don't write directly to the serial port on Linux as syslog already writes there. @@ -767,7 +770,7 @@ func Capabilities() []string { // TaskStateFile is the location of the task state file. func TaskStateFile() string { - if runtime.GOOS == "windows" { + if goos == "windows" { return filepath.Join(GetCacheDirWindows(), "osconfig_task.state") } @@ -776,7 +779,7 @@ func TaskStateFile() string { // OldTaskStateFile is the location of the task state file. func OldTaskStateFile() string { - if runtime.GOOS == "windows" { + if goos == "windows" { return oldTaskStateFileWindows } return oldTaskStateFileLinux @@ -784,7 +787,7 @@ func OldTaskStateFile() string { // RestartFile is the location of the restart required file. func RestartFile() string { - if runtime.GOOS == "windows" { + if goos == "windows" { return filepath.Join( GetCacheDirWindows(), "osconfig_agent_restart_required") } @@ -799,7 +802,7 @@ func OldRestartFile() string { // CacheDir is the location of the cache directory. func CacheDir() string { - if runtime.GOOS == "windows" { + if goos == "windows" { return GetCacheDirWindows() } diff --git a/agentconfig/agentconfig_test.go b/agentconfig/agentconfig_test.go index 16a0d95c1..5fdd6dcf3 100644 --- a/agentconfig/agentconfig_test.go +++ b/agentconfig/agentconfig_test.go @@ -16,27 +16,42 @@ package agentconfig import ( "context" + "encoding/base64" + "encoding/json" + "errors" "fmt" + "io/ioutil" + "net" "net/http" "net/http/httptest" + "net/url" "os" "path/filepath" "reflect" "runtime" "strings" + "sync" "testing" "time" ) -func TestWatchConfig(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - fmt.Fprintln(w, `{"project":{"numericProjectID":12345,"projectId":"projectId","attributes":{"osconfig-endpoint":"bad!!1","enable-os-inventory":"false"}},"instance":{"id":12345,"name":"name","zone":"zone","attributes":{"osconfig-endpoint":"SvcEndpoint","enable-os-inventory":"1","enable-os-config-debug":"true","osconfig-enabled-prerelease-features":"ospackage,ospatch", "osconfig-poll-interval":"3"}}}`) - })) - defer ts.Close() +// setupMockMetadataServer starts an httptest.Server with the provided handler and overrides the GCE_METADATA_HOST environment variable. +// It also registers cleanup functions to close the server and restore the environment variable. +func setupMockMetadataServer(t *testing.T, handler http.HandlerFunc) *httptest.Server { + t.Helper() + ts := httptest.NewServer(handler) + t.Cleanup(ts.Close) - if err := os.Setenv("GCE_METADATA_HOST", strings.Trim(ts.URL, "http://")); err != nil { - t.Fatalf("Error running os.Setenv: %v", err) - } + rollback := OverrideEnv(t, "GCE_METADATA_HOST", strings.TrimPrefix(ts.URL, "http://")) + t.Cleanup(rollback) + + return ts +} + +func TestWatchConfig(t *testing.T) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, `{"project":{"numericProjectID":12345,"projectId":"projectId","attributes":{"osconfig-endpoint":"bad!!1","enable-os-inventory":"false"}},"instance":{"id":12345,"name":"name","zone":"zone","attributes":{"osconfig-endpoint":"SvcEndpoint","enable-os-inventory":"1","enable-os-config-debug":"true","osconfig-enabled-prerelease-features":"ospackage,ospatch", "osconfig-poll-interval":"3", "enable-scalibr-linux":"true", "trace-get-inventory":"true", "enable-guest-attributes":"true"}}}`) + }) if err := WatchConfig(context.Background()); err != nil { t.Fatalf("Error running WatchConfig: %v", err) @@ -47,12 +62,12 @@ func TestWatchConfig(t *testing.T) { op func() string want string }{ - {"SvcEndpoint", SvcEndpoint, "SvcEndpoint"}, - {"Instance", Instance, "zone/instances/name"}, - {"ID", ID, "12345"}, - {"ProjectID", ProjectID, "projectId"}, - {"Zone", Zone, "zone"}, - {"Name", Name, "name"}, + {desc: "SvcEndpoint", op: SvcEndpoint, want: "SvcEndpoint"}, + {desc: "Instance", op: Instance, want: "zone/instances/name"}, + {desc: "ID", op: ID, want: "12345"}, + {desc: "ProjectID", op: ProjectID, want: "projectId"}, + {desc: "Zone", op: Zone, want: "zone"}, + {desc: "Name", op: Name, want: "name"}, } for _, tt := range testsString { if tt.op() != tt.want { @@ -65,10 +80,13 @@ func TestWatchConfig(t *testing.T) { op func() bool want bool }{ - {"osinventory should be enabled (proj disabled, inst enabled)", OSInventoryEnabled, true}, - {"taskNotification should be enabled (inst enabled)", TaskNotificationEnabled, true}, - {"guestpolicies should be enabled (proj enabled)", GuestPoliciesEnabled, true}, - {"debugenabled should be true (proj disabled, inst enabled)", Debug, true}, + {desc: "osinventory should be enabled (proj disabled, inst enabled)", op: OSInventoryEnabled, want: true}, + {desc: "taskNotification should be enabled (inst enabled)", op: TaskNotificationEnabled, want: true}, + {desc: "guestpolicies should be enabled (proj enabled)", op: GuestPoliciesEnabled, want: true}, + {desc: "debugenabled should be true (proj disabled, inst enabled)", op: Debug, want: true}, + {desc: "scalibrLinuxEnabled should be true", op: ScalibrLinuxEnabled, want: true}, + {desc: "traceGetInventory should be true", op: TraceGetInventory, want: true}, + {desc: "guestAttributesEnabled should be true", op: GuestAttributesEnabled, want: true}, } for _, tt := range testsBool { if tt.op() != tt.want { @@ -90,7 +108,7 @@ func TestWatchConfig(t *testing.T) { func TestSetConfigEnabled(t *testing.T) { var request int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { switch request { case 0: w.Header().Set("Etag", "etag-0") @@ -105,12 +123,7 @@ func TestSetConfigEnabled(t *testing.T) { w.Header().Set("Etag", "etag-3") fmt.Fprintln(w, `{"project":{"attributes":{"enable-osconfig":"true","osconfig-disabled-features":"osinventory"}}}`) } - })) - defer ts.Close() - - if err := os.Setenv("GCE_METADATA_HOST", strings.Trim(ts.URL, "http://")); err != nil { - t.Fatalf("Error running os.Setenv: %v", err) - } + }) for i, want := range []bool{false, true, false} { request = i @@ -122,9 +135,9 @@ func TestSetConfigEnabled(t *testing.T) { desc string op func() bool }{ - {"OSInventoryEnabled", OSInventoryEnabled}, - {"TaskNotificationEnabled", TaskNotificationEnabled}, - {"GuestPoliciesEnabled", GuestPoliciesEnabled}, + {desc: "OSInventoryEnabled", op: OSInventoryEnabled}, + {desc: "TaskNotificationEnabled", op: TaskNotificationEnabled}, + {desc: "GuestPoliciesEnabled", op: GuestPoliciesEnabled}, } for _, tt := range testsBool { if tt.op() != want { @@ -143,9 +156,9 @@ func TestSetConfigEnabled(t *testing.T) { op func() bool want bool }{ - {"OSInventoryEnabled", OSInventoryEnabled, false}, - {"TaskNotificationEnabled", TaskNotificationEnabled, true}, - {"GuestPoliciesEnabled", GuestPoliciesEnabled, true}, + {desc: "OSInventoryEnabled", op: OSInventoryEnabled, want: false}, + {desc: "TaskNotificationEnabled", op: TaskNotificationEnabled, want: true}, + {desc: "GuestPoliciesEnabled", op: GuestPoliciesEnabled, want: true}, } for _, tt := range testsBool { if tt.op() != tt.want { @@ -155,16 +168,11 @@ func TestSetConfigEnabled(t *testing.T) { } func TestSetConfigDefaultValues(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Etag", "sample-etag") // we always get zone value in instance metadata. fmt.Fprintln(w, `{"instance": {"zone": "fake-zone"}}`) - })) - defer ts.Close() - - if err := os.Setenv("GCE_METADATA_HOST", strings.Trim(ts.URL, "http://")); err != nil { - t.Fatalf("Error running os.Setenv: %v", err) - } + }) if err := WatchConfig(context.Background()); err != nil { t.Fatalf("Error running SetConfig: %v", err) @@ -174,10 +182,19 @@ func TestSetConfigDefaultValues(t *testing.T) { op func() string want string }{ - {AptRepoFilePath, aptRepoFilePath}, - {YumRepoFilePath, yumRepoFilePath}, - {ZypperRepoFilePath, zypperRepoFilePath}, - {GooGetRepoFilePath, googetRepoFilePath}, + {op: AptRepoFilePath, want: aptRepoFilePath}, + {op: YumRepoFilePath, want: yumRepoFilePath}, + {op: ZypperRepoFilePath, want: zypperRepoFilePath}, + {op: GooGetRepoFilePath, want: googetRepoFilePath}, + {op: ZypperRepoDir, want: zypperRepoDir}, + {op: ZypperRepoFormat, want: filepath.Join(zypperRepoDir, "osconfig_managed_%s.repo")}, + {op: YumRepoDir, want: yumRepoDir}, + {op: YumRepoFormat, want: filepath.Join(yumRepoDir, "osconfig_managed_%s.repo")}, + {op: AptRepoDir, want: aptRepoDir}, + {op: AptRepoFormat, want: filepath.Join(aptRepoDir, "osconfig_managed_%s.list")}, + {op: GooGetRepoDir, want: googetRepoDir}, + {op: GooGetRepoFormat, want: filepath.Join(googetRepoDir, "osconfig_managed_%s.repo")}, + {op: UniverseDomain, want: universeDomainDefault}, } for _, tt := range testsString { if tt.op() != tt.want { @@ -190,10 +207,10 @@ func TestSetConfigDefaultValues(t *testing.T) { op func() bool want bool }{ - {OSInventoryEnabled, osInventoryEnabledDefault}, - {TaskNotificationEnabled, taskNotificationEnabledDefault}, - {GuestPoliciesEnabled, guestPoliciesEnabledDefault}, - {Debug, debugEnabledDefault}, + {op: OSInventoryEnabled, want: osInventoryEnabledDefault}, + {op: TaskNotificationEnabled, want: taskNotificationEnabledDefault}, + {op: GuestPoliciesEnabled, want: guestPoliciesEnabledDefault}, + {op: Debug, want: debugEnabledDefault}, } for _, tt := range testsBool { if tt.op() != tt.want { @@ -212,6 +229,129 @@ func TestSetConfigDefaultValues(t *testing.T) { } } +// TestWatchConfigUnchangedConfigTimeout tests how the agent behaves when it receives +// updates from the metadata server, but the actual configuration data hasn't changed. +// +// The agent checks the SHA256 hash of the new data. If the hash is identical to +// the current configuration, it knows the update is superficial. Instead of +// applying the configuration and exiting, the agent should ignore the update and +// keep polling for real changes. This test verifies that the agent correctly +// continues to wait until its internal timeout runs out, and then exits normally. +func TestWatchConfigUnchangedConfigTimeout(t *testing.T) { + defer OverrideWatchConfigTimeouts(1*time.Millisecond, 10*time.Millisecond)() + + var count int + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { + count++ + w.Header().Set("Etag", fmt.Sprintf("etag-%d", count)) + w.Header().Set("Metadata-Flavor", "Google") + // Return exactly the same config on every request so asSha256() matches + fmt.Fprint(w, `{}`) + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + err := WatchConfig(ctx) + if err != nil { + t.Errorf("Expected nil error on timeout, got: %v", err) + } + if ctx.Err() != nil { + t.Errorf("Test context timed out before internal timeout fired: %v", ctx.Err()) + } +} + +// TestWatchConfigWebErrorLimit tests how WatchConfig handles network errors when it +// can't reach the metadata server. The test creates a situation where the agent +// can't connect to the server and checks that the agent retries the connection +// up to a limit of 12 times before giving up and reporting an error. +func TestWatchConfigWebErrorLimit(t *testing.T) { + lEtag.set("0") + defer OverrideWatchConfigTimeouts(1*time.Millisecond, 1*time.Second)() + defer OverrideEnv(t, "GCE_METADATA_HOST", "mock-host")() + + mockNetErr := &net.OpError{ + Op: "dial", + Net: "tcp", + Err: errors.New("connection refused"), + } + defer MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { + return nil, mockNetErr + })() + + err := WatchConfig(context.Background()) + if err == nil { + t.Fatal("Expected network error, got nil") + } + + expectedBaseErr := &url.Error{ + Op: "Get", + URL: "http://mock-host/computeMetadata/v1/?recursive=true&alt=json&wait_for_change=true&last_etag=0&timeout_sec=60", + Err: mockNetErr, + } + expectedErr := fmt.Errorf("network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w", expectedBaseErr) + if err.Error() != expectedErr.Error() { + t.Errorf("Expected exact error:\n%q\nGot:\n%q", expectedErr.Error(), err.Error()) + } +} + +// TestWatchConfigUnmarshalErrorLimit tests how WatchConfig handles bad or incomplete +// data from the metadata server. The test gives the agent a broken configuration +// response and verifies that the agent tries to read it again up to a limit of 3 +// times before it stops and reports an error. +func TestWatchConfigUnmarshalErrorLimit(t *testing.T) { + defer OverrideWatchConfigTimeouts(1*time.Millisecond, 1*time.Second)() + + badJSON := []byte(`{"bad json"`) + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Etag", fmt.Sprintf("unmarshal-error-etag-%d", time.Now().UnixNano())) + w.Header().Set("Metadata-Flavor", "Google") + w.Write(badJSON) + }) + + err := WatchConfig(context.Background()) + if err == nil { + t.Fatal("Expected unmarshal error, got nil") + } + + var dummy metadataJSON + expectedErr := json.Unmarshal(badJSON, &dummy) + if err.Error() != expectedErr.Error() { + t.Errorf("Expected exact error:\n%q\nGot:\n%q", expectedErr.Error(), err.Error()) + } +} + +// TestWatchConfigContextCancel tests that the WatchConfig function can be stopped +// correctly. It checks that if another part of the program tells WatchConfig to +// cancel, it stops immediately without waiting for a timeout or retrying failed +// requests. +func TestWatchConfigContextCancel(t *testing.T) { + defer OverrideWatchConfigTimeouts(1*time.Minute, 1*time.Minute)() + + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Etag", fmt.Sprintf("cancel-etag-%d", time.Now().UnixNano())) + w.Header().Set("Metadata-Flavor", "Google") + fmt.Fprint(w, `{"bad json"`) // Trigger unmarshal error loop which checks context + }) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() // Cancel the context immediately prior to passing it in + + if err := WatchConfig(ctx); err != nil { + t.Errorf("Expected nil error on context cancellation, got: %v", err) + } +} + +func TestSetConfigError(t *testing.T) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) {}) + + osConfigWatchConfigTimeout = 1 * time.Millisecond + + if err := WatchConfig(context.Background()); err == nil || !strings.Contains(err.Error(), "unexpected end of JSON input") { + t.Errorf("Unexpected output %+v", err) + } +} + func TestVersion(t *testing.T) { if Version() != "" { t.Errorf("Unexpected version %q, want \"\"", Version()) @@ -223,9 +363,590 @@ func TestVersion(t *testing.T) { } } +// TestLoggingFlags tests logging setting accessors against command-line flags. +func TestLoggingFlags(t *testing.T) { + origStdout := *stdout + origDisableLocalLogging := *disableLocalLogging + defer func() { + *stdout = origStdout + *disableLocalLogging = origDisableLocalLogging + }() + + *stdout = true + *disableLocalLogging = true + if !Stdout() { + t.Errorf("Stdout() = false, want true") + } + if !DisableLocalLogging() { + t.Errorf("DisableLocalLogging() = false, want true") + } + + *stdout = false + *disableLocalLogging = false + if Stdout() { + t.Errorf("Stdout() = true, want false") + } + if DisableLocalLogging() { + t.Errorf("DisableLocalLogging() = true, want false") + } +} + +// TestLogFeatures tests that feature status logging executes without panicking. +func TestLogFeatures(t *testing.T) { + LogFeatures(context.Background()) +} + +// TestIDToken tests getting and understanding the instance identity token from the +// metadata server. It checks valid tokens, caching behavior, and error handling +// (e.g. HTTP 500 or malformed tokens). +func TestIDToken(t *testing.T) { + // Create a valid dummy JWS token + // Header: {"alg":"RS256","typ":"JWT"} -> eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9 + // Payload: {"exp": 4102444800} (January 1, 2100 00:00:00 UTC) -> eyJleHAiOiA0MTAyNDQ0ODAwfQ + // Signature: dummy -> ZHVtbXk + validToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJleHAiOiA0MTAyNDQ0ODAwfQ.ZHVtbXk" + + // Create a token that expires in 5 minutes to test caching fallback. + // The agent re-requests the token if the expiry is within 10 minutes. + expTime := time.Now().Add(5 * time.Minute).Unix() + payload := fmt.Sprintf(`{"exp": %d}`, expTime) + payloadB64 := base64.RawURLEncoding.EncodeToString([]byte(payload)) + expiringToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9." + payloadB64 + ".ZHVtbXk" + + tests := []struct { + name string + handler http.HandlerFunc + numCalls int + wantToken string + wantErr bool + wantRequests int + }{ + { + name: "Valid token with caching", + handler: func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/computeMetadata/v1/instance/service-accounts/default/identity") { + w.Header().Set("Metadata-Flavor", "Google") + fmt.Fprint(w, validToken) + return + } + http.NotFound(w, r) + }, + numCalls: 2, + wantToken: validToken, + wantErr: false, + wantRequests: 1, // Only 1 request should be made due to caching + }, + { + name: "Expiring token forces re-fetch", + handler: func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/computeMetadata/v1/instance/service-accounts/default/identity") { + w.Header().Set("Metadata-Flavor", "Google") + fmt.Fprint(w, expiringToken) + return + } + http.NotFound(w, r) + }, + numCalls: 2, + wantToken: expiringToken, + wantErr: false, + wantRequests: 2, // Token is within 10m of expiry, should trigger a fetch on every call + }, + { + name: "HTTP 500 error", + handler: func(w http.ResponseWriter, r *http.Request) { + http.Error(w, "internal error", http.StatusInternalServerError) + }, + numCalls: 1, + wantErr: true, + // The compute/metadata client library automatically retries on 500 errors (1 initial + 5 retries). + wantRequests: 6, + }, + { + name: "Malformed token", + handler: func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Metadata-Flavor", "Google") + fmt.Fprint(w, "not.a.valid.token") + }, + numCalls: 1, + wantErr: true, + wantRequests: 1, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var requests int + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { + requests++ + tt.handler(w, r) + }) + + identity = idToken{} + + var token string + var err error + for i := 0; i < tt.numCalls; i++ { + token, err = IDToken() + } + + if (err != nil) != tt.wantErr { + t.Fatalf("IDToken() error = %v, wantErr %v", err, tt.wantErr) + } + if err == nil && token != tt.wantToken { + t.Errorf("IDToken() = %q, want %q", token, tt.wantToken) + } + if requests != tt.wantRequests { + t.Errorf("Expected %d HTTP requests, got %d", tt.wantRequests, requests) + } + }) + } +} + +// TestFormatMetadataError verifies that network and DNS errors are wrapped with helpful context. +func TestFormatMetadataError(t *testing.T) { + errStandard := fmt.Errorf("standard error") + errDNS := &url.Error{Err: &net.DNSError{Err: "no such host"}} + errNet := &url.Error{Err: &net.OpError{Op: "dial", Net: "tcp"}} + + tests := []struct { + name string + inputErr error + wantExact error + wantContain string + }{ + { + name: "standard error", + inputErr: errStandard, + wantExact: errStandard, + }, + { + name: "DNS error", + inputErr: errDNS, + wantContain: "DNS error when requesting metadata", + }, + { + name: "network error", + inputErr: errNet, + wantContain: "network error when requesting metadata", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := formatMetadataError(tt.inputErr) + if tt.wantExact != nil && got != tt.wantExact { + t.Errorf("formatMetadataError() = %v, want exact %v", got, tt.wantExact) + } + if tt.wantContain != "" && !strings.Contains(got.Error(), tt.wantContain) { + t.Errorf("formatMetadataError() = %v, want to contain %q", got, tt.wantContain) + } + }) + } +} + +// TestGetMetadata verifies successful and error responses from the metadata server. +func TestGetMetadata(t *testing.T) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/computeMetadata/v1/test-success" { + w.Header().Set("Etag", "test-etag") + fmt.Fprint(w, "success") + return + } + if r.URL.Path == "/computeMetadata/v1/test-404" { + http.NotFound(w, r) + return + } + http.Error(w, "internal error", http.StatusInternalServerError) + }) + + tests := []struct { + name string + suffix string + wantBody string + wantEtag string + wantNil bool + }{ + { + name: "success", + suffix: "test-success", + wantBody: "success", + wantEtag: "test-etag", + }, + { + name: "404 not found", + suffix: "test-404", + wantNil: true, + }, + { + name: "500 internal server error", + suffix: "test-500", + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + body, etag, err := getMetadata(tt.suffix) + if err != nil { + t.Errorf("getMetadata(%q) error: %v", tt.suffix, err) + } + if tt.wantNil { + if body != nil || etag != "" { + t.Errorf("getMetadata(%q) expected nil body and empty etag, got %q, %q", tt.suffix, body, etag) + } + } else { + if string(body) != tt.wantBody { + t.Errorf("getMetadata(%q) body = %q, want %q", tt.suffix, body, tt.wantBody) + } + if etag != tt.wantEtag { + t.Errorf("getMetadata(%q) etag = %q, want %q", tt.suffix, etag, tt.wantEtag) + } + } + }) + } +} + +// TestGetMetadataFallback verifies fallback to the default metadata IP address. +func TestGetMetadataFallback(t *testing.T) { + defer UnsetEnv(t, metadataHostEnv)() + + var requestedURL string + defer MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { + requestedURL = req.URL.String() + return &http.Response{StatusCode: 200, Body: ioutil.NopCloser(strings.NewReader("mock response"))}, nil + })() + + _, _, err := getMetadata("test-suffix") + if err != nil { + t.Fatalf("getMetadata error: %v", err) + } + + expected := "http://" + metadataIP + "/computeMetadata/v1/test-suffix" + if requestedURL != expected { + t.Errorf("getMetadata requested %q, want %q", requestedURL, expected) + } +} + +// TestGetMetadataErrors verifies request and network error handling in getMetadata. +func TestGetMetadataErrors(t *testing.T) { + tests := []struct { + name string + suffix string + mockTransport func(t *testing.T) (rollback func()) + wantErrContain string + }{ + { + name: "http.NewRequest error (bad control char in URL)", + suffix: "suffix\x7f", + wantErrContain: "invalid control character in URL", + }, + { + name: "client.Do error", + suffix: "test-suffix", + mockTransport: func(t *testing.T) func() { + return MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { + return nil, fmt.Errorf("mock dial error") + }) + }, + wantErrContain: "mock dial error", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.mockTransport != nil { + t.Cleanup(tt.mockTransport(t)) + } + _, _, err := getMetadata(tt.suffix) + if err == nil || !strings.Contains(err.Error(), tt.wantErrContain) { + t.Errorf("getMetadata() error = %v, want error containing %q", err, tt.wantErrContain) + } + }) + } +} + +// TestConfigSha256 verifies that equivalent configurations produce the same SHA256 signature. +func TestConfigSha256(t *testing.T) { + c1 := &config{projectID: "test-project", osInventoryEnabled: true} + c2 := &config{projectID: "test-project", osInventoryEnabled: true} + c3 := &config{projectID: "test-project", osInventoryEnabled: false} + + if c1.asSha256() != c2.asSha256() { + t.Errorf("Expected identical configs to have same SHA256") + } + if c1.asSha256() == c3.asSha256() { + t.Errorf("Expected different configs to have different SHA256") + } +} + +// TestLastEtag tests concurrent read and write access to the lastEtag tracker. +func TestLastEtag(t *testing.T) { + le := &lastEtag{Etag: "initial"} + var wg sync.WaitGroup + + // Run concurrent gets and sets to ensure no race conditions + for i := 0; i < 100; i++ { + wg.Add(1) + go func(val int) { + defer wg.Done() + le.set(fmt.Sprintf("etag-%d", val)) + _ = le.get() + }(i) + } + wg.Wait() + + if le.get() == "" { + t.Errorf("Expected non-empty etag") + } +} + +// TestSystemPaths verifies OS-specific system path generation. +func TestSystemPaths(t *testing.T) { + origGOOS := goos + defer func() { goos = origGOOS }() + + tests := []struct { + name string + op func() string + want map[string]string + }{ + { + name: "TaskStateFile", + op: TaskStateFile, + want: map[string]string{"windows": filepath.Join(GetCacheDirWindows(), "osconfig_task.state"), "linux": taskStateFileLinux}, + }, + { + name: "OldTaskStateFile", + op: OldTaskStateFile, + want: map[string]string{"windows": oldTaskStateFileWindows, "linux": oldTaskStateFileLinux}, + }, + { + name: "RestartFile", + op: RestartFile, + want: map[string]string{"windows": filepath.Join(GetCacheDirWindows(), "osconfig_agent_restart_required"), "linux": restartFileLinux}, + }, + { + name: "OldRestartFile", + op: OldRestartFile, + want: map[string]string{"windows": oldRestartFileLinux, "linux": oldRestartFileLinux}, + }, + { + name: "CacheDir", + op: CacheDir, + want: map[string]string{"windows": GetCacheDirWindows(), "linux": cacheDirLinux}, + }, + { + name: "SerialLogPort", + op: SerialLogPort, + want: map[string]string{"windows": "COM1", "linux": ""}, + }, + } + + for _, tt := range tests { + for _, testOS := range []string{"windows", "linux"} { + t.Run(fmt.Sprintf("%s_%s", tt.name, testOS), func(t *testing.T) { + goos = testOS + if got := tt.op(); got != tt.want[testOS] { + t.Errorf("%s() on %s = %v, want %v", tt.name, testOS, got, tt.want[testOS]) + } + }) + } + } +} + +// TestMiscGetters verifies static getter function outputs. +func TestMiscGetters(t *testing.T) { + SetVersion("1.2.3") + + tests := []struct { + name string + got interface{} + want interface{} + }{ + {name: "Capabilities", got: Capabilities(), want: []string{"PATCH_GA", "GUEST_POLICY_BETA", "CONFIG_V1"}}, + {name: "UserAgent", got: UserAgent(), want: "google-osconfig-agent/1.2.3"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if !reflect.DeepEqual(tt.got, tt.want) { + t.Errorf("%s() = %v, want %v", tt.name, tt.got, tt.want) + } + }) + } +} + +// TestCreateConfigFromMetadata tests that the agent's configuration is correctly +// built from the data it gets from the metadata server. The test checks how +// various settings are read, how instance-level settings take priority over +// project-level ones, and how command-line flags can override any metadata setting. +func TestCreateConfigFromMetadata(t *testing.T) { + // Reset the global agent config to avoid test cross-contamination + agentConfigMx.Lock() + agentConfig = &config{} + agentConfigMx.Unlock() + + pollInt15 := json.Number("15") + pollInt20 := json.Number("20") + id98765 := json.Number("98765") + + tests := []struct { + name string + md metadataJSON + setDebugFlag bool + want *config + }{ + { + name: "default values", + md: metadataJSON{}, + want: &config{ + osInventoryEnabled: osInventoryEnabledDefault, + guestPoliciesEnabled: guestPoliciesEnabledDefault, + taskNotificationEnabled: taskNotificationEnabledDefault, + debugEnabled: debugEnabledDefault, + svcEndpoint: strings.ReplaceAll(prodEndpoint, "{zone}", ""), + osConfigPollInterval: osConfigPollIntervalDefault, + googetRepoFilePath: googetRepoFilePath, + zypperRepoFilePath: zypperRepoFilePath, + yumRepoFilePath: yumRepoFilePath, + aptRepoFilePath: aptRepoFilePath, + universeDomain: universeDomainDefault, + }, + }, + { + name: "project level debug and numeric poll interval", + md: metadataJSON{ + Project: projectJSON{ + ProjectID: "proj-1", + Attributes: attributesJSON{ + LogLevel: "debug", + PollInterval: &pollInt15, + OSConfigEnabled: "true", + }, + }, + }, + want: &config{ + projectID: "proj-1", + osInventoryEnabled: true, + guestPoliciesEnabled: true, + taskNotificationEnabled: true, + debugEnabled: true, + svcEndpoint: strings.ReplaceAll(prodEndpoint, "{zone}", ""), + osConfigPollInterval: 15, + googetRepoFilePath: googetRepoFilePath, + zypperRepoFilePath: zypperRepoFilePath, + yumRepoFilePath: yumRepoFilePath, + aptRepoFilePath: aptRepoFilePath, + universeDomain: universeDomainDefault, + }, + }, + { + name: "instance level overrides project level", + md: metadataJSON{ + Project: projectJSON{ + ProjectID: "proj-1", + Attributes: attributesJSON{ + LogLevel: "info", + PollInterval: &pollInt15, + OSConfigEnabled: "true", + }, + }, + Instance: instanceJSON{ + Attributes: attributesJSON{ + LogLevel: "debug", + PollInterval: &pollInt20, + OSConfigEnabled: "false", + }, + }, + }, + want: &config{ + projectID: "proj-1", + osInventoryEnabled: false, + guestPoliciesEnabled: false, + taskNotificationEnabled: false, + debugEnabled: true, + svcEndpoint: strings.ReplaceAll(prodEndpoint, "{zone}", ""), + osConfigPollInterval: 20, + googetRepoFilePath: googetRepoFilePath, + zypperRepoFilePath: zypperRepoFilePath, + yumRepoFilePath: yumRepoFilePath, + aptRepoFilePath: aptRepoFilePath, + universeDomain: universeDomainDefault, + }, + }, + { + name: "legacy poll interval and disabled features", + md: metadataJSON{ + Project: projectJSON{ + Attributes: attributesJSON{ + PollIntervalOld: &pollInt15, + }, + }, + Instance: instanceJSON{ + ID: &id98765, + Attributes: attributesJSON{ + OSConfigEnabled: "true", + DisabledFeatures: "osinventory, guestpolicies", + }, + }, + }, + want: &config{ + instanceID: "98765", + osInventoryEnabled: false, + guestPoliciesEnabled: false, + taskNotificationEnabled: true, + debugEnabled: debugEnabledDefault, + svcEndpoint: strings.ReplaceAll(prodEndpoint, "{zone}", ""), + osConfigPollInterval: 15, + googetRepoFilePath: googetRepoFilePath, + zypperRepoFilePath: zypperRepoFilePath, + yumRepoFilePath: yumRepoFilePath, + aptRepoFilePath: aptRepoFilePath, + universeDomain: universeDomainDefault, + }, + }, + { + name: "debug flag overrides metadata", + md: metadataJSON{ + Project: projectJSON{ + Attributes: attributesJSON{ + LogLevel: "info", + }, + }, + }, + setDebugFlag: true, + want: &config{ + osInventoryEnabled: osInventoryEnabledDefault, + guestPoliciesEnabled: guestPoliciesEnabledDefault, + taskNotificationEnabled: taskNotificationEnabledDefault, + debugEnabled: true, + svcEndpoint: strings.ReplaceAll(prodEndpoint, "{zone}", ""), + osConfigPollInterval: osConfigPollIntervalDefault, + googetRepoFilePath: googetRepoFilePath, + zypperRepoFilePath: zypperRepoFilePath, + yumRepoFilePath: yumRepoFilePath, + aptRepoFilePath: aptRepoFilePath, + universeDomain: universeDomainDefault, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + origDebug := *debug + *debug = tt.setDebugFlag + got := createConfigFromMetadata(tt.md) + *debug = origDebug + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("createConfigFromMetadata() = %+v, want %+v", got, tt.want) + } + }) + } +} + func TestSvcEndpoint(t *testing.T) { var request int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { switch request { case 0: w.Header().Set("Etag", "etag-0") @@ -235,12 +956,7 @@ func TestSvcEndpoint(t *testing.T) { w.Header().Set("Etag", "etag-1") fmt.Fprintln(w, `{"universe": {"universeDomain": "domain.com"}, "instance": {"id": 12345,"name": "name","zone": "fakezone","attributes": {"osconfig-endpoint": "{zone}-dev.osconfig.googleapis.com"}}}`) } - })) - defer ts.Close() - - if err := os.Setenv("GCE_METADATA_HOST", strings.Trim(ts.URL, "http://")); err != nil { - t.Fatalf("Error running os.Setenv: %v", err) - } + }) for i, expectedSvcEndpoint := range []string{"fakezone-dev.osconfig.googleapis.com", "fakezone-dev.osconfig.domain.com"} { request = i @@ -255,25 +971,9 @@ func TestSvcEndpoint(t *testing.T) { } -func TestSetConfigError(t *testing.T) { - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - })) - defer ts.Close() - - if err := os.Setenv("GCE_METADATA_HOST", strings.Trim(ts.URL, "http://")); err != nil { - t.Fatalf("Error running os.Setenv: %v", err) - } - - osConfigWatchConfigTimeout = 1 * time.Millisecond - - if err := WatchConfig(context.Background()); err == nil || !strings.Contains(err.Error(), "unexpected end of JSON input") { - t.Errorf("Unexpected output %+v", err) - } -} - func TestDisableCloudLogging(t *testing.T) { var request int - ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { switch request { case 0: w.Header().Set("Etag", "etag-0") @@ -282,12 +982,7 @@ func TestDisableCloudLogging(t *testing.T) { w.Header().Set("Etag", "etag-1") fmt.Fprintln(w, `{"instance": {"zone": "fake-zone"}}`) } - })) - defer ts.Close() - - if err := os.Setenv("GCE_METADATA_HOST", strings.Trim(ts.URL, "http://")); err != nil { - t.Fatalf("Error running os.Setenv: %v", err) - } + }) for i, expectedDisableCloudLoggingValue := range []bool{true, false} { request = i @@ -301,3 +996,392 @@ func TestDisableCloudLogging(t *testing.T) { } } + +// TestSetScalibrEnablement tests Scalibr enablement flag extraction from metadata. +func TestSetScalibrEnablement(t *testing.T) { + tests := []struct { + name string + projVal string + instVal string + want bool + }{ + {name: "Both empty", projVal: "", instVal: "", want: false}, + {name: "Project true", projVal: "true", instVal: "", want: true}, + {name: "Project false", projVal: "false", instVal: "", want: false}, + {name: "Instance true", projVal: "", instVal: "true", want: true}, + {name: "Instance overrides project", projVal: "false", instVal: "true", want: true}, + {name: "Instance overrides project (false)", projVal: "true", instVal: "false", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &config{} + md := metadataJSON{ + Project: projectJSON{Attributes: attributesJSON{ScalibrLinuxEnabled: tt.projVal}}, + Instance: instanceJSON{Attributes: attributesJSON{ScalibrLinuxEnabled: tt.instVal}}, + } + setScalibrEnablement(md, c) + if c.scalibrLinuxEnabled != tt.want { + t.Errorf("setScalibrEnablement() = %v, want %v", c.scalibrLinuxEnabled, tt.want) + } + }) + } +} + +// TestSetTraceGetInventory tests the inventory tracing flag extraction from metadata. +func TestSetTraceGetInventory(t *testing.T) { + tests := []struct { + name string + projVal string + instVal string + want bool + }{ + {name: "Both empty", projVal: "", instVal: "", want: false}, + {name: "Project true", projVal: "true", instVal: "", want: true}, + {name: "Project false", projVal: "false", instVal: "", want: false}, + {name: "Instance true", projVal: "", instVal: "true", want: true}, + {name: "Instance overrides project", projVal: "false", instVal: "true", want: true}, + {name: "Instance overrides project (false)", projVal: "true", instVal: "false", want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := &config{} + md := metadataJSON{ + Project: projectJSON{Attributes: attributesJSON{TraceGetInventory: tt.projVal}}, + Instance: instanceJSON{Attributes: attributesJSON{TraceGetInventory: tt.instVal}}, + } + setTraceGetInventory(md, c) + if c.traceGetInventory != tt.want { + t.Errorf("setTraceGetInventory() = %v, want %v", c.traceGetInventory, tt.want) + } + }) + } +} + +// TestSetSVCEndpoint tests the logic for figuring out which OS Config service +// endpoint to use. It checks that command-line flags have the highest priority, +// that placeholders like `{zone}` are filled in correctly, and that the endpoint +// is adjusted for different universe domains. +func TestSetSVCEndpoint(t *testing.T) { + origEndpoint := *endpoint + defer func() { *endpoint = origEndpoint }() + + tests := []struct { + name string + flag string + instNew string + instOld string + projNew string + projOld string + universe string + instanceZone string + want string + }{ + { + name: "Default (all empty)", + flag: prodEndpoint, + instanceZone: "projects/123/zones/us-west1-a", + want: "us-west1-a-osconfig.googleapis.com.:443", + }, + { + name: "Flag overrides all", + flag: "custom-endpoint", + instNew: "inst-new", + want: "custom-endpoint", + }, + { + name: "Instance New", + flag: prodEndpoint, + instNew: "inst-new-{zone}", + instanceZone: "projects/123/zones/us-west1-a", + want: "inst-new-us-west1-a", + }, + { + name: "Instance Old fallback", + flag: prodEndpoint, + instOld: "inst-old-{zone}", + instanceZone: "projects/123/zones/us-west1-a", + want: "inst-old-us-west1-a", + }, + { + name: "Project New fallback", + flag: prodEndpoint, + projNew: "proj-new-{zone}", + instanceZone: "projects/123/zones/us-west1-a", + want: "proj-new-us-west1-a", + }, + { + name: "Project Old fallback", + flag: prodEndpoint, + projOld: "proj-old-{zone}", + instanceZone: "projects/123/zones/us-west1-a", + want: "proj-old-us-west1-a", + }, + { + name: "Universe Domain replacement", + flag: prodEndpoint, + instNew: "test-osconfig.googleapis.com", + universe: "my-universe.com", + want: "test-osconfig.my-universe.com", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + *endpoint = tt.flag + c := &config{ + instanceZone: tt.instanceZone, + svcEndpoint: prodEndpoint, + } + if tt.universe != "" { + c.universeDomain = tt.universe + } else { + c.universeDomain = universeDomainDefault + } + + md := metadataJSON{ + Project: projectJSON{ + Attributes: attributesJSON{ + OSConfigEndpoint: tt.projNew, + OSConfigEndpointOld: tt.projOld, + }, + }, + Instance: instanceJSON{ + Attributes: attributesJSON{ + OSConfigEndpoint: tt.instNew, + OSConfigEndpointOld: tt.instOld, + }, + }, + } + + setSVCEndpoint(md, c) + if c.svcEndpoint != tt.want { + t.Errorf("setSVCEndpoint() = %v, want %v", c.svcEndpoint, tt.want) + } + }) + } +} + +// TestGetCacheDirWindows verifies primary and fallback cache directory resolution on Windows. +func TestGetCacheDirWindows(t *testing.T) { + tests := []struct { + name string + setup func(t *testing.T) + want string + checkSuffix bool + }{ + { + name: "Standard call", + setup: func(t *testing.T) { /* no-op */ }, + want: windowsCacheDir, + checkSuffix: true, + }, + { + name: "Fallback to TempDir", + setup: func(t *testing.T) { + // Test fallback by unsetting the HOME, AppData, and XDG environment variables + // that os.UserCacheDir relies on to generate paths. + envs := []string{"HOME", "LocalAppData", "XDG_CACHE_HOME"} + for _, env := range envs { + t.Cleanup(UnsetEnv(t, env)) + } + }, + want: filepath.Join(os.TempDir(), windowsCacheDir), + checkSuffix: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.setup(t) + got := GetCacheDirWindows() + if tt.checkSuffix { + if !strings.HasSuffix(got, tt.want) { + t.Errorf("GetCacheDirWindows() = %q, want suffix %q", got, tt.want) + } + } else if got != tt.want { + t.Errorf("GetCacheDirWindows() = %q, want %q", got, tt.want) + } + }) + } +} + +// TestFlagsAndEnvVars verifies parsing of environment variables. +func TestFlagsAndEnvVars(t *testing.T) { + origFreeOSMemory := freeOSMemory + origDisableInventoryWrite := disableInventoryWrite + defer func() { + freeOSMemory = origFreeOSMemory + disableInventoryWrite = origDisableInventoryWrite + }() + + tests := []struct { + name string + freeOSMemoryVal string + disableInventoryWrite string + wantFreeOS bool + wantDisableInv bool + }{ + {name: "Both True", freeOSMemoryVal: "true", disableInventoryWrite: "1", wantFreeOS: true, wantDisableInv: true}, + {name: "Both False", freeOSMemoryVal: "false", disableInventoryWrite: "0", wantFreeOS: false, wantDisableInv: false}, + {name: "Empty", freeOSMemoryVal: "", disableInventoryWrite: "", wantFreeOS: false, wantDisableInv: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + freeOSMemory = tt.freeOSMemoryVal + disableInventoryWrite = tt.disableInventoryWrite + + if got := FreeOSMemory(); got != tt.wantFreeOS { + t.Errorf("FreeOSMemory() = %v, want %v", got, tt.wantFreeOS) + } + if got := DisableInventoryWrite(); got != tt.wantDisableInv { + t.Errorf("DisableInventoryWrite() = %v, want %v", got, tt.wantDisableInv) + } + }) + } +} + +// TestParseBool verifies string-to-boolean conversion logic. +func TestParseBool(t *testing.T) { + tests := []struct { + input string + want bool + }{ + {input: "true", want: true}, + {input: "1", want: true}, + {input: "false", want: false}, + {input: "0", want: false}, + {input: "invalid", want: false}, + } + + for _, tt := range tests { + if got := parseBool(tt.input); got != tt.want { + t.Errorf("parseBool(%q) = %v, want %v", tt.input, got, tt.want) + } + } +} + +// TestParseFeatures verifies comma-separated feature flag parsing. +func TestParseFeatures(t *testing.T) { + tests := []struct { + name string + initial config + features string + enabled bool + want config + }{ + { + name: "enabling features", + initial: config{}, + features: "tasks, ospackage, osinventory, unknown", + enabled: true, + want: config{ + taskNotificationEnabled: true, + guestPoliciesEnabled: true, + osInventoryEnabled: true, + }, + }, + { + name: "disabling features (using legacy names as well)", + initial: config{ + taskNotificationEnabled: true, + guestPoliciesEnabled: true, + osInventoryEnabled: true, + }, + features: "ospatch, guestpolicies", + enabled: false, + want: config{ + taskNotificationEnabled: false, + guestPoliciesEnabled: false, + osInventoryEnabled: true, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := tt.initial + c.parseFeatures(tt.features, tt.enabled) + + if c != tt.want { + t.Errorf("parseFeatures() state = %+v, want %+v", c, tt.want) + } + }) + } +} + +type roundTripperFunc func(*http.Request) (*http.Response, error) + +func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return f(req) +} + +// OverrideEnv sets an environment variable for the duration of a test and returns a rollback function to restore its original state. +func OverrideEnv(t *testing.T, env, value string) (rollback func()) { + orig, ok := os.LookupEnv(env) + rollback = func() { + if ok { + if err := os.Setenv(env, orig); err != nil { + t.Fatalf("Failed to restore environment variable %s: %v", env, err) + } + } else { + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } + } + } + + if err := os.Setenv(env, value); err != nil { + t.Fatalf("Failed to set environment variable %s: %v", env, err) + } + + return rollback +} + +// UnsetEnv unsets an environment variable for the duration of a test and returns a rollback function to restore its original state. +func UnsetEnv(t *testing.T, env string) (rollback func()) { + orig, ok := os.LookupEnv(env) + rollback = func() { + if ok { + if err := os.Setenv(env, orig); err != nil { + t.Fatalf("Failed to restore environment variable %s: %v", env, err) + } + } else { + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } + } + } + + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } + + return rollback +} + +// OverrideWatchConfigTimeouts temporarily overwrites the timeout and retry intervals for WatchConfig. +func OverrideWatchConfigTimeouts(interval, timeout time.Duration) (rollback func()) { + origInterval := watchConfigRetryInterval + origTimeout := osConfigWatchConfigTimeout + + watchConfigRetryInterval = interval + osConfigWatchConfigTimeout = timeout + return func() { + watchConfigRetryInterval = origInterval + osConfigWatchConfigTimeout = origTimeout + } +} + +// MockDefaultClientTransport temporarily replaces the defaultClient's transport with a custom round tripper. +func MockDefaultClientTransport(t *testing.T, roundTrip func(*http.Request) (*http.Response, error)) (rollback func()) { + origClient := defaultClient + defaultClient = &http.Client{ + Transport: roundTripperFunc(roundTrip), + } + return func() { + defaultClient = origClient + } +} From 6abe8ada089569128e9a8c53cbea50f287eed9a4 Mon Sep 17 00:00:00 2001 From: Ilya Tsupryk Date: Wed, 25 Mar 2026 17:26:57 +0000 Subject: [PATCH 5/5] Use utiltest helpers in new tests --- agentconfig/agentconfig_test.go | 155 +++++++++++--------------------- util/utiltest/utiltest.go | 48 +++++++++- 2 files changed, 95 insertions(+), 108 deletions(-) diff --git a/agentconfig/agentconfig_test.go b/agentconfig/agentconfig_test.go index 5fdd6dcf3..c09f31c45 100644 --- a/agentconfig/agentconfig_test.go +++ b/agentconfig/agentconfig_test.go @@ -33,6 +33,8 @@ import ( "sync" "testing" "time" + + "github.com/GoogleCloudPlatform/osconfig/util/utiltest" ) // setupMockMetadataServer starts an httptest.Server with the provided handler and overrides the GCE_METADATA_HOST environment variable. @@ -42,8 +44,7 @@ func setupMockMetadataServer(t *testing.T, handler http.HandlerFunc) *httptest.S ts := httptest.NewServer(handler) t.Cleanup(ts.Close) - rollback := OverrideEnv(t, "GCE_METADATA_HOST", strings.TrimPrefix(ts.URL, "http://")) - t.Cleanup(rollback) + utiltest.OverrideEnv(t, "GCE_METADATA_HOST", strings.TrimPrefix(ts.URL, "http://")) return ts } @@ -238,7 +239,7 @@ func TestSetConfigDefaultValues(t *testing.T) { // keep polling for real changes. This test verifies that the agent correctly // continues to wait until its internal timeout runs out, and then exits normally. func TestWatchConfigUnchangedConfigTimeout(t *testing.T) { - defer OverrideWatchConfigTimeouts(1*time.Millisecond, 10*time.Millisecond)() + OverrideWatchConfigTimeouts(t, 1*time.Millisecond, 10*time.Millisecond) var count int setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -267,17 +268,17 @@ func TestWatchConfigUnchangedConfigTimeout(t *testing.T) { // up to a limit of 12 times before giving up and reporting an error. func TestWatchConfigWebErrorLimit(t *testing.T) { lEtag.set("0") - defer OverrideWatchConfigTimeouts(1*time.Millisecond, 1*time.Second)() - defer OverrideEnv(t, "GCE_METADATA_HOST", "mock-host")() + OverrideWatchConfigTimeouts(t, 1*time.Millisecond, 1*time.Second) + utiltest.OverrideEnv(t, "GCE_METADATA_HOST", "mock-host") mockNetErr := &net.OpError{ Op: "dial", Net: "tcp", Err: errors.New("connection refused"), } - defer MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { + MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { return nil, mockNetErr - })() + }) err := WatchConfig(context.Background()) if err == nil { @@ -290,9 +291,7 @@ func TestWatchConfigWebErrorLimit(t *testing.T) { Err: mockNetErr, } expectedErr := fmt.Errorf("network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w", expectedBaseErr) - if err.Error() != expectedErr.Error() { - t.Errorf("Expected exact error:\n%q\nGot:\n%q", expectedErr.Error(), err.Error()) - } + utiltest.AssertErrorMatch(t, err, expectedErr) } // TestWatchConfigUnmarshalErrorLimit tests how WatchConfig handles bad or incomplete @@ -300,7 +299,7 @@ func TestWatchConfigWebErrorLimit(t *testing.T) { // response and verifies that the agent tries to read it again up to a limit of 3 // times before it stops and reports an error. func TestWatchConfigUnmarshalErrorLimit(t *testing.T) { - defer OverrideWatchConfigTimeouts(1*time.Millisecond, 1*time.Second)() + OverrideWatchConfigTimeouts(t, 1*time.Millisecond, 1*time.Second) badJSON := []byte(`{"bad json"`) setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { @@ -316,9 +315,7 @@ func TestWatchConfigUnmarshalErrorLimit(t *testing.T) { var dummy metadataJSON expectedErr := json.Unmarshal(badJSON, &dummy) - if err.Error() != expectedErr.Error() { - t.Errorf("Expected exact error:\n%q\nGot:\n%q", expectedErr.Error(), err.Error()) - } + utiltest.AssertErrorMatch(t, err, expectedErr) } // TestWatchConfigContextCancel tests that the WatchConfig function can be stopped @@ -326,7 +323,7 @@ func TestWatchConfigUnmarshalErrorLimit(t *testing.T) { // cancel, it stops immediately without waiting for a timeout or retrying failed // requests. func TestWatchConfigContextCancel(t *testing.T) { - defer OverrideWatchConfigTimeouts(1*time.Minute, 1*time.Minute)() + OverrideWatchConfigTimeouts(t, 1*time.Minute, 1*time.Minute) setupMockMetadataServer(t, func(w http.ResponseWriter, r *http.Request) { w.Header().Set("Etag", fmt.Sprintf("cancel-etag-%d", time.Now().UnixNano())) @@ -418,7 +415,7 @@ func TestIDToken(t *testing.T) { handler http.HandlerFunc numCalls int wantToken string - wantErr bool + wantErr error wantRequests int }{ { @@ -433,7 +430,7 @@ func TestIDToken(t *testing.T) { }, numCalls: 2, wantToken: validToken, - wantErr: false, + wantErr: nil, wantRequests: 1, // Only 1 request should be made due to caching }, { @@ -448,7 +445,7 @@ func TestIDToken(t *testing.T) { }, numCalls: 2, wantToken: expiringToken, - wantErr: false, + wantErr: nil, wantRequests: 2, // Token is within 10m of expiry, should trigger a fetch on every call }, { @@ -457,7 +454,7 @@ func TestIDToken(t *testing.T) { http.Error(w, "internal error", http.StatusInternalServerError) }, numCalls: 1, - wantErr: true, + wantErr: fmt.Errorf("error getting token from metadata: %w", errors.New("compute: Received 500 `internal error\n`")), // The compute/metadata client library automatically retries on 500 errors (1 initial + 5 retries). wantRequests: 6, }, @@ -468,7 +465,7 @@ func TestIDToken(t *testing.T) { fmt.Fprint(w, "not.a.valid.token") }, numCalls: 1, - wantErr: true, + wantErr: errors.New("jws: invalid token received"), wantRequests: 1, }, } @@ -488,11 +485,8 @@ func TestIDToken(t *testing.T) { for i := 0; i < tt.numCalls; i++ { token, err = IDToken() } - - if (err != nil) != tt.wantErr { - t.Fatalf("IDToken() error = %v, wantErr %v", err, tt.wantErr) - } - if err == nil && token != tt.wantToken { + utiltest.AssertErrorMatch(t, err, tt.wantErr) + if token != tt.wantToken { t.Errorf("IDToken() = %q, want %q", token, tt.wantToken) } if requests != tt.wantRequests { @@ -504,42 +498,36 @@ func TestIDToken(t *testing.T) { // TestFormatMetadataError verifies that network and DNS errors are wrapped with helpful context. func TestFormatMetadataError(t *testing.T) { - errStandard := fmt.Errorf("standard error") - errDNS := &url.Error{Err: &net.DNSError{Err: "no such host"}} - errNet := &url.Error{Err: &net.OpError{Op: "dial", Net: "tcp"}} + dnsErr := &url.Error{Err: &net.DNSError{Err: "no such host"}} + netErr := &url.Error{Err: &net.OpError{Op: "dial", Net: "tcp", Err: errors.New("connection refused")}} tests := []struct { - name string - inputErr error - wantExact error - wantContain string + name string + inputErr error + wantErr error }{ { - name: "standard error", - inputErr: errStandard, - wantExact: errStandard, + name: "standard error", + inputErr: fmt.Errorf("standard error"), + wantErr: fmt.Errorf("standard error"), }, { - name: "DNS error", - inputErr: errDNS, - wantContain: "DNS error when requesting metadata", + name: "DNS error", + inputErr: dnsErr, + wantErr: fmt.Errorf("DNS error when requesting metadata, check DNS settings and ensure metadata.google.internal is setup in your hosts file: %w", dnsErr), }, { - name: "network error", - inputErr: errNet, - wantContain: "network error when requesting metadata", + name: "network error", + inputErr: netErr, + wantErr: fmt.Errorf("network error when requesting metadata, make sure your instance has an active network and can reach the metadata server: %w", netErr), }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got := formatMetadataError(tt.inputErr) - if tt.wantExact != nil && got != tt.wantExact { - t.Errorf("formatMetadataError() = %v, want exact %v", got, tt.wantExact) - } - if tt.wantContain != "" && !strings.Contains(got.Error(), tt.wantContain) { - t.Errorf("formatMetadataError() = %v, want to contain %q", got, tt.wantContain) - } + + utiltest.AssertErrorMatch(t, got, tt.wantErr) }) } } @@ -608,13 +596,13 @@ func TestGetMetadata(t *testing.T) { // TestGetMetadataFallback verifies fallback to the default metadata IP address. func TestGetMetadataFallback(t *testing.T) { - defer UnsetEnv(t, metadataHostEnv)() + utiltest.UnsetEnv(t, metadataHostEnv) var requestedURL string - defer MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { + MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { requestedURL = req.URL.String() return &http.Response{StatusCode: 200, Body: ioutil.NopCloser(strings.NewReader("mock response"))}, nil - })() + }) _, _, err := getMetadata("test-suffix") if err != nil { @@ -632,7 +620,7 @@ func TestGetMetadataErrors(t *testing.T) { tests := []struct { name string suffix string - mockTransport func(t *testing.T) (rollback func()) + mockTransport func(t *testing.T) wantErrContain string }{ { @@ -643,8 +631,8 @@ func TestGetMetadataErrors(t *testing.T) { { name: "client.Do error", suffix: "test-suffix", - mockTransport: func(t *testing.T) func() { - return MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { + mockTransport: func(t *testing.T) { + MockDefaultClientTransport(t, func(req *http.Request) (*http.Response, error) { return nil, fmt.Errorf("mock dial error") }) }, @@ -655,7 +643,7 @@ func TestGetMetadataErrors(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.mockTransport != nil { - t.Cleanup(tt.mockTransport(t)) + tt.mockTransport(t) } _, _, err := getMetadata(tt.suffix) if err == nil || !strings.Contains(err.Error(), tt.wantErrContain) { @@ -1184,7 +1172,7 @@ func TestGetCacheDirWindows(t *testing.T) { // that os.UserCacheDir relies on to generate paths. envs := []string{"HOME", "LocalAppData", "XDG_CACHE_HOME"} for _, env := range envs { - t.Cleanup(UnsetEnv(t, env)) + utiltest.UnsetEnv(t, env) } }, want: filepath.Join(os.TempDir(), windowsCacheDir), @@ -1318,70 +1306,27 @@ func (f roundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { return f(req) } -// OverrideEnv sets an environment variable for the duration of a test and returns a rollback function to restore its original state. -func OverrideEnv(t *testing.T, env, value string) (rollback func()) { - orig, ok := os.LookupEnv(env) - rollback = func() { - if ok { - if err := os.Setenv(env, orig); err != nil { - t.Fatalf("Failed to restore environment variable %s: %v", env, err) - } - } else { - if err := os.Unsetenv(env); err != nil { - t.Fatalf("Failed to unset environment variable %s: %v", env, err) - } - } - } - - if err := os.Setenv(env, value); err != nil { - t.Fatalf("Failed to set environment variable %s: %v", env, err) - } - - return rollback -} - -// UnsetEnv unsets an environment variable for the duration of a test and returns a rollback function to restore its original state. -func UnsetEnv(t *testing.T, env string) (rollback func()) { - orig, ok := os.LookupEnv(env) - rollback = func() { - if ok { - if err := os.Setenv(env, orig); err != nil { - t.Fatalf("Failed to restore environment variable %s: %v", env, err) - } - } else { - if err := os.Unsetenv(env); err != nil { - t.Fatalf("Failed to unset environment variable %s: %v", env, err) - } - } - } - - if err := os.Unsetenv(env); err != nil { - t.Fatalf("Failed to unset environment variable %s: %v", env, err) - } - - return rollback -} - // OverrideWatchConfigTimeouts temporarily overwrites the timeout and retry intervals for WatchConfig. -func OverrideWatchConfigTimeouts(interval, timeout time.Duration) (rollback func()) { +func OverrideWatchConfigTimeouts(t *testing.T, interval, timeout time.Duration) { + t.Helper() origInterval := watchConfigRetryInterval origTimeout := osConfigWatchConfigTimeout watchConfigRetryInterval = interval osConfigWatchConfigTimeout = timeout - return func() { + t.Cleanup(func() { watchConfigRetryInterval = origInterval osConfigWatchConfigTimeout = origTimeout - } + }) } // MockDefaultClientTransport temporarily replaces the defaultClient's transport with a custom round tripper. -func MockDefaultClientTransport(t *testing.T, roundTrip func(*http.Request) (*http.Response, error)) (rollback func()) { +func MockDefaultClientTransport(t *testing.T, roundTrip func(*http.Request) (*http.Response, error)) { origClient := defaultClient defaultClient = &http.Client{ Transport: roundTripperFunc(roundTrip), } - return func() { + t.Cleanup(func() { defaultClient = origClient - } + }) } diff --git a/util/utiltest/utiltest.go b/util/utiltest/utiltest.go index 78a9e6b2b..3b3e70ccb 100644 --- a/util/utiltest/utiltest.go +++ b/util/utiltest/utiltest.go @@ -98,12 +98,12 @@ func AssertErrorMatch(t *testing.T, gotErr, wantErr error) { if gotErr == nil && wantErr == nil { return } - if gotErr == nil || wantErr == nil { + if gotErr == nil || wantErr == nil || reflect.TypeOf(gotErr) != reflect.TypeOf(wantErr) { t.Errorf("Errors mismatch, want %v, got %v", wantErr, gotErr) return } - if reflect.TypeOf(gotErr) != reflect.TypeOf(wantErr) || gotErr.Error() != wantErr.Error() { - t.Errorf("Unexpected error, want %v, got %v", wantErr, gotErr) + if diff := cmp.Diff(wantErr.Error(), gotErr.Error()); diff != "" { + t.Errorf("Unexpected error, got != want (-want +got):\n%s", diff) } } @@ -135,3 +135,45 @@ func AssertFileContents(t *testing.T, filePath string, wantContents string) { t.Errorf("File contents mismatch (-want +got):\n%s", diff) } } + +// OverrideEnv sets an environment variable for the duration of a test and restores its original state on cleanup. +func OverrideEnv(t *testing.T, env, value string) { + t.Helper() + orig, ok := os.LookupEnv(env) + t.Cleanup(func() { + if ok { + if err := os.Setenv(env, orig); err != nil { + t.Fatalf("Failed to restore environment variable %s: %v", env, err) + } + } else { + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } + } + }) + + if err := os.Setenv(env, value); err != nil { + t.Fatalf("Failed to set environment variable %s: %v", env, err) + } +} + +// UnsetEnv unsets an environment variable for the duration of a test and restores its original state on cleanup. +func UnsetEnv(t *testing.T, env string) { + t.Helper() + orig, ok := os.LookupEnv(env) + t.Cleanup(func() { + if ok { + if err := os.Setenv(env, orig); err != nil { + t.Fatalf("Failed to restore environment variable %s: %v", env, err) + } + } else { + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } + } + }) + + if err := os.Unsetenv(env); err != nil { + t.Fatalf("Failed to unset environment variable %s: %v", env, err) + } +}