Line data Source code
1 : #include "../../util/logger.hpp"
2 : #include "../../util/hex.hpp"
3 : #include "../../util/types.hpp"
4 : #include "websocket_connection.hpp"
5 : #include <utility>
6 : #include "../util/utf8.hpp"
7 :
8 : namespace thinger::http{
9 :
10 : std::atomic<unsigned long> websocket_connection::connections(0);
11 :
12 38 : websocket_connection::websocket_connection(std::shared_ptr<asio::websocket> socket) :
13 114 : ws_(std::move(socket))
14 : {
15 38 : connections++;
16 38 : LOG_LEVEL(2, "websocket connection created. current: {}", (unsigned long) connections);
17 :
18 38 : }
19 :
20 38 : websocket_connection::~websocket_connection()
21 : {
22 38 : connections--;
23 38 : LOG_LEVEL(1, "releasing websocket connection. current: {}", (unsigned long) connections);
24 38 : }
25 :
26 38 : void websocket_connection::on_message(std::function<void(std::string, bool binary)> callback){
27 38 : on_frame_callback_ = std::move(callback);
28 38 : }
29 :
30 38 : void websocket_connection::start_read_loop()
31 : {
32 38 : co_spawn(ws_->get_io_context(),
33 114 : [this, self = shared_from_this()]() -> awaitable<void> {
34 : // RAII guard: clears the on_message callback when the coroutine
35 : // frame is destroyed (normal exit, exception, or io_context
36 : // teardown). This breaks shared_ptr cycles that occur when the
37 : // user captures shared_ptr<websocket_connection> in the callback.
38 : struct cycle_guard {
39 : std::function<void(std::string, bool)>& ref;
40 38 : ~cycle_guard() { ref = nullptr; }
41 : } guard{on_frame_callback_};
42 :
43 : co_await read_loop();
44 76 : },
45 : detached);
46 38 : }
47 :
48 38 : awaitable<void> websocket_connection::read_loop()
49 : {
50 : size_t next_read_size = DEFAULT_BUFFER_SIZE;
51 :
52 : while (ws_->is_open()) {
53 : LOG_LEVEL(2, "waiting websocket data");
54 :
55 : auto buf = buffer_.prepare(next_read_size);
56 : auto bytes_transferred = co_await ws_->read_some(
57 : static_cast<uint8_t*>(buf.data()), buf.size());
58 :
59 : // read_some returns 0 on error or close
60 : if (bytes_transferred == 0) {
61 : break;
62 : }
63 :
64 : LOG_LEVEL(2, "socket read: {} bytes", bytes_transferred);
65 :
66 : buffer_.commit(bytes_transferred);
67 :
68 : // get remaining data in the frame
69 : auto remaining = ws_->remaining_in_frame();
70 :
71 : // no pending data to read from the frame
72 : if(remaining == 0){
73 :
74 : // is the message complete ? FIN flag is set
75 : if(ws_->is_message_complete()){
76 :
77 : auto readable = buffer_.data();
78 : auto* data_ptr = static_cast<const uint8_t*>(readable.data());
79 : auto data_size = readable.size();
80 :
81 : // check if the message is a valid UTF8 message
82 : if(!ws_->is_binary()){
83 : if(!utf8::is_valid(data_ptr, data_size)){
84 : LOG_ERROR("invalid UTF8 message received!");
85 : co_return;
86 : }
87 : }
88 :
89 : if (on_frame_callback_) {
90 : std::string data(reinterpret_cast<const char*>(data_ptr), data_size);
91 : LOG_DEBUG("decoded payload: '{}'", util::lowercase_hex_encode(data));
92 : on_frame_callback_(std::move(data), ws_->is_binary());
93 : }
94 :
95 : // clear processed buffer
96 : buffer_.consume(buffer_.size());
97 : }
98 :
99 : next_read_size = DEFAULT_BUFFER_SIZE;
100 :
101 : }else{
102 : // check if the buffer is not going to overflow
103 : if(buffer_.size() + remaining > buffer_.max_size()){
104 : LOG_ERROR("websocket buffer overflow. closing connection");
105 : co_return;
106 : }
107 :
108 : // next iteration will prepare enough space for remaining frame data
109 : next_read_size = remaining;
110 : }
111 : }
112 76 : }
113 :
114 36 : void websocket_connection::process_out_queue()
115 : {
116 36 : if(out_queue_.empty() || writing_) return;
117 36 : writing_ = true;
118 :
119 36 : co_spawn(ws_->get_io_context(),
120 108 : [this, self = shared_from_this()]() -> awaitable<void> {
121 : while(!out_queue_.empty() && ws_->is_open()) {
122 : LOG_LEVEL(2, "handling websocket write, remaining in queue: {}", out_queue_.size());
123 : auto& data = out_queue_.front();
124 : ws_->set_binary(data.second);
125 :
126 : co_await ws_->write(std::string_view(data.first));
127 :
128 : if (!ws_->is_open()) break;
129 :
130 : LOG_DEBUG("message sent, remaining in queue: {}", out_queue_.size());
131 : out_queue_.pop();
132 : }
133 : writing_ = false;
134 72 : },
135 : detached);
136 : }
137 :
138 0 : std::shared_ptr<asio::socket> websocket_connection::release_socket(){
139 : // cancel pending async i/io requests on this socket
140 0 : ws_->cancel();
141 :
142 : // return socket
143 0 : return ws_;
144 : }
145 :
146 38 : void websocket_connection::start(){
147 : // handle timeout on websocket
148 38 : ws_->start_timeout();
149 :
150 : // initiates async reading on socket
151 38 : start_read_loop();
152 38 : }
153 :
154 0 : void websocket_connection::stop(){
155 0 : execute([this, self = shared_from_this()]{
156 0 : LOG_LEVEL(1, "closing websocket");
157 0 : co_spawn(ws_->get_io_context(),
158 0 : [this, self]() -> awaitable<void> {
159 : co_await ws_->close_graceful();
160 : LOG_LEVEL(1, "websocket closed");
161 0 : },
162 : detached);
163 0 : });
164 0 : }
165 :
166 36 : bool websocket_connection::congested_connection(){
167 36 : return out_queue_.size()>=MAX_OUTPUT_MESSAGES;
168 : }
169 :
170 4 : void websocket_connection::send_binary(std::string data){
171 4 : execute([this, data = std::move(data)]() mutable {
172 : // stop pushing more messages if the connection is congested
173 4 : if(congested_connection()){
174 0 : LOG_WARNING("websocket is congested. discarding packets!");
175 0 : return;
176 : }
177 :
178 4 : LOG_LEVEL(2, "adding frame to websocket queue");
179 4 : out_queue_.emplace(std::move(data), true);
180 4 : process_out_queue();
181 : });
182 4 : }
183 :
184 32 : void websocket_connection::send_text(std::string text){
185 32 : execute([this, data = std::move(text)]() mutable {
186 : // stop pushing more messages if the connection is congested
187 32 : if(congested_connection()){
188 0 : LOG_WARNING("websocket is congested. discarding packets!");
189 0 : return;
190 : }
191 :
192 32 : LOG_LEVEL(2, "adding frame to websocket queue");
193 32 : out_queue_.emplace(std::move(data), false);
194 32 : process_out_queue();
195 : });
196 32 : }
197 :
198 : }
|