diff --git a/scripts/vm/hypervisor/kvm/patchviasocket.py b/scripts/vm/hypervisor/kvm/patchviasocket.py index d9616c9e8b9..c971d5dcc58 100755 --- a/scripts/vm/hypervisor/kvm/patchviasocket.py +++ b/scripts/vm/hypervisor/kvm/patchviasocket.py @@ -31,22 +31,18 @@ PUB_KEY_FILE = "/root/.ssh/id_rsa.pub.cloud" MESSAGE = "pubkey:{key}\ncmdline:{cmdline}\n" -def read_pub_key(key_file): - try: - if os.path.isfile(key_file): - with open(key_file, "r") as f: - return f.read() - except IOError: - return None - - def send_to_socket(sock_file, key_file, cmdline): - pub_key = read_pub_key(key_file) - - if not pub_key: + if not os.path.exists(key_file): print("ERROR: ssh public key not found on host at {0}".format(key_file)) return 1 + try: + with open(key_file, "r") as f: + pub_key = f.read() + except IOError as e: + print("ERROR: unable to open {0} - {1}".format(key_file, e.strerror)) + return 1 + # Keep old substitution from perl code: cmdline = cmdline.replace("%", " ") diff --git a/scripts/vm/hypervisor/kvm/test_patchviasocket.py b/scripts/vm/hypervisor/kvm/test_patchviasocket.py index 074b159a7a6..6b411d32246 100755 --- a/scripts/vm/hypervisor/kvm/test_patchviasocket.py +++ b/scripts/vm/hypervisor/kvm/test_patchviasocket.py @@ -32,7 +32,7 @@ NON_EXISTING_FILE = "must-not-exist" def write_key_file(): - tmpfile = tempfile.mktemp(".sck") + _, tmpfile = tempfile.mkstemp(".sck") with open(tmpfile, "w") as f: f.write(KEY_DATA) return tmpfile @@ -42,7 +42,8 @@ class SocketThread(threading.Thread): def __init__(self): super(SocketThread, self).__init__() self._data = "" - self._file = tempfile.mktemp(".sck") + self._folder = tempfile.mkdtemp(".sck") + self._file = os.path.join(self._folder, "socket") self._ready = False def data(self): @@ -60,18 +61,21 @@ class SocketThread(threading.Thread): MAX_SIZE = 10 * 1024 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s.bind(self._file) - s.listen(1) - s.settimeout(TIMEOUT) try: - self._ready = True - client, address = s.accept() - self._data = client.recv(MAX_SIZE) - client.close() - except socket.timeout: - pass - s.close() - os.remove(self._file) + s.bind(self._file) + s.listen(1) + s.settimeout(TIMEOUT) + try: + self._ready = True + client, address = s.accept() + self._data = client.recv(MAX_SIZE) + client.close() + except socket.timeout: + pass + finally: + s.close() + os.remove(self._file) + os.rmdir(self._folder) class TestPatchViaSocket(unittest.TestCase): @@ -88,15 +92,6 @@ class TestPatchViaSocket(unittest.TestCase): os.remove(self._key_file) os.remove(self._unreadable) - def test_read_file(self): - pub_key = patchviasocket.read_pub_key(self._key_file) - self.assertEqual(KEY_DATA, pub_key) - - def test_read_file_error(self): - self.assertIsNone(patchviasocket.read_pub_key(NON_EXISTING_FILE)) - self.assertIsNone(patchviasocket.read_pub_key(self._unreadable)) - self.assertIsNone(patchviasocket.read_pub_key("/tmp")) # folder is not a file - def test_write_to_socket(self): reader = SocketThread() reader.start() @@ -116,6 +111,13 @@ class TestPatchViaSocket(unittest.TestCase): self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA)) reader.join() # timeout + def test_host_key_access_denied(self): + reader = SocketThread() + reader.start() + reader.wait_until_ready() + self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA)) + reader.join() # timeout + def test_nonexistant_socket_error(self): reader = SocketThread() reader.start()