diff --git a/cmd/apiserver/main.go b/cmd/apiserver/main.go index ba666f73..7b8d04ea 100644 --- a/cmd/apiserver/main.go +++ b/cmd/apiserver/main.go @@ -10,6 +10,7 @@ import ( _ "net/http/pprof" "net/netip" "os" + "strings" "time" "github.com/kelseyhightower/envconfig" @@ -349,6 +350,21 @@ func run(log *logrus.Entry, cfg config.Config) error { return fmt.Errorf("read device with external_id=%v: %w", event.GetExternalID(), err) } + kolideDevice, err := kolideClient.GetDevice(ctx, device.ExternalID) + if err != nil { + return fmt.Errorf("get kolide device %v: %w", device.ExternalID, err) + } + + if !strings.EqualFold(kolideDevice.Owner.Email, device.Username) { + log.WithFields(logrus.Fields{ + "device_serial": device.Serial, + "device_platform": device.Platform, + "device_username": device.Username, + "kolide_owner_email": kolideDevice.Owner.Email, + "kolide_owner_id": kolideDevice.OwnerRef.Identifier, + }).Warn("kolide device owner email does not match enrolled username") + } + failures, err := kolideClient.GetDeviceIssues(ctx, device.ExternalID) if err != nil { return err diff --git a/internal/apiserver/api/device.go b/internal/apiserver/api/device.go index bae85634..aacc47cf 100644 --- a/internal/apiserver/api/device.go +++ b/internal/apiserver/api/device.go @@ -5,10 +5,14 @@ import ( "database/sql" "errors" "fmt" + "slices" + "strings" "time" + "github.com/nais/device/internal/apiserver/kolide" "github.com/nais/device/internal/apiserver/metrics" "github.com/nais/device/pkg/pb" + "github.com/sirupsen/logrus" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -231,6 +235,31 @@ func (s *grpcServer) UpdateAllDevices(ctx context.Context) error { } } + // Log devices that don't have a matching Kolide device (by serial, platform, and owner email) + for _, device := range devices { + idx := slices.IndexFunc(kolideDevices, func(kd *kolide.Device) bool { + return kd.Serial == device.Serial && kd.Platform == device.Platform + }) + if idx == -1 { + s.log.WithFields(logrus.Fields{ + "device_serial": device.Serial, + "device_platform": device.Platform, + "device_username": device.Username, + }).Info("no matching kolide device found for serial+platform") + continue + } + kd := kolideDevices[idx] + if !strings.EqualFold(kd.Owner.Email, device.Username) { + s.log.WithFields(logrus.Fields{ + "device_serial": device.Serial, + "device_platform": device.Platform, + "device_username": device.Username, + "kolide_owner_email": kd.Owner.Email, + "kolide_owner_id": kd.OwnerRef.Identifier, + }).Warn("kolide device owner email does not match enrolled username") + } + } + issues, err := s.kolideClient.GetIssues(ctx) if err != nil { return fmt.Errorf("getting kolide issues: %w", err) diff --git a/internal/apiserver/api/gateway_test.go b/internal/apiserver/api/gateway_test.go index 911b93df..8b1b8910 100644 --- a/internal/apiserver/api/gateway_test.go +++ b/internal/apiserver/api/gateway_test.go @@ -19,7 +19,7 @@ import ( ) func Test_MakeGatewayConfiguration(t *testing.T) { - ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() // hash generated with `controlplane-cli passhash --password hunter2` @@ -45,17 +45,7 @@ func Test_MakeGatewayConfiguration(t *testing.T) { AccessGroupIDs: []string{"groupId"}, } - db := database.NewMockDatabase(t) - db.EXPECT().ReadGateway(mock.Anything, "gateway").Return(mockGateway, nil).Times(2) - db.EXPECT().ReadGateway(mock.Anything, "privilegedGateway").Return(mockPrivilegedGateway, nil).Times(2) - db.EXPECT().GetAcceptances(mock.Anything).Return(map[string]struct{}{ - "sessionUserId": {}, - "sessionUserIdWithPrivileged": {}, - }, nil).Times(2) - db.EXPECT().UsersWithAccessToPrivilegedGateway(mock.Anything, "privilegedGateway").Return([]string{"sessionUserIdWithPrivileged"}, nil).Once() - - sessionStore := auth.NewMockSessionStore(t) - sessionStore.EXPECT().All().Return([]*pb.Session{ + sessions := []*pb.Session{ { Device: &pb.Device{PublicKey: "devicePublicKey1"}, ObjectID: "sessionUserId", @@ -68,7 +58,22 @@ func Test_MakeGatewayConfiguration(t *testing.T) { Expiry: timestamppb.New(time.Now().Add(24 * time.Hour)), Groups: []string{"groupId"}, }, - }) + } + + // The server loop runs makeGatewayConfiguration repeatedly (on ticker and + // triggers), so we allow any number of calls to the DB rather than fixing + // exact counts, which would cause flakiness based on goroutine scheduling. + db := database.NewMockDatabase(t) + db.On("ReadGateway", mock.Anything, "gateway").Return(mockGateway, nil).Maybe() + db.On("ReadGateway", mock.Anything, "privilegedGateway").Return(mockPrivilegedGateway, nil).Maybe() + db.On("GetAcceptances", mock.Anything).Return(map[string]struct{}{ + "sessionUserId": {}, + "sessionUserIdWithPrivileged": {}, + }, nil).Maybe() + db.On("UsersWithAccessToPrivilegedGateway", mock.Anything, "privilegedGateway").Return([]string{"sessionUserIdWithPrivileged"}, nil).Maybe() + + sessionStore := auth.NewMockSessionStore(t) + sessionStore.On("All").Return(sessions).Maybe() gatewayAuthenticator := auth.NewGatewayAuthenticator(db) @@ -93,9 +98,12 @@ func Test_MakeGatewayConfiguration(t *testing.T) { client := pb.NewAPIServerClient(conn) - // Test authenticated call with correct password + // Test first stream: normal gateway. + // Use a per-stream context so we can cancel it before opening the second + // stream, preventing the server-side loop from making extra DB calls. + stream1Ctx, stream1Cancel := context.WithCancel(ctx) stream, err := client.GetGatewayConfiguration( - ctx, + stream1Ctx, &pb.GetGatewayConfigurationRequest{ Gateway: "gateway", Password: "hunter2", @@ -112,8 +120,21 @@ func Test_MakeGatewayConfiguration(t *testing.T) { assert.Equal(t, "devicePublicKey1", resp.GetDevices()[0].PublicKey) assert.Equal(t, "devicePublicKey2", resp.GetDevices()[1].PublicKey) + // Cancel the first stream and drain it before opening the second, to ensure + // the server-side goroutine has exited and won't race with the second stream. + stream1Cancel() + for { + _, err := stream.Recv() + if err != nil { + break + } + } + + // Test second stream: privileged gateway. + stream2Ctx, stream2Cancel := context.WithCancel(ctx) + defer stream2Cancel() stream, err = client.GetGatewayConfiguration( - ctx, + stream2Ctx, &pb.GetGatewayConfigurationRequest{ Gateway: "privilegedGateway", Password: "hunter2", @@ -128,4 +149,6 @@ func Test_MakeGatewayConfiguration(t *testing.T) { assert.Len(t, resp.GetDevices(), 1) assert.Equal(t, "devicePublicKey2", resp.GetDevices()[0].PublicKey) + + stream2Cancel() } diff --git a/internal/apiserver/api/update_all_devices_test.go b/internal/apiserver/api/update_all_devices_test.go index ec7c6e85..533ce26f 100644 --- a/internal/apiserver/api/update_all_devices_test.go +++ b/internal/apiserver/api/update_all_devices_test.go @@ -2,6 +2,7 @@ package api_test import ( "context" + "fmt" "testing" "time" @@ -467,3 +468,16 @@ func (m *mockKolideClient) GetChecks(ctx context.Context) ([]*kolide.Check, erro func (m *mockKolideClient) GetDeviceIssues(ctx context.Context, deviceID string) ([]*kolide.Issue, error) { return m.issues, nil } + +func (m *mockKolideClient) GetDevice(ctx context.Context, deviceID string) (*kolide.Device, error) { + for _, d := range m.devices { + if d.ID == deviceID { + return d, nil + } + } + return nil, fmt.Errorf("device %v not found", deviceID) +} + +func (m *mockKolideClient) GetPeople(ctx context.Context) (map[string]*kolide.Person, error) { + return map[string]*kolide.Person{}, nil +} diff --git a/internal/apiserver/kolide/client.go b/internal/apiserver/kolide/client.go index 476f4aaf..2f3db51e 100644 --- a/internal/apiserver/kolide/client.go +++ b/internal/apiserver/kolide/client.go @@ -16,7 +16,9 @@ type Client interface { GetIssues(ctx context.Context) ([]*Issue, error) GetDeviceIssues(ctx context.Context, deviceID string) ([]*Issue, error) GetChecks(ctx context.Context) ([]*Check, error) + GetDevice(ctx context.Context, deviceID string) (*Device, error) GetDevices(ctx context.Context) ([]*Device, error) + GetPeople(ctx context.Context) (map[string]*Person, error) } type client struct { @@ -137,6 +139,24 @@ func (kc *client) getPaginated(ctx context.Context, initialURL string) ([]json.R } } +func (kc *client) GetPeople(ctx context.Context) (map[string]*Person, error) { + rawPeople, err := kc.getPaginated(ctx, kc.baseURL+"/people") + if err != nil { + return nil, fmt.Errorf("getting people: %w", err) + } + + people := make(map[string]*Person, len(rawPeople)) + for _, rawPerson := range rawPeople { + person := &Person{} + err := json.Unmarshal(rawPerson, person) + if err != nil { + return nil, fmt.Errorf("unmarshal person: %w", err) + } + people[person.ID] = person + } + return people, nil +} + func (kc *client) GetDevices(ctx context.Context) ([]*Device, error) { kc.log.Debug("getting all devices...") url := kc.baseURL + "/devices" @@ -145,6 +165,11 @@ func (kc *client) GetDevices(ctx context.Context) ([]*Device, error) { return nil, err } + people, err := kc.GetPeople(ctx) + if err != nil { + return nil, fmt.Errorf("getting people for device owner resolution: %w", err) + } + devices := make([]*Device, len(rawDevices)) for i, rawDevice := range rawDevices { device := &Device{} @@ -154,12 +179,46 @@ func (kc *client) GetDevices(ctx context.Context) ([]*Device, error) { } device.Platform = convertPlatform(device.Platform) + if person, ok := people[device.OwnerRef.Identifier]; ok { + device.Owner = *person + } devices[i] = device } return devices, nil } +func (kc *client) GetDevice(ctx context.Context, deviceID string) (*Device, error) { + resp, err := kc.get(ctx, fmt.Sprintf(kc.baseURL+"/devices/%v", deviceID)) + if err != nil { + return nil, fmt.Errorf("getting device %v: %w", deviceID, err) + } + defer ioconvenience.CloseWithLog(resp.Body, kc.log) + + device := &Device{} + if err := json.NewDecoder(resp.Body).Decode(device); err != nil { + return nil, fmt.Errorf("unmarshal device %v: %w", deviceID, err) + } + + device.Platform = convertPlatform(device.Platform) + + if device.OwnerRef.Identifier != "" { + resp, err := kc.get(ctx, fmt.Sprintf(kc.baseURL+"/people/%v", device.OwnerRef.Identifier)) + if err != nil { + return nil, fmt.Errorf("getting person %v for device %v: %w", device.OwnerRef.Identifier, deviceID, err) + } + defer ioconvenience.CloseWithLog(resp.Body, kc.log) + + person := &Person{} + if err := json.NewDecoder(resp.Body).Decode(person); err != nil { + return nil, fmt.Errorf("unmarshal person %v: %w", device.OwnerRef.Identifier, err) + } + device.Owner = *person + } + + return device, nil +} + func (kc *client) GetDeviceIssues(ctx context.Context, deviceID string) ([]*Issue, error) { url := fmt.Sprintf(kc.baseURL+"/devices/%v/open_issues", deviceID) rawIssues, err := kc.getPaginated(ctx, url) diff --git a/internal/apiserver/kolide/client_test.go b/internal/apiserver/kolide/client_test.go index a119c0ce..9c2dac57 100644 --- a/internal/apiserver/kolide/client_test.go +++ b/internal/apiserver/kolide/client_test.go @@ -62,6 +62,8 @@ func TestClient_GetDevices(t *testing.T) { assert.Equal(t, "darwin", d.Platform) // converted from "Mac" assert.Equal(t, "XXXX1111AAAA", d.Serial) assert.Equal(t, "44200", d.OwnerRef.Identifier) + assert.Equal(t, "john.doe@example.com", d.Owner.Email) + assert.Equal(t, "44200", d.Owner.ID) }) t.Run("Windows device platform converted", func(t *testing.T) { @@ -74,6 +76,33 @@ func TestClient_GetDevices(t *testing.T) { d := devices[3] assert.Equal(t, "10004", d.ID) assert.Equal(t, "", d.OwnerRef.Identifier) + assert.Equal(t, "", d.Owner.Email) + }) +} + +func TestClient_GetPeople(t *testing.T) { + ctx := context.Background() + server, _ := setupTestServer(t) + defer server.Close() + + client := kolide.New("token", logrus.New(), kolide.WithBaseURL(server.URL)) + + people, err := client.GetPeople(ctx) + require.NoError(t, err) + require.Len(t, people, 4) + + t.Run("person keyed by id", func(t *testing.T) { + p, ok := people["44200"] + require.True(t, ok) + assert.Equal(t, "44200", p.ID) + assert.Equal(t, "john.doe@example.com", p.Email) + }) + + t.Run("all expected people present", func(t *testing.T) { + assert.Contains(t, people, "44200") + assert.Contains(t, people, "44794") + assert.Contains(t, people, "45037") + assert.Contains(t, people, "47221") }) } @@ -193,6 +222,12 @@ func TestClient_Pagination(t *testing.T) { server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { callCount++ + + if r.URL.Path == "/people" { + w.Write([]byte(`{"data": [], "pagination": {"count": 0}}`)) + return + } + cursor := r.URL.Query().Get("cursor") switch cursor { @@ -215,7 +250,70 @@ func TestClient_Pagination(t *testing.T) { devices, err := client.GetDevices(ctx) require.NoError(t, err) assert.Len(t, devices, 2) - assert.Equal(t, 2, callCount) + assert.Equal(t, 3, callCount) // 2 pages of devices + 1 people call assert.Equal(t, "1", devices[0].ID) assert.Equal(t, "2", devices[1].ID) } + +func TestClient_GetDevice(t *testing.T) { + ctx := context.Background() + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/devices/10001": + w.Write([]byte(`{ + "id": "10001", + "name": "MacBook-Pro-001", + "device_type": "Mac", + "serial": "XXXX1111AAAA", + "registered_owner_info": {"identifier": "44200"} + }`)) + case "/people/44200": + w.Write([]byte(`{"id": "44200", "email": "john.doe@example.com"}`)) + default: + t.Errorf("unexpected request to %v", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer server.Close() + + client := kolide.New("token", logrus.New(), kolide.WithBaseURL(server.URL)) + + t.Run("device with owner resolved", func(t *testing.T) { + d, err := client.GetDevice(ctx, "10001") + require.NoError(t, err) + assert.Equal(t, "10001", d.ID) + assert.Equal(t, "MacBook-Pro-001", d.Name) + assert.Equal(t, "darwin", d.Platform) // converted from "Mac" + assert.Equal(t, "XXXX1111AAAA", d.Serial) + assert.Equal(t, "44200", d.OwnerRef.Identifier) + assert.Equal(t, "44200", d.Owner.ID) + assert.Equal(t, "john.doe@example.com", d.Owner.Email) + }) + + t.Run("device with no owner has empty Owner", func(t *testing.T) { + noOwnerServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/devices/10004": + w.Write([]byte(`{ + "id": "10004", + "name": "MacBook-Pro-003", + "device_type": "Mac", + "serial": "XXXX4444DDDD", + "registered_owner_info": {"identifier": ""} + }`)) + default: + t.Errorf("unexpected request to %v", r.URL.Path) + w.WriteHeader(http.StatusNotFound) + } + })) + defer noOwnerServer.Close() + + noOwnerClient := kolide.New("token", logrus.New(), kolide.WithBaseURL(noOwnerServer.URL)) + d, err := noOwnerClient.GetDevice(ctx, "10004") + require.NoError(t, err) + assert.Equal(t, "10004", d.ID) + assert.Equal(t, "", d.OwnerRef.Identifier) + assert.Equal(t, "", d.Owner.Email) + }) +} diff --git a/internal/apiserver/kolide/fakeclient.go b/internal/apiserver/kolide/fakeclient.go index 390533b1..82a8f282 100644 --- a/internal/apiserver/kolide/fakeclient.go +++ b/internal/apiserver/kolide/fakeclient.go @@ -13,6 +13,11 @@ func (f *FakeClient) GetChecks(ctx context.Context) ([]*Check, error) { panic("unimplemented") } +// GetDevice implements Client. +func (f *FakeClient) GetDevice(ctx context.Context, deviceID string) (*Device, error) { + panic("unimplemented") +} + // GetDeviceIssues implements Client. func (f *FakeClient) GetDeviceIssues(ctx context.Context, deviceID string) ([]*Issue, error) { panic("unimplemented") @@ -28,6 +33,11 @@ func (f *FakeClient) GetIssues(ctx context.Context) ([]*Issue, error) { panic("unimplemented") } +// GetPeople implements Client. +func (f *FakeClient) GetPeople(ctx context.Context) (map[string]*Person, error) { + panic("unimplemented") +} + var _ Client = &FakeClient{} func (f *FakeClient) Build() Client {