use stop token to manage object lifetime across async callbacks, such as download async

This commit is contained in:
ITotalJustice
2025-01-14 15:35:09 +00:00
parent 4e5e1a801b
commit 64a40ae672
7 changed files with 96 additions and 37 deletions

View File

@@ -6,6 +6,7 @@
#include <functional> #include <functional>
#include <unordered_map> #include <unordered_map>
#include <algorithm> #include <algorithm>
#include <stop_token>
#include <switch.h> #include <switch.h>
namespace sphaira::curl { namespace sphaira::curl {
@@ -29,6 +30,7 @@ struct ApiResult;
using Path = fs::FsPath; using Path = fs::FsPath;
using OnComplete = std::function<void(ApiResult& result)>; using OnComplete = std::function<void(ApiResult& result)>;
using OnProgress = std::function<bool(u32 dltotal, u32 dlnow, u32 ultotal, u32 ulnow)>; using OnProgress = std::function<bool(u32 dltotal, u32 dlnow, u32 ultotal, u32 ulnow)>;
using StopToken = std::stop_token;
struct Url { struct Url {
Url() = default; Url() = default;
@@ -71,6 +73,7 @@ struct ApiResult {
struct DownloadEventData { struct DownloadEventData {
OnComplete callback; OnComplete callback;
ApiResult result; ApiResult result;
StopToken stoken;
}; };
auto Init() -> bool; auto Init() -> bool;
@@ -114,6 +117,7 @@ struct Api {
auto ToMemory(Ts&&... ts) { auto ToMemory(Ts&&... ts) {
static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified"); static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified");
static_assert(!std::disjunction_v<std::is_same<Path, Ts>...>, "Path must not valid for memory"); static_assert(!std::disjunction_v<std::is_same<Path, Ts>...>, "Path must not valid for memory");
static_assert(!std::disjunction_v<std::is_same<OnComplete, Ts>...>, "OnComplete must not be specified");
Api::set_option(std::forward<Ts>(ts)...); Api::set_option(std::forward<Ts>(ts)...);
return curl::ToMemory(*this); return curl::ToMemory(*this);
} }
@@ -122,6 +126,7 @@ struct Api {
auto ToFile(Ts&&... ts) { auto ToFile(Ts&&... ts) {
static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified"); static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified");
static_assert(std::disjunction_v<std::is_same<Path, Ts>...>, "Path must be specified"); static_assert(std::disjunction_v<std::is_same<Path, Ts>...>, "Path must be specified");
static_assert(!std::disjunction_v<std::is_same<OnComplete, Ts>...>, "OnComplete must not be specified");
Api::set_option(std::forward<Ts>(ts)...); Api::set_option(std::forward<Ts>(ts)...);
return curl::ToFile(*this); return curl::ToFile(*this);
} }
@@ -131,6 +136,7 @@ struct Api {
static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified"); static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified");
static_assert(std::disjunction_v<std::is_same<OnComplete, Ts>...>, "OnComplete must be specified"); static_assert(std::disjunction_v<std::is_same<OnComplete, Ts>...>, "OnComplete must be specified");
static_assert(!std::disjunction_v<std::is_same<Path, Ts>...>, "Path must not valid for memory"); static_assert(!std::disjunction_v<std::is_same<Path, Ts>...>, "Path must not valid for memory");
static_assert(std::disjunction_v<std::is_same<StopToken, Ts>...>, "StopToken must be specified");
Api::set_option(std::forward<Ts>(ts)...); Api::set_option(std::forward<Ts>(ts)...);
return curl::ToMemoryAsync(*this); return curl::ToMemoryAsync(*this);
} }
@@ -140,18 +146,38 @@ struct Api {
static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified"); static_assert(std::disjunction_v<std::is_same<Url, Ts>...>, "Url must be specified");
static_assert(std::disjunction_v<std::is_same<Path, Ts>...>, "Path must be specified"); static_assert(std::disjunction_v<std::is_same<Path, Ts>...>, "Path must be specified");
static_assert(std::disjunction_v<std::is_same<OnComplete, Ts>...>, "OnComplete must be specified"); static_assert(std::disjunction_v<std::is_same<OnComplete, Ts>...>, "OnComplete must be specified");
static_assert(std::disjunction_v<std::is_same<StopToken, Ts>...>, "StopToken must be specified");
Api::set_option(std::forward<Ts>(ts)...); Api::set_option(std::forward<Ts>(ts)...);
return curl::ToFileAsync(*this); return curl::ToFileAsync(*this);
} }
Url m_url; auto& GetUrl() const {
Fields m_fields{}; return m_url.m_str;
Header m_header{}; }
Flags m_flags{}; auto& GetFields() const {
Path m_path{}; return m_fields.m_str;
OnComplete m_on_complete = nullptr; }
OnProgress m_on_progress = nullptr; auto& GetHeader() const {
Priority m_prio = Priority::High; return m_header;
}
auto& GetFlags() const {
return m_flags.m_flags;
}
auto& GetPath() const {
return m_path;
}
auto& GetOnComplete() const {
return m_on_complete;
}
auto& GetOnProgress() const {
return m_on_progress;
}
auto& GetPriority() const {
return m_prio;
}
auto& GetToken() const {
return m_stoken;
}
private: private:
void SetOption(Url&& v) { void SetOption(Url&& v) {
@@ -178,6 +204,9 @@ private:
void SetOption(Priority&& v) { void SetOption(Priority&& v) {
m_prio = v; m_prio = v;
} }
void SetOption(StopToken&& v) {
m_stoken = v;
}
template <typename T> template <typename T>
void set_option(T&& t) { void set_option(T&& t) {
@@ -189,6 +218,18 @@ private:
set_option(std::forward<T>(t)); set_option(std::forward<T>(t));
set_option(std::forward<Ts>(ts)...); set_option(std::forward<Ts>(ts)...);
} }
private:
Url m_url;
Fields m_fields{};
Header m_header{};
Flags m_flags{};
Path m_path{};
OnComplete m_on_complete{nullptr};
OnProgress m_on_progress{nullptr};
Priority m_prio{Priority::High};
std::stop_source m_stop_source{};
StopToken m_stoken{m_stop_source.get_token()};
}; };
} // namespace sphaira::curl } // namespace sphaira::curl

View File

@@ -1,13 +1,16 @@
#pragma once #pragma once
#include "types.hpp" #include "types.hpp"
#include <stop_token>
namespace sphaira::ui { namespace sphaira::ui {
class Object { class Object {
public: public:
Object() = default; Object() = default;
virtual ~Object() = default; virtual ~Object() {
m_stop_source.request_stop();
}
virtual auto Draw(NVGcontext* vg, Theme* theme) -> void = 0; virtual auto Draw(NVGcontext* vg, Theme* theme) -> void = 0;
@@ -71,8 +74,14 @@ public:
m_hidden = value; m_hidden = value;
} }
auto GetToken() const {
return m_stop_source.get_token();
}
protected: protected:
Vec4 m_pos{}; Vec4 m_pos{};
// used for lifetime management across threads.
std::stop_source m_stop_source{};
bool m_hidden{false}; bool m_hidden{false};
}; };

View File

@@ -421,7 +421,9 @@ void App::Loop() {
} }
} else if constexpr(std::is_same_v<T, curl::DownloadEventData>) { } else if constexpr(std::is_same_v<T, curl::DownloadEventData>) {
log_write("[DownloadEventData] got event\n"); log_write("[DownloadEventData] got event\n");
if (arg.callback && !arg.stoken.stop_requested()) {
arg.callback(arg.result); arg.callback(arg.result);
}
} else { } else {
static_assert(false, "non-exhaustive visitor!"); static_assert(false, "non-exhaustive visitor!");
} }

View File

@@ -308,7 +308,7 @@ struct ThreadQueue {
ThreadQueueEntry entry{}; ThreadQueueEntry entry{};
entry.api = api; entry.api = api;
switch (api.m_prio) { switch (api.GetPriority()) {
case Priority::Normal: case Priority::Normal:
m_entries.emplace_back(entry); m_entries.emplace_back(entry);
break; break;
@@ -350,13 +350,13 @@ auto ProgressCallbackFunc1(void *clientp, curl_off_t dltotal, curl_off_t dlnow,
} }
auto ProgressCallbackFunc2(void *clientp, curl_off_t dltotal, curl_off_t dlnow, curl_off_t ultotal, curl_off_t ulnow) -> size_t { auto ProgressCallbackFunc2(void *clientp, curl_off_t dltotal, curl_off_t dlnow, curl_off_t ultotal, curl_off_t ulnow) -> size_t {
if (!g_running) { auto api = static_cast<Api*>(clientp);
if (!g_running || api->GetToken().stop_requested()) {
return 1; return 1;
} }
// log_write("pcall called %u %u %u %u\n", dltotal, dlnow, ultotal, ulnow); // log_write("pcall called %u %u %u %u\n", dltotal, dlnow, ultotal, ulnow);
auto callback = *static_cast<OnProgress*>(clientp); if (!api->GetOnProgress()(dltotal, dlnow, ultotal, ulnow)) {
if (!callback(dltotal, dlnow, ultotal, ulnow)) {
return 1; return 1;
} }
@@ -444,11 +444,11 @@ auto header_callback(char* b, size_t size, size_t nitems, void* userdata) -> siz
auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult { auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
fs::FsPath tmp_buf; fs::FsPath tmp_buf;
const bool has_file = !e.m_path.empty() && e.m_path != ""; const bool has_file = !e.GetPath().empty() && e.GetPath() != "";
const bool has_post = !e.m_fields.m_str.empty() && e.m_fields.m_str != ""; const bool has_post = !e.GetFields().empty() && e.GetFields() != "";
DataStruct chunk; DataStruct chunk;
Header header_in = e.m_header; Header header_in = e.GetHeader();
Header header_out; Header header_out;
fs::FsNativeSd fs; fs::FsNativeSd fs;
@@ -466,8 +466,8 @@ auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
return {}; return {};
} }
if (e.m_flags.m_flags & Flag_Cache) { if (e.GetFlags() & Flag_Cache) {
g_cache.get(e.m_path, header_in); g_cache.get(e.GetPath(), header_in);
} }
} }
@@ -475,7 +475,7 @@ auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
chunk.data.reserve(CHUNK_SIZE); chunk.data.reserve(CHUNK_SIZE);
curl_easy_reset(curl); curl_easy_reset(curl);
CURL_EASY_SETOPT_LOG(curl, CURLOPT_URL, e.m_url.m_str.c_str()); CURL_EASY_SETOPT_LOG(curl, CURLOPT_URL, e.GetUrl().c_str());
CURL_EASY_SETOPT_LOG(curl, CURLOPT_USERAGENT, "TotalJustice"); CURL_EASY_SETOPT_LOG(curl, CURLOPT_USERAGENT, "TotalJustice");
CURL_EASY_SETOPT_LOG(curl, CURLOPT_FOLLOWLOCATION, 1L); CURL_EASY_SETOPT_LOG(curl, CURLOPT_FOLLOWLOCATION, 1L);
CURL_EASY_SETOPT_LOG(curl, CURLOPT_SSL_VERIFYPEER, 0L); CURL_EASY_SETOPT_LOG(curl, CURLOPT_SSL_VERIFYPEER, 0L);
@@ -487,8 +487,8 @@ auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
CURL_EASY_SETOPT_LOG(curl, CURLOPT_HEADERDATA, &header_out); CURL_EASY_SETOPT_LOG(curl, CURLOPT_HEADERDATA, &header_out);
if (has_post) { if (has_post) {
CURL_EASY_SETOPT_LOG(curl, CURLOPT_POSTFIELDS, e.m_fields.m_str.c_str()); CURL_EASY_SETOPT_LOG(curl, CURLOPT_POSTFIELDS, e.GetFields().c_str());
log_write("setting post field: %s\n", e.m_fields.m_str.c_str()); log_write("setting post field: %s\n", e.GetFields().c_str());
} }
struct curl_slist* list = NULL; struct curl_slist* list = NULL;
@@ -517,8 +517,8 @@ auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
} }
// progress calls. // progress calls.
if (e.m_on_progress) { if (e.GetOnProgress()) {
CURL_EASY_SETOPT_LOG(curl, CURLOPT_XFERINFODATA, &e.m_on_progress); CURL_EASY_SETOPT_LOG(curl, CURLOPT_XFERINFODATA, &e);
CURL_EASY_SETOPT_LOG(curl, CURLOPT_XFERINFOFUNCTION, ProgressCallbackFunc2); CURL_EASY_SETOPT_LOG(curl, CURLOPT_XFERINFOFUNCTION, ProgressCallbackFunc2);
} else { } else {
CURL_EASY_SETOPT_LOG(curl, CURLOPT_XFERINFOFUNCTION, ProgressCallbackFunc1); CURL_EASY_SETOPT_LOG(curl, CURLOPT_XFERINFOFUNCTION, ProgressCallbackFunc1);
@@ -546,16 +546,16 @@ auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
if (res == CURLE_OK) { if (res == CURLE_OK) {
if (http_code == 304) { if (http_code == 304) {
log_write("cached download: %s\n", e.m_url.m_str.c_str()); log_write("cached download: %s\n", e.GetUrl().c_str());
} else { } else {
log_write("un-cached download: %s code: %u\n", e.m_url.m_str.c_str(), http_code); log_write("un-cached download: %s code: %u\n", e.GetUrl().c_str(), http_code);
if (e.m_flags.m_flags & Flag_Cache) { if (e.GetFlags() & Flag_Cache) {
g_cache.set(e.m_path, header_out); g_cache.set(e.GetPath(), header_out);
} }
fs.DeleteFile(e.m_path); fs.DeleteFile(e.GetPath());
fs.CreateDirectoryRecursivelyWithPath(e.m_path); fs.CreateDirectoryRecursivelyWithPath(e.GetPath());
if (R_FAILED(fs.RenameFile(tmp_buf, e.m_path))) { if (R_FAILED(fs.RenameFile(tmp_buf, e.GetPath()))) {
success = false; success = false;
} }
} }
@@ -568,8 +568,8 @@ auto DownloadInternal(CURL* curl, const Api& e) -> ApiResult {
} }
} }
log_write("Downloaded %s %s\n", e.m_url.m_str.c_str(), curl_easy_strerror(res)); log_write("Downloaded %s %s\n", e.GetUrl().c_str(), curl_easy_strerror(res));
return {success, http_code, header_out, chunk.data, e.m_path}; return {success, http_code, header_out, chunk.data, e.GetPath()};
} }
auto DownloadInternal(const Api& e) -> ApiResult { auto DownloadInternal(const Api& e) -> ApiResult {
@@ -604,8 +604,8 @@ void ThreadEntry::ThreadFunc(void* p) {
#if 1 #if 1
const auto result = DownloadInternal(data->m_curl, data->m_api); const auto result = DownloadInternal(data->m_curl, data->m_api);
if (g_running) { if (g_running && data->m_api.GetOnComplete() && !data->m_api.GetToken().stop_requested()) {
const DownloadEventData event_data{data->m_api.m_on_complete, result}; const DownloadEventData event_data{data->m_api.GetOnComplete(), result, data->m_api.GetToken()};
evman::push(std::move(event_data), false); evman::push(std::move(event_data), false);
} else { } else {
break; break;
@@ -736,14 +736,14 @@ void Exit() {
} }
auto ToMemory(const Api& e) -> ApiResult { auto ToMemory(const Api& e) -> ApiResult {
if (!e.m_path.empty()) { if (!e.GetPath().empty()) {
return {}; return {};
} }
return DownloadInternal(e); return DownloadInternal(e);
} }
auto ToFile(const Api& e) -> ApiResult { auto ToFile(const Api& e) -> ApiResult {
if (e.m_path.empty()) { if (e.GetPath().empty()) {
return {}; return {};
} }
return DownloadInternal(e); return DownloadInternal(e);

View File

@@ -663,6 +663,7 @@ EntryMenu::EntryMenu(Entry& entry, const LazyImage& default_icon, Menu& menu)
curl::Url{URL_POST_FEEDBACK}, curl::Url{URL_POST_FEEDBACK},
curl::Path{file}, curl::Path{file},
curl::Fields{post}, curl::Fields{post},
curl::StopToken{this->GetToken()},
curl::OnComplete{[](auto& result){ curl::OnComplete{[](auto& result){
if (result.success) { if (result.success) {
log_write("got feedback!\n"); log_write("got feedback!\n");
@@ -697,6 +698,7 @@ EntryMenu::EntryMenu(Entry& entry, const LazyImage& default_icon, Menu& menu)
curl::Url{url}, curl::Url{url},
curl::Path{path}, curl::Path{path},
curl::Flags{curl::Flag_Cache}, curl::Flags{curl::Flag_Cache},
curl::StopToken{this->GetToken()},
curl::OnComplete{[this, path](auto& result){ curl::OnComplete{[this, path](auto& result){
if (result.success) { if (result.success) {
if (result.code == 304) { if (result.code == 304) {
@@ -990,6 +992,7 @@ Menu::Menu(const std::vector<NroEntry>& nro_entries) : MenuBase{"AppStore"_i18n}
curl::Url{URL_JSON}, curl::Url{URL_JSON},
curl::Path{REPO_PATH}, curl::Path{REPO_PATH},
curl::Flags{curl::Flag_Cache}, curl::Flags{curl::Flag_Cache},
curl::StopToken{this->GetToken()},
curl::OnComplete{[this](auto& result){ curl::OnComplete{[this](auto& result){
if (result.success) { if (result.success) {
m_repo_download_state = ImageDownloadState::Done; m_repo_download_state = ImageDownloadState::Done;
@@ -1071,6 +1074,7 @@ void Menu::Draw(NVGcontext* vg, Theme* theme) {
curl::Url{url}, curl::Url{url},
curl::Path{path}, curl::Path{path},
curl::Flags{curl::Flag_Cache}, curl::Flags{curl::Flag_Cache},
curl::StopToken{this->GetToken()},
curl::OnComplete{[this, &image](auto& result) { curl::OnComplete{[this, &image](auto& result) {
if (result.success) { if (result.success) {
image.state = ImageDownloadState::Done; image.state = ImageDownloadState::Done;

View File

@@ -150,6 +150,7 @@ MainMenu::MainMenu() {
curl::Url{GITHUB_URL}, curl::Url{GITHUB_URL},
curl::Path{CACHE_PATH}, curl::Path{CACHE_PATH},
curl::Flags{curl::Flag_Cache}, curl::Flags{curl::Flag_Cache},
curl::StopToken{this->GetToken()},
curl::Header{ curl::Header{
{ "Accept", "application/vnd.github+json" }, { "Accept", "application/vnd.github+json" },
}, },

View File

@@ -647,6 +647,7 @@ void Menu::Draw(NVGcontext* vg, Theme* theme) {
curl::Url{url}, curl::Url{url},
curl::Path{path}, curl::Path{path},
curl::Flags{curl::Flag_Cache}, curl::Flags{curl::Flag_Cache},
curl::StopToken{this->GetToken()},
curl::OnComplete{[this, &image](auto& result) { curl::OnComplete{[this, &image](auto& result) {
if (result.success) { if (result.success) {
image.state = ImageDownloadState::Done; image.state = ImageDownloadState::Done;
@@ -733,6 +734,7 @@ void Menu::PackListDownload() {
curl::Url{packList_url}, curl::Url{packList_url},
curl::Path{packlist_path}, curl::Path{packlist_path},
curl::Flags{curl::Flag_Cache}, curl::Flags{curl::Flag_Cache},
curl::StopToken{this->GetToken()},
curl::OnComplete{[this, page_index](auto& result){ curl::OnComplete{[this, page_index](auto& result){
log_write("got themezer data\n"); log_write("got themezer data\n");
if (!result.success) { if (!result.success) {