blob: 4ce50107ce7798d5c7b0e47f189a5b2e67db4c67 [file] [log] [blame]
#include "tlbmc/service/hft_service.h"
#include <cstddef>
#include <cstdint>
#include <memory>
#include <queue>
#include <string>
#include <utility>
#include "absl/base/thread_annotations.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/log.h"
#include "absl/strings/substitute.h"
#include "absl/synchronization/mutex.h"
#include "grpcpp/server_context.h"
#include "grpcpp/support/server_callback.h"
#include "grpcpp/support/status.h"
#include "hft_service.pb.h"
#include "tlbmc/subscription/manager.h"
#include "payload.pb.h"
#include "subscription_params.pb.h"
#include "bmcweb_authorizer_singleton.h"
#include "zatar/bmcweb_cert_provider.h"
namespace milotic_hft {
using ::milotic::authz::BmcWebAuthorizerSingleton;
using ::milotic::redfish::BmcWebCertProvider;
namespace internal {
ServerReactorImpl::ServerReactorImpl(
absl::AnyInvocable<void(ServerReactorImpl *)> on_done,
std::size_t maximum_event_queue_size)
: on_done_(std::move(on_done)),
maximum_event_queue_size_(maximum_event_queue_size) {}
void ServerReactorImpl::SafeFinish(const grpc::Status &status)
ABSL_LOCKS_EXCLUDED(mutex_) {
absl::MutexLock lock(&mutex_);
if (status_ == Status::kFinishCalled || status_ == Status::kFinished) {
return;
}
status_ = Status::kFinishCalled;
Finish(status);
}
bool ServerReactorImpl::AddResponse(HftResponse &&response)
ABSL_LOCKS_EXCLUDED(mutex_) {
absl::MutexLock lock(&mutex_);
switch (status_) {
case Status::kIdle:
case Status::kWriteInFlight:
break;
case Status::kFinishCalled:
case Status::kFinished:
DLOG(WARNING) << "Stream is finished or finishing";
return false;
}
if (response_queue_.size() >= maximum_event_queue_size_) {
DLOG(WARNING) << "Response queue is full, dropping response";
return false;
}
response_queue_.push(response);
if (status_ == Status::kIdle) {
status_ = Status::kWriteInFlight;
DLOG(INFO) << "StartWrite from idle";
StartWrite(&response_queue_.front());
}
return true;
}
ServerReactorImpl::Status ServerReactorImpl::GetStatus() const
ABSL_LOCKS_EXCLUDED(mutex_) {
absl::MutexLock lock(&mutex_);
return status_;
}
void ServerReactorImpl::OnWriteDone(bool ok) ABSL_LOCKS_EXCLUDED(mutex_) {
DLOG(INFO) << "OnWriteDone";
absl::MutexLock lock(&mutex_);
if (status_ != Status::kWriteInFlight) {
LOG(WARNING)
<< "OnWriteDone called when the stream is not write in flight: "
<< StatusToString(status_);
return;
}
if (!ok) {
LOG(ERROR) << "Failed to write response";
status_ = Status::kFinishCalled;
Finish(
grpc::Status(grpc::StatusCode::INTERNAL, "Failed to write response"));
return;
}
response_queue_.pop();
if (response_queue_.empty()) {
status_ = Status::kIdle;
} else {
DLOG(INFO) << "StartWrite from OnWriteDone";
StartWrite(&response_queue_.front());
}
}
void ServerReactorImpl::OnCancel() {
DLOG(INFO) << "OnCancel";
absl::MutexLock lock(&mutex_);
if (status_ == Status::kFinishCalled || status_ == Status::kFinished) {
return;
}
status_ = Status::kFinishCalled;
Finish(
grpc::Status(grpc::StatusCode::CANCELLED, "Client cancelled requests."));
}
void ServerReactorImpl::OnDone() {
DLOG(INFO) << "OnDone";
{
absl::MutexLock lock(&mutex_);
if (status_ == Status::kFinished) {
return;
}
status_ = Status::kFinished;
}
on_done_(this);
// At this moment, the reactor must not be used anymore.
}
} // namespace internal
uint64_t HftServiceImpl::GetAccumulativeSampleRate(
const std::string &role) const {
absl::MutexLock lock(&mutex_);
auto it = role_to_total_sample_rate_.find(role);
if (it != role_to_total_sample_rate_.end()) {
return it->second;
}
return 0;
}
uint64_t HftServiceImpl::GetDecreaseSampleRateWhenFinishedCount() const {
absl::MutexLock lock(&mutex_);
return reactor_should_decrease_sample_rate_when_finished_.size();
}
uint64_t HftServiceImpl::GetReactorsCount() const {
absl::MutexLock lock(&mutex_);
return reactor_map_.size();
}
uint64_t HftServiceImpl::GetRoleToTotalSampleRateCount() const {
absl::MutexLock lock(&mutex_);
return role_to_total_sample_rate_.size();
}
uint64_t HftServiceImpl::GetReactorToSubscriptionIdsCount() const {
absl::MutexLock lock(&mutex_);
return reactor_to_subscription_ids_.size();
}
HftServiceImpl::HftServiceImpl(
const HftServiceOptions &options,
std::unique_ptr<SubscriptionManager> subscription_manager)
: options_(options),
subscription_manager_(std::move(subscription_manager)) {}
grpc::ServerWriteReactor<HftResponse> *HftServiceImpl::Subscribe(
grpc::CallbackServerContext *context, const HftRequest *request) {
BmcWebAuthorizerSingleton &bmcweb_authorizer =
BmcWebAuthorizerSingleton::GetInstance();
std::string role;
grpc::Status authz_status = bmcweb_authorizer.GetPeerRoleFromAuthContext(
*context->auth_context(), role);
return SubscribeWithStatus(request, role, authz_status);
}
grpc::ServerWriteReactor<HftResponse> *HftServiceImpl::SubscribeWithStatus(
const HftRequest *request, const std::string &role,
grpc::Status authz_status) {
using internal::ServerReactorImpl;
BmcWebAuthorizerSingleton &bmcweb_authorizer =
BmcWebAuthorizerSingleton::GetInstance();
uint64_t total_sample_rate = 0;
// TODO(nanzhou): limit subscriptions for "all" resources as well. For now,
// assume that subscriptions all come with explicit identifiers.
for (const SubscriptionParams &subscription : request->subscriptions()) {
total_sample_rate += 1000 / subscription.sampling_interval_ms();
}
// Only allows subscriptions if the server has root certs and prod signed
// cert.
if (options_.cert_provider == nullptr ||
options_.cert_provider->GetServerStatus() !=
BmcWebCertProvider::ServerStatus::kWithRootCertsAndProdSignedCert) {
authz_status = grpc::Status(
grpc::StatusCode::PERMISSION_DENIED,
"The server does not have root certs and prod signed cert.");
}
absl::AnyInvocable<void(ServerReactorImpl *)> on_reactor_done =
[this, total_sample_rate, role](ServerReactorImpl *reactor) {
absl::MutexLock lock(&mutex_);
if (auto it = role_to_total_sample_rate_.find(role);
!role.empty() &&
reactor_should_decrease_sample_rate_when_finished_.contains(
reactor) &&
it != role_to_total_sample_rate_.end()) {
it->second = it->second - total_sample_rate;
DLOG(INFO) << "role: " << role
<< " total_sample_rate: " << total_sample_rate
<< " role_to_total_sample_rate_: "
<< role_to_total_sample_rate_[role];
if (it->second == 0) {
role_to_total_sample_rate_.erase(it);
}
}
reactor_should_decrease_sample_rate_when_finished_.erase(reactor);
for (const SubscriptionManager::SubscriptionId &subscription_id :
reactor_to_subscription_ids_[reactor]) {
if (absl::Status status =
subscription_manager_->Unsubscribe(subscription_id);
!status.ok()) {
LOG(WARNING) << "Failed to unsubscribe subscription: " << status;
}
}
reactor_to_subscription_ids_.erase(reactor);
reactor_map_.erase(reactor);
};
std::shared_ptr<ServerReactorImpl> reactor =
std::make_shared<ServerReactorImpl>(std::move(on_reactor_done),
options_.maximum_event_queue_size);
{
absl::MutexLock lock(&mutex_);
reactor_map_[reactor.get()] = reactor;
}
if (!authz_status.ok() && options_.enable_authorization) {
reactor->SafeFinish(authz_status);
}
if (authz_status.ok() && options_.enable_authorization) {
absl::MutexLock lock(&mutex_);
uint64_t new_total_sample_rate = total_sample_rate;
auto it = role_to_total_sample_rate_.find(role);
if (it != role_to_total_sample_rate_.end()) {
new_total_sample_rate += it->second;
}
uint64_t sample_rate_limit = bmcweb_authorizer.GetSampleRateLimit(role);
if (new_total_sample_rate > sample_rate_limit &&
options_.enable_authorization) {
reactor->SafeFinish(grpc::Status(
grpc::StatusCode::RESOURCE_EXHAUSTED,
absl::Substitute(
"Role $0 has reached the maximum allowed sample "
"rate of $1, currently $2, new total sample rate "
"after this RPC: $3",
role, sample_rate_limit,
it == role_to_total_sample_rate_.end() ? 0 : it->second,
new_total_sample_rate)));
} else {
role_to_total_sample_rate_[role] = new_total_sample_rate;
reactor_should_decrease_sample_rate_when_finished_.insert(reactor.get());
}
}
if (reactor->GetStatus() != ServerReactorImpl::Status::kFinishCalled &&
reactor->GetStatus() != ServerReactorImpl::Status::kFinished) {
for (const SubscriptionParams &subscription : request->subscriptions()) {
absl::AnyInvocable<void(Payload &&)> on_data_callback =
[reactor](Payload &&payload) mutable {
if (!reactor) return;
// If reactor ever reports that it is finished, we should never
// access it again.
if (reactor->GetStatus() == ServerReactorImpl::Status::kFinished) {
reactor = nullptr;
return;
}
// If reactor is about to finish, we should return and wait until
// it is actually finished. We must not start a new StartWrite().
if (reactor->GetStatus() ==
ServerReactorImpl::Status::kFinishCalled) {
return;
}
HftResponse response;
*response.add_payloads() = std::move(payload);
bool added = reactor->AddResponse(std::move(response));
if (!added) {
DLOG(WARNING) << "Failed to add response";
}
};
absl::StatusOr<SubscriptionManager::SubscriptionId> subscription_id =
subscription_manager_->AddSubscription(subscription,
std::move(on_data_callback));
if (!subscription_id.ok()) {
reactor->SafeFinish(
grpc::Status(grpc::StatusCode::INTERNAL,
absl::Substitute("Failed to add subscription: $0",
subscription_id.status().message())));
} else {
absl::MutexLock lock(&mutex_);
reactor_to_subscription_ids_[reactor.get()].push_back(*subscription_id);
}
}
}
if (request->subscriptions_size() == 0) {
reactor->SafeFinish(grpc::Status(grpc::StatusCode::INVALID_ARGUMENT,
"No subscription is provided."));
}
return reactor.get();
}
} // namespace milotic_hft