Line data Source code
1 : #include "tcp_socket_server.hpp"
2 : #include "workers.hpp"
3 : #include "../util/logger.hpp"
4 : #include "../util/types.hpp"
5 : #include <boost/asio/ssl.hpp>
6 :
7 : namespace thinger::asio {
8 :
9 : // Constructor with io_context providers
10 741 : tcp_socket_server::tcp_socket_server(std::string host,
11 : std::string port,
12 : io_context_provider acceptor_context_provider,
13 : io_context_provider connection_context_provider,
14 : std::set<std::string> allowed_remotes,
15 741 : std::set<std::string> forbidden_remotes)
16 741 : : socket_server_base(std::move(acceptor_context_provider),
17 741 : std::move(connection_context_provider),
18 741 : std::move(allowed_remotes),
19 741 : std::move(forbidden_remotes))
20 741 : , host_(std::move(host))
21 3705 : , port_(std::move(port))
22 : {
23 741 : }
24 :
25 : // Legacy constructor for backward compatibility
26 12 : tcp_socket_server::tcp_socket_server(std::string host,
27 : std::string port,
28 : std::set<std::string> allowed_remotes,
29 12 : std::set<std::string> forbidden_remotes)
30 : : tcp_socket_server(host, port,
31 12 : []() -> boost::asio::io_context& { return get_workers().get_thread_io_context(); },
32 12 : []() -> boost::asio::io_context& { return get_workers().get_next_io_context(); },
33 12 : std::move(allowed_remotes),
34 24 : std::move(forbidden_remotes))
35 : {
36 12 : }
37 :
38 1449 : tcp_socket_server::~tcp_socket_server() {
39 741 : close_acceptor();
40 1449 : }
41 :
42 735 : bool tcp_socket_server::stop() {
43 : // First call base class to set running_ = false
44 735 : socket_server_base::stop();
45 :
46 : // Now close the acceptor
47 735 : close_acceptor();
48 :
49 735 : return true;
50 : }
51 :
52 1476 : void tcp_socket_server::close_acceptor() {
53 : // Close the acceptor to cancel pending async operations, but do NOT
54 : // destroy it (reset) here. The async_accept handler may still be in
55 : // flight on the io_context thread and needs the acceptor alive until
56 : // the handler completes. The unique_ptr will clean up on destruction.
57 1476 : if (acceptor_ && acceptor_->is_open()) {
58 735 : boost::system::error_code ec;
59 735 : acceptor_->close(ec);
60 735 : if (ec) {
61 0 : LOG_WARNING("Error closing TCP acceptor: {}", ec.message());
62 : }
63 : }
64 1476 : }
65 :
66 0 : void tcp_socket_server::set_tcp_no_delay(bool tcp_no_delay) {
67 0 : tcp_no_delay_ = tcp_no_delay;
68 0 : }
69 :
70 52 : void tcp_socket_server::enable_ssl(bool ssl, bool client_certificate) {
71 52 : ssl_enabled_ = ssl;
72 52 : client_certificate_ = client_certificate;
73 52 : }
74 :
75 52 : void tcp_socket_server::set_ssl_context(std::shared_ptr<boost::asio::ssl::context> context) {
76 52 : ssl_context_ = std::move(context);
77 52 : }
78 :
79 52 : void tcp_socket_server::set_sni_callback(sni_callback_type callback) {
80 52 : if (ssl_context_) {
81 52 : SSL_CTX_set_tlsext_servername_callback(ssl_context_->native_handle(), callback);
82 : }
83 52 : }
84 :
85 0 : std::string tcp_socket_server::get_service_name() const {
86 0 : return (ssl_enabled_ ? "ssl_server@" : "tcp_server@") + host_ + ":" + port_;
87 : }
88 :
89 713 : uint16_t tcp_socket_server::local_port() const {
90 713 : return acceptor_ ? acceptor_->local_endpoint().port() : 0;
91 : }
92 :
93 741 : bool tcp_socket_server::create_acceptor() {
94 741 : int num_attempts = 0;
95 :
96 : // Get io_context from provider
97 741 : boost::asio::io_context& io_context = acceptor_context_provider_();
98 :
99 : // Resolve endpoint
100 741 : boost::asio::ip::tcp::endpoint endpoint;
101 : try {
102 741 : boost::asio::ip::tcp::resolver resolver(io_context);
103 741 : auto results = resolver.resolve(host_, port_);
104 738 : if (results.begin() == results.end()) {
105 0 : LOG_ERROR("no endpoints found for {}:{}", host_, port_);
106 0 : return false;
107 : }
108 :
109 738 : auto entry = *results.begin();
110 738 : endpoint = entry.endpoint();
111 744 : } catch (const boost::system::system_error& e) {
112 3 : LOG_ERROR("failed to resolve {}:{} - {}", host_, port_, e.code().message());
113 3 : return false;
114 3 : }
115 :
116 738 : bool success = false;
117 : do {
118 738 : LOG_DEBUG("starting TCP socket acceptor on {}:{}", host_, port_);
119 738 : if (num_attempts > 0) {
120 0 : std::this_thread::sleep_for(std::chrono::seconds(5));
121 : }
122 :
123 738 : acceptor_ = std::make_unique<boost::asio::ip::tcp::acceptor>(io_context);
124 738 : acceptor_->open(endpoint.protocol());
125 738 : acceptor_->set_option(boost::asio::ip::tcp::acceptor::reuse_address(true));
126 :
127 : try {
128 738 : LOG_DEBUG("binding and listening to endpoint: {}:{}",
129 : endpoint.address().to_string(), endpoint.port());
130 738 : acceptor_->bind(endpoint);
131 735 : acceptor_->listen();
132 735 : success = true;
133 3 : } catch (boost::system::system_error& error) {
134 3 : LOG_ERROR("cannot start listening on {}:{}: {}",
135 : host_, port_, error.code().message());
136 : // Reset acceptor if binding failed to avoid inconsistent state
137 3 : acceptor_.reset();
138 3 : if (max_listening_attempts_ >= 0 && num_attempts >= max_listening_attempts_) {
139 0 : return false;
140 : }
141 3 : }
142 738 : num_attempts++;
143 738 : } while (!success && (max_listening_attempts_ < 0 || num_attempts < max_listening_attempts_));
144 :
145 738 : if (success) {
146 735 : LOG_INFO("TCP server is now listening on {}:{}", host_, port_);
147 : }
148 :
149 738 : return success;
150 : }
151 :
152 1700 : void tcp_socket_server::accept_connection() {
153 : // Get next io_context from provider
154 1700 : boost::asio::io_context& io_context = connection_context_provider_();
155 :
156 : // Create socket based on SSL configuration
157 1700 : std::shared_ptr<tcp_socket> sock;
158 1700 : if (ssl_enabled_) {
159 104 : if (!ssl_context_) {
160 0 : LOG_ERROR("SSL enabled but no SSL context configured");
161 0 : return;
162 : }
163 104 : sock = std::make_shared<ssl_socket>("ssl_socket_server", io_context, ssl_context_);
164 : } else {
165 1596 : sock = std::make_shared<tcp_socket>("tcp_socket_server", io_context);
166 : }
167 :
168 1700 : auto& socket = sock->get_socket();
169 :
170 : // Start accepting a connection
171 1700 : acceptor_->async_accept(socket, [sock = std::move(sock), this](const boost::system::error_code& e) mutable {
172 1477 : if (!e) {
173 : // Get remote socket ip
174 968 : auto remote_ip = sock->get_remote_ip();
175 :
176 : // Check if IP is allowed
177 968 : if (!is_remote_allowed(remote_ip)) {
178 0 : sock->close();
179 0 : LOG_WARNING("rejecting connection from: ip: {}, port: {}, secure: {}",
180 : remote_ip, sock->get_local_port(), sock->is_secure());
181 0 : if (running_) accept_connection();
182 0 : return;
183 : }
184 :
185 968 : LOG_INFO("received connection from: ip: {}, port: {}, secure: {}",
186 : remote_ip, sock->get_local_port(), sock->is_secure());
187 :
188 968 : if (tcp_no_delay_) {
189 968 : sock->enable_tcp_no_delay();
190 : }
191 :
192 968 : if (sock->requires_handshake()) {
193 : // Use co_spawn to run the coroutine-based handshake
194 52 : co_spawn(sock->get_io_context(),
195 156 : [this, sock]() -> awaitable<void> {
196 : auto ec = co_await sock->handshake();
197 : if (ec) {
198 : LOG_ERROR("error while handling SSL handshake: {}, remote ip: {}",
199 : ec.message(), sock->get_remote_ip());
200 : co_return;
201 : }
202 : if (handler_) handler_(sock);
203 104 : },
204 : detached);
205 : } else {
206 916 : if (handler_) handler_(std::move(sock));
207 : }
208 :
209 : // Continue accepting connections
210 968 : if (running_) accept_connection();
211 968 : } else {
212 509 : if (e != boost::asio::error::operation_aborted) {
213 0 : LOG_ERROR("cannot accept more connections: {}", e.message());
214 0 : if (running_) {
215 : // Retry after a delay to avoid tight loop on persistent errors
216 : auto timer = std::make_shared<boost::asio::steady_timer>(
217 0 : acceptor_context_provider_(),
218 0 : std::chrono::seconds(1)
219 0 : );
220 0 : timer->async_wait([this, timer](const boost::system::error_code& e) {
221 0 : if (e != boost::asio::error::operation_aborted) {
222 0 : accept_connection();
223 : }
224 0 : });
225 0 : }
226 : } else {
227 509 : LOG_INFO("stop accepting connections");
228 : }
229 : }
230 : });
231 1700 : }
232 :
233 : } // namespace thinger::asio
|