diff --git a/resources/aws_ec2.go b/resources/aws_ec2.go index 7441e3ae..275e1bc7 100644 --- a/resources/aws_ec2.go +++ b/resources/aws_ec2.go @@ -23,6 +23,7 @@ import ( "fmt" "io/ioutil" "log" + "net" "net/http" "strconv" "sync" @@ -510,9 +511,17 @@ func (obj *AwsEc2Res) snsWatch() error { var exit *error defer obj.wg.Wait() defer close(obj.closeChan) - // set up the sns endpoint - snsServer := obj.snsServer() - // shutdown the sns endpoint when we're done + // create the sns listener + // closing is handled by http.Server.Shutdown in the defer func below + listener, err := obj.snsListener(obj.WatchListenAddr) + if err != nil { + return errwrap.Wrapf(err, "error creating listener") + } + // set up the sns server + snsServer := &http.Server{ + Handler: http.HandlerFunc(obj.snsPostHandler), + } + // close the listener and shutdown the sns server when we're done defer func() { ctx, cancel := context.WithTimeout(context.TODO(), SnsServerShutdownTimeout*time.Second) defer cancel() @@ -525,11 +534,11 @@ func (obj *AwsEc2Res) snsWatch() error { } }() obj.wg.Add(1) - // start the endpoint + // start the sns server go func() { defer obj.wg.Done() defer close(obj.awsChan) - if err := snsServer.ListenAndServe(); err != nil { + if err := snsServer.Serve(listener); err != nil { // when we shut down if err == http.ErrServerClosed { log.Printf("%s: Stopped SNS Endpoint", obj) @@ -802,18 +811,18 @@ func (obj *AwsEc2Res) prependName() string { return AwsPrefix + obj.GetName() } -// snsServer returns an http server used to listen for sns messages. -func (obj *AwsEc2Res) snsServer() *http.Server { - addr := obj.WatchListenAddr - // if addr is a port - if _, err := strconv.Atoi(obj.WatchListenAddr); err == nil { - addr = fmt.Sprintf(":%s", obj.WatchListenAddr) +// snsListener returns a listener bound to listenAddr. +func (obj *AwsEc2Res) snsListener(listenAddr string) (net.Listener, error) { + addr := listenAddr + // if listenAddr is a port + if _, err := strconv.Atoi(listenAddr); err == nil { + addr = fmt.Sprintf(":%s", listenAddr) } - handler := http.HandlerFunc(obj.snsPostHandler) - return &http.Server{ - Addr: addr, - Handler: handler, + listener, err := net.Listen("tcp", addr) + if err != nil { + return nil, err } + return listener, nil } // snsPostHandler listens for posts on the SNS Endpoint.