mirror of https://github.com/apache/cloudstack.git
143 lines
4.5 KiB
Python
Executable File
143 lines
4.5 KiB
Python
Executable File
#!/usr/bin/env python
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import patchviasocket
|
|
|
|
import getpass
|
|
import os
|
|
import socket
|
|
import tempfile
|
|
import time
|
|
import threading
|
|
import unittest
|
|
|
|
KEY_DATA = "I luv\nCloudStack\n"
|
|
CMD_DATA = "/run/this-for-me --please=TRUE! very%quickly"
|
|
NON_EXISTING_FILE = "must-not-exist"
|
|
|
|
|
|
def write_key_file():
|
|
tmpfile = tempfile.mktemp(".sck")
|
|
with open(tmpfile, "w") as f:
|
|
f.write(KEY_DATA)
|
|
return tmpfile
|
|
|
|
|
|
class SocketThread(threading.Thread):
|
|
def __init__(self):
|
|
super(SocketThread, self).__init__()
|
|
self._data = ""
|
|
self._file = tempfile.mktemp(".sck")
|
|
self._ready = False
|
|
|
|
def data(self):
|
|
return self._data
|
|
|
|
def file(self):
|
|
return self._file
|
|
|
|
def wait_until_ready(self):
|
|
while not self._ready:
|
|
time.sleep(0.050)
|
|
|
|
def run(self):
|
|
TIMEOUT = 0.314 # Very short time for tests that don't write to socket.
|
|
MAX_SIZE = 10 * 1024
|
|
|
|
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
|
|
s.bind(self._file)
|
|
s.listen(1)
|
|
s.settimeout(TIMEOUT)
|
|
try:
|
|
self._ready = True
|
|
client, address = s.accept()
|
|
self._data = client.recv(MAX_SIZE)
|
|
client.close()
|
|
except socket.timeout:
|
|
pass
|
|
s.close()
|
|
os.remove(self._file)
|
|
|
|
|
|
class TestPatchViaSocket(unittest.TestCase):
|
|
def setUp(self):
|
|
self._key_file = write_key_file()
|
|
|
|
self._unreadable = write_key_file()
|
|
os.chmod(self._unreadable, 0)
|
|
|
|
self.assertFalse(os.path.exists(NON_EXISTING_FILE))
|
|
self.assertNotEqual("root", getpass.getuser(), "must be non-root user (to test access denied errors)")
|
|
|
|
def tearDown(self):
|
|
os.remove(self._key_file)
|
|
os.remove(self._unreadable)
|
|
|
|
def test_read_file(self):
|
|
pub_key = patchviasocket.read_pub_key(self._key_file)
|
|
self.assertEqual(KEY_DATA, pub_key)
|
|
|
|
def test_read_file_error(self):
|
|
self.assertIsNone(patchviasocket.read_pub_key(NON_EXISTING_FILE))
|
|
self.assertIsNone(patchviasocket.read_pub_key(self._unreadable))
|
|
self.assertIsNone(patchviasocket.read_pub_key("/tmp")) # folder is not a file
|
|
|
|
def test_write_to_socket(self):
|
|
reader = SocketThread()
|
|
reader.start()
|
|
reader.wait_until_ready()
|
|
self.assertEquals(0, patchviasocket.send_to_socket(reader.file(), self._key_file, CMD_DATA))
|
|
reader.join()
|
|
data = reader.data()
|
|
self.assertIn(KEY_DATA, data)
|
|
self.assertIn(CMD_DATA.replace("%", " "), data)
|
|
self.assertNotIn("LUV", data)
|
|
self.assertNotIn("very%quickly", data) # Testing substitution
|
|
|
|
def test_host_key_error(self):
|
|
reader = SocketThread()
|
|
reader.start()
|
|
reader.wait_until_ready()
|
|
self.assertEquals(1, patchviasocket.send_to_socket(reader.file(), NON_EXISTING_FILE, CMD_DATA))
|
|
reader.join() # timeout
|
|
|
|
def test_nonexistant_socket_error(self):
|
|
reader = SocketThread()
|
|
reader.start()
|
|
reader.wait_until_ready()
|
|
self.assertEquals(1, patchviasocket.send_to_socket(NON_EXISTING_FILE, self._key_file, CMD_DATA))
|
|
reader.join() # timeout
|
|
|
|
def test_invalid_socket_error(self):
|
|
reader = SocketThread()
|
|
reader.start()
|
|
reader.wait_until_ready()
|
|
self.assertEquals(1, patchviasocket.send_to_socket(self._key_file, self._key_file, CMD_DATA))
|
|
reader.join() # timeout
|
|
|
|
def test_access_denied_socket_error(self):
|
|
reader = SocketThread()
|
|
reader.start()
|
|
reader.wait_until_ready()
|
|
self.assertEquals(1, patchviasocket.send_to_socket(self._unreadable, self._key_file, CMD_DATA))
|
|
reader.join() # timeout
|
|
|
|
|
|
if __name__ == '__main__':
|
|
unittest.main()
|