diff --git a/build.xml b/build.xml
index 9753213c23b..1a6ae190fb9 100755
--- a/build.xml
+++ b/build.xml
@@ -23,6 +23,7 @@
+
diff --git a/build/build-marvin.xml b/build/build-marvin.xml
new file mode 100644
index 00000000000..ed5cf565e26
--- /dev/null
+++ b/build/build-marvin.xml
@@ -0,0 +1,45 @@
+
+
+
+
+
+ This build file contains simple targets that
+ - build
+ - package
+ - distribute
+ the Marvin test client written in python
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
diff --git a/tools/marvin/CHANGES.txt b/tools/marvin/CHANGES.txt
new file mode 100644
index 00000000000..dc207fb992a
--- /dev/null
+++ b/tools/marvin/CHANGES.txt
@@ -0,0 +1 @@
+v0.1.0 Tuesday, April 10 2012 -- Packaging Marvin
diff --git a/tools/marvin/LICENSE.txt b/tools/marvin/LICENSE.txt
new file mode 100644
index 00000000000..fa64b4a85d0
--- /dev/null
+++ b/tools/marvin/LICENSE.txt
@@ -0,0 +1,9 @@
+Copyright 2012 Citrix Systems, Inc. Licensed under the Apache License, Version
+2.0 (the "License"); you may not use this file except in compliance with the
+License. Citrix Systems, Inc. reserves all rights not expressly granted by the
+License. You may obtain a copy of the License at
+http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or
+agreed to in writing, software distributed under the License is distributed on
+an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
+or implied. See the License for the specific language governing permissions and
+limitations under the License.
diff --git a/tools/marvin/MANIFEST.in b/tools/marvin/MANIFEST.in
new file mode 100644
index 00000000000..92036977a8c
--- /dev/null
+++ b/tools/marvin/MANIFEST.in
@@ -0,0 +1,2 @@
+include *.txt
+recursive-include docs *.txt
\ No newline at end of file
diff --git a/tools/marvin/README b/tools/marvin/README
new file mode 100644
index 00000000000..b225b6e93fc
--- /dev/null
+++ b/tools/marvin/README
@@ -0,0 +1,15 @@
+Marvin is the testing framework for CloudStack written in python. Writing of
+unittests and functional tests with Marvin makes testing with cloudstack easier
+
+1. INSTALL
+ untar Marvin-0.1.0.tar.gz
+ cd Marvin-0.1.0
+ python setup.py install
+
+2. Facility it provides:
+ 1. very handy cloudstack API python wrapper
+ 2. support async job executing in parallel
+ 3. remote ssh login/execute command
+ 4. mysql query
+
+3. sample code is under sandbox
diff --git a/tools/marvin/docs/tutorial.txt b/tools/marvin/docs/tutorial.txt
new file mode 100644
index 00000000000..4da4b1b1ef1
--- /dev/null
+++ b/tools/marvin/docs/tutorial.txt
@@ -0,0 +1 @@
+Can be found at : http://wiki.cloudstack.org/display/QA/Testing+with+python
diff --git a/tools/marvin/marvin/NoseTestExecutionEngine.py b/tools/marvin/marvin/NoseTestExecutionEngine.py
new file mode 100644
index 00000000000..c64e0c1a5f7
--- /dev/null
+++ b/tools/marvin/marvin/NoseTestExecutionEngine.py
@@ -0,0 +1,34 @@
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from functools import partial
+import nose
+import nose.config
+import nose.core
+import os
+import sys
+import logging
+
+module_logger = "testclient.nose"
+
+def testCaseLogger(message, logger=None):
+ if logger is not None:
+ logger.debug(message)
+
+class TestCaseExecuteEngine(object):
+ def __init__(self, testclient, testCaseFolder, testcaseLogFile=None, testResultLogFile=None):
+ self.testclient = testclient
+ self.debuglog = testcaseLogFile
+ self.testCaseFolder = testCaseFolder
+ self.testResultLogFile = testResultLogFile
+ self.cfg = nose.config.Config()
+ self.cfg.configureWhere(self.testCaseFolder)
+ self.cfg.configureLogging()
+
+ def run(self):
+ self.args = ["--debug-log="+self.debuglog]
+ suite = nose.core.TestProgram(argv = self.args, config = self.cfg)
+ result = suite.runTests()
+ print result
diff --git a/tools/marvin/marvin/TestCaseExecuteEngine.py b/tools/marvin/marvin/TestCaseExecuteEngine.py
new file mode 100644
index 00000000000..cceed9ca8bd
--- /dev/null
+++ b/tools/marvin/marvin/TestCaseExecuteEngine.py
@@ -0,0 +1,51 @@
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+from functools import partial
+import os
+import sys
+import logging
+
+module_logger = "testclient.testcase"
+
+def testCaseLogger(message, logger=None):
+ if logger is not None:
+ logger.debug(message)
+
+class TestCaseExecuteEngine(object):
+ def __init__(self, testclient, testCaseFolder, testcaseLogFile=None, testResultLogFile=None):
+ self.testclient = testclient
+ self.testCaseFolder = testCaseFolder
+ self.logger = None
+ if testcaseLogFile is not None:
+ logger = logging.getLogger("testclient.testcase.TestCaseExecuteEngine")
+ fh = logging.FileHandler(testcaseLogFile)
+ fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
+ logger.addHandler(fh)
+ logger.setLevel(logging.DEBUG)
+ self.logger = logger
+ if testResultLogFile is not None:
+ ch = logging.StreamHandler()
+ ch.setLevel(logging.ERROR)
+ ch.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
+ self.logger.addHandler(ch)
+ fp = open(testResultLogFile, "w")
+ self.testResultLogFile = fp
+ else:
+ self.testResultLogFile = sys.stdout
+
+ def injectTestCase(self, testSuites):
+ for test in testSuites:
+ if isinstance(test, unittest.BaseTestSuite):
+ self.injectTestCase(test)
+ else:
+ setattr(test, "testClient", self.testclient)
+ setattr(test, "debug", partial(testCaseLogger, logger=self.logger))
+ def run(self):
+ loader = unittest.loader.TestLoader()
+ suite = loader.discover(self.testCaseFolder)
+ self.injectTestCase(suite)
+
+ unittest.TextTestRunner(stream=self.testResultLogFile, verbosity=2).run(suite)
diff --git a/tools/marvin/marvin/__init__.py b/tools/marvin/marvin/__init__.py
new file mode 100644
index 00000000000..a3ded0a92d3
--- /dev/null
+++ b/tools/marvin/marvin/__init__.py
@@ -0,0 +1 @@
+#Marvin - The cloudstack test client
\ No newline at end of file
diff --git a/tools/marvin/marvin/asyncJobMgr.py b/tools/marvin/marvin/asyncJobMgr.py
new file mode 100644
index 00000000000..d843d05ec6e
--- /dev/null
+++ b/tools/marvin/marvin/asyncJobMgr.py
@@ -0,0 +1,216 @@
+import threading
+import cloudstackException
+import time
+import Queue
+import copy
+import sys
+import jsonHelper
+import datetime
+
+class job(object):
+ def __init__(self):
+ self.id = None
+ self.cmd = None
+class jobStatus(object):
+ def __init__(self):
+ self.result = None
+ self.status = None
+ self.startTime = None
+ self.endTime = None
+ self.duration = None
+ self.jobId = None
+ self.responsecls = None
+class workThread(threading.Thread):
+ def __init__(self, in_queue, outqueue, apiClient, db=None, lock=None):
+ threading.Thread.__init__(self)
+ self.inqueue = in_queue
+ self.output = outqueue
+ self.connection = apiClient.connection
+ self.db = None
+ self.lock = lock
+
+ def queryAsynJob(self, job):
+ if job.jobId is None:
+ return job
+
+ try:
+ self.lock.acquire()
+ result = self.connection.pollAsyncJob(job.jobId, job.responsecls).jobresult
+ except cloudstackException.cloudstackAPIException, e:
+ result = str(e)
+ finally:
+ self.lock.release()
+
+ job.result = result
+ return job
+
+ def executeCmd(self, job):
+ cmd = job.cmd
+
+ jobstatus = jobStatus()
+ jobId = None
+ try:
+ self.lock.acquire()
+
+ if cmd.isAsync == "false":
+ jobstatus.startTime = datetime.datetime.now()
+
+ result = self.connection.make_request(cmd)
+ jobstatus.result = result
+ jobstatus.endTime = datetime.datetime.now()
+ jobstatus.duration = time.mktime(jobstatus.endTime.timetuple()) - time.mktime(jobstatus.startTime.timetuple())
+ else:
+ result = self.connection.make_request(cmd, None, True)
+ if result is None:
+ jobstatus.status = False
+ else:
+ jobId = result.jobid
+ jobstatus.jobId = jobId
+ try:
+ responseName = cmd.__class__.__name__.replace("Cmd", "Response")
+ jobstatus.responsecls = jsonHelper.getclassFromName(cmd, responseName)
+ except:
+ pass
+ jobstatus.status = True
+ except cloudstackException.cloudstackAPIException, e:
+ jobstatus.result = str(e)
+ jobstatus.status = False
+ except:
+ jobstatus.status = False
+ jobstatus.result = sys.exc_info()
+ finally:
+ self.lock.release()
+
+ return jobstatus
+
+ def run(self):
+ while self.inqueue.qsize() > 0:
+ job = self.inqueue.get()
+ if isinstance(job, jobStatus):
+ jobstatus = self.queryAsynJob(job)
+ else:
+ jobstatus = self.executeCmd(job)
+
+ self.output.put(jobstatus)
+ self.inqueue.task_done()
+
+ '''release the resource'''
+ self.connection.close()
+
+class jobThread(threading.Thread):
+ def __init__(self, inqueue, interval):
+ threading.Thread.__init__(self)
+ self.inqueue = inqueue
+ self.interval = interval
+ def run(self):
+ while self.inqueue.qsize() > 0:
+ job = self.inqueue.get()
+ try:
+ job.run()
+ '''release the api connection'''
+ job.apiClient.connection.close()
+ except:
+ pass
+
+ self.inqueue.task_done()
+ time.sleep(self.interval)
+
+class outputDict(object):
+ def __init__(self):
+ self.lock = threading.Condition()
+ self.dict = {}
+
+class asyncJobMgr(object):
+ def __init__(self, apiClient, db):
+ self.inqueue = Queue.Queue()
+ self.output = outputDict()
+ self.outqueue = Queue.Queue()
+ self.apiClient = apiClient
+ self.db = db
+
+ def submitCmds(self, cmds):
+ if not self.inqueue.empty():
+ return False
+ id = 0
+ ids = []
+ for cmd in cmds:
+ asyncjob = job()
+ asyncjob.id = id
+ asyncjob.cmd = cmd
+ self.inqueue.put(asyncjob)
+ id += 1
+ ids.append(id)
+ return ids
+
+ def updateTimeStamp(self, jobstatus):
+ jobId = jobstatus.jobId
+ if jobId is not None and self.db is not None:
+ result = self.db.execute("select job_status, created, last_updated from async_job where id=%s"%jobId)
+ if result is not None and len(result) > 0:
+ if result[0][0] == 1:
+ jobstatus.status = True
+ else:
+ jobstatus.status = False
+ jobstatus.startTime = result[0][1]
+ jobstatus.endTime = result[0][2]
+ delta = jobstatus.endTime - jobstatus.startTime
+ jobstatus.duration = delta.total_seconds()
+
+ def waitForComplete(self, workers=10):
+ self.inqueue.join()
+ lock = threading.Lock()
+ resultQueue = Queue.Queue()
+ '''intermediate result is stored in self.outqueue'''
+ for i in range(workers):
+ worker = workThread(self.outqueue, resultQueue, self.apiClient, self.db, lock)
+ worker.start()
+
+ self.outqueue.join()
+
+ asyncJobResult = []
+ while resultQueue.qsize() > 0:
+ jobstatus = resultQueue.get()
+ self.updateTimeStamp(jobstatus)
+ asyncJobResult.append(jobstatus)
+
+ return asyncJobResult
+
+ '''put commands into a queue at first, then start workers numbers threads to execute this commands'''
+ def submitCmdsAndWait(self, cmds, workers=10):
+ self.submitCmds(cmds)
+ lock = threading.Lock()
+ for i in range(workers):
+ worker = workThread(self.inqueue, self.outqueue, self.apiClient, self.db, lock)
+ worker.start()
+
+ return self.waitForComplete(workers)
+
+ '''submit one job and execute the same job ntimes, with nums_threads of threads'''
+ def submitJobExecuteNtimes(self, job, ntimes=1, nums_threads=1, interval=1):
+ inqueue1 = Queue.Queue()
+ lock = threading.Condition()
+ for i in range(ntimes):
+ newjob = copy.copy(job)
+ setattr(newjob, "apiClient", copy.copy(self.apiClient))
+ setattr(newjob, "lock", lock)
+ inqueue1.put(newjob)
+
+ for i in range(nums_threads):
+ work = jobThread(inqueue1, interval)
+ work.start()
+ inqueue1.join()
+
+ '''submit n jobs, execute them with nums_threads of threads'''
+ def submitJobs(self, jobs, nums_threads=1, interval=1):
+ inqueue1 = Queue.Queue()
+ lock = threading.Condition()
+
+ for job in jobs:
+ setattr(job, "apiClient", copy.copy(self.apiClient))
+ setattr(job, "lock", lock)
+ inqueue1.put(job)
+
+ for i in range(nums_threads):
+ work = jobThread(inqueue1, interval)
+ work.start()
+ inqueue1.join()
\ No newline at end of file
diff --git a/tools/marvin/marvin/cloudstackConnection.py b/tools/marvin/marvin/cloudstackConnection.py
new file mode 100644
index 00000000000..d3a57263df7
--- /dev/null
+++ b/tools/marvin/marvin/cloudstackConnection.py
@@ -0,0 +1,163 @@
+import urllib2
+import urllib
+import base64
+import copy
+import hmac
+import hashlib
+import json
+import xml.dom.minidom
+import types
+import time
+import inspect
+import cloudstackException
+from cloudstackAPI import *
+import jsonHelper
+
+class cloudConnection(object):
+ def __init__(self, mgtSvr, port=8096, apiKey = None, securityKey = None, asyncTimeout=3600, logging=None):
+ self.apiKey = apiKey
+ self.securityKey = securityKey
+ self.mgtSvr = mgtSvr
+ self.port = port
+ self.logging = logging
+ if port == 8096:
+ self.auth = False
+ else:
+ self.auth = True
+
+ self.retries = 5
+ self.asyncTimeout = asyncTimeout
+
+ def close(self):
+ try:
+ self.connection.close()
+ except:
+ pass
+
+ def __copy__(self):
+ return cloudConnection(self.mgtSvr, self.port, self.apiKey, self.securityKey, self.asyncTimeout, self.logging)
+
+ def make_request_with_auth(self, command, requests={}):
+ requests["command"] = command
+ requests["apiKey"] = self.apiKey
+ requests["response"] = "json"
+ request = zip(requests.keys(), requests.values())
+ request.sort(key=lambda x: str.lower(x[0]))
+
+ requestUrl = "&".join(["=".join([r[0], urllib.quote_plus(str(r[1]))]) for r in request])
+ hashStr = "&".join(["=".join([str.lower(r[0]), str.lower(urllib.quote_plus(str(r[1]))).replace("+", "%20")]) for r in request])
+
+ sig = urllib.quote_plus(base64.encodestring(hmac.new(self.securityKey, hashStr, hashlib.sha1).digest()).strip())
+ requestUrl += "&signature=%s"%sig
+
+ try:
+ self.connection = urllib2.urlopen("http://%s:%d/client/api?%s"%(self.mgtSvr, self.port, requestUrl))
+ self.logging.debug("sending request: %s"%requestUrl)
+ response = self.connection.read()
+ self.logging.debug("got response: %s"%response)
+ except IOError, e:
+ if hasattr(e, 'reason'):
+ self.logging.debug("failed to reach %s because of %s"%(self.mgtSvr, e.reason))
+ elif hasattr(e, 'code'):
+ self.logging.debug("server returned %d error code"%e.code)
+ except HTTPException, h:
+ self.logging.debug("encountered http Exception %s"%h.args)
+ if self.retries > 0:
+ self.retries = self.retries - 1
+ self.make_request_with_auth(command, requests)
+ else:
+ self.retries = 5
+ raise h
+ else:
+ return response
+
+ def make_request_without_auth(self, command, requests={}):
+ requests["command"] = command
+ requests["response"] = "json"
+ requests = zip(requests.keys(), requests.values())
+ requestUrl = "&".join(["=".join([request[0], urllib.quote_plus(str(request[1]))]) for request in requests])
+
+ self.connection = urllib2.urlopen("http://%s:%d/client/api?%s"%(self.mgtSvr, self.port, requestUrl))
+ self.logging.debug("sending request without auth: %s"%requestUrl)
+ response = self.connection.read()
+ self.logging.debug("got response: %s"%response)
+ return response
+
+ def pollAsyncJob(self, jobId, response):
+ cmd = queryAsyncJobResult.queryAsyncJobResultCmd()
+ cmd.jobid = jobId
+
+ while self.asyncTimeout > 0:
+ asyncResonse = self.make_request(cmd, response, True)
+
+ if asyncResonse.jobstatus == 2:
+ raise cloudstackException.cloudstackAPIException("asyncquery", asyncResonse.jobresult)
+ elif asyncResonse.jobstatus == 1:
+ return asyncResonse
+
+ time.sleep(5)
+ self.asyncTimeout = self.asyncTimeout - 5
+
+ raise cloudstackException.cloudstackAPIException("asyncquery", "Async job timeout")
+
+ def make_request(self, cmd, response = None, raw=False):
+ commandName = cmd.__class__.__name__.replace("Cmd", "")
+ isAsync = "false"
+ requests = {}
+ required = []
+ for attribute in dir(cmd):
+ if attribute != "__doc__" and attribute != "__init__" and attribute != "__module__":
+ if attribute == "isAsync":
+ isAsync = getattr(cmd, attribute)
+ elif attribute == "required":
+ required = getattr(cmd, attribute)
+ else:
+ requests[attribute] = getattr(cmd, attribute)
+
+ for requiredPara in required:
+ if requests[requiredPara] is None:
+ raise cloudstackException.cloudstackAPIException(commandName, "%s is required"%requiredPara)
+ '''remove none value'''
+ for param, value in requests.items():
+ if value is None:
+ requests.pop(param)
+ elif isinstance(value, list):
+ if len(value) == 0:
+ requests.pop(param)
+ else:
+ if not isinstance(value[0], dict):
+ requests[param] = ",".join(value)
+ else:
+ requests.pop(param)
+ i = 0
+ for v in value:
+ for key, val in v.iteritems():
+ requests["%s[%d].%s"%(param,i,key)] = val
+ i = i + 1
+
+ if self.logging is not None:
+ self.logging.debug("sending command: %s %s"%(commandName, str(requests)))
+ result = None
+ if self.auth:
+ result = self.make_request_with_auth(commandName, requests)
+ else:
+ result = self.make_request_without_auth(commandName, requests)
+
+ if result is None:
+ return None
+ if self.logging is not None:
+ self.logging.debug("got result: " + result)
+
+ result = jsonHelper.getResultObj(result, response)
+ if raw or isAsync == "false":
+ return result
+ else:
+ asynJobId = result.jobid
+ result = self.pollAsyncJob(asynJobId, response)
+ return result.jobresult
+
+if __name__ == '__main__':
+ xml = '407i-1-407-RS3i-1-407-RS3system1ROOT2011-07-30T14:45:19-0700Runningfalse1CA13kvm-50-2054CentOS 5.5(64-bit) no GUI (KVM)CentOS 5.5(64-bit) no GUI (KVM)false1Small Instance15005121120NetworkFilesystem380203255.255.255.065.19.181.165.19.181.110vlan://65vlan://65GuestDirecttrue06:52:da:00:00:08KVM'
+ conn = cloudConnection(None)
+
+ print conn.paraseReturnXML(xml, deployVirtualMachine.deployVirtualMachineResponse())
diff --git a/tools/marvin/marvin/cloudstackException.py b/tools/marvin/marvin/cloudstackException.py
new file mode 100644
index 00000000000..f731be383c7
--- /dev/null
+++ b/tools/marvin/marvin/cloudstackException.py
@@ -0,0 +1,24 @@
+
+class cloudstackAPIException(Exception):
+ def __init__(self, cmd = "", result = ""):
+ self.errorMsg = "Execute cmd: %s failed, due to: %s"%(cmd, result)
+ def __str__(self):
+ return self.errorMsg
+
+class InvalidParameterException(Exception):
+ def __init__(self, msg=''):
+ self.errorMsg = msg
+ def __str__(self):
+ return self.errorMsg
+
+class dbException(Exception):
+ def __init__(self, msg=''):
+ self.errorMsg = msg
+ def __str__(self):
+ return self.errorMsg
+
+class internalError(Exception):
+ def __init__(self, msg=''):
+ self.errorMsg = msg
+ def __str__(self):
+ return self.errorMsg
\ No newline at end of file
diff --git a/tools/marvin/marvin/cloudstackTestCase.py b/tools/marvin/marvin/cloudstackTestCase.py
new file mode 100644
index 00000000000..4595aefaa42
--- /dev/null
+++ b/tools/marvin/marvin/cloudstackTestCase.py
@@ -0,0 +1,11 @@
+from cloudstackAPI import *
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+import cloudstackTestClient
+
+class cloudstackTestCase(unittest.case.TestCase):
+ def __init__(self, args):
+ unittest.case.TestCase.__init__(self, args)
+ self.testClient = cloudstackTestClient.cloudstackTestClient()
diff --git a/tools/marvin/marvin/cloudstackTestClient.py b/tools/marvin/marvin/cloudstackTestClient.py
new file mode 100644
index 00000000000..0d8f8108ea9
--- /dev/null
+++ b/tools/marvin/marvin/cloudstackTestClient.py
@@ -0,0 +1,58 @@
+import cloudstackConnection
+import asyncJobMgr
+import dbConnection
+from cloudstackAPI import *
+
+class cloudstackTestClient(object):
+ def __init__(self, mgtSvr=None, port=8096, apiKey = None, securityKey = None, asyncTimeout=3600, defaultWorkerThreads=10, logging=None):
+ self.connection = cloudstackConnection.cloudConnection(mgtSvr, port, apiKey, securityKey, asyncTimeout, logging)
+ self.apiClient = cloudstackAPIClient.CloudStackAPIClient(self.connection)
+ self.dbConnection = None
+ self.asyncJobMgr = None
+ self.ssh = None
+ self.defaultWorkerThreads = defaultWorkerThreads
+
+
+ def dbConfigure(self, host="localhost", port=3306, user='cloud', passwd='cloud', db='cloud'):
+ self.dbConnection = dbConnection.dbConnection(host, port, user, passwd, db)
+
+ def close(self):
+ if self.connection is not None:
+ self.connection.close()
+ if self.dbConnection is not None:
+ self.dbConnection.close()
+
+ def getDbConnection(self):
+ return self.dbConnection
+
+ def executeSql(self, sql=None):
+ if sql is None or self.dbConnection is None:
+ return None
+
+ return self.dbConnection.execute()
+
+ def executeSqlFromFile(self, sqlFile=None):
+ if sqlFile is None or self.dbConnection is None:
+ return None
+ return self.dbConnection.executeSqlFromFile(sqlFile)
+
+ def getApiClient(self):
+ return self.apiClient
+
+ '''FixME, httplib has issue if more than one thread submitted'''
+ def submitCmdsAndWait(self, cmds, workers=1):
+ if self.asyncJobMgr is None:
+ self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection)
+ return self.asyncJobMgr.submitCmdsAndWait(cmds, workers)
+
+ '''submit one job and execute the same job ntimes, with nums_threads of threads'''
+ def submitJob(self, job, ntimes=1, nums_threads=10, interval=1):
+ if self.asyncJobMgr is None:
+ self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection)
+ self.asyncJobMgr.submitJobExecuteNtimes(job, ntimes, nums_threads, interval)
+
+ '''submit n jobs, execute them with nums_threads of threads'''
+ def submitJobs(self, jobs, nums_threads=10, interval=1):
+ if self.asyncJobMgr is None:
+ self.asyncJobMgr = asyncJobMgr.asyncJobMgr(self.apiClient, self.dbConnection)
+ self.asyncJobMgr.submitJobs(jobs, nums_threads, interval)
\ No newline at end of file
diff --git a/tools/marvin/marvin/codegenerator.py b/tools/marvin/marvin/codegenerator.py
new file mode 100644
index 00000000000..aea53b8854b
--- /dev/null
+++ b/tools/marvin/marvin/codegenerator.py
@@ -0,0 +1,277 @@
+import xml.dom.minidom
+from optparse import OptionParser
+import os
+import sys
+class cmdParameterProperty(object):
+ def __init__(self):
+ self.name = None
+ self.required = False
+ self.desc = ""
+ self.type = "planObject"
+ self.subProperties = []
+
+class cloudStackCmd:
+ def __init__(self):
+ self.name = ""
+ self.desc = ""
+ self.async = "false"
+ self.request = []
+ self.response = []
+
+class codeGenerator:
+ space = " "
+
+ cmdsName = []
+
+ def __init__(self, outputFolder, apiSpecFile):
+ self.cmd = None
+ self.code = ""
+ self.required = []
+ self.subclass = []
+ self.outputFolder = outputFolder
+ self.apiSpecFile = apiSpecFile
+
+ def addAttribute(self, attr, pro):
+ value = pro.value
+ if pro.required:
+ self.required.append(attr)
+ desc = pro.desc
+ if desc is not None:
+ self.code += self.space
+ self.code += "''' " + pro.desc + " '''"
+ self.code += "\n"
+
+ self.code += self.space
+ self.code += attr + " = " + str(value)
+ self.code += "\n"
+
+ def generateSubClass(self, name, properties):
+ '''generate code for sub list'''
+ subclass = 'class %s:\n'%name
+ subclass += self.space + "def __init__(self):\n"
+ for pro in properties:
+ if pro.desc is not None:
+ subclass += self.space + self.space + '""""%s"""\n'%pro.desc
+ if len (pro.subProperties) > 0:
+ subclass += self.space + self.space + 'self.%s = []\n'%pro.name
+ self.generateSubClass(pro.name, pro.subProperties)
+ else:
+ subclass += self.space + self.space + 'self.%s = None\n'%pro.name
+
+ self.subclass.append(subclass)
+ def generate(self, cmd):
+
+ self.cmd = cmd
+ self.cmdsName.append(self.cmd.name)
+ self.code += "\n"
+ self.code += '"""%s"""\n'%self.cmd.desc
+ self.code += 'from baseCmd import *\n'
+ self.code += 'from baseResponse import *\n'
+ self.code += "class %sCmd (baseCmd):\n"%self.cmd.name
+ self.code += self.space + "def __init__(self):\n"
+
+ self.code += self.space + self.space + 'self.isAsync = "%s"\n' %self.cmd.async
+
+ for req in self.cmd.request:
+ if req.desc is not None:
+ self.code += self.space + self.space + '"""%s"""\n'%req.desc
+ if req.required == "true":
+ self.code += self.space + self.space + '"""Required"""\n'
+
+ value = "None"
+ if req.type == "list" or req.type == "map":
+ value = "[]"
+
+ self.code += self.space + self.space + 'self.%s = %s\n'%(req.name,value)
+ if req.required == "true":
+ self.required.append(req.name)
+
+ self.code += self.space + self.space + "self.required = ["
+ for require in self.required:
+ self.code += '"' + require + '",'
+ self.code += "]\n"
+ self.required = []
+
+
+ """generate response code"""
+ subItems = {}
+ self.code += "\n"
+ self.code += 'class %sResponse (baseResponse):\n'%self.cmd.name
+ self.code += self.space + "def __init__(self):\n"
+ if len(self.cmd.response) == 0:
+ self.code += self.space + self.space + "pass"
+ else:
+ for res in self.cmd.response:
+ if res.desc is not None:
+ self.code += self.space + self.space + '"""%s"""\n'%res.desc
+
+ if len(res.subProperties) > 0:
+ self.code += self.space + self.space + 'self.%s = []\n'%res.name
+ self.generateSubClass(res.name, res.subProperties)
+ else:
+ self.code += self.space + self.space + 'self.%s = None\n'%res.name
+ self.code += '\n'
+
+ for subclass in self.subclass:
+ self.code += subclass + "\n"
+
+ fp = open(self.outputFolder + "/cloudstackAPI/%s.py"%self.cmd.name, "w")
+ fp.write(self.code)
+ fp.close()
+ self.code = ""
+ self.subclass = []
+
+
+ def finalize(self):
+ '''generate an api call'''
+
+ header = '"""Test Client for CloudStack API"""\n'
+ imports = "import copy\n"
+ initCmdsList = '__all__ = ['
+ body = ''
+ body += "class CloudStackAPIClient:\n"
+ body += self.space + 'def __init__(self, connection):\n'
+ body += self.space + self.space + 'self.connection = connection\n'
+ body += "\n"
+
+ body += self.space + 'def __copy__(self):\n'
+ body += self.space + self.space + 'return CloudStackAPIClient(copy.copy(self.connection))\n'
+ body += "\n"
+
+ for cmdName in self.cmdsName:
+ body += self.space + 'def %s(self,command):\n'%cmdName
+ body += self.space + self.space + 'response = %sResponse()\n'%cmdName
+ body += self.space + self.space + 'response = self.connection.make_request(command, response)\n'
+ body += self.space + self.space + 'return response\n'
+ body += '\n'
+
+ imports += 'from %s import %sResponse\n'%(cmdName, cmdName)
+ initCmdsList += '"%s",'%cmdName
+
+ fp = open(self.outputFolder + '/cloudstackAPI/cloudstackAPIClient.py', 'w')
+ for item in [header, imports, body]:
+ fp.write(item)
+ fp.close()
+
+ '''generate __init__.py'''
+ initCmdsList += '"cloudstackAPIClient"]'
+ fp = open(self.outputFolder + '/cloudstackAPI/__init__.py', 'w')
+ fp.write(initCmdsList)
+ fp.close()
+
+ fp = open(self.outputFolder + '/cloudstackAPI/baseCmd.py', 'w')
+ basecmd = '"""Base Command"""\n'
+ basecmd += 'class baseCmd:\n'
+ basecmd += self.space + 'pass\n'
+ fp.write(basecmd)
+ fp.close()
+
+ fp = open(self.outputFolder + '/cloudstackAPI/baseResponse.py', 'w')
+ basecmd = '"""Base class for response"""\n'
+ basecmd += 'class baseResponse:\n'
+ basecmd += self.space + 'pass\n'
+ fp.write(basecmd)
+ fp.close()
+
+
+ def constructResponse(self, response):
+ paramProperty = cmdParameterProperty()
+ paramProperty.name = getText(response.getElementsByTagName('name'))
+ paramProperty.desc = getText(response.getElementsByTagName('description'))
+ if paramProperty.name.find('(*)') != -1:
+ '''This is a list'''
+ paramProperty.name = paramProperty.name.split('(*)')[0]
+ for subresponse in response.getElementsByTagName('arguments')[0].getElementsByTagName('arg'):
+ subProperty = self.constructResponse(subresponse)
+ paramProperty.subProperties.append(subProperty)
+ return paramProperty
+
+ def loadCmdFromXML(self):
+ dom = xml.dom.minidom.parse(self.apiSpecFile)
+ cmds = []
+ for cmd in dom.getElementsByTagName("command"):
+ csCmd = cloudStackCmd()
+ csCmd.name = getText(cmd.getElementsByTagName('name'))
+ assert csCmd.name
+
+ desc = getText(cmd.getElementsByTagName('description'))
+ if desc:
+ csCmd.desc = desc
+
+ async = getText(cmd.getElementsByTagName('isAsync'))
+ if async:
+ csCmd.async = async
+
+ for param in cmd.getElementsByTagName("request")[0].getElementsByTagName("arg"):
+ paramProperty = cmdParameterProperty()
+
+ paramProperty.name = getText(param.getElementsByTagName('name'))
+ assert paramProperty.name
+
+ required = param.getElementsByTagName('required')
+ if required:
+ paramProperty.required = getText(required)
+
+ requestDescription = param.getElementsByTagName('description')
+ if requestDescription:
+ paramProperty.desc = getText(requestDescription)
+
+ type = param.getElementsByTagName("type")
+ if type:
+ paramProperty.type = getText(type)
+
+ csCmd.request.append(paramProperty)
+
+ responseEle = cmd.getElementsByTagName("response")[0]
+ for response in responseEle.getElementsByTagName("arg"):
+ if response.parentNode != responseEle:
+ continue
+
+ paramProperty = self.constructResponse(response)
+ csCmd.response.append(paramProperty)
+
+ cmds.append(csCmd)
+ return cmds
+
+ def generateCode(self):
+ cmds = self.loadCmdFromXML()
+ for cmd in cmds:
+ self.generate(cmd)
+ self.finalize()
+
+def getText(elements):
+ return elements[0].childNodes[0].nodeValue.strip()
+
+if __name__ == "__main__":
+ parser = OptionParser()
+
+ parser.add_option("-o", "--output", dest="output", help="the root path where code genereted, default is .")
+ parser.add_option("-s", "--specfile", dest="spec", help="the path and name of the api spec xml file, default is /etc/cloud/cli/commands.xml")
+
+ (options, args) = parser.parse_args()
+
+ apiSpecFile = "/etc/cloud/cli/commands.xml"
+ if options.spec is not None:
+ apiSpecFile = options.spec
+
+ if not os.path.exists(apiSpecFile):
+ print "the spec file %s does not exists"%apiSpecFile
+ print parser.print_help()
+ exit(1)
+
+
+ folder = "."
+ if options.output is not None:
+ folder = options.output
+ apiModule=folder + "/cloudstackAPI"
+ if not os.path.exists(apiModule):
+ try:
+ os.mkdir(apiModule)
+ except:
+ print "Failed to create folder %s, due to %s"%(apiModule,sys.exc_info())
+ print parser.print_help()
+ exit(2)
+
+ cg = codeGenerator(folder, apiSpecFile)
+ cg.generateCode()
+
diff --git a/tools/marvin/marvin/configGenerator.py b/tools/marvin/marvin/configGenerator.py
new file mode 100644
index 00000000000..3baa3edcf52
--- /dev/null
+++ b/tools/marvin/marvin/configGenerator.py
@@ -0,0 +1,378 @@
+import json
+import os
+from optparse import OptionParser
+import jsonHelper
+
+class managementServer():
+ def __init__(self):
+ self.mgtSvrIp = None
+ self.port = 8096
+ self.apiKey = None
+ self.securityKey = None
+
+class dbServer():
+ def __init__(self):
+ self.dbSvr = None
+ self.port = 3306
+ self.user = "cloud"
+ self.passwd = "cloud"
+ self.db = "cloud"
+
+class configuration():
+ def __init__(self):
+ self.name = None
+ self.value = None
+
+class logger():
+ def __init__(self):
+ '''TestCase/TestClient'''
+ self.name = None
+ self.file = None
+
+class cloudstackConfiguration():
+ def __init__(self):
+ self.zones = []
+ self.mgtSvr = []
+ self.dbSvr = None
+ self.globalConfig = []
+ self.logger = []
+
+class zone():
+ def __init__(self):
+ self.dns1 = None
+ self.internaldns1 = None
+ self.name = None
+ '''Basic or Advanced'''
+ self.networktype = None
+ self.dns2 = None
+ self.guestcidraddress = None
+ self.internaldns2 = None
+ self.securitygroupenabled = None
+ self.vlan = None
+ '''default public network, in advanced mode'''
+ self.ipranges = []
+ '''tagged network, in advanced mode'''
+ self.networks = []
+ self.pods = []
+ self.secondaryStorages = []
+
+class pod():
+ def __init__(self):
+ self.gateway = None
+ self.name = None
+ self.netmask = None
+ self.startip = None
+ self.endip = None
+ self.zoneid = None
+ self.clusters = []
+ '''Used in basic network mode'''
+ self.guestIpRanges = []
+
+class cluster():
+ def __init__(self):
+ self.clustername = None
+ self.clustertype = None
+ self.hypervisor = None
+ self.zoneid = None
+ self.podid = None
+ self.password = None
+ self.url = None
+ self.username = None
+ self.hosts = []
+ self.primaryStorages = []
+
+class host():
+ def __init__(self):
+ self.hypervisor = None
+ self.password = None
+ self.url = None
+ self.username = None
+ self.zoneid = None
+ self.podid = None
+ self.clusterid = None
+ self.clustername = None
+ self.cpunumber = None
+ self.cpuspeed = None
+ self.hostmac = None
+ self.hosttags = None
+ self.memory = None
+
+class network():
+ def __init__(self):
+ self.displaytext = None
+ self.name = None
+ self.zoneid = None
+ self.account = None
+ self.domainid = None
+ self.isdefault = None
+ self.isshared = None
+ self.networkdomain = None
+ self.networkofferingid = None
+ self.ipranges = []
+
+class iprange():
+ def __init__(self):
+ '''tagged/untagged'''
+ self.gateway = None
+ self.netmask = None
+ self.startip = None
+ self.endip = None
+ self.vlan = None
+ '''for account specific '''
+ self.account = None
+ self.domain = None
+
+class primaryStorage():
+ def __init__(self):
+ self.name = None
+ self.url = None
+
+class secondaryStorage():
+ def __init__(self):
+ self.url = None
+
+'''sample code to generate setup configuration file'''
+def describe_setup_in_basic_mode():
+ zs = cloudstackConfiguration()
+
+ for l in range(1):
+ z = zone()
+ z.dns1 = "8.8.8.8"
+ z.dns2 = "4.4.4.4"
+ z.internaldns1 = "192.168.110.254"
+ z.internaldns2 = "192.168.110.253"
+ z.name = "test"+str(l)
+ z.networktype = 'Basic'
+
+ '''create 10 pods'''
+ for i in range(300):
+ p = pod()
+ p.name = "test" +str(l) + str(i)
+ p.gateway = "192.%d.%d.1"%((i/255)+168,i%255)
+ p.netmask = "255.255.255.0"
+
+ p.startip = "192.%d.%d.150"%((i/255)+168,i%255)
+ p.endip = "192.%d.%d.220"%((i/255)+168,i%255)
+
+ '''add two pod guest ip ranges'''
+ for j in range(1):
+ ip = iprange()
+ ip.gateway = p.gateway
+ ip.netmask = p.netmask
+ ip.startip = "192.%d.%d.%d"%(((i/255)+168), i%255,j*20)
+ ip.endip = "192.%d.%d.%d"%((i/255)+168,i%255,j*20+10)
+
+ p.guestIpRanges.append(ip)
+
+ '''add 10 clusters'''
+ for j in range(10):
+ c = cluster()
+ c.clustername = "test"+str(l)+str(i) + str(j)
+ c.clustertype = "CloudManaged"
+ c.hypervisor = "Simulator"
+
+ '''add 10 hosts'''
+ for k in range(1):
+ h = host()
+ h.username = "root"
+ h.password = "password"
+ memory = 8*1024*1024*1024
+ localstorage=1*1024*1024*1024*1024
+ #h.url = "http://Sim/%d%d%d%d/cpucore=1&cpuspeed=8000&memory=%d&localstorage=%d"%(l,i,j,k,memory,localstorage)
+ h.url = "http://Sim/%d%d%d%d"%(l,i,j,k)
+ c.hosts.append(h)
+
+ '''add 2 primary storages'''
+ '''
+ for m in range(2):
+ primary = primaryStorage()
+ size=1*1024*1024*1024*1024
+ primary.name = "primary"+str(l) + str(i) + str(j) + str(m)
+ #primary.url = "nfs://localhost/path%s/size=%d"%(str(l) + str(i) + str(j) + str(m), size)
+ primary.url = "nfs://localhost/path%s"%(str(l) + str(i) + str(j) + str(m))
+ c.primaryStorages.append(primary)
+ '''
+
+ p.clusters.append(c)
+
+ z.pods.append(p)
+
+ '''add two secondary'''
+ for i in range(5):
+ secondary = secondaryStorage()
+ secondary.url = "nfs://localhost/path"+str(l) + str(i)
+ z.secondaryStorages.append(secondary)
+
+ zs.zones.append(z)
+
+ '''Add one mgt server'''
+ mgt = managementServer()
+ mgt.mgtSvrIp = "localhost"
+ zs.mgtSvr.append(mgt)
+
+ '''Add a database'''
+ db = dbServer()
+ db.dbSvr = "localhost"
+
+ zs.dbSvr = db
+
+ '''add global configuration'''
+ global_settings = {'expunge.delay': '60',
+ 'expunge.interval': '60',
+ 'expunge.workers': '3',
+ }
+ for k,v in global_settings.iteritems():
+ cfg = configuration()
+ cfg.name = k
+ cfg.value = v
+ zs.globalConfig.append(cfg)
+
+ ''''add loggers'''
+ testClientLogger = logger()
+ testClientLogger.name = "TestClient"
+ testClientLogger.file = "/tmp/testclient.log"
+
+ testCaseLogger = logger()
+ testCaseLogger.name = "TestCase"
+ testCaseLogger.file = "/tmp/testcase.log"
+
+ zs.logger.append(testClientLogger)
+ zs.logger.append(testCaseLogger)
+
+ return zs
+
+'''sample code to generate setup configuration file'''
+def describe_setup_in_advanced_mode():
+ zs = cloudstackConfiguration()
+
+ for l in range(1):
+ z = zone()
+ z.dns1 = "8.8.8.8"
+ z.dns2 = "4.4.4.4"
+ z.internaldns1 = "192.168.110.254"
+ z.internaldns2 = "192.168.110.253"
+ z.name = "test"+str(l)
+ z.networktype = 'Advanced'
+ z.guestcidraddress = "10.1.1.0/24"
+ z.vlan = "100-2000"
+
+ '''create 10 pods'''
+ for i in range(2):
+ p = pod()
+ p.name = "test" +str(l) + str(i)
+ p.gateway = "192.168.%d.1"%i
+ p.netmask = "255.255.255.0"
+ p.startip = "192.168.%d.200"%i
+ p.endip = "192.168.%d.220"%i
+
+ '''add 10 clusters'''
+ for j in range(2):
+ c = cluster()
+ c.clustername = "test"+str(l)+str(i) + str(j)
+ c.clustertype = "CloudManaged"
+ c.hypervisor = "Simulator"
+
+ '''add 10 hosts'''
+ for k in range(2):
+ h = host()
+ h.username = "root"
+ h.password = "password"
+ memory = 8*1024*1024*1024
+ localstorage=1*1024*1024*1024*1024
+ #h.url = "http://Sim/%d%d%d%d/cpucore=1&cpuspeed=8000&memory=%d&localstorage=%d"%(l,i,j,k,memory,localstorage)
+ h.url = "http://Sim/%d%d%d%d"%(l,i,j,k)
+ c.hosts.append(h)
+
+ '''add 2 primary storages'''
+ for m in range(2):
+ primary = primaryStorage()
+ size=1*1024*1024*1024*1024
+ primary.name = "primary"+str(l) + str(i) + str(j) + str(m)
+ #primary.url = "nfs://localhost/path%s/size=%d"%(str(l) + str(i) + str(j) + str(m), size)
+ primary.url = "nfs://localhost/path%s"%(str(l) + str(i) + str(j) + str(m))
+ c.primaryStorages.append(primary)
+
+ p.clusters.append(c)
+
+ z.pods.append(p)
+
+ '''add two secondary'''
+ for i in range(5):
+ secondary = secondaryStorage()
+ secondary.url = "nfs://localhost/path"+str(l) + str(i)
+ z.secondaryStorages.append(secondary)
+
+ '''add default public network'''
+ ips = iprange()
+ ips.vlan = "26"
+ ips.startip = "172.16.26.2"
+ ips.endip = "172.16.26.100"
+ ips.gateway = "172.16.26.1"
+ ips.netmask = "255.255.255.0"
+ z.ipranges.append(ips)
+
+
+ zs.zones.append(z)
+
+ '''Add one mgt server'''
+ mgt = managementServer()
+ mgt.mgtSvrIp = "localhost"
+ zs.mgtSvr.append(mgt)
+
+ '''Add a database'''
+ db = dbServer()
+ db.dbSvr = "localhost"
+
+ zs.dbSvr = db
+
+ '''add global configuration'''
+ global_settings = {'expunge.delay': '60',
+ 'expunge.interval': '60',
+ 'expunge.workers': '3',
+ }
+ for k,v in global_settings.iteritems():
+ cfg = configuration()
+ cfg.name = k
+ cfg.value = v
+ zs.globalConfig.append(cfg)
+
+ ''''add loggers'''
+ testClientLogger = logger()
+ testClientLogger.name = "TestClient"
+ testClientLogger.file = "/tmp/testclient.log"
+
+ testCaseLogger = logger()
+ testCaseLogger.name = "TestCase"
+ testCaseLogger.file = "/tmp/testcase.log"
+
+ zs.logger.append(testClientLogger)
+ zs.logger.append(testCaseLogger)
+
+ return zs
+
+def generate_setup_config(config, file=None):
+ describe = config
+ if file is None:
+ return json.dumps(jsonHelper.jsonDump.dump(describe))
+ else:
+ fp = open(file, 'w')
+ json.dump(jsonHelper.jsonDump.dump(describe), fp, indent=4)
+ fp.close()
+
+
+def get_setup_config(file):
+ if not os.path.exists(file):
+ return None
+ config = cloudstackConfiguration()
+ fp = open(file, 'r')
+ config = json.load(fp)
+ return jsonHelper.jsonLoader(config)
+
+if __name__ == "__main__":
+ parser = OptionParser()
+
+ parser.add_option("-o", "--output", action="store", default="./datacenterCfg", dest="output", help="the path where the json config file generated, by default is ./datacenterCfg")
+
+ (options, args) = parser.parse_args()
+ config = describe_setup_in_basic_mode()
+ generate_setup_config(config, options.output)
diff --git a/tools/marvin/marvin/dbConnection.py b/tools/marvin/marvin/dbConnection.py
new file mode 100644
index 00000000000..e6135edf00f
--- /dev/null
+++ b/tools/marvin/marvin/dbConnection.py
@@ -0,0 +1,78 @@
+import pymysql
+import cloudstackException
+import sys
+import os
+import traceback
+class dbConnection(object):
+ def __init__(self, host="localhost", port=3306, user='cloud', passwd='cloud', db='cloud'):
+ self.host = host
+ self.port = port
+ self.user = user
+ self.passwd = passwd
+ self.database = db
+
+ try:
+ self.db = pymysql.Connect(host=host, port=port, user=user, passwd=passwd, db=db)
+ except:
+ traceback.print_exc()
+ raise cloudstackException.InvalidParameterException(sys.exc_info())
+
+ def __copy__(self):
+ return dbConnection(self.host, self.port, self.user, self.passwd, self.database)
+
+ def close(self):
+ try:
+ self.db.close()
+ except:
+ pass
+
+ def execute(self, sql=None):
+ if sql is None:
+ return None
+
+ resultRow = []
+ cursor = None
+ try:
+ cursor = self.db.cursor()
+ cursor.execute(sql)
+
+ result = cursor.fetchall()
+ if result is not None:
+ for r in result:
+ resultRow.append(r)
+ return resultRow
+ except pymysql.MySQLError, e:
+ raise cloudstackException.dbException("db Exception:%s"%e[1])
+ except:
+ raise cloudstackException.internalError(sys.exc_info())
+ finally:
+ if cursor is not None:
+ cursor.close()
+
+ def executeSqlFromFile(self, fileName=None):
+ if fileName is None:
+ raise cloudstackException.InvalidParameterException("file can't not none")
+
+ if not os.path.exists(fileName):
+ raise cloudstackException.InvalidParameterException("%s not exists"%fileName)
+
+ sqls = open(fileName, "r").read()
+ return self.execute(sqls)
+
+if __name__ == "__main__":
+ db = dbConnection()
+ '''
+ try:
+
+ result = db.executeSqlFromFile("/tmp/server-setup.sql")
+ if result is not None:
+ for r in result:
+ print r[0], r[1]
+ except cloudstackException.dbException, e:
+ print e
+ '''
+ print db.execute("update vm_template set name='fjkd' where id=200")
+ for i in range(10):
+ result = db.execute("select job_status, created, last_updated from async_job where id=%d"%i)
+ print result
+
diff --git a/tools/marvin/marvin/deployAndRun.py b/tools/marvin/marvin/deployAndRun.py
new file mode 100644
index 00000000000..084eb995418
--- /dev/null
+++ b/tools/marvin/marvin/deployAndRun.py
@@ -0,0 +1,32 @@
+import deployDataCenter
+import TestCaseExecuteEngine
+from optparse import OptionParser
+import os
+if __name__ == "__main__":
+ parser = OptionParser()
+
+ parser.add_option("-c", "--config", action="store", default="./datacenterCfg", dest="config", help="the path where the json config file generated, by default is ./datacenterCfg")
+ parser.add_option("-d", "--directory", dest="testCaseFolder", help="the test case directory")
+ parser.add_option("-r", "--result", dest="result", help="test result log file")
+ parser.add_option("-t", dest="testcaselog", help="test case log file")
+ parser.add_option("-l", "--load", dest="load", action="store_true", help="only load config, do not deploy, it will only run testcase")
+ (options, args) = parser.parse_args()
+ if options.testCaseFolder is None:
+ parser.print_usage()
+ exit(1)
+
+ testResultLogFile = None
+ if options.result is not None:
+ testResultLogFile = options.result
+
+ testCaseLogFile = None
+ if options.testcaselog is not None:
+ testCaseLogFile = options.testcaselog
+ deploy = deployDataCenter.deployDataCenters(options.config)
+ if options.load:
+ deploy.loadCfg()
+ else:
+ deploy.deploy()
+
+ testcaseEngine = TestCaseExecuteEngine.TestCaseExecuteEngine(deploy.testClient, options.testCaseFolder, testCaseLogFile, testResultLogFile)
+ testcaseEngine.run()
\ No newline at end of file
diff --git a/tools/marvin/marvin/deployDataCenter.py b/tools/marvin/marvin/deployDataCenter.py
new file mode 100644
index 00000000000..e47dc726676
--- /dev/null
+++ b/tools/marvin/marvin/deployDataCenter.py
@@ -0,0 +1,274 @@
+'''Deploy datacenters according to a json configuration file'''
+import configGenerator
+import cloudstackException
+import cloudstackTestClient
+import sys
+import logging
+from cloudstackAPI import *
+from optparse import OptionParser
+
+module_logger = "testclient.deploy"
+
+
+class deployDataCenters():
+ def __init__(self, cfgFile):
+ self.configFile = cfgFile
+
+ def addHosts(self, hosts, zoneId, podId, clusterId, hypervisor):
+ if hosts is None:
+ return
+ for host in hosts:
+ hostcmd = addHost.addHostCmd()
+ hostcmd.clusterid = clusterId
+ hostcmd.cpunumber = host.cpunumer
+ hostcmd.cpuspeed = host.cpuspeed
+ hostcmd.hostmac = host.hostmac
+ hostcmd.hosttags = host.hosttags
+ hostcmd.hypervisor = host.hypervisor
+ hostcmd.memory = host.memory
+ hostcmd.password = host.password
+ hostcmd.podid = podId
+ hostcmd.url = host.url
+ hostcmd.username = host.username
+ hostcmd.zoneid = zoneId
+ hostcmd.hypervisor = hypervisor
+ self.apiClient.addHost(hostcmd)
+
+ def createClusters(self, clusters, zoneId, podId):
+ if clusters is None:
+ return
+
+ for cluster in clusters:
+ clustercmd = addCluster.addClusterCmd()
+ clustercmd.clustername = cluster.clustername
+ clustercmd.clustertype = cluster.clustertype
+ clustercmd.hypervisor = cluster.hypervisor
+ clustercmd.password = cluster.password
+ clustercmd.podid = podId
+ clustercmd.url = cluster.url
+ clustercmd.username = cluster.username
+ clustercmd.zoneid = zoneId
+ clusterresponse = self.apiClient.addCluster(clustercmd)
+ clusterId = clusterresponse[0].id
+
+ self.addHosts(cluster.hosts, zoneId, podId, clusterId, cluster.hypervisor)
+ self.createPrimaryStorages(cluster.primaryStorages, zoneId, podId, clusterId)
+
+ def createPrimaryStorages(self, primaryStorages, zoneId, podId, clusterId):
+ if primaryStorages is None:
+ return
+ for primary in primaryStorages:
+ primarycmd = createStoragePool.createStoragePoolCmd()
+ primarycmd.details = primary.details
+ primarycmd.name = primary.name
+ primarycmd.podid = podId
+ primarycmd.tags = primary.tags
+ primarycmd.url = primary.url
+ primarycmd.zoneid = zoneId
+ primarycmd.clusterid = clusterId
+ self.apiClient.createStoragePool(primarycmd)
+
+ def createpods(self, pods, zone, zoneId):
+ if pods is None:
+ return
+ for pod in pods:
+ createpod = createPod.createPodCmd()
+ createpod.name = pod.name
+ createpod.gateway = pod.gateway
+ createpod.netmask = pod.netmask
+ createpod.startip = pod.startip
+ createpod.endip = pod.endip
+ createpod.zoneid = zoneId
+ createpodResponse = self.apiClient.createPod(createpod)
+ podId = createpodResponse.id
+
+ if pod.guestIpRanges is not None:
+ self.createVlanIpRanges("Basic", pod.guestIpRanges, zoneId, podId)
+
+ self.createClusters(pod.clusters, zoneId, podId)
+
+
+ def createVlanIpRanges(self, mode, ipranges, zoneId, podId=None, networkId=None):
+ if ipranges is None:
+ return
+ for iprange in ipranges:
+ vlanipcmd = createVlanIpRange.createVlanIpRangeCmd()
+ vlanipcmd.account = iprange.account
+ vlanipcmd.domainid = iprange.domainid
+ vlanipcmd.endip = iprange.endip
+ vlanipcmd.gateway = iprange.gateway
+ vlanipcmd.netmask = iprange.netmask
+ vlanipcmd.networkid = networkId
+ vlanipcmd.podid = podId
+ vlanipcmd.startip = iprange.startip
+ vlanipcmd.zoneid = zoneId
+ vlanipcmd.vlan = iprange.vlan
+ if mode == "Basic":
+ vlanipcmd.forvirtualnetwork = "false"
+ else:
+ vlanipcmd.forvirtualnetwork = "true"
+
+ self.apiClient.createVlanIpRange(vlanipcmd)
+
+ def createSecondaryStorages(self, secondaryStorages, zoneId):
+ if secondaryStorages is None:
+ return
+ for secondary in secondaryStorages:
+ secondarycmd = addSecondaryStorage.addSecondaryStorageCmd()
+ secondarycmd.url = secondary.url
+ secondarycmd.zoneid = zoneId
+ self.apiClient.addSecondaryStorage(secondarycmd)
+
+ def createnetworks(self, networks, zoneId):
+ if networks is None:
+ return
+ for network in networks:
+ ipranges = network.ipranges
+ if ipranges is None:
+ continue
+ iprange = ipranges.pop()
+ networkcmd = createNetwork.createNetworkCmd()
+ networkcmd.account = network.account
+ networkcmd.displaytext = network.displaytext
+ networkcmd.domainid = network.domainid
+ networkcmd.endip = iprange.endip
+ networkcmd.gateway = iprange.gateway
+ networkcmd.isdefault = network.isdefault
+ networkcmd.isshared = network.isshared
+ networkcmd.name = network.name
+ networkcmd.netmask = iprange.netmask
+ networkcmd.networkdomain = network.networkdomain
+ networkcmd.networkofferingid = network.networkofferingid
+ networkcmdresponse = self.apiClient.createNetwork(networkcmd)
+ networkId = networkcmdresponse.id
+ self.createVlanIpRanges("Advanced", ipranges, zoneId, networkId=networkId)
+
+ def createZones(self, zones):
+ for zone in zones:
+ '''create a zone'''
+ createzone = createZone.createZoneCmd()
+ createzone.guestcidraddress = zone.guestcidraddress
+ createzone.dns1 = zone.dns1
+ createzone.dns2 = zone.dns2
+ createzone.internaldns1 = zone.internaldns1
+ createzone.internaldns2 = zone.internaldns2
+ createzone.name = zone.name
+ createzone.securitygroupenabled = zone.securitygroupenabled
+ createzone.networktype = zone.networktype
+ createzone.vlan = zone.vlan
+
+ zoneresponse = self.apiClient.createZone(createzone)
+ zoneId = zoneresponse.id
+
+ '''create pods'''
+ self.createpods(zone.pods, zone, zoneId)
+
+ if zone.networktype == "Advanced":
+ '''create pubic network'''
+ self.createVlanIpRanges(zone.networktype, zone.ipranges, zoneId)
+
+ self.createnetworks(zone.networks, zoneId)
+ '''create secondary storage'''
+ self.createSecondaryStorages(zone.secondaryStorages, zoneId)
+ return
+
+ def registerApiKey(self):
+ listuser = listUsers.listUsersCmd()
+ listuser.account = "admin"
+ listuserRes = self.testClient.getApiClient().listUsers(listuser)
+ userId = listuserRes[0].id
+ apiKey = listuserRes[0].apikey
+ securityKey = listuserRes[0].secretkey
+ if apiKey is None:
+ registerUser = registerUserKeys.registerUserKeysCmd()
+ registerUser.id = userId
+ registerUserRes = self.testClient.getApiClient().registerUserKeys(registerUser)
+ apiKey = registerUserRes.apikey
+ securityKey = registerUserRes.secretkey
+
+ self.config.mgtSvr[0].port = 8080
+ self.config.mgtSvr[0].apiKey = apiKey
+ self.config.mgtSvr[0].securityKey = securityKey
+ return apiKey, securityKey
+
+ def loadCfg(self):
+ try:
+ self.config = configGenerator.get_setup_config(self.configFile)
+ except:
+ raise cloudstackException.InvalidParameterException( \
+ "Failed to load config" + sys.exc_info())
+
+ mgt = self.config.mgtSvr[0]
+
+ loggers = self.config.logger
+ testClientLogFile = None
+ self.testCaseLogFile = None
+ self.testResultLogFile = None
+ if loggers is not None and len(loggers) > 0:
+ for log in loggers:
+ if log.name == "TestClient":
+ testClientLogFile = log.file
+ elif log.name == "TestCase":
+ self.testCaseLogFile = log.file
+ elif log.name == "TestResult":
+ self.testResultLogFile = log.file
+
+ testClientLogger = None
+ if testClientLogFile is not None:
+ testClientLogger = logging.getLogger("testclient.deploy.deployDataCenters")
+ fh = logging.FileHandler(testClientLogFile)
+ fh.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(name)s - %(message)s"))
+ testClientLogger.addHandler(fh)
+ testClientLogger.setLevel(logging.DEBUG)
+ self.testClientLogger = testClientLogger
+
+ self.testClient = cloudstackTestClient.cloudstackTestClient(mgt.mgtSvrIp, mgt.port, mgt.apiKey, mgt.securityKey, logging=self.testClientLogger)
+ if mgt.apiKey is None:
+ apiKey, securityKey = self.registerApiKey()
+ self.testClient.close()
+ self.testClient = cloudstackTestClient.cloudstackTestClient(mgt.mgtSvrIp, 8080, apiKey, securityKey, logging=self.testClientLogger)
+
+ '''config database'''
+ dbSvr = self.config.dbSvr
+ self.testClient.dbConfigure(dbSvr.dbSvr, dbSvr.port, dbSvr.user, dbSvr.passwd, dbSvr.db)
+ self.apiClient = self.testClient.getApiClient()
+
+ def updateConfiguration(self, globalCfg):
+ if globalCfg is None:
+ return None
+
+ for config in globalCfg:
+ updateCfg = updateConfiguration.updateConfigurationCmd()
+ updateCfg.name = config.name
+ updateCfg.value = config.value
+ self.apiClient.updateConfiguration(updateCfg)
+
+ def deploy(self):
+ self.loadCfg()
+ self.createZones(self.config.zones)
+ self.updateConfiguration(self.config.globalConfig)
+
+
+if __name__ == "__main__":
+
+ parser = OptionParser()
+
+ parser.add_option("-i", "--intput", action="store", default="./datacenterCfg", dest="input", help="the path where the json config file generated, by default is ./datacenterCfg")
+
+ (options, args) = parser.parse_args()
+
+ deploy = deployDataCenters(options.input)
+ deploy.deploy()
+
+ '''
+ create = createStoragePool.createStoragePoolCmd()
+ create.clusterid = 1
+ create.podid = 2
+ create.name = "fdffdf"
+ create.url = "nfs://jfkdjf/fdkjfkd"
+ create.zoneid = 2
+
+ deploy = deployDataCenters("./datacenterCfg")
+ deploy.loadCfg()
+ deploy.apiClient.createStoragePool(create)
+ '''
diff --git a/tools/marvin/marvin/jsonHelper.py b/tools/marvin/marvin/jsonHelper.py
new file mode 100644
index 00000000000..6bf6d056f1f
--- /dev/null
+++ b/tools/marvin/marvin/jsonHelper.py
@@ -0,0 +1,174 @@
+import cloudstackException
+import json
+import inspect
+from cloudstackAPI import *
+import pdb
+
+class jsonLoader:
+ '''The recursive class for building and representing objects with.'''
+ def __init__(self, obj):
+ for k in obj:
+ v = obj[k]
+ if isinstance(v, dict):
+ setattr(self, k, jsonLoader(v))
+ elif isinstance(v, (list, tuple)):
+ setattr(self, k, [jsonLoader(elem) for elem in v])
+ else:
+ setattr(self,k,v)
+ def __getattr__(self, val):
+ if val in self.__dict__:
+ return self.__dict__[val]
+ else:
+ return None
+ def __repr__(self):
+ return '{%s}' % str(', '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.iteritems()))
+ def __str__(self):
+ return '{%s}' % str(', '.join('%s : %s' % (k, repr(v)) for (k, v) in self.__dict__.iteritems()))
+
+
+class jsonDump:
+ @staticmethod
+ def __serialize(obj):
+ """Recursively walk object's hierarchy."""
+ if isinstance(obj, (bool, int, long, float, basestring)):
+ return obj
+ elif isinstance(obj, dict):
+ obj = obj.copy()
+ newobj = {}
+ for key in obj:
+ if obj[key] is not None:
+ if (isinstance(obj[key], list) and len(obj[key]) == 0):
+ continue
+ newobj[key] = jsonDump.__serialize(obj[key])
+
+ return newobj
+ elif isinstance(obj, list):
+ return [jsonDump.__serialize(item) for item in obj]
+ elif isinstance(obj, tuple):
+ return tuple(jsonDump.__serialize([item for item in obj]))
+ elif hasattr(obj, '__dict__'):
+ return jsonDump.__serialize(obj.__dict__)
+ else:
+ return repr(obj) # Don't know how to handle, convert to string
+
+ @staticmethod
+ def dump(obj):
+ return jsonDump.__serialize(obj)
+
+def getclassFromName(cmd, name):
+ module = inspect.getmodule(cmd)
+ return getattr(module, name)()
+
+def finalizeResultObj(result, responseName, responsecls):
+ if responsecls is None and responseName.endswith("response") and responseName != "queryasyncjobresultresponse":
+ '''infer the response class from the name'''
+ moduleName = responseName.replace("response", "")
+ try:
+ responsecls = getclassFromName(moduleName, responseName)
+ except:
+ pass
+
+ if responseName is not None and responseName == "queryasyncjobresultresponse" and responsecls is not None and result.jobresult is not None:
+ result.jobresult = finalizeResultObj(result.jobresult, None, responsecls)
+ return result
+ elif responsecls is not None:
+ for k,v in result.__dict__.iteritems():
+ if k in responsecls.__dict__:
+ return result
+
+ attr = result.__dict__.keys()[0]
+
+ value = getattr(result, attr)
+ if not isinstance(value, jsonLoader):
+ return result
+
+ findObj = False
+ for k,v in value.__dict__.iteritems():
+ if k in responsecls.__dict__:
+ findObj = True
+ break
+ if findObj:
+ return value
+ else:
+ return result
+ else:
+ return result
+
+
+
+def getResultObj(returnObj, responsecls=None):
+ returnObj = json.loads(returnObj)
+
+ if len(returnObj) == 0:
+ return None
+ responseName = returnObj.keys()[0]
+
+ response = returnObj[responseName]
+ if len(response) == 0:
+ return None
+
+ result = jsonLoader(response)
+ if result.errorcode is not None:
+ errMsg = "errorCode: %s, errorText:%s"%(result.errorcode, result.errortext)
+ raise cloudstackException.cloudstackAPIException(responseName.replace("response", ""), errMsg)
+
+ if result.count is not None:
+ for key in result.__dict__.iterkeys():
+ if key == "count":
+ continue
+ else:
+ return getattr(result, key)
+ else:
+ return finalizeResultObj(result, responseName, responsecls)
+
+if __name__ == "__main__":
+
+ result = '{ "listnetworkserviceprovidersresponse" : { "count":1 ,"networkserviceprovider" : [ {"name":"VirtualRouter","physicalnetworkid":"ad2948fc-1054-46c7-b1c7-61d990b86710","destinationphysicalnetworkid":"0","state":"Disabled","id":"d827cae4-4998-4037-95a2-55b92b6318b1","servicelist":["Vpn","Dhcp","Dns","Gateway","Firewall","Lb","SourceNat","StaticNat","PortForwarding","UserData"]} ] } }'
+ nsp = getResultObj(result)
+
+ result = '{ "listzonesresponse" : { "count":1 ,"zone" : [ {"id":1,"name":"test0","dns1":"8.8.8.8","dns2":"4.4.4.4","internaldns1":"192.168.110.254","internaldns2":"192.168.110.253","networktype":"Basic","securitygroupsenabled":true,"allocationstate":"Enabled","zonetoken":"5e818a11-6b00-3429-9a07-e27511d3169a","dhcpprovider":"DhcpServer"} ] } }'
+ zones = getResultObj(result)
+ print zones[0].id
+ res = authorizeSecurityGroupIngress.authorizeSecurityGroupIngressResponse()
+ result = '{ "queryasyncjobresultresponse" : {"jobid":10,"jobstatus":1,"jobprocstatus":0,"jobresultcode":0,"jobresulttype":"object","jobresult":{"securitygroup":{"id":1,"name":"default","description":"Default Security Group","account":"admin","domainid":1,"domain":"ROOT","ingressrule":[{"ruleid":1,"protocol":"tcp","startport":22,"endport":22,"securitygroupname":"default","account":"a"},{"ruleid":2,"protocol":"tcp","startport":22,"endport":22,"securitygroupname":"default","account":"b"}]}}} }'
+ asynJob = getResultObj(result, res)
+ print asynJob.jobid, repr(asynJob.jobresult)
+ print asynJob.jobresult.ingressrule[0].account
+
+ result = '{ "queryasyncjobresultresponse" : {"errorcode" : 431, "errortext" : "Unable to execute API command queryasyncjobresultresponse due to missing parameter jobid"} }'
+ try:
+ asynJob = getResultObj(result)
+ except cloudstackException.cloudstackAPIException, e:
+ print e
+
+ result = '{ "queryasyncjobresultresponse" : {} }'
+ asynJob = getResultObj(result)
+ print asynJob
+
+ result = '{}'
+ asynJob = getResultObj(result)
+ print asynJob
+
+ result = '{ "createzoneresponse" : { "zone" : {"id":1,"name":"test0","dns1":"8.8.8.8","dns2":"4.4.4.4","internaldns1":"192.168.110.254","internaldns2":"192.168.110.253","networktype":"Basic","securitygroupsenabled":true,"allocationstate":"Enabled","zonetoken":"3442f287-e932-3111-960b-514d1f9c4610","dhcpprovider":"DhcpServer"} } }'
+ res = createZone.createZoneResponse()
+ zone = getResultObj(result, res)
+ print zone.id
+
+ result = '{ "attachvolumeresponse" : {"jobid":24} }'
+ res = attachVolume.attachVolumeResponse()
+ res = getResultObj(result, res)
+ print res
+
+ result = '{ "listtemplatesresponse" : { } }'
+ print getResultObj(result, listTemplates.listTemplatesResponse())
+
+ result = '{ "queryasyncjobresultresponse" : {"jobid":34,"jobstatus":2,"jobprocstatus":0,"jobresultcode":530,"jobresulttype":"object","jobresult":{"errorcode":431,"errortext":"Please provide either a volume id, or a tuple(device id, instance id)"}} }'
+ print getResultObj(result, listTemplates.listTemplatesResponse())
+ result = '{ "queryasyncjobresultresponse" : {"jobid":41,"jobstatus":1,"jobprocstatus":0,"jobresultcode":0,"jobresulttype":"object","jobresult":{"virtualmachine":{"id":37,"name":"i-2-37-TEST","displayname":"i-2-37-TEST","account":"admin","domainid":1,"domain":"ROOT","created":"2011-08-25T11:13:42-0700","state":"Running","haenable":false,"zoneid":1,"zonename":"test0","hostid":5,"hostname":"SimulatedAgent.1e629060-f547-40dd-b792-57cdc4b7d611","templateid":10,"templatename":"CentOS 5.3(64-bit) no GUI (Simulator)","templatedisplaytext":"CentOS 5.3(64-bit) no GUI (Simulator)","passwordenabled":false,"serviceofferingid":7,"serviceofferingname":"Small Instance","cpunumber":1,"cpuspeed":500,"memory":512,"guestosid":11,"rootdeviceid":0,"rootdevicetype":"NetworkFilesystem","securitygroup":[{"id":1,"name":"default","description":"Default Security Group"}],"nic":[{"id":43,"networkid":204,"netmask":"255.255.255.0","gateway":"192.168.1.1","ipaddress":"192.168.1.27","isolationuri":"ec2://untagged","broadcasturi":"vlan://untagged","traffictype":"Guest","type":"Direct","isdefault":true,"macaddress":"06:56:b8:00:00:53"}],"hypervisor":"Simulator"}}} }'
+ vm = getResultObj(result, deployVirtualMachine.deployVirtualMachineResponse())
+ print vm.jobresult.id
+
+ cmd = deployVirtualMachine.deployVirtualMachineCmd()
+ responsename = cmd.__class__.__name__.replace("Cmd", "Response")
+ response = getclassFromName(cmd, responsename)
+ print response.id
diff --git a/tools/marvin/marvin/pymysql/__init__.py b/tools/marvin/marvin/pymysql/__init__.py
new file mode 100644
index 00000000000..903107e539a
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/__init__.py
@@ -0,0 +1,131 @@
+'''
+PyMySQL: A pure-Python drop-in replacement for MySQLdb.
+
+Copyright (c) 2010 PyMySQL contributors
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+'''
+
+VERSION = (0, 4, None)
+
+from constants import FIELD_TYPE
+from converters import escape_dict, escape_sequence, escape_string
+from err import Warning, Error, InterfaceError, DataError, \
+ DatabaseError, OperationalError, IntegrityError, InternalError, \
+ NotSupportedError, ProgrammingError, MySQLError
+from times import Date, Time, Timestamp, \
+ DateFromTicks, TimeFromTicks, TimestampFromTicks
+
+import sys
+
+try:
+ frozenset
+except NameError:
+ from sets import ImmutableSet as frozenset
+ try:
+ from sets import BaseSet as set
+ except ImportError:
+ from sets import Set as set
+
+threadsafety = 1
+apilevel = "2.0"
+paramstyle = "format"
+
+class DBAPISet(frozenset):
+
+
+ def __ne__(self, other):
+ if isinstance(other, set):
+ return super(DBAPISet, self).__ne__(self, other)
+ else:
+ return other not in self
+
+ def __eq__(self, other):
+ if isinstance(other, frozenset):
+ return frozenset.__eq__(self, other)
+ else:
+ return other in self
+
+ def __hash__(self):
+ return frozenset.__hash__(self)
+
+
+STRING = DBAPISet([FIELD_TYPE.ENUM, FIELD_TYPE.STRING,
+ FIELD_TYPE.VAR_STRING])
+BINARY = DBAPISet([FIELD_TYPE.BLOB, FIELD_TYPE.LONG_BLOB,
+ FIELD_TYPE.MEDIUM_BLOB, FIELD_TYPE.TINY_BLOB])
+NUMBER = DBAPISet([FIELD_TYPE.DECIMAL, FIELD_TYPE.DOUBLE, FIELD_TYPE.FLOAT,
+ FIELD_TYPE.INT24, FIELD_TYPE.LONG, FIELD_TYPE.LONGLONG,
+ FIELD_TYPE.TINY, FIELD_TYPE.YEAR])
+DATE = DBAPISet([FIELD_TYPE.DATE, FIELD_TYPE.NEWDATE])
+TIME = DBAPISet([FIELD_TYPE.TIME])
+TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
+DATETIME = TIMESTAMP
+ROWID = DBAPISet()
+
+def Binary(x):
+ """Return x as a binary type."""
+ return str(x)
+
+def Connect(*args, **kwargs):
+ """
+ Connect to the database; see connections.Connection.__init__() for
+ more information.
+ """
+ from connections import Connection
+ return Connection(*args, **kwargs)
+
+def get_client_info(): # for MySQLdb compatibility
+ return '%s.%s.%s' % VERSION
+
+connect = Connection = Connect
+
+# we include a doctored version_info here for MySQLdb compatibility
+version_info = (1,2,2,"final",0)
+
+NULL = "NULL"
+
+__version__ = get_client_info()
+
+def thread_safe():
+ return True # match MySQLdb.thread_safe()
+
+def install_as_MySQLdb():
+ """
+ After this function is called, any application that imports MySQLdb or
+ _mysql will unwittingly actually use
+ """
+ sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
+
+__all__ = [
+ 'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
+ 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
+ 'DataError', 'DatabaseError', 'Error', 'FIELD_TYPE', 'IntegrityError',
+ 'InterfaceError', 'InternalError', 'MySQLError', 'NULL', 'NUMBER',
+ 'NotSupportedError', 'DBAPISet', 'OperationalError', 'ProgrammingError',
+ 'ROWID', 'STRING', 'TIME', 'TIMESTAMP', 'Warning', 'apilevel', 'connect',
+ 'connections', 'constants', 'converters', 'cursors',
+ 'escape_dict', 'escape_sequence', 'escape_string', 'get_client_info',
+ 'paramstyle', 'threadsafety', 'version_info',
+
+ "install_as_MySQLdb",
+
+ "NULL","__version__",
+ ]
diff --git a/tools/marvin/marvin/pymysql/charset.py b/tools/marvin/marvin/pymysql/charset.py
new file mode 100644
index 00000000000..10a91bd19f2
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/charset.py
@@ -0,0 +1,174 @@
+MBLENGTH = {
+ 8:1,
+ 33:3,
+ 88:2,
+ 91:2
+ }
+
+class Charset:
+ def __init__(self, id, name, collation, is_default):
+ self.id, self.name, self.collation = id, name, collation
+ self.is_default = is_default == 'Yes'
+
+class Charsets:
+ def __init__(self):
+ self._by_id = {}
+
+ def add(self, c):
+ self._by_id[c.id] = c
+
+ def by_id(self, id):
+ return self._by_id[id]
+
+ def by_name(self, name):
+ for c in self._by_id.values():
+ if c.name == name and c.is_default:
+ return c
+
+_charsets = Charsets()
+"""
+Generated with:
+
+mysql -N -s -e "select id, character_set_name, collation_name, is_default
+from information_schema.collations order by id;" | python -c "import sys
+for l in sys.stdin.readlines():
+ id, name, collation, is_default = l.split(chr(9))
+ print '_charsets.add(Charset(%s, \'%s\', \'%s\', \'%s\'))' \
+ % (id, name, collation, is_default.strip())
+"
+
+"""
+_charsets.add(Charset(1, 'big5', 'big5_chinese_ci', 'Yes'))
+_charsets.add(Charset(2, 'latin2', 'latin2_czech_cs', ''))
+_charsets.add(Charset(3, 'dec8', 'dec8_swedish_ci', 'Yes'))
+_charsets.add(Charset(4, 'cp850', 'cp850_general_ci', 'Yes'))
+_charsets.add(Charset(5, 'latin1', 'latin1_german1_ci', ''))
+_charsets.add(Charset(6, 'hp8', 'hp8_english_ci', 'Yes'))
+_charsets.add(Charset(7, 'koi8r', 'koi8r_general_ci', 'Yes'))
+_charsets.add(Charset(8, 'latin1', 'latin1_swedish_ci', 'Yes'))
+_charsets.add(Charset(9, 'latin2', 'latin2_general_ci', 'Yes'))
+_charsets.add(Charset(10, 'swe7', 'swe7_swedish_ci', 'Yes'))
+_charsets.add(Charset(11, 'ascii', 'ascii_general_ci', 'Yes'))
+_charsets.add(Charset(12, 'ujis', 'ujis_japanese_ci', 'Yes'))
+_charsets.add(Charset(13, 'sjis', 'sjis_japanese_ci', 'Yes'))
+_charsets.add(Charset(14, 'cp1251', 'cp1251_bulgarian_ci', ''))
+_charsets.add(Charset(15, 'latin1', 'latin1_danish_ci', ''))
+_charsets.add(Charset(16, 'hebrew', 'hebrew_general_ci', 'Yes'))
+_charsets.add(Charset(18, 'tis620', 'tis620_thai_ci', 'Yes'))
+_charsets.add(Charset(19, 'euckr', 'euckr_korean_ci', 'Yes'))
+_charsets.add(Charset(20, 'latin7', 'latin7_estonian_cs', ''))
+_charsets.add(Charset(21, 'latin2', 'latin2_hungarian_ci', ''))
+_charsets.add(Charset(22, 'koi8u', 'koi8u_general_ci', 'Yes'))
+_charsets.add(Charset(23, 'cp1251', 'cp1251_ukrainian_ci', ''))
+_charsets.add(Charset(24, 'gb2312', 'gb2312_chinese_ci', 'Yes'))
+_charsets.add(Charset(25, 'greek', 'greek_general_ci', 'Yes'))
+_charsets.add(Charset(26, 'cp1250', 'cp1250_general_ci', 'Yes'))
+_charsets.add(Charset(27, 'latin2', 'latin2_croatian_ci', ''))
+_charsets.add(Charset(28, 'gbk', 'gbk_chinese_ci', 'Yes'))
+_charsets.add(Charset(29, 'cp1257', 'cp1257_lithuanian_ci', ''))
+_charsets.add(Charset(30, 'latin5', 'latin5_turkish_ci', 'Yes'))
+_charsets.add(Charset(31, 'latin1', 'latin1_german2_ci', ''))
+_charsets.add(Charset(32, 'armscii8', 'armscii8_general_ci', 'Yes'))
+_charsets.add(Charset(33, 'utf8', 'utf8_general_ci', 'Yes'))
+_charsets.add(Charset(34, 'cp1250', 'cp1250_czech_cs', ''))
+_charsets.add(Charset(35, 'ucs2', 'ucs2_general_ci', 'Yes'))
+_charsets.add(Charset(36, 'cp866', 'cp866_general_ci', 'Yes'))
+_charsets.add(Charset(37, 'keybcs2', 'keybcs2_general_ci', 'Yes'))
+_charsets.add(Charset(38, 'macce', 'macce_general_ci', 'Yes'))
+_charsets.add(Charset(39, 'macroman', 'macroman_general_ci', 'Yes'))
+_charsets.add(Charset(40, 'cp852', 'cp852_general_ci', 'Yes'))
+_charsets.add(Charset(41, 'latin7', 'latin7_general_ci', 'Yes'))
+_charsets.add(Charset(42, 'latin7', 'latin7_general_cs', ''))
+_charsets.add(Charset(43, 'macce', 'macce_bin', ''))
+_charsets.add(Charset(44, 'cp1250', 'cp1250_croatian_ci', ''))
+_charsets.add(Charset(47, 'latin1', 'latin1_bin', ''))
+_charsets.add(Charset(48, 'latin1', 'latin1_general_ci', ''))
+_charsets.add(Charset(49, 'latin1', 'latin1_general_cs', ''))
+_charsets.add(Charset(50, 'cp1251', 'cp1251_bin', ''))
+_charsets.add(Charset(51, 'cp1251', 'cp1251_general_ci', 'Yes'))
+_charsets.add(Charset(52, 'cp1251', 'cp1251_general_cs', ''))
+_charsets.add(Charset(53, 'macroman', 'macroman_bin', ''))
+_charsets.add(Charset(57, 'cp1256', 'cp1256_general_ci', 'Yes'))
+_charsets.add(Charset(58, 'cp1257', 'cp1257_bin', ''))
+_charsets.add(Charset(59, 'cp1257', 'cp1257_general_ci', 'Yes'))
+_charsets.add(Charset(63, 'binary', 'binary', 'Yes'))
+_charsets.add(Charset(64, 'armscii8', 'armscii8_bin', ''))
+_charsets.add(Charset(65, 'ascii', 'ascii_bin', ''))
+_charsets.add(Charset(66, 'cp1250', 'cp1250_bin', ''))
+_charsets.add(Charset(67, 'cp1256', 'cp1256_bin', ''))
+_charsets.add(Charset(68, 'cp866', 'cp866_bin', ''))
+_charsets.add(Charset(69, 'dec8', 'dec8_bin', ''))
+_charsets.add(Charset(70, 'greek', 'greek_bin', ''))
+_charsets.add(Charset(71, 'hebrew', 'hebrew_bin', ''))
+_charsets.add(Charset(72, 'hp8', 'hp8_bin', ''))
+_charsets.add(Charset(73, 'keybcs2', 'keybcs2_bin', ''))
+_charsets.add(Charset(74, 'koi8r', 'koi8r_bin', ''))
+_charsets.add(Charset(75, 'koi8u', 'koi8u_bin', ''))
+_charsets.add(Charset(77, 'latin2', 'latin2_bin', ''))
+_charsets.add(Charset(78, 'latin5', 'latin5_bin', ''))
+_charsets.add(Charset(79, 'latin7', 'latin7_bin', ''))
+_charsets.add(Charset(80, 'cp850', 'cp850_bin', ''))
+_charsets.add(Charset(81, 'cp852', 'cp852_bin', ''))
+_charsets.add(Charset(82, 'swe7', 'swe7_bin', ''))
+_charsets.add(Charset(83, 'utf8', 'utf8_bin', ''))
+_charsets.add(Charset(84, 'big5', 'big5_bin', ''))
+_charsets.add(Charset(85, 'euckr', 'euckr_bin', ''))
+_charsets.add(Charset(86, 'gb2312', 'gb2312_bin', ''))
+_charsets.add(Charset(87, 'gbk', 'gbk_bin', ''))
+_charsets.add(Charset(88, 'sjis', 'sjis_bin', ''))
+_charsets.add(Charset(89, 'tis620', 'tis620_bin', ''))
+_charsets.add(Charset(90, 'ucs2', 'ucs2_bin', ''))
+_charsets.add(Charset(91, 'ujis', 'ujis_bin', ''))
+_charsets.add(Charset(92, 'geostd8', 'geostd8_general_ci', 'Yes'))
+_charsets.add(Charset(93, 'geostd8', 'geostd8_bin', ''))
+_charsets.add(Charset(94, 'latin1', 'latin1_spanish_ci', ''))
+_charsets.add(Charset(95, 'cp932', 'cp932_japanese_ci', 'Yes'))
+_charsets.add(Charset(96, 'cp932', 'cp932_bin', ''))
+_charsets.add(Charset(97, 'eucjpms', 'eucjpms_japanese_ci', 'Yes'))
+_charsets.add(Charset(98, 'eucjpms', 'eucjpms_bin', ''))
+_charsets.add(Charset(99, 'cp1250', 'cp1250_polish_ci', ''))
+_charsets.add(Charset(128, 'ucs2', 'ucs2_unicode_ci', ''))
+_charsets.add(Charset(129, 'ucs2', 'ucs2_icelandic_ci', ''))
+_charsets.add(Charset(130, 'ucs2', 'ucs2_latvian_ci', ''))
+_charsets.add(Charset(131, 'ucs2', 'ucs2_romanian_ci', ''))
+_charsets.add(Charset(132, 'ucs2', 'ucs2_slovenian_ci', ''))
+_charsets.add(Charset(133, 'ucs2', 'ucs2_polish_ci', ''))
+_charsets.add(Charset(134, 'ucs2', 'ucs2_estonian_ci', ''))
+_charsets.add(Charset(135, 'ucs2', 'ucs2_spanish_ci', ''))
+_charsets.add(Charset(136, 'ucs2', 'ucs2_swedish_ci', ''))
+_charsets.add(Charset(137, 'ucs2', 'ucs2_turkish_ci', ''))
+_charsets.add(Charset(138, 'ucs2', 'ucs2_czech_ci', ''))
+_charsets.add(Charset(139, 'ucs2', 'ucs2_danish_ci', ''))
+_charsets.add(Charset(140, 'ucs2', 'ucs2_lithuanian_ci', ''))
+_charsets.add(Charset(141, 'ucs2', 'ucs2_slovak_ci', ''))
+_charsets.add(Charset(142, 'ucs2', 'ucs2_spanish2_ci', ''))
+_charsets.add(Charset(143, 'ucs2', 'ucs2_roman_ci', ''))
+_charsets.add(Charset(144, 'ucs2', 'ucs2_persian_ci', ''))
+_charsets.add(Charset(145, 'ucs2', 'ucs2_esperanto_ci', ''))
+_charsets.add(Charset(146, 'ucs2', 'ucs2_hungarian_ci', ''))
+_charsets.add(Charset(192, 'utf8', 'utf8_unicode_ci', ''))
+_charsets.add(Charset(193, 'utf8', 'utf8_icelandic_ci', ''))
+_charsets.add(Charset(194, 'utf8', 'utf8_latvian_ci', ''))
+_charsets.add(Charset(195, 'utf8', 'utf8_romanian_ci', ''))
+_charsets.add(Charset(196, 'utf8', 'utf8_slovenian_ci', ''))
+_charsets.add(Charset(197, 'utf8', 'utf8_polish_ci', ''))
+_charsets.add(Charset(198, 'utf8', 'utf8_estonian_ci', ''))
+_charsets.add(Charset(199, 'utf8', 'utf8_spanish_ci', ''))
+_charsets.add(Charset(200, 'utf8', 'utf8_swedish_ci', ''))
+_charsets.add(Charset(201, 'utf8', 'utf8_turkish_ci', ''))
+_charsets.add(Charset(202, 'utf8', 'utf8_czech_ci', ''))
+_charsets.add(Charset(203, 'utf8', 'utf8_danish_ci', ''))
+_charsets.add(Charset(204, 'utf8', 'utf8_lithuanian_ci', ''))
+_charsets.add(Charset(205, 'utf8', 'utf8_slovak_ci', ''))
+_charsets.add(Charset(206, 'utf8', 'utf8_spanish2_ci', ''))
+_charsets.add(Charset(207, 'utf8', 'utf8_roman_ci', ''))
+_charsets.add(Charset(208, 'utf8', 'utf8_persian_ci', ''))
+_charsets.add(Charset(209, 'utf8', 'utf8_esperanto_ci', ''))
+_charsets.add(Charset(210, 'utf8', 'utf8_hungarian_ci', ''))
+
+def charset_by_name(name):
+ return _charsets.by_name(name)
+
+def charset_by_id(id):
+ return _charsets.by_id(id)
+
diff --git a/tools/marvin/marvin/pymysql/connections.py b/tools/marvin/marvin/pymysql/connections.py
new file mode 100644
index 00000000000..8897644ab09
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/connections.py
@@ -0,0 +1,928 @@
+# Python implementation of the MySQL client-server protocol
+# http://forge.mysql.com/wiki/MySQL_Internals_ClientServer_Protocol
+
+try:
+ import hashlib
+ sha_new = lambda *args, **kwargs: hashlib.new("sha1", *args, **kwargs)
+except ImportError:
+ import sha
+ sha_new = sha.new
+
+import socket
+try:
+ import ssl
+ SSL_ENABLED = True
+except ImportError:
+ SSL_ENABLED = False
+
+import struct
+import sys
+import os
+import ConfigParser
+
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+
+from charset import MBLENGTH, charset_by_name, charset_by_id
+from cursors import Cursor
+from constants import FIELD_TYPE, FLAG
+from constants import SERVER_STATUS
+from constants.CLIENT import *
+from constants.COMMAND import *
+from util import join_bytes, byte2int, int2byte
+from converters import escape_item, encoders, decoders
+from err import raise_mysql_exception, Warning, Error, \
+ InterfaceError, DataError, DatabaseError, OperationalError, \
+ IntegrityError, InternalError, NotSupportedError, ProgrammingError
+
+DEBUG = False
+
+NULL_COLUMN = 251
+UNSIGNED_CHAR_COLUMN = 251
+UNSIGNED_SHORT_COLUMN = 252
+UNSIGNED_INT24_COLUMN = 253
+UNSIGNED_INT64_COLUMN = 254
+UNSIGNED_CHAR_LENGTH = 1
+UNSIGNED_SHORT_LENGTH = 2
+UNSIGNED_INT24_LENGTH = 3
+UNSIGNED_INT64_LENGTH = 8
+
+DEFAULT_CHARSET = 'latin1'
+
+
+def dump_packet(data):
+
+ def is_ascii(data):
+ if byte2int(data) >= 65 and byte2int(data) <= 122: #data.isalnum():
+ return data
+ return '.'
+ print "packet length %d" % len(data)
+ print "method call[1]: %s" % sys._getframe(1).f_code.co_name
+ print "method call[2]: %s" % sys._getframe(2).f_code.co_name
+ print "method call[3]: %s" % sys._getframe(3).f_code.co_name
+ print "method call[4]: %s" % sys._getframe(4).f_code.co_name
+ print "method call[5]: %s" % sys._getframe(5).f_code.co_name
+ print "-" * 88
+ dump_data = [data[i:i+16] for i in xrange(len(data)) if i%16 == 0]
+ for d in dump_data:
+ print ' '.join(map(lambda x:"%02X" % byte2int(x), d)) + \
+ ' ' * (16 - len(d)) + ' ' * 2 + \
+ ' '.join(map(lambda x:"%s" % is_ascii(x), d))
+ print "-" * 88
+ print ""
+
+def _scramble(password, message):
+ if password == None or len(password) == 0:
+ return int2byte(0)
+ if DEBUG: print 'password=' + password
+ stage1 = sha_new(password).digest()
+ stage2 = sha_new(stage1).digest()
+ s = sha_new()
+ s.update(message)
+ s.update(stage2)
+ result = s.digest()
+ return _my_crypt(result, stage1)
+
+def _my_crypt(message1, message2):
+ length = len(message1)
+ result = struct.pack('B', length)
+ for i in xrange(length):
+ x = (struct.unpack('B', message1[i:i+1])[0] ^ \
+ struct.unpack('B', message2[i:i+1])[0])
+ result += struct.pack('B', x)
+ return result
+
+# old_passwords support ported from libmysql/password.c
+SCRAMBLE_LENGTH_323 = 8
+
+class RandStruct_323(object):
+ def __init__(self, seed1, seed2):
+ self.max_value = 0x3FFFFFFFL
+ self.seed1 = seed1 % self.max_value
+ self.seed2 = seed2 % self.max_value
+
+ def my_rnd(self):
+ self.seed1 = (self.seed1 * 3L + self.seed2) % self.max_value
+ self.seed2 = (self.seed1 + self.seed2 + 33L) % self.max_value
+ return float(self.seed1) / float(self.max_value)
+
+def _scramble_323(password, message):
+ hash_pass = _hash_password_323(password)
+ hash_message = _hash_password_323(message[:SCRAMBLE_LENGTH_323])
+ hash_pass_n = struct.unpack(">LL", hash_pass)
+ hash_message_n = struct.unpack(">LL", hash_message)
+
+ rand_st = RandStruct_323(hash_pass_n[0] ^ hash_message_n[0],
+ hash_pass_n[1] ^ hash_message_n[1])
+ outbuf = StringIO.StringIO()
+ for _ in xrange(min(SCRAMBLE_LENGTH_323, len(message))):
+ outbuf.write(int2byte(int(rand_st.my_rnd() * 31) + 64))
+ extra = int2byte(int(rand_st.my_rnd() * 31))
+ out = outbuf.getvalue()
+ outbuf = StringIO.StringIO()
+ for c in out:
+ outbuf.write(int2byte(byte2int(c) ^ byte2int(extra)))
+ return outbuf.getvalue()
+
+def _hash_password_323(password):
+ nr = 1345345333L
+ add = 7L
+ nr2 = 0x12345671L
+
+ for c in [byte2int(x) for x in password if x not in (' ', '\t')]:
+ nr^= (((nr & 63)+add)*c)+ (nr << 8) & 0xFFFFFFFF
+ nr2= (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
+ add= (add + c) & 0xFFFFFFFF
+
+ r1 = nr & ((1L << 31) - 1L) # kill sign bits
+ r2 = nr2 & ((1L << 31) - 1L)
+
+ # pack
+ return struct.pack(">LL", r1, r2)
+
+def pack_int24(n):
+ return struct.pack('BBB', n&0xFF, (n>>8)&0xFF, (n>>16)&0xFF)
+
+def unpack_uint16(n):
+ return struct.unpack(' len(self.__data):
+ raise Exception('Invalid advance amount (%s) for cursor. '
+ 'Position=%s' % (length, new_position))
+ self.__position = new_position
+
+ def rewind(self, position=0):
+ """Set the position of the data buffer cursor to 'position'."""
+ if position < 0 or position > len(self.__data):
+ raise Exception("Invalid position to rewind cursor to: %s." % position)
+ self.__position = position
+
+ def peek(self, size):
+ """Look at the first 'size' bytes in packet without moving cursor."""
+ result = self.__data[self.__position:(self.__position+size)]
+ if len(result) != size:
+ error = ('Result length not requested length:\n'
+ 'Expected=%s. Actual=%s. Position: %s. Data Length: %s'
+ % (size, len(result), self.__position, len(self.__data)))
+ if DEBUG:
+ print error
+ self.dump()
+ raise AssertionError(error)
+ return result
+
+ def get_bytes(self, position, length=1):
+ """Get 'length' bytes starting at 'position'.
+
+ Position is start of payload (first four packet header bytes are not
+ included) starting at index '0'.
+
+ No error checking is done. If requesting outside end of buffer
+ an empty string (or string shorter than 'length') may be returned!
+ """
+ return self.__data[position:(position+length)]
+
+ def read_length_coded_binary(self):
+ """Read a 'Length Coded Binary' number from the data buffer.
+
+ Length coded numbers can be anywhere from 1 to 9 bytes depending
+ on the value of the first byte.
+ """
+ c = byte2int(self.read(1))
+ if c == NULL_COLUMN:
+ return None
+ if c < UNSIGNED_CHAR_COLUMN:
+ return c
+ elif c == UNSIGNED_SHORT_COLUMN:
+ return unpack_uint16(self.read(UNSIGNED_SHORT_LENGTH))
+ elif c == UNSIGNED_INT24_COLUMN:
+ return unpack_int24(self.read(UNSIGNED_INT24_LENGTH))
+ elif c == UNSIGNED_INT64_COLUMN:
+ # TODO: what was 'longlong'? confirm it wasn't used?
+ return unpack_int64(self.read(UNSIGNED_INT64_LENGTH))
+
+ def read_length_coded_string(self):
+ """Read a 'Length Coded String' from the data buffer.
+
+ A 'Length Coded String' consists first of a length coded
+ (unsigned, positive) integer represented in 1-9 bytes followed by
+ that many bytes of binary data. (For example "cat" would be "3cat".)
+ """
+ length = self.read_length_coded_binary()
+ if length is None:
+ return None
+ return self.read(length)
+
+ def is_ok_packet(self):
+ return byte2int(self.get_bytes(0)) == 0
+
+ def is_eof_packet(self):
+ return byte2int(self.get_bytes(0)) == 254 # 'fe'
+
+ def is_resultset_packet(self):
+ field_count = byte2int(self.get_bytes(0))
+ return field_count >= 1 and field_count <= 250
+
+ def is_error_packet(self):
+ return byte2int(self.get_bytes(0)) == 255
+
+ def check_error(self):
+ if self.is_error_packet():
+ self.rewind()
+ self.advance(1) # field_count == error (we already know that)
+ errno = unpack_uint16(self.read(2))
+ if DEBUG: print "errno = %d" % errno
+ raise_mysql_exception(self.__data)
+
+ def dump(self):
+ dump_packet(self.__data)
+
+
+class FieldDescriptorPacket(MysqlPacket):
+ """A MysqlPacket that represents a specific column's metadata in the result.
+
+ Parsing is automatically done and the results are exported via public
+ attributes on the class such as: db, table_name, name, length, type_code.
+ """
+
+ def __init__(self, *args):
+ MysqlPacket.__init__(self, *args)
+ self.__parse_field_descriptor()
+
+ def __parse_field_descriptor(self):
+ """Parse the 'Field Descriptor' (Metadata) packet.
+
+ This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
+ """
+ self.catalog = self.read_length_coded_string()
+ self.db = self.read_length_coded_string()
+ self.table_name = self.read_length_coded_string()
+ self.org_table = self.read_length_coded_string()
+ self.name = self.read_length_coded_string().decode(self.connection.charset)
+ self.org_name = self.read_length_coded_string()
+ self.advance(1) # non-null filler
+ self.charsetnr = struct.unpack(' 2:
+ use_unicode = True
+
+ if compress or named_pipe:
+ raise NotImplementedError, "compress and named_pipe arguments are not supported"
+
+ if ssl and (ssl.has_key('capath') or ssl.has_key('cipher')):
+ raise NotImplementedError, 'ssl options capath and cipher are not supported'
+
+ self.ssl = False
+ if ssl:
+ if not SSL_ENABLED:
+ raise NotImplementedError, "ssl module not found"
+ self.ssl = True
+ client_flag |= SSL
+ for k in ('key', 'cert', 'ca'):
+ v = None
+ if ssl.has_key(k):
+ v = ssl[k]
+ setattr(self, k, v)
+
+ if read_default_group and not read_default_file:
+ if sys.platform.startswith("win"):
+ read_default_file = "c:\\my.ini"
+ else:
+ read_default_file = "/etc/my.cnf"
+
+ if read_default_file:
+ if not read_default_group:
+ read_default_group = "client"
+
+ cfg = ConfigParser.RawConfigParser()
+ cfg.read(os.path.expanduser(read_default_file))
+
+ def _config(key, default):
+ try:
+ return cfg.get(read_default_group,key)
+ except:
+ return default
+
+ user = _config("user",user)
+ passwd = _config("password",passwd)
+ host = _config("host", host)
+ db = _config("db",db)
+ unix_socket = _config("socket",unix_socket)
+ port = _config("port", port)
+ charset = _config("default-character-set", charset)
+
+ self.host = host
+ self.port = port
+ self.user = user
+ self.password = passwd
+ self.db = db
+ self.unix_socket = unix_socket
+ if charset:
+ self.charset = charset
+ self.use_unicode = True
+ else:
+ self.charset = DEFAULT_CHARSET
+ self.use_unicode = False
+
+ if use_unicode is not None:
+ self.use_unicode = use_unicode
+
+ client_flag |= CAPABILITIES
+ client_flag |= MULTI_STATEMENTS
+ if self.db:
+ client_flag |= CONNECT_WITH_DB
+ self.client_flag = client_flag
+
+ self.cursorclass = cursorclass
+ self.connect_timeout = connect_timeout
+
+ self._connect()
+
+ self.messages = []
+ self.set_charset(charset)
+ self.encoders = encoders
+ self.decoders = conv
+
+ self._result = None
+ self._affected_rows = 0
+ self.host_info = "Not connected"
+
+ self.autocommit(False)
+
+ if sql_mode is not None:
+ c = self.cursor()
+ c.execute("SET sql_mode=%s", (sql_mode,))
+
+ self.commit()
+
+ if init_command is not None:
+ c = self.cursor()
+ c.execute(init_command)
+
+ self.commit()
+
+
+ def close(self):
+ ''' Send the quit message and close the socket '''
+ if self.socket is None:
+ raise Error("Already closed")
+ send_data = struct.pack('= i + 1:
+ i += 1
+
+ self.server_capabilities = struct.unpack('= i+12-1:
+ rest_salt = data[i:i+12]
+ self.salt += rest_salt
+
+ def get_server_info(self):
+ return self.server_version
+
+ Warning = Warning
+ Error = Error
+ InterfaceError = InterfaceError
+ DatabaseError = DatabaseError
+ DataError = DataError
+ OperationalError = OperationalError
+ IntegrityError = IntegrityError
+ InternalError = InternalError
+ ProgrammingError = ProgrammingError
+ NotSupportedError = NotSupportedError
+
+# TODO: move OK and EOF packet parsing/logic into a proper subclass
+# of MysqlPacket like has been done with FieldDescriptorPacket.
+class MySQLResult(object):
+
+ def __init__(self, connection):
+ from weakref import proxy
+ self.connection = proxy(connection)
+ self.affected_rows = None
+ self.insert_id = None
+ self.server_status = 0
+ self.warning_count = 0
+ self.message = None
+ self.field_count = 0
+ self.description = None
+ self.rows = None
+ self.has_next = None
+
+ def read(self):
+ self.first_packet = self.connection.read_packet()
+
+ # TODO: use classes for different packet types?
+ if self.first_packet.is_ok_packet():
+ self._read_ok_packet()
+ else:
+ self._read_result_packet()
+
+ def _read_ok_packet(self):
+ self.first_packet.advance(1) # field_count (always '0')
+ self.affected_rows = self.first_packet.read_length_coded_binary()
+ self.insert_id = self.first_packet.read_length_coded_binary()
+ self.server_status = struct.unpack(' 2
+
+try:
+ set
+except NameError:
+ try:
+ from sets import BaseSet as set
+ except ImportError:
+ from sets import Set as set
+
+ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
+ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
+ '\'': '\\\'', '"': '\\"', '\\': '\\\\'}
+
+def escape_item(val, charset):
+ if type(val) in [tuple, list, set]:
+ return escape_sequence(val, charset)
+ if type(val) is dict:
+ return escape_dict(val, charset)
+ if PYTHON3 and hasattr(val, "decode") and not isinstance(val, unicode):
+ # deal with py3k bytes
+ val = val.decode(charset)
+ encoder = encoders[type(val)]
+ val = encoder(val)
+ if type(val) is str:
+ return val
+ val = val.encode(charset)
+ return val
+
+def escape_dict(val, charset):
+ n = {}
+ for k, v in val.items():
+ quoted = escape_item(v, charset)
+ n[k] = quoted
+ return n
+
+def escape_sequence(val, charset):
+ n = []
+ for item in val:
+ quoted = escape_item(item, charset)
+ n.append(quoted)
+ return "(" + ",".join(n) + ")"
+
+def escape_set(val, charset):
+ val = map(lambda x: escape_item(x, charset), val)
+ return ','.join(val)
+
+def escape_bool(value):
+ return str(int(value))
+
+def escape_object(value):
+ return str(value)
+
+escape_int = escape_long = escape_object
+
+def escape_float(value):
+ return ('%.15g' % value)
+
+def escape_string(value):
+ return ("'%s'" % ESCAPE_REGEX.sub(
+ lambda match: ESCAPE_MAP.get(match.group(0)), value))
+
+def escape_unicode(value):
+ return escape_string(value)
+
+def escape_None(value):
+ return 'NULL'
+
+def escape_timedelta(obj):
+ seconds = int(obj.seconds) % 60
+ minutes = int(obj.seconds // 60) % 60
+ hours = int(obj.seconds // 3600) % 24 + int(obj.days) * 24
+ return escape_string('%02d:%02d:%02d' % (hours, minutes, seconds))
+
+def escape_time(obj):
+ s = "%02d:%02d:%02d" % (int(obj.hour), int(obj.minute),
+ int(obj.second))
+ if obj.microsecond:
+ s += ".%f" % obj.microsecond
+
+ return escape_string(s)
+
+def escape_datetime(obj):
+ return escape_string(obj.strftime("%Y-%m-%d %H:%M:%S"))
+
+def escape_date(obj):
+ return escape_string(obj.strftime("%Y-%m-%d"))
+
+def escape_struct_time(obj):
+ return escape_datetime(datetime.datetime(*obj[:6]))
+
+def convert_datetime(connection, field, obj):
+ """Returns a DATETIME or TIMESTAMP column value as a datetime object:
+
+ >>> datetime_or_None('2007-02-25 23:06:20')
+ datetime.datetime(2007, 2, 25, 23, 6, 20)
+ >>> datetime_or_None('2007-02-25T23:06:20')
+ datetime.datetime(2007, 2, 25, 23, 6, 20)
+
+ Illegal values are returned as None:
+
+ >>> datetime_or_None('2007-02-31T23:06:20') is None
+ True
+ >>> datetime_or_None('0000-00-00 00:00:00') is None
+ True
+
+ """
+ if not isinstance(obj, unicode):
+ obj = obj.decode(connection.charset)
+ if ' ' in obj:
+ sep = ' '
+ elif 'T' in obj:
+ sep = 'T'
+ else:
+ return convert_date(connection, field, obj)
+
+ try:
+ ymd, hms = obj.split(sep, 1)
+ return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':') ])
+ except ValueError:
+ return convert_date(connection, field, obj)
+
+def convert_timedelta(connection, field, obj):
+ """Returns a TIME column as a timedelta object:
+
+ >>> timedelta_or_None('25:06:17')
+ datetime.timedelta(1, 3977)
+ >>> timedelta_or_None('-25:06:17')
+ datetime.timedelta(-2, 83177)
+
+ Illegal values are returned as None:
+
+ >>> timedelta_or_None('random crap') is None
+ True
+
+ Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
+ can accept values as (+|-)DD HH:MM:SS. The latter format will not
+ be parsed correctly by this function.
+ """
+ from math import modf
+ try:
+ if not isinstance(obj, unicode):
+ obj = obj.decode(connection.charset)
+ hours, minutes, seconds = tuple([int(x) for x in obj.split(':')])
+ tdelta = datetime.timedelta(
+ hours = int(hours),
+ minutes = int(minutes),
+ seconds = int(seconds),
+ microseconds = int(modf(float(seconds))[0]*1000000),
+ )
+ return tdelta
+ except ValueError:
+ return None
+
+def convert_time(connection, field, obj):
+ """Returns a TIME column as a time object:
+
+ >>> time_or_None('15:06:17')
+ datetime.time(15, 6, 17)
+
+ Illegal values are returned as None:
+
+ >>> time_or_None('-25:06:17') is None
+ True
+ >>> time_or_None('random crap') is None
+ True
+
+ Note that MySQL always returns TIME columns as (+|-)HH:MM:SS, but
+ can accept values as (+|-)DD HH:MM:SS. The latter format will not
+ be parsed correctly by this function.
+
+ Also note that MySQL's TIME column corresponds more closely to
+ Python's timedelta and not time. However if you want TIME columns
+ to be treated as time-of-day and not a time offset, then you can
+ use set this function as the converter for FIELD_TYPE.TIME.
+ """
+ from math import modf
+ try:
+ hour, minute, second = obj.split(':')
+ return datetime.time(hour=int(hour), minute=int(minute),
+ second=int(second),
+ microsecond=int(modf(float(second))[0]*1000000))
+ except ValueError:
+ return None
+
+def convert_date(connection, field, obj):
+ """Returns a DATE column as a date object:
+
+ >>> date_or_None('2007-02-26')
+ datetime.date(2007, 2, 26)
+
+ Illegal values are returned as None:
+
+ >>> date_or_None('2007-02-31') is None
+ True
+ >>> date_or_None('0000-00-00') is None
+ True
+
+ """
+ try:
+ if not isinstance(obj, unicode):
+ obj = obj.decode(connection.charset)
+ return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
+ except ValueError:
+ return None
+
+def convert_mysql_timestamp(connection, field, timestamp):
+ """Convert a MySQL TIMESTAMP to a Timestamp object.
+
+ MySQL >= 4.1 returns TIMESTAMP in the same format as DATETIME:
+
+ >>> mysql_timestamp_converter('2007-02-25 22:32:17')
+ datetime.datetime(2007, 2, 25, 22, 32, 17)
+
+ MySQL < 4.1 uses a big string of numbers:
+
+ >>> mysql_timestamp_converter('20070225223217')
+ datetime.datetime(2007, 2, 25, 22, 32, 17)
+
+ Illegal values are returned as None:
+
+ >>> mysql_timestamp_converter('2007-02-31 22:32:17') is None
+ True
+ >>> mysql_timestamp_converter('00000000000000') is None
+ True
+
+ """
+ if not isinstance(timestamp, unicode):
+ timestamp = timestamp.decode(connection.charset)
+
+ if timestamp[4] == '-':
+ return convert_datetime(connection, field, timestamp)
+ timestamp += "0"*(14-len(timestamp)) # padding
+ year, month, day, hour, minute, second = \
+ int(timestamp[:4]), int(timestamp[4:6]), int(timestamp[6:8]), \
+ int(timestamp[8:10]), int(timestamp[10:12]), int(timestamp[12:14])
+ try:
+ return datetime.datetime(year, month, day, hour, minute, second)
+ except ValueError:
+ return None
+
+def convert_set(s):
+ return set(s.split(","))
+
+def convert_bit(connection, field, b):
+ #b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
+ #return struct.unpack(">Q", b)[0]
+ #
+ # the snippet above is right, but MySQLdb doesn't process bits,
+ # so we shouldn't either
+ return b
+
+def convert_characters(connection, field, data):
+ field_charset = charset_by_id(field.charsetnr).name
+ if field.flags & FLAG.SET:
+ return convert_set(data.decode(field_charset))
+ if field.flags & FLAG.BINARY:
+ return data
+
+ if connection.use_unicode:
+ data = data.decode(field_charset)
+ elif connection.charset != field_charset:
+ data = data.decode(field_charset)
+ data = data.encode(connection.charset)
+ return data
+
+def convert_int(connection, field, data):
+ return int(data)
+
+def convert_long(connection, field, data):
+ return long(data)
+
+def convert_float(connection, field, data):
+ return float(data)
+
+encoders = {
+ bool: escape_bool,
+ int: escape_int,
+ long: escape_long,
+ float: escape_float,
+ str: escape_string,
+ unicode: escape_unicode,
+ tuple: escape_sequence,
+ list:escape_sequence,
+ set:escape_sequence,
+ dict:escape_dict,
+ type(None):escape_None,
+ datetime.date: escape_date,
+ datetime.datetime : escape_datetime,
+ datetime.timedelta : escape_timedelta,
+ datetime.time : escape_time,
+ time.struct_time : escape_struct_time,
+ }
+
+decoders = {
+ FIELD_TYPE.BIT: convert_bit,
+ FIELD_TYPE.TINY: convert_int,
+ FIELD_TYPE.SHORT: convert_int,
+ FIELD_TYPE.LONG: convert_long,
+ FIELD_TYPE.FLOAT: convert_float,
+ FIELD_TYPE.DOUBLE: convert_float,
+ FIELD_TYPE.DECIMAL: convert_float,
+ FIELD_TYPE.NEWDECIMAL: convert_float,
+ FIELD_TYPE.LONGLONG: convert_long,
+ FIELD_TYPE.INT24: convert_int,
+ FIELD_TYPE.YEAR: convert_int,
+ FIELD_TYPE.TIMESTAMP: convert_mysql_timestamp,
+ FIELD_TYPE.DATETIME: convert_datetime,
+ FIELD_TYPE.TIME: convert_timedelta,
+ FIELD_TYPE.DATE: convert_date,
+ FIELD_TYPE.SET: convert_set,
+ FIELD_TYPE.BLOB: convert_characters,
+ FIELD_TYPE.TINY_BLOB: convert_characters,
+ FIELD_TYPE.MEDIUM_BLOB: convert_characters,
+ FIELD_TYPE.LONG_BLOB: convert_characters,
+ FIELD_TYPE.STRING: convert_characters,
+ FIELD_TYPE.VAR_STRING: convert_characters,
+ FIELD_TYPE.VARCHAR: convert_characters,
+ #FIELD_TYPE.BLOB: str,
+ #FIELD_TYPE.STRING: str,
+ #FIELD_TYPE.VAR_STRING: str,
+ #FIELD_TYPE.VARCHAR: str
+ }
+conversions = decoders # for MySQLdb compatibility
+
+try:
+ # python version > 2.3
+ from decimal import Decimal
+ def convert_decimal(connection, field, data):
+ data = data.decode(connection.charset)
+ return Decimal(data)
+ decoders[FIELD_TYPE.DECIMAL] = convert_decimal
+ decoders[FIELD_TYPE.NEWDECIMAL] = convert_decimal
+
+ def escape_decimal(obj):
+ return unicode(obj)
+ encoders[Decimal] = escape_decimal
+
+except ImportError:
+ pass
diff --git a/tools/marvin/marvin/pymysql/cursors.py b/tools/marvin/marvin/pymysql/cursors.py
new file mode 100644
index 00000000000..4e10f83f4fa
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/cursors.py
@@ -0,0 +1,297 @@
+# -*- coding: utf-8 -*-
+import struct
+import re
+
+try:
+ import cStringIO as StringIO
+except ImportError:
+ import StringIO
+
+from err import Warning, Error, InterfaceError, DataError, \
+ DatabaseError, OperationalError, IntegrityError, InternalError, \
+ NotSupportedError, ProgrammingError
+
+insert_values = re.compile(r'\svalues\s*(\(.+\))', re.IGNORECASE)
+
+class Cursor(object):
+ '''
+ This is the object you use to interact with the database.
+ '''
+ def __init__(self, connection):
+ '''
+ Do not create an instance of a Cursor yourself. Call
+ connections.Connection.cursor().
+ '''
+ from weakref import proxy
+ self.connection = proxy(connection)
+ self.description = None
+ self.rownumber = 0
+ self.rowcount = -1
+ self.arraysize = 1
+ self._executed = None
+ self.messages = []
+ self.errorhandler = connection.errorhandler
+ self._has_next = None
+ self._rows = ()
+
+ def __del__(self):
+ '''
+ When this gets GC'd close it.
+ '''
+ self.close()
+
+ def close(self):
+ '''
+ Closing a cursor just exhausts all remaining data.
+ '''
+ if not self.connection:
+ return
+ try:
+ while self.nextset():
+ pass
+ except:
+ pass
+
+ self.connection = None
+
+ def _get_db(self):
+ if not self.connection:
+ self.errorhandler(self, ProgrammingError, "cursor closed")
+ return self.connection
+
+ def _check_executed(self):
+ if not self._executed:
+ self.errorhandler(self, ProgrammingError, "execute() first")
+
+ def setinputsizes(self, *args):
+ """Does nothing, required by DB API."""
+
+ def setoutputsizes(self, *args):
+ """Does nothing, required by DB API."""
+
+ def nextset(self):
+ ''' Get the next query set '''
+ if self._executed:
+ self.fetchall()
+ del self.messages[:]
+
+ if not self._has_next:
+ return None
+ connection = self._get_db()
+ connection.next_result()
+ self._do_get_result()
+ return True
+
+ def execute(self, query, args=None):
+ ''' Execute a query '''
+ from sys import exc_info
+
+ conn = self._get_db()
+ charset = conn.charset
+ del self.messages[:]
+
+ # TODO: make sure that conn.escape is correct
+
+ if args is not None:
+ if isinstance(args, tuple) or isinstance(args, list):
+ escaped_args = tuple(conn.escape(arg) for arg in args)
+ elif isinstance(args, dict):
+ escaped_args = dict((key, conn.escape(val)) for (key, val) in args.items())
+ else:
+ #If it's not a dictionary let's try escaping it anyways.
+ #Worst case it will throw a Value error
+ escaped_args = conn.escape(args)
+
+ query = query % escaped_args
+
+ if isinstance(query, unicode):
+ query = query.encode(charset)
+
+ result = 0
+ try:
+ result = self._query(query)
+ except:
+ exc, value, tb = exc_info()
+ del tb
+ self.messages.append((exc,value))
+ self.errorhandler(self, exc, value)
+
+ self._executed = query
+ return result
+
+ def executemany(self, query, args):
+ ''' Run several data against one query '''
+ del self.messages[:]
+ #conn = self._get_db()
+ if not args:
+ return
+ #charset = conn.charset
+ #if isinstance(query, unicode):
+ # query = query.encode(charset)
+
+ self.rowcount = sum([ self.execute(query, arg) for arg in args ])
+ return self.rowcount
+
+
+ def callproc(self, procname, args=()):
+ """Execute stored procedure procname with args
+
+ procname -- string, name of procedure to execute on server
+
+ args -- Sequence of parameters to use with procedure
+
+ Returns the original args.
+
+ Compatibility warning: PEP-249 specifies that any modified
+ parameters must be returned. This is currently impossible
+ as they are only available by storing them in a server
+ variable and then retrieved by a query. Since stored
+ procedures return zero or more result sets, there is no
+ reliable way to get at OUT or INOUT parameters via callproc.
+ The server variables are named @_procname_n, where procname
+ is the parameter above and n is the position of the parameter
+ (from zero). Once all result sets generated by the procedure
+ have been fetched, you can issue a SELECT @_procname_0, ...
+ query using .execute() to get any OUT or INOUT values.
+
+ Compatibility warning: The act of calling a stored procedure
+ itself creates an empty result set. This appears after any
+ result sets generated by the procedure. This is non-standard
+ behavior with respect to the DB-API. Be sure to use nextset()
+ to advance through all result sets; otherwise you may get
+ disconnected.
+ """
+ conn = self._get_db()
+ for index, arg in enumerate(args):
+ q = "SET @_%s_%d=%s" % (procname, index, conn.escape(arg))
+ if isinstance(q, unicode):
+ q = q.encode(conn.charset)
+ self._query(q)
+ self.nextset()
+
+ q = "CALL %s(%s)" % (procname,
+ ','.join(['@_%s_%d' % (procname, i)
+ for i in range(len(args))]))
+ if isinstance(q, unicode):
+ q = q.encode(conn.charset)
+ self._query(q)
+ self._executed = q
+
+ return args
+
+ def fetchone(self):
+ ''' Fetch the next row '''
+ self._check_executed()
+ if self._rows is None or self.rownumber >= len(self._rows):
+ return None
+ result = self._rows[self.rownumber]
+ self.rownumber += 1
+ return result
+
+ def fetchmany(self, size=None):
+ ''' Fetch several rows '''
+ self._check_executed()
+ end = self.rownumber + (size or self.arraysize)
+ result = self._rows[self.rownumber:end]
+ if self._rows is None:
+ return None
+ self.rownumber = min(end, len(self._rows))
+ return result
+
+ def fetchall(self):
+ ''' Fetch all the rows '''
+ self._check_executed()
+ if self._rows is None:
+ return None
+ if self.rownumber:
+ result = self._rows[self.rownumber:]
+ else:
+ result = self._rows
+ self.rownumber = len(self._rows)
+ return result
+
+ def scroll(self, value, mode='relative'):
+ self._check_executed()
+ if mode == 'relative':
+ r = self.rownumber + value
+ elif mode == 'absolute':
+ r = value
+ else:
+ self.errorhandler(self, ProgrammingError,
+ "unknown scroll mode %s" % mode)
+
+ if r < 0 or r >= len(self._rows):
+ self.errorhandler(self, IndexError, "out of range")
+ self.rownumber = r
+
+ def _query(self, q):
+ conn = self._get_db()
+ self._last_executed = q
+ conn.query(q)
+ self._do_get_result()
+ return self.rowcount
+
+ def _do_get_result(self):
+ conn = self._get_db()
+ self.rowcount = conn._result.affected_rows
+
+ self.rownumber = 0
+ self.description = conn._result.description
+ self.lastrowid = conn._result.insert_id
+ self._rows = conn._result.rows
+ self._has_next = conn._result.has_next
+
+ def __iter__(self):
+ return iter(self.fetchone, None)
+
+ Warning = Warning
+ Error = Error
+ InterfaceError = InterfaceError
+ DatabaseError = DatabaseError
+ DataError = DataError
+ OperationalError = OperationalError
+ IntegrityError = IntegrityError
+ InternalError = InternalError
+ ProgrammingError = ProgrammingError
+ NotSupportedError = NotSupportedError
+
+class DictCursor(Cursor):
+ """A cursor which returns results as a dictionary"""
+
+ def execute(self, query, args=None):
+ result = super(DictCursor, self).execute(query, args)
+ if self.description:
+ self._fields = [ field[0] for field in self.description ]
+ return result
+
+ def fetchone(self):
+ ''' Fetch the next row '''
+ self._check_executed()
+ if self._rows is None or self.rownumber >= len(self._rows):
+ return None
+ result = dict(zip(self._fields, self._rows[self.rownumber]))
+ self.rownumber += 1
+ return result
+
+ def fetchmany(self, size=None):
+ ''' Fetch several rows '''
+ self._check_executed()
+ if self._rows is None:
+ return None
+ end = self.rownumber + (size or self.arraysize)
+ result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:end] ]
+ self.rownumber = min(end, len(self._rows))
+ return tuple(result)
+
+ def fetchall(self):
+ ''' Fetch all the rows '''
+ self._check_executed()
+ if self._rows is None:
+ return None
+ if self.rownumber:
+ result = [ dict(zip(self._fields, r)) for r in self._rows[self.rownumber:] ]
+ else:
+ result = [ dict(zip(self._fields, r)) for r in self._rows ]
+ self.rownumber = len(self._rows)
+ return tuple(result)
+
diff --git a/tools/marvin/marvin/pymysql/err.py b/tools/marvin/marvin/pymysql/err.py
new file mode 100644
index 00000000000..b4322c63354
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/err.py
@@ -0,0 +1,147 @@
+import struct
+
+
+try:
+ StandardError, Warning
+except ImportError:
+ try:
+ from exceptions import StandardError, Warning
+ except ImportError:
+ import sys
+ e = sys.modules['exceptions']
+ StandardError = e.StandardError
+ Warning = e.Warning
+
+from constants import ER
+import sys
+
+class MySQLError(StandardError):
+
+ """Exception related to operation with MySQL."""
+
+
+class Warning(Warning, MySQLError):
+
+ """Exception raised for important warnings like data truncations
+ while inserting, etc."""
+
+class Error(MySQLError):
+
+ """Exception that is the base class of all other error exceptions
+ (not Warning)."""
+
+
+class InterfaceError(Error):
+
+ """Exception raised for errors that are related to the database
+ interface rather than the database itself."""
+
+
+class DatabaseError(Error):
+
+ """Exception raised for errors that are related to the
+ database."""
+
+
+class DataError(DatabaseError):
+
+ """Exception raised for errors that are due to problems with the
+ processed data like division by zero, numeric value out of range,
+ etc."""
+
+
+class OperationalError(DatabaseError):
+
+ """Exception raised for errors that are related to the database's
+ operation and not necessarily under the control of the programmer,
+ e.g. an unexpected disconnect occurs, the data source name is not
+ found, a transaction could not be processed, a memory allocation
+ error occurred during processing, etc."""
+
+
+class IntegrityError(DatabaseError):
+
+ """Exception raised when the relational integrity of the database
+ is affected, e.g. a foreign key check fails, duplicate key,
+ etc."""
+
+
+class InternalError(DatabaseError):
+
+ """Exception raised when the database encounters an internal
+ error, e.g. the cursor is not valid anymore, the transaction is
+ out of sync, etc."""
+
+
+class ProgrammingError(DatabaseError):
+
+ """Exception raised for programming errors, e.g. table not found
+ or already exists, syntax error in the SQL statement, wrong number
+ of parameters specified, etc."""
+
+
+class NotSupportedError(DatabaseError):
+
+ """Exception raised in case a method or database API was used
+ which is not supported by the database, e.g. requesting a
+ .rollback() on a connection that does not support transaction or
+ has transactions turned off."""
+
+
+error_map = {}
+
+def _map_error(exc, *errors):
+ for error in errors:
+ error_map[error] = exc
+
+_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
+ ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
+ ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
+ ER.INVALID_GROUP_FUNC_USE, ER.UNSUPPORTED_EXTENSION,
+ ER.TABLE_MUST_HAVE_COLUMNS, ER.CANT_DO_THIS_DURING_AN_TRANSACTION)
+_map_error(DataError, ER.WARN_DATA_TRUNCATED, ER.WARN_NULL_TO_NOTNULL,
+ ER.WARN_DATA_OUT_OF_RANGE, ER.NO_DEFAULT, ER.PRIMARY_CANT_HAVE_NULL,
+ ER.DATA_TOO_LONG, ER.DATETIME_FUNCTION_OVERFLOW)
+_map_error(IntegrityError, ER.DUP_ENTRY, ER.NO_REFERENCED_ROW,
+ ER.NO_REFERENCED_ROW_2, ER.ROW_IS_REFERENCED, ER.ROW_IS_REFERENCED_2,
+ ER.CANNOT_ADD_FOREIGN)
+_map_error(NotSupportedError, ER.WARNING_NOT_COMPLETE_ROLLBACK,
+ ER.NOT_SUPPORTED_YET, ER.FEATURE_DISABLED, ER.UNKNOWN_STORAGE_ENGINE)
+_map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
+ ER.TABLEACCESS_DENIED_ERROR, ER.COLUMNACCESS_DENIED_ERROR)
+
+del _map_error, ER
+
+
+def _get_error_info(data):
+ errno = struct.unpack(' tuple)
+ c.execute("SELECT * from dictcursor where name='bob'")
+ r = c.fetchall()
+ self.assertEqual((bob,),r,"fetch a 1 row result via fetchall failed via DictCursor")
+ # same test again but iterate over the
+ c.execute("SELECT * from dictcursor where name='bob'")
+ for r in c:
+ self.assertEqual(bob, r,"fetch a 1 row result via iteration failed via DictCursor")
+ # get all 3 row via fetchall
+ c.execute("SELECT * from dictcursor")
+ r = c.fetchall()
+ self.assertEqual((bob,jim,fred), r, "fetchall failed via DictCursor")
+ #same test again but do a list comprehension
+ c.execute("SELECT * from dictcursor")
+ r = [x for x in c]
+ self.assertEqual([bob,jim,fred], r, "list comprehension failed via DictCursor")
+ # get all 2 row via fetchmany
+ c.execute("SELECT * from dictcursor")
+ r = c.fetchmany(2)
+ self.assertEqual((bob,jim), r, "fetchmany failed via DictCursor")
+ finally:
+ c.execute("drop table dictcursor")
+
+__all__ = ["TestDictCursor"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/test_basic.py b/tools/marvin/marvin/pymysql/tests/test_basic.py
new file mode 100644
index 00000000000..c8fdd297f44
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/test_basic.py
@@ -0,0 +1,193 @@
+from pymysql.tests import base
+from pymysql import util
+
+import time
+import datetime
+
+class TestConversion(base.PyMySQLTestCase):
+ def test_datatypes(self):
+ """ test every data type """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_datatypes (b bit, i int, l bigint, f real, s varchar(32), u varchar(32), bb blob, d date, dt datetime, ts timestamp, td time, t time, st datetime)")
+ try:
+ # insert values
+ v = (True, -3, 123456789012, 5.7, "hello'\" world", u"Espa\xc3\xb1ol", "binary\x00data".encode(conn.charset), datetime.date(1988,2,2), datetime.datetime.now(), datetime.timedelta(5,6), datetime.time(16,32), time.localtime())
+ c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", v)
+ c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
+ r = c.fetchone()
+ self.assertEqual(util.int2byte(1), r[0])
+ self.assertEqual(v[1:8], r[1:8])
+ # mysql throws away microseconds so we need to check datetimes
+ # specially. additionally times are turned into timedeltas.
+ self.assertEqual(datetime.datetime(*v[8].timetuple()[:6]), r[8])
+ self.assertEqual(v[9], r[9]) # just timedeltas
+ self.assertEqual(datetime.timedelta(0, 60 * (v[10].hour * 60 + v[10].minute)), r[10])
+ self.assertEqual(datetime.datetime(*v[-1][:6]), r[-1])
+
+ c.execute("delete from test_datatypes")
+
+ # check nulls
+ c.execute("insert into test_datatypes (b,i,l,f,s,u,bb,d,dt,td,t,st) values (%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s,%s)", [None] * 12)
+ c.execute("select b,i,l,f,s,u,bb,d,dt,td,t,st from test_datatypes")
+ r = c.fetchone()
+ self.assertEqual(tuple([None] * 12), r)
+
+ c.execute("delete from test_datatypes")
+
+ # check sequence type
+ c.execute("insert into test_datatypes (i, l) values (2,4), (6,8), (10,12)")
+ c.execute("select l from test_datatypes where i in %s order by i", ((2,6),))
+ r = c.fetchall()
+ self.assertEqual(((4,),(8,)), r)
+ finally:
+ c.execute("drop table test_datatypes")
+
+ def test_dict(self):
+ """ test dict escaping """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_dict (a integer, b integer, c integer)")
+ try:
+ c.execute("insert into test_dict (a,b,c) values (%(a)s, %(b)s, %(c)s)", {"a":1,"b":2,"c":3})
+ c.execute("select a,b,c from test_dict")
+ self.assertEqual((1,2,3), c.fetchone())
+ finally:
+ c.execute("drop table test_dict")
+
+ def test_string(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_dict (a text)")
+ test_value = "I am a test string"
+ try:
+ c.execute("insert into test_dict (a) values (%s)", test_value)
+ c.execute("select a from test_dict")
+ self.assertEqual((test_value,), c.fetchone())
+ finally:
+ c.execute("drop table test_dict")
+
+ def test_integer(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_dict (a integer)")
+ test_value = 12345
+ try:
+ c.execute("insert into test_dict (a) values (%s)", test_value)
+ c.execute("select a from test_dict")
+ self.assertEqual((test_value,), c.fetchone())
+ finally:
+ c.execute("drop table test_dict")
+
+
+ def test_big_blob(self):
+ """ test tons of data """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_big_blob (b blob)")
+ try:
+ data = "pymysql" * 1024
+ c.execute("insert into test_big_blob (b) values (%s)", (data,))
+ c.execute("select b from test_big_blob")
+ self.assertEqual(data.encode(conn.charset), c.fetchone()[0])
+ finally:
+ c.execute("drop table test_big_blob")
+
+class TestCursor(base.PyMySQLTestCase):
+ # this test case does not work quite right yet, however,
+ # we substitute in None for the erroneous field which is
+ # compatible with the DB-API 2.0 spec and has not broken
+ # any unit tests for anything we've tried.
+
+ #def test_description(self):
+ # """ test description attribute """
+ # # result is from MySQLdb module
+ # r = (('Host', 254, 11, 60, 60, 0, 0),
+ # ('User', 254, 16, 16, 16, 0, 0),
+ # ('Password', 254, 41, 41, 41, 0, 0),
+ # ('Select_priv', 254, 1, 1, 1, 0, 0),
+ # ('Insert_priv', 254, 1, 1, 1, 0, 0),
+ # ('Update_priv', 254, 1, 1, 1, 0, 0),
+ # ('Delete_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_priv', 254, 1, 1, 1, 0, 0),
+ # ('Drop_priv', 254, 1, 1, 1, 0, 0),
+ # ('Reload_priv', 254, 1, 1, 1, 0, 0),
+ # ('Shutdown_priv', 254, 1, 1, 1, 0, 0),
+ # ('Process_priv', 254, 1, 1, 1, 0, 0),
+ # ('File_priv', 254, 1, 1, 1, 0, 0),
+ # ('Grant_priv', 254, 1, 1, 1, 0, 0),
+ # ('References_priv', 254, 1, 1, 1, 0, 0),
+ # ('Index_priv', 254, 1, 1, 1, 0, 0),
+ # ('Alter_priv', 254, 1, 1, 1, 0, 0),
+ # ('Show_db_priv', 254, 1, 1, 1, 0, 0),
+ # ('Super_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_tmp_table_priv', 254, 1, 1, 1, 0, 0),
+ # ('Lock_tables_priv', 254, 1, 1, 1, 0, 0),
+ # ('Execute_priv', 254, 1, 1, 1, 0, 0),
+ # ('Repl_slave_priv', 254, 1, 1, 1, 0, 0),
+ # ('Repl_client_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_view_priv', 254, 1, 1, 1, 0, 0),
+ # ('Show_view_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_routine_priv', 254, 1, 1, 1, 0, 0),
+ # ('Alter_routine_priv', 254, 1, 1, 1, 0, 0),
+ # ('Create_user_priv', 254, 1, 1, 1, 0, 0),
+ # ('Event_priv', 254, 1, 1, 1, 0, 0),
+ # ('Trigger_priv', 254, 1, 1, 1, 0, 0),
+ # ('ssl_type', 254, 0, 9, 9, 0, 0),
+ # ('ssl_cipher', 252, 0, 65535, 65535, 0, 0),
+ # ('x509_issuer', 252, 0, 65535, 65535, 0, 0),
+ # ('x509_subject', 252, 0, 65535, 65535, 0, 0),
+ # ('max_questions', 3, 1, 11, 11, 0, 0),
+ # ('max_updates', 3, 1, 11, 11, 0, 0),
+ # ('max_connections', 3, 1, 11, 11, 0, 0),
+ # ('max_user_connections', 3, 1, 11, 11, 0, 0))
+ # conn = self.connections[0]
+ # c = conn.cursor()
+ # c.execute("select * from mysql.user")
+ #
+ # self.assertEqual(r, c.description)
+
+ def test_fetch_no_result(self):
+ """ test a fetchone() with no rows """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table test_nr (b varchar(32))")
+ try:
+ data = "pymysql"
+ c.execute("insert into test_nr (b) values (%s)", (data,))
+ self.assertEqual(None, c.fetchone())
+ finally:
+ c.execute("drop table test_nr")
+
+ def test_aggregates(self):
+ """ test aggregate functions """
+ conn = self.connections[0]
+ c = conn.cursor()
+ try:
+ c.execute('create table test_aggregates (i integer)')
+ for i in xrange(0, 10):
+ c.execute('insert into test_aggregates (i) values (%s)', (i,))
+ c.execute('select sum(i) from test_aggregates')
+ r, = c.fetchone()
+ self.assertEqual(sum(range(0,10)), r)
+ finally:
+ c.execute('drop table test_aggregates')
+
+ def test_single_tuple(self):
+ """ test a single tuple """
+ conn = self.connections[0]
+ c = conn.cursor()
+ try:
+ c.execute("create table mystuff (id integer primary key)")
+ c.execute("insert into mystuff (id) values (1)")
+ c.execute("insert into mystuff (id) values (2)")
+ c.execute("select id from mystuff where id in %s", ((1,),))
+ self.assertEqual([(1,)], list(c.fetchall()))
+ finally:
+ c.execute("drop table mystuff")
+
+__all__ = ["TestConversion","TestCursor"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/test_example.py b/tools/marvin/marvin/pymysql/tests/test_example.py
new file mode 100644
index 00000000000..2da05db31c6
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/test_example.py
@@ -0,0 +1,32 @@
+import pymysql
+from pymysql.tests import base
+
+class TestExample(base.PyMySQLTestCase):
+ def test_example(self):
+ conn = pymysql.connect(host='127.0.0.1', port=3306, user='root', passwd='', db='mysql')
+
+
+ cur = conn.cursor()
+
+ cur.execute("SELECT Host,User FROM user")
+
+ # print cur.description
+
+ # r = cur.fetchall()
+ # print r
+ # ...or...
+ u = False
+
+ for r in cur.fetchall():
+ u = u or conn.user in r
+
+ self.assertTrue(u)
+
+ cur.close()
+ conn.close()
+
+__all__ = ["TestExample"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/test_issues.py b/tools/marvin/marvin/pymysql/tests/test_issues.py
new file mode 100644
index 00000000000..38d71639c90
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/test_issues.py
@@ -0,0 +1,268 @@
+import pymysql
+from pymysql.tests import base
+
+import sys
+
+try:
+ import imp
+ reload = imp.reload
+except AttributeError:
+ pass
+
+import datetime
+
+class TestOldIssues(base.PyMySQLTestCase):
+ def test_issue_3(self):
+ """ undefined methods datetime_or_None, date_or_None """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue3 (d date, t time, dt datetime, ts timestamp)")
+ try:
+ c.execute("insert into issue3 (d, t, dt, ts) values (%s,%s,%s,%s)", (None, None, None, None))
+ c.execute("select d from issue3")
+ self.assertEqual(None, c.fetchone()[0])
+ c.execute("select t from issue3")
+ self.assertEqual(None, c.fetchone()[0])
+ c.execute("select dt from issue3")
+ self.assertEqual(None, c.fetchone()[0])
+ c.execute("select ts from issue3")
+ self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
+ finally:
+ c.execute("drop table issue3")
+
+ def test_issue_4(self):
+ """ can't retrieve TIMESTAMP fields """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue4 (ts timestamp)")
+ try:
+ c.execute("insert into issue4 (ts) values (now())")
+ c.execute("select ts from issue4")
+ self.assertTrue(isinstance(c.fetchone()[0], datetime.datetime))
+ finally:
+ c.execute("drop table issue4")
+
+ def test_issue_5(self):
+ """ query on information_schema.tables fails """
+ con = self.connections[0]
+ cur = con.cursor()
+ cur.execute("select * from information_schema.tables")
+
+ def test_issue_6(self):
+ """ exception: TypeError: ord() expected a character, but string of length 0 found """
+ conn = pymysql.connect(host="localhost",user="root",passwd="",db="mysql")
+ c = conn.cursor()
+ c.execute("select * from user")
+ conn.close()
+
+ def test_issue_8(self):
+ """ Primary Key and Index error when selecting data """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("""CREATE TABLE `test` (`station` int(10) NOT NULL DEFAULT '0', `dh`
+datetime NOT NULL DEFAULT '0000-00-00 00:00:00', `echeance` int(1) NOT NULL
+DEFAULT '0', `me` double DEFAULT NULL, `mo` double DEFAULT NULL, PRIMARY
+KEY (`station`,`dh`,`echeance`)) ENGINE=MyISAM DEFAULT CHARSET=latin1;""")
+ try:
+ self.assertEqual(0, c.execute("SELECT * FROM test"))
+ c.execute("ALTER TABLE `test` ADD INDEX `idx_station` (`station`)")
+ self.assertEqual(0, c.execute("SELECT * FROM test"))
+ finally:
+ c.execute("drop table test")
+
+ def test_issue_9(self):
+ """ sets DeprecationWarning in Python 2.6 """
+ try:
+ reload(pymysql)
+ except DeprecationWarning:
+ self.fail()
+
+ def test_issue_10(self):
+ """ Allocate a variable to return when the exception handler is permissive """
+ conn = self.connections[0]
+ conn.errorhandler = lambda cursor, errorclass, errorvalue: None
+ cur = conn.cursor()
+ cur.execute( "create table t( n int )" )
+ cur.execute( "create table t( n int )" )
+
+ def test_issue_13(self):
+ """ can't handle large result fields """
+ conn = self.connections[0]
+ cur = conn.cursor()
+ try:
+ cur.execute("create table issue13 (t text)")
+ # ticket says 18k
+ size = 18*1024
+ cur.execute("insert into issue13 (t) values (%s)", ("x" * size,))
+ cur.execute("select t from issue13")
+ # use assert_ so that obscenely huge error messages don't print
+ r = cur.fetchone()[0]
+ self.assert_("x" * size == r)
+ finally:
+ cur.execute("drop table issue13")
+
+ def test_issue_14(self):
+ """ typo in converters.py """
+ self.assertEqual('1', pymysql.converters.escape_item(1, "utf8"))
+ self.assertEqual('1', pymysql.converters.escape_item(1L, "utf8"))
+
+ self.assertEqual('1', pymysql.converters.escape_object(1))
+ self.assertEqual('1', pymysql.converters.escape_object(1L))
+
+ def test_issue_15(self):
+ """ query should be expanded before perform character encoding """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue15 (t varchar(32))")
+ try:
+ c.execute("insert into issue15 (t) values (%s)", (u'\xe4\xf6\xfc',))
+ c.execute("select t from issue15")
+ self.assertEqual(u'\xe4\xf6\xfc', c.fetchone()[0])
+ finally:
+ c.execute("drop table issue15")
+
+ def test_issue_16(self):
+ """ Patch for string and tuple escaping """
+ conn = self.connections[0]
+ c = conn.cursor()
+ c.execute("create table issue16 (name varchar(32) primary key, email varchar(32))")
+ try:
+ c.execute("insert into issue16 (name, email) values ('pete', 'floydophone')")
+ c.execute("select email from issue16 where name=%s", ("pete",))
+ self.assertEqual("floydophone", c.fetchone()[0])
+ finally:
+ c.execute("drop table issue16")
+
+ def test_issue_17(self):
+ """ could not connect mysql use passwod """
+ conn = self.connections[0]
+ host = self.databases[0]["host"]
+ db = self.databases[0]["db"]
+ c = conn.cursor()
+ # grant access to a table to a user with a password
+ try:
+ c.execute("create table issue17 (x varchar(32) primary key)")
+ c.execute("insert into issue17 (x) values ('hello, world!')")
+ c.execute("grant all privileges on %s.issue17 to 'issue17user'@'%%' identified by '1234'" % db)
+ conn.commit()
+
+ conn2 = pymysql.connect(host=host, user="issue17user", passwd="1234", db=db)
+ c2 = conn2.cursor()
+ c2.execute("select x from issue17")
+ self.assertEqual("hello, world!", c2.fetchone()[0])
+ finally:
+ c.execute("drop table issue17")
+
+def _uni(s, e):
+ # hack for py3
+ if sys.version_info[0] > 2:
+ return unicode(bytes(s, sys.getdefaultencoding()), e)
+ else:
+ return unicode(s, e)
+
+class TestNewIssues(base.PyMySQLTestCase):
+ def test_issue_34(self):
+ try:
+ pymysql.connect(host="localhost", port=1237, user="root")
+ self.fail()
+ except pymysql.OperationalError, e:
+ self.assertEqual(2003, e.args[0])
+ except:
+ self.fail()
+
+ def test_issue_33(self):
+ conn = pymysql.connect(host="localhost", user="root", db=self.databases[0]["db"], charset="utf8")
+ c = conn.cursor()
+ try:
+ c.execute(_uni("create table hei\xc3\x9fe (name varchar(32))", "utf8"))
+ c.execute(_uni("insert into hei\xc3\x9fe (name) values ('Pi\xc3\xb1ata')", "utf8"))
+ c.execute(_uni("select name from hei\xc3\x9fe", "utf8"))
+ self.assertEqual(_uni("Pi\xc3\xb1ata","utf8"), c.fetchone()[0])
+ finally:
+ c.execute(_uni("drop table hei\xc3\x9fe", "utf8"))
+
+ # Will fail without manual intervention:
+ #def test_issue_35(self):
+ #
+ # conn = self.connections[0]
+ # c = conn.cursor()
+ # print "sudo killall -9 mysqld within the next 10 seconds"
+ # try:
+ # c.execute("select sleep(10)")
+ # self.fail()
+ # except pymysql.OperationalError, e:
+ # self.assertEqual(2013, e.args[0])
+
+ def test_issue_36(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ # kill connections[0]
+ original_count = c.execute("show processlist")
+ kill_id = None
+ for id,user,host,db,command,time,state,info in c.fetchall():
+ if info == "show processlist":
+ kill_id = id
+ break
+ # now nuke the connection
+ conn.kill(kill_id)
+ # make sure this connection has broken
+ try:
+ c.execute("show tables")
+ self.fail()
+ except:
+ pass
+ # check the process list from the other connection
+ self.assertEqual(original_count - 1, self.connections[1].cursor().execute("show processlist"))
+ del self.connections[0]
+
+ def test_issue_37(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ self.assertEqual(1, c.execute("SELECT @foo"))
+ self.assertEqual((None,), c.fetchone())
+ self.assertEqual(0, c.execute("SET @foo = 'bar'"))
+ c.execute("set @foo = 'bar'")
+
+ def test_issue_38(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ datum = "a" * 1024 * 1023 # reduced size for most default mysql installs
+
+ try:
+ c.execute("create table issue38 (id integer, data mediumblob)")
+ c.execute("insert into issue38 values (1, %s)", (datum,))
+ finally:
+ c.execute("drop table issue38")
+
+ def disabled_test_issue_54(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ big_sql = "select * from issue54 where "
+ big_sql += " and ".join("%d=%d" % (i,i) for i in xrange(0, 100000))
+
+ try:
+ c.execute("create table issue54 (id integer primary key)")
+ c.execute("insert into issue54 (id) values (7)")
+ c.execute(big_sql)
+ self.assertEquals(7, c.fetchone()[0])
+ finally:
+ c.execute("drop table issue54")
+
+class TestGitHubIssues(base.PyMySQLTestCase):
+ def test_issue_66(self):
+ conn = self.connections[0]
+ c = conn.cursor()
+ self.assertEquals(0, conn.insert_id())
+ try:
+ c.execute("create table issue66 (id integer primary key auto_increment, x integer)")
+ c.execute("insert into issue66 (x) values (1)")
+ c.execute("insert into issue66 (x) values (1)")
+ self.assertEquals(2, conn.insert_id())
+ finally:
+ c.execute("drop table issue66")
+
+__all__ = ["TestOldIssues", "TestNewIssues", "TestGitHubIssues"]
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py b/tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py
new file mode 100644
index 00000000000..bfcc075fc4b
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/__init__.py
@@ -0,0 +1,5 @@
+from test_MySQLdb import *
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
new file mode 100644
index 00000000000..b64f273cf08
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/__init__.py
@@ -0,0 +1,7 @@
+from test_MySQLdb_capabilities import test_MySQLdb as test_capabilities
+from test_MySQLdb_nonstandard import *
+from test_MySQLdb_dbapi20 import test_MySQLdb as test_dbapi2
+
+if __name__ == "__main__":
+ import unittest
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
new file mode 100644
index 00000000000..ddd012330e5
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/capabilities.py
@@ -0,0 +1,292 @@
+#!/usr/bin/env python -O
+""" Script to test database capabilities and the DB-API interface
+ for functionality and memory leaks.
+
+ Adapted from a script by M-A Lemburg.
+
+"""
+from time import time
+import array
+import unittest
+
+
+class DatabaseTest(unittest.TestCase):
+
+ db_module = None
+ connect_args = ()
+ connect_kwargs = dict(use_unicode=True, charset="utf8")
+ create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
+ rows = 10
+ debug = False
+
+ def setUp(self):
+ import gc
+ db = self.db_module.connect(*self.connect_args, **self.connect_kwargs)
+ self.connection = db
+ self.cursor = db.cursor()
+ self.BLOBText = ''.join([chr(i) for i in range(256)] * 100);
+ self.BLOBUText = u''.join([unichr(i) for i in range(16834)])
+ self.BLOBBinary = self.db_module.Binary(''.join([chr(i) for i in range(256)] * 16))
+
+ leak_test = True
+
+ def tearDown(self):
+ if self.leak_test:
+ import gc
+ del self.cursor
+ orphans = gc.collect()
+ self.assertFalse(orphans, "%d orphaned objects found after deleting cursor" % orphans)
+
+ del self.connection
+ orphans = gc.collect()
+ self.assertFalse(orphans, "%d orphaned objects found after deleting connection" % orphans)
+
+ def table_exists(self, name):
+ try:
+ self.cursor.execute('select * from %s where 1=0' % name)
+ except:
+ return False
+ else:
+ return True
+
+ def quote_identifier(self, ident):
+ return '"%s"' % ident
+
+ def new_table_name(self):
+ i = id(self.cursor)
+ while True:
+ name = self.quote_identifier('tb%08x' % i)
+ if not self.table_exists(name):
+ return name
+ i = i + 1
+
+ def create_table(self, columndefs):
+
+ """ Create a table using a list of column definitions given in
+ columndefs.
+
+ generator must be a function taking arguments (row_number,
+ col_number) returning a suitable data object for insertion
+ into the table.
+
+ """
+ self.table = self.new_table_name()
+ self.cursor.execute('CREATE TABLE %s (%s) %s' %
+ (self.table,
+ ',\n'.join(columndefs),
+ self.create_table_extra))
+
+ def check_data_integrity(self, columndefs, generator):
+ # insert
+ self.create_table(columndefs)
+ insert_statement = ('INSERT INTO %s VALUES (%s)' %
+ (self.table,
+ ','.join(['%s'] * len(columndefs))))
+ data = [ [ generator(i,j) for j in range(len(columndefs)) ]
+ for i in range(self.rows) ]
+ if self.debug:
+ print data
+ self.cursor.executemany(insert_statement, data)
+ self.connection.commit()
+ # verify
+ self.cursor.execute('select * from %s' % self.table)
+ l = self.cursor.fetchall()
+ if self.debug:
+ print l
+ self.assertEquals(len(l), self.rows)
+ try:
+ for i in range(self.rows):
+ for j in range(len(columndefs)):
+ self.assertEquals(l[i][j], generator(i,j))
+ finally:
+ if not self.debug:
+ self.cursor.execute('drop table %s' % (self.table))
+
+ def test_transactions(self):
+ columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
+ def generator(row, col):
+ if col == 0: return row
+ else: return ('%i' % (row%10))*255
+ self.create_table(columndefs)
+ insert_statement = ('INSERT INTO %s VALUES (%s)' %
+ (self.table,
+ ','.join(['%s'] * len(columndefs))))
+ data = [ [ generator(i,j) for j in range(len(columndefs)) ]
+ for i in range(self.rows) ]
+ self.cursor.executemany(insert_statement, data)
+ # verify
+ self.connection.commit()
+ self.cursor.execute('select * from %s' % self.table)
+ l = self.cursor.fetchall()
+ self.assertEquals(len(l), self.rows)
+ for i in range(self.rows):
+ for j in range(len(columndefs)):
+ self.assertEquals(l[i][j], generator(i,j))
+ delete_statement = 'delete from %s where col1=%%s' % self.table
+ self.cursor.execute(delete_statement, (0,))
+ self.cursor.execute('select col1 from %s where col1=%s' % \
+ (self.table, 0))
+ l = self.cursor.fetchall()
+ self.assertFalse(l, "DELETE didn't work")
+ self.connection.rollback()
+ self.cursor.execute('select col1 from %s where col1=%s' % \
+ (self.table, 0))
+ l = self.cursor.fetchall()
+ self.assertTrue(len(l) == 1, "ROLLBACK didn't work")
+ self.cursor.execute('drop table %s' % (self.table))
+
+ def test_truncation(self):
+ columndefs = ( 'col1 INT', 'col2 VARCHAR(255)')
+ def generator(row, col):
+ if col == 0: return row
+ else: return ('%i' % (row%10))*((255-self.rows/2)+row)
+ self.create_table(columndefs)
+ insert_statement = ('INSERT INTO %s VALUES (%s)' %
+ (self.table,
+ ','.join(['%s'] * len(columndefs))))
+
+ try:
+ self.cursor.execute(insert_statement, (0, '0'*256))
+ except Warning:
+ if self.debug: print self.cursor.messages
+ except self.connection.DataError:
+ pass
+ else:
+ self.fail("Over-long column did not generate warnings/exception with single insert")
+
+ self.connection.rollback()
+
+ try:
+ for i in range(self.rows):
+ data = []
+ for j in range(len(columndefs)):
+ data.append(generator(i,j))
+ self.cursor.execute(insert_statement,tuple(data))
+ except Warning:
+ if self.debug: print self.cursor.messages
+ except self.connection.DataError:
+ pass
+ else:
+ self.fail("Over-long columns did not generate warnings/exception with execute()")
+
+ self.connection.rollback()
+
+ try:
+ data = [ [ generator(i,j) for j in range(len(columndefs)) ]
+ for i in range(self.rows) ]
+ self.cursor.executemany(insert_statement, data)
+ except Warning:
+ if self.debug: print self.cursor.messages
+ except self.connection.DataError:
+ pass
+ else:
+ self.fail("Over-long columns did not generate warnings/exception with executemany()")
+
+ self.connection.rollback()
+ self.cursor.execute('drop table %s' % (self.table))
+
+ def test_CHAR(self):
+ # Character data
+ def generator(row,col):
+ return ('%i' % ((row+col) % 10)) * 255
+ self.check_data_integrity(
+ ('col1 char(255)','col2 char(255)'),
+ generator)
+
+ def test_INT(self):
+ # Number data
+ def generator(row,col):
+ return row*row
+ self.check_data_integrity(
+ ('col1 INT',),
+ generator)
+
+ def test_DECIMAL(self):
+ # DECIMAL
+ def generator(row,col):
+ from decimal import Decimal
+ return Decimal("%d.%02d" % (row, col))
+ self.check_data_integrity(
+ ('col1 DECIMAL(5,2)',),
+ generator)
+
+ def test_DATE(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.DateFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 DATE',),
+ generator)
+
+ def test_TIME(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimeFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 TIME',),
+ generator)
+
+ def test_DATETIME(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 DATETIME',),
+ generator)
+
+ def test_TIMESTAMP(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313)
+ self.check_data_integrity(
+ ('col1 TIMESTAMP',),
+ generator)
+
+ def test_fractional_TIMESTAMP(self):
+ ticks = time()
+ def generator(row,col):
+ return self.db_module.TimestampFromTicks(ticks+row*86400-col*1313+row*0.7*col/3.0)
+ self.check_data_integrity(
+ ('col1 TIMESTAMP',),
+ generator)
+
+ def test_LONG(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBUText # 'BLOB Text ' * 1024
+ self.check_data_integrity(
+ ('col1 INT', 'col2 LONG'),
+ generator)
+
+ def test_TEXT(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBUText[:5192] # 'BLOB Text ' * 1024
+ self.check_data_integrity(
+ ('col1 INT', 'col2 TEXT'),
+ generator)
+
+ def test_LONG_BYTE(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBBinary # 'BLOB\000Binary ' * 1024
+ self.check_data_integrity(
+ ('col1 INT','col2 LONG BYTE'),
+ generator)
+
+ def test_BLOB(self):
+ def generator(row,col):
+ if col == 0:
+ return row
+ else:
+ return self.BLOBBinary # 'BLOB\000Binary ' * 1024
+ self.check_data_integrity(
+ ('col1 INT','col2 BLOB'),
+ generator)
+
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
new file mode 100644
index 00000000000..a419e34a46c
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/dbapi20.py
@@ -0,0 +1,853 @@
+#!/usr/bin/env python
+''' Python DB API 2.0 driver compliance unit test suite.
+
+ This software is Public Domain and may be used without restrictions.
+
+ "Now we have booze and barflies entering the discussion, plus rumours of
+ DBAs on drugs... and I won't tell you what flashes through my mind each
+ time I read the subject line with 'Anal Compliance' in it. All around
+ this is turning out to be a thoroughly unwholesome unit test."
+
+ -- Ian Bicking
+'''
+
+__rcs_id__ = '$Id$'
+__version__ = '$Revision$'[11:-2]
+__author__ = 'Stuart Bishop '
+
+import unittest
+import time
+
+# $Log$
+# Revision 1.1.2.1 2006/02/25 03:44:32 adustman
+# Generic DB-API unit test module
+#
+# Revision 1.10 2003/10/09 03:14:14 zenzen
+# Add test for DB API 2.0 optional extension, where database exceptions
+# are exposed as attributes on the Connection object.
+#
+# Revision 1.9 2003/08/13 01:16:36 zenzen
+# Minor tweak from Stefan Fleiter
+#
+# Revision 1.8 2003/04/10 00:13:25 zenzen
+# Changes, as per suggestions by M.-A. Lemburg
+# - Add a table prefix, to ensure namespace collisions can always be avoided
+#
+# Revision 1.7 2003/02/26 23:33:37 zenzen
+# Break out DDL into helper functions, as per request by David Rushby
+#
+# Revision 1.6 2003/02/21 03:04:33 zenzen
+# Stuff from Henrik Ekelund:
+# added test_None
+# added test_nextset & hooks
+#
+# Revision 1.5 2003/02/17 22:08:43 zenzen
+# Implement suggestions and code from Henrik Eklund - test that cursor.arraysize
+# defaults to 1 & generic cursor.callproc test added
+#
+# Revision 1.4 2003/02/15 00:16:33 zenzen
+# Changes, as per suggestions and bug reports by M.-A. Lemburg,
+# Matthew T. Kromer, Federico Di Gregorio and Daniel Dittmar
+# - Class renamed
+# - Now a subclass of TestCase, to avoid requiring the driver stub
+# to use multiple inheritance
+# - Reversed the polarity of buggy test in test_description
+# - Test exception heirarchy correctly
+# - self.populate is now self._populate(), so if a driver stub
+# overrides self.ddl1 this change propogates
+# - VARCHAR columns now have a width, which will hopefully make the
+# DDL even more portible (this will be reversed if it causes more problems)
+# - cursor.rowcount being checked after various execute and fetchXXX methods
+# - Check for fetchall and fetchmany returning empty lists after results
+# are exhausted (already checking for empty lists if select retrieved
+# nothing
+# - Fix bugs in test_setoutputsize_basic and test_setinputsizes
+#
+
+class DatabaseAPI20Test(unittest.TestCase):
+ ''' Test a database self.driver for DB API 2.0 compatibility.
+ This implementation tests Gadfly, but the TestCase
+ is structured so that other self.drivers can subclass this
+ test case to ensure compiliance with the DB-API. It is
+ expected that this TestCase may be expanded in the future
+ if ambiguities or edge conditions are discovered.
+
+ The 'Optional Extensions' are not yet being tested.
+
+ self.drivers should subclass this test, overriding setUp, tearDown,
+ self.driver, connect_args and connect_kw_args. Class specification
+ should be as follows:
+
+ import dbapi20
+ class mytest(dbapi20.DatabaseAPI20Test):
+ [...]
+
+ Don't 'import DatabaseAPI20Test from dbapi20', or you will
+ confuse the unit tester - just 'import dbapi20'.
+ '''
+
+ # The self.driver module. This should be the module where the 'connect'
+ # method is to be found
+ driver = None
+ connect_args = () # List of arguments to pass to connect
+ connect_kw_args = {} # Keyword arguments for connect
+ table_prefix = 'dbapi20test_' # If you need to specify a prefix for tables
+
+ ddl1 = 'create table %sbooze (name varchar(20))' % table_prefix
+ ddl2 = 'create table %sbarflys (name varchar(20))' % table_prefix
+ xddl1 = 'drop table %sbooze' % table_prefix
+ xddl2 = 'drop table %sbarflys' % table_prefix
+
+ lowerfunc = 'lower' # Name of stored procedure to convert string->lowercase
+
+ # Some drivers may need to override these helpers, for example adding
+ # a 'commit' after the execute.
+ def executeDDL1(self,cursor):
+ cursor.execute(self.ddl1)
+
+ def executeDDL2(self,cursor):
+ cursor.execute(self.ddl2)
+
+ def setUp(self):
+ ''' self.drivers should override this method to perform required setup
+ if any is necessary, such as creating the database.
+ '''
+ pass
+
+ def tearDown(self):
+ ''' self.drivers should override this method to perform required cleanup
+ if any is necessary, such as deleting the test database.
+ The default drops the tables that may be created.
+ '''
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ for ddl in (self.xddl1,self.xddl2):
+ try:
+ cur.execute(ddl)
+ con.commit()
+ except self.driver.Error:
+ # Assume table didn't exist. Other tests will check if
+ # execute is busted.
+ pass
+ finally:
+ con.close()
+
+ def _connect(self):
+ try:
+ return self.driver.connect(
+ *self.connect_args,**self.connect_kw_args
+ )
+ except AttributeError:
+ self.fail("No connect method found in self.driver module")
+
+ def test_connect(self):
+ con = self._connect()
+ con.close()
+
+ def test_apilevel(self):
+ try:
+ # Must exist
+ apilevel = self.driver.apilevel
+ # Must equal 2.0
+ self.assertEqual(apilevel,'2.0')
+ except AttributeError:
+ self.fail("Driver doesn't define apilevel")
+
+ def test_threadsafety(self):
+ try:
+ # Must exist
+ threadsafety = self.driver.threadsafety
+ # Must be a valid value
+ self.assertTrue(threadsafety in (0,1,2,3))
+ except AttributeError:
+ self.fail("Driver doesn't define threadsafety")
+
+ def test_paramstyle(self):
+ try:
+ # Must exist
+ paramstyle = self.driver.paramstyle
+ # Must be a valid value
+ self.assertTrue(paramstyle in (
+ 'qmark','numeric','named','format','pyformat'
+ ))
+ except AttributeError:
+ self.fail("Driver doesn't define paramstyle")
+
+ def test_Exceptions(self):
+ # Make sure required exceptions exist, and are in the
+ # defined heirarchy.
+ self.assertTrue(issubclass(self.driver.Warning,StandardError))
+ self.assertTrue(issubclass(self.driver.Error,StandardError))
+ self.assertTrue(
+ issubclass(self.driver.InterfaceError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.DatabaseError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.OperationalError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.IntegrityError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.InternalError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.ProgrammingError,self.driver.Error)
+ )
+ self.assertTrue(
+ issubclass(self.driver.NotSupportedError,self.driver.Error)
+ )
+
+ def test_ExceptionsAsConnectionAttributes(self):
+ # OPTIONAL EXTENSION
+ # Test for the optional DB API 2.0 extension, where the exceptions
+ # are exposed as attributes on the Connection object
+ # I figure this optional extension will be implemented by any
+ # driver author who is using this test suite, so it is enabled
+ # by default.
+ con = self._connect()
+ drv = self.driver
+ self.assertTrue(con.Warning is drv.Warning)
+ self.assertTrue(con.Error is drv.Error)
+ self.assertTrue(con.InterfaceError is drv.InterfaceError)
+ self.assertTrue(con.DatabaseError is drv.DatabaseError)
+ self.assertTrue(con.OperationalError is drv.OperationalError)
+ self.assertTrue(con.IntegrityError is drv.IntegrityError)
+ self.assertTrue(con.InternalError is drv.InternalError)
+ self.assertTrue(con.ProgrammingError is drv.ProgrammingError)
+ self.assertTrue(con.NotSupportedError is drv.NotSupportedError)
+
+
+ def test_commit(self):
+ con = self._connect()
+ try:
+ # Commit must work, even if it doesn't do anything
+ con.commit()
+ finally:
+ con.close()
+
+ def test_rollback(self):
+ con = self._connect()
+ # If rollback is defined, it should either work or throw
+ # the documented exception
+ if hasattr(con,'rollback'):
+ try:
+ con.rollback()
+ except self.driver.NotSupportedError:
+ pass
+
+ def test_cursor(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ def test_cursor_isolation(self):
+ con = self._connect()
+ try:
+ # Make sure cursors created from the same connection have
+ # the documented transaction isolation level
+ cur1 = con.cursor()
+ cur2 = con.cursor()
+ self.executeDDL1(cur1)
+ cur1.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ cur2.execute("select name from %sbooze" % self.table_prefix)
+ booze = cur2.fetchall()
+ self.assertEqual(len(booze),1)
+ self.assertEqual(len(booze[0]),1)
+ self.assertEqual(booze[0][0],'Victoria Bitter')
+ finally:
+ con.close()
+
+ def test_description(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description should be none after executing a '
+ 'statement that can return no rows (such as DDL)'
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(len(cur.description),1,
+ 'cursor.description describes too many columns'
+ )
+ self.assertEqual(len(cur.description[0]),7,
+ 'cursor.description[x] tuples must have 7 elements'
+ )
+ self.assertEqual(cur.description[0][0].lower(),'name',
+ 'cursor.description[x][0] must return column name'
+ )
+ self.assertEqual(cur.description[0][1],self.driver.STRING,
+ 'cursor.description[x][1] must return column type. Got %r'
+ % cur.description[0][1]
+ )
+
+ # Make sure self.description gets reset
+ self.executeDDL2(cur)
+ self.assertEqual(cur.description,None,
+ 'cursor.description not being set to None when executing '
+ 'no-result statements (eg. DDL)'
+ )
+ finally:
+ con.close()
+
+ def test_rowcount(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount should be -1 after executing no-result '
+ 'statements'
+ )
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertTrue(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number or rows inserted, or '
+ 'set to -1 after executing an insert statement'
+ )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertTrue(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number of rows returned, or '
+ 'set to -1 after executing a select statement'
+ )
+ self.executeDDL2(cur)
+ self.assertEqual(cur.rowcount,-1,
+ 'cursor.rowcount not being reset to -1 after executing '
+ 'no-result statements'
+ )
+ finally:
+ con.close()
+
+ lower_func = 'lower'
+ def test_callproc(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if self.lower_func and hasattr(cur,'callproc'):
+ r = cur.callproc(self.lower_func,('FOO',))
+ self.assertEqual(len(r),1)
+ self.assertEqual(r[0],'FOO')
+ r = cur.fetchall()
+ self.assertEqual(len(r),1,'callproc produced no result set')
+ self.assertEqual(len(r[0]),1,
+ 'callproc produced invalid result set'
+ )
+ self.assertEqual(r[0][0],'foo',
+ 'callproc produced invalid results'
+ )
+ finally:
+ con.close()
+
+ def test_close(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ finally:
+ con.close()
+
+ # cursor.execute should raise an Error if called after connection
+ # closed
+ self.assertRaises(self.driver.Error,self.executeDDL1,cur)
+
+ # connection.commit should raise an Error if called after connection'
+ # closed.'
+ self.assertRaises(self.driver.Error,con.commit)
+
+ # connection.close should raise an Error if called more than once
+ self.assertRaises(self.driver.Error,con.close)
+
+ def test_execute(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self._paraminsert(cur)
+ finally:
+ con.close()
+
+ def _paraminsert(self,cur):
+ self.executeDDL1(cur)
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertTrue(cur.rowcount in (-1,1))
+
+ if self.driver.paramstyle == 'qmark':
+ cur.execute(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.execute(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.execute(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.execute(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ ("Cooper's",)
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.execute(
+ 'insert into %sbooze values (%%(beer)s)' % self.table_prefix,
+ {'beer':"Cooper's"}
+ )
+ else:
+ self.fail('Invalid paramstyle')
+ self.assertTrue(cur.rowcount in (-1,1))
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,'cursor.fetchall returned too few rows')
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Cooper's",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+ self.assertEqual(beers[1],"Victoria Bitter",
+ 'cursor.fetchall retrieved incorrect data, or data inserted '
+ 'incorrectly'
+ )
+
+ def test_executemany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ largs = [ ("Cooper's",) , ("Boag's",) ]
+ margs = [ {'beer': "Cooper's"}, {'beer': "Boag's"} ]
+ if self.driver.paramstyle == 'qmark':
+ cur.executemany(
+ 'insert into %sbooze values (?)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'numeric':
+ cur.executemany(
+ 'insert into %sbooze values (:1)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'named':
+ cur.executemany(
+ 'insert into %sbooze values (:beer)' % self.table_prefix,
+ margs
+ )
+ elif self.driver.paramstyle == 'format':
+ cur.executemany(
+ 'insert into %sbooze values (%%s)' % self.table_prefix,
+ largs
+ )
+ elif self.driver.paramstyle == 'pyformat':
+ cur.executemany(
+ 'insert into %sbooze values (%%(beer)s)' % (
+ self.table_prefix
+ ),
+ margs
+ )
+ else:
+ self.fail('Unknown paramstyle')
+ self.assertTrue(cur.rowcount in (-1,2),
+ 'insert using cursor.executemany set cursor.rowcount to '
+ 'incorrect value %r' % cur.rowcount
+ )
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ res = cur.fetchall()
+ self.assertEqual(len(res),2,
+ 'cursor.fetchall retrieved incorrect number of rows'
+ )
+ beers = [res[0][0],res[1][0]]
+ beers.sort()
+ self.assertEqual(beers[0],"Boag's",'incorrect data retrieved')
+ self.assertEqual(beers[1],"Cooper's",'incorrect data retrieved')
+ finally:
+ con.close()
+
+ def test_fetchone(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchone should raise an Error if called before
+ # executing a select-type query
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ self.executeDDL1(cur)
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if a query retrieves '
+ 'no rows'
+ )
+ self.assertTrue(cur.rowcount in (-1,0))
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchone()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchone should have retrieved a single row'
+ )
+ self.assertEqual(r[0],'Victoria Bitter',
+ 'cursor.fetchone retrieved incorrect data'
+ )
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if no more rows available'
+ )
+ self.assertTrue(cur.rowcount in (-1,1))
+ finally:
+ con.close()
+
+ samples = [
+ 'Carlton Cold',
+ 'Carlton Draft',
+ 'Mountain Goat',
+ 'Redback',
+ 'Victoria Bitter',
+ 'XXXX'
+ ]
+
+ def _populate(self):
+ ''' Return a list of sql commands to setup the DB for the fetch
+ tests.
+ '''
+ populate = [
+ "insert into %sbooze values ('%s')" % (self.table_prefix,s)
+ for s in self.samples
+ ]
+ return populate
+
+ def test_fetchmany(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchmany should raise an Error if called without
+ #issuing a query
+ self.assertRaises(self.driver.Error,cur.fetchmany,4)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchmany retrieved incorrect number of rows, '
+ 'default of arraysize is one.'
+ )
+ cur.arraysize=10
+ r = cur.fetchmany(3) # Should get 3 rows
+ self.assertEqual(len(r),3,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should get 2 more
+ self.assertEqual(len(r),2,
+ 'cursor.fetchmany retrieved incorrect number of rows'
+ )
+ r = cur.fetchmany(4) # Should be an empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence after '
+ 'results are exhausted'
+ )
+ self.assertTrue(cur.rowcount in (-1,6))
+
+ # Same as above, using cursor.arraysize
+ cur.arraysize=4
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchmany() # Should get 4 rows
+ self.assertEqual(len(r),4,
+ 'cursor.arraysize not being honoured by fetchmany'
+ )
+ r = cur.fetchmany() # Should get 2 more
+ self.assertEqual(len(r),2)
+ r = cur.fetchmany() # Should be an empty sequence
+ self.assertEqual(len(r),0)
+ self.assertTrue(cur.rowcount in (-1,6))
+
+ cur.arraysize=6
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchmany() # Should get all rows
+ self.assertTrue(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows),6)
+ self.assertEqual(len(rows),6)
+ rows = [r[0] for r in rows]
+ rows.sort()
+
+ # Make sure we get the right data back out
+ for i in range(0,6):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved by cursor.fetchmany'
+ )
+
+ rows = cur.fetchmany() # Should return an empty list
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'called after the whole result set has been fetched'
+ )
+ self.assertTrue(cur.rowcount in (-1,6))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ r = cur.fetchmany() # Should get empty sequence
+ self.assertEqual(len(r),0,
+ 'cursor.fetchmany should return an empty sequence if '
+ 'query retrieved no rows'
+ )
+ self.assertTrue(cur.rowcount in (-1,0))
+
+ finally:
+ con.close()
+
+ def test_fetchall(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ # cursor.fetchall should raise an Error if called
+ # without executing a query that may return rows (such
+ # as a select)
+ self.assertRaises(self.driver.Error, cur.fetchall)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ # cursor.fetchall should raise an Error if called
+ # after executing a a statement that cannot return rows
+ self.assertRaises(self.driver.Error,cur.fetchall)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+ self.assertEqual(len(rows),len(self.samples),
+ 'cursor.fetchall did not retrieve all rows'
+ )
+ rows = [r[0] for r in rows]
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'cursor.fetchall retrieved incorrect rows'
+ )
+ rows = cur.fetchall()
+ self.assertEqual(
+ len(rows),0,
+ 'cursor.fetchall should return an empty list if called '
+ 'after the whole result set has been fetched'
+ )
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,0))
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchall should return an empty list if '
+ 'a select query returns no rows'
+ )
+
+ finally:
+ con.close()
+
+ def test_mixedfetch(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows1 = cur.fetchone()
+ rows23 = cur.fetchmany(2)
+ rows4 = cur.fetchone()
+ rows56 = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,6))
+ self.assertEqual(len(rows23),2,
+ 'fetchmany returned incorrect number of rows'
+ )
+ self.assertEqual(len(rows56),2,
+ 'fetchall returned incorrect number of rows'
+ )
+
+ rows = [rows1[0]]
+ rows.extend([rows23[0][0],rows23[1][0]])
+ rows.append(rows4[0])
+ rows.extend([rows56[0][0],rows56[1][0]])
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'incorrect data retrieved or inserted'
+ )
+ finally:
+ con.close()
+
+ def help_nextset_setUp(self,cur):
+ ''' Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ '''
+ raise NotImplementedError,'Helper not implemented'
+ #sql="""
+ # create procedure deleteme as
+ # begin
+ # select count(*) from booze
+ # select name from booze
+ # end
+ #"""
+ #cur.execute(sql)
+
+ def help_nextset_tearDown(self,cur):
+ 'If cleaning up is needed after nextSetTest'
+ raise NotImplementedError,'Helper not implemented'
+ #cur.execute("drop procedure deleteme")
+
+ def test_nextset(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if not hasattr(cur,'nextset'):
+ return
+
+ try:
+ self.executeDDL1(cur)
+ sql=self._populate()
+ for sql in self._populate():
+ cur.execute(sql)
+
+ self.help_nextset_setUp(cur)
+
+ cur.callproc('deleteme')
+ numberofrows=cur.fetchone()
+ assert numberofrows[0]== len(self.samples)
+ assert cur.nextset()
+ names=cur.fetchall()
+ assert len(names) == len(self.samples)
+ s=cur.nextset()
+ assert s == None,'No more return sets, should return None'
+ finally:
+ self.help_nextset_tearDown(cur)
+
+ finally:
+ con.close()
+
+ def test_nextset(self):
+ raise NotImplementedError,'Drivers need to override this test'
+
+ def test_arraysize(self):
+ # Not much here - rest of the tests for this are in test_fetchmany
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.assertTrue(hasattr(cur,'arraysize'),
+ 'cursor.arraysize must be defined'
+ )
+ finally:
+ con.close()
+
+ def test_setinputsizes(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setinputsizes( (25,) )
+ self._paraminsert(cur) # Make sure cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize_basic(self):
+ # Basic test is to make sure setoutputsize doesn't blow up
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ cur.setoutputsize(1000)
+ cur.setoutputsize(2000,0)
+ self._paraminsert(cur) # Make sure the cursor still works
+ finally:
+ con.close()
+
+ def test_setoutputsize(self):
+ # Real test for setoutputsize is driver dependant
+ raise NotImplementedError,'Driver need to override this test'
+
+ def test_None(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+ cur.execute('insert into %sbooze values (NULL)' % self.table_prefix)
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchall()
+ self.assertEqual(len(r),1)
+ self.assertEqual(len(r[0]),1)
+ self.assertEqual(r[0][0],None,'NULL value not returned as None')
+ finally:
+ con.close()
+
+ def test_Date(self):
+ d1 = self.driver.Date(2002,12,25)
+ d2 = self.driver.DateFromTicks(time.mktime((2002,12,25,0,0,0,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(d1),str(d2))
+
+ def test_Time(self):
+ t1 = self.driver.Time(13,45,30)
+ t2 = self.driver.TimeFromTicks(time.mktime((2001,1,1,13,45,30,0,0,0)))
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Timestamp(self):
+ t1 = self.driver.Timestamp(2002,12,25,13,45,30)
+ t2 = self.driver.TimestampFromTicks(
+ time.mktime((2002,12,25,13,45,30,0,0,0))
+ )
+ # Can we assume this? API doesn't specify, but it seems implied
+ # self.assertEqual(str(t1),str(t2))
+
+ def test_Binary(self):
+ b = self.driver.Binary('Something')
+ b = self.driver.Binary('')
+
+ def test_STRING(self):
+ self.assertTrue(hasattr(self.driver,'STRING'),
+ 'module.STRING must be defined'
+ )
+
+ def test_BINARY(self):
+ self.assertTrue(hasattr(self.driver,'BINARY'),
+ 'module.BINARY must be defined.'
+ )
+
+ def test_NUMBER(self):
+ self.assertTrue(hasattr(self.driver,'NUMBER'),
+ 'module.NUMBER must be defined.'
+ )
+
+ def test_DATETIME(self):
+ self.assertTrue(hasattr(self.driver,'DATETIME'),
+ 'module.DATETIME must be defined.'
+ )
+
+ def test_ROWID(self):
+ self.assertTrue(hasattr(self.driver,'ROWID'),
+ 'module.ROWID must be defined.'
+ )
+
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
new file mode 100644
index 00000000000..e0bc93439c2
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_capabilities.py
@@ -0,0 +1,115 @@
+#!/usr/bin/env python
+import capabilities
+import unittest
+import pymysql
+from pymysql.tests import base
+import warnings
+
+warnings.filterwarnings('error')
+
+class test_MySQLdb(capabilities.DatabaseTest):
+
+ db_module = pymysql
+ connect_args = ()
+ connect_kwargs = base.PyMySQLTestCase.databases[0].copy()
+ connect_kwargs.update(dict(read_default_file='~/.my.cnf',
+ use_unicode=True,
+ charset='utf8', sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
+
+ create_table_extra = "ENGINE=INNODB CHARACTER SET UTF8"
+ leak_test = False
+
+ def quote_identifier(self, ident):
+ return "`%s`" % ident
+
+ def test_TIME(self):
+ from datetime import timedelta
+ def generator(row,col):
+ return timedelta(0, row*8000)
+ self.check_data_integrity(
+ ('col1 TIME',),
+ generator)
+
+ def test_TINYINT(self):
+ # Number data
+ def generator(row,col):
+ v = (row*row) % 256
+ if v > 127:
+ v = v-256
+ return v
+ self.check_data_integrity(
+ ('col1 TINYINT',),
+ generator)
+
+ def test_stored_procedures(self):
+ db = self.connection
+ c = self.cursor
+ try:
+ self.create_table(('pos INT', 'tree CHAR(20)'))
+ c.executemany("INSERT INTO %s (pos,tree) VALUES (%%s,%%s)" % self.table,
+ list(enumerate('ash birch cedar larch pine'.split())))
+ db.commit()
+
+ c.execute("""
+ CREATE PROCEDURE test_sp(IN t VARCHAR(255))
+ BEGIN
+ SELECT pos FROM %s WHERE tree = t;
+ END
+ """ % self.table)
+ db.commit()
+
+ c.callproc('test_sp', ('larch',))
+ rows = c.fetchall()
+ self.assertEquals(len(rows), 1)
+ self.assertEquals(rows[0][0], 3)
+ c.nextset()
+ finally:
+ c.execute("DROP PROCEDURE IF EXISTS test_sp")
+ c.execute('drop table %s' % (self.table))
+
+ def test_small_CHAR(self):
+ # Character data
+ def generator(row,col):
+ i = ((row+1)*(col+1)+62)%256
+ if i == 62: return ''
+ if i == 63: return None
+ return chr(i)
+ self.check_data_integrity(
+ ('col1 char(1)','col2 char(1)'),
+ generator)
+
+ def test_bug_2671682(self):
+ from pymysql.constants import ER
+ try:
+ self.cursor.execute("describe some_non_existent_table");
+ except self.connection.ProgrammingError, msg:
+ self.assertTrue(msg.args[0] == ER.NO_SUCH_TABLE)
+
+ def test_insert_values(self):
+ from pymysql.cursors import insert_values
+ query = """INSERT FOO (a, b, c) VALUES (a, b, c)"""
+ matched = insert_values.search(query)
+ self.assertTrue(matched)
+ values = matched.group(1)
+ self.assertTrue(values == "(a, b, c)")
+
+ def test_ping(self):
+ self.connection.ping()
+
+ def test_literal_int(self):
+ self.assertTrue("2" == self.connection.literal(2))
+
+ def test_literal_float(self):
+ self.assertTrue("3.1415" == self.connection.literal(3.1415))
+
+ def test_literal_string(self):
+ self.assertTrue("'foo'" == self.connection.literal("foo"))
+
+
+if __name__ == '__main__':
+ if test_MySQLdb.leak_test:
+ import gc
+ gc.enable()
+ gc.set_debug(gc.DEBUG_LEAK)
+ unittest.main()
+
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
new file mode 100644
index 00000000000..83c002fdf39
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_dbapi20.py
@@ -0,0 +1,205 @@
+#!/usr/bin/env python
+import dbapi20
+import unittest
+import pymysql
+from pymysql.tests import base
+
+class test_MySQLdb(dbapi20.DatabaseAPI20Test):
+ driver = pymysql
+ connect_args = ()
+ connect_kw_args = base.PyMySQLTestCase.databases[0].copy()
+ connect_kw_args.update(dict(read_default_file='~/.my.cnf',
+ charset='utf8',
+ sql_mode="ANSI,STRICT_TRANS_TABLES,TRADITIONAL"))
+
+ def test_setoutputsize(self): pass
+ def test_setoutputsize_basic(self): pass
+ def test_nextset(self): pass
+
+ """The tests on fetchone and fetchall and rowcount bogusly
+ test for an exception if the statement cannot return a
+ result set. MySQL always returns a result set; it's just that
+ some things return empty result sets."""
+
+ def test_fetchall(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ # cursor.fetchall should raise an Error if called
+ # without executing a query that may return rows (such
+ # as a select)
+ self.assertRaises(self.driver.Error, cur.fetchall)
+
+ self.executeDDL1(cur)
+ for sql in self._populate():
+ cur.execute(sql)
+
+ # cursor.fetchall should raise an Error if called
+ # after executing a a statement that cannot return rows
+## self.assertRaises(self.driver.Error,cur.fetchall)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+ self.assertEqual(len(rows),len(self.samples),
+ 'cursor.fetchall did not retrieve all rows'
+ )
+ rows = [r[0] for r in rows]
+ rows.sort()
+ for i in range(0,len(self.samples)):
+ self.assertEqual(rows[i],self.samples[i],
+ 'cursor.fetchall retrieved incorrect rows'
+ )
+ rows = cur.fetchall()
+ self.assertEqual(
+ len(rows),0,
+ 'cursor.fetchall should return an empty list if called '
+ 'after the whole result set has been fetched'
+ )
+ self.assertTrue(cur.rowcount in (-1,len(self.samples)))
+
+ self.executeDDL2(cur)
+ cur.execute('select name from %sbarflys' % self.table_prefix)
+ rows = cur.fetchall()
+ self.assertTrue(cur.rowcount in (-1,0))
+ self.assertEqual(len(rows),0,
+ 'cursor.fetchall should return an empty list if '
+ 'a select query returns no rows'
+ )
+
+ finally:
+ con.close()
+
+ def test_fetchone(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+
+ # cursor.fetchone should raise an Error if called before
+ # executing a select-type query
+ self.assertRaises(self.driver.Error,cur.fetchone)
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ self.executeDDL1(cur)
+## self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ self.assertEqual(cur.fetchone(),None,
+ 'cursor.fetchone should return None if a query retrieves '
+ 'no rows'
+ )
+ self.assertTrue(cur.rowcount in (-1,0))
+
+ # cursor.fetchone should raise an Error if called after
+ # executing a query that cannnot return rows
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+## self.assertRaises(self.driver.Error,cur.fetchone)
+
+ cur.execute('select name from %sbooze' % self.table_prefix)
+ r = cur.fetchone()
+ self.assertEqual(len(r),1,
+ 'cursor.fetchone should have retrieved a single row'
+ )
+ self.assertEqual(r[0],'Victoria Bitter',
+ 'cursor.fetchone retrieved incorrect data'
+ )
+## self.assertEqual(cur.fetchone(),None,
+## 'cursor.fetchone should return None if no more rows available'
+## )
+ self.assertTrue(cur.rowcount in (-1,1))
+ finally:
+ con.close()
+
+ # Same complaint as for fetchall and fetchone
+ def test_rowcount(self):
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ self.executeDDL1(cur)
+## self.assertEqual(cur.rowcount,-1,
+## 'cursor.rowcount should be -1 after executing no-result '
+## 'statements'
+## )
+ cur.execute("insert into %sbooze values ('Victoria Bitter')" % (
+ self.table_prefix
+ ))
+## self.assertTrue(cur.rowcount in (-1,1),
+## 'cursor.rowcount should == number or rows inserted, or '
+## 'set to -1 after executing an insert statement'
+## )
+ cur.execute("select name from %sbooze" % self.table_prefix)
+ self.assertTrue(cur.rowcount in (-1,1),
+ 'cursor.rowcount should == number of rows returned, or '
+ 'set to -1 after executing a select statement'
+ )
+ self.executeDDL2(cur)
+## self.assertEqual(cur.rowcount,-1,
+## 'cursor.rowcount not being reset to -1 after executing '
+## 'no-result statements'
+## )
+ finally:
+ con.close()
+
+ def test_callproc(self):
+ pass # performed in test_MySQL_capabilities
+
+ def help_nextset_setUp(self,cur):
+ ''' Should create a procedure called deleteme
+ that returns two result sets, first the
+ number of rows in booze then "name from booze"
+ '''
+ sql="""
+ create procedure deleteme()
+ begin
+ select count(*) from %(tp)sbooze;
+ select name from %(tp)sbooze;
+ end
+ """ % dict(tp=self.table_prefix)
+ cur.execute(sql)
+
+ def help_nextset_tearDown(self,cur):
+ 'If cleaning up is needed after nextSetTest'
+ cur.execute("drop procedure deleteme")
+
+ def test_nextset(self):
+ from warnings import warn
+ con = self._connect()
+ try:
+ cur = con.cursor()
+ if not hasattr(cur,'nextset'):
+ return
+
+ try:
+ self.executeDDL1(cur)
+ sql=self._populate()
+ for sql in self._populate():
+ cur.execute(sql)
+
+ self.help_nextset_setUp(cur)
+
+ cur.callproc('deleteme')
+ numberofrows=cur.fetchone()
+ assert numberofrows[0]== len(self.samples)
+ assert cur.nextset()
+ names=cur.fetchall()
+ assert len(names) == len(self.samples)
+ s=cur.nextset()
+ if s:
+ empty = cur.fetchall()
+ self.assertEquals(len(empty), 0,
+ "non-empty result set after other result sets")
+ #warn("Incompatibility: MySQL returns an empty result set for the CALL itself",
+ # Warning)
+ #assert s == None,'No more return sets, should return None'
+ finally:
+ self.help_nextset_tearDown(cur)
+
+ finally:
+ con.close()
+
+
+if __name__ == '__main__':
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
new file mode 100644
index 00000000000..f49369cb4f7
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/tests/thirdparty/test_MySQLdb/test_MySQLdb_nonstandard.py
@@ -0,0 +1,90 @@
+import unittest
+
+import pymysql
+_mysql = pymysql
+from pymysql.constants import FIELD_TYPE
+from pymysql.tests import base
+
+
+class TestDBAPISet(unittest.TestCase):
+ def test_set_equality(self):
+ self.assertTrue(pymysql.STRING == pymysql.STRING)
+
+ def test_set_inequality(self):
+ self.assertTrue(pymysql.STRING != pymysql.NUMBER)
+
+ def test_set_equality_membership(self):
+ self.assertTrue(FIELD_TYPE.VAR_STRING == pymysql.STRING)
+
+ def test_set_inequality_membership(self):
+ self.assertTrue(FIELD_TYPE.DATE != pymysql.STRING)
+
+
+class CoreModule(unittest.TestCase):
+ """Core _mysql module features."""
+
+ def test_NULL(self):
+ """Should have a NULL constant."""
+ self.assertEqual(_mysql.NULL, 'NULL')
+
+ def test_version(self):
+ """Version information sanity."""
+ self.assertTrue(isinstance(_mysql.__version__, str))
+
+ self.assertTrue(isinstance(_mysql.version_info, tuple))
+ self.assertEqual(len(_mysql.version_info), 5)
+
+ def test_client_info(self):
+ self.assertTrue(isinstance(_mysql.get_client_info(), str))
+
+ def test_thread_safe(self):
+ self.assertTrue(isinstance(_mysql.thread_safe(), int))
+
+
+class CoreAPI(unittest.TestCase):
+ """Test _mysql interaction internals."""
+
+ def setUp(self):
+ kwargs = base.PyMySQLTestCase.databases[0].copy()
+ kwargs["read_default_file"] = "~/.my.cnf"
+ self.conn = _mysql.connect(**kwargs)
+
+ def tearDown(self):
+ self.conn.close()
+
+ def test_thread_id(self):
+ tid = self.conn.thread_id()
+ self.assertTrue(isinstance(tid, int),
+ "thread_id didn't return an int.")
+
+ self.assertRaises(TypeError, self.conn.thread_id, ('evil',),
+ "thread_id shouldn't accept arguments.")
+
+ def test_affected_rows(self):
+ self.assertEquals(self.conn.affected_rows(), 0,
+ "Should return 0 before we do anything.")
+
+
+ #def test_debug(self):
+ ## FIXME Only actually tests if you lack SUPER
+ #self.assertRaises(pymysql.OperationalError,
+ #self.conn.dump_debug_info)
+
+ def test_charset_name(self):
+ self.assertTrue(isinstance(self.conn.character_set_name(), str),
+ "Should return a string.")
+
+ def test_host_info(self):
+ self.assertTrue(isinstance(self.conn.get_host_info(), str),
+ "Should return a string.")
+
+ def test_proto_info(self):
+ self.assertTrue(isinstance(self.conn.get_proto_info(), int),
+ "Should return an int.")
+
+ def test_server_info(self):
+ self.assertTrue(isinstance(self.conn.get_server_info(), basestring),
+ "Should return an str.")
+
+if __name__ == "__main__":
+ unittest.main()
diff --git a/tools/marvin/marvin/pymysql/times.py b/tools/marvin/marvin/pymysql/times.py
new file mode 100644
index 00000000000..c47db09eb9c
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/times.py
@@ -0,0 +1,16 @@
+from time import localtime
+from datetime import date, datetime, time, timedelta
+
+Date = date
+Time = time
+TimeDelta = timedelta
+Timestamp = datetime
+
+def DateFromTicks(ticks):
+ return date(*localtime(ticks)[:3])
+
+def TimeFromTicks(ticks):
+ return time(*localtime(ticks)[3:6])
+
+def TimestampFromTicks(ticks):
+ return datetime(*localtime(ticks)[:6])
diff --git a/tools/marvin/marvin/pymysql/util.py b/tools/marvin/marvin/pymysql/util.py
new file mode 100644
index 00000000000..cc622e57b74
--- /dev/null
+++ b/tools/marvin/marvin/pymysql/util.py
@@ -0,0 +1,19 @@
+import struct
+
+def byte2int(b):
+ if isinstance(b, int):
+ return b
+ else:
+ return struct.unpack("!B", b)[0]
+
+def int2byte(i):
+ return struct.pack("!B", i)
+
+def join_bytes(bs):
+ if len(bs) == 0:
+ return ""
+ else:
+ rv = bs[0]
+ for b in bs[1:]:
+ rv += b
+ return rv
diff --git a/tools/marvin/marvin/remoteSSHClient.py b/tools/marvin/marvin/remoteSSHClient.py
new file mode 100644
index 00000000000..806de7a9eb8
--- /dev/null
+++ b/tools/marvin/marvin/remoteSSHClient.py
@@ -0,0 +1,36 @@
+import paramiko
+import cloudstackException
+class remoteSSHClient(object):
+ def __init__(self, host, port, user, passwd, timeout=120):
+ self.host = host
+ self.port = port
+ self.user = user
+ self.passwd = passwd
+ self.ssh = paramiko.SSHClient()
+ self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+ try:
+ self.ssh.connect(str(host),int(port), user, passwd, timeout=timeout)
+ except paramiko.SSHException, sshex:
+ raise cloudstackException.InvalidParameterException(repr(sshex))
+
+ def execute(self, command):
+ stdin, stdout, stderr = self.ssh.exec_command(command)
+ output = stdout.readlines()
+ errors = stderr.readlines()
+ results = []
+ if output is not None and len(output) == 0:
+ if errors is not None and len(errors) > 0:
+ for error in errors:
+ results.append(error.rstrip())
+
+ else:
+ for strOut in output:
+ results.append(strOut.rstrip())
+
+ return results
+
+
+if __name__ == "__main__":
+ ssh = remoteSSHClient("192.168.137.2", 22, "root", "password")
+ print ssh.execute("ls -l")
+ print ssh.execute("rm x")
diff --git a/tools/marvin/marvin/sandbox/README.txt b/tools/marvin/marvin/sandbox/README.txt
new file mode 100644
index 00000000000..7efc190baf6
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/README.txt
@@ -0,0 +1,19 @@
+Welcome to the marvin sandbox
+----------------------------------
+
+In here you should find a few common deployment models of CloudStack that you
+can configure with properties files to suit your own deployment. One deployment
+model for each of - advanced zone, basic zone and a demo are given.
+
+$ ls -
+basic/
+advanced/
+demo/
+
+Each property file is divided into logical sections and should be familiar to
+those who have deployed CloudStack before. Once you have your properties file
+you will have to create a JSON configuration of your deployment using the
+python script provided in the respective folder.
+
+The demo files are from the tutorial for testing with python that can be found
+on the wiki.cloudstack.org
diff --git a/tools/marvin/marvin/sandbox/advanced/advanced_env.py b/tools/marvin/marvin/sandbox/advanced/advanced_env.py
new file mode 100644
index 00000000000..25e83dbb0f4
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/advanced/advanced_env.py
@@ -0,0 +1,117 @@
+#!/usr/bin/env python
+
+'''
+############################################################
+# Experimental state of scripts
+# * Need to be reviewed
+# * Only a sandbox
+############################################################
+'''
+
+from ConfigParser import SafeConfigParser
+from optparse import OptionParser
+from configGenerator import *
+import random
+
+
+def getGlobalSettings(config):
+ for k, v in dict(config.items('globals')).iteritems():
+ cfg = configuration()
+ cfg.name = k
+ cfg.value = v
+ yield cfg
+
+
+def describeResources(config):
+ zs = cloudstackConfiguration()
+
+ z = zone()
+ z.dns1 = config.get('environment', 'dns')
+ z.internaldns1 = config.get('environment', 'dns')
+ z.name = 'Sandbox-%s'%(config.get('environment', 'hypervisor'))
+ z.networktype = 'Advanced'
+ z.guestcidraddress = '10.1.1.0/24'
+
+ p = pod()
+ p.name = 'POD0'
+ p.gateway = config.get('cloudstack', 'private.gateway')
+ p.startip = config.get('cloudstack', 'private.pod.startip')
+ p.endip = config.get('cloudstack', 'private.pod.endip')
+ p.netmask = '255.255.255.0'
+
+ v = iprange()
+ v.gateway = config.get('cloudstack', 'public.gateway')
+ v.startip = config.get('cloudstack', 'public.vlan.startip')
+ v.endip = config.get('cloudstack', 'public.vlan.endip')
+ v.netmask = '255.255.255.0'
+ v.vlan = config.get('cloudstack', 'public.vlan')
+ z.ipranges.append(v)
+
+ c = cluster()
+ c.clustername = 'C0'
+ c.hypervisor = config.get('environment', 'hypervisor')
+ c.clustertype = 'CloudManaged'
+
+ h = host()
+ h.username = 'root'
+ h.password = 'password'
+ h.url = 'http://%s'%(config.get('cloudstack', 'host'))
+ c.hosts.append(h)
+
+ ps = primaryStorage()
+ ps.name = 'PS0'
+ ps.url = config.get('cloudstack', 'pool')
+ c.primaryStorages.append(ps)
+
+ p.clusters.append(c)
+ z.pods.append(p)
+
+ secondary = secondaryStorage()
+ secondary.url = config.get('cloudstack', 'secondary')
+ z.secondaryStorages.append(secondary)
+
+ '''Add zone'''
+ zs.zones.append(z)
+
+ '''Add mgt server'''
+ mgt = managementServer()
+ mgt.mgtSvrIp = config.get('environment', 'mshost')
+ zs.mgtSvr.append(mgt)
+
+ '''Add a database'''
+ db = dbServer()
+ db.dbSvr = config.get('environment', 'database')
+ zs.dbSvr = db
+
+ '''Add some configuration'''
+ [zs.globalConfig.append(cfg) for cfg in getGlobalSettings(config)]
+
+ ''''add loggers'''
+ testClientLogger = logger()
+ testClientLogger.name = 'TestClient'
+ testClientLogger.file = '/var/log/testclient.log'
+
+ testCaseLogger = logger()
+ testCaseLogger.name = 'TestCase'
+ testCaseLogger.file = '/var/log/testcase.log'
+
+ zs.logger.append(testClientLogger)
+ zs.logger.append(testCaseLogger)
+ return zs
+
+
+if __name__ == '__main__':
+ parser = OptionParser()
+ parser.add_option('-i', '--input', action='store', default='setup.properties', \
+ dest='input', help='file containing environment setup information')
+ parser.add_option('-o', '--output', action='store', default='./sandbox.cfg', \
+ dest='output', help='path where environment json will be generated')
+
+
+ (opts, args) = parser.parse_args()
+
+ cfg_parser = SafeConfigParser()
+ cfg_parser.read(opts.input)
+
+ cfg = describeResources(cfg_parser)
+ generate_setup_config(cfg, opts.output)
diff --git a/tools/marvin/marvin/sandbox/advanced/setup.properties b/tools/marvin/marvin/sandbox/advanced/setup.properties
new file mode 100644
index 00000000000..48b082e4f37
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/advanced/setup.properties
@@ -0,0 +1,36 @@
+[globals]
+expunge.delay=60
+expunge.interval=60
+storage.cleanup.interval=300
+account.cleanup.interval=600
+expunge.workers=3
+workers=10
+use.user.concentrated.pod.allocation=false
+vm.allocation.algorithm=random
+vm.op.wait.interval=5
+guest.domain.suffix=sandbox.kvm
+instance.name=QA
+direct.agent.load.size=1000
+default.page.size=10000
+check.pod.cidrs=true
+secstorage.allowed.internal.sites=10.147.28.0/24
+[environment]
+dns=10.147.28.6
+mshost=10.147.29.111
+database=10.147.29.111
+[cloudstack]
+zone.vlan=675-679
+#pod configuration
+private.gateway=10.147.29.1
+private.pod.startip=10.147.29.150
+private.pod.endip=10.147.29.159
+#public vlan range
+public.gateway=10.147.31.1
+public.vlan=31
+public.vlan.startip=10.147.31.150
+public.vlan.endip=10.147.31.159
+#hosts
+host=10.147.29.58
+#pools
+pool=nfs://10.147.28.6:/export/home/sandbox/kamakura
+secondary=nfs://10.147.28.6:/export/home/sandbox/sstor
diff --git a/tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py b/tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py
new file mode 100644
index 00000000000..bae181ca693
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/advanced/tests/test_scenarios.py
@@ -0,0 +1,126 @@
+#!/usr/bin/env python
+try:
+ import unittest2 as unittest
+except ImportError:
+ import unittest
+
+import random
+import hashlib
+from cloudstackTestCase import *
+import remoteSSHClient
+
+class SampleScenarios(cloudstackTestCase):
+ '''
+ '''
+ def setUp(self):
+ pass
+
+
+ def tearDown(self):
+ pass
+
+
+ def test_1_createAccounts(self, numberOfAccounts=2):
+ '''
+ Create a bunch of user accounts
+ '''
+ mdf = hashlib.md5()
+ mdf.update('password')
+ mdf_pass = mdf.hexdigest()
+ api = self.testClient.getApiClient()
+ for i in range(1, numberOfAccounts + 1):
+ acct = createAccount.createAccountCmd()
+ acct.accounttype = 0
+ acct.firstname = 'user' + str(i)
+ acct.lastname = 'user' + str(i)
+ acct.password = mdf_pass
+ acct.username = 'user' + str(i)
+ acct.email = 'user@example.com'
+ acct.account = 'user' + str(i)
+ acct.domainid = 1
+ acctResponse = api.createAccount(acct)
+ self.debug("successfully created account: %s, user: %s, id: %s"%(acctResponse.account, acctResponse.username, acctResponse.id))
+
+
+ def test_2_createServiceOffering(self):
+ apiClient = self.testClient.getApiClient()
+ createSOcmd=createServiceOffering.createServiceOfferingCmd()
+ createSOcmd.name='Sample SO'
+ createSOcmd.displaytext='Sample SO'
+ createSOcmd.storagetype='shared'
+ createSOcmd.cpunumber=1
+ createSOcmd.cpuspeed=100
+ createSOcmd.memory=128
+ createSOcmd.offerha='false'
+ createSOresponse = apiClient.createServiceOffering(createSOcmd)
+ return createSOresponse.id
+
+ def deployCmd(self, account, service):
+ deployVmCmd = deployVirtualMachine.deployVirtualMachineCmd()
+ deployVmCmd.zoneid = 1
+ deployVmCmd.account=account
+ deployVmCmd.domainid=1
+ deployVmCmd.templateid=2
+ deployVmCmd.serviceofferingid=service
+ return deployVmCmd
+
+ def listVmsInAccountCmd(self, acct):
+ api = self.testClient.getApiClient()
+ listVmCmd = listVirtualMachines.listVirtualMachinesCmd()
+ listVmCmd.account = acct
+ listVmCmd.zoneid = 1
+ listVmCmd.domainid = 1
+ listVmResponse = api.listVirtualMachines(listVmCmd)
+ return listVmResponse
+
+
+ def destroyVmCmd(self, key):
+ api = self.testClient.getApiClient()
+ destroyVmCmd = destroyVirtualMachine.destroyVirtualMachineCmd()
+ destroyVmCmd.id = key
+ api.destroyVirtualMachine(destroyVmCmd)
+
+
+ def test_3_stressDeploy(self):
+ '''
+ Deploy 5 Vms in each account
+ '''
+ service_id = self.test_2_createServiceOffering()
+ api = self.testClient.getApiClient()
+ for acct in range(1, 5):
+ [api.deployVirtualMachine(self.deployCmd('user'+str(acct), service_id)) for x in range(0,5)]
+
+ @unittest.skip("skipping destroys")
+ def test_4_stressDestroy(self):
+ '''
+ Cleanup all Vms in every account
+ '''
+ api = self.testClient.getApiClient()
+ for acct in range(1, 6):
+ for vm in self.listVmsInAccountCmd('user'+str(acct)):
+ if vm is not None:
+ self.destroyVmCmd(vm.id)
+
+ @unittest.skip("skipping destroys")
+ def test_5_combineStress(self):
+ for i in range(0, 5):
+ self.test_3_stressDeploy()
+ self.test_4_stressDestroy()
+
+ def deployN(self,nargs=300,batchsize=0):
+ '''
+ Deploy Nargs number of VMs concurrently in batches of size {batchsize}.
+ When batchsize is 0 all Vms are deployed in one batch
+ VMs will be deployed in 5:2:6 ratio
+ '''
+ cmds = []
+
+ if batchsize == 0:
+ self.testClient.submitCmdsAndWait(cmds)
+ else:
+ while len(z) > 0:
+ try:
+ newbatch = [cmds.pop() for b in range(batchsize)] #pop batchsize items
+ self.testClient.submitCmdsAndWait(newbatch)
+ except IndexError:
+ break
diff --git a/tools/marvin/marvin/sandbox/basic/basic_env.py b/tools/marvin/marvin/sandbox/basic/basic_env.py
new file mode 100644
index 00000000000..e69de29bb2d
diff --git a/tools/marvin/marvin/sandbox/demo/README b/tools/marvin/marvin/sandbox/demo/README
new file mode 100644
index 00000000000..650ea05df1c
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/demo/README
@@ -0,0 +1,4 @@
+Demo files for use with the tutorial on "Testing with Python".
+
+testDeployVM.py - to be run against a 2.2.y installation of management server
+testSshDeployVM.py - to be run against a 3.0.x installation of management server
diff --git a/tools/marvin/marvin/sandbox/demo/testDeployVM.py b/tools/marvin/marvin/sandbox/demo/testDeployVM.py
new file mode 100644
index 00000000000..4afaee346ad
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/demo/testDeployVM.py
@@ -0,0 +1,98 @@
+#!/usr/bin/env python
+# Copyright 2012 Citrix Systems, Inc. Licensed under the
+# Apache License, Version 2.0 (the "License"); you may not use this
+# file except in compliance with the License. Citrix Systems, Inc.
+# reserves all rights not expressly granted by the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Automatically generated by addcopyright.py at 04/03/2012
+
+from cloudstackTestCase import *
+
+import unittest
+import hashlib
+import random
+
+class TestDeployVm(cloudstackTestCase):
+ """
+ This test deploys a virtual machine into a user account
+ using the small service offering and builtin template
+ """
+ def setUp(self):
+ """
+ CloudStack internally saves its passwords in md5 form and that is how we
+ specify it in the API. Python's hashlib library helps us to quickly hash
+ strings as follows
+ """
+ mdf = hashlib.md5()
+ mdf.update('password')
+ mdf_pass = mdf.hexdigest()
+
+ self.apiClient = self.testClient.getApiClient() #Get ourselves an API client
+
+ self.acct = createAccount.createAccountCmd() #The createAccount command
+ self.acct.accounttype = 0 #We need a regular user. admins have accounttype=1
+ self.acct.firstname = 'bugs'
+ self.acct.lastname = 'bunny' #What's up doc?
+ self.acct.password = mdf_pass #The md5 hashed password string
+ self.acct.username = 'bugs'
+ self.acct.email = 'bugs@rabbithole.com'
+ self.acct.account = 'bugs'
+ self.acct.domainid = 1 #The default ROOT domain
+ self.acctResponse = self.apiClient.createAccount(self.acct)
+ #And upon successful creation we'll log a helpful message in our logs
+ self.debug("successfully created account: %s, user: %s, id: \
+ %s"%(self.acctResponse.account.account, \
+ self.acctResponse.account.username, \
+ self.acctResponse.account.id))
+
+ def test_DeployVm(self):
+ """
+ Let's start by defining the attributes of our VM that we will be
+ deploying on CloudStack. We will be assuming a single zone is available
+ and is configured and all templates are Ready
+
+ The hardcoded values are used only for brevity.
+ """
+ deployVmCmd = deployVirtualMachine.deployVirtualMachineCmd()
+ deployVmCmd.zoneid = 1
+ deployVmCmd.account = self.acct.account
+ deployVmCmd.domainid = self.acct.domainid
+ deployVmCmd.templateid = 2
+ deployVmCmd.serviceofferingid = 1
+
+ deployVmResponse = self.apiClient.deployVirtualMachine(deployVmCmd)
+ self.debug("VM %s was deployed in the job %s"%(deployVmResponse.id, deployVmResponse.jobid))
+
+ # At this point our VM is expected to be Running. Let's find out what
+ # listVirtualMachines tells us about VMs in this account
+
+ listVmCmd = listVirtualMachines.listVirtualMachinesCmd()
+ listVmCmd.id = deployVmResponse.id
+ listVmResponse = self.apiClient.listVirtualMachines(listVmCmd)
+
+ self.assertNotEqual(len(listVmResponse), 0, "Check if the list API \
+ returns a non-empty response")
+
+ vm = listVmResponse[0]
+
+ self.assertEqual(vm.id, deployVmResponse.id, "Check if the VM returned \
+ is the same as the one we deployed")
+
+
+ self.assertEqual(vm.state, "Running", "Check if VM has reached \
+ a state of running")
+
+ def tearDown(self):
+ """
+ And finally let us cleanup the resources we created by deleting the
+ account. All good unittests are atomic and rerunnable this way
+ """
+ deleteAcct = deleteAccount.deleteAccountCmd()
+ deleteAcct.id = self.acctResponse.account.id
+ self.apiClient.deleteAccount(deleteAcct)
diff --git a/tools/marvin/marvin/sandbox/demo/testSshDeployVM.py b/tools/marvin/marvin/sandbox/demo/testSshDeployVM.py
new file mode 100644
index 00000000000..106f693fda1
--- /dev/null
+++ b/tools/marvin/marvin/sandbox/demo/testSshDeployVM.py
@@ -0,0 +1,143 @@
+#!/usr/bin/env python
+# Copyright 2012 Citrix Systems, Inc. Licensed under the
+# Apache License, Version 2.0 (the "License"); you may not use this
+# file except in compliance with the License. Citrix Systems, Inc.
+# reserves all rights not expressly granted by the License.
+# You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# Automatically generated by addcopyright.py at 04/03/2012
+
+
+
+from cloudstackTestCase import *
+from remoteSSHClient import remoteSSHClient
+
+import unittest
+import hashlib
+import random
+import string
+
+class TestDeployVm(cloudstackTestCase):
+ """
+ This test deploys a virtual machine into a user account
+ using the small service offering and builtin template
+ """
+ @classmethod
+ def setUpClass(cls):
+ """
+ CloudStack internally saves its passwords in md5 form and that is how we
+ specify it in the API. Python's hashlib library helps us to quickly hash
+ strings as follows
+ """
+ mdf = hashlib.md5()
+ mdf.update('password')
+ mdf_pass = mdf.hexdigest()
+ acctName = 'bugs-'+''.join(random.choice(string.ascii_uppercase + string.digits) for x in range(6)) #randomly generated account
+
+ cls.apiClient = super(TestDeployVm, cls).getClsTestClient().getApiClient()
+ cls.acct = createAccount.createAccountCmd() #The createAccount command
+ cls.acct.accounttype = 0 #We need a regular user. admins have accounttype=1
+ cls.acct.firstname = 'bugs'
+ cls.acct.lastname = 'bunny' #What's up doc?
+ cls.acct.password = mdf_pass #The md5 hashed password string
+ cls.acct.username = acctName
+ cls.acct.email = 'bugs@rabbithole.com'
+ cls.acct.account = acctName
+ cls.acct.domainid = 1 #The default ROOT domain
+ cls.acctResponse = cls.apiClient.createAccount(cls.acct)
+
+ def setUpNAT(self, virtualmachineid):
+ listSourceNat = listPublicIpAddresses.listPublicIpAddressesCmd()
+ listSourceNat.account = self.acct.account
+ listSourceNat.domainid = self.acct.domainid
+ listSourceNat.issourcenat = True
+
+ listsnatresponse = self.apiClient.listPublicIpAddresses(listSourceNat)
+ self.assertNotEqual(len(listsnatresponse), 0, "Found a source NAT for the acct %s"%self.acct.account)
+
+ snatid = listsnatresponse[0].id
+ snatip = listsnatresponse[0].ipaddress
+
+ try:
+ createFwRule = createFirewallRule.createFirewallRuleCmd()
+ createFwRule.cidrlist = "0.0.0.0/0"
+ createFwRule.startport = 22
+ createFwRule.endport = 22
+ createFwRule.ipaddressid = snatid
+ createFwRule.protocol = "tcp"
+ createfwresponse = self.apiClient.createFirewallRule(createFwRule)
+
+ createPfRule = createPortForwardingRule.createPortForwardingRuleCmd()
+ createPfRule.privateport = 22
+ createPfRule.publicport = 22
+ createPfRule.virtualmachineid = virtualmachineid
+ createPfRule.ipaddressid = snatid
+ createPfRule.protocol = "tcp"
+
+ createpfresponse = self.apiClient.createPortForwardingRule(createPfRule)
+ except e:
+ self.debug("Failed to create PF rule in account %s due to %s"%(self.acct.account, e))
+ raise e
+ finally:
+ return snatip
+
+ def test_DeployVm(self):
+ """
+ Let's start by defining the attributes of our VM that we will be
+ deploying on CloudStack. We will be assuming a single zone is available
+ and is configured and all templates are Ready
+
+ The hardcoded values are used only for brevity.
+ """
+ deployVmCmd = deployVirtualMachine.deployVirtualMachineCmd()
+ deployVmCmd.zoneid = 1
+ deployVmCmd.account = self.acct.account
+ deployVmCmd.domainid = self.acct.domainid
+ deployVmCmd.templateid = 5 #CentOS 5.6 builtin
+ deployVmCmd.serviceofferingid = 1
+
+ deployVmResponse = self.apiClient.deployVirtualMachine(deployVmCmd)
+ self.debug("VM %s was deployed in the job %s"%(deployVmResponse.id, deployVmResponse.jobid))
+
+ # At this point our VM is expected to be Running. Let's find out what
+ # listVirtualMachines tells us about VMs in this account
+
+ listVmCmd = listVirtualMachines.listVirtualMachinesCmd()
+ listVmCmd.id = deployVmResponse.id
+ listVmResponse = self.apiClient.listVirtualMachines(listVmCmd)
+
+ self.assertNotEqual(len(listVmResponse), 0, "Check if the list API \
+ returns a non-empty response")
+
+ vm = listVmResponse[0]
+ hostname = vm.name
+ nattedip = self.setUpNAT(vm.id)
+
+ self.assertEqual(vm.id, deployVmResponse.id, "Check if the VM returned \
+ is the same as the one we deployed")
+
+
+ self.assertEqual(vm.state, "Running", "Check if VM has reached \
+ a state of running")
+
+ # SSH login and compare hostname
+ ssh_client = remoteSSHClient(nattedip, 22, "root", "password")
+ stdout = ssh_client.execute("hostname")
+
+ self.assertEqual(hostname, stdout[0], "cloudstack VM name and hostname match")
+
+
+ @classmethod
+ def tearDownClass(cls):
+ """
+ And finally let us cleanup the resources we created by deleting the
+ account. All good unittests are atomic and rerunnable this way
+ """
+ deleteAcct = deleteAccount.deleteAccountCmd()
+ deleteAcct.id = cls.acctResponse.account.id
+ cls.apiClient.deleteAccount(deleteAcct)
diff --git a/tools/marvin/setup.py b/tools/marvin/setup.py
new file mode 100644
index 00000000000..ce28b15b365
--- /dev/null
+++ b/tools/marvin/setup.py
@@ -0,0 +1,34 @@
+#!/usr/bin/env python
+# Copyright 2012 Citrix Systems, Inc. Licensed under the
+# Apache License, Version 2.0 (the "License"); you may not use this
+# file except in compliance with the License. Citrix Systems, Inc.
+
+from distutils.core import setup
+from sys import version
+
+if version < "2.7":
+ print "Marvin needs at least python 2.7, found : \n%s"%version
+else:
+ try:
+ import paramiko
+ except ImportError:
+ print "Marvin requires paramiko to be installed"
+ raise
+
+ setup(name="Marvin",
+ version="0.1.0",
+ description="Marvin - Python client for testing cloudstack",
+ author="Edison Su",
+ author_email="Edison.Su@citrix.com",
+ maintainer="Prasanna Santhanam",
+ maintainer_email="Prasanna.Santhanam@citrix.com",
+ long_description="Marvin is the cloudstack testclient written around the python unittest framework",
+ platforms=("Any",),
+ url="http://jenkins.cloudstack.org:8080/job/marvin",
+ packages=["marvin", "marvin.cloudstackAPI", "marvin.sandbox.tests", "marvin.pymysql", "marvin.pymysql.constants", "marvin.pymysql.tests"],
+ license="LICENSE.txt",
+ requires=[
+ "paramiko (>1.4)",
+ "Python (>=2.7)"
+ ]
+ )