patchviasocket improve error handling

more detailed error if host file not found or cannot be opened
using mkstemp and mkdtemp for improved security
improve resource cleanup in error conditions in unit test
This commit is contained in:
Sverrir Berg 2016-05-17 15:06:35 +00:00
parent 0acd3c12a2
commit 751d3552dc
2 changed files with 32 additions and 34 deletions

View File

@ -31,22 +31,18 @@ PUB_KEY_FILE = "/root/.ssh/id_rsa.pub.cloud"
MESSAGE = "pubkey:{key}\ncmdline:{cmdline}\n"
def read_pub_key(key_file):
try:
if os.path.isfile(key_file):
with open(key_file, "r") as f:
return f.read()
except IOError:
return None
def send_to_socket(sock_file, key_file, cmdline):
pub_key = read_pub_key(key_file)
if not pub_key:
if not os.path.exists(key_file):
print("ERROR: ssh public key not found on host at {0}".format(key_file))
return 1
try:
with open(key_file, "r") as f:
pub_key = f.read()
except IOError as e:
print("ERROR: unable to open {0} - {1}".format(key_file, e.strerror))
return 1
# Keep old substitution from perl code:
cmdline = cmdline.replace("%", " ")

View File

@ -32,7 +32,7 @@ NON_EXISTING_FILE = "must-not-exist"
def write_key_file():
tmpfile = tempfile.mktemp(".sck")
_, tmpfile = tempfile.mkstemp(".sck")
with open(tmpfile, "w") as f:
f.write(KEY_DATA)
return tmpfile
@ -42,7 +42,8 @@ class SocketThread(threading.Thread):
def __init__(self):
super(SocketThread, self).__init__()
self._data = ""
self._file = tempfile.mktemp(".sck")
self._folder = tempfile.mkdtemp(".sck")
self._file = os.path.join(self._folder, "socket")
self._ready = False
def data(self):
@ -60,18 +61,21 @@ class SocketThread(threading.Thread):
MAX_SIZE = 10 * 1024
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.bind(self._file)
s.listen(1)
s.settimeout(TIMEOUT)
try:
self._ready = True
client, address = s.accept()
self._data = client.recv(MAX_SIZE)
client.close()
except socket.timeout:
pass
s.close()
os.remove(self._file)
s.bind(self._file)
s.listen(1)
s.settimeout(TIMEOUT)
try:
self._ready = True
client, address = s.accept()
self._data = client.recv(MAX_SIZE)
client.close()
except socket.timeout:
pass
finally:
s.close()
os.remove(self._file)
os.rmdir(self._folder)
class TestPatchViaSocket(unittest.TestCase):
@ -88,15 +92,6 @@ class TestPatchViaSocket(unittest.TestCase):
os.remove(self._key_file)
os.remove(self._unreadable)
def test_read_file(self):
pub_key = patchviasocket.read_pub_key(self._key_file)
self.assertEqual(KEY_DATA, pub_key)
def test_read_file_error(self):
self.assertIsNone(patchviasocket.read_pub_key(NON_EXISTING_FILE))
self.assertIsNone(patchviasocket.read_pub_key(self._unreadable))
self.assertIsNone(patchviasocket.read_pub_key("/tmp")) # folder is not a file
def test_write_to_socket(self):
reader = SocketThread()
reader.start()
@ -116,6 +111,13 @@ class TestPatchViaSocket(unittest.TestCase):
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA))
reader.join() # timeout
def test_host_key_access_denied(self):
reader = SocketThread()
reader.start()
reader.wait_until_ready()
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), self._unreadable, CMD_DATA))
reader.join() # timeout
def test_nonexistant_socket_error(self):
reader = SocketThread()
reader.start()