diff --git a/src/mgmt/rpc/server/unit_tests/test_rpcserver.cc b/src/mgmt/rpc/server/unit_tests/test_rpcserver.cc index 4c67d681e28..39639a46bc8 100644 --- a/src/mgmt/rpc/server/unit_tests/test_rpcserver.cc +++ b/src/mgmt/rpc/server/unit_tests/test_rpcserver.cc @@ -27,6 +27,9 @@ #include #include #include +#include +#include +#include #include #include @@ -335,6 +338,10 @@ struct ScopedLocalSocket : shared::rpc::IPCSocketClient { } }; +struct TestableIPCSocketClient : shared::rpc::IPCSocketClient { + using shared::rpc::IPCSocketClient::_safe_write; +}; + // helper function to send a request and update the promise when the response is done. // This is to be used in a multithread test. void @@ -344,6 +351,72 @@ send_request(std::string json, std::promise p) auto resp = rpc_client.query(json); p.set_value(resp); } + +TEST_CASE("IPCSocketClient write returns when the peer stops reading", "[socket][client]") +{ + int fds[2]; + REQUIRE(::socketpair(AF_UNIX, SOCK_STREAM, 0, fds) == 0); + + int const flags = ::fcntl(fds[0], F_GETFL, 0); + REQUIRE(flags >= 0); + REQUIRE(::fcntl(fds[0], F_SETFL, flags | O_NONBLOCK) == 0); + + std::vector fill(4096, 'x'); + while (true) { + ssize_t const ret = ::write(fds[0], fill.data(), fill.size()); + if (ret < 0) { + REQUIRE((errno == EAGAIN || errno == EWOULDBLOCK)); + break; + } + REQUIRE(ret > 0); + } + + TestableIPCSocketClient rpc_client; + pid_t const pid = ::fork(); + REQUIRE(pid >= 0); + if (pid == 0) { + ::close(fds[1]); + char const byte = 'x'; + auto const ret = rpc_client._safe_write(fds[0], &byte, 1); + auto const err = errno; + ::close(fds[0]); + _exit(ret == -1 && err == ETIMEDOUT ? 0 : 1); + } + + int status = 0; + bool child_exited = false; + auto const child_deadline = std::chrono::steady_clock::now() + std::chrono::seconds(10); + while (std::chrono::steady_clock::now() < child_deadline) { + auto const wait_ret = ::waitpid(pid, &status, WNOHANG); + if (wait_ret == pid) { + child_exited = true; + break; + } + REQUIRE(wait_ret == 0); + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + } + + if (!child_exited) { + auto const wait_ret = ::waitpid(pid, &status, WNOHANG); + if (wait_ret == pid) { + child_exited = true; + } else { + REQUIRE(wait_ret == 0); + } + } + + if (!child_exited) { + ::kill(pid, SIGKILL); + REQUIRE(::waitpid(pid, &status, 0) == pid); + FAIL("_safe_write did not return when the nonblocking socket stayed unwritable"); + } + + ::close(fds[0]); + ::close(fds[1]); + + REQUIRE(WIFEXITED(status)); + REQUIRE(WEXITSTATUS(status) == 0); +} } // namespace TEST_CASE("Sending 'concurrent' requests to the rpc server.", "[thread]") { diff --git a/src/shared/rpc/IPCSocketClient.cc b/src/shared/rpc/IPCSocketClient.cc index bed9c58dc6f..6c3797b2ed7 100644 --- a/src/shared/rpc/IPCSocketClient.cc +++ b/src/shared/rpc/IPCSocketClient.cc @@ -89,13 +89,28 @@ IPCSocketClient::connect(std::chrono::milliseconds wait_ms, int attempts) std::int64_t IPCSocketClient::_safe_write(int fd, const char *buffer, int len) { + constexpr int WRITE_READY_TIMEOUT_MS = 1000; + std::int64_t written{0}; while (written < len) { const ssize_t ret = ::write(fd, buffer + written, len - written); if (ret == -1) { - if (errno == EAGAIN || errno == EINTR) { + if (errno == EINTR) { continue; } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + auto const ready = write_ready(fd, WRITE_READY_TIMEOUT_MS); + if (ready == 1) { + continue; + } + if (ready == 0) { + errno = ETIMEDOUT; + } + } + return -1; + } + if (ret == 0) { + errno = EIO; return -1; } written += ret;