diff --git a/etcd/ssh/ssh.go b/etcd/ssh/ssh.go index 78380c80..322a96e1 100644 --- a/etcd/ssh/ssh.go +++ b/etcd/ssh/ssh.go @@ -241,7 +241,7 @@ func (obj *World) Connect(ctx context.Context, init *engine.WorldInit) error { } obj.init.Logf("ssh: %s@%s", user, addr) - obj.sshClient, err = ssh.Dial("tcp", addr, sshConfig) + obj.sshClient, err = dialSSHWithContext(ctx, "tcp", addr, sshConfig) if err != nil { return err } @@ -341,3 +341,20 @@ func (obj *World) cleanup() error { func (obj *World) Cleanup() error { return obj.cleanup() } + +// dialSSHWithContext wraps ssh.Dial so that we can have a context to cancel. +func dialSSHWithContext(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*ssh.Client, error) { + dialer := &net.Dialer{} + conn, err := dialer.DialContext(ctx, network, addr) + if err != nil { + return nil, err + } + + c, chans, reqs, err := ssh.NewClientConn(conn, addr, config) + if err != nil { + conn.Close() + return nil, err + } + + return ssh.NewClient(c, chans, reqs), nil +}