Line data Source code
1 : #include "websocket.hpp"
2 : #include "../../util/logger.hpp"
3 : #include "tcp_socket.hpp"
4 :
5 : #include <random>
6 :
7 : namespace thinger::asio {
8 :
9 : using random_bytes_engine = std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
10 :
11 : std::atomic<unsigned long> websocket::connections{0};
12 :
13 76 : websocket::websocket(std::shared_ptr<socket> sock, bool binary, bool server)
14 : : socket("websocket", sock->get_io_context())
15 76 : , socket_(sock)
16 76 : , timer_(sock->get_io_context())
17 76 : , binary_(binary)
18 304 : , server_role_(server) {
19 76 : ++connections;
20 76 : LOG_DEBUG("websocket created");
21 76 : }
22 :
23 76 : websocket::~websocket() {
24 76 : --connections;
25 76 : LOG_DEBUG("releasing websocket");
26 76 : timer_.cancel();
27 76 : }
28 :
29 38 : void websocket::unmask(uint8_t buffer[], std::size_t size) {
30 38 : LOG_DEBUG("unmasking payload. size: {}", size);
31 2498 : for (size_t i = 0; i < size; ++i) {
32 2460 : buffer[i] ^= mask_[i % MASK_SIZE_BYTES];
33 : }
34 38 : }
35 :
36 38 : void websocket::start_timeout() {
37 38 : timer_.expires_after(CONNECTION_TIMEOUT_SECONDS);
38 38 : timer_.async_wait([this](const boost::system::error_code& e) {
39 38 : if (e) {
40 38 : if (e != boost::asio::error::operation_aborted) {
41 0 : LOG_ERROR("error on timeout: {}", e.message());
42 : }
43 38 : return;
44 : }
45 0 : if (data_received_) {
46 0 : data_received_ = false;
47 0 : return start_timeout();
48 : } else {
49 0 : if (!pending_ping_) {
50 0 : pending_ping_ = true;
51 0 : co_spawn(io_context_, [this]() -> awaitable<void> {
52 : co_await send_ping();
53 : if (socket_->is_open()) {
54 : start_timeout();
55 : }
56 0 : }, detached);
57 : } else {
58 0 : LOG_DEBUG("websocket ping timeout... closing connection!");
59 0 : close();
60 : }
61 : }
62 : });
63 38 : }
64 :
65 72 : awaitable<void> websocket::send_close(uint8_t buffer[], size_t size) {
66 : LOG_DEBUG("sending close frame");
67 : co_await send_message(0x88, buffer, size);
68 : close_sent_ = true;
69 : LOG_DEBUG("close frame sent");
70 :
71 : if (!close_received_) {
72 : // Race: read the close ack vs timeout
73 : timer_.cancel();
74 : timer_.expires_after(std::chrono::seconds{5});
75 :
76 38 : auto read_close_ack = [this]() -> awaitable<void> {
77 : uint8_t buf[125];
78 : boost::system::error_code ec;
79 : while (!close_received_ && socket_->is_open()) {
80 : co_await read_frame(buf, sizeof(buf), ec);
81 : if (ec) break;
82 : }
83 76 : };
84 :
85 38 : auto wait_timeout = [this]() -> awaitable<void> {
86 : auto [ec] = co_await timer_.async_wait(use_nothrow_awaitable);
87 76 : };
88 :
89 : co_await (read_close_ack() || wait_timeout());
90 :
91 : if (!close_received_) {
92 : LOG_WARNING("timeout while waiting close acknowledgement");
93 : }
94 : }
95 : close();
96 144 : }
97 :
98 0 : awaitable<void> websocket::send_ping(uint8_t buffer[], size_t size) {
99 : LOG_DEBUG("sending ping frame");
100 : co_await send_message(0x09, buffer, size);
101 : LOG_DEBUG("ping frame sent");
102 0 : }
103 :
104 0 : awaitable<void> websocket::send_pong(uint8_t buffer[], size_t size) {
105 : LOG_DEBUG("sending pong frame");
106 : co_await send_message(0x0A, buffer, size);
107 : LOG_DEBUG("pong frame sent");
108 0 : }
109 :
110 150 : awaitable<size_t> websocket::read_frame(uint8_t buffer[], size_t max_size, boost::system::error_code& ec) {
111 : // If there's remaining data in current frame, read it
112 : if (frame_remaining_ > 0) {
113 : auto read_size = std::min(frame_remaining_, max_size);
114 : auto [read_ec, bytes] = co_await socket_->read(buffer, read_size);
115 :
116 : if (read_ec) {
117 : ec = read_ec;
118 : co_return 0;
119 : }
120 :
121 : if (masked_) unmask(buffer, bytes);
122 : frame_remaining_ -= bytes;
123 :
124 : co_return bytes;
125 : }
126 :
127 : // Read frame header (2 bytes minimum)
128 : {
129 : auto [read_ec, bytes] = co_await socket_->read(buffer_, 2);
130 : if (read_ec) {
131 : ec = read_ec;
132 : co_return 0;
133 : }
134 : }
135 : data_received_ = true;
136 :
137 : uint8_t data_type = buffer_[0];
138 : uint8_t fin = data_type & 0b10000000;
139 : uint8_t rsv = data_type & 0b01110000;
140 :
141 : if (rsv) {
142 : LOG_ERROR("invalid RSV parameters");
143 : ec = boost::asio::error::invalid_argument;
144 : co_return 0;
145 : }
146 :
147 : uint8_t opcode = data_type & 0x0F;
148 : uint8_t data_size = buffer_[1] & ~(1 << 7);
149 : uint8_t masked = buffer_[1] & 0b10000000;
150 :
151 : LOG_DEBUG("decoded frame header. fin: {}, opcode: 0x{:02X} mask: {} data_size: {}", fin, opcode, masked, data_size);
152 :
153 : if (!masked && server_role_) {
154 : LOG_ERROR("client is not masking the information");
155 : ec = boost::asio::error::invalid_argument;
156 : co_return 0;
157 : }
158 :
159 : masked_ = masked;
160 : fin_ = fin;
161 : opcode_ = opcode;
162 :
163 : // Handle opcodes
164 : if (new_message_) {
165 : switch (opcode) {
166 : case 0x0:
167 : LOG_ERROR("received continuation message as the first message!");
168 : ec = boost::asio::error::invalid_argument;
169 : co_return 0;
170 : case 0x1: // text
171 : case 0x2: // binary
172 : message_opcode_ = opcode;
173 : break;
174 : case 0x8: // close
175 : case 0x9: // ping
176 : case 0xA: // pong
177 : if (!fin) {
178 : LOG_ERROR("control frame messages cannot be fragmented");
179 : ec = boost::asio::error::invalid_argument;
180 : co_return 0;
181 : }
182 : break;
183 : default:
184 : LOG_ERROR("received unknown websocket opcode: {}", (int)opcode);
185 : ec = boost::asio::error::invalid_argument;
186 : co_return 0;
187 : }
188 : } else {
189 : // Continuation frame expected
190 : if (opcode != 0x0 && opcode < 0x8) {
191 : LOG_ERROR("unexpected fragment type. expecting a continuation frame");
192 : ec = boost::asio::error::invalid_argument;
193 : co_return 0;
194 : }
195 : }
196 :
197 : // Determine payload length
198 : uint64_t payload_size = data_size;
199 : if (data_size == 126) {
200 : auto [read_ec, bytes] = co_await socket_->read(buffer_, 2);
201 : if (read_ec) {
202 : ec = read_ec;
203 : co_return 0;
204 : }
205 : payload_size = (buffer_[0] << 8) | buffer_[1];
206 : } else if (data_size == 127) {
207 : auto [read_ec, bytes] = co_await socket_->read(buffer_, 8);
208 : if (read_ec) {
209 : ec = read_ec;
210 : co_return 0;
211 : }
212 : payload_size = 0;
213 : for (int i = 0; i < 8; ++i) {
214 : payload_size = (payload_size << 8) | buffer_[i];
215 : }
216 : }
217 :
218 : frame_remaining_ = payload_size;
219 :
220 : // Read mask if present
221 : if (masked_) {
222 : auto [read_ec, bytes] = co_await socket_->read(mask_, MASK_SIZE_BYTES);
223 : if (read_ec) {
224 : ec = read_ec;
225 : co_return 0;
226 : }
227 : }
228 :
229 : // Handle control frames
230 : if (opcode >= 0x8) {
231 : // Read control frame payload
232 : uint8_t control_buffer[125];
233 : size_t control_size = std::min(static_cast<size_t>(payload_size), size_t(125));
234 : if (control_size > 0) {
235 : auto [read_ec, bytes] = co_await socket_->read(control_buffer, control_size);
236 : if (read_ec) {
237 : ec = read_ec;
238 : co_return 0;
239 : }
240 : if (masked_) unmask(control_buffer, control_size);
241 : }
242 : frame_remaining_ = 0;
243 :
244 : switch (opcode) {
245 : case 0x8: // close
246 : LOG_DEBUG("received close frame");
247 : close_received_ = true;
248 : if (!close_sent_) {
249 : co_await send_close();
250 : }
251 : ec = boost::asio::error::connection_aborted;
252 : co_return 0;
253 : case 0x9: // ping
254 : LOG_DEBUG("received ping frame");
255 : co_await send_pong(control_buffer, control_size);
256 : co_return co_await read_frame(buffer, max_size, ec);
257 : case 0xA: // pong
258 : LOG_DEBUG("received pong frame");
259 : pending_ping_ = false;
260 : data_received_ = false;
261 : co_return co_await read_frame(buffer, max_size, ec);
262 : }
263 : }
264 :
265 : // Read data frame payload
266 : if (frame_remaining_ == 0) {
267 : if (fin_) new_message_ = true;
268 : co_return 0;
269 : }
270 :
271 : auto read_size = std::min(frame_remaining_, max_size);
272 : auto [read_ec, bytes] = co_await socket_->read(buffer, read_size);
273 :
274 : if (read_ec) {
275 : ec = read_ec;
276 : co_return 0;
277 : }
278 :
279 : if (masked_) unmask(buffer, bytes);
280 : frame_remaining_ -= bytes;
281 :
282 : if (frame_remaining_ == 0 && fin_) {
283 : new_message_ = true;
284 : } else if (frame_remaining_ == 0) {
285 : new_message_ = false;
286 : }
287 :
288 : co_return bytes;
289 300 : }
290 :
291 146 : awaitable<io_result> websocket::send_message(uint8_t opcode, const uint8_t buffer[], size_t size) {
292 : std::lock_guard<std::mutex> lock(write_mutex_);
293 :
294 : uint8_t header_size = 2;
295 : output_[0] = 0x80 | opcode;
296 :
297 : if (size <= 125) {
298 : output_[1] = size;
299 : } else if (size <= 65535) {
300 : output_[1] = 126;
301 : output_[2] = (size >> 8) & 0xff;
302 : output_[3] = size & 0xff;
303 : header_size += 2;
304 : } else {
305 : output_[1] = 127;
306 : for (int i = 0; i < 8; ++i) {
307 : output_[2 + i] = (size >> ((7 - i) * 8)) & 0xff;
308 : }
309 : header_size += 8;
310 : }
311 :
312 : // Create output buffers
313 : std::vector<boost::asio::const_buffer> output_buffers;
314 : output_buffers.push_back(boost::asio::buffer(output_, header_size));
315 :
316 : // Handle masking for client role
317 : std::vector<uint8_t> masked_data;
318 : if (!server_role_) {
319 : output_[1] |= 0b10000000;
320 :
321 : static random_bytes_engine rbe;
322 : uint8_t mask[MASK_SIZE_BYTES];
323 : for (int i = 0; i < MASK_SIZE_BYTES; ++i) {
324 : mask[i] = rbe();
325 : output_[header_size + i] = mask[i];
326 : }
327 : header_size += MASK_SIZE_BYTES;
328 :
329 : // Mask the data
330 : masked_data.resize(size);
331 : for (size_t i = 0; i < size; ++i) {
332 : masked_data[i] = buffer[i] ^ mask[i % MASK_SIZE_BYTES];
333 : }
334 :
335 : output_buffers.clear();
336 : output_buffers.push_back(boost::asio::buffer(output_, header_size));
337 : output_buffers.push_back(boost::asio::buffer(masked_data));
338 : } else {
339 : output_buffers.push_back(boost::asio::buffer(buffer, size));
340 : }
341 :
342 : LOG_DEBUG("sending websocket data. header: {}, payload: {}", header_size, size);
343 :
344 : auto [ec, bytes] = co_await socket_->write(output_buffers);
345 : if (ec) {
346 : co_return io_result{ec, 0};
347 : }
348 : co_return io_result{ec, bytes - header_size};
349 292 : }
350 :
351 112 : awaitable<io_result> websocket::read_some(uint8_t buffer[], size_t max_size) {
352 : boost::system::error_code ec;
353 : auto bytes = co_await read_frame(buffer, max_size, ec);
354 : if (ec) {
355 : if (ec != boost::asio::error::connection_aborted) {
356 : LOG_ERROR("websocket read_some error: {}", ec.message());
357 : }
358 : co_return io_result{ec, 0};
359 : }
360 : co_return io_result{boost::system::error_code{}, bytes};
361 224 : }
362 :
363 0 : awaitable<io_result> websocket::read(uint8_t buffer[], size_t size) {
364 : boost::system::error_code ec;
365 : auto bytes = co_await read_frame(buffer, size, ec);
366 : if (ec) {
367 : if (ec != boost::asio::error::connection_aborted) {
368 : LOG_ERROR("websocket read error: {}", ec.message());
369 : }
370 : co_return io_result{ec, 0};
371 : }
372 : co_return io_result{boost::system::error_code{}, bytes};
373 0 : }
374 :
375 0 : awaitable<io_result> websocket::read(boost::asio::streambuf& buffer, size_t size) {
376 : // Not supported for websocket
377 : co_return io_result{boost::asio::error::operation_not_supported, 0};
378 0 : }
379 :
380 0 : awaitable<io_result> websocket::read_until(boost::asio::streambuf& buffer, std::string_view delim) {
381 : // Not supported for websocket
382 : co_return io_result{boost::asio::error::operation_not_supported, 0};
383 0 : }
384 :
385 4 : awaitable<io_result> websocket::write(const uint8_t buffer[], size_t size) {
386 : co_return co_await send_message(binary_ ? 0x02 : 0x01, buffer, size);
387 8 : }
388 :
389 70 : awaitable<io_result> websocket::write(std::string_view str) {
390 : co_return co_await send_message(binary_ ? 0x02 : 0x01,
391 : reinterpret_cast<const uint8_t*>(str.data()), str.size());
392 140 : }
393 :
394 0 : awaitable<io_result> websocket::write(const std::vector<boost::asio::const_buffer>& buffers) {
395 : // Not supported for websocket
396 : co_return io_result{boost::asio::error::operation_not_supported, 0};
397 0 : }
398 :
399 0 : awaitable<boost::system::error_code> websocket::wait(boost::asio::socket_base::wait_type type) {
400 : co_return co_await socket_->wait(type);
401 0 : }
402 :
403 0 : awaitable<boost::system::error_code> websocket::connect(
404 : const std::string& host,
405 : const std::string& port,
406 : std::chrono::seconds timeout) {
407 : co_return co_await socket_->connect(host, port, timeout);
408 0 : }
409 :
410 110 : void websocket::close() {
411 110 : timer_.cancel();
412 110 : socket_->close();
413 110 : }
414 :
415 38 : awaitable<void> websocket::close_graceful() {
416 : if (!close_sent_ && socket_->is_open()) {
417 : co_await send_close();
418 : } else if (socket_->is_open()) {
419 : close();
420 : }
421 76 : }
422 :
423 0 : void websocket::cancel() {
424 0 : socket_->cancel();
425 0 : }
426 :
427 0 : bool websocket::requires_handshake() const {
428 0 : return socket_->requires_handshake();
429 : }
430 :
431 0 : awaitable<boost::system::error_code> websocket::handshake(const std::string& host) {
432 : co_return co_await socket_->handshake(host);
433 0 : }
434 :
435 266 : bool websocket::is_open() const {
436 266 : return socket_->is_open();
437 : }
438 :
439 0 : bool websocket::is_secure() const {
440 0 : return socket_->is_secure();
441 : }
442 :
443 0 : size_t websocket::available() const {
444 0 : return socket_->available();
445 : }
446 :
447 0 : std::string websocket::get_remote_ip() const {
448 0 : return socket_->get_remote_ip();
449 : }
450 :
451 0 : std::string websocket::get_local_port() const {
452 0 : return socket_->get_local_port();
453 : }
454 :
455 0 : std::string websocket::get_remote_port() const {
456 0 : return socket_->get_remote_port();
457 : }
458 :
459 38 : std::size_t websocket::remaining_in_frame() const {
460 38 : return frame_remaining_;
461 : }
462 :
463 36 : bool websocket::is_message_complete() const {
464 36 : return new_message_;
465 : }
466 :
467 : }
|