LCOV - code coverage report
Current view: top level - asio/sockets - websocket.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 55.2 % 96 53
Test Date: 2026-04-21 17:49:55 Functions: 50.0 % 36 18

            Line data    Source code
       1              : #include "websocket.hpp"
       2              : #include "../../util/logger.hpp"
       3              : #include "tcp_socket.hpp"
       4              : 
       5              : #include <random>
       6              : 
       7              : namespace thinger::asio {
       8              : 
       9              : using random_bytes_engine = std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
      10              : 
      11              : std::atomic<unsigned long> websocket::connections{0};
      12              : 
      13           76 : websocket::websocket(std::shared_ptr<socket> sock, bool binary, bool server)
      14              :     : socket("websocket", sock->get_io_context())
      15           76 :     , socket_(sock)
      16           76 :     , timer_(sock->get_io_context())
      17           76 :     , binary_(binary)
      18          304 :     , server_role_(server) {
      19           76 :     ++connections;
      20           76 :     LOG_DEBUG("websocket created");
      21           76 : }
      22              : 
      23           76 : websocket::~websocket() {
      24           76 :     --connections;
      25           76 :     LOG_DEBUG("releasing websocket");
      26           76 :     timer_.cancel();
      27           76 : }
      28              : 
      29           38 : void websocket::unmask(uint8_t buffer[], std::size_t size) {
      30           38 :     LOG_DEBUG("unmasking payload. size: {}", size);
      31         2498 :     for (size_t i = 0; i < size; ++i) {
      32         2460 :         buffer[i] ^= mask_[i % MASK_SIZE_BYTES];
      33              :     }
      34           38 : }
      35              : 
      36           38 : void websocket::start_timeout() {
      37           38 :     timer_.expires_after(CONNECTION_TIMEOUT_SECONDS);
      38           38 :     timer_.async_wait([this](const boost::system::error_code& e) {
      39           38 :         if (e) {
      40           38 :             if (e != boost::asio::error::operation_aborted) {
      41            0 :                 LOG_ERROR("error on timeout: {}", e.message());
      42              :             }
      43           38 :             return;
      44              :         }
      45            0 :         if (data_received_) {
      46            0 :             data_received_ = false;
      47            0 :             return start_timeout();
      48              :         } else {
      49            0 :             if (!pending_ping_) {
      50            0 :                 pending_ping_ = true;
      51            0 :                 co_spawn(io_context_, [this]() -> awaitable<void> {
      52              :                     co_await send_ping();
      53              :                     if (socket_->is_open()) {
      54              :                         start_timeout();
      55              :                     }
      56            0 :                 }, detached);
      57              :             } else {
      58            0 :                 LOG_DEBUG("websocket ping timeout... closing connection!");
      59            0 :                 close();
      60              :             }
      61              :         }
      62              :     });
      63           38 : }
      64              : 
      65           72 : awaitable<void> websocket::send_close(uint8_t buffer[], size_t size) {
      66              :     LOG_DEBUG("sending close frame");
      67              :     co_await send_message(0x88, buffer, size);
      68              :     close_sent_ = true;
      69              :     LOG_DEBUG("close frame sent");
      70              : 
      71              :     if (!close_received_) {
      72              :         // Race: read the close ack vs timeout
      73              :         timer_.cancel();
      74              :         timer_.expires_after(std::chrono::seconds{5});
      75              : 
      76           38 :         auto read_close_ack = [this]() -> awaitable<void> {
      77              :             uint8_t buf[125];
      78              :             boost::system::error_code ec;
      79              :             while (!close_received_ && socket_->is_open()) {
      80              :                 co_await read_frame(buf, sizeof(buf), ec);
      81              :                 if (ec) break;
      82              :             }
      83           76 :         };
      84              : 
      85           38 :         auto wait_timeout = [this]() -> awaitable<void> {
      86              :             auto [ec] = co_await timer_.async_wait(use_nothrow_awaitable);
      87           76 :         };
      88              : 
      89              :         co_await (read_close_ack() || wait_timeout());
      90              : 
      91              :         if (!close_received_) {
      92              :             LOG_WARNING("timeout while waiting close acknowledgement");
      93              :         }
      94              :     }
      95              :     close();
      96          144 : }
      97              : 
      98            0 : awaitable<void> websocket::send_ping(uint8_t buffer[], size_t size) {
      99              :     LOG_DEBUG("sending ping frame");
     100              :     co_await send_message(0x09, buffer, size);
     101              :     LOG_DEBUG("ping frame sent");
     102            0 : }
     103              : 
     104            0 : awaitable<void> websocket::send_pong(uint8_t buffer[], size_t size) {
     105              :     LOG_DEBUG("sending pong frame");
     106              :     co_await send_message(0x0A, buffer, size);
     107              :     LOG_DEBUG("pong frame sent");
     108            0 : }
     109              : 
     110          150 : awaitable<size_t> websocket::read_frame(uint8_t buffer[], size_t max_size, boost::system::error_code& ec) {
     111              :     // If there's remaining data in current frame, read it
     112              :     if (frame_remaining_ > 0) {
     113              :         auto read_size = std::min(frame_remaining_, max_size);
     114              :         auto [read_ec, bytes] = co_await socket_->read(buffer, read_size);
     115              : 
     116              :         if (read_ec) {
     117              :             ec = read_ec;
     118              :             co_return 0;
     119              :         }
     120              : 
     121              :         if (masked_) unmask(buffer, bytes);
     122              :         frame_remaining_ -= bytes;
     123              : 
     124              :         co_return bytes;
     125              :     }
     126              : 
     127              :     // Read frame header (2 bytes minimum)
     128              :     {
     129              :         auto [read_ec, bytes] = co_await socket_->read(buffer_, 2);
     130              :         if (read_ec) {
     131              :             ec = read_ec;
     132              :             co_return 0;
     133              :         }
     134              :     }
     135              :     data_received_ = true;
     136              : 
     137              :     uint8_t data_type = buffer_[0];
     138              :     uint8_t fin = data_type & 0b10000000;
     139              :     uint8_t rsv = data_type & 0b01110000;
     140              : 
     141              :     if (rsv) {
     142              :         LOG_ERROR("invalid RSV parameters");
     143              :         ec = boost::asio::error::invalid_argument;
     144              :         co_return 0;
     145              :     }
     146              : 
     147              :     uint8_t opcode = data_type & 0x0F;
     148              :     uint8_t data_size = buffer_[1] & ~(1 << 7);
     149              :     uint8_t masked = buffer_[1] & 0b10000000;
     150              : 
     151              :     LOG_DEBUG("decoded frame header. fin: {}, opcode: 0x{:02X} mask: {} data_size: {}", fin, opcode, masked, data_size);
     152              : 
     153              :     if (!masked && server_role_) {
     154              :         LOG_ERROR("client is not masking the information");
     155              :         ec = boost::asio::error::invalid_argument;
     156              :         co_return 0;
     157              :     }
     158              : 
     159              :     masked_ = masked;
     160              :     fin_ = fin;
     161              :     opcode_ = opcode;
     162              : 
     163              :     // Handle opcodes
     164              :     if (new_message_) {
     165              :         switch (opcode) {
     166              :             case 0x0:
     167              :                 LOG_ERROR("received continuation message as the first message!");
     168              :                 ec = boost::asio::error::invalid_argument;
     169              :                 co_return 0;
     170              :             case 0x1: // text
     171              :             case 0x2: // binary
     172              :                 message_opcode_ = opcode;
     173              :                 break;
     174              :             case 0x8: // close
     175              :             case 0x9: // ping
     176              :             case 0xA: // pong
     177              :                 if (!fin) {
     178              :                     LOG_ERROR("control frame messages cannot be fragmented");
     179              :                     ec = boost::asio::error::invalid_argument;
     180              :                     co_return 0;
     181              :                 }
     182              :                 break;
     183              :             default:
     184              :                 LOG_ERROR("received unknown websocket opcode: {}", (int)opcode);
     185              :                 ec = boost::asio::error::invalid_argument;
     186              :                 co_return 0;
     187              :         }
     188              :     } else {
     189              :         // Continuation frame expected
     190              :         if (opcode != 0x0 && opcode < 0x8) {
     191              :             LOG_ERROR("unexpected fragment type. expecting a continuation frame");
     192              :             ec = boost::asio::error::invalid_argument;
     193              :             co_return 0;
     194              :         }
     195              :     }
     196              : 
     197              :     // Determine payload length
     198              :     uint64_t payload_size = data_size;
     199              :     if (data_size == 126) {
     200              :         auto [read_ec, bytes] = co_await socket_->read(buffer_, 2);
     201              :         if (read_ec) {
     202              :             ec = read_ec;
     203              :             co_return 0;
     204              :         }
     205              :         payload_size = (buffer_[0] << 8) | buffer_[1];
     206              :     } else if (data_size == 127) {
     207              :         auto [read_ec, bytes] = co_await socket_->read(buffer_, 8);
     208              :         if (read_ec) {
     209              :             ec = read_ec;
     210              :             co_return 0;
     211              :         }
     212              :         payload_size = 0;
     213              :         for (int i = 0; i < 8; ++i) {
     214              :             payload_size = (payload_size << 8) | buffer_[i];
     215              :         }
     216              :     }
     217              : 
     218              :     frame_remaining_ = payload_size;
     219              : 
     220              :     // Read mask if present
     221              :     if (masked_) {
     222              :         auto [read_ec, bytes] = co_await socket_->read(mask_, MASK_SIZE_BYTES);
     223              :         if (read_ec) {
     224              :             ec = read_ec;
     225              :             co_return 0;
     226              :         }
     227              :     }
     228              : 
     229              :     // Handle control frames
     230              :     if (opcode >= 0x8) {
     231              :         // Read control frame payload
     232              :         uint8_t control_buffer[125];
     233              :         size_t control_size = std::min(static_cast<size_t>(payload_size), size_t(125));
     234              :         if (control_size > 0) {
     235              :             auto [read_ec, bytes] = co_await socket_->read(control_buffer, control_size);
     236              :             if (read_ec) {
     237              :                 ec = read_ec;
     238              :                 co_return 0;
     239              :             }
     240              :             if (masked_) unmask(control_buffer, control_size);
     241              :         }
     242              :         frame_remaining_ = 0;
     243              : 
     244              :         switch (opcode) {
     245              :             case 0x8: // close
     246              :                 LOG_DEBUG("received close frame");
     247              :                 close_received_ = true;
     248              :                 if (!close_sent_) {
     249              :                     co_await send_close();
     250              :                 }
     251              :                 ec = boost::asio::error::connection_aborted;
     252              :                 co_return 0;
     253              :             case 0x9: // ping
     254              :                 LOG_DEBUG("received ping frame");
     255              :                 co_await send_pong(control_buffer, control_size);
     256              :                 co_return co_await read_frame(buffer, max_size, ec);
     257              :             case 0xA: // pong
     258              :                 LOG_DEBUG("received pong frame");
     259              :                 pending_ping_ = false;
     260              :                 data_received_ = false;
     261              :                 co_return co_await read_frame(buffer, max_size, ec);
     262              :         }
     263              :     }
     264              : 
     265              :     // Read data frame payload
     266              :     if (frame_remaining_ == 0) {
     267              :         if (fin_) new_message_ = true;
     268              :         co_return 0;
     269              :     }
     270              : 
     271              :     auto read_size = std::min(frame_remaining_, max_size);
     272              :     auto [read_ec, bytes] = co_await socket_->read(buffer, read_size);
     273              : 
     274              :     if (read_ec) {
     275              :         ec = read_ec;
     276              :         co_return 0;
     277              :     }
     278              : 
     279              :     if (masked_) unmask(buffer, bytes);
     280              :     frame_remaining_ -= bytes;
     281              : 
     282              :     if (frame_remaining_ == 0 && fin_) {
     283              :         new_message_ = true;
     284              :     } else if (frame_remaining_ == 0) {
     285              :         new_message_ = false;
     286              :     }
     287              : 
     288              :     co_return bytes;
     289          300 : }
     290              : 
     291          146 : awaitable<io_result> websocket::send_message(uint8_t opcode, const uint8_t buffer[], size_t size) {
     292              :     std::lock_guard<std::mutex> lock(write_mutex_);
     293              : 
     294              :     uint8_t header_size = 2;
     295              :     output_[0] = 0x80 | opcode;
     296              : 
     297              :     if (size <= 125) {
     298              :         output_[1] = size;
     299              :     } else if (size <= 65535) {
     300              :         output_[1] = 126;
     301              :         output_[2] = (size >> 8) & 0xff;
     302              :         output_[3] = size & 0xff;
     303              :         header_size += 2;
     304              :     } else {
     305              :         output_[1] = 127;
     306              :         for (int i = 0; i < 8; ++i) {
     307              :             output_[2 + i] = (size >> ((7 - i) * 8)) & 0xff;
     308              :         }
     309              :         header_size += 8;
     310              :     }
     311              : 
     312              :     // Create output buffers
     313              :     std::vector<boost::asio::const_buffer> output_buffers;
     314              :     output_buffers.push_back(boost::asio::buffer(output_, header_size));
     315              : 
     316              :     // Handle masking for client role
     317              :     std::vector<uint8_t> masked_data;
     318              :     if (!server_role_) {
     319              :         output_[1] |= 0b10000000;
     320              : 
     321              :         static random_bytes_engine rbe;
     322              :         uint8_t mask[MASK_SIZE_BYTES];
     323              :         for (int i = 0; i < MASK_SIZE_BYTES; ++i) {
     324              :             mask[i] = rbe();
     325              :             output_[header_size + i] = mask[i];
     326              :         }
     327              :         header_size += MASK_SIZE_BYTES;
     328              : 
     329              :         // Mask the data
     330              :         masked_data.resize(size);
     331              :         for (size_t i = 0; i < size; ++i) {
     332              :             masked_data[i] = buffer[i] ^ mask[i % MASK_SIZE_BYTES];
     333              :         }
     334              : 
     335              :         output_buffers.clear();
     336              :         output_buffers.push_back(boost::asio::buffer(output_, header_size));
     337              :         output_buffers.push_back(boost::asio::buffer(masked_data));
     338              :     } else {
     339              :         output_buffers.push_back(boost::asio::buffer(buffer, size));
     340              :     }
     341              : 
     342              :     LOG_DEBUG("sending websocket data. header: {}, payload: {}", header_size, size);
     343              : 
     344              :     auto [ec, bytes] = co_await socket_->write(output_buffers);
     345              :     if (ec) {
     346              :         co_return io_result{ec, 0};
     347              :     }
     348              :     co_return io_result{ec, bytes - header_size};
     349          292 : }
     350              : 
     351          112 : awaitable<io_result> websocket::read_some(uint8_t buffer[], size_t max_size) {
     352              :     boost::system::error_code ec;
     353              :     auto bytes = co_await read_frame(buffer, max_size, ec);
     354              :     if (ec) {
     355              :         if (ec != boost::asio::error::connection_aborted) {
     356              :             LOG_ERROR("websocket read_some error: {}", ec.message());
     357              :         }
     358              :         co_return io_result{ec, 0};
     359              :     }
     360              :     co_return io_result{boost::system::error_code{}, bytes};
     361          224 : }
     362              : 
     363            0 : awaitable<io_result> websocket::read(uint8_t buffer[], size_t size) {
     364              :     boost::system::error_code ec;
     365              :     auto bytes = co_await read_frame(buffer, size, ec);
     366              :     if (ec) {
     367              :         if (ec != boost::asio::error::connection_aborted) {
     368              :             LOG_ERROR("websocket read error: {}", ec.message());
     369              :         }
     370              :         co_return io_result{ec, 0};
     371              :     }
     372              :     co_return io_result{boost::system::error_code{}, bytes};
     373            0 : }
     374              : 
     375            0 : awaitable<io_result> websocket::read(boost::asio::streambuf& buffer, size_t size) {
     376              :     // Not supported for websocket
     377              :     co_return io_result{boost::asio::error::operation_not_supported, 0};
     378            0 : }
     379              : 
     380            0 : awaitable<io_result> websocket::read_until(boost::asio::streambuf& buffer, std::string_view delim) {
     381              :     // Not supported for websocket
     382              :     co_return io_result{boost::asio::error::operation_not_supported, 0};
     383            0 : }
     384              : 
     385            4 : awaitable<io_result> websocket::write(const uint8_t buffer[], size_t size) {
     386              :     co_return co_await send_message(binary_ ? 0x02 : 0x01, buffer, size);
     387            8 : }
     388              : 
     389           70 : awaitable<io_result> websocket::write(std::string_view str) {
     390              :     co_return co_await send_message(binary_ ? 0x02 : 0x01,
     391              :         reinterpret_cast<const uint8_t*>(str.data()), str.size());
     392          140 : }
     393              : 
     394            0 : awaitable<io_result> websocket::write(const std::vector<boost::asio::const_buffer>& buffers) {
     395              :     // Not supported for websocket
     396              :     co_return io_result{boost::asio::error::operation_not_supported, 0};
     397            0 : }
     398              : 
     399            0 : awaitable<boost::system::error_code> websocket::wait(boost::asio::socket_base::wait_type type) {
     400              :     co_return co_await socket_->wait(type);
     401            0 : }
     402              : 
     403            0 : awaitable<boost::system::error_code> websocket::connect(
     404              :     const std::string& host,
     405              :     const std::string& port,
     406              :     std::chrono::seconds timeout) {
     407              :     co_return co_await socket_->connect(host, port, timeout);
     408            0 : }
     409              : 
     410          110 : void websocket::close() {
     411          110 :     timer_.cancel();
     412          110 :     socket_->close();
     413          110 : }
     414              : 
     415           38 : awaitable<void> websocket::close_graceful() {
     416              :     if (!close_sent_ && socket_->is_open()) {
     417              :         co_await send_close();
     418              :     } else if (socket_->is_open()) {
     419              :         close();
     420              :     }
     421           76 : }
     422              : 
     423            0 : void websocket::cancel() {
     424            0 :     socket_->cancel();
     425            0 : }
     426              : 
     427            0 : bool websocket::requires_handshake() const {
     428            0 :     return socket_->requires_handshake();
     429              : }
     430              : 
     431            0 : awaitable<boost::system::error_code> websocket::handshake(const std::string& host) {
     432              :     co_return co_await socket_->handshake(host);
     433            0 : }
     434              : 
     435          266 : bool websocket::is_open() const {
     436          266 :     return socket_->is_open();
     437              : }
     438              : 
     439            0 : bool websocket::is_secure() const {
     440            0 :     return socket_->is_secure();
     441              : }
     442              : 
     443            0 : size_t websocket::available() const {
     444            0 :     return socket_->available();
     445              : }
     446              : 
     447            0 : std::string websocket::get_remote_ip() const {
     448            0 :     return socket_->get_remote_ip();
     449              : }
     450              : 
     451            0 : std::string websocket::get_local_port() const {
     452            0 :     return socket_->get_local_port();
     453              : }
     454              : 
     455            0 : std::string websocket::get_remote_port() const {
     456            0 :     return socket_->get_remote_port();
     457              : }
     458              : 
     459           38 : std::size_t websocket::remaining_in_frame() const {
     460           38 :     return frame_remaining_;
     461              : }
     462              : 
     463           36 : bool websocket::is_message_complete() const {
     464           36 :     return new_message_;
     465              : }
     466              : 
     467              : }
        

Generated by: LCOV version 2.0-1