Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions example/calculator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,11 @@ int main(int argc, char** argv)
std::cout << "Usage: mpcalculator <fd>\n";
return 1;
}
int fd;
if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) {
std::cerr << argv[1] << " is not a number or is larger than an int\n";
return 1;
}
mp::SocketId socket{mp::StartSpawned(argv[1])};
mp::EventLoop loop("mpcalculator", LogPrint);
std::unique_ptr<Init> init = std::make_unique<InitImpl>();
mp::ServeStream<InitInterface>(loop, fd, *init);
mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)};
mp::ServeStream<InitInterface>(loop, kj::mv(stream), *init);
loop.loop();
return 0;
}
10 changes: 5 additions & 5 deletions example/example.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,14 @@ namespace fs = std::filesystem;

static auto Spawn(mp::EventLoop& loop, const std::string& process_argv0, const std::string& new_exe_name)
{
int pid;
const int fd = mp::SpawnProcess(pid, [&](int fd) -> std::vector<std::string> {
auto pair{mp::SocketPair()};
mp::ProcessId pid{mp::SpawnProcess(pair[0], [&](mp::ConnectInfo info) -> std::vector<std::string> {
fs::path path = process_argv0;
path.remove_filename();
path.append(new_exe_name);
return {path.string(), std::to_string(fd)};
});
return std::make_tuple(mp::ConnectStream<InitInterface>(loop, fd), pid);
return {path.string(), std::move(info)};
})};
return std::make_tuple(mp::ConnectStream<InitInterface>(loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(pair[1])), pid);
}

static void LogPrint(mp::LogMessage log_data)
Expand Down
9 changes: 3 additions & 6 deletions example/printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,11 @@ int main(int argc, char** argv)
std::cout << "Usage: mpprinter <fd>\n";
return 1;
}
int fd;
if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) {
std::cerr << argv[1] << " is not a number or is larger than an int\n";
return 1;
}
mp::SocketId socket{mp::StartSpawned(argv[1])};
mp::EventLoop loop("mpprinter", LogPrint);
std::unique_ptr<Init> init = std::make_unique<InitImpl>();
mp::ServeStream<InitInterface>(loop, fd, *init);
mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)};
mp::ServeStream<InitInterface>(loop, std::move(stream), *init);
loop.loop();
return 0;
}
29 changes: 19 additions & 10 deletions include/mp/proxy-io.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,17 @@ class Logger

std::string LongThreadName(const char* exe_name);

using Stream = kj::Own<kj::AsyncIoStream>;

inline SocketId StreamSocketId(const Stream& stream)
{
if (stream) KJ_IF_MAYBE(fd, stream->getFd()) return *fd;
#ifdef WIN32
if (stream) KJ_IF_MAYBE(handle, stream->getWin32Handle()) return reinterpret_cast<SocketId>(*handle);
#endif
throw std::logic_error("Stream socket unset");
}

//! Event loop implementation.
//!
//! Cap'n Proto threading model is very simple: all I/O operations are
Expand Down Expand Up @@ -283,11 +294,12 @@ class EventLoop
//! Callback functions to run on async thread.
std::optional<CleanupList> m_async_fns MP_GUARDED_BY(m_mutex);

//! Pipe read handle used to wake up the event loop thread.
int m_wait_fd = -1;
//! Socket pair used to post and wait for wakeups to the event loop thread.
kj::Own<kj::AsyncIoStream> m_wait_stream;
kj::Own<kj::AsyncIoStream> m_post_stream;

//! Pipe write handle used to wake up the event loop thread.
int m_post_fd = -1;
//! Synchronous writer used to write to m_post_stream.
kj::Own<kj::OutputStream> m_post_writer;

//! Number of clients holding references to ProxyServerBase objects that
//! reference this event loop.
Expand Down Expand Up @@ -679,13 +691,11 @@ struct ThreadContext
//! over the stream. Also create a new Connection object embedded in the
//! client that is freed when the client is closed.
template <typename InitInterface>
std::unique_ptr<ProxyClient<InitInterface>> ConnectStream(EventLoop& loop, int fd)
std::unique_ptr<ProxyClient<InitInterface>> ConnectStream(EventLoop& loop, kj::Own<kj::AsyncIoStream> stream)
{
typename InitInterface::Client init_client(nullptr);
std::unique_ptr<Connection> connection;
loop.sync([&] {
auto stream =
loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP);
connection = std::make_unique<Connection>(loop, kj::mv(stream));
init_client = connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs<InitInterface>();
Connection* connection_ptr = connection.get();
Expand Down Expand Up @@ -735,10 +745,9 @@ void _Listen(EventLoop& loop, kj::Own<kj::ConnectionReceiver>&& listener, InitIm
//! Given stream file descriptor and an init object, handle requests on the
//! stream by calling methods on the Init object.
template <typename InitInterface, typename InitImpl>
void ServeStream(EventLoop& loop, int fd, InitImpl& init)
void ServeStream(EventLoop& loop, kj::Own<kj::AsyncIoStream> stream, InitImpl& init)
{
_Serve<InitInterface>(
loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP), init);
_Serve<InitInterface>(loop, kj::mv(stream), init);
}

//! Given listening socket file descriptor and an init object, handle incoming
Expand Down
44 changes: 35 additions & 9 deletions include/mp/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#include <variant>
#include <vector>

#ifdef WIN32
#include <winsock2.h>
#endif

namespace mp {

//! Generic utility functions used by capnp code.
Expand Down Expand Up @@ -216,22 +220,44 @@ std::string ThreadName(const char* exe_name);
//! errors in python unit tests.
std::string LogEscape(const kj::StringTree& string, size_t max_size);

#ifdef WIN32
using ProcessId = uintptr_t;
using SocketId = uintptr_t;
constexpr SocketId SocketError{INVALID_SOCKET};
#else
using ProcessId = int;
using SocketId = int;
constexpr SocketId SocketError{-1};
#endif

//! Information about parent process passed to child process. On unix this is
//! just the inherited int file descriptor formatted as a string. On windows,
//! this is a path to a named path pipe the parent process will write
//! WSADuplicateSocket info to.
using ConnectInfo = std::string;

//! Callback type used by SpawnProcess below.
using FdToArgsFn = std::function<std::vector<std::string>(int fd)>;
using ConnectInfoToArgsFn = std::function<std::vector<std::string>(const ConnectInfo&)>;

//! Create a socket pair that can be used to communicate within a process or
//! between parent and child processes.
std::array<SocketId, 2> SocketPair();

//! Spawn a new process that communicates with the current process over provided
//! socket argument. Calls connect_info_to_args callback with a connection
//! string that needs to be passed to the child process, and executes the
//! argv command line it returns. Returns child process id.
ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args);

//! Spawn a new process that communicates with the current process over a socket
//! pair. Returns pid through an output argument, and file descriptor for the
//! local side of the socket. Invokes fd_to_args callback with the remote file
//! descriptor number which returns the command line arguments that should be
//! used to execute the process, and which should have the remote file
//! descriptor embedded in whatever format the child process expects.
int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args);
//! Initialize spawned child process using the ConnectInfo string passed to it,
//! returning a socket id for communicating with the parent process.
SocketId StartSpawned(const ConnectInfo& connect_info);

//! Call execvp with vector args.
void ExecProcess(const std::vector<std::string>& args);

//! Wait for a process to exit and return its exit code.
int WaitProcess(int pid);
int WaitProcess(ProcessId pid);

inline char* CharCast(char* c) { return c; }
inline char* CharCast(unsigned char* c) { return (char*)c; }
Expand Down
92 changes: 73 additions & 19 deletions src/mp/proxy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,15 @@
#include <optional>
#include <stdexcept>
#include <string>
#include <sys/socket.h>
#include <thread>
#include <tuple>
#include <unistd.h>
#include <utility>

#ifndef WIN32
#include <sys/socket.h>
#include <unistd.h>
#endif

namespace mp {

thread_local ThreadContext g_thread_context;
Expand Down Expand Up @@ -66,10 +69,9 @@ void EventLoopRef::reset(bool relock) MP_NO_TSA
loop->m_num_clients -= 1;
if (loop->done()) {
loop->m_cv.notify_all();
int post_fd{loop->m_post_fd};
loop_lock->unlock();
char buffer = 0;
KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon)
loop->m_post_writer->write(&buffer, 1);
// By default, do not try to relock `loop_lock` after writing,
// because the event loop could wake up and destroy itself and the
// mutex might no longer exist.
Expand All @@ -96,6 +98,20 @@ Connection::~Connection()
// after the calls finish.
m_rpc_system.reset();

// shutdownWrite is needed on Windows so pending data in the m_stream socket
// will be sent instead of discarded when m_stream is destroyed. On unix,
// this doesn't seem to be needed because data is sent more reliably.
//
// Sending pending data is important if the connection is a socketpair
// because when one side of the socketpair is closed, the other side doesn't
// seem to receive any onDisconnect event. So it is important for the other
// side to instead receive Cap'n Proto "release" messages (see `struct
// Release` in capnp/rpc.capnp) from local Client objects being being
// destroyed so the remote side can free resources and shut down cleanly.
// Without this call, Server objects corresponding to the Client objects on
// the other side of the connection are not freed by Cap'n Proto.
m_stream->shutdownWrite();

// ProxyClient cleanup handlers are in sync list, and ProxyServer cleanup
// handlers are in the async list.
//
Expand Down Expand Up @@ -192,17 +208,59 @@ void EventLoop::addAsyncCleanup(std::function<void()> fn)
startAsyncThread();
}

#ifdef WIN32
//! Synchronous socket output stream. Cap'n Proto library only provides limited
//! support for synchronous IO. It provides `FdOutputStream` which wraps unix
//! file descriptors and calls write() internally, and `HandleOutStream` which
//! wraps windows HANDLE values and calls WriteFile() internally. This class
//! just provides analagous functionality wrapping SOCKET values and calls
//! send() internally.
class SocketOutputStream : public kj::OutputStream {
public:
explicit SocketOutputStream(SOCKET socket) : m_socket(socket) {}

void write(const void* buffer, size_t size) override;

private:
SOCKET m_socket;
};

static constexpr size_t WRITE_CLAMP_SIZE = 1u << 30; // 1GB clamp for Windows, like FdOutputStream

void SocketOutputStream::write(const void* buffer, size_t size) {
const char* pos = reinterpret_cast<const char*>(buffer);

while (size > 0) {
int n = send(m_socket, pos, static_cast<int>(kj::min(size, WRITE_CLAMP_SIZE)), 0);

KJ_WIN32(n != SOCKET_ERROR, "send() failed");
KJ_ASSERT(n > 0, "send() returned zero.");

pos += n;
size -= n;
}
}
#endif

EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context)
: m_exe_name(exe_name),
m_io_context(kj::setupAsyncIo()),
m_task_set(new kj::TaskSet(m_error_handler)),
m_log_opts(std::move(log_opts)),
m_context(context)
{
int fds[2];
KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds));
m_wait_fd = fds[0];
m_post_fd = fds[1];
auto pipe = m_io_context.provider->newTwoWayPipe();
m_wait_stream = kj::mv(pipe.ends[0]);
m_post_stream = kj::mv(pipe.ends[1]);
KJ_IF_MAYBE(fd, m_post_stream->getFd()) {
m_post_writer = kj::heap<kj::FdOutputStream>(*fd);
#ifdef WIN32
} else KJ_IF_MAYBE(handle, m_post_stream->getWin32Handle()) {
m_post_writer = kj::heap<SocketOutputStream>(reinterpret_cast<SOCKET>(*handle));
#endif
} else {
throw std::logic_error("Could not get file descriptor for new pipe.");
}
}

EventLoop::~EventLoop()
Expand All @@ -211,8 +269,8 @@ EventLoop::~EventLoop()
const Lock lock(m_mutex);
KJ_ASSERT(m_post_fn == nullptr);
KJ_ASSERT(!m_async_fns);
KJ_ASSERT(m_wait_fd == -1);
KJ_ASSERT(m_post_fd == -1);
KJ_ASSERT(!m_wait_stream);
KJ_ASSERT(!m_post_stream);
KJ_ASSERT(m_num_clients == 0);

// Spin event loop. wait for any promises triggered by RPC shutdown.
Expand All @@ -232,9 +290,7 @@ void EventLoop::loop()
m_async_fns.emplace();
}

kj::Own<kj::AsyncIoStream> wait_stream{
m_io_context.lowLevelProvider->wrapSocketFd(m_wait_fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)};
int post_fd{m_post_fd};
kj::Own<kj::AsyncIoStream>& wait_stream{m_wait_stream};
char buffer = 0;
for (;;) {
const size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope);
Expand All @@ -246,7 +302,7 @@ void EventLoop::loop()
m_cv.notify_all();
} else if (done()) {
// Intentionally do not break if m_post_fn was set, even if done()
// would return true, to ensure that the EventLoopRef write(post_fd)
// would return true, to ensure that the EventLoopRef write(post_stream)
// call always succeeds and the loop does not exit between the time
// that the done condition is set and the write call is made.
break;
Expand All @@ -256,10 +312,9 @@ void EventLoop::loop()
m_task_set.reset();
MP_LOG(*this, Log::Info) << "EventLoop::loop bye.";
wait_stream = nullptr;
KJ_SYSCALL(::close(post_fd));
const Lock lock(m_mutex);
m_wait_fd = -1;
m_post_fd = -1;
m_wait_stream = nullptr;
m_post_stream = nullptr;
m_async_fns.reset();
m_cv.notify_all();
}
Expand All @@ -274,10 +329,9 @@ void EventLoop::post(kj::Function<void()> fn)
EventLoopRef ref(*this, &lock);
m_cv.wait(lock.m_lock, [this]() MP_REQUIRES(m_mutex) { return m_post_fn == nullptr; });
m_post_fn = &fn;
int post_fd{m_post_fd};
Unlock(lock, [&] {
char buffer = 0;
KJ_SYSCALL(write(post_fd, &buffer, 1));
m_post_writer->write(&buffer, 1);
});
m_cv.wait(lock.m_lock, [this, &fn]() MP_REQUIRES(m_mutex) { return m_post_fn != &fn; });
}
Expand Down
Loading
Loading