Image server unittests

This commit is contained in:
Abhisar Sinha 2026-03-23 20:53:37 +05:30 committed by Abhishek Kumar
parent 81fc6d5da6
commit dad314a8a6
6 changed files with 1734 additions and 0 deletions

View File

@ -0,0 +1,16 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

View File

@ -0,0 +1,440 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Shared infrastructure for the image-server test suite (stdlib unittest only).
Provides:
- A singleton image server process started once for the entire test run.
- Control-socket helpers using pure-Python AF_UNIX (no socat).
- qemu-nbd server management.
- Transfer registration / teardown helpers.
- HTTP helper functions.
"""
import functools
import json
import logging
import os
import random
import select
import shutil
import signal
import socket
import subprocess
import sys
import tempfile
import time
import unittest
import uuid
from pathlib import Path
from typing import Any, Dict, Optional
IMAGE_SIZE = 1 * 1024 * 1024 # 1 MiB
SERVER_STARTUP_TIMEOUT = 10
QEMU_NBD_STARTUP_TIMEOUT = 5
HTTP_TIMEOUT = 30 # seconds per HTTP request
logging.basicConfig(
level=logging.INFO,
stream=sys.stderr,
format="%(asctime)s [TEST] %(message)s",
)
log = logging.getLogger(__name__)
def randbytes(seed, n):
"""Generate n deterministic pseudo-random bytes (works on Python 3.6+)."""
rng = random.Random(seed)
return rng.getrandbits(8 * n).to_bytes(n, "big")
def test_timeout(seconds):
"""Decorator that fails a test if it exceeds *seconds* (SIGALRM, Unix only)."""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
def _alarm(signum, frame):
raise TimeoutError(
"{} timed out after {}s".format(func.__qualname__, seconds)
)
prev = signal.signal(signal.SIGALRM, _alarm)
signal.alarm(seconds)
try:
return func(*args, **kwargs)
finally:
signal.alarm(0)
signal.signal(signal.SIGALRM, prev)
return wrapper
return decorator
# ── Singleton state shared across all test modules ──────────────────────
_tmp_dir: Optional[str] = None
_server_proc: Optional[subprocess.Popen] = None
_server_info: Optional[Dict[str, Any]] = None
def _free_port() -> int:
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
s.bind(("127.0.0.1", 0))
return s.getsockname()[1]
def control_socket_send(sock_path: str, message: dict, retries: int = 5) -> dict:
"""Send a JSON message to the control socket and return the parsed response."""
payload = (json.dumps(message) + "\n").encode("utf-8")
last_err = None
for attempt in range(retries):
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
s.settimeout(5)
s.connect(sock_path)
s.sendall(payload)
s.shutdown(socket.SHUT_WR)
data = b""
while True:
chunk = s.recv(4096)
if not chunk:
break
data += chunk
return json.loads(data.strip())
except (BlockingIOError, ConnectionRefusedError, OSError) as e:
last_err = e
time.sleep(0.1 * (attempt + 1))
raise last_err
def _wait_for_control_socket(sock_path: str, timeout: float = SERVER_STARTUP_TIMEOUT) -> None:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
try:
resp = control_socket_send(sock_path, {"action": "status"})
if resp.get("status") == "ok":
return
except (ConnectionRefusedError, FileNotFoundError, OSError):
pass
time.sleep(0.2)
raise RuntimeError(
f"Image server control socket at {sock_path} not ready within {timeout}s"
)
def _wait_for_nbd_socket(sock_path: str, timeout: float = QEMU_NBD_STARTUP_TIMEOUT) -> None:
deadline = time.monotonic() + timeout
while time.monotonic() < deadline:
if os.path.exists(sock_path):
try:
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
s.settimeout(1)
s.connect(sock_path)
return
except (ConnectionRefusedError, OSError):
pass
time.sleep(0.2)
raise RuntimeError(
f"qemu-nbd socket at {sock_path} not ready within {timeout}s"
)
def get_tmp_dir() -> str:
global _tmp_dir
if _tmp_dir is None:
_tmp_dir = tempfile.mkdtemp(prefix="imageserver_test_")
return _tmp_dir
def get_image_server() -> Dict[str, Any]:
"""Return the singleton image-server info dict, starting it if needed."""
global _server_proc, _server_info
if _server_info is not None:
return _server_info
tmp = get_tmp_dir()
port = _free_port()
ctrl_sock = os.path.join(tmp, "ctrl.sock")
imageserver_pkg = str(Path(__file__).resolve().parent.parent)
parent_dir = str(Path(imageserver_pkg).parent)
env = os.environ.copy()
env["PYTHONPATH"] = parent_dir + os.pathsep + env.get("PYTHONPATH", "")
proc = subprocess.Popen(
[
sys.executable, "-m", "imageserver",
"--listen", "127.0.0.1",
"--port", str(port),
"--control-socket", ctrl_sock,
],
cwd=parent_dir,
env=env,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
_server_proc = proc
try:
_wait_for_control_socket(ctrl_sock)
except RuntimeError:
proc.kill()
stdout, stderr = proc.communicate(timeout=5)
raise RuntimeError(
f"Image server failed to start.\nstdout: {stdout.decode()}\nstderr: {stderr.decode()}"
)
def send(msg: dict) -> dict:
return control_socket_send(ctrl_sock, msg)
_server_info = {
"base_url": f"http://127.0.0.1:{port}",
"port": port,
"ctrl_sock": ctrl_sock,
"send": send,
}
return _server_info
def shutdown_image_server() -> None:
global _server_proc, _server_info, _tmp_dir
if _server_proc is not None:
for pipe in (_server_proc.stdout, _server_proc.stderr):
if pipe:
try:
pipe.close()
except Exception:
pass
_server_proc.terminate()
try:
_server_proc.wait(timeout=5)
except subprocess.TimeoutExpired:
_server_proc.kill()
_server_proc.wait(timeout=5)
_server_proc = None
_server_info = None
if _tmp_dir is not None:
shutil.rmtree(_tmp_dir, ignore_errors=True)
_tmp_dir = None
# ── qemu-nbd server ────────────────────────────────────────────────────
class QemuNbdServer:
"""Manages a qemu-nbd process exporting a raw image over a Unix socket."""
def __init__(self, image_path: str, socket_path: str, image_size: int = IMAGE_SIZE):
self.image_path = image_path
self.socket_path = socket_path
self.image_size = image_size
self._proc: Optional[subprocess.Popen] = None
def start(self) -> None:
if not os.path.exists(self.image_path):
with open(self.image_path, "wb") as f:
f.truncate(self.image_size)
self._proc = subprocess.Popen(
[
"qemu-nbd",
"--socket", self.socket_path,
"--format", "raw",
"--persistent",
"--shared=8",
"--cache=none",
self.image_path,
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
_wait_for_nbd_socket(self.socket_path)
def stop(self) -> None:
if self._proc is not None:
for pipe in (self._proc.stdout, self._proc.stderr):
if pipe:
try:
pipe.close()
except Exception:
pass
self._proc.terminate()
try:
self._proc.wait(timeout=5)
except subprocess.TimeoutExpired:
self._proc.kill()
self._proc.wait(timeout=5)
self._proc = None
# ── Factory helpers ─────────────────────────────────────────────────────
def make_tmp_image(data=None, image_size=IMAGE_SIZE) -> str:
"""Create a temp raw image file in the shared tmp dir; return path."""
tmp = get_tmp_dir()
path = os.path.join(tmp, f"img_{uuid.uuid4().hex[:8]}.raw")
if data is not None:
with open(path, "wb") as f:
f.write(data)
else:
with open(path, "wb") as f:
f.write(randbytes(42, image_size))
return path
def make_file_transfer(data=None, image_size=IMAGE_SIZE):
"""
Create a temp file + register a file-backend transfer.
Returns (transfer_id, url, file_path, cleanup_callable).
"""
srv = get_image_server()
path = make_tmp_image(data=data, image_size=image_size)
transfer_id = f"file-{uuid.uuid4().hex[:8]}"
resp = srv["send"]({
"action": "register",
"transfer_id": transfer_id,
"config": {"backend": "file", "file": path},
})
assert resp["status"] == "ok", f"register failed: {resp}"
url = f"{srv['base_url']}/images/{transfer_id}"
def cleanup():
srv["send"]({"action": "unregister", "transfer_id": transfer_id})
try:
os.unlink(path)
except FileNotFoundError:
pass
return transfer_id, url, path, cleanup
def make_nbd_transfer(image_size=IMAGE_SIZE):
"""
Create a qemu-nbd server + register an NBD-backend transfer.
Returns (transfer_id, url, QemuNbdServer, cleanup_callable).
"""
srv = get_image_server()
tmp = get_tmp_dir()
img_path = os.path.join(tmp, f"nbd_{uuid.uuid4().hex[:8]}.raw")
sock_path = os.path.join(tmp, f"nbd_{uuid.uuid4().hex[:8]}.sock")
server = QemuNbdServer(img_path, sock_path, image_size=image_size)
server.start()
transfer_id = f"nbd-{uuid.uuid4().hex[:8]}"
resp = srv["send"]({
"action": "register",
"transfer_id": transfer_id,
"config": {"backend": "nbd", "socket": sock_path},
})
assert resp["status"] == "ok", f"register failed: {resp}"
url = f"{srv['base_url']}/images/{transfer_id}"
def cleanup():
srv["send"]({"action": "unregister", "transfer_id": transfer_id})
server.stop()
for p in (img_path, sock_path):
try:
os.unlink(p)
except FileNotFoundError:
pass
return transfer_id, url, server, cleanup
# ── HTTP helpers ────────────────────────────────────────────────────────
import urllib.request
import urllib.error
def http_get(url, headers=None, timeout=HTTP_TIMEOUT):
req = urllib.request.Request(url, headers=headers or {})
return urllib.request.urlopen(req, timeout=timeout)
def http_put(url, data, headers=None, timeout=HTTP_TIMEOUT):
hdrs = {"Content-Length": str(len(data))}
if headers:
hdrs.update(headers)
req = urllib.request.Request(url, data=data, headers=hdrs, method="PUT")
return urllib.request.urlopen(req, timeout=timeout)
def http_post(url, data=b"", headers=None, timeout=HTTP_TIMEOUT):
hdrs = {}
if headers:
hdrs.update(headers)
req = urllib.request.Request(url, data=data, headers=hdrs, method="POST")
return urllib.request.urlopen(req, timeout=timeout)
def http_options(url, timeout=HTTP_TIMEOUT):
req = urllib.request.Request(url, method="OPTIONS")
return urllib.request.urlopen(req, timeout=timeout)
def http_patch(url, data, headers=None, timeout=HTTP_TIMEOUT):
hdrs = {}
if headers:
hdrs.update(headers)
req = urllib.request.Request(url, data=data, headers=hdrs, method="PATCH")
return urllib.request.urlopen(req, timeout=timeout)
# ── Base TestCase with shared setUp/tearDown ────────────────────────────
class ImageServerTestCase(unittest.TestCase):
"""
Base class for image-server tests.
Ensures the image server is running before any test method.
Subclasses that need a file or NBD transfer should set them up
in setUp() and tear down in tearDown().
"""
@classmethod
def setUpClass(cls):
cls.server = get_image_server()
cls.base_url = cls.server["base_url"]
def ctrl(self, msg):
"""Send a control-socket message; wraps server['send'] to avoid descriptor issues."""
return self.server["send"](msg)
def _make_tmp_image(self, data=None):
return make_tmp_image(data=data)
def _register_file_transfer(self, data=None):
return make_file_transfer(data=data)
def _register_nbd_transfer(self):
return make_nbd_transfer()
@staticmethod
def dump_server_logs():
"""Read any available server stderr and print it for post-mortem debugging."""
if _server_proc is None or _server_proc.stderr is None:
return
try:
if select.select([_server_proc.stderr], [], [], 0)[0]:
data = _server_proc.stderr.read1(64 * 1024)
if data:
sys.stderr.write("\n=== IMAGE SERVER STDERR ===\n")
sys.stderr.write(data.decode(errors="replace"))
sys.stderr.write("\n=== END SERVER STDERR ===\n")
except Exception:
pass

View File

@ -0,0 +1,397 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""
Multi-operation sequences, parallel reads across multiple transfer objects,
cross-backend scenarios, and edge cases.
"""
import json
import logging
import unittest
import urllib.error
from concurrent.futures import ThreadPoolExecutor, as_completed
from .test_base import (
IMAGE_SIZE,
ImageServerTestCase,
http_get,
http_patch,
http_post,
http_put,
make_file_transfer,
make_nbd_transfer,
randbytes,
shutdown_image_server,
test_timeout,
)
log = logging.getLogger(__name__)
FUTURES_TIMEOUT = 60 # seconds for as_completed to collect all results
def _fetch(url, headers=None):
"""GET *url* and return the body bytes, properly closing the response."""
resp = http_get(url, headers=headers)
try:
return resp.read()
finally:
resp.close()
class TestParallelReadsFileBackend(ImageServerTestCase):
"""Multiple concurrent GET requests to multiple file-backed transfers."""
@test_timeout(120)
def test_parallel_reads_single_file_transfer(self):
data = randbytes(500, IMAGE_SIZE)
tid, url, path, cleanup = make_file_transfer(data=data)
try:
results = {}
with ThreadPoolExecutor(max_workers=8) as pool:
futures = {}
for i in range(8):
start = i * (IMAGE_SIZE // 8)
end = start + (IMAGE_SIZE // 8) - 1
f = pool.submit(
_fetch, url, headers={"Range": f"bytes={start}-{end}"}
)
futures[f] = (start, end)
for f in as_completed(futures, timeout=FUTURES_TIMEOUT):
start, end = futures[f]
results[(start, end)] = f.result()
for (start, end), chunk in sorted(results.items()):
self.assertEqual(chunk, data[start:end + 1], f"Mismatch at {start}-{end}")
finally:
cleanup()
@test_timeout(120)
def test_parallel_reads_multiple_file_transfers(self):
"""Parallel reads across 4 different file-backed transfer objects."""
transfers = []
try:
for i in range(4):
data = randbytes(600 + i, IMAGE_SIZE)
tid, url, path, cleanup = make_file_transfer(data=data)
transfers.append((tid, url, data, cleanup))
with ThreadPoolExecutor(max_workers=8) as pool:
futures = {}
for idx, (tid, url, data, _) in enumerate(transfers):
for j in range(2):
f = pool.submit(_fetch, url)
futures[f] = (idx, data)
for f in as_completed(futures, timeout=FUTURES_TIMEOUT):
idx, expected_data = futures[f]
got = f.result()
self.assertEqual(got, expected_data, f"Transfer {idx} mismatch")
finally:
for _, _, _, cleanup in transfers:
cleanup()
class TestParallelReadsNbdBackend(ImageServerTestCase):
"""Multiple concurrent GET requests to multiple NBD-backed transfers."""
@test_timeout(120)
def test_parallel_reads_single_nbd_transfer(self):
data = randbytes(700, IMAGE_SIZE)
tid, url, nbd_server, cleanup = make_nbd_transfer()
try:
log.info("Writing %d bytes to NBD transfer %s", IMAGE_SIZE, tid)
http_put(url, data)
log.info("NBD write done, starting 8 parallel range reads")
results = {}
with ThreadPoolExecutor(max_workers=8) as pool:
futures = {}
for i in range(8):
start = i * (IMAGE_SIZE // 8)
end = start + (IMAGE_SIZE // 8) - 1
f = pool.submit(
_fetch, url, headers={"Range": f"bytes={start}-{end}"}
)
futures[f] = (start, end)
completed = 0
for f in as_completed(futures, timeout=FUTURES_TIMEOUT):
start, end = futures[f]
results[(start, end)] = f.result()
completed += 1
log.info("NBD range read %d/8 done: bytes=%d-%d", completed, start, end)
for (start, end), chunk in sorted(results.items()):
self.assertEqual(chunk, data[start:end + 1], f"Mismatch at {start}-{end}")
finally:
cleanup()
@test_timeout(120)
def test_parallel_reads_multiple_nbd_transfers(self):
"""Parallel reads across 4 different NBD-backed transfer objects."""
transfers = []
try:
for i in range(4):
data = randbytes(800 + i, IMAGE_SIZE)
log.info("Setting up NBD transfer %d", i)
tid, url, nbd_server, cleanup = make_nbd_transfer()
log.info("Writing data to NBD transfer %d (tid=%s)", i, tid)
http_put(url, data)
transfers.append((tid, url, data, cleanup))
log.info("NBD transfer %d ready", i)
log.info("Starting parallel reads across %d NBD transfers", len(transfers))
with ThreadPoolExecutor(max_workers=8) as pool:
futures = {}
for idx, (tid, url, data, _) in enumerate(transfers):
for j in range(2):
f = pool.submit(_fetch, url)
futures[f] = (idx, data)
completed = 0
for f in as_completed(futures, timeout=FUTURES_TIMEOUT):
idx, expected_data = futures[f]
got = f.result()
completed += 1
log.info("Read %d/%d done: NBD transfer idx=%d, %d bytes",
completed, len(futures), idx, len(got))
self.assertEqual(got, expected_data, f"NBD transfer {idx} mismatch")
finally:
for _, _, _, cleanup in transfers:
cleanup()
class TestParallelReadsMixedBackends(ImageServerTestCase):
"""Parallel reads across a mix of file and NBD transfers simultaneously."""
@test_timeout(120)
def test_parallel_reads_file_and_nbd_mixed(self):
transfers = []
try:
for i in range(2):
log.info("Setting up file transfer %d", i)
data = randbytes(900 + i, IMAGE_SIZE)
tid, url, path, cleanup = make_file_transfer(data=data)
transfers.append(("file", tid, url, data, cleanup))
log.info("File transfer %d ready: tid=%s", i, tid)
for i in range(2):
log.info("Setting up NBD transfer %d", i)
data = randbytes(950 + i, IMAGE_SIZE)
tid, url, nbd_server, cleanup = make_nbd_transfer()
log.info("NBD transfer %d registered: tid=%s, writing data...", i, tid)
http_put(url, data)
transfers.append(("nbd", tid, url, data, cleanup))
log.info("NBD transfer %d ready", i)
log.info("Starting parallel reads across %d transfers (2 file + 2 nbd)",
len(transfers))
with ThreadPoolExecutor(max_workers=8) as pool:
futures = {}
for idx, (backend_type, tid, url, data, _) in enumerate(transfers):
for j in range(2):
f = pool.submit(_fetch, url)
futures[f] = (idx, backend_type, data)
completed = 0
for f in as_completed(futures, timeout=FUTURES_TIMEOUT):
idx, backend_type, expected = futures[f]
got = f.result()
completed += 1
log.info("Read %d/%d done: %s transfer idx=%d, %d bytes",
completed, len(futures), backend_type, idx, len(got))
self.assertEqual(got, expected, f"{backend_type} transfer {idx} mismatch")
log.info("All parallel mixed reads completed successfully")
except TimeoutError:
log.error("TIMEOUT in mixed parallel reads — dumping server logs")
self.dump_server_logs()
raise
finally:
for _, _, _, _, cleanup in transfers:
cleanup()
class TestWriteThenReadNbd(ImageServerTestCase):
"""Multi-step write sequences on NBD backend."""
def setUp(self):
self._tid, self._url, self._nbd, self._cleanup = make_nbd_transfer()
def tearDown(self):
self._cleanup()
def test_partial_writes_then_full_read(self):
http_put(self._url, b"\x00" * IMAGE_SIZE)
chunk_size = IMAGE_SIZE // 4
for i in range(4):
offset = i * chunk_size
end = offset + chunk_size - 1
data = bytes([i & 0xFF]) * chunk_size
http_patch(self._url, data, headers={
"Range": f"bytes={offset}-{end}",
"Content-Type": "application/octet-stream",
"Content-Length": str(chunk_size),
})
resp = http_get(self._url)
full = resp.read()
for i in range(4):
offset = i * chunk_size
self.assertEqual(full[offset:offset + chunk_size], bytes([i & 0xFF]) * chunk_size)
def test_zero_then_extents(self):
http_put(self._url, randbytes(1000, IMAGE_SIZE))
payload = json.dumps({"op": "zero", "size": IMAGE_SIZE // 2, "offset": 0}).encode()
http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
resp = http_get(f"{self._url}/extents")
extents = json.loads(resp.read())
total = sum(e["length"] for e in extents)
self.assertEqual(total, IMAGE_SIZE)
def test_write_flush_read(self):
data = randbytes(1001, IMAGE_SIZE)
resp = http_put(f"{self._url}?flush=y", data)
body = json.loads(resp.read())
self.assertTrue(body["flushed"])
resp2 = http_get(self._url)
self.assertEqual(resp2.read(), data)
class TestWriteThenReadFile(ImageServerTestCase):
def setUp(self):
self._tid, self._url, self._path, self._cleanup = make_file_transfer()
def tearDown(self):
self._cleanup()
def test_put_then_get_roundtrip(self):
data = randbytes(1100, IMAGE_SIZE)
http_put(self._url, data)
resp = http_get(self._url)
self.assertEqual(resp.read(), data)
class TestRegisterUseUnregisterUse(ImageServerTestCase):
def test_unregistered_transfer_returns_404(self):
data = randbytes(1200, IMAGE_SIZE)
tid, url, path, cleanup = make_file_transfer(data=data)
resp = http_get(url)
self.assertEqual(resp.read(), data)
cleanup()
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_get(url)
self.assertEqual(ctx.exception.code, 404)
class TestMultipleTransfersSimultaneous(ImageServerTestCase):
@test_timeout(120)
def test_operate_on_file_and_nbd_concurrently(self):
file_data = randbytes(1300, IMAGE_SIZE)
nbd_data = randbytes(1301, IMAGE_SIZE)
ftid, furl, fpath, fcleanup = make_file_transfer(data=file_data)
ntid, nurl, nbd_server, ncleanup = make_nbd_transfer()
try:
log.info("Writing data to NBD transfer %s", ntid)
http_put(nurl, nbd_data)
log.info("Starting concurrent file + NBD reads")
with ThreadPoolExecutor(max_workers=4) as pool:
f_file = pool.submit(_fetch, furl)
f_nbd = pool.submit(_fetch, nurl)
self.assertEqual(f_file.result(timeout=FUTURES_TIMEOUT), file_data)
self.assertEqual(f_nbd.result(timeout=FUTURES_TIMEOUT), nbd_data)
log.info("Concurrent reads completed successfully")
finally:
fcleanup()
ncleanup()
class TestLargeChunkedTransfer(ImageServerTestCase):
def test_put_larger_than_chunk_size_file(self):
"""Upload data that spans multiple CHUNK_SIZE boundaries."""
tid, url, path, cleanup = make_file_transfer()
try:
data = randbytes(1400, IMAGE_SIZE)
http_put(url, data)
resp = http_get(url)
self.assertEqual(resp.read(), data)
finally:
cleanup()
def test_nbd_put_larger_than_chunk_size(self):
tid, url, nbd_server, cleanup = make_nbd_transfer()
try:
data = randbytes(1401, IMAGE_SIZE)
http_put(url, data)
resp = http_get(url)
self.assertEqual(resp.read(), data)
finally:
cleanup()
class TestEdgeCases(ImageServerTestCase):
def test_get_not_found_path(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_get(f"{self.base_url}/not/images/path")
self.assertEqual(ctx.exception.code, 404)
def test_post_unknown_tail(self):
tid, url, path, cleanup = make_file_transfer()
try:
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_post(f"{url}/unknown")
self.assertEqual(ctx.exception.code, 404)
finally:
cleanup()
def test_get_extents_then_flush_nbd(self):
tid, url, nbd_server, cleanup = make_nbd_transfer()
try:
http_put(url, randbytes(1500, IMAGE_SIZE))
resp = http_get(f"{url}/extents")
self.assertEqual(resp.status, 200)
resp.read()
resp2 = http_post(f"{url}/flush")
body = json.loads(resp2.read())
self.assertTrue(body["ok"])
finally:
cleanup()
if __name__ == "__main__":
try:
unittest.main()
finally:
shutdown_image_server()

View File

@ -0,0 +1,258 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for the Unix domain control socket protocol (register / unregister / status)."""
import json
import socket
import unittest
import uuid
from concurrent.futures import ThreadPoolExecutor, as_completed
from .test_base import ImageServerTestCase, make_tmp_image, shutdown_image_server, test_timeout
class TestStatus(ImageServerTestCase):
def test_status_returns_ok(self):
resp = self.ctrl({"action": "status"})
self.assertEqual(resp["status"], "ok")
self.assertIn("active_transfers", resp)
def test_status_count_is_integer(self):
resp = self.ctrl({"action": "status"})
self.assertIsInstance(resp["active_transfers"], int)
self.assertGreaterEqual(resp["active_transfers"], 0)
class TestRegister(ImageServerTestCase):
def test_register_file_backend(self):
img = make_tmp_image()
tid = f"test-{uuid.uuid4().hex[:8]}"
try:
resp = self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "file", "file": img},
})
self.assertEqual(resp["status"], "ok")
self.assertGreaterEqual(resp["active_transfers"], 1)
finally:
self.ctrl({"action": "unregister", "transfer_id": tid})
def test_register_nbd_backend(self):
tid = f"test-{uuid.uuid4().hex[:8]}"
try:
resp = self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "nbd", "socket": "/tmp/fake.sock"},
})
self.assertEqual(resp["status"], "ok")
finally:
self.ctrl({"action": "unregister", "transfer_id": tid})
def test_register_increments_active_count(self):
img = make_tmp_image()
before = self.ctrl({"action": "status"})["active_transfers"]
tid = f"test-{uuid.uuid4().hex[:8]}"
try:
self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "file", "file": img},
})
after = self.ctrl({"action": "status"})["active_transfers"]
self.assertEqual(after, before + 1)
finally:
self.ctrl({"action": "unregister", "transfer_id": tid})
def test_register_missing_transfer_id(self):
img = make_tmp_image()
resp = self.ctrl({
"action": "register",
"config": {"backend": "file", "file": img},
})
self.assertEqual(resp["status"], "error")
def test_register_empty_transfer_id(self):
img = make_tmp_image()
resp = self.ctrl({
"action": "register",
"transfer_id": "",
"config": {"backend": "file", "file": img},
})
self.assertEqual(resp["status"], "error")
def test_register_missing_config(self):
resp = self.ctrl({
"action": "register",
"transfer_id": f"test-{uuid.uuid4().hex[:8]}",
})
self.assertEqual(resp["status"], "error")
def test_register_invalid_backend(self):
resp = self.ctrl({
"action": "register",
"transfer_id": f"test-{uuid.uuid4().hex[:8]}",
"config": {"backend": "invalid"},
})
self.assertEqual(resp["status"], "error")
def test_register_file_missing_path(self):
resp = self.ctrl({
"action": "register",
"transfer_id": f"test-{uuid.uuid4().hex[:8]}",
"config": {"backend": "file"},
})
self.assertEqual(resp["status"], "error")
def test_register_nbd_missing_socket(self):
resp = self.ctrl({
"action": "register",
"transfer_id": f"test-{uuid.uuid4().hex[:8]}",
"config": {"backend": "nbd"},
})
self.assertEqual(resp["status"], "error")
def test_register_path_traversal_rejected(self):
img = make_tmp_image()
resp = self.ctrl({
"action": "register",
"transfer_id": "../etc/passwd",
"config": {"backend": "file", "file": img},
})
self.assertEqual(resp["status"], "error")
def test_register_dot_rejected(self):
img = make_tmp_image()
resp = self.ctrl({
"action": "register",
"transfer_id": ".",
"config": {"backend": "file", "file": img},
})
self.assertEqual(resp["status"], "error")
def test_register_slash_rejected(self):
img = make_tmp_image()
resp = self.ctrl({
"action": "register",
"transfer_id": "a/b",
"config": {"backend": "file", "file": img},
})
self.assertEqual(resp["status"], "error")
def test_register_duplicate_replaces(self):
img = make_tmp_image()
tid = f"test-{uuid.uuid4().hex[:8]}"
try:
self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "file", "file": img},
})
count_before = self.ctrl({"action": "status"})["active_transfers"]
self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "file", "file": img},
})
count_after = self.ctrl({"action": "status"})["active_transfers"]
self.assertEqual(count_after, count_before)
finally:
self.ctrl({"action": "unregister", "transfer_id": tid})
class TestUnregister(ImageServerTestCase):
def test_unregister_existing(self):
img = make_tmp_image()
tid = f"test-{uuid.uuid4().hex[:8]}"
self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "file", "file": img},
})
before = self.ctrl({"action": "status"})["active_transfers"]
resp = self.ctrl({"action": "unregister", "transfer_id": tid})
self.assertEqual(resp["status"], "ok")
self.assertEqual(resp["active_transfers"], before - 1)
def test_unregister_nonexistent(self):
resp = self.ctrl({"action": "unregister", "transfer_id": "does-not-exist"})
self.assertEqual(resp["status"], "ok")
def test_unregister_missing_id(self):
resp = self.ctrl({"action": "unregister"})
self.assertEqual(resp["status"], "error")
class TestUnknownAction(ImageServerTestCase):
def test_unknown_action(self):
resp = self.ctrl({"action": "foobar"})
self.assertEqual(resp["status"], "error")
self.assertIn("unknown", resp.get("message", "").lower())
class TestMalformed(ImageServerTestCase):
def test_malformed_json(self):
sock_path = self.server["ctrl_sock"]
payload = b"not valid json\n"
with socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) as s:
s.settimeout(5)
s.connect(sock_path)
s.sendall(payload)
s.shutdown(socket.SHUT_WR)
data = b""
while True:
chunk = s.recv(4096)
if not chunk:
break
data += chunk
resp = json.loads(data.strip())
self.assertEqual(resp["status"], "error")
class TestConcurrentRegistrations(ImageServerTestCase):
@test_timeout(60)
def test_concurrent_registers(self):
img = make_tmp_image()
tids = [f"conc-{uuid.uuid4().hex[:8]}" for _ in range(20)]
results = []
def register_one(tid):
return self.ctrl({
"action": "register",
"transfer_id": tid,
"config": {"backend": "file", "file": img},
})
try:
with ThreadPoolExecutor(max_workers=10) as pool:
futures = {pool.submit(register_one, tid): tid for tid in tids}
for f in as_completed(futures, timeout=30):
results.append(f.result())
self.assertTrue(all(r["status"] == "ok" for r in results))
finally:
for tid in tids:
self.ctrl({"action": "unregister", "transfer_id": tid})
if __name__ == "__main__":
try:
unittest.main()
finally:
shutdown_image_server()

View File

@ -0,0 +1,230 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for HTTP operations against a file-backend transfer."""
import json
import os
import unittest
import urllib.error
from .test_base import (
IMAGE_SIZE,
ImageServerTestCase,
http_get,
http_options,
http_patch,
http_post,
http_put,
make_file_transfer,
randbytes,
shutdown_image_server,
)
class FileBackendTestCase(ImageServerTestCase):
"""Base that creates a file-backend transfer per test."""
def setUp(self):
self._tid, self._url, self._path, self._cleanup = make_file_transfer()
def tearDown(self):
self._cleanup()
class TestOptions(FileBackendTestCase):
def test_options_returns_features(self):
resp = http_options(self._url)
self.assertEqual(resp.status, 200)
body = json.loads(resp.read())
self.assertIn("flush", body["features"])
self.assertGreaterEqual(body["max_readers"], 1)
self.assertGreaterEqual(body["max_writers"], 1)
def test_options_allowed_methods(self):
resp = http_options(self._url)
methods = resp.getheader("Access-Control-Allow-Methods")
for m in ("GET", "PUT", "POST", "OPTIONS"):
self.assertIn(m, methods)
class TestGetFull(FileBackendTestCase):
def test_get_full_returns_file_content(self):
with open(self._path, "rb") as f:
expected = f.read()
resp = http_get(self._url)
self.assertEqual(resp.status, 200)
data = resp.read()
self.assertEqual(len(data), len(expected))
self.assertEqual(data, expected)
def test_get_full_content_type(self):
resp = http_get(self._url)
resp.read()
self.assertIn("application/octet-stream", resp.getheader("Content-Type"))
def test_get_full_content_length(self):
resp = http_get(self._url)
resp.read()
self.assertEqual(int(resp.getheader("Content-Length")), os.path.getsize(self._path))
class TestGetRange(FileBackendTestCase):
def test_get_range_partial(self):
with open(self._path, "rb") as f:
f.seek(100)
expected = f.read(200)
resp = http_get(self._url, headers={"Range": "bytes=100-299"})
self.assertEqual(resp.status, 206)
self.assertEqual(resp.read(), expected)
def test_get_range_content_range_header(self):
size = os.path.getsize(self._path)
resp = http_get(self._url, headers={"Range": "bytes=0-99"})
self.assertEqual(resp.status, 206)
resp.read()
self.assertEqual(resp.getheader("Content-Range"), f"bytes 0-99/{size}")
def test_get_range_suffix(self):
with open(self._path, "rb") as f:
expected = f.read()[-100:]
resp = http_get(self._url, headers={"Range": "bytes=-100"})
self.assertEqual(resp.status, 206)
self.assertEqual(resp.read(), expected)
def test_get_range_open_ended(self):
with open(self._path, "rb") as f:
f.seek(IMAGE_SIZE - 50)
expected = f.read()
resp = http_get(self._url, headers={"Range": f"bytes={IMAGE_SIZE - 50}-"})
self.assertEqual(resp.status, 206)
self.assertEqual(resp.read(), expected)
def test_get_range_unsatisfiable(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_get(self._url, headers={"Range": f"bytes={IMAGE_SIZE + 100}-{IMAGE_SIZE + 200}"})
self.assertEqual(ctx.exception.code, 416)
class TestPut(FileBackendTestCase):
def test_put_full_upload(self):
new_data = randbytes(99, IMAGE_SIZE)
resp = http_put(self._url, new_data)
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
self.assertEqual(body["bytes_written"], IMAGE_SIZE)
with open(self._path, "rb") as f:
self.assertEqual(f.read(), new_data)
def test_put_with_flush(self):
new_data = randbytes(100, IMAGE_SIZE)
resp = http_put(f"{self._url}?flush=y", new_data)
body = json.loads(resp.read())
self.assertTrue(body["ok"])
self.assertTrue(body["flushed"])
def test_put_verify_by_get(self):
new_data = randbytes(101, IMAGE_SIZE)
http_put(self._url, new_data)
resp = http_get(self._url)
self.assertEqual(resp.read(), new_data)
def test_put_with_content_range_rejected(self):
data = b"x" * 100
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_put(self._url, data, headers={"Content-Range": "bytes 0-99/*"})
self.assertEqual(ctx.exception.code, 400)
def test_put_with_range_header_rejected(self):
data = b"x" * 100
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_put(self._url, data, headers={"Range": "bytes=0-99"})
self.assertEqual(ctx.exception.code, 400)
class TestFlush(FileBackendTestCase):
def test_post_flush(self):
resp = http_post(f"{self._url}/flush")
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
class TestPatchRejected(FileBackendTestCase):
def test_patch_rejected_for_file(self):
data = json.dumps({"op": "zero", "size": 100}).encode()
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_patch(self._url, data, headers={
"Content-Type": "application/json",
"Content-Length": str(len(data)),
})
self.assertEqual(ctx.exception.code, 400)
class TestExtentsRejected(FileBackendTestCase):
def test_extents_rejected_for_file(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_get(f"{self._url}/extents")
self.assertEqual(ctx.exception.code, 400)
class TestUnknownImage(ImageServerTestCase):
def test_get_unknown_image(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_get(f"{self.base_url}/images/nonexistent-id")
self.assertEqual(ctx.exception.code, 404)
def test_put_unknown_image(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_put(f"{self.base_url}/images/nonexistent-id", b"data")
self.assertEqual(ctx.exception.code, 404)
def test_options_unknown_image(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_options(f"{self.base_url}/images/nonexistent-id")
self.assertEqual(ctx.exception.code, 404)
class TestRoundTrip(FileBackendTestCase):
def test_put_then_get_roundtrip(self):
payload = randbytes(200, IMAGE_SIZE)
http_put(self._url, payload)
resp = http_get(self._url)
self.assertEqual(resp.read(), payload)
def test_put_then_ranged_get_roundtrip(self):
payload = randbytes(201, IMAGE_SIZE)
http_put(self._url, payload)
resp = http_get(self._url, headers={"Range": "bytes=512-1023"})
self.assertEqual(resp.read(), payload[512:1024])
def test_multiple_puts_last_wins(self):
first = randbytes(300, IMAGE_SIZE)
second = randbytes(301, IMAGE_SIZE)
http_put(self._url, first)
http_put(self._url, second)
resp = http_get(self._url)
self.assertEqual(resp.read(), second)
if __name__ == "__main__":
try:
unittest.main()
finally:
shutdown_image_server()

View File

@ -0,0 +1,393 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Tests for HTTP operations against an NBD-backend transfer (real qemu-nbd)."""
import json
import unittest
import urllib.error
import urllib.request
from .test_base import (
IMAGE_SIZE,
ImageServerTestCase,
http_get,
http_options,
http_patch,
http_post,
http_put,
make_nbd_transfer,
randbytes,
shutdown_image_server,
)
class NbdBackendTestCase(ImageServerTestCase):
"""Base that creates an NBD-backend transfer per test."""
def setUp(self):
self._tid, self._url, self._nbd, self._cleanup = make_nbd_transfer()
def tearDown(self):
self._cleanup()
class TestOptions(NbdBackendTestCase):
def test_options_returns_extents_feature(self):
resp = http_options(self._url)
self.assertEqual(resp.status, 200)
body = json.loads(resp.read())
self.assertIn("extents", body["features"])
def test_options_includes_patch_method(self):
resp = http_options(self._url)
methods = resp.getheader("Access-Control-Allow-Methods")
self.assertIn("PATCH", methods)
def test_options_has_capabilities(self):
resp = http_options(self._url)
body = json.loads(resp.read())
self.assertGreaterEqual(body["max_readers"], 1)
self.assertGreaterEqual(body["max_writers"], 1)
class TestGetFull(NbdBackendTestCase):
def test_get_full_returns_image_data(self):
with open(self._nbd.image_path, "rb") as f:
expected = f.read()
resp = http_get(self._url)
data = resp.read()
self.assertEqual(resp.status, 200)
self.assertEqual(len(data), len(expected))
self.assertEqual(data, expected)
def test_get_full_content_length(self):
resp = http_get(self._url)
resp.read()
self.assertEqual(int(resp.getheader("Content-Length")), IMAGE_SIZE)
class TestGetRange(NbdBackendTestCase):
def test_get_range_partial(self):
test_data = randbytes(50, IMAGE_SIZE)
http_put(self._url, test_data)
resp = http_get(self._url, headers={"Range": "bytes=100-299"})
self.assertEqual(resp.status, 206)
self.assertEqual(resp.read(), test_data[100:300])
def test_get_range_content_range_header(self):
resp = http_get(self._url, headers={"Range": "bytes=0-99"})
self.assertEqual(resp.status, 206)
resp.read()
self.assertEqual(resp.getheader("Content-Range"), f"bytes 0-99/{IMAGE_SIZE}")
def test_get_range_suffix(self):
test_data = randbytes(51, IMAGE_SIZE)
http_put(self._url, test_data)
resp = http_get(self._url, headers={"Range": "bytes=-100"})
self.assertEqual(resp.status, 206)
self.assertEqual(resp.read(), test_data[-100:])
def test_get_range_unsatisfiable(self):
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_get(self._url, headers={"Range": f"bytes={IMAGE_SIZE + 100}-{IMAGE_SIZE + 200}"})
self.assertEqual(ctx.exception.code, 416)
class TestPutFull(NbdBackendTestCase):
def test_put_full_upload(self):
new_data = randbytes(60, IMAGE_SIZE)
resp = http_put(self._url, new_data)
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
self.assertEqual(body["bytes_written"], IMAGE_SIZE)
resp2 = http_get(self._url)
self.assertEqual(resp2.read(), new_data)
def test_put_with_flush(self):
new_data = randbytes(61, IMAGE_SIZE)
resp = http_put(f"{self._url}?flush=y", new_data)
body = json.loads(resp.read())
self.assertTrue(body["ok"])
self.assertTrue(body["flushed"])
class TestPutRange(NbdBackendTestCase):
def test_put_content_range(self):
base_data = randbytes(70, IMAGE_SIZE)
http_put(self._url, base_data)
patch_data = b"\xAB" * 512
resp = http_put(self._url, patch_data, headers={
"Content-Range": "bytes 0-511/*",
"Content-Length": str(len(patch_data)),
})
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
self.assertEqual(body["bytes_written"], 512)
resp2 = http_get(self._url, headers={"Range": "bytes=0-511"})
self.assertEqual(resp2.read(), patch_data)
resp3 = http_get(self._url, headers={"Range": "bytes=512-1023"})
self.assertEqual(resp3.read(), base_data[512:1024])
def test_put_content_range_with_flush(self):
base_data = b"\x00" * IMAGE_SIZE
http_put(self._url, base_data)
patch_data = b"\xFF" * 256
resp = http_put(f"{self._url}?flush=y", patch_data, headers={
"Content-Range": "bytes 1024-1279/*",
"Content-Length": str(len(patch_data)),
})
body = json.loads(resp.read())
self.assertTrue(body["ok"])
self.assertTrue(body["flushed"])
class TestPatchRange(NbdBackendTestCase):
def test_patch_binary_range(self):
base_data = randbytes(80, IMAGE_SIZE)
http_put(self._url, base_data)
patch_data = b"\xCD" * 1024
resp = http_patch(self._url, patch_data, headers={
"Range": "bytes=2048-3071",
"Content-Type": "application/octet-stream",
"Content-Length": str(len(patch_data)),
})
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
self.assertEqual(body["bytes_written"], 1024)
resp2 = http_get(self._url, headers={"Range": "bytes=2048-3071"})
self.assertEqual(resp2.read(), patch_data)
def test_patch_multiple_ranges_preserves_unwritten(self):
base_data = randbytes(81, IMAGE_SIZE)
http_put(self._url, base_data)
patch1 = b"\x11" * 256
http_patch(self._url, patch1, headers={
"Range": "bytes=0-255",
"Content-Type": "application/octet-stream",
"Content-Length": "256",
})
patch2 = b"\x22" * 256
http_patch(self._url, patch2, headers={
"Range": "bytes=512-767",
"Content-Type": "application/octet-stream",
"Content-Length": "256",
})
resp = http_get(self._url, headers={"Range": "bytes=0-767"})
got = resp.read()
self.assertEqual(got[:256], patch1)
self.assertEqual(got[256:512], base_data[256:512])
self.assertEqual(got[512:768], patch2)
class TestPatchZero(NbdBackendTestCase):
def test_patch_zero(self):
data = randbytes(90, IMAGE_SIZE)
http_put(self._url, data)
payload = json.dumps({"op": "zero", "size": 4096, "offset": 0}).encode()
resp = http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
resp2 = http_get(self._url, headers={"Range": "bytes=0-4095"})
self.assertEqual(resp2.read(), b"\x00" * 4096)
def test_patch_zero_with_flush(self):
data = b"\xFF" * IMAGE_SIZE
http_put(self._url, data)
payload = json.dumps({"op": "zero", "size": 512, "offset": 1024, "flush": True}).encode()
resp = http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
body = json.loads(resp.read())
self.assertTrue(body["ok"])
resp2 = http_get(self._url, headers={"Range": "bytes=1024-1535"})
self.assertEqual(resp2.read(), b"\x00" * 512)
def test_patch_zero_preserves_neighbors(self):
data = randbytes(91, IMAGE_SIZE)
http_put(self._url, data)
payload = json.dumps({"op": "zero", "size": 256, "offset": 512}).encode()
http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
resp = http_get(self._url, headers={"Range": "bytes=0-1023"})
got = resp.read()
self.assertEqual(got[:512], data[:512])
self.assertEqual(got[512:768], b"\x00" * 256)
self.assertEqual(got[768:1024], data[768:1024])
class TestPatchFlush(NbdBackendTestCase):
def test_patch_flush_op(self):
payload = json.dumps({"op": "flush"}).encode()
resp = http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
class TestPostFlush(NbdBackendTestCase):
def test_post_flush(self):
resp = http_post(f"{self._url}/flush")
body = json.loads(resp.read())
self.assertEqual(resp.status, 200)
self.assertTrue(body["ok"])
class TestExtents(NbdBackendTestCase):
def test_get_allocation_extents(self):
resp = http_get(f"{self._url}/extents")
self.assertEqual(resp.status, 200)
extents = json.loads(resp.read())
self.assertIsInstance(extents, list)
self.assertGreaterEqual(len(extents), 1)
for ext in extents:
self.assertIn("start", ext)
self.assertIn("length", ext)
self.assertIn("zero", ext)
def test_extents_cover_full_image(self):
resp = http_get(f"{self._url}/extents")
extents = json.loads(resp.read())
total = sum(e["length"] for e in extents)
self.assertEqual(total, IMAGE_SIZE)
def test_extents_dirty_context_without_bitmap(self):
resp = http_get(f"{self._url}/extents?context=dirty")
self.assertEqual(resp.status, 200)
extents = json.loads(resp.read())
self.assertIsInstance(extents, list)
self.assertGreaterEqual(len(extents), 1)
for ext in extents:
self.assertIn("dirty", ext)
self.assertTrue(ext["dirty"])
def test_extents_after_write_and_zero(self):
http_put(self._url, randbytes(95, IMAGE_SIZE))
payload = json.dumps({"op": "zero", "size": 4096, "offset": 0}).encode()
http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
resp = http_get(f"{self._url}/extents")
extents = json.loads(resp.read())
self.assertGreaterEqual(len(extents), 1)
total = sum(e["length"] for e in extents)
self.assertEqual(total, IMAGE_SIZE)
class TestErrorCases(NbdBackendTestCase):
def test_patch_unsupported_op(self):
payload = json.dumps({"op": "invalid"}).encode()
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
self.assertEqual(ctx.exception.code, 400)
def test_patch_zero_missing_size(self):
payload = json.dumps({"op": "zero", "offset": 0}).encode()
with self.assertRaises(urllib.error.HTTPError) as ctx:
http_patch(self._url, payload, headers={
"Content-Type": "application/json",
"Content-Length": str(len(payload)),
})
self.assertEqual(ctx.exception.code, 400)
def test_put_missing_content_length(self):
import http.client
from urllib.parse import urlparse
parsed = urlparse(self._url)
conn = http.client.HTTPConnection(parsed.hostname, parsed.port, timeout=30)
try:
conn.putrequest("PUT", parsed.path)
conn.endheaders()
resp = conn.getresponse()
self.assertEqual(resp.status, 400)
finally:
conn.close()
class TestRoundTrip(NbdBackendTestCase):
def test_write_read_full_roundtrip(self):
payload = randbytes(110, IMAGE_SIZE)
http_put(self._url, payload)
resp = http_get(self._url)
self.assertEqual(resp.read(), payload)
def test_write_read_range_roundtrip(self):
payload = randbytes(111, IMAGE_SIZE)
http_put(self._url, payload)
for start, end in [(0, 255), (1024, 2047), (IMAGE_SIZE - 512, IMAGE_SIZE - 1)]:
resp = http_get(self._url, headers={"Range": f"bytes={start}-{end}"})
self.assertEqual(resp.read(), payload[start:end + 1])
def test_range_write_read_roundtrip(self):
http_put(self._url, b"\x00" * IMAGE_SIZE)
chunk = randbytes(112, 4096)
http_put(self._url, chunk, headers={
"Content-Range": "bytes 8192-12287/*",
"Content-Length": str(len(chunk)),
})
resp = http_get(self._url, headers={"Range": "bytes=8192-12287"})
self.assertEqual(resp.read(), chunk)
resp2 = http_get(self._url, headers={"Range": "bytes=0-4095"})
self.assertEqual(resp2.read(), b"\x00" * 4096)
if __name__ == "__main__":
try:
unittest.main()
finally:
shutdown_image_server()