resources: aws: ec2: Use custom listener for snsServer

This patch replaces the call to Server.ListenAndServe() with
Server.Serve(listener) in order to make sure the listener is up
and running before we subscribe to the topic in a future patch.
This commit is contained in:
Jonathan Gold
2017-11-14 17:10:43 -05:00
committed by James Shubin
parent 12fce52cd7
commit 966172eac6

View File

@@ -23,6 +23,7 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"net/http"
"strconv"
"sync"
@@ -510,9 +511,17 @@ func (obj *AwsEc2Res) snsWatch() error {
var exit *error
defer obj.wg.Wait()
defer close(obj.closeChan)
// set up the sns endpoint
snsServer := obj.snsServer()
// shutdown the sns endpoint when we're done
// create the sns listener
// closing is handled by http.Server.Shutdown in the defer func below
listener, err := obj.snsListener(obj.WatchListenAddr)
if err != nil {
return errwrap.Wrapf(err, "error creating listener")
}
// set up the sns server
snsServer := &http.Server{
Handler: http.HandlerFunc(obj.snsPostHandler),
}
// close the listener and shutdown the sns server when we're done
defer func() {
ctx, cancel := context.WithTimeout(context.TODO(), SnsServerShutdownTimeout*time.Second)
defer cancel()
@@ -525,11 +534,11 @@ func (obj *AwsEc2Res) snsWatch() error {
}
}()
obj.wg.Add(1)
// start the endpoint
// start the sns server
go func() {
defer obj.wg.Done()
defer close(obj.awsChan)
if err := snsServer.ListenAndServe(); err != nil {
if err := snsServer.Serve(listener); err != nil {
// when we shut down
if err == http.ErrServerClosed {
log.Printf("%s: Stopped SNS Endpoint", obj)
@@ -802,18 +811,18 @@ func (obj *AwsEc2Res) prependName() string {
return AwsPrefix + obj.GetName()
}
// snsServer returns an http server used to listen for sns messages.
func (obj *AwsEc2Res) snsServer() *http.Server {
addr := obj.WatchListenAddr
// if addr is a port
if _, err := strconv.Atoi(obj.WatchListenAddr); err == nil {
addr = fmt.Sprintf(":%s", obj.WatchListenAddr)
// snsListener returns a listener bound to listenAddr.
func (obj *AwsEc2Res) snsListener(listenAddr string) (net.Listener, error) {
addr := listenAddr
// if listenAddr is a port
if _, err := strconv.Atoi(listenAddr); err == nil {
addr = fmt.Sprintf(":%s", listenAddr)
}
handler := http.HandlerFunc(obj.snsPostHandler)
return &http.Server{
Addr: addr,
Handler: handler,
listener, err := net.Listen("tcp", addr)
if err != nil {
return nil, err
}
return listener, nil
}
// snsPostHandler listens for posts on the SNS Endpoint.