diff --git a/ecs/pkg/amazon/down_test.go b/ecs/pkg/amazon/down_test.go index 49e5fe2a..bf98a6b9 100644 --- a/ecs/pkg/amazon/down_test.go +++ b/ecs/pkg/amazon/down_test.go @@ -19,8 +19,10 @@ func TestDownDontDeleteCluster(t *testing.T) { } ctx := context.TODO() recorder := m.EXPECT() - recorder.DeleteStack(ctx, "test_project").Return(nil).Times(1) - recorder.WaitStackComplete(ctx, "test_project", gomock.Any()).Return(nil).Times(1) + recorder.DeleteStack(ctx, "test_project").Return(nil) + recorder.GetStackID(ctx, "test_project").Return("stack-123", nil) + recorder.WaitStackComplete(ctx, "stack-123", StackDelete).Return(nil) + recorder.DescribeStackEvents(ctx, "stack-123").Return(nil, nil) c.ComposeDown(ctx, "test_project", false) } @@ -37,9 +39,11 @@ func TestDownDeleteCluster(t *testing.T) { ctx := context.TODO() recorder := m.EXPECT() - recorder.DeleteStack(ctx, "test_project").Return(nil).Times(1) - recorder.WaitStackComplete(ctx, "test_project", gomock.Any()).Return(nil).Times(1) - recorder.DeleteCluster(ctx, "test_cluster").Return(nil).Times(1) + recorder.DeleteStack(ctx, "test_project").Return(nil) + recorder.GetStackID(ctx, "test_project").Return("stack-123", nil) + recorder.WaitStackComplete(ctx, "stack-123", StackDelete).Return(nil) + recorder.DescribeStackEvents(ctx, "stack-123").Return(nil, nil) + recorder.DeleteCluster(ctx, "test_cluster").Return(nil) c.ComposeDown(ctx, "test_project", true) } diff --git a/ecs/pkg/amazon/wait.go b/ecs/pkg/amazon/wait.go index 095fcc6b..58ae93d7 100644 --- a/ecs/pkg/amazon/wait.go +++ b/ecs/pkg/amazon/wait.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "sort" + "strings" "time" "github.com/aws/aws-sdk-go/aws" @@ -22,21 +23,20 @@ func (c *client) WaitStackCompletion(ctx context.Context, name string, operation } ticker := time.NewTicker(1 * time.Second) - done := make(chan error) + done := make(chan bool) go func() { - err := c.api.WaitStackComplete(ctx, name, operation) + c.api.WaitStackComplete(ctx, stackID, operation) //nolint:errcheck ticker.Stop() - done <- err + done <- true }() var completed bool - var waitErr error + var stackErr error for !completed { select { - case err := <-done: + case <-done: completed = true - waitErr = err case <-ticker.C: } events, err := c.api.DescribeStackEvents(ctx, stackID) @@ -55,10 +55,15 @@ func (c *client) WaitStackCompletion(ctx context.Context, name string, operation knownEvents[*event.EventId] = struct{}{} resource := fmt.Sprintf("%s %q", aws.StringValue(event.ResourceType), aws.StringValue(event.LogicalResourceId)) - w.ResourceEvent(resource, aws.StringValue(event.ResourceStatus), aws.StringValue(event.ResourceStatusReason)) + reason := aws.StringValue(event.ResourceStatusReason) + status := aws.StringValue(event.ResourceStatus) + w.ResourceEvent(resource, status, reason) + if stackErr == nil && strings.HasSuffix(status, "_FAILED") { + stackErr = fmt.Errorf(reason) + } } } - return waitErr + return stackErr } type waitAPI interface {