From dad314a8a6d804621f02b65c358b2e4dca88b1b1 Mon Sep 17 00:00:00 2001 From: Abhisar Sinha <63767682+abh1sar@users.noreply.github.com> Date: Mon, 23 Mar 2026 20:53:37 +0530 Subject: [PATCH] Image server unittests --- .../kvm/imageserver/tests/__init__.py | 16 + .../kvm/imageserver/tests/test_base.py | 440 ++++++++++++++++++ .../imageserver/tests/test_combinations.py | 397 ++++++++++++++++ .../imageserver/tests/test_control_socket.py | 258 ++++++++++ .../imageserver/tests/test_file_backend.py | 230 +++++++++ .../kvm/imageserver/tests/test_nbd_backend.py | 393 ++++++++++++++++ 6 files changed, 1734 insertions(+) create mode 100644 scripts/vm/hypervisor/kvm/imageserver/tests/__init__.py create mode 100644 scripts/vm/hypervisor/kvm/imageserver/tests/test_base.py create mode 100644 scripts/vm/hypervisor/kvm/imageserver/tests/test_combinations.py create mode 100644 scripts/vm/hypervisor/kvm/imageserver/tests/test_control_socket.py create mode 100644 scripts/vm/hypervisor/kvm/imageserver/tests/test_file_backend.py create mode 100644 scripts/vm/hypervisor/kvm/imageserver/tests/test_nbd_backend.py diff --git a/scripts/vm/hypervisor/kvm/imageserver/tests/__init__.py b/scripts/vm/hypervisor/kvm/imageserver/tests/__init__.py new file mode 100644 index 00000000000..0ccbeeeafb7 --- /dev/null +++ b/scripts/vm/hypervisor/kvm/imageserver/tests/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/scripts/vm/hypervisor/kvm/imageserver/tests/test_base.py b/scripts/vm/hypervisor/kvm/imageserver/tests/test_base.py new file mode 100644 index 00000000000..91e7eda79ed --- /dev/null +++ b/scripts/vm/hypervisor/kvm/imageserver/tests/test_base.py @@ -0,0 +1,440 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Shared infrastructure for the image-server test suite (stdlib unittest only). + +Provides: +- A singleton image server process started once for the entire test run. +- Control-socket helpers using pure-Python AF_UNIX (no socat). +- qemu-nbd server management. +- Transfer registration / teardown helpers. +- HTTP helper functions. +""" + +import functools +import json +import logging +import os +import random +import select +import shutil +import signal +import socket +import subprocess +import sys +import tempfile +import time +import unittest +import uuid +from pathlib import Path +from typing import Any, Dict, Optional + +IMAGE_SIZE = 1 * 1024 * 1024 # 1 MiB +SERVER_STARTUP_TIMEOUT = 10 +QEMU_NBD_STARTUP_TIMEOUT = 5 +HTTP_TIMEOUT = 30 # seconds per HTTP request + +logging.basicConfig( + level=logging.INFO, + stream=sys.stderr, + format="%(asctime)s [TEST] %(message)s", +) +log = logging.getLogger(__name__) + + +def randbytes(seed, n): + """Generate n deterministic pseudo-random bytes (works on Python 3.6+).""" + rng = random.Random(seed) + return rng.getrandbits(8 * n).to_bytes(n, "big") + + +def test_timeout(seconds): + """Decorator that fails a test if it exceeds *seconds* (SIGALRM, Unix only).""" + def decorator(func): + @functools.wraps(func) + def wrapper(*args, **kwargs): + def _alarm(signum, frame): + raise TimeoutError( + "{} timed out after {}s".format(func.__qualname__, seconds) + ) + prev = signal.signal(signal.SIGALRM, _alarm) + signal.alarm(seconds) + try: + return func(*args, **kwargs) + finally: + signal.alarm(0) + signal.signal(signal.SIGALRM, prev) + return wrapper + return decorator + +# ── Singleton state shared across all test modules ────────────────────── + +_tmp_dir: Optional[str] = None +_server_proc: Optional[subprocess.Popen] = None +_server_info: Optional[Dict[str, Any]] = None + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def control_socket_send(sock_path: str, message: dict, retries: int = 5) -> dict: + """Send a JSON message to the control socket and return the parsed response.""" + payload = (json.dumps(message) + "\n").encode("utf-8") + last_err = None + for attempt in range(retries): + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.settimeout(5) + s.connect(sock_path) + s.sendall(payload) + s.shutdown(socket.SHUT_WR) + data = b"" + while True: + chunk = s.recv(4096) + if not chunk: + break + data += chunk + return json.loads(data.strip()) + except (BlockingIOError, ConnectionRefusedError, OSError) as e: + last_err = e + time.sleep(0.1 * (attempt + 1)) + raise last_err + + +def _wait_for_control_socket(sock_path: str, timeout: float = SERVER_STARTUP_TIMEOUT) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + try: + resp = control_socket_send(sock_path, {"action": "status"}) + if resp.get("status") == "ok": + return + except (ConnectionRefusedError, FileNotFoundError, OSError): + pass + time.sleep(0.2) + raise RuntimeError( + f"Image server control socket at {sock_path} not ready within {timeout}s" + ) + + +def _wait_for_nbd_socket(sock_path: str, timeout: float = QEMU_NBD_STARTUP_TIMEOUT) -> None: + deadline = time.monotonic() + timeout + while time.monotonic() < deadline: + if os.path.exists(sock_path): + try: + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.settimeout(1) + s.connect(sock_path) + return + except (ConnectionRefusedError, OSError): + pass + time.sleep(0.2) + raise RuntimeError( + f"qemu-nbd socket at {sock_path} not ready within {timeout}s" + ) + + +def get_tmp_dir() -> str: + global _tmp_dir + if _tmp_dir is None: + _tmp_dir = tempfile.mkdtemp(prefix="imageserver_test_") + return _tmp_dir + + +def get_image_server() -> Dict[str, Any]: + """Return the singleton image-server info dict, starting it if needed.""" + global _server_proc, _server_info + + if _server_info is not None: + return _server_info + + tmp = get_tmp_dir() + port = _free_port() + ctrl_sock = os.path.join(tmp, "ctrl.sock") + + imageserver_pkg = str(Path(__file__).resolve().parent.parent) + parent_dir = str(Path(imageserver_pkg).parent) + + env = os.environ.copy() + env["PYTHONPATH"] = parent_dir + os.pathsep + env.get("PYTHONPATH", "") + + proc = subprocess.Popen( + [ + sys.executable, "-m", "imageserver", + "--listen", "127.0.0.1", + "--port", str(port), + "--control-socket", ctrl_sock, + ], + cwd=parent_dir, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _server_proc = proc + + try: + _wait_for_control_socket(ctrl_sock) + except RuntimeError: + proc.kill() + stdout, stderr = proc.communicate(timeout=5) + raise RuntimeError( + f"Image server failed to start.\nstdout: {stdout.decode()}\nstderr: {stderr.decode()}" + ) + + def send(msg: dict) -> dict: + return control_socket_send(ctrl_sock, msg) + + _server_info = { + "base_url": f"http://127.0.0.1:{port}", + "port": port, + "ctrl_sock": ctrl_sock, + "send": send, + } + return _server_info + + +def shutdown_image_server() -> None: + global _server_proc, _server_info, _tmp_dir + if _server_proc is not None: + for pipe in (_server_proc.stdout, _server_proc.stderr): + if pipe: + try: + pipe.close() + except Exception: + pass + _server_proc.terminate() + try: + _server_proc.wait(timeout=5) + except subprocess.TimeoutExpired: + _server_proc.kill() + _server_proc.wait(timeout=5) + _server_proc = None + _server_info = None + if _tmp_dir is not None: + shutil.rmtree(_tmp_dir, ignore_errors=True) + _tmp_dir = None + + +# ── qemu-nbd server ──────────────────────────────────────────────────── + +class QemuNbdServer: + """Manages a qemu-nbd process exporting a raw image over a Unix socket.""" + + def __init__(self, image_path: str, socket_path: str, image_size: int = IMAGE_SIZE): + self.image_path = image_path + self.socket_path = socket_path + self.image_size = image_size + self._proc: Optional[subprocess.Popen] = None + + def start(self) -> None: + if not os.path.exists(self.image_path): + with open(self.image_path, "wb") as f: + f.truncate(self.image_size) + + self._proc = subprocess.Popen( + [ + "qemu-nbd", + "--socket", self.socket_path, + "--format", "raw", + "--persistent", + "--shared=8", + "--cache=none", + self.image_path, + ], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + _wait_for_nbd_socket(self.socket_path) + + def stop(self) -> None: + if self._proc is not None: + for pipe in (self._proc.stdout, self._proc.stderr): + if pipe: + try: + pipe.close() + except Exception: + pass + self._proc.terminate() + try: + self._proc.wait(timeout=5) + except subprocess.TimeoutExpired: + self._proc.kill() + self._proc.wait(timeout=5) + self._proc = None + + +# ── Factory helpers ───────────────────────────────────────────────────── + +def make_tmp_image(data=None, image_size=IMAGE_SIZE) -> str: + """Create a temp raw image file in the shared tmp dir; return path.""" + tmp = get_tmp_dir() + path = os.path.join(tmp, f"img_{uuid.uuid4().hex[:8]}.raw") + if data is not None: + with open(path, "wb") as f: + f.write(data) + else: + with open(path, "wb") as f: + f.write(randbytes(42, image_size)) + return path + + +def make_file_transfer(data=None, image_size=IMAGE_SIZE): + """ + Create a temp file + register a file-backend transfer. + Returns (transfer_id, url, file_path, cleanup_callable). + """ + srv = get_image_server() + path = make_tmp_image(data=data, image_size=image_size) + transfer_id = f"file-{uuid.uuid4().hex[:8]}" + resp = srv["send"]({ + "action": "register", + "transfer_id": transfer_id, + "config": {"backend": "file", "file": path}, + }) + assert resp["status"] == "ok", f"register failed: {resp}" + url = f"{srv['base_url']}/images/{transfer_id}" + + def cleanup(): + srv["send"]({"action": "unregister", "transfer_id": transfer_id}) + try: + os.unlink(path) + except FileNotFoundError: + pass + + return transfer_id, url, path, cleanup + + +def make_nbd_transfer(image_size=IMAGE_SIZE): + """ + Create a qemu-nbd server + register an NBD-backend transfer. + Returns (transfer_id, url, QemuNbdServer, cleanup_callable). + """ + srv = get_image_server() + tmp = get_tmp_dir() + img_path = os.path.join(tmp, f"nbd_{uuid.uuid4().hex[:8]}.raw") + sock_path = os.path.join(tmp, f"nbd_{uuid.uuid4().hex[:8]}.sock") + + server = QemuNbdServer(img_path, sock_path, image_size=image_size) + server.start() + + transfer_id = f"nbd-{uuid.uuid4().hex[:8]}" + resp = srv["send"]({ + "action": "register", + "transfer_id": transfer_id, + "config": {"backend": "nbd", "socket": sock_path}, + }) + assert resp["status"] == "ok", f"register failed: {resp}" + url = f"{srv['base_url']}/images/{transfer_id}" + + def cleanup(): + srv["send"]({"action": "unregister", "transfer_id": transfer_id}) + server.stop() + for p in (img_path, sock_path): + try: + os.unlink(p) + except FileNotFoundError: + pass + + return transfer_id, url, server, cleanup + + +# ── HTTP helpers ──────────────────────────────────────────────────────── + +import urllib.request +import urllib.error + + +def http_get(url, headers=None, timeout=HTTP_TIMEOUT): + req = urllib.request.Request(url, headers=headers or {}) + return urllib.request.urlopen(req, timeout=timeout) + + +def http_put(url, data, headers=None, timeout=HTTP_TIMEOUT): + hdrs = {"Content-Length": str(len(data))} + if headers: + hdrs.update(headers) + req = urllib.request.Request(url, data=data, headers=hdrs, method="PUT") + return urllib.request.urlopen(req, timeout=timeout) + + +def http_post(url, data=b"", headers=None, timeout=HTTP_TIMEOUT): + hdrs = {} + if headers: + hdrs.update(headers) + req = urllib.request.Request(url, data=data, headers=hdrs, method="POST") + return urllib.request.urlopen(req, timeout=timeout) + + +def http_options(url, timeout=HTTP_TIMEOUT): + req = urllib.request.Request(url, method="OPTIONS") + return urllib.request.urlopen(req, timeout=timeout) + + +def http_patch(url, data, headers=None, timeout=HTTP_TIMEOUT): + hdrs = {} + if headers: + hdrs.update(headers) + req = urllib.request.Request(url, data=data, headers=hdrs, method="PATCH") + return urllib.request.urlopen(req, timeout=timeout) + + +# ── Base TestCase with shared setUp/tearDown ──────────────────────────── + +class ImageServerTestCase(unittest.TestCase): + """ + Base class for image-server tests. + + Ensures the image server is running before any test method. + Subclasses that need a file or NBD transfer should set them up + in setUp() and tear down in tearDown(). + """ + + @classmethod + def setUpClass(cls): + cls.server = get_image_server() + cls.base_url = cls.server["base_url"] + + def ctrl(self, msg): + """Send a control-socket message; wraps server['send'] to avoid descriptor issues.""" + return self.server["send"](msg) + + def _make_tmp_image(self, data=None): + return make_tmp_image(data=data) + + def _register_file_transfer(self, data=None): + return make_file_transfer(data=data) + + def _register_nbd_transfer(self): + return make_nbd_transfer() + + @staticmethod + def dump_server_logs(): + """Read any available server stderr and print it for post-mortem debugging.""" + if _server_proc is None or _server_proc.stderr is None: + return + try: + if select.select([_server_proc.stderr], [], [], 0)[0]: + data = _server_proc.stderr.read1(64 * 1024) + if data: + sys.stderr.write("\n=== IMAGE SERVER STDERR ===\n") + sys.stderr.write(data.decode(errors="replace")) + sys.stderr.write("\n=== END SERVER STDERR ===\n") + except Exception: + pass diff --git a/scripts/vm/hypervisor/kvm/imageserver/tests/test_combinations.py b/scripts/vm/hypervisor/kvm/imageserver/tests/test_combinations.py new file mode 100644 index 00000000000..509f9fde05a --- /dev/null +++ b/scripts/vm/hypervisor/kvm/imageserver/tests/test_combinations.py @@ -0,0 +1,397 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Multi-operation sequences, parallel reads across multiple transfer objects, +cross-backend scenarios, and edge cases. +""" + +import json +import logging +import unittest +import urllib.error +from concurrent.futures import ThreadPoolExecutor, as_completed + +from .test_base import ( + IMAGE_SIZE, + ImageServerTestCase, + http_get, + http_patch, + http_post, + http_put, + make_file_transfer, + make_nbd_transfer, + randbytes, + shutdown_image_server, + test_timeout, +) + +log = logging.getLogger(__name__) +FUTURES_TIMEOUT = 60 # seconds for as_completed to collect all results + + +def _fetch(url, headers=None): + """GET *url* and return the body bytes, properly closing the response.""" + resp = http_get(url, headers=headers) + try: + return resp.read() + finally: + resp.close() + + +class TestParallelReadsFileBackend(ImageServerTestCase): + """Multiple concurrent GET requests to multiple file-backed transfers.""" + + @test_timeout(120) + def test_parallel_reads_single_file_transfer(self): + data = randbytes(500, IMAGE_SIZE) + tid, url, path, cleanup = make_file_transfer(data=data) + try: + results = {} + with ThreadPoolExecutor(max_workers=8) as pool: + futures = {} + for i in range(8): + start = i * (IMAGE_SIZE // 8) + end = start + (IMAGE_SIZE // 8) - 1 + f = pool.submit( + _fetch, url, headers={"Range": f"bytes={start}-{end}"} + ) + futures[f] = (start, end) + + for f in as_completed(futures, timeout=FUTURES_TIMEOUT): + start, end = futures[f] + results[(start, end)] = f.result() + + for (start, end), chunk in sorted(results.items()): + self.assertEqual(chunk, data[start:end + 1], f"Mismatch at {start}-{end}") + finally: + cleanup() + + @test_timeout(120) + def test_parallel_reads_multiple_file_transfers(self): + """Parallel reads across 4 different file-backed transfer objects.""" + transfers = [] + try: + for i in range(4): + data = randbytes(600 + i, IMAGE_SIZE) + tid, url, path, cleanup = make_file_transfer(data=data) + transfers.append((tid, url, data, cleanup)) + + with ThreadPoolExecutor(max_workers=8) as pool: + futures = {} + for idx, (tid, url, data, _) in enumerate(transfers): + for j in range(2): + f = pool.submit(_fetch, url) + futures[f] = (idx, data) + + for f in as_completed(futures, timeout=FUTURES_TIMEOUT): + idx, expected_data = futures[f] + got = f.result() + self.assertEqual(got, expected_data, f"Transfer {idx} mismatch") + finally: + for _, _, _, cleanup in transfers: + cleanup() + + +class TestParallelReadsNbdBackend(ImageServerTestCase): + """Multiple concurrent GET requests to multiple NBD-backed transfers.""" + + @test_timeout(120) + def test_parallel_reads_single_nbd_transfer(self): + data = randbytes(700, IMAGE_SIZE) + tid, url, nbd_server, cleanup = make_nbd_transfer() + try: + log.info("Writing %d bytes to NBD transfer %s", IMAGE_SIZE, tid) + http_put(url, data) + log.info("NBD write done, starting 8 parallel range reads") + + results = {} + with ThreadPoolExecutor(max_workers=8) as pool: + futures = {} + for i in range(8): + start = i * (IMAGE_SIZE // 8) + end = start + (IMAGE_SIZE // 8) - 1 + f = pool.submit( + _fetch, url, headers={"Range": f"bytes={start}-{end}"} + ) + futures[f] = (start, end) + + completed = 0 + for f in as_completed(futures, timeout=FUTURES_TIMEOUT): + start, end = futures[f] + results[(start, end)] = f.result() + completed += 1 + log.info("NBD range read %d/8 done: bytes=%d-%d", completed, start, end) + + for (start, end), chunk in sorted(results.items()): + self.assertEqual(chunk, data[start:end + 1], f"Mismatch at {start}-{end}") + finally: + cleanup() + + @test_timeout(120) + def test_parallel_reads_multiple_nbd_transfers(self): + """Parallel reads across 4 different NBD-backed transfer objects.""" + transfers = [] + try: + for i in range(4): + data = randbytes(800 + i, IMAGE_SIZE) + log.info("Setting up NBD transfer %d", i) + tid, url, nbd_server, cleanup = make_nbd_transfer() + log.info("Writing data to NBD transfer %d (tid=%s)", i, tid) + http_put(url, data) + transfers.append((tid, url, data, cleanup)) + log.info("NBD transfer %d ready", i) + + log.info("Starting parallel reads across %d NBD transfers", len(transfers)) + with ThreadPoolExecutor(max_workers=8) as pool: + futures = {} + for idx, (tid, url, data, _) in enumerate(transfers): + for j in range(2): + f = pool.submit(_fetch, url) + futures[f] = (idx, data) + + completed = 0 + for f in as_completed(futures, timeout=FUTURES_TIMEOUT): + idx, expected_data = futures[f] + got = f.result() + completed += 1 + log.info("Read %d/%d done: NBD transfer idx=%d, %d bytes", + completed, len(futures), idx, len(got)) + self.assertEqual(got, expected_data, f"NBD transfer {idx} mismatch") + finally: + for _, _, _, cleanup in transfers: + cleanup() + + +class TestParallelReadsMixedBackends(ImageServerTestCase): + """Parallel reads across a mix of file and NBD transfers simultaneously.""" + + @test_timeout(120) + def test_parallel_reads_file_and_nbd_mixed(self): + transfers = [] + try: + for i in range(2): + log.info("Setting up file transfer %d", i) + data = randbytes(900 + i, IMAGE_SIZE) + tid, url, path, cleanup = make_file_transfer(data=data) + transfers.append(("file", tid, url, data, cleanup)) + log.info("File transfer %d ready: tid=%s", i, tid) + + for i in range(2): + log.info("Setting up NBD transfer %d", i) + data = randbytes(950 + i, IMAGE_SIZE) + tid, url, nbd_server, cleanup = make_nbd_transfer() + log.info("NBD transfer %d registered: tid=%s, writing data...", i, tid) + http_put(url, data) + transfers.append(("nbd", tid, url, data, cleanup)) + log.info("NBD transfer %d ready", i) + + log.info("Starting parallel reads across %d transfers (2 file + 2 nbd)", + len(transfers)) + with ThreadPoolExecutor(max_workers=8) as pool: + futures = {} + for idx, (backend_type, tid, url, data, _) in enumerate(transfers): + for j in range(2): + f = pool.submit(_fetch, url) + futures[f] = (idx, backend_type, data) + + completed = 0 + for f in as_completed(futures, timeout=FUTURES_TIMEOUT): + idx, backend_type, expected = futures[f] + got = f.result() + completed += 1 + log.info("Read %d/%d done: %s transfer idx=%d, %d bytes", + completed, len(futures), backend_type, idx, len(got)) + self.assertEqual(got, expected, f"{backend_type} transfer {idx} mismatch") + + log.info("All parallel mixed reads completed successfully") + except TimeoutError: + log.error("TIMEOUT in mixed parallel reads — dumping server logs") + self.dump_server_logs() + raise + finally: + for _, _, _, _, cleanup in transfers: + cleanup() + + +class TestWriteThenReadNbd(ImageServerTestCase): + """Multi-step write sequences on NBD backend.""" + + def setUp(self): + self._tid, self._url, self._nbd, self._cleanup = make_nbd_transfer() + + def tearDown(self): + self._cleanup() + + def test_partial_writes_then_full_read(self): + http_put(self._url, b"\x00" * IMAGE_SIZE) + + chunk_size = IMAGE_SIZE // 4 + for i in range(4): + offset = i * chunk_size + end = offset + chunk_size - 1 + data = bytes([i & 0xFF]) * chunk_size + http_patch(self._url, data, headers={ + "Range": f"bytes={offset}-{end}", + "Content-Type": "application/octet-stream", + "Content-Length": str(chunk_size), + }) + + resp = http_get(self._url) + full = resp.read() + for i in range(4): + offset = i * chunk_size + self.assertEqual(full[offset:offset + chunk_size], bytes([i & 0xFF]) * chunk_size) + + def test_zero_then_extents(self): + http_put(self._url, randbytes(1000, IMAGE_SIZE)) + + payload = json.dumps({"op": "zero", "size": IMAGE_SIZE // 2, "offset": 0}).encode() + http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + + resp = http_get(f"{self._url}/extents") + extents = json.loads(resp.read()) + total = sum(e["length"] for e in extents) + self.assertEqual(total, IMAGE_SIZE) + + def test_write_flush_read(self): + data = randbytes(1001, IMAGE_SIZE) + resp = http_put(f"{self._url}?flush=y", data) + body = json.loads(resp.read()) + self.assertTrue(body["flushed"]) + + resp2 = http_get(self._url) + self.assertEqual(resp2.read(), data) + + +class TestWriteThenReadFile(ImageServerTestCase): + def setUp(self): + self._tid, self._url, self._path, self._cleanup = make_file_transfer() + + def tearDown(self): + self._cleanup() + + def test_put_then_get_roundtrip(self): + data = randbytes(1100, IMAGE_SIZE) + http_put(self._url, data) + resp = http_get(self._url) + self.assertEqual(resp.read(), data) + + +class TestRegisterUseUnregisterUse(ImageServerTestCase): + def test_unregistered_transfer_returns_404(self): + data = randbytes(1200, IMAGE_SIZE) + tid, url, path, cleanup = make_file_transfer(data=data) + + resp = http_get(url) + self.assertEqual(resp.read(), data) + + cleanup() + + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_get(url) + self.assertEqual(ctx.exception.code, 404) + + +class TestMultipleTransfersSimultaneous(ImageServerTestCase): + @test_timeout(120) + def test_operate_on_file_and_nbd_concurrently(self): + file_data = randbytes(1300, IMAGE_SIZE) + nbd_data = randbytes(1301, IMAGE_SIZE) + + ftid, furl, fpath, fcleanup = make_file_transfer(data=file_data) + ntid, nurl, nbd_server, ncleanup = make_nbd_transfer() + + try: + log.info("Writing data to NBD transfer %s", ntid) + http_put(nurl, nbd_data) + + log.info("Starting concurrent file + NBD reads") + with ThreadPoolExecutor(max_workers=4) as pool: + f_file = pool.submit(_fetch, furl) + f_nbd = pool.submit(_fetch, nurl) + + self.assertEqual(f_file.result(timeout=FUTURES_TIMEOUT), file_data) + self.assertEqual(f_nbd.result(timeout=FUTURES_TIMEOUT), nbd_data) + log.info("Concurrent reads completed successfully") + finally: + fcleanup() + ncleanup() + + +class TestLargeChunkedTransfer(ImageServerTestCase): + def test_put_larger_than_chunk_size_file(self): + """Upload data that spans multiple CHUNK_SIZE boundaries.""" + tid, url, path, cleanup = make_file_transfer() + try: + data = randbytes(1400, IMAGE_SIZE) + http_put(url, data) + resp = http_get(url) + self.assertEqual(resp.read(), data) + finally: + cleanup() + + def test_nbd_put_larger_than_chunk_size(self): + tid, url, nbd_server, cleanup = make_nbd_transfer() + try: + data = randbytes(1401, IMAGE_SIZE) + http_put(url, data) + resp = http_get(url) + self.assertEqual(resp.read(), data) + finally: + cleanup() + + +class TestEdgeCases(ImageServerTestCase): + def test_get_not_found_path(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_get(f"{self.base_url}/not/images/path") + self.assertEqual(ctx.exception.code, 404) + + def test_post_unknown_tail(self): + tid, url, path, cleanup = make_file_transfer() + try: + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_post(f"{url}/unknown") + self.assertEqual(ctx.exception.code, 404) + finally: + cleanup() + + def test_get_extents_then_flush_nbd(self): + tid, url, nbd_server, cleanup = make_nbd_transfer() + try: + http_put(url, randbytes(1500, IMAGE_SIZE)) + + resp = http_get(f"{url}/extents") + self.assertEqual(resp.status, 200) + resp.read() + + resp2 = http_post(f"{url}/flush") + body = json.loads(resp2.read()) + self.assertTrue(body["ok"]) + finally: + cleanup() + + +if __name__ == "__main__": + try: + unittest.main() + finally: + shutdown_image_server() diff --git a/scripts/vm/hypervisor/kvm/imageserver/tests/test_control_socket.py b/scripts/vm/hypervisor/kvm/imageserver/tests/test_control_socket.py new file mode 100644 index 00000000000..187592ff107 --- /dev/null +++ b/scripts/vm/hypervisor/kvm/imageserver/tests/test_control_socket.py @@ -0,0 +1,258 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for the Unix domain control socket protocol (register / unregister / status).""" + +import json +import socket +import unittest +import uuid +from concurrent.futures import ThreadPoolExecutor, as_completed + +from .test_base import ImageServerTestCase, make_tmp_image, shutdown_image_server, test_timeout + + +class TestStatus(ImageServerTestCase): + def test_status_returns_ok(self): + resp = self.ctrl({"action": "status"}) + self.assertEqual(resp["status"], "ok") + self.assertIn("active_transfers", resp) + + def test_status_count_is_integer(self): + resp = self.ctrl({"action": "status"}) + self.assertIsInstance(resp["active_transfers"], int) + self.assertGreaterEqual(resp["active_transfers"], 0) + + +class TestRegister(ImageServerTestCase): + def test_register_file_backend(self): + img = make_tmp_image() + tid = f"test-{uuid.uuid4().hex[:8]}" + try: + resp = self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "file", "file": img}, + }) + self.assertEqual(resp["status"], "ok") + self.assertGreaterEqual(resp["active_transfers"], 1) + finally: + self.ctrl({"action": "unregister", "transfer_id": tid}) + + def test_register_nbd_backend(self): + tid = f"test-{uuid.uuid4().hex[:8]}" + try: + resp = self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "nbd", "socket": "/tmp/fake.sock"}, + }) + self.assertEqual(resp["status"], "ok") + finally: + self.ctrl({"action": "unregister", "transfer_id": tid}) + + def test_register_increments_active_count(self): + img = make_tmp_image() + before = self.ctrl({"action": "status"})["active_transfers"] + tid = f"test-{uuid.uuid4().hex[:8]}" + try: + self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "file", "file": img}, + }) + after = self.ctrl({"action": "status"})["active_transfers"] + self.assertEqual(after, before + 1) + finally: + self.ctrl({"action": "unregister", "transfer_id": tid}) + + def test_register_missing_transfer_id(self): + img = make_tmp_image() + resp = self.ctrl({ + "action": "register", + "config": {"backend": "file", "file": img}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_empty_transfer_id(self): + img = make_tmp_image() + resp = self.ctrl({ + "action": "register", + "transfer_id": "", + "config": {"backend": "file", "file": img}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_missing_config(self): + resp = self.ctrl({ + "action": "register", + "transfer_id": f"test-{uuid.uuid4().hex[:8]}", + }) + self.assertEqual(resp["status"], "error") + + def test_register_invalid_backend(self): + resp = self.ctrl({ + "action": "register", + "transfer_id": f"test-{uuid.uuid4().hex[:8]}", + "config": {"backend": "invalid"}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_file_missing_path(self): + resp = self.ctrl({ + "action": "register", + "transfer_id": f"test-{uuid.uuid4().hex[:8]}", + "config": {"backend": "file"}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_nbd_missing_socket(self): + resp = self.ctrl({ + "action": "register", + "transfer_id": f"test-{uuid.uuid4().hex[:8]}", + "config": {"backend": "nbd"}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_path_traversal_rejected(self): + img = make_tmp_image() + resp = self.ctrl({ + "action": "register", + "transfer_id": "../etc/passwd", + "config": {"backend": "file", "file": img}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_dot_rejected(self): + img = make_tmp_image() + resp = self.ctrl({ + "action": "register", + "transfer_id": ".", + "config": {"backend": "file", "file": img}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_slash_rejected(self): + img = make_tmp_image() + resp = self.ctrl({ + "action": "register", + "transfer_id": "a/b", + "config": {"backend": "file", "file": img}, + }) + self.assertEqual(resp["status"], "error") + + def test_register_duplicate_replaces(self): + img = make_tmp_image() + tid = f"test-{uuid.uuid4().hex[:8]}" + try: + self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "file", "file": img}, + }) + count_before = self.ctrl({"action": "status"})["active_transfers"] + self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "file", "file": img}, + }) + count_after = self.ctrl({"action": "status"})["active_transfers"] + self.assertEqual(count_after, count_before) + finally: + self.ctrl({"action": "unregister", "transfer_id": tid}) + + +class TestUnregister(ImageServerTestCase): + def test_unregister_existing(self): + img = make_tmp_image() + tid = f"test-{uuid.uuid4().hex[:8]}" + self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "file", "file": img}, + }) + before = self.ctrl({"action": "status"})["active_transfers"] + resp = self.ctrl({"action": "unregister", "transfer_id": tid}) + self.assertEqual(resp["status"], "ok") + self.assertEqual(resp["active_transfers"], before - 1) + + def test_unregister_nonexistent(self): + resp = self.ctrl({"action": "unregister", "transfer_id": "does-not-exist"}) + self.assertEqual(resp["status"], "ok") + + def test_unregister_missing_id(self): + resp = self.ctrl({"action": "unregister"}) + self.assertEqual(resp["status"], "error") + + +class TestUnknownAction(ImageServerTestCase): + def test_unknown_action(self): + resp = self.ctrl({"action": "foobar"}) + self.assertEqual(resp["status"], "error") + self.assertIn("unknown", resp.get("message", "").lower()) + + +class TestMalformed(ImageServerTestCase): + def test_malformed_json(self): + sock_path = self.server["ctrl_sock"] + payload = b"not valid json\n" + with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s: + s.settimeout(5) + s.connect(sock_path) + s.sendall(payload) + s.shutdown(socket.SHUT_WR) + data = b"" + while True: + chunk = s.recv(4096) + if not chunk: + break + data += chunk + resp = json.loads(data.strip()) + self.assertEqual(resp["status"], "error") + + +class TestConcurrentRegistrations(ImageServerTestCase): + @test_timeout(60) + def test_concurrent_registers(self): + img = make_tmp_image() + tids = [f"conc-{uuid.uuid4().hex[:8]}" for _ in range(20)] + results = [] + + def register_one(tid): + return self.ctrl({ + "action": "register", + "transfer_id": tid, + "config": {"backend": "file", "file": img}, + }) + + try: + with ThreadPoolExecutor(max_workers=10) as pool: + futures = {pool.submit(register_one, tid): tid for tid in tids} + for f in as_completed(futures, timeout=30): + results.append(f.result()) + + self.assertTrue(all(r["status"] == "ok" for r in results)) + finally: + for tid in tids: + self.ctrl({"action": "unregister", "transfer_id": tid}) + + +if __name__ == "__main__": + try: + unittest.main() + finally: + shutdown_image_server() diff --git a/scripts/vm/hypervisor/kvm/imageserver/tests/test_file_backend.py b/scripts/vm/hypervisor/kvm/imageserver/tests/test_file_backend.py new file mode 100644 index 00000000000..be6eb259cc3 --- /dev/null +++ b/scripts/vm/hypervisor/kvm/imageserver/tests/test_file_backend.py @@ -0,0 +1,230 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for HTTP operations against a file-backend transfer.""" + +import json +import os +import unittest +import urllib.error + +from .test_base import ( + IMAGE_SIZE, + ImageServerTestCase, + http_get, + http_options, + http_patch, + http_post, + http_put, + make_file_transfer, + randbytes, + shutdown_image_server, +) + + +class FileBackendTestCase(ImageServerTestCase): + """Base that creates a file-backend transfer per test.""" + + def setUp(self): + self._tid, self._url, self._path, self._cleanup = make_file_transfer() + + def tearDown(self): + self._cleanup() + + +class TestOptions(FileBackendTestCase): + def test_options_returns_features(self): + resp = http_options(self._url) + self.assertEqual(resp.status, 200) + body = json.loads(resp.read()) + self.assertIn("flush", body["features"]) + self.assertGreaterEqual(body["max_readers"], 1) + self.assertGreaterEqual(body["max_writers"], 1) + + def test_options_allowed_methods(self): + resp = http_options(self._url) + methods = resp.getheader("Access-Control-Allow-Methods") + for m in ("GET", "PUT", "POST", "OPTIONS"): + self.assertIn(m, methods) + + +class TestGetFull(FileBackendTestCase): + def test_get_full_returns_file_content(self): + with open(self._path, "rb") as f: + expected = f.read() + resp = http_get(self._url) + self.assertEqual(resp.status, 200) + data = resp.read() + self.assertEqual(len(data), len(expected)) + self.assertEqual(data, expected) + + def test_get_full_content_type(self): + resp = http_get(self._url) + resp.read() + self.assertIn("application/octet-stream", resp.getheader("Content-Type")) + + def test_get_full_content_length(self): + resp = http_get(self._url) + resp.read() + self.assertEqual(int(resp.getheader("Content-Length")), os.path.getsize(self._path)) + + +class TestGetRange(FileBackendTestCase): + def test_get_range_partial(self): + with open(self._path, "rb") as f: + f.seek(100) + expected = f.read(200) + resp = http_get(self._url, headers={"Range": "bytes=100-299"}) + self.assertEqual(resp.status, 206) + self.assertEqual(resp.read(), expected) + + def test_get_range_content_range_header(self): + size = os.path.getsize(self._path) + resp = http_get(self._url, headers={"Range": "bytes=0-99"}) + self.assertEqual(resp.status, 206) + resp.read() + self.assertEqual(resp.getheader("Content-Range"), f"bytes 0-99/{size}") + + def test_get_range_suffix(self): + with open(self._path, "rb") as f: + expected = f.read()[-100:] + resp = http_get(self._url, headers={"Range": "bytes=-100"}) + self.assertEqual(resp.status, 206) + self.assertEqual(resp.read(), expected) + + def test_get_range_open_ended(self): + with open(self._path, "rb") as f: + f.seek(IMAGE_SIZE - 50) + expected = f.read() + resp = http_get(self._url, headers={"Range": f"bytes={IMAGE_SIZE - 50}-"}) + self.assertEqual(resp.status, 206) + self.assertEqual(resp.read(), expected) + + def test_get_range_unsatisfiable(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_get(self._url, headers={"Range": f"bytes={IMAGE_SIZE + 100}-{IMAGE_SIZE + 200}"}) + self.assertEqual(ctx.exception.code, 416) + + +class TestPut(FileBackendTestCase): + def test_put_full_upload(self): + new_data = randbytes(99, IMAGE_SIZE) + resp = http_put(self._url, new_data) + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + self.assertEqual(body["bytes_written"], IMAGE_SIZE) + + with open(self._path, "rb") as f: + self.assertEqual(f.read(), new_data) + + def test_put_with_flush(self): + new_data = randbytes(100, IMAGE_SIZE) + resp = http_put(f"{self._url}?flush=y", new_data) + body = json.loads(resp.read()) + self.assertTrue(body["ok"]) + self.assertTrue(body["flushed"]) + + def test_put_verify_by_get(self): + new_data = randbytes(101, IMAGE_SIZE) + http_put(self._url, new_data) + resp = http_get(self._url) + self.assertEqual(resp.read(), new_data) + + def test_put_with_content_range_rejected(self): + data = b"x" * 100 + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_put(self._url, data, headers={"Content-Range": "bytes 0-99/*"}) + self.assertEqual(ctx.exception.code, 400) + + def test_put_with_range_header_rejected(self): + data = b"x" * 100 + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_put(self._url, data, headers={"Range": "bytes=0-99"}) + self.assertEqual(ctx.exception.code, 400) + + +class TestFlush(FileBackendTestCase): + def test_post_flush(self): + resp = http_post(f"{self._url}/flush") + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + + +class TestPatchRejected(FileBackendTestCase): + def test_patch_rejected_for_file(self): + data = json.dumps({"op": "zero", "size": 100}).encode() + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_patch(self._url, data, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(data)), + }) + self.assertEqual(ctx.exception.code, 400) + + +class TestExtentsRejected(FileBackendTestCase): + def test_extents_rejected_for_file(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_get(f"{self._url}/extents") + self.assertEqual(ctx.exception.code, 400) + + +class TestUnknownImage(ImageServerTestCase): + def test_get_unknown_image(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_get(f"{self.base_url}/images/nonexistent-id") + self.assertEqual(ctx.exception.code, 404) + + def test_put_unknown_image(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_put(f"{self.base_url}/images/nonexistent-id", b"data") + self.assertEqual(ctx.exception.code, 404) + + def test_options_unknown_image(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_options(f"{self.base_url}/images/nonexistent-id") + self.assertEqual(ctx.exception.code, 404) + + +class TestRoundTrip(FileBackendTestCase): + def test_put_then_get_roundtrip(self): + payload = randbytes(200, IMAGE_SIZE) + http_put(self._url, payload) + resp = http_get(self._url) + self.assertEqual(resp.read(), payload) + + def test_put_then_ranged_get_roundtrip(self): + payload = randbytes(201, IMAGE_SIZE) + http_put(self._url, payload) + resp = http_get(self._url, headers={"Range": "bytes=512-1023"}) + self.assertEqual(resp.read(), payload[512:1024]) + + def test_multiple_puts_last_wins(self): + first = randbytes(300, IMAGE_SIZE) + second = randbytes(301, IMAGE_SIZE) + http_put(self._url, first) + http_put(self._url, second) + resp = http_get(self._url) + self.assertEqual(resp.read(), second) + + +if __name__ == "__main__": + try: + unittest.main() + finally: + shutdown_image_server() diff --git a/scripts/vm/hypervisor/kvm/imageserver/tests/test_nbd_backend.py b/scripts/vm/hypervisor/kvm/imageserver/tests/test_nbd_backend.py new file mode 100644 index 00000000000..4c0e66003b3 --- /dev/null +++ b/scripts/vm/hypervisor/kvm/imageserver/tests/test_nbd_backend.py @@ -0,0 +1,393 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Tests for HTTP operations against an NBD-backend transfer (real qemu-nbd).""" + +import json +import unittest +import urllib.error +import urllib.request + +from .test_base import ( + IMAGE_SIZE, + ImageServerTestCase, + http_get, + http_options, + http_patch, + http_post, + http_put, + make_nbd_transfer, + randbytes, + shutdown_image_server, +) + + +class NbdBackendTestCase(ImageServerTestCase): + """Base that creates an NBD-backend transfer per test.""" + + def setUp(self): + self._tid, self._url, self._nbd, self._cleanup = make_nbd_transfer() + + def tearDown(self): + self._cleanup() + + +class TestOptions(NbdBackendTestCase): + def test_options_returns_extents_feature(self): + resp = http_options(self._url) + self.assertEqual(resp.status, 200) + body = json.loads(resp.read()) + self.assertIn("extents", body["features"]) + + def test_options_includes_patch_method(self): + resp = http_options(self._url) + methods = resp.getheader("Access-Control-Allow-Methods") + self.assertIn("PATCH", methods) + + def test_options_has_capabilities(self): + resp = http_options(self._url) + body = json.loads(resp.read()) + self.assertGreaterEqual(body["max_readers"], 1) + self.assertGreaterEqual(body["max_writers"], 1) + + +class TestGetFull(NbdBackendTestCase): + def test_get_full_returns_image_data(self): + with open(self._nbd.image_path, "rb") as f: + expected = f.read() + resp = http_get(self._url) + data = resp.read() + self.assertEqual(resp.status, 200) + self.assertEqual(len(data), len(expected)) + self.assertEqual(data, expected) + + def test_get_full_content_length(self): + resp = http_get(self._url) + resp.read() + self.assertEqual(int(resp.getheader("Content-Length")), IMAGE_SIZE) + + +class TestGetRange(NbdBackendTestCase): + def test_get_range_partial(self): + test_data = randbytes(50, IMAGE_SIZE) + http_put(self._url, test_data) + + resp = http_get(self._url, headers={"Range": "bytes=100-299"}) + self.assertEqual(resp.status, 206) + self.assertEqual(resp.read(), test_data[100:300]) + + def test_get_range_content_range_header(self): + resp = http_get(self._url, headers={"Range": "bytes=0-99"}) + self.assertEqual(resp.status, 206) + resp.read() + self.assertEqual(resp.getheader("Content-Range"), f"bytes 0-99/{IMAGE_SIZE}") + + def test_get_range_suffix(self): + test_data = randbytes(51, IMAGE_SIZE) + http_put(self._url, test_data) + + resp = http_get(self._url, headers={"Range": "bytes=-100"}) + self.assertEqual(resp.status, 206) + self.assertEqual(resp.read(), test_data[-100:]) + + def test_get_range_unsatisfiable(self): + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_get(self._url, headers={"Range": f"bytes={IMAGE_SIZE + 100}-{IMAGE_SIZE + 200}"}) + self.assertEqual(ctx.exception.code, 416) + + +class TestPutFull(NbdBackendTestCase): + def test_put_full_upload(self): + new_data = randbytes(60, IMAGE_SIZE) + resp = http_put(self._url, new_data) + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + self.assertEqual(body["bytes_written"], IMAGE_SIZE) + + resp2 = http_get(self._url) + self.assertEqual(resp2.read(), new_data) + + def test_put_with_flush(self): + new_data = randbytes(61, IMAGE_SIZE) + resp = http_put(f"{self._url}?flush=y", new_data) + body = json.loads(resp.read()) + self.assertTrue(body["ok"]) + self.assertTrue(body["flushed"]) + + +class TestPutRange(NbdBackendTestCase): + def test_put_content_range(self): + base_data = randbytes(70, IMAGE_SIZE) + http_put(self._url, base_data) + + patch_data = b"\xAB" * 512 + resp = http_put(self._url, patch_data, headers={ + "Content-Range": "bytes 0-511/*", + "Content-Length": str(len(patch_data)), + }) + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + self.assertEqual(body["bytes_written"], 512) + + resp2 = http_get(self._url, headers={"Range": "bytes=0-511"}) + self.assertEqual(resp2.read(), patch_data) + + resp3 = http_get(self._url, headers={"Range": "bytes=512-1023"}) + self.assertEqual(resp3.read(), base_data[512:1024]) + + def test_put_content_range_with_flush(self): + base_data = b"\x00" * IMAGE_SIZE + http_put(self._url, base_data) + + patch_data = b"\xFF" * 256 + resp = http_put(f"{self._url}?flush=y", patch_data, headers={ + "Content-Range": "bytes 1024-1279/*", + "Content-Length": str(len(patch_data)), + }) + body = json.loads(resp.read()) + self.assertTrue(body["ok"]) + self.assertTrue(body["flushed"]) + + +class TestPatchRange(NbdBackendTestCase): + def test_patch_binary_range(self): + base_data = randbytes(80, IMAGE_SIZE) + http_put(self._url, base_data) + + patch_data = b"\xCD" * 1024 + resp = http_patch(self._url, patch_data, headers={ + "Range": "bytes=2048-3071", + "Content-Type": "application/octet-stream", + "Content-Length": str(len(patch_data)), + }) + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + self.assertEqual(body["bytes_written"], 1024) + + resp2 = http_get(self._url, headers={"Range": "bytes=2048-3071"}) + self.assertEqual(resp2.read(), patch_data) + + def test_patch_multiple_ranges_preserves_unwritten(self): + base_data = randbytes(81, IMAGE_SIZE) + http_put(self._url, base_data) + + patch1 = b"\x11" * 256 + http_patch(self._url, patch1, headers={ + "Range": "bytes=0-255", + "Content-Type": "application/octet-stream", + "Content-Length": "256", + }) + + patch2 = b"\x22" * 256 + http_patch(self._url, patch2, headers={ + "Range": "bytes=512-767", + "Content-Type": "application/octet-stream", + "Content-Length": "256", + }) + + resp = http_get(self._url, headers={"Range": "bytes=0-767"}) + got = resp.read() + self.assertEqual(got[:256], patch1) + self.assertEqual(got[256:512], base_data[256:512]) + self.assertEqual(got[512:768], patch2) + + +class TestPatchZero(NbdBackendTestCase): + def test_patch_zero(self): + data = randbytes(90, IMAGE_SIZE) + http_put(self._url, data) + + payload = json.dumps({"op": "zero", "size": 4096, "offset": 0}).encode() + resp = http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + + resp2 = http_get(self._url, headers={"Range": "bytes=0-4095"}) + self.assertEqual(resp2.read(), b"\x00" * 4096) + + def test_patch_zero_with_flush(self): + data = b"\xFF" * IMAGE_SIZE + http_put(self._url, data) + + payload = json.dumps({"op": "zero", "size": 512, "offset": 1024, "flush": True}).encode() + resp = http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + body = json.loads(resp.read()) + self.assertTrue(body["ok"]) + + resp2 = http_get(self._url, headers={"Range": "bytes=1024-1535"}) + self.assertEqual(resp2.read(), b"\x00" * 512) + + def test_patch_zero_preserves_neighbors(self): + data = randbytes(91, IMAGE_SIZE) + http_put(self._url, data) + + payload = json.dumps({"op": "zero", "size": 256, "offset": 512}).encode() + http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + + resp = http_get(self._url, headers={"Range": "bytes=0-1023"}) + got = resp.read() + self.assertEqual(got[:512], data[:512]) + self.assertEqual(got[512:768], b"\x00" * 256) + self.assertEqual(got[768:1024], data[768:1024]) + + +class TestPatchFlush(NbdBackendTestCase): + def test_patch_flush_op(self): + payload = json.dumps({"op": "flush"}).encode() + resp = http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + + +class TestPostFlush(NbdBackendTestCase): + def test_post_flush(self): + resp = http_post(f"{self._url}/flush") + body = json.loads(resp.read()) + self.assertEqual(resp.status, 200) + self.assertTrue(body["ok"]) + + +class TestExtents(NbdBackendTestCase): + def test_get_allocation_extents(self): + resp = http_get(f"{self._url}/extents") + self.assertEqual(resp.status, 200) + extents = json.loads(resp.read()) + self.assertIsInstance(extents, list) + self.assertGreaterEqual(len(extents), 1) + for ext in extents: + self.assertIn("start", ext) + self.assertIn("length", ext) + self.assertIn("zero", ext) + + def test_extents_cover_full_image(self): + resp = http_get(f"{self._url}/extents") + extents = json.loads(resp.read()) + total = sum(e["length"] for e in extents) + self.assertEqual(total, IMAGE_SIZE) + + def test_extents_dirty_context_without_bitmap(self): + resp = http_get(f"{self._url}/extents?context=dirty") + self.assertEqual(resp.status, 200) + extents = json.loads(resp.read()) + self.assertIsInstance(extents, list) + self.assertGreaterEqual(len(extents), 1) + for ext in extents: + self.assertIn("dirty", ext) + self.assertTrue(ext["dirty"]) + + def test_extents_after_write_and_zero(self): + http_put(self._url, randbytes(95, IMAGE_SIZE)) + + payload = json.dumps({"op": "zero", "size": 4096, "offset": 0}).encode() + http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + + resp = http_get(f"{self._url}/extents") + extents = json.loads(resp.read()) + self.assertGreaterEqual(len(extents), 1) + total = sum(e["length"] for e in extents) + self.assertEqual(total, IMAGE_SIZE) + + +class TestErrorCases(NbdBackendTestCase): + def test_patch_unsupported_op(self): + payload = json.dumps({"op": "invalid"}).encode() + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + self.assertEqual(ctx.exception.code, 400) + + def test_patch_zero_missing_size(self): + payload = json.dumps({"op": "zero", "offset": 0}).encode() + with self.assertRaises(urllib.error.HTTPError) as ctx: + http_patch(self._url, payload, headers={ + "Content-Type": "application/json", + "Content-Length": str(len(payload)), + }) + self.assertEqual(ctx.exception.code, 400) + + def test_put_missing_content_length(self): + import http.client + from urllib.parse import urlparse + parsed = urlparse(self._url) + conn = http.client.HTTPConnection(parsed.hostname, parsed.port, timeout=30) + try: + conn.putrequest("PUT", parsed.path) + conn.endheaders() + resp = conn.getresponse() + self.assertEqual(resp.status, 400) + finally: + conn.close() + + +class TestRoundTrip(NbdBackendTestCase): + def test_write_read_full_roundtrip(self): + payload = randbytes(110, IMAGE_SIZE) + http_put(self._url, payload) + resp = http_get(self._url) + self.assertEqual(resp.read(), payload) + + def test_write_read_range_roundtrip(self): + payload = randbytes(111, IMAGE_SIZE) + http_put(self._url, payload) + + for start, end in [(0, 255), (1024, 2047), (IMAGE_SIZE - 512, IMAGE_SIZE - 1)]: + resp = http_get(self._url, headers={"Range": f"bytes={start}-{end}"}) + self.assertEqual(resp.read(), payload[start:end + 1]) + + def test_range_write_read_roundtrip(self): + http_put(self._url, b"\x00" * IMAGE_SIZE) + + chunk = randbytes(112, 4096) + http_put(self._url, chunk, headers={ + "Content-Range": "bytes 8192-12287/*", + "Content-Length": str(len(chunk)), + }) + + resp = http_get(self._url, headers={"Range": "bytes=8192-12287"}) + self.assertEqual(resp.read(), chunk) + + resp2 = http_get(self._url, headers={"Range": "bytes=0-4095"}) + self.assertEqual(resp2.read(), b"\x00" * 4096) + + +if __name__ == "__main__": + try: + unittest.main() + finally: + shutdown_image_server()