diff --git a/cereal/SConscript b/cereal/SConscript index a58a9490ce..73dc61844b 100644 --- a/cereal/SConscript +++ b/cereal/SConscript @@ -13,7 +13,7 @@ cereal = env.Library('cereal', [f'gen/cpp/{s}.c++' for s in schema_files]) # Build messaging services_h = env.Command(['services.h'], ['services.py'], 'python3 ' + cereal_dir.path + '/services.py > $TARGET') -env.Program('messaging/bridge', ['messaging/bridge.cc', 'messaging/msgq_to_zmq.cc'], LIBS=[msgq, common, 'pthread']) +env.Program('messaging/bridge', ['messaging/bridge.cc', 'messaging/msgq_to_zmq.cc', 'messaging/bridge_zmq.cc'], LIBS=[msgq, common, 'pthread']) socketmaster = env.Library('socketmaster', ['messaging/socketmaster.cc']) diff --git a/cereal/messaging/bridge.cc b/cereal/messaging/bridge.cc index fb92c575c9..2eb3de32c0 100644 --- a/cereal/messaging/bridge.cc +++ b/cereal/messaging/bridge.cc @@ -25,14 +25,14 @@ void msgq_to_zmq(const std::vector &endpoints, const std::string &i } void zmq_to_msgq(const std::vector &endpoints, const std::string &ip) { - auto poller = std::make_unique(); + auto poller = std::make_unique(); auto pub_context = std::make_unique(); - auto sub_context = std::make_unique(); - std::map sub2pub; + auto sub_context = std::make_unique(); + std::map sub2pub; for (auto endpoint : endpoints) { auto pub_sock = new MSGQPubSocket(); - auto sub_sock = new ZMQSubSocket(); + auto sub_sock = new BridgeZmqSubSocket(); size_t queue_size = services.at(endpoint).queue_size; pub_sock->connect(pub_context.get(), endpoint, true, queue_size); sub_sock->connect(sub_context.get(), endpoint, ip, false); diff --git a/cereal/messaging/bridge_zmq.cc b/cereal/messaging/bridge_zmq.cc new file mode 100644 index 0000000000..5c56673b47 --- /dev/null +++ b/cereal/messaging/bridge_zmq.cc @@ -0,0 +1,170 @@ +#include "cereal/messaging/bridge_zmq.h" + +#include +#include +#include + +static size_t fnv1a_hash(const std::string &str) { + const size_t fnv_prime = 0x100000001b3; + size_t hash_value = 0xcbf29ce484222325; + for (char c : str) { + hash_value ^= (unsigned char)c; + hash_value *= fnv_prime; + } + return hash_value; +} + +// FIXME: This is a hack to get the port number from the socket name, might have collisions. +static int get_port(std::string endpoint) { + size_t hash_value = fnv1a_hash(endpoint); + int start_port = 8023; + int max_port = 65535; + return start_port + (hash_value % (max_port - start_port)); +} + +BridgeZmqContext::BridgeZmqContext() { + context = zmq_ctx_new(); +} + +BridgeZmqContext::~BridgeZmqContext() { + if (context != nullptr) { + zmq_ctx_term(context); + } +} + +void BridgeZmqMessage::init(size_t sz) { + size = sz; + data = new char[size]; +} + +void BridgeZmqMessage::init(char *d, size_t sz) { + size = sz; + data = new char[size]; + memcpy(data, d, size); +} + +void BridgeZmqMessage::close() { + if (size > 0) { + delete[] data; + } + data = nullptr; + size = 0; +} + +BridgeZmqMessage::~BridgeZmqMessage() { + close(); +} + +int BridgeZmqSubSocket::connect(BridgeZmqContext *context, std::string endpoint, std::string address, bool conflate, bool check_endpoint) { + sock = zmq_socket(context->getRawContext(), ZMQ_SUB); + if (sock == nullptr) { + return -1; + } + + zmq_setsockopt(sock, ZMQ_SUBSCRIBE, "", 0); + + if (conflate) { + int arg = 1; + zmq_setsockopt(sock, ZMQ_CONFLATE, &arg, sizeof(int)); + } + + int reconnect_ivl = 500; + zmq_setsockopt(sock, ZMQ_RECONNECT_IVL_MAX, &reconnect_ivl, sizeof(reconnect_ivl)); + + full_endpoint = "tcp://" + address + ":"; + if (check_endpoint) { + full_endpoint += std::to_string(get_port(endpoint)); + } else { + full_endpoint += endpoint; + } + + return zmq_connect(sock, full_endpoint.c_str()); +} + +void BridgeZmqSubSocket::setTimeout(int timeout) { + zmq_setsockopt(sock, ZMQ_RCVTIMEO, &timeout, sizeof(int)); +} + +Message *BridgeZmqSubSocket::receive(bool non_blocking) { + zmq_msg_t msg; + assert(zmq_msg_init(&msg) == 0); + + int flags = non_blocking ? ZMQ_DONTWAIT : 0; + int rc = zmq_msg_recv(&msg, sock, flags); + + Message *ret = nullptr; + if (rc >= 0) { + ret = new BridgeZmqMessage; + ret->init((char *)zmq_msg_data(&msg), zmq_msg_size(&msg)); + } + + zmq_msg_close(&msg); + return ret; +} + +BridgeZmqSubSocket::~BridgeZmqSubSocket() { + if (sock != nullptr) { + zmq_close(sock); + } +} + +int BridgeZmqPubSocket::connect(BridgeZmqContext *context, std::string endpoint, bool check_endpoint) { + sock = zmq_socket(context->getRawContext(), ZMQ_PUB); + if (sock == nullptr) { + return -1; + } + + full_endpoint = "tcp://*:"; + if (check_endpoint) { + full_endpoint += std::to_string(get_port(endpoint)); + } else { + full_endpoint += endpoint; + } + + // ZMQ pub sockets cannot be shared between processes, so we need to ensure pid stays the same. + pid = getpid(); + + return zmq_bind(sock, full_endpoint.c_str()); +} + +int BridgeZmqPubSocket::sendMessage(Message *message) { + assert(pid == getpid()); + return zmq_send(sock, message->getData(), message->getSize(), ZMQ_DONTWAIT); +} + +int BridgeZmqPubSocket::send(char *data, size_t size) { + assert(pid == getpid()); + return zmq_send(sock, data, size, ZMQ_DONTWAIT); +} + +BridgeZmqPubSocket::~BridgeZmqPubSocket() { + if (sock != nullptr) { + zmq_close(sock); + } +} + +void BridgeZmqPoller::registerSocket(BridgeZmqSubSocket *socket) { + assert(num_polls + 1 < (sizeof(polls) / sizeof(polls[0]))); + polls[num_polls].socket = socket->getRawSocket(); + polls[num_polls].events = ZMQ_POLLIN; + + sockets.push_back(socket); + num_polls++; +} + +std::vector BridgeZmqPoller::poll(int timeout) { + std::vector ret; + + int rc = zmq_poll(polls, num_polls, timeout); + if (rc < 0) { + return ret; + } + + for (size_t i = 0; i < num_polls; i++) { + if (polls[i].revents) { + ret.push_back(sockets[i]); + } + } + + return ret; +} diff --git a/cereal/messaging/bridge_zmq.h b/cereal/messaging/bridge_zmq.h new file mode 100644 index 0000000000..eab48a91e8 --- /dev/null +++ b/cereal/messaging/bridge_zmq.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include +#include + +#include + +#include "msgq/ipc.h" + +class BridgeZmqContext { +public: + BridgeZmqContext(); + void *getRawContext() { return context; } + ~BridgeZmqContext(); + +private: + void *context = nullptr; +}; + +class BridgeZmqMessage : public Message { +public: + void init(size_t size) override; + void init(char *data, size_t size) override; + void close() override; + size_t getSize() override { return size; } + char *getData() override { return data; } + ~BridgeZmqMessage() override; + +private: + char *data = nullptr; + size_t size = 0; +}; + +class BridgeZmqSubSocket { +public: + int connect(BridgeZmqContext *context, std::string endpoint, std::string address, bool conflate = false, bool check_endpoint = true); + void setTimeout(int timeout); + Message *receive(bool non_blocking = false); + void *getRawSocket() { return sock; } + ~BridgeZmqSubSocket(); + +private: + void *sock = nullptr; + std::string full_endpoint; +}; + +class BridgeZmqPubSocket { +public: + int connect(BridgeZmqContext *context, std::string endpoint, bool check_endpoint = true); + int sendMessage(Message *message); + int send(char *data, size_t size); + void *getRawSocket() { return sock; } + ~BridgeZmqPubSocket(); + +private: + void *sock = nullptr; + std::string full_endpoint; + int pid = -1; +}; + +class BridgeZmqPoller { +public: + void registerSocket(BridgeZmqSubSocket *socket); + std::vector poll(int timeout); + +private: + static constexpr size_t MAX_BRIDGE_ZMQ_POLLERS = 128; + std::vector sockets; + zmq_pollitem_t polls[MAX_BRIDGE_ZMQ_POLLERS] = {}; + size_t num_polls = 0; +}; diff --git a/cereal/messaging/msgq_to_zmq.cc b/cereal/messaging/msgq_to_zmq.cc index 7f8c738d4d..930b617e7b 100644 --- a/cereal/messaging/msgq_to_zmq.cc +++ b/cereal/messaging/msgq_to_zmq.cc @@ -22,14 +22,14 @@ static std::string recv_zmq_msg(void *sock) { } void MsgqToZmq::run(const std::vector &endpoints, const std::string &ip) { - zmq_context = std::make_unique(); + zmq_context = std::make_unique(); msgq_context = std::make_unique(); // Create ZMQPubSockets for each endpoint for (const auto &endpoint : endpoints) { auto &socket_pair = socket_pairs.emplace_back(); socket_pair.endpoint = endpoint; - socket_pair.pub_sock = std::make_unique(); + socket_pair.pub_sock = std::make_unique(); int ret = socket_pair.pub_sock->connect(zmq_context.get(), endpoint); if (ret != 0) { printf("Failed to create ZMQ publisher for [%s]: %s\n", endpoint.c_str(), zmq_strerror(zmq_errno())); @@ -49,7 +49,7 @@ void MsgqToZmq::run(const std::vector &endpoints, const std::string for (auto sub_sock : msgq_poller->poll(100)) { // Process messages for each socket - ZMQPubSocket *pub_sock = sub2pub.at(sub_sock); + BridgeZmqPubSocket *pub_sock = sub2pub.at(sub_sock); for (int i = 0; i < MAX_MESSAGES_PER_SOCKET; ++i) { auto msg = std::unique_ptr(sub_sock->receive(true)); if (!msg) break; @@ -72,7 +72,7 @@ void MsgqToZmq::zmqMonitorThread() { // Set up ZMQ monitor for each pub socket for (int i = 0; i < socket_pairs.size(); ++i) { std::string addr = "inproc://op-bridge-monitor-" + std::to_string(i); - zmq_socket_monitor(socket_pairs[i].pub_sock->sock, addr.c_str(), ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_DISCONNECTED); + zmq_socket_monitor(socket_pairs[i].pub_sock->getRawSocket(), addr.c_str(), ZMQ_EVENT_ACCEPTED | ZMQ_EVENT_DISCONNECTED); void *monitor_socket = zmq_socket(zmq_context->getRawContext(), ZMQ_PAIR); zmq_connect(monitor_socket, addr.c_str()); @@ -130,7 +130,7 @@ void MsgqToZmq::zmqMonitorThread() { // Clean up monitor sockets for (int i = 0; i < pollitems.size(); ++i) { - zmq_socket_monitor(socket_pairs[i].pub_sock->sock, nullptr, 0); + zmq_socket_monitor(socket_pairs[i].pub_sock->getRawSocket(), nullptr, 0); zmq_close(pollitems[i].socket); } cv.notify_one(); diff --git a/cereal/messaging/msgq_to_zmq.h b/cereal/messaging/msgq_to_zmq.h index ebdbe5df69..ac92631fe7 100644 --- a/cereal/messaging/msgq_to_zmq.h +++ b/cereal/messaging/msgq_to_zmq.h @@ -7,9 +7,8 @@ #include #include -#define private public #include "msgq/impl_msgq.h" -#include "msgq/impl_zmq.h" +#include "cereal/messaging/bridge_zmq.h" class MsgqToZmq { public: @@ -22,16 +21,16 @@ protected: struct SocketPair { std::string endpoint; - std::unique_ptr pub_sock; + std::unique_ptr pub_sock; std::unique_ptr sub_sock; int connected_clients = 0; }; std::unique_ptr msgq_context; - std::unique_ptr zmq_context; + std::unique_ptr zmq_context; std::mutex mutex; std::condition_variable cv; std::unique_ptr msgq_poller; - std::map sub2pub; + std::map sub2pub; std::vector socket_pairs; };