diff --git a/resources/aws_ec2.go b/resources/aws_ec2.go index 07c7d9b6..dc8c50ac 100644 --- a/resources/aws_ec2.go +++ b/resources/aws_ec2.go @@ -466,46 +466,25 @@ func (obj *AwsEc2Res) longpollWatch() error { }, } log.Printf("%s: Watching: %s", obj, *diOutput.Reservations[0].Instances[0].InstanceId) - if err := obj.client.WaitUntilInstanceStoppedWithContext(ctx, waitInput); err != nil { - if aerr, ok := err.(awserr.Error); ok { - if aerr.Code() == request.CanceledErrorCode { - log.Printf("%s: Request cancelled", obj) - } - if aerr.Code() == request.WaiterResourceNotReadyErrorCode { - continue - } - } - select { - case obj.awsChan <- &chanStruct{ - err: errwrap.Wrapf(err, "unknown error waiting for instance to stop"), - }: - case <-obj.closeChan: - } - return - } - stateOutput, err := obj.client.DescribeInstances(diInput) + // Wait for instance to stop or + // terminate concurrently. + event, err := longpollRunningWaiter(ctx, waitInput, obj.client) if err != nil { select { case obj.awsChan <- &chanStruct{ - err: errwrap.Wrapf(err, "error describing instances"), + err: errwrap.Wrapf(err, "internal waiter error"), }: case <-obj.closeChan: } return } - var stateName string - if len(stateOutput.Reservations) == 1 { - stateName = *stateOutput.Reservations[0].Instances[0].State.Name - } - if len(stateOutput.Reservations) == 0 || (len(stateOutput.Reservations) == 1 && stateName != "running") { - select { - case obj.awsChan <- &chanStruct{ - event: awsEc2EventInstanceStopped, - }: - case <-obj.closeChan: - return - } + select { + case obj.awsChan <- &chanStruct{ + event: event, + }: + case <-obj.closeChan: } + return } } if obj.State == "stopped" { @@ -641,11 +620,14 @@ func (obj *AwsEc2Res) longpollWatch() error { } case msg, ok := <-obj.awsChan: if !ok { - return *exit + return fmt.Errorf("channel closed unexpectedly") } if err := msg.err; err != nil { return err } + if msg.event == awsEc2EventNone { + continue + } log.Printf("%s: State: %v", obj, msg.event) obj.StateOK(false) send = true @@ -1014,6 +996,101 @@ func (obj *AwsEc2Res) prependName() string { return AwsPrefix + obj.GetName() } +// longpollRunningWaiter waits for the instance to stop and waits for it to +// terminate. If either waiter returns, the instance state is checked, and an +// awsEc2Event is returned. +func longpollRunningWaiter(ctx context.Context, waitInput *ec2.DescribeInstancesInput, c *ec2.EC2) (awsEc2Event, error) { + if err := waitUntilInstanceStoppedOrTerminatedWithContext(ctx, waitInput, c); err != nil { + return awsEc2EventNone, errwrap.Wrapf(err, "error waiting for instance to stop or terminate") + } + // Check the instance state, and return the appropriate event. + stateOutput, err := c.DescribeInstances(waitInput) + if err != nil { + return awsEc2EventNone, errwrap.Wrapf(err, "error describing instances") + } + if len(stateOutput.Reservations) == 1 { + switch *stateOutput.Reservations[0].Instances[0].State.Name { + case "stopped": + return awsEc2EventInstanceStopped, nil + case "terminated": + return awsEc2EventInstanceTerminated, nil + } + } + return awsEc2EventNone, nil +} + +// waitUntilInstanceStoppedOrTerminatedWithContext combines the two waiters +// required to trigger events if a running instance is stopped or terminated. +// This function is needed, because the AWS api only provides waiters for one +// or the other. +func waitUntilInstanceStoppedOrTerminatedWithContext(ctx context.Context, waitInput *ec2.DescribeInstancesInput, c *ec2.EC2) error { + errChan := make(chan error) + defer close(errChan) // unnecessary, but nice to have + wg := sync.WaitGroup{} + defer wg.Wait() + + closeChan := make(chan struct{}) + innerCtx, cancel := context.WithCancel(context.TODO()) // uncoupled!! + + once := &sync.Once{} + closer := func() { + cancel() + close(closeChan) + } + defer once.Do(closer) // needed if we exit below due to ctx being cancelled + + wg.Add(1) + go func() { + defer wg.Done() + defer once.Do(closer) + + err := c.WaitUntilInstanceStoppedWithContext(innerCtx, waitInput) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + if aerr.Code() == request.CanceledErrorCode || aerr.Code() == request.WaiterResourceNotReadyErrorCode { + err = nil // we want to ignore these kinds of errors + } + } + } + + select { + case errChan <- errwrap.Wrapf(err, "unknown error waiting for instance to stop"): + case <-closeChan: + } + }() + + wg.Add(1) + go func() { + defer wg.Done() + defer once.Do(closer) + + err := c.WaitUntilInstanceTerminatedWithContext(innerCtx, waitInput) + if err != nil { + if aerr, ok := err.(awserr.Error); ok { + if aerr.Code() == request.CanceledErrorCode || aerr.Code() == request.WaiterResourceNotReadyErrorCode { + err = nil + } + } + } + + select { + case errChan <- errwrap.Wrapf(err, "unknown error waiting for instance to terminate"): + case <-closeChan: + } + }() + + select { + case err, ok := <-errChan: + if !ok { + return fmt.Errorf("channel closed unexpectedly") + } + return err // return either nil or an error + case <-ctx.Done(): // if ctx is canceled, we need to transmit that error + // TODO: should we instead use the aws context copy and request.CanceledErrorCode ? + return ctx.Err() + } +} + // snsListener returns a listener bound to listenAddr. func (obj *AwsEc2Res) snsListener(listenAddr string) (net.Listener, error) { addr := listenAddr