From 751d3552dc3e3514c51fb9038ab91a625470212f Mon Sep 17 00:00:00 2001 From: Sverrir Berg Date: Tue, 17 May 2016 15:06:35 +0000 Subject: [PATCH] patchviasocket improve error handling more detailed error if host file not found or cannot be opened using mkstemp and mkdtemp for improved security improve resource cleanup in error conditions in unit test --- scripts/vm/hypervisor/kvm/patchviasocket.py | 20 ++++---- .../vm/hypervisor/kvm/test_patchviasocket.py | 46 ++++++++++--------- 2 files changed, 32 insertions(+), 34 deletions(-) diff --git a/scripts/vm/hypervisor/kvm/patchviasocket.py b/scripts/vm/hypervisor/kvm/patchviasocket.py index d9616c9e8b9..c971d5dcc58 100755 --- a/scripts/vm/hypervisor/kvm/patchviasocket.py +++ b/scripts/vm/hypervisor/kvm/patchviasocket.py @@ -31,22 +31,18 @@ PUB_KEY_FILE = "/root/.ssh/id_rsa.pub.cloud" MESSAGE = "pubkey:{key}\ncmdline:{cmdline}\n" -def read_pub_key(key_file): - try: - if os.path.isfile(key_file): - with open(key_file, "r") as f: - return f.read() - except IOError: - return None - - def send_to_socket(sock_file, key_file, cmdline): - pub_key = read_pub_key(key_file) - - if not pub_key: + if not os.path.exists(key_file): print("ERROR: ssh public key not found on host at {0}".format(key_file)) return 1 + try: + with open(key_file, "r") as f: + pub_key = f.read() + except IOError as e: + print("ERROR: unable to open {0} - {1}".format(key_file, e.strerror)) + return 1 + # Keep old substitution from perl code: cmdline = cmdline.replace("%", " ") diff --git a/scripts/vm/hypervisor/kvm/test_patchviasocket.py b/scripts/vm/hypervisor/kvm/test_patchviasocket.py index 074b159a7a6..6b411d32246 100755 --- a/scripts/vm/hypervisor/kvm/test_patchviasocket.py +++ b/scripts/vm/hypervisor/kvm/test_patchviasocket.py @@ -32,7 +32,7 @@ NON_EXISTING_FILE = "must-not-exist" def write_key_file(): - tmpfile = tempfile.mktemp(".sck") + _, tmpfile = tempfile.mkstemp(".sck") with open(tmpfile, "w") as f: f.write(KEY_DATA) return tmpfile @@ -42,7 +42,8 @@ class SocketThread(threading.Thread): def __init__(self): super(SocketThread, self).__init__() self._data = "" - self._file = tempfile.mktemp(".sck") + self._folder = tempfile.mkdtemp(".sck") + self._file = os.path.join(self._folder, "socket") self._ready = False def data(self): @@ -60,18 +61,21 @@ class SocketThread(threading.Thread): MAX_SIZE = 10 * 1024 s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - s.bind(self._file) - s.listen(1) - s.settimeout(TIMEOUT) try: - self._ready = True - client, address = s.accept() - self._data = client.recv(MAX_SIZE) - client.close() - except socket.timeout: - pass - s.close() - os.remove(self._file) + s.bind(self._file) + s.listen(1) + s.settimeout(TIMEOUT) + try: + self._ready = True + client, address = s.accept() + self._data = client.recv(MAX_SIZE) + client.close() + except socket.timeout: + pass + finally: + s.close() + os.remove(self._file) + os.rmdir(self._folder) class TestPatchViaSocket(unittest.TestCase): @@ -88,15 +92,6 @@ class TestPatchViaSocket(unittest.TestCase): os.remove(self._key_file) os.remove(self._unreadable) - def test_read_file(self): - pub_key = patchviasocket.read_pub_key(self._key_file) - self.assertEqual(KEY_DATA, pub_key) - - def test_read_file_error(self): - self.assertIsNone(patchviasocket.read_pub_key(NON_EXISTING_FILE)) - self.assertIsNone(patchviasocket.read_pub_key(self._unreadable)) - self.assertIsNone(patchviasocket.read_pub_key("/tmp")) # folder is not a file - def test_write_to_socket(self): reader = SocketThread() reader.start() @@ -116,6 +111,13 @@ class TestPatchViaSocket(unittest.TestCase): self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA)) reader.join() # timeout + def test_host_key_access_denied(self): + reader = SocketThread() + reader.start() + reader.wait_until_ready() + self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA)) + reader.join() # timeout + def test_nonexistant_socket_error(self): reader = SocketThread() reader.start()