blob: 8a4e29b53b053f03fd1a79b66645baef07837ac4 [file] [log] [blame]
#include "condition.h"
#include <memory>
#include <string>
#include <tuple>
#include <utility>
#include <vector>
#include "google/protobuf/timestamp.pb.h"
#include "convert_proto.h"
#include "daemon_context.h"
#include "safepower_agent.pb.h"
#include "state_fields.h"
#include "state_merge.h"
#include "state_updater.h"
#include "absl/base/nullability.h"
#include "absl/functional/any_invocable.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "absl/time/time.h"
#include "google/protobuf/repeated_ptr_field.h"
#include "bmc/status_macros.h"
namespace safepower_agent {
using ::safepower_agent_proto::AllOfCondition;
using ::safepower_agent_proto::AnyOfCondition;
using ::safepower_agent_proto::BootState;
using ::safepower_agent_proto::ComparisonType;
using ::safepower_agent_proto::ComponentState;
using ::safepower_agent_proto::ConnectionState;
using ::safepower_agent_proto::DaemonState;
using ::safepower_agent_proto::PowerState;
using ::safepower_agent_proto::StateMatchedCondition;
using ::safepower_agent_proto::SystemState;
using ::safepower_agent_proto::TimeoutCondition;
template <typename T>
absl::StatusOr<bool> Compare(ComparisonType comparison_type, T&& a,
T&& expected) {
switch (comparison_type) {
case ComparisonType::COMPARISON_TYPE_EQ:
return a == expected;
case ComparisonType::COMPARISON_TYPE_NE:
return a != expected;
case ComparisonType::COMPARISON_TYPE_GT:
return a > expected;
case ComparisonType::COMPARISON_TYPE_GE:
return a >= expected;
case ComparisonType::COMPARISON_TYPE_LT:
return a < expected;
case ComparisonType::COMPARISON_TYPE_LE:
return a <= expected;
default:
return absl::InvalidArgumentError(
absl::StrCat("Unknown comparison type: ", comparison_type));
return false;
}
}
template <typename Proto, typename FieldDefT>
absl::StatusOr<bool> CompareField(const FieldDefT& field,
ComparisonType comparison_type,
const Proto& a, const Proto& expected) {
if (field.is_in(a) && field.is_in(expected)) {
return Compare(comparison_type, field.get_from(a),
field.get_from(expected));
} else {
return true;
}
}
static bool CompareResultMatches(const absl::StatusOr<bool>& result) {
return result.ok() && *result;
}
template <typename Proto>
absl::StatusOr<bool> Compare(ComparisonType comparison_type, const Proto& a,
const Proto& expected) {
return std::apply(
[&](auto&&... field_defs) {
absl::StatusOr<bool> result = true;
(CompareResultMatches(
result = CompareField(field_defs, comparison_type, a, expected)) &&
...);
return result;
},
ProtoFields<Proto>::kFields);
}
template <>
absl::StatusOr<bool> Compare(
safepower_agent_proto::ComparisonType comparison_type,
const ComponentState& a, const ComponentState& expected) {
if (a.state_case() != expected.state_case()) {
return absl::InvalidArgumentError(
absl::StrCat("Mismatched component state types: ", a.state_case(),
" vs expected ", expected.state_case()));
}
switch (a.state_case()) {
case ComponentState::kDaemonState:
return Compare(comparison_type, a.daemon_state(),
expected.daemon_state());
case ComponentState::kBootState:
return Compare(comparison_type, a.boot_state(), expected.boot_state());
case ComponentState::kPowerState:
return Compare(comparison_type, a.power_state(), expected.power_state());
case ComponentState::kConnectionState:
return Compare(comparison_type, a.connection_state(),
expected.connection_state());
case ComponentState::STATE_NOT_SET:
return absl::InvalidArgumentError("Component state not set");
}
return false;
}
using MatchResult = std::tuple<absl::Status, Condition::MatchList>;
static MatchResult ConditionMatches(
const safepower_agent_proto::Condition& condition,
const safepower_agent_proto::SystemState& system_state,
absl::Time start_time);
static MatchResult ConditionMatches(const AnyOfCondition& any_of,
const SystemState& system_state,
absl::Time start_time) {
Condition::MatchList result;
absl::Status status = absl::OkStatus();
for (const safepower_agent_proto::Condition& condition :
any_of.conditions()) {
auto [match_status, match_list] =
ConditionMatches(condition, system_state, start_time);
result.insert(result.end(), match_list.begin(), match_list.end());
status.Update(match_status);
}
return {status, result};
}
static MatchResult ConditionMatches(const AllOfCondition& all_of,
const SystemState& system_state,
absl::Time start_time) {
bool matches_all = true;
absl::Status status = absl::OkStatus();
Condition::MatchList result;
for (const safepower_agent_proto::Condition& condition :
all_of.conditions()) {
auto [match_status, match_list] =
ConditionMatches(condition, system_state, start_time);
result.insert(result.end(), match_list.begin(), match_list.end());
if (match_list.empty()) {
// Make sure we don't return early if there are any other conditions. It
// is possible that the other conditions will return an error, which will
// override the return value.
matches_all = false;
}
status.Update(match_status);
}
if (!matches_all && status.ok()) {
result.clear();
}
return {status, result};
}
static absl::StatusOr<bool> ConditionMatches(
const StateMatchedCondition& condition, const SystemState& system_state) {
auto node_state_found = system_state.node_state().find(
condition.system_component().node_entity_tag());
if (node_state_found == system_state.node_state().end()) {
return false;
}
auto component_state_found = node_state_found->second.component_state().find(
condition.system_component().name());
if (component_state_found ==
node_state_found->second.component_state().end()) {
return false;
}
if (condition.component_state().state_case() ==
ComponentState::STATE_NOT_SET) {
return absl::InvalidArgumentError("Component state not set");
}
ASSIGN_OR_RETURN(bool matches, Compare(condition.comparison_type(),
component_state_found->second,
condition.component_state()));
if (matches && condition.abort()) {
return absl::AbortedError(
absl::StrCat("Abort condition matched: ", condition));
}
return matches;
}
static absl::StatusOr<bool> ConditionMatches(const TimeoutCondition& condition,
absl::Time start_time) {
bool matches = false;
switch (condition.time_case()) {
case TimeoutCondition::kTimeoutMs:
matches = start_time + absl::Milliseconds(condition.timeout_ms()) <=
DaemonContext::Get().now();
break;
case TimeoutCondition::kDeadline:
matches = ConvertTime(condition.deadline()) <= DaemonContext::Get().now();
break;
default:
return absl::InvalidArgumentError(
absl::StrCat("Unknown timeout type: ", condition));
}
if (matches && condition.abort()) {
return absl::AbortedError(
absl::StrCat("Abort condition matched: ", condition));
}
return matches;
}
static MatchResult GetMatchResult(
const safepower_agent_proto::Condition& condition,
absl::StatusOr<bool> matches) {
if (!matches.ok()) {
return {matches.status(), {&condition}};
}
if (*matches) {
return {absl::OkStatus(), {&condition}};
}
return {absl::OkStatus(), {}};
}
static MatchResult ConditionMatches(
const safepower_agent_proto::Condition& condition,
const safepower_agent_proto::SystemState& system_state,
absl::Time start_time) {
switch (condition.condition_type_case()) {
case safepower_agent_proto::Condition::kAnyOf:
return ConditionMatches(condition.any_of(), system_state, start_time);
case safepower_agent_proto::Condition::kAllOf:
return ConditionMatches(condition.all_of(), system_state, start_time);
case safepower_agent_proto::Condition::kStateCondition:
return GetMatchResult(
condition,
ConditionMatches(condition.state_condition(), system_state));
case safepower_agent_proto::Condition::kTimeout:
return GetMatchResult(condition,
ConditionMatches(condition.timeout(), start_time));
default:
return {absl::InvalidArgumentError(
absl::StrCat("Unknown condition type: ", condition)),
{&condition}};
}
}
Condition::Condition(const safepower_agent_proto::Condition& condition,
absl::Time original_start_time)
: listener_(*this),
condition_(condition),
start_time_(original_start_time) {}
Condition::~Condition() {
for (const std::string& job_name : pending_jobs_) {
absl::Status status = DaemonContext::Get().scheduler().CancelCall(job_name);
if (!status.ok() && !absl::IsNotFound(status)) {
LOG(ERROR) << "Failed to cancel job: " << status;
}
}
}
absl::Status Condition::WaitForMatch(
absl::string_view unique_id,
absl::Nonnull<std::shared_ptr<StateUpdater<SystemState>>> state_updater,
absl::AnyInvocable<void(absl::Status, Condition::MatchList) &&> callback) {
complete_job_name_ = absl::StrCat(unique_id, ".complete");
state_updater_ = std::move(state_updater);
if (callback == nullptr) {
return absl::InvalidArgumentError("Callback must not be null");
}
callback_ = std::move(callback);
RETURN_IF_ERROR(RegisterTimeouts(condition_, unique_id));
listener_.Listen(state_updater_);
return absl::OkStatus();
}
std::tuple<absl::Status, Condition::MatchList> Condition::CheckState(
const SystemState& state) {
return ConditionMatches(condition_, state, start_time_);
}
absl::Status Condition::TriggerCallback(absl::Status match_status,
MatchList match_list) {
if (callback_ == nullptr) {
return absl::FailedPreconditionError("Condition is not active");
}
return DaemonContext::Get().scheduler().DelayCall(
[callback = std::move(callback_), match_status = std::move(match_status),
match_list = std::move(match_list)]() mutable {
std::move(callback)(std::move(match_status), std::move(match_list));
},
absl::ZeroDuration(), complete_job_name_);
}
void Condition::Listener::UpdateState(const SystemState& current_state,
const SystemState& update) {
auto [match_status, match_list] =
condition_.CheckState(MergeState(current_state, update));
if (!match_status.ok() || !match_list.empty()) {
absl::Status status = condition_.TriggerCallback(std::move(match_status),
std::move(match_list));
if (!status.ok()) {
LOG(ERROR) << "Failed to trigger callback: " << status;
}
}
}
absl::Status Condition::RegisterTimeouts(
const safepower_agent_proto::Condition& condition,
absl::string_view job_name_prefix) {
switch (condition.condition_type_case()) {
case safepower_agent_proto::Condition::kAnyOf:
RETURN_IF_ERROR(
RegisterTimeouts(condition.any_of().conditions(), job_name_prefix));
break;
case safepower_agent_proto::Condition::kAllOf:
RETURN_IF_ERROR(
RegisterTimeouts(condition.all_of().conditions(), job_name_prefix));
break;
case safepower_agent_proto::Condition::kStateCondition:
break;
case safepower_agent_proto::Condition::kTimeout: {
absl::Time deadline;
switch (condition.timeout().time_case()) {
case TimeoutCondition::kTimeoutMs:
deadline = start_time_ +
absl::Milliseconds(condition.timeout().timeout_ms());
break;
case TimeoutCondition::kDeadline:
deadline = ConvertTime(condition.timeout().deadline());
break;
default:
return absl::InvalidArgumentError(
absl::StrCat("Unknown timeout type: ", condition));
}
pending_jobs_.push_back(absl::StrCat(job_name_prefix, ".timer"));
RETURN_IF_ERROR(DaemonContext::Get().scheduler().DelayCall(
[this] {
// Push empty update to trigger a check.
state_updater_->UpdateState({});
},
deadline - DaemonContext::Get().now(), pending_jobs_.back()));
break;
}
default:
return absl::InvalidArgumentError(
absl::StrCat("Unknown condition type: ", condition));
}
return absl::OkStatus();
}
absl::Status Condition::RegisterTimeouts(
const google::protobuf::RepeatedPtrField<safepower_agent_proto::Condition>&
conditions,
absl::string_view job_name_prefix) {
int index = 0;
for (const safepower_agent_proto::Condition& condition : conditions) {
RETURN_IF_ERROR(RegisterTimeouts(
condition, absl::StrCat(job_name_prefix, ".", index++)));
}
return absl::OkStatus();
}
} // namespace safepower_agent