cloudstack/scripts/vm/hypervisor/kvm/imageserver/server.py

246 lines
8.3 KiB
Python

# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import argparse
import json
import logging
import os
import socket
import ssl
import threading
from http.server import HTTPServer
from socketserver import ThreadingMixIn
from typing import Type
try:
from http.server import ThreadingHTTPServer
except ImportError:
class ThreadingHTTPServer(ThreadingMixIn, HTTPServer): # type: ignore[no-redef]
pass
from .concurrency import ConcurrencyManager
from .config import TransferRegistry
from .constants import (
CONTROL_RECV_BUFFER,
CONTROL_SOCKET,
CONTROL_SOCKET_BACKLOG,
CONTROL_SOCKET_PERMISSIONS,
DEFAULT_HTTP_PORT,
DEFAULT_LISTEN_ADDRESS,
MAX_PARALLEL_READS,
MAX_PARALLEL_WRITES,
)
from .handler import Handler
def make_handler(
concurrency: ConcurrencyManager,
registry: TransferRegistry,
) -> Type[Handler]:
"""
Create a Handler subclass with injected dependencies.
BaseHTTPRequestHandler is instantiated per-request by the server, so we
cannot pass constructor args. Instead, we set class-level attributes.
"""
class ConfiguredHandler(Handler):
_concurrency = concurrency
_registry = registry
return ConfiguredHandler
def _validate_config(obj: dict) -> dict:
"""
Validate and normalize a transfer config dict received over the control
socket. Returns the cleaned config or raises ValueError.
"""
backend = obj.get("backend")
if backend is None:
backend = "nbd"
if not isinstance(backend, str):
raise ValueError("invalid backend type")
backend = backend.lower()
if backend not in ("nbd", "file"):
raise ValueError(f"unsupported backend: {backend}")
if backend == "file":
file_path = obj.get("file")
if not isinstance(file_path, str) or not file_path.strip():
raise ValueError("missing/invalid file path for file backend")
return {"backend": "file", "file": file_path.strip()}
socket_path = obj.get("socket")
export = obj.get("export")
export_bitmap = obj.get("export_bitmap")
if not isinstance(socket_path, str) or not socket_path.strip():
raise ValueError("missing/invalid socket path for nbd backend")
if export is not None and (not isinstance(export, str) or not export):
raise ValueError("invalid export name")
return {
"backend": "nbd",
"socket": socket_path.strip(),
"export": export,
"export_bitmap": export_bitmap,
}
def _handle_control_conn(conn: socket.socket, registry: TransferRegistry) -> None:
"""Handle a single control-socket connection (one JSON request/response)."""
try:
data = b""
while True:
chunk = conn.recv(CONTROL_RECV_BUFFER)
if not chunk:
break
data += chunk
if b"\n" in data:
break
msg = json.loads(data.strip())
action = msg.get("action")
if action == "register":
transfer_id = msg.get("transfer_id")
raw_config = msg.get("config")
if not transfer_id or not isinstance(raw_config, dict):
resp = {"status": "error", "message": "missing transfer_id or config"}
else:
try:
config = _validate_config(raw_config)
except ValueError as e:
resp = {"status": "error", "message": str(e)}
else:
if registry.register(transfer_id, config):
resp = {"status": "ok", "active_transfers": registry.active_count()}
else:
resp = {"status": "error", "message": "invalid transfer_id"}
elif action == "unregister":
transfer_id = msg.get("transfer_id")
if not transfer_id:
resp = {"status": "error", "message": "missing transfer_id"}
else:
remaining = registry.unregister(transfer_id)
resp = {"status": "ok", "active_transfers": remaining}
elif action == "status":
resp = {"status": "ok", "active_transfers": registry.active_count()}
else:
resp = {"status": "error", "message": f"unknown action: {action}"}
conn.sendall((json.dumps(resp) + "\n").encode("utf-8"))
except Exception as e:
logging.error("control socket error: %r", e)
try:
conn.sendall((json.dumps({"status": "error", "message": str(e)}) + "\n").encode("utf-8"))
except Exception:
pass
finally:
conn.close()
def _control_listener(registry: TransferRegistry, sock_path: str) -> None:
"""Accept loop for the Unix domain control socket (runs in a daemon thread)."""
if os.path.exists(sock_path):
os.unlink(sock_path)
os.makedirs(os.path.dirname(sock_path), exist_ok=True)
srv = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
srv.bind(sock_path)
os.chmod(sock_path, CONTROL_SOCKET_PERMISSIONS)
srv.listen(CONTROL_SOCKET_BACKLOG)
logging.info("control socket listening on %s", sock_path)
while True:
conn, _ = srv.accept()
threading.Thread(
target=_handle_control_conn,
args=(conn, registry),
daemon=True,
).start()
def main() -> None:
parser = argparse.ArgumentParser(
description="CloudStack image server backed by NBD / local file"
)
parser.add_argument("--listen", default=DEFAULT_LISTEN_ADDRESS, help="Address to bind")
parser.add_argument("--port", type=int, default=DEFAULT_HTTP_PORT, help="Port to listen on")
parser.add_argument(
"--control-socket",
default=CONTROL_SOCKET,
help="Path to the Unix domain control socket",
)
parser.add_argument(
"--tls-enabled",
action="store_true",
help="Enable TLS for the HTTP transfer endpoint",
)
parser.add_argument(
"--tls-cert-file",
default=None,
help="Path to PEM certificate file used when TLS is enabled",
)
parser.add_argument(
"--tls-key-file",
default=None,
help="Path to PEM private key file used when TLS is enabled",
)
args = parser.parse_args()
if args.tls_enabled and (not args.tls_cert_file or not args.tls_key_file):
parser.error("--tls-enabled requires --tls-cert-file and --tls-key-file")
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(message)s",
)
registry = TransferRegistry()
concurrency = ConcurrencyManager(MAX_PARALLEL_READS, MAX_PARALLEL_WRITES)
handler_cls = make_handler(concurrency, registry)
ctrl_thread = threading.Thread(
target=_control_listener,
args=(registry, args.control_socket),
daemon=True,
)
ctrl_thread.start()
addr = (args.listen, args.port)
httpd = ThreadingHTTPServer(addr, handler_cls)
scheme = "http"
if args.tls_enabled:
context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
if hasattr(ssl, "TLSVersion") and hasattr(context, "minimum_version"):
context.minimum_version = ssl.TLSVersion.TLSv1_2
else:
if hasattr(ssl, "OP_NO_TLSv1"):
context.options |= ssl.OP_NO_TLSv1
if hasattr(ssl, "OP_NO_TLSv1_1"):
context.options |= ssl.OP_NO_TLSv1_1
context.load_cert_chain(certfile=args.tls_cert_file, keyfile=args.tls_key_file)
httpd.socket = context.wrap_socket(httpd.socket, server_side=True)
scheme = "https"
logging.info("listening on %s://%s:%d", scheme, args.listen, args.port)
httpd.serve_forever()