Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions cmd/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
_ "net/http/pprof"
"net/netip"
"os"
"strings"
"time"

"github.com/kelseyhightower/envconfig"
Expand Down Expand Up @@ -349,6 +350,21 @@ func run(log *logrus.Entry, cfg config.Config) error {
return fmt.Errorf("read device with external_id=%v: %w", event.GetExternalID(), err)
}

kolideDevice, err := kolideClient.GetDevice(ctx, device.ExternalID)
if err != nil {
return fmt.Errorf("get kolide device %v: %w", device.ExternalID, err)
}

if !strings.EqualFold(kolideDevice.Owner.Email, device.Username) {
log.WithFields(logrus.Fields{
"device_serial": device.Serial,
"device_platform": device.Platform,
"device_username": device.Username,
"kolide_owner_email": kolideDevice.Owner.Email,
"kolide_owner_id": kolideDevice.OwnerRef.Identifier,
}).Warn("kolide device owner email does not match enrolled username")
}

failures, err := kolideClient.GetDeviceIssues(ctx, device.ExternalID)
if err != nil {
return err
Expand Down
29 changes: 29 additions & 0 deletions internal/apiserver/api/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,14 @@ import (
"database/sql"
"errors"
"fmt"
"slices"
"strings"
"time"

"github.com/nais/device/internal/apiserver/kolide"
"github.com/nais/device/internal/apiserver/metrics"
"github.com/nais/device/pkg/pb"
"github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/timestamppb"
Expand Down Expand Up @@ -231,6 +235,31 @@ func (s *grpcServer) UpdateAllDevices(ctx context.Context) error {
}
}

// Log devices that don't have a matching Kolide device (by serial, platform, and owner email)
for _, device := range devices {
idx := slices.IndexFunc(kolideDevices, func(kd *kolide.Device) bool {
return kd.Serial == device.Serial && kd.Platform == device.Platform
})
if idx == -1 {
s.log.WithFields(logrus.Fields{
"device_serial": device.Serial,
"device_platform": device.Platform,
"device_username": device.Username,
}).Info("no matching kolide device found for serial+platform")
continue
}
kd := kolideDevices[idx]
if !strings.EqualFold(kd.Owner.Email, device.Username) {
s.log.WithFields(logrus.Fields{
"device_serial": device.Serial,
"device_platform": device.Platform,
"device_username": device.Username,
"kolide_owner_email": kd.Owner.Email,
"kolide_owner_id": kd.OwnerRef.Identifier,
}).Warn("kolide device owner email does not match enrolled username")
}
}

issues, err := s.kolideClient.GetIssues(ctx)
if err != nil {
return fmt.Errorf("getting kolide issues: %w", err)
Expand Down
55 changes: 39 additions & 16 deletions internal/apiserver/api/gateway_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
)

func Test_MakeGatewayConfiguration(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()

// hash generated with `controlplane-cli passhash --password hunter2`
Expand All @@ -45,17 +45,7 @@ func Test_MakeGatewayConfiguration(t *testing.T) {
AccessGroupIDs: []string{"groupId"},
}

db := database.NewMockDatabase(t)
db.EXPECT().ReadGateway(mock.Anything, "gateway").Return(mockGateway, nil).Times(2)
db.EXPECT().ReadGateway(mock.Anything, "privilegedGateway").Return(mockPrivilegedGateway, nil).Times(2)
db.EXPECT().GetAcceptances(mock.Anything).Return(map[string]struct{}{
"sessionUserId": {},
"sessionUserIdWithPrivileged": {},
}, nil).Times(2)
db.EXPECT().UsersWithAccessToPrivilegedGateway(mock.Anything, "privilegedGateway").Return([]string{"sessionUserIdWithPrivileged"}, nil).Once()

sessionStore := auth.NewMockSessionStore(t)
sessionStore.EXPECT().All().Return([]*pb.Session{
sessions := []*pb.Session{
{
Device: &pb.Device{PublicKey: "devicePublicKey1"},
ObjectID: "sessionUserId",
Expand All @@ -68,7 +58,22 @@ func Test_MakeGatewayConfiguration(t *testing.T) {
Expiry: timestamppb.New(time.Now().Add(24 * time.Hour)),
Groups: []string{"groupId"},
},
})
}

// The server loop runs makeGatewayConfiguration repeatedly (on ticker and
// triggers), so we allow any number of calls to the DB rather than fixing
// exact counts, which would cause flakiness based on goroutine scheduling.
db := database.NewMockDatabase(t)
db.On("ReadGateway", mock.Anything, "gateway").Return(mockGateway, nil).Maybe()
db.On("ReadGateway", mock.Anything, "privilegedGateway").Return(mockPrivilegedGateway, nil).Maybe()
db.On("GetAcceptances", mock.Anything).Return(map[string]struct{}{
"sessionUserId": {},
"sessionUserIdWithPrivileged": {},
}, nil).Maybe()
db.On("UsersWithAccessToPrivilegedGateway", mock.Anything, "privilegedGateway").Return([]string{"sessionUserIdWithPrivileged"}, nil).Maybe()

sessionStore := auth.NewMockSessionStore(t)
sessionStore.On("All").Return(sessions).Maybe()

gatewayAuthenticator := auth.NewGatewayAuthenticator(db)

Expand All @@ -93,9 +98,12 @@ func Test_MakeGatewayConfiguration(t *testing.T) {

client := pb.NewAPIServerClient(conn)

// Test authenticated call with correct password
// Test first stream: normal gateway.
// Use a per-stream context so we can cancel it before opening the second
// stream, preventing the server-side loop from making extra DB calls.
stream1Ctx, stream1Cancel := context.WithCancel(ctx)
stream, err := client.GetGatewayConfiguration(
ctx,
stream1Ctx,
&pb.GetGatewayConfigurationRequest{
Gateway: "gateway",
Password: "hunter2",
Expand All @@ -112,8 +120,21 @@ func Test_MakeGatewayConfiguration(t *testing.T) {
assert.Equal(t, "devicePublicKey1", resp.GetDevices()[0].PublicKey)
assert.Equal(t, "devicePublicKey2", resp.GetDevices()[1].PublicKey)

// Cancel the first stream and drain it before opening the second, to ensure
// the server-side goroutine has exited and won't race with the second stream.
stream1Cancel()
for {
_, err := stream.Recv()
if err != nil {
break
}
}

// Test second stream: privileged gateway.
stream2Ctx, stream2Cancel := context.WithCancel(ctx)
defer stream2Cancel()
stream, err = client.GetGatewayConfiguration(
ctx,
stream2Ctx,
&pb.GetGatewayConfigurationRequest{
Gateway: "privilegedGateway",
Password: "hunter2",
Expand All @@ -128,4 +149,6 @@ func Test_MakeGatewayConfiguration(t *testing.T) {

assert.Len(t, resp.GetDevices(), 1)
assert.Equal(t, "devicePublicKey2", resp.GetDevices()[0].PublicKey)

stream2Cancel()
}
14 changes: 14 additions & 0 deletions internal/apiserver/api/update_all_devices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package api_test

import (
"context"
"fmt"
"testing"
"time"

Expand Down Expand Up @@ -467,3 +468,16 @@ func (m *mockKolideClient) GetChecks(ctx context.Context) ([]*kolide.Check, erro
func (m *mockKolideClient) GetDeviceIssues(ctx context.Context, deviceID string) ([]*kolide.Issue, error) {
return m.issues, nil
}

func (m *mockKolideClient) GetDevice(ctx context.Context, deviceID string) (*kolide.Device, error) {
for _, d := range m.devices {
if d.ID == deviceID {
return d, nil
}
}
return nil, fmt.Errorf("device %v not found", deviceID)
}

func (m *mockKolideClient) GetPeople(ctx context.Context) (map[string]*kolide.Person, error) {
return map[string]*kolide.Person{}, nil
}
59 changes: 59 additions & 0 deletions internal/apiserver/kolide/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ type Client interface {
GetIssues(ctx context.Context) ([]*Issue, error)
GetDeviceIssues(ctx context.Context, deviceID string) ([]*Issue, error)
GetChecks(ctx context.Context) ([]*Check, error)
GetDevice(ctx context.Context, deviceID string) (*Device, error)
GetDevices(ctx context.Context) ([]*Device, error)
GetPeople(ctx context.Context) (map[string]*Person, error)
}

type client struct {
Expand Down Expand Up @@ -137,6 +139,24 @@ func (kc *client) getPaginated(ctx context.Context, initialURL string) ([]json.R
}
}

func (kc *client) GetPeople(ctx context.Context) (map[string]*Person, error) {
rawPeople, err := kc.getPaginated(ctx, kc.baseURL+"/people")
if err != nil {
return nil, fmt.Errorf("getting people: %w", err)
}

people := make(map[string]*Person, len(rawPeople))
for _, rawPerson := range rawPeople {
person := &Person{}
err := json.Unmarshal(rawPerson, person)
if err != nil {
return nil, fmt.Errorf("unmarshal person: %w", err)
}
people[person.ID] = person
}
return people, nil
}

func (kc *client) GetDevices(ctx context.Context) ([]*Device, error) {
kc.log.Debug("getting all devices...")
url := kc.baseURL + "/devices"
Expand All @@ -145,6 +165,11 @@ func (kc *client) GetDevices(ctx context.Context) ([]*Device, error) {
return nil, err
}

people, err := kc.GetPeople(ctx)
if err != nil {
return nil, fmt.Errorf("getting people for device owner resolution: %w", err)
}

devices := make([]*Device, len(rawDevices))
for i, rawDevice := range rawDevices {
device := &Device{}
Expand All @@ -154,12 +179,46 @@ func (kc *client) GetDevices(ctx context.Context) ([]*Device, error) {
}

device.Platform = convertPlatform(device.Platform)
if person, ok := people[device.OwnerRef.Identifier]; ok {
device.Owner = *person
}
devices[i] = device
}

return devices, nil
}

func (kc *client) GetDevice(ctx context.Context, deviceID string) (*Device, error) {
resp, err := kc.get(ctx, fmt.Sprintf(kc.baseURL+"/devices/%v", deviceID))
if err != nil {
return nil, fmt.Errorf("getting device %v: %w", deviceID, err)
}
defer ioconvenience.CloseWithLog(resp.Body, kc.log)

device := &Device{}
if err := json.NewDecoder(resp.Body).Decode(device); err != nil {
return nil, fmt.Errorf("unmarshal device %v: %w", deviceID, err)
}

device.Platform = convertPlatform(device.Platform)

if device.OwnerRef.Identifier != "" {
resp, err := kc.get(ctx, fmt.Sprintf(kc.baseURL+"/people/%v", device.OwnerRef.Identifier))
if err != nil {
return nil, fmt.Errorf("getting person %v for device %v: %w", device.OwnerRef.Identifier, deviceID, err)
}
defer ioconvenience.CloseWithLog(resp.Body, kc.log)

person := &Person{}
if err := json.NewDecoder(resp.Body).Decode(person); err != nil {
return nil, fmt.Errorf("unmarshal person %v: %w", device.OwnerRef.Identifier, err)
}
device.Owner = *person
}

return device, nil
}

func (kc *client) GetDeviceIssues(ctx context.Context, deviceID string) ([]*Issue, error) {
url := fmt.Sprintf(kc.baseURL+"/devices/%v/open_issues", deviceID)
rawIssues, err := kc.getPaginated(ctx, url)
Expand Down
Loading