LCOV - code coverage report
Current view: top level - asio/sockets - websocket.cpp (source / functions) Coverage Total Hit
Test: coverage_filtered.info Lines: 55.2 % 96 53
Test Date: 2026-02-20 15:38:22 Functions: 50.0 % 36 18

            Line data    Source code
       1              : #include "websocket.hpp"
       2              : #include "../../util/logger.hpp"
       3              : #include "tcp_socket.hpp"
       4              : 
       5              : #include <random>
       6              : 
       7              : namespace thinger::asio {
       8              : 
       9              : using random_bytes_engine = std::independent_bits_engine<std::default_random_engine, CHAR_BIT, unsigned char>;
      10              : 
      11              : std::atomic<unsigned long> websocket::connections{0};
      12              : 
      13           76 : websocket::websocket(std::shared_ptr<socket> sock, bool binary, bool server)
      14              :     : socket("websocket", sock->get_io_context())
      15           76 :     , socket_(sock)
      16           76 :     , timer_(sock->get_io_context())
      17           76 :     , binary_(binary)
      18          304 :     , server_role_(server) {
      19           76 :     ++connections;
      20           76 :     LOG_DEBUG("websocket created");
      21           76 : }
      22              : 
      23           76 : websocket::~websocket() {
      24           76 :     --connections;
      25           76 :     LOG_DEBUG("releasing websocket");
      26           76 :     timer_.cancel();
      27           76 : }
      28              : 
      29           38 : void websocket::unmask(uint8_t buffer[], std::size_t size) {
      30           38 :     LOG_DEBUG("unmasking payload. size: {}", size);
      31         2498 :     for (size_t i = 0; i < size; ++i) {
      32         2460 :         buffer[i] ^= mask_[i % MASK_SIZE_BYTES];
      33              :     }
      34           38 : }
      35              : 
      36           38 : void websocket::start_timeout() {
      37           38 :     timer_.expires_after(CONNECTION_TIMEOUT_SECONDS);
      38           38 :     timer_.async_wait([this](const boost::system::error_code& e) {
      39           38 :         if (e) {
      40           38 :             if (e != boost::asio::error::operation_aborted) {
      41            0 :                 LOG_ERROR("error on timeout: {}", e.message());
      42              :             }
      43           38 :             return;
      44              :         }
      45            0 :         if (data_received_) {
      46            0 :             data_received_ = false;
      47            0 :             return start_timeout();
      48              :         } else {
      49            0 :             if (!pending_ping_) {
      50            0 :                 pending_ping_ = true;
      51            0 :                 co_spawn(io_context_, [this]() -> awaitable<void> {
      52              :                     co_await send_ping();
      53              :                     if (socket_->is_open()) {
      54              :                         start_timeout();
      55              :                     }
      56            0 :                 }, detached);
      57              :             } else {
      58            0 :                 LOG_DEBUG("websocket ping timeout... closing connection!");
      59            0 :                 close();
      60              :             }
      61              :         }
      62              :     });
      63           38 : }
      64              : 
      65           72 : awaitable<void> websocket::send_close(uint8_t buffer[], size_t size) {
      66              :     LOG_DEBUG("sending close frame");
      67              :     co_await send_message(0x88, buffer, size);
      68              :     close_sent_ = true;
      69              :     LOG_DEBUG("close frame sent");
      70              : 
      71              :     if (!close_received_) {
      72              :         // Race: read the close ack vs timeout
      73              :         timer_.cancel();
      74              :         timer_.expires_after(std::chrono::seconds{5});
      75              : 
      76           38 :         auto read_close_ack = [this]() -> awaitable<void> {
      77              :             uint8_t buf[125];
      78              :             boost::system::error_code ec;
      79              :             while (!close_received_ && socket_->is_open()) {
      80              :                 co_await read_frame(buf, sizeof(buf), ec);
      81              :                 if (ec) break;
      82              :             }
      83           76 :         };
      84              : 
      85           38 :         auto wait_timeout = [this]() -> awaitable<void> {
      86              :             auto [ec] = co_await timer_.async_wait(use_nothrow_awaitable);
      87           76 :         };
      88              : 
      89              :         co_await (read_close_ack() || wait_timeout());
      90              : 
      91              :         if (!close_received_) {
      92              :             LOG_WARNING("timeout while waiting close acknowledgement");
      93              :         }
      94              :     }
      95              :     close();
      96          144 : }
      97              : 
      98            0 : awaitable<void> websocket::send_ping(uint8_t buffer[], size_t size) {
      99              :     LOG_DEBUG("sending ping frame");
     100              :     co_await send_message(0x09, buffer, size);
     101              :     LOG_DEBUG("ping frame sent");
     102            0 : }
     103              : 
     104            0 : awaitable<void> websocket::send_pong(uint8_t buffer[], size_t size) {
     105              :     LOG_DEBUG("sending pong frame");
     106              :     co_await send_message(0x0A, buffer, size);
     107              :     LOG_DEBUG("pong frame sent");
     108            0 : }
     109              : 
     110          150 : awaitable<size_t> websocket::read_frame(uint8_t buffer[], size_t max_size, boost::system::error_code& ec) {
     111              :     // If there's remaining data in current frame, read it
     112              :     if (frame_remaining_ > 0) {
     113              :         auto read_size = std::min(frame_remaining_, max_size);
     114              :         auto bytes = co_await socket_->read(buffer, read_size);
     115              : 
     116              :         if (bytes == 0) {
     117              :             ec = boost::asio::error::connection_reset;
     118              :             co_return 0;
     119              :         }
     120              : 
     121              :         if (masked_) unmask(buffer, bytes);
     122              :         frame_remaining_ -= bytes;
     123              : 
     124              :         co_return bytes;
     125              :     }
     126              : 
     127              :     // Read frame header (2 bytes minimum)
     128              :     if (co_await socket_->read(buffer_, 2) != 2) {
     129              :         ec = boost::asio::error::connection_reset;
     130              :         co_return 0;
     131              :     }
     132              :     data_received_ = true;
     133              : 
     134              :     uint8_t data_type = buffer_[0];
     135              :     uint8_t fin = data_type & 0b10000000;
     136              :     uint8_t rsv = data_type & 0b01110000;
     137              : 
     138              :     if (rsv) {
     139              :         LOG_ERROR("invalid RSV parameters");
     140              :         ec = boost::asio::error::invalid_argument;
     141              :         co_return 0;
     142              :     }
     143              : 
     144              :     uint8_t opcode = data_type & 0x0F;
     145              :     uint8_t data_size = buffer_[1] & ~(1 << 7);
     146              :     uint8_t masked = buffer_[1] & 0b10000000;
     147              : 
     148              :     LOG_DEBUG("decoded frame header. fin: {}, opcode: 0x{:02X} mask: {} data_size: {}", fin, opcode, masked, data_size);
     149              : 
     150              :     if (!masked && server_role_) {
     151              :         LOG_ERROR("client is not masking the information");
     152              :         ec = boost::asio::error::invalid_argument;
     153              :         co_return 0;
     154              :     }
     155              : 
     156              :     masked_ = masked;
     157              :     fin_ = fin;
     158              :     opcode_ = opcode;
     159              : 
     160              :     // Handle opcodes
     161              :     if (new_message_) {
     162              :         switch (opcode) {
     163              :             case 0x0:
     164              :                 LOG_ERROR("received continuation message as the first message!");
     165              :                 ec = boost::asio::error::invalid_argument;
     166              :                 co_return 0;
     167              :             case 0x1: // text
     168              :             case 0x2: // binary
     169              :                 message_opcode_ = opcode;
     170              :                 break;
     171              :             case 0x8: // close
     172              :             case 0x9: // ping
     173              :             case 0xA: // pong
     174              :                 if (!fin) {
     175              :                     LOG_ERROR("control frame messages cannot be fragmented");
     176              :                     ec = boost::asio::error::invalid_argument;
     177              :                     co_return 0;
     178              :                 }
     179              :                 break;
     180              :             default:
     181              :                 LOG_ERROR("received unknown websocket opcode: {}", (int)opcode);
     182              :                 ec = boost::asio::error::invalid_argument;
     183              :                 co_return 0;
     184              :         }
     185              :     } else {
     186              :         // Continuation frame expected
     187              :         if (opcode != 0x0 && opcode < 0x8) {
     188              :             LOG_ERROR("unexpected fragment type. expecting a continuation frame");
     189              :             ec = boost::asio::error::invalid_argument;
     190              :             co_return 0;
     191              :         }
     192              :     }
     193              : 
     194              :     // Determine payload length
     195              :     uint64_t payload_size = data_size;
     196              :     if (data_size == 126) {
     197              :         if (co_await socket_->read(buffer_, 2) != 2) {
     198              :             ec = boost::asio::error::connection_reset;
     199              :             co_return 0;
     200              :         }
     201              :         payload_size = (buffer_[0] << 8) | buffer_[1];
     202              :     } else if (data_size == 127) {
     203              :         if (co_await socket_->read(buffer_, 8) != 8) {
     204              :             ec = boost::asio::error::connection_reset;
     205              :             co_return 0;
     206              :         }
     207              :         payload_size = 0;
     208              :         for (int i = 0; i < 8; ++i) {
     209              :             payload_size = (payload_size << 8) | buffer_[i];
     210              :         }
     211              :     }
     212              : 
     213              :     frame_remaining_ = payload_size;
     214              : 
     215              :     // Read mask if present
     216              :     if (masked_) {
     217              :         if (co_await socket_->read(mask_, MASK_SIZE_BYTES) != MASK_SIZE_BYTES) {
     218              :             ec = boost::asio::error::connection_reset;
     219              :             co_return 0;
     220              :         }
     221              :     }
     222              : 
     223              :     // Handle control frames
     224              :     if (opcode >= 0x8) {
     225              :         // Read control frame payload
     226              :         uint8_t control_buffer[125];
     227              :         size_t control_size = std::min(static_cast<size_t>(payload_size), size_t(125));
     228              :         if (control_size > 0) {
     229              :             if (co_await socket_->read(control_buffer, control_size) != control_size) {
     230              :                 ec = boost::asio::error::connection_reset;
     231              :                 co_return 0;
     232              :             }
     233              :             if (masked_) unmask(control_buffer, control_size);
     234              :         }
     235              :         frame_remaining_ = 0;
     236              : 
     237              :         switch (opcode) {
     238              :             case 0x8: // close
     239              :                 LOG_DEBUG("received close frame");
     240              :                 close_received_ = true;
     241              :                 if (!close_sent_) {
     242              :                     co_await send_close();
     243              :                 }
     244              :                 ec = boost::asio::error::connection_aborted;
     245              :                 co_return 0;
     246              :             case 0x9: // ping
     247              :                 LOG_DEBUG("received ping frame");
     248              :                 co_await send_pong(control_buffer, control_size);
     249              :                 co_return co_await read_frame(buffer, max_size, ec);
     250              :             case 0xA: // pong
     251              :                 LOG_DEBUG("received pong frame");
     252              :                 pending_ping_ = false;
     253              :                 data_received_ = false;
     254              :                 co_return co_await read_frame(buffer, max_size, ec);
     255              :         }
     256              :     }
     257              : 
     258              :     // Read data frame payload
     259              :     if (frame_remaining_ == 0) {
     260              :         if (fin_) new_message_ = true;
     261              :         co_return 0;
     262              :     }
     263              : 
     264              :     auto read_size = std::min(frame_remaining_, max_size);
     265              :     auto bytes = co_await socket_->read(buffer, read_size);
     266              : 
     267              :     if (bytes == 0) {
     268              :         ec = boost::asio::error::connection_reset;
     269              :         co_return 0;
     270              :     }
     271              : 
     272              :     if (masked_) unmask(buffer, bytes);
     273              :     frame_remaining_ -= bytes;
     274              : 
     275              :     if (frame_remaining_ == 0 && fin_) {
     276              :         new_message_ = true;
     277              :     } else if (frame_remaining_ == 0) {
     278              :         new_message_ = false;
     279              :     }
     280              : 
     281              :     co_return bytes;
     282          300 : }
     283              : 
     284          146 : awaitable<size_t> websocket::send_message(uint8_t opcode, const uint8_t buffer[], size_t size) {
     285              :     std::lock_guard<std::mutex> lock(write_mutex_);
     286              : 
     287              :     uint8_t header_size = 2;
     288              :     output_[0] = 0x80 | opcode;
     289              : 
     290              :     if (size <= 125) {
     291              :         output_[1] = size;
     292              :     } else if (size <= 65535) {
     293              :         output_[1] = 126;
     294              :         output_[2] = (size >> 8) & 0xff;
     295              :         output_[3] = size & 0xff;
     296              :         header_size += 2;
     297              :     } else {
     298              :         output_[1] = 127;
     299              :         for (int i = 0; i < 8; ++i) {
     300              :             output_[2 + i] = (size >> ((7 - i) * 8)) & 0xff;
     301              :         }
     302              :         header_size += 8;
     303              :     }
     304              : 
     305              :     // Create output buffers
     306              :     std::vector<boost::asio::const_buffer> output_buffers;
     307              :     output_buffers.push_back(boost::asio::buffer(output_, header_size));
     308              : 
     309              :     // Handle masking for client role
     310              :     std::vector<uint8_t> masked_data;
     311              :     if (!server_role_) {
     312              :         output_[1] |= 0b10000000;
     313              : 
     314              :         static random_bytes_engine rbe;
     315              :         uint8_t mask[MASK_SIZE_BYTES];
     316              :         for (int i = 0; i < MASK_SIZE_BYTES; ++i) {
     317              :             mask[i] = rbe();
     318              :             output_[header_size + i] = mask[i];
     319              :         }
     320              :         header_size += MASK_SIZE_BYTES;
     321              : 
     322              :         // Mask the data
     323              :         masked_data.resize(size);
     324              :         for (size_t i = 0; i < size; ++i) {
     325              :             masked_data[i] = buffer[i] ^ mask[i % MASK_SIZE_BYTES];
     326              :         }
     327              : 
     328              :         output_buffers.clear();
     329              :         output_buffers.push_back(boost::asio::buffer(output_, header_size));
     330              :         output_buffers.push_back(boost::asio::buffer(masked_data));
     331              :     } else {
     332              :         output_buffers.push_back(boost::asio::buffer(buffer, size));
     333              :     }
     334              : 
     335              :     LOG_DEBUG("sending websocket data. header: {}, payload: {}", header_size, size);
     336              : 
     337              :     auto bytes = co_await socket_->write(output_buffers);
     338              :     if (bytes == 0) {
     339              :         co_return 0;
     340              :     }
     341              :     co_return bytes - header_size;
     342          292 : }
     343              : 
     344          112 : awaitable<size_t> websocket::read_some(uint8_t buffer[], size_t max_size) {
     345              :     boost::system::error_code ec;
     346              :     auto bytes = co_await read_frame(buffer, max_size, ec);
     347              :     if (ec) {
     348              :         if (ec != boost::asio::error::connection_aborted) {
     349              :             LOG_ERROR("websocket read_some error: {}", ec.message());
     350              :         }
     351              :         co_return 0;
     352              :     }
     353              :     co_return bytes;
     354          224 : }
     355              : 
     356            0 : awaitable<size_t> websocket::read(uint8_t buffer[], size_t size) {
     357              :     boost::system::error_code ec;
     358              :     auto bytes = co_await read_frame(buffer, size, ec);
     359              :     if (ec) {
     360              :         if (ec != boost::asio::error::connection_aborted) {
     361              :             LOG_ERROR("websocket read error: {}", ec.message());
     362              :         }
     363              :         co_return 0;
     364              :     }
     365              :     co_return bytes;
     366            0 : }
     367              : 
     368            0 : awaitable<size_t> websocket::read(boost::asio::streambuf& buffer, size_t size) {
     369              :     // Not supported for websocket
     370              :     co_return 0;
     371            0 : }
     372              : 
     373            0 : awaitable<size_t> websocket::read_until(boost::asio::streambuf& buffer, std::string_view delim) {
     374              :     // Not supported for websocket
     375              :     co_return 0;
     376            0 : }
     377              : 
     378            4 : awaitable<size_t> websocket::write(const uint8_t buffer[], size_t size) {
     379              :     co_return co_await send_message(binary_ ? 0x02 : 0x01, buffer, size);
     380            8 : }
     381              : 
     382           70 : awaitable<size_t> websocket::write(std::string_view str) {
     383              :     co_return co_await send_message(binary_ ? 0x02 : 0x01,
     384              :         reinterpret_cast<const uint8_t*>(str.data()), str.size());
     385          140 : }
     386              : 
     387            0 : awaitable<size_t> websocket::write(const std::vector<boost::asio::const_buffer>& buffers) {
     388              :     // Not supported for websocket
     389              :     co_return 0;
     390            0 : }
     391              : 
     392            0 : awaitable<boost::system::error_code> websocket::wait(boost::asio::socket_base::wait_type type) {
     393              :     co_return co_await socket_->wait(type);
     394            0 : }
     395              : 
     396            0 : awaitable<boost::system::error_code> websocket::connect(
     397              :     const std::string& host,
     398              :     const std::string& port,
     399              :     std::chrono::seconds timeout) {
     400              :     co_return co_await socket_->connect(host, port, timeout);
     401            0 : }
     402              : 
     403          110 : void websocket::close() {
     404          110 :     timer_.cancel();
     405          110 :     socket_->close();
     406          110 : }
     407              : 
     408           38 : awaitable<void> websocket::close_graceful() {
     409              :     if (!close_sent_ && socket_->is_open()) {
     410              :         co_await send_close();
     411              :     } else if (socket_->is_open()) {
     412              :         close();
     413              :     }
     414           76 : }
     415              : 
     416            0 : void websocket::cancel() {
     417            0 :     socket_->cancel();
     418            0 : }
     419              : 
     420            0 : bool websocket::requires_handshake() const {
     421            0 :     return socket_->requires_handshake();
     422              : }
     423              : 
     424            0 : awaitable<boost::system::error_code> websocket::handshake(const std::string& host) {
     425              :     co_return co_await socket_->handshake(host);
     426            0 : }
     427              : 
     428          340 : bool websocket::is_open() const {
     429          340 :     return socket_->is_open();
     430              : }
     431              : 
     432            0 : bool websocket::is_secure() const {
     433            0 :     return socket_->is_secure();
     434              : }
     435              : 
     436            0 : size_t websocket::available() const {
     437            0 :     return socket_->available();
     438              : }
     439              : 
     440            0 : std::string websocket::get_remote_ip() const {
     441            0 :     return socket_->get_remote_ip();
     442              : }
     443              : 
     444            0 : std::string websocket::get_local_port() const {
     445            0 :     return socket_->get_local_port();
     446              : }
     447              : 
     448            0 : std::string websocket::get_remote_port() const {
     449            0 :     return socket_->get_remote_port();
     450              : }
     451              : 
     452           38 : std::size_t websocket::remaining_in_frame() const {
     453           38 :     return frame_remaining_;
     454              : }
     455              : 
     456           36 : bool websocket::is_message_complete() const {
     457           36 :     return new_message_;
     458              : }
     459              : 
     460              : }
        

Generated by: LCOV version 2.0-1