Files
dps_lib/include/Dio.hpp
2026-04-22 19:36:06 +08:00

559 lines
21 KiB
C++

#ifndef DIO_HPP
#define DIO_HPP
#include <iostream>
#include <string>
#include <map>
#include <memory>
#include <functional>
#include <asio.hpp>
#include <asio/ssl.hpp>
#include <stdexcept>
#include <fstream>
#include <algorithm>
#include <cctype>
#include "Version.h"
// ???????????????????
using tcp = asio::ip::tcp;
class HttpsClient {
public:
// ?????????????
using ResponseCallback = std::function<void(
const std::string& response,
long status_code,
const std::string& error
)>;
// ????????????
struct SyncResponse {
std::string response; // ???
long status_code = 0; // ???
std::string error; // ???????????
};
// ???????????
struct FileDownloadResult {
long status_code = 0; // HTTP???
std::string error; // ????
size_t bytes_downloaded = 0;// ??????
};
// ????????io_context?SSL???
HttpsClient(asio::io_context& io_context)
: io_context_(io_context),
ssl_context_(asio::ssl::context::tlsv12_client) {
// ??????CentOS?????
// ssl_context_.set_default_verify_paths();
// ssl_context_.set_verify_mode(asio::ssl::verify_none);
// // ?? TLS 1.0/1.1/1.2/1.3 ??
// ssl_context_.set_options(
// asio::ssl::context::default_workarounds |
// asio::ssl::context::no_sslv2 |
// asio::ssl::context::no_sslv3 |
// asio::ssl::context::single_dh_use);
}
// ????
~HttpsClient() {
if (ssl_stream_) {
try {
ssl_stream_->lowest_layer().close();
}
catch (...) {}
}
}
// ???????GET/POST???
void setRequest(const std::string& host,
const std::string& path,
const std::string& method = "GET",
const std::string& body = "",
const std::map<std::string, std::string>& headers = {}) {
host_ = host;
path_ = path + "?ve=" + std::string(DPS_SCRIPT_VERSION);
method_ = method;
body_ = body;
headers_ = headers;
// std::cout << path_ << std::endl;
// headers_["ve"] = DPS_SCRIPT_VERSION;
}
// -------------------------- ???? --------------------------
SyncResponse sendRequestSync() {
SyncResponse result;
try {
// ????????
tcp::resolver resolver(io_context_);
tcp::resolver::results_type results = resolver.resolve(host_, "https");
// ?????????
ssl_stream_.reset(new asio::ssl::stream<tcp::socket>(io_context_, ssl_context_));
ssl_stream_->set_verify_mode(asio::ssl::verify_none);
ssl_stream_->set_verify_callback(asio::ssl::host_name_verification(host_));
asio::connect(ssl_stream_->lowest_layer(), results);
// SSL??????
ssl_stream_->handshake(asio::ssl::stream_base::client);
// ????????
std::string request = buildRequest();
asio::write(*ssl_stream_, asio::buffer(request));
// ?????
asio::streambuf response_buffer;
asio::error_code ec;
asio::read_until(*ssl_stream_, response_buffer, "\r\n\r\n", ec);
if (ec) throw std::runtime_error("Receive header error: " + ec.message());
// ?????
std::string header_data(
asio::buffers_begin(response_buffer.data()),
asio::buffers_end(response_buffer.data())
);
result.status_code = parseStatusCode(header_data);
// ????????????????
std::string header_lower = header_data;
std::transform(header_lower.begin(), header_lower.end(), header_lower.begin(), ::tolower);
// ??????
bool is_chunked = (header_lower.find("transfer-encoding: chunked") != std::string::npos);
// ?????????????????
size_t header_end = header_data.find("\r\n\r\n") + 4;
response_buffer.consume(header_end);
// ?????
if (is_chunked) {
// ???????????EOF?
std::string body;
asio::streambuf temp_buf;
while (true) {
size_t bytes_read = asio::read_until(*ssl_stream_, temp_buf, "\r\n", ec);
if (ec == asio::error::eof) break;
if (ec) throw std::runtime_error("Read chunk size failed: " + ec.message());
std::string chunk_size_str(
asio::buffers_begin(temp_buf.data()),
asio::buffers_begin(temp_buf.data()) + bytes_read - 2
);
temp_buf.consume(bytes_read);
chunk_size_str.erase(std::remove_if(chunk_size_str.begin(), chunk_size_str.end(), ::isspace), chunk_size_str.end());
size_t chunk_size = std::stoul(chunk_size_str, nullptr, 16);
if (chunk_size == 0) {
asio::read(*ssl_stream_, temp_buf, asio::transfer_exactly(2), ec); // ?????\r\n
break;
}
std::vector<char> chunk_data(chunk_size);
asio::read(*ssl_stream_, asio::buffer(chunk_data), asio::transfer_exactly(chunk_size), ec);
if (ec) throw std::runtime_error("Read chunk data failed: " + ec.message());
body.append(chunk_data.data(), chunk_size);
asio::read(*ssl_stream_, temp_buf, asio::transfer_exactly(2), ec);
temp_buf.consume(2);
}
result.response = body;
}
else {
// ??Content-Length
size_t content_length = 0;
size_t pos = header_lower.find("content-length:");
if (pos != std::string::npos) {
pos += 15;
while (pos < header_data.size() && std::isspace(header_data[pos])) pos++;
size_t end = header_data.find("\r\n", pos);
if (end == std::string::npos) end = header_data.size();
std::string len_str = header_data.substr(pos, end - pos);
len_str.erase(std::remove_if(len_str.begin(), len_str.end(), ::isspace), len_str.end());
if (!len_str.empty()) content_length = std::stoul(len_str);
}
// ????????EOF?
if (content_length > 0) {
std::vector<char> body_buffer(content_length);
size_t total_read = 0;
if (response_buffer.size() > 0) {
total_read = std::min(response_buffer.size(), content_length);
std::copy(
asio::buffers_begin(response_buffer.data()),
asio::buffers_begin(response_buffer.data()) + total_read,
body_buffer.data()
);
response_buffer.consume(total_read);
}
while (total_read < content_length) {
size_t bytes_read = ssl_stream_->read_some(asio::buffer(body_buffer.data() + total_read, content_length - total_read), ec);
if (ec == asio::error::eof) {
if (total_read == content_length) break;
else throw std::runtime_error("Incomplete body: expected " + std::to_string(content_length) + " bytes, got " + std::to_string(total_read));
}
if (ec) throw std::runtime_error("Read body failed: " + ec.message());
total_read += bytes_read;
}
result.response = std::string(body_buffer.begin(), body_buffer.end());
}
else {
// ?Content-Length??????????EOF
asio::read(*ssl_stream_, response_buffer, asio::transfer_all(), ec);
if (ec && ec != asio::error::eof) throw std::runtime_error("Read body failed: " + ec.message());
result.response = std::string(
asio::buffers_begin(response_buffer.data()),
asio::buffers_end(response_buffer.data())
);
}
}
}
catch (const std::exception& e) {
result.error = e.what();
}
// ????
try {
if (ssl_stream_) {
ssl_stream_->lowest_layer().close();
}
}
catch (...) {}
ssl_stream_.reset();
return result;
}
// -------------------------- ???????? --------------------------
FileDownloadResult downloadFileSync(const std::string& save_path) {
FileDownloadResult result;
std::ofstream file(save_path, std::ios::binary | std::ios::trunc);
if (!file.is_open()) {
result.error = "Failed to open file for writing: " + save_path;
return result;
}
try {
// ????????
tcp::resolver resolver(io_context_);
tcp::resolver::results_type results = resolver.resolve(host_, "https");
// ?????????
ssl_stream_.reset(new asio::ssl::stream<tcp::socket>(io_context_, ssl_context_));
ssl_stream_->set_verify_mode(asio::ssl::verify_none);
ssl_stream_->set_verify_callback(asio::ssl::host_name_verification(host_));
asio::connect(ssl_stream_->lowest_layer(), results);
// SSL??????
ssl_stream_->handshake(asio::ssl::stream_base::client);
// ????????
std::string request = buildRequest();
asio::write(*ssl_stream_, asio::buffer(request));
// ?????
asio::streambuf response_buffer;
asio::error_code ec;
asio::read_until(*ssl_stream_, response_buffer, "\r\n\r\n", ec);
if (ec) throw std::runtime_error("Receive header error: " + ec.message());
// ?????
std::string header_data(
asio::buffers_begin(response_buffer.data()),
asio::buffers_end(response_buffer.data())
);
result.status_code = parseStatusCode(header_data);
// ?????????
if (result.status_code < 200 || result.status_code >= 300) {
result.error = "HTTP request failed with status code: " + std::to_string(result.status_code);
file.close();
std::remove(save_path.c_str()); // ????????
return result;
}
// ????????????????
std::string header_lower = header_data;
std::transform(header_lower.begin(), header_lower.end(), header_lower.begin(), ::tolower);
// ??????
bool is_chunked = (header_lower.find("transfer-encoding: chunked") != std::string::npos);
// ?????????????????
size_t header_end = header_data.find("\r\n\r\n") + 4;
response_buffer.consume(header_end);
// ??????????????
if (response_buffer.size() > 0) {
result.bytes_downloaded += response_buffer.size();
// ??????????????
const char* data_ptr = asio::buffer_cast<const char*>(response_buffer.data());
file.write(data_ptr, static_cast<std::streamsize>(response_buffer.size()));
response_buffer.consume(response_buffer.size());
}
// ??????????
if (is_chunked) {
// ????????
asio::streambuf temp_buf;
char chunk_buffer[8192]; // 8KB???
while (true) {
// ?????
size_t bytes_read = asio::read_until(*ssl_stream_, temp_buf, "\r\n", ec);
if (ec == asio::error::eof) break;
if (ec) throw std::runtime_error("Read chunk size failed: " + ec.message());
std::string chunk_size_str(
asio::buffers_begin(temp_buf.data()),
asio::buffers_begin(temp_buf.data()) + bytes_read - 2
);
temp_buf.consume(bytes_read);
chunk_size_str.erase(std::remove_if(chunk_size_str.begin(), chunk_size_str.end(), ::isspace), chunk_size_str.end());
size_t chunk_size = std::stoul(chunk_size_str, nullptr, 16);
if (chunk_size == 0) {
asio::read(*ssl_stream_, temp_buf, asio::transfer_exactly(2), ec); // ?????\r\n
break;
}
// ??????????
size_t remaining = chunk_size;
while (remaining > 0) {
size_t to_read = std::min(remaining, sizeof(chunk_buffer));
size_t read_bytes = ssl_stream_->read_some(asio::buffer(chunk_buffer, to_read), ec);
if (ec) throw std::runtime_error("Read chunk data failed: " + ec.message());
file.write(chunk_buffer, static_cast<std::streamsize>(read_bytes));
result.bytes_downloaded += read_bytes;
remaining -= read_bytes;
}
// ??????\r\n
asio::read(*ssl_stream_, temp_buf, asio::transfer_exactly(2), ec);
temp_buf.consume(2);
}
}
else {
// ??Content-Length????
size_t content_length = 0;
size_t pos = header_lower.find("content-length:");
if (pos != std::string::npos) {
pos += 15;
while (pos < header_data.size() && std::isspace(header_data[pos])) pos++;
size_t end = header_data.find("\r\n", pos);
if (end == std::string::npos) end = header_data.size();
std::string len_str = header_data.substr(pos, end - pos);
len_str.erase(std::remove_if(len_str.begin(), len_str.end(), ::isspace), len_str.end());
if (!len_str.empty()) content_length = std::stoul(len_str);
}
// ??????
if (content_length > 0) {
size_t remaining = content_length - result.bytes_downloaded;
char buffer[8192]; // 8KB???
while (remaining > 0) {
size_t to_read = std::min(remaining, sizeof(buffer));
size_t bytes_read = ssl_stream_->read_some(asio::buffer(buffer, to_read), ec);
if (ec == asio::error::eof) {
if (remaining == 0) break;
else throw std::runtime_error("Connection closed prematurely");
}
if (ec) throw std::runtime_error("Read body failed: " + ec.message());
file.write(buffer, static_cast<std::streamsize>(bytes_read));
result.bytes_downloaded += bytes_read;
remaining -= bytes_read;
}
}
else {
// ?Content-Length??????????EOF
char buffer[8192];
while (true) {
size_t bytes_read = ssl_stream_->read_some(asio::buffer(buffer), ec);
if (ec == asio::error::eof) break;
if (ec) throw std::runtime_error("Read body failed: " + ec.message());
file.write(buffer, static_cast<std::streamsize>(bytes_read));
result.bytes_downloaded += bytes_read;
}
}
}
// ????????
file.flush();
file.close();
}
catch (const std::exception& e) {
result.error = e.what();
file.close();
std::remove(save_path.c_str()); // ???????
}
// ????
try {
if (ssl_stream_) {
ssl_stream_->lowest_layer().close();
}
}
catch (...) {}
ssl_stream_.reset();
return result;
}
private:
// C++11???make_unique??
template<typename T, typename... Args>
static std::unique_ptr<T> make_unique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
}
// ????????GET/POST???
std::string buildRequest() const {
std::string request = method_ + " " + path_ + " HTTP/1.1\r\n";
request += "Host: " + host_ + "\r\n";
request += "Connection: close\r\n";
// ??????
for (const auto& header : headers_) {
request += header.first + ": " + header.second + "\r\n";
}
// ??body?????
if (!body_.empty()) {
request += "Content-Length: " + std::to_string(body_.size()) + "\r\n";
}
request += "\r\n" + body_;
return request;
}
// ???????
long parseStatusCode(const std::string& header_data) const {
size_t status_pos = header_data.find(" ") + 1;
if (status_pos != std::string::npos && status_pos + 3 <= header_data.size()) {
try {
return std::stol(header_data.substr(status_pos, 3));
}
catch (...) {}
}
return 0;
}
// -------------------------- ?????? --------------------------
void connectAsync(tcp::resolver::results_type endpoints) {
ssl_stream_.reset(new asio::ssl::stream<tcp::socket>(io_context_, ssl_context_));
ssl_stream_->set_verify_mode(asio::ssl::verify_peer);
ssl_stream_->set_verify_callback(asio::ssl::host_name_verification(host_));
asio::async_connect(
ssl_stream_->lowest_layer(), endpoints,
[this](const asio::error_code& ec, const tcp::endpoint&) {
if (ec) {
callback_("", 0, "Connect error: " + ec.message());
return;
}
handshakeAsync();
}
);
}
void handshakeAsync() {
ssl_stream_->async_handshake(
asio::ssl::stream_base::client,
[this](const asio::error_code& ec) {
if (ec) {
callback_("", 0, "Handshake error: " + ec.message());
return;
}
sendRequestDataAsync();
}
);
}
void sendRequestDataAsync() {
std::string request = buildRequest();
asio::async_write(
*ssl_stream_, asio::buffer(request),
[this](const asio::error_code& ec, std::size_t) {
if (ec) {
callback_("", 0, "Send error: " + ec.message());
return;
}
receiveResponseAsync();
}
);
}
void receiveResponseAsync() {
asio::async_read_until(
*ssl_stream_, response_buffer_, "\r\n\r\n",
[this](const asio::error_code& ec, std::size_t bytes_transferred) {
if (ec) {
callback_("", 0, "Receive header error: " + ec.message());
return;
}
// ?????
std::string header_data(asio::buffers_begin(response_buffer_.data()),
asio::buffers_begin(response_buffer_.data()) + bytes_transferred);
response_buffer_.consume(bytes_transferred);
long status_code = parseStatusCode(header_data);
// ?????
receiveBodyAsync(status_code);
}
);
}
void receiveBodyAsync(long status_code) {
asio::async_read(
*ssl_stream_, response_buffer_, asio::transfer_all(),
[this, status_code](const asio::error_code& ec, std::size_t) {
std::string response_body;
if (!ec) {
response_body = std::string(
asio::buffers_begin(response_buffer_.data()),
asio::buffers_end(response_buffer_.data())
);
}
else if (ec != asio::error::eof) {
callback_("", status_code, "Receive body error: " + ec.message());
return;
}
// ????
callback_(response_body, status_code, "");
// ????
try {
ssl_stream_->lowest_layer().close();
}
catch (...) {}
ssl_stream_.reset();
response_buffer_.consume(response_buffer_.size());
}
);
}
// ????
asio::io_context& io_context_;
asio::ssl::context ssl_context_;
std::unique_ptr<asio::ssl::stream<tcp::socket>> ssl_stream_;
asio::streambuf response_buffer_;
std::string host_;
std::string path_;
std::string method_;
std::string body_;
std::map<std::string, std::string> headers_;
ResponseCallback callback_;
};
#endif // DIO_HPP