diff --git a/CMakeLists.txt b/CMakeLists.txt index f25d127..aa44efa 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -4,15 +4,27 @@ project(echo_test) set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++11 -g") -add_library(common STATIC - address.cpp - epoll.cpp - pipe.cpp - socket.cpp - throw_error.cpp - timer.cpp - file_descriptor.cpp -) +IF(UNIX) + IF(APPLE) + add_library(common STATIC + address.cpp + kqueue.cpp + socket_apple.cpp + throw_error.cpp + timer.cpp + file_descriptor.cpp + ) + ELSE(APPLE) + add_library(common STATIC + address.cpp + epoll.cpp + socket.cpp + throw_error.cpp + timer.cpp + file_descriptor.cpp + ) + ENDIF(APPLE) +ENDIF(UNIX) add_executable(echo_server main_echo_server.cpp diff --git a/echo_server.h b/echo_server.h index afa5384..730bebc 100644 --- a/echo_server.h +++ b/echo_server.h @@ -2,7 +2,11 @@ #define ECHO_SERVER_H #include +#ifdef __APPLE__ +#include "socket_apple.h" +#else #include "socket.h" +#endif struct echo_server { diff --git a/echo_tester.h b/echo_tester.h index 0ab6f4e..0c4037f 100644 --- a/echo_tester.h +++ b/echo_tester.h @@ -1,9 +1,14 @@ #ifndef ECHO_TESTER_H #define ECHO_TESTER_H +#ifdef __APPLE__ +#include "kqueue.hpp" +#include "socket_apple.h" +#else #include "epoll.h" -#include "address.h" #include "socket.h" +#endif +#include "address.h" #include struct echo_tester diff --git a/file_descriptor.cpp b/file_descriptor.cpp index 7955a0a..da46791 100644 --- a/file_descriptor.cpp +++ b/file_descriptor.cpp @@ -121,7 +121,11 @@ size_t write_some(weak_file_descriptor fdc, void const* data, std::size_t size) int fd = fdc.getfd(); assert(fd != -1); +#ifdef __APPLE__ + ssize_t res = ::send(fd, data, size, 0); +#else ssize_t res = ::send(fd, data, size, MSG_NOSIGNAL); +#endif if (res == -1) { int err = errno; diff --git a/kqueue.cpp b/kqueue.cpp new file mode 100644 index 0000000..63a84b1 --- /dev/null +++ b/kqueue.cpp @@ -0,0 +1,241 @@ +#include "kqueue.hpp" + +#include +#include +#include +#include +#include + +#include "throw_error.h" + +using namespace sysapi; + +epoll::epoll() +{ + int r = kqueue(); + if (r == -1) + throw_error(errno, "kqueue()"); + + assert(r >= 0); + + fd_.reset(r); +} + +epoll::epoll(epoll&& rhs) + : fd_(std::move(rhs.fd_)) +{} + +epoll& epoll::operator=(epoll rhs) +{ + swap(rhs); + return *this; +} + +void epoll::swap(epoll& other) +{ + using std::swap; + swap(fd_, other.fd_); +} + +void epoll::run() +{ + for (;;) + { + std::array ev; + + again: + int timeout = run_timers_calculate_timeout(); + struct timespec tmout = {timeout, 0}; + int r = kevent(fd_.getfd(), NULL, 0, ev.data(), ev.size(), timeout == -1 ? nullptr : &tmout); + + if (r < 0) + { + int err = errno; + + if (err == EINTR) + goto again; + + throw_error(err, "kevent()"); + } + + if (r == 0) + goto again; + + assert(r > 0); + size_t num_events = static_cast(r); + assert(num_events <= ev.size()); + + for (auto i = ev.begin(); i != ev.begin() + num_events; ++i) + { + try + { + struct kevent const& ee = *i; + if (ee.ident == -1) { + continue; + } + static_cast(ee.udata)->callback(ee); + if (!deleted_events.empty()) { + for (auto k = deleted_events.begin(); k != deleted_events.end(); k++) { + for (auto j = i + 1; j != ev.begin() + num_events; j++) { + if (j->ident == k->first && j->filter == k->second) { + j->ident = -1; + } + } + } + deleted_events = {}; + } + } + catch (std::exception const& e) + { + std::cerr << "error: " << e.what() << std::endl; + } + catch (...) + { + std::cerr << "unknown exception in message loop" << std::endl; + } + } + } +} + +timer& epoll::get_timer() +{ + return timer_; +} + +void epoll::add(int fd, int16_t event, epoll_registration* reg) +{ + struct kevent ev; + EV_SET(&ev, fd, event, EV_ADD, 0, 0, reg); + + int r = kevent(fd_.getfd(), &ev, 1, NULL, 0, NULL); + if (r < 0) + throw_error(errno, "kevent(EV_ADD)"); +} + +void epoll::modify(int fd, int16_t event, epoll_registration* reg) +{ + struct kevent ev; + EV_SET(&ev, fd, event, EV_ADD, 0, 0, reg); + + int r = kevent(fd_.getfd(), &ev, 1, NULL, 0, NULL); + if (r < 0) + throw_error(errno, "kevent() MOD"); + + deleted_events.push_back({fd, event}); +} + +void epoll::remove(int fd, int16_t event) +{ + struct kevent ev; + EV_SET(&ev, fd, event, EV_DELETE, 0, 0, NULL); + + int r = kevent(fd_.getfd(), &ev, 1, NULL, 0, NULL); + if (r < 0) + throw_error(errno, "kevent(EV_DELETE)"); + + deleted_events.push_back({fd, event}); +} + +int epoll::run_timers_calculate_timeout() +{ + if (timer_.empty()) + return -1; + + timer::clock_t::time_point now = timer::clock_t::now(); + timer_.notify(now); + + if (timer_.empty()) + return -1; + + return std::chrono::duration_cast(timer_.top() - now).count(); +} + +epoll_registration::epoll_registration() + : ep() + , fd(-1) + , events() +{} + +epoll_registration::epoll_registration(epoll& ep, int fd, std::list events, callback_t callback) + : ep(&ep) + , fd(fd) + , events(events) + , callback(std::move(callback)) +{ + for (auto it = events.begin(); it != events.end(); it++) { + ep.add(fd, *it, this); + } +} + +epoll_registration::epoll_registration(epoll_registration&& rhs) + : ep(rhs.ep) + , fd(rhs.fd) + , events(rhs.events) + , callback(std::move(rhs.callback)) +{ + update(); + rhs.ep = nullptr; + rhs.fd = -1; + rhs.events = {}; + rhs.callback = callback_t(); +} + +epoll_registration::~epoll_registration() +{ + clear(); +} + +epoll_registration& epoll_registration::operator=(epoll_registration rhs) +{ + swap(rhs); + return *this; +} + +void epoll_registration::modify(std::list new_events) +{ + assert(ep); + + if (events == new_events) + return; + + for (auto it = events.begin(); it != events.end(); it++) { + ep->modify(fd, *it, this); + } + events = new_events; +} + +void epoll_registration::swap(epoll_registration& other) +{ + std::swap(ep, other.ep); + std::swap(fd, other.fd); + std::swap(events, other.events); + std::swap(callback, other.callback); + update(); + other.update(); +} + +void epoll_registration::clear() +{ + if (ep) + { + for (auto it = events.begin(); it != events.end(); it++) { + ep->remove(fd, *it); + } + ep = nullptr; + fd = -1; + events = {}; + } +} + +epoll& epoll_registration::get_epoll() const +{ + return *ep; +} + +void epoll_registration::update() +{ + if (ep) + for (auto it = events.begin(); it != events.end(); it++) { + ep->modify(fd, *it, this); + } +} diff --git a/kqueue.hpp b/kqueue.hpp new file mode 100644 index 0000000..f421720 --- /dev/null +++ b/kqueue.hpp @@ -0,0 +1,81 @@ +#ifndef kqueue_hpp +#define kqueue_hpp + +#include "file_descriptor.h" +#include "timer.h" + +#include +#include +#include +#include + +namespace sysapi +{ + struct epoll; + struct epoll_registration; + + struct epoll + { + typedef std::function action_t; + epoll(); + epoll(epoll const&) = delete; + epoll(epoll&&); + + epoll& operator=(epoll); + + void swap(epoll& other); + + void run(); + timer& get_timer(); + + private: + void add(int fd, int16_t event, epoll_registration*); + void modify(int fd, int16_t event, epoll_registration*); + void remove(int fd, int16_t event); + + int run_timers_calculate_timeout(); + + private: + file_descriptor fd_; + timer timer_; + std::list> deleted_events; + + friend struct epoll_registration; + }; + + struct epoll_registration + { + typedef std::function callback_t; + + epoll_registration(); + epoll_registration(epoll&, int fd, std::list events, callback_t callback); + epoll_registration(epoll_registration const&) = delete; + epoll_registration(epoll_registration&&); + ~epoll_registration(); + + epoll_registration& operator=(epoll_registration); + + void modify(std::list new_events); + + void swap(epoll_registration& other); + + void clear(); + epoll& get_epoll() const; + + private: + void update(); + + private: + epoll* ep; + int fd; + std::list events; + callback_t callback; + + friend struct epoll; + }; +} + +using sysapi::epoll; +using sysapi::epoll_registration; + +#endif diff --git a/main_echo_server.cpp b/main_echo_server.cpp index 372d9fc..ec8cb1a 100644 --- a/main_echo_server.cpp +++ b/main_echo_server.cpp @@ -1,6 +1,10 @@ #include +#ifdef __APPLE__ +#include "kqueue.hpp" +#else #include "epoll.h" +#endif #include "echo_server.h" int main() diff --git a/main_echo_test.cpp b/main_echo_test.cpp index c557432..a1af5df 100644 --- a/main_echo_test.cpp +++ b/main_echo_test.cpp @@ -1,6 +1,10 @@ #include +#ifdef __APPLE__ +#include "kqueue.hpp" +#else #include "epoll.h" +#endif #include "echo_tester.h" int main(int argc, char* argv[]) diff --git a/pipe.cpp b/pipe.cpp deleted file mode 100644 index fa3824c..0000000 --- a/pipe.cpp +++ /dev/null @@ -1,21 +0,0 @@ -#include "pipe.h" - -#include -#include -#include - -#include "throw_error.h" - -pipe_pair make_pipe(bool non_block) -{ - int fds[2]; - int res = pipe2(fds, O_CLOEXEC | (non_block ? O_NONBLOCK : 0)); - - if (res != 0) - { - assert(res == -1); - throw_error(errno, "pipe2()"); - } - - return pipe_pair{weak_file_descriptor{fds[0]}, weak_file_descriptor{fds[1]}}; -} diff --git a/pipe.h b/pipe.h deleted file mode 100644 index 7b167b5..0000000 --- a/pipe.h +++ /dev/null @@ -1,14 +0,0 @@ -#ifndef PIPE_H -#define PIPE_H - -#include "file_descriptor.h" - -struct pipe_pair -{ - weak_file_descriptor out; - weak_file_descriptor in; -}; - -pipe_pair make_pipe(bool non_block); - -#endif // PIPE_H diff --git a/socket_apple.cpp b/socket_apple.cpp new file mode 100644 index 0000000..9948062 --- /dev/null +++ b/socket_apple.cpp @@ -0,0 +1,252 @@ +#include "socket_apple.h" + +#include +#include +#include +#include + +#include "throw_error.h" + +namespace +{ + int get_fd_flags(int fd); + void set_fd_flags(int fd, int flags); + + file_descriptor make_socket(int domain, int type) + { + int fd = ::socket(domain, type, 0); + if (fd == -1) + throw_error(errno, "socket()"); + + return file_descriptor{fd}; + } + + void start_listen(int fd) + { + int res = ::listen(fd, SOMAXCONN); + if (res == -1) + throw_error(errno, "listen()"); + } + + void bind_socket(int fd, uint16_t port_net, uint32_t addr_net) + { + sockaddr_in saddr{}; + saddr.sin_family = AF_INET; + saddr.sin_port = port_net; + saddr.sin_addr.s_addr = addr_net; + int res = ::bind(fd, reinterpret_cast(&saddr), sizeof saddr); + if (res == -1) + throw_error(errno, "bind()"); + } + + void connect_socket(int fd, uint16_t port_net, uint32_t addr_net) + { + sockaddr_in saddr{}; + saddr.sin_family = AF_INET; + saddr.sin_port = port_net; + saddr.sin_addr.s_addr = addr_net; + + int res = ::connect(fd, reinterpret_cast(&saddr), sizeof saddr); + if (res == -1) + throw_error(errno, "connect()"); + } + + int get_fd_flags(int fd) + { + int res = fcntl(fd, F_GETFL, 0); + if (res == -1) + throw_error(errno, "fcntl(F_GETFL)"); + return res; + } + + void set_fd_flags(int fd, int flags) + { + int res = fcntl(fd, F_SETFL, flags); + if (res == -1) + throw_error(errno, "fcntl(F_SETFL)"); + } +} + +client_socket::client_socket(sysapi::epoll &ep, file_descriptor fd, on_ready_t on_disconnect) +: client_socket(ep, std::move(fd), std::move(on_disconnect), on_ready_t{}, on_ready_t{}) +{} + +client_socket::client_socket(epoll& ep, + file_descriptor fd, + on_ready_t on_disconnect, + on_ready_t on_read_ready, + on_ready_t on_write_ready) +: pimpl(new impl(ep, std::move(fd), std::move(on_disconnect), std::move(on_read_ready), std::move(on_write_ready))) +{} + +client_socket::impl::impl(sysapi::epoll &ep, file_descriptor fd, on_ready_t on_disconnect, on_ready_t on_read_ready, on_ready_t on_write_ready) + : ep(ep) + , fd(std::move(fd)) + , on_disconnect(std::move(on_disconnect)) + , on_read_ready(std::move(on_read_ready)) + , on_write_ready(std::move(on_write_ready)) + , reg(ep, this->fd.getfd(), {EVFILT_READ}, [this](struct kevent event) + { + assert(event.filter == EVFILT_READ || event.filter == EVFILT_WRITE); + bool is_destroyed = false; + assert(destroyed == nullptr); + destroyed = &is_destroyed; + try + { + if ((event.filter == EVFILT_READ && event.flags & EV_EOF) || (event.filter == EVFILT_WRITE && event.flags & EV_EOF)) + { + this->on_disconnect(); + if (is_destroyed) + return; + } + if (event.filter == EVFILT_READ) + { + this->on_read_ready(); + if (is_destroyed) + return; + } + if (event.filter == EVFILT_WRITE) + { + this->on_write_ready(); + if (is_destroyed) + return; + } + } + catch (...) + { + destroyed = nullptr; + throw; + } + + destroyed = nullptr; + }) + , destroyed(nullptr) +{} + +client_socket::impl::~impl() +{ + if (destroyed) + *destroyed = true; +} + +void client_socket::impl::update_registration() +{ + reg.modify(calculate_flags()); +} + +std::list client_socket::impl::calculate_flags() const +{ + std::list ev_list; + if (on_read_ready) + ev_list.push_back(EVFILT_READ); + if (on_write_ready) + ev_list.push_back(EVFILT_WRITE); + + return ev_list; +} + +void client_socket::set_on_read_write(on_ready_t on_read_ready, + on_ready_t on_write_ready) +{ + pimpl->on_read_ready = std::move(on_read_ready); + pimpl->on_write_ready = std::move(on_write_ready); + pimpl->update_registration(); +} + +void client_socket::set_on_read(on_ready_t on_ready) +{ + // TODO: not exception safe + pimpl->on_read_ready = on_ready; + pimpl->update_registration(); +} + +void client_socket::set_on_write(client_socket::on_ready_t on_ready) +{ + pimpl->on_write_ready = on_ready; + pimpl->update_registration(); +} + +size_t client_socket::write_some(const void *data, size_t size) +{ + return ::write_some(pimpl->fd, data, size); +} + +size_t client_socket::read_some(void* data, size_t size) +{ + return ::read_some(pimpl->fd, data, size); +} + +client_socket client_socket::connect(sysapi::epoll &ep, const ipv4_endpoint &remote, on_ready_t on_disconnect) +{ + file_descriptor fd = make_socket(AF_INET, SOCK_STREAM); + connect_socket(fd.getfd(), remote.port_net, remote.addr_net); + set_fd_flags(fd.getfd(), get_fd_flags(fd.getfd()) | O_NONBLOCK); + client_socket res{ep, std::move(fd), std::move(on_disconnect)}; + return res; +} + +server_socket::server_socket(epoll& ep, on_connected_t on_connected) + : fd(make_socket(AF_INET, SOCK_STREAM)) + , on_connected(on_connected) + , reg(ep, fd.getfd(), {EVFILT_READ}, [this](struct kevent event) { + assert(event.filter == EVFILT_READ); + this->on_connected(); + }) +{ + set_fd_flags(fd.getfd(), get_fd_flags(fd.getfd()) | O_NONBLOCK); + start_listen(fd.getfd()); +} + +server_socket::server_socket(epoll& ep, ipv4_endpoint local_endpoint, on_connected_t on_connected) + : fd(make_socket(AF_INET, SOCK_STREAM)) + , on_connected(on_connected) + , reg(ep, fd.getfd(), {EVFILT_READ}, [this](struct kevent event) { + assert(event.filter & EVFILT_READ); + this->on_connected(); + }) +{ + set_fd_flags(fd.getfd(), get_fd_flags(fd.getfd()) | O_NONBLOCK); + bind_socket(fd.getfd(), local_endpoint.port_net, local_endpoint.addr_net); + start_listen(fd.getfd()); +} + +ipv4_endpoint server_socket::local_endpoint() const +{ + sockaddr_in saddr{}; + socklen_t saddr_len = sizeof saddr; + int res = ::getsockname(fd.getfd(), reinterpret_cast(&saddr), &saddr_len); + if (res == -1) + throw_error(errno, "getsockname()"); + assert(saddr_len == sizeof saddr); + return ipv4_endpoint{saddr.sin_port, saddr.sin_addr.s_addr}; +} + +client_socket server_socket::accept(client_socket::on_ready_t on_disconnect) const +{ + int res = ::accept(fd.getfd(), nullptr, nullptr); + if (res == -1) + throw_error(errno, "accept()"); + + set_fd_flags(fd.getfd(), get_fd_flags(fd.getfd()) | O_NONBLOCK); + + const int set = 1; + ::setsockopt(res, SOL_SOCKET, SO_NOSIGPIPE, &set, sizeof(set)); // NOSIGPIPE FOR SEND + + return client_socket{reg.get_epoll(), {res}, std::move(on_disconnect)}; +} + +client_socket server_socket::accept(client_socket::on_ready_t on_disconnect, + client_socket::on_ready_t on_read_ready, + client_socket::on_ready_t on_write_ready) const +{ + int res = ::accept(fd.getfd(), nullptr, nullptr); + if (res == -1) + throw_error(errno, "accept4()"); + + set_fd_flags(fd.getfd(), get_fd_flags(fd.getfd()) | O_NONBLOCK); + + const int set = 1; + ::setsockopt(res, SOL_SOCKET, SO_NOSIGPIPE, &set, sizeof(set)); // NOSIGPIPE FOR SEND + + return client_socket{reg.get_epoll(), {res}, std::move(on_disconnect), std::move(on_read_ready), std::move(on_write_ready)}; +} diff --git a/socket_apple.h b/socket_apple.h new file mode 100644 index 0000000..3c159d6 --- /dev/null +++ b/socket_apple.h @@ -0,0 +1,87 @@ +#ifndef SOCKET_APPLE_H +#define SOCKET_APPLE_H + +#include "file_descriptor.h" +#include "address.h" +#include "kqueue.hpp" +#include +#include + +struct client_socket +{ + typedef std::function on_ready_t; + + client_socket(epoll& ep, + file_descriptor fd, + on_ready_t on_disconnect); + + client_socket(epoll& ep, + file_descriptor fd, + on_ready_t on_disconnect, + on_ready_t on_read_ready, + on_ready_t on_write_ready); + + void set_on_read_write(on_ready_t on_read_ready, on_ready_t on_write_ready); + void set_on_read(on_ready_t on_ready); + void set_on_write(on_ready_t on_ready); + + size_t write_some(void const* data, size_t size); + size_t read_some(void* data, size_t size); + + static client_socket connect(epoll& ep, ipv4_endpoint const& remote, on_ready_t on_disconnect); + +private: + struct impl + { + impl(epoll& ep, file_descriptor fd, on_ready_t on_disconnect, on_ready_t on_read_ready, on_ready_t on_write_ready); + ~impl(); + + void update_registration(); + std::list calculate_flags() const; + + epoll& ep; + file_descriptor fd; + on_ready_t on_disconnect; + on_ready_t on_read_ready; + on_ready_t on_write_ready; + epoll_registration reg; + bool* destroyed; + }; + + std::unique_ptr pimpl; +}; + +struct server_socket +{ + typedef std::function on_connected_t; + + server_socket(epoll& ep, on_connected_t on_connected); + server_socket(epoll& ep, ipv4_endpoint local_endpoint, on_connected_t on_connected); + + ipv4_endpoint local_endpoint() const; + client_socket accept(client_socket::on_ready_t on_disconnect) const; + client_socket accept(client_socket::on_ready_t on_disconnect, + client_socket::on_ready_t on_read_ready, + client_socket::on_ready_t on_write_ready) const; + +private: + file_descriptor fd; + on_connected_t on_connected; + epoll_registration reg; +}; + +struct eventfd +{ + typedef std::function on_event_t; + + eventfd(epoll& ep, bool semaphore, on_event_t on_event); + void notify(uint64_t increment = 1); + void set_on_event(on_event_t on_event); + +private: + file_descriptor fd; + on_event_t on_event; + epoll_registration reg; +}; + +#endif // SOCKET_APPLE_H diff --git a/throw_error.cpp b/throw_error.cpp index 59a235b..8dff9c7 100644 --- a/throw_error.cpp +++ b/throw_error.cpp @@ -41,8 +41,7 @@ void throw_error [[noreturn]] (int err, char const* action) { std::stringstream ss; ss << action << " failed, error: " << error_enum_name(err); - char tmp[2048]; - char const* err_msg = strerror_r(err, tmp, sizeof tmp); + char const* err_msg = strerror(err); ss << " (" << err << ", " << err_msg << ")"; throw std::runtime_error(ss.str()); }