From 432f29e2cd754ad6255426cc90c3ee82cd2d81cf Mon Sep 17 00:00:00 2001 From: mozturk Date: Wed, 25 Feb 2026 13:36:50 +0300 Subject: [PATCH] BIF-2484 Adapt poller & processor to new message API, Remove SQS --- queue/jec_message.go | 8 ++ queue/job.go | 33 +---- queue/job_test.go | 94 ++++--------- queue/message.go | 25 ++-- queue/message_test.go | 19 ++- queue/poller.go | 120 +++++++++-------- queue/poller_test.go | 167 ++++++++++++++--------- queue/processor.go | 199 +++++++-------------------- queue/processor_test.go | 292 ++++++---------------------------------- 9 files changed, 314 insertions(+), 643 deletions(-) create mode 100644 queue/jec_message.go diff --git a/queue/jec_message.go b/queue/jec_message.go new file mode 100644 index 0000000..018eb01 --- /dev/null +++ b/queue/jec_message.go @@ -0,0 +1,8 @@ +package queue + +// JECMessage represents a message fetched from the JEC API. +type JECMessage struct { + MessageId string `json:"messageId"` + Body string `json:"body"` + ChannelId string `json:"channelId"` +} diff --git a/queue/job.go b/queue/job.go index d631857..5d49251 100644 --- a/queue/job.go +++ b/queue/job.go @@ -2,7 +2,6 @@ package queue import ( "github.com/atlassian/jec/runbook" - "github.com/aws/aws-sdk-go/service/sqs" "github.com/pkg/errors" "github.com/sirupsen/logrus" "sync" @@ -17,11 +16,9 @@ const ( ) type job struct { - queueProvider SQSProvider messageHandler MessageHandler - message sqs.Message - ownerId string + message JECMessage apiKey string baseUrl string @@ -29,12 +26,10 @@ type job struct { executeMutex *sync.Mutex } -func newJob(queueProvider SQSProvider, messageHandler MessageHandler, message sqs.Message, apiKey, baseUrl, ownerId string) *job { +func newJob(messageHandler MessageHandler, message JECMessage, apiKey, baseUrl string) *job { return &job{ - queueProvider: queueProvider, messageHandler: messageHandler, message: message, - ownerId: ownerId, apiKey: apiKey, baseUrl: baseUrl, state: jobInitial, @@ -43,11 +38,7 @@ func newJob(queueProvider SQSProvider, messageHandler MessageHandler, message sq } func (j *job) Id() string { - return *j.message.MessageId -} - -func (j *job) sqsMessage() sqs.Message { - return j.message + return j.message.MessageId } func (j *job) Execute() error { @@ -60,26 +51,8 @@ func (j *job) Execute() error { } j.state = jobExecuting - region := j.queueProvider.Properties().Region() messageId := j.Id() - err := j.queueProvider.DeleteMessage(&j.message) - if err != nil { - j.state = jobError - return errors.Errorf("Message[%s] could not be deleted from the queue[%s]: %s", messageId, region, err) - } - - logrus.Debugf("Message[%s] is deleted from the queue[%s].", messageId, region) - - messageAttr := j.sqsMessage().MessageAttributes - - if messageAttr == nil || - *messageAttr[ownerId].StringValue != j.ownerId && - *messageAttr[channelId].StringValue != j.ownerId { - j.state = jobError - return errors.Errorf("Message[%s] is invalid, will not be processed.", messageId) - } - result, err := j.messageHandler.Handle(j.message) if result != nil { diff --git a/queue/job_test.go b/queue/job_test.go index c3624ac..788a079 100644 --- a/queue/job_test.go +++ b/queue/job_test.go @@ -3,7 +3,6 @@ package queue import ( "encoding/json" "github.com/atlassian/jec/runbook" - "github.com/aws/aws-sdk-go/service/sqs" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "io/ioutil" @@ -20,27 +19,22 @@ var mockActionResultPayload = &runbook.ActionResultPayload{ func newJobTest() *job { mockMessageHandler := &MockMessageHandler{} - mockMessageHandler.HandleFunc = func(message sqs.Message) (payload *runbook.ActionResultPayload, e error) { + mockMessageHandler.HandleFunc = func(message JECMessage) (payload *runbook.ActionResultPayload, e error) { return mockActionResultPayload, nil } - body := "mockBody" - messageAttr := map[string]*sqs.MessageAttributeValue{ownerId: {StringValue: &mockOwnerId}} - - message := sqs.Message{ - MessageId: &mockMessageId, - Body: &body, - MessageAttributes: messageAttr, + message := JECMessage{ + MessageId: mockMessageId, + Body: "mockBody", + ChannelId: mockChannelId, } return &job{ - queueProvider: NewMockQueueProvider(), messageHandler: mockMessageHandler, message: message, executeMutex: &sync.Mutex{}, apiKey: mockApiKey, baseUrl: mockBaseUrl, - ownerId: mockOwnerId, state: jobInitial, } } @@ -61,17 +55,17 @@ func TestExecute(t *testing.T) { })) defer testServer.Close() - sqsJob := newJobTest() - sqsJob.baseUrl = testServer.URL + jecJob := newJobTest() + jecJob.baseUrl = testServer.URL wg.Add(1) - err := sqsJob.Execute() + err := jecJob.Execute() wg.Wait() assert.Nil(t, err) expectedState := int32(jobFinished) - actualState := sqsJob.state + actualState := jecJob.state assert.Equal(t, expectedState, actualState) } @@ -85,8 +79,8 @@ func TestMultipleExecute(t *testing.T) { })) defer testServer.Close() - sqsJob := newJobTest() - sqsJob.baseUrl = testServer.URL + jecJob := newJobTest() + jecJob.baseUrl = testServer.URL errorResults := make(chan error, 25) @@ -94,16 +88,16 @@ func TestMultipleExecute(t *testing.T) { for i := 0; i < 25; i++ { go func() { defer wg.Done() - err := sqsJob.Execute() + err := jecJob.Execute() if err != nil { - errorResults <- sqsJob.Execute() + errorResults <- jecJob.Execute() } }() } wg.Wait() expectedState := int32(jobFinished) - actualState := sqsJob.state + actualState := jecJob.state assert.Equal(t, expectedState, actualState) // only one execute finished assert.Equal(t, 24, len(errorResults)) // other executes will fail @@ -111,13 +105,13 @@ func TestMultipleExecute(t *testing.T) { func TestExecuteInNotInitialState(t *testing.T) { - sqsJob := newJobTest() - sqsJob.state = jobExecuting + jecJob := newJobTest() + jecJob.state = jobExecuting - err := sqsJob.Execute() + err := jecJob.Execute() assert.NotNil(t, err) - expectedErr := errors.Errorf("Job[%s] is already executing or finished.", sqsJob.Id()) + expectedErr := errors.Errorf("Job[%s] is already executing or finished.", jecJob.Id()) assert.EqualError(t, err, expectedErr.Error()) } @@ -141,64 +135,24 @@ func TestExecuteWithProcessError(t *testing.T) { })) defer testServer.Close() - sqsJob := newJobTest() - sqsJob.baseUrl = testServer.URL + jecJob := newJobTest() + jecJob.baseUrl = testServer.URL - sqsJob.messageHandler.(*MockMessageHandler).HandleFunc = func(message sqs.Message) (payload *runbook.ActionResultPayload, e error) { + jecJob.messageHandler.(*MockMessageHandler).HandleFunc = func(message JECMessage) (payload *runbook.ActionResultPayload, e error) { return errPayload, errors.New("Process Error") } wg.Add(1) - err := sqsJob.Execute() + err := jecJob.Execute() wg.Wait() assert.NotNil(t, err) - expectedErr := errors.Errorf("Message[%s] could not be processed: %s", sqsJob.Id(), "Process Error") - assert.EqualError(t, err, expectedErr.Error()) - - expectedState := int32(jobError) - actualState := sqsJob.state - - assert.Equal(t, expectedState, actualState) -} - -func TestExecuteWithDeleteError(t *testing.T) { - - sqsJob := newJobTest() - - sqsJob.queueProvider.(*MockSQSProvider).DeleteMessageFunc = func(message *sqs.Message) error { - return errors.New("Delete Error") - } - - err := sqsJob.Execute() - assert.NotNil(t, err) - - expectedErr := errors.Errorf("Message[%s] could not be deleted from the queue[%s]: %s", sqsJob.Id(), sqsJob.queueProvider.Properties().Region(), "Delete Error") - assert.EqualError(t, err, expectedErr.Error()) - - expectedState := int32(jobError) - actualState := sqsJob.state - - assert.Equal(t, expectedState, actualState) -} - -func TestExecuteWithInvalidQueueMessage(t *testing.T) { - - sqsJob := newJobTest() - - falseIntegrationId := "falseIntegrationId" - messageAttr := map[string]*sqs.MessageAttributeValue{ownerId: {StringValue: &falseIntegrationId}, channelId: {StringValue: &falseIntegrationId}} - sqsJob.message = sqs.Message{MessageAttributes: messageAttr, MessageId: &mockMessageId} - - err := sqsJob.Execute() - assert.NotNil(t, err) - - expectedErr := errors.Errorf("Message[%s] is invalid, will not be processed.", sqsJob.Id()) + expectedErr := errors.Errorf("Message[%s] could not be processed: %s", jecJob.Id(), "Process Error") assert.EqualError(t, err, expectedErr.Error()) expectedState := int32(jobError) - actualState := sqsJob.state + actualState := jecJob.state assert.Equal(t, expectedState, actualState) } diff --git a/queue/message.go b/queue/message.go index bb189a8..7863aef 100644 --- a/queue/message.go +++ b/queue/message.go @@ -7,7 +7,6 @@ import ( "github.com/atlassian/jec/conf" "github.com/atlassian/jec/git" "github.com/atlassian/jec/runbook" - "github.com/aws/aws-sdk-go/service/sqs" "github.com/pkg/errors" "github.com/sirupsen/logrus" "io" @@ -15,7 +14,7 @@ import ( ) type MessageHandler interface { - Handle(message sqs.Message) (*runbook.ActionResultPayload, error) + Handle(message JECMessage) (*runbook.ActionResultPayload, error) } type messageHandler struct { @@ -32,9 +31,9 @@ func NewMessageHandler(repositories git.Repositories, actionSpecs conf.ActionSpe } } -func (mh *messageHandler) Handle(message sqs.Message) (*runbook.ActionResultPayload, error) { +func (mh *messageHandler) Handle(message JECMessage) (*runbook.ActionResultPayload, error) { queuePayload := payload{} - err := json.Unmarshal([]byte(*message.Body), &queuePayload) + err := json.Unmarshal([]byte(message.Body), &queuePayload) if err != nil { return nil, err } @@ -45,7 +44,7 @@ func (mh *messageHandler) Handle(message sqs.Message) (*runbook.ActionResultPayl action = queuePayload.Action } if action == "" { - return nil, errors.Errorf("SQS message does not contain action property.") + return nil, errors.Errorf("Message does not contain action property.") } result := &runbook.ActionResultPayload{ @@ -73,7 +72,7 @@ func (mh *messageHandler) Handle(message sqs.Message) (*runbook.ActionResultPayl case *runbook.ExecError: result.IsSuccessful = false result.FailureMessage = fmt.Sprintf("Err: %s, Stderr: %s", err.Error(), err.Stderr) - logrus.Debugf("Action[%s] execution of message[%s] failed: %s Stderr: %s", action, *message.MessageId, err.Error(), err.Stderr) + logrus.Debugf("Action[%s] execution of message[%s] failed: %s Stderr: %s", action, message.MessageId, err.Error(), err.Stderr) case nil: result.IsSuccessful = true if !queuePayload.DiscardScriptResponse && queuePayload.ActionType == HttpActionType { @@ -82,13 +81,13 @@ func (mh *messageHandler) Handle(message sqs.Message) (*runbook.ActionResultPayl if err != nil { result.IsSuccessful = false logrus.Debugf("Http Action[%s] execution of message[%s] failed, could not parse http response fields: %s, error: %s", - action, *message.MessageId, executionResult, err.Error()) + action, message.MessageId, executionResult, err.Error()) result.FailureMessage = "Could not parse http response fields: " + executionResult } else { result.HttpResponse = httpResult } } - logrus.Debugf("Action[%s] execution of message[%s] has been completed and it took %f seconds.", action, *message.MessageId, took.Seconds()) + logrus.Debugf("Action[%s] execution of message[%s] has been completed and it took %f seconds.", action, message.MessageId, took.Seconds()) default: return nil, err @@ -103,19 +102,19 @@ func (mh *messageHandler) resolveMappedAction(action string, actionType string) if !ok { failureMessage := fmt.Sprintf("No mapped action is configured for requested action[%s]. "+ "The request will be ignored.", action) - return nil, errors.Errorf(failureMessage) + return nil, errors.New(failureMessage) } if mappedAction.Type != actionType { failureMessage := fmt.Sprintf("The type[%s] of the mapped action[%s] is not compatible with requested type[%s]. "+ "The request will be ignored.", mappedAction.Type, action, actionType) - return nil, errors.Errorf(failureMessage) + return nil, errors.New(failureMessage) } return &mappedAction, nil } -func (mh *messageHandler) execute(mappedAction *conf.MappedAction, message *sqs.Message) (string, string, error) { +func (mh *messageHandler) execute(mappedAction *conf.MappedAction, message *JECMessage) (string, string, error) { sourceType := mappedAction.SourceType switch sourceType { @@ -135,7 +134,7 @@ func (mh *messageHandler) execute(mappedAction *conf.MappedAction, message *sqs. case conf.LocalSourceType: args := append(mh.actionSpecs.GlobalFlags.Args(), mappedAction.Flags.Args()...) - args = append(args, []string{"-payload", *message.Body}...) + args = append(args, []string{"-payload", message.Body}...) args = append(args, mh.actionSpecs.GlobalArgs...) args = append(args, mappedAction.Args...) env := append(mh.actionSpecs.GlobalEnv, mappedAction.Env...) @@ -151,7 +150,7 @@ func (mh *messageHandler) execute(mappedAction *conf.MappedAction, message *sqs. } stderr := mh.actionLoggers[mappedAction.Stderr] - callbackContext, err := runbook.ExecuteFunc(*message.MessageId, mappedAction.Filepath, args, env, stdout, stderr) + callbackContext, err := runbook.ExecuteFunc(message.MessageId, mappedAction.Filepath, args, env, stdout, stderr) return stdoutBuff.String(), callbackContext, err default: return "", "", errors.Errorf("Unknown action sourceType[%s].", sourceType) diff --git a/queue/message_test.go b/queue/message_test.go index 77cd91f..fcfbe01 100644 --- a/queue/message_test.go +++ b/queue/message_test.go @@ -5,7 +5,6 @@ import ( "github.com/atlassian/jec/conf" "github.com/atlassian/jec/git" "github.com/atlassian/jec/runbook" - "github.com/aws/aws-sdk-go/service/sqs" "github.com/pkg/errors" "github.com/stretchr/testify/assert" "io" @@ -81,8 +80,7 @@ func TestProcess(t *testing.T) { func testProcessSuccessfully(t *testing.T) { body := `{"action":"Create", "requestId": "RequestId"}` - id := "MessageId" - message := sqs.Message{Body: &body, MessageId: &id} + message := JECMessage{Body: body, MessageId: "MessageId"} queueMessage := NewMessageHandler(nil, mockActionSpecs, mockActionLoggers) runbook.ExecuteFunc = func(executionId string, executablePath string, args, environmentVars []string, stdout, stderr io.Writer) (string, error) { @@ -105,8 +103,7 @@ func testProcessHttpActionSuccessfully(t *testing.T) { } body := `{"actionType":"http", "action":"Retrieve", "requestId": "RequestId"}` - id := "MessageId" - message := sqs.Message{Body: &body, MessageId: &id} + message := JECMessage{Body: body, MessageId: "MessageId"} queueMessage := NewMessageHandler(nil, mockActionSpecs, mockActionLoggers) result, err := queueMessage.Handle(message) @@ -127,7 +124,7 @@ func testProcessMappedActionNotFound(t *testing.T) { runbook.ExecuteFunc = mockExecute body := `{"actionType":"custom", "action":"Ack", "requestId": "RequestId"}` - message := sqs.Message{Body: &body} + message := JECMessage{Body: body} messageHandler := NewMessageHandler(nil, mockActionSpecs, mockActionLoggers) result, err := messageHandler.Handle(message) @@ -147,7 +144,7 @@ func testProcessActionTypeNotMatched(t *testing.T) { runbook.ExecuteFunc = mockExecute body := `{"actionType":"http", "action":"Close", "requestId": "RequestId"}` - message := sqs.Message{Body: &body} + message := JECMessage{Body: body} messageHandler := NewMessageHandler(nil, mockActionSpecs, mockActionLoggers) result, err := messageHandler.Handle(message) @@ -170,20 +167,20 @@ func testProcessFieldMissing(t *testing.T) { runbook.ExecuteFunc = mockExecute body := `{"alert":{}}` - message := sqs.Message{Body: &body} + message := JECMessage{Body: body} messageHandler := NewMessageHandler(nil, mockActionSpecs, mockActionLoggers) _, err := messageHandler.Handle(message) - expectedErr := errors.New("SQS message does not contain action property.") + expectedErr := errors.New("Message does not contain action property.") assert.EqualError(t, err, expectedErr.Error()) } // Mock Queue Message type MockMessageHandler struct { - HandleFunc func(message sqs.Message) (*runbook.ActionResultPayload, error) + HandleFunc func(message JECMessage) (*runbook.ActionResultPayload, error) } -func (mqm *MockMessageHandler) Handle(message sqs.Message) (*runbook.ActionResultPayload, error) { +func (mqm *MockMessageHandler) Handle(message JECMessage) (*runbook.ActionResultPayload, error) { if mqm.HandleFunc != nil { return mqm.HandleFunc(message) } diff --git a/queue/poller.go b/queue/poller.go index 20e86ea..575dfed 100644 --- a/queue/poller.go +++ b/queue/poller.go @@ -1,13 +1,17 @@ package queue import ( + "encoding/json" + "fmt" "github.com/atlassian/jec/conf" + "github.com/atlassian/jec/retryer" "github.com/atlassian/jec/util" "github.com/atlassian/jec/worker_pool" - "github.com/aws/aws-sdk-go/service/sqs" "github.com/pkg/errors" "github.com/sirupsen/logrus" "gopkg.in/natefinch/lumberjack.v2" + "io/ioutil" + "net/http" "os" "path/filepath" "strconv" @@ -15,18 +19,19 @@ import ( "time" ) +const messagesPath = "/jsm/ops/jec/v1/messages/channels/" + type Poller interface { Processor - RefreshClient(assumeRoleResult AssumeRoleResult) error - QueueProvider() SQSProvider + ChannelId() string } type poller struct { workerPool worker_pool.WorkerPool - queueProvider SQSProvider messageHandler MessageHandler + retryer *retryer.Retryer - ownerId string + channelId string conf *conf.Configuration queueMessageLogrus *logrus.Logger @@ -38,18 +43,17 @@ type poller struct { } func NewPoller(workerPool worker_pool.WorkerPool, - queueProvider SQSProvider, messageHandler MessageHandler, conf *conf.Configuration, - ownerId string) Poller { + channelId string) Poller { return &poller{ workerPool: workerPool, - queueProvider: queueProvider, messageHandler: messageHandler, - ownerId: ownerId, + retryer: &retryer.Retryer{}, + channelId: channelId, conf: conf, - queueMessageLogrus: newQueueMessageLogrus(queueProvider.Properties().Region()), + queueMessageLogrus: newQueueMessageLogrus(channelId), isRunning: false, isRunningWg: &sync.WaitGroup{}, startStopMu: &sync.Mutex{}, @@ -58,12 +62,8 @@ func NewPoller(workerPool worker_pool.WorkerPool, } } -func (p *poller) QueueProvider() SQSProvider { - return p.queueProvider -} - -func (p *poller) RefreshClient(assumeRoleResult AssumeRoleResult) error { - return p.queueProvider.RefreshClient(assumeRoleResult) +func (p *poller) ChannelId() string { + return p.channelId } func (p *poller) Start() error { @@ -99,21 +99,46 @@ func (p *poller) Stop() error { return nil } -func (p *poller) terminateMessageVisibility(messages []*sqs.Message) { +func (p *poller) fetchMessages(maxNumberOfMessages int64) ([]*JECMessage, error) { - region := p.queueProvider.Properties().Region() + url := fmt.Sprintf("%s%s%s", p.conf.BaseUrl, messagesPath, p.channelId) + + request, err := retryer.NewRequest(http.MethodGet, url, nil) + if err != nil { + return nil, err + } - for i := 0; i < len(messages); i++ { - messageId := *messages[i].MessageId + request.Header.Add("Authorization", "GenieKey "+p.conf.ApiKey) + request.Header.Add("X-JEC-Client-Info", UserAgentHeader) - err := p.queueProvider.ChangeMessageVisibility(messages[i], 0) - if err != nil { - logrus.Warnf("Poller[%s] could not terminate visibility of message[%s]: %s.", region, messageId, err.Error()) - continue - } + response, err := p.retryer.Do(request) + if err != nil { + return nil, err + } + defer response.Body.Close() - logrus.Debugf("Poller[%s] terminated visibility of message[%s].", region, messageId) + if response.StatusCode != http.StatusOK { + body, _ := ioutil.ReadAll(response.Body) + return nil, errors.Errorf("Failed to fetch messages from channel[%s], status: %s, message: %s", p.channelId, response.Status, body) } + + body, err := ioutil.ReadAll(response.Body) + if err != nil { + return nil, err + } + + var messages []*JECMessage + err = json.Unmarshal(body, &messages) + if err != nil { + return nil, err + } + + // Limit to maxNumberOfMessages + if int64(len(messages)) > maxNumberOfMessages { + messages = messages[:maxNumberOfMessages] + } + + return messages, nil } func (p *poller) poll() (shouldWait bool) { @@ -123,45 +148,41 @@ func (p *poller) poll() (shouldWait bool) { return true } - region := p.queueProvider.Properties().Region() maxNumberOfMessages := util.Min(p.conf.PollerConf.MaxNumberOfMessages, int64(availableWorkerCount)) - messages, err := p.queueProvider.ReceiveMessage(maxNumberOfMessages, p.conf.PollerConf.VisibilityTimeoutInSeconds) - if err != nil { // todo check wait time according to error / check error - logrus.Errorf("Poller[%s] could not receive message: %s", region, err.Error()) + messages, err := p.fetchMessages(maxNumberOfMessages) + if err != nil { + logrus.Errorf("Poller[%s] could not fetch messages: %s", p.channelId, err.Error()) return true } messageLength := len(messages) if messageLength == 0 { - logrus.Tracef("There is no new message in the queue[%s].", region) + logrus.Tracef("There is no new message in channel[%s].", p.channelId) return true } - logrus.Debugf("Received %d messages from the queue[%s].", messageLength, region) + logrus.Debugf("Received %d messages from channel[%s].", messageLength, p.channelId) for i := 0; i < messageLength; i++ { p.queueMessageLogrus. - WithField("messageId", *messages[i].MessageId). - Info("Message body: ", *messages[i].Body) + WithField("messageId", messages[i].MessageId). + Info("Message body: ", messages[i].Body) job := newJob( - p.queueProvider, p.messageHandler, *messages[i], p.conf.ApiKey, p.conf.BaseUrl, - p.ownerId, ) isSubmitted, err := p.workerPool.Submit(job) if err != nil { - logrus.Debugf("Error occurred while submitting, messages will be terminated: %s.", err.Error()) - p.terminateMessageVisibility(messages[i:]) + logrus.Debugf("Error occurred while submitting: %s.", err.Error()) return true } else if !isSubmitted { - p.terminateMessageVisibility(messages[i : i+1]) + logrus.Debugf("Job[%s] could not be submitted.", messages[i].MessageId) } } return false @@ -169,8 +190,7 @@ func (p *poller) poll() (shouldWait bool) { func (p *poller) wait(pollingWaitInterval time.Duration) { - queueUrl := p.queueProvider.Properties().Url() - logrus.Tracef("Poller[%s] will wait %s before next polling", queueUrl, pollingWaitInterval.String()) + logrus.Tracef("Poller[%s] will wait %s before next polling", p.channelId, pollingWaitInterval.String()) ticker := time.NewTicker(pollingWaitInterval) defer ticker.Stop() @@ -178,7 +198,7 @@ func (p *poller) wait(pollingWaitInterval time.Duration) { for { select { case <-p.wakeUp: - logrus.Debugf("Poller[%s] has been interrupted while waiting for next polling.", queueUrl) + logrus.Debugf("Poller[%s] has been interrupted while waiting for next polling.", p.channelId) return case <-ticker.C: return @@ -188,32 +208,26 @@ func (p *poller) wait(pollingWaitInterval time.Duration) { func (p *poller) run() { - queueUrl := p.queueProvider.Properties().Url() - logrus.Infof("Poller[%s] has started to run.", queueUrl) + logrus.Infof("Poller[%s] has started to run.", p.channelId) pollingWaitInterval := p.conf.PollerConf.PollingWaitIntervalInMillis * time.Millisecond - expiredTokenWaitInterval := errorRefreshPeriod for { select { case <-p.quit: - logrus.Infof("Poller[%s] has stopped to poll.", queueUrl) + logrus.Infof("Poller[%s] has stopped to poll.", p.channelId) p.isRunningWg.Done() return default: - if p.queueProvider.IsTokenExpired() { - region := p.queueProvider.Properties().Region() - logrus.Warnf("Security token is expired, poller[%s] skips to receive message.", region) - p.wait(expiredTokenWaitInterval) - } else if shouldWait := p.poll(); shouldWait { + if shouldWait := p.poll(); shouldWait { p.wait(pollingWaitInterval) } } } } -func newQueueMessageLogrus(region string) *logrus.Logger { - logFilePath := filepath.Join("/var", "log", "jec", "jecQueueMessages-"+region+"-"+strconv.Itoa(os.Getpid())+".log") +func newQueueMessageLogrus(channelId string) *logrus.Logger { + logFilePath := filepath.Join("/var", "log", "jec", "jecQueueMessages-"+channelId+"-"+strconv.Itoa(os.Getpid())+".log") queueMessageLogger := &lumberjack.Logger{ Filename: logFilePath, MaxSize: 3, // MB diff --git a/queue/poller_test.go b/queue/poller_test.go index c284f6a..e0049f8 100644 --- a/queue/poller_test.go +++ b/queue/poller_test.go @@ -1,20 +1,25 @@ package queue import ( + "bytes" + "encoding/json" "github.com/atlassian/jec/conf" + "github.com/atlassian/jec/retryer" "github.com/atlassian/jec/worker_pool" - "github.com/aws/aws-sdk-go/service/sqs" "github.com/pkg/errors" "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" + "io/ioutil" + "net/http" + "strconv" "sync" "testing" ) var mockPollerConf = &conf.PollerConf{ - pollingWaitIntervalInMillis, - visibilityTimeoutInSec, - maxNumberOfMessages, + PollingWaitIntervalInMillis: pollingWaitIntervalInMillis, + VisibilityTimeoutInSeconds: visibilityTimeoutInSec, + MaxNumberOfMessages: maxNumberOfMessages, } func newPollerTest() *poller { @@ -30,14 +35,27 @@ func newPollerTest() *poller { PollerConf: *mockPollerConf, ActionSpecifications: mockActionSpecs, }, - + channelId: mockChannelId, workerPool: NewMockWorkerPool(), - queueProvider: NewMockQueueProvider(), messageHandler: NewMockMessageHandler(), + retryer: &retryer.Retryer{}, queueMessageLogrus: &logrus.Logger{}, } } +func mockFetchMessages(messages []*JECMessage, err error) func(r *retryer.Retryer, request *retryer.Request) (*http.Response, error) { + return func(r *retryer.Retryer, request *retryer.Request) (*http.Response, error) { + if err != nil { + return nil, err + } + body, _ := json.Marshal(messages) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(body)), + }, nil + } +} + func TestStartAndStopPolling(t *testing.T) { poller := newPollerTest() @@ -84,9 +102,7 @@ func TestPollWithReceiveError(t *testing.T) { poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return 1 } - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = func(i int64, i2 int64) ([]*sqs.Message, error) { - return nil, errors.New("") - } + poller.retryer.DoFunc = mockFetchMessages(nil, errors.New("")) shouldWait := poller.poll() assert.True(t, shouldWait) @@ -99,9 +115,7 @@ func TestPollZeroMessage(t *testing.T) { poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return 1 } - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = func(i int64, i2 int64) ([]*sqs.Message, error) { - return []*sqs.Message{}, nil - } + poller.retryer.DoFunc = mockFetchMessages([]*JECMessage{}, nil) logrus.SetLevel(logrus.DebugLevel) shouldWait := poller.poll() @@ -112,54 +126,84 @@ func TestPollMaxMessage(t *testing.T) { poller := newPollerTest() - expected := 4 + expected := int64(4) poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return int32(expected) } - maxNumberOfMessages := 0 - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = func(numOfMessage int64, visibilityTimeout int64) ([]*sqs.Message, error) { - maxNumberOfMessages = int(numOfMessage) - return nil, errors.New("Receive Error") + messages := make([]*JECMessage, expected) + for i := int64(0); i < expected; i++ { + messages[i] = &JECMessage{ + MessageId: strconv.FormatInt(i, 10), + Body: "body", + ChannelId: mockChannelId, + } + } + poller.retryer.DoFunc = mockFetchMessages(messages, nil) + + submitCount := 0 + poller.workerPool.(*MockWorkerPool).SubmitFunc = func(job worker_pool.Job) (bool, error) { + submitCount++ + return true, nil } shouldWait := poller.poll() - assert.True(t, shouldWait) - assert.Equal(t, expected, maxNumberOfMessages) + assert.False(t, shouldWait) + assert.Equal(t, int(expected), submitCount) } func TestPollMaxMessageUpperBound(t *testing.T) { poller := newPollerTest() - availableWorkerCount := 12 + availableWorkerCount := int64(12) poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return int32(availableWorkerCount) } - maxNumberOfMessages := int64(0) - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = func(numOfMessage int64, visibilityTimeout int64) ([]*sqs.Message, error) { - maxNumberOfMessages = numOfMessage - return nil, errors.New("Receive Error") + // API returns more messages than maxNumberOfMessages, poller should cap + messages := make([]*JECMessage, 20) + for i := 0; i < 20; i++ { + messages[i] = &JECMessage{ + MessageId: strconv.FormatInt(int64(i), 10), + Body: "body", + ChannelId: mockChannelId, + } + } + poller.retryer.DoFunc = mockFetchMessages(messages, nil) + + submitCount := 0 + poller.workerPool.(*MockWorkerPool).SubmitFunc = func(job worker_pool.Job) (bool, error) { + submitCount++ + return true, nil } shouldWait := poller.poll() - assert.True(t, shouldWait) - assert.Equal(t, poller.conf.PollerConf.MaxNumberOfMessages, maxNumberOfMessages) + assert.False(t, shouldWait) + assert.Equal(t, int(poller.conf.PollerConf.MaxNumberOfMessages), submitCount) } func TestPollMessageSubmitFail(t *testing.T) { poller := newPollerTest() - expected := 4 + expected := int64(4) poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return int32(expected) } - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = mockSuccessReceiveFunc + + messages := make([]*JECMessage, expected) + for i := int64(0); i < expected; i++ { + messages[i] = &JECMessage{ + MessageId: strconv.FormatInt(i, 10), + Body: "body", + ChannelId: mockChannelId, + } + } + poller.retryer.DoFunc = mockFetchMessages(messages, nil) submitCount := 0 poller.workerPool.(*MockWorkerPool).SubmitFunc = func(job worker_pool.Job) (bool, error) { @@ -167,31 +211,31 @@ func TestPollMessageSubmitFail(t *testing.T) { return false, nil } - releaseCount := 0 - poller.queueProvider.(*MockSQSProvider).ChangeMessageVisibilityFunc = func(message *sqs.Message, visibilityTimeout int64) error { - if visibilityTimeout == 0 { - releaseCount++ - } - return nil - } - shouldWait := poller.poll() assert.False(t, shouldWait) - assert.Equal(t, expected, submitCount) - assert.Equal(t, expected, releaseCount) + assert.Equal(t, int(expected), submitCount) } func TestPollMessageSubmitError(t *testing.T) { poller := newPollerTest() - expected := 5 + expected := int64(5) poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return int32(expected) } - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = mockSuccessReceiveFunc + + messages := make([]*JECMessage, expected) + for i := int64(0); i < expected; i++ { + messages[i] = &JECMessage{ + MessageId: strconv.FormatInt(i, 10), + Body: "body", + ChannelId: mockChannelId, + } + } + poller.retryer.DoFunc = mockFetchMessages(messages, nil) submitCount := 0 poller.workerPool.(*MockWorkerPool).SubmitFunc = func(job worker_pool.Job) (bool, error) { @@ -199,19 +243,10 @@ func TestPollMessageSubmitError(t *testing.T) { return false, errors.New("Submit Error") } - releaseCount := 0 - poller.queueProvider.(*MockSQSProvider).ChangeMessageVisibilityFunc = func(message *sqs.Message, visibilityTimeout int64) error { - if visibilityTimeout == 0 { - releaseCount++ - } - return nil - } - shouldWait := poller.poll() assert.True(t, shouldWait) assert.Equal(t, 1, submitCount) - assert.Equal(t, expected, releaseCount) } func TestPollMessageSubmitSuccess(t *testing.T) { @@ -221,7 +256,16 @@ func TestPollMessageSubmitSuccess(t *testing.T) { poller.workerPool.(*MockWorkerPool).NumberOfAvailableWorkerFunc = func() int32 { return 5 } - poller.queueProvider.(*MockSQSProvider).ReceiveMessageFunc = mockSuccessReceiveFunc + + messages := make([]*JECMessage, 5) + for i := 0; i < 5; i++ { + messages[i] = &JECMessage{ + MessageId: strconv.FormatInt(int64(i), 10), + Body: "body", + ChannelId: mockChannelId, + } + } + poller.retryer.DoFunc = mockFetchMessages(messages, nil) poller.workerPool.(*MockWorkerPool).SubmitFunc = func(job worker_pool.Job) (bool, error) { return true, nil @@ -236,17 +280,15 @@ func TestPollMessageSubmitSuccess(t *testing.T) { type MockPoller struct { StartPollingFunc func() error StopPollingFunc func() error - - RefreshClientFunc func(assumeRoleResult AssumeRoleResult) error - QueueProviderFunc func() SQSProvider + ChannelIdFunc func() string } func NewMockPoller() Poller { return &MockPoller{} } -func NewMockPollerForQueueProcessor(workerPool worker_pool.WorkerPool, queueProvider SQSProvider, - messageHandler MessageHandler, conf *conf.Configuration, ownerId string) Poller { +func NewMockPollerForQueueProcessor(workerPool worker_pool.WorkerPool, + messageHandler MessageHandler, conf *conf.Configuration, channelId string) Poller { return NewMockPoller() } @@ -264,16 +306,9 @@ func (p *MockPoller) Stop() error { return nil } -func (p *MockPoller) RefreshClient(assumeRoleResult AssumeRoleResult) error { - if p.RefreshClientFunc != nil { - return p.RefreshClientFunc(assumeRoleResult) - } - return nil -} - -func (p *MockPoller) QueueProvider() SQSProvider { - if p.QueueProviderFunc != nil { - return p.QueueProviderFunc() +func (p *MockPoller) ChannelId() string { + if p.ChannelIdFunc != nil { + return p.ChannelIdFunc() } - return NewMockQueueProvider() + return mockChannelId } diff --git a/queue/processor.go b/queue/processor.go index 54ac49a..30a1ecb 100644 --- a/queue/processor.go +++ b/queue/processor.go @@ -1,7 +1,6 @@ package queue import ( - "bytes" "encoding/json" "github.com/atlassian/jec/conf" "github.com/atlassian/jec/git" @@ -13,7 +12,6 @@ import ( "io" "io/ioutil" "net/http" - "strconv" "sync" "time" ) @@ -21,17 +19,20 @@ import ( var UserAgentHeader string const ( - pollingWaitIntervalInMillis = 100 + pollingWaitIntervalInMillis = 1000 visibilityTimeoutInSec = 30 maxNumberOfMessages = 10 - successRefreshPeriod = time.Minute - errorRefreshPeriod = time.Minute - repositoryRefreshPeriod = time.Minute ) -const tokenPath = "/jsm/ops/jec/v1/credentials" +const authenticatePath = "/jsm/ops/jec/v1/authenticate" + +type authenticateResponse struct { + ChannelId string `json:"channelId"` + OwnerId string `json:"ownerId"` + OwnerType string `json:"ownerType"` +} var newPollerFunc = NewPoller @@ -42,7 +43,7 @@ type Processor interface { type processor struct { workerPool worker_pool.WorkerPool - pollers map[string]Poller + poller Poller retryer *retryer.Retryer @@ -50,9 +51,6 @@ type processor struct { repositories git.Repositories actionLoggers map[string]io.Writer - successRefreshPeriod time.Duration - errorRefreshPeriod time.Duration - isRunning bool isRunningWg *sync.WaitGroup startStopMu *sync.Mutex @@ -77,18 +75,15 @@ func NewProcessor(conf *conf.Configuration) Processor { } return &processor{ - successRefreshPeriod: successRefreshPeriod, - errorRefreshPeriod: errorRefreshPeriod, - workerPool: worker_pool.New(&conf.PoolConf), - configuration: conf, - repositories: git.NewRepositories(), - actionLoggers: newActionLoggers(conf.ActionMappings), - pollers: make(map[string]Poller), - quit: make(chan struct{}), - isRunning: false, - isRunningWg: &sync.WaitGroup{}, - startStopMu: &sync.Mutex{}, - retryer: &retryer.Retryer{}, + workerPool: worker_pool.New(&conf.PoolConf), + configuration: conf, + repositories: git.NewRepositories(), + actionLoggers: newActionLoggers(conf.ActionMappings), + quit: make(chan struct{}), + isRunning: false, + isRunningWg: &sync.WaitGroup{}, + startStopMu: &sync.Mutex{}, + retryer: &retryer.Retryer{}, } } @@ -101,9 +96,10 @@ func (qp *processor) Start() error { } logrus.Infof("Queue processor is starting.") - token, err := qp.receiveToken() + + authResp, err := qp.authenticate() if err != nil { - logrus.Errorf("Queue processor could not get initial token and will terminate.") + logrus.Errorf("Queue processor could not authenticate and will terminate.") return err } @@ -114,17 +110,30 @@ func (qp *processor) Start() error { } if qp.repositories.NotEmpty() { - qp.isRunningWg.Add(1) // one for pulling repositories + qp.isRunningWg.Add(1) go qp.startPullingRepositories(repositoryRefreshPeriod) conf.AddRepositoryPathToGitActionFilepaths(qp.configuration.ActionMappings, qp.repositories) } + qp.workerPool.Start() - qp.refreshPollers(token) - qp.isRunningWg.Add(1) // one for receiving token - go qp.run() + + messageHandler := &messageHandler{ + repositories: qp.repositories, + actionSpecs: qp.configuration.ActionSpecifications, + actionLoggers: qp.actionLoggers, + } + + qp.poller = newPollerFunc( + qp.workerPool, + messageHandler, + qp.configuration, + authResp.ChannelId, + ) + qp.poller.Start() qp.isRunning = true + logrus.Infof("Queue processor has started.") return nil } @@ -141,6 +150,10 @@ func (qp *processor) Stop() error { close(qp.quit) qp.isRunningWg.Wait() + if qp.poller != nil { + qp.poller.Stop() + } + qp.workerPool.Stop() qp.repositories.RemoveAll() @@ -149,11 +162,11 @@ func (qp *processor) Stop() error { return nil } -func (qp *processor) receiveToken() (*token, error) { +func (qp *processor) authenticate() (*authenticateResponse, error) { - tokenUrl := qp.configuration.BaseUrl + tokenPath + url := qp.configuration.BaseUrl + authenticatePath - request, err := retryer.NewRequest(http.MethodGet, tokenUrl, nil) + request, err := retryer.NewRequest(http.MethodGet, url, nil) if err != nil { return nil, err } @@ -161,16 +174,6 @@ func (qp *processor) receiveToken() (*token, error) { request.Header.Add("Authorization", "GenieKey "+qp.configuration.ApiKey) request.Header.Add("X-JEC-Client-Info", UserAgentHeader) - query := request.URL.Query() - for _, poller := range qp.pollers { - queueProperties := poller.QueueProvider().Properties() - query.Add( - queueProperties.Region(), - strconv.FormatInt(queueProperties.ExpireTimeMillis(), 10), - ) - } - request.URL.RawQuery = query.Encode() - response, err := qp.retryer.Do(request) if err != nil { return nil, err @@ -182,121 +185,19 @@ func (qp *processor) receiveToken() (*token, error) { return nil, errors.Errorf("Token could not be received from Jira Service Management, status: %s, message: %s", response.Status, body) } - responseToken := bytes.NewBufferString(response.Header.Get("Token")) - - token := &token{} - err = json.NewDecoder(responseToken).Decode(&token) + body, err := io.ReadAll(response.Body) if err != nil { return nil, err } - return token, nil -} - -func (qp *processor) addPoller(queueProperties Properties, ownerId string) (Poller, error) { - - queueProvider, err := NewSqsProvider(queueProperties) + authResp := &authenticateResponse{} + err = json.Unmarshal(body, authResp) if err != nil { return nil, err } - messageHandler := &messageHandler{ - repositories: qp.repositories, - actionSpecs: qp.configuration.ActionSpecifications, - actionLoggers: qp.actionLoggers, - } - - poller := newPollerFunc( - qp.workerPool, - queueProvider, - messageHandler, - qp.configuration, - ownerId, - ) - qp.pollers[queueProvider.Properties().Url()] = poller - return poller, nil -} - -func (qp *processor) removePoller(queueUrl string) Poller { - poller := qp.pollers[queueUrl] - delete(qp.pollers, queueUrl) - return poller -} - -func (qp *processor) refreshPollers(token *token) { - pollerKeys := make(map[string]struct{}, len(qp.pollers)) - for key := range qp.pollers { - pollerKeys[key] = struct{}{} - } - - for _, queueProperties := range token.QueuePropertiesList { - queueUrl := queueProperties.Url() - - // refresh existing pollers if there comes new AssumeRoleResult - if poller, contains := qp.pollers[queueUrl]; contains { - isTokenRefreshed := queueProperties.AssumeRoleResult != AssumeRoleResult{} - if isTokenRefreshed { - err := poller.RefreshClient(queueProperties.AssumeRoleResult) - if err != nil { - logrus.Errorf("Client of queue provider[%s] could not be refreshed.", queueUrl) - } - logrus.Infof("Client of queue provider[%s] has refreshed.", queueUrl) - } - delete(pollerKeys, queueUrl) - - // add new pollers - } else { - poller, err := qp.addPoller(queueProperties, token.OwnerId) - if err != nil { - logrus.Errorf("Poller[%s] could not be added: %s.", queueUrl, err) - continue - } - poller.Start() - logrus.Debugf("Poller[%s] is added.", queueUrl) - } - } - - // remove unnecessary pollers - for queueUrl := range pollerKeys { - qp.removePoller(queueUrl).Stop() - logrus.Debugf("Poller[%s] is removed.", queueUrl) - } - - if len(token.QueuePropertiesList) != 0 { // pick first Properties to refresh waitPeriods, can be change for further usage - qp.successRefreshPeriod = time.Second * time.Duration(token.QueuePropertiesList[0].Configuration.SuccessRefreshPeriodInSeconds) - qp.errorRefreshPeriod = time.Second * time.Duration(token.QueuePropertiesList[0].Configuration.ErrorRefreshPeriodInSeconds) - } -} - -func (qp *processor) run() { - - logrus.Infof("Queue processor has started to run. Refresh client period: %s.", qp.successRefreshPeriod.String()) - - ticker := time.NewTicker(qp.successRefreshPeriod) - - for { - select { - case <-qp.quit: - ticker.Stop() - for _, poller := range qp.pollers { - poller.Stop() - } - qp.isRunningWg.Done() - return - case <-ticker.C: - ticker.Stop() - token, err := qp.receiveToken() - if err != nil { - logrus.Warnf("Refresh cycle of queue processor has failed: %s", err) - logrus.Debugf("Will refresh token after %s", qp.errorRefreshPeriod.String()) - ticker = time.NewTicker(qp.errorRefreshPeriod) - break - } - qp.refreshPollers(token) - - ticker = time.NewTicker(qp.successRefreshPeriod) - } - } + logrus.Infof("Successfully authenticated. ChannelId: %s, OwnerId: %s, OwnerType: %s", authResp.ChannelId, authResp.OwnerId, authResp.OwnerType) + return authResp, nil } func (qp *processor) startPullingRepositories(pullPeriod time.Duration) { diff --git a/queue/processor_test.go b/queue/processor_test.go index f4dfdf7..43cffe0 100644 --- a/queue/processor_test.go +++ b/queue/processor_test.go @@ -11,8 +11,6 @@ import ( "github.com/stretchr/testify/assert" "io/ioutil" "net/http" - "strconv" - "strings" "sync" "testing" "time" @@ -32,51 +30,31 @@ var mockPoolConf = &conf.PoolConf{ func newQueueProcessorTest() *processor { return &processor{ - successRefreshPeriod: successRefreshPeriod, - errorRefreshPeriod: errorRefreshPeriod, - workerPool: NewMockWorkerPool(), - configuration: mockConf, - repositories: git.NewRepositories(), - pollers: make(map[string]Poller), - quit: make(chan struct{}), - isRunning: false, - isRunningWg: &sync.WaitGroup{}, - startStopMu: &sync.Mutex{}, - retryer: &retryer.Retryer{}, + workerPool: NewMockWorkerPool(), + configuration: mockConf, + repositories: git.NewRepositories(), + quit: make(chan struct{}), + isRunning: false, + isRunningWg: &sync.WaitGroup{}, + startStopMu: &sync.Mutex{}, + retryer: &retryer.Retryer{}, } } -var mockPollers = map[string]Poller{ - mockQueueUrl1: NewMockPoller(), - mockQueueUrl2: NewMockPoller(), -} - -func mockHttpGetError(retryer *retryer.Retryer, request *retryer.Request) (*http.Response, error) { - return nil, errors.New("Test http error has occurred while getting token.") -} - -func mockHttpGet(retryer *retryer.Retryer, request *retryer.Request) (*http.Response, error) { - - token, _ := json.Marshal(mockToken) - - header := http.Header{} - header.Add("Token", string(token)) - - response := &http.Response{ - StatusCode: 200, - Header: header, - Body: ioutil.NopCloser(nil), +func mockAuthenticateSuccess(r *retryer.Retryer, request *retryer.Request) (*http.Response, error) { + authResp := authenticateResponse{ + ChannelId: mockChannelId, + OwnerId: "mockOwnerId", } - - return response, nil + body, _ := json.Marshal(authResp) + return &http.Response{ + StatusCode: http.StatusOK, + Body: ioutil.NopCloser(bytes.NewReader(body)), + }, nil } -func mockHttpGetInvalidJson(retryer *retryer.Retryer, request *retryer.Request) (*http.Response, error) { - - response := &http.Response{} - response.Body = ioutil.NopCloser(bytes.NewBufferString(`{"Invalid json": }`)) - - return response, nil +func mockAuthenticateError(r *retryer.Retryer, request *retryer.Request) (*http.Response, error) { + return nil, errors.New("Test http error has occurred while authenticating.") } func TestValidateNewQueueProcessor(t *testing.T) { @@ -96,19 +74,20 @@ func TestStartAndStopQueueProcessor(t *testing.T) { processor := newQueueProcessorTest() - processor.retryer.DoFunc = mockHttpGet + processor.retryer.DoFunc = mockAuthenticateSuccess newPollerFunc = NewMockPollerForQueueProcessor err := processor.Start() assert.Nil(t, err) - - assert.Equal(t, 2, len(processor.pollers)) + assert.True(t, processor.isRunning) + assert.NotNil(t, processor.poller) err = processor.Stop() assert.Nil(t, err) + assert.False(t, processor.isRunning) } -func TestStartQueueProcessorAndRefresh(t *testing.T) { +func TestStartQueueProcessorAuthenticationError(t *testing.T) { defer func() { newPollerFunc = NewPoller @@ -116,37 +95,13 @@ func TestStartQueueProcessorAndRefresh(t *testing.T) { processor := newQueueProcessorTest() - processor.retryer.DoFunc = mockHttpGet - processor.successRefreshPeriod = time.Nanosecond - newPollerFunc = NewMockPollerForQueueProcessor - - err := processor.Start() - assert.Nil(t, err) - - time.Sleep(time.Nanosecond * 100) - - assert.Equal(t, 2, len(processor.pollers)) - assert.Equal(t, successRefreshPeriod, processor.successRefreshPeriod) - - err = processor.Stop() - assert.Nil(t, err) -} - -func TestStartQueueProcessorInitialError(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - processor.retryer.DoFunc = mockHttpGetError + processor.retryer.DoFunc = mockAuthenticateError newPollerFunc = NewMockPollerForQueueProcessor err := processor.Start() assert.NotNil(t, err) - assert.Equal(t, "Test http error has occurred while getting token.", err.Error()) + assert.Equal(t, "Test http error has occurred while authenticating.", err.Error()) } func TestStopQueueProcessorWhileNotRunning(t *testing.T) { @@ -159,213 +114,48 @@ func TestStopQueueProcessorWhileNotRunning(t *testing.T) { assert.Equal(t, "Queue processor is not running.", err.Error()) } -func TestReceiveToken(t *testing.T) { +func TestAuthenticate(t *testing.T) { processor := newQueueProcessorTest() - processor.pollers = mockPollers - var actualRequest *http.Request - processor.retryer.DoFunc = func(retryer *retryer.Retryer, request *retryer.Request) (*http.Response, error) { + processor.retryer.DoFunc = func(r *retryer.Retryer, request *retryer.Request) (*http.Response, error) { actualRequest = request.Request - return mockHttpGet(retryer, request) + return mockAuthenticateSuccess(r, request) } - token, err := processor.receiveToken() + authResp, err := processor.authenticate() assert.Nil(t, err) - assert.Equal(t, 2, len(token.QueuePropertiesList)) - assert.Equal(t, "accessKeyId1", token.QueuePropertiesList[0].AssumeRoleResult.Credentials.AccessKeyId) - assert.Equal(t, "accessKeyId2", token.QueuePropertiesList[1].AssumeRoleResult.Credentials.AccessKeyId) - - for _, poller := range processor.pollers { - queueProperties := poller.QueueProvider().Properties() - expectedQuery := queueProperties.Region() + "=" + strconv.FormatInt(queueProperties.ExpireTimeMillis(), 10) - - assert.True(t, strings.Contains(actualRequest.URL.RawQuery, expectedQuery)) - } - - assert.Equal(t, "/jsm/ops/jec/v1/credentials", actualRequest.URL.Path) -} - -func TestReceiveTokenInvalidJson(t *testing.T) { - - processor := newQueueProcessorTest() - processor.retryer.DoFunc = mockHttpGetInvalidJson - - _, err := processor.receiveToken() - - assert.NotNil(t, err) + assert.Equal(t, mockChannelId, authResp.ChannelId) + assert.Contains(t, actualRequest.URL.Path, authenticatePath) } -func TestReceiveTokenGetError(t *testing.T) { +func TestAuthenticateError(t *testing.T) { processor := newQueueProcessorTest() - processor.retryer.DoFunc = mockHttpGetError + processor.retryer.DoFunc = mockAuthenticateError - _, err := processor.receiveToken() + _, err := processor.authenticate() assert.NotNil(t, err) - assert.Equal(t, "Test http error has occurred while getting token.", err.Error()) + assert.Equal(t, "Test http error has occurred while authenticating.", err.Error()) } -func TestReceiveTokenRequestError(t *testing.T) { +func TestAuthenticateRequestError(t *testing.T) { processor := newQueueProcessorTest() - processor.configuration.BaseUrl = "invalid" + processor.configuration = &conf.Configuration{ + BaseUrl: "invalid", + } - _, err := processor.receiveToken() + _, err := processor.authenticate() assert.NotNil(t, err) - assert.Contains(t, err.Error(), "invalid"+tokenPath+"") assert.Contains(t, err.Error(), "unsupported protocol scheme") } -func TestAddTwoDifferentPollersTest(t *testing.T) { - - processor := newQueueProcessorTest() - - p1, _ := processor.addPoller(mockQueueProperties1, mockOwnerId) - poller1 := p1.(*poller) - - mockQueueProvider2 := NewMockQueueProvider().(*MockSQSProvider) - mockQueueProvider2.QueuePropertiesFunc = func() Properties { - return mockQueueProperties2 - } - - processor.addPoller(mockQueueProperties2, mockOwnerId) - - assert.Equal(t, mockQueueProperties1, poller1.QueueProvider().Properties()) - assert.Equal(t, processor.configuration.PollerConf, poller1.conf.PollerConf) - - _, contains := processor.pollers[mockQueueProperties1.Url()] - assert.True(t, contains) - - assert.Equal(t, 2, len(processor.pollers)) -} - -func TestRemovePollerTest(t *testing.T) { - - processor := newQueueProcessorTest() - - processor.pollers = mockPollers - - poller := processor.removePoller(mockQueueUrl1) - processor.removePoller(mockQueueUrl2) - - assert.Equal(t, mockQueueProperties1.Url(), poller.QueueProvider().Properties().Url()) - - assert.Equal(t, 0, len(processor.pollers)) -} - -func TestRefreshPollersRepeat(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - newPollerFunc = NewMockPollerForQueueProcessor - - processor.refreshPollers(&mockToken) - processor.refreshPollers(&mockToken) - processor.refreshPollers(&mockToken) - - assert.Equal(t, 2, len(processor.pollers)) -} - -func TestRefreshPollersAddAndRemove(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - newPollerFunc = NewMockPollerForQueueProcessor - - processor.refreshPollers(&mockToken) - processor.refreshPollers(&mockEmptyToken) - - assert.Equal(t, 0, len(processor.pollers)) -} - -func TestRefreshPollersAdd(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - newPollerFunc = NewMockPollerForQueueProcessor - - processor.refreshPollers(&mockEmptyToken) - processor.refreshPollers(&mockToken) - - assert.Equal(t, 2, len(processor.pollers)) -} - -func TestRefreshPollersWithNotHavingPoller(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - newPollerFunc = NewMockPollerForQueueProcessor - - processor.refreshPollers(&mockToken) - processor.refreshPollers(&mockToken) - processor.refreshPollers(&mockToken) - - assert.Equal(t, 2, len(processor.pollers)) -} - -func TestRefreshOldPollersAlreadyHavingPollers(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - newPollerFunc = NewMockPollerForQueueProcessor - processor.pollers = mockPollers - - processor.refreshPollers(&mockToken) - - assert.Equal(t, 2, len(processor.pollers)) -} - -func TestRefreshPollersWithEmptyAssumeRoleResult(t *testing.T) { - - defer func() { - newPollerFunc = NewPoller - }() - - processor := newQueueProcessorTest() - - newPollerFunc = NewMockPollerForQueueProcessor - processor.pollers = mockPollers - - processor.refreshPollers(&mockTokenWithEmptyAssumeRoleResult) - - assert.Equal(t, 2, len(processor.pollers)) -} - -func TestRefreshPollerWithEmptyToken(t *testing.T) { - - processor := newQueueProcessorTest() - - processor.refreshPollers(&mockEmptyToken) - - assert.Equal(t, 0, len(processor.pollers)) -} - // Mock QueueProcessor type MockQueueProcessor struct { StartProcessingFunc func() error