resources: aws: ec2: Verify SNS message signatures

This commit is contained in:
Jonathan Gold
2017-11-23 15:37:49 -05:00
committed by James Shubin
parent 388a08e13a
commit e330ebc8c9

View File

@@ -19,12 +19,16 @@ package resources
import ( import (
"context" "context"
"crypto/x509"
"encoding/base64" "encoding/base64"
"encoding/json" "encoding/json"
"encoding/pem"
"fmt" "fmt"
"io/ioutil"
"log" "log"
"net" "net"
"net/http" "net/http"
"regexp"
"strconv" "strconv"
"sync" "sync"
"time" "time"
@@ -69,6 +73,11 @@ const (
SnsPolicyService = "events.amazonaws.com" SnsPolicyService = "events.amazonaws.com"
// SnsPolicyAction is the specific permission we are granting in the policy. // SnsPolicyAction is the specific permission we are granting in the policy.
SnsPolicyAction = "SNS:Publish" SnsPolicyAction = "SNS:Publish"
// SnsCertURLRegex is used to make sure we only download certificates
// from amazon. This regex will match "https://sns.***.amazonaws.com/"
// where *** represents any combination of words and hyphens, and will
// match any aws region name, eg: ca-central-1.
SnsCertURLRegex = `(^https:\/\/sns\.([\w\-])+\.amazonaws.com\/)`
// CwePrefix gets prepended onto the cloudwatch rule name. // CwePrefix gets prepended onto the cloudwatch rule name.
CwePrefix = Ec2Prefix + "cw-" CwePrefix = Ec2Prefix + "cw-"
// CweRuleName is the name of the rule created by makeCloudWatchRule. // CweRuleName is the name of the rule created by makeCloudWatchRule.
@@ -133,6 +142,12 @@ type AwsEc2Res struct {
WatchEndpoint string `yaml:"watchendpoint"` // the public url of the sns endpoint, eg: http://server:12345/ WatchEndpoint string `yaml:"watchendpoint"` // the public url of the sns endpoint, eg: http://server:12345/
WatchListenAddr string `yaml:"watchlistenaddr"` // the local address or port that the sns listens on, eg: 10.0.0.0:23456 or 23456 WatchListenAddr string `yaml:"watchlistenaddr"` // the local address or port that the sns listens on, eg: 10.0.0.0:23456 or 23456
// ErrorOnMalformedPost controls whether or not malformed HTTP post
// requests, that cause JSON decoder errors, will also make the engine
// shut down. If ErrorOnMalformedPost set to true and an error occurs,
// Watch() will return the error and the engine will shut down.
ErrorOnMalformedPost bool `yaml:"erroronmalformedpost"`
// UserData is used to run bash and cloud-init commands on first launch. // UserData is used to run bash and cloud-init commands on first launch.
// See http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html // See http://docs.aws.amazon.com/AWSEC2/latest/UserGuide/user-data.html
// for documantation and examples. // for documantation and examples.
@@ -199,9 +214,16 @@ type ruleDetail struct {
// postData is the format of the messages received and decoded by snsPostHandler(). // postData is the format of the messages received and decoded by snsPostHandler().
type postData struct { type postData struct {
Type string `json:"Type"` Type string `json:"Type"`
Token string `json:"Token"` MessageID string `json:"MessageId"`
Message string `json:"Message"` Token string `json:"Token"`
TopicArn string `json:"TopicArn"`
Message string `json:"Message"`
SubscribeURL string `json:"SubscribeURL"`
Timestamp string `json:"Timestamp"`
SignatureVersion string `json:"SignatureVersion"`
Signature string `json:"Signature"`
SigningCertURL string `json:"SigningCertURL"`
} }
// postMsg is used to unmarshal the postData message if it's an event notification. // postMsg is used to unmarshal the postData message if it's an event notification.
@@ -920,15 +942,18 @@ func (obj *AwsEc2Res) Compare(r Res) bool {
if obj.ImageID != res.ImageID { if obj.ImageID != res.ImageID {
return false return false
} }
if obj.UserData != res.UserData {
return false
}
if obj.WatchEndpoint != res.WatchEndpoint { if obj.WatchEndpoint != res.WatchEndpoint {
return false return false
} }
if obj.WatchListenAddr != res.WatchListenAddr { if obj.WatchListenAddr != res.WatchListenAddr {
return false return false
} }
if obj.ErrorOnMalformedPost != res.ErrorOnMalformedPost {
return false
}
if obj.UserData != res.UserData {
return false
}
return true return true
} }
@@ -971,25 +996,37 @@ func (obj *AwsEc2Res) snsListener(listenAddr string) (net.Listener, error) {
} }
// snsPostHandler listens for posts on the SNS Endpoint. // snsPostHandler listens for posts on the SNS Endpoint.
// TODO: download pem and check message against signature
func (obj *AwsEc2Res) snsPostHandler(w http.ResponseWriter, req *http.Request) { func (obj *AwsEc2Res) snsPostHandler(w http.ResponseWriter, req *http.Request) {
if req.Method != "POST" { if req.Method != "POST" {
http.Error(w, "Invalid request method", http.StatusMethodNotAllowed) http.Error(w, "Invalid request method", http.StatusMethodNotAllowed)
return return
} }
// decode json // Decode the post. If an error is produced we either ignore the post,
// or if ErrorOnMalformedPost is true, send the error through awsChan so
// Watch() can return the error and the engine can shut down.
decoder := json.NewDecoder(req.Body) decoder := json.NewDecoder(req.Body)
var post postData var post postData
if err := decoder.Decode(&post); err != nil { if err := decoder.Decode(&post); err != nil {
log.Printf("%s: error decoding post: %s", obj, err)
http.Error(w, "Bad request", http.StatusBadRequest) http.Error(w, "Bad request", http.StatusBadRequest)
select { if obj.ErrorOnMalformedPost {
case obj.awsChan <- &chanStruct{ select {
err: errwrap.Wrapf(err, "error decoding incoming POST, check struct formatting"), case obj.awsChan <- &chanStruct{
}: err: errwrap.Wrapf(err, "error decoding incoming POST, check struct formatting"),
case <-obj.closeChan: }:
case <-obj.closeChan:
}
} }
return return
} }
// Verify the x509 signature. If there is an error verifying the
// signature, we print the error, ignore the event and return.
if err := obj.snsVerifySignature(post); err != nil {
log.Printf("%s: error verifying signature: %s", obj, err)
http.Error(w, "Bad request", http.StatusBadRequest)
return
}
// confirm the subscription
if post.Type == "SubscriptionConfirmation" { if post.Type == "SubscriptionConfirmation" {
if err := obj.snsConfirmSubscription(obj.snsTopicArn, post.Token); err != nil { if err := obj.snsConfirmSubscription(obj.snsTopicArn, post.Token); err != nil {
select { select {
@@ -1020,7 +1057,7 @@ func (obj *AwsEc2Res) snsPostHandler(w http.ResponseWriter, req *http.Request) {
if err != nil { if err != nil {
select { select {
case obj.awsChan <- &chanStruct{ case obj.awsChan <- &chanStruct{
err: errwrap.Wrapf(err, "error confirming subscription"), err: errwrap.Wrapf(err, "error processing event"),
}: }:
case <-obj.closeChan: case <-obj.closeChan:
} }
@@ -1038,6 +1075,92 @@ func (obj *AwsEc2Res) snsPostHandler(w http.ResponseWriter, req *http.Request) {
} }
} }
// snsVerifySignature verifies that the post messages are genuine and originate
// from amazon by checking if the signature is valid for the provided key and
// message contents.
func (obj *AwsEc2Res) snsVerifySignature(post postData) error {
// download and parse the signing certificate
cert, err := obj.snsGetCert(post.SigningCertURL)
if err != nil {
return errwrap.Wrapf(err, "error getting certificate")
}
// convert the message to canonical form
message := obj.snsCanonicalFormat(post)
// decode the message signature from base64
signature, err := base64.StdEncoding.DecodeString(post.Signature)
if err != nil {
return errwrap.Wrapf(err, "error decoding string")
}
// check the signature against the message
if err := cert.CheckSignature(x509.SHA1WithRSA, message, signature); err != nil {
return errwrap.Wrapf(err, "error checking signature")
}
return nil
}
// snsGetCert downloads and parses the signing certificate from the provided
// URL for message verification.
func (obj *AwsEc2Res) snsGetCert(url string) (*x509.Certificate, error) {
// only download valid certificates from amazon
matchURL, err := regexp.MatchString(SnsCertURLRegex, url)
if err != nil {
return nil, errwrap.Wrapf(err, "error matching regex")
}
if !matchURL {
return nil, fmt.Errorf("invalid certificate url: %s", url)
}
// download the signing certificate
resp, err := http.Get(url)
if err != nil {
return nil, errwrap.Wrapf(err, "http get error")
}
defer resp.Body.Close()
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errwrap.Wrapf(err, "error reading post body")
}
// Decode the certificate and discard the second argument, which
// contains any additional data in the response following the pem
// block, if present.
decodedCert, _ := pem.Decode(body)
if decodedCert == nil {
return nil, fmt.Errorf("certificate is nil")
}
// parse the certificate
parsedCert, err := x509.ParseCertificate(decodedCert.Bytes)
if err != nil {
return nil, errwrap.Wrapf(err, "error parsing certificate")
}
return parsedCert, nil
}
// snsCanonicalFormat formats post messages as required for signature
// verification. For more information about this requirement see:
// http://docs.aws.amazon.com/sns/latest/dg/SendMessageToHttp.verify.signature.html
func (obj *AwsEc2Res) snsCanonicalFormat(post postData) []byte {
var str string
str += "Message\n"
str += post.Message + "\n"
str += "MessageId\n"
str += post.MessageID + "\n"
if post.SubscribeURL != "" {
str += "SubscribeURL\n"
str += post.SubscribeURL + "\n"
}
str += "Timestamp\n"
str += post.Timestamp + "\n"
if post.Token != "" {
str += "Token\n"
str += post.Token + "\n"
}
str += "TopicArn\n"
str += post.TopicArn + "\n"
str += "Type\n"
str += post.Type + "\n"
return []byte(str)
}
// snsMakeTopic creates a topic on aws sns. // snsMakeTopic creates a topic on aws sns.
func (obj *AwsEc2Res) snsMakeTopic() (string, error) { func (obj *AwsEc2Res) snsMakeTopic() (string, error) {
// make topic // make topic