From 5c37d144cc6c590f38769aaa570d14226136d171 Mon Sep 17 00:00:00 2001 From: Prasanna Santhanam Date: Wed, 11 Apr 2012 16:43:33 +0530 Subject: [PATCH] Package management for the python testclient christened Marvin $ant package-marvin will create a packaged source tarball of the testclient that is redistributable and decoupled from the rest of the cloudstack build reviewed-by: unittest --- build.xml | 1 + build/build-marvin.xml | 45 + tools/marvin/CHANGES.txt | 1 + tools/marvin/LICENSE.txt | 9 + tools/marvin/MANIFEST.in | 2 + tools/marvin/README | 15 + tools/marvin/docs/tutorial.txt | 1 + .../marvin/marvin/NoseTestExecutionEngine.py | 34 + tools/marvin/marvin/TestCaseExecuteEngine.py | 51 + tools/marvin/marvin/__init__.py | 1 + tools/marvin/marvin/asyncJobMgr.py | 216 ++++ tools/marvin/marvin/cloudstackConnection.py | 163 +++ tools/marvin/marvin/cloudstackException.py | 24 + tools/marvin/marvin/cloudstackTestCase.py | 11 + tools/marvin/marvin/cloudstackTestClient.py | 58 ++ tools/marvin/marvin/codegenerator.py | 277 ++++++ tools/marvin/marvin/configGenerator.py | 378 +++++++ tools/marvin/marvin/dbConnection.py | 78 ++ tools/marvin/marvin/deployAndRun.py | 32 + tools/marvin/marvin/deployDataCenter.py | 274 ++++++ tools/marvin/marvin/jsonHelper.py | 174 ++++ tools/marvin/marvin/pymysql/__init__.py | 131 +++ tools/marvin/marvin/pymysql/charset.py | 174 ++++ tools/marvin/marvin/pymysql/connections.py | 928 ++++++++++++++++++ .../marvin/marvin/pymysql/constants/CLIENT.py | 20 + .../marvin/pymysql/constants/COMMAND.py | 23 + tools/marvin/marvin/pymysql/constants/ER.py | 472 +++++++++ .../marvin/pymysql/constants/FIELD_TYPE.py | 32 + tools/marvin/marvin/pymysql/constants/FLAG.py | 15 + .../marvin/pymysql/constants/SERVER_STATUS.py | 12 + .../marvin/pymysql/constants/__init__.py | 0 tools/marvin/marvin/pymysql/converters.py | 348 +++++++ tools/marvin/marvin/pymysql/cursors.py | 297 ++++++ tools/marvin/marvin/pymysql/err.py | 147 +++ tools/marvin/marvin/pymysql/tests/__init__.py | 13 + tools/marvin/marvin/pymysql/tests/base.py | 20 + .../marvin/pymysql/tests/test_DictCursor.py | 56 ++ .../marvin/marvin/pymysql/tests/test_basic.py | 193 ++++ .../marvin/pymysql/tests/test_example.py | 32 + .../marvin/pymysql/tests/test_issues.py | 268 +++++ .../pymysql/tests/thirdparty/__init__.py | 5 + .../tests/thirdparty/test_MySQLdb/__init__.py | 7 + .../thirdparty/test_MySQLdb/capabilities.py | 292 ++++++ .../tests/thirdparty/test_MySQLdb/dbapi20.py | 853 ++++++++++++++++ .../test_MySQLdb/test_MySQLdb_capabilities.py | 115 +++ .../test_MySQLdb/test_MySQLdb_dbapi20.py | 205 ++++ .../test_MySQLdb/test_MySQLdb_nonstandard.py | 90 ++ tools/marvin/marvin/pymysql/times.py | 16 + tools/marvin/marvin/pymysql/util.py | 19 + tools/marvin/marvin/remoteSSHClient.py | 36 + tools/marvin/marvin/sandbox/README.txt | 19 + .../marvin/sandbox/advanced/advanced_env.py | 117 +++ .../marvin/sandbox/advanced/setup.properties | 36 + .../sandbox/advanced/tests/test_scenarios.py | 126 +++ .../marvin/marvin/sandbox/basic/basic_env.py | 0 tools/marvin/marvin/sandbox/demo/README | 4 + .../marvin/sandbox/demo/testDeployVM.py | 98 ++ .../marvin/sandbox/demo/testSshDeployVM.py | 143 +++ tools/marvin/setup.py | 34 + 59 files changed, 7241 insertions(+) create mode 100644 build/build-marvin.xml create mode 100644 tools/marvin/CHANGES.txt create mode 100644 tools/marvin/LICENSE.txt create mode 100644 tools/marvin/MANIFEST.in create mode 100644 tools/marvin/README create mode 100644 tools/marvin/docs/tutorial.txt create mode 100644 tools/marvin/marvin/NoseTestExecutionEngine.py create mode 100644 tools/marvin/marvin/TestCaseExecuteEngine.py create mode 100644 tools/marvin/marvin/__init__.py create mode 100644 tools/marvin/marvin/asyncJobMgr.py create mode 100644 tools/marvin/marvin/cloudstackConnection.py create mode 100644 tools/marvin/marvin/cloudstackException.py create mode 100644 tools/marvin/marvin/cloudstackTestCase.py create mode 100644 tools/marvin/marvin/cloudstackTestClient.py create mode 100644 tools/marvin/marvin/codegenerator.py create mode 100644 tools/marvin/marvin/configGenerator.py create mode 100644 tools/marvin/marvin/dbConnection.py create mode 100644 tools/marvin/marvin/deployAndRun.py create mode 100644 tools/marvin/marvin/deployDataCenter.py create mode 100644 tools/marvin/marvin/jsonHelper.py create mode 100644 tools/marvin/marvin/pymysql/__init__.py create mode 100644 tools/marvin/marvin/pymysql/charset.py create mode 100644 tools/marvin/marvin/pymysql/connections.py create mode 100644 tools/marvin/marvin/pymysql/constants/CLIENT.py create mode 100644 tools/marvin/marvin/pymysql/constants/COMMAND.py create mode 100644 tools/marvin/marvin/pymysql/constants/ER.py create mode 100644 tools/marvin/marvin/pymysql/constants/FIELD_TYPE.py create mode 100644 tools/marvin/marvin/pymysql/constants/FLAG.py create mode 100644 tools/marvin/marvin/pymysql/constants/SERVER_STATUS.py create mode 100644 tools/marvin/marvin/pymysql/constants/__init__.py create mode 100644 tools/marvin/marvin/pymysql/converters.py create mode 100644 tools/marvin/marvin/pymysql/cursors.py create mode 100644 tools/marvin/marvin/pymysql/err.py create mode 100644 tools/marvin/marvin/pymysql/tests/__init__.py create mode 100644 tools/marvin/marvin/pymysql/tests/base.py create mode 100644 tools/marvin/marvin/pymysql/tests/test_DictCursor.py create mode 100644 tools/marvin/marvin/pymysql/tests/test_basic.py create mode 100644 tools/marvin/marvin/pymysql/tests/test_example.py create mode 100644 tools/marvin/marvin/pymysql/tests/test_issues.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py create mode 100644 tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py create mode 100644 tools/marvin/marvin/pymysql/times.py create mode 100644 tools/marvin/marvin/pymysql/util.py create mode 100644 tools/marvin/marvin/remoteSSHClient.py create mode 100644 tools/marvin/marvin/sandbox/README.txt create mode 100644 tools/marvin/marvin/sandbox/advanced/advanced_env.py create mode 100644 tools/marvin/marvin/sandbox/advanced/setup.properties create mode 100644 tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py create mode 100644 tools/marvin/marvin/sandbox/basic/basic_env.py create mode 100644 tools/marvin/marvin/sandbox/demo/README create mode 100644 tools/marvin/marvin/sandbox/demo/testDeployVM.py create mode 100644 tools/marvin/marvin/sandbox/demo/testSshDeployVM.py create mode 100644 tools/marvin/setup.py diff --git a/build.xml b/build.xml index 9753213c23b..1a6ae190fb9 100755 --- a/build.xml +++ b/build.xml @@ -23,6 +23,7 @@ + diff --git a/build/build-marvin.xml b/build/build-marvin.xml new file mode 100644 index 00000000000..ed5cf565e26 --- /dev/null +++ b/build/build-marvin.xml @@ -0,0 +1,45 @@ + + + + + + This build file contains simple targets that + - build + - package + - distribute + the Marvin test client written in python + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tools/marvin/CHANGES.txt b/tools/marvin/CHANGES.txt new file mode 100644 index 00000000000..dc207fb992a --- /dev/null +++ b/tools/marvin/CHANGES.txt @@ -0,0 +1 @@ +v0.1.0 Tuesday, April 10 2012 -- Packaging Marvin diff --git a/tools/marvin/LICENSE.txt b/tools/marvin/LICENSE.txt new file mode 100644 index 00000000000..fa64b4a85d0 --- /dev/null +++ b/tools/marvin/LICENSE.txt @@ -0,0 +1,9 @@ +Copyright 2012 Citrix Systems, Inc. Licensed under the Apache License, Version +2.0 (the "License"); you may not use this file except in compliance with the +License. Citrix Systems, Inc. reserves all rights not expressly granted by the +License. You may obtain a copy of the License at +http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or +agreed to in writing, software distributed under the License is distributed on +an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +or implied. See the License for the specific language governing permissions and +limitations under the License. diff --git a/tools/marvin/MANIFEST.in b/tools/marvin/MANIFEST.in new file mode 100644 index 00000000000..92036977a8c --- /dev/null +++ b/tools/marvin/MANIFEST.in @@ -0,0 +1,2 @@ +include *.txt +recursive-include docs *.txt \ No newline at end of file diff --git a/tools/marvin/README b/tools/marvin/README new file mode 100644 index 00000000000..b225b6e93fc --- /dev/null +++ b/tools/marvin/README @@ -0,0 +1,15 @@ +Marvin is the testing framework for CloudStack written in python. Writing of +unittests and functional tests with Marvin makes testing with cloudstack easier + +1. INSTALL + untar Marvin-0.1.0.tar.gz + cd Marvin-0.1.0 + python setup.py install + +2. Facility it provides: + 1. very handy cloudstack API python wrapper + 2. support async job executing in parallel + 3. remote ssh login/execute command + 4. mysql query + +3. sample code is under sandbox diff --git a/tools/marvin/docs/tutorial.txt b/tools/marvin/docs/tutorial.txt new file mode 100644 index 00000000000..4da4b1b1ef1 --- /dev/null +++ b/tools/marvin/docs/tutorial.txt @@ -0,0 +1 @@ +Can be found at : http://wiki.cloudstack.org/display/QA/Testing+with+python diff --git a/tools/marvin/marvin/NoseTestExecutionEngine.py b/tools/marvin/marvin/NoseTestExecutionEngine.py new file mode 100644 index 00000000000..c64e0c1a5f7 --- /dev/null +++ b/tools/marvin/marvin/NoseTestExecutionEngine.py @@ -0,0 +1,34 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest + +from functools import partial +import nose +import nose.config +import nose.core +import os +import sys +import logging + +module_logger = "testclient.nose" + +def testCaseLogger(message, logger=None): + if logger is not None: + logger.debug(message) + +class TestCaseExecuteEngine(object): + def __init__(self, testclient, testCaseFolder, testcaseLogFile=None, testResultLogFile=None): + self.testclient = testclient + self.debuglog = testcaseLogFile + self.testCaseFolder = testCaseFolder + self.testResultLogFile = testResultLogFile + self.cfg = nose.config.Config() + self.cfg.configureWhere(self.testCaseFolder) + self.cfg.configureLogging() + + def run(self): + self.args = ["--debug-log="+self.debuglog] + suite = nose.core.TestProgram(argv = self.args, config = self.cfg) + result = suite.runTests() + print result diff --git a/tools/marvin/marvin/TestCaseExecuteEngine.py b/tools/marvin/marvin/TestCaseExecuteEngine.py new file mode 100644 index 00000000000..cceed9ca8bd --- /dev/null +++ b/tools/marvin/marvin/TestCaseExecuteEngine.py @@ -0,0 +1,51 @@ +try: + import unittest2 as unittest +except ImportError: + import unittest + +from functools import partial +import os +import sys +import logging + +module_logger = "testclient.testcase" + +def testCaseLogger(message, logger=None): + if logger is not None: + logger.debug(message) + +class TestCaseExecuteEngine(object): + def __init__(self, testclient, testCaseFolder, testcaseLogFile=None, testResultLogFile=None): + self.testclient = testclient + self.testCaseFolder = testCaseFolder + self.logger = None + if testcaseLogFile is not None: + logger = logging.getLogger("testclient.testcase.TestCaseExecuteEngine") + fh = logging.FileHandler(testcaseLogFile) + fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + logger.addHandler(fh) + logger.setLevel(logging.DEBUG) + self.logger = logger + if testResultLogFile is not None: + ch = logging.StreamHandler() + ch.setLevel(logging.ERROR) + ch.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + self.logger.addHandler(ch) + fp = open(testResultLogFile, "w") + self.testResultLogFile = fp + else: + self.testResultLogFile = sys.stdout + + def injectTestCase(self, testSuites): + for test in testSuites: + if isinstance(test, unittest.BaseTestSuite): + self.injectTestCase(test) + else: + setattr(test, "testClient", self.testclient) + setattr(test, "debug", partial(testCaseLogger, logger=self.logger)) + def run(self): + loader = unittest.loader.TestLoader() + suite = loader.discover(self.testCaseFolder) + self.injectTestCase(suite) + + unittest.TextTestRunner(stream=self.testResultLogFile, verbosity=2).run(suite) diff --git a/tools/marvin/marvin/__init__.py b/tools/marvin/marvin/__init__.py new file mode 100644 index 00000000000..a3ded0a92d3 --- /dev/null +++ b/tools/marvin/marvin/__init__.py @@ -0,0 +1 @@ +#Marvin - The cloudstack test client \ No newline at end of file diff --git a/tools/marvin/marvin/asyncJobMgr.py b/tools/marvin/marvin/asyncJobMgr.py new file mode 100644 index 00000000000..d843d05ec6e --- /dev/null +++ b/tools/marvin/marvin/asyncJobMgr.py @@ -0,0 +1,216 @@ +import threading +import cloudstackException +import time +import Queue +import copy +import sys +import jsonHelper +import datetime + +class job(object): + def __init__(self): + self.id = None + self.cmd = None +class jobStatus(object): + def __init__(self): + self.result = None + self.status = None + self.startTime = None + self.endTime = None + self.duration = None + self.jobId = None + self.responsecls = None +class workThread(threading.Thread): + def __init__(self, in_queue, outqueue, apiClient, db=None, lock=None): + threading.Thread.__init__(self) + self.inqueue = in_queue + self.output = outqueue + self.connection = apiClient.connection + self.db = None + self.lock = lock + + def queryAsynJob(self, job): + if job.jobId is None: + return job + + try: + self.lock.acquire() + result = self.connection.pollAsyncJob(job.jobId, job.responsecls).jobresult + except cloudstackException.cloudstackAPIException, e: + result = str(e) + finally: + self.lock.release() + + job.result = result + return job + + def executeCmd(self, job): + cmd = job.cmd + + jobstatus = jobStatus() + jobId = None + try: + self.lock.acquire() + + if cmd.isAsync == "false": + jobstatus.startTime = datetime.datetime.now() + + result = self.connection.make_request(cmd) + jobstatus.result = result + jobstatus.endTime = datetime.datetime.now() + jobstatus.duration = time.mktime(jobstatus.endTime.timetuple()) - time.mktime(jobstatus.startTime.timetuple()) + else: + result = self.connection.make_request(cmd, None, True) + if result is None: + jobstatus.status = False + else: + jobId = result.jobid + jobstatus.jobId = jobId + try: + responseName = cmd.__class__.__name__.replace("Cmd", "Response") + jobstatus.responsecls = jsonHelper.getclassFromName(cmd, responseName) + except: + pass + jobstatus.status = True + except cloudstackException.cloudstackAPIException, e: + jobstatus.result = str(e) + jobstatus.status = False + except: + jobstatus.status = False + jobstatus.result = sys.exc_info() + finally: + self.lock.release() + + return jobstatus + + def run(self): + while self.inqueue.qsize() > 0: + job = self.inqueue.get() + if isinstance(job, jobStatus): + jobstatus = self.queryAsynJob(job) + else: + jobstatus = self.executeCmd(job) + + self.output.put(jobstatus) + self.inqueue.task_done() + + '''release the resource''' + self.connection.close() + +class jobThread(threading.Thread): + def __init__(self, inqueue, interval): + threading.Thread.__init__(self) + self.inqueue = inqueue + self.interval = interval + def run(self): + while self.inqueue.qsize() > 0: + job = self.inqueue.get() + try: + job.run() + '''release the api connection''' + job.apiClient.connection.close() + except: + pass + + self.inqueue.task_done() + time.sleep(self.interval) + +class outputDict(object): + def __init__(self): + self.lock = threading.Condition() + self.dict = {} + +class asyncJobMgr(object): + def __init__(self, apiClient, db): + self.inqueue = Queue.Queue() + self.output = outputDict() + self.outqueue = Queue.Queue() + self.apiClient = apiClient + self.db = db + + def submitCmds(self, cmds): + if not self.inqueue.empty(): + return False + id = 0 + ids = [] + for cmd in cmds: + asyncjob = job() + asyncjob.id = id + asyncjob.cmd = cmd + self.inqueue.put(asyncjob) + id += 1 + ids.append(id) + return ids + + def updateTimeStamp(self, jobstatus): + jobId = jobstatus.jobId + if jobId is not None and self.db is not None: + result = self.db.execute("select job_status, created, last_updated from async_job where id=%s"%jobId) + if result is not None and len(result) > 0: + if result[0][0] == 1: + jobstatus.status = True + else: + jobstatus.status = False + jobstatus.startTime = result[0][1] + jobstatus.endTime = result[0][2] + delta = jobstatus.endTime - jobstatus.startTime + jobstatus.duration = delta.total_seconds() + + def waitForComplete(self, workers=10): + self.inqueue.join() + lock = threading.Lock() + resultQueue = Queue.Queue() + '''intermediate result is stored in self.outqueue''' + for i in range(workers): + worker = workThread(self.outqueue, resultQueue, self.apiClient, self.db, lock) + worker.start() + + self.outqueue.join() + + asyncJobResult = [] + while resultQueue.qsize() > 0: + jobstatus = resultQueue.get() + self.updateTimeStamp(jobstatus) + asyncJobResult.append(jobstatus) + + return asyncJobResult + + '''put commands into a queue at first, then start workers numbers threads to execute this commands''' + def submitCmdsAndWait(self, cmds, workers=10): + self.submitCmds(cmds) + lock = threading.Lock() + for i in range(workers): + worker = workThread(self.inqueue, self.outqueue, self.apiClient, self.db, lock) + worker.start() + + return self.waitForComplete(workers) + + '''submit one job and execute the same job ntimes, with nums_threads of threads''' + def submitJobExecuteNtimes(self, job, ntimes=1, nums_threads=1, interval=1): + inqueue1 = Queue.Queue() + lock = threading.Condition() + for i in range(ntimes): + newjob = copy.copy(job) + setattr(newjob, "apiClient", copy.copy(self.apiClient)) + setattr(newjob, "lock", lock) + inqueue1.put(newjob) + + for i in range(nums_threads): + work = jobThread(inqueue1, interval) + work.start() + inqueue1.join() + + '''submit n jobs, execute them with nums_threads of threads''' + def submitJobs(self, jobs, nums_threads=1, interval=1): + inqueue1 = Queue.Queue() + lock = threading.Condition() + + for job in jobs: + setattr(job, "apiClient", copy.copy(self.apiClient)) + setattr(job, "lock", lock) + inqueue1.put(job) + + for i in range(nums_threads): + work = jobThread(inqueue1, interval) + work.start() + inqueue1.join() \ No newline at end of file diff --git a/tools/marvin/marvin/cloudstackConnection.py b/tools/marvin/marvin/cloudstackConnection.py new file mode 100644 index 00000000000..d3a57263df7 --- /dev/null +++ b/tools/marvin/marvin/cloudstackConnection.py @@ -0,0 +1,163 @@ +import urllib2 +import urllib +import base64 +import copy +import hmac +import hashlib +import json +import xml.dom.minidom +import types +import time +import inspect +import cloudstackException +from cloudstackAPI import * +import jsonHelper + +class cloudConnection(object): + def __init__(self, mgtSvr, port=8096, apiKey = None, securityKey = None, asyncTimeout=3600, logging=None): + self.apiKey = apiKey + self.securityKey = securityKey + self.mgtSvr = mgtSvr + self.port = port + self.logging = logging + if port == 8096: + self.auth = False + else: + self.auth = True + + self.retries = 5 + self.asyncTimeout = asyncTimeout + + def close(self): + try: + self.connection.close() + except: + pass + + def __copy__(self): + return cloudConnection(self.mgtSvr, self.port, self.apiKey, self.securityKey, self.asyncTimeout, self.logging) + + def make_request_with_auth(self, command, requests={}): + requests["command"] = command + requests["apiKey"] = self.apiKey + requests["response"] = "json" + request = zip(requests.keys(), requests.values()) + request.sort(key=lambda x: str.lower(x[0])) + + requestUrl = "&".join(["=".join([r[0], urllib.quote_plus(str(r[1]))]) for r in request]) + hashStr = "&".join(["=".join([str.lower(r[0]), str.lower(urllib.quote_plus(str(r[1]))).replace("+", "%20")]) for r in request]) + + sig = urllib.quote_plus(base64.encodestring(hmac.new(self.securityKey, hashStr, hashlib.sha1).digest()).strip()) + requestUrl += "&signature=%s"%sig + + try: + self.connection = urllib2.urlopen("http://%s:%d/client/api?%s"%(self.mgtSvr, self.port, requestUrl)) + self.logging.debug("sending request: %s"%requestUrl) + response = self.connection.read() + self.logging.debug("got response: %s"%response) + except IOError, e: + if hasattr(e, 'reason'): + self.logging.debug("failed to reach %s because of %s"%(self.mgtSvr, e.reason)) + elif hasattr(e, 'code'): + self.logging.debug("server returned %d error code"%e.code) + except HTTPException, h: + self.logging.debug("encountered http Exception %s"%h.args) + if self.retries > 0: + self.retries = self.retries - 1 + self.make_request_with_auth(command, requests) + else: + self.retries = 5 + raise h + else: + return response + + def make_request_without_auth(self, command, requests={}): + requests["command"] = command + requests["response"] = "json" + requests = zip(requests.keys(), requests.values()) + requestUrl = "&".join(["=".join([request[0], urllib.quote_plus(str(request[1]))]) for request in requests]) + + self.connection = urllib2.urlopen("http://%s:%d/client/api?%s"%(self.mgtSvr, self.port, requestUrl)) + self.logging.debug("sending request without auth: %s"%requestUrl) + response = self.connection.read() + self.logging.debug("got response: %s"%response) + return response + + def pollAsyncJob(self, jobId, response): + cmd = queryAsyncJobResult.queryAsyncJobResultCmd() + cmd.jobid = jobId + + while self.asyncTimeout > 0: + asyncResonse = self.make_request(cmd, response, True) + + if asyncResonse.jobstatus == 2: + raise cloudstackException.cloudstackAPIException("asyncquery", asyncResonse.jobresult) + elif asyncResonse.jobstatus == 1: + return asyncResonse + + time.sleep(5) + self.asyncTimeout = self.asyncTimeout - 5 + + raise cloudstackException.cloudstackAPIException("asyncquery", "Async job timeout") + + def make_request(self, cmd, response = None, raw=False): + commandName = cmd.__class__.__name__.replace("Cmd", "") + isAsync = "false" + requests = {} + required = [] + for attribute in dir(cmd): + if attribute != "__doc__" and attribute != "__init__" and attribute != "__module__": + if attribute == "isAsync": + isAsync = getattr(cmd, attribute) + elif attribute == "required": + required = getattr(cmd, attribute) + else: + requests[attribute] = getattr(cmd, attribute) + + for requiredPara in required: + if requests[requiredPara] is None: + raise cloudstackException.cloudstackAPIException(commandName, "%s is required"%requiredPara) + '''remove none value''' + for param, value in requests.items(): + if value is None: + requests.pop(param) + elif isinstance(value, list): + if len(value) == 0: + requests.pop(param) + else: + if not isinstance(value[0], dict): + requests[param] = ",".join(value) + else: + requests.pop(param) + i = 0 + for v in value: + for key, val in v.iteritems(): + requests["%s[%d].%s"%(param,i,key)] = val + i = i + 1 + + if self.logging is not None: + self.logging.debug("sending command: %s %s"%(commandName, str(requests))) + result = None + if self.auth: + result = self.make_request_with_auth(commandName, requests) + else: + result = self.make_request_without_auth(commandName, requests) + + if result is None: + return None + if self.logging is not None: + self.logging.debug("got result: " + result) + + result = jsonHelper.getResultObj(result, response) + if raw or isAsync == "false": + return result + else: + asynJobId = result.jobid + result = self.pollAsyncJob(asynJobId, response) + return result.jobresult + +if __name__ == '__main__': + xml = '407i-1-407-RS3i-1-407-RS3system1ROOT2011-07-30T14:45:19-0700Runningfalse1CA13kvm-50-2054CentOS 5.5(64-bit) no GUI (KVM)CentOS 5.5(64-bit) no GUI (KVM)false1Small Instance15005121120NetworkFilesystem380203255.255.255.065.19.181.165.19.181.110vlan://65vlan://65GuestDirecttrue06:52:da:00:00:08KVM' + conn = cloudConnection(None) + + print conn.paraseReturnXML(xml, deployVirtualMachine.deployVirtualMachineResponse()) diff --git a/tools/marvin/marvin/cloudstackException.py b/tools/marvin/marvin/cloudstackException.py new file mode 100644 index 00000000000..f731be383c7 --- /dev/null +++ b/tools/marvin/marvin/cloudstackException.py @@ -0,0 +1,24 @@ + +class cloudstackAPIException(Exception): + def __init__(self, cmd = "", result = ""): + self.errorMsg = "Execute cmd: %s failed, due to: %s"%(cmd, result) + def __str__(self): + return self.errorMsg + +class InvalidParameterException(Exception): + def __init__(self, msg=''): + self.errorMsg = msg + def __str__(self): + return self.errorMsg + +class dbException(Exception): + def __init__(self, msg=''): + self.errorMsg = msg + def __str__(self): + return self.errorMsg + +class internalError(Exception): + def __init__(self, msg=''): + self.errorMsg = msg + def __str__(self): + return self.errorMsg \ No newline at end of file diff --git a/tools/marvin/marvin/cloudstackTestCase.py b/tools/marvin/marvin/cloudstackTestCase.py new file mode 100644 index 00000000000..4595aefaa42 --- /dev/null +++ b/tools/marvin/marvin/cloudstackTestCase.py @@ -0,0 +1,11 @@ +from cloudstackAPI import * +try: + import unittest2 as unittest +except ImportError: + import unittest +import cloudstackTestClient + +class cloudstackTestCase(unittest.case.TestCase): + def __init__(self, args): + unittest.case.TestCase.__init__(self, args) + self.testClient = cloudstackTestClient.cloudstackTestClient() diff --git a/tools/marvin/marvin/cloudstackTestClient.py b/tools/marvin/marvin/cloudstackTestClient.py new file mode 100644 index 00000000000..0d8f8108ea9 --- /dev/null +++ b/tools/marvin/marvin/cloudstackTestClient.py @@ -0,0 +1,58 @@ +import cloudstackConnection +import asyncJobMgr +import dbConnection +from cloudstackAPI import * + +class cloudstackTestClient(object): + def __init__(self, mgtSvr=None, port=8096, apiKey = None, securityKey = None, asyncTimeout=3600, defaultWorkerThreads=10, logging=None): + self.connection = cloudstackConnection.cloudConnection(mgtSvr, port, apiKey, securityKey, asyncTimeout, logging) + self.apiClient = cloudstackAPIClient.CloudStackAPIClient(self.connection) + self.dbConnection = None + self.asyncJobMgr = None + self.ssh = None + self.defaultWorkerThreads = defaultWorkerThreads + + + def dbConfigure(self, host="localhost", port=3306, user='cloud', passwd='cloud', db='cloud'): + self.dbConnection = dbConnection.dbConnection(host, port, user, passwd, db) + + def close(self): + if self.connection is not None: + self.connection.close() + if self.dbConnection is not None: + self.dbConnection.close() + + def getDbConnection(self): + return self.dbConnection + + def executeSql(self, sql=None): + if sql is None or self.dbConnection is None: + return None + + return self.dbConnection.execute() + + def executeSqlFromFile(self, sqlFile=None): + if sqlFile is None or self.dbConnection is None: + return None + return self.dbConnection.executeSqlFromFile(sqlFile) + + def getApiClient(self): + return self.apiClient + + '''FixME, httplib has issue if more than one thread submitted''' + def submitCmdsAndWait(self, cmds, workers=1): + if self.asyncJobMgr is None: + self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection) + return self.asyncJobMgr.submitCmdsAndWait(cmds, workers) + + '''submit one job and execute the same job ntimes, with nums_threads of threads''' + def submitJob(self, job, ntimes=1, nums_threads=10, interval=1): + if self.asyncJobMgr is None: + self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection) + self.asyncJobMgr.submitJobExecuteNtimes(job, ntimes, nums_threads, interval) + + '''submit n jobs, execute them with nums_threads of threads''' + def submitJobs(self, jobs, nums_threads=10, interval=1): + if self.asyncJobMgr is None: + self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection) + self.asyncJobMgr.submitJobs(jobs, nums_threads, interval) \ No newline at end of file diff --git a/tools/marvin/marvin/codegenerator.py b/tools/marvin/marvin/codegenerator.py new file mode 100644 index 00000000000..aea53b8854b --- /dev/null +++ b/tools/marvin/marvin/codegenerator.py @@ -0,0 +1,277 @@ +import xml.dom.minidom +from optparse import OptionParser +import os +import sys +class cmdParameterProperty(object): + def __init__(self): + self.name = None + self.required = False + self.desc = "" + self.type = "planObject" + self.subProperties = [] + +class cloudStackCmd: + def __init__(self): + self.name = "" + self.desc = "" + self.async = "false" + self.request = [] + self.response = [] + +class codeGenerator: + space = " " + + cmdsName = [] + + def __init__(self, outputFolder, apiSpecFile): + self.cmd = None + self.code = "" + self.required = [] + self.subclass = [] + self.outputFolder = outputFolder + self.apiSpecFile = apiSpecFile + + def addAttribute(self, attr, pro): + value = pro.value + if pro.required: + self.required.append(attr) + desc = pro.desc + if desc is not None: + self.code += self.space + self.code += "''' " + pro.desc + " '''" + self.code += "\n" + + self.code += self.space + self.code += attr + " = " + str(value) + self.code += "\n" + + def generateSubClass(self, name, properties): + '''generate code for sub list''' + subclass = 'class %s:\n'%name + subclass += self.space + "def __init__(self):\n" + for pro in properties: + if pro.desc is not None: + subclass += self.space + self.space + '""""%s"""\n'%pro.desc + if len (pro.subProperties) > 0: + subclass += self.space + self.space + 'self.%s = []\n'%pro.name + self.generateSubClass(pro.name, pro.subProperties) + else: + subclass += self.space + self.space + 'self.%s = None\n'%pro.name + + self.subclass.append(subclass) + def generate(self, cmd): + + self.cmd = cmd + self.cmdsName.append(self.cmd.name) + self.code += "\n" + self.code += '"""%s"""\n'%self.cmd.desc + self.code += 'from baseCmd import *\n' + self.code += 'from baseResponse import *\n' + self.code += "class %sCmd (baseCmd):\n"%self.cmd.name + self.code += self.space + "def __init__(self):\n" + + self.code += self.space + self.space + 'self.isAsync = "%s"\n' %self.cmd.async + + for req in self.cmd.request: + if req.desc is not None: + self.code += self.space + self.space + '"""%s"""\n'%req.desc + if req.required == "true": + self.code += self.space + self.space + '"""Required"""\n' + + value = "None" + if req.type == "list" or req.type == "map": + value = "[]" + + self.code += self.space + self.space + 'self.%s = %s\n'%(req.name,value) + if req.required == "true": + self.required.append(req.name) + + self.code += self.space + self.space + "self.required = [" + for require in self.required: + self.code += '"' + require + '",' + self.code += "]\n" + self.required = [] + + + """generate response code""" + subItems = {} + self.code += "\n" + self.code += 'class %sResponse (baseResponse):\n'%self.cmd.name + self.code += self.space + "def __init__(self):\n" + if len(self.cmd.response) == 0: + self.code += self.space + self.space + "pass" + else: + for res in self.cmd.response: + if res.desc is not None: + self.code += self.space + self.space + '"""%s"""\n'%res.desc + + if len(res.subProperties) > 0: + self.code += self.space + self.space + 'self.%s = []\n'%res.name + self.generateSubClass(res.name, res.subProperties) + else: + self.code += self.space + self.space + 'self.%s = None\n'%res.name + self.code += '\n' + + for subclass in self.subclass: + self.code += subclass + "\n" + + fp = open(self.outputFolder + "/cloudstackAPI/%s.py"%self.cmd.name, "w") + fp.write(self.code) + fp.close() + self.code = "" + self.subclass = [] + + + def finalize(self): + '''generate an api call''' + + header = '"""Test Client for CloudStack API"""\n' + imports = "import copy\n" + initCmdsList = '__all__ = [' + body = '' + body += "class CloudStackAPIClient:\n" + body += self.space + 'def __init__(self, connection):\n' + body += self.space + self.space + 'self.connection = connection\n' + body += "\n" + + body += self.space + 'def __copy__(self):\n' + body += self.space + self.space + 'return CloudStackAPIClient(copy.copy(self.connection))\n' + body += "\n" + + for cmdName in self.cmdsName: + body += self.space + 'def %s(self,command):\n'%cmdName + body += self.space + self.space + 'response = %sResponse()\n'%cmdName + body += self.space + self.space + 'response = self.connection.make_request(command, response)\n' + body += self.space + self.space + 'return response\n' + body += '\n' + + imports += 'from %s import %sResponse\n'%(cmdName, cmdName) + initCmdsList += '"%s",'%cmdName + + fp = open(self.outputFolder + '/cloudstackAPI/cloudstackAPIClient.py', 'w') + for item in [header, imports, body]: + fp.write(item) + fp.close() + + '''generate __init__.py''' + initCmdsList += '"cloudstackAPIClient"]' + fp = open(self.outputFolder + '/cloudstackAPI/__init__.py', 'w') + fp.write(initCmdsList) + fp.close() + + fp = open(self.outputFolder + '/cloudstackAPI/baseCmd.py', 'w') + basecmd = '"""Base Command"""\n' + basecmd += 'class baseCmd:\n' + basecmd += self.space + 'pass\n' + fp.write(basecmd) + fp.close() + + fp = open(self.outputFolder + '/cloudstackAPI/baseResponse.py', 'w') + basecmd = '"""Base class for response"""\n' + basecmd += 'class baseResponse:\n' + basecmd += self.space + 'pass\n' + fp.write(basecmd) + fp.close() + + + def constructResponse(self, response): + paramProperty = cmdParameterProperty() + paramProperty.name = getText(response.getElementsByTagName('name')) + paramProperty.desc = getText(response.getElementsByTagName('description')) + if paramProperty.name.find('(*)') != -1: + '''This is a list''' + paramProperty.name = paramProperty.name.split('(*)')[0] + for subresponse in response.getElementsByTagName('arguments')[0].getElementsByTagName('arg'): + subProperty = self.constructResponse(subresponse) + paramProperty.subProperties.append(subProperty) + return paramProperty + + def loadCmdFromXML(self): + dom = xml.dom.minidom.parse(self.apiSpecFile) + cmds = [] + for cmd in dom.getElementsByTagName("command"): + csCmd = cloudStackCmd() + csCmd.name = getText(cmd.getElementsByTagName('name')) + assert csCmd.name + + desc = getText(cmd.getElementsByTagName('description')) + if desc: + csCmd.desc = desc + + async = getText(cmd.getElementsByTagName('isAsync')) + if async: + csCmd.async = async + + for param in cmd.getElementsByTagName("request")[0].getElementsByTagName("arg"): + paramProperty = cmdParameterProperty() + + paramProperty.name = getText(param.getElementsByTagName('name')) + assert paramProperty.name + + required = param.getElementsByTagName('required') + if required: + paramProperty.required = getText(required) + + requestDescription = param.getElementsByTagName('description') + if requestDescription: + paramProperty.desc = getText(requestDescription) + + type = param.getElementsByTagName("type") + if type: + paramProperty.type = getText(type) + + csCmd.request.append(paramProperty) + + responseEle = cmd.getElementsByTagName("response")[0] + for response in responseEle.getElementsByTagName("arg"): + if response.parentNode != responseEle: + continue + + paramProperty = self.constructResponse(response) + csCmd.response.append(paramProperty) + + cmds.append(csCmd) + return cmds + + def generateCode(self): + cmds = self.loadCmdFromXML() + for cmd in cmds: + self.generate(cmd) + self.finalize() + +def getText(elements): + return elements[0].childNodes[0].nodeValue.strip() + +if __name__ == "__main__": + parser = OptionParser() + + parser.add_option("-o", "--output", dest="output", help="the root path where code genereted, default is .") + parser.add_option("-s", "--specfile", dest="spec", help="the path and name of the api spec xml file, default is /etc/cloud/cli/commands.xml") + + (options, args) = parser.parse_args() + + apiSpecFile = "/etc/cloud/cli/commands.xml" + if options.spec is not None: + apiSpecFile = options.spec + + if not os.path.exists(apiSpecFile): + print "the spec file %s does not exists"%apiSpecFile + print parser.print_help() + exit(1) + + + folder = "." + if options.output is not None: + folder = options.output + apiModule=folder + "/cloudstackAPI" + if not os.path.exists(apiModule): + try: + os.mkdir(apiModule) + except: + print "Failed to create folder %s, due to %s"%(apiModule,sys.exc_info()) + print parser.print_help() + exit(2) + + cg = codeGenerator(folder, apiSpecFile) + cg.generateCode() + diff --git a/tools/marvin/marvin/configGenerator.py b/tools/marvin/marvin/configGenerator.py new file mode 100644 index 00000000000..3baa3edcf52 --- /dev/null +++ b/tools/marvin/marvin/configGenerator.py @@ -0,0 +1,378 @@ +import json +import os +from optparse import OptionParser +import jsonHelper + +class managementServer(): + def __init__(self): + self.mgtSvrIp = None + self.port = 8096 + self.apiKey = None + self.securityKey = None + +class dbServer(): + def __init__(self): + self.dbSvr = None + self.port = 3306 + self.user = "cloud" + self.passwd = "cloud" + self.db = "cloud" + +class configuration(): + def __init__(self): + self.name = None + self.value = None + +class logger(): + def __init__(self): + '''TestCase/TestClient''' + self.name = None + self.file = None + +class cloudstackConfiguration(): + def __init__(self): + self.zones = [] + self.mgtSvr = [] + self.dbSvr = None + self.globalConfig = [] + self.logger = [] + +class zone(): + def __init__(self): + self.dns1 = None + self.internaldns1 = None + self.name = None + '''Basic or Advanced''' + self.networktype = None + self.dns2 = None + self.guestcidraddress = None + self.internaldns2 = None + self.securitygroupenabled = None + self.vlan = None + '''default public network, in advanced mode''' + self.ipranges = [] + '''tagged network, in advanced mode''' + self.networks = [] + self.pods = [] + self.secondaryStorages = [] + +class pod(): + def __init__(self): + self.gateway = None + self.name = None + self.netmask = None + self.startip = None + self.endip = None + self.zoneid = None + self.clusters = [] + '''Used in basic network mode''' + self.guestIpRanges = [] + +class cluster(): + def __init__(self): + self.clustername = None + self.clustertype = None + self.hypervisor = None + self.zoneid = None + self.podid = None + self.password = None + self.url = None + self.username = None + self.hosts = [] + self.primaryStorages = [] + +class host(): + def __init__(self): + self.hypervisor = None + self.password = None + self.url = None + self.username = None + self.zoneid = None + self.podid = None + self.clusterid = None + self.clustername = None + self.cpunumber = None + self.cpuspeed = None + self.hostmac = None + self.hosttags = None + self.memory = None + +class network(): + def __init__(self): + self.displaytext = None + self.name = None + self.zoneid = None + self.account = None + self.domainid = None + self.isdefault = None + self.isshared = None + self.networkdomain = None + self.networkofferingid = None + self.ipranges = [] + +class iprange(): + def __init__(self): + '''tagged/untagged''' + self.gateway = None + self.netmask = None + self.startip = None + self.endip = None + self.vlan = None + '''for account specific ''' + self.account = None + self.domain = None + +class primaryStorage(): + def __init__(self): + self.name = None + self.url = None + +class secondaryStorage(): + def __init__(self): + self.url = None + +'''sample code to generate setup configuration file''' +def describe_setup_in_basic_mode(): + zs = cloudstackConfiguration() + + for l in range(1): + z = zone() + z.dns1 = "8.8.8.8" + z.dns2 = "4.4.4.4" + z.internaldns1 = "192.168.110.254" + z.internaldns2 = "192.168.110.253" + z.name = "test"+str(l) + z.networktype = 'Basic' + + '''create 10 pods''' + for i in range(300): + p = pod() + p.name = "test" +str(l) + str(i) + p.gateway = "192.%d.%d.1"%((i/255)+168,i%255) + p.netmask = "255.255.255.0" + + p.startip = "192.%d.%d.150"%((i/255)+168,i%255) + p.endip = "192.%d.%d.220"%((i/255)+168,i%255) + + '''add two pod guest ip ranges''' + for j in range(1): + ip = iprange() + ip.gateway = p.gateway + ip.netmask = p.netmask + ip.startip = "192.%d.%d.%d"%(((i/255)+168), i%255,j*20) + ip.endip = "192.%d.%d.%d"%((i/255)+168,i%255,j*20+10) + + p.guestIpRanges.append(ip) + + '''add 10 clusters''' + for j in range(10): + c = cluster() + c.clustername = "test"+str(l)+str(i) + str(j) + c.clustertype = "CloudManaged" + c.hypervisor = "Simulator" + + '''add 10 hosts''' + for k in range(1): + h = host() + h.username = "root" + h.password = "password" + memory = 8*1024*1024*1024 + localstorage=1*1024*1024*1024*1024 + #h.url = "http://Sim/%d%d%d%d/cpucore=1&cpuspeed=8000&memory=%d&localstorage=%d"%(l,i,j,k,memory,localstorage) + h.url = "http://Sim/%d%d%d%d"%(l,i,j,k) + c.hosts.append(h) + + '''add 2 primary storages''' + ''' + for m in range(2): + primary = primaryStorage() + size=1*1024*1024*1024*1024 + primary.name = "primary"+str(l) + str(i) + str(j) + str(m) + #primary.url = "nfs://localhost/path%s/size=%d"%(str(l) + str(i) + str(j) + str(m), size) + primary.url = "nfs://localhost/path%s"%(str(l) + str(i) + str(j) + str(m)) + c.primaryStorages.append(primary) + ''' + + p.clusters.append(c) + + z.pods.append(p) + + '''add two secondary''' + for i in range(5): + secondary = secondaryStorage() + secondary.url = "nfs://localhost/path"+str(l) + str(i) + z.secondaryStorages.append(secondary) + + zs.zones.append(z) + + '''Add one mgt server''' + mgt = managementServer() + mgt.mgtSvrIp = "localhost" + zs.mgtSvr.append(mgt) + + '''Add a database''' + db = dbServer() + db.dbSvr = "localhost" + + zs.dbSvr = db + + '''add global configuration''' + global_settings = {'expunge.delay': '60', + 'expunge.interval': '60', + 'expunge.workers': '3', + } + for k,v in global_settings.iteritems(): + cfg = configuration() + cfg.name = k + cfg.value = v + zs.globalConfig.append(cfg) + + ''''add loggers''' + testClientLogger = logger() + testClientLogger.name = "TestClient" + testClientLogger.file = "/tmp/testclient.log" + + testCaseLogger = logger() + testCaseLogger.name = "TestCase" + testCaseLogger.file = "/tmp/testcase.log" + + zs.logger.append(testClientLogger) + zs.logger.append(testCaseLogger) + + return zs + +'''sample code to generate setup configuration file''' +def describe_setup_in_advanced_mode(): + zs = cloudstackConfiguration() + + for l in range(1): + z = zone() + z.dns1 = "8.8.8.8" + z.dns2 = "4.4.4.4" + z.internaldns1 = "192.168.110.254" + z.internaldns2 = "192.168.110.253" + z.name = "test"+str(l) + z.networktype = 'Advanced' + z.guestcidraddress = "10.1.1.0/24" + z.vlan = "100-2000" + + '''create 10 pods''' + for i in range(2): + p = pod() + p.name = "test" +str(l) + str(i) + p.gateway = "192.168.%d.1"%i + p.netmask = "255.255.255.0" + p.startip = "192.168.%d.200"%i + p.endip = "192.168.%d.220"%i + + '''add 10 clusters''' + for j in range(2): + c = cluster() + c.clustername = "test"+str(l)+str(i) + str(j) + c.clustertype = "CloudManaged" + c.hypervisor = "Simulator" + + '''add 10 hosts''' + for k in range(2): + h = host() + h.username = "root" + h.password = "password" + memory = 8*1024*1024*1024 + localstorage=1*1024*1024*1024*1024 + #h.url = "http://Sim/%d%d%d%d/cpucore=1&cpuspeed=8000&memory=%d&localstorage=%d"%(l,i,j,k,memory,localstorage) + h.url = "http://Sim/%d%d%d%d"%(l,i,j,k) + c.hosts.append(h) + + '''add 2 primary storages''' + for m in range(2): + primary = primaryStorage() + size=1*1024*1024*1024*1024 + primary.name = "primary"+str(l) + str(i) + str(j) + str(m) + #primary.url = "nfs://localhost/path%s/size=%d"%(str(l) + str(i) + str(j) + str(m), size) + primary.url = "nfs://localhost/path%s"%(str(l) + str(i) + str(j) + str(m)) + c.primaryStorages.append(primary) + + p.clusters.append(c) + + z.pods.append(p) + + '''add two secondary''' + for i in range(5): + secondary = secondaryStorage() + secondary.url = "nfs://localhost/path"+str(l) + str(i) + z.secondaryStorages.append(secondary) + + '''add default public network''' + ips = iprange() + ips.vlan = "26" + ips.startip = "172.16.26.2" + ips.endip = "172.16.26.100" + ips.gateway = "172.16.26.1" + ips.netmask = "255.255.255.0" + z.ipranges.append(ips) + + + zs.zones.append(z) + + '''Add one mgt server''' + mgt = managementServer() + mgt.mgtSvrIp = "localhost" + zs.mgtSvr.append(mgt) + + '''Add a database''' + db = dbServer() + db.dbSvr = "localhost" + + zs.dbSvr = db + + '''add global configuration''' + global_settings = {'expunge.delay': '60', + 'expunge.interval': '60', + 'expunge.workers': '3', + } + for k,v in global_settings.iteritems(): + cfg = configuration() + cfg.name = k + cfg.value = v + zs.globalConfig.append(cfg) + + ''''add loggers''' + testClientLogger = logger() + testClientLogger.name = "TestClient" + testClientLogger.file = "/tmp/testclient.log" + + testCaseLogger = logger() + testCaseLogger.name = "TestCase" + testCaseLogger.file = "/tmp/testcase.log" + + zs.logger.append(testClientLogger) + zs.logger.append(testCaseLogger) + + return zs + +def generate_setup_config(config, file=None): + describe = config + if file is None: + return json.dumps(jsonHelper.jsonDump.dump(describe)) + else: + fp = open(file, 'w') + json.dump(jsonHelper.jsonDump.dump(describe), fp, indent=4) + fp.close() + + +def get_setup_config(file): + if not os.path.exists(file): + return None + config = cloudstackConfiguration() + fp = open(file, 'r') + config = json.load(fp) + return jsonHelper.jsonLoader(config) + +if __name__ == "__main__": + parser = OptionParser() + + parser.add_option("-o", "--output", action="store", default="./datacenterCfg", dest="output", help="the path where the json config file generated, by default is ./datacenterCfg") + + (options, args) = parser.parse_args() + config = describe_setup_in_basic_mode() + generate_setup_config(config, options.output) diff --git a/tools/marvin/marvin/dbConnection.py b/tools/marvin/marvin/dbConnection.py new file mode 100644 index 00000000000..e6135edf00f --- /dev/null +++ b/tools/marvin/marvin/dbConnection.py @@ -0,0 +1,78 @@ +import pymysql +import cloudstackException +import sys +import os +import traceback +class dbConnection(object): + def __init__(self, host="localhost", port=3306, user='cloud', passwd='cloud', db='cloud'): + self.host = host + self.port = port + self.user = user + self.passwd = passwd + self.database = db + + try: + self.db = pymysql.Connect(host=host, port=port, user=user, passwd=passwd, db=db) + except: + traceback.print_exc() + raise cloudstackException.InvalidParameterException(sys.exc_info()) + + def __copy__(self): + return dbConnection(self.host, self.port, self.user, self.passwd, self.database) + + def close(self): + try: + self.db.close() + except: + pass + + def execute(self, sql=None): + if sql is None: + return None + + resultRow = [] + cursor = None + try: + cursor = self.db.cursor() + cursor.execute(sql) + + result = cursor.fetchall() + if result is not None: + for r in result: + resultRow.append(r) + return resultRow + except pymysql.MySQLError, e: + raise cloudstackException.dbException("db Exception:%s"%e[1]) + except: + raise cloudstackException.internalError(sys.exc_info()) + finally: + if cursor is not None: + cursor.close() + + def executeSqlFromFile(self, fileName=None): + if fileName is None: + raise cloudstackException.InvalidParameterException("file can't not none") + + if not os.path.exists(fileName): + raise cloudstackException.InvalidParameterException("%s not exists"%fileName) + + sqls = open(fileName, "r").read() + return self.execute(sqls) + +if __name__ == "__main__": + db = dbConnection() + ''' + try: + + result = db.executeSqlFromFile("/tmp/server-setup.sql") + if result is not None: + for r in result: + print r[0], r[1] + except cloudstackException.dbException, e: + print e + ''' + print db.execute("update vm_template set name='fjkd' where id=200") + for i in range(10): + result = db.execute("select job_status, created, last_updated from async_job where id=%d"%i) + print result + diff --git a/tools/marvin/marvin/deployAndRun.py b/tools/marvin/marvin/deployAndRun.py new file mode 100644 index 00000000000..084eb995418 --- /dev/null +++ b/tools/marvin/marvin/deployAndRun.py @@ -0,0 +1,32 @@ +import deployDataCenter +import TestCaseExecuteEngine +from optparse import OptionParser +import os +if __name__ == "__main__": + parser = OptionParser() + + parser.add_option("-c", "--config", action="store", default="./datacenterCfg", dest="config", help="the path where the json config file generated, by default is ./datacenterCfg") + parser.add_option("-d", "--directory", dest="testCaseFolder", help="the test case directory") + parser.add_option("-r", "--result", dest="result", help="test result log file") + parser.add_option("-t", dest="testcaselog", help="test case log file") + parser.add_option("-l", "--load", dest="load", action="store_true", help="only load config, do not deploy, it will only run testcase") + (options, args) = parser.parse_args() + if options.testCaseFolder is None: + parser.print_usage() + exit(1) + + testResultLogFile = None + if options.result is not None: + testResultLogFile = options.result + + testCaseLogFile = None + if options.testcaselog is not None: + testCaseLogFile = options.testcaselog + deploy = deployDataCenter.deployDataCenters(options.config) + if options.load: + deploy.loadCfg() + else: + deploy.deploy() + + testcaseEngine = TestCaseExecuteEngine.TestCaseExecuteEngine(deploy.testClient, options.testCaseFolder, testCaseLogFile, testResultLogFile) + testcaseEngine.run() \ No newline at end of file diff --git a/tools/marvin/marvin/deployDataCenter.py b/tools/marvin/marvin/deployDataCenter.py new file mode 100644 index 00000000000..e47dc726676 --- /dev/null +++ b/tools/marvin/marvin/deployDataCenter.py @@ -0,0 +1,274 @@ +'''Deploy datacenters according to a json configuration file''' +import configGenerator +import cloudstackException +import cloudstackTestClient +import sys +import logging +from cloudstackAPI import * +from optparse import OptionParser + +module_logger = "testclient.deploy" + + +class deployDataCenters(): + def __init__(self, cfgFile): + self.configFile = cfgFile + + def addHosts(self, hosts, zoneId, podId, clusterId, hypervisor): + if hosts is None: + return + for host in hosts: + hostcmd = addHost.addHostCmd() + hostcmd.clusterid = clusterId + hostcmd.cpunumber = host.cpunumer + hostcmd.cpuspeed = host.cpuspeed + hostcmd.hostmac = host.hostmac + hostcmd.hosttags = host.hosttags + hostcmd.hypervisor = host.hypervisor + hostcmd.memory = host.memory + hostcmd.password = host.password + hostcmd.podid = podId + hostcmd.url = host.url + hostcmd.username = host.username + hostcmd.zoneid = zoneId + hostcmd.hypervisor = hypervisor + self.apiClient.addHost(hostcmd) + + def createClusters(self, clusters, zoneId, podId): + if clusters is None: + return + + for cluster in clusters: + clustercmd = addCluster.addClusterCmd() + clustercmd.clustername = cluster.clustername + clustercmd.clustertype = cluster.clustertype + clustercmd.hypervisor = cluster.hypervisor + clustercmd.password = cluster.password + clustercmd.podid = podId + clustercmd.url = cluster.url + clustercmd.username = cluster.username + clustercmd.zoneid = zoneId + clusterresponse = self.apiClient.addCluster(clustercmd) + clusterId = clusterresponse[0].id + + self.addHosts(cluster.hosts, zoneId, podId, clusterId, cluster.hypervisor) + self.createPrimaryStorages(cluster.primaryStorages, zoneId, podId, clusterId) + + def createPrimaryStorages(self, primaryStorages, zoneId, podId, clusterId): + if primaryStorages is None: + return + for primary in primaryStorages: + primarycmd = createStoragePool.createStoragePoolCmd() + primarycmd.details = primary.details + primarycmd.name = primary.name + primarycmd.podid = podId + primarycmd.tags = primary.tags + primarycmd.url = primary.url + primarycmd.zoneid = zoneId + primarycmd.clusterid = clusterId + self.apiClient.createStoragePool(primarycmd) + + def createpods(self, pods, zone, zoneId): + if pods is None: + return + for pod in pods: + createpod = createPod.createPodCmd() + createpod.name = pod.name + createpod.gateway = pod.gateway + createpod.netmask = pod.netmask + createpod.startip = pod.startip + createpod.endip = pod.endip + createpod.zoneid = zoneId + createpodResponse = self.apiClient.createPod(createpod) + podId = createpodResponse.id + + if pod.guestIpRanges is not None: + self.createVlanIpRanges("Basic", pod.guestIpRanges, zoneId, podId) + + self.createClusters(pod.clusters, zoneId, podId) + + + def createVlanIpRanges(self, mode, ipranges, zoneId, podId=None, networkId=None): + if ipranges is None: + return + for iprange in ipranges: + vlanipcmd = createVlanIpRange.createVlanIpRangeCmd() + vlanipcmd.account = iprange.account + vlanipcmd.domainid = iprange.domainid + vlanipcmd.endip = iprange.endip + vlanipcmd.gateway = iprange.gateway + vlanipcmd.netmask = iprange.netmask + vlanipcmd.networkid = networkId + vlanipcmd.podid = podId + vlanipcmd.startip = iprange.startip + vlanipcmd.zoneid = zoneId + vlanipcmd.vlan = iprange.vlan + if mode == "Basic": + vlanipcmd.forvirtualnetwork = "false" + else: + vlanipcmd.forvirtualnetwork = "true" + + self.apiClient.createVlanIpRange(vlanipcmd) + + def createSecondaryStorages(self, secondaryStorages, zoneId): + if secondaryStorages is None: + return + for secondary in secondaryStorages: + secondarycmd = addSecondaryStorage.addSecondaryStorageCmd() + secondarycmd.url = secondary.url + secondarycmd.zoneid = zoneId + self.apiClient.addSecondaryStorage(secondarycmd) + + def createnetworks(self, networks, zoneId): + if networks is None: + return + for network in networks: + ipranges = network.ipranges + if ipranges is None: + continue + iprange = ipranges.pop() + networkcmd = createNetwork.createNetworkCmd() + networkcmd.account = network.account + networkcmd.displaytext = network.displaytext + networkcmd.domainid = network.domainid + networkcmd.endip = iprange.endip + networkcmd.gateway = iprange.gateway + networkcmd.isdefault = network.isdefault + networkcmd.isshared = network.isshared + networkcmd.name = network.name + networkcmd.netmask = iprange.netmask + networkcmd.networkdomain = network.networkdomain + networkcmd.networkofferingid = network.networkofferingid + networkcmdresponse = self.apiClient.createNetwork(networkcmd) + networkId = networkcmdresponse.id + self.createVlanIpRanges("Advanced", ipranges, zoneId, networkId=networkId) + + def createZones(self, zones): + for zone in zones: + '''create a zone''' + createzone = createZone.createZoneCmd() + createzone.guestcidraddress = zone.guestcidraddress + createzone.dns1 = zone.dns1 + createzone.dns2 = zone.dns2 + createzone.internaldns1 = zone.internaldns1 + createzone.internaldns2 = zone.internaldns2 + createzone.name = zone.name + createzone.securitygroupenabled = zone.securitygroupenabled + createzone.networktype = zone.networktype + createzone.vlan = zone.vlan + + zoneresponse = self.apiClient.createZone(createzone) + zoneId = zoneresponse.id + + '''create pods''' + self.createpods(zone.pods, zone, zoneId) + + if zone.networktype == "Advanced": + '''create pubic network''' + self.createVlanIpRanges(zone.networktype, zone.ipranges, zoneId) + + self.createnetworks(zone.networks, zoneId) + '''create secondary storage''' + self.createSecondaryStorages(zone.secondaryStorages, zoneId) + return + + def registerApiKey(self): + listuser = listUsers.listUsersCmd() + listuser.account = "admin" + listuserRes = self.testClient.getApiClient().listUsers(listuser) + userId = listuserRes[0].id + apiKey = listuserRes[0].apikey + securityKey = listuserRes[0].secretkey + if apiKey is None: + registerUser = registerUserKeys.registerUserKeysCmd() + registerUser.id = userId + registerUserRes = self.testClient.getApiClient().registerUserKeys(registerUser) + apiKey = registerUserRes.apikey + securityKey = registerUserRes.secretkey + + self.config.mgtSvr[0].port = 8080 + self.config.mgtSvr[0].apiKey = apiKey + self.config.mgtSvr[0].securityKey = securityKey + return apiKey, securityKey + + def loadCfg(self): + try: + self.config = configGenerator.get_setup_config(self.configFile) + except: + raise cloudstackException.InvalidParameterException( \ + "Failed to load config" + sys.exc_info()) + + mgt = self.config.mgtSvr[0] + + loggers = self.config.logger + testClientLogFile = None + self.testCaseLogFile = None + self.testResultLogFile = None + if loggers is not None and len(loggers) > 0: + for log in loggers: + if log.name == "TestClient": + testClientLogFile = log.file + elif log.name == "TestCase": + self.testCaseLogFile = log.file + elif log.name == "TestResult": + self.testResultLogFile = log.file + + testClientLogger = None + if testClientLogFile is not None: + testClientLogger = logging.getLogger("testclient.deploy.deployDataCenters") + fh = logging.FileHandler(testClientLogFile) + fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s")) + testClientLogger.addHandler(fh) + testClientLogger.setLevel(logging.DEBUG) + self.testClientLogger = testClientLogger + + self.testClient = cloudstackTestClient.cloudstackTestClient(mgt.mgtSvrIp, mgt.port, mgt.apiKey, mgt.securityKey, logging=self.testClientLogger) + if mgt.apiKey is None: + apiKey, securityKey = self.registerApiKey() + self.testClient.close() + self.testClient = cloudstackTestClient.cloudstackTestClient(mgt.mgtSvrIp, 8080, apiKey, securityKey, logging=self.testClientLogger) + + '''config database''' + dbSvr = self.config.dbSvr + self.testClient.dbConfigure(dbSvr.dbSvr, dbSvr.port, dbSvr.user, dbSvr.passwd, dbSvr.db) + self.apiClient = self.testClient.getApiClient() + + def updateConfiguration(self, globalCfg): + if globalCfg is None: + return None + + for config in globalCfg: + updateCfg = updateConfiguration.updateConfigurationCmd() + updateCfg.name = config.name + updateCfg.value = config.value + self.apiClient.updateConfiguration(updateCfg) + + def deploy(self): + self.loadCfg() + self.createZones(self.config.zones) + self.updateConfiguration(self.config.globalConfig) + + +if __name__ == "__main__": + + parser = OptionParser() + + parser.add_option("-i", "--intput", action="store", default="./datacenterCfg", dest="input", help="the path where the json config file generated, by default is ./datacenterCfg") + + (options, args) = parser.parse_args() + + deploy = deployDataCenters(options.input) + deploy.deploy() + + ''' + create = createStoragePool.createStoragePoolCmd() + create.clusterid = 1 + create.podid = 2 + create.name = "fdffdf" + create.url = "nfs://jfkdjf/fdkjfkd" + create.zoneid = 2 + + deploy = deployDataCenters("./datacenterCfg") + deploy.loadCfg() + deploy.apiClient.createStoragePool(create) + ''' diff --git a/tools/marvin/marvin/jsonHelper.py b/tools/marvin/marvin/jsonHelper.py new file mode 100644 index 00000000000..6bf6d056f1f --- /dev/null +++ b/tools/marvin/marvin/jsonHelper.py @@ -0,0 +1,174 @@ +import cloudstackException +import json +import inspect +from cloudstackAPI import * +import pdb + +class jsonLoader: + '''The recursive class for building and representing objects with.''' + def __init__(self, obj): + for k in obj: + v = obj[k] + if isinstance(v, dict): + setattr(self, k, jsonLoader(v)) + elif isinstance(v, (list, tuple)): + setattr(self, k, [jsonLoader(elem) for elem in v]) + else: + setattr(self,k,v) + def __getattr__(self, val): + if val in self.__dict__: + return self.__dict__[val] + else: + return None + def __repr__(self): + return '{%s}' % str(', '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.iteritems())) + def __str__(self): + return '{%s}' % str(', '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.iteritems())) + + +class jsonDump: + @staticmethod + def __serialize(obj): + """Recursively walk object's hierarchy.""" + if isinstance(obj, (bool, int, long, float, basestring)): + return obj + elif isinstance(obj, dict): + obj = obj.copy() + newobj = {} + for key in obj: + if obj[key] is not None: + if (isinstance(obj[key], list) and len(obj[key]) == 0): + continue + newobj[key] = jsonDump.__serialize(obj[key]) + + return newobj + elif isinstance(obj, list): + return [jsonDump.__serialize(item) for item in obj] + elif isinstance(obj, tuple): + return tuple(jsonDump.__serialize([item for item in obj])) + elif hasattr(obj, '__dict__'): + return jsonDump.__serialize(obj.__dict__) + else: + return repr(obj) # Don't know how to handle, convert to string + + @staticmethod + def dump(obj): + return jsonDump.__serialize(obj) + +def getclassFromName(cmd, name): + module = inspect.getmodule(cmd) + return getattr(module, name)() + +def finalizeResultObj(result, responseName, responsecls): + if responsecls is None and responseName.endswith("response") and responseName != "queryasyncjobresultresponse": + '''infer the response class from the name''' + moduleName = responseName.replace("response", "") + try: + responsecls = getclassFromName(moduleName, responseName) + except: + pass + + if responseName is not None and responseName == "queryasyncjobresultresponse" and responsecls is not None and result.jobresult is not None: + result.jobresult = finalizeResultObj(result.jobresult, None, responsecls) + return result + elif responsecls is not None: + for k,v in result.__dict__.iteritems(): + if k in responsecls.__dict__: + return result + + attr = result.__dict__.keys()[0] + + value = getattr(result, attr) + if not isinstance(value, jsonLoader): + return result + + findObj = False + for k,v in value.__dict__.iteritems(): + if k in responsecls.__dict__: + findObj = True + break + if findObj: + return value + else: + return result + else: + return result + + + +def getResultObj(returnObj, responsecls=None): + returnObj = json.loads(returnObj) + + if len(returnObj) == 0: + return None + responseName = returnObj.keys()[0] + + response = returnObj[responseName] + if len(response) == 0: + return None + + result = jsonLoader(response) + if result.errorcode is not None: + errMsg = "errorCode: %s, errorText:%s"%(result.errorcode, result.errortext) + raise cloudstackException.cloudstackAPIException(responseName.replace("response", ""), errMsg) + + if result.count is not None: + for key in result.__dict__.iterkeys(): + if key == "count": + continue + else: + return getattr(result, key) + else: + return finalizeResultObj(result, responseName, responsecls) + +if __name__ == "__main__": + + result = '{ "listnetworkserviceprovidersresponse" : { "count":1 ,"networkserviceprovider" : [ {"name":"VirtualRouter","physicalnetworkid":"ad2948fc-1054-46c7-b1c7-61d990b86710","destinationphysicalnetworkid":"0","state":"Disabled","id":"d827cae4-4998-4037-95a2-55b92b6318b1","servicelist":["Vpn","Dhcp","Dns","Gateway","Firewall","Lb","SourceNat","StaticNat","PortForwarding","UserData"]} ] } }' + nsp = getResultObj(result) + + result = '{ "listzonesresponse" : { "count":1 ,"zone" : [ {"id":1,"name":"test0","dns1":"8.8.8.8","dns2":"4.4.4.4","internaldns1":"192.168.110.254","internaldns2":"192.168.110.253","networktype":"Basic","securitygroupsenabled":true,"allocationstate":"Enabled","zonetoken":"5e818a11-6b00-3429-9a07-e27511d3169a","dhcpprovider":"DhcpServer"} ] } }' + zones = getResultObj(result) + print zones[0].id + res = authorizeSecurityGroupIngress.authorizeSecurityGroupIngressResponse() + result = '{ "queryasyncjobresultresponse" : {"jobid":10,"jobstatus":1,"jobprocstatus":0,"jobresultcode":0,"jobresulttype":"object","jobresult":{"securitygroup":{"id":1,"name":"default","description":"Default Security Group","account":"admin","domainid":1,"domain":"ROOT","ingressrule":[{"ruleid":1,"protocol":"tcp","startport":22,"endport":22,"securitygroupname":"default","account":"a"},{"ruleid":2,"protocol":"tcp","startport":22,"endport":22,"securitygroupname":"default","account":"b"}]}}} }' + asynJob = getResultObj(result, res) + print asynJob.jobid, repr(asynJob.jobresult) + print asynJob.jobresult.ingressrule[0].account + + result = '{ "queryasyncjobresultresponse" : {"errorcode" : 431, "errortext" : "Unable to execute API command queryasyncjobresultresponse due to missing parameter jobid"} }' + try: + asynJob = getResultObj(result) + except cloudstackException.cloudstackAPIException, e: + print e + + result = '{ "queryasyncjobresultresponse" : {} }' + asynJob = getResultObj(result) + print asynJob + + result = '{}' + asynJob = getResultObj(result) + print asynJob + + result = '{ "createzoneresponse" : { "zone" : {"id":1,"name":"test0","dns1":"8.8.8.8","dns2":"4.4.4.4","internaldns1":"192.168.110.254","internaldns2":"192.168.110.253","networktype":"Basic","securitygroupsenabled":true,"allocationstate":"Enabled","zonetoken":"3442f287-e932-3111-960b-514d1f9c4610","dhcpprovider":"DhcpServer"} } }' + res = createZone.createZoneResponse() + zone = getResultObj(result, res) + print zone.id + + result = '{ "attachvolumeresponse" : {"jobid":24} }' + res = attachVolume.attachVolumeResponse() + res = getResultObj(result, res) + print res + + result = '{ "listtemplatesresponse" : { } }' + print getResultObj(result, listTemplates.listTemplatesResponse()) + + result = '{ "queryasyncjobresultresponse" : {"jobid":34,"jobstatus":2,"jobprocstatus":0,"jobresultcode":530,"jobresulttype":"object","jobresult":{"errorcode":431,"errortext":"Please provide either a volume id, or a tuple(device id, instance id)"}} }' + print getResultObj(result, listTemplates.listTemplatesResponse()) + result = '{ "queryasyncjobresultresponse" : {"jobid":41,"jobstatus":1,"jobprocstatus":0,"jobresultcode":0,"jobresulttype":"object","jobresult":{"virtualmachine":{"id":37,"name":"i-2-37-TEST","displayname":"i-2-37-TEST","account":"admin","domainid":1,"domain":"ROOT","created":"2011-08-25T11:13:42-0700","state":"Running","haenable":false,"zoneid":1,"zonename":"test0","hostid":5,"hostname":"SimulatedAgent.1e629060-f547-40dd-b792-57cdc4b7d611","templateid":10,"templatename":"CentOS 5.3(64-bit) no GUI (Simulator)","templatedisplaytext":"CentOS 5.3(64-bit) no GUI (Simulator)","passwordenabled":false,"serviceofferingid":7,"serviceofferingname":"Small Instance","cpunumber":1,"cpuspeed":500,"memory":512,"guestosid":11,"rootdeviceid":0,"rootdevicetype":"NetworkFilesystem","securitygroup":[{"id":1,"name":"default","description":"Default Security Group"}],"nic":[{"id":43,"networkid":204,"netmask":"255.255.255.0","gateway":"192.168.1.1","ipaddress":"192.168.1.27","isolationuri":"ec2://untagged","broadcasturi":"vlan://untagged","traffictype":"Guest","type":"Direct","isdefault":true,"macaddress":"06:56:b8:00:00:53"}],"hypervisor":"Simulator"}}} }' + vm = getResultObj(result, deployVirtualMachine.deployVirtualMachineResponse()) + print vm.jobresult.id + + cmd = deployVirtualMachine.deployVirtualMachineCmd() + responsename = cmd.__class__.__name__.replace("Cmd", "Response") + response = getclassFromName(cmd, responsename) + print response.id diff --git a/tools/marvin/marvin/pymysql/__init__.py b/tools/marvin/marvin/pymysql/__init__.py new file mode 100644 index 00000000000..903107e539a --- /dev/null +++ b/tools/marvin/marvin/pymysql/__init__.py @@ -0,0 +1,131 @@ +''' +PyMySQL: A pure-Python drop-in replacement for MySQLdb. + +Copyright (c) 2010 PyMySQL contributors + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. + +''' + +VERSION = (0, 4, None) + +from constants import FIELD_TYPE +from converters import escape_dict, escape_sequence, escape_string +from err import Warning, Error, InterfaceError, DataError, \ + DatabaseError, OperationalError, IntegrityError, InternalError, \ + NotSupportedError, ProgrammingError, MySQLError +from times import Date, Time, Timestamp, \ + DateFromTicks, TimeFromTicks, TimestampFromTicks + +import sys + +try: + frozenset +except NameError: + from sets import ImmutableSet as frozenset + try: + from sets import BaseSet as set + except ImportError: + from sets import Set as set + +threadsafety = 1 +apilevel = "2.0" +paramstyle = "format" + +class DBAPISet(frozenset): + + + def __ne__(self, other): + if isinstance(other, set): + return super(DBAPISet, self).__ne__(self, other) + else: + return other not in self + + def __eq__(self, other): + if isinstance(other, frozenset): + return frozenset.__eq__(self, other) + else: + return other in self + + def __hash__(self): + return frozenset.__hash__(self) + + +STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING, + FIELD_TYPE.VAR_STRING]) +BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB, + FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB]) +NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT, + FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG, + FIELD_TYPE.TINY, FIELD_TYPE.YEAR]) +DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE]) +TIME = DBAPISet([FIELD_TYPE.TIME]) +TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME]) +DATETIME = TIMESTAMP +ROWID = DBAPISet() + +def Binary(x): + """Return x as a binary type.""" + return str(x) + +def Connect(*args, **kwargs): + """ + Connect to the database; see connections.Connection.__init__() for + more information. + """ + from connections import Connection + return Connection(*args, **kwargs) + +def get_client_info(): # for MySQLdb compatibility + return '%s.%s.%s' % VERSION + +connect = Connection = Connect + +# we include a doctored version_info here for MySQLdb compatibility +version_info = (1,2,2,"final",0) + +NULL = "NULL" + +__version__ = get_client_info() + +def thread_safe(): + return True # match MySQLdb.thread_safe() + +def install_as_MySQLdb(): + """ + After this function is called, any application that imports MySQLdb or + _mysql will unwittingly actually use + """ + sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"] + +__all__ = [ + 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date', + 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks', + 'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError', + 'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER', + 'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError', + 'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect', + 'connections', 'constants', 'converters', 'cursors', + 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info', + 'paramstyle', 'threadsafety', 'version_info', + + "install_as_MySQLdb", + + "NULL","__version__", + ] diff --git a/tools/marvin/marvin/pymysql/charset.py b/tools/marvin/marvin/pymysql/charset.py new file mode 100644 index 00000000000..10a91bd19f2 --- /dev/null +++ b/tools/marvin/marvin/pymysql/charset.py @@ -0,0 +1,174 @@ +MBLENGTH = { + 8:1, + 33:3, + 88:2, + 91:2 + } + +class Charset: + def __init__(self, id, name, collation, is_default): + self.id, self.name, self.collation = id, name, collation + self.is_default = is_default == 'Yes' + +class Charsets: + def __init__(self): + self._by_id = {} + + def add(self, c): + self._by_id[c.id] = c + + def by_id(self, id): + return self._by_id[id] + + def by_name(self, name): + for c in self._by_id.values(): + if c.name == name and c.is_default: + return c + +_charsets = Charsets() +""" +Generated with: + +mysql -N -s -e "select id, character_set_name, collation_name, is_default +from information_schema.collations order by id;" | python -c "import sys +for l in sys.stdin.readlines(): + id, name, collation, is_default = l.split(chr(9)) + print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \ + % (id, name, collation, is_default.strip()) +" + +""" +_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes')) +_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', '')) +_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes')) +_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes')) +_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', '')) +_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes')) +_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes')) +_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes')) +_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes')) +_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes')) +_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes')) +_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes')) +_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes')) +_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', '')) +_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', '')) +_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes')) +_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes')) +_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes')) +_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', '')) +_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', '')) +_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes')) +_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', '')) +_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes')) +_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes')) +_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes')) +_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', '')) +_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes')) +_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', '')) +_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes')) +_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', '')) +_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes')) +_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes')) +_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', '')) +_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes')) +_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes')) +_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes')) +_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes')) +_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes')) +_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes')) +_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes')) +_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', '')) +_charsets.add(Charset(43, 'macce', 'macce_bin', '')) +_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', '')) +_charsets.add(Charset(47, 'latin1', 'latin1_bin', '')) +_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', '')) +_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', '')) +_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', '')) +_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes')) +_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', '')) +_charsets.add(Charset(53, 'macroman', 'macroman_bin', '')) +_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes')) +_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', '')) +_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes')) +_charsets.add(Charset(63, 'binary', 'binary', 'Yes')) +_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', '')) +_charsets.add(Charset(65, 'ascii', 'ascii_bin', '')) +_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', '')) +_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', '')) +_charsets.add(Charset(68, 'cp866', 'cp866_bin', '')) +_charsets.add(Charset(69, 'dec8', 'dec8_bin', '')) +_charsets.add(Charset(70, 'greek', 'greek_bin', '')) +_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', '')) +_charsets.add(Charset(72, 'hp8', 'hp8_bin', '')) +_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', '')) +_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', '')) +_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', '')) +_charsets.add(Charset(77, 'latin2', 'latin2_bin', '')) +_charsets.add(Charset(78, 'latin5', 'latin5_bin', '')) +_charsets.add(Charset(79, 'latin7', 'latin7_bin', '')) +_charsets.add(Charset(80, 'cp850', 'cp850_bin', '')) +_charsets.add(Charset(81, 'cp852', 'cp852_bin', '')) +_charsets.add(Charset(82, 'swe7', 'swe7_bin', '')) +_charsets.add(Charset(83, 'utf8', 'utf8_bin', '')) +_charsets.add(Charset(84, 'big5', 'big5_bin', '')) +_charsets.add(Charset(85, 'euckr', 'euckr_bin', '')) +_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', '')) +_charsets.add(Charset(87, 'gbk', 'gbk_bin', '')) +_charsets.add(Charset(88, 'sjis', 'sjis_bin', '')) +_charsets.add(Charset(89, 'tis620', 'tis620_bin', '')) +_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', '')) +_charsets.add(Charset(91, 'ujis', 'ujis_bin', '')) +_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes')) +_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', '')) +_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', '')) +_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes')) +_charsets.add(Charset(96, 'cp932', 'cp932_bin', '')) +_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes')) +_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', '')) +_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', '')) +_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', '')) +_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', '')) +_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', '')) +_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', '')) +_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', '')) +_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', '')) +_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', '')) +_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', '')) +_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', '')) +_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', '')) +_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', '')) +_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', '')) +_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', '')) +_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', '')) +_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', '')) +_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', '')) +_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', '')) +_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', '')) +_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', '')) +_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', '')) +_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', '')) +_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', '')) +_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', '')) +_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', '')) +_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', '')) +_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', '')) +_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', '')) +_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', '')) +_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', '')) +_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', '')) +_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', '')) +_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', '')) +_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', '')) +_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', '')) +_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', '')) +_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', '')) +_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', '')) +_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', '')) + +def charset_by_name(name): + return _charsets.by_name(name) + +def charset_by_id(id): + return _charsets.by_id(id) + diff --git a/tools/marvin/marvin/pymysql/connections.py b/tools/marvin/marvin/pymysql/connections.py new file mode 100644 index 00000000000..8897644ab09 --- /dev/null +++ b/tools/marvin/marvin/pymysql/connections.py @@ -0,0 +1,928 @@ +# Python implementation of the MySQL client-server protocol +# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol + +try: + import hashlib + sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs) +except ImportError: + import sha + sha_new = sha.new + +import socket +try: + import ssl + SSL_ENABLED = True +except ImportError: + SSL_ENABLED = False + +import struct +import sys +import os +import ConfigParser + +try: + import cStringIO as StringIO +except ImportError: + import StringIO + +from charset import MBLENGTH, charset_by_name, charset_by_id +from cursors import Cursor +from constants import FIELD_TYPE, FLAG +from constants import SERVER_STATUS +from constants.CLIENT import * +from constants.COMMAND import * +from util import join_bytes, byte2int, int2byte +from converters import escape_item, encoders, decoders +from err import raise_mysql_exception, Warning, Error, \ + InterfaceError, DataError, DatabaseError, OperationalError, \ + IntegrityError, InternalError, NotSupportedError, ProgrammingError + +DEBUG = False + +NULL_COLUMN = 251 +UNSIGNED_CHAR_COLUMN = 251 +UNSIGNED_SHORT_COLUMN = 252 +UNSIGNED_INT24_COLUMN = 253 +UNSIGNED_INT64_COLUMN = 254 +UNSIGNED_CHAR_LENGTH = 1 +UNSIGNED_SHORT_LENGTH = 2 +UNSIGNED_INT24_LENGTH = 3 +UNSIGNED_INT64_LENGTH = 8 + +DEFAULT_CHARSET = 'latin1' + + +def dump_packet(data): + + def is_ascii(data): + if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum(): + return data + return '.' + print "packet length %d" % len(data) + print "method call[1]: %s" % sys._getframe(1).f_code.co_name + print "method call[2]: %s" % sys._getframe(2).f_code.co_name + print "method call[3]: %s" % sys._getframe(3).f_code.co_name + print "method call[4]: %s" % sys._getframe(4).f_code.co_name + print "method call[5]: %s" % sys._getframe(5).f_code.co_name + print "-" * 88 + dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0] + for d in dump_data: + print ' '.join(map(lambda x:"%02X" % byte2int(x), d)) + \ + ' ' * (16 - len(d)) + ' ' * 2 + \ + ' '.join(map(lambda x:"%s" % is_ascii(x), d)) + print "-" * 88 + print "" + +def _scramble(password, message): + if password == None or len(password) == 0: + return int2byte(0) + if DEBUG: print 'password=' + password + stage1 = sha_new(password).digest() + stage2 = sha_new(stage1).digest() + s = sha_new() + s.update(message) + s.update(stage2) + result = s.digest() + return _my_crypt(result, stage1) + +def _my_crypt(message1, message2): + length = len(message1) + result = struct.pack('B', length) + for i in xrange(length): + x = (struct.unpack('B', message1[i:i+1])[0] ^ \ + struct.unpack('B', message2[i:i+1])[0]) + result += struct.pack('B', x) + return result + +# old_passwords support ported from libmysql/password.c +SCRAMBLE_LENGTH_323 = 8 + +class RandStruct_323(object): + def __init__(self, seed1, seed2): + self.max_value = 0x3FFFFFFFL + self.seed1 = seed1 % self.max_value + self.seed2 = seed2 % self.max_value + + def my_rnd(self): + self.seed1 = (self.seed1 * 3L + self.seed2) % self.max_value + self.seed2 = (self.seed1 + self.seed2 + 33L) % self.max_value + return float(self.seed1) / float(self.max_value) + +def _scramble_323(password, message): + hash_pass = _hash_password_323(password) + hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323]) + hash_pass_n = struct.unpack(">LL", hash_pass) + hash_message_n = struct.unpack(">LL", hash_message) + + rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0], + hash_pass_n[1] ^ hash_message_n[1]) + outbuf = StringIO.StringIO() + for _ in xrange(min(SCRAMBLE_LENGTH_323, len(message))): + outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64)) + extra = int2byte(int(rand_st.my_rnd() * 31)) + out = outbuf.getvalue() + outbuf = StringIO.StringIO() + for c in out: + outbuf.write(int2byte(byte2int(c) ^ byte2int(extra))) + return outbuf.getvalue() + +def _hash_password_323(password): + nr = 1345345333L + add = 7L + nr2 = 0x12345671L + + for c in [byte2int(x) for x in password if x not in (' ', '\t')]: + nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF + nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF + add= (add + c) & 0xFFFFFFFF + + r1 = nr & ((1L << 31) - 1L) # kill sign bits + r2 = nr2 & ((1L << 31) - 1L) + + # pack + return struct.pack(">LL", r1, r2) + +def pack_int24(n): + return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF) + +def unpack_uint16(n): + return struct.unpack(' len(self.__data): + raise Exception('Invalid advance amount (%s) for cursor. ' + 'Position=%s' % (length, new_position)) + self.__position = new_position + + def rewind(self, position=0): + """Set the position of the data buffer cursor to 'position'.""" + if position < 0 or position > len(self.__data): + raise Exception("Invalid position to rewind cursor to: %s." % position) + self.__position = position + + def peek(self, size): + """Look at the first 'size' bytes in packet without moving cursor.""" + result = self.__data[self.__position:(self.__position+size)] + if len(result) != size: + error = ('Result length not requested length:\n' + 'Expected=%s. Actual=%s. Position: %s. Data Length: %s' + % (size, len(result), self.__position, len(self.__data))) + if DEBUG: + print error + self.dump() + raise AssertionError(error) + return result + + def get_bytes(self, position, length=1): + """Get 'length' bytes starting at 'position'. + + Position is start of payload (first four packet header bytes are not + included) starting at index '0'. + + No error checking is done. If requesting outside end of buffer + an empty string (or string shorter than 'length') may be returned! + """ + return self.__data[position:(position+length)] + + def read_length_coded_binary(self): + """Read a 'Length Coded Binary' number from the data buffer. + + Length coded numbers can be anywhere from 1 to 9 bytes depending + on the value of the first byte. + """ + c = byte2int(self.read(1)) + if c == NULL_COLUMN: + return None + if c < UNSIGNED_CHAR_COLUMN: + return c + elif c == UNSIGNED_SHORT_COLUMN: + return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH)) + elif c == UNSIGNED_INT24_COLUMN: + return unpack_int24(self.read(UNSIGNED_INT24_LENGTH)) + elif c == UNSIGNED_INT64_COLUMN: + # TODO: what was 'longlong'? confirm it wasn't used? + return unpack_int64(self.read(UNSIGNED_INT64_LENGTH)) + + def read_length_coded_string(self): + """Read a 'Length Coded String' from the data buffer. + + A 'Length Coded String' consists first of a length coded + (unsigned, positive) integer represented in 1-9 bytes followed by + that many bytes of binary data. (For example "cat" would be "3cat".) + """ + length = self.read_length_coded_binary() + if length is None: + return None + return self.read(length) + + def is_ok_packet(self): + return byte2int(self.get_bytes(0)) == 0 + + def is_eof_packet(self): + return byte2int(self.get_bytes(0)) == 254 # 'fe' + + def is_resultset_packet(self): + field_count = byte2int(self.get_bytes(0)) + return field_count >= 1 and field_count <= 250 + + def is_error_packet(self): + return byte2int(self.get_bytes(0)) == 255 + + def check_error(self): + if self.is_error_packet(): + self.rewind() + self.advance(1) # field_count == error (we already know that) + errno = unpack_uint16(self.read(2)) + if DEBUG: print "errno = %d" % errno + raise_mysql_exception(self.__data) + + def dump(self): + dump_packet(self.__data) + + +class FieldDescriptorPacket(MysqlPacket): + """A MysqlPacket that represents a specific column's metadata in the result. + + Parsing is automatically done and the results are exported via public + attributes on the class such as: db, table_name, name, length, type_code. + """ + + def __init__(self, *args): + MysqlPacket.__init__(self, *args) + self.__parse_field_descriptor() + + def __parse_field_descriptor(self): + """Parse the 'Field Descriptor' (Metadata) packet. + + This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0). + """ + self.catalog = self.read_length_coded_string() + self.db = self.read_length_coded_string() + self.table_name = self.read_length_coded_string() + self.org_table = self.read_length_coded_string() + self.name = self.read_length_coded_string().decode(self.connection.charset) + self.org_name = self.read_length_coded_string() + self.advance(1) # non-null filler + self.charsetnr = struct.unpack(' 2: + use_unicode = True + + if compress or named_pipe: + raise NotImplementedError, "compress and named_pipe arguments are not supported" + + if ssl and (ssl.has_key('capath') or ssl.has_key('cipher')): + raise NotImplementedError, 'ssl options capath and cipher are not supported' + + self.ssl = False + if ssl: + if not SSL_ENABLED: + raise NotImplementedError, "ssl module not found" + self.ssl = True + client_flag |= SSL + for k in ('key', 'cert', 'ca'): + v = None + if ssl.has_key(k): + v = ssl[k] + setattr(self, k, v) + + if read_default_group and not read_default_file: + if sys.platform.startswith("win"): + read_default_file = "c:\\my.ini" + else: + read_default_file = "/etc/my.cnf" + + if read_default_file: + if not read_default_group: + read_default_group = "client" + + cfg = ConfigParser.RawConfigParser() + cfg.read(os.path.expanduser(read_default_file)) + + def _config(key, default): + try: + return cfg.get(read_default_group,key) + except: + return default + + user = _config("user",user) + passwd = _config("password",passwd) + host = _config("host", host) + db = _config("db",db) + unix_socket = _config("socket",unix_socket) + port = _config("port", port) + charset = _config("default-character-set", charset) + + self.host = host + self.port = port + self.user = user + self.password = passwd + self.db = db + self.unix_socket = unix_socket + if charset: + self.charset = charset + self.use_unicode = True + else: + self.charset = DEFAULT_CHARSET + self.use_unicode = False + + if use_unicode is not None: + self.use_unicode = use_unicode + + client_flag |= CAPABILITIES + client_flag |= MULTI_STATEMENTS + if self.db: + client_flag |= CONNECT_WITH_DB + self.client_flag = client_flag + + self.cursorclass = cursorclass + self.connect_timeout = connect_timeout + + self._connect() + + self.messages = [] + self.set_charset(charset) + self.encoders = encoders + self.decoders = conv + + self._result = None + self._affected_rows = 0 + self.host_info = "Not connected" + + self.autocommit(False) + + if sql_mode is not None: + c = self.cursor() + c.execute("SET sql_mode=%s", (sql_mode,)) + + self.commit() + + if init_command is not None: + c = self.cursor() + c.execute(init_command) + + self.commit() + + + def close(self): + ''' Send the quit message and close the socket ''' + if self.socket is None: + raise Error("Already closed") + send_data = struct.pack('= i + 1: + i += 1 + + self.server_capabilities = struct.unpack('= i+12-1: + rest_salt = data[i:i+12] + self.salt += rest_salt + + def get_server_info(self): + return self.server_version + + Warning = Warning + Error = Error + InterfaceError = InterfaceError + DatabaseError = DatabaseError + DataError = DataError + OperationalError = OperationalError + IntegrityError = IntegrityError + InternalError = InternalError + ProgrammingError = ProgrammingError + NotSupportedError = NotSupportedError + +# TODO: move OK and EOF packet parsing/logic into a proper subclass +# of MysqlPacket like has been done with FieldDescriptorPacket. +class MySQLResult(object): + + def __init__(self, connection): + from weakref import proxy + self.connection = proxy(connection) + self.affected_rows = None + self.insert_id = None + self.server_status = 0 + self.warning_count = 0 + self.message = None + self.field_count = 0 + self.description = None + self.rows = None + self.has_next = None + + def read(self): + self.first_packet = self.connection.read_packet() + + # TODO: use classes for different packet types? + if self.first_packet.is_ok_packet(): + self._read_ok_packet() + else: + self._read_result_packet() + + def _read_ok_packet(self): + self.first_packet.advance(1) # field_count (always '0') + self.affected_rows = self.first_packet.read_length_coded_binary() + self.insert_id = self.first_packet.read_length_coded_binary() + self.server_status = struct.unpack(' 2 + +try: + set +except NameError: + try: + from sets import BaseSet as set + except ImportError: + from sets import Set as set + +ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]") +ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z', + '\'': '\\\'', '"': '\\"', '\\': '\\\\'} + +def escape_item(val, charset): + if type(val) in [tuple, list, set]: + return escape_sequence(val, charset) + if type(val) is dict: + return escape_dict(val, charset) + if PYTHON3 and hasattr(val, "decode") and not isinstance(val, unicode): + # deal with py3k bytes + val = val.decode(charset) + encoder = encoders[type(val)] + val = encoder(val) + if type(val) is str: + return val + val = val.encode(charset) + return val + +def escape_dict(val, charset): + n = {} + for k, v in val.items(): + quoted = escape_item(v, charset) + n[k] = quoted + return n + +def escape_sequence(val, charset): + n = [] + for item in val: + quoted = escape_item(item, charset) + n.append(quoted) + return "(" + ",".join(n) + ")" + +def escape_set(val, charset): + val = map(lambda x: escape_item(x, charset), val) + return ','.join(val) + +def escape_bool(value): + return str(int(value)) + +def escape_object(value): + return str(value) + +escape_int = escape_long = escape_object + +def escape_float(value): + return ('%.15g' % value) + +def escape_string(value): + return ("'%s'" % ESCAPE_REGEX.sub( + lambda match: ESCAPE_MAP.get(match.group(0)), value)) + +def escape_unicode(value): + return escape_string(value) + +def escape_None(value): + return 'NULL' + +def escape_timedelta(obj): + seconds = int(obj.seconds) % 60 + minutes = int(obj.seconds // 60) % 60 + hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24 + return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds)) + +def escape_time(obj): + s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute), + int(obj.second)) + if obj.microsecond: + s += ".%f" % obj.microsecond + + return escape_string(s) + +def escape_datetime(obj): + return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S")) + +def escape_date(obj): + return escape_string(obj.strftime("%Y-%m-%d")) + +def escape_struct_time(obj): + return escape_datetime(datetime.datetime(*obj[:6])) + +def convert_datetime(connection, field, obj): + """Returns a DATETIME or TIMESTAMP column value as a datetime object: + + >>> datetime_or_None('2007-02-25 23:06:20') + datetime.datetime(2007, 2, 25, 23, 6, 20) + >>> datetime_or_None('2007-02-25T23:06:20') + datetime.datetime(2007, 2, 25, 23, 6, 20) + + Illegal values are returned as None: + + >>> datetime_or_None('2007-02-31T23:06:20') is None + True + >>> datetime_or_None('0000-00-00 00:00:00') is None + True + + """ + if not isinstance(obj, unicode): + obj = obj.decode(connection.charset) + if ' ' in obj: + sep = ' ' + elif 'T' in obj: + sep = 'T' + else: + return convert_date(connection, field, obj) + + try: + ymd, hms = obj.split(sep, 1) + return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ]) + except ValueError: + return convert_date(connection, field, obj) + +def convert_timedelta(connection, field, obj): + """Returns a TIME column as a timedelta object: + + >>> timedelta_or_None('25:06:17') + datetime.timedelta(1, 3977) + >>> timedelta_or_None('-25:06:17') + datetime.timedelta(-2, 83177) + + Illegal values are returned as None: + + >>> timedelta_or_None('random crap') is None + True + + Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but + can accept values as (+|-)DD HH:MM:SS. The latter format will not + be parsed correctly by this function. + """ + from math import modf + try: + if not isinstance(obj, unicode): + obj = obj.decode(connection.charset) + hours, minutes, seconds = tuple([int(x) for x in obj.split(':')]) + tdelta = datetime.timedelta( + hours = int(hours), + minutes = int(minutes), + seconds = int(seconds), + microseconds = int(modf(float(seconds))[0]*1000000), + ) + return tdelta + except ValueError: + return None + +def convert_time(connection, field, obj): + """Returns a TIME column as a time object: + + >>> time_or_None('15:06:17') + datetime.time(15, 6, 17) + + Illegal values are returned as None: + + >>> time_or_None('-25:06:17') is None + True + >>> time_or_None('random crap') is None + True + + Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but + can accept values as (+|-)DD HH:MM:SS. The latter format will not + be parsed correctly by this function. + + Also note that MySQL's TIME column corresponds more closely to + Python's timedelta and not time. However if you want TIME columns + to be treated as time-of-day and not a time offset, then you can + use set this function as the converter for FIELD_TYPE.TIME. + """ + from math import modf + try: + hour, minute, second = obj.split(':') + return datetime.time(hour=int(hour), minute=int(minute), + second=int(second), + microsecond=int(modf(float(second))[0]*1000000)) + except ValueError: + return None + +def convert_date(connection, field, obj): + """Returns a DATE column as a date object: + + >>> date_or_None('2007-02-26') + datetime.date(2007, 2, 26) + + Illegal values are returned as None: + + >>> date_or_None('2007-02-31') is None + True + >>> date_or_None('0000-00-00') is None + True + + """ + try: + if not isinstance(obj, unicode): + obj = obj.decode(connection.charset) + return datetime.date(*[ int(x) for x in obj.split('-', 2) ]) + except ValueError: + return None + +def convert_mysql_timestamp(connection, field, timestamp): + """Convert a MySQL TIMESTAMP to a Timestamp object. + + MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME: + + >>> mysql_timestamp_converter('2007-02-25 22:32:17') + datetime.datetime(2007, 2, 25, 22, 32, 17) + + MySQL < 4.1 uses a big string of numbers: + + >>> mysql_timestamp_converter('20070225223217') + datetime.datetime(2007, 2, 25, 22, 32, 17) + + Illegal values are returned as None: + + >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None + True + >>> mysql_timestamp_converter('00000000000000') is None + True + + """ + if not isinstance(timestamp, unicode): + timestamp = timestamp.decode(connection.charset) + + if timestamp[4] == '-': + return convert_datetime(connection, field, timestamp) + timestamp += "0"*(14-len(timestamp)) # padding + year, month, day, hour, minute, second = \ + int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \ + int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14]) + try: + return datetime.datetime(year, month, day, hour, minute, second) + except ValueError: + return None + +def convert_set(s): + return set(s.split(",")) + +def convert_bit(connection, field, b): + #b = "\x00" * (8 - len(b)) + b # pad w/ zeroes + #return struct.unpack(">Q", b)[0] + # + # the snippet above is right, but MySQLdb doesn't process bits, + # so we shouldn't either + return b + +def convert_characters(connection, field, data): + field_charset = charset_by_id(field.charsetnr).name + if field.flags & FLAG.SET: + return convert_set(data.decode(field_charset)) + if field.flags & FLAG.BINARY: + return data + + if connection.use_unicode: + data = data.decode(field_charset) + elif connection.charset != field_charset: + data = data.decode(field_charset) + data = data.encode(connection.charset) + return data + +def convert_int(connection, field, data): + return int(data) + +def convert_long(connection, field, data): + return long(data) + +def convert_float(connection, field, data): + return float(data) + +encoders = { + bool: escape_bool, + int: escape_int, + long: escape_long, + float: escape_float, + str: escape_string, + unicode: escape_unicode, + tuple: escape_sequence, + list:escape_sequence, + set:escape_sequence, + dict:escape_dict, + type(None):escape_None, + datetime.date: escape_date, + datetime.datetime : escape_datetime, + datetime.timedelta : escape_timedelta, + datetime.time : escape_time, + time.struct_time : escape_struct_time, + } + +decoders = { + FIELD_TYPE.BIT: convert_bit, + FIELD_TYPE.TINY: convert_int, + FIELD_TYPE.SHORT: convert_int, + FIELD_TYPE.LONG: convert_long, + FIELD_TYPE.FLOAT: convert_float, + FIELD_TYPE.DOUBLE: convert_float, + FIELD_TYPE.DECIMAL: convert_float, + FIELD_TYPE.NEWDECIMAL: convert_float, + FIELD_TYPE.LONGLONG: convert_long, + FIELD_TYPE.INT24: convert_int, + FIELD_TYPE.YEAR: convert_int, + FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp, + FIELD_TYPE.DATETIME: convert_datetime, + FIELD_TYPE.TIME: convert_timedelta, + FIELD_TYPE.DATE: convert_date, + FIELD_TYPE.SET: convert_set, + FIELD_TYPE.BLOB: convert_characters, + FIELD_TYPE.TINY_BLOB: convert_characters, + FIELD_TYPE.MEDIUM_BLOB: convert_characters, + FIELD_TYPE.LONG_BLOB: convert_characters, + FIELD_TYPE.STRING: convert_characters, + FIELD_TYPE.VAR_STRING: convert_characters, + FIELD_TYPE.VARCHAR: convert_characters, + #FIELD_TYPE.BLOB: str, + #FIELD_TYPE.STRING: str, + #FIELD_TYPE.VAR_STRING: str, + #FIELD_TYPE.VARCHAR: str + } +conversions = decoders # for MySQLdb compatibility + +try: + # python version > 2.3 + from decimal import Decimal + def convert_decimal(connection, field, data): + data = data.decode(connection.charset) + return Decimal(data) + decoders[FIELD_TYPE.DECIMAL] = convert_decimal + decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal + + def escape_decimal(obj): + return unicode(obj) + encoders[Decimal] = escape_decimal + +except ImportError: + pass diff --git a/tools/marvin/marvin/pymysql/cursors.py b/tools/marvin/marvin/pymysql/cursors.py new file mode 100644 index 00000000000..4e10f83f4fa --- /dev/null +++ b/tools/marvin/marvin/pymysql/cursors.py @@ -0,0 +1,297 @@ +# -*- coding: utf-8 -*- +import struct +import re + +try: + import cStringIO as StringIO +except ImportError: + import StringIO + +from err import Warning, Error, InterfaceError, DataError, \ + DatabaseError, OperationalError, IntegrityError, InternalError, \ + NotSupportedError, ProgrammingError + +insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE) + +class Cursor(object): + ''' + This is the object you use to interact with the database. + ''' + def __init__(self, connection): + ''' + Do not create an instance of a Cursor yourself. Call + connections.Connection.cursor(). + ''' + from weakref import proxy + self.connection = proxy(connection) + self.description = None + self.rownumber = 0 + self.rowcount = -1 + self.arraysize = 1 + self._executed = None + self.messages = [] + self.errorhandler = connection.errorhandler + self._has_next = None + self._rows = () + + def __del__(self): + ''' + When this gets GC'd close it. + ''' + self.close() + + def close(self): + ''' + Closing a cursor just exhausts all remaining data. + ''' + if not self.connection: + return + try: + while self.nextset(): + pass + except: + pass + + self.connection = None + + def _get_db(self): + if not self.connection: + self.errorhandler(self, ProgrammingError, "cursor closed") + return self.connection + + def _check_executed(self): + if not self._executed: + self.errorhandler(self, ProgrammingError, "execute() first") + + def setinputsizes(self, *args): + """Does nothing, required by DB API.""" + + def setoutputsizes(self, *args): + """Does nothing, required by DB API.""" + + def nextset(self): + ''' Get the next query set ''' + if self._executed: + self.fetchall() + del self.messages[:] + + if not self._has_next: + return None + connection = self._get_db() + connection.next_result() + self._do_get_result() + return True + + def execute(self, query, args=None): + ''' Execute a query ''' + from sys import exc_info + + conn = self._get_db() + charset = conn.charset + del self.messages[:] + + # TODO: make sure that conn.escape is correct + + if args is not None: + if isinstance(args, tuple) or isinstance(args, list): + escaped_args = tuple(conn.escape(arg) for arg in args) + elif isinstance(args, dict): + escaped_args = dict((key, conn.escape(val)) for (key, val) in args.items()) + else: + #If it's not a dictionary let's try escaping it anyways. + #Worst case it will throw a Value error + escaped_args = conn.escape(args) + + query = query % escaped_args + + if isinstance(query, unicode): + query = query.encode(charset) + + result = 0 + try: + result = self._query(query) + except: + exc, value, tb = exc_info() + del tb + self.messages.append((exc,value)) + self.errorhandler(self, exc, value) + + self._executed = query + return result + + def executemany(self, query, args): + ''' Run several data against one query ''' + del self.messages[:] + #conn = self._get_db() + if not args: + return + #charset = conn.charset + #if isinstance(query, unicode): + # query = query.encode(charset) + + self.rowcount = sum([ self.execute(query, arg) for arg in args ]) + return self.rowcount + + + def callproc(self, procname, args=()): + """Execute stored procedure procname with args + + procname -- string, name of procedure to execute on server + + args -- Sequence of parameters to use with procedure + + Returns the original args. + + Compatibility warning: PEP-249 specifies that any modified + parameters must be returned. This is currently impossible + as they are only available by storing them in a server + variable and then retrieved by a query. Since stored + procedures return zero or more result sets, there is no + reliable way to get at OUT or INOUT parameters via callproc. + The server variables are named @_procname_n, where procname + is the parameter above and n is the position of the parameter + (from zero). Once all result sets generated by the procedure + have been fetched, you can issue a SELECT @_procname_0, ... + query using .execute() to get any OUT or INOUT values. + + Compatibility warning: The act of calling a stored procedure + itself creates an empty result set. This appears after any + result sets generated by the procedure. This is non-standard + behavior with respect to the DB-API. Be sure to use nextset() + to advance through all result sets; otherwise you may get + disconnected. + """ + conn = self._get_db() + for index, arg in enumerate(args): + q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg)) + if isinstance(q, unicode): + q = q.encode(conn.charset) + self._query(q) + self.nextset() + + q = "CALL %s(%s)" % (procname, + ','.join(['@_%s_%d' % (procname, i) + for i in range(len(args))])) + if isinstance(q, unicode): + q = q.encode(conn.charset) + self._query(q) + self._executed = q + + return args + + def fetchone(self): + ''' Fetch the next row ''' + self._check_executed() + if self._rows is None or self.rownumber >= len(self._rows): + return None + result = self._rows[self.rownumber] + self.rownumber += 1 + return result + + def fetchmany(self, size=None): + ''' Fetch several rows ''' + self._check_executed() + end = self.rownumber + (size or self.arraysize) + result = self._rows[self.rownumber:end] + if self._rows is None: + return None + self.rownumber = min(end, len(self._rows)) + return result + + def fetchall(self): + ''' Fetch all the rows ''' + self._check_executed() + if self._rows is None: + return None + if self.rownumber: + result = self._rows[self.rownumber:] + else: + result = self._rows + self.rownumber = len(self._rows) + return result + + def scroll(self, value, mode='relative'): + self._check_executed() + if mode == 'relative': + r = self.rownumber + value + elif mode == 'absolute': + r = value + else: + self.errorhandler(self, ProgrammingError, + "unknown scroll mode %s" % mode) + + if r < 0 or r >= len(self._rows): + self.errorhandler(self, IndexError, "out of range") + self.rownumber = r + + def _query(self, q): + conn = self._get_db() + self._last_executed = q + conn.query(q) + self._do_get_result() + return self.rowcount + + def _do_get_result(self): + conn = self._get_db() + self.rowcount = conn._result.affected_rows + + self.rownumber = 0 + self.description = conn._result.description + self.lastrowid = conn._result.insert_id + self._rows = conn._result.rows + self._has_next = conn._result.has_next + + def __iter__(self): + return iter(self.fetchone, None) + + Warning = Warning + Error = Error + InterfaceError = InterfaceError + DatabaseError = DatabaseError + DataError = DataError + OperationalError = OperationalError + IntegrityError = IntegrityError + InternalError = InternalError + ProgrammingError = ProgrammingError + NotSupportedError = NotSupportedError + +class DictCursor(Cursor): + """A cursor which returns results as a dictionary""" + + def execute(self, query, args=None): + result = super(DictCursor, self).execute(query, args) + if self.description: + self._fields = [ field[0] for field in self.description ] + return result + + def fetchone(self): + ''' Fetch the next row ''' + self._check_executed() + if self._rows is None or self.rownumber >= len(self._rows): + return None + result = dict(zip(self._fields, self._rows[self.rownumber])) + self.rownumber += 1 + return result + + def fetchmany(self, size=None): + ''' Fetch several rows ''' + self._check_executed() + if self._rows is None: + return None + end = self.rownumber + (size or self.arraysize) + result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ] + self.rownumber = min(end, len(self._rows)) + return tuple(result) + + def fetchall(self): + ''' Fetch all the rows ''' + self._check_executed() + if self._rows is None: + return None + if self.rownumber: + result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ] + else: + result = [ dict(zip(self._fields, r)) for r in self._rows ] + self.rownumber = len(self._rows) + return tuple(result) + diff --git a/tools/marvin/marvin/pymysql/err.py b/tools/marvin/marvin/pymysql/err.py new file mode 100644 index 00000000000..b4322c63354 --- /dev/null +++ b/tools/marvin/marvin/pymysql/err.py @@ -0,0 +1,147 @@ +import struct + + +try: + StandardError, Warning +except ImportError: + try: + from exceptions import StandardError, Warning + except ImportError: + import sys + e = sys.modules['exceptions'] + StandardError = e.StandardError + Warning = e.Warning + +from constants import ER +import sys + +class MySQLError(StandardError): + + """Exception related to operation with MySQL.""" + + +class Warning(Warning, MySQLError): + + """Exception raised for important warnings like data truncations + while inserting, etc.""" + +class Error(MySQLError): + + """Exception that is the base class of all other error exceptions + (not Warning).""" + + +class InterfaceError(Error): + + """Exception raised for errors that are related to the database + interface rather than the database itself.""" + + +class DatabaseError(Error): + + """Exception raised for errors that are related to the + database.""" + + +class DataError(DatabaseError): + + """Exception raised for errors that are due to problems with the + processed data like division by zero, numeric value out of range, + etc.""" + + +class OperationalError(DatabaseError): + + """Exception raised for errors that are related to the database's + operation and not necessarily under the control of the programmer, + e.g. an unexpected disconnect occurs, the data source name is not + found, a transaction could not be processed, a memory allocation + error occurred during processing, etc.""" + + +class IntegrityError(DatabaseError): + + """Exception raised when the relational integrity of the database + is affected, e.g. a foreign key check fails, duplicate key, + etc.""" + + +class InternalError(DatabaseError): + + """Exception raised when the database encounters an internal + error, e.g. the cursor is not valid anymore, the transaction is + out of sync, etc.""" + + +class ProgrammingError(DatabaseError): + + """Exception raised for programming errors, e.g. table not found + or already exists, syntax error in the SQL statement, wrong number + of parameters specified, etc.""" + + +class NotSupportedError(DatabaseError): + + """Exception raised in case a method or database API was used + which is not supported by the database, e.g. requesting a + .rollback() on a connection that does not support transaction or + has transactions turned off.""" + + +error_map = {} + +def _map_error(exc, *errors): + for error in errors: + error_map[error] = exc + +_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR, + ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME, + ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE, + ER.INVALID_GROUP_FUNC_USE, ER.UNSUPPORTED_EXTENSION, + ER.TABLE_MUST_HAVE_COLUMNS, ER.CANT_DO_THIS_DURING_AN_TRANSACTION) +_map_error(DataError, ER.WARN_DATA_TRUNCATED, ER.WARN_NULL_TO_NOTNULL, + ER.WARN_DATA_OUT_OF_RANGE, ER.NO_DEFAULT, ER.PRIMARY_CANT_HAVE_NULL, + ER.DATA_TOO_LONG, ER.DATETIME_FUNCTION_OVERFLOW) +_map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW, + ER.NO_REFERENCED_ROW_2, ER.ROW_IS_REFERENCED, ER.ROW_IS_REFERENCED_2, + ER.CANNOT_ADD_FOREIGN) +_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK, + ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE) +_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR, + ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR) + +del _map_error, ER + + +def _get_error_info(data): + errno = struct.unpack(' tuple) + c.execute("SELECT * from dictcursor where name='bob'") + r = c.fetchall() + self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor") + # same test again but iterate over the + c.execute("SELECT * from dictcursor where name='bob'") + for r in c: + self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor") + # get all 3 row via fetchall + c.execute("SELECT * from dictcursor") + r = c.fetchall() + self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor") + #same test again but do a list comprehension + c.execute("SELECT * from dictcursor") + r = [x for x in c] + self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor") + # get all 2 row via fetchmany + c.execute("SELECT * from dictcursor") + r = c.fetchmany(2) + self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor") + finally: + c.execute("drop table dictcursor") + +__all__ = ["TestDictCursor"] + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/test_basic.py b/tools/marvin/marvin/pymysql/tests/test_basic.py new file mode 100644 index 00000000000..c8fdd297f44 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/test_basic.py @@ -0,0 +1,193 @@ +from pymysql.tests import base +from pymysql import util + +import time +import datetime + +class TestConversion(base.PyMySQLTestCase): + def test_datatypes(self): + """ test every data type """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)") + try: + # insert values + v = (True, -3, 123456789012, 5.7, "hello'\" world", u"Espa\xc3\xb1ol", "binary\x00data".encode(conn.charset), datetime.date(1988,2,2), datetime.datetime.now(), datetime.timedelta(5,6), datetime.time(16,32), time.localtime()) + c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v) + c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") + r = c.fetchone() + self.assertEqual(util.int2byte(1), r[0]) + self.assertEqual(v[1:8], r[1:8]) + # mysql throws away microseconds so we need to check datetimes + # specially. additionally times are turned into timedeltas. + self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8]) + self.assertEqual(v[9], r[9]) # just timedeltas + self.assertEqual(datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10]) + self.assertEqual(datetime.datetime(*v[-1][:6]), r[-1]) + + c.execute("delete from test_datatypes") + + # check nulls + c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", [None] * 12) + c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes") + r = c.fetchone() + self.assertEqual(tuple([None] * 12), r) + + c.execute("delete from test_datatypes") + + # check sequence type + c.execute("insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)") + c.execute("select l from test_datatypes where i in %s order by i", ((2,6),)) + r = c.fetchall() + self.assertEqual(((4,),(8,)), r) + finally: + c.execute("drop table test_datatypes") + + def test_dict(self): + """ test dict escaping """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table test_dict (a integer, b integer, c integer)") + try: + c.execute("insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", {"a":1,"b":2,"c":3}) + c.execute("select a,b,c from test_dict") + self.assertEqual((1,2,3), c.fetchone()) + finally: + c.execute("drop table test_dict") + + def test_string(self): + conn = self.connections[0] + c = conn.cursor() + c.execute("create table test_dict (a text)") + test_value = "I am a test string" + try: + c.execute("insert into test_dict (a) values (%s)", test_value) + c.execute("select a from test_dict") + self.assertEqual((test_value,), c.fetchone()) + finally: + c.execute("drop table test_dict") + + def test_integer(self): + conn = self.connections[0] + c = conn.cursor() + c.execute("create table test_dict (a integer)") + test_value = 12345 + try: + c.execute("insert into test_dict (a) values (%s)", test_value) + c.execute("select a from test_dict") + self.assertEqual((test_value,), c.fetchone()) + finally: + c.execute("drop table test_dict") + + + def test_big_blob(self): + """ test tons of data """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table test_big_blob (b blob)") + try: + data = "pymysql" * 1024 + c.execute("insert into test_big_blob (b) values (%s)", (data,)) + c.execute("select b from test_big_blob") + self.assertEqual(data.encode(conn.charset), c.fetchone()[0]) + finally: + c.execute("drop table test_big_blob") + +class TestCursor(base.PyMySQLTestCase): + # this test case does not work quite right yet, however, + # we substitute in None for the erroneous field which is + # compatible with the DB-API 2.0 spec and has not broken + # any unit tests for anything we've tried. + + #def test_description(self): + # """ test description attribute """ + # # result is from MySQLdb module + # r = (('Host', 254, 11, 60, 60, 0, 0), + # ('User', 254, 16, 16, 16, 0, 0), + # ('Password', 254, 41, 41, 41, 0, 0), + # ('Select_priv', 254, 1, 1, 1, 0, 0), + # ('Insert_priv', 254, 1, 1, 1, 0, 0), + # ('Update_priv', 254, 1, 1, 1, 0, 0), + # ('Delete_priv', 254, 1, 1, 1, 0, 0), + # ('Create_priv', 254, 1, 1, 1, 0, 0), + # ('Drop_priv', 254, 1, 1, 1, 0, 0), + # ('Reload_priv', 254, 1, 1, 1, 0, 0), + # ('Shutdown_priv', 254, 1, 1, 1, 0, 0), + # ('Process_priv', 254, 1, 1, 1, 0, 0), + # ('File_priv', 254, 1, 1, 1, 0, 0), + # ('Grant_priv', 254, 1, 1, 1, 0, 0), + # ('References_priv', 254, 1, 1, 1, 0, 0), + # ('Index_priv', 254, 1, 1, 1, 0, 0), + # ('Alter_priv', 254, 1, 1, 1, 0, 0), + # ('Show_db_priv', 254, 1, 1, 1, 0, 0), + # ('Super_priv', 254, 1, 1, 1, 0, 0), + # ('Create_tmp_table_priv', 254, 1, 1, 1, 0, 0), + # ('Lock_tables_priv', 254, 1, 1, 1, 0, 0), + # ('Execute_priv', 254, 1, 1, 1, 0, 0), + # ('Repl_slave_priv', 254, 1, 1, 1, 0, 0), + # ('Repl_client_priv', 254, 1, 1, 1, 0, 0), + # ('Create_view_priv', 254, 1, 1, 1, 0, 0), + # ('Show_view_priv', 254, 1, 1, 1, 0, 0), + # ('Create_routine_priv', 254, 1, 1, 1, 0, 0), + # ('Alter_routine_priv', 254, 1, 1, 1, 0, 0), + # ('Create_user_priv', 254, 1, 1, 1, 0, 0), + # ('Event_priv', 254, 1, 1, 1, 0, 0), + # ('Trigger_priv', 254, 1, 1, 1, 0, 0), + # ('ssl_type', 254, 0, 9, 9, 0, 0), + # ('ssl_cipher', 252, 0, 65535, 65535, 0, 0), + # ('x509_issuer', 252, 0, 65535, 65535, 0, 0), + # ('x509_subject', 252, 0, 65535, 65535, 0, 0), + # ('max_questions', 3, 1, 11, 11, 0, 0), + # ('max_updates', 3, 1, 11, 11, 0, 0), + # ('max_connections', 3, 1, 11, 11, 0, 0), + # ('max_user_connections', 3, 1, 11, 11, 0, 0)) + # conn = self.connections[0] + # c = conn.cursor() + # c.execute("select * from mysql.user") + # + # self.assertEqual(r, c.description) + + def test_fetch_no_result(self): + """ test a fetchone() with no rows """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table test_nr (b varchar(32))") + try: + data = "pymysql" + c.execute("insert into test_nr (b) values (%s)", (data,)) + self.assertEqual(None, c.fetchone()) + finally: + c.execute("drop table test_nr") + + def test_aggregates(self): + """ test aggregate functions """ + conn = self.connections[0] + c = conn.cursor() + try: + c.execute('create table test_aggregates (i integer)') + for i in xrange(0, 10): + c.execute('insert into test_aggregates (i) values (%s)', (i,)) + c.execute('select sum(i) from test_aggregates') + r, = c.fetchone() + self.assertEqual(sum(range(0,10)), r) + finally: + c.execute('drop table test_aggregates') + + def test_single_tuple(self): + """ test a single tuple """ + conn = self.connections[0] + c = conn.cursor() + try: + c.execute("create table mystuff (id integer primary key)") + c.execute("insert into mystuff (id) values (1)") + c.execute("insert into mystuff (id) values (2)") + c.execute("select id from mystuff where id in %s", ((1,),)) + self.assertEqual([(1,)], list(c.fetchall())) + finally: + c.execute("drop table mystuff") + +__all__ = ["TestConversion","TestCursor"] + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/test_example.py b/tools/marvin/marvin/pymysql/tests/test_example.py new file mode 100644 index 00000000000..2da05db31c6 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/test_example.py @@ -0,0 +1,32 @@ +import pymysql +from pymysql.tests import base + +class TestExample(base.PyMySQLTestCase): + def test_example(self): + conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='mysql') + + + cur = conn.cursor() + + cur.execute("SELECT Host,User FROM user") + + # print cur.description + + # r = cur.fetchall() + # print r + # ...or... + u = False + + for r in cur.fetchall(): + u = u or conn.user in r + + self.assertTrue(u) + + cur.close() + conn.close() + +__all__ = ["TestExample"] + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/test_issues.py b/tools/marvin/marvin/pymysql/tests/test_issues.py new file mode 100644 index 00000000000..38d71639c90 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/test_issues.py @@ -0,0 +1,268 @@ +import pymysql +from pymysql.tests import base + +import sys + +try: + import imp + reload = imp.reload +except AttributeError: + pass + +import datetime + +class TestOldIssues(base.PyMySQLTestCase): + def test_issue_3(self): + """ undefined methods datetime_or_None, date_or_None """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)") + try: + c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None)) + c.execute("select d from issue3") + self.assertEqual(None, c.fetchone()[0]) + c.execute("select t from issue3") + self.assertEqual(None, c.fetchone()[0]) + c.execute("select dt from issue3") + self.assertEqual(None, c.fetchone()[0]) + c.execute("select ts from issue3") + self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime)) + finally: + c.execute("drop table issue3") + + def test_issue_4(self): + """ can't retrieve TIMESTAMP fields """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table issue4 (ts timestamp)") + try: + c.execute("insert into issue4 (ts) values (now())") + c.execute("select ts from issue4") + self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime)) + finally: + c.execute("drop table issue4") + + def test_issue_5(self): + """ query on information_schema.tables fails """ + con = self.connections[0] + cur = con.cursor() + cur.execute("select * from information_schema.tables") + + def test_issue_6(self): + """ exception: TypeError: ord() expected a character, but string of length 0 found """ + conn = pymysql.connect(host="localhost",user="root",passwd="",db="mysql") + c = conn.cursor() + c.execute("select * from user") + conn.close() + + def test_issue_8(self): + """ Primary Key and Index error when selecting data """ + conn = self.connections[0] + c = conn.cursor() + c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh` +datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL +DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY +KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""") + try: + self.assertEqual(0, c.execute("SELECT * FROM test")) + c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)") + self.assertEqual(0, c.execute("SELECT * FROM test")) + finally: + c.execute("drop table test") + + def test_issue_9(self): + """ sets DeprecationWarning in Python 2.6 """ + try: + reload(pymysql) + except DeprecationWarning: + self.fail() + + def test_issue_10(self): + """ Allocate a variable to return when the exception handler is permissive """ + conn = self.connections[0] + conn.errorhandler = lambda cursor, errorclass, errorvalue: None + cur = conn.cursor() + cur.execute( "create table t( n int )" ) + cur.execute( "create table t( n int )" ) + + def test_issue_13(self): + """ can't handle large result fields """ + conn = self.connections[0] + cur = conn.cursor() + try: + cur.execute("create table issue13 (t text)") + # ticket says 18k + size = 18*1024 + cur.execute("insert into issue13 (t) values (%s)", ("x" * size,)) + cur.execute("select t from issue13") + # use assert_ so that obscenely huge error messages don't print + r = cur.fetchone()[0] + self.assert_("x" * size == r) + finally: + cur.execute("drop table issue13") + + def test_issue_14(self): + """ typo in converters.py """ + self.assertEqual('1', pymysql.converters.escape_item(1, "utf8")) + self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8")) + + self.assertEqual('1', pymysql.converters.escape_object(1)) + self.assertEqual('1', pymysql.converters.escape_object(1L)) + + def test_issue_15(self): + """ query should be expanded before perform character encoding """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table issue15 (t varchar(32))") + try: + c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',)) + c.execute("select t from issue15") + self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0]) + finally: + c.execute("drop table issue15") + + def test_issue_16(self): + """ Patch for string and tuple escaping """ + conn = self.connections[0] + c = conn.cursor() + c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))") + try: + c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')") + c.execute("select email from issue16 where name=%s", ("pete",)) + self.assertEqual("floydophone", c.fetchone()[0]) + finally: + c.execute("drop table issue16") + + def test_issue_17(self): + """ could not connect mysql use passwod """ + conn = self.connections[0] + host = self.databases[0]["host"] + db = self.databases[0]["db"] + c = conn.cursor() + # grant access to a table to a user with a password + try: + c.execute("create table issue17 (x varchar(32) primary key)") + c.execute("insert into issue17 (x) values ('hello, world!')") + c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db) + conn.commit() + + conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db) + c2 = conn2.cursor() + c2.execute("select x from issue17") + self.assertEqual("hello, world!", c2.fetchone()[0]) + finally: + c.execute("drop table issue17") + +def _uni(s, e): + # hack for py3 + if sys.version_info[0] > 2: + return unicode(bytes(s, sys.getdefaultencoding()), e) + else: + return unicode(s, e) + +class TestNewIssues(base.PyMySQLTestCase): + def test_issue_34(self): + try: + pymysql.connect(host="localhost", port=1237, user="root") + self.fail() + except pymysql.OperationalError, e: + self.assertEqual(2003, e.args[0]) + except: + self.fail() + + def test_issue_33(self): + conn = pymysql.connect(host="localhost", user="root", db=self.databases[0]["db"], charset="utf8") + c = conn.cursor() + try: + c.execute(_uni("create table hei\xc3\x9fe (name varchar(32))", "utf8")) + c.execute(_uni("insert into hei\xc3\x9fe (name) values ('Pi\xc3\xb1ata')", "utf8")) + c.execute(_uni("select name from hei\xc3\x9fe", "utf8")) + self.assertEqual(_uni("Pi\xc3\xb1ata","utf8"), c.fetchone()[0]) + finally: + c.execute(_uni("drop table hei\xc3\x9fe", "utf8")) + + # Will fail without manual intervention: + #def test_issue_35(self): + # + # conn = self.connections[0] + # c = conn.cursor() + # print "sudo killall -9 mysqld within the next 10 seconds" + # try: + # c.execute("select sleep(10)") + # self.fail() + # except pymysql.OperationalError, e: + # self.assertEqual(2013, e.args[0]) + + def test_issue_36(self): + conn = self.connections[0] + c = conn.cursor() + # kill connections[0] + original_count = c.execute("show processlist") + kill_id = None + for id,user,host,db,command,time,state,info in c.fetchall(): + if info == "show processlist": + kill_id = id + break + # now nuke the connection + conn.kill(kill_id) + # make sure this connection has broken + try: + c.execute("show tables") + self.fail() + except: + pass + # check the process list from the other connection + self.assertEqual(original_count - 1, self.connections[1].cursor().execute("show processlist")) + del self.connections[0] + + def test_issue_37(self): + conn = self.connections[0] + c = conn.cursor() + self.assertEqual(1, c.execute("SELECT @foo")) + self.assertEqual((None,), c.fetchone()) + self.assertEqual(0, c.execute("SET @foo = 'bar'")) + c.execute("set @foo = 'bar'") + + def test_issue_38(self): + conn = self.connections[0] + c = conn.cursor() + datum = "a" * 1024 * 1023 # reduced size for most default mysql installs + + try: + c.execute("create table issue38 (id integer, data mediumblob)") + c.execute("insert into issue38 values (1, %s)", (datum,)) + finally: + c.execute("drop table issue38") + + def disabled_test_issue_54(self): + conn = self.connections[0] + c = conn.cursor() + big_sql = "select * from issue54 where " + big_sql += " and ".join("%d=%d" % (i,i) for i in xrange(0, 100000)) + + try: + c.execute("create table issue54 (id integer primary key)") + c.execute("insert into issue54 (id) values (7)") + c.execute(big_sql) + self.assertEquals(7, c.fetchone()[0]) + finally: + c.execute("drop table issue54") + +class TestGitHubIssues(base.PyMySQLTestCase): + def test_issue_66(self): + conn = self.connections[0] + c = conn.cursor() + self.assertEquals(0, conn.insert_id()) + try: + c.execute("create table issue66 (id integer primary key auto_increment, x integer)") + c.execute("insert into issue66 (x) values (1)") + c.execute("insert into issue66 (x) values (1)") + self.assertEquals(2, conn.insert_id()) + finally: + c.execute("drop table issue66") + +__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"] + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py b/tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py new file mode 100644 index 00000000000..bfcc075fc4b --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py @@ -0,0 +1,5 @@ +from test_MySQLdb import * + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py new file mode 100644 index 00000000000..b64f273cf08 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py @@ -0,0 +1,7 @@ +from test_MySQLdb_capabilities import test_MySQLdb as test_capabilities +from test_MySQLdb_nonstandard import * +from test_MySQLdb_dbapi20 import test_MySQLdb as test_dbapi2 + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py new file mode 100644 index 00000000000..ddd012330e5 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py @@ -0,0 +1,292 @@ +#!/usr/bin/env python -O +""" Script to test database capabilities and the DB-API interface + for functionality and memory leaks. + + Adapted from a script by M-A Lemburg. + +""" +from time import time +import array +import unittest + + +class DatabaseTest(unittest.TestCase): + + db_module = None + connect_args = () + connect_kwargs = dict(use_unicode=True, charset="utf8") + create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" + rows = 10 + debug = False + + def setUp(self): + import gc + db = self.db_module.connect(*self.connect_args, **self.connect_kwargs) + self.connection = db + self.cursor = db.cursor() + self.BLOBText = ''.join([chr(i) for i in range(256)] * 100); + self.BLOBUText = u''.join([unichr(i) for i in range(16834)]) + self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16)) + + leak_test = True + + def tearDown(self): + if self.leak_test: + import gc + del self.cursor + orphans = gc.collect() + self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans) + + del self.connection + orphans = gc.collect() + self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans) + + def table_exists(self, name): + try: + self.cursor.execute('select * from %s where 1=0' % name) + except: + return False + else: + return True + + def quote_identifier(self, ident): + return '"%s"' % ident + + def new_table_name(self): + i = id(self.cursor) + while True: + name = self.quote_identifier('tb%08x' % i) + if not self.table_exists(name): + return name + i = i + 1 + + def create_table(self, columndefs): + + """ Create a table using a list of column definitions given in + columndefs. + + generator must be a function taking arguments (row_number, + col_number) returning a suitable data object for insertion + into the table. + + """ + self.table = self.new_table_name() + self.cursor.execute('CREATE TABLE %s (%s) %s' % + (self.table, + ',\n'.join(columndefs), + self.create_table_extra)) + + def check_data_integrity(self, columndefs, generator): + # insert + self.create_table(columndefs) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) + data = [ [ generator(i,j) for j in range(len(columndefs)) ] + for i in range(self.rows) ] + if self.debug: + print data + self.cursor.executemany(insert_statement, data) + self.connection.commit() + # verify + self.cursor.execute('select * from %s' % self.table) + l = self.cursor.fetchall() + if self.debug: + print l + self.assertEquals(len(l), self.rows) + try: + for i in range(self.rows): + for j in range(len(columndefs)): + self.assertEquals(l[i][j], generator(i,j)) + finally: + if not self.debug: + self.cursor.execute('drop table %s' % (self.table)) + + def test_transactions(self): + columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + def generator(row, col): + if col == 0: return row + else: return ('%i' % (row%10))*255 + self.create_table(columndefs) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) + data = [ [ generator(i,j) for j in range(len(columndefs)) ] + for i in range(self.rows) ] + self.cursor.executemany(insert_statement, data) + # verify + self.connection.commit() + self.cursor.execute('select * from %s' % self.table) + l = self.cursor.fetchall() + self.assertEquals(len(l), self.rows) + for i in range(self.rows): + for j in range(len(columndefs)): + self.assertEquals(l[i][j], generator(i,j)) + delete_statement = 'delete from %s where col1=%%s' % self.table + self.cursor.execute(delete_statement, (0,)) + self.cursor.execute('select col1 from %s where col1=%s' % \ + (self.table, 0)) + l = self.cursor.fetchall() + self.assertFalse(l, "DELETE didn't work") + self.connection.rollback() + self.cursor.execute('select col1 from %s where col1=%s' % \ + (self.table, 0)) + l = self.cursor.fetchall() + self.assertTrue(len(l) == 1, "ROLLBACK didn't work") + self.cursor.execute('drop table %s' % (self.table)) + + def test_truncation(self): + columndefs = ( 'col1 INT', 'col2 VARCHAR(255)') + def generator(row, col): + if col == 0: return row + else: return ('%i' % (row%10))*((255-self.rows/2)+row) + self.create_table(columndefs) + insert_statement = ('INSERT INTO %s VALUES (%s)' % + (self.table, + ','.join(['%s'] * len(columndefs)))) + + try: + self.cursor.execute(insert_statement, (0, '0'*256)) + except Warning: + if self.debug: print self.cursor.messages + except self.connection.DataError: + pass + else: + self.fail("Over-long column did not generate warnings/exception with single insert") + + self.connection.rollback() + + try: + for i in range(self.rows): + data = [] + for j in range(len(columndefs)): + data.append(generator(i,j)) + self.cursor.execute(insert_statement,tuple(data)) + except Warning: + if self.debug: print self.cursor.messages + except self.connection.DataError: + pass + else: + self.fail("Over-long columns did not generate warnings/exception with execute()") + + self.connection.rollback() + + try: + data = [ [ generator(i,j) for j in range(len(columndefs)) ] + for i in range(self.rows) ] + self.cursor.executemany(insert_statement, data) + except Warning: + if self.debug: print self.cursor.messages + except self.connection.DataError: + pass + else: + self.fail("Over-long columns did not generate warnings/exception with executemany()") + + self.connection.rollback() + self.cursor.execute('drop table %s' % (self.table)) + + def test_CHAR(self): + # Character data + def generator(row,col): + return ('%i' % ((row+col) % 10)) * 255 + self.check_data_integrity( + ('col1 char(255)','col2 char(255)'), + generator) + + def test_INT(self): + # Number data + def generator(row,col): + return row*row + self.check_data_integrity( + ('col1 INT',), + generator) + + def test_DECIMAL(self): + # DECIMAL + def generator(row,col): + from decimal import Decimal + return Decimal("%d.%02d" % (row, col)) + self.check_data_integrity( + ('col1 DECIMAL(5,2)',), + generator) + + def test_DATE(self): + ticks = time() + def generator(row,col): + return self.db_module.DateFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 DATE',), + generator) + + def test_TIME(self): + ticks = time() + def generator(row,col): + return self.db_module.TimeFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 TIME',), + generator) + + def test_DATETIME(self): + ticks = time() + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 DATETIME',), + generator) + + def test_TIMESTAMP(self): + ticks = time() + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313) + self.check_data_integrity( + ('col1 TIMESTAMP',), + generator) + + def test_fractional_TIMESTAMP(self): + ticks = time() + def generator(row,col): + return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0) + self.check_data_integrity( + ('col1 TIMESTAMP',), + generator) + + def test_LONG(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBUText # 'BLOB Text ' * 1024 + self.check_data_integrity( + ('col1 INT', 'col2 LONG'), + generator) + + def test_TEXT(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBUText[:5192] # 'BLOB Text ' * 1024 + self.check_data_integrity( + ('col1 INT', 'col2 TEXT'), + generator) + + def test_LONG_BYTE(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + self.check_data_integrity( + ('col1 INT','col2 LONG BYTE'), + generator) + + def test_BLOB(self): + def generator(row,col): + if col == 0: + return row + else: + return self.BLOBBinary # 'BLOB\000Binary ' * 1024 + self.check_data_integrity( + ('col1 INT','col2 BLOB'), + generator) + diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py new file mode 100644 index 00000000000..a419e34a46c --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py @@ -0,0 +1,853 @@ +#!/usr/bin/env python +''' Python DB API 2.0 driver compliance unit test suite. + + This software is Public Domain and may be used without restrictions. + + "Now we have booze and barflies entering the discussion, plus rumours of + DBAs on drugs... and I won't tell you what flashes through my mind each + time I read the subject line with 'Anal Compliance' in it. All around + this is turning out to be a thoroughly unwholesome unit test." + + -- Ian Bicking +''' + +__rcs_id__ = '$Id$' +__version__ = '$Revision$'[11:-2] +__author__ = 'Stuart Bishop ' + +import unittest +import time + +# $Log$ +# Revision 1.1.2.1 2006/02/25 03:44:32 adustman +# Generic DB-API unit test module +# +# Revision 1.10 2003/10/09 03:14:14 zenzen +# Add test for DB API 2.0 optional extension, where database exceptions +# are exposed as attributes on the Connection object. +# +# Revision 1.9 2003/08/13 01:16:36 zenzen +# Minor tweak from Stefan Fleiter +# +# Revision 1.8 2003/04/10 00:13:25 zenzen +# Changes, as per suggestions by M.-A. Lemburg +# - Add a table prefix, to ensure namespace collisions can always be avoided +# +# Revision 1.7 2003/02/26 23:33:37 zenzen +# Break out DDL into helper functions, as per request by David Rushby +# +# Revision 1.6 2003/02/21 03:04:33 zenzen +# Stuff from Henrik Ekelund: +# added test_None +# added test_nextset & hooks +# +# Revision 1.5 2003/02/17 22:08:43 zenzen +# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize +# defaults to 1 & generic cursor.callproc test added +# +# Revision 1.4 2003/02/15 00:16:33 zenzen +# Changes, as per suggestions and bug reports by M.-A. Lemburg, +# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar +# - Class renamed +# - Now a subclass of TestCase, to avoid requiring the driver stub +# to use multiple inheritance +# - Reversed the polarity of buggy test in test_description +# - Test exception heirarchy correctly +# - self.populate is now self._populate(), so if a driver stub +# overrides self.ddl1 this change propogates +# - VARCHAR columns now have a width, which will hopefully make the +# DDL even more portible (this will be reversed if it causes more problems) +# - cursor.rowcount being checked after various execute and fetchXXX methods +# - Check for fetchall and fetchmany returning empty lists after results +# are exhausted (already checking for empty lists if select retrieved +# nothing +# - Fix bugs in test_setoutputsize_basic and test_setinputsizes +# + +class DatabaseAPI20Test(unittest.TestCase): + ''' Test a database self.driver for DB API 2.0 compatibility. + This implementation tests Gadfly, but the TestCase + is structured so that other self.drivers can subclass this + test case to ensure compiliance with the DB-API. It is + expected that this TestCase may be expanded in the future + if ambiguities or edge conditions are discovered. + + The 'Optional Extensions' are not yet being tested. + + self.drivers should subclass this test, overriding setUp, tearDown, + self.driver, connect_args and connect_kw_args. Class specification + should be as follows: + + import dbapi20 + class mytest(dbapi20.DatabaseAPI20Test): + [...] + + Don't 'import DatabaseAPI20Test from dbapi20', or you will + confuse the unit tester - just 'import dbapi20'. + ''' + + # The self.driver module. This should be the module where the 'connect' + # method is to be found + driver = None + connect_args = () # List of arguments to pass to connect + connect_kw_args = {} # Keyword arguments for connect + table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables + + ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix + ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix + xddl1 = 'drop table %sbooze' % table_prefix + xddl2 = 'drop table %sbarflys' % table_prefix + + lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase + + # Some drivers may need to override these helpers, for example adding + # a 'commit' after the execute. + def executeDDL1(self,cursor): + cursor.execute(self.ddl1) + + def executeDDL2(self,cursor): + cursor.execute(self.ddl2) + + def setUp(self): + ''' self.drivers should override this method to perform required setup + if any is necessary, such as creating the database. + ''' + pass + + def tearDown(self): + ''' self.drivers should override this method to perform required cleanup + if any is necessary, such as deleting the test database. + The default drops the tables that may be created. + ''' + con = self._connect() + try: + cur = con.cursor() + for ddl in (self.xddl1,self.xddl2): + try: + cur.execute(ddl) + con.commit() + except self.driver.Error: + # Assume table didn't exist. Other tests will check if + # execute is busted. + pass + finally: + con.close() + + def _connect(self): + try: + return self.driver.connect( + *self.connect_args,**self.connect_kw_args + ) + except AttributeError: + self.fail("No connect method found in self.driver module") + + def test_connect(self): + con = self._connect() + con.close() + + def test_apilevel(self): + try: + # Must exist + apilevel = self.driver.apilevel + # Must equal 2.0 + self.assertEqual(apilevel,'2.0') + except AttributeError: + self.fail("Driver doesn't define apilevel") + + def test_threadsafety(self): + try: + # Must exist + threadsafety = self.driver.threadsafety + # Must be a valid value + self.assertTrue(threadsafety in (0,1,2,3)) + except AttributeError: + self.fail("Driver doesn't define threadsafety") + + def test_paramstyle(self): + try: + # Must exist + paramstyle = self.driver.paramstyle + # Must be a valid value + self.assertTrue(paramstyle in ( + 'qmark','numeric','named','format','pyformat' + )) + except AttributeError: + self.fail("Driver doesn't define paramstyle") + + def test_Exceptions(self): + # Make sure required exceptions exist, and are in the + # defined heirarchy. + self.assertTrue(issubclass(self.driver.Warning,StandardError)) + self.assertTrue(issubclass(self.driver.Error,StandardError)) + self.assertTrue( + issubclass(self.driver.InterfaceError,self.driver.Error) + ) + self.assertTrue( + issubclass(self.driver.DatabaseError,self.driver.Error) + ) + self.assertTrue( + issubclass(self.driver.OperationalError,self.driver.Error) + ) + self.assertTrue( + issubclass(self.driver.IntegrityError,self.driver.Error) + ) + self.assertTrue( + issubclass(self.driver.InternalError,self.driver.Error) + ) + self.assertTrue( + issubclass(self.driver.ProgrammingError,self.driver.Error) + ) + self.assertTrue( + issubclass(self.driver.NotSupportedError,self.driver.Error) + ) + + def test_ExceptionsAsConnectionAttributes(self): + # OPTIONAL EXTENSION + # Test for the optional DB API 2.0 extension, where the exceptions + # are exposed as attributes on the Connection object + # I figure this optional extension will be implemented by any + # driver author who is using this test suite, so it is enabled + # by default. + con = self._connect() + drv = self.driver + self.assertTrue(con.Warning is drv.Warning) + self.assertTrue(con.Error is drv.Error) + self.assertTrue(con.InterfaceError is drv.InterfaceError) + self.assertTrue(con.DatabaseError is drv.DatabaseError) + self.assertTrue(con.OperationalError is drv.OperationalError) + self.assertTrue(con.IntegrityError is drv.IntegrityError) + self.assertTrue(con.InternalError is drv.InternalError) + self.assertTrue(con.ProgrammingError is drv.ProgrammingError) + self.assertTrue(con.NotSupportedError is drv.NotSupportedError) + + + def test_commit(self): + con = self._connect() + try: + # Commit must work, even if it doesn't do anything + con.commit() + finally: + con.close() + + def test_rollback(self): + con = self._connect() + # If rollback is defined, it should either work or throw + # the documented exception + if hasattr(con,'rollback'): + try: + con.rollback() + except self.driver.NotSupportedError: + pass + + def test_cursor(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + def test_cursor_isolation(self): + con = self._connect() + try: + # Make sure cursors created from the same connection have + # the documented transaction isolation level + cur1 = con.cursor() + cur2 = con.cursor() + self.executeDDL1(cur1) + cur1.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + cur2.execute("select name from %sbooze" % self.table_prefix) + booze = cur2.fetchall() + self.assertEqual(len(booze),1) + self.assertEqual(len(booze[0]),1) + self.assertEqual(booze[0][0],'Victoria Bitter') + finally: + con.close() + + def test_description(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.description,None, + 'cursor.description should be none after executing a ' + 'statement that can return no rows (such as DDL)' + ) + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(len(cur.description),1, + 'cursor.description describes too many columns' + ) + self.assertEqual(len(cur.description[0]),7, + 'cursor.description[x] tuples must have 7 elements' + ) + self.assertEqual(cur.description[0][0].lower(),'name', + 'cursor.description[x][0] must return column name' + ) + self.assertEqual(cur.description[0][1],self.driver.STRING, + 'cursor.description[x][1] must return column type. Got %r' + % cur.description[0][1] + ) + + # Make sure self.description gets reset + self.executeDDL2(cur) + self.assertEqual(cur.description,None, + 'cursor.description not being set to None when executing ' + 'no-result statements (eg. DDL)' + ) + finally: + con.close() + + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount should be -1 after executing no-result ' + 'statements' + ) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertTrue(cur.rowcount in (-1,1), + 'cursor.rowcount should == number or rows inserted, or ' + 'set to -1 after executing an insert statement' + ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertTrue(cur.rowcount in (-1,1), + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) + self.executeDDL2(cur) + self.assertEqual(cur.rowcount,-1, + 'cursor.rowcount not being reset to -1 after executing ' + 'no-result statements' + ) + finally: + con.close() + + lower_func = 'lower' + def test_callproc(self): + con = self._connect() + try: + cur = con.cursor() + if self.lower_func and hasattr(cur,'callproc'): + r = cur.callproc(self.lower_func,('FOO',)) + self.assertEqual(len(r),1) + self.assertEqual(r[0],'FOO') + r = cur.fetchall() + self.assertEqual(len(r),1,'callproc produced no result set') + self.assertEqual(len(r[0]),1, + 'callproc produced invalid result set' + ) + self.assertEqual(r[0][0],'foo', + 'callproc produced invalid results' + ) + finally: + con.close() + + def test_close(self): + con = self._connect() + try: + cur = con.cursor() + finally: + con.close() + + # cursor.execute should raise an Error if called after connection + # closed + self.assertRaises(self.driver.Error,self.executeDDL1,cur) + + # connection.commit should raise an Error if called after connection' + # closed.' + self.assertRaises(self.driver.Error,con.commit) + + # connection.close should raise an Error if called more than once + self.assertRaises(self.driver.Error,con.close) + + def test_execute(self): + con = self._connect() + try: + cur = con.cursor() + self._paraminsert(cur) + finally: + con.close() + + def _paraminsert(self,cur): + self.executeDDL1(cur) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertTrue(cur.rowcount in (-1,1)) + + if self.driver.paramstyle == 'qmark': + cur.execute( + 'insert into %sbooze values (?)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'numeric': + cur.execute( + 'insert into %sbooze values (:1)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'named': + cur.execute( + 'insert into %sbooze values (:beer)' % self.table_prefix, + {'beer':"Cooper's"} + ) + elif self.driver.paramstyle == 'format': + cur.execute( + 'insert into %sbooze values (%%s)' % self.table_prefix, + ("Cooper's",) + ) + elif self.driver.paramstyle == 'pyformat': + cur.execute( + 'insert into %sbooze values (%%(beer)s)' % self.table_prefix, + {'beer':"Cooper's"} + ) + else: + self.fail('Invalid paramstyle') + self.assertTrue(cur.rowcount in (-1,1)) + + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2,'cursor.fetchall returned too few rows') + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Cooper's", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + self.assertEqual(beers[1],"Victoria Bitter", + 'cursor.fetchall retrieved incorrect data, or data inserted ' + 'incorrectly' + ) + + def test_executemany(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + largs = [ ("Cooper's",) , ("Boag's",) ] + margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ] + if self.driver.paramstyle == 'qmark': + cur.executemany( + 'insert into %sbooze values (?)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'numeric': + cur.executemany( + 'insert into %sbooze values (:1)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'named': + cur.executemany( + 'insert into %sbooze values (:beer)' % self.table_prefix, + margs + ) + elif self.driver.paramstyle == 'format': + cur.executemany( + 'insert into %sbooze values (%%s)' % self.table_prefix, + largs + ) + elif self.driver.paramstyle == 'pyformat': + cur.executemany( + 'insert into %sbooze values (%%(beer)s)' % ( + self.table_prefix + ), + margs + ) + else: + self.fail('Unknown paramstyle') + self.assertTrue(cur.rowcount in (-1,2), + 'insert using cursor.executemany set cursor.rowcount to ' + 'incorrect value %r' % cur.rowcount + ) + cur.execute('select name from %sbooze' % self.table_prefix) + res = cur.fetchall() + self.assertEqual(len(res),2, + 'cursor.fetchall retrieved incorrect number of rows' + ) + beers = [res[0][0],res[1][0]] + beers.sort() + self.assertEqual(beers[0],"Boag's",'incorrect data retrieved') + self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved') + finally: + con.close() + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error,cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + self.executeDDL1(cur) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.assertTrue(cur.rowcount in (-1,0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) + self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r),1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0],'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if no more rows available' + ) + self.assertTrue(cur.rowcount in (-1,1)) + finally: + con.close() + + samples = [ + 'Carlton Cold', + 'Carlton Draft', + 'Mountain Goat', + 'Redback', + 'Victoria Bitter', + 'XXXX' + ] + + def _populate(self): + ''' Return a list of sql commands to setup the DB for the fetch + tests. + ''' + populate = [ + "insert into %sbooze values ('%s')" % (self.table_prefix,s) + for s in self.samples + ] + return populate + + def test_fetchmany(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchmany should raise an Error if called without + #issuing a query + self.assertRaises(self.driver.Error,cur.fetchmany,4) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() + self.assertEqual(len(r),1, + 'cursor.fetchmany retrieved incorrect number of rows, ' + 'default of arraysize is one.' + ) + cur.arraysize=10 + r = cur.fetchmany(3) # Should get 3 rows + self.assertEqual(len(r),3, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should get 2 more + self.assertEqual(len(r),2, + 'cursor.fetchmany retrieved incorrect number of rows' + ) + r = cur.fetchmany(4) # Should be an empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence after ' + 'results are exhausted' + ) + self.assertTrue(cur.rowcount in (-1,6)) + + # Same as above, using cursor.arraysize + cur.arraysize=4 + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchmany() # Should get 4 rows + self.assertEqual(len(r),4, + 'cursor.arraysize not being honoured by fetchmany' + ) + r = cur.fetchmany() # Should get 2 more + self.assertEqual(len(r),2) + r = cur.fetchmany() # Should be an empty sequence + self.assertEqual(len(r),0) + self.assertTrue(cur.rowcount in (-1,6)) + + cur.arraysize=6 + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchmany() # Should get all rows + self.assertTrue(cur.rowcount in (-1,6)) + self.assertEqual(len(rows),6) + self.assertEqual(len(rows),6) + rows = [r[0] for r in rows] + rows.sort() + + # Make sure we get the right data back out + for i in range(0,6): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved by cursor.fetchmany' + ) + + rows = cur.fetchmany() # Should return an empty list + self.assertEqual(len(rows),0, + 'cursor.fetchmany should return an empty sequence if ' + 'called after the whole result set has been fetched' + ) + self.assertTrue(cur.rowcount in (-1,6)) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + r = cur.fetchmany() # Should get empty sequence + self.assertEqual(len(r),0, + 'cursor.fetchmany should return an empty sequence if ' + 'query retrieved no rows' + ) + self.assertTrue(cur.rowcount in (-1,0)) + + finally: + con.close() + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows + self.assertRaises(self.driver.Error,cur.fetchall) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,len(self.samples))) + self.assertEqual(len(rows),len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) + rows = cur.fetchall() + self.assertEqual( + len(rows),0, + 'cursor.fetchall should return an empty list if called ' + 'after the whole result set has been fetched' + ) + self.assertTrue(cur.rowcount in (-1,len(self.samples))) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,0)) + self.assertEqual(len(rows),0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) + + finally: + con.close() + + def test_mixedfetch(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows1 = cur.fetchone() + rows23 = cur.fetchmany(2) + rows4 = cur.fetchone() + rows56 = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,6)) + self.assertEqual(len(rows23),2, + 'fetchmany returned incorrect number of rows' + ) + self.assertEqual(len(rows56),2, + 'fetchall returned incorrect number of rows' + ) + + rows = [rows1[0]] + rows.extend([rows23[0][0],rows23[1][0]]) + rows.append(rows4[0]) + rows.extend([rows56[0][0],rows56[1][0]]) + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'incorrect data retrieved or inserted' + ) + finally: + con.close() + + def help_nextset_setUp(self,cur): + ''' Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + raise NotImplementedError,'Helper not implemented' + #sql=""" + # create procedure deleteme as + # begin + # select count(*) from booze + # select name from booze + # end + #""" + #cur.execute(sql) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + raise NotImplementedError,'Helper not implemented' + #cur.execute("drop procedure deleteme") + + def test_nextset(self): + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc('deleteme') + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + assert s == None,'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + def test_nextset(self): + raise NotImplementedError,'Drivers need to override this test' + + def test_arraysize(self): + # Not much here - rest of the tests for this are in test_fetchmany + con = self._connect() + try: + cur = con.cursor() + self.assertTrue(hasattr(cur,'arraysize'), + 'cursor.arraysize must be defined' + ) + finally: + con.close() + + def test_setinputsizes(self): + con = self._connect() + try: + cur = con.cursor() + cur.setinputsizes( (25,) ) + self._paraminsert(cur) # Make sure cursor still works + finally: + con.close() + + def test_setoutputsize_basic(self): + # Basic test is to make sure setoutputsize doesn't blow up + con = self._connect() + try: + cur = con.cursor() + cur.setoutputsize(1000) + cur.setoutputsize(2000,0) + self._paraminsert(cur) # Make sure the cursor still works + finally: + con.close() + + def test_setoutputsize(self): + # Real test for setoutputsize is driver dependant + raise NotImplementedError,'Driver need to override this test' + + def test_None(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) + cur.execute('insert into %sbooze values (NULL)' % self.table_prefix) + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchall() + self.assertEqual(len(r),1) + self.assertEqual(len(r[0]),1) + self.assertEqual(r[0][0],None,'NULL value not returned as None') + finally: + con.close() + + def test_Date(self): + d1 = self.driver.Date(2002,12,25) + d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(d1),str(d2)) + + def test_Time(self): + t1 = self.driver.Time(13,45,30) + t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0))) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Timestamp(self): + t1 = self.driver.Timestamp(2002,12,25,13,45,30) + t2 = self.driver.TimestampFromTicks( + time.mktime((2002,12,25,13,45,30,0,0,0)) + ) + # Can we assume this? API doesn't specify, but it seems implied + # self.assertEqual(str(t1),str(t2)) + + def test_Binary(self): + b = self.driver.Binary('Something') + b = self.driver.Binary('') + + def test_STRING(self): + self.assertTrue(hasattr(self.driver,'STRING'), + 'module.STRING must be defined' + ) + + def test_BINARY(self): + self.assertTrue(hasattr(self.driver,'BINARY'), + 'module.BINARY must be defined.' + ) + + def test_NUMBER(self): + self.assertTrue(hasattr(self.driver,'NUMBER'), + 'module.NUMBER must be defined.' + ) + + def test_DATETIME(self): + self.assertTrue(hasattr(self.driver,'DATETIME'), + 'module.DATETIME must be defined.' + ) + + def test_ROWID(self): + self.assertTrue(hasattr(self.driver,'ROWID'), + 'module.ROWID must be defined.' + ) + diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py new file mode 100644 index 00000000000..e0bc93439c2 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py @@ -0,0 +1,115 @@ +#!/usr/bin/env python +import capabilities +import unittest +import pymysql +from pymysql.tests import base +import warnings + +warnings.filterwarnings('error') + +class test_MySQLdb(capabilities.DatabaseTest): + + db_module = pymysql + connect_args = () + connect_kwargs = base.PyMySQLTestCase.databases[0].copy() + connect_kwargs.update(dict(read_default_file='~/.my.cnf', + use_unicode=True, + charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + + create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8" + leak_test = False + + def quote_identifier(self, ident): + return "`%s`" % ident + + def test_TIME(self): + from datetime import timedelta + def generator(row,col): + return timedelta(0, row*8000) + self.check_data_integrity( + ('col1 TIME',), + generator) + + def test_TINYINT(self): + # Number data + def generator(row,col): + v = (row*row) % 256 + if v > 127: + v = v-256 + return v + self.check_data_integrity( + ('col1 TINYINT',), + generator) + + def test_stored_procedures(self): + db = self.connection + c = self.cursor + try: + self.create_table(('pos INT', 'tree CHAR(20)')) + c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table, + list(enumerate('ash birch cedar larch pine'.split()))) + db.commit() + + c.execute(""" + CREATE PROCEDURE test_sp(IN t VARCHAR(255)) + BEGIN + SELECT pos FROM %s WHERE tree = t; + END + """ % self.table) + db.commit() + + c.callproc('test_sp', ('larch',)) + rows = c.fetchall() + self.assertEquals(len(rows), 1) + self.assertEquals(rows[0][0], 3) + c.nextset() + finally: + c.execute("DROP PROCEDURE IF EXISTS test_sp") + c.execute('drop table %s' % (self.table)) + + def test_small_CHAR(self): + # Character data + def generator(row,col): + i = ((row+1)*(col+1)+62)%256 + if i == 62: return '' + if i == 63: return None + return chr(i) + self.check_data_integrity( + ('col1 char(1)','col2 char(1)'), + generator) + + def test_bug_2671682(self): + from pymysql.constants import ER + try: + self.cursor.execute("describe some_non_existent_table"); + except self.connection.ProgrammingError, msg: + self.assertTrue(msg.args[0] == ER.NO_SUCH_TABLE) + + def test_insert_values(self): + from pymysql.cursors import insert_values + query = """INSERT FOO (a, b, c) VALUES (a, b, c)""" + matched = insert_values.search(query) + self.assertTrue(matched) + values = matched.group(1) + self.assertTrue(values == "(a, b, c)") + + def test_ping(self): + self.connection.ping() + + def test_literal_int(self): + self.assertTrue("2" == self.connection.literal(2)) + + def test_literal_float(self): + self.assertTrue("3.1415" == self.connection.literal(3.1415)) + + def test_literal_string(self): + self.assertTrue("'foo'" == self.connection.literal("foo")) + + +if __name__ == '__main__': + if test_MySQLdb.leak_test: + import gc + gc.enable() + gc.set_debug(gc.DEBUG_LEAK) + unittest.main() + diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py new file mode 100644 index 00000000000..83c002fdf39 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py @@ -0,0 +1,205 @@ +#!/usr/bin/env python +import dbapi20 +import unittest +import pymysql +from pymysql.tests import base + +class test_MySQLdb(dbapi20.DatabaseAPI20Test): + driver = pymysql + connect_args = () + connect_kw_args = base.PyMySQLTestCase.databases[0].copy() + connect_kw_args.update(dict(read_default_file='~/.my.cnf', + charset='utf8', + sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL")) + + def test_setoutputsize(self): pass + def test_setoutputsize_basic(self): pass + def test_nextset(self): pass + + """The tests on fetchone and fetchall and rowcount bogusly + test for an exception if the statement cannot return a + result set. MySQL always returns a result set; it's just that + some things return empty result sets.""" + + def test_fetchall(self): + con = self._connect() + try: + cur = con.cursor() + # cursor.fetchall should raise an Error if called + # without executing a query that may return rows (such + # as a select) + self.assertRaises(self.driver.Error, cur.fetchall) + + self.executeDDL1(cur) + for sql in self._populate(): + cur.execute(sql) + + # cursor.fetchall should raise an Error if called + # after executing a a statement that cannot return rows +## self.assertRaises(self.driver.Error,cur.fetchall) + + cur.execute('select name from %sbooze' % self.table_prefix) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,len(self.samples))) + self.assertEqual(len(rows),len(self.samples), + 'cursor.fetchall did not retrieve all rows' + ) + rows = [r[0] for r in rows] + rows.sort() + for i in range(0,len(self.samples)): + self.assertEqual(rows[i],self.samples[i], + 'cursor.fetchall retrieved incorrect rows' + ) + rows = cur.fetchall() + self.assertEqual( + len(rows),0, + 'cursor.fetchall should return an empty list if called ' + 'after the whole result set has been fetched' + ) + self.assertTrue(cur.rowcount in (-1,len(self.samples))) + + self.executeDDL2(cur) + cur.execute('select name from %sbarflys' % self.table_prefix) + rows = cur.fetchall() + self.assertTrue(cur.rowcount in (-1,0)) + self.assertEqual(len(rows),0, + 'cursor.fetchall should return an empty list if ' + 'a select query returns no rows' + ) + + finally: + con.close() + + def test_fetchone(self): + con = self._connect() + try: + cur = con.cursor() + + # cursor.fetchone should raise an Error if called before + # executing a select-type query + self.assertRaises(self.driver.Error,cur.fetchone) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + self.executeDDL1(cur) +## self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + self.assertEqual(cur.fetchone(),None, + 'cursor.fetchone should return None if a query retrieves ' + 'no rows' + ) + self.assertTrue(cur.rowcount in (-1,0)) + + # cursor.fetchone should raise an Error if called after + # executing a query that cannnot return rows + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) +## self.assertRaises(self.driver.Error,cur.fetchone) + + cur.execute('select name from %sbooze' % self.table_prefix) + r = cur.fetchone() + self.assertEqual(len(r),1, + 'cursor.fetchone should have retrieved a single row' + ) + self.assertEqual(r[0],'Victoria Bitter', + 'cursor.fetchone retrieved incorrect data' + ) +## self.assertEqual(cur.fetchone(),None, +## 'cursor.fetchone should return None if no more rows available' +## ) + self.assertTrue(cur.rowcount in (-1,1)) + finally: + con.close() + + # Same complaint as for fetchall and fetchone + def test_rowcount(self): + con = self._connect() + try: + cur = con.cursor() + self.executeDDL1(cur) +## self.assertEqual(cur.rowcount,-1, +## 'cursor.rowcount should be -1 after executing no-result ' +## 'statements' +## ) + cur.execute("insert into %sbooze values ('Victoria Bitter')" % ( + self.table_prefix + )) +## self.assertTrue(cur.rowcount in (-1,1), +## 'cursor.rowcount should == number or rows inserted, or ' +## 'set to -1 after executing an insert statement' +## ) + cur.execute("select name from %sbooze" % self.table_prefix) + self.assertTrue(cur.rowcount in (-1,1), + 'cursor.rowcount should == number of rows returned, or ' + 'set to -1 after executing a select statement' + ) + self.executeDDL2(cur) +## self.assertEqual(cur.rowcount,-1, +## 'cursor.rowcount not being reset to -1 after executing ' +## 'no-result statements' +## ) + finally: + con.close() + + def test_callproc(self): + pass # performed in test_MySQL_capabilities + + def help_nextset_setUp(self,cur): + ''' Should create a procedure called deleteme + that returns two result sets, first the + number of rows in booze then "name from booze" + ''' + sql=""" + create procedure deleteme() + begin + select count(*) from %(tp)sbooze; + select name from %(tp)sbooze; + end + """ % dict(tp=self.table_prefix) + cur.execute(sql) + + def help_nextset_tearDown(self,cur): + 'If cleaning up is needed after nextSetTest' + cur.execute("drop procedure deleteme") + + def test_nextset(self): + from warnings import warn + con = self._connect() + try: + cur = con.cursor() + if not hasattr(cur,'nextset'): + return + + try: + self.executeDDL1(cur) + sql=self._populate() + for sql in self._populate(): + cur.execute(sql) + + self.help_nextset_setUp(cur) + + cur.callproc('deleteme') + numberofrows=cur.fetchone() + assert numberofrows[0]== len(self.samples) + assert cur.nextset() + names=cur.fetchall() + assert len(names) == len(self.samples) + s=cur.nextset() + if s: + empty = cur.fetchall() + self.assertEquals(len(empty), 0, + "non-empty result set after other result sets") + #warn("Incompatibility: MySQL returns an empty result set for the CALL itself", + # Warning) + #assert s == None,'No more return sets, should return None' + finally: + self.help_nextset_tearDown(cur) + + finally: + con.close() + + +if __name__ == '__main__': + unittest.main() diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py new file mode 100644 index 00000000000..f49369cb4f7 --- /dev/null +++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py @@ -0,0 +1,90 @@ +import unittest + +import pymysql +_mysql = pymysql +from pymysql.constants import FIELD_TYPE +from pymysql.tests import base + + +class TestDBAPISet(unittest.TestCase): + def test_set_equality(self): + self.assertTrue(pymysql.STRING == pymysql.STRING) + + def test_set_inequality(self): + self.assertTrue(pymysql.STRING != pymysql.NUMBER) + + def test_set_equality_membership(self): + self.assertTrue(FIELD_TYPE.VAR_STRING == pymysql.STRING) + + def test_set_inequality_membership(self): + self.assertTrue(FIELD_TYPE.DATE != pymysql.STRING) + + +class CoreModule(unittest.TestCase): + """Core _mysql module features.""" + + def test_NULL(self): + """Should have a NULL constant.""" + self.assertEqual(_mysql.NULL, 'NULL') + + def test_version(self): + """Version information sanity.""" + self.assertTrue(isinstance(_mysql.__version__, str)) + + self.assertTrue(isinstance(_mysql.version_info, tuple)) + self.assertEqual(len(_mysql.version_info), 5) + + def test_client_info(self): + self.assertTrue(isinstance(_mysql.get_client_info(), str)) + + def test_thread_safe(self): + self.assertTrue(isinstance(_mysql.thread_safe(), int)) + + +class CoreAPI(unittest.TestCase): + """Test _mysql interaction internals.""" + + def setUp(self): + kwargs = base.PyMySQLTestCase.databases[0].copy() + kwargs["read_default_file"] = "~/.my.cnf" + self.conn = _mysql.connect(**kwargs) + + def tearDown(self): + self.conn.close() + + def test_thread_id(self): + tid = self.conn.thread_id() + self.assertTrue(isinstance(tid, int), + "thread_id didn't return an int.") + + self.assertRaises(TypeError, self.conn.thread_id, ('evil',), + "thread_id shouldn't accept arguments.") + + def test_affected_rows(self): + self.assertEquals(self.conn.affected_rows(), 0, + "Should return 0 before we do anything.") + + + #def test_debug(self): + ## FIXME Only actually tests if you lack SUPER + #self.assertRaises(pymysql.OperationalError, + #self.conn.dump_debug_info) + + def test_charset_name(self): + self.assertTrue(isinstance(self.conn.character_set_name(), str), + "Should return a string.") + + def test_host_info(self): + self.assertTrue(isinstance(self.conn.get_host_info(), str), + "Should return a string.") + + def test_proto_info(self): + self.assertTrue(isinstance(self.conn.get_proto_info(), int), + "Should return an int.") + + def test_server_info(self): + self.assertTrue(isinstance(self.conn.get_server_info(), basestring), + "Should return an str.") + +if __name__ == "__main__": + unittest.main() diff --git a/tools/marvin/marvin/pymysql/times.py b/tools/marvin/marvin/pymysql/times.py new file mode 100644 index 00000000000..c47db09eb9c --- /dev/null +++ b/tools/marvin/marvin/pymysql/times.py @@ -0,0 +1,16 @@ +from time import localtime +from datetime import date, datetime, time, timedelta + +Date = date +Time = time +TimeDelta = timedelta +Timestamp = datetime + +def DateFromTicks(ticks): + return date(*localtime(ticks)[:3]) + +def TimeFromTicks(ticks): + return time(*localtime(ticks)[3:6]) + +def TimestampFromTicks(ticks): + return datetime(*localtime(ticks)[:6]) diff --git a/tools/marvin/marvin/pymysql/util.py b/tools/marvin/marvin/pymysql/util.py new file mode 100644 index 00000000000..cc622e57b74 --- /dev/null +++ b/tools/marvin/marvin/pymysql/util.py @@ -0,0 +1,19 @@ +import struct + +def byte2int(b): + if isinstance(b, int): + return b + else: + return struct.unpack("!B", b)[0] + +def int2byte(i): + return struct.pack("!B", i) + +def join_bytes(bs): + if len(bs) == 0: + return "" + else: + rv = bs[0] + for b in bs[1:]: + rv += b + return rv diff --git a/tools/marvin/marvin/remoteSSHClient.py b/tools/marvin/marvin/remoteSSHClient.py new file mode 100644 index 00000000000..806de7a9eb8 --- /dev/null +++ b/tools/marvin/marvin/remoteSSHClient.py @@ -0,0 +1,36 @@ +import paramiko +import cloudstackException +class remoteSSHClient(object): + def __init__(self, host, port, user, passwd, timeout=120): + self.host = host + self.port = port + self.user = user + self.passwd = passwd + self.ssh = paramiko.SSHClient() + self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + try: + self.ssh.connect(str(host),int(port), user, passwd, timeout=timeout) + except paramiko.SSHException, sshex: + raise cloudstackException.InvalidParameterException(repr(sshex)) + + def execute(self, command): + stdin, stdout, stderr = self.ssh.exec_command(command) + output = stdout.readlines() + errors = stderr.readlines() + results = [] + if output is not None and len(output) == 0: + if errors is not None and len(errors) > 0: + for error in errors: + results.append(error.rstrip()) + + else: + for strOut in output: + results.append(strOut.rstrip()) + + return results + + +if __name__ == "__main__": + ssh = remoteSSHClient("192.168.137.2", 22, "root", "password") + print ssh.execute("ls -l") + print ssh.execute("rm x") diff --git a/tools/marvin/marvin/sandbox/README.txt b/tools/marvin/marvin/sandbox/README.txt new file mode 100644 index 00000000000..7efc190baf6 --- /dev/null +++ b/tools/marvin/marvin/sandbox/README.txt @@ -0,0 +1,19 @@ +Welcome to the marvin sandbox +---------------------------------- + +In here you should find a few common deployment models of CloudStack that you +can configure with properties files to suit your own deployment. One deployment +model for each of - advanced zone, basic zone and a demo are given. + +$ ls - +basic/ +advanced/ +demo/ + +Each property file is divided into logical sections and should be familiar to +those who have deployed CloudStack before. Once you have your properties file +you will have to create a JSON configuration of your deployment using the +python script provided in the respective folder. + +The demo files are from the tutorial for testing with python that can be found +on the wiki.cloudstack.org diff --git a/tools/marvin/marvin/sandbox/advanced/advanced_env.py b/tools/marvin/marvin/sandbox/advanced/advanced_env.py new file mode 100644 index 00000000000..25e83dbb0f4 --- /dev/null +++ b/tools/marvin/marvin/sandbox/advanced/advanced_env.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python + +''' +############################################################ +# Experimental state of scripts +# * Need to be reviewed +# * Only a sandbox +############################################################ +''' + +from ConfigParser import SafeConfigParser +from optparse import OptionParser +from configGenerator import * +import random + + +def getGlobalSettings(config): + for k, v in dict(config.items('globals')).iteritems(): + cfg = configuration() + cfg.name = k + cfg.value = v + yield cfg + + +def describeResources(config): + zs = cloudstackConfiguration() + + z = zone() + z.dns1 = config.get('environment', 'dns') + z.internaldns1 = config.get('environment', 'dns') + z.name = 'Sandbox-%s'%(config.get('environment', 'hypervisor')) + z.networktype = 'Advanced' + z.guestcidraddress = '10.1.1.0/24' + + p = pod() + p.name = 'POD0' + p.gateway = config.get('cloudstack', 'private.gateway') + p.startip = config.get('cloudstack', 'private.pod.startip') + p.endip = config.get('cloudstack', 'private.pod.endip') + p.netmask = '255.255.255.0' + + v = iprange() + v.gateway = config.get('cloudstack', 'public.gateway') + v.startip = config.get('cloudstack', 'public.vlan.startip') + v.endip = config.get('cloudstack', 'public.vlan.endip') + v.netmask = '255.255.255.0' + v.vlan = config.get('cloudstack', 'public.vlan') + z.ipranges.append(v) + + c = cluster() + c.clustername = 'C0' + c.hypervisor = config.get('environment', 'hypervisor') + c.clustertype = 'CloudManaged' + + h = host() + h.username = 'root' + h.password = 'password' + h.url = 'http://%s'%(config.get('cloudstack', 'host')) + c.hosts.append(h) + + ps = primaryStorage() + ps.name = 'PS0' + ps.url = config.get('cloudstack', 'pool') + c.primaryStorages.append(ps) + + p.clusters.append(c) + z.pods.append(p) + + secondary = secondaryStorage() + secondary.url = config.get('cloudstack', 'secondary') + z.secondaryStorages.append(secondary) + + '''Add zone''' + zs.zones.append(z) + + '''Add mgt server''' + mgt = managementServer() + mgt.mgtSvrIp = config.get('environment', 'mshost') + zs.mgtSvr.append(mgt) + + '''Add a database''' + db = dbServer() + db.dbSvr = config.get('environment', 'database') + zs.dbSvr = db + + '''Add some configuration''' + [zs.globalConfig.append(cfg) for cfg in getGlobalSettings(config)] + + ''''add loggers''' + testClientLogger = logger() + testClientLogger.name = 'TestClient' + testClientLogger.file = '/var/log/testclient.log' + + testCaseLogger = logger() + testCaseLogger.name = 'TestCase' + testCaseLogger.file = '/var/log/testcase.log' + + zs.logger.append(testClientLogger) + zs.logger.append(testCaseLogger) + return zs + + +if __name__ == '__main__': + parser = OptionParser() + parser.add_option('-i', '--input', action='store', default='setup.properties', \ + dest='input', help='file containing environment setup information') + parser.add_option('-o', '--output', action='store', default='./sandbox.cfg', \ + dest='output', help='path where environment json will be generated') + + + (opts, args) = parser.parse_args() + + cfg_parser = SafeConfigParser() + cfg_parser.read(opts.input) + + cfg = describeResources(cfg_parser) + generate_setup_config(cfg, opts.output) diff --git a/tools/marvin/marvin/sandbox/advanced/setup.properties b/tools/marvin/marvin/sandbox/advanced/setup.properties new file mode 100644 index 00000000000..48b082e4f37 --- /dev/null +++ b/tools/marvin/marvin/sandbox/advanced/setup.properties @@ -0,0 +1,36 @@ +[globals] +expunge.delay=60 +expunge.interval=60 +storage.cleanup.interval=300 +account.cleanup.interval=600 +expunge.workers=3 +workers=10 +use.user.concentrated.pod.allocation=false +vm.allocation.algorithm=random +vm.op.wait.interval=5 +guest.domain.suffix=sandbox.kvm +instance.name=QA +direct.agent.load.size=1000 +default.page.size=10000 +check.pod.cidrs=true +secstorage.allowed.internal.sites=10.147.28.0/24 +[environment] +dns=10.147.28.6 +mshost=10.147.29.111 +database=10.147.29.111 +[cloudstack] +zone.vlan=675-679 +#pod configuration +private.gateway=10.147.29.1 +private.pod.startip=10.147.29.150 +private.pod.endip=10.147.29.159 +#public vlan range +public.gateway=10.147.31.1 +public.vlan=31 +public.vlan.startip=10.147.31.150 +public.vlan.endip=10.147.31.159 +#hosts +host=10.147.29.58 +#pools +pool=nfs://10.147.28.6:/export/home/sandbox/kamakura +secondary=nfs://10.147.28.6:/export/home/sandbox/sstor diff --git a/tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py b/tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py new file mode 100644 index 00000000000..bae181ca693 --- /dev/null +++ b/tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python +try: + import unittest2 as unittest +except ImportError: + import unittest + +import random +import hashlib +from cloudstackTestCase import * +import remoteSSHClient + +class SampleScenarios(cloudstackTestCase): + ''' + ''' + def setUp(self): + pass + + + def tearDown(self): + pass + + + def test_1_createAccounts(self, numberOfAccounts=2): + ''' + Create a bunch of user accounts + ''' + mdf = hashlib.md5() + mdf.update('password') + mdf_pass = mdf.hexdigest() + api = self.testClient.getApiClient() + for i in range(1, numberOfAccounts + 1): + acct = createAccount.createAccountCmd() + acct.accounttype = 0 + acct.firstname = 'user' + str(i) + acct.lastname = 'user' + str(i) + acct.password = mdf_pass + acct.username = 'user' + str(i) + acct.email = 'user@example.com' + acct.account = 'user' + str(i) + acct.domainid = 1 + acctResponse = api.createAccount(acct) + self.debug("successfully created account: %s, user: %s, id: %s"%(acctResponse.account, acctResponse.username, acctResponse.id)) + + + def test_2_createServiceOffering(self): + apiClient = self.testClient.getApiClient() + createSOcmd=createServiceOffering.createServiceOfferingCmd() + createSOcmd.name='Sample SO' + createSOcmd.displaytext='Sample SO' + createSOcmd.storagetype='shared' + createSOcmd.cpunumber=1 + createSOcmd.cpuspeed=100 + createSOcmd.memory=128 + createSOcmd.offerha='false' + createSOresponse = apiClient.createServiceOffering(createSOcmd) + return createSOresponse.id + + def deployCmd(self, account, service): + deployVmCmd = deployVirtualMachine.deployVirtualMachineCmd() + deployVmCmd.zoneid = 1 + deployVmCmd.account=account + deployVmCmd.domainid=1 + deployVmCmd.templateid=2 + deployVmCmd.serviceofferingid=service + return deployVmCmd + + def listVmsInAccountCmd(self, acct): + api = self.testClient.getApiClient() + listVmCmd = listVirtualMachines.listVirtualMachinesCmd() + listVmCmd.account = acct + listVmCmd.zoneid = 1 + listVmCmd.domainid = 1 + listVmResponse = api.listVirtualMachines(listVmCmd) + return listVmResponse + + + def destroyVmCmd(self, key): + api = self.testClient.getApiClient() + destroyVmCmd = destroyVirtualMachine.destroyVirtualMachineCmd() + destroyVmCmd.id = key + api.destroyVirtualMachine(destroyVmCmd) + + + def test_3_stressDeploy(self): + ''' + Deploy 5 Vms in each account + ''' + service_id = self.test_2_createServiceOffering() + api = self.testClient.getApiClient() + for acct in range(1, 5): + [api.deployVirtualMachine(self.deployCmd('user'+str(acct), service_id)) for x in range(0,5)] + + @unittest.skip("skipping destroys") + def test_4_stressDestroy(self): + ''' + Cleanup all Vms in every account + ''' + api = self.testClient.getApiClient() + for acct in range(1, 6): + for vm in self.listVmsInAccountCmd('user'+str(acct)): + if vm is not None: + self.destroyVmCmd(vm.id) + + @unittest.skip("skipping destroys") + def test_5_combineStress(self): + for i in range(0, 5): + self.test_3_stressDeploy() + self.test_4_stressDestroy() + + def deployN(self,nargs=300,batchsize=0): + ''' + Deploy Nargs number of VMs concurrently in batches of size {batchsize}. + When batchsize is 0 all Vms are deployed in one batch + VMs will be deployed in 5:2:6 ratio + ''' + cmds = [] + + if batchsize == 0: + self.testClient.submitCmdsAndWait(cmds) + else: + while len(z) > 0: + try: + newbatch = [cmds.pop() for b in range(batchsize)] #pop batchsize items + self.testClient.submitCmdsAndWait(newbatch) + except IndexError: + break diff --git a/tools/marvin/marvin/sandbox/basic/basic_env.py b/tools/marvin/marvin/sandbox/basic/basic_env.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/tools/marvin/marvin/sandbox/demo/README b/tools/marvin/marvin/sandbox/demo/README new file mode 100644 index 00000000000..650ea05df1c --- /dev/null +++ b/tools/marvin/marvin/sandbox/demo/README @@ -0,0 +1,4 @@ +Demo files for use with the tutorial on "Testing with Python". + +testDeployVM.py - to be run against a 2.2.y installation of management server +testSshDeployVM.py - to be run against a 3.0.x installation of management server diff --git a/tools/marvin/marvin/sandbox/demo/testDeployVM.py b/tools/marvin/marvin/sandbox/demo/testDeployVM.py new file mode 100644 index 00000000000..4afaee346ad --- /dev/null +++ b/tools/marvin/marvin/sandbox/demo/testDeployVM.py @@ -0,0 +1,98 @@ +#!/usr/bin/env python +# Copyright 2012 Citrix Systems, Inc. Licensed under the +# Apache License, Version 2.0 (the "License"); you may not use this +# file except in compliance with the License. Citrix Systems, Inc. +# reserves all rights not expressly granted by the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Automatically generated by addcopyright.py at 04/03/2012 + +from cloudstackTestCase import * + +import unittest +import hashlib +import random + +class TestDeployVm(cloudstackTestCase): + """ + This test deploys a virtual machine into a user account + using the small service offering and builtin template + """ + def setUp(self): + """ + CloudStack internally saves its passwords in md5 form and that is how we + specify it in the API. Python's hashlib library helps us to quickly hash + strings as follows + """ + mdf = hashlib.md5() + mdf.update('password') + mdf_pass = mdf.hexdigest() + + self.apiClient = self.testClient.getApiClient() #Get ourselves an API client + + self.acct = createAccount.createAccountCmd() #The createAccount command + self.acct.accounttype = 0 #We need a regular user. admins have accounttype=1 + self.acct.firstname = 'bugs' + self.acct.lastname = 'bunny' #What's up doc? + self.acct.password = mdf_pass #The md5 hashed password string + self.acct.username = 'bugs' + self.acct.email = 'bugs@rabbithole.com' + self.acct.account = 'bugs' + self.acct.domainid = 1 #The default ROOT domain + self.acctResponse = self.apiClient.createAccount(self.acct) + #And upon successful creation we'll log a helpful message in our logs + self.debug("successfully created account: %s, user: %s, id: \ + %s"%(self.acctResponse.account.account, \ + self.acctResponse.account.username, \ + self.acctResponse.account.id)) + + def test_DeployVm(self): + """ + Let's start by defining the attributes of our VM that we will be + deploying on CloudStack. We will be assuming a single zone is available + and is configured and all templates are Ready + + The hardcoded values are used only for brevity. + """ + deployVmCmd = deployVirtualMachine.deployVirtualMachineCmd() + deployVmCmd.zoneid = 1 + deployVmCmd.account = self.acct.account + deployVmCmd.domainid = self.acct.domainid + deployVmCmd.templateid = 2 + deployVmCmd.serviceofferingid = 1 + + deployVmResponse = self.apiClient.deployVirtualMachine(deployVmCmd) + self.debug("VM %s was deployed in the job %s"%(deployVmResponse.id, deployVmResponse.jobid)) + + # At this point our VM is expected to be Running. Let's find out what + # listVirtualMachines tells us about VMs in this account + + listVmCmd = listVirtualMachines.listVirtualMachinesCmd() + listVmCmd.id = deployVmResponse.id + listVmResponse = self.apiClient.listVirtualMachines(listVmCmd) + + self.assertNotEqual(len(listVmResponse), 0, "Check if the list API \ + returns a non-empty response") + + vm = listVmResponse[0] + + self.assertEqual(vm.id, deployVmResponse.id, "Check if the VM returned \ + is the same as the one we deployed") + + + self.assertEqual(vm.state, "Running", "Check if VM has reached \ + a state of running") + + def tearDown(self): + """ + And finally let us cleanup the resources we created by deleting the + account. All good unittests are atomic and rerunnable this way + """ + deleteAcct = deleteAccount.deleteAccountCmd() + deleteAcct.id = self.acctResponse.account.id + self.apiClient.deleteAccount(deleteAcct) diff --git a/tools/marvin/marvin/sandbox/demo/testSshDeployVM.py b/tools/marvin/marvin/sandbox/demo/testSshDeployVM.py new file mode 100644 index 00000000000..106f693fda1 --- /dev/null +++ b/tools/marvin/marvin/sandbox/demo/testSshDeployVM.py @@ -0,0 +1,143 @@ +#!/usr/bin/env python +# Copyright 2012 Citrix Systems, Inc. Licensed under the +# Apache License, Version 2.0 (the "License"); you may not use this +# file except in compliance with the License. Citrix Systems, Inc. +# reserves all rights not expressly granted by the License. +# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# Automatically generated by addcopyright.py at 04/03/2012 + + + +from cloudstackTestCase import * +from remoteSSHClient import remoteSSHClient + +import unittest +import hashlib +import random +import string + +class TestDeployVm(cloudstackTestCase): + """ + This test deploys a virtual machine into a user account + using the small service offering and builtin template + """ + @classmethod + def setUpClass(cls): + """ + CloudStack internally saves its passwords in md5 form and that is how we + specify it in the API. Python's hashlib library helps us to quickly hash + strings as follows + """ + mdf = hashlib.md5() + mdf.update('password') + mdf_pass = mdf.hexdigest() + acctName = 'bugs-'+''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(6)) #randomly generated account + + cls.apiClient = super(TestDeployVm, cls).getClsTestClient().getApiClient() + cls.acct = createAccount.createAccountCmd() #The createAccount command + cls.acct.accounttype = 0 #We need a regular user. admins have accounttype=1 + cls.acct.firstname = 'bugs' + cls.acct.lastname = 'bunny' #What's up doc? + cls.acct.password = mdf_pass #The md5 hashed password string + cls.acct.username = acctName + cls.acct.email = 'bugs@rabbithole.com' + cls.acct.account = acctName + cls.acct.domainid = 1 #The default ROOT domain + cls.acctResponse = cls.apiClient.createAccount(cls.acct) + + def setUpNAT(self, virtualmachineid): + listSourceNat = listPublicIpAddresses.listPublicIpAddressesCmd() + listSourceNat.account = self.acct.account + listSourceNat.domainid = self.acct.domainid + listSourceNat.issourcenat = True + + listsnatresponse = self.apiClient.listPublicIpAddresses(listSourceNat) + self.assertNotEqual(len(listsnatresponse), 0, "Found a source NAT for the acct %s"%self.acct.account) + + snatid = listsnatresponse[0].id + snatip = listsnatresponse[0].ipaddress + + try: + createFwRule = createFirewallRule.createFirewallRuleCmd() + createFwRule.cidrlist = "0.0.0.0/0" + createFwRule.startport = 22 + createFwRule.endport = 22 + createFwRule.ipaddressid = snatid + createFwRule.protocol = "tcp" + createfwresponse = self.apiClient.createFirewallRule(createFwRule) + + createPfRule = createPortForwardingRule.createPortForwardingRuleCmd() + createPfRule.privateport = 22 + createPfRule.publicport = 22 + createPfRule.virtualmachineid = virtualmachineid + createPfRule.ipaddressid = snatid + createPfRule.protocol = "tcp" + + createpfresponse = self.apiClient.createPortForwardingRule(createPfRule) + except e: + self.debug("Failed to create PF rule in account %s due to %s"%(self.acct.account, e)) + raise e + finally: + return snatip + + def test_DeployVm(self): + """ + Let's start by defining the attributes of our VM that we will be + deploying on CloudStack. We will be assuming a single zone is available + and is configured and all templates are Ready + + The hardcoded values are used only for brevity. + """ + deployVmCmd = deployVirtualMachine.deployVirtualMachineCmd() + deployVmCmd.zoneid = 1 + deployVmCmd.account = self.acct.account + deployVmCmd.domainid = self.acct.domainid + deployVmCmd.templateid = 5 #CentOS 5.6 builtin + deployVmCmd.serviceofferingid = 1 + + deployVmResponse = self.apiClient.deployVirtualMachine(deployVmCmd) + self.debug("VM %s was deployed in the job %s"%(deployVmResponse.id, deployVmResponse.jobid)) + + # At this point our VM is expected to be Running. Let's find out what + # listVirtualMachines tells us about VMs in this account + + listVmCmd = listVirtualMachines.listVirtualMachinesCmd() + listVmCmd.id = deployVmResponse.id + listVmResponse = self.apiClient.listVirtualMachines(listVmCmd) + + self.assertNotEqual(len(listVmResponse), 0, "Check if the list API \ + returns a non-empty response") + + vm = listVmResponse[0] + hostname = vm.name + nattedip = self.setUpNAT(vm.id) + + self.assertEqual(vm.id, deployVmResponse.id, "Check if the VM returned \ + is the same as the one we deployed") + + + self.assertEqual(vm.state, "Running", "Check if VM has reached \ + a state of running") + + # SSH login and compare hostname + ssh_client = remoteSSHClient(nattedip, 22, "root", "password") + stdout = ssh_client.execute("hostname") + + self.assertEqual(hostname, stdout[0], "cloudstack VM name and hostname match") + + + @classmethod + def tearDownClass(cls): + """ + And finally let us cleanup the resources we created by deleting the + account. All good unittests are atomic and rerunnable this way + """ + deleteAcct = deleteAccount.deleteAccountCmd() + deleteAcct.id = cls.acctResponse.account.id + cls.apiClient.deleteAccount(deleteAcct) diff --git a/tools/marvin/setup.py b/tools/marvin/setup.py new file mode 100644 index 00000000000..ce28b15b365 --- /dev/null +++ b/tools/marvin/setup.py @@ -0,0 +1,34 @@ +#!/usr/bin/env python +# Copyright 2012 Citrix Systems, Inc. Licensed under the +# Apache License, Version 2.0 (the "License"); you may not use this +# file except in compliance with the License. Citrix Systems, Inc. + +from distutils.core import setup +from sys import version + +if version < "2.7": + print "Marvin needs at least python 2.7, found : \n%s"%version +else: + try: + import paramiko + except ImportError: + print "Marvin requires paramiko to be installed" + raise + + setup(name="Marvin", + version="0.1.0", + description="Marvin - Python client for testing cloudstack", + author="Edison Su", + author_email="Edison.Su@citrix.com", + maintainer="Prasanna Santhanam", + maintainer_email="Prasanna.Santhanam@citrix.com", + long_description="Marvin is the cloudstack testclient written around the python unittest framework", + platforms=("Any",), + url="http://jenkins.cloudstack.org:8080/job/marvin", + packages=["marvin", "marvin.cloudstackAPI", "marvin.sandbox.tests", "marvin.pymysql", "marvin.pymysql.constants", "marvin.pymysql.tests"], + license="LICENSE.txt", + requires=[ + "paramiko (>1.4)", + "Python (>=2.7)" + ] + )