htc: implement much of worker receive logic
This commit is contained in:
@@ -38,6 +38,46 @@ namespace ams::htclow::mux {
|
||||
}
|
||||
}
|
||||
|
||||
Result Mux::CheckReceivedHeader(const PacketHeader &header) const {
|
||||
/* Check the packet signature. */
|
||||
AMS_ASSERT(header.signature == HtcGen2Signature);
|
||||
|
||||
/* Switch on the packet type. */
|
||||
switch (header.packet_type) {
|
||||
case PacketType_Data:
|
||||
R_UNLESS(header.version == m_version, htclow::ResultProtocolError());
|
||||
R_UNLESS(header.body_size <= sizeof(PacketBody), htclow::ResultProtocolError());
|
||||
break;
|
||||
case PacketType_MaxData:
|
||||
R_UNLESS(header.version == m_version, htclow::ResultProtocolError());
|
||||
R_UNLESS(header.body_size == 0, htclow::ResultProtocolError());
|
||||
break;
|
||||
case PacketType_Error:
|
||||
R_UNLESS(header.body_size == 0, htclow::ResultProtocolError());
|
||||
break;
|
||||
AMS_UNREACHABLE_DEFAULT_CASE();
|
||||
}
|
||||
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
Result Mux::ProcessReceivePacket(const PacketHeader &header, const void *body, size_t body_size) {
|
||||
/* Lock ourselves. */
|
||||
std::scoped_lock lk(m_mutex);
|
||||
|
||||
/* Process for the channel. */
|
||||
if (m_channel_impl_map.Exists(header.channel)) {
|
||||
R_TRY(this->CheckChannelExist(header.channel));
|
||||
|
||||
return m_channel_impl_map[header.channel].ProcessReceivePacket(header, body, body_size);
|
||||
} else {
|
||||
if (header.packet_type == PacketType_Data || header.packet_type == PacketType_MaxData) {
|
||||
this->SendErrorPacket(header.channel);
|
||||
}
|
||||
return htclow::ResultChannelNotExist();
|
||||
}
|
||||
}
|
||||
|
||||
void Mux::UpdateChannelState() {
|
||||
/* Lock ourselves. */
|
||||
std::scoped_lock lk(m_mutex);
|
||||
@@ -62,4 +102,14 @@ namespace ams::htclow::mux {
|
||||
}
|
||||
}
|
||||
|
||||
Result Mux::CheckChannelExist(impl::ChannelInternalType channel) {
|
||||
R_UNLESS(m_channel_impl_map.Exists(channel), htclow::ResultChannelNotExist());
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
Result Mux::SendErrorPacket(impl::ChannelInternalType channel) {
|
||||
/* TODO */
|
||||
AMS_ABORT("Mux::SendErrorPacket");
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -37,8 +37,15 @@ namespace ams::htclow::mux {
|
||||
|
||||
void SetVersion(u16 version);
|
||||
|
||||
Result CheckReceivedHeader(const PacketHeader &header) const;
|
||||
Result ProcessReceivePacket(const PacketHeader &header, const void *body, size_t body_size);
|
||||
|
||||
void UpdateChannelState();
|
||||
void UpdateMuxState();
|
||||
private:
|
||||
Result CheckChannelExist(impl::ChannelInternalType channel);
|
||||
|
||||
Result SendErrorPacket(impl::ChannelInternalType channel);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -28,6 +28,110 @@ namespace ams::htclow::mux {
|
||||
m_send_buffer.SetVersion(version);
|
||||
}
|
||||
|
||||
Result ChannelImpl::CheckState(std::initializer_list<ChannelState> states) const {
|
||||
/* Determine if we have a matching state. */
|
||||
bool match = false;
|
||||
for (const auto &state : states) {
|
||||
match |= m_state == state;
|
||||
}
|
||||
|
||||
/* If we do, we're good. */
|
||||
R_SUCCEED_IF(match);
|
||||
|
||||
/* Otherwise, return appropriate failure error. */
|
||||
if (m_state == ChannelState_Disconnected) {
|
||||
return htclow::ResultInvalidChannelStateDisconnected();
|
||||
} else {
|
||||
return htclow::ResultInvalidChannelState();
|
||||
}
|
||||
}
|
||||
|
||||
Result ChannelImpl::CheckPacketVersion(s16 version) const {
|
||||
R_UNLESS(version == m_version, htclow::ResultChannelVersionNotMatched());
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
|
||||
Result ChannelImpl::ProcessReceivePacket(const PacketHeader &header, const void *body, size_t body_size) {
|
||||
switch (header.packet_type) {
|
||||
case PacketType_Data:
|
||||
return this->ProcessReceiveDataPacket(header.version, header.share, header.offset, body, body_size);
|
||||
case PacketType_MaxData:
|
||||
return this->ProcessReceiveMaxDataPacket(header.version, header.share);
|
||||
case PacketType_Error:
|
||||
return this->ProcessReceiveErrorPacket();
|
||||
default:
|
||||
return htclow::ResultProtocolError();
|
||||
}
|
||||
}
|
||||
|
||||
Result ChannelImpl::ProcessReceiveDataPacket(s16 version, u64 share, u32 offset, const void *body, size_t body_size) {
|
||||
/* Check our state. */
|
||||
R_TRY(this->CheckState({ChannelState_Connectable, ChannelState_Connected}));
|
||||
|
||||
/* Check the packet version. */
|
||||
R_TRY(this->CheckPacketVersion(version));
|
||||
|
||||
/* Check that offset matches. */
|
||||
R_UNLESS(offset == static_cast<u32>(m_offset), htclow::ResultProtocolError());
|
||||
|
||||
/* Check for flow control, if we should. */
|
||||
if (m_config.flow_control_enabled) {
|
||||
/* Check that the share increases monotonically. */
|
||||
if (m_share.has_value()) {
|
||||
R_UNLESS(m_share.value() <= share, htclow::ResultProtocolError());
|
||||
}
|
||||
|
||||
/* Update our share. */
|
||||
m_share = share;
|
||||
|
||||
/* Signal our event. */
|
||||
this->SignalSendPacketEvent();
|
||||
}
|
||||
|
||||
/* Update our offset. */
|
||||
m_offset += body_size;
|
||||
|
||||
/* Write the packet body. */
|
||||
R_ABORT_UNLESS(m_receive_buffer.Write(body, body_size));
|
||||
|
||||
/* Notify the data was received. */
|
||||
m_task_manager->NotifyReceiveData(m_channel, m_receive_buffer.GetDataSize());
|
||||
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
Result ChannelImpl::ProcessReceiveMaxDataPacket(s16 version, u64 share) {
|
||||
/* Check our state. */
|
||||
R_TRY(this->CheckState({ChannelState_Connectable, ChannelState_Connected}));
|
||||
|
||||
/* Check the packet version. */
|
||||
R_TRY(this->CheckPacketVersion(version));
|
||||
|
||||
/* Check for flow control, if we should. */
|
||||
if (m_config.flow_control_enabled) {
|
||||
/* Check that the share increases monotonically. */
|
||||
if (m_share.has_value()) {
|
||||
R_UNLESS(m_share.value() <= share, htclow::ResultProtocolError());
|
||||
}
|
||||
|
||||
/* Update our share. */
|
||||
m_share = share;
|
||||
|
||||
/* Signal our event. */
|
||||
this->SignalSendPacketEvent();
|
||||
}
|
||||
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
Result ChannelImpl::ProcessReceiveErrorPacket() {
|
||||
if (m_state == ChannelState_Connected || m_state == ChannelState_Disconnected) {
|
||||
this->ShutdownForce();
|
||||
}
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
void ChannelImpl::UpdateState() {
|
||||
/* Check if shutdown must be forced. */
|
||||
if (m_state_machine->IsUnsupportedServiceChannelToShutdown(m_channel)) {
|
||||
@@ -83,4 +187,10 @@ namespace ams::htclow::mux {
|
||||
}
|
||||
}
|
||||
|
||||
void ChannelImpl::SignalSendPacketEvent() {
|
||||
if (m_event != nullptr) {
|
||||
m_event->Signal();
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -42,9 +42,10 @@ namespace ams::htclow::mux {
|
||||
SendBuffer m_send_buffer;
|
||||
RingBuffer m_receive_buffer;
|
||||
s16 m_version;
|
||||
/* TODO: Channel config */
|
||||
ChannelConfig m_config;
|
||||
/* TODO: tracking variables. */
|
||||
std::optional<u64> m_108;
|
||||
u64 m_offset;
|
||||
std::optional<u64> m_share;
|
||||
os::Event m_state_change_event;
|
||||
ChannelState m_state;
|
||||
public:
|
||||
@@ -52,11 +53,22 @@ namespace ams::htclow::mux {
|
||||
|
||||
void SetVersion(s16 version);
|
||||
|
||||
Result ProcessReceivePacket(const PacketHeader &header, const void *body, size_t body_size);
|
||||
|
||||
void UpdateState();
|
||||
private:
|
||||
void ShutdownForce();
|
||||
void SetState(ChannelState state);
|
||||
void SetStateWithoutCheck(ChannelState state);
|
||||
|
||||
void SignalSendPacketEvent();
|
||||
|
||||
Result CheckState(std::initializer_list<ChannelState> states) const;
|
||||
Result CheckPacketVersion(s16 version) const;
|
||||
|
||||
Result ProcessReceiveDataPacket(s16 version, u64 share, u32 offset, const void *body, size_t body_size);
|
||||
Result ProcessReceiveMaxDataPacket(s16 version, u64 share);
|
||||
Result ProcessReceiveErrorPacket();
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -41,6 +41,10 @@ namespace ams::htclow::mux {
|
||||
ChannelImplMap(PacketFactory *pf, ctrl::HtcctrlStateMachine *sm, TaskManager *tm, os::Event *ev);
|
||||
|
||||
ChannelImpl &GetChannelImpl(impl::ChannelInternalType channel);
|
||||
|
||||
bool Exists(impl::ChannelInternalType channel) const {
|
||||
return m_map.find(channel) != m_map.end();
|
||||
}
|
||||
private:
|
||||
ChannelImpl &GetChannelImpl(int index);
|
||||
public:
|
||||
|
||||
@@ -0,0 +1,48 @@
|
||||
/*
|
||||
* Copyright (c) 2018-2020 Atmosphère-NX
|
||||
*
|
||||
* This program is free software; you can redistribute it and/or modify it
|
||||
* under the terms and conditions of the GNU General Public License,
|
||||
* version 2, as published by the Free Software Foundation.
|
||||
*
|
||||
* This program is distributed in the hope it will be useful, but WITHOUT
|
||||
* ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
|
||||
* FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
|
||||
* more details.
|
||||
*
|
||||
* You should have received a copy of the GNU General Public License
|
||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
#include <stratosphere.hpp>
|
||||
#include "htclow_mux_ring_buffer.hpp"
|
||||
|
||||
namespace ams::htclow::mux {
|
||||
|
||||
Result RingBuffer::Write(const void *data, size_t size) {
|
||||
/* Validate pre-conditions. */
|
||||
AMS_ASSERT(!m_is_read_only);
|
||||
|
||||
/* Check that our buffer can hold the data. */
|
||||
R_UNLESS(m_buffer != nullptr, htclow::ResultChannelBufferOverflow());
|
||||
R_UNLESS(m_data_size + size <= m_buffer_size, htclow::ResultChannelBufferOverflow());
|
||||
|
||||
/* Determine position and copy sizes. */
|
||||
const size_t pos = (m_data_size + m_offset) % m_buffer_size;
|
||||
const size_t left = m_buffer_size - pos;
|
||||
const size_t over = size - left;
|
||||
|
||||
/* Copy. */
|
||||
if (left != 0) {
|
||||
std::memcpy(static_cast<u8 *>(m_buffer) + pos, data, left);
|
||||
}
|
||||
if (over != 0) {
|
||||
std::memcpy(m_buffer, static_cast<const u8 *>(data) + left, over);
|
||||
}
|
||||
|
||||
/* Update our data size. */
|
||||
m_data_size += size;
|
||||
|
||||
return ResultSuccess();
|
||||
}
|
||||
|
||||
}
|
||||
@@ -29,6 +29,10 @@ namespace ams::htclow::mux {
|
||||
bool m_has_copied;
|
||||
public:
|
||||
RingBuffer() : m_buffer(), m_read_only_buffer(), m_is_read_only(true), m_buffer_size(), m_data_size(), m_offset(), m_has_copied(false) { /* ... */ }
|
||||
|
||||
size_t GetDataSize() { return m_data_size; }
|
||||
|
||||
Result Write(const void *data, size_t size);
|
||||
};
|
||||
|
||||
}
|
||||
|
||||
@@ -26,6 +26,14 @@ namespace ams::htclow::mux {
|
||||
}
|
||||
}
|
||||
|
||||
void TaskManager::NotifyReceiveData(impl::ChannelInternalType channel, size_t size) {
|
||||
for (auto i = 0; i < MaxTaskCount; ++i) {
|
||||
if (m_valid[i] && m_tasks[i].channel == channel && m_tasks[i].size <= size) {
|
||||
this->CompleteTask(i, EventTrigger_ReceiveData);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void TaskManager::NotifyConnectReady() {
|
||||
for (auto i = 0; i < MaxTaskCount; ++i) {
|
||||
if (m_valid[i] && m_tasks[i].type == TaskType_Connect) {
|
||||
|
||||
@@ -22,6 +22,7 @@ namespace ams::htclow::mux {
|
||||
|
||||
enum EventTrigger : u8 {
|
||||
EventTrigger_Disconnect = 1,
|
||||
EventTrigger_ReceiveData = 2,
|
||||
EventTrigger_ConnectReady = 11,
|
||||
};
|
||||
|
||||
@@ -40,7 +41,7 @@ namespace ams::htclow::mux {
|
||||
bool has_event_trigger;
|
||||
EventTrigger event_trigger;
|
||||
TaskType type;
|
||||
u64 _38;
|
||||
size_t size;
|
||||
};
|
||||
private:
|
||||
bool m_valid[MaxTaskCount];
|
||||
@@ -49,6 +50,7 @@ namespace ams::htclow::mux {
|
||||
TaskManager() : m_valid() { /* ... */ }
|
||||
|
||||
void NotifyDisconnect(impl::ChannelInternalType channel);
|
||||
void NotifyReceiveData(impl::ChannelInternalType channel, size_t size);
|
||||
void NotifyConnectReady();
|
||||
private:
|
||||
void CompleteTask(int index, EventTrigger trigger);
|
||||
|
||||
Reference in New Issue
Block a user