Line data Source code
1 : #include "socket_pipe.hpp"
2 : #include "../util/logger.hpp"
3 :
4 : #include <vector>
5 :
6 : namespace thinger::asio {
7 :
8 18 : socket_pipe::socket_pipe(std::shared_ptr<socket> source, std::shared_ptr<socket> target)
9 18 : : source_(std::move(source)), target_(std::move(target)) {
10 18 : }
11 :
12 18 : socket_pipe::~socket_pipe() {
13 : try {
14 18 : if (on_end_) on_end_();
15 0 : } catch (...) {}
16 18 : }
17 :
18 15 : awaitable<void> socket_pipe::run() {
19 : auto self = shared_from_this();
20 : co_await (
21 : forward(source_, target_, bytes_s2t_) ||
22 : forward(target_, source_, bytes_t2s_)
23 : );
24 : cancel();
25 30 : }
26 :
27 0 : void socket_pipe::start() {
28 0 : auto self = shared_from_this();
29 0 : co_spawn(source_->get_io_context(), [self]() -> awaitable<void> {
30 : co_await self->run();
31 0 : }, detached);
32 0 : }
33 :
34 57 : void socket_pipe::cancel() {
35 57 : if (cancelled_.exchange(true)) return;
36 15 : source_->close();
37 15 : target_->close();
38 : }
39 :
40 6 : void socket_pipe::set_on_end(std::function<void()> listener) {
41 6 : on_end_ = std::move(listener);
42 6 : }
43 :
44 3 : size_t socket_pipe::bytes_source_to_target() const {
45 6 : return bytes_s2t_.load();
46 : }
47 :
48 3 : size_t socket_pipe::bytes_target_to_source() const {
49 6 : return bytes_t2s_.load();
50 : }
51 :
52 3 : std::shared_ptr<socket> socket_pipe::get_source() const {
53 3 : return source_;
54 : }
55 :
56 3 : std::shared_ptr<socket> socket_pipe::get_target() const {
57 3 : return target_;
58 : }
59 :
60 30 : awaitable<void> socket_pipe::forward(
61 : std::shared_ptr<socket> from,
62 : std::shared_ptr<socket> to,
63 : std::atomic<size_t>& bytes_transferred)
64 : {
65 : try {
66 : std::vector<uint8_t> buffer(BUFFER_SIZE);
67 : while (!cancelled_) {
68 : size_t n = co_await from->read_some(buffer.data(), BUFFER_SIZE);
69 : if (n == 0) break;
70 : co_await to->write(buffer.data(), n);
71 : bytes_transferred.fetch_add(n, std::memory_order_relaxed);
72 : }
73 : } catch (const boost::system::system_error& e) {
74 : if (e.code() != boost::asio::error::eof &&
75 : e.code() != boost::asio::error::operation_aborted) {
76 : LOG_WARNING("socket_pipe forward error: {}", e.what());
77 : }
78 : } catch (const std::exception& e) {
79 : LOG_WARNING("socket_pipe forward error: {}", e.what());
80 : }
81 : // Close both sockets to interrupt the other direction
82 : cancel();
83 60 : }
84 :
85 : } // namespace thinger::asio
|