diff --git a/main.go b/main.go index 9185c8f0..4abef2e0 100644 --- a/main.go +++ b/main.go @@ -18,6 +18,7 @@ package main import ( + "fmt" etcdtypes "github.com/coreos/etcd/pkg/types" "github.com/coreos/pkg/capnslog" "github.com/urfave/cli" @@ -33,6 +34,7 @@ import ( var ( program string version string + prefix = fmt.Sprintf("/var/lib/%s/", program) ) const ( @@ -268,6 +270,8 @@ func run(c *cli.Context) error { cConns, c.Bool("allow-interactive"), c.String("ssh-priv-id-rsa"), + !c.Bool("no-caching"), + prefix, ) // TODO: is there any benefit to running the remotes above in the loop? @@ -451,6 +455,10 @@ func main() { Usage: "number of maximum concurrent remote ssh connections to run, 0 for unlimited", EnvVar: "MGMT_CCONNS", }, + cli.BoolFlag{ + Name: "no-caching", + Usage: "don't allow remote caching of remote execution binary", + }, }, }, } diff --git a/remote.go b/remote.go index d3cc5767..27e153c5 100644 --- a/remote.go +++ b/remote.go @@ -43,6 +43,8 @@ package main // TODO: make this a separate ssh package import ( "bytes" + "crypto/sha256" + "encoding/hex" "fmt" "github.com/howeyc/gopass" "github.com/kardianos/osext" @@ -86,6 +88,9 @@ type SSH struct { remoteURLs []string // list of urls where the remote server connects to noop bool // whether to run the remote process with --noop + caching bool // whether to try and cache the copy of the binary + prefix string // location we're allowed to put data on the remote server + client *ssh.Client // client object sftp *sftp.Client // sftp object listener net.Listener // remote listener @@ -152,13 +157,42 @@ func (obj *SSH) Sftp() error { } // TODO: make the path configurable to deal with /tmp/ mounted noexec? + tmpdir := func() string { + return fmt.Sprintf(formatPattern, fmtUUID(10)) // eg: /tmp/mgmt.abcdefghij/ + } + var ready bool obj.remotewd = "" + if obj.caching && obj.prefix != "" { + // try and make the parent dir, just in case... + obj.sftp.Mkdir(obj.prefix) // ignore any errors + obj.remotewd = path.Join(obj.prefix, "remote") // eg: /var/lib/mgmt/remote/ + if fileinfo, err := obj.sftp.Stat(obj.remotewd); err == nil { + if fileinfo.IsDir() { + ready = true + } + } + } else { + obj.remotewd = tmpdir() + } + for i := 0; true; { // NOTE: since fmtUUID is deterministic, if we don't clean up // previous runs, we may get the same paths generated, and here // they will conflict. - obj.remotewd = fmt.Sprintf(formatPattern, fmtUUID(10)) // eg: /tmp/mgmt.abcdefghij/ if err := obj.sftp.Mkdir(obj.remotewd); err != nil { + // TODO: if we could determine if this was a "file + // already exists" error, we could break now! + // https://github.com/pkg/sftp/issues/131 + //if status, ok := err.(*sftp.StatusError); ok { + // log.Printf("Code: %v, %v", status.Code, status.Error()) + // if status.Code == ??? && obj.caching { + // break + // } + //} + if ready { // dir already exists + break + } + i++ // count number of times we've tried e := fmt.Errorf("Can't make tmp directory: %s", err) log.Println(e) @@ -166,32 +200,38 @@ func (obj *SSH) Sftp() error { log.Printf("Remote: Please clean up the remote dir: %s", obj.remotewd) return e } + if obj.caching { // maybe /var/lib/mgmt/ is read-only. + obj.remotewd = tmpdir() + } continue // try again, unlucky conflict! } log.Printf("Remote: Remotely created: %s", obj.remotewd) break } - // FIXME: consider running a hashing function to check if the remote file - // is valid before copying it over again... this would need a deterministic - // temp directory location first... this actually happens with fmtUUID! - // future patch! - obj.execpath = path.Join(obj.remotewd, program) // program is a compile time string from main.go log.Printf("Remote: Remote path is: %s", obj.execpath) - log.Println("Remote: Copying binary, please be patient...") - _, err = obj.SftpCopy(selfpath, obj.execpath) - if err != nil { - // TODO: cleanup - return fmt.Errorf("Error copying binary: %s", err) + var same bool + if obj.caching { + same, _ = obj.SftpHash(selfpath, obj.execpath) // ignore errors + } + if same { + log.Println("Remote: Skipping binary copy, file was cached.") + } else { + log.Println("Remote: Copying binary, please be patient...") + _, err = obj.SftpCopy(selfpath, obj.execpath) + if err != nil { + // TODO: cleanup + return fmt.Errorf("Error copying binary: %s", err) + } } if obj.exitCheck() { return nil } - // make file executable + // make file executable; don't cache this in case it didn't ever happen // TODO: do we want the group or other bits set? if err := obj.sftp.Chmod(obj.execpath, 0770); err != nil { return fmt.Errorf("Can't set file mode bits!") @@ -250,6 +290,39 @@ func (obj *SSH) SftpCopy(src, dst string) (int64, error) { return n, nil } +// SftpHash hashes a local file, and compares that hash to the result of a +// remote hashing command run on the second file path. +func (obj *SSH) SftpHash(local, remote string) (bool, error) { + // TODO: we could run both hash operations in parallel! :) + hash := sha256.New() + f, err := os.Open(local) + if err != nil { + return false, err + } + defer f.Close() + + if _, err := io.Copy(hash, f); err != nil { + return false, err + } + sha256sum := hex.EncodeToString(hash.Sum(nil)) + //log.Printf("sha256sum: %s", sha256sum) + + // We run a remote hashing command, instead of reading the file in over + // the wire and hashing it ourselves, because assuming symmetric + // bandwidth, that would defeat the point of caching it altogether! + cmd := fmt.Sprintf("sha256sum '%s'", remote) + out, err := obj.simpleRun(cmd) + if err != nil { + return false, err + } + + s := strings.Split(out, " ") // sha256sum returns: hash + filename + if s[0] == sha256sum { + return true, nil + } + return false, nil // files were different +} + // SftpClean cleans up the mess and closes the connection from the sftp work. func (obj *SSH) SftpClean() error { if obj.sftp == nil { @@ -265,12 +338,14 @@ func (obj *SSH) SftpClean() error { // clean up the graph definition in obj.remotewd err := obj.sftp.Remove(obj.filepath) - // TODO: add binary caching - if e := obj.sftp.Remove(obj.execpath); e != nil { - err = e - } - if e := obj.sftp.Remove(obj.remotewd); e != nil { - err = e + // if we're not caching+sha1sum-ing, then also remove the rest + if !obj.caching { + if e := obj.sftp.Remove(obj.execpath); e != nil { + err = e + } + if e := obj.sftp.Remove(obj.remotewd); e != nil { + err = e + } } if e := obj.sftp.Close(); e != nil { @@ -429,16 +504,17 @@ func (obj *SSH) Exec() error { } // simpleRun is a simple helper for running commands in new sessions. -func (obj *SSH) simpleRun(cmd string) error { +func (obj *SSH) simpleRun(cmd string) (string, error) { session, err := obj.client.NewSession() // not the main session! if err != nil { - return fmt.Errorf("Failed to create session: %s", err.Error()) + return "", fmt.Errorf("Failed to create session: %s", err.Error()) } defer session.Close() - if err := session.Run(cmd); err != nil { - return fmt.Errorf("Error running command: %s", err) + var out []byte + if out, err = session.CombinedOutput(cmd); err != nil { + return string(out), fmt.Errorf("Error running command: %s", err) } - return nil + return string(out), nil } // ExecExit sends a SIGINT (^C) signal to the remote process, and waits for the @@ -455,7 +531,7 @@ func (obj *SSH) ExecExit() error { } // FIXME: workaround: force a signal! - if err := obj.simpleRun(fmt.Sprintf("killall -SIGINT %s", program)); err != nil { // FIXME: low specificity + if _, err := obj.simpleRun(fmt.Sprintf("killall -SIGINT %s", program)); err != nil { // FIXME: low specificity log.Printf("Remote: Failed to send SIGINT: %s", err.Error()) } @@ -470,7 +546,7 @@ func (obj *SSH) ExecExit() error { // FIXME: workaround: wait (spin lock) until process quits cleanly... cmd := fmt.Sprintf("while killall -0 %s 2> /dev/null; do sleep 1s; done", program) // FIXME: low specificity - if err := obj.simpleRun(cmd); err != nil { + if _, err := obj.simpleRun(cmd); err != nil { return fmt.Errorf("Error waiting: %s", err) } @@ -575,6 +651,8 @@ type Remotes struct { cConns uint16 // number of concurrent ssh connections, zero means unlimited interactive bool // allow interactive prompting sshPrivIdRsa string // path to ~/.ssh/id_rsa + caching bool // whether to try and cache the copy of the binary + prefix string // folder prefix to use for misc storage wg sync.WaitGroup // keep track of each running SSH connection lock sync.Mutex // mutex for access to sshmap @@ -584,7 +662,7 @@ type Remotes struct { } // The NewRemotes function builds a Remotes struct. -func NewRemotes(clientURLs, remoteURLs []string, noop bool, remotes []string, cConns uint16, interactive bool, sshPrivIdRsa string) *Remotes { +func NewRemotes(clientURLs, remoteURLs []string, noop bool, remotes []string, cConns uint16, interactive bool, sshPrivIdRsa string, caching bool, prefix string) *Remotes { return &Remotes{ clientURLs: clientURLs, remoteURLs: remoteURLs, @@ -593,6 +671,8 @@ func NewRemotes(clientURLs, remoteURLs []string, noop bool, remotes []string, cC cConns: cConns, interactive: interactive, sshPrivIdRsa: sshPrivIdRsa, + caching: caching, + prefix: prefix, sshmap: make(map[string]*SSH), semaphore: NewSemaphore(int(cConns)), } @@ -668,6 +748,8 @@ func (obj *Remotes) NewSSH(file string) (*SSH, error) { clientURLs: obj.clientURLs, remoteURLs: obj.remoteURLs, noop: obj.noop, + caching: obj.caching, + prefix: obj.prefix, }, nil }