Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 190 additions & 67 deletions rpc/server/rpc_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
#include "rpc/server/rpc_server.h"
#include "proto/subspace.pb.h"
#include <inttypes.h>
#include <poll.h>
#include <stdio.h>
#include <unistd.h>
#include <vector>

namespace subspace {

Expand Down Expand Up @@ -42,15 +45,22 @@ absl::Status RpcServer::RegisterMethod(
google::protobuf::Any *, co::Coroutine *)>
callback,
MethodOptions &&options) {
if (methods_.find(method) != methods_.end()) {
return absl::AlreadyExistsError("Method already registered: " + method);
}

methods_[method] = std::make_shared<Method>(
this, method, std::string{request_type}, std::string{response_type}, options.slot_size,
options.num_slots, std::move(callback),
options.id == -1 ? ++next_method_id_ : options.id);
return absl::OkStatus();
return RegisterMethodAsync(
method,
[method, callback = std::move(callback)](
const google::protobuf::Any &req, co::Coroutine *c,
std::function<void(std::unique_ptr<google::protobuf::Any>)> reply,
std::function<void(std::string)> error_reply) {
auto res = std::make_unique<google::protobuf::Any>();
auto status = callback(req, res.get(), c);
if (!status.ok()) {
error_reply(absl::StrFormat("Error executing method %s: %s", method,
status.ToString()));
return;
}
reply(std::move(res));
},
std::move(options), request_type, response_type);
}

// Register void method.
Expand All @@ -59,23 +69,43 @@ absl::Status RpcServer::RegisterMethod(
std::function<absl::Status(const google::protobuf::Any &, co::Coroutine *)>
callback,
MethodOptions &&options) {
if (methods_.find(method) != methods_.end()) {
return absl::AlreadyExistsError("Method already registered: " + method);
}
methods_[method] = std::make_shared<Method>(
this, method, std::string{request_type}, "subspace.VoidMessage", options.slot_size,
options.num_slots,
[callback](const google::protobuf::Any &req, google::protobuf::Any *res,
co::Coroutine *c) {
return RegisterMethodAsync(
method,
[method, callback = std::move(callback)](
const google::protobuf::Any &req, co::Coroutine *c,
std::function<void(std::unique_ptr<google::protobuf::Any>)> reply,
std::function<void(std::string)> error_reply) {
auto status = callback(req, c);
if (!status.ok()) {
return status;
error_reply(absl::StrFormat("Error executing method %s: %s", method,
status.ToString()));
return;
}
// The response for this void method is a VoidMessage packed
// into a google.protobuf.Any.
auto res = std::make_unique<google::protobuf::Any>();
res->PackFrom(VoidMessage());
return absl::OkStatus();
reply(std::move(res));
},
std::move(options), request_type, "subspace.VoidMessage");
}

absl::Status RpcServer::RegisterMethodAsync(
const std::string &method,
std::function<void(
const google::protobuf::Any &, co::Coroutine *,
std::function<void(std::unique_ptr<google::protobuf::Any>)>,
std::function<void(std::string)>)>
callback,
MethodOptions &&options, std::string_view request_type,
std::string_view response_type) {
if (methods_.find(method) != methods_.end()) {
return absl::AlreadyExistsError("Method already registered: " + method);
}

methods_[method] = std::make_shared<Method>(
this, method, std::string{request_type}, std::string{response_type},
options.slot_size, options.num_slots, std::move(callback),
options.id == -1 ? ++next_method_id_ : options.id);
return absl::OkStatus();
}
Expand Down Expand Up @@ -441,20 +471,50 @@ RpcServer::CreateSession(uint64_t client_id) {

session->methods.insert({method->id, method_instance});

AddCoroutine(std::make_unique<co::Coroutine>(
*scheduler_,
[server = shared_from_this(), session,
method_instance](co::Coroutine *c) {
if (method_instance->method->IsStreaming()) {
if (method->IsStreaming()) {
AddCoroutine(std::make_unique<co::Coroutine>(
*scheduler_,
[server = shared_from_this(), session,
method_instance](co::Coroutine *c) {
SessionStreamingMethodCoroutine(std::move(server), session,
method_instance, c);
} else {
SessionMethodCoroutine(std::move(server), session, method_instance,
c);
}
},
absl::StrFormat("Session %d Method %s", session->session_id,
method->name.c_str())));
},
absl::StrFormat("Session %d Method %s", session->session_id,
method->name.c_str())));
} else {
// Non-streaming methods use a two-coroutine pipeline: a request
// coroutine that invokes the (possibly async) handler and a response
// coroutine that publishes completed replies. They communicate via a
// SharedPtrPipe of ReplyItems.
auto pipe = toolbelt::SharedPtrPipe<internal::ReplyItem>::Create();
if (!pipe.ok()) {
logger_.Log(toolbelt::LogLevel::kError,
"Failed to create reply queue for method %s: %s",
method->name.c_str(), pipe.status().ToString().c_str());
return pipe.status();
}
method_instance->reply_queue =
std::make_shared<internal::ReplyQueue>(std::move(*pipe));

AddCoroutine(std::make_unique<co::Coroutine>(
*scheduler_,
[server = shared_from_this(), session,
method_instance](co::Coroutine *c) {
SessionRequestCoroutine(std::move(server), session, method_instance,
c);
},
absl::StrFormat("Session %d Method %s request", session->session_id,
method->name.c_str())));
AddCoroutine(std::make_unique<co::Coroutine>(
*scheduler_,
[server = shared_from_this(), session,
method_instance](co::Coroutine *c) {
SessionResponseCoroutine(std::move(server), session,
method_instance, c);
},
absl::StrFormat("Session %d Method %s response", session->session_id,
method->name.c_str())));
}
}
sessions_[session->session_id] = session;
logger_.Log(toolbelt::LogLevel::kDebug, "Created session: %d",
Expand All @@ -468,7 +528,7 @@ absl::Status RpcServer::DestroySession(int session_id) {
return absl::OkStatus();
}

void RpcServer::SessionMethodCoroutine(
void RpcServer::SessionRequestCoroutine(
std::shared_ptr<RpcServer> server, std::shared_ptr<Session> session,
std::shared_ptr<MethodInstance> method_instance, co::Coroutine *c) {
while (server->running_) {
Expand All @@ -483,8 +543,9 @@ void RpcServer::SessionMethodCoroutine(
if (*s == server->interrupt_pipe_.ReadFd().Fd()) {
break;
}
subspace::RpcRequest request;
bool request_ok = false;
// Drain all requests currently available for this method. Requests for
// other sessions can share the same channel, so only this session's
// messages are dispatched to the handler below.
for (;;) {
auto m = method_instance->request_subscriber->ReadMessage();
if (!m.ok()) {
Expand All @@ -495,46 +556,108 @@ void RpcServer::SessionMethodCoroutine(
break;
}
if (m->length == 0) {
// No message, continue waiting.
// No more messages, go back to waiting.
break;
}
subspace::RpcRequest request;
if (!request.ParseFromArray(m->buffer, m->length)) {
server->logger_.Log(toolbelt::LogLevel::kError,
"Error parsing request for method %s: %s",
method_instance->method->name.c_str(),
m.status().ToString().c_str());
"Error parsing request for method %s",
method_instance->method->name.c_str());
continue;
}
if (request.session_id() == session->session_id) {
request_ok = true;
break;
if (request.session_id() != session->session_id) {
continue;
}

// Give the handler reply handles tied to this request. The handler runs
// on this request coroutine (c), so the completed result is pushed into
// the reply queue from c and the response coroutine routes it back to the
// client that issued this request. Writes pass c so that, if the reply
// pipe is momentarily full, the write yields this coroutine (letting the
// response coroutine drain it) instead of blocking the scheduler thread.
// No locking is needed: the server runs on a single coroutine scheduler
// thread. Reply functions must only be invoked from c.
auto queue = method_instance->reply_queue;
auto reply_fn =
[queue, c, session_id = session->session_id,
request_id = request.request_id(), client_id = request.client_id()](
std::unique_ptr<google::protobuf::Any> response) mutable {
auto item = std::make_shared<internal::ReplyItem>();
item->session_id = session_id;
item->request_id = request_id;
item->client_id = client_id;
item->response = std::move(response);
(void)queue->pipe.Write(std::move(item), c);
};
auto error_fn =
[queue, c, session_id = session->session_id,
request_id = request.request_id(),
client_id = request.client_id()](std::string error_msg) mutable {
auto item = std::make_shared<internal::ReplyItem>();
item->session_id = session_id;
item->request_id = request_id;
item->client_id = client_id;
item->error_message = std::move(error_msg);
(void)queue->pipe.Write(std::move(item), c);
};

server->logger_.Log(toolbelt::LogLevel::kDebug, "Calling method %s",
method_instance->method->name.c_str());
method_instance->method->async_callback(
request.argument(), c, std::move(reply_fn), std::move(error_fn));
}
if (!request_ok) {
}
}

void RpcServer::SessionResponseCoroutine(
std::shared_ptr<RpcServer> server, std::shared_ptr<Session> session,
std::shared_ptr<MethodInstance> method_instance, co::Coroutine *c) {
auto &queue = *method_instance->reply_queue;
// dup the interrupt fd so this coroutine can wait on it concurrently with
// the request coroutine without violating the epoll "one waiter per fd"
// restriction.
toolbelt::FileDescriptor interrupt(
::dup(server->interrupt_pipe_.ReadFd().Fd()));

while (server->running_) {
// Wait for the next completed reply for this method, or for shutdown.
int fd = c->Wait(std::vector<int>{queue.pipe.ReadFd().Fd(), interrupt.Fd()},
POLLIN);
if (fd == interrupt.Fd()) {
break;
}
if (fd != queue.pipe.ReadFd().Fd()) {
continue;
}

subspace::RpcResponse response;
response.set_session_id(session->session_id);
response.set_request_id(request.request_id());
response.set_client_id(request.client_id());
auto *result = response.mutable_result();
server->logger_.Log(toolbelt::LogLevel::kDebug, "Calling method %s",
method_instance->method->name.c_str());
absl::Status method_status =
method_instance->method->callback(request.argument(), result, c);
if (!method_status.ok()) {
auto item_or = queue.pipe.Read(c);
if (!item_or.ok()) {
server->logger_.Log(toolbelt::LogLevel::kError,
"Error executing method %s: %s",
"Error reading queued response for method %s: %s",
method_instance->method->name.c_str(),
method_status.ToString().c_str());
response.set_error(absl::StrFormat("Error executing method %s: %s",
method_instance->method->name,
method_status.ToString()));
item_or.status().ToString().c_str());
continue;
}
auto item = std::move(*item_or);

// Convert the completed handler result into the protocol response expected
// by the waiting client.
subspace::RpcResponse response;
response.set_session_id(item->session_id);
response.set_request_id(item->request_id);
response.set_client_id(item->client_id);
if (!item->error_message.empty()) {
response.set_error(item->error_message);
} else if (item->response != nullptr) {
response.mutable_result()->CopyFrom(*item->response);
} else {
response.set_error("handler produced a null response");
}

uint64_t length = response.ByteSizeLong();
absl::StatusOr<void *> buffer;
bool got_buffer = false;
for (;;) {
buffer = method_instance->response_publisher->GetMessageBuffer(
int32_t(length));
Expand All @@ -543,36 +666,36 @@ void RpcServer::SessionMethodCoroutine(
"Error getting buffer for method %s: %s",
method_instance->method->name.c_str(),
buffer.status().ToString().c_str());
response.set_error(absl::StrFormat(
"Error getting buffer for method %s: %s",
method_instance->method->name, buffer.status().ToString()));
return;
break;
}
if (*buffer != nullptr) {
got_buffer = true;
break;
}
if (!server->interrupt_pipe_.ReadFd().Valid()) {
return;
}
// Buffer is not ready, wait and try again.
auto status = method_instance->response_publisher->Wait(
server->interrupt_pipe_.ReadFd(), c);
// The response channel is temporarily full; wait for the client to free
// a slot or for shutdown.
auto status = method_instance->response_publisher->Wait(interrupt, c);
if (!status.ok()) {
server->logger_.Log(toolbelt::LogLevel::kError,
"Error waiting for buffer: %s",
status.status().ToString().c_str());
return;
}
if (*status == server->interrupt_pipe_.ReadFd().Fd()) {
if (*status == interrupt.Fd()) {
return;
}
}
// We got a buffer, fill it in and send it.
if (!got_buffer) {
continue;
}
if (!response.SerializeToArray(*buffer, length)) {
server->logger_.Log(toolbelt::LogLevel::kError,
"Error serializing response for method %s",
method_instance->method->name.c_str());
break;
continue;
}
server->logger_.Log(
toolbelt::LogLevel::kDebug, "Publishing response for method %s: %s",
Expand Down
Loading
Loading