usb: change api so that every packet sent is crc32c, update python usb api, add automated tests for usb.

This commit is contained in:
ITotalJustice
2025-08-31 06:12:02 +01:00
parent b6b1af5959
commit 22e965521a
13 changed files with 469 additions and 71 deletions

114
tools/test_usb_export.py Normal file
View File

@@ -0,0 +1,114 @@
import unittest
import crc32c
import os
from usb_common import CMD_EXPORT, CMD_QUIT, RESULT_OK, RESULT_ERROR
class FakeUsb:
def __init__(self, files=None):
# files: list of tuples (filename: str, data: bytes)
self.files = files or [("testfile.bin", b"testdata")]
self._cmd_index = 0
self._file_index = 0
self._data_index = 0
self.results = []
self._reading_filename = True
self._reading_data = False
self._current_data = b""
self._current_data_offset = 0
self._current_data_sent = 0
self._current_file = None
self._send_data_header_calls = 0
def wait_for_connect(self):
pass
def get_send_header(self):
# Simulate command sequence: export for each file, then quit
if self._cmd_index < len(self.files):
filename, data = self.files[self._cmd_index]
self._current_file = (filename, data)
self._cmd_index += 1
self._reading_filename = True
self._reading_data = False
self._current_data = data
self._current_data_offset = 0
self._current_data_sent = 0
self._send_data_header_calls = 0
return [CMD_EXPORT, len(filename.encode("utf-8")), 0]
else:
return [CMD_QUIT, 0, 0]
def read(self, size):
# Simulate reading file name or data
if self._reading_filename:
filename = self._current_file[0].encode("utf-8")
self._reading_filename = False
self._reading_data = True
return filename[:size]
elif self._reading_data:
# Return file data for export
data = self._current_data[self._current_data_sent:self._current_data_sent+size]
self._current_data_sent += len(data)
return data
else:
return b""
def get_send_data_header(self):
# Simulate sending data in one chunk, then finish
if self._send_data_header_calls == 0:
self._send_data_header_calls += 1
data = self._current_data
crc = crc32c.crc32c(data)
return [0, len(data), crc]
else:
return [0, 0, 0] # End of transfer
def send_result(self, result):
self.results.append(result)
# test case for usb_export.py
class TestUsbExport(unittest.TestCase):
def setUp(self):
self.root = "test_output"
os.makedirs(self.root, exist_ok=True)
# 100 files named test1.bin, test2.bin, ..., test10.bin, each with different sizes
self.files = [
(f"test{i+1}.bin", bytes([65 + i]) * (i * 100 + 1)) for i in range(100)
]
self.fake_usb = FakeUsb(files=self.files)
def tearDown(self):
# Clean up created files/folders
for f in os.listdir(self.root):
os.remove(os.path.join(self.root, f))
os.rmdir(self.root)
def test_export_multiple_files(self):
from usb_export import get_file_name, create_file_folder, wait_for_input
# Simulate the main loop for all files
for filename, data in self.files:
cmd, name_len, _ = self.fake_usb.get_send_header()
self.assertEqual(cmd, CMD_EXPORT)
file_name = get_file_name(self.fake_usb, name_len)
self.assertEqual(file_name, filename)
full_path = create_file_folder(self.root, file_name)
self.fake_usb.send_result(RESULT_OK)
wait_for_input(self.fake_usb, full_path)
# Check file was created and contents match
with open(full_path, "rb") as f:
filedata = f.read()
self.assertEqual(filedata, data)
# After all files, should get CMD_QUIT
cmd, _, _ = self.fake_usb.get_send_header()
self.assertEqual(cmd, CMD_QUIT)
self.assertIn(RESULT_OK, self.fake_usb.results)
if __name__ == "__main__":
unittest.main()

103
tools/test_usb_install.py Normal file
View File

@@ -0,0 +1,103 @@
import unittest
import tempfile
import os
import shutil
import crc32c
from usb_common import RESULT_OK
# Simpler FakeUsb for file transfer
class FakeUsb:
def __init__(self, filedata):
self.filedata = filedata # bytes
self.results = []
self.writes = []
self._send_data_header_calls = 0
self._current_data = filedata
def wait_for_connect(self):
pass
def get_send_data_header(self):
# Simulate sending the file in one chunk, then finish
if self._send_data_header_calls == 0:
self._send_data_header_calls += 1
data = self._current_data
crc = crc32c.crc32c(data)
return [0, len(data), crc]
else:
return [0, 0, 0]
def send_result(self, result, arg2=0, arg3=0):
self.results.append((result, arg2, arg3))
def write(self, data):
self.writes.append(data)
class TestUsbInstall(unittest.TestCase):
def setUp(self):
import random
self.tempdir = tempfile.mkdtemp()
# 100 files named test1.nsp, test2.xci, test3.nsz, test4.xcz, ..., cycling extensions, each with random sizes (0-2048 bytes)
extensions = ["nsp", "xci", "nsz", "xcz"]
self.files = [
(f"test{i+1}.{extensions[i % 4]}", os.urandom(random.randint(0, 2048))) for i in range(100)
]
self.filepaths = []
for fname, data in self.files:
fpath = os.path.join(self.tempdir, fname)
with open(fpath, "wb") as f:
f.write(data)
self.filepaths.append(fpath)
def tearDown(self):
shutil.rmtree(self.tempdir)
def test_multiple_file_install(self):
from usb_install import add_file_to_install_list, paths, wait_for_input
paths.clear()
for fpath in self.filepaths:
add_file_to_install_list(fpath)
for idx, (fname, data) in enumerate(self.files):
fake_usb = FakeUsb(data)
wait_for_input(fake_usb, idx)
# Check that the file on disk matches expected data
with open(self.filepaths[idx], "rb") as f:
filedata = f.read()
self.assertEqual(filedata, data)
found = False
for result, arg2, arg3 in fake_usb.results:
if result == RESULT_OK and arg2 == len(data):
found = True
self.assertTrue(found)
def test_directory_install(self):
from usb_install import add_file_to_install_list, paths, wait_for_input
paths.clear()
for fpath in self.filepaths:
add_file_to_install_list(fpath)
for idx, (fname, data) in enumerate(self.files):
fake_usb = FakeUsb(data)
wait_for_input(fake_usb, idx)
# Check that the file on disk matches expected data
with open(self.filepaths[idx], "rb") as f:
filedata = f.read()
self.assertEqual(filedata, data)
found = False
for result, arg2, arg3 in fake_usb.results:
if result == RESULT_OK and arg2 == len(data):
found = True
self.assertTrue(found)
if __name__ == "__main__":
unittest.main()

View File

@@ -2,9 +2,11 @@ import struct
import usb.core
import usb.util
import time
import crc32c
# magic number (SPH0) for the script and switch.
MAGIC = 0x53504830
PACKET_SIZE = 24
# commands
CMD_QUIT = 0
@@ -19,14 +21,85 @@ RESULT_ERROR = 1
FLAG_NONE = 0
FLAG_STREAM = 1 << 0
# disabled, see usbds.cpp usbDsEndpoint_SetZlt
ENABLE_ZLT = 0
class UsbPacket:
STRUCT_FORMAT = "<6I" # 6 unsigned 32-bit ints, little-endian
def __init__(self, magic=MAGIC, arg2=0, arg3=0, arg4=0, arg5=0, crc32c_val=0):
self.magic = magic
self.arg2 = arg2
self.arg3 = arg3
self.arg4 = arg4
self.arg5 = arg5
self.crc32c = crc32c_val
def pack(self):
self.generate_crc32c()
return struct.pack(self.STRUCT_FORMAT, self.magic, self.arg2, self.arg3, self.arg4, self.arg5, self.crc32c)
@classmethod
def unpack(cls, data):
fields = struct.unpack(cls.STRUCT_FORMAT, data)
return cls(*fields)
def calculate_crc32c(self):
data = struct.pack("<5I", self.magic, self.arg2, self.arg3, self.arg4, self.arg5)
return crc32c.crc32c(data)
def generate_crc32c(self):
self.crc32c = self.calculate_crc32c()
def verify(self):
if self.crc32c != self.calculate_crc32c():
raise ValueError("CRC32C mismatch")
if self.magic != MAGIC:
raise ValueError("Bad magic")
return True
class SendPacket(UsbPacket):
@classmethod
def build(cls, cmd, arg3=0, arg4=0):
packet = cls(MAGIC, cmd, arg3, arg4)
packet.generate_crc32c()
return packet
def get_cmd(self):
return self.arg2
class ResultPacket(UsbPacket):
@classmethod
def build(cls, result, arg3=0, arg4=0):
packet = cls(MAGIC, result, arg3, arg4)
packet.generate_crc32c()
return packet
def verify(self):
super().verify()
if self.arg2 != RESULT_OK:
raise ValueError("Result not OK")
return True
class SendDataPacket(UsbPacket):
@classmethod
def build(cls, offset, size, crc32c_val):
arg2 = (offset >> 32) & 0xFFFFFFFF
arg3 = offset & 0xFFFFFFFF
packet = cls(MAGIC, arg2, arg3, size, crc32c_val)
packet.generate_crc32c()
return packet
def get_offset(self):
return (self.arg2 << 32) | self.arg3
def get_size(self):
return self.arg4
def get_crc32c(self):
return self.arg5
class Usb:
def __init__(self):
self.__out_ep = None
self.__in_ep = None
self.__packet_size = 0
def wait_for_connect(self) -> None:
print("waiting for switch")
@@ -61,29 +134,23 @@ class Usb:
print("iManufacturer: {} iProduct: {} iSerialNumber: {}".format(dev.manufacturer, dev.product, dev.serial_number))
print("bcdUSB: {} bMaxPacketSize0: {}".format(hex(dev.bcdUSB), dev.bMaxPacketSize0))
self.__packet_size = 1 << dev.bMaxPacketSize0
def read(self, size: int, timeout: int = 0) -> bytes:
if (ENABLE_ZLT and size and (size % self.__packet_size) == 0):
size += 1
return self.__in_ep.read(size, timeout)
def write(self, buf: bytes, timeout: int = 0) -> int:
return self.__out_ep.write(data=buf, timeout=timeout)
def get_send_header(self) -> tuple[int, int, int]:
header = self.read(16)
[magic, arg2, arg3, arg4] = struct.unpack('<IIII', header)
if magic != MAGIC:
raise Exception("Unexpected magic {}".format(magic))
return arg2, arg3, arg4
packet = SendPacket.unpack(self.read(PACKET_SIZE))
packet.verify()
return packet.get_cmd(), packet.arg3, packet.arg4
def get_send_data_header(self) -> tuple[int, int, int]:
header = self.read(16)
return struct.unpack('<QII', header)
packet = SendDataPacket.unpack(self.read(PACKET_SIZE))
packet.verify()
return packet.get_offset(), packet.get_size(), packet.get_crc32c()
def send_result(self, result: int, arg3: int = 0, arg4: int = 0) -> None:
send_data = struct.pack('<IIII', MAGIC, result, arg3, arg4)
send_data = ResultPacket.build(result, arg3, arg4).pack()
self.write(send_data)

View File

@@ -62,7 +62,7 @@ if __name__ == '__main__':
if (not os.path.isdir(root_path)):
raise ValueError('must be a dir!')
usb: Usb = Usb()
usb = Usb()
try:
# get usb endpoints.

View File

@@ -127,14 +127,14 @@ if __name__ == '__main__':
else:
raise ValueError('must be a file!')
usb: Usb = Usb()
usb = Usb()
try:
# get usb endpoints.
usb.wait_for_connect()
# build string table.
string_table: bytes
string_table = bytes()
for [_, path] in paths:
string_table += bytes(Path(path).name.__str__(), 'utf8') + b'\n'