Line data Source code
1 : #include "websocket.hpp"
2 : #include "../../util/logger.hpp"
3 : #include "tcp_socket.hpp"
4 :
5 : #include <random>
6 :
7 : namespace thinger::asio {
8 :
9 : using random_bytes_engine = std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
10 :
11 : std::atomic<unsigned long> websocket::connections{0};
12 :
13 76 : websocket::websocket(std::shared_ptr<socket> sock, bool binary, bool server)
14 : : socket("websocket", sock->get_io_context())
15 76 : , socket_(sock)
16 76 : , timer_(sock->get_io_context())
17 76 : , binary_(binary)
18 304 : , server_role_(server) {
19 76 : ++connections;
20 76 : LOG_DEBUG("websocket created");
21 76 : }
22 :
23 76 : websocket::~websocket() {
24 76 : --connections;
25 76 : LOG_DEBUG("releasing websocket");
26 76 : timer_.cancel();
27 76 : }
28 :
29 38 : void websocket::unmask(uint8_t buffer[], std::size_t size) {
30 38 : LOG_DEBUG("unmasking payload. size: {}", size);
31 2498 : for (size_t i = 0; i < size; ++i) {
32 2460 : buffer[i] ^= mask_[i % MASK_SIZE_BYTES];
33 : }
34 38 : }
35 :
36 38 : void websocket::start_timeout() {
37 38 : timer_.expires_after(CONNECTION_TIMEOUT_SECONDS);
38 38 : timer_.async_wait([this](const boost::system::error_code& e) {
39 38 : if (e) {
40 38 : if (e != boost::asio::error::operation_aborted) {
41 0 : LOG_ERROR("error on timeout: {}", e.message());
42 : }
43 38 : return;
44 : }
45 0 : if (data_received_) {
46 0 : data_received_ = false;
47 0 : return start_timeout();
48 : } else {
49 0 : if (!pending_ping_) {
50 0 : pending_ping_ = true;
51 0 : co_spawn(io_context_, [this]() -> awaitable<void> {
52 : co_await send_ping();
53 : if (socket_->is_open()) {
54 : start_timeout();
55 : }
56 0 : }, detached);
57 : } else {
58 0 : LOG_DEBUG("websocket ping timeout... closing connection!");
59 0 : close();
60 : }
61 : }
62 : });
63 38 : }
64 :
65 72 : awaitable<void> websocket::send_close(uint8_t buffer[], size_t size) {
66 : LOG_DEBUG("sending close frame");
67 : co_await send_message(0x88, buffer, size);
68 : close_sent_ = true;
69 : LOG_DEBUG("close frame sent");
70 :
71 : if (!close_received_) {
72 : // Race: read the close ack vs timeout
73 : timer_.cancel();
74 : timer_.expires_after(std::chrono::seconds{5});
75 :
76 38 : auto read_close_ack = [this]() -> awaitable<void> {
77 : uint8_t buf[125];
78 : boost::system::error_code ec;
79 : while (!close_received_ && socket_->is_open()) {
80 : co_await read_frame(buf, sizeof(buf), ec);
81 : if (ec) break;
82 : }
83 76 : };
84 :
85 38 : auto wait_timeout = [this]() -> awaitable<void> {
86 : auto [ec] = co_await timer_.async_wait(use_nothrow_awaitable);
87 76 : };
88 :
89 : co_await (read_close_ack() || wait_timeout());
90 :
91 : if (!close_received_) {
92 : LOG_WARNING("timeout while waiting close acknowledgement");
93 : }
94 : }
95 : close();
96 144 : }
97 :
98 0 : awaitable<void> websocket::send_ping(uint8_t buffer[], size_t size) {
99 : LOG_DEBUG("sending ping frame");
100 : co_await send_message(0x09, buffer, size);
101 : LOG_DEBUG("ping frame sent");
102 0 : }
103 :
104 0 : awaitable<void> websocket::send_pong(uint8_t buffer[], size_t size) {
105 : LOG_DEBUG("sending pong frame");
106 : co_await send_message(0x0A, buffer, size);
107 : LOG_DEBUG("pong frame sent");
108 0 : }
109 :
110 150 : awaitable<size_t> websocket::read_frame(uint8_t buffer[], size_t max_size, boost::system::error_code& ec) {
111 : // If there's remaining data in current frame, read it
112 : if (frame_remaining_ > 0) {
113 : auto read_size = std::min(frame_remaining_, max_size);
114 : auto bytes = co_await socket_->read(buffer, read_size);
115 :
116 : if (bytes == 0) {
117 : ec = boost::asio::error::connection_reset;
118 : co_return 0;
119 : }
120 :
121 : if (masked_) unmask(buffer, bytes);
122 : frame_remaining_ -= bytes;
123 :
124 : co_return bytes;
125 : }
126 :
127 : // Read frame header (2 bytes minimum)
128 : if (co_await socket_->read(buffer_, 2) != 2) {
129 : ec = boost::asio::error::connection_reset;
130 : co_return 0;
131 : }
132 : data_received_ = true;
133 :
134 : uint8_t data_type = buffer_[0];
135 : uint8_t fin = data_type & 0b10000000;
136 : uint8_t rsv = data_type & 0b01110000;
137 :
138 : if (rsv) {
139 : LOG_ERROR("invalid RSV parameters");
140 : ec = boost::asio::error::invalid_argument;
141 : co_return 0;
142 : }
143 :
144 : uint8_t opcode = data_type & 0x0F;
145 : uint8_t data_size = buffer_[1] & ~(1 << 7);
146 : uint8_t masked = buffer_[1] & 0b10000000;
147 :
148 : LOG_DEBUG("decoded frame header. fin: {}, opcode: 0x{:02X} mask: {} data_size: {}", fin, opcode, masked, data_size);
149 :
150 : if (!masked && server_role_) {
151 : LOG_ERROR("client is not masking the information");
152 : ec = boost::asio::error::invalid_argument;
153 : co_return 0;
154 : }
155 :
156 : masked_ = masked;
157 : fin_ = fin;
158 : opcode_ = opcode;
159 :
160 : // Handle opcodes
161 : if (new_message_) {
162 : switch (opcode) {
163 : case 0x0:
164 : LOG_ERROR("received continuation message as the first message!");
165 : ec = boost::asio::error::invalid_argument;
166 : co_return 0;
167 : case 0x1: // text
168 : case 0x2: // binary
169 : message_opcode_ = opcode;
170 : break;
171 : case 0x8: // close
172 : case 0x9: // ping
173 : case 0xA: // pong
174 : if (!fin) {
175 : LOG_ERROR("control frame messages cannot be fragmented");
176 : ec = boost::asio::error::invalid_argument;
177 : co_return 0;
178 : }
179 : break;
180 : default:
181 : LOG_ERROR("received unknown websocket opcode: {}", (int)opcode);
182 : ec = boost::asio::error::invalid_argument;
183 : co_return 0;
184 : }
185 : } else {
186 : // Continuation frame expected
187 : if (opcode != 0x0 && opcode < 0x8) {
188 : LOG_ERROR("unexpected fragment type. expecting a continuation frame");
189 : ec = boost::asio::error::invalid_argument;
190 : co_return 0;
191 : }
192 : }
193 :
194 : // Determine payload length
195 : uint64_t payload_size = data_size;
196 : if (data_size == 126) {
197 : if (co_await socket_->read(buffer_, 2) != 2) {
198 : ec = boost::asio::error::connection_reset;
199 : co_return 0;
200 : }
201 : payload_size = (buffer_[0] << 8) | buffer_[1];
202 : } else if (data_size == 127) {
203 : if (co_await socket_->read(buffer_, 8) != 8) {
204 : ec = boost::asio::error::connection_reset;
205 : co_return 0;
206 : }
207 : payload_size = 0;
208 : for (int i = 0; i < 8; ++i) {
209 : payload_size = (payload_size << 8) | buffer_[i];
210 : }
211 : }
212 :
213 : frame_remaining_ = payload_size;
214 :
215 : // Read mask if present
216 : if (masked_) {
217 : if (co_await socket_->read(mask_, MASK_SIZE_BYTES) != MASK_SIZE_BYTES) {
218 : ec = boost::asio::error::connection_reset;
219 : co_return 0;
220 : }
221 : }
222 :
223 : // Handle control frames
224 : if (opcode >= 0x8) {
225 : // Read control frame payload
226 : uint8_t control_buffer[125];
227 : size_t control_size = std::min(static_cast<size_t>(payload_size), size_t(125));
228 : if (control_size > 0) {
229 : if (co_await socket_->read(control_buffer, control_size) != control_size) {
230 : ec = boost::asio::error::connection_reset;
231 : co_return 0;
232 : }
233 : if (masked_) unmask(control_buffer, control_size);
234 : }
235 : frame_remaining_ = 0;
236 :
237 : switch (opcode) {
238 : case 0x8: // close
239 : LOG_DEBUG("received close frame");
240 : close_received_ = true;
241 : if (!close_sent_) {
242 : co_await send_close();
243 : }
244 : ec = boost::asio::error::connection_aborted;
245 : co_return 0;
246 : case 0x9: // ping
247 : LOG_DEBUG("received ping frame");
248 : co_await send_pong(control_buffer, control_size);
249 : co_return co_await read_frame(buffer, max_size, ec);
250 : case 0xA: // pong
251 : LOG_DEBUG("received pong frame");
252 : pending_ping_ = false;
253 : data_received_ = false;
254 : co_return co_await read_frame(buffer, max_size, ec);
255 : }
256 : }
257 :
258 : // Read data frame payload
259 : if (frame_remaining_ == 0) {
260 : if (fin_) new_message_ = true;
261 : co_return 0;
262 : }
263 :
264 : auto read_size = std::min(frame_remaining_, max_size);
265 : auto bytes = co_await socket_->read(buffer, read_size);
266 :
267 : if (bytes == 0) {
268 : ec = boost::asio::error::connection_reset;
269 : co_return 0;
270 : }
271 :
272 : if (masked_) unmask(buffer, bytes);
273 : frame_remaining_ -= bytes;
274 :
275 : if (frame_remaining_ == 0 && fin_) {
276 : new_message_ = true;
277 : } else if (frame_remaining_ == 0) {
278 : new_message_ = false;
279 : }
280 :
281 : co_return bytes;
282 300 : }
283 :
284 146 : awaitable<size_t> websocket::send_message(uint8_t opcode, const uint8_t buffer[], size_t size) {
285 : std::lock_guard<std::mutex> lock(write_mutex_);
286 :
287 : uint8_t header_size = 2;
288 : output_[0] = 0x80 | opcode;
289 :
290 : if (size <= 125) {
291 : output_[1] = size;
292 : } else if (size <= 65535) {
293 : output_[1] = 126;
294 : output_[2] = (size >> 8) & 0xff;
295 : output_[3] = size & 0xff;
296 : header_size += 2;
297 : } else {
298 : output_[1] = 127;
299 : for (int i = 0; i < 8; ++i) {
300 : output_[2 + i] = (size >> ((7 - i) * 8)) & 0xff;
301 : }
302 : header_size += 8;
303 : }
304 :
305 : // Create output buffers
306 : std::vector<boost::asio::const_buffer> output_buffers;
307 : output_buffers.push_back(boost::asio::buffer(output_, header_size));
308 :
309 : // Handle masking for client role
310 : std::vector<uint8_t> masked_data;
311 : if (!server_role_) {
312 : output_[1] |= 0b10000000;
313 :
314 : static random_bytes_engine rbe;
315 : uint8_t mask[MASK_SIZE_BYTES];
316 : for (int i = 0; i < MASK_SIZE_BYTES; ++i) {
317 : mask[i] = rbe();
318 : output_[header_size + i] = mask[i];
319 : }
320 : header_size += MASK_SIZE_BYTES;
321 :
322 : // Mask the data
323 : masked_data.resize(size);
324 : for (size_t i = 0; i < size; ++i) {
325 : masked_data[i] = buffer[i] ^ mask[i % MASK_SIZE_BYTES];
326 : }
327 :
328 : output_buffers.clear();
329 : output_buffers.push_back(boost::asio::buffer(output_, header_size));
330 : output_buffers.push_back(boost::asio::buffer(masked_data));
331 : } else {
332 : output_buffers.push_back(boost::asio::buffer(buffer, size));
333 : }
334 :
335 : LOG_DEBUG("sending websocket data. header: {}, payload: {}", header_size, size);
336 :
337 : auto bytes = co_await socket_->write(output_buffers);
338 : if (bytes == 0) {
339 : co_return 0;
340 : }
341 : co_return bytes - header_size;
342 292 : }
343 :
344 112 : awaitable<size_t> websocket::read_some(uint8_t buffer[], size_t max_size) {
345 : boost::system::error_code ec;
346 : auto bytes = co_await read_frame(buffer, max_size, ec);
347 : if (ec) {
348 : if (ec != boost::asio::error::connection_aborted) {
349 : LOG_ERROR("websocket read_some error: {}", ec.message());
350 : }
351 : co_return 0;
352 : }
353 : co_return bytes;
354 224 : }
355 :
356 0 : awaitable<size_t> websocket::read(uint8_t buffer[], size_t size) {
357 : boost::system::error_code ec;
358 : auto bytes = co_await read_frame(buffer, size, ec);
359 : if (ec) {
360 : if (ec != boost::asio::error::connection_aborted) {
361 : LOG_ERROR("websocket read error: {}", ec.message());
362 : }
363 : co_return 0;
364 : }
365 : co_return bytes;
366 0 : }
367 :
368 0 : awaitable<size_t> websocket::read(boost::asio::streambuf& buffer, size_t size) {
369 : // Not supported for websocket
370 : co_return 0;
371 0 : }
372 :
373 0 : awaitable<size_t> websocket::read_until(boost::asio::streambuf& buffer, std::string_view delim) {
374 : // Not supported for websocket
375 : co_return 0;
376 0 : }
377 :
378 4 : awaitable<size_t> websocket::write(const uint8_t buffer[], size_t size) {
379 : co_return co_await send_message(binary_ ? 0x02 : 0x01, buffer, size);
380 8 : }
381 :
382 70 : awaitable<size_t> websocket::write(std::string_view str) {
383 : co_return co_await send_message(binary_ ? 0x02 : 0x01,
384 : reinterpret_cast<const uint8_t*>(str.data()), str.size());
385 140 : }
386 :
387 0 : awaitable<size_t> websocket::write(const std::vector<boost::asio::const_buffer>& buffers) {
388 : // Not supported for websocket
389 : co_return 0;
390 0 : }
391 :
392 0 : awaitable<boost::system::error_code> websocket::wait(boost::asio::socket_base::wait_type type) {
393 : co_return co_await socket_->wait(type);
394 0 : }
395 :
396 0 : awaitable<boost::system::error_code> websocket::connect(
397 : const std::string& host,
398 : const std::string& port,
399 : std::chrono::seconds timeout) {
400 : co_return co_await socket_->connect(host, port, timeout);
401 0 : }
402 :
403 110 : void websocket::close() {
404 110 : timer_.cancel();
405 110 : socket_->close();
406 110 : }
407 :
408 38 : awaitable<void> websocket::close_graceful() {
409 : if (!close_sent_ && socket_->is_open()) {
410 : co_await send_close();
411 : } else if (socket_->is_open()) {
412 : close();
413 : }
414 76 : }
415 :
416 0 : void websocket::cancel() {
417 0 : socket_->cancel();
418 0 : }
419 :
420 0 : bool websocket::requires_handshake() const {
421 0 : return socket_->requires_handshake();
422 : }
423 :
424 0 : awaitable<boost::system::error_code> websocket::handshake(const std::string& host) {
425 : co_return co_await socket_->handshake(host);
426 0 : }
427 :
428 340 : bool websocket::is_open() const {
429 340 : return socket_->is_open();
430 : }
431 :
432 0 : bool websocket::is_secure() const {
433 0 : return socket_->is_secure();
434 : }
435 :
436 0 : size_t websocket::available() const {
437 0 : return socket_->available();
438 : }
439 :
440 0 : std::string websocket::get_remote_ip() const {
441 0 : return socket_->get_remote_ip();
442 : }
443 :
444 0 : std::string websocket::get_local_port() const {
445 0 : return socket_->get_local_port();
446 : }
447 :
448 0 : std::string websocket::get_remote_port() const {
449 0 : return socket_->get_remote_port();
450 : }
451 :
452 38 : std::size_t websocket::remaining_in_frame() const {
453 38 : return frame_remaining_;
454 : }
455 :
456 36 : bool websocket::is_message_complete() const {
457 36 : return new_message_;
458 : }
459 :
460 : }
|