diff --git a/tools/migration/10to21Upgrade.txt b/tools/migration/10to21Upgrade.txt new file mode 100644 index 00000000000..a94989434df --- /dev/null +++ b/tools/migration/10to21Upgrade.txt @@ -0,0 +1,63 @@ +CloudStack Migration: 1.0.x to 2.1.x + +How it works: + +There are four major steps to migrating a 1.0.x system's data to version 2.1.x of the CloudStack: + 1. Users in the 1.0.x system need to be re-created in the 2.1.x system + 2. Public IPs that are allocated in the 1.0.x system need to be allocated in the 2.1.x sytem + 3. Port forwarding and load balancer rules that were created in the 1.0.x system need to be re-created in the 2.1.x system + 4. Virtual machines and the data on their root/data disks need to be migrated + +To accomplish steps 1, 2, and 3, the CloudStack Migration tool automatically reads information from 1.0.x system's database and re-creates the data in the 2.1.x system through a combination of API calls and direct SQL inserts. +To accomplish step 4, the tool creates a direct link between 1.0.x storage servers and 2.1.x XenServers, and copies volume data using the XenServer API. + +The overall process should take between 15-30 minutes per VM, depending on the speed of your private network and the size of the volumes involved. + + +What you need: + +1. A running 1.0.x system that has one zone and one pod. +2. Necessary hardware for the 2.1.x system: one or more management servers, and one or more XenServers that are all on the same public network as the 1.0.x system. + * The 2.1.x management server must be able to access the 1.0.x management server's database, as well as the 1.0.x system's storage servers. +3. The 10to21Upgrade.tgz package. + + + +How to run the migration tool: + +1. If you DO NOT have a 2.1.x system installed and running: Do a fresh 2.1.x install (please refer to the admin guide for instructions), taking into account the following special instructions: + * Before you add any XenServer host, add one public IP range into the system with exactly two public IPs; these must be unallocated in the 1.0.x system. + * After adding all of your XenServer hosts in the UI, verify that the secondary storage VM and console proxy VM started. Then, add remainining public IPs as a second IP range. + +2. Register a bootable ISO and note down its database ID (you will need this for step 6). The OS of the ISO doesn't matter. + If you already have a bootable ISO in the 2.1 system, you can use its database ID in step 6. + * If you have no preference about which ISO to use, simply enter the following URL to register an Ubuntu 10.04 ISO: + http://localhost:8096/client/api?command=registerIso&bootable=true&zoneid=1&ispublic=true&name=Ubuntu&displayText=Ubuntu&url=http://ftp.ucsb.edu/pub/mirrors/linux/ubuntu/10.04/ubuntu-10.04.1-desktop-amd64.iso&ostypeid=59 + * Else, use the following API command (replacing variables as necessary): + http://localhost:8096/client/api/?command=registerIso&bootable=true&zoneid=1&ispublic=true&name=ISO_NAME&displayText=ISO_DISPLAY_TEXT&url=ISO_URL&ostypeid=ISO_OS_TYPE_ID + * To determine the ISO_OS_TYPE_ID, run the following API command and find the ID that corresponds to the OS of the ISO: + http://localhost:8096/client/api/?command=listOsTypes + +3. For every service offering in the 1.0.x system: + * Make sure there is a service offering in the 2.1.x system with the same cpu #, cpu speed, and RAM size. + * Make sure there is a disk offering in the 2.1.x system with the same disk size. 1.0.x allowed for creating service offerings with disk sizes that had an arbitrary number of MB. + However, in 2.1.x, disk offerings must be created in multiples of 1 GB. If there is a service offering in the 1.0.x system with a disk size that is not a + multiple of 1 Gb (1024 MB), create a disk offering in the 2.1.x system that is the 1.0.x disk size rounded to the next GB. For example, a disk size of 2000 MB in 1.0.x + will correspond to a disk offering with size 2 GB in the 2.1.x system. + +4. Install Python on the 2.1.x management server, if it isn't already installed. Version 2.4 or above is required. + +5. Download 10to21Upgrade.tgz to any folder on the 2.1.x management server, and uncompress it. + +6. Fill out upgrade.properties. Instructions about various fields are included in the file itself. + +7. If you DO have a 2.1.x system installed and running: + * Add a new public IP range in the 2.1.x system that corresponds to the public IP range in the 1.0.x system. + The public IP ranges that already exist in the 2.1.x system must not overlap with the IP range in the 1.0.x system. + * Run "python upgrade.py publicips". This will immediately allocate the public IPs of all users in the 1.0.x system, so that existing 2.1.x users can't allocate them. + +8. Run "python upgrade.py" on the 2.1.x management server. Status information will be printed out to the console. + * If there is an error, please contact Cloud.com Support and send us the migration log. By default, this file is called "migrationLog" and is in the same directory as upgrade.py. + * After the cause for an error has been resolved, you can run upgrade.py again from the beginning; it will skip over any work that has already been done. + * If you would like to re-enable a user on 1.0.x system, simply stop all of the user's VMs that have been migrated on the 2.1.x system, and start the user's VMs on the 1.0.x system. + diff --git a/tools/migration/XenAPI.py b/tools/migration/XenAPI.py new file mode 100644 index 00000000000..dfa72b7c3bd --- /dev/null +++ b/tools/migration/XenAPI.py @@ -0,0 +1,229 @@ +#============================================================================ +# This library is free software; you can redistribute it and/or +# modify it under the terms of version 2.1 of the GNU Lesser General Public +# License as published by the Free Software Foundation. +# +# This library is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU +# Lesser General Public License for more details. +# +# You should have received a copy of the GNU Lesser General Public +# License along with this library; if not, write to the Free Software +# Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA +#============================================================================ +# Copyright (C) 2006-2007 XenSource Inc. +#============================================================================ +# +# Parts of this file are based upon xmlrpclib.py, the XML-RPC client +# interface included in the Python distribution. +# +# Copyright (c) 1999-2002 by Secret Labs AB +# Copyright (c) 1999-2002 by Fredrik Lundh +# +# By obtaining, using, and/or copying this software and/or its +# associated documentation, you agree that you have read, understood, +# and will comply with the following terms and conditions: +# +# Permission to use, copy, modify, and distribute this software and +# its associated documentation for any purpose and without fee is +# hereby granted, provided that the above copyright notice appears in +# all copies, and that both that copyright notice and this permission +# notice appear in supporting documentation, and that the name of +# Secret Labs AB or the author not be used in advertising or publicity +# pertaining to distribution of the software without specific, written +# prior permission. +# +# SECRET LABS AB AND THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD +# TO THIS SOFTWARE, INCLUDING ALL IMPLIED WARRANTIES OF MERCHANT- +# ABILITY AND FITNESS. IN NO EVENT SHALL SECRET LABS AB OR THE AUTHOR +# BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL DAMAGES OR ANY +# DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, +# WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS +# ACTION, ARISING OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE +# OF THIS SOFTWARE. +# -------------------------------------------------------------------- + +import gettext +import xmlrpclib +import httplib +import socket + +translation = gettext.translation('xen-xm', fallback = True) + +API_VERSION_1_1 = '1.1' +API_VERSION_1_2 = '1.2' + +class Failure(Exception): + def __init__(self, details): + self.details = details + + def __str__(self): + try: + return str(self.details) + except Exception, exn: + import sys + print >>sys.stderr, exn + return "Xen-API failure: %s" % str(self.details) + + def _details_map(self): + return dict([(str(i), self.details[i]) + for i in range(len(self.details))]) + + +_RECONNECT_AND_RETRY = (lambda _ : ()) + +class UDSHTTPConnection(httplib.HTTPConnection): + """HTTPConnection subclass to allow HTTP over Unix domain sockets. """ + def connect(self): + path = self.host.replace("_", "/") + self.sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + self.sock.connect(path) + +class UDSHTTP(httplib.HTTP): + _connection_class = UDSHTTPConnection + +class UDSTransport(xmlrpclib.Transport): + def __init__(self, use_datetime=0): + self._use_datetime = use_datetime + self._extra_headers=[] + def add_extra_header(self, key, value): + self._extra_headers += [ (key,value) ] + def make_connection(self, host): + return UDSHTTP(host) + def send_request(self, connection, handler, request_body): + connection.putrequest("POST", handler) + for key, value in self._extra_headers: + connection.putheader(key, value) + +class Session(xmlrpclib.ServerProxy): + """A server proxy and session manager for communicating with xapi using + the Xen-API. + + Example: + + session = Session('http://localhost/') + session.login_with_password('me', 'mypassword') + session.xenapi.VM.start(vm_uuid) + session.xenapi.session.logout() + """ + + def __init__(self, uri, transport=None, encoding=None, verbose=0, + allow_none=1): + xmlrpclib.ServerProxy.__init__(self, uri, transport, encoding, + verbose, allow_none) + self.transport = transport + self._session = None + self.last_login_method = None + self.last_login_params = None + self.API_version = API_VERSION_1_1 + + + def xenapi_request(self, methodname, params): + if methodname.startswith('login'): + self._login(methodname, params) + return None + elif methodname == 'logout' or methodname == 'session.logout': + self._logout() + return None + else: + retry_count = 0 + while retry_count < 3: + full_params = (self._session,) + params + result = _parse_result(getattr(self, methodname)(*full_params)) + if result == _RECONNECT_AND_RETRY: + retry_count += 1 + if self.last_login_method: + self._login(self.last_login_method, + self.last_login_params) + else: + raise xmlrpclib.Fault(401, 'You must log in') + else: + return result + raise xmlrpclib.Fault( + 500, 'Tried 3 times to get a valid session, but failed') + + + def _login(self, method, params): + result = _parse_result(getattr(self, 'session.%s' % method)(*params)) + if result == _RECONNECT_AND_RETRY: + raise xmlrpclib.Fault( + 500, 'Received SESSION_INVALID when logging in') + self._session = result + self.last_login_method = method + self.last_login_params = params + self.API_version = self._get_api_version() + + def _logout(self): + try: + if self.last_login_method.startswith("slave_local"): + return _parse_result(self.session.local_logout(self._session)) + else: + return _parse_result(self.session.logout(self._session)) + finally: + self._session = None + self.last_login_method = None + self.last_login_params = None + self.API_version = API_VERSION_1_1 + + def _get_api_version(self): + pool = self.xenapi.pool.get_all()[0] + host = self.xenapi.pool.get_master(pool) + major = self.xenapi.host.get_API_version_major(host) + minor = self.xenapi.host.get_API_version_minor(host) + return "%s.%s"%(major,minor) + + def __getattr__(self, name): + if name == 'handle': + return self._session + elif name == 'xenapi': + return _Dispatcher(self.API_version, self.xenapi_request, None) + elif name.startswith('login') or name.startswith('slave_local'): + return lambda *params: self._login(name, params) + else: + return xmlrpclib.ServerProxy.__getattr__(self, name) + +def xapi_local(): + return Session("http://_var_xapi_xapi/", transport=UDSTransport()) + +def _parse_result(result): + if type(result) != dict or 'Status' not in result: + raise xmlrpclib.Fault(500, 'Missing Status in response from server' + result) + if result['Status'] == 'Success': + if 'Value' in result: + return result['Value'] + else: + raise xmlrpclib.Fault(500, + 'Missing Value in response from server') + else: + if 'ErrorDescription' in result: + if result['ErrorDescription'][0] == 'SESSION_INVALID': + return _RECONNECT_AND_RETRY + else: + raise Failure(result['ErrorDescription']) + else: + raise xmlrpclib.Fault( + 500, 'Missing ErrorDescription in response from server') + + +# Based upon _Method from xmlrpclib. +class _Dispatcher: + def __init__(self, API_version, send, name): + self.__API_version = API_version + self.__send = send + self.__name = name + + def __repr__(self): + if self.__name: + return '' % self.__name + else: + return '' + + def __getattr__(self, name): + if self.__name is None: + return _Dispatcher(self.__API_version, self.__send, name) + else: + return _Dispatcher(self.__API_version, self.__send, "%s.%s" % (self.__name, name)) + + def __call__(self, *args): + return self.__send(self.__name, args) diff --git a/tools/migration/paramiko/__init__.py b/tools/migration/paramiko/__init__.py new file mode 100644 index 00000000000..ac0d559b2cd --- /dev/null +++ b/tools/migration/paramiko/__init__.py @@ -0,0 +1,141 @@ +# Copyright (C) 2003-2009 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +I{Paramiko} (a combination of the esperanto words for "paranoid" and "friend") +is a module for python 2.3 or greater that implements the SSH2 protocol for +secure (encrypted and authenticated) connections to remote machines. Unlike +SSL (aka TLS), the SSH2 protocol does not require heirarchical certificates +signed by a powerful central authority. You may know SSH2 as the protocol that +replaced C{telnet} and C{rsh} for secure access to remote shells, but the +protocol also includes the ability to open arbitrary channels to remote +services across an encrypted tunnel. (This is how C{sftp} works, for example.) + +The high-level client API starts with creation of an L{SSHClient} object. +For more direct control, pass a socket (or socket-like object) to a +L{Transport}, and use L{start_server } or +L{start_client } to negoatite +with the remote host as either a server or client. As a client, you are +responsible for authenticating using a password or private key, and checking +the server's host key. I{(Key signature and verification is done by paramiko, +but you will need to provide private keys and check that the content of a +public key matches what you expected to see.)} As a server, you are +responsible for deciding which users, passwords, and keys to allow, and what +kind of channels to allow. + +Once you have finished, either side may request flow-controlled L{Channel}s to +the other side, which are python objects that act like sockets, but send and +receive data over the encrypted session. + +Paramiko is written entirely in python (no C or platform-dependent code) and is +released under the GNU Lesser General Public License (LGPL). + +Website: U{http://www.lag.net/paramiko/} + +@version: 1.7.6 (Fanny) +@author: Robey Pointer +@contact: robeypointer@gmail.com +@license: GNU Lesser General Public License (LGPL) +""" + +import sys + +if sys.version_info < (2, 2): + raise RuntimeError('You need python 2.2 for this module.') + + +__author__ = "Robey Pointer " +__date__ = "1 Nov 2009" +__version__ = "1.7.6 (Fanny)" +__version_info__ = (1, 7, 6) +__license__ = "GNU Lesser General Public License (LGPL)" + + +from transport import randpool, SecurityOptions, Transport +from client import SSHClient, MissingHostKeyPolicy, AutoAddPolicy, RejectPolicy, WarningPolicy +from auth_handler import AuthHandler +from channel import Channel, ChannelFile +from ssh_exception import SSHException, PasswordRequiredException, \ + BadAuthenticationType, ChannelException, BadHostKeyException, \ + AuthenticationException +from server import ServerInterface, SubsystemHandler, InteractiveQuery +from rsakey import RSAKey +from dsskey import DSSKey +from sftp import SFTPError, BaseSFTP +from sftp_client import SFTP, SFTPClient +from sftp_server import SFTPServer +from sftp_attr import SFTPAttributes +from sftp_handle import SFTPHandle +from sftp_si import SFTPServerInterface +from sftp_file import SFTPFile +from message import Message +from packet import Packetizer +from file import BufferedFile +from agent import Agent, AgentKey +from pkey import PKey +from hostkeys import HostKeys +from config import SSHConfig + +# fix module names for epydoc +for c in locals().values(): + if issubclass(type(c), type) or type(c).__name__ == 'classobj': + # classobj for exceptions :/ + c.__module__ = __name__ +del c + +from common import AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED, \ + OPEN_SUCCEEDED, OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, OPEN_FAILED_CONNECT_FAILED, \ + OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, OPEN_FAILED_RESOURCE_SHORTAGE + +from sftp import SFTP_OK, SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, \ + SFTP_BAD_MESSAGE, SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED + +__all__ = [ 'Transport', + 'SSHClient', + 'MissingHostKeyPolicy', + 'AutoAddPolicy', + 'RejectPolicy', + 'WarningPolicy', + 'SecurityOptions', + 'SubsystemHandler', + 'Channel', + 'PKey', + 'RSAKey', + 'DSSKey', + 'Message', + 'SSHException', + 'AuthenticationException', + 'PasswordRequiredException', + 'BadAuthenticationType', + 'ChannelException', + 'BadHostKeyException', + 'SFTP', + 'SFTPFile', + 'SFTPHandle', + 'SFTPClient', + 'SFTPServer', + 'SFTPError', + 'SFTPAttributes', + 'SFTPServerInterface', + 'ServerInterface', + 'BufferedFile', + 'Agent', + 'AgentKey', + 'HostKeys', + 'SSHConfig', + 'util' ] diff --git a/tools/migration/paramiko/agent.py b/tools/migration/paramiko/agent.py new file mode 100644 index 00000000000..71de8b84da5 --- /dev/null +++ b/tools/migration/paramiko/agent.py @@ -0,0 +1,151 @@ +# Copyright (C) 2003-2007 John Rochester +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +SSH Agent interface for Unix clients. +""" + +import os +import socket +import struct +import sys + +from paramiko.ssh_exception import SSHException +from paramiko.message import Message +from paramiko.pkey import PKey + + +SSH2_AGENTC_REQUEST_IDENTITIES, SSH2_AGENT_IDENTITIES_ANSWER, \ + SSH2_AGENTC_SIGN_REQUEST, SSH2_AGENT_SIGN_RESPONSE = range(11, 15) + + +class Agent: + """ + Client interface for using private keys from an SSH agent running on the + local machine. If an SSH agent is running, this class can be used to + connect to it and retreive L{PKey} objects which can be used when + attempting to authenticate to remote SSH servers. + + Because the SSH agent protocol uses environment variables and unix-domain + sockets, this probably doesn't work on Windows. It does work on most + posix platforms though (Linux and MacOS X, for example). + """ + + def __init__(self): + """ + Open a session with the local machine's SSH agent, if one is running. + If no agent is running, initialization will succeed, but L{get_keys} + will return an empty tuple. + + @raise SSHException: if an SSH agent is found, but speaks an + incompatible protocol + """ + self.keys = () + if ('SSH_AUTH_SOCK' in os.environ) and (sys.platform != 'win32'): + conn = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) + try: + conn.connect(os.environ['SSH_AUTH_SOCK']) + except: + # probably a dangling env var: the ssh agent is gone + return + self.conn = conn + elif sys.platform == 'win32': + import win_pageant + if win_pageant.can_talk_to_agent(): + self.conn = win_pageant.PageantConnection() + else: + return + else: + # no agent support + return + + ptype, result = self._send_message(chr(SSH2_AGENTC_REQUEST_IDENTITIES)) + if ptype != SSH2_AGENT_IDENTITIES_ANSWER: + raise SSHException('could not get keys from ssh-agent') + keys = [] + for i in range(result.get_int()): + keys.append(AgentKey(self, result.get_string())) + result.get_string() + self.keys = tuple(keys) + + def close(self): + """ + Close the SSH agent connection. + """ + self.conn.close() + self.conn = None + self.keys = () + + def get_keys(self): + """ + Return the list of keys available through the SSH agent, if any. If + no SSH agent was running (or it couldn't be contacted), an empty list + will be returned. + + @return: a list of keys available on the SSH agent + @rtype: tuple of L{AgentKey} + """ + return self.keys + + def _send_message(self, msg): + msg = str(msg) + self.conn.send(struct.pack('>I', len(msg)) + msg) + l = self._read_all(4) + msg = Message(self._read_all(struct.unpack('>I', l)[0])) + return ord(msg.get_byte()), msg + + def _read_all(self, wanted): + result = self.conn.recv(wanted) + while len(result) < wanted: + if len(result) == 0: + raise SSHException('lost ssh-agent') + extra = self.conn.recv(wanted - len(result)) + if len(extra) == 0: + raise SSHException('lost ssh-agent') + result += extra + return result + + +class AgentKey(PKey): + """ + Private key held in a local SSH agent. This type of key can be used for + authenticating to a remote server (signing). Most other key operations + work as expected. + """ + + def __init__(self, agent, blob): + self.agent = agent + self.blob = blob + self.name = Message(blob).get_string() + + def __str__(self): + return self.blob + + def get_name(self): + return self.name + + def sign_ssh_data(self, randpool, data): + msg = Message() + msg.add_byte(chr(SSH2_AGENTC_SIGN_REQUEST)) + msg.add_string(self.blob) + msg.add_string(data) + msg.add_int(0) + ptype, result = self.agent._send_message(msg) + if ptype != SSH2_AGENT_SIGN_RESPONSE: + raise SSHException('key cannot be used for signing') + return result.get_string() diff --git a/tools/migration/paramiko/auth_handler.py b/tools/migration/paramiko/auth_handler.py new file mode 100644 index 00000000000..0f2e4f66a2f --- /dev/null +++ b/tools/migration/paramiko/auth_handler.py @@ -0,0 +1,426 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{AuthHandler} +""" + +import threading +import weakref + +# this helps freezing utils +import encodings.utf_8 + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException, AuthenticationException, \ + BadAuthenticationType, PartialAuthentication +from paramiko.server import InteractiveQuery + + +class AuthHandler (object): + """ + Internal class to handle the mechanics of authentication. + """ + + def __init__(self, transport): + self.transport = weakref.proxy(transport) + self.username = None + self.authenticated = False + self.auth_event = None + self.auth_method = '' + self.password = None + self.private_key = None + self.interactive_handler = None + self.submethods = None + # for server mode: + self.auth_username = None + self.auth_fail_count = 0 + + def is_authenticated(self): + return self.authenticated + + def get_username(self): + if self.transport.server_mode: + return self.auth_username + else: + return self.username + + def auth_none(self, username, event): + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'none' + self.username = username + self._request_auth() + finally: + self.transport.lock.release() + + def auth_publickey(self, username, key, event): + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'publickey' + self.username = username + self.private_key = key + self._request_auth() + finally: + self.transport.lock.release() + + def auth_password(self, username, password, event): + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'password' + self.username = username + self.password = password + self._request_auth() + finally: + self.transport.lock.release() + + def auth_interactive(self, username, handler, event, submethods=''): + """ + response_list = handler(title, instructions, prompt_list) + """ + self.transport.lock.acquire() + try: + self.auth_event = event + self.auth_method = 'keyboard-interactive' + self.username = username + self.interactive_handler = handler + self.submethods = submethods + self._request_auth() + finally: + self.transport.lock.release() + + def abort(self): + if self.auth_event is not None: + self.auth_event.set() + + + ### internals... + + + def _request_auth(self): + m = Message() + m.add_byte(chr(MSG_SERVICE_REQUEST)) + m.add_string('ssh-userauth') + self.transport._send_message(m) + + def _disconnect_service_not_available(self): + m = Message() + m.add_byte(chr(MSG_DISCONNECT)) + m.add_int(DISCONNECT_SERVICE_NOT_AVAILABLE) + m.add_string('Service not available') + m.add_string('en') + self.transport._send_message(m) + self.transport.close() + + def _disconnect_no_more_auth(self): + m = Message() + m.add_byte(chr(MSG_DISCONNECT)) + m.add_int(DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE) + m.add_string('No more auth methods available') + m.add_string('en') + self.transport._send_message(m) + self.transport.close() + + def _get_session_blob(self, key, service, username): + m = Message() + m.add_string(self.transport.session_id) + m.add_byte(chr(MSG_USERAUTH_REQUEST)) + m.add_string(username) + m.add_string(service) + m.add_string('publickey') + m.add_boolean(1) + m.add_string(key.get_name()) + m.add_string(str(key)) + return str(m) + + def wait_for_response(self, event): + while True: + event.wait(0.1) + if not self.transport.is_active(): + e = self.transport.get_exception() + if (e is None) or issubclass(e.__class__, EOFError): + e = AuthenticationException('Authentication failed.') + raise e + if event.isSet(): + break + if not self.is_authenticated(): + e = self.transport.get_exception() + if e is None: + e = AuthenticationException('Authentication failed.') + # this is horrible. python Exception isn't yet descended from + # object, so type(e) won't work. :( + if issubclass(e.__class__, PartialAuthentication): + return e.allowed_types + raise e + return [] + + def _parse_service_request(self, m): + service = m.get_string() + if self.transport.server_mode and (service == 'ssh-userauth'): + # accepted + m = Message() + m.add_byte(chr(MSG_SERVICE_ACCEPT)) + m.add_string(service) + self.transport._send_message(m) + return + # dunno this one + self._disconnect_service_not_available() + + def _parse_service_accept(self, m): + service = m.get_string() + if service == 'ssh-userauth': + self.transport._log(DEBUG, 'userauth is OK') + m = Message() + m.add_byte(chr(MSG_USERAUTH_REQUEST)) + m.add_string(self.username) + m.add_string('ssh-connection') + m.add_string(self.auth_method) + if self.auth_method == 'password': + m.add_boolean(False) + password = self.password + if isinstance(password, unicode): + password = password.encode('UTF-8') + m.add_string(password) + elif self.auth_method == 'publickey': + m.add_boolean(True) + m.add_string(self.private_key.get_name()) + m.add_string(str(self.private_key)) + blob = self._get_session_blob(self.private_key, 'ssh-connection', self.username) + sig = self.private_key.sign_ssh_data(self.transport.randpool, blob) + m.add_string(str(sig)) + elif self.auth_method == 'keyboard-interactive': + m.add_string('') + m.add_string(self.submethods) + elif self.auth_method == 'none': + pass + else: + raise SSHException('Unknown auth method "%s"' % self.auth_method) + self.transport._send_message(m) + else: + self.transport._log(DEBUG, 'Service request "%s" accepted (?)' % service) + + def _send_auth_result(self, username, method, result): + # okay, send result + m = Message() + if result == AUTH_SUCCESSFUL: + self.transport._log(INFO, 'Auth granted (%s).' % method) + m.add_byte(chr(MSG_USERAUTH_SUCCESS)) + self.authenticated = True + else: + self.transport._log(INFO, 'Auth rejected (%s).' % method) + m.add_byte(chr(MSG_USERAUTH_FAILURE)) + m.add_string(self.transport.server_object.get_allowed_auths(username)) + if result == AUTH_PARTIALLY_SUCCESSFUL: + m.add_boolean(1) + else: + m.add_boolean(0) + self.auth_fail_count += 1 + self.transport._send_message(m) + if self.auth_fail_count >= 10: + self._disconnect_no_more_auth() + if result == AUTH_SUCCESSFUL: + self.transport._auth_trigger() + + def _interactive_query(self, q): + # make interactive query instead of response + m = Message() + m.add_byte(chr(MSG_USERAUTH_INFO_REQUEST)) + m.add_string(q.name) + m.add_string(q.instructions) + m.add_string('') + m.add_int(len(q.prompts)) + for p in q.prompts: + m.add_string(p[0]) + m.add_boolean(p[1]) + self.transport._send_message(m) + + def _parse_userauth_request(self, m): + if not self.transport.server_mode: + # er, uh... what? + m = Message() + m.add_byte(chr(MSG_USERAUTH_FAILURE)) + m.add_string('none') + m.add_boolean(0) + self.transport._send_message(m) + return + if self.authenticated: + # ignore + return + username = m.get_string() + service = m.get_string() + method = m.get_string() + self.transport._log(DEBUG, 'Auth request (type=%s) service=%s, username=%s' % (method, service, username)) + if service != 'ssh-connection': + self._disconnect_service_not_available() + return + if (self.auth_username is not None) and (self.auth_username != username): + self.transport._log(WARNING, 'Auth rejected because the client attempted to change username in mid-flight') + self._disconnect_no_more_auth() + return + self.auth_username = username + + if method == 'none': + result = self.transport.server_object.check_auth_none(username) + elif method == 'password': + changereq = m.get_boolean() + password = m.get_string() + try: + password = password.decode('UTF-8') + except UnicodeError: + # some clients/servers expect non-utf-8 passwords! + # in this case, just return the raw byte string. + pass + if changereq: + # always treated as failure, since we don't support changing passwords, but collect + # the list of valid auth types from the callback anyway + self.transport._log(DEBUG, 'Auth request to change passwords (rejected)') + newpassword = m.get_string() + try: + newpassword = newpassword.decode('UTF-8', 'replace') + except UnicodeError: + pass + result = AUTH_FAILED + else: + result = self.transport.server_object.check_auth_password(username, password) + elif method == 'publickey': + sig_attached = m.get_boolean() + keytype = m.get_string() + keyblob = m.get_string() + try: + key = self.transport._key_info[keytype](Message(keyblob)) + except SSHException, e: + self.transport._log(INFO, 'Auth rejected: public key: %s' % str(e)) + key = None + except: + self.transport._log(INFO, 'Auth rejected: unsupported or mangled public key') + key = None + if key is None: + self._disconnect_no_more_auth() + return + # first check if this key is okay... if not, we can skip the verify + result = self.transport.server_object.check_auth_publickey(username, key) + if result != AUTH_FAILED: + # key is okay, verify it + if not sig_attached: + # client wants to know if this key is acceptable, before it + # signs anything... send special "ok" message + m = Message() + m.add_byte(chr(MSG_USERAUTH_PK_OK)) + m.add_string(keytype) + m.add_string(keyblob) + self.transport._send_message(m) + return + sig = Message(m.get_string()) + blob = self._get_session_blob(key, service, username) + if not key.verify_ssh_sig(blob, sig): + self.transport._log(INFO, 'Auth rejected: invalid signature') + result = AUTH_FAILED + elif method == 'keyboard-interactive': + lang = m.get_string() + submethods = m.get_string() + result = self.transport.server_object.check_auth_interactive(username, submethods) + if isinstance(result, InteractiveQuery): + # make interactive query instead of response + self._interactive_query(result) + return + else: + result = self.transport.server_object.check_auth_none(username) + # okay, send result + self._send_auth_result(username, method, result) + + def _parse_userauth_success(self, m): + self.transport._log(INFO, 'Authentication (%s) successful!' % self.auth_method) + self.authenticated = True + self.transport._auth_trigger() + if self.auth_event != None: + self.auth_event.set() + + def _parse_userauth_failure(self, m): + authlist = m.get_list() + partial = m.get_boolean() + if partial: + self.transport._log(INFO, 'Authentication continues...') + self.transport._log(DEBUG, 'Methods: ' + str(authlist)) + self.transport.saved_exception = PartialAuthentication(authlist) + elif self.auth_method not in authlist: + self.transport._log(DEBUG, 'Authentication type (%s) not permitted.' % self.auth_method) + self.transport._log(DEBUG, 'Allowed methods: ' + str(authlist)) + self.transport.saved_exception = BadAuthenticationType('Bad authentication type', authlist) + else: + self.transport._log(INFO, 'Authentication (%s) failed.' % self.auth_method) + self.authenticated = False + self.username = None + if self.auth_event != None: + self.auth_event.set() + + def _parse_userauth_banner(self, m): + banner = m.get_string() + lang = m.get_string() + self.transport._log(INFO, 'Auth banner: ' + banner) + # who cares. + + def _parse_userauth_info_request(self, m): + if self.auth_method != 'keyboard-interactive': + raise SSHException('Illegal info request from server') + title = m.get_string() + instructions = m.get_string() + m.get_string() # lang + prompts = m.get_int() + prompt_list = [] + for i in range(prompts): + prompt_list.append((m.get_string(), m.get_boolean())) + response_list = self.interactive_handler(title, instructions, prompt_list) + + m = Message() + m.add_byte(chr(MSG_USERAUTH_INFO_RESPONSE)) + m.add_int(len(response_list)) + for r in response_list: + m.add_string(r) + self.transport._send_message(m) + + def _parse_userauth_info_response(self, m): + if not self.transport.server_mode: + raise SSHException('Illegal info response from server') + n = m.get_int() + responses = [] + for i in range(n): + responses.append(m.get_string()) + result = self.transport.server_object.check_auth_interactive_response(responses) + if isinstance(type(result), InteractiveQuery): + # make interactive query instead of response + self._interactive_query(result) + return + self._send_auth_result(self.auth_username, 'keyboard-interactive', result) + + + _handler_table = { + MSG_SERVICE_REQUEST: _parse_service_request, + MSG_SERVICE_ACCEPT: _parse_service_accept, + MSG_USERAUTH_REQUEST: _parse_userauth_request, + MSG_USERAUTH_SUCCESS: _parse_userauth_success, + MSG_USERAUTH_FAILURE: _parse_userauth_failure, + MSG_USERAUTH_BANNER: _parse_userauth_banner, + MSG_USERAUTH_INFO_REQUEST: _parse_userauth_info_request, + MSG_USERAUTH_INFO_RESPONSE: _parse_userauth_info_response, + } + diff --git a/tools/migration/paramiko/ber.py b/tools/migration/paramiko/ber.py new file mode 100644 index 00000000000..19568dd5e94 --- /dev/null +++ b/tools/migration/paramiko/ber.py @@ -0,0 +1,129 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + + +import util + + +class BERException (Exception): + pass + + +class BER(object): + """ + Robey's tiny little attempt at a BER decoder. + """ + + def __init__(self, content=''): + self.content = content + self.idx = 0 + + def __str__(self): + return self.content + + def __repr__(self): + return 'BER(\'' + repr(self.content) + '\')' + + def decode(self): + return self.decode_next() + + def decode_next(self): + if self.idx >= len(self.content): + return None + ident = ord(self.content[self.idx]) + self.idx += 1 + if (ident & 31) == 31: + # identifier > 30 + ident = 0 + while self.idx < len(self.content): + t = ord(self.content[self.idx]) + self.idx += 1 + ident = (ident << 7) | (t & 0x7f) + if not (t & 0x80): + break + if self.idx >= len(self.content): + return None + # now fetch length + size = ord(self.content[self.idx]) + self.idx += 1 + if size & 0x80: + # more complimicated... + # FIXME: theoretically should handle indefinite-length (0x80) + t = size & 0x7f + if self.idx + t > len(self.content): + return None + size = util.inflate_long(self.content[self.idx : self.idx + t], True) + self.idx += t + if self.idx + size > len(self.content): + # can't fit + return None + data = self.content[self.idx : self.idx + size] + self.idx += size + # now switch on id + if ident == 0x30: + # sequence + return self.decode_sequence(data) + elif ident == 2: + # int + return util.inflate_long(data) + else: + # 1: boolean (00 false, otherwise true) + raise BERException('Unknown ber encoding type %d (robey is lazy)' % ident) + + def decode_sequence(data): + out = [] + b = BER(data) + while True: + x = b.decode_next() + if x is None: + break + out.append(x) + return out + decode_sequence = staticmethod(decode_sequence) + + def encode_tlv(self, ident, val): + # no need to support ident > 31 here + self.content += chr(ident) + if len(val) > 0x7f: + lenstr = util.deflate_long(len(val)) + self.content += chr(0x80 + len(lenstr)) + lenstr + else: + self.content += chr(len(val)) + self.content += val + + def encode(self, x): + if type(x) is bool: + if x: + self.encode_tlv(1, '\xff') + else: + self.encode_tlv(1, '\x00') + elif (type(x) is int) or (type(x) is long): + self.encode_tlv(2, util.deflate_long(x)) + elif type(x) is str: + self.encode_tlv(4, x) + elif (type(x) is list) or (type(x) is tuple): + self.encode_tlv(0x30, self.encode_sequence(x)) + else: + raise BERException('Unknown type for encoding: %s' % repr(type(x))) + + def encode_sequence(data): + b = BER() + for item in data: + b.encode(item) + return str(b) + encode_sequence = staticmethod(encode_sequence) diff --git a/tools/migration/paramiko/buffered_pipe.py b/tools/migration/paramiko/buffered_pipe.py new file mode 100644 index 00000000000..b19d74b21c7 --- /dev/null +++ b/tools/migration/paramiko/buffered_pipe.py @@ -0,0 +1,200 @@ +# Copyright (C) 2006-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Attempt to generalize the "feeder" part of a Channel: an object which can be +read from and closed, but is reading from a buffer fed by another thread. The +read operations are blocking and can have a timeout set. +""" + +import array +import threading +import time + + +class PipeTimeout (IOError): + """ + Indicates that a timeout was reached on a read from a L{BufferedPipe}. + """ + pass + + +class BufferedPipe (object): + """ + A buffer that obeys normal read (with timeout) & close semantics for a + file or socket, but is fed data from another thread. This is used by + L{Channel}. + """ + + def __init__(self): + self._lock = threading.Lock() + self._cv = threading.Condition(self._lock) + self._event = None + self._buffer = array.array('B') + self._closed = False + + def set_event(self, event): + """ + Set an event on this buffer. When data is ready to be read (or the + buffer has been closed), the event will be set. When no data is + ready, the event will be cleared. + + @param event: the event to set/clear + @type event: Event + """ + self._event = event + if len(self._buffer) > 0: + event.set() + else: + event.clear() + + def feed(self, data): + """ + Feed new data into this pipe. This method is assumed to be called + from a separate thread, so synchronization is done. + + @param data: the data to add + @type data: str + """ + self._lock.acquire() + try: + if self._event is not None: + self._event.set() + self._buffer.fromstring(data) + self._cv.notifyAll() + finally: + self._lock.release() + + def read_ready(self): + """ + Returns true if data is buffered and ready to be read from this + feeder. A C{False} result does not mean that the feeder has closed; + it means you may need to wait before more data arrives. + + @return: C{True} if a L{read} call would immediately return at least + one byte; C{False} otherwise. + @rtype: bool + """ + self._lock.acquire() + try: + if len(self._buffer) == 0: + return False + return True + finally: + self._lock.release() + + def read(self, nbytes, timeout=None): + """ + Read data from the pipe. The return value is a string representing + the data received. The maximum amount of data to be received at once + is specified by C{nbytes}. If a string of length zero is returned, + the pipe has been closed. + + The optional C{timeout} argument can be a nonnegative float expressing + seconds, or C{None} for no timeout. If a float is given, a + C{PipeTimeout} will be raised if the timeout period value has + elapsed before any data arrives. + + @param nbytes: maximum number of bytes to read + @type nbytes: int + @param timeout: maximum seconds to wait (or C{None}, the default, to + wait forever) + @type timeout: float + @return: data + @rtype: str + + @raise PipeTimeout: if a timeout was specified and no data was ready + before that timeout + """ + out = '' + self._lock.acquire() + try: + if len(self._buffer) == 0: + if self._closed: + return out + # should we block? + if timeout == 0.0: + raise PipeTimeout() + # loop here in case we get woken up but a different thread has + # grabbed everything in the buffer. + while (len(self._buffer) == 0) and not self._closed: + then = time.time() + self._cv.wait(timeout) + if timeout is not None: + timeout -= time.time() - then + if timeout <= 0.0: + raise PipeTimeout() + + # something's in the buffer and we have the lock! + if len(self._buffer) <= nbytes: + out = self._buffer.tostring() + del self._buffer[:] + if (self._event is not None) and not self._closed: + self._event.clear() + else: + out = self._buffer[:nbytes].tostring() + del self._buffer[:nbytes] + finally: + self._lock.release() + + return out + + def empty(self): + """ + Clear out the buffer and return all data that was in it. + + @return: any data that was in the buffer prior to clearing it out + @rtype: str + """ + self._lock.acquire() + try: + out = self._buffer.tostring() + del self._buffer[:] + if (self._event is not None) and not self._closed: + self._event.clear() + return out + finally: + self._lock.release() + + def close(self): + """ + Close this pipe object. Future calls to L{read} after the buffer + has been emptied will return immediately with an empty string. + """ + self._lock.acquire() + try: + self._closed = True + self._cv.notifyAll() + if self._event is not None: + self._event.set() + finally: + self._lock.release() + + def __len__(self): + """ + Return the number of bytes buffered. + + @return: number of bytes bufferes + @rtype: int + """ + self._lock.acquire() + try: + return len(self._buffer) + finally: + self._lock.release() + diff --git a/tools/migration/paramiko/channel.py b/tools/migration/paramiko/channel.py new file mode 100644 index 00000000000..4694eef3644 --- /dev/null +++ b/tools/migration/paramiko/channel.py @@ -0,0 +1,1234 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Abstraction for an SSH2 channel. +""" + +import binascii +import sys +import time +import threading +import socket +import os + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException +from paramiko.file import BufferedFile +from paramiko.buffered_pipe import BufferedPipe, PipeTimeout +from paramiko import pipe + + +# lower bound on the max packet size we'll accept from the remote host +MIN_PACKET_SIZE = 1024 + + +class Channel (object): + """ + A secure tunnel across an SSH L{Transport}. A Channel is meant to behave + like a socket, and has an API that should be indistinguishable from the + python socket API. + + Because SSH2 has a windowing kind of flow control, if you stop reading data + from a Channel and its buffer fills up, the server will be unable to send + you any more data until you read some of it. (This won't affect other + channels on the same transport -- all channels on a single transport are + flow-controlled independently.) Similarly, if the server isn't reading + data you send, calls to L{send} may block, unless you set a timeout. This + is exactly like a normal network socket, so it shouldn't be too surprising. + """ + + def __init__(self, chanid): + """ + Create a new channel. The channel is not associated with any + particular session or L{Transport} until the Transport attaches it. + Normally you would only call this method from the constructor of a + subclass of L{Channel}. + + @param chanid: the ID of this channel, as passed by an existing + L{Transport}. + @type chanid: int + """ + self.chanid = chanid + self.remote_chanid = 0 + self.transport = None + self.active = False + self.eof_received = 0 + self.eof_sent = 0 + self.in_buffer = BufferedPipe() + self.in_stderr_buffer = BufferedPipe() + self.timeout = None + self.closed = False + self.ultra_debug = False + self.lock = threading.Lock() + self.out_buffer_cv = threading.Condition(self.lock) + self.in_window_size = 0 + self.out_window_size = 0 + self.in_max_packet_size = 0 + self.out_max_packet_size = 0 + self.in_window_threshold = 0 + self.in_window_sofar = 0 + self.status_event = threading.Event() + self._name = str(chanid) + self.logger = util.get_logger('paramiko.transport') + self._pipe = None + self.event = threading.Event() + self.event_ready = False + self.combine_stderr = False + self.exit_status = -1 + self.origin_addr = None + + def __del__(self): + try: + self.close() + except: + pass + + def __repr__(self): + """ + Return a string representation of this object, for debugging. + + @rtype: str + """ + out = ' 0: + out += ' in-buffer=%d' % (len(self.in_buffer),) + out += ' -> ' + repr(self.transport) + out += '>' + return out + + def get_pty(self, term='vt100', width=80, height=24): + """ + Request a pseudo-terminal from the server. This is usually used right + after creating a client channel, to ask the server to provide some + basic terminal semantics for a shell invoked with L{invoke_shell}. + It isn't necessary (or desirable) to call this method if you're going + to exectue a single command with L{exec_command}. + + @param term: the terminal type to emulate (for example, C{'vt100'}) + @type term: str + @param width: width (in characters) of the terminal screen + @type width: int + @param height: height (in characters) of the terminal screen + @type height: int + + @raise SSHException: if the request was rejected or the channel was + closed + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('pty-req') + m.add_boolean(True) + m.add_string(term) + m.add_int(width) + m.add_int(height) + # pixel height, width (usually useless) + m.add_int(0).add_int(0) + m.add_string('') + self._event_pending() + self.transport._send_user_message(m) + self._wait_for_event() + + def invoke_shell(self): + """ + Request an interactive shell session on this channel. If the server + allows it, the channel will then be directly connected to the stdin, + stdout, and stderr of the shell. + + Normally you would call L{get_pty} before this, in which case the + shell will operate through the pty, and the channel will be connected + to the stdin and stdout of the pty. + + When the shell exits, the channel will be closed and can't be reused. + You must open a new channel if you wish to open another shell. + + @raise SSHException: if the request was rejected or the channel was + closed + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('shell') + m.add_boolean(1) + self._event_pending() + self.transport._send_user_message(m) + self._wait_for_event() + + def exec_command(self, command): + """ + Execute a command on the server. If the server allows it, the channel + will then be directly connected to the stdin, stdout, and stderr of + the command being executed. + + When the command finishes executing, the channel will be closed and + can't be reused. You must open a new channel if you wish to execute + another command. + + @param command: a shell command to execute. + @type command: str + + @raise SSHException: if the request was rejected or the channel was + closed + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('exec') + m.add_boolean(True) + m.add_string(command) + self._event_pending() + self.transport._send_user_message(m) + self._wait_for_event() + + def invoke_subsystem(self, subsystem): + """ + Request a subsystem on the server (for example, C{sftp}). If the + server allows it, the channel will then be directly connected to the + requested subsystem. + + When the subsystem finishes, the channel will be closed and can't be + reused. + + @param subsystem: name of the subsystem being requested. + @type subsystem: str + + @raise SSHException: if the request was rejected or the channel was + closed + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('subsystem') + m.add_boolean(True) + m.add_string(subsystem) + self._event_pending() + self.transport._send_user_message(m) + self._wait_for_event() + + def resize_pty(self, width=80, height=24): + """ + Resize the pseudo-terminal. This can be used to change the width and + height of the terminal emulation created in a previous L{get_pty} call. + + @param width: new width (in characters) of the terminal screen + @type width: int + @param height: new height (in characters) of the terminal screen + @type height: int + + @raise SSHException: if the request was rejected or the channel was + closed + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('window-change') + m.add_boolean(True) + m.add_int(width) + m.add_int(height) + m.add_int(0).add_int(0) + self._event_pending() + self.transport._send_user_message(m) + self._wait_for_event() + + def exit_status_ready(self): + """ + Return true if the remote process has exited and returned an exit + status. You may use this to poll the process status if you don't + want to block in L{recv_exit_status}. Note that the server may not + return an exit status in some cases (like bad servers). + + @return: True if L{recv_exit_status} will return immediately + @rtype: bool + @since: 1.7.3 + """ + return self.closed or self.status_event.isSet() + + def recv_exit_status(self): + """ + Return the exit status from the process on the server. This is + mostly useful for retrieving the reults of an L{exec_command}. + If the command hasn't finished yet, this method will wait until + it does, or until the channel is closed. If no exit status is + provided by the server, -1 is returned. + + @return: the exit code of the process on the server. + @rtype: int + + @since: 1.2 + """ + self.status_event.wait() + assert self.status_event.isSet() + return self.exit_status + + def send_exit_status(self, status): + """ + Send the exit status of an executed command to the client. (This + really only makes sense in server mode.) Many clients expect to + get some sort of status code back from an executed command after + it completes. + + @param status: the exit code of the process + @type status: int + + @since: 1.2 + """ + # in many cases, the channel will not still be open here. + # that's fine. + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('exit-status') + m.add_boolean(False) + m.add_int(status) + self.transport._send_user_message(m) + + def request_x11(self, screen_number=0, auth_protocol=None, auth_cookie=None, + single_connection=False, handler=None): + """ + Request an x11 session on this channel. If the server allows it, + further x11 requests can be made from the server to the client, + when an x11 application is run in a shell session. + + From RFC4254:: + + It is RECOMMENDED that the 'x11 authentication cookie' that is + sent be a fake, random cookie, and that the cookie be checked and + replaced by the real cookie when a connection request is received. + + If you omit the auth_cookie, a new secure random 128-bit value will be + generated, used, and returned. You will need to use this value to + verify incoming x11 requests and replace them with the actual local + x11 cookie (which requires some knoweldge of the x11 protocol). + + If a handler is passed in, the handler is called from another thread + whenever a new x11 connection arrives. The default handler queues up + incoming x11 connections, which may be retrieved using + L{Transport.accept}. The handler's calling signature is:: + + handler(channel: Channel, (address: str, port: int)) + + @param screen_number: the x11 screen number (0, 10, etc) + @type screen_number: int + @param auth_protocol: the name of the X11 authentication method used; + if none is given, C{"MIT-MAGIC-COOKIE-1"} is used + @type auth_protocol: str + @param auth_cookie: hexadecimal string containing the x11 auth cookie; + if none is given, a secure random 128-bit value is generated + @type auth_cookie: str + @param single_connection: if True, only a single x11 connection will be + forwarded (by default, any number of x11 connections can arrive + over this session) + @type single_connection: bool + @param handler: an optional handler to use for incoming X11 connections + @type handler: function + @return: the auth_cookie used + """ + if self.closed or self.eof_received or self.eof_sent or not self.active: + raise SSHException('Channel is not open') + if auth_protocol is None: + auth_protocol = 'MIT-MAGIC-COOKIE-1' + if auth_cookie is None: + auth_cookie = binascii.hexlify(self.transport.randpool.get_bytes(16)) + + m = Message() + m.add_byte(chr(MSG_CHANNEL_REQUEST)) + m.add_int(self.remote_chanid) + m.add_string('x11-req') + m.add_boolean(True) + m.add_boolean(single_connection) + m.add_string(auth_protocol) + m.add_string(auth_cookie) + m.add_int(screen_number) + self._event_pending() + self.transport._send_user_message(m) + self._wait_for_event() + self.transport._set_x11_handler(handler) + return auth_cookie + + def get_transport(self): + """ + Return the L{Transport} associated with this channel. + + @return: the L{Transport} that was used to create this channel. + @rtype: L{Transport} + """ + return self.transport + + def set_name(self, name): + """ + Set a name for this channel. Currently it's only used to set the name + of the channel in logfile entries. The name can be fetched with the + L{get_name} method. + + @param name: new channel name + @type name: str + """ + self._name = name + + def get_name(self): + """ + Get the name of this channel that was previously set by L{set_name}. + + @return: the name of this channel. + @rtype: str + """ + return self._name + + def get_id(self): + """ + Return the ID # for this channel. The channel ID is unique across + a L{Transport} and usually a small number. It's also the number + passed to L{ServerInterface.check_channel_request} when determining + whether to accept a channel request in server mode. + + @return: the ID of this channel. + @rtype: int + """ + return self.chanid + + def set_combine_stderr(self, combine): + """ + Set whether stderr should be combined into stdout on this channel. + The default is C{False}, but in some cases it may be convenient to + have both streams combined. + + If this is C{False}, and L{exec_command} is called (or C{invoke_shell} + with no pty), output to stderr will not show up through the L{recv} + and L{recv_ready} calls. You will have to use L{recv_stderr} and + L{recv_stderr_ready} to get stderr output. + + If this is C{True}, data will never show up via L{recv_stderr} or + L{recv_stderr_ready}. + + @param combine: C{True} if stderr output should be combined into + stdout on this channel. + @type combine: bool + @return: previous setting. + @rtype: bool + + @since: 1.1 + """ + data = '' + self.lock.acquire() + try: + old = self.combine_stderr + self.combine_stderr = combine + if combine and not old: + # copy old stderr buffer into primary buffer + data = self.in_stderr_buffer.empty() + finally: + self.lock.release() + if len(data) > 0: + self._feed(data) + return old + + + ### socket API + + + def settimeout(self, timeout): + """ + Set a timeout on blocking read/write operations. The C{timeout} + argument can be a nonnegative float expressing seconds, or C{None}. If + a float is given, subsequent channel read/write operations will raise + a timeout exception if the timeout period value has elapsed before the + operation has completed. Setting a timeout of C{None} disables + timeouts on socket operations. + + C{chan.settimeout(0.0)} is equivalent to C{chan.setblocking(0)}; + C{chan.settimeout(None)} is equivalent to C{chan.setblocking(1)}. + + @param timeout: seconds to wait for a pending read/write operation + before raising C{socket.timeout}, or C{None} for no timeout. + @type timeout: float + """ + self.timeout = timeout + + def gettimeout(self): + """ + Returns the timeout in seconds (as a float) associated with socket + operations, or C{None} if no timeout is set. This reflects the last + call to L{setblocking} or L{settimeout}. + + @return: timeout in seconds, or C{None}. + @rtype: float + """ + return self.timeout + + def setblocking(self, blocking): + """ + Set blocking or non-blocking mode of the channel: if C{blocking} is 0, + the channel is set to non-blocking mode; otherwise it's set to blocking + mode. Initially all channels are in blocking mode. + + In non-blocking mode, if a L{recv} call doesn't find any data, or if a + L{send} call can't immediately dispose of the data, an error exception + is raised. In blocking mode, the calls block until they can proceed. An + EOF condition is considered "immediate data" for L{recv}, so if the + channel is closed in the read direction, it will never block. + + C{chan.setblocking(0)} is equivalent to C{chan.settimeout(0)}; + C{chan.setblocking(1)} is equivalent to C{chan.settimeout(None)}. + + @param blocking: 0 to set non-blocking mode; non-0 to set blocking + mode. + @type blocking: int + """ + if blocking: + self.settimeout(None) + else: + self.settimeout(0.0) + + def getpeername(self): + """ + Return the address of the remote side of this Channel, if possible. + This is just a wrapper around C{'getpeername'} on the Transport, used + to provide enough of a socket-like interface to allow asyncore to work. + (asyncore likes to call C{'getpeername'}.) + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + return self.transport.getpeername() + + def close(self): + """ + Close the channel. All future read/write operations on the channel + will fail. The remote end will receive no more data (after queued data + is flushed). Channels are automatically closed when their L{Transport} + is closed or when they are garbage collected. + """ + self.lock.acquire() + try: + # only close the pipe when the user explicitly closes the channel. + # otherwise they will get unpleasant surprises. (and do it before + # checking self.closed, since the remote host may have already + # closed the connection.) + if self._pipe is not None: + self._pipe.close() + self._pipe = None + + if not self.active or self.closed: + return + msgs = self._close_internal() + finally: + self.lock.release() + for m in msgs: + if m is not None: + self.transport._send_user_message(m) + + def recv_ready(self): + """ + Returns true if data is buffered and ready to be read from this + channel. A C{False} result does not mean that the channel has closed; + it means you may need to wait before more data arrives. + + @return: C{True} if a L{recv} call on this channel would immediately + return at least one byte; C{False} otherwise. + @rtype: boolean + """ + return self.in_buffer.read_ready() + + def recv(self, nbytes): + """ + Receive data from the channel. The return value is a string + representing the data received. The maximum amount of data to be + received at once is specified by C{nbytes}. If a string of length zero + is returned, the channel stream has closed. + + @param nbytes: maximum number of bytes to read. + @type nbytes: int + @return: data. + @rtype: str + + @raise socket.timeout: if no data is ready before the timeout set by + L{settimeout}. + """ + try: + out = self.in_buffer.read(nbytes, self.timeout) + except PipeTimeout, e: + raise socket.timeout() + + ack = self._check_add_window(len(out)) + # no need to hold the channel lock when sending this + if ack > 0: + m = Message() + m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m.add_int(self.remote_chanid) + m.add_int(ack) + self.transport._send_user_message(m) + + return out + + def recv_stderr_ready(self): + """ + Returns true if data is buffered and ready to be read from this + channel's stderr stream. Only channels using L{exec_command} or + L{invoke_shell} without a pty will ever have data on the stderr + stream. + + @return: C{True} if a L{recv_stderr} call on this channel would + immediately return at least one byte; C{False} otherwise. + @rtype: boolean + + @since: 1.1 + """ + return self.in_stderr_buffer.read_ready() + + def recv_stderr(self, nbytes): + """ + Receive data from the channel's stderr stream. Only channels using + L{exec_command} or L{invoke_shell} without a pty will ever have data + on the stderr stream. The return value is a string representing the + data received. The maximum amount of data to be received at once is + specified by C{nbytes}. If a string of length zero is returned, the + channel stream has closed. + + @param nbytes: maximum number of bytes to read. + @type nbytes: int + @return: data. + @rtype: str + + @raise socket.timeout: if no data is ready before the timeout set by + L{settimeout}. + + @since: 1.1 + """ + try: + out = self.in_stderr_buffer.read(nbytes, self.timeout) + except PipeTimeout, e: + raise socket.timeout() + + ack = self._check_add_window(len(out)) + # no need to hold the channel lock when sending this + if ack > 0: + m = Message() + m.add_byte(chr(MSG_CHANNEL_WINDOW_ADJUST)) + m.add_int(self.remote_chanid) + m.add_int(ack) + self.transport._send_user_message(m) + + return out + + def send_ready(self): + """ + Returns true if data can be written to this channel without blocking. + This means the channel is either closed (so any write attempt would + return immediately) or there is at least one byte of space in the + outbound buffer. If there is at least one byte of space in the + outbound buffer, a L{send} call will succeed immediately and return + the number of bytes actually written. + + @return: C{True} if a L{send} call on this channel would immediately + succeed or fail + @rtype: boolean + """ + self.lock.acquire() + try: + if self.closed or self.eof_sent: + return True + return self.out_window_size > 0 + finally: + self.lock.release() + + def send(self, s): + """ + Send data to the channel. Returns the number of bytes sent, or 0 if + the channel stream is closed. Applications are responsible for + checking that all data has been sent: if only some of the data was + transmitted, the application needs to attempt delivery of the remaining + data. + + @param s: data to send + @type s: str + @return: number of bytes actually sent + @rtype: int + + @raise socket.timeout: if no data could be sent before the timeout set + by L{settimeout}. + """ + size = len(s) + self.lock.acquire() + try: + size = self._wait_for_send_window(size) + if size == 0: + # eof or similar + return 0 + m = Message() + m.add_byte(chr(MSG_CHANNEL_DATA)) + m.add_int(self.remote_chanid) + m.add_string(s[:size]) + finally: + self.lock.release() + # Note: We release self.lock before calling _send_user_message. + # Otherwise, we can deadlock during re-keying. + self.transport._send_user_message(m) + return size + + def send_stderr(self, s): + """ + Send data to the channel on the "stderr" stream. This is normally + only used by servers to send output from shell commands -- clients + won't use this. Returns the number of bytes sent, or 0 if the channel + stream is closed. Applications are responsible for checking that all + data has been sent: if only some of the data was transmitted, the + application needs to attempt delivery of the remaining data. + + @param s: data to send. + @type s: str + @return: number of bytes actually sent. + @rtype: int + + @raise socket.timeout: if no data could be sent before the timeout set + by L{settimeout}. + + @since: 1.1 + """ + size = len(s) + self.lock.acquire() + try: + size = self._wait_for_send_window(size) + if size == 0: + # eof or similar + return 0 + m = Message() + m.add_byte(chr(MSG_CHANNEL_EXTENDED_DATA)) + m.add_int(self.remote_chanid) + m.add_int(1) + m.add_string(s[:size]) + finally: + self.lock.release() + # Note: We release self.lock before calling _send_user_message. + # Otherwise, we can deadlock during re-keying. + self.transport._send_user_message(m) + return size + + def sendall(self, s): + """ + Send data to the channel, without allowing partial results. Unlike + L{send}, this method continues to send data from the given string until + either all data has been sent or an error occurs. Nothing is returned. + + @param s: data to send. + @type s: str + + @raise socket.timeout: if sending stalled for longer than the timeout + set by L{settimeout}. + @raise socket.error: if an error occured before the entire string was + sent. + + @note: If the channel is closed while only part of the data hase been + sent, there is no way to determine how much data (if any) was sent. + This is irritating, but identically follows python's API. + """ + while s: + if self.closed: + # this doesn't seem useful, but it is the documented behavior of Socket + raise socket.error('Socket is closed') + sent = self.send(s) + s = s[sent:] + return None + + def sendall_stderr(self, s): + """ + Send data to the channel's "stderr" stream, without allowing partial + results. Unlike L{send_stderr}, this method continues to send data + from the given string until all data has been sent or an error occurs. + Nothing is returned. + + @param s: data to send to the client as "stderr" output. + @type s: str + + @raise socket.timeout: if sending stalled for longer than the timeout + set by L{settimeout}. + @raise socket.error: if an error occured before the entire string was + sent. + + @since: 1.1 + """ + while s: + if self.closed: + raise socket.error('Socket is closed') + sent = self.send_stderr(s) + s = s[sent:] + return None + + def makefile(self, *params): + """ + Return a file-like object associated with this channel. The optional + C{mode} and C{bufsize} arguments are interpreted the same way as by + the built-in C{file()} function in python. + + @return: object which can be used for python file I/O. + @rtype: L{ChannelFile} + """ + return ChannelFile(*([self] + list(params))) + + def makefile_stderr(self, *params): + """ + Return a file-like object associated with this channel's stderr + stream. Only channels using L{exec_command} or L{invoke_shell} + without a pty will ever have data on the stderr stream. + + The optional C{mode} and C{bufsize} arguments are interpreted the + same way as by the built-in C{file()} function in python. For a + client, it only makes sense to open this file for reading. For a + server, it only makes sense to open this file for writing. + + @return: object which can be used for python file I/O. + @rtype: L{ChannelFile} + + @since: 1.1 + """ + return ChannelStderrFile(*([self] + list(params))) + + def fileno(self): + """ + Returns an OS-level file descriptor which can be used for polling, but + but I{not} for reading or writing. This is primaily to allow python's + C{select} module to work. + + The first time C{fileno} is called on a channel, a pipe is created to + simulate real OS-level file descriptor (FD) behavior. Because of this, + two OS-level FDs are created, which will use up FDs faster than normal. + (You won't notice this effect unless you have hundreds of channels + open at the same time.) + + @return: an OS-level file descriptor + @rtype: int + + @warning: This method causes channel reads to be slightly less + efficient. + """ + self.lock.acquire() + try: + if self._pipe is not None: + return self._pipe.fileno() + # create the pipe and feed in any existing data + self._pipe = pipe.make_pipe() + p1, p2 = pipe.make_or_pipe(self._pipe) + self.in_buffer.set_event(p1) + self.in_stderr_buffer.set_event(p2) + return self._pipe.fileno() + finally: + self.lock.release() + + def shutdown(self, how): + """ + Shut down one or both halves of the connection. If C{how} is 0, + further receives are disallowed. If C{how} is 1, further sends + are disallowed. If C{how} is 2, further sends and receives are + disallowed. This closes the stream in one or both directions. + + @param how: 0 (stop receiving), 1 (stop sending), or 2 (stop + receiving and sending). + @type how: int + """ + if (how == 0) or (how == 2): + # feign "read" shutdown + self.eof_received = 1 + if (how == 1) or (how == 2): + self.lock.acquire() + try: + m = self._send_eof() + finally: + self.lock.release() + if m is not None: + self.transport._send_user_message(m) + + def shutdown_read(self): + """ + Shutdown the receiving side of this socket, closing the stream in + the incoming direction. After this call, future reads on this + channel will fail instantly. This is a convenience method, equivalent + to C{shutdown(0)}, for people who don't make it a habit to + memorize unix constants from the 1970s. + + @since: 1.2 + """ + self.shutdown(0) + + def shutdown_write(self): + """ + Shutdown the sending side of this socket, closing the stream in + the outgoing direction. After this call, future writes on this + channel will fail instantly. This is a convenience method, equivalent + to C{shutdown(1)}, for people who don't make it a habit to + memorize unix constants from the 1970s. + + @since: 1.2 + """ + self.shutdown(1) + + + ### calls from Transport + + + def _set_transport(self, transport): + self.transport = transport + self.logger = util.get_logger(self.transport.get_log_channel()) + + def _set_window(self, window_size, max_packet_size): + self.in_window_size = window_size + self.in_max_packet_size = max_packet_size + # threshold of bytes we receive before we bother to send a window update + self.in_window_threshold = window_size // 10 + self.in_window_sofar = 0 + self._log(DEBUG, 'Max packet in: %d bytes' % max_packet_size) + + def _set_remote_channel(self, chanid, window_size, max_packet_size): + self.remote_chanid = chanid + self.out_window_size = window_size + self.out_max_packet_size = max(max_packet_size, MIN_PACKET_SIZE) + self.active = 1 + self._log(DEBUG, 'Max packet out: %d bytes' % max_packet_size) + + def _request_success(self, m): + self._log(DEBUG, 'Sesch channel %d request ok' % self.chanid) + self.event_ready = True + self.event.set() + return + + def _request_failed(self, m): + self.lock.acquire() + try: + msgs = self._close_internal() + finally: + self.lock.release() + for m in msgs: + if m is not None: + self.transport._send_user_message(m) + + def _feed(self, m): + if type(m) is str: + # passed from _feed_extended + s = m + else: + s = m.get_string() + self.in_buffer.feed(s) + + def _feed_extended(self, m): + code = m.get_int() + s = m.get_string() + if code != 1: + self._log(ERROR, 'unknown extended_data type %d; discarding' % code) + return + if self.combine_stderr: + self._feed(s) + else: + self.in_stderr_buffer.feed(s) + + def _window_adjust(self, m): + nbytes = m.get_int() + self.lock.acquire() + try: + if self.ultra_debug: + self._log(DEBUG, 'window up %d' % nbytes) + self.out_window_size += nbytes + self.out_buffer_cv.notifyAll() + finally: + self.lock.release() + + def _handle_request(self, m): + key = m.get_string() + want_reply = m.get_boolean() + server = self.transport.server_object + ok = False + if key == 'exit-status': + self.exit_status = m.get_int() + self.status_event.set() + ok = True + elif key == 'xon-xoff': + # ignore + ok = True + elif key == 'pty-req': + term = m.get_string() + width = m.get_int() + height = m.get_int() + pixelwidth = m.get_int() + pixelheight = m.get_int() + modes = m.get_string() + if server is None: + ok = False + else: + ok = server.check_channel_pty_request(self, term, width, height, pixelwidth, + pixelheight, modes) + elif key == 'shell': + if server is None: + ok = False + else: + ok = server.check_channel_shell_request(self) + elif key == 'exec': + cmd = m.get_string() + if server is None: + ok = False + else: + ok = server.check_channel_exec_request(self, cmd) + elif key == 'subsystem': + name = m.get_string() + if server is None: + ok = False + else: + ok = server.check_channel_subsystem_request(self, name) + elif key == 'window-change': + width = m.get_int() + height = m.get_int() + pixelwidth = m.get_int() + pixelheight = m.get_int() + if server is None: + ok = False + else: + ok = server.check_channel_window_change_request(self, width, height, pixelwidth, + pixelheight) + elif key == 'x11-req': + single_connection = m.get_boolean() + auth_proto = m.get_string() + auth_cookie = m.get_string() + screen_number = m.get_int() + if server is None: + ok = False + else: + ok = server.check_channel_x11_request(self, single_connection, + auth_proto, auth_cookie, screen_number) + else: + self._log(DEBUG, 'Unhandled channel request "%s"' % key) + ok = False + if want_reply: + m = Message() + if ok: + m.add_byte(chr(MSG_CHANNEL_SUCCESS)) + else: + m.add_byte(chr(MSG_CHANNEL_FAILURE)) + m.add_int(self.remote_chanid) + self.transport._send_user_message(m) + + def _handle_eof(self, m): + self.lock.acquire() + try: + if not self.eof_received: + self.eof_received = True + self.in_buffer.close() + self.in_stderr_buffer.close() + if self._pipe is not None: + self._pipe.set_forever() + finally: + self.lock.release() + self._log(DEBUG, 'EOF received (%s)', self._name) + + def _handle_close(self, m): + self.lock.acquire() + try: + msgs = self._close_internal() + self.transport._unlink_channel(self.chanid) + finally: + self.lock.release() + for m in msgs: + if m is not None: + self.transport._send_user_message(m) + + + ### internals... + + + def _log(self, level, msg, *args): + self.logger.log(level, "[chan " + self._name + "] " + msg, *args) + + def _event_pending(self): + self.event.clear() + self.event_ready = False + + def _wait_for_event(self): + self.event.wait() + assert self.event.isSet() + if self.event_ready: + return + e = self.transport.get_exception() + if e is None: + e = SSHException('Channel closed.') + raise e + + def _set_closed(self): + # you are holding the lock. + self.closed = True + self.in_buffer.close() + self.in_stderr_buffer.close() + self.out_buffer_cv.notifyAll() + # Notify any waiters that we are closed + self.event.set() + self.status_event.set() + if self._pipe is not None: + self._pipe.set_forever() + + def _send_eof(self): + # you are holding the lock. + if self.eof_sent: + return None + m = Message() + m.add_byte(chr(MSG_CHANNEL_EOF)) + m.add_int(self.remote_chanid) + self.eof_sent = True + self._log(DEBUG, 'EOF sent (%s)', self._name) + return m + + def _close_internal(self): + # you are holding the lock. + if not self.active or self.closed: + return None, None + m1 = self._send_eof() + m2 = Message() + m2.add_byte(chr(MSG_CHANNEL_CLOSE)) + m2.add_int(self.remote_chanid) + self._set_closed() + # can't unlink from the Transport yet -- the remote side may still + # try to send meta-data (exit-status, etc) + return m1, m2 + + def _unlink(self): + # server connection could die before we become active: still signal the close! + if self.closed: + return + self.lock.acquire() + try: + self._set_closed() + self.transport._unlink_channel(self.chanid) + finally: + self.lock.release() + + def _check_add_window(self, n): + self.lock.acquire() + try: + if self.closed or self.eof_received or not self.active: + return 0 + if self.ultra_debug: + self._log(DEBUG, 'addwindow %d' % n) + self.in_window_sofar += n + if self.in_window_sofar <= self.in_window_threshold: + return 0 + if self.ultra_debug: + self._log(DEBUG, 'addwindow send %d' % self.in_window_sofar) + out = self.in_window_sofar + self.in_window_sofar = 0 + return out + finally: + self.lock.release() + + def _wait_for_send_window(self, size): + """ + (You are already holding the lock.) + Wait for the send window to open up, and allocate up to C{size} bytes + for transmission. If no space opens up before the timeout, a timeout + exception is raised. Returns the number of bytes available to send + (may be less than requested). + """ + # you are already holding the lock + if self.closed or self.eof_sent: + return 0 + if self.out_window_size == 0: + # should we block? + if self.timeout == 0.0: + raise socket.timeout() + # loop here in case we get woken up but a different thread has filled the buffer + timeout = self.timeout + while self.out_window_size == 0: + if self.closed or self.eof_sent: + return 0 + then = time.time() + self.out_buffer_cv.wait(timeout) + if timeout != None: + timeout -= time.time() - then + if timeout <= 0.0: + raise socket.timeout() + # we have some window to squeeze into + if self.closed or self.eof_sent: + return 0 + if self.out_window_size < size: + size = self.out_window_size + if self.out_max_packet_size - 64 < size: + size = self.out_max_packet_size - 64 + self.out_window_size -= size + if self.ultra_debug: + self._log(DEBUG, 'window down to %d' % self.out_window_size) + return size + + +class ChannelFile (BufferedFile): + """ + A file-like wrapper around L{Channel}. A ChannelFile is created by calling + L{Channel.makefile}. + + @bug: To correctly emulate the file object created from a socket's + C{makefile} method, a L{Channel} and its C{ChannelFile} should be able + to be closed or garbage-collected independently. Currently, closing + the C{ChannelFile} does nothing but flush the buffer. + """ + + def __init__(self, channel, mode = 'r', bufsize = -1): + self.channel = channel + BufferedFile.__init__(self) + self._set_mode(mode, bufsize) + + def __repr__(self): + """ + Returns a string representation of this object, for debugging. + + @rtype: str + """ + return '' + + def _read(self, size): + return self.channel.recv(size) + + def _write(self, data): + self.channel.sendall(data) + return len(data) + + +class ChannelStderrFile (ChannelFile): + def __init__(self, channel, mode = 'r', bufsize = -1): + ChannelFile.__init__(self, channel, mode, bufsize) + + def _read(self, size): + return self.channel.recv_stderr(size) + + def _write(self, data): + self.channel.sendall_stderr(data) + return len(data) + + +# vim: set shiftwidth=4 expandtab : diff --git a/tools/migration/paramiko/client.py b/tools/migration/paramiko/client.py new file mode 100644 index 00000000000..023b40505fb --- /dev/null +++ b/tools/migration/paramiko/client.py @@ -0,0 +1,486 @@ +# Copyright (C) 2006-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{SSHClient}. +""" + +from binascii import hexlify +import getpass +import os +import socket +import warnings + +from paramiko.agent import Agent +from paramiko.common import * +from paramiko.dsskey import DSSKey +from paramiko.hostkeys import HostKeys +from paramiko.resource import ResourceManager +from paramiko.rsakey import RSAKey +from paramiko.ssh_exception import SSHException, BadHostKeyException +from paramiko.transport import Transport + + +SSH_PORT = 22 + +class MissingHostKeyPolicy (object): + """ + Interface for defining the policy that L{SSHClient} should use when the + SSH server's hostname is not in either the system host keys or the + application's keys. Pre-made classes implement policies for automatically + adding the key to the application's L{HostKeys} object (L{AutoAddPolicy}), + and for automatically rejecting the key (L{RejectPolicy}). + + This function may be used to ask the user to verify the key, for example. + """ + + def missing_host_key(self, client, hostname, key): + """ + Called when an L{SSHClient} receives a server key for a server that + isn't in either the system or local L{HostKeys} object. To accept + the key, simply return. To reject, raised an exception (which will + be passed to the calling application). + """ + pass + + +class AutoAddPolicy (MissingHostKeyPolicy): + """ + Policy for automatically adding the hostname and new host key to the + local L{HostKeys} object, and saving it. This is used by L{SSHClient}. + """ + + def missing_host_key(self, client, hostname, key): + client._host_keys.add(hostname, key.get_name(), key) + if client._host_keys_filename is not None: + client.save_host_keys(client._host_keys_filename) + client._log(DEBUG, 'Adding %s host key for %s: %s' % + (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + + +class RejectPolicy (MissingHostKeyPolicy): + """ + Policy for automatically rejecting the unknown hostname & key. This is + used by L{SSHClient}. + """ + + def missing_host_key(self, client, hostname, key): + client._log(DEBUG, 'Rejecting %s host key for %s: %s' % + (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + raise SSHException('Unknown server %s' % hostname) + + +class WarningPolicy (MissingHostKeyPolicy): + """ + Policy for logging a python-style warning for an unknown host key, but + accepting it. This is used by L{SSHClient}. + """ + def missing_host_key(self, client, hostname, key): + warnings.warn('Unknown %s host key for %s: %s' % + (key.get_name(), hostname, hexlify(key.get_fingerprint()))) + + +class SSHClient (object): + """ + A high-level representation of a session with an SSH server. This class + wraps L{Transport}, L{Channel}, and L{SFTPClient} to take care of most + aspects of authenticating and opening channels. A typical use case is:: + + client = SSHClient() + client.load_system_host_keys() + client.connect('ssh.example.com') + stdin, stdout, stderr = client.exec_command('ls -l') + + You may pass in explicit overrides for authentication and server host key + checking. The default mechanism is to try to use local key files or an + SSH agent (if one is running). + + @since: 1.6 + """ + + def __init__(self): + """ + Create a new SSHClient. + """ + self._system_host_keys = HostKeys() + self._host_keys = HostKeys() + self._host_keys_filename = None + self._log_channel = None + self._policy = RejectPolicy() + self._transport = None + + def load_system_host_keys(self, filename=None): + """ + Load host keys from a system (read-only) file. Host keys read with + this method will not be saved back by L{save_host_keys}. + + This method can be called multiple times. Each new set of host keys + will be merged with the existing set (new replacing old if there are + conflicts). + + If C{filename} is left as C{None}, an attempt will be made to read + keys from the user's local "known hosts" file, as used by OpenSSH, + and no exception will be raised if the file can't be read. This is + probably only useful on posix. + + @param filename: the filename to read, or C{None} + @type filename: str + + @raise IOError: if a filename was provided and the file could not be + read + """ + if filename is None: + # try the user's .ssh key file, and mask exceptions + filename = os.path.expanduser('~/.ssh/known_hosts') + try: + self._system_host_keys.load(filename) + except IOError: + pass + return + self._system_host_keys.load(filename) + + def load_host_keys(self, filename): + """ + Load host keys from a local host-key file. Host keys read with this + method will be checked I{after} keys loaded via L{load_system_host_keys}, + but will be saved back by L{save_host_keys} (so they can be modified). + The missing host key policy L{AutoAddPolicy} adds keys to this set and + saves them, when connecting to a previously-unknown server. + + This method can be called multiple times. Each new set of host keys + will be merged with the existing set (new replacing old if there are + conflicts). When automatically saving, the last hostname is used. + + @param filename: the filename to read + @type filename: str + + @raise IOError: if the filename could not be read + """ + self._host_keys_filename = filename + self._host_keys.load(filename) + + def save_host_keys(self, filename): + """ + Save the host keys back to a file. Only the host keys loaded with + L{load_host_keys} (plus any added directly) will be saved -- not any + host keys loaded with L{load_system_host_keys}. + + @param filename: the filename to save to + @type filename: str + + @raise IOError: if the file could not be written + """ + f = open(filename, 'w') + f.write('# SSH host keys collected by paramiko\n') + for hostname, keys in self._host_keys.iteritems(): + for keytype, key in keys.iteritems(): + f.write('%s %s %s\n' % (hostname, keytype, key.get_base64())) + f.close() + + def get_host_keys(self): + """ + Get the local L{HostKeys} object. This can be used to examine the + local host keys or change them. + + @return: the local host keys + @rtype: L{HostKeys} + """ + return self._host_keys + + def set_log_channel(self, name): + """ + Set the channel for logging. The default is C{"paramiko.transport"} + but it can be set to anything you want. + + @param name: new channel name for logging + @type name: str + """ + self._log_channel = name + + def set_missing_host_key_policy(self, policy): + """ + Set the policy to use when connecting to a server that doesn't have a + host key in either the system or local L{HostKeys} objects. The + default policy is to reject all unknown servers (using L{RejectPolicy}). + You may substitute L{AutoAddPolicy} or write your own policy class. + + @param policy: the policy to use when receiving a host key from a + previously-unknown server + @type policy: L{MissingHostKeyPolicy} + """ + self._policy = policy + + def connect(self, hostname, port=SSH_PORT, username=None, password=None, pkey=None, + key_filename=None, timeout=None, allow_agent=True, look_for_keys=True): + """ + Connect to an SSH server and authenticate to it. The server's host key + is checked against the system host keys (see L{load_system_host_keys}) + and any local host keys (L{load_host_keys}). If the server's hostname + is not found in either set of host keys, the missing host key policy + is used (see L{set_missing_host_key_policy}). The default policy is + to reject the key and raise an L{SSHException}. + + Authentication is attempted in the following order of priority: + + - The C{pkey} or C{key_filename} passed in (if any) + - Any key we can find through an SSH agent + - Any "id_rsa" or "id_dsa" key discoverable in C{~/.ssh/} + - Plain username/password auth, if a password was given + + If a private key requires a password to unlock it, and a password is + passed in, that password will be used to attempt to unlock the key. + + @param hostname: the server to connect to + @type hostname: str + @param port: the server port to connect to + @type port: int + @param username: the username to authenticate as (defaults to the + current local username) + @type username: str + @param password: a password to use for authentication or for unlocking + a private key + @type password: str + @param pkey: an optional private key to use for authentication + @type pkey: L{PKey} + @param key_filename: the filename, or list of filenames, of optional + private key(s) to try for authentication + @type key_filename: str or list(str) + @param timeout: an optional timeout (in seconds) for the TCP connect + @type timeout: float + @param allow_agent: set to False to disable connecting to the SSH agent + @type allow_agent: bool + @param look_for_keys: set to False to disable searching for discoverable + private key files in C{~/.ssh/} + @type look_for_keys: bool + + @raise BadHostKeyException: if the server's host key could not be + verified + @raise AuthenticationException: if authentication failed + @raise SSHException: if there was any other error connecting or + establishing an SSH session + @raise socket.error: if a socket error occurred while connecting + """ + for (family, socktype, proto, canonname, sockaddr) in socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM): + if socktype == socket.SOCK_STREAM: + af = family + addr = sockaddr + break + else: + raise SSHException('No suitable address family for %s' % hostname) + sock = socket.socket(af, socket.SOCK_STREAM) + if timeout is not None: + try: + sock.settimeout(timeout) + except: + pass + sock.connect(addr) + t = self._transport = Transport(sock) + + if self._log_channel is not None: + t.set_log_channel(self._log_channel) + t.start_client() + ResourceManager.register(self, t) + + server_key = t.get_remote_server_key() + keytype = server_key.get_name() + + if port == SSH_PORT: + server_hostkey_name = hostname + else: + server_hostkey_name = "[%s]:%d" % (hostname, port) + our_server_key = self._system_host_keys.get(server_hostkey_name, {}).get(keytype, None) + if our_server_key is None: + our_server_key = self._host_keys.get(server_hostkey_name, {}).get(keytype, None) + if our_server_key is None: + # will raise exception if the key is rejected; let that fall out + self._policy.missing_host_key(self, server_hostkey_name, server_key) + # if the callback returns, assume the key is ok + our_server_key = server_key + + if server_key != our_server_key: + raise BadHostKeyException(hostname, server_key, our_server_key) + + if username is None: + username = getpass.getuser() + + if key_filename is None: + key_filenames = [] + elif isinstance(key_filename, (str, unicode)): + key_filenames = [ key_filename ] + else: + key_filenames = key_filename + self._auth(username, password, pkey, key_filenames, allow_agent, look_for_keys) + + def close(self): + """ + Close this SSHClient and its underlying L{Transport}. + """ + if self._transport is None: + return + self._transport.close() + self._transport = None + + def exec_command(self, command, bufsize=-1): + """ + Execute a command on the SSH server. A new L{Channel} is opened and + the requested command is executed. The command's input and output + streams are returned as python C{file}-like objects representing + stdin, stdout, and stderr. + + @param command: the command to execute + @type command: str + @param bufsize: interpreted the same way as by the built-in C{file()} function in python + @type bufsize: int + @return: the stdin, stdout, and stderr of the executing command + @rtype: tuple(L{ChannelFile}, L{ChannelFile}, L{ChannelFile}) + + @raise SSHException: if the server fails to execute the command + """ + chan = self._transport.open_session() + chan.exec_command(command) + stdin = chan.makefile('wb', bufsize) + stdout = chan.makefile('rb', bufsize) + stderr = chan.makefile_stderr('rb', bufsize) + return stdin, stdout, stderr + + def invoke_shell(self, term='vt100', width=80, height=24): + """ + Start an interactive shell session on the SSH server. A new L{Channel} + is opened and connected to a pseudo-terminal using the requested + terminal type and size. + + @param term: the terminal type to emulate (for example, C{"vt100"}) + @type term: str + @param width: the width (in characters) of the terminal window + @type width: int + @param height: the height (in characters) of the terminal window + @type height: int + @return: a new channel connected to the remote shell + @rtype: L{Channel} + + @raise SSHException: if the server fails to invoke a shell + """ + chan = self._transport.open_session() + chan.get_pty(term, width, height) + chan.invoke_shell() + return chan + + def open_sftp(self): + """ + Open an SFTP session on the SSH server. + + @return: a new SFTP session object + @rtype: L{SFTPClient} + """ + return self._transport.open_sftp_client() + + def get_transport(self): + """ + Return the underlying L{Transport} object for this SSH connection. + This can be used to perform lower-level tasks, like opening specific + kinds of channels. + + @return: the Transport for this connection + @rtype: L{Transport} + """ + return self._transport + + def _auth(self, username, password, pkey, key_filenames, allow_agent, look_for_keys): + """ + Try, in order: + + - The key passed in, if one was passed in. + - Any key we can find through an SSH agent (if allowed). + - Any "id_rsa" or "id_dsa" key discoverable in ~/.ssh/ (if allowed). + - Plain username/password auth, if a password was given. + + (The password might be needed to unlock a private key.) + """ + saved_exception = None + + if pkey is not None: + try: + self._log(DEBUG, 'Trying SSH key %s' % hexlify(pkey.get_fingerprint())) + self._transport.auth_publickey(username, pkey) + return + except SSHException, e: + saved_exception = e + + for key_filename in key_filenames: + for pkey_class in (RSAKey, DSSKey): + try: + key = pkey_class.from_private_key_file(key_filename, password) + self._log(DEBUG, 'Trying key %s from %s' % (hexlify(key.get_fingerprint()), key_filename)) + self._transport.auth_publickey(username, key) + return + except SSHException, e: + saved_exception = e + + if allow_agent: + for key in Agent().get_keys(): + try: + self._log(DEBUG, 'Trying SSH agent key %s' % hexlify(key.get_fingerprint())) + self._transport.auth_publickey(username, key) + return + except SSHException, e: + saved_exception = e + + keyfiles = [] + rsa_key = os.path.expanduser('~/.ssh/id_rsa') + dsa_key = os.path.expanduser('~/.ssh/id_dsa') + if os.path.isfile(rsa_key): + keyfiles.append((RSAKey, rsa_key)) + if os.path.isfile(dsa_key): + keyfiles.append((DSSKey, dsa_key)) + # look in ~/ssh/ for windows users: + rsa_key = os.path.expanduser('~/ssh/id_rsa') + dsa_key = os.path.expanduser('~/ssh/id_dsa') + if os.path.isfile(rsa_key): + keyfiles.append((RSAKey, rsa_key)) + if os.path.isfile(dsa_key): + keyfiles.append((DSSKey, dsa_key)) + + if not look_for_keys: + keyfiles = [] + + for pkey_class, filename in keyfiles: + try: + key = pkey_class.from_private_key_file(filename, password) + self._log(DEBUG, 'Trying discovered key %s in %s' % (hexlify(key.get_fingerprint()), filename)) + self._transport.auth_publickey(username, key) + return + except SSHException, e: + saved_exception = e + except IOError, e: + saved_exception = e + + if password is not None: + try: + self._transport.auth_password(username, password) + return + except SSHException, e: + saved_exception = e + + # if we got an auth-failed exception earlier, re-raise it + if saved_exception is not None: + raise saved_exception + raise SSHException('No authentication methods available') + + def _log(self, level, msg): + self._transport._log(level, msg) + diff --git a/tools/migration/paramiko/common.py b/tools/migration/paramiko/common.py new file mode 100644 index 00000000000..7a374633eb6 --- /dev/null +++ b/tools/migration/paramiko/common.py @@ -0,0 +1,126 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Common constants and global variables. +""" + +MSG_DISCONNECT, MSG_IGNORE, MSG_UNIMPLEMENTED, MSG_DEBUG, MSG_SERVICE_REQUEST, \ + MSG_SERVICE_ACCEPT = range(1, 7) +MSG_KEXINIT, MSG_NEWKEYS = range(20, 22) +MSG_USERAUTH_REQUEST, MSG_USERAUTH_FAILURE, MSG_USERAUTH_SUCCESS, \ + MSG_USERAUTH_BANNER = range(50, 54) +MSG_USERAUTH_PK_OK = 60 +MSG_USERAUTH_INFO_REQUEST, MSG_USERAUTH_INFO_RESPONSE = range(60, 62) +MSG_GLOBAL_REQUEST, MSG_REQUEST_SUCCESS, MSG_REQUEST_FAILURE = range(80, 83) +MSG_CHANNEL_OPEN, MSG_CHANNEL_OPEN_SUCCESS, MSG_CHANNEL_OPEN_FAILURE, \ + MSG_CHANNEL_WINDOW_ADJUST, MSG_CHANNEL_DATA, MSG_CHANNEL_EXTENDED_DATA, \ + MSG_CHANNEL_EOF, MSG_CHANNEL_CLOSE, MSG_CHANNEL_REQUEST, \ + MSG_CHANNEL_SUCCESS, MSG_CHANNEL_FAILURE = range(90, 101) + + +# for debugging: +MSG_NAMES = { + MSG_DISCONNECT: 'disconnect', + MSG_IGNORE: 'ignore', + MSG_UNIMPLEMENTED: 'unimplemented', + MSG_DEBUG: 'debug', + MSG_SERVICE_REQUEST: 'service-request', + MSG_SERVICE_ACCEPT: 'service-accept', + MSG_KEXINIT: 'kexinit', + MSG_NEWKEYS: 'newkeys', + 30: 'kex30', + 31: 'kex31', + 32: 'kex32', + 33: 'kex33', + 34: 'kex34', + MSG_USERAUTH_REQUEST: 'userauth-request', + MSG_USERAUTH_FAILURE: 'userauth-failure', + MSG_USERAUTH_SUCCESS: 'userauth-success', + MSG_USERAUTH_BANNER: 'userauth--banner', + MSG_USERAUTH_PK_OK: 'userauth-60(pk-ok/info-request)', + MSG_USERAUTH_INFO_RESPONSE: 'userauth-info-response', + MSG_GLOBAL_REQUEST: 'global-request', + MSG_REQUEST_SUCCESS: 'request-success', + MSG_REQUEST_FAILURE: 'request-failure', + MSG_CHANNEL_OPEN: 'channel-open', + MSG_CHANNEL_OPEN_SUCCESS: 'channel-open-success', + MSG_CHANNEL_OPEN_FAILURE: 'channel-open-failure', + MSG_CHANNEL_WINDOW_ADJUST: 'channel-window-adjust', + MSG_CHANNEL_DATA: 'channel-data', + MSG_CHANNEL_EXTENDED_DATA: 'channel-extended-data', + MSG_CHANNEL_EOF: 'channel-eof', + MSG_CHANNEL_CLOSE: 'channel-close', + MSG_CHANNEL_REQUEST: 'channel-request', + MSG_CHANNEL_SUCCESS: 'channel-success', + MSG_CHANNEL_FAILURE: 'channel-failure' + } + + +# authentication request return codes: +AUTH_SUCCESSFUL, AUTH_PARTIALLY_SUCCESSFUL, AUTH_FAILED = range(3) + + +# channel request failed reasons: +(OPEN_SUCCEEDED, + OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED, + OPEN_FAILED_CONNECT_FAILED, + OPEN_FAILED_UNKNOWN_CHANNEL_TYPE, + OPEN_FAILED_RESOURCE_SHORTAGE) = range(0, 5) + + +CONNECTION_FAILED_CODE = { + 1: 'Administratively prohibited', + 2: 'Connect failed', + 3: 'Unknown channel type', + 4: 'Resource shortage' +} + + +DISCONNECT_SERVICE_NOT_AVAILABLE, DISCONNECT_AUTH_CANCELLED_BY_USER, \ + DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE = 7, 13, 14 + +from rng import StrongLockingRandomPool + +# keep a crypto-strong PRNG nearby +randpool = StrongLockingRandomPool() + +import sys +if sys.version_info < (2, 3): + try: + import logging + except: + import logging22 as logging + import select + PY22 = True + + import socket + if not hasattr(socket, 'timeout'): + class timeout(socket.error): pass + socket.timeout = timeout + del timeout +else: + import logging + PY22 = False + + +DEBUG = logging.DEBUG +INFO = logging.INFO +WARNING = logging.WARNING +ERROR = logging.ERROR +CRITICAL = logging.CRITICAL diff --git a/tools/migration/paramiko/compress.py b/tools/migration/paramiko/compress.py new file mode 100644 index 00000000000..40b430f9018 --- /dev/null +++ b/tools/migration/paramiko/compress.py @@ -0,0 +1,39 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Compression implementations for a Transport. +""" + +import zlib + + +class ZlibCompressor (object): + def __init__(self): + self.z = zlib.compressobj(9) + + def __call__(self, data): + return self.z.compress(data) + self.z.flush(zlib.Z_FULL_FLUSH) + + +class ZlibDecompressor (object): + def __init__(self): + self.z = zlib.decompressobj() + + def __call__(self, data): + return self.z.decompress(data) diff --git a/tools/migration/paramiko/config.py b/tools/migration/paramiko/config.py new file mode 100644 index 00000000000..2a2cbff3022 --- /dev/null +++ b/tools/migration/paramiko/config.py @@ -0,0 +1,110 @@ +# Copyright (C) 2006-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{SSHConfig}. +""" + +import fnmatch + + +class SSHConfig (object): + """ + Representation of config information as stored in the format used by + OpenSSH. Queries can be made via L{lookup}. The format is described in + OpenSSH's C{ssh_config} man page. This class is provided primarily as a + convenience to posix users (since the OpenSSH format is a de-facto + standard on posix) but should work fine on Windows too. + + @since: 1.6 + """ + + def __init__(self): + """ + Create a new OpenSSH config object. + """ + self._config = [ { 'host': '*' } ] + + def parse(self, file_obj): + """ + Read an OpenSSH config from the given file object. + + @param file_obj: a file-like object to read the config file from + @type file_obj: file + """ + configs = [self._config[0]] + for line in file_obj: + line = line.rstrip('\n').lstrip() + if (line == '') or (line[0] == '#'): + continue + if '=' in line: + key, value = line.split('=', 1) + key = key.strip().lower() + else: + # find first whitespace, and split there + i = 0 + while (i < len(line)) and not line[i].isspace(): + i += 1 + if i == len(line): + raise Exception('Unparsable line: %r' % line) + key = line[:i].lower() + value = line[i:].lstrip() + + if key == 'host': + del configs[:] + # the value may be multiple hosts, space-delimited + for host in value.split(): + # do we have a pre-existing host config to append to? + matches = [c for c in self._config if c['host'] == host] + if len(matches) > 0: + configs.append(matches[0]) + else: + config = { 'host': host } + self._config.append(config) + configs.append(config) + else: + for config in configs: + config[key] = value + + def lookup(self, hostname): + """ + Return a dict of config options for a given hostname. + + The host-matching rules of OpenSSH's C{ssh_config} man page are used, + which means that all configuration options from matching host + specifications are merged, with more specific hostmasks taking + precedence. In other words, if C{"Port"} is set under C{"Host *"} + and also C{"Host *.example.com"}, and the lookup is for + C{"ssh.example.com"}, then the port entry for C{"Host *.example.com"} + will win out. + + The keys in the returned dict are all normalized to lowercase (look for + C{"port"}, not C{"Port"}. No other processing is done to the keys or + values. + + @param hostname: the hostname to lookup + @type hostname: str + """ + matches = [x for x in self._config if fnmatch.fnmatch(hostname, x['host'])] + # sort in order of shortest match (usually '*') to longest + matches.sort(lambda x,y: cmp(len(x['host']), len(y['host']))) + ret = {} + for m in matches: + ret.update(m) + del ret['host'] + return ret diff --git a/tools/migration/paramiko/dsskey.py b/tools/migration/paramiko/dsskey.py new file mode 100644 index 00000000000..eecfa6951c7 --- /dev/null +++ b/tools/migration/paramiko/dsskey.py @@ -0,0 +1,197 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{DSSKey} +""" + +from Crypto.PublicKey import DSA +from Crypto.Hash import SHA + +from paramiko.common import * +from paramiko import util +from paramiko.ssh_exception import SSHException +from paramiko.message import Message +from paramiko.ber import BER, BERException +from paramiko.pkey import PKey + + +class DSSKey (PKey): + """ + Representation of a DSS key which can be used to sign an verify SSH2 + data. + """ + + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None): + self.p = None + self.q = None + self.g = None + self.y = None + self.x = None + if file_obj is not None: + self._from_private_key(file_obj, password) + return + if filename is not None: + self._from_private_key_file(filename, password) + return + if (msg is None) and (data is not None): + msg = Message(data) + if vals is not None: + self.p, self.q, self.g, self.y = vals + else: + if msg is None: + raise SSHException('Key object may not be empty') + if msg.get_string() != 'ssh-dss': + raise SSHException('Invalid key') + self.p = msg.get_mpint() + self.q = msg.get_mpint() + self.g = msg.get_mpint() + self.y = msg.get_mpint() + self.size = util.bit_length(self.p) + + def __str__(self): + m = Message() + m.add_string('ssh-dss') + m.add_mpint(self.p) + m.add_mpint(self.q) + m.add_mpint(self.g) + m.add_mpint(self.y) + return str(m) + + def __hash__(self): + h = hash(self.get_name()) + h = h * 37 + hash(self.p) + h = h * 37 + hash(self.q) + h = h * 37 + hash(self.g) + h = h * 37 + hash(self.y) + # h might be a long by now... + return hash(h) + + def get_name(self): + return 'ssh-dss' + + def get_bits(self): + return self.size + + def can_sign(self): + return self.x is not None + + def sign_ssh_data(self, rpool, data): + digest = SHA.new(data).digest() + dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q), long(self.x))) + # generate a suitable k + qsize = len(util.deflate_long(self.q, 0)) + while True: + k = util.inflate_long(rpool.get_bytes(qsize), 1) + if (k > 2) and (k < self.q): + break + r, s = dss.sign(util.inflate_long(digest, 1), k) + m = Message() + m.add_string('ssh-dss') + # apparently, in rare cases, r or s may be shorter than 20 bytes! + rstr = util.deflate_long(r, 0) + sstr = util.deflate_long(s, 0) + if len(rstr) < 20: + rstr = '\x00' * (20 - len(rstr)) + rstr + if len(sstr) < 20: + sstr = '\x00' * (20 - len(sstr)) + sstr + m.add_string(rstr + sstr) + return m + + def verify_ssh_sig(self, data, msg): + if len(str(msg)) == 40: + # spies.com bug: signature has no header + sig = str(msg) + else: + kind = msg.get_string() + if kind != 'ssh-dss': + return 0 + sig = msg.get_string() + + # pull out (r, s) which are NOT encoded as mpints + sigR = util.inflate_long(sig[:20], 1) + sigS = util.inflate_long(sig[20:], 1) + sigM = util.inflate_long(SHA.new(data).digest(), 1) + + dss = DSA.construct((long(self.y), long(self.g), long(self.p), long(self.q))) + return dss.verify(sigM, (sigR, sigS)) + + def _encode_key(self): + if self.x is None: + raise SSHException('Not enough key information') + keylist = [ 0, self.p, self.q, self.g, self.y, self.x ] + try: + b = BER() + b.encode(keylist) + except BERException: + raise SSHException('Unable to create ber encoding of key') + return str(b) + + def write_private_key_file(self, filename, password=None): + self._write_private_key_file('DSA', filename, self._encode_key(), password) + + def write_private_key(self, file_obj, password=None): + self._write_private_key('DSA', file_obj, self._encode_key(), password) + + def generate(bits=1024, progress_func=None): + """ + Generate a new private DSS key. This factory function can be used to + generate a new host key or authentication key. + + @param bits: number of bits the generated key should be. + @type bits: int + @param progress_func: an optional function to call at key points in + key generation (used by C{pyCrypto.PublicKey}). + @type progress_func: function + @return: new private key + @rtype: L{DSSKey} + """ + randpool.stir() + dsa = DSA.generate(bits, randpool.get_bytes, progress_func) + key = DSSKey(vals=(dsa.p, dsa.q, dsa.g, dsa.y)) + key.x = dsa.x + return key + generate = staticmethod(generate) + + + ### internals... + + + def _from_private_key_file(self, filename, password): + data = self._read_private_key_file('DSA', filename, password) + self._decode_key(data) + + def _from_private_key(self, file_obj, password): + data = self._read_private_key('DSA', file_obj, password) + self._decode_key(data) + + def _decode_key(self, data): + # private key file contains: + # DSAPrivateKey = { version = 0, p, q, g, y, x } + try: + keylist = BER(data).decode() + except BERException, x: + raise SSHException('Unable to parse key file: ' + str(x)) + if (type(keylist) is not list) or (len(keylist) < 6) or (keylist[0] != 0): + raise SSHException('not a valid DSA private key file (bad ber encoding)') + self.p = keylist[1] + self.q = keylist[2] + self.g = keylist[3] + self.y = keylist[4] + self.x = keylist[5] + self.size = util.bit_length(self.p) diff --git a/tools/migration/paramiko/file.py b/tools/migration/paramiko/file.py new file mode 100644 index 00000000000..d4aec8e3c5e --- /dev/null +++ b/tools/migration/paramiko/file.py @@ -0,0 +1,456 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +BufferedFile. +""" + +from cStringIO import StringIO + + +class BufferedFile (object): + """ + Reusable base class to implement python-style file buffering around a + simpler stream. + """ + + _DEFAULT_BUFSIZE = 8192 + + SEEK_SET = 0 + SEEK_CUR = 1 + SEEK_END = 2 + + FLAG_READ = 0x1 + FLAG_WRITE = 0x2 + FLAG_APPEND = 0x4 + FLAG_BINARY = 0x10 + FLAG_BUFFERED = 0x20 + FLAG_LINE_BUFFERED = 0x40 + FLAG_UNIVERSAL_NEWLINE = 0x80 + + def __init__(self): + self.newlines = None + self._flags = 0 + self._bufsize = self._DEFAULT_BUFSIZE + self._wbuffer = StringIO() + self._rbuffer = '' + self._at_trailing_cr = False + self._closed = False + # pos - position within the file, according to the user + # realpos - position according the OS + # (these may be different because we buffer for line reading) + self._pos = self._realpos = 0 + # size only matters for seekable files + self._size = 0 + + def __del__(self): + self.close() + + def __iter__(self): + """ + Returns an iterator that can be used to iterate over the lines in this + file. This iterator happens to return the file itself, since a file is + its own iterator. + + @raise ValueError: if the file is closed. + + @return: an interator. + @rtype: iterator + """ + if self._closed: + raise ValueError('I/O operation on closed file') + return self + + def close(self): + """ + Close the file. Future read and write operations will fail. + """ + self.flush() + self._closed = True + + def flush(self): + """ + Write out any data in the write buffer. This may do nothing if write + buffering is not turned on. + """ + self._write_all(self._wbuffer.getvalue()) + self._wbuffer = StringIO() + return + + def next(self): + """ + Returns the next line from the input, or raises L{StopIteration} when + EOF is hit. Unlike python file objects, it's okay to mix calls to + C{next} and L{readline}. + + @raise StopIteration: when the end of the file is reached. + + @return: a line read from the file. + @rtype: str + """ + line = self.readline() + if not line: + raise StopIteration + return line + + def read(self, size=None): + """ + Read at most C{size} bytes from the file (less if we hit the end of the + file first). If the C{size} argument is negative or omitted, read all + the remaining data in the file. + + @param size: maximum number of bytes to read + @type size: int + @return: data read from the file, or an empty string if EOF was + encountered immediately + @rtype: str + """ + if self._closed: + raise IOError('File is closed') + if not (self._flags & self.FLAG_READ): + raise IOError('File is not open for reading') + if (size is None) or (size < 0): + # go for broke + result = self._rbuffer + self._rbuffer = '' + self._pos += len(result) + while True: + try: + new_data = self._read(self._DEFAULT_BUFSIZE) + except EOFError: + new_data = None + if (new_data is None) or (len(new_data) == 0): + break + result += new_data + self._realpos += len(new_data) + self._pos += len(new_data) + return result + if size <= len(self._rbuffer): + result = self._rbuffer[:size] + self._rbuffer = self._rbuffer[size:] + self._pos += len(result) + return result + while len(self._rbuffer) < size: + read_size = size - len(self._rbuffer) + if self._flags & self.FLAG_BUFFERED: + read_size = max(self._bufsize, read_size) + try: + new_data = self._read(read_size) + except EOFError: + new_data = None + if (new_data is None) or (len(new_data) == 0): + break + self._rbuffer += new_data + self._realpos += len(new_data) + result = self._rbuffer[:size] + self._rbuffer = self._rbuffer[size:] + self._pos += len(result) + return result + + def readline(self, size=None): + """ + Read one entire line from the file. A trailing newline character is + kept in the string (but may be absent when a file ends with an + incomplete line). If the size argument is present and non-negative, it + is a maximum byte count (including the trailing newline) and an + incomplete line may be returned. An empty string is returned only when + EOF is encountered immediately. + + @note: Unlike stdio's C{fgets()}, the returned string contains null + characters (C{'\\0'}) if they occurred in the input. + + @param size: maximum length of returned string. + @type size: int + @return: next line of the file, or an empty string if the end of the + file has been reached. + @rtype: str + """ + # it's almost silly how complex this function is. + if self._closed: + raise IOError('File is closed') + if not (self._flags & self.FLAG_READ): + raise IOError('File not open for reading') + line = self._rbuffer + while True: + if self._at_trailing_cr and (self._flags & self.FLAG_UNIVERSAL_NEWLINE) and (len(line) > 0): + # edge case: the newline may be '\r\n' and we may have read + # only the first '\r' last time. + if line[0] == '\n': + line = line[1:] + self._record_newline('\r\n') + else: + self._record_newline('\r') + self._at_trailing_cr = False + # check size before looking for a linefeed, in case we already have + # enough. + if (size is not None) and (size >= 0): + if len(line) >= size: + # truncate line and return + self._rbuffer = line[size:] + line = line[:size] + self._pos += len(line) + return line + n = size - len(line) + else: + n = self._bufsize + if ('\n' in line) or ((self._flags & self.FLAG_UNIVERSAL_NEWLINE) and ('\r' in line)): + break + try: + new_data = self._read(n) + except EOFError: + new_data = None + if (new_data is None) or (len(new_data) == 0): + self._rbuffer = '' + self._pos += len(line) + return line + line += new_data + self._realpos += len(new_data) + # find the newline + pos = line.find('\n') + if self._flags & self.FLAG_UNIVERSAL_NEWLINE: + rpos = line.find('\r') + if (rpos >= 0) and ((rpos < pos) or (pos < 0)): + pos = rpos + xpos = pos + 1 + if (line[pos] == '\r') and (xpos < len(line)) and (line[xpos] == '\n'): + xpos += 1 + self._rbuffer = line[xpos:] + lf = line[pos:xpos] + line = line[:pos] + '\n' + if (len(self._rbuffer) == 0) and (lf == '\r'): + # we could read the line up to a '\r' and there could still be a + # '\n' following that we read next time. note that and eat it. + self._at_trailing_cr = True + else: + self._record_newline(lf) + self._pos += len(line) + return line + + def readlines(self, sizehint=None): + """ + Read all remaining lines using L{readline} and return them as a list. + If the optional C{sizehint} argument is present, instead of reading up + to EOF, whole lines totalling approximately sizehint bytes (possibly + after rounding up to an internal buffer size) are read. + + @param sizehint: desired maximum number of bytes to read. + @type sizehint: int + @return: list of lines read from the file. + @rtype: list + """ + lines = [] + bytes = 0 + while True: + line = self.readline() + if len(line) == 0: + break + lines.append(line) + bytes += len(line) + if (sizehint is not None) and (bytes >= sizehint): + break + return lines + + def seek(self, offset, whence=0): + """ + Set the file's current position, like stdio's C{fseek}. Not all file + objects support seeking. + + @note: If a file is opened in append mode (C{'a'} or C{'a+'}), any seek + operations will be undone at the next write (as the file position + will move back to the end of the file). + + @param offset: position to move to within the file, relative to + C{whence}. + @type offset: int + @param whence: type of movement: 0 = absolute; 1 = relative to the + current position; 2 = relative to the end of the file. + @type whence: int + + @raise IOError: if the file doesn't support random access. + """ + raise IOError('File does not support seeking.') + + def tell(self): + """ + Return the file's current position. This may not be accurate or + useful if the underlying file doesn't support random access, or was + opened in append mode. + + @return: file position (in bytes). + @rtype: int + """ + return self._pos + + def write(self, data): + """ + Write data to the file. If write buffering is on (C{bufsize} was + specified and non-zero), some or all of the data may not actually be + written yet. (Use L{flush} or L{close} to force buffered data to be + written out.) + + @param data: data to write. + @type data: str + """ + if self._closed: + raise IOError('File is closed') + if not (self._flags & self.FLAG_WRITE): + raise IOError('File not open for writing') + if not (self._flags & self.FLAG_BUFFERED): + self._write_all(data) + return + self._wbuffer.write(data) + if self._flags & self.FLAG_LINE_BUFFERED: + # only scan the new data for linefeed, to avoid wasting time. + last_newline_pos = data.rfind('\n') + if last_newline_pos >= 0: + wbuf = self._wbuffer.getvalue() + last_newline_pos += len(wbuf) - len(data) + self._write_all(wbuf[:last_newline_pos + 1]) + self._wbuffer = StringIO() + self._wbuffer.write(wbuf[last_newline_pos + 1:]) + return + # even if we're line buffering, if the buffer has grown past the + # buffer size, force a flush. + if self._wbuffer.tell() >= self._bufsize: + self.flush() + return + + def writelines(self, sequence): + """ + Write a sequence of strings to the file. The sequence can be any + iterable object producing strings, typically a list of strings. (The + name is intended to match L{readlines}; C{writelines} does not add line + separators.) + + @param sequence: an iterable sequence of strings. + @type sequence: sequence + """ + for line in sequence: + self.write(line) + return + + def xreadlines(self): + """ + Identical to C{iter(f)}. This is a deprecated file interface that + predates python iterator support. + + @return: an iterator. + @rtype: iterator + """ + return self + + + ### overrides... + + + def _read(self, size): + """ + I{(subclass override)} + Read data from the stream. Return C{None} or raise C{EOFError} to + indicate EOF. + """ + raise EOFError() + + def _write(self, data): + """ + I{(subclass override)} + Write data into the stream. + """ + raise IOError('write not implemented') + + def _get_size(self): + """ + I{(subclass override)} + Return the size of the file. This is called from within L{_set_mode} + if the file is opened in append mode, so the file position can be + tracked and L{seek} and L{tell} will work correctly. If the file is + a stream that can't be randomly accessed, you don't need to override + this method, + """ + return 0 + + + ### internals... + + + def _set_mode(self, mode='r', bufsize=-1): + """ + Subclasses call this method to initialize the BufferedFile. + """ + # set bufsize in any event, because it's used for readline(). + self._bufsize = self._DEFAULT_BUFSIZE + if bufsize < 0: + # do no buffering by default, because otherwise writes will get + # buffered in a way that will probably confuse people. + bufsize = 0 + if bufsize == 1: + # apparently, line buffering only affects writes. reads are only + # buffered if you call readline (directly or indirectly: iterating + # over a file will indirectly call readline). + self._flags |= self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED + elif bufsize > 1: + self._bufsize = bufsize + self._flags |= self.FLAG_BUFFERED + self._flags &= ~self.FLAG_LINE_BUFFERED + elif bufsize == 0: + # unbuffered + self._flags &= ~(self.FLAG_BUFFERED | self.FLAG_LINE_BUFFERED) + + if ('r' in mode) or ('+' in mode): + self._flags |= self.FLAG_READ + if ('w' in mode) or ('+' in mode): + self._flags |= self.FLAG_WRITE + if ('a' in mode): + self._flags |= self.FLAG_WRITE | self.FLAG_APPEND + self._size = self._get_size() + self._pos = self._realpos = self._size + if ('b' in mode): + self._flags |= self.FLAG_BINARY + if ('U' in mode): + self._flags |= self.FLAG_UNIVERSAL_NEWLINE + # built-in file objects have this attribute to store which kinds of + # line terminations they've seen: + # + self.newlines = None + + def _write_all(self, data): + # the underlying stream may be something that does partial writes (like + # a socket). + while len(data) > 0: + count = self._write(data) + data = data[count:] + if self._flags & self.FLAG_APPEND: + self._size += count + self._pos = self._realpos = self._size + else: + self._pos += count + self._realpos += count + return None + + def _record_newline(self, newline): + # silliness about tracking what kinds of newlines we've seen. + # i don't understand why it can be None, a string, or a tuple, instead + # of just always being a tuple, but we'll emulate that behavior anyway. + if not (self._flags & self.FLAG_UNIVERSAL_NEWLINE): + return + if self.newlines is None: + self.newlines = newline + elif (type(self.newlines) is str) and (self.newlines != newline): + self.newlines = (self.newlines, newline) + elif newline not in self.newlines: + self.newlines += (newline,) diff --git a/tools/migration/paramiko/hostkeys.py b/tools/migration/paramiko/hostkeys.py new file mode 100644 index 00000000000..9ceef43c1dd --- /dev/null +++ b/tools/migration/paramiko/hostkeys.py @@ -0,0 +1,316 @@ +# Copyright (C) 2006-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{HostKeys} +""" + +import base64 +from Crypto.Hash import SHA, HMAC +import UserDict + +from paramiko.common import * +from paramiko.dsskey import DSSKey +from paramiko.rsakey import RSAKey + + +class HostKeyEntry: + """ + Representation of a line in an OpenSSH-style "known hosts" file. + """ + + def __init__(self, hostnames=None, key=None): + self.valid = (hostnames is not None) and (key is not None) + self.hostnames = hostnames + self.key = key + + def from_line(cls, line): + """ + Parses the given line of text to find the names for the host, + the type of key, and the key data. The line is expected to be in the + format used by the openssh known_hosts file. + + Lines are expected to not have leading or trailing whitespace. + We don't bother to check for comments or empty lines. All of + that should be taken care of before sending the line to us. + + @param line: a line from an OpenSSH known_hosts file + @type line: str + """ + fields = line.split(' ') + if len(fields) < 3: + # Bad number of fields + return None + fields = fields[:3] + + names, keytype, key = fields + names = names.split(',') + + # Decide what kind of key we're looking at and create an object + # to hold it accordingly. + if keytype == 'ssh-rsa': + key = RSAKey(data=base64.decodestring(key)) + elif keytype == 'ssh-dss': + key = DSSKey(data=base64.decodestring(key)) + else: + return None + + return cls(names, key) + from_line = classmethod(from_line) + + def to_line(self): + """ + Returns a string in OpenSSH known_hosts file format, or None if + the object is not in a valid state. A trailing newline is + included. + """ + if self.valid: + return '%s %s %s\n' % (','.join(self.hostnames), self.key.get_name(), + self.key.get_base64()) + return None + + def __repr__(self): + return '' % (self.hostnames, self.key) + + +class HostKeys (UserDict.DictMixin): + """ + Representation of an openssh-style "known hosts" file. Host keys can be + read from one or more files, and then individual hosts can be looked up to + verify server keys during SSH negotiation. + + A HostKeys object can be treated like a dict; any dict lookup is equivalent + to calling L{lookup}. + + @since: 1.5.3 + """ + + def __init__(self, filename=None): + """ + Create a new HostKeys object, optionally loading keys from an openssh + style host-key file. + + @param filename: filename to load host keys from, or C{None} + @type filename: str + """ + # emulate a dict of { hostname: { keytype: PKey } } + self._entries = [] + if filename is not None: + self.load(filename) + + def add(self, hostname, keytype, key): + """ + Add a host key entry to the table. Any existing entry for a + C{(hostname, keytype)} pair will be replaced. + + @param hostname: the hostname (or IP) to add + @type hostname: str + @param keytype: key type (C{"ssh-rsa"} or C{"ssh-dss"}) + @type keytype: str + @param key: the key to add + @type key: L{PKey} + """ + for e in self._entries: + if (hostname in e.hostnames) and (e.key.get_name() == keytype): + e.key = key + return + self._entries.append(HostKeyEntry([hostname], key)) + + def load(self, filename): + """ + Read a file of known SSH host keys, in the format used by openssh. + This type of file unfortunately doesn't exist on Windows, but on + posix, it will usually be stored in + C{os.path.expanduser("~/.ssh/known_hosts")}. + + If this method is called multiple times, the host keys are merged, + not cleared. So multiple calls to C{load} will just call L{add}, + replacing any existing entries and adding new ones. + + @param filename: name of the file to read host keys from + @type filename: str + + @raise IOError: if there was an error reading the file + """ + f = open(filename, 'r') + for line in f: + line = line.strip() + if (len(line) == 0) or (line[0] == '#'): + continue + e = HostKeyEntry.from_line(line) + if e is not None: + self._entries.append(e) + f.close() + + def save(self, filename): + """ + Save host keys into a file, in the format used by openssh. The order of + keys in the file will be preserved when possible (if these keys were + loaded from a file originally). The single exception is that combined + lines will be split into individual key lines, which is arguably a bug. + + @param filename: name of the file to write + @type filename: str + + @raise IOError: if there was an error writing the file + + @since: 1.6.1 + """ + f = open(filename, 'w') + for e in self._entries: + line = e.to_line() + if line: + f.write(line) + f.close() + + def lookup(self, hostname): + """ + Find a hostkey entry for a given hostname or IP. If no entry is found, + C{None} is returned. Otherwise a dictionary of keytype to key is + returned. The keytype will be either C{"ssh-rsa"} or C{"ssh-dss"}. + + @param hostname: the hostname (or IP) to lookup + @type hostname: str + @return: keys associated with this host (or C{None}) + @rtype: dict(str, L{PKey}) + """ + class SubDict (UserDict.DictMixin): + def __init__(self, hostname, entries, hostkeys): + self._hostname = hostname + self._entries = entries + self._hostkeys = hostkeys + + def __getitem__(self, key): + for e in self._entries: + if e.key.get_name() == key: + return e.key + raise KeyError(key) + + def __setitem__(self, key, val): + for e in self._entries: + if e.key is None: + continue + if e.key.get_name() == key: + # replace + e.key = val + break + else: + # add a new one + e = HostKeyEntry([hostname], val) + self._entries.append(e) + self._hostkeys._entries.append(e) + + def keys(self): + return [e.key.get_name() for e in self._entries if e.key is not None] + + entries = [] + for e in self._entries: + for h in e.hostnames: + if (h.startswith('|1|') and (self.hash_host(hostname, h) == h)) or (h == hostname): + entries.append(e) + if len(entries) == 0: + return None + return SubDict(hostname, entries, self) + + def check(self, hostname, key): + """ + Return True if the given key is associated with the given hostname + in this dictionary. + + @param hostname: hostname (or IP) of the SSH server + @type hostname: str + @param key: the key to check + @type key: L{PKey} + @return: C{True} if the key is associated with the hostname; C{False} + if not + @rtype: bool + """ + k = self.lookup(hostname) + if k is None: + return False + host_key = k.get(key.get_name(), None) + if host_key is None: + return False + return str(host_key) == str(key) + + def clear(self): + """ + Remove all host keys from the dictionary. + """ + self._entries = [] + + def __getitem__(self, key): + ret = self.lookup(key) + if ret is None: + raise KeyError(key) + return ret + + def __setitem__(self, hostname, entry): + # don't use this please. + if len(entry) == 0: + self._entries.append(HostKeyEntry([hostname], None)) + return + for key_type in entry.keys(): + found = False + for e in self._entries: + if (hostname in e.hostnames) and (e.key.get_name() == key_type): + # replace + e.key = entry[key_type] + found = True + if not found: + self._entries.append(HostKeyEntry([hostname], entry[key_type])) + + def keys(self): + # python 2.4 sets would be nice here. + ret = [] + for e in self._entries: + for h in e.hostnames: + if h not in ret: + ret.append(h) + return ret + + def values(self): + ret = [] + for k in self.keys(): + ret.append(self.lookup(k)) + return ret + + def hash_host(hostname, salt=None): + """ + Return a "hashed" form of the hostname, as used by openssh when storing + hashed hostnames in the known_hosts file. + + @param hostname: the hostname to hash + @type hostname: str + @param salt: optional salt to use when hashing (must be 20 bytes long) + @type salt: str + @return: the hashed hostname + @rtype: str + """ + if salt is None: + salt = randpool.get_bytes(SHA.digest_size) + else: + if salt.startswith('|1|'): + salt = salt.split('|')[2] + salt = base64.decodestring(salt) + assert len(salt) == SHA.digest_size + hmac = HMAC.HMAC(salt, hostname, SHA).digest() + hostkey = '|1|%s|%s' % (base64.encodestring(salt), base64.encodestring(hmac)) + return hostkey.replace('\n', '') + hash_host = staticmethod(hash_host) + diff --git a/tools/migration/paramiko/kex_gex.py b/tools/migration/paramiko/kex_gex.py new file mode 100644 index 00000000000..c6be638e512 --- /dev/null +++ b/tools/migration/paramiko/kex_gex.py @@ -0,0 +1,244 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Variant on L{KexGroup1 } where the prime "p" and +generator "g" are provided by the server. A bit more work is required on the +client side, and a B{lot} more on the server side. +""" + +from Crypto.Hash import SHA +from Crypto.Util import number + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException + + +_MSG_KEXDH_GEX_REQUEST_OLD, _MSG_KEXDH_GEX_GROUP, _MSG_KEXDH_GEX_INIT, \ + _MSG_KEXDH_GEX_REPLY, _MSG_KEXDH_GEX_REQUEST = range(30, 35) + + +class KexGex (object): + + name = 'diffie-hellman-group-exchange-sha1' + min_bits = 1024 + max_bits = 8192 + preferred_bits = 2048 + + def __init__(self, transport): + self.transport = transport + self.p = None + self.q = None + self.g = None + self.x = None + self.e = None + self.f = None + self.old_style = False + + def start_kex(self, _test_old_style=False): + if self.transport.server_mode: + self.transport._expect_packet(_MSG_KEXDH_GEX_REQUEST, _MSG_KEXDH_GEX_REQUEST_OLD) + return + # request a bit range: we accept (min_bits) to (max_bits), but prefer + # (preferred_bits). according to the spec, we shouldn't pull the + # minimum up above 1024. + m = Message() + if _test_old_style: + # only used for unit tests: we shouldn't ever send this + m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST_OLD)) + m.add_int(self.preferred_bits) + self.old_style = True + else: + m.add_byte(chr(_MSG_KEXDH_GEX_REQUEST)) + m.add_int(self.min_bits) + m.add_int(self.preferred_bits) + m.add_int(self.max_bits) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_GEX_GROUP) + + def parse_next(self, ptype, m): + if ptype == _MSG_KEXDH_GEX_REQUEST: + return self._parse_kexdh_gex_request(m) + elif ptype == _MSG_KEXDH_GEX_GROUP: + return self._parse_kexdh_gex_group(m) + elif ptype == _MSG_KEXDH_GEX_INIT: + return self._parse_kexdh_gex_init(m) + elif ptype == _MSG_KEXDH_GEX_REPLY: + return self._parse_kexdh_gex_reply(m) + elif ptype == _MSG_KEXDH_GEX_REQUEST_OLD: + return self._parse_kexdh_gex_request_old(m) + raise SSHException('KexGex asked to handle packet type %d' % ptype) + + + ### internals... + + + def _generate_x(self): + # generate an "x" (1 < x < (p-1)/2). + q = (self.p - 1) // 2 + qnorm = util.deflate_long(q, 0) + qhbyte = ord(qnorm[0]) + bytes = len(qnorm) + qmask = 0xff + while not (qhbyte & 0x80): + qhbyte <<= 1 + qmask >>= 1 + while True: + self.transport.randpool.stir() + x_bytes = self.transport.randpool.get_bytes(bytes) + x_bytes = chr(ord(x_bytes[0]) & qmask) + x_bytes[1:] + x = util.inflate_long(x_bytes, 1) + if (x > 1) and (x < q): + break + self.x = x + + def _parse_kexdh_gex_request(self, m): + minbits = m.get_int() + preferredbits = m.get_int() + maxbits = m.get_int() + # smoosh the user's preferred size into our own limits + if preferredbits > self.max_bits: + preferredbits = self.max_bits + if preferredbits < self.min_bits: + preferredbits = self.min_bits + # fix min/max if they're inconsistent. technically, we could just pout + # and hang up, but there's no harm in giving them the benefit of the + # doubt and just picking a bitsize for them. + if minbits > preferredbits: + minbits = preferredbits + if maxbits < preferredbits: + maxbits = preferredbits + # now save a copy + self.min_bits = minbits + self.preferred_bits = preferredbits + self.max_bits = maxbits + # generate prime + pack = self.transport._get_modulus_pack() + if pack is None: + raise SSHException('Can\'t do server-side gex with no modulus pack') + self.transport._log(DEBUG, 'Picking p (%d <= %d <= %d bits)' % (minbits, preferredbits, maxbits)) + self.g, self.p = pack.get_modulus(minbits, preferredbits, maxbits) + m = Message() + m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) + m.add_mpint(self.p) + m.add_mpint(self.g) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_GEX_INIT) + + def _parse_kexdh_gex_request_old(self, m): + # same as above, but without min_bits or max_bits (used by older clients like putty) + self.preferred_bits = m.get_int() + # smoosh the user's preferred size into our own limits + if self.preferred_bits > self.max_bits: + self.preferred_bits = self.max_bits + if self.preferred_bits < self.min_bits: + self.preferred_bits = self.min_bits + # generate prime + pack = self.transport._get_modulus_pack() + if pack is None: + raise SSHException('Can\'t do server-side gex with no modulus pack') + self.transport._log(DEBUG, 'Picking p (~ %d bits)' % (self.preferred_bits,)) + self.g, self.p = pack.get_modulus(self.min_bits, self.preferred_bits, self.max_bits) + m = Message() + m.add_byte(chr(_MSG_KEXDH_GEX_GROUP)) + m.add_mpint(self.p) + m.add_mpint(self.g) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_GEX_INIT) + self.old_style = True + + def _parse_kexdh_gex_group(self, m): + self.p = m.get_mpint() + self.g = m.get_mpint() + # reject if p's bit length < 1024 or > 8192 + bitlen = util.bit_length(self.p) + if (bitlen < 1024) or (bitlen > 8192): + raise SSHException('Server-generated gex p (don\'t ask) is out of range (%d bits)' % bitlen) + self.transport._log(DEBUG, 'Got server p (%d bits)' % bitlen) + self._generate_x() + # now compute e = g^x mod p + self.e = pow(self.g, self.x, self.p) + m = Message() + m.add_byte(chr(_MSG_KEXDH_GEX_INIT)) + m.add_mpint(self.e) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_GEX_REPLY) + + def _parse_kexdh_gex_init(self, m): + self.e = m.get_mpint() + if (self.e < 1) or (self.e > self.p - 1): + raise SSHException('Client kex "e" is out of range') + self._generate_x() + self.f = pow(self.g, self.x, self.p) + K = pow(self.e, self.x, self.p) + key = str(self.transport.get_server_key()) + # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) + hm = Message() + hm.add(self.transport.remote_version, self.transport.local_version, + self.transport.remote_kex_init, self.transport.local_kex_init, + key) + if not self.old_style: + hm.add_int(self.min_bits) + hm.add_int(self.preferred_bits) + if not self.old_style: + hm.add_int(self.max_bits) + hm.add_mpint(self.p) + hm.add_mpint(self.g) + hm.add_mpint(self.e) + hm.add_mpint(self.f) + hm.add_mpint(K) + H = SHA.new(str(hm)).digest() + self.transport._set_K_H(K, H) + # sign it + sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) + # send reply + m = Message() + m.add_byte(chr(_MSG_KEXDH_GEX_REPLY)) + m.add_string(key) + m.add_mpint(self.f) + m.add_string(str(sig)) + self.transport._send_message(m) + self.transport._activate_outbound() + + def _parse_kexdh_gex_reply(self, m): + host_key = m.get_string() + self.f = m.get_mpint() + sig = m.get_string() + if (self.f < 1) or (self.f > self.p - 1): + raise SSHException('Server kex "f" is out of range') + K = pow(self.f, self.x, self.p) + # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || min || n || max || p || g || e || f || K) + hm = Message() + hm.add(self.transport.local_version, self.transport.remote_version, + self.transport.local_kex_init, self.transport.remote_kex_init, + host_key) + if not self.old_style: + hm.add_int(self.min_bits) + hm.add_int(self.preferred_bits) + if not self.old_style: + hm.add_int(self.max_bits) + hm.add_mpint(self.p) + hm.add_mpint(self.g) + hm.add_mpint(self.e) + hm.add_mpint(self.f) + hm.add_mpint(K) + self.transport._set_K_H(K, SHA.new(str(hm)).digest()) + self.transport._verify_key(host_key, sig) + self.transport._activate_outbound() diff --git a/tools/migration/paramiko/kex_group1.py b/tools/migration/paramiko/kex_group1.py new file mode 100644 index 00000000000..4228dd9d21b --- /dev/null +++ b/tools/migration/paramiko/kex_group1.py @@ -0,0 +1,136 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Standard SSH key exchange ("kex" if you wanna sound cool). Diffie-Hellman of +1024 bit key halves, using a known "p" prime and "g" generator. +""" + +from Crypto.Hash import SHA + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException + + +_MSG_KEXDH_INIT, _MSG_KEXDH_REPLY = range(30, 32) + +# draft-ietf-secsh-transport-09.txt, page 17 +P = 0xFFFFFFFFFFFFFFFFC90FDAA22168C234C4C6628B80DC1CD129024E088A67CC74020BBEA63B139B22514A08798E3404DDEF9519B3CD3A431B302B0A6DF25F14374FE1356D6D51C245E485B576625E7EC6F44C42E9A637ED6B0BFF5CB6F406B7EDEE386BFB5A899FA5AE9F24117C4B1FE649286651ECE65381FFFFFFFFFFFFFFFFL +G = 2 + + +class KexGroup1(object): + + name = 'diffie-hellman-group1-sha1' + + def __init__(self, transport): + self.transport = transport + self.x = 0L + self.e = 0L + self.f = 0L + + def start_kex(self): + self._generate_x() + if self.transport.server_mode: + # compute f = g^x mod p, but don't send it yet + self.f = pow(G, self.x, P) + self.transport._expect_packet(_MSG_KEXDH_INIT) + return + # compute e = g^x mod p (where g=2), and send it + self.e = pow(G, self.x, P) + m = Message() + m.add_byte(chr(_MSG_KEXDH_INIT)) + m.add_mpint(self.e) + self.transport._send_message(m) + self.transport._expect_packet(_MSG_KEXDH_REPLY) + + def parse_next(self, ptype, m): + if self.transport.server_mode and (ptype == _MSG_KEXDH_INIT): + return self._parse_kexdh_init(m) + elif not self.transport.server_mode and (ptype == _MSG_KEXDH_REPLY): + return self._parse_kexdh_reply(m) + raise SSHException('KexGroup1 asked to handle packet type %d' % ptype) + + + ### internals... + + + def _generate_x(self): + # generate an "x" (1 < x < q), where q is (p-1)/2. + # p is a 128-byte (1024-bit) number, where the first 64 bits are 1. + # therefore q can be approximated as a 2^1023. we drop the subset of + # potential x where the first 63 bits are 1, because some of those will be + # larger than q (but this is a tiny tiny subset of potential x). + while 1: + self.transport.randpool.stir() + x_bytes = self.transport.randpool.get_bytes(128) + x_bytes = chr(ord(x_bytes[0]) & 0x7f) + x_bytes[1:] + if (x_bytes[:8] != '\x7F\xFF\xFF\xFF\xFF\xFF\xFF\xFF') and \ + (x_bytes[:8] != '\x00\x00\x00\x00\x00\x00\x00\x00'): + break + self.x = util.inflate_long(x_bytes) + + def _parse_kexdh_reply(self, m): + # client mode + host_key = m.get_string() + self.f = m.get_mpint() + if (self.f < 1) or (self.f > P - 1): + raise SSHException('Server kex "f" is out of range') + sig = m.get_string() + K = pow(self.f, self.x, P) + # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K) + hm = Message() + hm.add(self.transport.local_version, self.transport.remote_version, + self.transport.local_kex_init, self.transport.remote_kex_init) + hm.add_string(host_key) + hm.add_mpint(self.e) + hm.add_mpint(self.f) + hm.add_mpint(K) + self.transport._set_K_H(K, SHA.new(str(hm)).digest()) + self.transport._verify_key(host_key, sig) + self.transport._activate_outbound() + + def _parse_kexdh_init(self, m): + # server mode + self.e = m.get_mpint() + if (self.e < 1) or (self.e > P - 1): + raise SSHException('Client kex "e" is out of range') + K = pow(self.e, self.x, P) + key = str(self.transport.get_server_key()) + # okay, build up the hash H of (V_C || V_S || I_C || I_S || K_S || e || f || K) + hm = Message() + hm.add(self.transport.remote_version, self.transport.local_version, + self.transport.remote_kex_init, self.transport.local_kex_init) + hm.add_string(key) + hm.add_mpint(self.e) + hm.add_mpint(self.f) + hm.add_mpint(K) + H = SHA.new(str(hm)).digest() + self.transport._set_K_H(K, H) + # sign it + sig = self.transport.get_server_key().sign_ssh_data(self.transport.randpool, H) + # send reply + m = Message() + m.add_byte(chr(_MSG_KEXDH_REPLY)) + m.add_string(key) + m.add_mpint(self.f) + m.add_string(str(sig)) + self.transport._send_message(m) + self.transport._activate_outbound() diff --git a/tools/migration/paramiko/logging22.py b/tools/migration/paramiko/logging22.py new file mode 100644 index 00000000000..ed1d8919815 --- /dev/null +++ b/tools/migration/paramiko/logging22.py @@ -0,0 +1,66 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Stub out logging on python < 2.3. +""" + + +DEBUG = 10 +INFO = 20 +WARNING = 30 +ERROR = 40 +CRITICAL = 50 + + +def getLogger(name): + return _logger + + +class logger (object): + def __init__(self): + self.handlers = [ ] + self.level = ERROR + + def setLevel(self, level): + self.level = level + + def addHandler(self, h): + self.handlers.append(h) + + def addFilter(self, filter): + pass + + def log(self, level, text): + if level >= self.level: + for h in self.handlers: + h.f.write(text + '\n') + h.f.flush() + +class StreamHandler (object): + def __init__(self, f): + self.f = f + + def setFormatter(self, f): + pass + +class Formatter (object): + def __init__(self, x, y): + pass + +_logger = logger() diff --git a/tools/migration/paramiko/message.py b/tools/migration/paramiko/message.py new file mode 100644 index 00000000000..366c43c96c9 --- /dev/null +++ b/tools/migration/paramiko/message.py @@ -0,0 +1,301 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Implementation of an SSH2 "message". +""" + +import struct +import cStringIO + +from paramiko import util + + +class Message (object): + """ + An SSH2 I{Message} is a stream of bytes that encodes some combination of + strings, integers, bools, and infinite-precision integers (known in python + as I{long}s). This class builds or breaks down such a byte stream. + + Normally you don't need to deal with anything this low-level, but it's + exposed for people implementing custom extensions, or features that + paramiko doesn't support yet. + """ + + def __init__(self, content=None): + """ + Create a new SSH2 Message. + + @param content: the byte stream to use as the Message content (passed + in only when decomposing a Message). + @type content: string + """ + if content != None: + self.packet = cStringIO.StringIO(content) + else: + self.packet = cStringIO.StringIO() + + def __str__(self): + """ + Return the byte stream content of this Message, as a string. + + @return: the contents of this Message. + @rtype: string + """ + return self.packet.getvalue() + + def __repr__(self): + """ + Returns a string representation of this object, for debugging. + + @rtype: string + """ + return 'paramiko.Message(' + repr(self.packet.getvalue()) + ')' + + def rewind(self): + """ + Rewind the message to the beginning as if no items had been parsed + out of it yet. + """ + self.packet.seek(0) + + def get_remainder(self): + """ + Return the bytes of this Message that haven't already been parsed and + returned. + + @return: a string of the bytes not parsed yet. + @rtype: string + """ + position = self.packet.tell() + remainder = self.packet.read() + self.packet.seek(position) + return remainder + + def get_so_far(self): + """ + Returns the bytes of this Message that have been parsed and returned. + The string passed into a Message's constructor can be regenerated by + concatenating C{get_so_far} and L{get_remainder}. + + @return: a string of the bytes parsed so far. + @rtype: string + """ + position = self.packet.tell() + self.rewind() + return self.packet.read(position) + + def get_bytes(self, n): + """ + Return the next C{n} bytes of the Message, without decomposing into + an int, string, etc. Just the raw bytes are returned. + + @return: a string of the next C{n} bytes of the Message, or a string + of C{n} zero bytes, if there aren't C{n} bytes remaining. + @rtype: string + """ + b = self.packet.read(n) + if len(b) < n: + return b + '\x00' * (n - len(b)) + return b + + def get_byte(self): + """ + Return the next byte of the Message, without decomposing it. This + is equivalent to L{get_bytes(1)}. + + @return: the next byte of the Message, or C{'\000'} if there aren't + any bytes remaining. + @rtype: string + """ + return self.get_bytes(1) + + def get_boolean(self): + """ + Fetch a boolean from the stream. + + @return: C{True} or C{False} (from the Message). + @rtype: bool + """ + b = self.get_bytes(1) + return b != '\x00' + + def get_int(self): + """ + Fetch an int from the stream. + + @return: a 32-bit unsigned integer. + @rtype: int + """ + return struct.unpack('>I', self.get_bytes(4))[0] + + def get_int64(self): + """ + Fetch a 64-bit int from the stream. + + @return: a 64-bit unsigned integer. + @rtype: long + """ + return struct.unpack('>Q', self.get_bytes(8))[0] + + def get_mpint(self): + """ + Fetch a long int (mpint) from the stream. + + @return: an arbitrary-length integer. + @rtype: long + """ + return util.inflate_long(self.get_string()) + + def get_string(self): + """ + Fetch a string from the stream. This could be a byte string and may + contain unprintable characters. (It's not unheard of for a string to + contain another byte-stream Message.) + + @return: a string. + @rtype: string + """ + return self.get_bytes(self.get_int()) + + def get_list(self): + """ + Fetch a list of strings from the stream. These are trivially encoded + as comma-separated values in a string. + + @return: a list of strings. + @rtype: list of strings + """ + return self.get_string().split(',') + + def add_bytes(self, b): + """ + Write bytes to the stream, without any formatting. + + @param b: bytes to add + @type b: str + """ + self.packet.write(b) + return self + + def add_byte(self, b): + """ + Write a single byte to the stream, without any formatting. + + @param b: byte to add + @type b: str + """ + self.packet.write(b) + return self + + def add_boolean(self, b): + """ + Add a boolean value to the stream. + + @param b: boolean value to add + @type b: bool + """ + if b: + self.add_byte('\x01') + else: + self.add_byte('\x00') + return self + + def add_int(self, n): + """ + Add an integer to the stream. + + @param n: integer to add + @type n: int + """ + self.packet.write(struct.pack('>I', n)) + return self + + def add_int64(self, n): + """ + Add a 64-bit int to the stream. + + @param n: long int to add + @type n: long + """ + self.packet.write(struct.pack('>Q', n)) + return self + + def add_mpint(self, z): + """ + Add a long int to the stream, encoded as an infinite-precision + integer. This method only works on positive numbers. + + @param z: long int to add + @type z: long + """ + self.add_string(util.deflate_long(z)) + return self + + def add_string(self, s): + """ + Add a string to the stream. + + @param s: string to add + @type s: str + """ + self.add_int(len(s)) + self.packet.write(s) + return self + + def add_list(self, l): + """ + Add a list of strings to the stream. They are encoded identically to + a single string of values separated by commas. (Yes, really, that's + how SSH2 does it.) + + @param l: list of strings to add + @type l: list(str) + """ + self.add_string(','.join(l)) + return self + + def _add(self, i): + if type(i) is str: + return self.add_string(i) + elif type(i) is int: + return self.add_int(i) + elif type(i) is long: + if i > 0xffffffffL: + return self.add_mpint(i) + else: + return self.add_int(i) + elif type(i) is bool: + return self.add_boolean(i) + elif type(i) is list: + return self.add_list(i) + else: + raise Exception('Unknown type') + + def add(self, *seq): + """ + Add a sequence of items to the stream. The values are encoded based + on their type: str, int, bool, list, or long. + + @param seq: the sequence of items + @type seq: sequence + + @bug: longs are encoded non-deterministically. Don't use this method. + """ + for item in seq: + self._add(item) diff --git a/tools/migration/paramiko/packet.py b/tools/migration/paramiko/packet.py new file mode 100644 index 00000000000..9072fbedd4a --- /dev/null +++ b/tools/migration/paramiko/packet.py @@ -0,0 +1,488 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Packetizer. +""" + +import errno +import select +import socket +import struct +import threading +import time + +from paramiko.common import * +from paramiko import util +from paramiko.ssh_exception import SSHException +from paramiko.message import Message + + +got_r_hmac = False +try: + import r_hmac + got_r_hmac = True +except ImportError: + pass +def compute_hmac(key, message, digest_class): + if got_r_hmac: + return r_hmac.HMAC(key, message, digest_class).digest() + from Crypto.Hash import HMAC + return HMAC.HMAC(key, message, digest_class).digest() + + +class NeedRekeyException (Exception): + pass + + +class Packetizer (object): + """ + Implementation of the base SSH packet protocol. + """ + + # READ the secsh RFC's before raising these values. if anything, + # they should probably be lower. + REKEY_PACKETS = pow(2, 30) + REKEY_BYTES = pow(2, 30) + + def __init__(self, socket): + self.__socket = socket + self.__logger = None + self.__closed = False + self.__dump_packets = False + self.__need_rekey = False + self.__init_count = 0 + self.__remainder = '' + + # used for noticing when to re-key: + self.__sent_bytes = 0 + self.__sent_packets = 0 + self.__received_bytes = 0 + self.__received_packets = 0 + self.__received_packets_overflow = 0 + + # current inbound/outbound ciphering: + self.__block_size_out = 8 + self.__block_size_in = 8 + self.__mac_size_out = 0 + self.__mac_size_in = 0 + self.__block_engine_out = None + self.__block_engine_in = None + self.__mac_engine_out = None + self.__mac_engine_in = None + self.__mac_key_out = '' + self.__mac_key_in = '' + self.__compress_engine_out = None + self.__compress_engine_in = None + self.__sequence_number_out = 0L + self.__sequence_number_in = 0L + + # lock around outbound writes (packet computation) + self.__write_lock = threading.RLock() + + # keepalives: + self.__keepalive_interval = 0 + self.__keepalive_last = time.time() + self.__keepalive_callback = None + + def set_log(self, log): + """ + Set the python log object to use for logging. + """ + self.__logger = log + + def set_outbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key): + """ + Switch outbound data cipher. + """ + self.__block_engine_out = block_engine + self.__block_size_out = block_size + self.__mac_engine_out = mac_engine + self.__mac_size_out = mac_size + self.__mac_key_out = mac_key + self.__sent_bytes = 0 + self.__sent_packets = 0 + # wait until the reset happens in both directions before clearing rekey flag + self.__init_count |= 1 + if self.__init_count == 3: + self.__init_count = 0 + self.__need_rekey = False + + def set_inbound_cipher(self, block_engine, block_size, mac_engine, mac_size, mac_key): + """ + Switch inbound data cipher. + """ + self.__block_engine_in = block_engine + self.__block_size_in = block_size + self.__mac_engine_in = mac_engine + self.__mac_size_in = mac_size + self.__mac_key_in = mac_key + self.__received_bytes = 0 + self.__received_packets = 0 + self.__received_packets_overflow = 0 + # wait until the reset happens in both directions before clearing rekey flag + self.__init_count |= 2 + if self.__init_count == 3: + self.__init_count = 0 + self.__need_rekey = False + + def set_outbound_compressor(self, compressor): + self.__compress_engine_out = compressor + + def set_inbound_compressor(self, compressor): + self.__compress_engine_in = compressor + + def close(self): + self.__closed = True + self.__socket.close() + + def set_hexdump(self, hexdump): + self.__dump_packets = hexdump + + def get_hexdump(self): + return self.__dump_packets + + def get_mac_size_in(self): + return self.__mac_size_in + + def get_mac_size_out(self): + return self.__mac_size_out + + def need_rekey(self): + """ + Returns C{True} if a new set of keys needs to be negotiated. This + will be triggered during a packet read or write, so it should be + checked after every read or write, or at least after every few. + + @return: C{True} if a new set of keys needs to be negotiated + """ + return self.__need_rekey + + def set_keepalive(self, interval, callback): + """ + Turn on/off the callback keepalive. If C{interval} seconds pass with + no data read from or written to the socket, the callback will be + executed and the timer will be reset. + """ + self.__keepalive_interval = interval + self.__keepalive_callback = callback + self.__keepalive_last = time.time() + + def read_all(self, n, check_rekey=False): + """ + Read as close to N bytes as possible, blocking as long as necessary. + + @param n: number of bytes to read + @type n: int + @return: the data read + @rtype: str + @raise EOFError: if the socket was closed before all the bytes could + be read + """ + out = '' + # handle over-reading from reading the banner line + if len(self.__remainder) > 0: + out = self.__remainder[:n] + self.__remainder = self.__remainder[n:] + n -= len(out) + if PY22: + return self._py22_read_all(n, out) + while n > 0: + got_timeout = False + try: + x = self.__socket.recv(n) + if len(x) == 0: + raise EOFError() + out += x + n -= len(x) + except socket.timeout: + got_timeout = True + except socket.error, e: + # on Linux, sometimes instead of socket.timeout, we get + # EAGAIN. this is a bug in recent (> 2.6.9) kernels but + # we need to work around it. + if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): + got_timeout = True + elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): + # syscall interrupted; try again + pass + elif self.__closed: + raise EOFError() + else: + raise + if got_timeout: + if self.__closed: + raise EOFError() + if check_rekey and (len(out) == 0) and self.__need_rekey: + raise NeedRekeyException() + self._check_keepalive() + return out + + def write_all(self, out): + self.__keepalive_last = time.time() + while len(out) > 0: + got_timeout = False + try: + n = self.__socket.send(out) + except socket.timeout: + got_timeout = True + except socket.error, e: + if (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EAGAIN): + got_timeout = True + elif (type(e.args) is tuple) and (len(e.args) > 0) and (e.args[0] == errno.EINTR): + # syscall interrupted; try again + pass + else: + n = -1 + except Exception: + # could be: (32, 'Broken pipe') + n = -1 + if got_timeout: + n = 0 + if self.__closed: + n = -1 + if n < 0: + raise EOFError() + if n == len(out): + break + out = out[n:] + return + + def readline(self, timeout): + """ + Read a line from the socket. We assume no data is pending after the + line, so it's okay to attempt large reads. + """ + buf = self.__remainder + while not '\n' in buf: + buf += self._read_timeout(timeout) + n = buf.index('\n') + self.__remainder = buf[n+1:] + buf = buf[:n] + if (len(buf) > 0) and (buf[-1] == '\r'): + buf = buf[:-1] + return buf + + def send_message(self, data): + """ + Write a block of data using the current cipher, as an SSH block. + """ + # encrypt this sucka + data = str(data) + cmd = ord(data[0]) + if cmd in MSG_NAMES: + cmd_name = MSG_NAMES[cmd] + else: + cmd_name = '$%x' % cmd + orig_len = len(data) + self.__write_lock.acquire() + try: + if self.__compress_engine_out is not None: + data = self.__compress_engine_out(data) + packet = self._build_packet(data) + if self.__dump_packets: + self._log(DEBUG, 'Write packet <%s>, length %d' % (cmd_name, orig_len)) + self._log(DEBUG, util.format_binary(packet, 'OUT: ')) + if self.__block_engine_out != None: + out = self.__block_engine_out.encrypt(packet) + else: + out = packet + # + mac + if self.__block_engine_out != None: + payload = struct.pack('>I', self.__sequence_number_out) + packet + out += compute_hmac(self.__mac_key_out, payload, self.__mac_engine_out)[:self.__mac_size_out] + self.__sequence_number_out = (self.__sequence_number_out + 1) & 0xffffffffL + self.write_all(out) + + self.__sent_bytes += len(out) + self.__sent_packets += 1 + if (self.__sent_packets % 100) == 0: + # stirring the randpool takes 30ms on my ibook!! + randpool.stir() + if ((self.__sent_packets >= self.REKEY_PACKETS) or (self.__sent_bytes >= self.REKEY_BYTES)) \ + and not self.__need_rekey: + # only ask once for rekeying + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes sent)' % + (self.__sent_packets, self.__sent_bytes)) + self.__received_packets_overflow = 0 + self._trigger_rekey() + finally: + self.__write_lock.release() + + def read_message(self): + """ + Only one thread should ever be in this function (no other locking is + done). + + @raise SSHException: if the packet is mangled + @raise NeedRekeyException: if the transport should rekey + """ + header = self.read_all(self.__block_size_in, check_rekey=True) + if self.__block_engine_in != None: + header = self.__block_engine_in.decrypt(header) + if self.__dump_packets: + self._log(DEBUG, util.format_binary(header, 'IN: ')); + packet_size = struct.unpack('>I', header[:4])[0] + # leftover contains decrypted bytes from the first block (after the length field) + leftover = header[4:] + if (packet_size - len(leftover)) % self.__block_size_in != 0: + raise SSHException('Invalid packet blocking') + buf = self.read_all(packet_size + self.__mac_size_in - len(leftover)) + packet = buf[:packet_size - len(leftover)] + post_packet = buf[packet_size - len(leftover):] + if self.__block_engine_in != None: + packet = self.__block_engine_in.decrypt(packet) + if self.__dump_packets: + self._log(DEBUG, util.format_binary(packet, 'IN: ')); + packet = leftover + packet + + if self.__mac_size_in > 0: + mac = post_packet[:self.__mac_size_in] + mac_payload = struct.pack('>II', self.__sequence_number_in, packet_size) + packet + my_mac = compute_hmac(self.__mac_key_in, mac_payload, self.__mac_engine_in)[:self.__mac_size_in] + if my_mac != mac: + raise SSHException('Mismatched MAC') + padding = ord(packet[0]) + payload = packet[1:packet_size - padding] + randpool.add_event() + if self.__dump_packets: + self._log(DEBUG, 'Got payload (%d bytes, %d padding)' % (packet_size, padding)) + + if self.__compress_engine_in is not None: + payload = self.__compress_engine_in(payload) + + msg = Message(payload[1:]) + msg.seqno = self.__sequence_number_in + self.__sequence_number_in = (self.__sequence_number_in + 1) & 0xffffffffL + + # check for rekey + self.__received_bytes += packet_size + self.__mac_size_in + 4 + self.__received_packets += 1 + if self.__need_rekey: + # we've asked to rekey -- give them 20 packets to comply before + # dropping the connection + self.__received_packets_overflow += 1 + if self.__received_packets_overflow >= 20: + raise SSHException('Remote transport is ignoring rekey requests') + elif (self.__received_packets >= self.REKEY_PACKETS) or \ + (self.__received_bytes >= self.REKEY_BYTES): + # only ask once for rekeying + self._log(DEBUG, 'Rekeying (hit %d packets, %d bytes received)' % + (self.__received_packets, self.__received_bytes)) + self.__received_packets_overflow = 0 + self._trigger_rekey() + + cmd = ord(payload[0]) + if cmd in MSG_NAMES: + cmd_name = MSG_NAMES[cmd] + else: + cmd_name = '$%x' % cmd + if self.__dump_packets: + self._log(DEBUG, 'Read packet <%s>, length %d' % (cmd_name, len(payload))) + return cmd, msg + + + ########## protected + + + def _log(self, level, msg): + if self.__logger is None: + return + if issubclass(type(msg), list): + for m in msg: + self.__logger.log(level, m) + else: + self.__logger.log(level, msg) + + def _check_keepalive(self): + if (not self.__keepalive_interval) or (not self.__block_engine_out) or \ + self.__need_rekey: + # wait till we're encrypting, and not in the middle of rekeying + return + now = time.time() + if now > self.__keepalive_last + self.__keepalive_interval: + self.__keepalive_callback() + self.__keepalive_last = now + + def _py22_read_all(self, n, out): + while n > 0: + r, w, e = select.select([self.__socket], [], [], 0.1) + if self.__socket not in r: + if self.__closed: + raise EOFError() + self._check_keepalive() + else: + x = self.__socket.recv(n) + if len(x) == 0: + raise EOFError() + out += x + n -= len(x) + return out + + def _py22_read_timeout(self, timeout): + start = time.time() + while True: + r, w, e = select.select([self.__socket], [], [], 0.1) + if self.__socket in r: + x = self.__socket.recv(1) + if len(x) == 0: + raise EOFError() + break + if self.__closed: + raise EOFError() + now = time.time() + if now - start >= timeout: + raise socket.timeout() + return x + + def _read_timeout(self, timeout): + if PY22: + return self._py22_read_timeout(timeout) + start = time.time() + while True: + try: + x = self.__socket.recv(128) + if len(x) == 0: + raise EOFError() + break + except socket.timeout: + pass + if self.__closed: + raise EOFError() + now = time.time() + if now - start >= timeout: + raise socket.timeout() + return x + + def _build_packet(self, payload): + # pad up at least 4 bytes, to nearest block-size (usually 8) + bsize = self.__block_size_out + padding = 3 + bsize - ((len(payload) + 8) % bsize) + packet = struct.pack('>IB', len(payload) + padding + 1, padding) + packet += payload + if self.__block_engine_out is not None: + packet += randpool.get_bytes(padding) + else: + # cute trick i caught openssh doing: if we're not encrypting, + # don't waste random bytes for the padding + packet += (chr(0) * padding) + return packet + + def _trigger_rekey(self): + # outside code should check for this flag + self.__need_rekey = True diff --git a/tools/migration/paramiko/pipe.py b/tools/migration/paramiko/pipe.py new file mode 100644 index 00000000000..37191ef9f37 --- /dev/null +++ b/tools/migration/paramiko/pipe.py @@ -0,0 +1,147 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Abstraction of a one-way pipe where the read end can be used in select(). +Normally this is trivial, but Windows makes it nearly impossible. + +The pipe acts like an Event, which can be set or cleared. When set, the pipe +will trigger as readable in select(). +""" + +import sys +import os +import socket + + +def make_pipe (): + if sys.platform[:3] != 'win': + p = PosixPipe() + else: + p = WindowsPipe() + return p + + +class PosixPipe (object): + def __init__ (self): + self._rfd, self._wfd = os.pipe() + self._set = False + self._forever = False + self._closed = False + + def close (self): + os.close(self._rfd) + os.close(self._wfd) + # used for unit tests: + self._closed = True + + def fileno (self): + return self._rfd + + def clear (self): + if not self._set or self._forever: + return + os.read(self._rfd, 1) + self._set = False + + def set (self): + if self._set or self._closed: + return + self._set = True + os.write(self._wfd, '*') + + def set_forever (self): + self._forever = True + self.set() + + +class WindowsPipe (object): + """ + On Windows, only an OS-level "WinSock" may be used in select(), but reads + and writes must be to the actual socket object. + """ + def __init__ (self): + serv = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + serv.bind(('127.0.0.1', 0)) + serv.listen(1) + + # need to save sockets in _rsock/_wsock so they don't get closed + self._rsock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self._rsock.connect(('127.0.0.1', serv.getsockname()[1])) + + self._wsock, addr = serv.accept() + serv.close() + self._set = False + self._forever = False + self._closed = False + + def close (self): + self._rsock.close() + self._wsock.close() + # used for unit tests: + self._closed = True + + def fileno (self): + return self._rsock.fileno() + + def clear (self): + if not self._set or self._forever: + return + self._rsock.recv(1) + self._set = False + + def set (self): + if self._set or self._closed: + return + self._set = True + self._wsock.send('*') + + def set_forever (self): + self._forever = True + self.set() + + +class OrPipe (object): + def __init__(self, pipe): + self._set = False + self._partner = None + self._pipe = pipe + + def set(self): + self._set = True + if not self._partner._set: + self._pipe.set() + + def clear(self): + self._set = False + if not self._partner._set: + self._pipe.clear() + + +def make_or_pipe(pipe): + """ + wraps a pipe into two pipe-like objects which are "or"d together to + affect the real pipe. if either returned pipe is set, the wrapped pipe + is set. when both are cleared, the wrapped pipe is cleared. + """ + p1 = OrPipe(pipe) + p2 = OrPipe(pipe) + p1._partner = p2 + p2._partner = p1 + return p1, p2 + diff --git a/tools/migration/paramiko/pkey.py b/tools/migration/paramiko/pkey.py new file mode 100644 index 00000000000..bb8c83c655d --- /dev/null +++ b/tools/migration/paramiko/pkey.py @@ -0,0 +1,380 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Common API for all public keys. +""" + +import base64 +from binascii import hexlify, unhexlify +import os + +from Crypto.Hash import MD5 +from Crypto.Cipher import DES3 + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ssh_exception import SSHException, PasswordRequiredException + + +class PKey (object): + """ + Base class for public keys. + """ + + # known encryption types for private key files: + _CIPHER_TABLE = { + 'DES-EDE3-CBC': { 'cipher': DES3, 'keysize': 24, 'blocksize': 8, 'mode': DES3.MODE_CBC } + } + + + def __init__(self, msg=None, data=None): + """ + Create a new instance of this public key type. If C{msg} is given, + the key's public part(s) will be filled in from the message. If + C{data} is given, the key's public part(s) will be filled in from + the string. + + @param msg: an optional SSH L{Message} containing a public key of this + type. + @type msg: L{Message} + @param data: an optional string containing a public key of this type + @type data: str + + @raise SSHException: if a key cannot be created from the C{data} or + C{msg} given, or no key was passed in. + """ + pass + + def __str__(self): + """ + Return a string of an SSH L{Message} made up of the public part(s) of + this key. This string is suitable for passing to L{__init__} to + re-create the key object later. + + @return: string representation of an SSH key message. + @rtype: str + """ + return '' + + def __cmp__(self, other): + """ + Compare this key to another. Returns 0 if this key is equivalent to + the given key, or non-0 if they are different. Only the public parts + of the key are compared, so a public key will compare equal to its + corresponding private key. + + @param other: key to compare to. + @type other: L{PKey} + @return: 0 if the two keys are equivalent, non-0 otherwise. + @rtype: int + """ + hs = hash(self) + ho = hash(other) + if hs != ho: + return cmp(hs, ho) + return cmp(str(self), str(other)) + + def get_name(self): + """ + Return the name of this private key implementation. + + @return: name of this private key type, in SSH terminology (for + example, C{"ssh-rsa"}). + @rtype: str + """ + return '' + + def get_bits(self): + """ + Return the number of significant bits in this key. This is useful + for judging the relative security of a key. + + @return: bits in the key. + @rtype: int + """ + return 0 + + def can_sign(self): + """ + Return C{True} if this key has the private part necessary for signing + data. + + @return: C{True} if this is a private key. + @rtype: bool + """ + return False + + def get_fingerprint(self): + """ + Return an MD5 fingerprint of the public part of this key. Nothing + secret is revealed. + + @return: a 16-byte string (binary) of the MD5 fingerprint, in SSH + format. + @rtype: str + """ + return MD5.new(str(self)).digest() + + def get_base64(self): + """ + Return a base64 string containing the public part of this key. Nothing + secret is revealed. This format is compatible with that used to store + public key files or recognized host keys. + + @return: a base64 string containing the public part of the key. + @rtype: str + """ + return base64.encodestring(str(self)).replace('\n', '') + + def sign_ssh_data(self, randpool, data): + """ + Sign a blob of data with this private key, and return a L{Message} + representing an SSH signature message. + + @param randpool: a secure random number generator. + @type randpool: L{Crypto.Util.randpool.RandomPool} + @param data: the data to sign. + @type data: str + @return: an SSH signature message. + @rtype: L{Message} + """ + return '' + + def verify_ssh_sig(self, data, msg): + """ + Given a blob of data, and an SSH message representing a signature of + that data, verify that it was signed with this key. + + @param data: the data that was signed. + @type data: str + @param msg: an SSH signature message + @type msg: L{Message} + @return: C{True} if the signature verifies correctly; C{False} + otherwise. + @rtype: boolean + """ + return False + + def from_private_key_file(cls, filename, password=None): + """ + Create a key object by reading a private key file. If the private + key is encrypted and C{password} is not C{None}, the given password + will be used to decrypt the key (otherwise L{PasswordRequiredException} + is thrown). Through the magic of python, this factory method will + exist in all subclasses of PKey (such as L{RSAKey} or L{DSSKey}), but + is useless on the abstract PKey class. + + @param filename: name of the file to read + @type filename: str + @param password: an optional password to use to decrypt the key file, + if it's encrypted + @type password: str + @return: a new key object based on the given private key + @rtype: L{PKey} + + @raise IOError: if there was an error reading the file + @raise PasswordRequiredException: if the private key file is + encrypted, and C{password} is C{None} + @raise SSHException: if the key file is invalid + """ + key = cls(filename=filename, password=password) + return key + from_private_key_file = classmethod(from_private_key_file) + + def from_private_key(cls, file_obj, password=None): + """ + Create a key object by reading a private key from a file (or file-like) + object. If the private key is encrypted and C{password} is not C{None}, + the given password will be used to decrypt the key (otherwise + L{PasswordRequiredException} is thrown). + + @param file_obj: the file to read from + @type file_obj: file + @param password: an optional password to use to decrypt the key, if it's + encrypted + @type password: str + @return: a new key object based on the given private key + @rtype: L{PKey} + + @raise IOError: if there was an error reading the key + @raise PasswordRequiredException: if the private key file is encrypted, + and C{password} is C{None} + @raise SSHException: if the key file is invalid + """ + key = cls(file_obj=file_obj, password=password) + return key + from_private_key = classmethod(from_private_key) + + def write_private_key_file(self, filename, password=None): + """ + Write private key contents into a file. If the password is not + C{None}, the key is encrypted before writing. + + @param filename: name of the file to write + @type filename: str + @param password: an optional password to use to encrypt the key file + @type password: str + + @raise IOError: if there was an error writing the file + @raise SSHException: if the key is invalid + """ + raise Exception('Not implemented in PKey') + + def write_private_key(self, file_obj, password=None): + """ + Write private key contents into a file (or file-like) object. If the + password is not C{None}, the key is encrypted before writing. + + @param file_obj: the file object to write into + @type file_obj: file + @param password: an optional password to use to encrypt the key + @type password: str + + @raise IOError: if there was an error writing to the file + @raise SSHException: if the key is invalid + """ + raise Exception('Not implemented in PKey') + + def _read_private_key_file(self, tag, filename, password=None): + """ + Read an SSH2-format private key file, looking for a string of the type + C{"BEGIN xxx PRIVATE KEY"} for some C{xxx}, base64-decode the text we + find, and return it as a string. If the private key is encrypted and + C{password} is not C{None}, the given password will be used to decrypt + the key (otherwise L{PasswordRequiredException} is thrown). + + @param tag: C{"RSA"} or C{"DSA"}, the tag used to mark the data block. + @type tag: str + @param filename: name of the file to read. + @type filename: str + @param password: an optional password to use to decrypt the key file, + if it's encrypted. + @type password: str + @return: data blob that makes up the private key. + @rtype: str + + @raise IOError: if there was an error reading the file. + @raise PasswordRequiredException: if the private key file is + encrypted, and C{password} is C{None}. + @raise SSHException: if the key file is invalid. + """ + f = open(filename, 'r') + data = self._read_private_key(tag, f, password) + f.close() + return data + + def _read_private_key(self, tag, f, password=None): + lines = f.readlines() + start = 0 + while (start < len(lines)) and (lines[start].strip() != '-----BEGIN ' + tag + ' PRIVATE KEY-----'): + start += 1 + if start >= len(lines): + raise SSHException('not a valid ' + tag + ' private key file') + # parse any headers first + headers = {} + start += 1 + while start < len(lines): + l = lines[start].split(': ') + if len(l) == 1: + break + headers[l[0].lower()] = l[1].strip() + start += 1 + # find end + end = start + while (lines[end].strip() != '-----END ' + tag + ' PRIVATE KEY-----') and (end < len(lines)): + end += 1 + # if we trudged to the end of the file, just try to cope. + try: + data = base64.decodestring(''.join(lines[start:end])) + except base64.binascii.Error, e: + raise SSHException('base64 decoding error: ' + str(e)) + if 'proc-type' not in headers: + # unencryped: done + return data + # encrypted keyfile: will need a password + if headers['proc-type'] != '4,ENCRYPTED': + raise SSHException('Unknown private key structure "%s"' % headers['proc-type']) + try: + encryption_type, saltstr = headers['dek-info'].split(',') + except: + raise SSHException('Can\'t parse DEK-info in private key file') + if encryption_type not in self._CIPHER_TABLE: + raise SSHException('Unknown private key cipher "%s"' % encryption_type) + # if no password was passed in, raise an exception pointing out that we need one + if password is None: + raise PasswordRequiredException('Private key file is encrypted') + cipher = self._CIPHER_TABLE[encryption_type]['cipher'] + keysize = self._CIPHER_TABLE[encryption_type]['keysize'] + mode = self._CIPHER_TABLE[encryption_type]['mode'] + salt = unhexlify(saltstr) + key = util.generate_key_bytes(MD5, salt, password, keysize) + return cipher.new(key, mode, salt).decrypt(data) + + def _write_private_key_file(self, tag, filename, data, password=None): + """ + Write an SSH2-format private key file in a form that can be read by + paramiko or openssh. If no password is given, the key is written in + a trivially-encoded format (base64) which is completely insecure. If + a password is given, DES-EDE3-CBC is used. + + @param tag: C{"RSA"} or C{"DSA"}, the tag used to mark the data block. + @type tag: str + @param filename: name of the file to write. + @type filename: str + @param data: data blob that makes up the private key. + @type data: str + @param password: an optional password to use to encrypt the file. + @type password: str + + @raise IOError: if there was an error writing the file. + """ + f = open(filename, 'w', 0600) + # grrr... the mode doesn't always take hold + os.chmod(filename, 0600) + self._write_private_key(tag, f, data, password) + f.close() + + def _write_private_key(self, tag, f, data, password=None): + f.write('-----BEGIN %s PRIVATE KEY-----\n' % tag) + if password is not None: + # since we only support one cipher here, use it + cipher_name = self._CIPHER_TABLE.keys()[0] + cipher = self._CIPHER_TABLE[cipher_name]['cipher'] + keysize = self._CIPHER_TABLE[cipher_name]['keysize'] + blocksize = self._CIPHER_TABLE[cipher_name]['blocksize'] + mode = self._CIPHER_TABLE[cipher_name]['mode'] + salt = randpool.get_bytes(8) + key = util.generate_key_bytes(MD5, salt, password, keysize) + if len(data) % blocksize != 0: + n = blocksize - len(data) % blocksize + #data += randpool.get_bytes(n) + # that would make more sense ^, but it confuses openssh. + data += '\0' * n + data = cipher.new(key, mode, salt).encrypt(data) + f.write('Proc-Type: 4,ENCRYPTED\n') + f.write('DEK-Info: %s,%s\n' % (cipher_name, hexlify(salt).upper())) + f.write('\n') + s = base64.encodestring(data) + # re-wrap to 64-char lines + s = ''.join(s.split('\n')) + s = '\n'.join([s[i : i+64] for i in range(0, len(s), 64)]) + f.write(s) + f.write('\n') + f.write('-----END %s PRIVATE KEY-----\n' % tag) diff --git a/tools/migration/paramiko/primes.py b/tools/migration/paramiko/primes.py new file mode 100644 index 00000000000..1cf79058215 --- /dev/null +++ b/tools/migration/paramiko/primes.py @@ -0,0 +1,151 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Utility functions for dealing with primes. +""" + +from Crypto.Util import number + +from paramiko import util +from paramiko.ssh_exception import SSHException + + +def _generate_prime(bits, randpool): + "primtive attempt at prime generation" + hbyte_mask = pow(2, bits % 8) - 1 + while True: + # loop catches the case where we increment n into a higher bit-range + x = randpool.get_bytes((bits+7) // 8) + if hbyte_mask > 0: + x = chr(ord(x[0]) & hbyte_mask) + x[1:] + n = util.inflate_long(x, 1) + n |= 1 + n |= (1 << (bits - 1)) + while not number.isPrime(n): + n += 2 + if util.bit_length(n) == bits: + break + return n + +def _roll_random(rpool, n): + "returns a random # from 0 to N-1" + bits = util.bit_length(n-1) + bytes = (bits + 7) // 8 + hbyte_mask = pow(2, bits % 8) - 1 + + # so here's the plan: + # we fetch as many random bits as we'd need to fit N-1, and if the + # generated number is >= N, we try again. in the worst case (N-1 is a + # power of 2), we have slightly better than 50% odds of getting one that + # fits, so i can't guarantee that this loop will ever finish, but the odds + # of it looping forever should be infinitesimal. + while True: + x = rpool.get_bytes(bytes) + if hbyte_mask > 0: + x = chr(ord(x[0]) & hbyte_mask) + x[1:] + num = util.inflate_long(x, 1) + if num < n: + break + return num + + +class ModulusPack (object): + """ + convenience object for holding the contents of the /etc/ssh/moduli file, + on systems that have such a file. + """ + + def __init__(self, rpool): + # pack is a hash of: bits -> [ (generator, modulus) ... ] + self.pack = {} + self.discarded = [] + self.randpool = rpool + + def _parse_modulus(self, line): + timestamp, mod_type, tests, tries, size, generator, modulus = line.split() + mod_type = int(mod_type) + tests = int(tests) + tries = int(tries) + size = int(size) + generator = int(generator) + modulus = long(modulus, 16) + + # weed out primes that aren't at least: + # type 2 (meets basic structural requirements) + # test 4 (more than just a small-prime sieve) + # tries < 100 if test & 4 (at least 100 tries of miller-rabin) + if (mod_type < 2) or (tests < 4) or ((tests & 4) and (tests < 8) and (tries < 100)): + self.discarded.append((modulus, 'does not meet basic requirements')) + return + if generator == 0: + generator = 2 + + # there's a bug in the ssh "moduli" file (yeah, i know: shock! dismay! + # call cnn!) where it understates the bit lengths of these primes by 1. + # this is okay. + bl = util.bit_length(modulus) + if (bl != size) and (bl != size + 1): + self.discarded.append((modulus, 'incorrectly reported bit length %d' % size)) + return + if bl not in self.pack: + self.pack[bl] = [] + self.pack[bl].append((generator, modulus)) + + def read_file(self, filename): + """ + @raise IOError: passed from any file operations that fail. + """ + self.pack = {} + f = open(filename, 'r') + for line in f: + line = line.strip() + if (len(line) == 0) or (line[0] == '#'): + continue + try: + self._parse_modulus(line) + except: + continue + f.close() + + def get_modulus(self, min, prefer, max): + bitsizes = self.pack.keys() + bitsizes.sort() + if len(bitsizes) == 0: + raise SSHException('no moduli available') + good = -1 + # find nearest bitsize >= preferred + for b in bitsizes: + if (b >= prefer) and (b < max) and ((b < good) or (good == -1)): + good = b + # if that failed, find greatest bitsize >= min + if good == -1: + for b in bitsizes: + if (b >= min) and (b < max) and (b > good): + good = b + if good == -1: + # their entire (min, max) range has no intersection with our range. + # if their range is below ours, pick the smallest. otherwise pick + # the largest. it'll be out of their range requirement either way, + # but we'll be sending them the closest one we have. + good = bitsizes[0] + if min > good: + good = bitsizes[-1] + # now pick a random modulus of this bitsize + n = _roll_random(self.randpool, len(self.pack[good])) + return self.pack[good][n] diff --git a/tools/migration/paramiko/resource.py b/tools/migration/paramiko/resource.py new file mode 100644 index 00000000000..0d5c82fa3ec --- /dev/null +++ b/tools/migration/paramiko/resource.py @@ -0,0 +1,72 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Resource manager. +""" + +import weakref + + +class ResourceManager (object): + """ + A registry of objects and resources that should be closed when those + objects are deleted. + + This is meant to be a safer alternative to python's C{__del__} method, + which can cause reference cycles to never be collected. Objects registered + with the ResourceManager can be collected but still free resources when + they die. + + Resources are registered using L{register}, and when an object is garbage + collected, each registered resource is closed by having its C{close()} + method called. Multiple resources may be registered per object, but a + resource will only be closed once, even if multiple objects register it. + (The last object to register it wins.) + """ + + def __init__(self): + self._table = {} + + def register(self, obj, resource): + """ + Register a resource to be closed with an object is collected. + + When the given C{obj} is garbage-collected by the python interpreter, + the C{resource} will be closed by having its C{close()} method called. + Any exceptions are ignored. + + @param obj: the object to track + @type obj: object + @param resource: the resource to close when the object is collected + @type resource: object + """ + def callback(ref): + try: + resource.close() + except: + pass + del self._table[id(resource)] + + # keep the weakref in a table so it sticks around long enough to get + # its callback called. :) + self._table[id(resource)] = weakref.ref(obj, callback) + + +# singleton +ResourceManager = ResourceManager() diff --git a/tools/migration/paramiko/rng.py b/tools/migration/paramiko/rng.py new file mode 100644 index 00000000000..46329d1edf4 --- /dev/null +++ b/tools/migration/paramiko/rng.py @@ -0,0 +1,112 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +# Copyright (C) 2008 Dwayne C. Litzenberger +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +import sys +import threading +from Crypto.Util.randpool import RandomPool as _RandomPool + +try: + import platform +except ImportError: + platform = None # Not available using Python 2.2 + +def _strxor(a, b): + assert len(a) == len(b) + return "".join(map(lambda x, y: chr(ord(x) ^ ord(y)), a, b)) + +## +## Find a strong random entropy source, depending on the detected platform. +## WARNING TO DEVELOPERS: This will fail on some systems, but do NOT use +## Crypto.Util.randpool.RandomPool as a fall-back. RandomPool will happily run +## with very little entropy, thus _silently_ defeating any security that +## Paramiko attempts to provide. (This is current as of PyCrypto 2.0.1). +## See http://www.lag.net/pipermail/paramiko/2008-January/000599.html +## and http://www.lag.net/pipermail/paramiko/2008-April/000678.html +## + +if ((platform is not None and platform.system().lower() == 'windows') or + sys.platform == 'win32'): + # MS Windows + from paramiko import rng_win32 + rng_device = rng_win32.open_rng_device() +else: + # Assume POSIX (any system where /dev/urandom exists) + from paramiko import rng_posix + rng_device = rng_posix.open_rng_device() + + +class StrongLockingRandomPool(object): + """Wrapper around RandomPool guaranteeing strong random numbers. + + Crypto.Util.randpool.RandomPool will silently operate even if it is seeded + with little or no entropy, and it provides no prediction resistance if its + state is ever compromised throughout its runtime. It is also not thread-safe. + + This wrapper augments RandomPool by XORing its output with random bits from + the operating system, and by controlling access to the underlying + RandomPool using an exclusive lock. + """ + + def __init__(self, instance=None): + if instance is None: + instance = _RandomPool() + self.randpool = instance + self.randpool_lock = threading.Lock() + self.entropy = rng_device + + # Stir 256 bits of entropy from the RNG device into the RandomPool. + self.randpool.stir(self.entropy.read(32)) + self.entropy.randomize() + + def stir(self, s=''): + self.randpool_lock.acquire() + try: + self.randpool.stir(s) + finally: + self.randpool_lock.release() + self.entropy.randomize() + + def randomize(self, N=0): + self.randpool_lock.acquire() + try: + self.randpool.randomize(N) + finally: + self.randpool_lock.release() + self.entropy.randomize() + + def add_event(self, s=''): + self.randpool_lock.acquire() + try: + self.randpool.add_event(s) + finally: + self.randpool_lock.release() + + def get_bytes(self, N): + self.randpool_lock.acquire() + try: + randpool_data = self.randpool.get_bytes(N) + finally: + self.randpool_lock.release() + entropy_data = self.entropy.read(N) + result = _strxor(randpool_data, entropy_data) + assert len(randpool_data) == N and len(entropy_data) == N and len(result) == N + return result + +# vim:set ts=4 sw=4 sts=4 expandtab: diff --git a/tools/migration/paramiko/rng_posix.py b/tools/migration/paramiko/rng_posix.py new file mode 100644 index 00000000000..c4c969111a5 --- /dev/null +++ b/tools/migration/paramiko/rng_posix.py @@ -0,0 +1,97 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +# Copyright (C) 2008 Dwayne C. Litzenberger +# Copyright (C) 2008 Open Systems Canada Limited +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +import os +import stat + +class error(Exception): + pass + +class _RNG(object): + def __init__(self, file): + self.file = file + + def read(self, bytes): + return self.file.read(bytes) + + def close(self): + return self.file.close() + + def randomize(self): + return + +def open_rng_device(device_path=None): + """Open /dev/urandom and perform some sanity checks.""" + + f = None + g = None + + if device_path is None: + device_path = "/dev/urandom" + + try: + # Try to open /dev/urandom now so that paramiko will be able to access + # it even if os.chroot() is invoked later. + try: + f = open(device_path, "rb", 0) + except EnvironmentError: + raise error("Unable to open /dev/urandom") + + # Open a second file descriptor for sanity checking later. + try: + g = open(device_path, "rb", 0) + except EnvironmentError: + raise error("Unable to open /dev/urandom") + + # Check that /dev/urandom is a character special device, not a regular file. + st = os.fstat(f.fileno()) # f + if stat.S_ISREG(st.st_mode) or not stat.S_ISCHR(st.st_mode): + raise error("/dev/urandom is not a character special device") + + st = os.fstat(g.fileno()) # g + if stat.S_ISREG(st.st_mode) or not stat.S_ISCHR(st.st_mode): + raise error("/dev/urandom is not a character special device") + + # Check that /dev/urandom always returns the number of bytes requested + x = f.read(20) + y = g.read(20) + if len(x) != 20 or len(y) != 20: + raise error("Error reading from /dev/urandom: input truncated") + + # Check that different reads return different data + if x == y: + raise error("/dev/urandom is broken; returning identical data: %r == %r" % (x, y)) + + # Close the duplicate file object + g.close() + + # Return the first file object + return _RNG(f) + + except error: + if f is not None: + f.close() + if g is not None: + g.close() + raise + +# vim:set ts=4 sw=4 sts=4 expandtab: + diff --git a/tools/migration/paramiko/rng_win32.py b/tools/migration/paramiko/rng_win32.py new file mode 100644 index 00000000000..3cb8b84cb4f --- /dev/null +++ b/tools/migration/paramiko/rng_win32.py @@ -0,0 +1,121 @@ +#!/usr/bin/python +# -*- coding: ascii -*- +# Copyright (C) 2008 Dwayne C. Litzenberger +# Copyright (C) 2008 Open Systems Canada Limited +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. + +class error(Exception): + pass + +# Try to import the "winrandom" module +try: + from Crypto.Util import winrandom as _winrandom +except ImportError: + _winrandom = None + +# Try to import the "urandom" module +try: + from os import urandom as _urandom +except ImportError: + _urandom = None + + +class _RNG(object): + def __init__(self, readfunc): + self.read = readfunc + + def randomize(self): + # According to "Cryptanalysis of the Random Number Generator of the + # Windows Operating System", by Leo Dorrendorf and Zvi Gutterman + # and Benny Pinkas , + # CryptGenRandom only updates its internal state using kernel-provided + # random data every 128KiB of output. + self.read(128*1024) # discard 128 KiB of output + +def _open_winrandom(): + if _winrandom is None: + raise error("Crypto.Util.winrandom module not found") + + # Check that we can open the winrandom module + try: + r0 = _winrandom.new() + r1 = _winrandom.new() + except Exception, exc: + raise error("winrandom.new() failed: %s" % str(exc), exc) + + # Check that we can read from the winrandom module + try: + x = r0.get_bytes(20) + y = r1.get_bytes(20) + except Exception, exc: + raise error("winrandom get_bytes failed: %s" % str(exc), exc) + + # Check that the requested number of bytes are returned + if len(x) != 20 or len(y) != 20: + raise error("Error reading from winrandom: input truncated") + + # Check that different reads return different data + if x == y: + raise error("winrandom broken: returning identical data") + + return _RNG(r0.get_bytes) + +def _open_urandom(): + if _urandom is None: + raise error("os.urandom function not found") + + # Check that we can read from os.urandom() + try: + x = _urandom(20) + y = _urandom(20) + except Exception, exc: + raise error("os.urandom failed: %s" % str(exc), exc) + + # Check that the requested number of bytes are returned + if len(x) != 20 or len(y) != 20: + raise error("os.urandom failed: input truncated") + + # Check that different reads return different data + if x == y: + raise error("os.urandom failed: returning identical data") + + return _RNG(_urandom) + +def open_rng_device(): + # Try using the Crypto.Util.winrandom module + try: + return _open_winrandom() + except error: + pass + + # Several versions of PyCrypto do not contain the winrandom module, but + # Python >= 2.4 has os.urandom, so try to use that. + try: + return _open_urandom() + except error: + pass + + # SECURITY NOTE: DO NOT USE Crypto.Util.randpool.RandomPool HERE! + # If we got to this point, RandomPool will silently run with very little + # entropy. (This is current as of PyCrypto 2.0.1). + # See http://www.lag.net/pipermail/paramiko/2008-January/000599.html + # and http://www.lag.net/pipermail/paramiko/2008-April/000678.html + + raise error("Unable to find a strong random entropy source. You cannot run this software securely under the current configuration.") + +# vim:set ts=4 sw=4 sts=4 expandtab: diff --git a/tools/migration/paramiko/rsakey.py b/tools/migration/paramiko/rsakey.py new file mode 100644 index 00000000000..a6652791af7 --- /dev/null +++ b/tools/migration/paramiko/rsakey.py @@ -0,0 +1,186 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{RSAKey} +""" + +from Crypto.PublicKey import RSA +from Crypto.Hash import SHA, MD5 +from Crypto.Cipher import DES3 + +from paramiko.common import * +from paramiko import util +from paramiko.message import Message +from paramiko.ber import BER, BERException +from paramiko.pkey import PKey +from paramiko.ssh_exception import SSHException + + +class RSAKey (PKey): + """ + Representation of an RSA key which can be used to sign and verify SSH2 + data. + """ + + def __init__(self, msg=None, data=None, filename=None, password=None, vals=None, file_obj=None): + self.n = None + self.e = None + self.d = None + self.p = None + self.q = None + if file_obj is not None: + self._from_private_key(file_obj, password) + return + if filename is not None: + self._from_private_key_file(filename, password) + return + if (msg is None) and (data is not None): + msg = Message(data) + if vals is not None: + self.e, self.n = vals + else: + if msg is None: + raise SSHException('Key object may not be empty') + if msg.get_string() != 'ssh-rsa': + raise SSHException('Invalid key') + self.e = msg.get_mpint() + self.n = msg.get_mpint() + self.size = util.bit_length(self.n) + + def __str__(self): + m = Message() + m.add_string('ssh-rsa') + m.add_mpint(self.e) + m.add_mpint(self.n) + return str(m) + + def __hash__(self): + h = hash(self.get_name()) + h = h * 37 + hash(self.e) + h = h * 37 + hash(self.n) + return hash(h) + + def get_name(self): + return 'ssh-rsa' + + def get_bits(self): + return self.size + + def can_sign(self): + return self.d is not None + + def sign_ssh_data(self, rpool, data): + digest = SHA.new(data).digest() + rsa = RSA.construct((long(self.n), long(self.e), long(self.d))) + sig = util.deflate_long(rsa.sign(self._pkcs1imify(digest), '')[0], 0) + m = Message() + m.add_string('ssh-rsa') + m.add_string(sig) + return m + + def verify_ssh_sig(self, data, msg): + if msg.get_string() != 'ssh-rsa': + return False + sig = util.inflate_long(msg.get_string(), True) + # verify the signature by SHA'ing the data and encrypting it using the + # public key. some wackiness ensues where we "pkcs1imify" the 20-byte + # hash into a string as long as the RSA key. + hash_obj = util.inflate_long(self._pkcs1imify(SHA.new(data).digest()), True) + rsa = RSA.construct((long(self.n), long(self.e))) + return rsa.verify(hash_obj, (sig,)) + + def _encode_key(self): + if (self.p is None) or (self.q is None): + raise SSHException('Not enough key info to write private key file') + keylist = [ 0, self.n, self.e, self.d, self.p, self.q, + self.d % (self.p - 1), self.d % (self.q - 1), + util.mod_inverse(self.q, self.p) ] + try: + b = BER() + b.encode(keylist) + except BERException: + raise SSHException('Unable to create ber encoding of key') + return str(b) + + def write_private_key_file(self, filename, password=None): + self._write_private_key_file('RSA', filename, self._encode_key(), password) + + def write_private_key(self, file_obj, password=None): + self._write_private_key('RSA', file_obj, self._encode_key(), password) + + def generate(bits, progress_func=None): + """ + Generate a new private RSA key. This factory function can be used to + generate a new host key or authentication key. + + @param bits: number of bits the generated key should be. + @type bits: int + @param progress_func: an optional function to call at key points in + key generation (used by C{pyCrypto.PublicKey}). + @type progress_func: function + @return: new private key + @rtype: L{RSAKey} + """ + randpool.stir() + rsa = RSA.generate(bits, randpool.get_bytes, progress_func) + key = RSAKey(vals=(rsa.e, rsa.n)) + key.d = rsa.d + key.p = rsa.p + key.q = rsa.q + return key + generate = staticmethod(generate) + + + ### internals... + + + def _pkcs1imify(self, data): + """ + turn a 20-byte SHA1 hash into a blob of data as large as the key's N, + using PKCS1's \"emsa-pkcs1-v1_5\" encoding. totally bizarre. + """ + SHA1_DIGESTINFO = '\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14' + size = len(util.deflate_long(self.n, 0)) + filler = '\xff' * (size - len(SHA1_DIGESTINFO) - len(data) - 3) + return '\x00\x01' + filler + '\x00' + SHA1_DIGESTINFO + data + + def _from_private_key_file(self, filename, password): + data = self._read_private_key_file('RSA', filename, password) + self._decode_key(data) + + def _from_private_key(self, file_obj, password): + data = self._read_private_key('RSA', file_obj, password) + self._decode_key(data) + + def _decode_key(self, data): + # private key file contains: + # RSAPrivateKey = { version = 0, n, e, d, p, q, d mod p-1, d mod q-1, q**-1 mod p } + try: + keylist = BER(data).decode() + except BERException: + raise SSHException('Unable to parse key file') + if (type(keylist) is not list) or (len(keylist) < 4) or (keylist[0] != 0): + raise SSHException('Not a valid RSA private key file (bad ber encoding)') + self.n = keylist[1] + self.e = keylist[2] + self.d = keylist[3] + # not really needed + self.p = keylist[4] + self.q = keylist[5] + self.size = util.bit_length(self.n) diff --git a/tools/migration/paramiko/server.py b/tools/migration/paramiko/server.py new file mode 100644 index 00000000000..6424b63a4ed --- /dev/null +++ b/tools/migration/paramiko/server.py @@ -0,0 +1,632 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{ServerInterface} is an interface to override for server support. +""" + +import threading +from paramiko.common import * +from paramiko import util + + +class InteractiveQuery (object): + """ + A query (set of prompts) for a user during interactive authentication. + """ + + def __init__(self, name='', instructions='', *prompts): + """ + Create a new interactive query to send to the client. The name and + instructions are optional, but are generally displayed to the end + user. A list of prompts may be included, or they may be added via + the L{add_prompt} method. + + @param name: name of this query + @type name: str + @param instructions: user instructions (usually short) about this query + @type instructions: str + @param prompts: one or more authentication prompts + @type prompts: str + """ + self.name = name + self.instructions = instructions + self.prompts = [] + for x in prompts: + if (type(x) is str) or (type(x) is unicode): + self.add_prompt(x) + else: + self.add_prompt(x[0], x[1]) + + def add_prompt(self, prompt, echo=True): + """ + Add a prompt to this query. The prompt should be a (reasonably short) + string. Multiple prompts can be added to the same query. + + @param prompt: the user prompt + @type prompt: str + @param echo: C{True} (default) if the user's response should be echoed; + C{False} if not (for a password or similar) + @type echo: bool + """ + self.prompts.append((prompt, echo)) + + +class ServerInterface (object): + """ + This class defines an interface for controlling the behavior of paramiko + in server mode. + + Methods on this class are called from paramiko's primary thread, so you + shouldn't do too much work in them. (Certainly nothing that blocks or + sleeps.) + """ + + def check_channel_request(self, kind, chanid): + """ + Determine if a channel request of a given type will be granted, and + return C{OPEN_SUCCEEDED} or an error code. This method is + called in server mode when the client requests a channel, after + authentication is complete. + + If you allow channel requests (and an ssh server that didn't would be + useless), you should also override some of the channel request methods + below, which are used to determine which services will be allowed on + a given channel: + - L{check_channel_pty_request} + - L{check_channel_shell_request} + - L{check_channel_subsystem_request} + - L{check_channel_window_change_request} + - L{check_channel_x11_request} + + The C{chanid} parameter is a small number that uniquely identifies the + channel within a L{Transport}. A L{Channel} object is not created + unless this method returns C{OPEN_SUCCEEDED} -- once a + L{Channel} object is created, you can call L{Channel.get_id} to + retrieve the channel ID. + + The return value should either be C{OPEN_SUCCEEDED} (or + C{0}) to allow the channel request, or one of the following error + codes to reject it: + - C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED} + - C{OPEN_FAILED_CONNECT_FAILED} + - C{OPEN_FAILED_UNKNOWN_CHANNEL_TYPE} + - C{OPEN_FAILED_RESOURCE_SHORTAGE} + + The default implementation always returns + C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED}. + + @param kind: the kind of channel the client would like to open + (usually C{"session"}). + @type kind: str + @param chanid: ID of the channel + @type chanid: int + @return: a success or failure code (listed above) + @rtype: int + """ + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + + def get_allowed_auths(self, username): + """ + Return a list of authentication methods supported by the server. + This list is sent to clients attempting to authenticate, to inform them + of authentication methods that might be successful. + + The "list" is actually a string of comma-separated names of types of + authentication. Possible values are C{"password"}, C{"publickey"}, + and C{"none"}. + + The default implementation always returns C{"password"}. + + @param username: the username requesting authentication. + @type username: str + @return: a comma-separated list of authentication types + @rtype: str + """ + return 'password' + + def check_auth_none(self, username): + """ + Determine if a client may open channels with no (further) + authentication. + + Return L{AUTH_FAILED} if the client must authenticate, or + L{AUTH_SUCCESSFUL} if it's okay for the client to not + authenticate. + + The default implementation always returns L{AUTH_FAILED}. + + @param username: the username of the client. + @type username: str + @return: L{AUTH_FAILED} if the authentication fails; + L{AUTH_SUCCESSFUL} if it succeeds. + @rtype: int + """ + return AUTH_FAILED + + def check_auth_password(self, username, password): + """ + Determine if a given username and password supplied by the client is + acceptable for use in authentication. + + Return L{AUTH_FAILED} if the password is not accepted, + L{AUTH_SUCCESSFUL} if the password is accepted and completes + the authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your + authentication is stateful, and this key is accepted for + authentication, but more authentication is required. (In this latter + case, L{get_allowed_auths} will be called to report to the client what + options it has for continuing the authentication.) + + The default implementation always returns L{AUTH_FAILED}. + + @param username: the username of the authenticating client. + @type username: str + @param password: the password given by the client. + @type password: str + @return: L{AUTH_FAILED} if the authentication fails; + L{AUTH_SUCCESSFUL} if it succeeds; + L{AUTH_PARTIALLY_SUCCESSFUL} if the password auth is + successful, but authentication must continue. + @rtype: int + """ + return AUTH_FAILED + + def check_auth_publickey(self, username, key): + """ + Determine if a given key supplied by the client is acceptable for use + in authentication. You should override this method in server mode to + check the username and key and decide if you would accept a signature + made using this key. + + Return L{AUTH_FAILED} if the key is not accepted, + L{AUTH_SUCCESSFUL} if the key is accepted and completes the + authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your + authentication is stateful, and this password is accepted for + authentication, but more authentication is required. (In this latter + case, L{get_allowed_auths} will be called to report to the client what + options it has for continuing the authentication.) + + Note that you don't have to actually verify any key signtature here. + If you're willing to accept the key, paramiko will do the work of + verifying the client's signature. + + The default implementation always returns L{AUTH_FAILED}. + + @param username: the username of the authenticating client + @type username: str + @param key: the key object provided by the client + @type key: L{PKey } + @return: L{AUTH_FAILED} if the client can't authenticate + with this key; L{AUTH_SUCCESSFUL} if it can; + L{AUTH_PARTIALLY_SUCCESSFUL} if it can authenticate with + this key but must continue with authentication + @rtype: int + """ + return AUTH_FAILED + + def check_auth_interactive(self, username, submethods): + """ + Begin an interactive authentication challenge, if supported. You + should override this method in server mode if you want to support the + C{"keyboard-interactive"} auth type, which requires you to send a + series of questions for the client to answer. + + Return L{AUTH_FAILED} if this auth method isn't supported. Otherwise, + you should return an L{InteractiveQuery} object containing the prompts + and instructions for the user. The response will be sent via a call + to L{check_auth_interactive_response}. + + The default implementation always returns L{AUTH_FAILED}. + + @param username: the username of the authenticating client + @type username: str + @param submethods: a comma-separated list of methods preferred by the + client (usually empty) + @type submethods: str + @return: L{AUTH_FAILED} if this auth method isn't supported; otherwise + an object containing queries for the user + @rtype: int or L{InteractiveQuery} + """ + return AUTH_FAILED + + def check_auth_interactive_response(self, responses): + """ + Continue or finish an interactive authentication challenge, if + supported. You should override this method in server mode if you want + to support the C{"keyboard-interactive"} auth type. + + Return L{AUTH_FAILED} if the responses are not accepted, + L{AUTH_SUCCESSFUL} if the responses are accepted and complete + the authentication, or L{AUTH_PARTIALLY_SUCCESSFUL} if your + authentication is stateful, and this set of responses is accepted for + authentication, but more authentication is required. (In this latter + case, L{get_allowed_auths} will be called to report to the client what + options it has for continuing the authentication.) + + If you wish to continue interactive authentication with more questions, + you may return an L{InteractiveQuery} object, which should cause the + client to respond with more answers, calling this method again. This + cycle can continue indefinitely. + + The default implementation always returns L{AUTH_FAILED}. + + @param responses: list of responses from the client + @type responses: list(str) + @return: L{AUTH_FAILED} if the authentication fails; + L{AUTH_SUCCESSFUL} if it succeeds; + L{AUTH_PARTIALLY_SUCCESSFUL} if the interactive auth is + successful, but authentication must continue; otherwise an object + containing queries for the user + @rtype: int or L{InteractiveQuery} + """ + return AUTH_FAILED + + def check_port_forward_request(self, address, port): + """ + Handle a request for port forwarding. The client is asking that + connections to the given address and port be forwarded back across + this ssh connection. An address of C{"0.0.0.0"} indicates a global + address (any address associated with this server) and a port of C{0} + indicates that no specific port is requested (usually the OS will pick + a port). + + The default implementation always returns C{False}, rejecting the + port forwarding request. If the request is accepted, you should return + the port opened for listening. + + @param address: the requested address + @type address: str + @param port: the requested port + @type port: int + @return: the port number that was opened for listening, or C{False} to + reject + @rtype: int + """ + return False + + def cancel_port_forward_request(self, address, port): + """ + The client would like to cancel a previous port-forwarding request. + If the given address and port is being forwarded across this ssh + connection, the port should be closed. + + @param address: the forwarded address + @type address: str + @param port: the forwarded port + @type port: int + """ + pass + + def check_global_request(self, kind, msg): + """ + Handle a global request of the given C{kind}. This method is called + in server mode and client mode, whenever the remote host makes a global + request. If there are any arguments to the request, they will be in + C{msg}. + + There aren't any useful global requests defined, aside from port + forwarding, so usually this type of request is an extension to the + protocol. + + If the request was successful and you would like to return contextual + data to the remote host, return a tuple. Items in the tuple will be + sent back with the successful result. (Note that the items in the + tuple can only be strings, ints, longs, or bools.) + + The default implementation always returns C{False}, indicating that it + does not support any global requests. + + @note: Port forwarding requests are handled separately, in + L{check_port_forward_request}. + + @param kind: the kind of global request being made. + @type kind: str + @param msg: any extra arguments to the request. + @type msg: L{Message} + @return: C{True} or a tuple of data if the request was granted; + C{False} otherwise. + @rtype: bool + """ + return False + + + ### Channel requests + + + def check_channel_pty_request(self, channel, term, width, height, pixelwidth, pixelheight, + modes): + """ + Determine if a pseudo-terminal of the given dimensions (usually + requested for shell access) can be provided on the given channel. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the pty request arrived on. + @type channel: L{Channel} + @param term: type of terminal requested (for example, C{"vt100"}). + @type term: str + @param width: width of screen in characters. + @type width: int + @param height: height of screen in characters. + @type height: int + @param pixelwidth: width of screen in pixels, if known (may be C{0} if + unknown). + @type pixelwidth: int + @param pixelheight: height of screen in pixels, if known (may be C{0} + if unknown). + @type pixelheight: int + @return: C{True} if the psuedo-terminal has been allocated; C{False} + otherwise. + @rtype: bool + """ + return False + + def check_channel_shell_request(self, channel): + """ + Determine if a shell will be provided to the client on the given + channel. If this method returns C{True}, the channel should be + connected to the stdin/stdout of a shell (or something that acts like + a shell). + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the request arrived on. + @type channel: L{Channel} + @return: C{True} if this channel is now hooked up to a shell; C{False} + if a shell can't or won't be provided. + @rtype: bool + """ + return False + + def check_channel_exec_request(self, channel, command): + """ + Determine if a shell command will be executed for the client. If this + method returns C{True}, the channel should be connected to the stdin, + stdout, and stderr of the shell command. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the request arrived on. + @type channel: L{Channel} + @param command: the command to execute. + @type command: str + @return: C{True} if this channel is now hooked up to the stdin, + stdout, and stderr of the executing command; C{False} if the + command will not be executed. + @rtype: bool + + @since: 1.1 + """ + return False + + def check_channel_subsystem_request(self, channel, name): + """ + Determine if a requested subsystem will be provided to the client on + the given channel. If this method returns C{True}, all future I/O + through this channel will be assumed to be connected to the requested + subsystem. An example of a subsystem is C{sftp}. + + The default implementation checks for a subsystem handler assigned via + L{Transport.set_subsystem_handler}. + If one has been set, the handler is invoked and this method returns + C{True}. Otherwise it returns C{False}. + + @note: Because the default implementation uses the L{Transport} to + identify valid subsystems, you probably won't need to override this + method. + + @param channel: the L{Channel} the pty request arrived on. + @type channel: L{Channel} + @param name: name of the requested subsystem. + @type name: str + @return: C{True} if this channel is now hooked up to the requested + subsystem; C{False} if that subsystem can't or won't be provided. + @rtype: bool + """ + handler_class, larg, kwarg = channel.get_transport()._get_subsystem_handler(name) + if handler_class is None: + return False + handler = handler_class(channel, name, self, *larg, **kwarg) + handler.start() + return True + + def check_channel_window_change_request(self, channel, width, height, pixelwidth, pixelheight): + """ + Determine if the pseudo-terminal on the given channel can be resized. + This only makes sense if a pty was previously allocated on it. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the pty request arrived on. + @type channel: L{Channel} + @param width: width of screen in characters. + @type width: int + @param height: height of screen in characters. + @type height: int + @param pixelwidth: width of screen in pixels, if known (may be C{0} if + unknown). + @type pixelwidth: int + @param pixelheight: height of screen in pixels, if known (may be C{0} + if unknown). + @type pixelheight: int + @return: C{True} if the terminal was resized; C{False} if not. + @rtype: bool + """ + return False + + def check_channel_x11_request(self, channel, single_connection, auth_protocol, auth_cookie, screen_number): + """ + Determine if the client will be provided with an X11 session. If this + method returns C{True}, X11 applications should be routed through new + SSH channels, using L{Transport.open_x11_channel}. + + The default implementation always returns C{False}. + + @param channel: the L{Channel} the X11 request arrived on + @type channel: L{Channel} + @param single_connection: C{True} if only a single X11 channel should + be opened + @type single_connection: bool + @param auth_protocol: the protocol used for X11 authentication + @type auth_protocol: str + @param auth_cookie: the cookie used to authenticate to X11 + @type auth_cookie: str + @param screen_number: the number of the X11 screen to connect to + @type screen_number: int + @return: C{True} if the X11 session was opened; C{False} if not + @rtype: bool + """ + return False + + def check_channel_direct_tcpip_request(self, chanid, origin, destination): + """ + Determine if a local port forwarding channel will be granted, and + return C{OPEN_SUCCEEDED} or an error code. This method is + called in server mode when the client requests a channel, after + authentication is complete. + + The C{chanid} parameter is a small number that uniquely identifies the + channel within a L{Transport}. A L{Channel} object is not created + unless this method returns C{OPEN_SUCCEEDED} -- once a + L{Channel} object is created, you can call L{Channel.get_id} to + retrieve the channel ID. + + The origin and destination parameters are (ip_address, port) tuples + that correspond to both ends of the TCP connection in the forwarding + tunnel. + + The return value should either be C{OPEN_SUCCEEDED} (or + C{0}) to allow the channel request, or one of the following error + codes to reject it: + - C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED} + - C{OPEN_FAILED_CONNECT_FAILED} + - C{OPEN_FAILED_UNKNOWN_CHANNEL_TYPE} + - C{OPEN_FAILED_RESOURCE_SHORTAGE} + + The default implementation always returns + C{OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED}. + + @param chanid: ID of the channel + @type chanid: int + @param origin: 2-tuple containing the IP address and port of the + originator (client side) + @type origin: tuple + @param destination: 2-tuple containing the IP address and port of the + destination (server side) + @type destination: tuple + @return: a success or failure code (listed above) + @rtype: int + """ + return OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + + +class SubsystemHandler (threading.Thread): + """ + Handler for a subsytem in server mode. If you create a subclass of this + class and pass it to + L{Transport.set_subsystem_handler}, + an object of this + class will be created for each request for this subsystem. Each new object + will be executed within its own new thread by calling L{start_subsystem}. + When that method completes, the channel is closed. + + For example, if you made a subclass C{MP3Handler} and registered it as the + handler for subsystem C{"mp3"}, then whenever a client has successfully + authenticated and requests subsytem C{"mp3"}, an object of class + C{MP3Handler} will be created, and L{start_subsystem} will be called on + it from a new thread. + """ + def __init__(self, channel, name, server): + """ + Create a new handler for a channel. This is used by L{ServerInterface} + to start up a new handler when a channel requests this subsystem. You + don't need to override this method, but if you do, be sure to pass the + C{channel} and C{name} parameters through to the original C{__init__} + method here. + + @param channel: the channel associated with this subsystem request. + @type channel: L{Channel} + @param name: name of the requested subsystem. + @type name: str + @param server: the server object for the session that started this + subsystem + @type server: L{ServerInterface} + """ + threading.Thread.__init__(self, target=self._run) + self.__channel = channel + self.__transport = channel.get_transport() + self.__name = name + self.__server = server + + def get_server(self): + """ + Return the L{ServerInterface} object associated with this channel and + subsystem. + + @rtype: L{ServerInterface} + """ + return self.__server + + def _run(self): + try: + self.__transport._log(DEBUG, 'Starting handler for subsystem %s' % self.__name) + self.start_subsystem(self.__name, self.__transport, self.__channel) + except Exception, e: + self.__transport._log(ERROR, 'Exception in subsystem handler for "%s": %s' % + (self.__name, str(e))) + self.__transport._log(ERROR, util.tb_strings()) + try: + self.finish_subsystem() + except: + pass + + def start_subsystem(self, name, transport, channel): + """ + Process an ssh subsystem in server mode. This method is called on a + new object (and in a new thread) for each subsystem request. It is + assumed that all subsystem logic will take place here, and when the + subsystem is finished, this method will return. After this method + returns, the channel is closed. + + The combination of C{transport} and C{channel} are unique; this handler + corresponds to exactly one L{Channel} on one L{Transport}. + + @note: It is the responsibility of this method to exit if the + underlying L{Transport} is closed. This can be done by checking + L{Transport.is_active} or noticing an EOF + on the L{Channel}. If this method loops forever without checking + for this case, your python interpreter may refuse to exit because + this thread will still be running. + + @param name: name of the requested subsystem. + @type name: str + @param transport: the server-mode L{Transport}. + @type transport: L{Transport} + @param channel: the channel associated with this subsystem request. + @type channel: L{Channel} + """ + pass + + def finish_subsystem(self): + """ + Perform any cleanup at the end of a subsystem. The default + implementation just closes the channel. + + @since: 1.1 + """ + self.__channel.close() diff --git a/tools/migration/paramiko/sftp.py b/tools/migration/paramiko/sftp.py new file mode 100644 index 00000000000..a0b08e0272c --- /dev/null +++ b/tools/migration/paramiko/sftp.py @@ -0,0 +1,188 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +import select +import socket +import struct + +from paramiko.common import * +from paramiko import util +from paramiko.channel import Channel +from paramiko.message import Message + + +CMD_INIT, CMD_VERSION, CMD_OPEN, CMD_CLOSE, CMD_READ, CMD_WRITE, CMD_LSTAT, CMD_FSTAT, \ + CMD_SETSTAT, CMD_FSETSTAT, CMD_OPENDIR, CMD_READDIR, CMD_REMOVE, CMD_MKDIR, \ + CMD_RMDIR, CMD_REALPATH, CMD_STAT, CMD_RENAME, CMD_READLINK, CMD_SYMLINK \ + = range(1, 21) +CMD_STATUS, CMD_HANDLE, CMD_DATA, CMD_NAME, CMD_ATTRS = range(101, 106) +CMD_EXTENDED, CMD_EXTENDED_REPLY = range(200, 202) + +SFTP_OK = 0 +SFTP_EOF, SFTP_NO_SUCH_FILE, SFTP_PERMISSION_DENIED, SFTP_FAILURE, SFTP_BAD_MESSAGE, \ + SFTP_NO_CONNECTION, SFTP_CONNECTION_LOST, SFTP_OP_UNSUPPORTED = range(1, 9) + +SFTP_DESC = [ 'Success', + 'End of file', + 'No such file', + 'Permission denied', + 'Failure', + 'Bad message', + 'No connection', + 'Connection lost', + 'Operation unsupported' ] + +SFTP_FLAG_READ = 0x1 +SFTP_FLAG_WRITE = 0x2 +SFTP_FLAG_APPEND = 0x4 +SFTP_FLAG_CREATE = 0x8 +SFTP_FLAG_TRUNC = 0x10 +SFTP_FLAG_EXCL = 0x20 + +_VERSION = 3 + + +# for debugging +CMD_NAMES = { + CMD_INIT: 'init', + CMD_VERSION: 'version', + CMD_OPEN: 'open', + CMD_CLOSE: 'close', + CMD_READ: 'read', + CMD_WRITE: 'write', + CMD_LSTAT: 'lstat', + CMD_FSTAT: 'fstat', + CMD_SETSTAT: 'setstat', + CMD_FSETSTAT: 'fsetstat', + CMD_OPENDIR: 'opendir', + CMD_READDIR: 'readdir', + CMD_REMOVE: 'remove', + CMD_MKDIR: 'mkdir', + CMD_RMDIR: 'rmdir', + CMD_REALPATH: 'realpath', + CMD_STAT: 'stat', + CMD_RENAME: 'rename', + CMD_READLINK: 'readlink', + CMD_SYMLINK: 'symlink', + CMD_STATUS: 'status', + CMD_HANDLE: 'handle', + CMD_DATA: 'data', + CMD_NAME: 'name', + CMD_ATTRS: 'attrs', + CMD_EXTENDED: 'extended', + CMD_EXTENDED_REPLY: 'extended_reply' + } + + +class SFTPError (Exception): + pass + + +class BaseSFTP (object): + def __init__(self): + self.logger = util.get_logger('paramiko.sftp') + self.sock = None + self.ultra_debug = False + + + ### internals... + + + def _send_version(self): + self._send_packet(CMD_INIT, struct.pack('>I', _VERSION)) + t, data = self._read_packet() + if t != CMD_VERSION: + raise SFTPError('Incompatible sftp protocol') + version = struct.unpack('>I', data[:4])[0] + # if version != _VERSION: + # raise SFTPError('Incompatible sftp protocol') + return version + + def _send_server_version(self): + # winscp will freak out if the server sends version info before the + # client finishes sending INIT. + t, data = self._read_packet() + if t != CMD_INIT: + raise SFTPError('Incompatible sftp protocol') + version = struct.unpack('>I', data[:4])[0] + # advertise that we support "check-file" + extension_pairs = [ 'check-file', 'md5,sha1' ] + msg = Message() + msg.add_int(_VERSION) + msg.add(*extension_pairs) + self._send_packet(CMD_VERSION, str(msg)) + return version + + def _log(self, level, msg, *args): + self.logger.log(level, msg, *args) + + def _write_all(self, out): + while len(out) > 0: + n = self.sock.send(out) + if n <= 0: + raise EOFError() + if n == len(out): + return + out = out[n:] + return + + def _read_all(self, n): + out = '' + while n > 0: + if isinstance(self.sock, socket.socket): + # sometimes sftp is used directly over a socket instead of + # through a paramiko channel. in this case, check periodically + # if the socket is closed. (for some reason, recv() won't ever + # return or raise an exception, but calling select on a closed + # socket will.) + while True: + read, write, err = select.select([ self.sock ], [], [], 0.1) + if len(read) > 0: + x = self.sock.recv(n) + break + else: + x = self.sock.recv(n) + + if len(x) == 0: + raise EOFError() + out += x + n -= len(x) + return out + + def _send_packet(self, t, packet): + #self._log(DEBUG2, 'write: %s (len=%d)' % (CMD_NAMES.get(t, '0x%02x' % t), len(packet))) + out = struct.pack('>I', len(packet) + 1) + chr(t) + packet + if self.ultra_debug: + self._log(DEBUG, util.format_binary(out, 'OUT: ')) + self._write_all(out) + + def _read_packet(self): + x = self._read_all(4) + # most sftp servers won't accept packets larger than about 32k, so + # anything with the high byte set (> 16MB) is just garbage. + if x[0] != '\x00': + raise SFTPError('Garbage packet received') + size = struct.unpack('>I', x)[0] + data = self._read_all(size) + if self.ultra_debug: + self._log(DEBUG, util.format_binary(data, 'IN: ')); + if size > 0: + t = ord(data[0]) + #self._log(DEBUG2, 'read: %s (len=%d)' % (CMD_NAMES.get(t), '0x%02x' % t, len(data)-1)) + return t, data[1:] + return 0, '' diff --git a/tools/migration/paramiko/sftp_attr.py b/tools/migration/paramiko/sftp_attr.py new file mode 100644 index 00000000000..26290bece47 --- /dev/null +++ b/tools/migration/paramiko/sftp_attr.py @@ -0,0 +1,223 @@ +# Copyright (C) 2003-2006 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +import stat +import time +from paramiko.common import * +from paramiko.sftp import * + + +class SFTPAttributes (object): + """ + Representation of the attributes of a file (or proxied file) for SFTP in + client or server mode. It attemps to mirror the object returned by + C{os.stat} as closely as possible, so it may have the following fields, + with the same meanings as those returned by an C{os.stat} object: + - st_size + - st_uid + - st_gid + - st_mode + - st_atime + - st_mtime + + Because SFTP allows flags to have other arbitrary named attributes, these + are stored in a dict named C{attr}. Occasionally, the filename is also + stored, in C{filename}. + """ + + FLAG_SIZE = 1 + FLAG_UIDGID = 2 + FLAG_PERMISSIONS = 4 + FLAG_AMTIME = 8 + FLAG_EXTENDED = 0x80000000L + + def __init__(self): + """ + Create a new (empty) SFTPAttributes object. All fields will be empty. + """ + self._flags = 0 + self.st_size = None + self.st_uid = None + self.st_gid = None + self.st_mode = None + self.st_atime = None + self.st_mtime = None + self.attr = {} + + def from_stat(cls, obj, filename=None): + """ + Create an SFTPAttributes object from an existing C{stat} object (an + object returned by C{os.stat}). + + @param obj: an object returned by C{os.stat} (or equivalent). + @type obj: object + @param filename: the filename associated with this file. + @type filename: str + @return: new L{SFTPAttributes} object with the same attribute fields. + @rtype: L{SFTPAttributes} + """ + attr = cls() + attr.st_size = obj.st_size + attr.st_uid = obj.st_uid + attr.st_gid = obj.st_gid + attr.st_mode = obj.st_mode + attr.st_atime = obj.st_atime + attr.st_mtime = obj.st_mtime + if filename is not None: + attr.filename = filename + return attr + from_stat = classmethod(from_stat) + + def __repr__(self): + return '' % self._debug_str() + + + ### internals... + + + def _from_msg(cls, msg, filename=None, longname=None): + attr = cls() + attr._unpack(msg) + if filename is not None: + attr.filename = filename + if longname is not None: + attr.longname = longname + return attr + _from_msg = classmethod(_from_msg) + + def _unpack(self, msg): + self._flags = msg.get_int() + if self._flags & self.FLAG_SIZE: + self.st_size = msg.get_int64() + if self._flags & self.FLAG_UIDGID: + self.st_uid = msg.get_int() + self.st_gid = msg.get_int() + if self._flags & self.FLAG_PERMISSIONS: + self.st_mode = msg.get_int() + if self._flags & self.FLAG_AMTIME: + self.st_atime = msg.get_int() + self.st_mtime = msg.get_int() + if self._flags & self.FLAG_EXTENDED: + count = msg.get_int() + for i in range(count): + self.attr[msg.get_string()] = msg.get_string() + + def _pack(self, msg): + self._flags = 0 + if self.st_size is not None: + self._flags |= self.FLAG_SIZE + if (self.st_uid is not None) and (self.st_gid is not None): + self._flags |= self.FLAG_UIDGID + if self.st_mode is not None: + self._flags |= self.FLAG_PERMISSIONS + if (self.st_atime is not None) and (self.st_mtime is not None): + self._flags |= self.FLAG_AMTIME + if len(self.attr) > 0: + self._flags |= self.FLAG_EXTENDED + msg.add_int(self._flags) + if self._flags & self.FLAG_SIZE: + msg.add_int64(self.st_size) + if self._flags & self.FLAG_UIDGID: + msg.add_int(self.st_uid) + msg.add_int(self.st_gid) + if self._flags & self.FLAG_PERMISSIONS: + msg.add_int(self.st_mode) + if self._flags & self.FLAG_AMTIME: + # throw away any fractional seconds + msg.add_int(long(self.st_atime)) + msg.add_int(long(self.st_mtime)) + if self._flags & self.FLAG_EXTENDED: + msg.add_int(len(self.attr)) + for key, val in self.attr.iteritems(): + msg.add_string(key) + msg.add_string(val) + return + + def _debug_str(self): + out = '[ ' + if self.st_size is not None: + out += 'size=%d ' % self.st_size + if (self.st_uid is not None) and (self.st_gid is not None): + out += 'uid=%d gid=%d ' % (self.st_uid, self.st_gid) + if self.st_mode is not None: + out += 'mode=' + oct(self.st_mode) + ' ' + if (self.st_atime is not None) and (self.st_mtime is not None): + out += 'atime=%d mtime=%d ' % (self.st_atime, self.st_mtime) + for k, v in self.attr.iteritems(): + out += '"%s"=%r ' % (str(k), v) + out += ']' + return out + + def _rwx(n, suid, sticky=False): + if suid: + suid = 2 + out = '-r'[n >> 2] + '-w'[(n >> 1) & 1] + if sticky: + out += '-xTt'[suid + (n & 1)] + else: + out += '-xSs'[suid + (n & 1)] + return out + _rwx = staticmethod(_rwx) + + def __str__(self): + "create a unix-style long description of the file (like ls -l)" + if self.st_mode is not None: + kind = stat.S_IFMT(self.st_mode) + if kind == stat.S_IFIFO: + ks = 'p' + elif kind == stat.S_IFCHR: + ks = 'c' + elif kind == stat.S_IFDIR: + ks = 'd' + elif kind == stat.S_IFBLK: + ks = 'b' + elif kind == stat.S_IFREG: + ks = '-' + elif kind == stat.S_IFLNK: + ks = 'l' + elif kind == stat.S_IFSOCK: + ks = 's' + else: + ks = '?' + ks += self._rwx((self.st_mode & 0700) >> 6, self.st_mode & stat.S_ISUID) + ks += self._rwx((self.st_mode & 070) >> 3, self.st_mode & stat.S_ISGID) + ks += self._rwx(self.st_mode & 7, self.st_mode & stat.S_ISVTX, True) + else: + ks = '?---------' + # compute display date + if (self.st_mtime is None) or (self.st_mtime == 0xffffffff): + # shouldn't really happen + datestr = '(unknown date)' + else: + if abs(time.time() - self.st_mtime) > 15552000: + # (15552000 = 6 months) + datestr = time.strftime('%d %b %Y', time.localtime(self.st_mtime)) + else: + datestr = time.strftime('%d %b %H:%M', time.localtime(self.st_mtime)) + filename = getattr(self, 'filename', '?') + + # not all servers support uid/gid + uid = self.st_uid + gid = self.st_gid + if uid is None: + uid = 0 + if gid is None: + gid = 0 + + return '%s 1 %-8d %-8d %8d %-12s %s' % (ks, uid, gid, self.st_size, datestr, filename) + diff --git a/tools/migration/paramiko/sftp_client.py b/tools/migration/paramiko/sftp_client.py new file mode 100644 index 00000000000..1f1107552c6 --- /dev/null +++ b/tools/migration/paramiko/sftp_client.py @@ -0,0 +1,726 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Client-mode SFTP support. +""" + +from binascii import hexlify +import errno +import os +import stat +import threading +import time +import weakref + +from paramiko.sftp import * +from paramiko.sftp_attr import SFTPAttributes +from paramiko.ssh_exception import SSHException +from paramiko.sftp_file import SFTPFile + + +def _to_unicode(s): + """ + decode a string as ascii or utf8 if possible (as required by the sftp + protocol). if neither works, just return a byte string because the server + probably doesn't know the filename's encoding. + """ + try: + return s.encode('ascii') + except UnicodeError: + try: + return s.decode('utf-8') + except UnicodeError: + return s + + +class SFTPClient (BaseSFTP): + """ + SFTP client object. C{SFTPClient} is used to open an sftp session across + an open ssh L{Transport} and do remote file operations. + """ + + def __init__(self, sock): + """ + Create an SFTP client from an existing L{Channel}. The channel + should already have requested the C{"sftp"} subsystem. + + An alternate way to create an SFTP client context is by using + L{from_transport}. + + @param sock: an open L{Channel} using the C{"sftp"} subsystem + @type sock: L{Channel} + + @raise SSHException: if there's an exception while negotiating + sftp + """ + BaseSFTP.__init__(self) + self.sock = sock + self.ultra_debug = False + self.request_number = 1 + # lock for request_number + self._lock = threading.Lock() + self._cwd = None + # request # -> SFTPFile + self._expecting = weakref.WeakValueDictionary() + if type(sock) is Channel: + # override default logger + transport = self.sock.get_transport() + self.logger = util.get_logger(transport.get_log_channel() + '.sftp') + self.ultra_debug = transport.get_hexdump() + try: + server_version = self._send_version() + except EOFError, x: + raise SSHException('EOF during negotiation') + self._log(INFO, 'Opened sftp connection (server version %d)' % server_version) + + def from_transport(cls, t): + """ + Create an SFTP client channel from an open L{Transport}. + + @param t: an open L{Transport} which is already authenticated + @type t: L{Transport} + @return: a new L{SFTPClient} object, referring to an sftp session + (channel) across the transport + @rtype: L{SFTPClient} + """ + chan = t.open_session() + if chan is None: + return None + chan.invoke_subsystem('sftp') + return cls(chan) + from_transport = classmethod(from_transport) + + def _log(self, level, msg, *args): + if isinstance(msg, list): + for m in msg: + super(SFTPClient, self)._log(level, "[chan %s] " + m, *([ self.sock.get_name() ] + list(args))) + else: + super(SFTPClient, self)._log(level, "[chan %s] " + msg, *([ self.sock.get_name() ] + list(args))) + + def close(self): + """ + Close the SFTP session and its underlying channel. + + @since: 1.4 + """ + self._log(INFO, 'sftp session closed.') + self.sock.close() + + def get_channel(self): + """ + Return the underlying L{Channel} object for this SFTP session. This + might be useful for doing things like setting a timeout on the channel. + + @return: the SSH channel + @rtype: L{Channel} + + @since: 1.7.1 + """ + return self.sock + + def listdir(self, path='.'): + """ + Return a list containing the names of the entries in the given C{path}. + The list is in arbitrary order. It does not include the special + entries C{'.'} and C{'..'} even if they are present in the folder. + This method is meant to mirror C{os.listdir} as closely as possible. + For a list of full L{SFTPAttributes} objects, see L{listdir_attr}. + + @param path: path to list (defaults to C{'.'}) + @type path: str + @return: list of filenames + @rtype: list of str + """ + return [f.filename for f in self.listdir_attr(path)] + + def listdir_attr(self, path='.'): + """ + Return a list containing L{SFTPAttributes} objects corresponding to + files in the given C{path}. The list is in arbitrary order. It does + not include the special entries C{'.'} and C{'..'} even if they are + present in the folder. + + The returned L{SFTPAttributes} objects will each have an additional + field: C{longname}, which may contain a formatted string of the file's + attributes, in unix format. The content of this string will probably + depend on the SFTP server implementation. + + @param path: path to list (defaults to C{'.'}) + @type path: str + @return: list of attributes + @rtype: list of L{SFTPAttributes} + + @since: 1.2 + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'listdir(%r)' % path) + t, msg = self._request(CMD_OPENDIR, path) + if t != CMD_HANDLE: + raise SFTPError('Expected handle') + handle = msg.get_string() + filelist = [] + while True: + try: + t, msg = self._request(CMD_READDIR, handle) + except EOFError, e: + # done with handle + break + if t != CMD_NAME: + raise SFTPError('Expected name response') + count = msg.get_int() + for i in range(count): + filename = _to_unicode(msg.get_string()) + longname = _to_unicode(msg.get_string()) + attr = SFTPAttributes._from_msg(msg, filename, longname) + if (filename != '.') and (filename != '..'): + filelist.append(attr) + self._request(CMD_CLOSE, handle) + return filelist + + def open(self, filename, mode='r', bufsize=-1): + """ + Open a file on the remote server. The arguments are the same as for + python's built-in C{file} (aka C{open}). A file-like object is + returned, which closely mimics the behavior of a normal python file + object. + + The mode indicates how the file is to be opened: C{'r'} for reading, + C{'w'} for writing (truncating an existing file), C{'a'} for appending, + C{'r+'} for reading/writing, C{'w+'} for reading/writing (truncating an + existing file), C{'a+'} for reading/appending. The python C{'b'} flag + is ignored, since SSH treats all files as binary. The C{'U'} flag is + supported in a compatible way. + + Since 1.5.2, an C{'x'} flag indicates that the operation should only + succeed if the file was created and did not previously exist. This has + no direct mapping to python's file flags, but is commonly known as the + C{O_EXCL} flag in posix. + + The file will be buffered in standard python style by default, but + can be altered with the C{bufsize} parameter. C{0} turns off + buffering, C{1} uses line buffering, and any number greater than 1 + (C{>1}) uses that specific buffer size. + + @param filename: name of the file to open + @type filename: str + @param mode: mode (python-style) to open in + @type mode: str + @param bufsize: desired buffering (-1 = default buffer size) + @type bufsize: int + @return: a file object representing the open file + @rtype: SFTPFile + + @raise IOError: if the file could not be opened. + """ + filename = self._adjust_cwd(filename) + self._log(DEBUG, 'open(%r, %r)' % (filename, mode)) + imode = 0 + if ('r' in mode) or ('+' in mode): + imode |= SFTP_FLAG_READ + if ('w' in mode) or ('+' in mode) or ('a' in mode): + imode |= SFTP_FLAG_WRITE + if ('w' in mode): + imode |= SFTP_FLAG_CREATE | SFTP_FLAG_TRUNC + if ('a' in mode): + imode |= SFTP_FLAG_CREATE | SFTP_FLAG_APPEND + if ('x' in mode): + imode |= SFTP_FLAG_CREATE | SFTP_FLAG_EXCL + attrblock = SFTPAttributes() + t, msg = self._request(CMD_OPEN, filename, imode, attrblock) + if t != CMD_HANDLE: + raise SFTPError('Expected handle') + handle = msg.get_string() + self._log(DEBUG, 'open(%r, %r) -> %s' % (filename, mode, hexlify(handle))) + return SFTPFile(self, handle, mode, bufsize) + + # python continues to vacillate about "open" vs "file"... + file = open + + def remove(self, path): + """ + Remove the file at the given path. This only works on files; for + removing folders (directories), use L{rmdir}. + + @param path: path (absolute or relative) of the file to remove + @type path: str + + @raise IOError: if the path refers to a folder (directory) + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'remove(%r)' % path) + self._request(CMD_REMOVE, path) + + unlink = remove + + def rename(self, oldpath, newpath): + """ + Rename a file or folder from C{oldpath} to C{newpath}. + + @param oldpath: existing name of the file or folder + @type oldpath: str + @param newpath: new name for the file or folder + @type newpath: str + + @raise IOError: if C{newpath} is a folder, or something else goes + wrong + """ + oldpath = self._adjust_cwd(oldpath) + newpath = self._adjust_cwd(newpath) + self._log(DEBUG, 'rename(%r, %r)' % (oldpath, newpath)) + self._request(CMD_RENAME, oldpath, newpath) + + def mkdir(self, path, mode=0777): + """ + Create a folder (directory) named C{path} with numeric mode C{mode}. + The default mode is 0777 (octal). On some systems, mode is ignored. + Where it is used, the current umask value is first masked out. + + @param path: name of the folder to create + @type path: str + @param mode: permissions (posix-style) for the newly-created folder + @type mode: int + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'mkdir(%r, %r)' % (path, mode)) + attr = SFTPAttributes() + attr.st_mode = mode + self._request(CMD_MKDIR, path, attr) + + def rmdir(self, path): + """ + Remove the folder named C{path}. + + @param path: name of the folder to remove + @type path: str + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'rmdir(%r)' % path) + self._request(CMD_RMDIR, path) + + def stat(self, path): + """ + Retrieve information about a file on the remote system. The return + value is an object whose attributes correspond to the attributes of + python's C{stat} structure as returned by C{os.stat}, except that it + contains fewer fields. An SFTP server may return as much or as little + info as it wants, so the results may vary from server to server. + + Unlike a python C{stat} object, the result may not be accessed as a + tuple. This is mostly due to the author's slack factor. + + The fields supported are: C{st_mode}, C{st_size}, C{st_uid}, C{st_gid}, + C{st_atime}, and C{st_mtime}. + + @param path: the filename to stat + @type path: str + @return: an object containing attributes about the given file + @rtype: SFTPAttributes + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'stat(%r)' % path) + t, msg = self._request(CMD_STAT, path) + if t != CMD_ATTRS: + raise SFTPError('Expected attributes') + return SFTPAttributes._from_msg(msg) + + def lstat(self, path): + """ + Retrieve information about a file on the remote system, without + following symbolic links (shortcuts). This otherwise behaves exactly + the same as L{stat}. + + @param path: the filename to stat + @type path: str + @return: an object containing attributes about the given file + @rtype: SFTPAttributes + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'lstat(%r)' % path) + t, msg = self._request(CMD_LSTAT, path) + if t != CMD_ATTRS: + raise SFTPError('Expected attributes') + return SFTPAttributes._from_msg(msg) + + def symlink(self, source, dest): + """ + Create a symbolic link (shortcut) of the C{source} path at + C{destination}. + + @param source: path of the original file + @type source: str + @param dest: path of the newly created symlink + @type dest: str + """ + dest = self._adjust_cwd(dest) + self._log(DEBUG, 'symlink(%r, %r)' % (source, dest)) + if type(source) is unicode: + source = source.encode('utf-8') + self._request(CMD_SYMLINK, source, dest) + + def chmod(self, path, mode): + """ + Change the mode (permissions) of a file. The permissions are + unix-style and identical to those used by python's C{os.chmod} + function. + + @param path: path of the file to change the permissions of + @type path: str + @param mode: new permissions + @type mode: int + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'chmod(%r, %r)' % (path, mode)) + attr = SFTPAttributes() + attr.st_mode = mode + self._request(CMD_SETSTAT, path, attr) + + def chown(self, path, uid, gid): + """ + Change the owner (C{uid}) and group (C{gid}) of a file. As with + python's C{os.chown} function, you must pass both arguments, so if you + only want to change one, use L{stat} first to retrieve the current + owner and group. + + @param path: path of the file to change the owner and group of + @type path: str + @param uid: new owner's uid + @type uid: int + @param gid: new group id + @type gid: int + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'chown(%r, %r, %r)' % (path, uid, gid)) + attr = SFTPAttributes() + attr.st_uid, attr.st_gid = uid, gid + self._request(CMD_SETSTAT, path, attr) + + def utime(self, path, times): + """ + Set the access and modified times of the file specified by C{path}. If + C{times} is C{None}, then the file's access and modified times are set + to the current time. Otherwise, C{times} must be a 2-tuple of numbers, + of the form C{(atime, mtime)}, which is used to set the access and + modified times, respectively. This bizarre API is mimicked from python + for the sake of consistency -- I apologize. + + @param path: path of the file to modify + @type path: str + @param times: C{None} or a tuple of (access time, modified time) in + standard internet epoch time (seconds since 01 January 1970 GMT) + @type times: tuple(int) + """ + path = self._adjust_cwd(path) + if times is None: + times = (time.time(), time.time()) + self._log(DEBUG, 'utime(%r, %r)' % (path, times)) + attr = SFTPAttributes() + attr.st_atime, attr.st_mtime = times + self._request(CMD_SETSTAT, path, attr) + + def truncate(self, path, size): + """ + Change the size of the file specified by C{path}. This usually extends + or shrinks the size of the file, just like the C{truncate()} method on + python file objects. + + @param path: path of the file to modify + @type path: str + @param size: the new size of the file + @type size: int or long + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'truncate(%r, %r)' % (path, size)) + attr = SFTPAttributes() + attr.st_size = size + self._request(CMD_SETSTAT, path, attr) + + def readlink(self, path): + """ + Return the target of a symbolic link (shortcut). You can use + L{symlink} to create these. The result may be either an absolute or + relative pathname. + + @param path: path of the symbolic link file + @type path: str + @return: target path + @rtype: str + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'readlink(%r)' % path) + t, msg = self._request(CMD_READLINK, path) + if t != CMD_NAME: + raise SFTPError('Expected name response') + count = msg.get_int() + if count == 0: + return None + if count != 1: + raise SFTPError('Readlink returned %d results' % count) + return _to_unicode(msg.get_string()) + + def normalize(self, path): + """ + Return the normalized path (on the server) of a given path. This + can be used to quickly resolve symbolic links or determine what the + server is considering to be the "current folder" (by passing C{'.'} + as C{path}). + + @param path: path to be normalized + @type path: str + @return: normalized form of the given path + @rtype: str + + @raise IOError: if the path can't be resolved on the server + """ + path = self._adjust_cwd(path) + self._log(DEBUG, 'normalize(%r)' % path) + t, msg = self._request(CMD_REALPATH, path) + if t != CMD_NAME: + raise SFTPError('Expected name response') + count = msg.get_int() + if count != 1: + raise SFTPError('Realpath returned %d results' % count) + return _to_unicode(msg.get_string()) + + def chdir(self, path): + """ + Change the "current directory" of this SFTP session. Since SFTP + doesn't really have the concept of a current working directory, this + is emulated by paramiko. Once you use this method to set a working + directory, all operations on this SFTPClient object will be relative + to that path. You can pass in C{None} to stop using a current working + directory. + + @param path: new current working directory + @type path: str + + @raise IOError: if the requested path doesn't exist on the server + + @since: 1.4 + """ + if path is None: + self._cwd = None + return + if not stat.S_ISDIR(self.stat(path).st_mode): + raise SFTPError(errno.ENOTDIR, "%s: %s" % (os.strerror(errno.ENOTDIR), path)) + self._cwd = self.normalize(path).encode('utf-8') + + def getcwd(self): + """ + Return the "current working directory" for this SFTP session, as + emulated by paramiko. If no directory has been set with L{chdir}, + this method will return C{None}. + + @return: the current working directory on the server, or C{None} + @rtype: str + + @since: 1.4 + """ + return self._cwd + + def put(self, localpath, remotepath, callback=None): + """ + Copy a local file (C{localpath}) to the SFTP server as C{remotepath}. + Any exception raised by operations will be passed through. This + method is primarily provided as a convenience. + + The SFTP operations use pipelining for speed. + + @param localpath: the local file to copy + @type localpath: str + @param remotepath: the destination path on the SFTP server + @type remotepath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + (since 1.7.4) + @type callback: function(int, int) + @return: an object containing attributes about the given file + (since 1.7.4) + @rtype: SFTPAttributes + + @since: 1.4 + """ + file_size = os.stat(localpath).st_size + fl = file(localpath, 'rb') + try: + fr = self.file(remotepath, 'wb') + fr.set_pipelined(True) + size = 0 + try: + while True: + data = fl.read(32768) + if len(data) == 0: + break + fr.write(data) + size += len(data) + if callback is not None: + callback(size, file_size) + finally: + fr.close() + finally: + fl.close() + s = self.stat(remotepath) + if s.st_size != size: + raise IOError('size mismatch in put! %d != %d' % (s.st_size, size)) + return s + + def get(self, remotepath, localpath, callback=None): + """ + Copy a remote file (C{remotepath}) from the SFTP server to the local + host as C{localpath}. Any exception raised by operations will be + passed through. This method is primarily provided as a convenience. + + @param remotepath: the remote file to copy + @type remotepath: str + @param localpath: the destination path on the local host + @type localpath: str + @param callback: optional callback function that accepts the bytes + transferred so far and the total bytes to be transferred + (since 1.7.4) + @type callback: function(int, int) + + @since: 1.4 + """ + fr = self.file(remotepath, 'rb') + file_size = self.stat(remotepath).st_size + fr.prefetch() + try: + fl = file(localpath, 'wb') + try: + size = 0 + while True: + data = fr.read(32768) + if len(data) == 0: + break + fl.write(data) + size += len(data) + if callback is not None: + callback(size, file_size) + finally: + fl.close() + finally: + fr.close() + s = os.stat(localpath) + if s.st_size != size: + raise IOError('size mismatch in get! %d != %d' % (s.st_size, size)) + + + ### internals... + + + def _request(self, t, *arg): + num = self._async_request(type(None), t, *arg) + return self._read_response(num) + + def _async_request(self, fileobj, t, *arg): + # this method may be called from other threads (prefetch) + self._lock.acquire() + try: + msg = Message() + msg.add_int(self.request_number) + for item in arg: + if type(item) is int: + msg.add_int(item) + elif type(item) is long: + msg.add_int64(item) + elif type(item) is str: + msg.add_string(item) + elif type(item) is SFTPAttributes: + item._pack(msg) + else: + raise Exception('unknown type for %r type %r' % (item, type(item))) + num = self.request_number + self._expecting[num] = fileobj + self._send_packet(t, str(msg)) + self.request_number += 1 + finally: + self._lock.release() + return num + + def _read_response(self, waitfor=None): + while True: + try: + t, data = self._read_packet() + except EOFError, e: + raise SSHException('Server connection dropped: %s' % (str(e),)) + msg = Message(data) + num = msg.get_int() + if num not in self._expecting: + # might be response for a file that was closed before responses came back + self._log(DEBUG, 'Unexpected response #%d' % (num,)) + if waitfor is None: + # just doing a single check + break + continue + fileobj = self._expecting[num] + del self._expecting[num] + if num == waitfor: + # synchronous + if t == CMD_STATUS: + self._convert_status(msg) + return t, msg + if fileobj is not type(None): + fileobj._async_response(t, msg) + if waitfor is None: + # just doing a single check + break + return (None, None) + + def _finish_responses(self, fileobj): + while fileobj in self._expecting.values(): + self._read_response() + fileobj._check_exception() + + def _convert_status(self, msg): + """ + Raises EOFError or IOError on error status; otherwise does nothing. + """ + code = msg.get_int() + text = msg.get_string() + if code == SFTP_OK: + return + elif code == SFTP_EOF: + raise EOFError(text) + elif code == SFTP_NO_SUCH_FILE: + # clever idea from john a. meinel: map the error codes to errno + raise IOError(errno.ENOENT, text) + elif code == SFTP_PERMISSION_DENIED: + raise IOError(errno.EACCES, text) + else: + raise IOError(text) + + def _adjust_cwd(self, path): + """ + Return an adjusted path if we're emulating a "current working + directory" for the server. + """ + if type(path) is unicode: + path = path.encode('utf-8') + if self._cwd is None: + return path + if (len(path) > 0) and (path[0] == '/'): + # absolute path + return path + if self._cwd == '/': + return self._cwd + path + return self._cwd + '/' + path + + +class SFTP (SFTPClient): + "an alias for L{SFTPClient} for backwards compatability" + pass diff --git a/tools/migration/paramiko/sftp_file.py b/tools/migration/paramiko/sftp_file.py new file mode 100644 index 00000000000..8c5c7aca5eb --- /dev/null +++ b/tools/migration/paramiko/sftp_file.py @@ -0,0 +1,476 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{SFTPFile} +""" + +from binascii import hexlify +import socket +import threading +import time + +from paramiko.common import * +from paramiko.sftp import * +from paramiko.file import BufferedFile +from paramiko.sftp_attr import SFTPAttributes + + +class SFTPFile (BufferedFile): + """ + Proxy object for a file on the remote server, in client mode SFTP. + """ + + # Some sftp servers will choke if you send read/write requests larger than + # this size. + MAX_REQUEST_SIZE = 32768 + + def __init__(self, sftp, handle, mode='r', bufsize=-1): + BufferedFile.__init__(self) + self.sftp = sftp + self.handle = handle + BufferedFile._set_mode(self, mode, bufsize) + self.pipelined = False + self._prefetching = False + self._prefetch_done = False + self._prefetch_data = {} + self._prefetch_reads = [] + self._saved_exception = None + + def __del__(self): + self._close(async=True) + + def close(self): + self._close(async=False) + + def _close(self, async=False): + # We allow double-close without signaling an error, because real + # Python file objects do. However, we must protect against actually + # sending multiple CMD_CLOSE packets, because after we close our + # handle, the same handle may be re-allocated by the server, and we + # may end up mysteriously closing some random other file. (This is + # especially important because we unconditionally call close() from + # __del__.) + if self._closed: + return + self.sftp._log(DEBUG, 'close(%s)' % hexlify(self.handle)) + if self.pipelined: + self.sftp._finish_responses(self) + BufferedFile.close(self) + try: + if async: + # GC'd file handle could be called from an arbitrary thread -- don't wait for a response + self.sftp._async_request(type(None), CMD_CLOSE, self.handle) + else: + self.sftp._request(CMD_CLOSE, self.handle) + except EOFError: + # may have outlived the Transport connection + pass + except (IOError, socket.error): + # may have outlived the Transport connection + pass + + def _data_in_prefetch_requests(self, offset, size): + k = [i for i in self._prefetch_reads if i[0] <= offset] + if len(k) == 0: + return False + k.sort(lambda x, y: cmp(x[0], y[0])) + buf_offset, buf_size = k[-1] + if buf_offset + buf_size <= offset: + # prefetch request ends before this one begins + return False + if buf_offset + buf_size >= offset + size: + # inclusive + return True + # well, we have part of the request. see if another chunk has the rest. + return self._data_in_prefetch_requests(buf_offset + buf_size, offset + size - buf_offset - buf_size) + + def _data_in_prefetch_buffers(self, offset): + """ + if a block of data is present in the prefetch buffers, at the given + offset, return the offset of the relevant prefetch buffer. otherwise, + return None. this guarantees nothing about the number of bytes + collected in the prefetch buffer so far. + """ + k = [i for i in self._prefetch_data.keys() if i <= offset] + if len(k) == 0: + return None + index = max(k) + buf_offset = offset - index + if buf_offset >= len(self._prefetch_data[index]): + # it's not here + return None + return index + + def _read_prefetch(self, size): + """ + read data out of the prefetch buffer, if possible. if the data isn't + in the buffer, return None. otherwise, behaves like a normal read. + """ + # while not closed, and haven't fetched past the current position, and haven't reached EOF... + while True: + offset = self._data_in_prefetch_buffers(self._realpos) + if offset is not None: + break + if self._prefetch_done or self._closed: + break + self.sftp._read_response() + self._check_exception() + if offset is None: + self._prefetching = False + return None + prefetch = self._prefetch_data[offset] + del self._prefetch_data[offset] + + buf_offset = self._realpos - offset + if buf_offset > 0: + self._prefetch_data[offset] = prefetch[:buf_offset] + prefetch = prefetch[buf_offset:] + if size < len(prefetch): + self._prefetch_data[self._realpos + size] = prefetch[size:] + prefetch = prefetch[:size] + return prefetch + + def _read(self, size): + size = min(size, self.MAX_REQUEST_SIZE) + if self._prefetching: + data = self._read_prefetch(size) + if data is not None: + return data + t, msg = self.sftp._request(CMD_READ, self.handle, long(self._realpos), int(size)) + if t != CMD_DATA: + raise SFTPError('Expected data') + return msg.get_string() + + def _write(self, data): + # may write less than requested if it would exceed max packet size + chunk = min(len(data), self.MAX_REQUEST_SIZE) + req = self.sftp._async_request(type(None), CMD_WRITE, self.handle, long(self._realpos), str(data[:chunk])) + if not self.pipelined or self.sftp.sock.recv_ready(): + t, msg = self.sftp._read_response(req) + if t != CMD_STATUS: + raise SFTPError('Expected status') + # convert_status already called + return chunk + + def settimeout(self, timeout): + """ + Set a timeout on read/write operations on the underlying socket or + ssh L{Channel}. + + @see: L{Channel.settimeout} + @param timeout: seconds to wait for a pending read/write operation + before raising C{socket.timeout}, or C{None} for no timeout + @type timeout: float + """ + self.sftp.sock.settimeout(timeout) + + def gettimeout(self): + """ + Returns the timeout in seconds (as a float) associated with the socket + or ssh L{Channel} used for this file. + + @see: L{Channel.gettimeout} + @rtype: float + """ + return self.sftp.sock.gettimeout() + + def setblocking(self, blocking): + """ + Set blocking or non-blocking mode on the underiying socket or ssh + L{Channel}. + + @see: L{Channel.setblocking} + @param blocking: 0 to set non-blocking mode; non-0 to set blocking + mode. + @type blocking: int + """ + self.sftp.sock.setblocking(blocking) + + def seek(self, offset, whence=0): + self.flush() + if whence == self.SEEK_SET: + self._realpos = self._pos = offset + elif whence == self.SEEK_CUR: + self._pos += offset + self._realpos = self._pos + else: + self._realpos = self._pos = self._get_size() + offset + self._rbuffer = '' + + def stat(self): + """ + Retrieve information about this file from the remote system. This is + exactly like L{SFTP.stat}, except that it operates on an already-open + file. + + @return: an object containing attributes about this file. + @rtype: SFTPAttributes + """ + t, msg = self.sftp._request(CMD_FSTAT, self.handle) + if t != CMD_ATTRS: + raise SFTPError('Expected attributes') + return SFTPAttributes._from_msg(msg) + + def chmod(self, mode): + """ + Change the mode (permissions) of this file. The permissions are + unix-style and identical to those used by python's C{os.chmod} + function. + + @param mode: new permissions + @type mode: int + """ + self.sftp._log(DEBUG, 'chmod(%s, %r)' % (hexlify(self.handle), mode)) + attr = SFTPAttributes() + attr.st_mode = mode + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def chown(self, uid, gid): + """ + Change the owner (C{uid}) and group (C{gid}) of this file. As with + python's C{os.chown} function, you must pass both arguments, so if you + only want to change one, use L{stat} first to retrieve the current + owner and group. + + @param uid: new owner's uid + @type uid: int + @param gid: new group id + @type gid: int + """ + self.sftp._log(DEBUG, 'chown(%s, %r, %r)' % (hexlify(self.handle), uid, gid)) + attr = SFTPAttributes() + attr.st_uid, attr.st_gid = uid, gid + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def utime(self, times): + """ + Set the access and modified times of this file. If + C{times} is C{None}, then the file's access and modified times are set + to the current time. Otherwise, C{times} must be a 2-tuple of numbers, + of the form C{(atime, mtime)}, which is used to set the access and + modified times, respectively. This bizarre API is mimicked from python + for the sake of consistency -- I apologize. + + @param times: C{None} or a tuple of (access time, modified time) in + standard internet epoch time (seconds since 01 January 1970 GMT) + @type times: tuple(int) + """ + if times is None: + times = (time.time(), time.time()) + self.sftp._log(DEBUG, 'utime(%s, %r)' % (hexlify(self.handle), times)) + attr = SFTPAttributes() + attr.st_atime, attr.st_mtime = times + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def truncate(self, size): + """ + Change the size of this file. This usually extends + or shrinks the size of the file, just like the C{truncate()} method on + python file objects. + + @param size: the new size of the file + @type size: int or long + """ + self.sftp._log(DEBUG, 'truncate(%s, %r)' % (hexlify(self.handle), size)) + attr = SFTPAttributes() + attr.st_size = size + self.sftp._request(CMD_FSETSTAT, self.handle, attr) + + def check(self, hash_algorithm, offset=0, length=0, block_size=0): + """ + Ask the server for a hash of a section of this file. This can be used + to verify a successful upload or download, or for various rsync-like + operations. + + The file is hashed from C{offset}, for C{length} bytes. If C{length} + is 0, the remainder of the file is hashed. Thus, if both C{offset} + and C{length} are zero, the entire file is hashed. + + Normally, C{block_size} will be 0 (the default), and this method will + return a byte string representing the requested hash (for example, a + string of length 16 for MD5, or 20 for SHA-1). If a non-zero + C{block_size} is given, each chunk of the file (from C{offset} to + C{offset + length}) of C{block_size} bytes is computed as a separate + hash. The hash results are all concatenated and returned as a single + string. + + For example, C{check('sha1', 0, 1024, 512)} will return a string of + length 40. The first 20 bytes will be the SHA-1 of the first 512 bytes + of the file, and the last 20 bytes will be the SHA-1 of the next 512 + bytes. + + @param hash_algorithm: the name of the hash algorithm to use (normally + C{"sha1"} or C{"md5"}) + @type hash_algorithm: str + @param offset: offset into the file to begin hashing (0 means to start + from the beginning) + @type offset: int or long + @param length: number of bytes to hash (0 means continue to the end of + the file) + @type length: int or long + @param block_size: number of bytes to hash per result (must not be less + than 256; 0 means to compute only one hash of the entire segment) + @type block_size: int + @return: string of bytes representing the hash of each block, + concatenated together + @rtype: str + + @note: Many (most?) servers don't support this extension yet. + + @raise IOError: if the server doesn't support the "check-file" + extension, or possibly doesn't support the hash algorithm + requested + + @since: 1.4 + """ + t, msg = self.sftp._request(CMD_EXTENDED, 'check-file', self.handle, + hash_algorithm, long(offset), long(length), block_size) + ext = msg.get_string() + alg = msg.get_string() + data = msg.get_remainder() + return data + + def set_pipelined(self, pipelined=True): + """ + Turn on/off the pipelining of write operations to this file. When + pipelining is on, paramiko won't wait for the server response after + each write operation. Instead, they're collected as they come in. + At the first non-write operation (including L{close}), all remaining + server responses are collected. This means that if there was an error + with one of your later writes, an exception might be thrown from + within L{close} instead of L{write}. + + By default, files are I{not} pipelined. + + @param pipelined: C{True} if pipelining should be turned on for this + file; C{False} otherwise + @type pipelined: bool + + @since: 1.5 + """ + self.pipelined = pipelined + + def prefetch(self): + """ + Pre-fetch the remaining contents of this file in anticipation of + future L{read} calls. If reading the entire file, pre-fetching can + dramatically improve the download speed by avoiding roundtrip latency. + The file's contents are incrementally buffered in a background thread. + + The prefetched data is stored in a buffer until read via the L{read} + method. Once data has been read, it's removed from the buffer. The + data may be read in a random order (using L{seek}); chunks of the + buffer that haven't been read will continue to be buffered. + + @since: 1.5.1 + """ + size = self.stat().st_size + # queue up async reads for the rest of the file + chunks = [] + n = self._realpos + while n < size: + chunk = min(self.MAX_REQUEST_SIZE, size - n) + chunks.append((n, chunk)) + n += chunk + if len(chunks) > 0: + self._start_prefetch(chunks) + + def readv(self, chunks): + """ + Read a set of blocks from the file by (offset, length). This is more + efficient than doing a series of L{seek} and L{read} calls, since the + prefetch machinery is used to retrieve all the requested blocks at + once. + + @param chunks: a list of (offset, length) tuples indicating which + sections of the file to read + @type chunks: list(tuple(long, int)) + @return: a list of blocks read, in the same order as in C{chunks} + @rtype: list(str) + + @since: 1.5.4 + """ + self.sftp._log(DEBUG, 'readv(%s, %r)' % (hexlify(self.handle), chunks)) + + read_chunks = [] + for offset, size in chunks: + # don't fetch data that's already in the prefetch buffer + if self._data_in_prefetch_buffers(offset) or self._data_in_prefetch_requests(offset, size): + continue + + # break up anything larger than the max read size + while size > 0: + chunk_size = min(size, self.MAX_REQUEST_SIZE) + read_chunks.append((offset, chunk_size)) + offset += chunk_size + size -= chunk_size + + self._start_prefetch(read_chunks) + # now we can just devolve to a bunch of read()s :) + for x in chunks: + self.seek(x[0]) + yield self.read(x[1]) + + + ### internals... + + + def _get_size(self): + try: + return self.stat().st_size + except: + return 0 + + def _start_prefetch(self, chunks): + self._prefetching = True + self._prefetch_done = False + self._prefetch_reads.extend(chunks) + + t = threading.Thread(target=self._prefetch_thread, args=(chunks,)) + t.setDaemon(True) + t.start() + + def _prefetch_thread(self, chunks): + # do these read requests in a temporary thread because there may be + # a lot of them, so it may block. + for offset, length in chunks: + self.sftp._async_request(self, CMD_READ, self.handle, long(offset), int(length)) + + def _async_response(self, t, msg): + if t == CMD_STATUS: + # save exception and re-raise it on next file operation + try: + self.sftp._convert_status(msg) + except Exception, x: + self._saved_exception = x + return + if t != CMD_DATA: + raise SFTPError('Expected data') + data = msg.get_string() + offset, length = self._prefetch_reads.pop(0) + self._prefetch_data[offset] = data + if len(self._prefetch_reads) == 0: + self._prefetch_done = True + + def _check_exception(self): + "if there's a saved exception, raise & clear it" + if self._saved_exception is not None: + x = self._saved_exception + self._saved_exception = None + raise x diff --git a/tools/migration/paramiko/sftp_handle.py b/tools/migration/paramiko/sftp_handle.py new file mode 100644 index 00000000000..a6cd44a82ef --- /dev/null +++ b/tools/migration/paramiko/sftp_handle.py @@ -0,0 +1,202 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Abstraction of an SFTP file handle (for server mode). +""" + +import os + +from paramiko.common import * +from paramiko.sftp import * + + +class SFTPHandle (object): + """ + Abstract object representing a handle to an open file (or folder) in an + SFTP server implementation. Each handle has a string representation used + by the client to refer to the underlying file. + + Server implementations can (and should) subclass SFTPHandle to implement + features of a file handle, like L{stat} or L{chattr}. + """ + def __init__(self, flags=0): + """ + Create a new file handle representing a local file being served over + SFTP. If C{flags} is passed in, it's used to determine if the file + is open in append mode. + + @param flags: optional flags as passed to L{SFTPServerInterface.open} + @type flags: int + """ + self.__flags = flags + self.__name = None + # only for handles to folders: + self.__files = { } + self.__tell = None + + def close(self): + """ + When a client closes a file, this method is called on the handle. + Normally you would use this method to close the underlying OS level + file object(s). + + The default implementation checks for attributes on C{self} named + C{readfile} and/or C{writefile}, and if either or both are present, + their C{close()} methods are called. This means that if you are + using the default implementations of L{read} and L{write}, this + method's default implementation should be fine also. + """ + readfile = getattr(self, 'readfile', None) + if readfile is not None: + readfile.close() + writefile = getattr(self, 'writefile', None) + if writefile is not None: + writefile.close() + + def read(self, offset, length): + """ + Read up to C{length} bytes from this file, starting at position + C{offset}. The offset may be a python long, since SFTP allows it + to be 64 bits. + + If the end of the file has been reached, this method may return an + empty string to signify EOF, or it may also return L{SFTP_EOF}. + + The default implementation checks for an attribute on C{self} named + C{readfile}, and if present, performs the read operation on the python + file-like object found there. (This is meant as a time saver for the + common case where you are wrapping a python file object.) + + @param offset: position in the file to start reading from. + @type offset: int or long + @param length: number of bytes to attempt to read. + @type length: int + @return: data read from the file, or an SFTP error code. + @rtype: str + """ + readfile = getattr(self, 'readfile', None) + if readfile is None: + return SFTP_OP_UNSUPPORTED + try: + if self.__tell is None: + self.__tell = readfile.tell() + if offset != self.__tell: + readfile.seek(offset) + self.__tell = offset + data = readfile.read(length) + except IOError, e: + self.__tell = None + return SFTPServer.convert_errno(e.errno) + self.__tell += len(data) + return data + + def write(self, offset, data): + """ + Write C{data} into this file at position C{offset}. Extending the + file past its original end is expected. Unlike python's normal + C{write()} methods, this method cannot do a partial write: it must + write all of C{data} or else return an error. + + The default implementation checks for an attribute on C{self} named + C{writefile}, and if present, performs the write operation on the + python file-like object found there. The attribute is named + differently from C{readfile} to make it easy to implement read-only + (or write-only) files, but if both attributes are present, they should + refer to the same file. + + @param offset: position in the file to start reading from. + @type offset: int or long + @param data: data to write into the file. + @type data: str + @return: an SFTP error code like L{SFTP_OK}. + """ + writefile = getattr(self, 'writefile', None) + if writefile is None: + return SFTP_OP_UNSUPPORTED + try: + # in append mode, don't care about seeking + if (self.__flags & os.O_APPEND) == 0: + if self.__tell is None: + self.__tell = writefile.tell() + if offset != self.__tell: + writefile.seek(offset) + self.__tell = offset + writefile.write(data) + writefile.flush() + except IOError, e: + self.__tell = None + return SFTPServer.convert_errno(e.errno) + if self.__tell is not None: + self.__tell += len(data) + return SFTP_OK + + def stat(self): + """ + Return an L{SFTPAttributes} object referring to this open file, or an + error code. This is equivalent to L{SFTPServerInterface.stat}, except + it's called on an open file instead of a path. + + @return: an attributes object for the given file, or an SFTP error + code (like L{SFTP_PERMISSION_DENIED}). + @rtype: L{SFTPAttributes} I{or error code} + """ + return SFTP_OP_UNSUPPORTED + + def chattr(self, attr): + """ + Change the attributes of this file. The C{attr} object will contain + only those fields provided by the client in its request, so you should + check for the presence of fields before using them. + + @param attr: the attributes to change on this file. + @type attr: L{SFTPAttributes} + @return: an error code like L{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED + + + ### internals... + + + def _set_files(self, files): + """ + Used by the SFTP server code to cache a directory listing. (In + the SFTP protocol, listing a directory is a multi-stage process + requiring a temporary handle.) + """ + self.__files = files + + def _get_next_files(self): + """ + Used by the SFTP server code to retreive a cached directory + listing. + """ + fnlist = self.__files[:16] + self.__files = self.__files[16:] + return fnlist + + def _get_name(self): + return self.__name + + def _set_name(self, name): + self.__name = name + + +from paramiko.sftp_server import SFTPServer diff --git a/tools/migration/paramiko/sftp_server.py b/tools/migration/paramiko/sftp_server.py new file mode 100644 index 00000000000..7cc6c0c35dd --- /dev/null +++ b/tools/migration/paramiko/sftp_server.py @@ -0,0 +1,444 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Server-mode SFTP support. +""" + +import os +import errno + +from Crypto.Hash import MD5, SHA +from paramiko.common import * +from paramiko.server import SubsystemHandler +from paramiko.sftp import * +from paramiko.sftp_si import * +from paramiko.sftp_attr import * + + +# known hash algorithms for the "check-file" extension +_hash_class = { + 'sha1': SHA, + 'md5': MD5, +} + + +class SFTPServer (BaseSFTP, SubsystemHandler): + """ + Server-side SFTP subsystem support. Since this is a L{SubsystemHandler}, + it can be (and is meant to be) set as the handler for C{"sftp"} requests. + Use L{Transport.set_subsystem_handler} to activate this class. + """ + + def __init__(self, channel, name, server, sftp_si=SFTPServerInterface, *largs, **kwargs): + """ + The constructor for SFTPServer is meant to be called from within the + L{Transport} as a subsystem handler. C{server} and any additional + parameters or keyword parameters are passed from the original call to + L{Transport.set_subsystem_handler}. + + @param channel: channel passed from the L{Transport}. + @type channel: L{Channel} + @param name: name of the requested subsystem. + @type name: str + @param server: the server object associated with this channel and + subsystem + @type server: L{ServerInterface} + @param sftp_si: a subclass of L{SFTPServerInterface} to use for handling + individual requests. + @type sftp_si: class + """ + BaseSFTP.__init__(self) + SubsystemHandler.__init__(self, channel, name, server) + transport = channel.get_transport() + self.logger = util.get_logger(transport.get_log_channel() + '.sftp') + self.ultra_debug = transport.get_hexdump() + self.next_handle = 1 + # map of handle-string to SFTPHandle for files & folders: + self.file_table = { } + self.folder_table = { } + self.server = sftp_si(server, *largs, **kwargs) + + def _log(self, level, msg): + if issubclass(type(msg), list): + for m in msg: + super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + m) + else: + super(SFTPServer, self)._log(level, "[chan " + self.sock.get_name() + "] " + msg) + + def start_subsystem(self, name, transport, channel): + self.sock = channel + self._log(DEBUG, 'Started sftp server on channel %s' % repr(channel)) + self._send_server_version() + self.server.session_started() + while True: + try: + t, data = self._read_packet() + except EOFError: + self._log(DEBUG, 'EOF -- end of session') + return + except Exception, e: + self._log(DEBUG, 'Exception on channel: ' + str(e)) + self._log(DEBUG, util.tb_strings()) + return + msg = Message(data) + request_number = msg.get_int() + try: + self._process(t, request_number, msg) + except Exception, e: + self._log(DEBUG, 'Exception in server processing: ' + str(e)) + self._log(DEBUG, util.tb_strings()) + # send some kind of failure message, at least + try: + self._send_status(request_number, SFTP_FAILURE) + except: + pass + + def finish_subsystem(self): + self.server.session_ended() + super(SFTPServer, self).finish_subsystem() + # close any file handles that were left open (so we can return them to the OS quickly) + for f in self.file_table.itervalues(): + f.close() + for f in self.folder_table.itervalues(): + f.close() + self.file_table = {} + self.folder_table = {} + + def convert_errno(e): + """ + Convert an errno value (as from an C{OSError} or C{IOError}) into a + standard SFTP result code. This is a convenience function for trapping + exceptions in server code and returning an appropriate result. + + @param e: an errno code, as from C{OSError.errno}. + @type e: int + @return: an SFTP error code like L{SFTP_NO_SUCH_FILE}. + @rtype: int + """ + if e == errno.EACCES: + # permission denied + return SFTP_PERMISSION_DENIED + elif (e == errno.ENOENT) or (e == errno.ENOTDIR): + # no such file + return SFTP_NO_SUCH_FILE + else: + return SFTP_FAILURE + convert_errno = staticmethod(convert_errno) + + def set_file_attr(filename, attr): + """ + Change a file's attributes on the local filesystem. The contents of + C{attr} are used to change the permissions, owner, group ownership, + and/or modification & access time of the file, depending on which + attributes are present in C{attr}. + + This is meant to be a handy helper function for translating SFTP file + requests into local file operations. + + @param filename: name of the file to alter (should usually be an + absolute path). + @type filename: str + @param attr: attributes to change. + @type attr: L{SFTPAttributes} + """ + if sys.platform != 'win32': + # mode operations are meaningless on win32 + if attr._flags & attr.FLAG_PERMISSIONS: + os.chmod(filename, attr.st_mode) + if attr._flags & attr.FLAG_UIDGID: + os.chown(filename, attr.st_uid, attr.st_gid) + if attr._flags & attr.FLAG_AMTIME: + os.utime(filename, (attr.st_atime, attr.st_mtime)) + if attr._flags & attr.FLAG_SIZE: + open(filename, 'w+').truncate(attr.st_size) + set_file_attr = staticmethod(set_file_attr) + + + ### internals... + + + def _response(self, request_number, t, *arg): + msg = Message() + msg.add_int(request_number) + for item in arg: + if type(item) is int: + msg.add_int(item) + elif type(item) is long: + msg.add_int64(item) + elif type(item) is str: + msg.add_string(item) + elif type(item) is SFTPAttributes: + item._pack(msg) + else: + raise Exception('unknown type for ' + repr(item) + ' type ' + repr(type(item))) + self._send_packet(t, str(msg)) + + def _send_handle_response(self, request_number, handle, folder=False): + if not issubclass(type(handle), SFTPHandle): + # must be error code + self._send_status(request_number, handle) + return + handle._set_name('hx%d' % self.next_handle) + self.next_handle += 1 + if folder: + self.folder_table[handle._get_name()] = handle + else: + self.file_table[handle._get_name()] = handle + self._response(request_number, CMD_HANDLE, handle._get_name()) + + def _send_status(self, request_number, code, desc=None): + if desc is None: + try: + desc = SFTP_DESC[code] + except IndexError: + desc = 'Unknown' + # some clients expect a "langauge" tag at the end (but don't mind it being blank) + self._response(request_number, CMD_STATUS, code, desc, '') + + def _open_folder(self, request_number, path): + resp = self.server.list_folder(path) + if issubclass(type(resp), list): + # got an actual list of filenames in the folder + folder = SFTPHandle() + folder._set_files(resp) + self._send_handle_response(request_number, folder, True) + return + # must be an error code + self._send_status(request_number, resp) + + def _read_folder(self, request_number, folder): + flist = folder._get_next_files() + if len(flist) == 0: + self._send_status(request_number, SFTP_EOF) + return + msg = Message() + msg.add_int(request_number) + msg.add_int(len(flist)) + for attr in flist: + msg.add_string(attr.filename) + msg.add_string(str(attr)) + attr._pack(msg) + self._send_packet(CMD_NAME, str(msg)) + + def _check_file(self, request_number, msg): + # this extension actually comes from v6 protocol, but since it's an + # extension, i feel like we can reasonably support it backported. + # it's very useful for verifying uploaded files or checking for + # rsync-like differences between local and remote files. + handle = msg.get_string() + alg_list = msg.get_list() + start = msg.get_int64() + length = msg.get_int64() + block_size = msg.get_int() + if handle not in self.file_table: + self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + return + f = self.file_table[handle] + for x in alg_list: + if x in _hash_class: + algname = x + alg = _hash_class[x] + break + else: + self._send_status(request_number, SFTP_FAILURE, 'No supported hash types found') + return + if length == 0: + st = f.stat() + if not issubclass(type(st), SFTPAttributes): + self._send_status(request_number, st, 'Unable to stat file') + return + length = st.st_size - start + if block_size == 0: + block_size = length + if block_size < 256: + self._send_status(request_number, SFTP_FAILURE, 'Block size too small') + return + + sum_out = '' + offset = start + while offset < start + length: + blocklen = min(block_size, start + length - offset) + # don't try to read more than about 64KB at a time + chunklen = min(blocklen, 65536) + count = 0 + hash_obj = alg.new() + while count < blocklen: + data = f.read(offset, chunklen) + if not type(data) is str: + self._send_status(request_number, data, 'Unable to hash file') + return + hash_obj.update(data) + count += len(data) + offset += count + sum_out += hash_obj.digest() + + msg = Message() + msg.add_int(request_number) + msg.add_string('check-file') + msg.add_string(algname) + msg.add_bytes(sum_out) + self._send_packet(CMD_EXTENDED_REPLY, str(msg)) + + def _convert_pflags(self, pflags): + "convert SFTP-style open() flags to python's os.open() flags" + if (pflags & SFTP_FLAG_READ) and (pflags & SFTP_FLAG_WRITE): + flags = os.O_RDWR + elif pflags & SFTP_FLAG_WRITE: + flags = os.O_WRONLY + else: + flags = os.O_RDONLY + if pflags & SFTP_FLAG_APPEND: + flags |= os.O_APPEND + if pflags & SFTP_FLAG_CREATE: + flags |= os.O_CREAT + if pflags & SFTP_FLAG_TRUNC: + flags |= os.O_TRUNC + if pflags & SFTP_FLAG_EXCL: + flags |= os.O_EXCL + return flags + + def _process(self, t, request_number, msg): + self._log(DEBUG, 'Request: %s' % CMD_NAMES[t]) + if t == CMD_OPEN: + path = msg.get_string() + flags = self._convert_pflags(msg.get_int()) + attr = SFTPAttributes._from_msg(msg) + self._send_handle_response(request_number, self.server.open(path, flags, attr)) + elif t == CMD_CLOSE: + handle = msg.get_string() + if handle in self.folder_table: + del self.folder_table[handle] + self._send_status(request_number, SFTP_OK) + return + if handle in self.file_table: + self.file_table[handle].close() + del self.file_table[handle] + self._send_status(request_number, SFTP_OK) + return + self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + elif t == CMD_READ: + handle = msg.get_string() + offset = msg.get_int64() + length = msg.get_int() + if handle not in self.file_table: + self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + return + data = self.file_table[handle].read(offset, length) + if type(data) is str: + if len(data) == 0: + self._send_status(request_number, SFTP_EOF) + else: + self._response(request_number, CMD_DATA, data) + else: + self._send_status(request_number, data) + elif t == CMD_WRITE: + handle = msg.get_string() + offset = msg.get_int64() + data = msg.get_string() + if handle not in self.file_table: + self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + return + self._send_status(request_number, self.file_table[handle].write(offset, data)) + elif t == CMD_REMOVE: + path = msg.get_string() + self._send_status(request_number, self.server.remove(path)) + elif t == CMD_RENAME: + oldpath = msg.get_string() + newpath = msg.get_string() + self._send_status(request_number, self.server.rename(oldpath, newpath)) + elif t == CMD_MKDIR: + path = msg.get_string() + attr = SFTPAttributes._from_msg(msg) + self._send_status(request_number, self.server.mkdir(path, attr)) + elif t == CMD_RMDIR: + path = msg.get_string() + self._send_status(request_number, self.server.rmdir(path)) + elif t == CMD_OPENDIR: + path = msg.get_string() + self._open_folder(request_number, path) + return + elif t == CMD_READDIR: + handle = msg.get_string() + if handle not in self.folder_table: + self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + return + folder = self.folder_table[handle] + self._read_folder(request_number, folder) + elif t == CMD_STAT: + path = msg.get_string() + resp = self.server.stat(path) + if issubclass(type(resp), SFTPAttributes): + self._response(request_number, CMD_ATTRS, resp) + else: + self._send_status(request_number, resp) + elif t == CMD_LSTAT: + path = msg.get_string() + resp = self.server.lstat(path) + if issubclass(type(resp), SFTPAttributes): + self._response(request_number, CMD_ATTRS, resp) + else: + self._send_status(request_number, resp) + elif t == CMD_FSTAT: + handle = msg.get_string() + if handle not in self.file_table: + self._send_status(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + return + resp = self.file_table[handle].stat() + if issubclass(type(resp), SFTPAttributes): + self._response(request_number, CMD_ATTRS, resp) + else: + self._send_status(request_number, resp) + elif t == CMD_SETSTAT: + path = msg.get_string() + attr = SFTPAttributes._from_msg(msg) + self._send_status(request_number, self.server.chattr(path, attr)) + elif t == CMD_FSETSTAT: + handle = msg.get_string() + attr = SFTPAttributes._from_msg(msg) + if handle not in self.file_table: + self._response(request_number, SFTP_BAD_MESSAGE, 'Invalid handle') + return + self._send_status(request_number, self.file_table[handle].chattr(attr)) + elif t == CMD_READLINK: + path = msg.get_string() + resp = self.server.readlink(path) + if type(resp) is str: + self._response(request_number, CMD_NAME, 1, resp, '', SFTPAttributes()) + else: + self._send_status(request_number, resp) + elif t == CMD_SYMLINK: + # the sftp 2 draft is incorrect here! path always follows target_path + target_path = msg.get_string() + path = msg.get_string() + self._send_status(request_number, self.server.symlink(target_path, path)) + elif t == CMD_REALPATH: + path = msg.get_string() + rpath = self.server.canonicalize(path) + self._response(request_number, CMD_NAME, 1, rpath, '', SFTPAttributes()) + elif t == CMD_EXTENDED: + tag = msg.get_string() + if tag == 'check-file': + self._check_file(request_number, msg) + else: + self._send_status(request_number, SFTP_OP_UNSUPPORTED) + else: + self._send_status(request_number, SFTP_OP_UNSUPPORTED) + + +from paramiko.sftp_handle import SFTPHandle diff --git a/tools/migration/paramiko/sftp_si.py b/tools/migration/paramiko/sftp_si.py new file mode 100644 index 00000000000..401a4e996e3 --- /dev/null +++ b/tools/migration/paramiko/sftp_si.py @@ -0,0 +1,310 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{SFTPServerInterface} is an interface to override for SFTP server support. +""" + +import os + +from paramiko.common import * +from paramiko.sftp import * + + +class SFTPServerInterface (object): + """ + This class defines an interface for controlling the behavior of paramiko + when using the L{SFTPServer} subsystem to provide an SFTP server. + + Methods on this class are called from the SFTP session's thread, so you can + block as long as necessary without affecting other sessions (even other + SFTP sessions). However, raising an exception will usually cause the SFTP + session to abruptly end, so you will usually want to catch exceptions and + return an appropriate error code. + + All paths are in string form instead of unicode because not all SFTP + clients & servers obey the requirement that paths be encoded in UTF-8. + """ + + def __init__ (self, server, *largs, **kwargs): + """ + Create a new SFTPServerInterface object. This method does nothing by + default and is meant to be overridden by subclasses. + + @param server: the server object associated with this channel and + SFTP subsystem + @type server: L{ServerInterface} + """ + super(SFTPServerInterface, self).__init__(*largs, **kwargs) + + def session_started(self): + """ + The SFTP server session has just started. This method is meant to be + overridden to perform any necessary setup before handling callbacks + from SFTP operations. + """ + pass + + def session_ended(self): + """ + The SFTP server session has just ended, either cleanly or via an + exception. This method is meant to be overridden to perform any + necessary cleanup before this C{SFTPServerInterface} object is + destroyed. + """ + pass + + def open(self, path, flags, attr): + """ + Open a file on the server and create a handle for future operations + on that file. On success, a new object subclassed from L{SFTPHandle} + should be returned. This handle will be used for future operations + on the file (read, write, etc). On failure, an error code such as + L{SFTP_PERMISSION_DENIED} should be returned. + + C{flags} contains the requested mode for opening (read-only, + write-append, etc) as a bitset of flags from the C{os} module: + - C{os.O_RDONLY} + - C{os.O_WRONLY} + - C{os.O_RDWR} + - C{os.O_APPEND} + - C{os.O_CREAT} + - C{os.O_TRUNC} + - C{os.O_EXCL} + (One of C{os.O_RDONLY}, C{os.O_WRONLY}, or C{os.O_RDWR} will always + be set.) + + The C{attr} object contains requested attributes of the file if it + has to be created. Some or all attribute fields may be missing if + the client didn't specify them. + + @note: The SFTP protocol defines all files to be in "binary" mode. + There is no equivalent to python's "text" mode. + + @param path: the requested path (relative or absolute) of the file + to be opened. + @type path: str + @param flags: flags or'd together from the C{os} module indicating the + requested mode for opening the file. + @type flags: int + @param attr: requested attributes of the file if it is newly created. + @type attr: L{SFTPAttributes} + @return: a new L{SFTPHandle} I{or error code}. + @rtype L{SFTPHandle} + """ + return SFTP_OP_UNSUPPORTED + + def list_folder(self, path): + """ + Return a list of files within a given folder. The C{path} will use + posix notation (C{"/"} separates folder names) and may be an absolute + or relative path. + + The list of files is expected to be a list of L{SFTPAttributes} + objects, which are similar in structure to the objects returned by + C{os.stat}. In addition, each object should have its C{filename} + field filled in, since this is important to a directory listing and + not normally present in C{os.stat} results. The method + L{SFTPAttributes.from_stat} will usually do what you want. + + In case of an error, you should return one of the C{SFTP_*} error + codes, such as L{SFTP_PERMISSION_DENIED}. + + @param path: the requested path (relative or absolute) to be listed. + @type path: str + @return: a list of the files in the given folder, using + L{SFTPAttributes} objects. + @rtype: list of L{SFTPAttributes} I{or error code} + + @note: You should normalize the given C{path} first (see the + C{os.path} module) and check appropriate permissions before returning + the list of files. Be careful of malicious clients attempting to use + relative paths to escape restricted folders, if you're doing a direct + translation from the SFTP server path to your local filesystem. + """ + return SFTP_OP_UNSUPPORTED + + def stat(self, path): + """ + Return an L{SFTPAttributes} object for a path on the server, or an + error code. If your server supports symbolic links (also known as + "aliases"), you should follow them. (L{lstat} is the corresponding + call that doesn't follow symlinks/aliases.) + + @param path: the requested path (relative or absolute) to fetch + file statistics for. + @type path: str + @return: an attributes object for the given file, or an SFTP error + code (like L{SFTP_PERMISSION_DENIED}). + @rtype: L{SFTPAttributes} I{or error code} + """ + return SFTP_OP_UNSUPPORTED + + def lstat(self, path): + """ + Return an L{SFTPAttributes} object for a path on the server, or an + error code. If your server supports symbolic links (also known as + "aliases"), you should I{not} follow them -- instead, you should + return data on the symlink or alias itself. (L{stat} is the + corresponding call that follows symlinks/aliases.) + + @param path: the requested path (relative or absolute) to fetch + file statistics for. + @type path: str + @return: an attributes object for the given file, or an SFTP error + code (like L{SFTP_PERMISSION_DENIED}). + @rtype: L{SFTPAttributes} I{or error code} + """ + return SFTP_OP_UNSUPPORTED + + def remove(self, path): + """ + Delete a file, if possible. + + @param path: the requested path (relative or absolute) of the file + to delete. + @type path: str + @return: an SFTP error code like L{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED + + def rename(self, oldpath, newpath): + """ + Rename (or move) a file. The SFTP specification implies that this + method can be used to move an existing file into a different folder, + and since there's no other (easy) way to move files via SFTP, it's + probably a good idea to implement "move" in this method too, even for + files that cross disk partition boundaries, if at all possible. + + @note: You should return an error if a file with the same name as + C{newpath} already exists. (The rename operation should be + non-desctructive.) + + @param oldpath: the requested path (relative or absolute) of the + existing file. + @type oldpath: str + @param newpath: the requested new path of the file. + @type newpath: str + @return: an SFTP error code like L{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED + + def mkdir(self, path, attr): + """ + Create a new directory with the given attributes. The C{attr} + object may be considered a "hint" and ignored. + + The C{attr} object will contain only those fields provided by the + client in its request, so you should use C{hasattr} to check for + the presense of fields before using them. In some cases, the C{attr} + object may be completely empty. + + @param path: requested path (relative or absolute) of the new + folder. + @type path: str + @param attr: requested attributes of the new folder. + @type attr: L{SFTPAttributes} + @return: an SFTP error code like L{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED + + def rmdir(self, path): + """ + Remove a directory if it exists. The C{path} should refer to an + existing, empty folder -- otherwise this method should return an + error. + + @param path: requested path (relative or absolute) of the folder + to remove. + @type path: str + @return: an SFTP error code like L{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED + + def chattr(self, path, attr): + """ + Change the attributes of a file. The C{attr} object will contain + only those fields provided by the client in its request, so you + should check for the presence of fields before using them. + + @param path: requested path (relative or absolute) of the file to + change. + @type path: str + @param attr: requested attributes to change on the file. + @type attr: L{SFTPAttributes} + @return: an error code like L{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED + + def canonicalize(self, path): + """ + Return the canonical form of a path on the server. For example, + if the server's home folder is C{/home/foo}, the path + C{"../betty"} would be canonicalized to C{"/home/betty"}. Note + the obvious security issues: if you're serving files only from a + specific folder, you probably don't want this method to reveal path + names outside that folder. + + You may find the python methods in C{os.path} useful, especially + C{os.path.normpath} and C{os.path.realpath}. + + The default implementation returns C{os.path.normpath('/' + path)}. + """ + if os.path.isabs(path): + out = os.path.normpath(path) + else: + out = os.path.normpath('/' + path) + if sys.platform == 'win32': + # on windows, normalize backslashes to sftp/posix format + out = out.replace('\\', '/') + return out + + def readlink(self, path): + """ + Return the target of a symbolic link (or shortcut) on the server. + If the specified path doesn't refer to a symbolic link, an error + should be returned. + + @param path: path (relative or absolute) of the symbolic link. + @type path: str + @return: the target path of the symbolic link, or an error code like + L{SFTP_NO_SUCH_FILE}. + @rtype: str I{or error code} + """ + return SFTP_OP_UNSUPPORTED + + def symlink(self, target_path, path): + """ + Create a symbolic link on the server, as new pathname C{path}, + with C{target_path} as the target of the link. + + @param target_path: path (relative or absolute) of the target for + this new symbolic link. + @type target_path: str + @param path: path (relative or absolute) of the symbolic link to + create. + @type path: str + @return: an error code like C{SFTP_OK}. + @rtype: int + """ + return SFTP_OP_UNSUPPORTED diff --git a/tools/migration/paramiko/ssh_exception.py b/tools/migration/paramiko/ssh_exception.py new file mode 100644 index 00000000000..68924d0f148 --- /dev/null +++ b/tools/migration/paramiko/ssh_exception.py @@ -0,0 +1,115 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Exceptions defined by paramiko. +""" + + +class SSHException (Exception): + """ + Exception raised by failures in SSH2 protocol negotiation or logic errors. + """ + pass + + +class AuthenticationException (SSHException): + """ + Exception raised when authentication failed for some reason. It may be + possible to retry with different credentials. (Other classes specify more + specific reasons.) + + @since: 1.6 + """ + pass + + +class PasswordRequiredException (AuthenticationException): + """ + Exception raised when a password is needed to unlock a private key file. + """ + pass + + +class BadAuthenticationType (AuthenticationException): + """ + Exception raised when an authentication type (like password) is used, but + the server isn't allowing that type. (It may only allow public-key, for + example.) + + @ivar allowed_types: list of allowed authentication types provided by the + server (possible values are: C{"none"}, C{"password"}, and + C{"publickey"}). + @type allowed_types: list + + @since: 1.1 + """ + allowed_types = [] + + def __init__(self, explanation, types): + AuthenticationException.__init__(self, explanation) + self.allowed_types = types + + def __str__(self): + return SSHException.__str__(self) + ' (allowed_types=%r)' % self.allowed_types + + +class PartialAuthentication (AuthenticationException): + """ + An internal exception thrown in the case of partial authentication. + """ + allowed_types = [] + + def __init__(self, types): + AuthenticationException.__init__(self, 'partial authentication') + self.allowed_types = types + + +class ChannelException (SSHException): + """ + Exception raised when an attempt to open a new L{Channel} fails. + + @ivar code: the error code returned by the server + @type code: int + + @since: 1.6 + """ + def __init__(self, code, text): + SSHException.__init__(self, text) + self.code = code + + +class BadHostKeyException (SSHException): + """ + The host key given by the SSH server did not match what we were expecting. + + @ivar hostname: the hostname of the SSH server + @type hostname: str + @ivar key: the host key presented by the server + @type key: L{PKey} + @ivar expected_key: the host key expected + @type expected_key: L{PKey} + + @since: 1.6 + """ + def __init__(self, hostname, got_key, expected_key): + SSHException.__init__(self, 'Host key for server %s does not match!' % hostname) + self.hostname = hostname + self.key = got_key + self.expected_key = expected_key + diff --git a/tools/migration/paramiko/transport.py b/tools/migration/paramiko/transport.py new file mode 100644 index 00000000000..50e78e7dcfb --- /dev/null +++ b/tools/migration/paramiko/transport.py @@ -0,0 +1,2099 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +L{Transport} handles the core SSH2 protocol. +""" + +import os +import socket +import string +import struct +import sys +import threading +import time +import weakref + +from paramiko import util +from paramiko.auth_handler import AuthHandler +from paramiko.channel import Channel +from paramiko.common import * +from paramiko.compress import ZlibCompressor, ZlibDecompressor +from paramiko.dsskey import DSSKey +from paramiko.kex_gex import KexGex +from paramiko.kex_group1 import KexGroup1 +from paramiko.message import Message +from paramiko.packet import Packetizer, NeedRekeyException +from paramiko.primes import ModulusPack +from paramiko.rsakey import RSAKey +from paramiko.server import ServerInterface +from paramiko.sftp_client import SFTPClient +from paramiko.ssh_exception import SSHException, BadAuthenticationType, ChannelException + +# these come from PyCrypt +# http://www.amk.ca/python/writing/pycrypt/ +# i believe this on the standards track. +# PyCrypt compiled for Win32 can be downloaded from the HashTar homepage: +# http://nitace.bsd.uchicago.edu:8080/hashtar +from Crypto.Cipher import Blowfish, AES, DES3, ARC4 +from Crypto.Hash import SHA, MD5 +try: + from Crypto.Util import Counter +except ImportError: + from paramiko.util import Counter + + +# for thread cleanup +_active_threads = [] +def _join_lingering_threads(): + for thr in _active_threads: + thr.stop_thread() +import atexit +atexit.register(_join_lingering_threads) + + +class SecurityOptions (object): + """ + Simple object containing the security preferences of an ssh transport. + These are tuples of acceptable ciphers, digests, key types, and key + exchange algorithms, listed in order of preference. + + Changing the contents and/or order of these fields affects the underlying + L{Transport} (but only if you change them before starting the session). + If you try to add an algorithm that paramiko doesn't recognize, + C{ValueError} will be raised. If you try to assign something besides a + tuple to one of the fields, C{TypeError} will be raised. + """ + __slots__ = [ 'ciphers', 'digests', 'key_types', 'kex', 'compression', '_transport' ] + + def __init__(self, transport): + self._transport = transport + + def __repr__(self): + """ + Returns a string representation of this object, for debugging. + + @rtype: str + """ + return '' % repr(self._transport) + + def _get_ciphers(self): + return self._transport._preferred_ciphers + + def _get_digests(self): + return self._transport._preferred_macs + + def _get_key_types(self): + return self._transport._preferred_keys + + def _get_kex(self): + return self._transport._preferred_kex + + def _get_compression(self): + return self._transport._preferred_compression + + def _set(self, name, orig, x): + if type(x) is list: + x = tuple(x) + if type(x) is not tuple: + raise TypeError('expected tuple or list') + possible = getattr(self._transport, orig).keys() + forbidden = filter(lambda n: n not in possible, x) + if len(forbidden) > 0: + raise ValueError('unknown cipher') + setattr(self._transport, name, x) + + def _set_ciphers(self, x): + self._set('_preferred_ciphers', '_cipher_info', x) + + def _set_digests(self, x): + self._set('_preferred_macs', '_mac_info', x) + + def _set_key_types(self, x): + self._set('_preferred_keys', '_key_info', x) + + def _set_kex(self, x): + self._set('_preferred_kex', '_kex_info', x) + + def _set_compression(self, x): + self._set('_preferred_compression', '_compression_info', x) + + ciphers = property(_get_ciphers, _set_ciphers, None, + "Symmetric encryption ciphers") + digests = property(_get_digests, _set_digests, None, + "Digest (one-way hash) algorithms") + key_types = property(_get_key_types, _set_key_types, None, + "Public-key algorithms") + kex = property(_get_kex, _set_kex, None, "Key exchange algorithms") + compression = property(_get_compression, _set_compression, None, + "Compression algorithms") + + +class ChannelMap (object): + def __init__(self): + # (id -> Channel) + self._map = weakref.WeakValueDictionary() + self._lock = threading.Lock() + + def put(self, chanid, chan): + self._lock.acquire() + try: + self._map[chanid] = chan + finally: + self._lock.release() + + def get(self, chanid): + self._lock.acquire() + try: + return self._map.get(chanid, None) + finally: + self._lock.release() + + def delete(self, chanid): + self._lock.acquire() + try: + try: + del self._map[chanid] + except KeyError: + pass + finally: + self._lock.release() + + def values(self): + self._lock.acquire() + try: + return self._map.values() + finally: + self._lock.release() + + def __len__(self): + self._lock.acquire() + try: + return len(self._map) + finally: + self._lock.release() + + +class Transport (threading.Thread): + """ + An SSH Transport attaches to a stream (usually a socket), negotiates an + encrypted session, authenticates, and then creates stream tunnels, called + L{Channel}s, across the session. Multiple channels can be multiplexed + across a single session (and often are, in the case of port forwardings). + """ + + _PROTO_ID = '2.0' + _CLIENT_ID = 'paramiko_1.7.6' + + _preferred_ciphers = ( 'aes128-ctr', 'aes256-ctr', 'aes128-cbc', 'blowfish-cbc', 'aes256-cbc', '3des-cbc', + 'arcfour128', 'arcfour256' ) + _preferred_macs = ( 'hmac-sha1', 'hmac-md5', 'hmac-sha1-96', 'hmac-md5-96' ) + _preferred_keys = ( 'ssh-rsa', 'ssh-dss' ) + _preferred_kex = ( 'diffie-hellman-group1-sha1', 'diffie-hellman-group-exchange-sha1' ) + _preferred_compression = ( 'none', ) + + _cipher_info = { + 'aes128-ctr': { 'class': AES, 'mode': AES.MODE_CTR, 'block-size': 16, 'key-size': 16 }, + 'aes256-ctr': { 'class': AES, 'mode': AES.MODE_CTR, 'block-size': 16, 'key-size': 32 }, + 'blowfish-cbc': { 'class': Blowfish, 'mode': Blowfish.MODE_CBC, 'block-size': 8, 'key-size': 16 }, + 'aes128-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 16 }, + 'aes256-cbc': { 'class': AES, 'mode': AES.MODE_CBC, 'block-size': 16, 'key-size': 32 }, + '3des-cbc': { 'class': DES3, 'mode': DES3.MODE_CBC, 'block-size': 8, 'key-size': 24 }, + 'arcfour128': { 'class': ARC4, 'mode': None, 'block-size': 8, 'key-size': 16 }, + 'arcfour256': { 'class': ARC4, 'mode': None, 'block-size': 8, 'key-size': 32 }, + } + + _mac_info = { + 'hmac-sha1': { 'class': SHA, 'size': 20 }, + 'hmac-sha1-96': { 'class': SHA, 'size': 12 }, + 'hmac-md5': { 'class': MD5, 'size': 16 }, + 'hmac-md5-96': { 'class': MD5, 'size': 12 }, + } + + _key_info = { + 'ssh-rsa': RSAKey, + 'ssh-dss': DSSKey, + } + + _kex_info = { + 'diffie-hellman-group1-sha1': KexGroup1, + 'diffie-hellman-group-exchange-sha1': KexGex, + } + + _compression_info = { + # zlib@openssh.com is just zlib, but only turned on after a successful + # authentication. openssh servers may only offer this type because + # they've had troubles with security holes in zlib in the past. + 'zlib@openssh.com': ( ZlibCompressor, ZlibDecompressor ), + 'zlib': ( ZlibCompressor, ZlibDecompressor ), + 'none': ( None, None ), + } + + + _modulus_pack = None + + def __init__(self, sock): + """ + Create a new SSH session over an existing socket, or socket-like + object. This only creates the Transport object; it doesn't begin the + SSH session yet. Use L{connect} or L{start_client} to begin a client + session, or L{start_server} to begin a server session. + + If the object is not actually a socket, it must have the following + methods: + - C{send(str)}: Writes from 1 to C{len(str)} bytes, and + returns an int representing the number of bytes written. Returns + 0 or raises C{EOFError} if the stream has been closed. + - C{recv(int)}: Reads from 1 to C{int} bytes and returns them as a + string. Returns 0 or raises C{EOFError} if the stream has been + closed. + - C{close()}: Closes the socket. + - C{settimeout(n)}: Sets a (float) timeout on I/O operations. + + For ease of use, you may also pass in an address (as a tuple) or a host + string as the C{sock} argument. (A host string is a hostname with an + optional port (separated by C{":"}) which will be converted into a + tuple of C{(hostname, port)}.) A socket will be connected to this + address and used for communication. Exceptions from the C{socket} call + may be thrown in this case. + + @param sock: a socket or socket-like object to create the session over. + @type sock: socket + """ + if isinstance(sock, (str, unicode)): + # convert "host:port" into (host, port) + hl = sock.split(':', 1) + if len(hl) == 1: + sock = (hl[0], 22) + else: + sock = (hl[0], int(hl[1])) + if type(sock) is tuple: + # connect to the given (host, port) + hostname, port = sock + for (family, socktype, proto, canonname, sockaddr) in socket.getaddrinfo(hostname, port, socket.AF_UNSPEC, socket.SOCK_STREAM): + if socktype == socket.SOCK_STREAM: + af = family + addr = sockaddr + break + else: + raise SSHException('No suitable address family for %s' % hostname) + sock = socket.socket(af, socket.SOCK_STREAM) + sock.connect((hostname, port)) + # okay, normal socket-ish flow here... + threading.Thread.__init__(self) + self.setDaemon(True) + self.randpool = randpool + self.sock = sock + # Python < 2.3 doesn't have the settimeout method - RogerB + try: + # we set the timeout so we can check self.active periodically to + # see if we should bail. socket.timeout exception is never + # propagated. + self.sock.settimeout(0.1) + except AttributeError: + pass + + # negotiated crypto parameters + self.packetizer = Packetizer(sock) + self.local_version = 'SSH-' + self._PROTO_ID + '-' + self._CLIENT_ID + self.remote_version = '' + self.local_cipher = self.remote_cipher = '' + self.local_kex_init = self.remote_kex_init = None + self.local_mac = self.remote_mac = None + self.local_compression = self.remote_compression = None + self.session_id = None + self.host_key_type = None + self.host_key = None + + # state used during negotiation + self.kex_engine = None + self.H = None + self.K = None + + self.active = False + self.initial_kex_done = False + self.in_kex = False + self.authenticated = False + self._expected_packet = tuple() + self.lock = threading.Lock() # synchronization (always higher level than write_lock) + + # tracking open channels + self._channels = ChannelMap() + self.channel_events = { } # (id -> Event) + self.channels_seen = { } # (id -> True) + self._channel_counter = 1 + self.window_size = 65536 + self.max_packet_size = 34816 + self._x11_handler = None + self._tcp_handler = None + + self.saved_exception = None + self.clear_to_send = threading.Event() + self.clear_to_send_lock = threading.Lock() + self.clear_to_send_timeout = 30.0 + self.log_name = 'paramiko.transport' + self.logger = util.get_logger(self.log_name) + self.packetizer.set_log(self.logger) + self.auth_handler = None + self.global_response = None # response Message from an arbitrary global request + self.completion_event = None # user-defined event callbacks + self.banner_timeout = 15 # how long (seconds) to wait for the SSH banner + + # server mode: + self.server_mode = False + self.server_object = None + self.server_key_dict = { } + self.server_accepts = [ ] + self.server_accept_cv = threading.Condition(self.lock) + self.subsystem_table = { } + + def __repr__(self): + """ + Returns a string representation of this object, for debugging. + + @rtype: str + """ + out = '} or + L{auth_publickey }. + + @note: L{connect} is a simpler method for connecting as a client. + + @note: After calling this method (or L{start_server} or L{connect}), + you should no longer directly read from or write to the original + socket object. + + @param event: an event to trigger when negotiation is complete + (optional) + @type event: threading.Event + + @raise SSHException: if negotiation fails (and no C{event} was passed + in) + """ + self.active = True + if event is not None: + # async, return immediately and let the app poll for completion + self.completion_event = event + self.start() + return + + # synchronous, wait for a result + self.completion_event = event = threading.Event() + self.start() + while True: + event.wait(0.1) + if not self.active: + e = self.get_exception() + if e is not None: + raise e + raise SSHException('Negotiation failed.') + if event.isSet(): + break + + def start_server(self, event=None, server=None): + """ + Negotiate a new SSH2 session as a server. This is the first step after + creating a new L{Transport} and setting up your server host key(s). A + separate thread is created for protocol negotiation. + + If an event is passed in, this method returns immediately. When + negotiation is done (successful or not), the given C{Event} will + be triggered. On failure, L{is_active} will return C{False}. + + (Since 1.4) If C{event} is C{None}, this method will not return until + negotation is done. On success, the method returns normally. + Otherwise an SSHException is raised. + + After a successful negotiation, the client will need to authenticate. + Override the methods + L{get_allowed_auths }, + L{check_auth_none }, + L{check_auth_password }, and + L{check_auth_publickey } in the + given C{server} object to control the authentication process. + + After a successful authentication, the client should request to open + a channel. Override + L{check_channel_request } in the + given C{server} object to allow channels to be opened. + + @note: After calling this method (or L{start_client} or L{connect}), + you should no longer directly read from or write to the original + socket object. + + @param event: an event to trigger when negotiation is complete. + @type event: threading.Event + @param server: an object used to perform authentication and create + L{Channel}s. + @type server: L{server.ServerInterface} + + @raise SSHException: if negotiation fails (and no C{event} was passed + in) + """ + if server is None: + server = ServerInterface() + self.server_mode = True + self.server_object = server + self.active = True + if event is not None: + # async, return immediately and let the app poll for completion + self.completion_event = event + self.start() + return + + # synchronous, wait for a result + self.completion_event = event = threading.Event() + self.start() + while True: + event.wait(0.1) + if not self.active: + e = self.get_exception() + if e is not None: + raise e + raise SSHException('Negotiation failed.') + if event.isSet(): + break + + def add_server_key(self, key): + """ + Add a host key to the list of keys used for server mode. When behaving + as a server, the host key is used to sign certain packets during the + SSH2 negotiation, so that the client can trust that we are who we say + we are. Because this is used for signing, the key must contain private + key info, not just the public half. Only one key of each type (RSA or + DSS) is kept. + + @param key: the host key to add, usually an L{RSAKey } or + L{DSSKey }. + @type key: L{PKey } + """ + self.server_key_dict[key.get_name()] = key + + def get_server_key(self): + """ + Return the active host key, in server mode. After negotiating with the + client, this method will return the negotiated host key. If only one + type of host key was set with L{add_server_key}, that's the only key + that will ever be returned. But in cases where you have set more than + one type of host key (for example, an RSA key and a DSS key), the key + type will be negotiated by the client, and this method will return the + key of the type agreed on. If the host key has not been negotiated + yet, C{None} is returned. In client mode, the behavior is undefined. + + @return: host key of the type negotiated by the client, or C{None}. + @rtype: L{PKey } + """ + try: + return self.server_key_dict[self.host_key_type] + except KeyError: + pass + return None + + def load_server_moduli(filename=None): + """ + I{(optional)} + Load a file of prime moduli for use in doing group-exchange key + negotiation in server mode. It's a rather obscure option and can be + safely ignored. + + In server mode, the remote client may request "group-exchange" key + negotiation, which asks the server to send a random prime number that + fits certain criteria. These primes are pretty difficult to compute, + so they can't be generated on demand. But many systems contain a file + of suitable primes (usually named something like C{/etc/ssh/moduli}). + If you call C{load_server_moduli} and it returns C{True}, then this + file of primes has been loaded and we will support "group-exchange" in + server mode. Otherwise server mode will just claim that it doesn't + support that method of key negotiation. + + @param filename: optional path to the moduli file, if you happen to + know that it's not in a standard location. + @type filename: str + @return: True if a moduli file was successfully loaded; False + otherwise. + @rtype: bool + + @note: This has no effect when used in client mode. + """ + Transport._modulus_pack = ModulusPack(randpool) + # places to look for the openssh "moduli" file + file_list = [ '/etc/ssh/moduli', '/usr/local/etc/moduli' ] + if filename is not None: + file_list.insert(0, filename) + for fn in file_list: + try: + Transport._modulus_pack.read_file(fn) + return True + except IOError: + pass + # none succeeded + Transport._modulus_pack = None + return False + load_server_moduli = staticmethod(load_server_moduli) + + def close(self): + """ + Close this session, and any open channels that are tied to it. + """ + if not self.active: + return + self.active = False + self.packetizer.close() + self.join() + for chan in self._channels.values(): + chan._unlink() + + def get_remote_server_key(self): + """ + Return the host key of the server (in client mode). + + @note: Previously this call returned a tuple of (key type, key string). + You can get the same effect by calling + L{PKey.get_name } for the key type, and + C{str(key)} for the key string. + + @raise SSHException: if no session is currently active. + + @return: public key of the remote server + @rtype: L{PKey } + """ + if (not self.active) or (not self.initial_kex_done): + raise SSHException('No existing session') + return self.host_key + + def is_active(self): + """ + Return true if this session is active (open). + + @return: True if the session is still active (open); False if the + session is closed + @rtype: bool + """ + return self.active + + def open_session(self): + """ + Request a new channel to the server, of type C{"session"}. This + is just an alias for C{open_channel('session')}. + + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('session') + + def open_x11_channel(self, src_addr=None): + """ + Request a new channel to the client, of type C{"x11"}. This + is just an alias for C{open_channel('x11', src_addr=src_addr)}. + + @param src_addr: the source address of the x11 server (port is the + x11 port, ie. 6010) + @type src_addr: (str, int) + @return: a new L{Channel} + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + return self.open_channel('x11', src_addr=src_addr) + + def open_forwarded_tcpip_channel(self, (src_addr, src_port), (dest_addr, dest_port)): + """ + Request a new channel back to the client, of type C{"forwarded-tcpip"}. + This is used after a client has requested port forwarding, for sending + incoming connections back to the client. + + @param src_addr: originator's address + @param src_port: originator's port + @param dest_addr: local (server) connected address + @param dest_port: local (server) connected port + """ + return self.open_channel('forwarded-tcpip', (dest_addr, dest_port), (src_addr, src_port)) + + def open_channel(self, kind, dest_addr=None, src_addr=None): + """ + Request a new channel to the server. L{Channel}s are socket-like + objects used for the actual transfer of data across the session. + You may only request a channel after negotiating encryption (using + L{connect} or L{start_client}) and authenticating. + + @param kind: the kind of channel requested (usually C{"session"}, + C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"}) + @type kind: str + @param dest_addr: the destination address of this port forwarding, + if C{kind} is C{"forwarded-tcpip"} or C{"direct-tcpip"} (ignored + for other channel types) + @type dest_addr: (str, int) + @param src_addr: the source address of this port forwarding, if + C{kind} is C{"forwarded-tcpip"}, C{"direct-tcpip"}, or C{"x11"} + @type src_addr: (str, int) + @return: a new L{Channel} on success + @rtype: L{Channel} + + @raise SSHException: if the request is rejected or the session ends + prematurely + """ + chan = None + if not self.active: + # don't bother trying to allocate a channel + return None + self.lock.acquire() + try: + chanid = self._next_channel() + m = Message() + m.add_byte(chr(MSG_CHANNEL_OPEN)) + m.add_string(kind) + m.add_int(chanid) + m.add_int(self.window_size) + m.add_int(self.max_packet_size) + if (kind == 'forwarded-tcpip') or (kind == 'direct-tcpip'): + m.add_string(dest_addr[0]) + m.add_int(dest_addr[1]) + m.add_string(src_addr[0]) + m.add_int(src_addr[1]) + elif kind == 'x11': + m.add_string(src_addr[0]) + m.add_int(src_addr[1]) + chan = Channel(chanid) + self._channels.put(chanid, chan) + self.channel_events[chanid] = event = threading.Event() + self.channels_seen[chanid] = True + chan._set_transport(self) + chan._set_window(self.window_size, self.max_packet_size) + finally: + self.lock.release() + self._send_user_message(m) + while True: + event.wait(0.1); + if not self.active: + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e + if event.isSet(): + break + chan = self._channels.get(chanid) + if chan is not None: + return chan + e = self.get_exception() + if e is None: + e = SSHException('Unable to open channel.') + raise e + + def request_port_forward(self, address, port, handler=None): + """ + Ask the server to forward TCP connections from a listening port on + the server, across this SSH session. + + If a handler is given, that handler is called from a different thread + whenever a forwarded connection arrives. The handler parameters are:: + + handler(channel, (origin_addr, origin_port), (server_addr, server_port)) + + where C{server_addr} and C{server_port} are the address and port that + the server was listening on. + + If no handler is set, the default behavior is to send new incoming + forwarded connections into the accept queue, to be picked up via + L{accept}. + + @param address: the address to bind when forwarding + @type address: str + @param port: the port to forward, or 0 to ask the server to allocate + any port + @type port: int + @param handler: optional handler for incoming forwarded connections + @type handler: function(Channel, (str, int), (str, int)) + @return: the port # allocated by the server + @rtype: int + + @raise SSHException: if the server refused the TCP forward request + """ + if not self.active: + raise SSHException('SSH session not active') + address = str(address) + port = int(port) + response = self.global_request('tcpip-forward', (address, port), wait=True) + if response is None: + raise SSHException('TCP forwarding request denied') + if port == 0: + port = response.get_int() + if handler is None: + def default_handler(channel, (src_addr, src_port), (dest_addr, dest_port)): + self._queue_incoming_channel(channel) + handler = default_handler + self._tcp_handler = handler + return port + + def cancel_port_forward(self, address, port): + """ + Ask the server to cancel a previous port-forwarding request. No more + connections to the given address & port will be forwarded across this + ssh connection. + + @param address: the address to stop forwarding + @type address: str + @param port: the port to stop forwarding + @type port: int + """ + if not self.active: + return + self._tcp_handler = None + self.global_request('cancel-tcpip-forward', (address, port), wait=True) + + def open_sftp_client(self): + """ + Create an SFTP client channel from an open transport. On success, + an SFTP session will be opened with the remote host, and a new + SFTPClient object will be returned. + + @return: a new L{SFTPClient} object, referring to an sftp session + (channel) across this transport + @rtype: L{SFTPClient} + """ + return SFTPClient.from_transport(self) + + def send_ignore(self, bytes=None): + """ + Send a junk packet across the encrypted link. This is sometimes used + to add "noise" to a connection to confuse would-be attackers. It can + also be used as a keep-alive for long lived connections traversing + firewalls. + + @param bytes: the number of random bytes to send in the payload of the + ignored packet -- defaults to a random number from 10 to 41. + @type bytes: int + """ + m = Message() + m.add_byte(chr(MSG_IGNORE)) + randpool.stir() + if bytes is None: + bytes = (ord(randpool.get_bytes(1)) % 32) + 10 + m.add_bytes(randpool.get_bytes(bytes)) + self._send_user_message(m) + + def renegotiate_keys(self): + """ + Force this session to switch to new keys. Normally this is done + automatically after the session hits a certain number of packets or + bytes sent or received, but this method gives you the option of forcing + new keys whenever you want. Negotiating new keys causes a pause in + traffic both ways as the two sides swap keys and do computations. This + method returns when the session has switched to new keys. + + @raise SSHException: if the key renegotiation failed (which causes the + session to end) + """ + self.completion_event = threading.Event() + self._send_kex_init() + while True: + self.completion_event.wait(0.1) + if not self.active: + e = self.get_exception() + if e is not None: + raise e + raise SSHException('Negotiation failed.') + if self.completion_event.isSet(): + break + return + + def set_keepalive(self, interval): + """ + Turn on/off keepalive packets (default is off). If this is set, after + C{interval} seconds without sending any data over the connection, a + "keepalive" packet will be sent (and ignored by the remote host). This + can be useful to keep connections alive over a NAT, for example. + + @param interval: seconds to wait before sending a keepalive packet (or + 0 to disable keepalives). + @type interval: int + """ + self.packetizer.set_keepalive(interval, + lambda x=weakref.proxy(self): x.global_request('keepalive@lag.net', wait=False)) + + def global_request(self, kind, data=None, wait=True): + """ + Make a global request to the remote host. These are normally + extensions to the SSH2 protocol. + + @param kind: name of the request. + @type kind: str + @param data: an optional tuple containing additional data to attach + to the request. + @type data: tuple + @param wait: C{True} if this method should not return until a response + is received; C{False} otherwise. + @type wait: bool + @return: a L{Message} containing possible additional data if the + request was successful (or an empty L{Message} if C{wait} was + C{False}); C{None} if the request was denied. + @rtype: L{Message} + """ + if wait: + self.completion_event = threading.Event() + m = Message() + m.add_byte(chr(MSG_GLOBAL_REQUEST)) + m.add_string(kind) + m.add_boolean(wait) + if data is not None: + m.add(*data) + self._log(DEBUG, 'Sending global request "%s"' % kind) + self._send_user_message(m) + if not wait: + return None + while True: + self.completion_event.wait(0.1) + if not self.active: + return None + if self.completion_event.isSet(): + break + return self.global_response + + def accept(self, timeout=None): + """ + Return the next channel opened by the client over this transport, in + server mode. If no channel is opened before the given timeout, C{None} + is returned. + + @param timeout: seconds to wait for a channel, or C{None} to wait + forever + @type timeout: int + @return: a new Channel opened by the client + @rtype: L{Channel} + """ + self.lock.acquire() + try: + if len(self.server_accepts) > 0: + chan = self.server_accepts.pop(0) + else: + self.server_accept_cv.wait(timeout) + if len(self.server_accepts) > 0: + chan = self.server_accepts.pop(0) + else: + # timeout + chan = None + finally: + self.lock.release() + return chan + + def connect(self, hostkey=None, username='', password=None, pkey=None): + """ + Negotiate an SSH2 session, and optionally verify the server's host key + and authenticate using a password or private key. This is a shortcut + for L{start_client}, L{get_remote_server_key}, and + L{Transport.auth_password} or L{Transport.auth_publickey}. Use those + methods if you want more control. + + You can use this method immediately after creating a Transport to + negotiate encryption with a server. If it fails, an exception will be + thrown. On success, the method will return cleanly, and an encrypted + session exists. You may immediately call L{open_channel} or + L{open_session} to get a L{Channel} object, which is used for data + transfer. + + @note: If you fail to supply a password or private key, this method may + succeed, but a subsequent L{open_channel} or L{open_session} call may + fail because you haven't authenticated yet. + + @param hostkey: the host key expected from the server, or C{None} if + you don't want to do host key verification. + @type hostkey: L{PKey} + @param username: the username to authenticate as. + @type username: str + @param password: a password to use for authentication, if you want to + use password authentication; otherwise C{None}. + @type password: str + @param pkey: a private key to use for authentication, if you want to + use private key authentication; otherwise C{None}. + @type pkey: L{PKey} + + @raise SSHException: if the SSH2 negotiation fails, the host key + supplied by the server is incorrect, or authentication fails. + """ + if hostkey is not None: + self._preferred_keys = [ hostkey.get_name() ] + + self.start_client() + + # check host key if we were given one + if (hostkey is not None): + key = self.get_remote_server_key() + if (key.get_name() != hostkey.get_name()) or (str(key) != str(hostkey)): + self._log(DEBUG, 'Bad host key from server') + self._log(DEBUG, 'Expected: %s: %s' % (hostkey.get_name(), repr(str(hostkey)))) + self._log(DEBUG, 'Got : %s: %s' % (key.get_name(), repr(str(key)))) + raise SSHException('Bad host key from server') + self._log(DEBUG, 'Host key verified (%s)' % hostkey.get_name()) + + if (pkey is not None) or (password is not None): + if password is not None: + self._log(DEBUG, 'Attempting password auth...') + self.auth_password(username, password) + else: + self._log(DEBUG, 'Attempting public-key auth...') + self.auth_publickey(username, pkey) + + return + + def get_exception(self): + """ + Return any exception that happened during the last server request. + This can be used to fetch more specific error information after using + calls like L{start_client}. The exception (if any) is cleared after + this call. + + @return: an exception, or C{None} if there is no stored exception. + @rtype: Exception + + @since: 1.1 + """ + self.lock.acquire() + try: + e = self.saved_exception + self.saved_exception = None + return e + finally: + self.lock.release() + + def set_subsystem_handler(self, name, handler, *larg, **kwarg): + """ + Set the handler class for a subsystem in server mode. If a request + for this subsystem is made on an open ssh channel later, this handler + will be constructed and called -- see L{SubsystemHandler} for more + detailed documentation. + + Any extra parameters (including keyword arguments) are saved and + passed to the L{SubsystemHandler} constructor later. + + @param name: name of the subsystem. + @type name: str + @param handler: subclass of L{SubsystemHandler} that handles this + subsystem. + @type handler: class + """ + try: + self.lock.acquire() + self.subsystem_table[name] = (handler, larg, kwarg) + finally: + self.lock.release() + + def is_authenticated(self): + """ + Return true if this session is active and authenticated. + + @return: True if the session is still open and has been authenticated + successfully; False if authentication failed and/or the session is + closed. + @rtype: bool + """ + return self.active and (self.auth_handler is not None) and self.auth_handler.is_authenticated() + + def get_username(self): + """ + Return the username this connection is authenticated for. If the + session is not authenticated (or authentication failed), this method + returns C{None}. + + @return: username that was authenticated, or C{None}. + @rtype: string + """ + if not self.active or (self.auth_handler is None): + return None + return self.auth_handler.get_username() + + def auth_none(self, username): + """ + Try to authenticate to the server using no authentication at all. + This will almost always fail. It may be useful for determining the + list of authentication types supported by the server, by catching the + L{BadAuthenticationType} exception raised. + + @param username: the username to authenticate as + @type username: string + @return: list of auth types permissible for the next stage of + authentication (normally empty) + @rtype: list + + @raise BadAuthenticationType: if "none" authentication isn't allowed + by the server for this user + @raise SSHException: if the authentication failed due to a network + error + + @since: 1.5 + """ + if (not self.active) or (not self.initial_kex_done): + raise SSHException('No existing session') + my_event = threading.Event() + self.auth_handler = AuthHandler(self) + self.auth_handler.auth_none(username, my_event) + return self.auth_handler.wait_for_response(my_event) + + def auth_password(self, username, password, event=None, fallback=True): + """ + Authenticate to the server using a password. The username and password + are sent over an encrypted link. + + If an C{event} is passed in, this method will return immediately, and + the event will be triggered once authentication succeeds or fails. On + success, L{is_authenticated} will return C{True}. On failure, you may + use L{get_exception} to get more detailed error information. + + Since 1.1, if no event is passed, this method will block until the + authentication succeeds or fails. On failure, an exception is raised. + Otherwise, the method simply returns. + + Since 1.5, if no event is passed and C{fallback} is C{True} (the + default), if the server doesn't support plain password authentication + but does support so-called "keyboard-interactive" mode, an attempt + will be made to authenticate using this interactive mode. If it fails, + the normal exception will be thrown as if the attempt had never been + made. This is useful for some recent Gentoo and Debian distributions, + which turn off plain password authentication in a misguided belief + that interactive authentication is "more secure". (It's not.) + + If the server requires multi-step authentication (which is very rare), + this method will return a list of auth types permissible for the next + step. Otherwise, in the normal case, an empty list is returned. + + @param username: the username to authenticate as + @type username: str + @param password: the password to authenticate with + @type password: str or unicode + @param event: an event to trigger when the authentication attempt is + complete (whether it was successful or not) + @type event: threading.Event + @param fallback: C{True} if an attempt at an automated "interactive" + password auth should be made if the server doesn't support normal + password auth + @type fallback: bool + @return: list of auth types permissible for the next stage of + authentication (normally empty) + @rtype: list + + @raise BadAuthenticationType: if password authentication isn't + allowed by the server for this user (and no event was passed in) + @raise AuthenticationException: if the authentication failed (and no + event was passed in) + @raise SSHException: if there was a network error + """ + if (not self.active) or (not self.initial_kex_done): + # we should never try to send the password unless we're on a secure link + raise SSHException('No existing session') + if event is None: + my_event = threading.Event() + else: + my_event = event + self.auth_handler = AuthHandler(self) + self.auth_handler.auth_password(username, password, my_event) + if event is not None: + # caller wants to wait for event themselves + return [] + try: + return self.auth_handler.wait_for_response(my_event) + except BadAuthenticationType, x: + # if password auth isn't allowed, but keyboard-interactive *is*, try to fudge it + if not fallback or ('keyboard-interactive' not in x.allowed_types): + raise + try: + def handler(title, instructions, fields): + if len(fields) > 1: + raise SSHException('Fallback authentication failed.') + if len(fields) == 0: + # for some reason, at least on os x, a 2nd request will + # be made with zero fields requested. maybe it's just + # to try to fake out automated scripting of the exact + # type we're doing here. *shrug* :) + return [] + return [ password ] + return self.auth_interactive(username, handler) + except SSHException, ignored: + # attempt failed; just raise the original exception + raise x + return None + + def auth_publickey(self, username, key, event=None): + """ + Authenticate to the server using a private key. The key is used to + sign data from the server, so it must include the private part. + + If an C{event} is passed in, this method will return immediately, and + the event will be triggered once authentication succeeds or fails. On + success, L{is_authenticated} will return C{True}. On failure, you may + use L{get_exception} to get more detailed error information. + + Since 1.1, if no event is passed, this method will block until the + authentication succeeds or fails. On failure, an exception is raised. + Otherwise, the method simply returns. + + If the server requires multi-step authentication (which is very rare), + this method will return a list of auth types permissible for the next + step. Otherwise, in the normal case, an empty list is returned. + + @param username: the username to authenticate as + @type username: string + @param key: the private key to authenticate with + @type key: L{PKey } + @param event: an event to trigger when the authentication attempt is + complete (whether it was successful or not) + @type event: threading.Event + @return: list of auth types permissible for the next stage of + authentication (normally empty) + @rtype: list + + @raise BadAuthenticationType: if public-key authentication isn't + allowed by the server for this user (and no event was passed in) + @raise AuthenticationException: if the authentication failed (and no + event was passed in) + @raise SSHException: if there was a network error + """ + if (not self.active) or (not self.initial_kex_done): + # we should never try to authenticate unless we're on a secure link + raise SSHException('No existing session') + if event is None: + my_event = threading.Event() + else: + my_event = event + self.auth_handler = AuthHandler(self) + self.auth_handler.auth_publickey(username, key, my_event) + if event is not None: + # caller wants to wait for event themselves + return [] + return self.auth_handler.wait_for_response(my_event) + + def auth_interactive(self, username, handler, submethods=''): + """ + Authenticate to the server interactively. A handler is used to answer + arbitrary questions from the server. On many servers, this is just a + dumb wrapper around PAM. + + This method will block until the authentication succeeds or fails, + peroidically calling the handler asynchronously to get answers to + authentication questions. The handler may be called more than once + if the server continues to ask questions. + + The handler is expected to be a callable that will handle calls of the + form: C{handler(title, instructions, prompt_list)}. The C{title} is + meant to be a dialog-window title, and the C{instructions} are user + instructions (both are strings). C{prompt_list} will be a list of + prompts, each prompt being a tuple of C{(str, bool)}. The string is + the prompt and the boolean indicates whether the user text should be + echoed. + + A sample call would thus be: + C{handler('title', 'instructions', [('Password:', False)])}. + + The handler should return a list or tuple of answers to the server's + questions. + + If the server requires multi-step authentication (which is very rare), + this method will return a list of auth types permissible for the next + step. Otherwise, in the normal case, an empty list is returned. + + @param username: the username to authenticate as + @type username: string + @param handler: a handler for responding to server questions + @type handler: callable + @param submethods: a string list of desired submethods (optional) + @type submethods: str + @return: list of auth types permissible for the next stage of + authentication (normally empty). + @rtype: list + + @raise BadAuthenticationType: if public-key authentication isn't + allowed by the server for this user + @raise AuthenticationException: if the authentication failed + @raise SSHException: if there was a network error + + @since: 1.5 + """ + if (not self.active) or (not self.initial_kex_done): + # we should never try to authenticate unless we're on a secure link + raise SSHException('No existing session') + my_event = threading.Event() + self.auth_handler = AuthHandler(self) + self.auth_handler.auth_interactive(username, handler, my_event, submethods) + return self.auth_handler.wait_for_response(my_event) + + def set_log_channel(self, name): + """ + Set the channel for this transport's logging. The default is + C{"paramiko.transport"} but it can be set to anything you want. + (See the C{logging} module for more info.) SSH Channels will log + to a sub-channel of the one specified. + + @param name: new channel name for logging + @type name: str + + @since: 1.1 + """ + self.log_name = name + self.logger = util.get_logger(name) + self.packetizer.set_log(self.logger) + + def get_log_channel(self): + """ + Return the channel name used for this transport's logging. + + @return: channel name. + @rtype: str + + @since: 1.2 + """ + return self.log_name + + def set_hexdump(self, hexdump): + """ + Turn on/off logging a hex dump of protocol traffic at DEBUG level in + the logs. Normally you would want this off (which is the default), + but if you are debugging something, it may be useful. + + @param hexdump: C{True} to log protocol traffix (in hex) to the log; + C{False} otherwise. + @type hexdump: bool + """ + self.packetizer.set_hexdump(hexdump) + + def get_hexdump(self): + """ + Return C{True} if the transport is currently logging hex dumps of + protocol traffic. + + @return: C{True} if hex dumps are being logged + @rtype: bool + + @since: 1.4 + """ + return self.packetizer.get_hexdump() + + def use_compression(self, compress=True): + """ + Turn on/off compression. This will only have an affect before starting + the transport (ie before calling L{connect}, etc). By default, + compression is off since it negatively affects interactive sessions. + + @param compress: C{True} to ask the remote client/server to compress + traffic; C{False} to refuse compression + @type compress: bool + + @since: 1.5.2 + """ + if compress: + self._preferred_compression = ( 'zlib@openssh.com', 'zlib', 'none' ) + else: + self._preferred_compression = ( 'none', ) + + def getpeername(self): + """ + Return the address of the remote side of this Transport, if possible. + This is effectively a wrapper around C{'getpeername'} on the underlying + socket. If the socket-like object has no C{'getpeername'} method, + then C{("unknown", 0)} is returned. + + @return: the address if the remote host, if known + @rtype: tuple(str, int) + """ + gp = getattr(self.sock, 'getpeername', None) + if gp is None: + return ('unknown', 0) + return gp() + + def stop_thread(self): + self.active = False + self.packetizer.close() + + + ### internals... + + + def _log(self, level, msg, *args): + if issubclass(type(msg), list): + for m in msg: + self.logger.log(level, m) + else: + self.logger.log(level, msg, *args) + + def _get_modulus_pack(self): + "used by KexGex to find primes for group exchange" + return self._modulus_pack + + def _next_channel(self): + "you are holding the lock" + chanid = self._channel_counter + while self._channels.get(chanid) is not None: + self._channel_counter = (self._channel_counter + 1) & 0xffffff + chanid = self._channel_counter + self._channel_counter = (self._channel_counter + 1) & 0xffffff + return chanid + + def _unlink_channel(self, chanid): + "used by a Channel to remove itself from the active channel list" + self._channels.delete(chanid) + + def _send_message(self, data): + self.packetizer.send_message(data) + + def _send_user_message(self, data): + """ + send a message, but block if we're in key negotiation. this is used + for user-initiated requests. + """ + start = time.time() + while True: + self.clear_to_send.wait(0.1) + if not self.active: + self._log(DEBUG, 'Dropping user packet because connection is dead.') + return + self.clear_to_send_lock.acquire() + if self.clear_to_send.isSet(): + break + self.clear_to_send_lock.release() + if time.time() > start + self.clear_to_send_timeout: + raise SSHException('Key-exchange timed out waiting for key negotiation') + try: + self._send_message(data) + finally: + self.clear_to_send_lock.release() + + def _set_K_H(self, k, h): + "used by a kex object to set the K (root key) and H (exchange hash)" + self.K = k + self.H = h + if self.session_id == None: + self.session_id = h + + def _expect_packet(self, *ptypes): + "used by a kex object to register the next packet type it expects to see" + self._expected_packet = tuple(ptypes) + + def _verify_key(self, host_key, sig): + key = self._key_info[self.host_key_type](Message(host_key)) + if key is None: + raise SSHException('Unknown host key type') + if not key.verify_ssh_sig(self.H, Message(sig)): + raise SSHException('Signature verification (%s) failed.' % self.host_key_type) + self.host_key = key + + def _compute_key(self, id, nbytes): + "id is 'A' - 'F' for the various keys used by ssh" + m = Message() + m.add_mpint(self.K) + m.add_bytes(self.H) + m.add_byte(id) + m.add_bytes(self.session_id) + out = sofar = SHA.new(str(m)).digest() + while len(out) < nbytes: + m = Message() + m.add_mpint(self.K) + m.add_bytes(self.H) + m.add_bytes(sofar) + digest = SHA.new(str(m)).digest() + out += digest + sofar += digest + return out[:nbytes] + + def _get_cipher(self, name, key, iv): + if name not in self._cipher_info: + raise SSHException('Unknown client cipher ' + name) + if name in ('arcfour128', 'arcfour256'): + # arcfour cipher + cipher = self._cipher_info[name]['class'].new(key) + # as per RFC 4345, the first 1536 bytes of keystream + # generated by the cipher MUST be discarded + cipher.encrypt(" " * 1536) + return cipher + elif name.endswith("-ctr"): + # CTR modes, we need a counter + counter = Counter.new(nbits=self._cipher_info[name]['block-size'] * 8, initial_value=util.inflate_long(iv, True)) + return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv, counter) + else: + return self._cipher_info[name]['class'].new(key, self._cipher_info[name]['mode'], iv) + + def _set_x11_handler(self, handler): + # only called if a channel has turned on x11 forwarding + if handler is None: + # by default, use the same mechanism as accept() + def default_handler(channel, (src_addr, src_port)): + self._queue_incoming_channel(channel) + self._x11_handler = default_handler + else: + self._x11_handler = handler + + def _queue_incoming_channel(self, channel): + self.lock.acquire() + try: + self.server_accepts.append(channel) + self.server_accept_cv.notify() + finally: + self.lock.release() + + def run(self): + # (use the exposed "run" method, because if we specify a thread target + # of a private method, threading.Thread will keep a reference to it + # indefinitely, creating a GC cycle and not letting Transport ever be + # GC'd. it's a bug in Thread.) + + # active=True occurs before the thread is launched, to avoid a race + _active_threads.append(self) + if self.server_mode: + self._log(DEBUG, 'starting thread (server mode): %s' % hex(long(id(self)) & 0xffffffffL)) + else: + self._log(DEBUG, 'starting thread (client mode): %s' % hex(long(id(self)) & 0xffffffffL)) + try: + self.packetizer.write_all(self.local_version + '\r\n') + self._check_banner() + self._send_kex_init() + self._expect_packet(MSG_KEXINIT) + + while self.active: + if self.packetizer.need_rekey() and not self.in_kex: + self._send_kex_init() + try: + ptype, m = self.packetizer.read_message() + except NeedRekeyException: + continue + if ptype == MSG_IGNORE: + continue + elif ptype == MSG_DISCONNECT: + self._parse_disconnect(m) + self.active = False + self.packetizer.close() + break + elif ptype == MSG_DEBUG: + self._parse_debug(m) + continue + if len(self._expected_packet) > 0: + if ptype not in self._expected_packet: + raise SSHException('Expecting packet from %r, got %d' % (self._expected_packet, ptype)) + self._expected_packet = tuple() + if (ptype >= 30) and (ptype <= 39): + self.kex_engine.parse_next(ptype, m) + continue + + if ptype in self._handler_table: + self._handler_table[ptype](self, m) + elif ptype in self._channel_handler_table: + chanid = m.get_int() + chan = self._channels.get(chanid) + if chan is not None: + self._channel_handler_table[ptype](chan, m) + elif chanid in self.channels_seen: + self._log(DEBUG, 'Ignoring message for dead channel %d' % chanid) + else: + self._log(ERROR, 'Channel request for unknown channel %d' % chanid) + self.active = False + self.packetizer.close() + elif (self.auth_handler is not None) and (ptype in self.auth_handler._handler_table): + self.auth_handler._handler_table[ptype](self.auth_handler, m) + else: + self._log(WARNING, 'Oops, unhandled type %d' % ptype) + msg = Message() + msg.add_byte(chr(MSG_UNIMPLEMENTED)) + msg.add_int(m.seqno) + self._send_message(msg) + except SSHException, e: + self._log(ERROR, 'Exception: ' + str(e)) + self._log(ERROR, util.tb_strings()) + self.saved_exception = e + except EOFError, e: + self._log(DEBUG, 'EOF in transport thread') + #self._log(DEBUG, util.tb_strings()) + self.saved_exception = e + except socket.error, e: + if type(e.args) is tuple: + emsg = '%s (%d)' % (e.args[1], e.args[0]) + else: + emsg = e.args + self._log(ERROR, 'Socket exception: ' + emsg) + self.saved_exception = e + except Exception, e: + self._log(ERROR, 'Unknown exception: ' + str(e)) + self._log(ERROR, util.tb_strings()) + self.saved_exception = e + _active_threads.remove(self) + for chan in self._channels.values(): + chan._unlink() + if self.active: + self.active = False + self.packetizer.close() + if self.completion_event != None: + self.completion_event.set() + if self.auth_handler is not None: + self.auth_handler.abort() + for event in self.channel_events.values(): + event.set() + try: + self.lock.acquire() + self.server_accept_cv.notify() + finally: + self.lock.release() + self.sock.close() + + + ### protocol stages + + + def _negotiate_keys(self, m): + # throws SSHException on anything unusual + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() + if self.local_kex_init == None: + # remote side wants to renegotiate + self._send_kex_init() + self._parse_kex_init(m) + self.kex_engine.start_kex() + + def _check_banner(self): + # this is slow, but we only have to do it once + for i in range(100): + # give them 15 seconds for the first line, then just 2 seconds + # each additional line. (some sites have very high latency.) + if i == 0: + timeout = self.banner_timeout + else: + timeout = 2 + try: + buf = self.packetizer.readline(timeout) + except Exception, x: + raise SSHException('Error reading SSH protocol banner' + str(x)) + if buf[:4] == 'SSH-': + break + self._log(DEBUG, 'Banner: ' + buf) + if buf[:4] != 'SSH-': + raise SSHException('Indecipherable protocol version "' + buf + '"') + # save this server version string for later + self.remote_version = buf + # pull off any attached comment + comment = '' + i = string.find(buf, ' ') + if i >= 0: + comment = buf[i+1:] + buf = buf[:i] + # parse out version string and make sure it matches + segs = buf.split('-', 2) + if len(segs) < 3: + raise SSHException('Invalid SSH banner') + version = segs[1] + client = segs[2] + if version != '1.99' and version != '2.0': + raise SSHException('Incompatible version (%s instead of 2.0)' % (version,)) + self._log(INFO, 'Connected (version %s, client %s)' % (version, client)) + + def _send_kex_init(self): + """ + announce to the other side that we'd like to negotiate keys, and what + kind of key negotiation we support. + """ + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.clear() + finally: + self.clear_to_send_lock.release() + self.in_kex = True + if self.server_mode: + if (self._modulus_pack is None) and ('diffie-hellman-group-exchange-sha1' in self._preferred_kex): + # can't do group-exchange if we don't have a pack of potential primes + pkex = list(self.get_security_options().kex) + pkex.remove('diffie-hellman-group-exchange-sha1') + self.get_security_options().kex = pkex + available_server_keys = filter(self.server_key_dict.keys().__contains__, + self._preferred_keys) + else: + available_server_keys = self._preferred_keys + + randpool.stir() + m = Message() + m.add_byte(chr(MSG_KEXINIT)) + m.add_bytes(randpool.get_bytes(16)) + m.add_list(self._preferred_kex) + m.add_list(available_server_keys) + m.add_list(self._preferred_ciphers) + m.add_list(self._preferred_ciphers) + m.add_list(self._preferred_macs) + m.add_list(self._preferred_macs) + m.add_list(self._preferred_compression) + m.add_list(self._preferred_compression) + m.add_string('') + m.add_string('') + m.add_boolean(False) + m.add_int(0) + # save a copy for later (needed to compute a hash) + self.local_kex_init = str(m) + self._send_message(m) + + def _parse_kex_init(self, m): + cookie = m.get_bytes(16) + kex_algo_list = m.get_list() + server_key_algo_list = m.get_list() + client_encrypt_algo_list = m.get_list() + server_encrypt_algo_list = m.get_list() + client_mac_algo_list = m.get_list() + server_mac_algo_list = m.get_list() + client_compress_algo_list = m.get_list() + server_compress_algo_list = m.get_list() + client_lang_list = m.get_list() + server_lang_list = m.get_list() + kex_follows = m.get_boolean() + unused = m.get_int() + + self._log(DEBUG, 'kex algos:' + str(kex_algo_list) + ' server key:' + str(server_key_algo_list) + \ + ' client encrypt:' + str(client_encrypt_algo_list) + \ + ' server encrypt:' + str(server_encrypt_algo_list) + \ + ' client mac:' + str(client_mac_algo_list) + \ + ' server mac:' + str(server_mac_algo_list) + \ + ' client compress:' + str(client_compress_algo_list) + \ + ' server compress:' + str(server_compress_algo_list) + \ + ' client lang:' + str(client_lang_list) + \ + ' server lang:' + str(server_lang_list) + \ + ' kex follows?' + str(kex_follows)) + + # as a server, we pick the first item in the client's list that we support. + # as a client, we pick the first item in our list that the server supports. + if self.server_mode: + agreed_kex = filter(self._preferred_kex.__contains__, kex_algo_list) + else: + agreed_kex = filter(kex_algo_list.__contains__, self._preferred_kex) + if len(agreed_kex) == 0: + raise SSHException('Incompatible ssh peer (no acceptable kex algorithm)') + self.kex_engine = self._kex_info[agreed_kex[0]](self) + + if self.server_mode: + available_server_keys = filter(self.server_key_dict.keys().__contains__, + self._preferred_keys) + agreed_keys = filter(available_server_keys.__contains__, server_key_algo_list) + else: + agreed_keys = filter(server_key_algo_list.__contains__, self._preferred_keys) + if len(agreed_keys) == 0: + raise SSHException('Incompatible ssh peer (no acceptable host key)') + self.host_key_type = agreed_keys[0] + if self.server_mode and (self.get_server_key() is None): + raise SSHException('Incompatible ssh peer (can\'t match requested host key type)') + + if self.server_mode: + agreed_local_ciphers = filter(self._preferred_ciphers.__contains__, + server_encrypt_algo_list) + agreed_remote_ciphers = filter(self._preferred_ciphers.__contains__, + client_encrypt_algo_list) + else: + agreed_local_ciphers = filter(client_encrypt_algo_list.__contains__, + self._preferred_ciphers) + agreed_remote_ciphers = filter(server_encrypt_algo_list.__contains__, + self._preferred_ciphers) + if (len(agreed_local_ciphers) == 0) or (len(agreed_remote_ciphers) == 0): + raise SSHException('Incompatible ssh server (no acceptable ciphers)') + self.local_cipher = agreed_local_ciphers[0] + self.remote_cipher = agreed_remote_ciphers[0] + self._log(DEBUG, 'Ciphers agreed: local=%s, remote=%s' % (self.local_cipher, self.remote_cipher)) + + if self.server_mode: + agreed_remote_macs = filter(self._preferred_macs.__contains__, client_mac_algo_list) + agreed_local_macs = filter(self._preferred_macs.__contains__, server_mac_algo_list) + else: + agreed_local_macs = filter(client_mac_algo_list.__contains__, self._preferred_macs) + agreed_remote_macs = filter(server_mac_algo_list.__contains__, self._preferred_macs) + if (len(agreed_local_macs) == 0) or (len(agreed_remote_macs) == 0): + raise SSHException('Incompatible ssh server (no acceptable macs)') + self.local_mac = agreed_local_macs[0] + self.remote_mac = agreed_remote_macs[0] + + if self.server_mode: + agreed_remote_compression = filter(self._preferred_compression.__contains__, client_compress_algo_list) + agreed_local_compression = filter(self._preferred_compression.__contains__, server_compress_algo_list) + else: + agreed_local_compression = filter(client_compress_algo_list.__contains__, self._preferred_compression) + agreed_remote_compression = filter(server_compress_algo_list.__contains__, self._preferred_compression) + if (len(agreed_local_compression) == 0) or (len(agreed_remote_compression) == 0): + raise SSHException('Incompatible ssh server (no acceptable compression) %r %r %r' % (agreed_local_compression, agreed_remote_compression, self._preferred_compression)) + self.local_compression = agreed_local_compression[0] + self.remote_compression = agreed_remote_compression[0] + + self._log(DEBUG, 'using kex %s; server key type %s; cipher: local %s, remote %s; mac: local %s, remote %s; compression: local %s, remote %s' % + (agreed_kex[0], self.host_key_type, self.local_cipher, self.remote_cipher, self.local_mac, + self.remote_mac, self.local_compression, self.remote_compression)) + + # save for computing hash later... + # now wait! openssh has a bug (and others might too) where there are + # actually some extra bytes (one NUL byte in openssh's case) added to + # the end of the packet but not parsed. turns out we need to throw + # away those bytes because they aren't part of the hash. + self.remote_kex_init = chr(MSG_KEXINIT) + m.get_so_far() + + def _activate_inbound(self): + "switch on newly negotiated encryption parameters for inbound traffic" + block_size = self._cipher_info[self.remote_cipher]['block-size'] + if self.server_mode: + IV_in = self._compute_key('A', block_size) + key_in = self._compute_key('C', self._cipher_info[self.remote_cipher]['key-size']) + else: + IV_in = self._compute_key('B', block_size) + key_in = self._compute_key('D', self._cipher_info[self.remote_cipher]['key-size']) + engine = self._get_cipher(self.remote_cipher, key_in, IV_in) + mac_size = self._mac_info[self.remote_mac]['size'] + mac_engine = self._mac_info[self.remote_mac]['class'] + # initial mac keys are done in the hash's natural size (not the potentially truncated + # transmission size) + if self.server_mode: + mac_key = self._compute_key('E', mac_engine.digest_size) + else: + mac_key = self._compute_key('F', mac_engine.digest_size) + self.packetizer.set_inbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) + compress_in = self._compression_info[self.remote_compression][1] + if (compress_in is not None) and ((self.remote_compression != 'zlib@openssh.com') or self.authenticated): + self._log(DEBUG, 'Switching on inbound compression ...') + self.packetizer.set_inbound_compressor(compress_in()) + + def _activate_outbound(self): + "switch on newly negotiated encryption parameters for outbound traffic" + m = Message() + m.add_byte(chr(MSG_NEWKEYS)) + self._send_message(m) + block_size = self._cipher_info[self.local_cipher]['block-size'] + if self.server_mode: + IV_out = self._compute_key('B', block_size) + key_out = self._compute_key('D', self._cipher_info[self.local_cipher]['key-size']) + else: + IV_out = self._compute_key('A', block_size) + key_out = self._compute_key('C', self._cipher_info[self.local_cipher]['key-size']) + engine = self._get_cipher(self.local_cipher, key_out, IV_out) + mac_size = self._mac_info[self.local_mac]['size'] + mac_engine = self._mac_info[self.local_mac]['class'] + # initial mac keys are done in the hash's natural size (not the potentially truncated + # transmission size) + if self.server_mode: + mac_key = self._compute_key('F', mac_engine.digest_size) + else: + mac_key = self._compute_key('E', mac_engine.digest_size) + self.packetizer.set_outbound_cipher(engine, block_size, mac_engine, mac_size, mac_key) + compress_out = self._compression_info[self.local_compression][0] + if (compress_out is not None) and ((self.local_compression != 'zlib@openssh.com') or self.authenticated): + self._log(DEBUG, 'Switching on outbound compression ...') + self.packetizer.set_outbound_compressor(compress_out()) + if not self.packetizer.need_rekey(): + self.in_kex = False + # we always expect to receive NEWKEYS now + self._expect_packet(MSG_NEWKEYS) + + def _auth_trigger(self): + self.authenticated = True + # delayed initiation of compression + if self.local_compression == 'zlib@openssh.com': + compress_out = self._compression_info[self.local_compression][0] + self._log(DEBUG, 'Switching on outbound compression ...') + self.packetizer.set_outbound_compressor(compress_out()) + if self.remote_compression == 'zlib@openssh.com': + compress_in = self._compression_info[self.remote_compression][1] + self._log(DEBUG, 'Switching on inbound compression ...') + self.packetizer.set_inbound_compressor(compress_in()) + + def _parse_newkeys(self, m): + self._log(DEBUG, 'Switch to new keys ...') + self._activate_inbound() + # can also free a bunch of stuff here + self.local_kex_init = self.remote_kex_init = None + self.K = None + self.kex_engine = None + if self.server_mode and (self.auth_handler is None): + # create auth handler for server mode + self.auth_handler = AuthHandler(self) + if not self.initial_kex_done: + # this was the first key exchange + self.initial_kex_done = True + # send an event? + if self.completion_event != None: + self.completion_event.set() + # it's now okay to send data again (if this was a re-key) + if not self.packetizer.need_rekey(): + self.in_kex = False + self.clear_to_send_lock.acquire() + try: + self.clear_to_send.set() + finally: + self.clear_to_send_lock.release() + return + + def _parse_disconnect(self, m): + code = m.get_int() + desc = m.get_string() + self._log(INFO, 'Disconnect (code %d): %s' % (code, desc)) + + def _parse_global_request(self, m): + kind = m.get_string() + self._log(DEBUG, 'Received global request "%s"' % kind) + want_reply = m.get_boolean() + if not self.server_mode: + self._log(DEBUG, 'Rejecting "%s" global request from server.' % kind) + ok = False + elif kind == 'tcpip-forward': + address = m.get_string() + port = m.get_int() + ok = self.server_object.check_port_forward_request(address, port) + if ok != False: + ok = (ok,) + elif kind == 'cancel-tcpip-forward': + address = m.get_string() + port = m.get_int() + self.server_object.cancel_port_forward_request(address, port) + ok = True + else: + ok = self.server_object.check_global_request(kind, m) + extra = () + if type(ok) is tuple: + extra = ok + ok = True + if want_reply: + msg = Message() + if ok: + msg.add_byte(chr(MSG_REQUEST_SUCCESS)) + msg.add(*extra) + else: + msg.add_byte(chr(MSG_REQUEST_FAILURE)) + self._send_message(msg) + + def _parse_request_success(self, m): + self._log(DEBUG, 'Global request successful.') + self.global_response = m + if self.completion_event is not None: + self.completion_event.set() + + def _parse_request_failure(self, m): + self._log(DEBUG, 'Global request denied.') + self.global_response = None + if self.completion_event is not None: + self.completion_event.set() + + def _parse_channel_open_success(self, m): + chanid = m.get_int() + server_chanid = m.get_int() + server_window_size = m.get_int() + server_max_packet_size = m.get_int() + chan = self._channels.get(chanid) + if chan is None: + self._log(WARNING, 'Success for unrequested channel! [??]') + return + self.lock.acquire() + try: + chan._set_remote_channel(server_chanid, server_window_size, server_max_packet_size) + self._log(INFO, 'Secsh channel %d opened.' % chanid) + if chanid in self.channel_events: + self.channel_events[chanid].set() + del self.channel_events[chanid] + finally: + self.lock.release() + return + + def _parse_channel_open_failure(self, m): + chanid = m.get_int() + reason = m.get_int() + reason_str = m.get_string() + lang = m.get_string() + reason_text = CONNECTION_FAILED_CODE.get(reason, '(unknown code)') + self._log(INFO, 'Secsh channel %d open FAILED: %s: %s' % (chanid, reason_str, reason_text)) + self.lock.acquire() + try: + self.saved_exception = ChannelException(reason, reason_text) + if chanid in self.channel_events: + self._channels.delete(chanid) + if chanid in self.channel_events: + self.channel_events[chanid].set() + del self.channel_events[chanid] + finally: + self.lock.release() + return + + def _parse_channel_open(self, m): + kind = m.get_string() + chanid = m.get_int() + initial_window_size = m.get_int() + max_packet_size = m.get_int() + reject = False + if (kind == 'x11') and (self._x11_handler is not None): + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming x11 connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif (kind == 'forwarded-tcpip') and (self._tcp_handler is not None): + server_addr = m.get_string() + server_port = m.get_int() + origin_addr = m.get_string() + origin_port = m.get_int() + self._log(DEBUG, 'Incoming tcp forwarded connection from %s:%d' % (origin_addr, origin_port)) + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + elif not self.server_mode: + self._log(DEBUG, 'Rejecting "%s" channel request from server.' % kind) + reject = True + reason = OPEN_FAILED_ADMINISTRATIVELY_PROHIBITED + else: + self.lock.acquire() + try: + my_chanid = self._next_channel() + finally: + self.lock.release() + if kind == 'direct-tcpip': + # handle direct-tcpip requests comming from the client + dest_addr = m.get_string() + dest_port = m.get_int() + origin_addr = m.get_string() + origin_port = m.get_int() + reason = self.server_object.check_channel_direct_tcpip_request( + my_chanid, (origin_addr, origin_port), + (dest_addr, dest_port)) + else: + reason = self.server_object.check_channel_request(kind, my_chanid) + if reason != OPEN_SUCCEEDED: + self._log(DEBUG, 'Rejecting "%s" channel request from client.' % kind) + reject = True + if reject: + msg = Message() + msg.add_byte(chr(MSG_CHANNEL_OPEN_FAILURE)) + msg.add_int(chanid) + msg.add_int(reason) + msg.add_string('') + msg.add_string('en') + self._send_message(msg) + return + + chan = Channel(my_chanid) + self.lock.acquire() + try: + self._channels.put(my_chanid, chan) + self.channels_seen[my_chanid] = True + chan._set_transport(self) + chan._set_window(self.window_size, self.max_packet_size) + chan._set_remote_channel(chanid, initial_window_size, max_packet_size) + finally: + self.lock.release() + m = Message() + m.add_byte(chr(MSG_CHANNEL_OPEN_SUCCESS)) + m.add_int(chanid) + m.add_int(my_chanid) + m.add_int(self.window_size) + m.add_int(self.max_packet_size) + self._send_message(m) + self._log(INFO, 'Secsh channel %d (%s) opened.', my_chanid, kind) + if kind == 'x11': + self._x11_handler(chan, (origin_addr, origin_port)) + elif kind == 'forwarded-tcpip': + chan.origin_addr = (origin_addr, origin_port) + self._tcp_handler(chan, (origin_addr, origin_port), (server_addr, server_port)) + else: + self._queue_incoming_channel(chan) + + def _parse_debug(self, m): + always_display = m.get_boolean() + msg = m.get_string() + lang = m.get_string() + self._log(DEBUG, 'Debug msg: ' + util.safe_string(msg)) + + def _get_subsystem_handler(self, name): + try: + self.lock.acquire() + if name not in self.subsystem_table: + return (None, [], {}) + return self.subsystem_table[name] + finally: + self.lock.release() + + _handler_table = { + MSG_NEWKEYS: _parse_newkeys, + MSG_GLOBAL_REQUEST: _parse_global_request, + MSG_REQUEST_SUCCESS: _parse_request_success, + MSG_REQUEST_FAILURE: _parse_request_failure, + MSG_CHANNEL_OPEN_SUCCESS: _parse_channel_open_success, + MSG_CHANNEL_OPEN_FAILURE: _parse_channel_open_failure, + MSG_CHANNEL_OPEN: _parse_channel_open, + MSG_KEXINIT: _negotiate_keys, + } + + _channel_handler_table = { + MSG_CHANNEL_SUCCESS: Channel._request_success, + MSG_CHANNEL_FAILURE: Channel._request_failed, + MSG_CHANNEL_DATA: Channel._feed, + MSG_CHANNEL_EXTENDED_DATA: Channel._feed_extended, + MSG_CHANNEL_WINDOW_ADJUST: Channel._window_adjust, + MSG_CHANNEL_REQUEST: Channel._handle_request, + MSG_CHANNEL_EOF: Channel._handle_eof, + MSG_CHANNEL_CLOSE: Channel._handle_close, + } diff --git a/tools/migration/paramiko/util.py b/tools/migration/paramiko/util.py new file mode 100644 index 00000000000..0d6a53483ef --- /dev/null +++ b/tools/migration/paramiko/util.py @@ -0,0 +1,302 @@ +# Copyright (C) 2003-2007 Robey Pointer +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Useful functions used by the rest of paramiko. +""" + +from __future__ import generators + +import array +from binascii import hexlify, unhexlify +import sys +import struct +import traceback +import threading + +from paramiko.common import * +from paramiko.config import SSHConfig + + +# Change by RogerB - python < 2.3 doesn't have enumerate so we implement it +if sys.version_info < (2,3): + class enumerate: + def __init__ (self, sequence): + self.sequence = sequence + def __iter__ (self): + count = 0 + for item in self.sequence: + yield (count, item) + count += 1 + + +def inflate_long(s, always_positive=False): + "turns a normalized byte string into a long-int (adapted from Crypto.Util.number)" + out = 0L + negative = 0 + if not always_positive and (len(s) > 0) and (ord(s[0]) >= 0x80): + negative = 1 + if len(s) % 4: + filler = '\x00' + if negative: + filler = '\xff' + s = filler * (4 - len(s) % 4) + s + for i in range(0, len(s), 4): + out = (out << 32) + struct.unpack('>I', s[i:i+4])[0] + if negative: + out -= (1L << (8 * len(s))) + return out + +def deflate_long(n, add_sign_padding=True): + "turns a long-int into a normalized byte string (adapted from Crypto.Util.number)" + # after much testing, this algorithm was deemed to be the fastest + s = '' + n = long(n) + while (n != 0) and (n != -1): + s = struct.pack('>I', n & 0xffffffffL) + s + n = n >> 32 + # strip off leading zeros, FFs + for i in enumerate(s): + if (n == 0) and (i[1] != '\000'): + break + if (n == -1) and (i[1] != '\xff'): + break + else: + # degenerate case, n was either 0 or -1 + i = (0,) + if n == 0: + s = '\000' + else: + s = '\xff' + s = s[i[0]:] + if add_sign_padding: + if (n == 0) and (ord(s[0]) >= 0x80): + s = '\x00' + s + if (n == -1) and (ord(s[0]) < 0x80): + s = '\xff' + s + return s + +def format_binary_weird(data): + out = '' + for i in enumerate(data): + out += '%02X' % ord(i[1]) + if i[0] % 2: + out += ' ' + if i[0] % 16 == 15: + out += '\n' + return out + +def format_binary(data, prefix=''): + x = 0 + out = [] + while len(data) > x + 16: + out.append(format_binary_line(data[x:x+16])) + x += 16 + if x < len(data): + out.append(format_binary_line(data[x:])) + return [prefix + x for x in out] + +def format_binary_line(data): + left = ' '.join(['%02X' % ord(c) for c in data]) + right = ''.join([('.%c..' % c)[(ord(c)+63)//95] for c in data]) + return '%-50s %s' % (left, right) + +def hexify(s): + return hexlify(s).upper() + +def unhexify(s): + return unhexlify(s) + +def safe_string(s): + out = '' + for c in s: + if (ord(c) >= 32) and (ord(c) <= 127): + out += c + else: + out += '%%%02X' % ord(c) + return out + +# ''.join([['%%%02X' % ord(c), c][(ord(c) >= 32) and (ord(c) <= 127)] for c in s]) + +def bit_length(n): + norm = deflate_long(n, 0) + hbyte = ord(norm[0]) + if hbyte == 0: + return 1 + bitlen = len(norm) * 8 + while not (hbyte & 0x80): + hbyte <<= 1 + bitlen -= 1 + return bitlen + +def tb_strings(): + return ''.join(traceback.format_exception(*sys.exc_info())).split('\n') + +def generate_key_bytes(hashclass, salt, key, nbytes): + """ + Given a password, passphrase, or other human-source key, scramble it + through a secure hash into some keyworthy bytes. This specific algorithm + is used for encrypting/decrypting private key files. + + @param hashclass: class from L{Crypto.Hash} that can be used as a secure + hashing function (like C{MD5} or C{SHA}). + @type hashclass: L{Crypto.Hash} + @param salt: data to salt the hash with. + @type salt: string + @param key: human-entered password or passphrase. + @type key: string + @param nbytes: number of bytes to generate. + @type nbytes: int + @return: key data + @rtype: string + """ + keydata = '' + digest = '' + if len(salt) > 8: + salt = salt[:8] + while nbytes > 0: + hash_obj = hashclass.new() + if len(digest) > 0: + hash_obj.update(digest) + hash_obj.update(key) + hash_obj.update(salt) + digest = hash_obj.digest() + size = min(nbytes, len(digest)) + keydata += digest[:size] + nbytes -= size + return keydata + +def load_host_keys(filename): + """ + Read a file of known SSH host keys, in the format used by openssh, and + return a compound dict of C{hostname -> keytype ->} L{PKey }. + The hostname may be an IP address or DNS name. The keytype will be either + C{"ssh-rsa"} or C{"ssh-dss"}. + + This type of file unfortunately doesn't exist on Windows, but on posix, + it will usually be stored in C{os.path.expanduser("~/.ssh/known_hosts")}. + + Since 1.5.3, this is just a wrapper around L{HostKeys}. + + @param filename: name of the file to read host keys from + @type filename: str + @return: dict of host keys, indexed by hostname and then keytype + @rtype: dict(hostname, dict(keytype, L{PKey })) + """ + from paramiko.hostkeys import HostKeys + return HostKeys(filename) + +def parse_ssh_config(file_obj): + """ + Provided only as a backward-compatible wrapper around L{SSHConfig}. + """ + config = SSHConfig() + config.parse(file_obj) + return config + +def lookup_ssh_host_config(hostname, config): + """ + Provided only as a backward-compatible wrapper around L{SSHConfig}. + """ + return config.lookup(hostname) + +def mod_inverse(x, m): + # it's crazy how small python can make this function. + u1, u2, u3 = 1, 0, m + v1, v2, v3 = 0, 1, x + + while v3 > 0: + q = u3 // v3 + u1, v1 = v1, u1 - v1 * q + u2, v2 = v2, u2 - v2 * q + u3, v3 = v3, u3 - v3 * q + if u2 < 0: + u2 += m + return u2 + +_g_thread_ids = {} +_g_thread_counter = 0 +_g_thread_lock = threading.Lock() +def get_thread_id(): + global _g_thread_ids, _g_thread_counter, _g_thread_lock + tid = id(threading.currentThread()) + try: + return _g_thread_ids[tid] + except KeyError: + _g_thread_lock.acquire() + try: + _g_thread_counter += 1 + ret = _g_thread_ids[tid] = _g_thread_counter + finally: + _g_thread_lock.release() + return ret + +def log_to_file(filename, level=DEBUG): + "send paramiko logs to a logfile, if they're not already going somewhere" + l = logging.getLogger("paramiko") + if len(l.handlers) > 0: + return + l.setLevel(level) + f = open(filename, 'w') + lh = logging.StreamHandler(f) + lh.setFormatter(logging.Formatter('%(levelname)-.3s [%(asctime)s.%(msecs)03d] thr=%(_threadid)-3d %(name)s: %(message)s', + '%Y%m%d-%H:%M:%S')) + l.addHandler(lh) + +# make only one filter object, so it doesn't get applied more than once +class PFilter (object): + def filter(self, record): + record._threadid = get_thread_id() + return True +_pfilter = PFilter() + +def get_logger(name): + l = logging.getLogger(name) + l.addFilter(_pfilter) + return l + + +class Counter (object): + """Stateful counter for CTR mode crypto""" + def __init__(self, nbits, initial_value=1L, overflow=0L): + self.blocksize = nbits / 8 + self.overflow = overflow + # start with value - 1 so we don't have to store intermediate values when counting + # could the iv be 0? + if initial_value == 0: + self.value = array.array('c', '\xFF' * self.blocksize) + else: + x = deflate_long(initial_value - 1, add_sign_padding=False) + self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x) + + def __call__(self): + """Increament the counter and return the new value""" + i = self.blocksize - 1 + while i > -1: + c = self.value[i] = chr((ord(self.value[i]) + 1) % 256) + if c != '\x00': + return self.value.tostring() + i -= 1 + # counter reset + x = deflate_long(self.overflow, add_sign_padding=False) + self.value = array.array('c', '\x00' * (self.blocksize - len(x)) + x) + return self.value.tostring() + + def new(cls, nbits, initial_value=1L, overflow=0L): + return cls(nbits, initial_value=initial_value, overflow=overflow) + new = classmethod(new) diff --git a/tools/migration/paramiko/win_pageant.py b/tools/migration/paramiko/win_pageant.py new file mode 100644 index 00000000000..787032b8d08 --- /dev/null +++ b/tools/migration/paramiko/win_pageant.py @@ -0,0 +1,148 @@ +# Copyright (C) 2005 John Arbash-Meinel +# Modified up by: Todd Whiteman +# +# This file is part of paramiko. +# +# Paramiko is free software; you can redistribute it and/or modify it under the +# terms of the GNU Lesser General Public License as published by the Free +# Software Foundation; either version 2.1 of the License, or (at your option) +# any later version. +# +# Paramiko is distrubuted in the hope that it will be useful, but WITHOUT ANY +# WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR +# A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more +# details. +# +# You should have received a copy of the GNU Lesser General Public License +# along with Paramiko; if not, write to the Free Software Foundation, Inc., +# 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA. + +""" +Functions for communicating with Pageant, the basic windows ssh agent program. +""" + +import os +import struct +import tempfile +import mmap +import array + +# if you're on windows, you should have one of these, i guess? +# ctypes is part of standard library since Python 2.5 +_has_win32all = False +_has_ctypes = False +try: + # win32gui is preferred over win32ui to avoid MFC dependencies + import win32gui + _has_win32all = True +except ImportError: + try: + import ctypes + _has_ctypes = True + except ImportError: + pass + + +_AGENT_COPYDATA_ID = 0x804e50ba +_AGENT_MAX_MSGLEN = 8192 +# Note: The WM_COPYDATA value is pulled from win32con, as a workaround +# so we do not have to import this huge library just for this one variable. +win32con_WM_COPYDATA = 74 + + +def _get_pageant_window_object(): + if _has_win32all: + try: + hwnd = win32gui.FindWindow('Pageant', 'Pageant') + return hwnd + except win32gui.error: + pass + elif _has_ctypes: + # Return 0 if there is no Pageant window. + return ctypes.windll.user32.FindWindowA('Pageant', 'Pageant') + return None + + +def can_talk_to_agent(): + """ + Check to see if there is a "Pageant" agent we can talk to. + + This checks both if we have the required libraries (win32all or ctypes) + and if there is a Pageant currently running. + """ + if (_has_win32all or _has_ctypes) and _get_pageant_window_object(): + return True + return False + + +def _query_pageant(msg): + hwnd = _get_pageant_window_object() + if not hwnd: + # Raise a failure to connect exception, pageant isn't running anymore! + return None + + # Write our pageant request string into the file (pageant will read this to determine what to do) + filename = tempfile.mktemp('.pag') + map_filename = os.path.basename(filename) + + f = open(filename, 'w+b') + f.write(msg ) + # Ensure the rest of the file is empty, otherwise pageant will read this + f.write('\0' * (_AGENT_MAX_MSGLEN - len(msg))) + # Create the shared file map that pageant will use to read from + pymap = mmap.mmap(f.fileno(), _AGENT_MAX_MSGLEN, tagname=map_filename, access=mmap.ACCESS_WRITE) + try: + # Create an array buffer containing the mapped filename + char_buffer = array.array("c", map_filename + '\0') + char_buffer_address, char_buffer_size = char_buffer.buffer_info() + # Create a string to use for the SendMessage function call + cds = struct.pack("LLP", _AGENT_COPYDATA_ID, char_buffer_size, char_buffer_address) + + if _has_win32all: + # win32gui.SendMessage should also allow the same pattern as + # ctypes, but let's keep it like this for now... + response = win32gui.SendMessage(hwnd, win32con_WM_COPYDATA, len(cds), cds) + elif _has_ctypes: + _buf = array.array('B', cds) + _addr, _size = _buf.buffer_info() + response = ctypes.windll.user32.SendMessageA(hwnd, win32con_WM_COPYDATA, _size, _addr) + else: + response = 0 + + if response > 0: + datalen = pymap.read(4) + retlen = struct.unpack('>I', datalen)[0] + return datalen + pymap.read(retlen) + return None + finally: + pymap.close() + f.close() + # Remove the file, it was temporary only + os.unlink(filename) + + +class PageantConnection (object): + """ + Mock "connection" to an agent which roughly approximates the behavior of + a unix local-domain socket (as used by Agent). Requests are sent to the + pageant daemon via special Windows magick, and responses are buffered back + for subsequent reads. + """ + + def __init__(self): + self._response = None + + def send(self, data): + self._response = _query_pageant(data) + + def recv(self, n): + if self._response is None: + return '' + ret = self._response[:n] + self._response = self._response[n:] + if self._response == '': + self._response = None + return ret + + def close(self): + pass diff --git a/tools/migration/share_all_lus.sh b/tools/migration/share_all_lus.sh new file mode 100755 index 00000000000..0985a9caf21 --- /dev/null +++ b/tools/migration/share_all_lus.sh @@ -0,0 +1,226 @@ +# !/usr/bin/env bash + + + + # + # Copyright (C) 2011 Cloud.com, Inc. All rights reserved. + # + + +# lu_share.sh -- makes all logical units (LUs) available over iSCSI, for a specified initiator IQN +# OpenSolaris + +usage() { + printf "Usage: %s -i [ -s | -u ]\n" $(basename $0) >&2 +} + +valid_target_name() { # + echo $1 | grep ':lu:' >/dev/null + return $? +} + +target_iqn_from_target_name() { # + echo $1 | cut -d':' -f1,2,3 +} + +hg_from_initiator_iqn() { # + echo $1 + return 0 +} + +lu_name_from_target_name() { # + echo $1 | cut -d':' -f5 +} + +view_entry_from_hg_and_lu_name() { # + local hg=$1 + local lu_name=$2 + local view= + local last_view= + local last_hg= + for w in $(stmfadm list-view -l $lu_name) + do + case $w in + [0-9]*) last_view=$w + ;; + esac + + if [ "$w" == "$hg" ] + then + echo $last_view + return 0 + fi + done + return 1 +} + +create_host_group() { # + local hg=$1 + local i_iqn=$2 + local host_group= + + local result= + result=$(stmfadm create-hg $hg 2>&1) + if [ $? -ne 0 ] + then + echo $result | grep "already exists" > /dev/null + if [ $? -ne 0 ] + then + printf "%s: create-hg %s failed due to %s\n" $(basename $0) $i_iqn $result >&2 + return 11 + fi + fi + + result=$(stmfadm add-hg-member -g $hg $i_iqn 2>&1) + if [ $? -ne 0 ] + then + echo $result | grep "already exists" > /dev/null + if [ $? -ne 0 ] + then + printf "%s: unable to add %s due to %s\n" $(basename $0) $i_iqn $result >&2 + return 12 + fi + fi + return 0 +} + +add_view() { # + local i=1 + local hg=$1 + local lu=$2 + + while [ $i -lt 500 ] + do + local lun=$[ ( $RANDOM % 512 ) ] + local result= + result=$(stmfadm add-view -h $hg -n $lun $lu 2>&1) + if [ $? -eq 0 ] + then + printf "lun %s for luname %s\n" $lun $lu >&2 + #stmfadm list-view -l $lu + #sbdadm list-lu + return 0 + fi + echo $result | grep "view entry exists" > /dev/null + if [ $? -eq 0 ] + then + return 0 + fi + echo $result | grep "LUN already in use" > /dev/null + if [ $? -ne 0 ] + then + echo $result + return 1 + fi + let i=i+1 + done + printf "Unable to add view after lots of tries\n" >&2 + return 1 +} + +add_view_and_hg() { # + local i_iqn=$1 + local lu_name=$2 + local hg="Migration" + local result= + + if ! create_host_group $hg $i_iqn + then + printf "%s: create_host_group failed: %s %s\n" $(basename $0) $i_iqn $lu_name >&2 + return 22 + fi + + if ! add_view $hg $lu_name + then + return 1 + fi + + return 0 +} + +remove_view() { # + local lu_name=$1 + local hg="Migration" + local view=$(view_entry_from_hg_and_lu_name $hg $lu_name) + if [ -n "$view" ] + then + local result= + result=$(stmfadm remove-view -l $lu_name $view 2>&1) + if [ $? -ne 0 ] + then + echo $result | grep "not found" + if [ $? -eq 0 ] + then + return 0 + fi + echo $result | grep "no views found" + if [ $? -eq 0 ] + then + return 0 + fi + printf "Unable to remove view due to: $result\n" >&2 + return 5 + fi + fi + return 0 +} + +# set -x + +iflag= +sflag= +uflag= + +while getopts 'sui:' OPTION +do + case $OPTION in + i) iflag=1 + init_iqn="$OPTARG" + ;; + s) sflag=1 + ;; + u) uflag=1 + ;; + *) usage + exit 2 + ;; + esac +done + +if [ "$sflag$iflag" != "11" -a "$uflag" != "1" ] +then + usage + exit 3 +fi + +lu_names="$(stmfadm list-lu | cut -d":" -f2)" + +for lu_name in $lu_names +do + if [ "$uflag" == "1" ] + then + remove_view $lu_name + if [ $? -gt 0 ] + then + printf "%s: remove_view failed: %s\n" $(basename $0) $lu_name >&2 + exit 1 + fi + else + if [ "$sflag" == "1" ] + then + add_view_and_hg $init_iqn $lu_name + if [ $? -gt 0 ] + then + printf "%s: add_view failed: %s\n" $(basename $0) $lu_name >&2 + exit 1 + fi + fi + fi +done + +if [ "$uflag" == "1" ] +then + stmfadm delete-hg "Migration" +fi + +exit 0 diff --git a/tools/migration/upgrade.properties b/tools/migration/upgrade.properties new file mode 100644 index 00000000000..502d9a19ffb --- /dev/null +++ b/tools/migration/upgrade.properties @@ -0,0 +1,85 @@ +### Users to upgrade + +# Specify the list of user IDs to upgrade as a comma separated list; i.e. 3, 4, 5 +# This is optional; you can also run upgrade.py with a list of user IDs; i.e. "python upgrade.py 3 4 5". +USERS = + + + +### Information about the 1.0.x system + +# The management server IP +SRC_MANAGEMENT_SERVER_IP = + +# The database username and password +SRC_DB_LOGIN = +SRC_DB_PASSWORD = + +# A map between storage host IPs and root passwords +# Ex: 1.2.3.4:password1, 2.3.4.5:password2 +STORAGE_HOST_PASSWORDS = + +# The id of the zone +SRC_ZONE_ID = 1 + + + +### Information about the 2.1.x system + +# The management server IP +DEST_MANAGEMENT_SERVER_IP = localhost + +# The database username and password +DEST_DB_LOGIN = +DEST_DB_PASSWORD = + +# The private IP and root password of one of the XenServers in the 2.1.x system +# Fill this section out only if all of your XenServers have the same root password +DEST_XENSERVER_IP = +DEST_XENSERVER_PASSWORD = + +# A map between XenServer IPs in the 2.1.x system to passwords for each host +# I.e. 1.2.3.4:password1, 2.3.4.5:password2, 3.4.5.6:password3 +# Fill this section out only if your XenServers have different root passwords +DEST_XENSERVER_PASSWORDS = + +# A map between template IDs in the 1.0.x system to guest OS IDs in the 2.1.x system +# Should be in the format: [1.0 template ID]:[2.1 guest OS ID]. Ex: 3:12, 4:14, 5:64 +# To find the ID that corresponds to a guest OS, refer to the output of the API command: http://localhost:8096/client/api/?command=listOsTypes +GUEST_OS_MAP = + +# The id of the ISO you registered +DEST_ISO_ID = 201 + +# The id of the zone +DEST_ZONE_ID = 1 + +# The id of the default CentOS template +DEST_TEMPLATE_ID = 2 + +# The id of the default service offering +DEST_SERVICE_OFFERING_ID = 3 + +# The id of the default disk offering +DEST_DISK_OFFERING_ID = 5 + +# The id of the guest OS category that corresponds to Windows +DEST_WINDOWS_GUEST_OS_CATEGORY_ID = 6 + + + +### Misc. variables + +# The location of the log file +LOG_FILE = ./migrationLog + +# The location of the migrated users file +MIGRATED_ACCOUNTS_FILE = ./migratedAccounts + +# The number of retries for async API commands +ASYNC_RETRIES = 20 + +# The time to pause between retries for async API commands +ASYNC_SLEEP_TIME = 30 + + diff --git a/tools/migration/upgrade.py b/tools/migration/upgrade.py new file mode 100644 index 00000000000..529f69e103f --- /dev/null +++ b/tools/migration/upgrade.py @@ -0,0 +1,1599 @@ + + # + # Copyright (C) 2011 Cloud.com, Inc. All rights reserved. + # + + +from xml.dom.minidom import * +import urllib2 +import MySQLdb +from XenAPI import * +import time +import datetime +import os +import sys +import paramiko +import traceback + +### Logging functions + +def getTimestamp(): + return datetime.datetime.now().strftime("%m/%d/%Y | %I:%M:%S %p") + +def nonExitingLogDecorator(entryMessage): + return genDecoratorFn(entryMessage, False, False) + +def basicLogDecorator(entryMessage): + return genDecoratorFn(entryMessage, True, True) + +def verboseLogDecorator(entryMessage): + return genDecoratorFn(entryMessage, False, True) + +def genDecoratorFn(entryMessage, printToScreen, exitOnError): + def wrap(f): + def g(*args): + writeToLog("", printToScreen) + writeToLog(getTimestamp() + " | " + entryMessage, printToScreen) + + if (len(args) > 0): + argString = "" + for i in range(len(args)): + arg = args[i] + argString += str(arg) + if (i != len(args) - 1): + argString += ", " + writeToLog("args: " + argString, False) + + returnValue = None + try: + returnValue = f(*args) + except SystemExit: + sys.exit(1) + except Exception, e: + if (exitOnError): + handleError(str(e), True) + traceback.print_exc(file = GLOBALS["LOG_FILE"]) + sys.exit(1) + else: + return False + + if (returnValue in (None, False)): + if (exitOnError): + writeToLog(str(f) + " returned " + str(returnValue), False) + handleError(None, True) + sys.exit(1) + else: + return False + else: + return returnValue + return g + return wrap + +def handleError(msg, printToScreen): + writeToLog(getTimestamp() + " | " + "Failed to complete this step.", printToScreen) + if (msg != None): + writeToLog("Details: " + msg, printToScreen) + +def writeToLog(message, printToScreen): + logFile = GLOBALS.get("LOG_FILE") + if (logFile != None): + logFile.write(message) + logFile.write("\n") + if (printToScreen): + print message + +### Util classes + +class System: + def __init__(self, managementServerIp, asyncApi, xenServerIp, xenServerPassword, xenServerPasswordMap, dbName, dbLogin, dbPassword, zoneId, templateId, isoId, defaultServiceOfferingId, defaultDiskOfferingId): + self.zoneId = zoneId + self.templateId = templateId + self.isoId = isoId + self.defaultServiceOfferingId = defaultServiceOfferingId + self.defaultDiskOfferingId = defaultDiskOfferingId + self.api = System.API(managementServerIp, asyncApi) + if (dbPassword == None): + dbPassword = "" + self.db = System.DB(managementServerIp, dbName, dbLogin, dbPassword) + + self.xenServerIp = None + self.xenapi = None + if (xenServerIp != None or xenServerPasswordMap != None): + self.findXenApi(xenServerIp, xenServerPassword, xenServerPasswordMap) + self.controlDomainRef = self.findControlDomainRef() + self.sshConn = paramiko.SSHClient() + self.sshConn.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + self.sshConn.connect(self.xenServerIp, username = "root", password = self.xenServerPassword) + + @verboseLogDecorator("Finding the XenAPI connection...") + def findXenApi(self, xenServerIp, xenServerPassword, xenServerPasswordMap): + if (xenServerPasswordMap != None): + xenServerIp = xenServerPasswordMap.keys()[0] + xenServerPassword = xenServerPasswordMap[xenServerIp] + + masterXenServerIp = xenServerIp + masterXenServerPassword = xenServerPassword + session = None + try: + session = Session("http://" + masterXenServerIp + "/var/run/xend/xen-api.sock") + session.login_with_password("root", masterXenServerPassword) + except Exception, e: + if (e.details != None and len(e.details) == 2 and e.details[0] == "HOST_IS_SLAVE"): + masterXenServerIp = e.details[1] + if (xenServerPasswordMap != None): + masterXenServerPassword = xenServerPasswordMap[masterXenServerIp] + else: + masterXenServerPassword = xenServerPassword + session = Session("http://" + masterXenServerIp + "/var/run/xend/xen-api.sock") + session.login_with_password("root", masterXenServerPassword) + else: + raise + + self.xenServerIp = masterXenServerIp + self.xenServerPassword = masterXenServerPassword + self.xenapi = session.xenapi + return True + + @verboseLogDecorator("Finding the control domain for the dest system...") + def findControlDomainRef(self): + # Find the host ref for this system + hostRefs = self.xenapi.host.get_all() + systemHostRef = None + for hostRef in hostRefs: + address = self.xenapi.host.get_address(hostRef) + if (address == self.xenServerIp): + systemHostRef = hostRef + break + + if (systemHostRef == None): + raise Exception("Failed to find the XenServer host ref for " + str(self)) + + # Find the control domain ref that corresponds to this host + vmRefs = self.xenapi.VM.get_all() + for vmRef in vmRefs: + if (self.xenapi.VM.get_is_control_domain(vmRef)): + controlDomainHostRef = self.xenapi.VM.get_resident_on(vmRef) + if (controlDomainHostRef == systemHostRef): + return vmRef + return None + + @nonExitingLogDecorator("Running ssh command...") + def runSshCommand(self, command): + stdin, stdout, stderr = self.sshConn.exec_command(command) + return (stdin.channel.recv_exit_status() == 0) + + def updateIsoPermissions(self): + setParams = {"public":"1"} + whereParams = {"id":self.isoId} + return self.db.updateDbValues("vm_template", setParams, whereParams) + + def __str__(self): + description = "Management Server: %s" % (self.api.ip) + if (self.xenServerIp != None): + description += " | XenServer: %s" %(self.xenServerIp) + return description + + class API: + # Vars: ip + def __init__(self, ip, asyncApi): + self.ip = ip + self.asyncApi = asyncApi + + # Runs a synchronous API command and returns the ID of the created object, or success/failure + def runSyncApiCommand(self, command, params, objectName): + requestURL = self.buildRequestUrl(command, params) + xmlText = urllib2.urlopen(requestURL).read() + if (objectName == None): + responseName = (command + "response").lower() + return (System.API.getTagValue(xmlText, responseName, "success") == "true") + else: + return System.API.getTagValue(xmlText, objectName, "id") + + # Runs a asynchronous API command and returns the ID of the created object, or success/failure + def runAsyncApiCommand(self, command, params, objectName): + requestURL = self.buildRequestUrl(command, params) + xmlText = urllib2.urlopen(requestURL).read() + responseName = (command + "response").lower() + jobId = System.API.getTagValue(xmlText, responseName, "jobid") + objectId = System.API.getTagValue(xmlText, responseName, objectName + "id") + params = dict() + params["jobId"] = jobId + requestURL = self.buildRequestUrl("queryAsyncJobResult", params) + retries = int(GLOBALS["ASYNC_RETRIES"]) + jobResult = None + while (retries > 0): + time.sleep(float(GLOBALS["ASYNC_SLEEP_TIME"])) + xmlText = urllib2.urlopen(requestURL).read() + if (System.API.getTagValue(xmlText, "queryasyncjobresultresponse", "jobstatus") == "1"): + if (objectId != None): + return objectId + else: + return True + jobResult = System.API.getTagValue(xmlText, "queryasyncjobresultresponse", "jobresult") + retries -= 1 + raise Exception(jobResult) + + def buildRequestUrl(self, command, params): + requestURL = "http://" + self.ip + ":8096/client/api/?command=" + command + for paramKey in params.keys(): + paramVal = params.get(paramKey) + requestURL += "&" + str(paramKey) + "=" + str(paramVal) + return requestURL + + @staticmethod + def getTagValue(xmlText, objectName, tagName): + xmlDoc = parseString(xmlText) + response = xmlDoc.getElementsByTagName(objectName)[0] + for x in response.childNodes: + if (x.tagName == tagName): + return " ".join(z.wholeText for z in x.childNodes) + return None + + @staticmethod + def printApiValues(listOfAPIObjects): + for apiObject in listOfAPIObjects: + for key in apiObject.keys(): + print key + ":" + apiObject.get(key) + print " " + + class DB: + # Vars: conn + def __init__(self, ip, dbName, dbLogin, dbPassword): + self.conn = MySQLdb.connect(host = ip, user = dbLogin, passwd = dbPassword, db = dbName) + self.conn.autocommit(True) + + def getTable(self, table): + cursor = self.conn.cursor() + sql = "SELECT * from " + table + cursor.execute(sql) + for row in cursor.fetchall(): + print row + cursor.close() + + # Returns a list of hashtables, each one representing a row and using column names as keys. + def getDbValues(self, table, columns, whereParams): + values = [] + cursor = self.conn.cursor() + columnsText = ",".join(columns) + sql = "SELECT " + columnsText + " FROM " + table + if (len(whereParams) > 0): + sql += System.DB.buildSqlWhereClause(whereParams) + cursor.execute(sql) + rows = cursor.fetchall() + cursor.close() + for row in rows: + value = dict() + for i in range(len(columns)): + val = str(row[i]) + value[columns[i]] = val + values.append(value) + return values + + def updateDbValues(self, table, setParams, whereParams): + setClause = System.DB.buildSqlSetClause(setParams) + sql = "UPDATE " + table + setClause + if (len(whereParams) > 0): + sql += System.DB.buildSqlWhereClause(whereParams) + + cursor = self.conn.cursor() + cursor.execute(sql) + self.conn.commit() + cursor.close() + return True + + def insertIntoDb(self, table, setParams): + existingRecords = self.getDbValues(table, ["id"], setParams) + if (len(existingRecords) > 0): + return existingRecords[0]["id"] + columns = setParams.keys() + values = [] + for column in columns: + values.append(setParams[column]) + sql = "INSERT INTO " + table + System.DB.buildSqlInsertClause(columns, values) + cursor = self.conn.cursor() + cursor.execute(sql) + insertId = self.conn.insert_id() + self.conn.commit() + cursor.close() + return insertId + + @staticmethod + def buildSqlInsertClause(columns, values): + columnsSql = " (" + valuesSql = " VALUES (" + for i in range(len(columns)): + if (str(values[i]) == "null"): + continue + + columnsSql += columns[i] + valuesSql += "'" + str(values[i]) + "'" + if (i != (len(columns) - 1)): + columnsSql += ", " + valuesSql += ", " + else: + columnsSql += ")" + valuesSql += ")" + return columnsSql + valuesSql + + @staticmethod + def buildSqlWhereClause(params): + sql = " WHERE " + keys = params.keys() + for i in range(len(keys)): + key = str(keys[i]) + val = str(params[key]) + if ("like" in val): + val = val.split(":")[1] + sql += key + " like '" + val + "'" + elif ("neq" in val): + val = val.split(":")[1] + sql += key + " != '" + val + "'" + elif (val == "null" or val == "not null"): + sql += key + " IS " + val + else: + sql += key + " = '" + val + "'" + if (i != (len(keys) - 1)): + sql += " AND " + return sql + + @staticmethod + def buildSqlSetClause(params): + sql = " SET " + keys = params.keys() + for i in range(len(keys)): + key = keys[i] + val = params[key] + sql += key + " = " + if (val == "null"): + sql += "null" + else: + sql += "'" + val + "'" + if (i != (len(keys) - 1)): + sql += ", " + return sql + +### Data classes + +class User: + # Vars: system, id, username, password, accountId, firstname, lastname, email, accountType, accountName, domainId + + def __init__(self, system, userId, username, password, accountId, firstname, lastname, email, accountType, accountName, domainId): + self.system = system + self.id = userId + self.username = username + self.password = password + self.accountId = accountId + self.firstname = firstname + self.lastname = lastname + self.email = email + self.accountType = accountType + self.accountName = accountName + self.domainId = domainId + + def __str__(self): + return "(User: %s | %s)" % (self.username, self.system) + + def alreadyMigrated(self): + f = open(GLOBALS["MIGRATED_ACCOUNTS_FILE"], "a+") + migratedUsersCsv = f.read() + f.close() + migratedUsersEntries = migratedUsersCsv.split(",") + for migratedUsersEntry in migratedUsersEntries: + if (migratedUsersEntry.strip() == self.accountId): + return True + return False + + def tagAsMigrated(self): + if (not self.alreadyMigrated()): + f = open(GLOBALS["MIGRATED_ACCOUNTS_FILE"], "a") + f.write(self.accountId + ",") + f.close() + return True + + @staticmethod + def getByName(system, username): + columns = ["id"] + users = system.db.getDbValues("user", columns, {"username":username, "removed":"null"}) + if (len(users) > 0): + return User.get(system, users[0]["id"]) + else: + return None + + @staticmethod + def get(system, userId): + columns = ["id", "username", "password", "account_id", "firstname", "lastname", "email"] + users = system.db.getDbValues("user", columns, {"id":userId}) + if (len(users) == 0): + return None + user = users[0] + columns = ["type", "account_name", "domain_id"] + account = system.db.getDbValues("account", columns, {"id":user["account_id"]})[0] + return User(system, userId, user["username"], user["password"], user["account_id"], user["firstname"], user["lastname"], user["email"], account["type"], account["account_name"], account["domain_id"]) + + @staticmethod + def getDomain(system, domainId): + columns = ["id", "parent", "name", "owner"] + return system.db.getDbValues("domain", columns, {"id":domainId})[0] + + @staticmethod + def getDomainByName(system, domainName): + columns = ["id"] + domains = system.db.getDbValues("domain", columns, {"name":domainName}) + if (len(domains) > 0): + return domains[0] + else: + return None + + @staticmethod + def createDomain(srcSystem, destSystem, srcDomainId): + # Get the source domain + srcDomain = User.getDomain(srcSystem, srcDomainId) + + # If a domain with the same name exists in the dest system, return its ID + destDomain = User.getDomainByName(destSystem, srcDomain["name"]) + if (destDomain != None): + return destDomain["id"] + else: + # Otherwise, create a new domain in the dest system with the same name, and return its ID + # If the src domain has parent domains, we need to create these first + parentId = None + if (srcDomain["parent"] != "null"): + parentId = User.createDomain(srcSystem, destSystem, srcDomain["parent"]) + params = dict() + params["name"] = srcDomain["name"] + if (parentId != None): + params["parent"] = parentId + newDomainId = destSystem.api.runSyncApiCommand("createDomain", params, "domain") + if (newDomainId == None): + raise Exception("Failed to create domain " + srcDomain["name"]) + else: + # Set the owner for the new domain + srcOwner = User.get(srcSystem, srcDomain["owner"]) + destOwner = User.create(destSystem, srcOwner) + if (not destSystem.db.updateDbValues("domain", {"owner":destOwner.id}, {"id":newDomainId})): + raise Exception("Failed to update the owner for domain " + srcDomain["name"]) + return newDomainId + + @staticmethod + @basicLogDecorator("Creating new user...") + def create(system, srcUser): + user = User.getByName(system, srcUser.username) + if (user != None): + return user + else: + # If the user's domain doesn't exist in the system, create it + domainId = User.createDomain(srcUser.system, system, srcUser.domainId) + + params = dict() + params["username"] = srcUser.username + params["password"] = "temp" + params["firstname"] = srcUser.firstname + params["lastname"] = srcUser.lastname + params["email"] = srcUser.email + accountType = srcUser.accountType + if (accountType == "2"): + accountType = "0" + params["accounttype"] = accountType + params["account"] = srcUser.accountName + params["domainid"] = domainId + newUserId = system.api.runSyncApiCommand("createUser", params, "user") + if (newUserId != None): + if (system.db.updateDbValues("user", {"password":srcUser.password}, {"id":newUserId})): + return User.get(system, newUserId) + else: + return None + else: + return None + +class ServiceOffering: + # Vars: system, id, numCpus, speed, memory, disk + + def __init__(self, system, offeringId, numCpus, speed, memory, disk): + self.system = system + self.id = offeringId + self.numCpus = numCpus + self.speed = speed + self.memory = memory + self.disk = disk + + def __str__(self): + return "ServiceOffering: %s | id: %s | numCpus: %s | speed: %s | memory: %s | disk: %s" % (self.system, self.id, self.numCpus, self.speed, self.memory, self.disk) + + @staticmethod + def getCorrespondingServiceOffering(srcServiceOfferingId): + srcServiceOffering = ServiceOffering.getSrcSystemServiceOfferingById(srcServiceOfferingId) + destServiceOffering = ServiceOffering.getDestSystemServiceOffering(srcServiceOffering.numCpus, srcServiceOffering.speed, srcServiceOffering.memory) + return destServiceOffering + + @staticmethod + def getDestSystemServiceOffering(numCpus, speed, memory): + serviceOfferings = GLOBALS["DEST_SYSTEM"].db.getDbValues("service_offering", ["id"], {"cpu":numCpus, "speed":speed, "ram_size":memory, "guest_ip_type":"Virtualized"}) + if (len(serviceOfferings) > 0): + return ServiceOffering(GLOBALS["DEST_SYSTEM"], serviceOfferings[0]["id"], numCpus, speed, memory, None) + else: + return None + + @staticmethod + def getSrcSystemServiceOfferingByVmId(vmId): + serviceOfferingId = GLOBALS["SRC_SYSTEM"].db.getDbValues("user_vm", ["service_offering_id"], {"id":vmId})[0]["service_offering_id"] + return getSrcSystemServiceOfferingById(serviceOfferingId) + + @staticmethod + def getSrcSystemServiceOfferingById(serviceOfferingId): + columns = ["id", "cpu", "speed", "ram_size", "disk"] + serviceOfferings = GLOBALS["SRC_SYSTEM"].db.getDbValues("service_offering", columns, {"id":serviceOfferingId}) + if (len(serviceOfferings) > 0): + offering = serviceOfferings[0] + return ServiceOffering(GLOBALS["SRC_SYSTEM"], offering["id"], offering["cpu"], offering["speed"], offering["ram_size"], offering["disk"]) + else: + return None + + @staticmethod + def getSrcSystemServiceOfferings(): + serviceOfferings = [] + columns = ["id", "cpu", "speed", "ram_size", "disk"] + srcServiceOfferings = GLOBALS["SRC_SYSTEM"].db.getDbValues("service_offering", columns, {}) + for offering in srcServiceOfferings: + serviceOfferings.append(ServiceOffering(GLOBALS["SRC_SYSTEM"], offering["id"], offering["cpu"], offering["speed"], offering["ram_size"], offering["disk"])) + return serviceOfferings + +class DiskOffering: + # Vars: id, size + + def __init__(self, diskOfferingId, size): + self.id = diskOfferingId + self.size = size + + def __str__(self): + return "Disk Offering: size = %s" % (self.size) + + @staticmethod + def getDestDiskOffering(size): + columns = ["id"] + diskOfferingRows = GLOBALS["DEST_SYSTEM"].db.getDbValues("disk_offering", ["id"], {"disk_size":size, "type":"Disk"}) + if (len(diskOfferingRows) > 0): + return DiskOffering(diskOfferingRows[0]["id"], size) + else: + size = ((int(size) / 1024) + 1) * 1024 + diskOfferingRows = GLOBALS["DEST_SYSTEM"].db.getDbValues("disk_offering", ["id"], {"disk_size":size, "type":"Disk"}) + if (len(diskOfferingRows) > 0): + return DiskOffering(diskOfferingRows[0]["id"], size) + else: + return None + + @staticmethod + def getCorrespondingDiskOffering(srcServiceOfferingId): + srcServiceOffering = ServiceOffering.getSrcSystemServiceOfferingById(srcServiceOfferingId) + diskOffering = DiskOffering.getDestDiskOffering(srcServiceOffering.disk) + return diskOffering + + +class VM: + # Vars: id, system, user, serviceOfferingId, name, templateId, guestOsId + + def __init__(self, vmId, user, serviceOfferingId, guestOsId, guestOsCategoryId): + self.id = vmId + self.system = user.system + self.user = user + self.serviceOfferingId = serviceOfferingId + self.guestOsId = guestOsId + self.guestOsCategoryId = guestOsCategoryId + + def __str__(self): + return "UserVM: id = %s | username = %s | system = %s" % (self.id, self.user.username, self.system) + + def getName(self): + columns = ["name"] + return self.system.db.getDbValues("vm_instance", columns, {"id":self.id})[0]["name"] + + @basicLogDecorator("Deploying a temporary VM...") + def deployTemp(self): + params = {"account":self.user.accountName, "domainid":self.user.domainId, "zoneId":self.system.zoneId, "serviceofferingid":self.system.defaultServiceOfferingId, "templateid":self.system.templateId} + vmId = self.system.api.runAsyncApiCommand("deployVirtualMachine", params, "virtualmachine") + if (vmId in (None, False)): + return False + self.id = vmId + self.name = self.getName() + "-temp-vm" + success = self.system.db.updateDbValues("vm_instance", {"name":self.name}, {"id":self.id}) + return success + + @basicLogDecorator("Deploying a new VM for the user ...") + def deploy(self, srcVm): + params = dict() + params["account"] = self.user.accountName + params["domainid"] = self.user.domainId + params["zoneid"] = self.system.zoneId + params["serviceofferingid"] = self.serviceOfferingId + params["templateid"] = self.system.isoId + params["diskofferingid"] = self.system.defaultDiskOfferingId + vmId = self.system.api.runAsyncApiCommand("deployVirtualMachine", params, "virtualmachine") + if (vmId in (None, False)): + return None + self.id = vmId + self.name = self.getName() + "-" + str(srcVm.id) + " (" + VM.getGuestOsName(GLOBALS["DEST_SYSTEM"], self.guestOsId) + ")" + success = self.system.db.updateDbValues("vm_instance", {"name":self.name}, {"id":self.id}) + return success + + @verboseLogDecorator("Updating the guest OS ID for the VM...") + def updateGuestOsId(self): + setParams = {"guest_os_id":self.guestOsId} + whereParams = {"id":self.id} + return self.system.db.updateDbValues("vm_instance", setParams, whereParams) + + @basicLogDecorator("Starting VM...") + def start(self): + params = {"id":self.id} + return self.system.api.runAsyncApiCommand("startVirtualMachine", params, "virtualmachine") + + @basicLogDecorator("Stopping VM...") + def stop(self): + params = {"id":self.id} + if (self.system.api.asyncApi): + return self.system.api.runAsyncApiCommand("stopVirtualMachine", params, "virtualmachine") + else: + return self.system.api.runSyncApiCommand("stopVirtualMachine", params, None) + + @basicLogDecorator("Destroying temporary VM...") + def destroy(self): + params = {"id":self.id} + return self.system.api.runAsyncApiCommand("destroyVirtualMachine", params, "virtualmachine") + + @verboseLogDecorator("Detaching ISO from VM...") + def detachIso(self): + isoId = self.system.db.getDbValues("vm_instance", ["iso_id"], {"id":self.id})[0]["iso_id"] + if (isoId == "None"): + return True + + params = {"virtualmachineid":self.id} + return self.system.api.runAsyncApiCommand("detachIso", params, "virtualmachine") + + def isLinuxVm(self): + return (self.guestOsCategoryId != str(GLOBALS["DEST_WINDOWS_GUEST_OS_CATEGORY_ID"])) + + @staticmethod + def getGuestOsName(system, guestOsId): + columns = ["id", "display_name"] + guestOsList = system.db.getDbValues("guest_os", columns, {"id":guestOsId}) + if (len(guestOsList) > 0): + return guestOsList[0]["display_name"] + else: + return None + + @staticmethod + def getGuestOsCategoryId(system, guestOsId): + columns = ["category_id"] + return system.db.getDbValues("guest_os", columns, {"id":guestOsId})[0]["category_id"] + + @staticmethod + def getVmId(system, accountId, guestIpAddress): + userVms = system.db.getDbValues("user_vm", ["id"], {"account_id":accountId, "guest_ip_address":guestIpAddress}) + for userVm in userVms: + vmInstances = system.db.getDbValues("vm_instance", ["id"], {"id":userVm["id"], "removed":"null", "state":"neq:Destroyed"}) + if (len(vmInstances) > 0): + return vmInstances[0]["id"] + return None + + @staticmethod + def getVms(user): + system = user.system + vms = [] + columns = ["id", "service_offering_id"] + userVmRows = system.db.getDbValues("user_vm", columns, {"account_id":user.accountId}) + for userVmRow in userVmRows: + vmInstanceRow = system.db.getDbValues("vm_instance", ["vm_template_id", "removed"], {"id":userVmRow["id"]})[0] + if (vmInstanceRow["removed"] != "None"): + continue + + # Determine the service offering ID + serviceOfferingId = userVmRow["service_offering_id"] + + # Determine the new guest OS id and category id + templateId = vmInstanceRow["vm_template_id"] + guestOsId = GLOBALS["GUEST_OS_MAP"][templateId] + guestOsCategoryId = VM.getGuestOsCategoryId(GLOBALS["DEST_SYSTEM"], guestOsId) + + vms.append(VM(userVmRow["id"], user, serviceOfferingId, guestOsId, guestOsCategoryId)) + return vms + + @staticmethod + def getCorrespondingVm(destUser, srcVm): + system = destUser.system + columns = ["id", "guest_os_id"] + correspondingVms = system.db.getDbValues("vm_instance", columns, {"name":"like:%-" + srcVm.id + " (%"}) + if (len(correspondingVms) > 0): + correspondingVm = correspondingVms[0] + newServiceOffering = ServiceOffering.getCorrespondingServiceOffering(srcVm.serviceOfferingId) + vmId = correspondingVm["id"] + guestOsId = srcVm.guestOsId + guestOsCategoryId = VM.getGuestOsCategoryId(GLOBALS["DEST_SYSTEM"], guestOsId) + return VM(vmId, destUser, newServiceOffering.id, guestOsId, guestOsCategoryId) + else: + return None + + @staticmethod + def getTempVm(user): + system = user.system + columns = ["id"] + tempVms = system.db.getDbValues("vm_instance", columns, {"removed":"null", "state":"neq:Destroyed", "name":"like:%-temp-vm"}) + if (len(tempVms) > 0): + tempVm = tempVms[0] + return VM(tempVm["id"], user, None, None, None) + else: + return None + + @staticmethod + def getTemplate(system, templateId): + columns = ["id", "name", "format"] + templates = system.db.getDbValues("vm_template", columns, {"id":templateId}) + if (len(templates) > 0): + return templates[0] + else: + return None + + @staticmethod + def getTemplateIds(system): + templateIds = [] + columns = ["id", "unique_name"] + templates = system.db.getDbValues("vm_template", columns, {}) + for template in templates: + if (template["unique_name"] == "routing"): + continue + templateIds.append(template["id"]) + return templateIds + + @staticmethod + @basicLogDecorator("Migrating the user's VMs...") + def migrateVirtualMachines(srcUser, destUser): + # Maintain a map of src system VM IDs to dest system VM ids + vmIdMap = dict() + + # Get a list of user VMs for the source user + srcVms = VM.getVms(srcUser) + + for srcVm in srcVms: + # Try to find an existing VM in the dest system that corresponds to the VM in the src system + destVm = VM.getCorrespondingVm(destUser, srcVm) + + # If there is no corresponding VM, deploy a new VM in the dest system + if (destVm == None): + destVm = VM(None, destUser, srcVm.serviceOfferingId, srcVm.guestOsId, srcVm.guestOsCategoryId) + destVm.deploy(srcVm) + + # Add a mapping between the src VM and the dest VM + vmIdMap[srcVm.id] = destVm.id + + # Get a list of volumes for the source VM + srcVolumes = Volume.getSrcVolumes(srcUser, srcVm) + + # If these volumes have already been copied to the dest system, skip migration for this VM + vmAlreadyMigrated = True + for srcVolume in srcVolumes: + destVolume = Volume.getDestVolume(None, destVm, srcVolume.type) + if (destVolume == None): + vmAlreadyMigrated = False + break + elif (srcVolume.id != destVolume.name.split("-")[-1]): + vmAlreadyMigrated = False + break + + if (vmAlreadyMigrated): + writeToLog("\n" + str(srcVm) + " has already been migrated.", True) + continue + else: + writeToLog("\nMigrating volumes for source VM: " + str(srcVm), True) + + # Stop the dest VM + destVm.stop() + + # Stop the source VM + srcVm.stop() + + for srcVolume in srcVolumes: + destVolume = None + if (srcVolume.type == "DATADISK"): + destVolume = Volume.getDestVolume(None, destVm, "DATADISK") + if (destVolume == None): + diskOffering = DiskOffering.getCorrespondingDiskOffering(srcVm.serviceOfferingId) + destVolume = Volume(GLOBALS["DEST_SYSTEM"], None, str(destVm.id) + "-DATADISK", None, None, None, "DATA", diskOffering.id) + destVolume.createAndAttach(destVm) + else: + destVolume = Volume.getDestVolume(None, destVm, "ROOT") + + # If the dest volume is already tagged with the source volume's ID, we don't need to do a copy + if (srcVolume.id == destVolume.name.split("-")[-1]): + writeToLog(str(srcVolume) + " has already been migrated.") + continue + + # If the srcVolume's iSCSI SR isn't created on the XenServer, create it + srcHost = Host.getHost(GLOBALS["SRC_SYSTEM"], srcVolume.hostId) + srcSR = SR.getExistingSrcSr(srcHost.ip, srcHost.iqn) + if (srcSR == None): + srcSR = SR(GLOBALS["DEST_SYSTEM"], srcHost.ip, srcHost.iqn, None) + srcSR.create() + else: + writeToLog("Found existing SR: " + str(srcSR), False) + + # Find the VDI corresponding to the src volume + srcVdi = VDI(srcSR, srcVolume, None) + + # Find the SR corresponding to the dest storage pool + destStoragePool = StoragePool.getStoragePool(GLOBALS["DEST_SYSTEM"], destVolume.poolId) + destSR = SR(GLOBALS["DEST_SYSTEM"], None, None, destStoragePool.uuid) + destSR.find() + + # Copy the src VDI to the dest SR + copiedVdiUuid = srcVdi.copy(destSR) + + # If this is the rootdisk of a Linux VM, change the disk name + destVdi = VDI(destSR, destVolume, copiedVdiUuid) + if (destVolume.type == "ROOT" and srcVm.isLinuxVm()): + destVdi.changeBootableDeviceName() + + # Destroy the VM's old VDI + oldDestVdi = VDI(destSR, destVolume, destVolume.path) + oldDestVdi.destroy() + + # Update the destVolume's database record to have the UUID of the copied VDI, the virtual size of the copied VDI, and the ID of the source volume + destVolume.update(copiedVdiUuid, destVdi.getVirtualSize(), destVolume.name + "-" + srcVolume.id) + + # Detach the dest VM's ISO + destVm.detachIso() + + # Update the guest OS ID for the VM + destVm.updateGuestOsId() + + # Start the dest VM + destVm.start() + + return vmIdMap + +class Volume: + # vars: system, id, hostId, poolId, path, zoneId, iscsiName, type, diskOfferingId + + def __init__(self, system, volumeId, name, poolOrHostId, path, iscsiName, volumeType, diskOfferingId): + self.system = system + self.id = volumeId + self.name = name + if (iscsiName == None): + self.poolId = poolOrHostId + self.iscsiName = None + else: + self.hostId = poolOrHostId + self.iscsiName = iscsiName + self.path = path + self.type = volumeType + self.diskOfferingId = diskOfferingId + + def __str__(self): + return "Volume: %s | type: %s | path: %s" % (self.system, self.type, self.path) + + @basicLogDecorator("Creating a new volume and attaching it to the user's VM...") + def createAndAttach(self, destVm): + params = dict() + params["account"] = destVm.user.accountName + params["domainid"] = destVm.user.domainId + params["name"] = self.name + params["zoneid"] = self.system.zoneId + params["diskofferingid"] = self.diskOfferingId + volumeId = self.system.api.runAsyncApiCommand("createVolume", params, "volume") + if (volumeId in (None, False)): + return False + self.id = volumeId + params = dict() + params["id"] = volumeId + params["virtualmachineid"] = destVm.id + success = self.system.api.runAsyncApiCommand("attachVolume", params, "volume") + if (success in (None, False)): + return False + newVolume = Volume.getDestVolume(volumeId, None, None) + self.poolId = newVolume.poolId + self.path = newVolume.path + return True + + def update(self, volumeUuid, volumeSize, name): + setParams = {"path":volumeUuid, "size":volumeSize, "name":name} + whereParams = {"id":self.id} + return self.system.db.updateDbValues("volumes", setParams, whereParams) + + @staticmethod + def getSrcVolumes(user, vm): + volumes = [] + columns = ["id", "name", "host_id", "path", "iscsi_name", "volume_type", "offering_id"] + volumeRows = GLOBALS["SRC_SYSTEM"].db.getDbValues("volumes", columns, {"account_id":user.accountId, "instance_id":vm.id, "removed":"null"}) + for volumeRow in volumeRows: + volumes.append(Volume(GLOBALS["SRC_SYSTEM"], volumeRow["id"], volumeRow["name"], volumeRow["host_id"], volumeRow["path"], volumeRow["iscsi_name"], volumeRow["volume_type"], volumeRow["offering_id"])) + return volumes + + @staticmethod + def getDestVolume(volumeId, vm, volumeType): + columns = ["id", "name", "pool_id", "path", "disk_offering_id"] + whereParams = None + if (volumeId != None): + whereParams = {"id":volumeId} + else: + whereParams = {"instance_id":vm.id, "volume_type":volumeType} + volumeRows = GLOBALS["DEST_SYSTEM"].db.getDbValues("volumes", columns, whereParams) + if (len(volumeRows) > 0): + volumeRow = volumeRows[0] + return Volume(GLOBALS["DEST_SYSTEM"], volumeRow["id"], volumeRow["name"], volumeRow["pool_id"], volumeRow["path"], None, volumeType, volumeRow["disk_offering_id"]) + else: + return None + +class DomainRouter: + # Vars: id, system, user + + def __init__(self, user): + self.system = user.system + self.user = user + self.id = self.getId() + + def __str__(self): + return "DomainRouter: %s" % (self.user) + + @basicLogDecorator("Stopping user's router...") + def stop(self): + if (self.id == None): + raise Exception("Could not find router for " + str(self.user)) + params = {"id":self.id} + if (self.system.api.asyncApi): + return self.system.api.runAsyncApiCommand("stopRouter", params, "router") + else: + return self.system.api.runSyncApiCommand("stopRouter", params, None) + + @basicLogDecorator("Starting user's router...") + def start(self): + if (self.id == None): + return False + params = {"id":self.id} + if (self.system.api.asyncApi): + return self.system.api.runAsyncApiCommand("startRouter", params, "router") + else: + return self.system.api.runSyncApiCommand("startRouter", params, None) + + @basicLogDecorator("Rebooting user's router...") + def reboot(self): + if (self.id == None): + return False + params = {"id":self.id} + return self.system.api.runAsyncApiCommand("rebootRouter", params, "router") + + def getId(self): + routers = self.system.db.getDbValues("domain_router", ["id"], {"account_id":self.user.accountId}) + if (len(routers) > 0): + return routers[0]["id"] + else: + return None + +class PublicIp: + # Vars: system, user, address, zoneId, sourceNat, allocated + + def __init__(self, system, user, address, zoneId, sourceNat, allocated): + self.system = system + self.user = user + self.address = address + self.zoneId = zoneId + self.sourceNat = sourceNat + self.allocated = allocated + + def __str__(self): + return self.address + + def __repr__(self): + return self.address + + def allocate(self): + setParams = {"account_id":self.user.accountId, + "domain_id":self.user.domainId, + "source_nat":self.sourceNat, + "allocated":self.allocated} + whereParams = {"public_ip_address":self.address, + "data_center_id":self.zoneId} + return self.system.db.updateDbValues("user_ip_address", setParams, whereParams) + + @staticmethod + @basicLogDecorator("Clearing existing public IPs...") + def clearPublicIps(user): + system = user.system + setParams = {"account_id":"null", "domain_id":"null", "source_nat":"0", "allocated":"null"} + whereParams = {"account_id":user.accountId} + return system.db.updateDbValues("user_ip_address", setParams, whereParams) + + @staticmethod + @basicLogDecorator("Migrating allocated public IPs...") + def migrateAllocatedPublicIps(srcUser, destUser): + # Get a list of public IPs allocated to the source user + ips = PublicIp.getAllocatedPublicIps(srcUser) + + # Allocate each one of these IPs in the dest system + for ip in ips: + ip.system = GLOBALS["DEST_SYSTEM"] + ip.user = destUser + if (not ip.allocate()): + return None + + return ips + + @staticmethod + def getAllocatedPublicIps(user): + system = user.system + ips = [] + columns = ["public_ip_address", "data_center_id", "source_nat", "allocated"] + ipRows = system.db.getDbValues("user_ip_address", columns, {"account_id":user.accountId}) + for ipRow in ipRows: + ips.append(PublicIp(system, user, ipRow["public_ip_address"], ipRow["data_center_id"], ipRow["source_nat"], ipRow["allocated"])) + return ips + + @staticmethod + def getGuestIpAddress(system, vmId): + columns = ["guest_ip_address"] + guestIp = system.db.getDbValues("user_vm", columns, {"id":vmId})[0] + return guestIp["guest_ip_address"] + +class ForwardingRule: + @staticmethod + @basicLogDecorator("Migrating port forwarding and load balancer rules...") + def migrateForwardingRules(srcUser, destUser, publicIps, vmIdMap): + for publicIp in publicIps: + forwardingRules = ForwardingRule.getSrcForwardingRules(srcUser, destUser, publicIp.address, vmIdMap) + for forwardingRule in forwardingRules: + newRuleId = forwardingRule.createInDestSystem() + if (newRuleId == None): + return False + return True + + @staticmethod + def getSrcForwardingRules(srcUser, destUser, address, vmIdMap): + # vmIdMap maps UserVM database IDs in the src system to UserVM database IDs in the dest system + columns = ["id", "public_port", "private_ip_address", "private_port", "enabled", "protocol", "forwarding", "algorithm"] + ruleRows = GLOBALS["SRC_SYSTEM"].db.getDbValues("ip_forwarding", columns, {"public_ip_address":address}) + activeRules = [] + for ruleRow in ruleRows: + srcVmId = VM.getVmId(GLOBALS["SRC_SYSTEM"], srcUser.accountId, ruleRow["private_ip_address"]) + destVmId = vmIdMap.get(srcVmId) + if (destVmId == None): + continue + + if (ruleRow["forwarding"] == "1"): + activeRules.append(ForwardingRule.PortForwardingRule(address, ruleRow["public_port"], PublicIp.getGuestIpAddress(GLOBALS["DEST_SYSTEM"], destVmId), ruleRow["private_port"], ruleRow["enabled"], ruleRow["protocol"])) + else: + activeRules.append(ForwardingRule.LoadBalancerRule(destUser.accountId, address, ruleRow["public_port"], ruleRow["private_port"], destVmId, ruleRow["algorithm"])) + + return activeRules + + class PortForwardingRule: + def __init__(self, publicIp, publicPort, privateIp, privatePort, enabled, protocol): + self.publicIp = publicIp + self.publicPort = publicPort + self.privateIp = privateIp + self.privatePort = privatePort + self.enabled = enabled + self.protocol = protocol + + def createInDestSystem(self): + setParams = dict() + setParams["public_ip_address"] = self.publicIp + setParams["public_port"] = self.publicPort + setParams["private_ip_address"] = self.privateIp + setParams["private_port"] = self.privatePort + setParams["enabled"] = self.enabled + setParams["protocol"] = self.protocol + setParams["forwarding"] = "1" + setParams["algorithm"] = "null" + setParams["group_id"] = "null" + return GLOBALS["DEST_SYSTEM"].db.insertIntoDb("ip_forwarding", setParams) + + class LoadBalancerRule: + def __init__(self, accountId, ip, publicPort, privatePort, vmId, algorithm): + self.accountId = accountId + self.ip = ip + self.publicPort = publicPort + self.privatePort = privatePort + self.vmId = vmId + self.algorithm = algorithm + + def createInDestSystem(self): + setParams = dict() + setParams["name"] = str(self.publicPort) + "-" + str(self.privatePort) + setParams["account_id"] = self.accountId + setParams["ip_address"] = self.ip + setParams["public_port"] = self.publicPort + setParams["private_port"] = self.privatePort + setParams["algorithm"] = self.algorithm + newLoadBalancerRuleId = GLOBALS["DEST_SYSTEM"].db.insertIntoDb("load_balancer", setParams) + if (newLoadBalancerRuleId == None or newLoadBalancerRuleId == "0"): + return None + setParams = dict() + setParams["load_balancer_id"] = newLoadBalancerRuleId + setParams["instance_id"] = self.vmId + return GLOBALS["DEST_SYSTEM"].db.insertIntoDb("load_balancer_vm_map", setParams) + +class SR: + def __init__(self, system, ip, iqn, uuid): + self.system = system + self.ip = ip + self.iqn = iqn + self.uuid = uuid + + def __str__(self): + return "SR: %s | ip: %s | iqn: %s | uuid: %s" % (self.system, self.ip, self.iqn, self.uuid) + + @verboseLogDecorator("Finding SR...") + def find(self): + xenapi = self.system.xenapi + self.ref = xenapi.SR.get_by_name_label(self.uuid)[0] + return True + + @verboseLogDecorator("Finding source system's iSCSI SR...") + def create(self): + xenapi = self.system.xenapi + host = xenapi.host.get_all()[0] + deviceConfig = {'targetIQN': self.iqn, 'target': self.ip} + srRef = None + name = "1.0 iSCSI pool: " + self.ip + "-" + self.iqn + srRef = xenapi.SR.create(host, deviceConfig, "0", name, name, "iscsi", "user", True) + if (srRef != None): + self.ref = srRef + return True + else: + return False + + @staticmethod + def getExistingSrcSr(ip, iqn): + xenapi = GLOBALS["DEST_SYSTEM"].xenapi + srRefs = xenapi.SR.get_all() + for srRef in srRefs: + srNameLabel = xenapi.SR.get_name_label(srRef) + if (srNameLabel == "1.0 iSCSI pool: " + ip + "-" + iqn): + sr = SR(GLOBALS["DEST_SYSTEM"], ip, iqn, xenapi.SR.get_uuid(srRef)) + sr.ref = srRef + return sr + return None + + @staticmethod + @verboseLogDecorator("Forgetting all src iSCSI SRs...") + def forgetAllSrcSrs(): + xenapi = GLOBALS["DEST_SYSTEM"].xenapi + srRefs = xenapi.SR.get_all() + for srRef in srRefs: + srNameLabel = xenapi.SR.get_name_label(srRef) + if ("1.0 iSCSI pool" in srNameLabel): + # Unplug and destroy the SR's PBDs + pbdRefs = xenapi.SR.get_PBDs(srRef) + for pbdRef in pbdRefs: + xenapi.PBD.unplug(pbdRef) + xenapi.PBD.destroy(pbdRef) + + # Forget the SR + xenapi.SR.forget(srRef) + + return True + +class VDI: + def __init__(self, sr, volume, uuid): + self.sr = sr + self.volume = volume + self.uuid = uuid + self.find() + + def __str__(self): + return "VDI: %s | uuid: %s" % (self.volume, self.uuid) + + @verboseLogDecorator("Getting virtual size for VDI...") + def getVirtualSize(self): + xenapi = self.sr.system.xenapi + return xenapi.VDI.get_virtual_size(self.ref) + + @basicLogDecorator("Copying source system volume to dest system...") + def copy(self, destSR): + xenapi = self.sr.system.xenapi + newVdiRef = xenapi.VDI.copy(self.ref, destSR.ref) + return xenapi.VDI.get_uuid(newVdiRef) + + @basicLogDecorator("Destroying old volume...") + def destroy(self): + xenapi = self.sr.system.xenapi + xenapi.VDI.destroy(self.ref) + return True + + @verboseLogDecorator("Finding VDI in SR...") + def find(self): + xenapi = self.sr.system.xenapi + if (self.uuid == None): + # Run an sr-scan + xenapi.SR.scan(self.sr.ref) + + # Get a list of VDIs in the SR + vdiRefs = xenapi.SR.get_VDIs(self.sr.ref) + + # Find the VDI that has the same SCSI ID as the specified volume + volumeScsiId = self.volume.iscsiName.split(":")[-1].strip() + for vdiRef in vdiRefs: + smConfig = xenapi.VDI.get_sm_config(vdiRef) + vdiScsiId = smConfig["SCSIid"].strip()[1:] + if (vdiScsiId == volumeScsiId): + self.ref = vdiRef + self.uuid = xenapi.VDI.get_uuid(vdiRef) + return True + + return False + else: + self.ref = xenapi.VDI.get_by_uuid(self.uuid) + return True + + @basicLogDecorator("Changing disk name for VDI...") + def changeBootableDeviceName(self): + system = self.volume.system + xenapi = system.xenapi + controlDomainRef = system.controlDomainRef + + vbdRef = None + try: + # Create a VBD for the VDI + vbd = {'bootable': True, 'userdevice': '0', 'VDI': self.ref, + 'other_config': {}, 'VM': controlDomainRef, + 'mode': 'rw', 'qos_algorithm_type': '', 'qos_algorithm_params': {}, + 'type': 'Disk', 'empty': False, 'unpluggable': True} + vbdRef = xenapi.VBD.create(vbd) + + # Plug the VBD + xenapi.VBD.plug(vbdRef) + + # Create a temporary directory + if (not system.runSshCommand("mkdir -p /root/temp")): + raise Exception ("Failed to create directory /root/temp") + + # Check if /dev/xvda1 exists + xvda1Exists = system.runSshCommand("ls /dev/xvda1") + + # If /dev/xvda1 doesn't exist, work with /dev/xvda + if (not xvda1Exists): + # Mount /dev/xvda to /root/temp + if (not system.runSshCommand("mount /dev/xvda /root/temp")): + raise Exception("Failed to mount /dev/xvda to /root/temp") + writeToLog("Using /dev/xvda to change bootable device name.", False) + else: + # Mount /dev/xvda1 to /root/temp + if (not system.runSshCommand("mount /dev/xvda1 /root/temp")): + raise Exception("Failed to mount /dev/xvda1 to /root/temp") + + # If the boot directory exists under /root/temp, we can work with xvda1 + if (system.runSshCommand("ls /root/temp/boot")): + writeToLog("Using /dev/xvda1 to change bootable device name.", False) + else: + # If the boot directory doesn't exist under /root/temp, we need to work with /dev/xvda2 + + # Check that /dev/xvda2 exists + if (not system.runSshCommand("ls /dev/xvda2")): + raise Exception("/dev/xvda1 exists but /dev/xvda2 doesn't exist") + + # Unmount /dev/xvda1 + if (not system.runSshCommand("umount /root/temp")): + raise Exception("Failed to unmount /dev/xvda1") + + # Mount /dev/xvda2 + if (not system.runSshCommand("mount /dev/xvda2 /root/temp")): + raise Exception("Failed to mount /dev/xvda2") + + writeToLog("Using /dev/xvda2 to change bootable device name.", False) + + # Modify fstab, grub.conf, and device.map, if they exist + for fileToModify in ["/root/temp/etc/fstab", "/root/temp/boot/grub/grub.conf", "/root/temp/boot/grub/device.map"]: + if (system.runSshCommand("ls " + fileToModify)): + if (not system.runSshCommand("sed -i 's_/dev/sda_/dev/xvda_' " + fileToModify)): + raise Exception("Failed to modify " + fileToModify) + finally: + # Unmount /root/temp if necessary + if (system.runSshCommand("mount | grep '/root/temp'")): + if (not system.runSshCommand("umount /root/temp")): + raise Exception("Failed to unmount /root/temp") + + # Delete /root/temp + system.runSshCommand("rm -rf /root/temp") + + if (vbdRef != None): + # Unplug the VBD + xenapi.VBD.unplug(vbdRef) + + # Destroy the VBD + xenapi.VBD.destroy(vbdRef) + + return True + +class StoragePool: + # Vars: id, uuid + + def __init__(self, storagePoolId, uuid): + self.id = storagePoolId + self.uuid = uuid + + @staticmethod + def getStoragePool(system, storagePoolId): + columns = ["id", "uuid"] + storagePoolRow = system.db.getDbValues("storage_pool", columns, {"id":storagePoolId})[0] + return StoragePool(storagePoolRow["id"], storagePoolRow["uuid"]) + +class Host: + # Vars: id, ip, iqn + + def __init__(self, hostId, ip, iqn): + self.id = hostId + self.ip = ip + self.iqn = iqn + + def __str__(self): + return "Host: id: %s | ip %s" % (self.id, self.ip) + + @staticmethod + @basicLogDecorator("Sharing LUs...") + def shareAllLus(): + return Host.shareOrUnshareAllLus(True) + + @staticmethod + @basicLogDecorator("Unsharing LUs...") + def unshareAllLus(): + return Host.shareOrUnshareAllLus(False) + + @staticmethod + def shareOrUnshareAllLus(share): + # Get a map of XenServer IPs -> IQNs + xenServerIqns = Host.getXenServerIqns() + + # Get a map of storage host IPs in the source system -> passwords + storageHostPasswords = GLOBALS["STORAGE_HOST_PASSWORDS"] + + # Copy share_all_lus.sh to each storage host and run with each XenServer IQN + for ip in storageHostPasswords.keys(): + password = storageHostPasswords[ip] + sshConn = paramiko.SSHClient() + sshConn.set_missing_host_key_policy(paramiko.AutoAddPolicy()) + sshConn.connect(ip, username = "root", password = password) + sftpConn = sshConn.open_sftp() + sftpConn.put("share_all_lus.sh", "/root/share_all_lus.sh") + for xenServerIp in xenServerIqns.keys(): + iqn = xenServerIqns[xenServerIp] + command = "bash /root/share_all_lus.sh -i " + iqn + if (share): + command += " -s" + else: + command += " -u" + stdin, stdout, stderr = sshConn.exec_command(command) + if (stdin.channel.recv_exit_status() != 0): + return False + if (not share): + sftpConn.remove("/root/share_all_lus.sh") + sshConn.close() + sftpConn.close() + + return True + + @staticmethod + def getHost(system, hostId): + columns = ["id", "private_ip_address", "iqn"] + hostRow = system.db.getDbValues("host", columns, {"id":hostId})[0] + return Host(hostRow["id"], hostRow["private_ip_address"], hostRow["iqn"]) + + @staticmethod + def getStorageHostIps(): + storageHostIps = [] + columns = ["private_ip_address"] + hostRows = GLOBALS["SRC_SYSTEM"].db.getDbValues("host", columns, {"type":"Storage"}) + for hostRow in hostRows: + storageHostIps.append(hostRow["private_ip_address"]) + return storageHostIps + + @staticmethod + def getXenServerIqns(): + xenServerIqns = dict() + columns = ["private_ip_address", "url"] + hostRows = GLOBALS["DEST_SYSTEM"].db.getDbValues("host", columns, {"type":"Routing"}) + for hostRow in hostRows: + xenServerIqns[hostRow["private_ip_address"]] = hostRow["url"] + return xenServerIqns + + +### Runtime + +GLOBALS = dict() + +@basicLogDecorator("Reading upgrade.properties...") +def readUpgradeProperties(): + upgradePropertiesFile = open("upgrade.properties", "r") + upgradeProperties = upgradePropertiesFile.read().splitlines() + for upgradeProperty in upgradeProperties: + if (upgradeProperty.strip() == ""): + continue + elif (upgradeProperty.startswith("#")): + continue + else: + propList = upgradeProperty.split("=") + var = propList[0].strip() + val = propList[1].strip() + if (val == ""): + continue + GLOBALS[var] = val + + # Create the log file + logFilePath = GLOBALS["LOG_FILE"] + GLOBALS["LOG_FILE"] = open(logFilePath, "a") + + # Create the guest OS map + GLOBALS["GUEST_OS_MAP"] = csvToMap(GLOBALS["GUEST_OS_MAP"]) + + # Create the XenServer passwords map + if (GLOBALS.get("DEST_XENSERVER_PASSWORDS") != None): + GLOBALS["DEST_XENSERVER_PASSWORDS"] = csvToMap(GLOBALS.get("DEST_XENSERVER_PASSWORDS")) + + # Create the Storage Host passwords map + if (GLOBALS.get("STORAGE_HOST_PASSWORDS") == None): + raise Exception ("Please fill out the variable STORAGE_HOST_PASSWORDS in upgrade.properties.") + else: + GLOBALS["STORAGE_HOST_PASSWORDS"] = csvToMap(GLOBALS["STORAGE_HOST_PASSWORDS"]) + + # Create the list of users to upgrade + if GLOBALS.has_key("USERS"): + GLOBALS["USERS"] = [userId.strip() for userId in GLOBALS["USERS"].split(",")] + else: + GLOBALS["USERS"] = None + + return True + +def csvToMap(csv): + entries = csv.split(",") + entryMap = dict() + for entry in entries: + entryList = entry.strip().split(":") + key = entryList[0].strip() + val = entryList[1].strip() + entryMap[key] = val + return entryMap + +@basicLogDecorator("Running diagnostic...") +def runDiagnostic(): + # Either one XenServer IP and password should be specified, or a mapping between XenServer IPs and passwords should be specified + if ((GLOBALS.get("DEST_XENSERVER_IP") == None and (GLOBALS.get("DEST_XENSERVER_PASSWORD") != None or GLOBALS.get("DEST_XENSERVER_PASSWORDS") == None)) + or (GLOBALS.get("DEST_XENSERVER_IP") != None and (GLOBALS.get("DEST_XENSERVER_PASSWORD") == None or GLOBALS.get("DEST_XENSERVER_PASSWORDS") != None))): + raise Exception("Please specify the IP and root password for one XenServer (if all XenServers have the same root password), or the IPs and root passwords of all XenServers.") + + GLOBALS["SRC_SYSTEM"] = System(GLOBALS["SRC_MANAGEMENT_SERVER_IP"], False, None, None, None, "vmops", GLOBALS["SRC_DB_LOGIN"], + GLOBALS.get("SRC_DB_PASSWORD"), GLOBALS["SRC_ZONE_ID"], None, None, None, None) + + GLOBALS["DEST_SYSTEM"] = System(GLOBALS["DEST_MANAGEMENT_SERVER_IP"], True, GLOBALS.get("DEST_XENSERVER_IP"), + GLOBALS.get("DEST_XENSERVER_PASSWORD"), GLOBALS.get("DEST_XENSERVER_PASSWORDS"), + "cloud", GLOBALS["DEST_DB_LOGIN"], GLOBALS.get("DEST_DB_PASSWORD"), GLOBALS["DEST_ZONE_ID"], + GLOBALS["DEST_TEMPLATE_ID"], GLOBALS["DEST_ISO_ID"], GLOBALS["DEST_SERVICE_OFFERING_ID"], + GLOBALS["DEST_DISK_OFFERING_ID"]) + + srcSystemServiceOfferings = ServiceOffering.getSrcSystemServiceOfferings() + for srcSystemServiceOffering in srcSystemServiceOfferings: + # Every service offering in the src system must have a corresponding service offering in the dest system + destSystemServiceOffering = ServiceOffering.getCorrespondingServiceOffering(srcSystemServiceOffering.id) + if (destSystemServiceOffering == None): + raise Exception("No corresponding service offering found for: " + str(srcSystemServiceOffering)) + + # Every service offering in the src system has a corresponding disk offering in the dest system + destSystemDiskOffering = DiskOffering.getCorrespondingDiskOffering(srcSystemServiceOffering.id) + if (destSystemDiskOffering == None): + raise Exception("No corresponding disk offering found for: " + str(srcSystemServiceOffering)) + + # Every template ID in the src system has a valid entry in GUEST_OS_MAP + srcSystemTemplateIds = VM.getTemplateIds(GLOBALS["SRC_SYSTEM"]) + for templateId in srcSystemTemplateIds: + if (not GLOBALS["GUEST_OS_MAP"].has_key(templateId)): + raise Exception("No corresponding guest OS ID for templateId: " + templateId) + else: + guestOsId = GLOBALS["GUEST_OS_MAP"][templateId] + guestOsName = VM.getGuestOsName(GLOBALS["DEST_SYSTEM"], guestOsId) + if (guestOsName == None): + raise Exception("The guest OS ID that corresponds to template ID: " + templateId + " is not valid.") + + # The dest system's ISO id must be valid + template = VM.getTemplate(GLOBALS["DEST_SYSTEM"], GLOBALS["DEST_ISO_ID"]) + if (template == None or template["format"] != "ISO"): + raise Exception("The dest system ISO ID is not valid.") + + # Verify that all source system storage hosts have a password + storageHostIps = Host.getStorageHostIps() + for ip in storageHostIps: + if (ip not in GLOBALS["STORAGE_HOST_PASSWORDS"].keys()): + raise Exception("The storage host IP: " + str(ip) + " has no entry in STORAGE_HOST_PASSWORDS.") + + return True + +@basicLogDecorator("Starting CloudStack Migration (1.0 -> 2.1)...") +def upgradeUsers(userIds, onlyMigratePublicIps): + # Read variables from upgrade.properties + readUpgradeProperties() + + # Run the diagnostic + runDiagnostic() + + if (userIds == None): + if (GLOBALS["USERS"] == None): + raise Exception("Please specify one or more users to upgrade.") + else: + userIds = GLOBALS["USERS"] + + # Make sure all users are valid + for userId in userIds: + if (User.get(GLOBALS["SRC_SYSTEM"], userId) == None): + raise Exception("The user ID: " + str(userId) + " is not valid.") + + if (not onlyMigratePublicIps): + # Share all LUs + Host.shareAllLus() + + try: + for userId in userIds: + doUpgrade(userId, onlyMigratePublicIps) + return True + finally: + if (not onlyMigratePublicIps): + # Forget all iSCSI SRs + SR.forgetAllSrcSrs() + # Unshare all LUs + Host.unshareAllLus() + +def doUpgrade(userId, onlyMigratePublicIps): + # Get the specified user from the source system + srcUser = User.get(GLOBALS["SRC_SYSTEM"], userId) + + writeToLog("\nStarting migration for " + str(srcUser), True) + + # Create a new user in the destination system with the same attributes as the original user + destUser = User.create(GLOBALS["DEST_SYSTEM"], srcUser) + + if (not srcUser.alreadyMigrated()): + # Allocate the src user's public IPs in the dest system + allocatedPublicIps = PublicIp.migrateAllocatedPublicIps(srcUser, destUser) + + if (onlyMigratePublicIps): + writeToLog("\nMigrated public IPs for " + str(srcUser), True) + return + + # Stop the source user's DomR + srcUserDomR = DomainRouter(srcUser) + + # Only migrate the user's VMs if there is a DomR + if (srcUserDomR.id != None): + srcUserDomR.stop() + + # If the dest user doesn't have a DomR, deploy a temporary VM + destUserDomR = DomainRouter(destUser) + tempVm = None + if (destUserDomR.id == None): + tempVm = VM(None, destUser, None, None, None) + tempVm.deployTemp() + destUserDomR.id = destUserDomR.getId() + else: + tempVm = VM.getTempVm(destUser) + + # Migrate the source user's VM's to the dest system + vmIdMap = VM.migrateVirtualMachines(srcUser, destUser) + + # Migrate the source user's port forwarding and load balancer rules + ForwardingRule.migrateForwardingRules(srcUser, destUser, allocatedPublicIps, vmIdMap) + + # Reboot the dest user's router + destUserDomR.reboot() + + # Destroy the temporary VM + if (tempVm != None): + tempVm.destroy() + + srcUser.tagAsMigrated() + else: + writeToLog("\n" + str(srcUser) + " has already been migrated.", True) + + writeToLog("\nMigration was successful for " + str(srcUser), True) + +if (len(sys.argv) > 1): + if (sys.argv[1].lower() == "publicips"): + if (len(sys.argv) > 2): + upgradeUsers(sys.argv[2:], True) + else: + upgradeUsers(None, True) + else: + upgradeUsers(sys.argv[1:], False) +else: + upgradeUsers(None, False) +