From e921ec6ec79c50096d58264d60c15091969ff888 Mon Sep 17 00:00:00 2001 From: Gaurav Aradhye Date: Tue, 23 Sep 2014 14:18:18 +0530 Subject: [PATCH] CLOUDSTACK-7408: Fixed - Private key of the ssh keypair was getting corrupted Signed-off-by: SrikanteswaraRao Talluri --- tools/marvin/marvin/lib/base.py | 8 +++++--- tools/marvin/marvin/lib/utils.py | 7 +++++-- tools/marvin/marvin/sshClient.py | 20 +++++++++++++++++--- 3 files changed, 27 insertions(+), 8 deletions(-) diff --git a/tools/marvin/marvin/lib/base.py b/tools/marvin/marvin/lib/base.py index 5bd89318523..d6233862fd3 100755 --- a/tools/marvin/marvin/lib/base.py +++ b/tools/marvin/marvin/lib/base.py @@ -552,7 +552,7 @@ class VirtualMachine: def get_ssh_client( self, ipaddress=None, reconnect=False, port=None, - keyPairFileLocation=None): + keyPairFileLocation=None, knownHostsFilePath=None): """Get SSH object of VM""" # If NAT Rules are not created while VM deployment in Advanced mode @@ -571,14 +571,16 @@ class VirtualMachine: self.ssh_port, self.username, self.password, - keyPairFileLocation=keyPairFileLocation + keyPairFileLocation=keyPairFileLocation, + knownHostsFilePath=knownHostsFilePath ) self.ssh_client = self.ssh_client or is_server_ssh_ready( self.ssh_ip, self.ssh_port, self.username, self.password, - keyPairFileLocation=keyPairFileLocation + keyPairFileLocation=keyPairFileLocation, + knownHostsFilePath=knownHostsFilePath ) return self.ssh_client diff --git a/tools/marvin/marvin/lib/utils.py b/tools/marvin/marvin/lib/utils.py index 8788b3b736f..b58b59dccab 100644 --- a/tools/marvin/marvin/lib/utils.py +++ b/tools/marvin/marvin/lib/utils.py @@ -121,7 +121,9 @@ def cleanup_resources(api_client, resources): obj.delete(api_client) -def is_server_ssh_ready(ipaddress, port, username, password, retries=20, retryinterv=30, timeout=10.0, keyPairFileLocation=None): +def is_server_ssh_ready(ipaddress, port, username, password, retries=20, + retryinterv=30, timeout=10.0, keyPairFileLocation=None, + knownHostsFilePath=None): ''' @Name: is_server_ssh_ready @Input: timeout: tcp connection timeout flag, @@ -140,7 +142,8 @@ def is_server_ssh_ready(ipaddress, port, username, password, retries=20, retryin keyPairFiles=keyPairFileLocation, retries=retries, delay=retryinterv, - timeout=timeout) + timeout=timeout, + knownHostsFilePath=knownHostsFilePath) except Exception, e: raise Exception("SSH connection has Failed. Waited %ss. Error is %s" % (retries * retryinterv, str(e))) else: diff --git a/tools/marvin/marvin/sshClient.py b/tools/marvin/marvin/sshClient.py index df2eeee1d4c..f027890522f 100644 --- a/tools/marvin/marvin/sshClient.py +++ b/tools/marvin/marvin/sshClient.py @@ -24,6 +24,7 @@ from paramiko import (BadHostKeyException, SFTPClient) import socket import time +import os from marvin.cloudstackException import ( internalError, GetDetailExceptionInfo @@ -49,7 +50,8 @@ class SshClient(object): ''' def __init__(self, host, port, user, passwd, retries=60, delay=10, - log_lvl=logging.DEBUG, keyPairFiles=None, timeout=10.0): + log_lvl=logging.DEBUG, keyPairFiles=None, timeout=10.0, + knownHostsFilePath=None): self.host = None self.port = 22 self.user = user @@ -77,6 +79,18 @@ class SshClient(object): self.timeout = timeout if port is not None and port >= 0: self.port = port + + # If the known_hosts file is not at default location, + # then its location can be passed, or else the default + # path will be considered (which is ~/.ssh/known_hosts) + if knownHostsFilePath: + self.knownHostsFilePath = knownHostsFilePath + else: + self.knownHostsFilePath = os.path.expanduser( + os.path.join( + "~", + ".ssh", + "known_hosts")) if self.createConnection() == FAILED: raise internalError("SSH Connection Failed") @@ -120,14 +134,14 @@ class SshClient(object): password=self.passwd, timeout=self.timeout) else: - self.ssh.load_host_keys(self.keyPairFiles) + self.ssh.load_host_keys(self.knownHostsFilePath) self.ssh.connect(hostname=self.host, port=self.port, username=self.user, password=self.passwd, key_filename=self.keyPairFiles, timeout=self.timeout, - look_for_keys=True + look_for_keys=False ) self.logger.debug("===SSH to Host %s port : %s SUCCESSFUL===" % (str(self.host), str(self.port)))