Add caching for remote execution
This speeds up copying of the binary for slow connections. It also finally adds a universal directory prefix for mgmt!
This commit is contained in:
8
main.go
8
main.go
@@ -18,6 +18,7 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
etcdtypes "github.com/coreos/etcd/pkg/types"
|
||||
"github.com/coreos/pkg/capnslog"
|
||||
"github.com/urfave/cli"
|
||||
@@ -33,6 +34,7 @@ import (
|
||||
var (
|
||||
program string
|
||||
version string
|
||||
prefix = fmt.Sprintf("/var/lib/%s/", program)
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -268,6 +270,8 @@ func run(c *cli.Context) error {
|
||||
cConns,
|
||||
c.Bool("allow-interactive"),
|
||||
c.String("ssh-priv-id-rsa"),
|
||||
!c.Bool("no-caching"),
|
||||
prefix,
|
||||
)
|
||||
|
||||
// TODO: is there any benefit to running the remotes above in the loop?
|
||||
@@ -451,6 +455,10 @@ func main() {
|
||||
Usage: "number of maximum concurrent remote ssh connections to run, 0 for unlimited",
|
||||
EnvVar: "MGMT_CCONNS",
|
||||
},
|
||||
cli.BoolFlag{
|
||||
Name: "no-caching",
|
||||
Usage: "don't allow remote caching of remote execution binary",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
114
remote.go
114
remote.go
@@ -43,6 +43,8 @@ package main // TODO: make this a separate ssh package
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"github.com/howeyc/gopass"
|
||||
"github.com/kardianos/osext"
|
||||
@@ -86,6 +88,9 @@ type SSH struct {
|
||||
remoteURLs []string // list of urls where the remote server connects to
|
||||
noop bool // whether to run the remote process with --noop
|
||||
|
||||
caching bool // whether to try and cache the copy of the binary
|
||||
prefix string // location we're allowed to put data on the remote server
|
||||
|
||||
client *ssh.Client // client object
|
||||
sftp *sftp.Client // sftp object
|
||||
listener net.Listener // remote listener
|
||||
@@ -152,13 +157,42 @@ func (obj *SSH) Sftp() error {
|
||||
}
|
||||
|
||||
// TODO: make the path configurable to deal with /tmp/ mounted noexec?
|
||||
tmpdir := func() string {
|
||||
return fmt.Sprintf(formatPattern, fmtUUID(10)) // eg: /tmp/mgmt.abcdefghij/
|
||||
}
|
||||
var ready bool
|
||||
obj.remotewd = ""
|
||||
if obj.caching && obj.prefix != "" {
|
||||
// try and make the parent dir, just in case...
|
||||
obj.sftp.Mkdir(obj.prefix) // ignore any errors
|
||||
obj.remotewd = path.Join(obj.prefix, "remote") // eg: /var/lib/mgmt/remote/
|
||||
if fileinfo, err := obj.sftp.Stat(obj.remotewd); err == nil {
|
||||
if fileinfo.IsDir() {
|
||||
ready = true
|
||||
}
|
||||
}
|
||||
} else {
|
||||
obj.remotewd = tmpdir()
|
||||
}
|
||||
|
||||
for i := 0; true; {
|
||||
// NOTE: since fmtUUID is deterministic, if we don't clean up
|
||||
// previous runs, we may get the same paths generated, and here
|
||||
// they will conflict.
|
||||
obj.remotewd = fmt.Sprintf(formatPattern, fmtUUID(10)) // eg: /tmp/mgmt.abcdefghij/
|
||||
if err := obj.sftp.Mkdir(obj.remotewd); err != nil {
|
||||
// TODO: if we could determine if this was a "file
|
||||
// already exists" error, we could break now!
|
||||
// https://github.com/pkg/sftp/issues/131
|
||||
//if status, ok := err.(*sftp.StatusError); ok {
|
||||
// log.Printf("Code: %v, %v", status.Code, status.Error())
|
||||
// if status.Code == ??? && obj.caching {
|
||||
// break
|
||||
// }
|
||||
//}
|
||||
if ready { // dir already exists
|
||||
break
|
||||
}
|
||||
|
||||
i++ // count number of times we've tried
|
||||
e := fmt.Errorf("Can't make tmp directory: %s", err)
|
||||
log.Println(e)
|
||||
@@ -166,32 +200,38 @@ func (obj *SSH) Sftp() error {
|
||||
log.Printf("Remote: Please clean up the remote dir: %s", obj.remotewd)
|
||||
return e
|
||||
}
|
||||
if obj.caching { // maybe /var/lib/mgmt/ is read-only.
|
||||
obj.remotewd = tmpdir()
|
||||
}
|
||||
continue // try again, unlucky conflict!
|
||||
}
|
||||
log.Printf("Remote: Remotely created: %s", obj.remotewd)
|
||||
break
|
||||
}
|
||||
|
||||
// FIXME: consider running a hashing function to check if the remote file
|
||||
// is valid before copying it over again... this would need a deterministic
|
||||
// temp directory location first... this actually happens with fmtUUID!
|
||||
// future patch!
|
||||
|
||||
obj.execpath = path.Join(obj.remotewd, program) // program is a compile time string from main.go
|
||||
log.Printf("Remote: Remote path is: %s", obj.execpath)
|
||||
|
||||
var same bool
|
||||
if obj.caching {
|
||||
same, _ = obj.SftpHash(selfpath, obj.execpath) // ignore errors
|
||||
}
|
||||
if same {
|
||||
log.Println("Remote: Skipping binary copy, file was cached.")
|
||||
} else {
|
||||
log.Println("Remote: Copying binary, please be patient...")
|
||||
_, err = obj.SftpCopy(selfpath, obj.execpath)
|
||||
if err != nil {
|
||||
// TODO: cleanup
|
||||
return fmt.Errorf("Error copying binary: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if obj.exitCheck() {
|
||||
return nil
|
||||
}
|
||||
|
||||
// make file executable
|
||||
// make file executable; don't cache this in case it didn't ever happen
|
||||
// TODO: do we want the group or other bits set?
|
||||
if err := obj.sftp.Chmod(obj.execpath, 0770); err != nil {
|
||||
return fmt.Errorf("Can't set file mode bits!")
|
||||
@@ -250,6 +290,39 @@ func (obj *SSH) SftpCopy(src, dst string) (int64, error) {
|
||||
return n, nil
|
||||
}
|
||||
|
||||
// SftpHash hashes a local file, and compares that hash to the result of a
|
||||
// remote hashing command run on the second file path.
|
||||
func (obj *SSH) SftpHash(local, remote string) (bool, error) {
|
||||
// TODO: we could run both hash operations in parallel! :)
|
||||
hash := sha256.New()
|
||||
f, err := os.Open(local)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
if _, err := io.Copy(hash, f); err != nil {
|
||||
return false, err
|
||||
}
|
||||
sha256sum := hex.EncodeToString(hash.Sum(nil))
|
||||
//log.Printf("sha256sum: %s", sha256sum)
|
||||
|
||||
// We run a remote hashing command, instead of reading the file in over
|
||||
// the wire and hashing it ourselves, because assuming symmetric
|
||||
// bandwidth, that would defeat the point of caching it altogether!
|
||||
cmd := fmt.Sprintf("sha256sum '%s'", remote)
|
||||
out, err := obj.simpleRun(cmd)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
s := strings.Split(out, " ") // sha256sum returns: hash + filename
|
||||
if s[0] == sha256sum {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil // files were different
|
||||
}
|
||||
|
||||
// SftpClean cleans up the mess and closes the connection from the sftp work.
|
||||
func (obj *SSH) SftpClean() error {
|
||||
if obj.sftp == nil {
|
||||
@@ -265,13 +338,15 @@ func (obj *SSH) SftpClean() error {
|
||||
// clean up the graph definition in obj.remotewd
|
||||
err := obj.sftp.Remove(obj.filepath)
|
||||
|
||||
// TODO: add binary caching
|
||||
// if we're not caching+sha1sum-ing, then also remove the rest
|
||||
if !obj.caching {
|
||||
if e := obj.sftp.Remove(obj.execpath); e != nil {
|
||||
err = e
|
||||
}
|
||||
if e := obj.sftp.Remove(obj.remotewd); e != nil {
|
||||
err = e
|
||||
}
|
||||
}
|
||||
|
||||
if e := obj.sftp.Close(); e != nil {
|
||||
err = e
|
||||
@@ -429,16 +504,17 @@ func (obj *SSH) Exec() error {
|
||||
}
|
||||
|
||||
// simpleRun is a simple helper for running commands in new sessions.
|
||||
func (obj *SSH) simpleRun(cmd string) error {
|
||||
func (obj *SSH) simpleRun(cmd string) (string, error) {
|
||||
session, err := obj.client.NewSession() // not the main session!
|
||||
if err != nil {
|
||||
return fmt.Errorf("Failed to create session: %s", err.Error())
|
||||
return "", fmt.Errorf("Failed to create session: %s", err.Error())
|
||||
}
|
||||
defer session.Close()
|
||||
if err := session.Run(cmd); err != nil {
|
||||
return fmt.Errorf("Error running command: %s", err)
|
||||
var out []byte
|
||||
if out, err = session.CombinedOutput(cmd); err != nil {
|
||||
return string(out), fmt.Errorf("Error running command: %s", err)
|
||||
}
|
||||
return nil
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
// ExecExit sends a SIGINT (^C) signal to the remote process, and waits for the
|
||||
@@ -455,7 +531,7 @@ func (obj *SSH) ExecExit() error {
|
||||
}
|
||||
|
||||
// FIXME: workaround: force a signal!
|
||||
if err := obj.simpleRun(fmt.Sprintf("killall -SIGINT %s", program)); err != nil { // FIXME: low specificity
|
||||
if _, err := obj.simpleRun(fmt.Sprintf("killall -SIGINT %s", program)); err != nil { // FIXME: low specificity
|
||||
log.Printf("Remote: Failed to send SIGINT: %s", err.Error())
|
||||
}
|
||||
|
||||
@@ -470,7 +546,7 @@ func (obj *SSH) ExecExit() error {
|
||||
|
||||
// FIXME: workaround: wait (spin lock) until process quits cleanly...
|
||||
cmd := fmt.Sprintf("while killall -0 %s 2> /dev/null; do sleep 1s; done", program) // FIXME: low specificity
|
||||
if err := obj.simpleRun(cmd); err != nil {
|
||||
if _, err := obj.simpleRun(cmd); err != nil {
|
||||
return fmt.Errorf("Error waiting: %s", err)
|
||||
}
|
||||
|
||||
@@ -575,6 +651,8 @@ type Remotes struct {
|
||||
cConns uint16 // number of concurrent ssh connections, zero means unlimited
|
||||
interactive bool // allow interactive prompting
|
||||
sshPrivIdRsa string // path to ~/.ssh/id_rsa
|
||||
caching bool // whether to try and cache the copy of the binary
|
||||
prefix string // folder prefix to use for misc storage
|
||||
|
||||
wg sync.WaitGroup // keep track of each running SSH connection
|
||||
lock sync.Mutex // mutex for access to sshmap
|
||||
@@ -584,7 +662,7 @@ type Remotes struct {
|
||||
}
|
||||
|
||||
// The NewRemotes function builds a Remotes struct.
|
||||
func NewRemotes(clientURLs, remoteURLs []string, noop bool, remotes []string, cConns uint16, interactive bool, sshPrivIdRsa string) *Remotes {
|
||||
func NewRemotes(clientURLs, remoteURLs []string, noop bool, remotes []string, cConns uint16, interactive bool, sshPrivIdRsa string, caching bool, prefix string) *Remotes {
|
||||
return &Remotes{
|
||||
clientURLs: clientURLs,
|
||||
remoteURLs: remoteURLs,
|
||||
@@ -593,6 +671,8 @@ func NewRemotes(clientURLs, remoteURLs []string, noop bool, remotes []string, cC
|
||||
cConns: cConns,
|
||||
interactive: interactive,
|
||||
sshPrivIdRsa: sshPrivIdRsa,
|
||||
caching: caching,
|
||||
prefix: prefix,
|
||||
sshmap: make(map[string]*SSH),
|
||||
semaphore: NewSemaphore(int(cConns)),
|
||||
}
|
||||
@@ -668,6 +748,8 @@ func (obj *Remotes) NewSSH(file string) (*SSH, error) {
|
||||
clientURLs: obj.clientURLs,
|
||||
remoteURLs: obj.remoteURLs,
|
||||
noop: obj.noop,
|
||||
caching: obj.caching,
|
||||
prefix: obj.prefix,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user