#include "udp_socket_wrapper.h"
#include <asio_helpers.h>
#include <iostream>
#include <cinttypes>
cUdpSocketWrapper::cUdpSocketWrapper(std::function<void(const void*, size_t)> fnOnReceive):
m_oUdpSocket(m_oIoContext), m_fnOnReceive(std::move(fnOnReceive))
{
m_oIoContext.stop();
}
cUdpSocketWrapper::~cUdpSocketWrapper() = default;
tResult cUdpSocketWrapper::Open(
const std::string& strInterface,
uint16_t nLocalPort,
const std::string& strMulticastGroup,
size_t nReceiveBufferSize,
size_t nSocketReceiveBufferSize)
{
if (m_oUdpSocket.is_open())
{
RETURN_ERROR_DESC(ERR_INVALID_STATE,
"Can't reconfigure socket - socket is already open.");
}
const auto oBindAddress = resolve_address(strInterface);
asio::ip::udp::socket::protocol_type protocol = oBindAddress.is_v4() ? asio::ip::udp::socket::protocol_type::v4() : asio::ip::udp::socket::protocol_type::v6();
RETURN_IF_THROWS(m_oUdpSocket.set_option(asio::ip::udp::socket::reuse_address(
true)));
if (nSocketReceiveBufferSize != 0)
{
try
{
asio::ip::udp::socket::receive_buffer_size oRcvBufferSizeOptionGet;
m_oUdpSocket.get_option(oRcvBufferSizeOptionGet);
m_oUdpSocket.set_option(asio::ip::udp::socket::receive_buffer_size(static_cast<int>(nSocketReceiveBufferSize)));
LOG_INFO(
"Sockets buffersize set from %d to %zu bytes.", oRcvBufferSizeOptionGet.value(),
nSocketReceiveBufferSize);
}
catch (const asio::error_code&)
{
return adtf::base::current_exception_to_result();
}
}
if (strMulticastGroup.empty())
{
RETURN_IF_THROWS(m_oUdpSocket.bind(asio::ip::udp::endpoint(oBindAddress, nLocalPort)));
}
else
{
const auto oMulticastAddress = resolve_address(strMulticastGroup);
if (!oMulticastAddress.is_multicast())
{
RETURN_ERROR_DESC(ERR_INVALID_ARG,
"%s is not a multicast address.", strMulticastGroup.c_str());
}
if (!oBindAddress.is_unspecified() &&
oMulticastAddress.is_v4() != oBindAddress.is_v4())
{
RETURN_ERROR_DESC(ERR_INVALID_ARG,
"Both the multicast address and the bind interface address need to be of the same ip version.");
}
#ifdef _WIN32
RETURN_IF_THROWS(m_oUdpSocket.bind(asio::ip::udp::endpoint(oBindAddress, nLocalPort)));
#else
RETURN_IF_THROWS(m_oUdpSocket.bind(asio::ip::udp::endpoint(oMulticastAddress, nLocalPort)));
#endif
if (oMulticastAddress.is_v4())
{
RETURN_IF_THROWS(m_oUdpSocket.set_option(asio::ip::multicast::join_group(oMulticastAddress.to_v4(), oBindAddress.to_v4())));
}
else
{
RETURN_IF_THROWS(m_oUdpSocket.set_option(asio::ip::multicast::join_group(oMulticastAddress.to_v6(), oBindAddress.to_v6().scope_id())));
}
}
if (m_fnOnReceive)
{
std::vector<asio::mutable_buffer> oOriginalBuffers;
for (size_t i = 0; i < m_nQueueDepth; ++i)
{
m_oHandlers[i] = std::make_shared<tPacketHandler>(i, nReceiveBufferSize);
oOriginalBuffers.emplace_back(m_oHandlers[i]->oBuffer.data(), m_oHandlers[i]->oBuffer.size());
}
m_pAsioBufferRegistration = std::make_unique<asio::buffer_registration<std::vector<asio::mutable_buffer>>>(
asio::register_buffers(m_oIoContext, oOriginalBuffers));
for (size_t i = 0; i < m_nQueueDepth; ++i)
{
m_oHandlers[i]->oRegisteredBuffer = (*m_pAsioBufferRegistration)[i];
}
#if defined(OS_HAS_RECVMMSG) && !defined(ASIO_HAS_IO_URING_AS_DEFAULT)
ScheduleReadMany();
#else
for (auto& oBuffer : m_oHandlers)
{
ScheduleReadOne(oBuffer);
}
#endif
}
}
tResult cUdpSocketWrapper::SetRemote(
const std::string& strRemoteAddress, uint16_t nRemotePort)
{
RETURN_IF_THROWS(m_oUdpRemoteEndpoint = asio::ip::udp::endpoint(resolve_address(strRemoteAddress), nRemotePort));
}
tResult cUdpSocketWrapper::Send(
const void* pData,
size_t nDataSize)
{
asio::ip::udp::endpoint oDestination;
if (m_oUdpRemoteEndpoint)
{
oDestination = *m_oUdpRemoteEndpoint;
}
else if (m_oLastReceivedUdpRemoteEndpoint)
{
oDestination = *m_oLastReceivedUdpRemoteEndpoint;
}
else
{
RETURN_ERROR_DESC(ERR_INVALID_STATE,
"Unable to send data via UDP, remote endpoint is not yet known.");
}
size_t nDataSent = 0;
RETURN_IF_THROWS(nDataSent = m_oUdpSocket.send_to(asio::buffer(pData, nDataSize), oDestination));
if (static_cast<size_t>(nDataSent) != nDataSize)
{
RETURN_ERROR_DESC(ERR_INVALID_ARG,
"Unable to send data via UDP, data size is too large: %" PRIu64, nDataSize);
}
}
{
std::unique_lock oLock(m_oRunningMutex);
if (m_bRunning)
{
}
if (m_oIoContext.stopped())
{
}
m_bRunning = true;
m_oRunningCv.notify_all();
}
{
std::unique_lock oLock(m_oRunningMutex);
{
oLock.unlock();
m_oIoContext.run();
oLock.lock();
m_bRunning = false;
m_oRunningCv.notify_all();
}
else
{
}
}
{
{
std::unique_lock oLock(m_oRunningMutex);
m_oIoContext.stop();
m_oRunningCv.wait(oLock, [this]() -> bool { return !m_bRunning; });
}
}
#if defined(OS_HAS_RECVMMSG)
void cUdpSocketWrapper::ScheduleReadMany() noexcept
{
m_oUdpSocket.async_wait(asio::ip::udp::socket::wait_read, std::bind(&cUdpSocketWrapper::ReadMany, this, std::placeholders::_1));
}
void cUdpSocketWrapper::ReadMany(const std::error_code& nError) noexcept
{
if (!nError)
{
do {
std::array<sockaddr_storage, m_nQueueDepth> oAddresses = {0};
std::array<iovec, m_nQueueDepth> oBuffers = {0};
std::array<mmsghdr, m_nQueueDepth> oMessages = {0};
for (size_t i = 0; i < m_nQueueDepth; ++i)
{
oMessages[i].msg_hdr.msg_name = &oAddresses[i];
oMessages[i].msg_hdr.msg_namelen = sizeof(sockaddr_storage);
oBuffers[i].iov_base = m_oHandlers[i]->oBuffer.data();
oBuffers[i].iov_len = m_oHandlers[i]->oBuffer.size();
oMessages[i].msg_hdr.msg_iov = &oBuffers[i];
oMessages[i].msg_hdr.msg_iovlen = 1;
oMessages[i].msg_len = 0;
}
const auto nReceived =
::recvmmsg(m_oUdpSocket.native_handle(), oMessages.data(), oMessages.size(), MSG_DONTWAIT, nullptr);
if (nReceived > 0)
{
for (int i = 0; i < nReceived; ++i)
{
const auto& oMessage = oMessages[i];
if (m_fnOnReceive && oMessage.msg_hdr.msg_iovlen > 0)
{
m_fnOnReceive(oMessage.msg_hdr.msg_iov[0].iov_base, oMessage.msg_len);
}
}
const auto& oLastMessage = oMessages[nReceived - 1];
const auto pAddress = static_cast<const sockaddr_storage*>(oLastMessage.msg_hdr.msg_name);
if (pAddress)
{
if (pAddress->ss_family == AF_INET)
{
const auto pIPV4 = reinterpret_cast<const sockaddr_in*>(pAddress);
m_oLastReceivedUdpRemoteEndpoint = {asio::ip::address_v4(pIPV4->sin_addr.s_addr),
pIPV4->sin_port};
}
else if (pAddress->ss_family == AF_INET6)
{
const auto pIPV6 = reinterpret_cast<const sockaddr_in6*>(pAddress);
m_oLastReceivedUdpRemoteEndpoint = {
asio::ip::address_v6(
reinterpret_cast<const asio::ip::address_v6::bytes_type&>(pIPV6->sin6_addr.s6_addr),
pIPV6->sin6_scope_id),
pIPV6->sin6_port};
}
}
}
else
{
break;
}
} while(true);
}
if (nError != asio::error::operation_aborted)
{
ScheduleReadMany();
}
}
#endif
void cUdpSocketWrapper::ScheduleReadOne(const std::shared_ptr<tPacketHandler>& oBuffer) noexcept
{
m_oUdpSocket.async_receive_from(
oBuffer->oRegisteredBuffer, oBuffer->oEndpoint,
std::bind(&cUdpSocketWrapper::OnReceiveOne, this, std::placeholders::_1, std::placeholders::_2, oBuffer)
);
}
void cUdpSocketWrapper::OnReceiveOne(
const std::error_code& nError,
std::size_t nBytesTransferred,
const std::shared_ptr<tPacketHandler>& oBuffer) noexcept
{
if (!nError)
{
m_oLastReceivedUdpRemoteEndpoint = oBuffer->oEndpoint;
if (m_fnOnReceive)
{
m_fnOnReceive(oBuffer->oBuffer.data(), nBytesTransferred);
}
}
if (nError != asio::error::operation_aborted)
{
ScheduleReadOne(oBuffer);
}
}
#define LOG_INFO(...)
Logs an info message.
#define RETURN_ERROR(code)
Return specific error code, which requires the calling function's return type to be tResult.
std::chrono::milliseconds milliseconds
Compatibility to C++11 std::chrono::milliseconds
#define RETURN_IF_THROWS(s)
if the expression throws an exception, returns a tResult containing the exception information.