mirror of https://github.com/apache/cloudstack.git
patchviasocket improve error handling
more detailed error if host file not found or cannot be opened using mkstemp and mkdtemp for improved security improve resource cleanup in error conditions in unit test
This commit is contained in:
parent
0acd3c12a2
commit
751d3552dc
|
|
@ -31,22 +31,18 @@ PUB_KEY_FILE = "/root/.ssh/id_rsa.pub.cloud"
|
|||
MESSAGE = "pubkey:{key}\ncmdline:{cmdline}\n"
|
||||
|
||||
|
||||
def read_pub_key(key_file):
|
||||
try:
|
||||
if os.path.isfile(key_file):
|
||||
with open(key_file, "r") as f:
|
||||
return f.read()
|
||||
except IOError:
|
||||
return None
|
||||
|
||||
|
||||
def send_to_socket(sock_file, key_file, cmdline):
|
||||
pub_key = read_pub_key(key_file)
|
||||
|
||||
if not pub_key:
|
||||
if not os.path.exists(key_file):
|
||||
print("ERROR: ssh public key not found on host at {0}".format(key_file))
|
||||
return 1
|
||||
|
||||
try:
|
||||
with open(key_file, "r") as f:
|
||||
pub_key = f.read()
|
||||
except IOError as e:
|
||||
print("ERROR: unable to open {0} - {1}".format(key_file, e.strerror))
|
||||
return 1
|
||||
|
||||
# Keep old substitution from perl code:
|
||||
cmdline = cmdline.replace("%", " ")
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ NON_EXISTING_FILE = "must-not-exist"
|
|||
|
||||
|
||||
def write_key_file():
|
||||
tmpfile = tempfile.mktemp(".sck")
|
||||
_, tmpfile = tempfile.mkstemp(".sck")
|
||||
with open(tmpfile, "w") as f:
|
||||
f.write(KEY_DATA)
|
||||
return tmpfile
|
||||
|
|
@ -42,7 +42,8 @@ class SocketThread(threading.Thread):
|
|||
def __init__(self):
|
||||
super(SocketThread, self).__init__()
|
||||
self._data = ""
|
||||
self._file = tempfile.mktemp(".sck")
|
||||
self._folder = tempfile.mkdtemp(".sck")
|
||||
self._file = os.path.join(self._folder, "socket")
|
||||
self._ready = False
|
||||
|
||||
def data(self):
|
||||
|
|
@ -60,18 +61,21 @@ class SocketThread(threading.Thread):
|
|||
MAX_SIZE = 10 * 1024
|
||||
|
||||
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
||||
s.bind(self._file)
|
||||
s.listen(1)
|
||||
s.settimeout(TIMEOUT)
|
||||
try:
|
||||
self._ready = True
|
||||
client, address = s.accept()
|
||||
self._data = client.recv(MAX_SIZE)
|
||||
client.close()
|
||||
except socket.timeout:
|
||||
pass
|
||||
s.close()
|
||||
os.remove(self._file)
|
||||
s.bind(self._file)
|
||||
s.listen(1)
|
||||
s.settimeout(TIMEOUT)
|
||||
try:
|
||||
self._ready = True
|
||||
client, address = s.accept()
|
||||
self._data = client.recv(MAX_SIZE)
|
||||
client.close()
|
||||
except socket.timeout:
|
||||
pass
|
||||
finally:
|
||||
s.close()
|
||||
os.remove(self._file)
|
||||
os.rmdir(self._folder)
|
||||
|
||||
|
||||
class TestPatchViaSocket(unittest.TestCase):
|
||||
|
|
@ -88,15 +92,6 @@ class TestPatchViaSocket(unittest.TestCase):
|
|||
os.remove(self._key_file)
|
||||
os.remove(self._unreadable)
|
||||
|
||||
def test_read_file(self):
|
||||
pub_key = patchviasocket.read_pub_key(self._key_file)
|
||||
self.assertEqual(KEY_DATA, pub_key)
|
||||
|
||||
def test_read_file_error(self):
|
||||
self.assertIsNone(patchviasocket.read_pub_key(NON_EXISTING_FILE))
|
||||
self.assertIsNone(patchviasocket.read_pub_key(self._unreadable))
|
||||
self.assertIsNone(patchviasocket.read_pub_key("/tmp")) # folder is not a file
|
||||
|
||||
def test_write_to_socket(self):
|
||||
reader = SocketThread()
|
||||
reader.start()
|
||||
|
|
@ -116,6 +111,13 @@ class TestPatchViaSocket(unittest.TestCase):
|
|||
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA))
|
||||
reader.join() # timeout
|
||||
|
||||
def test_host_key_access_denied(self):
|
||||
reader = SocketThread()
|
||||
reader.start()
|
||||
reader.wait_until_ready()
|
||||
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA))
|
||||
reader.join() # timeout
|
||||
|
||||
def test_nonexistant_socket_error(self):
|
||||
reader = SocketThread()
|
||||
reader.start()
|
||||
|
|
|
|||
Loading…
Reference in New Issue