usb: change api so that every packet sent is crc32c, update python usb api, add automated tests for usb.
This commit is contained in:
33
.github/workflows/python-usb-export.yml
vendored
Normal file
33
.github/workflows/python-usb-export.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: USB Export Python Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
paths: &python_usb_export_paths
|
||||
- 'tools/test_usb_export.py'
|
||||
- 'tools/usb_export.py'
|
||||
- 'tools/usb_common.py'
|
||||
- 'tools/requirements.txt'
|
||||
- '.github/workflows/python-usb-export.yml'
|
||||
pull_request:
|
||||
paths: *python_usb_export_paths
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r tools/requirements.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
python3 tools/test_usb_export.py
|
||||
33
.github/workflows/python-usb-install.yml
vendored
Normal file
33
.github/workflows/python-usb-install.yml
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
name: USB Install Python Tests
|
||||
|
||||
on:
|
||||
push:
|
||||
paths: &python_usb_install_paths
|
||||
- 'tools/test_usb_install.py'
|
||||
- 'tools/usb_install.py'
|
||||
- 'tools/usb_common.py'
|
||||
- 'tools/requirements.txt'
|
||||
- '.github/workflows/python-usb-install.yml'
|
||||
pull_request:
|
||||
paths: *python_usb_install_paths
|
||||
|
||||
jobs:
|
||||
test:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Python 3.11
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.11'
|
||||
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -r tools/requirements.txt
|
||||
|
||||
- name: Run tests
|
||||
run: |
|
||||
python3 tools/test_usb_install.py
|
||||
@@ -7,7 +7,7 @@ namespace sphaira::usb::api {
|
||||
|
||||
enum : u32 {
|
||||
MAGIC = 0x53504830,
|
||||
PACKET_SIZE = 16,
|
||||
PACKET_SIZE = 24,
|
||||
};
|
||||
|
||||
enum : u32 {
|
||||
@@ -26,39 +26,87 @@ enum : u32 {
|
||||
FLAG_STREAM = 1 << 0,
|
||||
};
|
||||
|
||||
struct SendHeader {
|
||||
u32 magic;
|
||||
u32 cmd;
|
||||
u32 arg3;
|
||||
u32 arg4;
|
||||
struct UsbPacket {
|
||||
u32 magic{};
|
||||
u32 arg2{};
|
||||
u32 arg3{};
|
||||
u32 arg4{};
|
||||
u32 arg5{};
|
||||
u32 crc32c{}; // crc32 over the above 16 bytes.
|
||||
|
||||
protected:
|
||||
u32 CalculateCrc32c() const {
|
||||
return crc32cCalculate(this, 20);
|
||||
}
|
||||
|
||||
void GenerateCrc32c() {
|
||||
crc32c = CalculateCrc32c();
|
||||
}
|
||||
|
||||
Result Verify() const {
|
||||
R_UNLESS(crc32c == CalculateCrc32c(), 1); // todo: add error code.
|
||||
R_UNLESS(magic == MAGIC, Result_UsbBadMagic);
|
||||
R_SUCCEED();
|
||||
}
|
||||
};
|
||||
|
||||
struct ResultHeader {
|
||||
u32 magic;
|
||||
u32 result;
|
||||
u32 arg3;
|
||||
u32 arg4;
|
||||
struct SendPacket : UsbPacket {
|
||||
static SendPacket Build(u32 cmd, u32 arg3 = 0, u32 arg4 = 0) {
|
||||
SendPacket packet{MAGIC, cmd, arg3, arg4};
|
||||
packet.GenerateCrc32c();
|
||||
return packet;
|
||||
}
|
||||
|
||||
Result Verify() const {
|
||||
R_UNLESS(magic == MAGIC, Result_UsbBadMagic);
|
||||
R_UNLESS(result == RESULT_OK, 1); // todo: create error code.
|
||||
return UsbPacket::Verify();
|
||||
}
|
||||
|
||||
u32 GetCmd() const {
|
||||
return arg2;
|
||||
}
|
||||
};
|
||||
|
||||
struct ResultPacket : UsbPacket {
|
||||
static ResultPacket Build(u32 result, u32 arg3 = 0, u32 arg4 = 0) {
|
||||
ResultPacket packet{MAGIC, result, arg3, arg4};
|
||||
packet.GenerateCrc32c();
|
||||
return packet;
|
||||
}
|
||||
|
||||
Result Verify() const {
|
||||
R_TRY(UsbPacket::Verify());
|
||||
R_UNLESS(arg2 == RESULT_OK, 1); // todo: create error code.
|
||||
R_SUCCEED();
|
||||
}
|
||||
};
|
||||
|
||||
struct SendDataHeader {
|
||||
u64 offset;
|
||||
u32 size;
|
||||
u32 crc32c;
|
||||
struct SendDataPacket : UsbPacket {
|
||||
static SendDataPacket Build(u64 off, u32 size, u32 crc32c) {
|
||||
SendDataPacket packet{MAGIC, u32(off >> 32), u32(off), size, crc32c};
|
||||
packet.GenerateCrc32c();
|
||||
return packet;
|
||||
}
|
||||
|
||||
Result Verify() const {
|
||||
return UsbPacket::Verify();
|
||||
}
|
||||
|
||||
u64 GetOffset() const {
|
||||
return (u64(arg2) << 32) | arg3;
|
||||
}
|
||||
|
||||
u32 GetSize() const {
|
||||
return arg4;
|
||||
}
|
||||
|
||||
u32 GetCrc32c() const {
|
||||
return arg5;
|
||||
}
|
||||
};
|
||||
|
||||
static_assert(sizeof(SendHeader) == PACKET_SIZE);
|
||||
static_assert(sizeof(ResultHeader) == PACKET_SIZE);
|
||||
static_assert(sizeof(SendDataHeader) == PACKET_SIZE);
|
||||
static_assert(sizeof(UsbPacket) == PACKET_SIZE);
|
||||
static_assert(sizeof(SendPacket) == PACKET_SIZE);
|
||||
static_assert(sizeof(ResultPacket) == PACKET_SIZE);
|
||||
static_assert(sizeof(SendDataPacket) == PACKET_SIZE);
|
||||
|
||||
} // namespace sphaira::usb::api
|
||||
|
||||
@@ -25,8 +25,8 @@ struct Usb {
|
||||
Result CloseFile();
|
||||
|
||||
private:
|
||||
Result SendAndVerify(const void* data, u32 size, u64 timeout, api::ResultHeader* out = nullptr);
|
||||
Result SendAndVerify(const void* data, u32 size, api::ResultHeader* out = nullptr);
|
||||
Result SendAndVerify(const void* data, u32 size, u64 timeout, api::ResultPacket* out = nullptr);
|
||||
Result SendAndVerify(const void* data, u32 size, api::ResultPacket* out = nullptr);
|
||||
|
||||
private:
|
||||
std::unique_ptr<usb::UsbDs> m_usb{};
|
||||
|
||||
@@ -26,8 +26,8 @@ struct Usb {
|
||||
Result CloseFile();
|
||||
|
||||
private:
|
||||
Result SendAndVerify(const void* data, u32 size, u64 timeout, api::ResultHeader* out = nullptr);
|
||||
Result SendAndVerify(const void* data, u32 size, api::ResultHeader* out = nullptr);
|
||||
Result SendAndVerify(const void* data, u32 size, u64 timeout, api::ResultPacket* out = nullptr);
|
||||
Result SendAndVerify(const void* data, u32 size, api::ResultPacket* out = nullptr);
|
||||
|
||||
private:
|
||||
std::unique_ptr<usb::UsbDs> m_usb{};
|
||||
|
||||
@@ -19,7 +19,7 @@ Usb::Usb(u64 transfer_timeout) {
|
||||
|
||||
Usb::~Usb() {
|
||||
if (m_was_connected && R_SUCCEEDED(m_usb->IsUsbConnected(0))) {
|
||||
SendHeader send_header{MAGIC, CMD_QUIT};
|
||||
const auto send_header = SendPacket::Build(CMD_QUIT);
|
||||
SendAndVerify(&send_header, sizeof(send_header));
|
||||
}
|
||||
}
|
||||
@@ -35,7 +35,7 @@ Result Usb::WaitForConnection(std::string_view path, u64 timeout) {
|
||||
R_TRY(m_open_result);
|
||||
R_TRY(m_usb->IsUsbConnected(timeout));
|
||||
|
||||
SendHeader send_header{MAGIC, CMD_EXPORT, (u32)path.length()};
|
||||
const auto send_header = SendPacket::Build(CMD_EXPORT, path.length());
|
||||
R_TRY(SendAndVerify(&send_header, sizeof(send_header), timeout));
|
||||
R_TRY(SendAndVerify(path.data(), path.length(), timeout));
|
||||
|
||||
@@ -44,7 +44,7 @@ Result Usb::WaitForConnection(std::string_view path, u64 timeout) {
|
||||
}
|
||||
|
||||
Result Usb::CloseFile() {
|
||||
SendDataHeader send_header{0, 0};
|
||||
const auto send_header = SendDataPacket::Build(0, 0, 0);
|
||||
|
||||
return SendAndVerify(&send_header, sizeof(send_header));
|
||||
}
|
||||
@@ -54,17 +54,18 @@ void Usb::SignalCancel() {
|
||||
}
|
||||
|
||||
Result Usb::Write(const void* buf, u64 off, u32 size) {
|
||||
SendDataHeader send_header{off, size, crc32cCalculate(buf, size)};
|
||||
const auto send_header = SendDataPacket::Build(off, size, crc32cCalculate(buf, size));
|
||||
|
||||
R_TRY(SendAndVerify(&send_header, sizeof(send_header)));
|
||||
return SendAndVerify(buf, size);
|
||||
}
|
||||
|
||||
// casts away const, but it does not modify the buffer!
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, u64 timeout, ResultHeader* out) {
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, u64 timeout, ResultPacket* out) {
|
||||
R_TRY(m_usb->TransferAll(false, const_cast<void*>(data), size, timeout));
|
||||
|
||||
ResultHeader recv_header;
|
||||
|
||||
ResultPacket recv_header;
|
||||
R_TRY(m_usb->TransferAll(true, &recv_header, sizeof(recv_header), timeout));
|
||||
R_TRY(recv_header.Verify());
|
||||
|
||||
@@ -72,7 +73,7 @@ Result Usb::SendAndVerify(const void* data, u32 size, u64 timeout, ResultHeader*
|
||||
R_SUCCEED();
|
||||
}
|
||||
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, ResultHeader* out) {
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, ResultPacket* out) {
|
||||
return SendAndVerify(data, size, m_usb->GetTransferTimeout(), out);
|
||||
}
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ Usb::Usb(u64 transfer_timeout) {
|
||||
|
||||
Usb::~Usb() {
|
||||
if (m_was_connected && R_SUCCEEDED(m_usb->IsUsbConnected(0))) {
|
||||
SendHeader send_header{MAGIC, CMD_QUIT};
|
||||
const auto send_header = SendPacket::Build(CMD_QUIT);
|
||||
SendAndVerify(&send_header, sizeof(send_header));
|
||||
}
|
||||
}
|
||||
@@ -35,8 +35,8 @@ Result Usb::WaitForConnection(u64 timeout, std::vector<std::string>& out_names)
|
||||
R_TRY(m_open_result);
|
||||
R_TRY(m_usb->IsUsbConnected(timeout));
|
||||
|
||||
SendHeader send_header{MAGIC, RESULT_OK};
|
||||
ResultHeader recv_header;
|
||||
const auto send_header = SendPacket::Build(RESULT_OK);
|
||||
ResultPacket recv_header;
|
||||
R_TRY(SendAndVerify(&send_header, sizeof(send_header), timeout, &recv_header))
|
||||
|
||||
std::vector<char> names(recv_header.arg3);
|
||||
@@ -50,15 +50,14 @@ Result Usb::WaitForConnection(u64 timeout, std::vector<std::string>& out_names)
|
||||
}
|
||||
}
|
||||
|
||||
m_flags = recv_header.arg4;
|
||||
m_was_connected = true;
|
||||
R_SUCCEED();
|
||||
}
|
||||
|
||||
Result Usb::OpenFile(u32 index, s64& file_size) {
|
||||
log_write("doing open file\n");
|
||||
SendHeader send_header{MAGIC, CMD_OPEN, index};
|
||||
ResultHeader recv_header;
|
||||
const auto send_header = SendPacket::Build(CMD_OPEN, index);
|
||||
ResultPacket recv_header;
|
||||
R_TRY(SendAndVerify(&send_header, sizeof(send_header), &recv_header))
|
||||
log_write("did open file\n");
|
||||
|
||||
@@ -72,7 +71,7 @@ Result Usb::OpenFile(u32 index, s64& file_size) {
|
||||
}
|
||||
|
||||
Result Usb::CloseFile() {
|
||||
SendDataHeader send_header{0, 0};
|
||||
const auto send_header = SendDataPacket::Build(0, 0, 0);
|
||||
|
||||
return SendAndVerify(&send_header, sizeof(send_header));
|
||||
}
|
||||
@@ -86,8 +85,8 @@ u32 Usb::GetFlags() const {
|
||||
}
|
||||
|
||||
Result Usb::Read(void* buf, u64 off, u32 size, u64* bytes_read) {
|
||||
SendDataHeader send_header{off, size};
|
||||
ResultHeader recv_header;
|
||||
const auto send_header = SendDataPacket::Build(off, size, 0);
|
||||
ResultPacket recv_header;
|
||||
R_TRY(SendAndVerify(&send_header, sizeof(send_header), &recv_header))
|
||||
|
||||
// adjust the size and read the data.
|
||||
@@ -102,10 +101,10 @@ Result Usb::Read(void* buf, u64 off, u32 size, u64* bytes_read) {
|
||||
}
|
||||
|
||||
// casts away const, but it does not modify the buffer!
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, u64 timeout, ResultHeader* out) {
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, u64 timeout, ResultPacket* out) {
|
||||
R_TRY(m_usb->TransferAll(false, const_cast<void*>(data), size, timeout));
|
||||
|
||||
ResultHeader recv_header;
|
||||
ResultPacket recv_header;
|
||||
R_TRY(m_usb->TransferAll(true, &recv_header, sizeof(recv_header), timeout));
|
||||
R_TRY(recv_header.Verify());
|
||||
|
||||
@@ -113,7 +112,7 @@ Result Usb::SendAndVerify(const void* data, u32 size, u64 timeout, ResultHeader*
|
||||
R_SUCCEED();
|
||||
}
|
||||
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, ResultHeader* out) {
|
||||
Result Usb::SendAndVerify(const void* data, u32 size, ResultPacket* out) {
|
||||
return SendAndVerify(data, size, m_usb->GetTransferTimeout(), out);
|
||||
}
|
||||
|
||||
|
||||
@@ -53,7 +53,7 @@ Result Usb::WaitForConnection(u64 timeout, std::span<const std::string> names) {
|
||||
}
|
||||
|
||||
// send.
|
||||
SendHeader send_header;
|
||||
SendPacket send_header;
|
||||
R_TRY(m_usb->TransferAll(true, &send_header, sizeof(send_header), timeout));
|
||||
R_TRY(send_header.Verify());
|
||||
|
||||
@@ -67,14 +67,14 @@ Result Usb::WaitForConnection(u64 timeout, std::span<const std::string> names) {
|
||||
}
|
||||
|
||||
Result Usb::PollCommands() {
|
||||
SendHeader send_header;
|
||||
SendPacket send_header;
|
||||
R_TRY(m_usb->TransferAll(true, &send_header, sizeof(send_header)));
|
||||
R_TRY(send_header.Verify());
|
||||
|
||||
if (send_header.cmd == CMD_QUIT) {
|
||||
if (send_header.GetCmd() == CMD_QUIT) {
|
||||
R_TRY(SendResult(RESULT_OK));
|
||||
R_THROW(Result_UsbUploadExit);
|
||||
} else if (send_header.cmd == CMD_OPEN) {
|
||||
} else if (send_header.GetCmd() == CMD_OPEN) {
|
||||
s64 file_size;
|
||||
u16 flags;
|
||||
R_TRY(Open(send_header.arg3, file_size, flags));
|
||||
@@ -92,11 +92,11 @@ Result Usb::file_transfer_loop() {
|
||||
log_write("doing file transfer\n");
|
||||
|
||||
// get offset + size.
|
||||
SendDataHeader send_header;
|
||||
SendDataPacket send_header;
|
||||
R_TRY(m_usb->TransferAll(true, &send_header, sizeof(send_header)));
|
||||
|
||||
// check if we should finish now.
|
||||
if (send_header.offset == 0 && send_header.size == 0) {
|
||||
if (send_header.GetOffset() == 0 && send_header.GetSize() == 0) {
|
||||
log_write("finished\n");
|
||||
R_TRY(SendResult(RESULT_OK));
|
||||
return Result_UsbUploadExit;
|
||||
@@ -104,10 +104,10 @@ Result Usb::file_transfer_loop() {
|
||||
|
||||
// read file and calculate the hash.
|
||||
u64 bytes_read;
|
||||
m_buf.resize(send_header.size);
|
||||
m_buf.resize(send_header.GetSize());
|
||||
log_write("reading buffer: %zu\n", m_buf.size());
|
||||
|
||||
R_TRY(Read(m_buf.data(), send_header.offset, m_buf.size(), &bytes_read));
|
||||
R_TRY(Read(m_buf.data(), send_header.GetOffset(), m_buf.size(), &bytes_read));
|
||||
const auto crc32 = crc32Calculate(m_buf.data(), m_buf.size());
|
||||
|
||||
log_write("read the buffer: %zu\n", bytes_read);
|
||||
@@ -125,7 +125,7 @@ Result Usb::file_transfer_loop() {
|
||||
}
|
||||
|
||||
Result Usb::SendResult(u32 result, u32 arg3, u32 arg4) {
|
||||
ResultHeader recv_header{MAGIC, result, arg3, arg4};
|
||||
auto recv_header = api::ResultPacket::Build(result, arg3, arg4);
|
||||
return m_usb->TransferAll(false, &recv_header, sizeof(recv_header));
|
||||
}
|
||||
|
||||
|
||||
114
tools/test_usb_export.py
Normal file
114
tools/test_usb_export.py
Normal file
@@ -0,0 +1,114 @@
|
||||
import unittest
|
||||
import crc32c
|
||||
import os
|
||||
|
||||
from usb_common import CMD_EXPORT, CMD_QUIT, RESULT_OK, RESULT_ERROR
|
||||
|
||||
class FakeUsb:
|
||||
def __init__(self, files=None):
|
||||
# files: list of tuples (filename: str, data: bytes)
|
||||
self.files = files or [("testfile.bin", b"testdata")]
|
||||
self._cmd_index = 0
|
||||
self._file_index = 0
|
||||
self._data_index = 0
|
||||
self.results = []
|
||||
self._reading_filename = True
|
||||
self._reading_data = False
|
||||
self._current_data = b""
|
||||
self._current_data_offset = 0
|
||||
self._current_data_sent = 0
|
||||
self._current_file = None
|
||||
self._send_data_header_calls = 0
|
||||
|
||||
def wait_for_connect(self):
|
||||
pass
|
||||
|
||||
def get_send_header(self):
|
||||
# Simulate command sequence: export for each file, then quit
|
||||
if self._cmd_index < len(self.files):
|
||||
filename, data = self.files[self._cmd_index]
|
||||
self._current_file = (filename, data)
|
||||
self._cmd_index += 1
|
||||
self._reading_filename = True
|
||||
self._reading_data = False
|
||||
self._current_data = data
|
||||
self._current_data_offset = 0
|
||||
self._current_data_sent = 0
|
||||
self._send_data_header_calls = 0
|
||||
return [CMD_EXPORT, len(filename.encode("utf-8")), 0]
|
||||
else:
|
||||
return [CMD_QUIT, 0, 0]
|
||||
|
||||
def read(self, size):
|
||||
# Simulate reading file name or data
|
||||
if self._reading_filename:
|
||||
filename = self._current_file[0].encode("utf-8")
|
||||
self._reading_filename = False
|
||||
self._reading_data = True
|
||||
return filename[:size]
|
||||
elif self._reading_data:
|
||||
# Return file data for export
|
||||
data = self._current_data[self._current_data_sent:self._current_data_sent+size]
|
||||
self._current_data_sent += len(data)
|
||||
return data
|
||||
else:
|
||||
return b""
|
||||
|
||||
def get_send_data_header(self):
|
||||
# Simulate sending data in one chunk, then finish
|
||||
if self._send_data_header_calls == 0:
|
||||
self._send_data_header_calls += 1
|
||||
data = self._current_data
|
||||
crc = crc32c.crc32c(data)
|
||||
return [0, len(data), crc]
|
||||
else:
|
||||
return [0, 0, 0] # End of transfer
|
||||
|
||||
def send_result(self, result):
|
||||
self.results.append(result)
|
||||
|
||||
# test case for usb_export.py
|
||||
class TestUsbExport(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.root = "test_output"
|
||||
os.makedirs(self.root, exist_ok=True)
|
||||
# 100 files named test1.bin, test2.bin, ..., test10.bin, each with different sizes
|
||||
self.files = [
|
||||
(f"test{i+1}.bin", bytes([65 + i]) * (i * 100 + 1)) for i in range(100)
|
||||
]
|
||||
self.fake_usb = FakeUsb(files=self.files)
|
||||
|
||||
def tearDown(self):
|
||||
# Clean up created files/folders
|
||||
for f in os.listdir(self.root):
|
||||
os.remove(os.path.join(self.root, f))
|
||||
os.rmdir(self.root)
|
||||
|
||||
def test_export_multiple_files(self):
|
||||
from usb_export import get_file_name, create_file_folder, wait_for_input
|
||||
|
||||
# Simulate the main loop for all files
|
||||
for filename, data in self.files:
|
||||
cmd, name_len, _ = self.fake_usb.get_send_header()
|
||||
self.assertEqual(cmd, CMD_EXPORT)
|
||||
|
||||
file_name = get_file_name(self.fake_usb, name_len)
|
||||
self.assertEqual(file_name, filename)
|
||||
|
||||
full_path = create_file_folder(self.root, file_name)
|
||||
self.fake_usb.send_result(RESULT_OK)
|
||||
|
||||
wait_for_input(self.fake_usb, full_path)
|
||||
|
||||
# Check file was created and contents match
|
||||
with open(full_path, "rb") as f:
|
||||
filedata = f.read()
|
||||
self.assertEqual(filedata, data)
|
||||
|
||||
# After all files, should get CMD_QUIT
|
||||
cmd, _, _ = self.fake_usb.get_send_header()
|
||||
self.assertEqual(cmd, CMD_QUIT)
|
||||
self.assertIn(RESULT_OK, self.fake_usb.results)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
103
tools/test_usb_install.py
Normal file
103
tools/test_usb_install.py
Normal file
@@ -0,0 +1,103 @@
|
||||
import unittest
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import crc32c
|
||||
|
||||
from usb_common import RESULT_OK
|
||||
|
||||
# Simpler FakeUsb for file transfer
|
||||
class FakeUsb:
|
||||
def __init__(self, filedata):
|
||||
self.filedata = filedata # bytes
|
||||
self.results = []
|
||||
self.writes = []
|
||||
self._send_data_header_calls = 0
|
||||
self._current_data = filedata
|
||||
|
||||
def wait_for_connect(self):
|
||||
pass
|
||||
|
||||
def get_send_data_header(self):
|
||||
# Simulate sending the file in one chunk, then finish
|
||||
if self._send_data_header_calls == 0:
|
||||
self._send_data_header_calls += 1
|
||||
data = self._current_data
|
||||
crc = crc32c.crc32c(data)
|
||||
return [0, len(data), crc]
|
||||
else:
|
||||
return [0, 0, 0]
|
||||
|
||||
def send_result(self, result, arg2=0, arg3=0):
|
||||
self.results.append((result, arg2, arg3))
|
||||
|
||||
def write(self, data):
|
||||
self.writes.append(data)
|
||||
|
||||
class TestUsbInstall(unittest.TestCase):
|
||||
def setUp(self):
|
||||
import random
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
# 100 files named test1.nsp, test2.xci, test3.nsz, test4.xcz, ..., cycling extensions, each with random sizes (0-2048 bytes)
|
||||
extensions = ["nsp", "xci", "nsz", "xcz"]
|
||||
self.files = [
|
||||
(f"test{i+1}.{extensions[i % 4]}", os.urandom(random.randint(0, 2048))) for i in range(100)
|
||||
]
|
||||
self.filepaths = []
|
||||
|
||||
for fname, data in self.files:
|
||||
fpath = os.path.join(self.tempdir, fname)
|
||||
with open(fpath, "wb") as f:
|
||||
f.write(data)
|
||||
self.filepaths.append(fpath)
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.tempdir)
|
||||
|
||||
def test_multiple_file_install(self):
|
||||
from usb_install import add_file_to_install_list, paths, wait_for_input
|
||||
paths.clear()
|
||||
|
||||
for fpath in self.filepaths:
|
||||
add_file_to_install_list(fpath)
|
||||
|
||||
for idx, (fname, data) in enumerate(self.files):
|
||||
fake_usb = FakeUsb(data)
|
||||
wait_for_input(fake_usb, idx)
|
||||
|
||||
# Check that the file on disk matches expected data
|
||||
with open(self.filepaths[idx], "rb") as f:
|
||||
filedata = f.read()
|
||||
self.assertEqual(filedata, data)
|
||||
found = False
|
||||
|
||||
for result, arg2, arg3 in fake_usb.results:
|
||||
if result == RESULT_OK and arg2 == len(data):
|
||||
found = True
|
||||
self.assertTrue(found)
|
||||
|
||||
def test_directory_install(self):
|
||||
from usb_install import add_file_to_install_list, paths, wait_for_input
|
||||
paths.clear()
|
||||
|
||||
for fpath in self.filepaths:
|
||||
add_file_to_install_list(fpath)
|
||||
|
||||
for idx, (fname, data) in enumerate(self.files):
|
||||
fake_usb = FakeUsb(data)
|
||||
wait_for_input(fake_usb, idx)
|
||||
|
||||
# Check that the file on disk matches expected data
|
||||
with open(self.filepaths[idx], "rb") as f:
|
||||
filedata = f.read()
|
||||
|
||||
self.assertEqual(filedata, data)
|
||||
found = False
|
||||
|
||||
for result, arg2, arg3 in fake_usb.results:
|
||||
if result == RESULT_OK and arg2 == len(data):
|
||||
found = True
|
||||
self.assertTrue(found)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -2,9 +2,11 @@ import struct
|
||||
import usb.core
|
||||
import usb.util
|
||||
import time
|
||||
import crc32c
|
||||
|
||||
# magic number (SPH0) for the script and switch.
|
||||
MAGIC = 0x53504830
|
||||
PACKET_SIZE = 24
|
||||
|
||||
# commands
|
||||
CMD_QUIT = 0
|
||||
@@ -19,14 +21,85 @@ RESULT_ERROR = 1
|
||||
FLAG_NONE = 0
|
||||
FLAG_STREAM = 1 << 0
|
||||
|
||||
# disabled, see usbds.cpp usbDsEndpoint_SetZlt
|
||||
ENABLE_ZLT = 0
|
||||
class UsbPacket:
|
||||
STRUCT_FORMAT = "<6I" # 6 unsigned 32-bit ints, little-endian
|
||||
|
||||
def __init__(self, magic=MAGIC, arg2=0, arg3=0, arg4=0, arg5=0, crc32c_val=0):
|
||||
self.magic = magic
|
||||
self.arg2 = arg2
|
||||
self.arg3 = arg3
|
||||
self.arg4 = arg4
|
||||
self.arg5 = arg5
|
||||
self.crc32c = crc32c_val
|
||||
|
||||
def pack(self):
|
||||
self.generate_crc32c()
|
||||
return struct.pack(self.STRUCT_FORMAT, self.magic, self.arg2, self.arg3, self.arg4, self.arg5, self.crc32c)
|
||||
|
||||
@classmethod
|
||||
def unpack(cls, data):
|
||||
fields = struct.unpack(cls.STRUCT_FORMAT, data)
|
||||
return cls(*fields)
|
||||
|
||||
def calculate_crc32c(self):
|
||||
data = struct.pack("<5I", self.magic, self.arg2, self.arg3, self.arg4, self.arg5)
|
||||
return crc32c.crc32c(data)
|
||||
|
||||
def generate_crc32c(self):
|
||||
self.crc32c = self.calculate_crc32c()
|
||||
|
||||
def verify(self):
|
||||
if self.crc32c != self.calculate_crc32c():
|
||||
raise ValueError("CRC32C mismatch")
|
||||
if self.magic != MAGIC:
|
||||
raise ValueError("Bad magic")
|
||||
return True
|
||||
|
||||
class SendPacket(UsbPacket):
|
||||
@classmethod
|
||||
def build(cls, cmd, arg3=0, arg4=0):
|
||||
packet = cls(MAGIC, cmd, arg3, arg4)
|
||||
packet.generate_crc32c()
|
||||
return packet
|
||||
|
||||
def get_cmd(self):
|
||||
return self.arg2
|
||||
|
||||
class ResultPacket(UsbPacket):
|
||||
@classmethod
|
||||
def build(cls, result, arg3=0, arg4=0):
|
||||
packet = cls(MAGIC, result, arg3, arg4)
|
||||
packet.generate_crc32c()
|
||||
return packet
|
||||
|
||||
def verify(self):
|
||||
super().verify()
|
||||
if self.arg2 != RESULT_OK:
|
||||
raise ValueError("Result not OK")
|
||||
return True
|
||||
|
||||
class SendDataPacket(UsbPacket):
|
||||
@classmethod
|
||||
def build(cls, offset, size, crc32c_val):
|
||||
arg2 = (offset >> 32) & 0xFFFFFFFF
|
||||
arg3 = offset & 0xFFFFFFFF
|
||||
packet = cls(MAGIC, arg2, arg3, size, crc32c_val)
|
||||
packet.generate_crc32c()
|
||||
return packet
|
||||
|
||||
def get_offset(self):
|
||||
return (self.arg2 << 32) | self.arg3
|
||||
|
||||
def get_size(self):
|
||||
return self.arg4
|
||||
|
||||
def get_crc32c(self):
|
||||
return self.arg5
|
||||
|
||||
class Usb:
|
||||
def __init__(self):
|
||||
self.__out_ep = None
|
||||
self.__in_ep = None
|
||||
self.__packet_size = 0
|
||||
|
||||
def wait_for_connect(self) -> None:
|
||||
print("waiting for switch")
|
||||
@@ -61,29 +134,23 @@ class Usb:
|
||||
|
||||
print("iManufacturer: {} iProduct: {} iSerialNumber: {}".format(dev.manufacturer, dev.product, dev.serial_number))
|
||||
print("bcdUSB: {} bMaxPacketSize0: {}".format(hex(dev.bcdUSB), dev.bMaxPacketSize0))
|
||||
self.__packet_size = 1 << dev.bMaxPacketSize0
|
||||
|
||||
def read(self, size: int, timeout: int = 0) -> bytes:
|
||||
if (ENABLE_ZLT and size and (size % self.__packet_size) == 0):
|
||||
size += 1
|
||||
return self.__in_ep.read(size, timeout)
|
||||
|
||||
def write(self, buf: bytes, timeout: int = 0) -> int:
|
||||
return self.__out_ep.write(data=buf, timeout=timeout)
|
||||
|
||||
def get_send_header(self) -> tuple[int, int, int]:
|
||||
header = self.read(16)
|
||||
[magic, arg2, arg3, arg4] = struct.unpack('<IIII', header)
|
||||
|
||||
if magic != MAGIC:
|
||||
raise Exception("Unexpected magic {}".format(magic))
|
||||
|
||||
return arg2, arg3, arg4
|
||||
packet = SendPacket.unpack(self.read(PACKET_SIZE))
|
||||
packet.verify()
|
||||
return packet.get_cmd(), packet.arg3, packet.arg4
|
||||
|
||||
def get_send_data_header(self) -> tuple[int, int, int]:
|
||||
header = self.read(16)
|
||||
return struct.unpack('<QII', header)
|
||||
packet = SendDataPacket.unpack(self.read(PACKET_SIZE))
|
||||
packet.verify()
|
||||
return packet.get_offset(), packet.get_size(), packet.get_crc32c()
|
||||
|
||||
def send_result(self, result: int, arg3: int = 0, arg4: int = 0) -> None:
|
||||
send_data = struct.pack('<IIII', MAGIC, result, arg3, arg4)
|
||||
send_data = ResultPacket.build(result, arg3, arg4).pack()
|
||||
self.write(send_data)
|
||||
|
||||
@@ -62,7 +62,7 @@ if __name__ == '__main__':
|
||||
if (not os.path.isdir(root_path)):
|
||||
raise ValueError('must be a dir!')
|
||||
|
||||
usb: Usb = Usb()
|
||||
usb = Usb()
|
||||
|
||||
try:
|
||||
# get usb endpoints.
|
||||
|
||||
@@ -127,14 +127,14 @@ if __name__ == '__main__':
|
||||
else:
|
||||
raise ValueError('must be a file!')
|
||||
|
||||
usb: Usb = Usb()
|
||||
usb = Usb()
|
||||
|
||||
try:
|
||||
# get usb endpoints.
|
||||
usb.wait_for_connect()
|
||||
|
||||
# build string table.
|
||||
string_table: bytes
|
||||
string_table = bytes()
|
||||
for [_, path] in paths:
|
||||
string_table += bytes(Path(path).name.__str__(), 'utf8') + b'\n'
|
||||
|
||||
|
||||
Reference in New Issue
Block a user