Line data Source code
1 : #include "../../util/logger.hpp"
2 : #include "../../util/hex.hpp"
3 : #include "../../util/types.hpp"
4 : #include "websocket_connection.hpp"
5 : #include <utility>
6 : #include "../util/utf8.hpp"
7 :
8 : namespace thinger::http{
9 :
10 : std::atomic<unsigned long> websocket_connection::connections(0);
11 :
12 38 : websocket_connection::websocket_connection(std::shared_ptr<asio::websocket> socket) :
13 114 : ws_(std::move(socket))
14 : {
15 38 : connections++;
16 38 : LOG_LEVEL(2, "websocket connection created. current: {}", (unsigned long) connections);
17 :
18 38 : }
19 :
20 38 : websocket_connection::~websocket_connection()
21 : {
22 38 : connections--;
23 38 : LOG_LEVEL(1, "releasing websocket connection. current: {}", (unsigned long) connections);
24 38 : }
25 :
26 38 : void websocket_connection::on_message(std::function<void(std::string, bool binary)> callback){
27 38 : on_frame_callback_ = std::move(callback);
28 38 : }
29 :
30 38 : void websocket_connection::start_read_loop()
31 : {
32 38 : co_spawn(ws_->get_io_context(),
33 114 : [this, self = shared_from_this()]() -> awaitable<void> {
34 : // RAII guard: clears the on_message callback when the coroutine
35 : // frame is destroyed (normal exit, exception, or io_context
36 : // teardown). This breaks shared_ptr cycles that occur when the
37 : // user captures shared_ptr<websocket_connection> in the callback.
38 : struct cycle_guard {
39 : std::function<void(std::string, bool)>& ref;
40 38 : ~cycle_guard() { ref = nullptr; }
41 : } guard{on_frame_callback_};
42 :
43 : co_await read_loop();
44 76 : },
45 : detached);
46 38 : }
47 :
48 38 : awaitable<void> websocket_connection::read_loop()
49 : {
50 : size_t next_read_size = DEFAULT_BUFFER_SIZE;
51 :
52 : while (ws_->is_open()) {
53 : LOG_LEVEL(2, "waiting websocket data");
54 :
55 : auto buf = buffer_.prepare(next_read_size);
56 : auto [ec, bytes_transferred] = co_await ws_->read_some(
57 : static_cast<uint8_t*>(buf.data()), buf.size());
58 :
59 : if (ec) {
60 : break;
61 : }
62 :
63 : LOG_LEVEL(2, "socket read: {} bytes", bytes_transferred);
64 :
65 : buffer_.commit(bytes_transferred);
66 :
67 : // get remaining data in the frame
68 : auto remaining = ws_->remaining_in_frame();
69 :
70 : // no pending data to read from the frame
71 : if(remaining == 0){
72 :
73 : // is the message complete ? FIN flag is set
74 : if(ws_->is_message_complete()){
75 :
76 : auto readable = buffer_.data();
77 : auto* data_ptr = static_cast<const uint8_t*>(readable.data());
78 : auto data_size = readable.size();
79 :
80 : // check if the message is a valid UTF8 message
81 : if(!ws_->is_binary()){
82 : if(!utf8::is_valid(data_ptr, data_size)){
83 : LOG_ERROR("invalid UTF8 message received!");
84 : co_return;
85 : }
86 : }
87 :
88 : if (on_frame_callback_) {
89 : std::string data(reinterpret_cast<const char*>(data_ptr), data_size);
90 : LOG_DEBUG("decoded payload: '{}'", util::lowercase_hex_encode(data));
91 : on_frame_callback_(std::move(data), ws_->is_binary());
92 : }
93 :
94 : // clear processed buffer
95 : buffer_.consume(buffer_.size());
96 : }
97 :
98 : next_read_size = DEFAULT_BUFFER_SIZE;
99 :
100 : }else{
101 : // check if the buffer is not going to overflow
102 : if(buffer_.size() + remaining > buffer_.max_size()){
103 : LOG_ERROR("websocket buffer overflow. closing connection");
104 : co_return;
105 : }
106 :
107 : // next iteration will prepare enough space for remaining frame data
108 : next_read_size = remaining;
109 : }
110 : }
111 76 : }
112 :
113 36 : void websocket_connection::process_out_queue()
114 : {
115 36 : if(out_queue_.empty() || writing_) return;
116 36 : writing_ = true;
117 :
118 36 : co_spawn(ws_->get_io_context(),
119 108 : [this, self = shared_from_this()]() -> awaitable<void> {
120 : while(!out_queue_.empty() && ws_->is_open()) {
121 : LOG_LEVEL(2, "handling websocket write, remaining in queue: {}", out_queue_.size());
122 : auto& data = out_queue_.front();
123 : ws_->set_binary(data.second);
124 :
125 : auto [write_ec, write_bytes] = co_await ws_->write(std::string_view(data.first));
126 :
127 : if (write_ec) break;
128 :
129 : LOG_DEBUG("message sent, remaining in queue: {}", out_queue_.size());
130 : out_queue_.pop();
131 : }
132 : writing_ = false;
133 72 : },
134 : detached);
135 : }
136 :
137 0 : std::shared_ptr<asio::socket> websocket_connection::release_socket(){
138 : // cancel pending async i/io requests on this socket
139 0 : ws_->cancel();
140 :
141 : // return socket
142 0 : return ws_;
143 : }
144 :
145 38 : void websocket_connection::start(){
146 : // handle timeout on websocket
147 38 : ws_->start_timeout();
148 :
149 : // initiates async reading on socket
150 38 : start_read_loop();
151 38 : }
152 :
153 0 : void websocket_connection::stop(){
154 0 : execute([this, self = shared_from_this()]{
155 0 : LOG_LEVEL(1, "closing websocket");
156 0 : co_spawn(ws_->get_io_context(),
157 0 : [this, self]() -> awaitable<void> {
158 : co_await ws_->close_graceful();
159 : LOG_LEVEL(1, "websocket closed");
160 0 : },
161 : detached);
162 0 : });
163 0 : }
164 :
165 36 : bool websocket_connection::congested_connection(){
166 36 : return out_queue_.size()>=MAX_OUTPUT_MESSAGES;
167 : }
168 :
169 4 : void websocket_connection::send_binary(std::string data){
170 4 : execute([this, data = std::move(data)]() mutable {
171 : // stop pushing more messages if the connection is congested
172 4 : if(congested_connection()){
173 0 : LOG_WARNING("websocket is congested. discarding packets!");
174 0 : return;
175 : }
176 :
177 4 : LOG_LEVEL(2, "adding frame to websocket queue");
178 4 : out_queue_.emplace(std::move(data), true);
179 4 : process_out_queue();
180 : });
181 4 : }
182 :
183 32 : void websocket_connection::send_text(std::string text){
184 32 : execute([this, data = std::move(text)]() mutable {
185 : // stop pushing more messages if the connection is congested
186 32 : if(congested_connection()){
187 0 : LOG_WARNING("websocket is congested. discarding packets!");
188 0 : return;
189 : }
190 :
191 32 : LOG_LEVEL(2, "adding frame to websocket queue");
192 32 : out_queue_.emplace(std::move(data), false);
193 32 : process_out_queue();
194 : });
195 32 : }
196 :
197 : }
|