Kylin/AsioZeroMQ/BasicSocket.inl
2023-07-21 14:07:27 +08:00

332 lines
12 KiB
C++

#ifndef BASICSOCKET_INL
#define BASICSOCKET_INL
#include "BasicSocket.h"
#include "ErrorCode.h"
#include "SocketService.h"
#include <boost/asio/posix/stream_descriptor.hpp>
#include <boost/asio/post.hpp>
#include <zmq.h>
#include "BoostLog.h"
namespace ZeroMQ {
template <typename Service>
BasicSocket<Service>::BasicSocket(boost::asio::io_context &context, SocketType type)
: m_service(boost::asio::use_service<Service>(context)) {
m_service.construct(m_impl, type);
}
template <typename Service>
BasicSocket<Service>::~BasicSocket() {
m_service.destroy(m_impl);
}
template <typename Service>
void BasicSocket<Service>::connect(std::string_view address, boost::system::error_code &error) {
m_service.connect(m_impl, std::move(address), error);
}
template <typename Service>
void BasicSocket<Service>::connect(std::string_view address) {
boost::system::error_code error;
connect(std::move(address), error);
if (error) throw error;
}
template <typename Service>
bool BasicSocket<Service>::connected() const {
return m_impl != nullptr && m_impl->socket != nullptr;
}
template <typename Service>
void BasicSocket<Service>::bind(std::string_view address, boost::system::error_code &error) {
std::lock_guard lock_guard(m_impl->mutex);
auto status = zmq_bind(m_impl->socket, address.data());
if (status < 0) {
error = makeErrorCode();
return;
}
}
template <typename Service>
void BasicSocket<Service>::bind(std::string_view address) {
boost::system::error_code error;
bind(std::move(address), error);
if (error) throw error;
}
template <typename Service>
template <int Option, class T, bool BoolUnit>
void BasicSocket<Service>::setOption(IntegralOption<Option, T, BoolUnit>, const T &value,
boost::system::error_code &error) {
static_assert(std::is_integral<T>::value, "T must be integral");
m_service.setOption(m_impl, Option, &value, sizeof(value), error);
}
template <typename Service>
template <int Option, class T, bool BoolUnit>
void BasicSocket<Service>::setOption(IntegralOption<Option, T, BoolUnit>, const T &value) {
boost::system::error_code error;
setOption(IntegralOption<Option, T, BoolUnit>(), value, error);
if (error) throw error;
}
template <typename Service>
template <int Option, int NullTerm>
void BasicSocket<Service>::setOption(ArrayOption<Option, NullTerm>, const std::string_view &buffer,
boost::system::error_code &error) {
m_service.setOption(m_impl, Option, buffer.data(), buffer.size(), error);
}
template <typename Service>
template <int Option, class T, bool BoolUnit>
T BasicSocket<Service>::option(IntegralOption<Option, T, BoolUnit>, boost::system::error_code &error) const {
static_assert(std::is_integral<T>::value, "T must be integral");
T value;
size_t size = sizeof value;
m_service.option(m_impl, Option, &value, &size, error);
assert(size == sizeof value);
return value;
}
template <typename Service>
template <int Option, class T, bool BoolUnit>
T BasicSocket<Service>::option(IntegralOption<Option, T, BoolUnit>) const {
boost::system::error_code error;
auto ret = option(IntegralOption<Option, T, BoolUnit>(), error);
if (error) throw error;
return ret;
}
template <typename Service>
boost::asio::io_context &BasicSocket<Service>::ioContext() const {
return m_service.get_io_context();
}
template <typename Service>
size_t BasicSocket<Service>::send(boost::asio::const_buffer buffer, SendFlags flags, boost::system::error_code &error) {
const int nbytes = zmq_send(m_impl->socket, buffer.data(), buffer.size(), static_cast<int>(flags));
if (nbytes >= 0) return static_cast<size_t>(nbytes);
error = makeErrorCode();
return nbytes;
}
template <typename Service>
template <typename ConstBufferSequence>
typename boost::enable_if<boost::has_range_const_iterator<ConstBufferSequence>, size_t>::type
BasicSocket<Service>::send(const ConstBufferSequence &buffers, SendFlags flags, boost::system::error_code &error) {
size_t res = 0;
auto last = std::distance(std::begin(buffers), std::end(buffers)) - 1;
auto index = 0u;
for (auto it = std::begin(buffers); it != std::end(buffers); ++it, ++index) {
auto f = index == last ? static_cast<int>(flags) : static_cast<int>(flags) | ZMQ_SNDMORE;
res += send(*it, static_cast<SendFlags>(f), error);
if (error) return 0u;
}
return res;
}
template <typename Service>
template <typename ConstBufferSequence>
typename boost::enable_if<boost::has_range_const_iterator<ConstBufferSequence>, size_t>::type
BasicSocket<Service>::send(const ConstBufferSequence &buffers, SendFlags flags) {
boost::system::error_code error;
auto size = send(buffers, flags, error);
if (error) throw error;
return size;
}
template <typename Service>
size_t BasicSocket<Service>::send(Message &&message, SendFlags flags, boost::system::error_code &error) {
int nbytes = zmq_msg_send(message.handle(), m_impl->socket, static_cast<int>(flags));
if (nbytes >= 0) return static_cast<size_t>(nbytes);
error = makeErrorCode();
return nbytes;
}
template <typename Service>
size_t BasicSocket<Service>::send(Message &&message, SendFlags flags) {
boost::system::error_code error;
auto size = send(std::move(message), flags, error);
if (error) throw error;
return size;
}
template <typename Service>
std::size_t BasicSocket<Service>::receive(Message &message, RecvFlags flags, boost::system::error_code &error) {
std::lock_guard lock_guard(m_impl->mutex);
BOOST_ASSERT_MSG(m_impl->socket, "Invalid socket");
auto size = zmq_msg_recv(message.handle(), m_impl->socket, static_cast<int>(flags));
if (size < 0) error = makeErrorCode();
return size;
}
template <typename Service>
std::size_t BasicSocket<Service>::receive(Message &message, RecvFlags flags) {
boost::system::error_code error;
auto size = receive(message, flags, error);
if (error) throw error;
return size;
}
template <typename Service>
size_t BasicSocket<Service>::receive(boost::asio::mutable_buffer buffer, RecvFlags flags,
boost::system::error_code &error) {
const int nbytes = zmq_recv(m_impl->socket, buffer.data(), buffer.size(), static_cast<int>(flags));
if (nbytes >= 0) return nbytes;
error = makeErrorCode();
return nbytes;
}
template <typename Service>
template <typename MutableBufferSequence>
typename boost::enable_if<boost::has_range_const_iterator<MutableBufferSequence>, std::vector<size_t>>::type
BasicSocket<Service>::receive(const MutableBufferSequence &buffers, RecvFlags flags, boost::system::error_code &error) {
std::vector<size_t> ret;
auto iterator = std::begin(buffers);
auto f = static_cast<int>(flags);
do {
auto size = receive(*iterator, flags, error);
if (error) return ret;
ret.push_back(size);
f |= ZMQ_RCVMORE;
++iterator;
} while ((iterator != std::end(buffers)) && option(ReceiveMore));
if (option(ReceiveMore)) error = makeErrorCode(boost::system::errc::no_buffer_space);
return ret;
}
template <typename Service>
template <typename MutableBufferSequence>
typename boost::enable_if<boost::has_range_const_iterator<MutableBufferSequence>, std::vector<size_t>>::type
BasicSocket<Service>::receive(const MutableBufferSequence &buffers, RecvFlags flags) {
boost::system::error_code error;
auto size = receive(buffers, flags, error);
if (error) throw error;
return size;
}
template <typename Service>
size_t BasicSocket<Service>::receive(const boost::asio::mutable_buffer &buffer, RecvFlags flags) {
boost::system::error_code error;
auto size = receive(std::move(buffer), flags, error);
if (error) throw error;
return size;
}
template <typename Service>
template <typename MutableBufferSequence, typename ReadHandler>
void BasicSocket<Service>::asyncReceive(const MutableBufferSequence &buffers, ReadHandler &&handler) {
// using namespace boost::asio::posix;
if (option(Events) & ZMQ_POLLIN) {
boost::asio::post(m_service.get_io_context(), [&buffers, this, handler{std::move(handler)}]() {
boost::system::error_code error;
auto size = receive(buffers, RecvFlags::Dontwait, error);
handler(error, size);
});
return;
}
m_impl->descriptor->async_wait(StreamType::wait_read, [this, &buffers, handler{std::move(handler)}](
const boost::system::error_code &waitError) {
if (waitError) {
handler(waitError, {});
return;
}
if (option(Events) & ZMQ_POLLIN) {
boost::system::error_code error;
auto size = receive(buffers, RecvFlags::Dontwait, error);
return handler(error, size);
} else {
asyncReceive(buffers, handler);
}
});
}
template <typename Service>
template <typename ReadHandler>
void BasicSocket<Service>::asyncReceive(Message &message, ReadHandler &&handler) {
// using namespace boost::asio::posix;
if (option(Events) & ZMQ_POLLIN) {
boost::asio::post(m_service.get_io_context(), [&message, this, handler{std::move(handler)}]() {
boost::system::error_code error;
auto size = receive(message, RecvFlags::Dontwait, error);
handler(error, size);
});
return;
}
m_impl->descriptor->async_wait(StreamType::wait_read, [this, &message, handler{std::move(handler)}](
const boost::system::error_code &waitError) {
if (waitError) {
handler(waitError, 0);
return;
}
if (option(Events) & ZMQ_POLLIN) {
boost::system::error_code error;
auto size = receive(message, RecvFlags::Dontwait, error);
return handler(error, size);
} else {
asyncReceive(message, handler);
}
});
}
template <typename Service>
template <bool CheckN, class OutputIt>
size_t BasicSocket<Service>::receiveMultipart(OutputIt &out, size_t n, RecvFlags flags,
boost::system::error_code &error) {
size_t msg_count = 0;
Message message;
while (true) {
if (CheckN) {
if (msg_count >= n) throw std::runtime_error("Too many message parts in recv_multipart_n");
}
receive(message, flags, error);
if (error) return msg_count;
++msg_count;
const bool more = message.more();
*out++ = std::move(message);
if (!more) break;
}
return msg_count;
}
template <typename Service>
template <class OutputIt, typename ReadHandler>
void BasicSocket<Service>::asyncReceiveMultipart(OutputIt out, ReadHandler &&handler) {
// using namespace boost::asio::posix;
boost::system::error_code error;
if (option(Events) & ZMQ_POLLIN) {
auto size = receiveMultipart<false>(out, 0, RecvFlags::Dontwait, error);
return handler(error, size);
}
m_impl->descriptor->async_wait(StreamType::wait_read, [this, out, handler{std::move(handler)}](
const boost::system::error_code &waitError) mutable {
if (waitError) {
handler(waitError, 0);
return;
}
if ((option(Events) & ZMQ_POLLIN) == 0) {
asyncReceiveMultipart(out, handler);
return;
}
size_t size = 0;
boost::system::error_code error;
size += receiveMultipart<false>(out, 0, RecvFlags::Dontwait, error);
if (!error || error.value() != EAGAIN) return handler(error, size);
});
}
} // namespace ZeroMQ
#endif // BASICSOCKET_INL