| #include "condition.h" |
| |
| #include <memory> |
| #include <string> |
| #include <tuple> |
| #include <utility> |
| #include <vector> |
| |
| #include "google/protobuf/timestamp.pb.h" |
| #include "convert_proto.h" |
| #include "daemon_context.h" |
| #include "safepower_agent.pb.h" |
| #include "state_fields.h" |
| #include "state_merge.h" |
| #include "state_updater.h" |
| #include "absl/base/nullability.h" |
| #include "absl/functional/any_invocable.h" |
| #include "absl/log/check.h" |
| #include "absl/log/log.h" |
| #include "absl/status/status.h" |
| #include "absl/status/statusor.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/string_view.h" |
| #include "absl/time/time.h" |
| #include "google/protobuf/repeated_ptr_field.h" |
| #include "bmc/status_macros.h" |
| |
| namespace safepower_agent { |
| |
| using ::safepower_agent_proto::AllOfCondition; |
| using ::safepower_agent_proto::AnyOfCondition; |
| using ::safepower_agent_proto::BootState; |
| using ::safepower_agent_proto::ComparisonType; |
| using ::safepower_agent_proto::ComponentState; |
| using ::safepower_agent_proto::ConnectionState; |
| using ::safepower_agent_proto::DaemonState; |
| using ::safepower_agent_proto::PowerState; |
| using ::safepower_agent_proto::StateMatchedCondition; |
| using ::safepower_agent_proto::SystemState; |
| using ::safepower_agent_proto::TimeoutCondition; |
| |
| template <typename T> |
| absl::StatusOr<bool> Compare(ComparisonType comparison_type, T&& a, |
| T&& expected) { |
| switch (comparison_type) { |
| case ComparisonType::COMPARISON_TYPE_EQ: |
| return a == expected; |
| case ComparisonType::COMPARISON_TYPE_NE: |
| return a != expected; |
| case ComparisonType::COMPARISON_TYPE_GT: |
| return a > expected; |
| case ComparisonType::COMPARISON_TYPE_GE: |
| return a >= expected; |
| case ComparisonType::COMPARISON_TYPE_LT: |
| return a < expected; |
| case ComparisonType::COMPARISON_TYPE_LE: |
| return a <= expected; |
| default: |
| return absl::InvalidArgumentError( |
| absl::StrCat("Unknown comparison type: ", comparison_type)); |
| return false; |
| } |
| } |
| |
| template <typename Proto, typename FieldDefT> |
| absl::StatusOr<bool> CompareField(const FieldDefT& field, |
| ComparisonType comparison_type, |
| const Proto& a, const Proto& expected) { |
| if (field.is_in(a) && field.is_in(expected)) { |
| return Compare(comparison_type, field.get_from(a), |
| field.get_from(expected)); |
| } else { |
| return true; |
| } |
| } |
| |
| static bool CompareResultMatches(const absl::StatusOr<bool>& result) { |
| return result.ok() && *result; |
| } |
| |
| template <typename Proto> |
| absl::StatusOr<bool> Compare(ComparisonType comparison_type, const Proto& a, |
| const Proto& expected) { |
| return std::apply( |
| [&](auto&&... field_defs) { |
| absl::StatusOr<bool> result = true; |
| (CompareResultMatches( |
| result = CompareField(field_defs, comparison_type, a, expected)) && |
| ...); |
| return result; |
| }, |
| ProtoFields<Proto>::kFields); |
| } |
| |
| template <> |
| absl::StatusOr<bool> Compare( |
| safepower_agent_proto::ComparisonType comparison_type, |
| const ComponentState& a, const ComponentState& expected) { |
| if (a.state_case() != expected.state_case()) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("Mismatched component state types: ", a.state_case(), |
| " vs expected ", expected.state_case())); |
| } |
| switch (a.state_case()) { |
| case ComponentState::kDaemonState: |
| return Compare(comparison_type, a.daemon_state(), |
| expected.daemon_state()); |
| case ComponentState::kBootState: |
| return Compare(comparison_type, a.boot_state(), expected.boot_state()); |
| case ComponentState::kPowerState: |
| return Compare(comparison_type, a.power_state(), expected.power_state()); |
| case ComponentState::kConnectionState: |
| return Compare(comparison_type, a.connection_state(), |
| expected.connection_state()); |
| case ComponentState::STATE_NOT_SET: |
| return absl::InvalidArgumentError("Component state not set"); |
| } |
| return false; |
| } |
| |
| using MatchResult = std::tuple<absl::Status, Condition::MatchList>; |
| |
| static MatchResult ConditionMatches( |
| const safepower_agent_proto::Condition& condition, |
| const safepower_agent_proto::SystemState& system_state, |
| absl::Time start_time); |
| |
| static MatchResult ConditionMatches(const AnyOfCondition& any_of, |
| const SystemState& system_state, |
| absl::Time start_time) { |
| Condition::MatchList result; |
| absl::Status status = absl::OkStatus(); |
| for (const safepower_agent_proto::Condition& condition : |
| any_of.conditions()) { |
| auto [match_status, match_list] = |
| ConditionMatches(condition, system_state, start_time); |
| result.insert(result.end(), match_list.begin(), match_list.end()); |
| status.Update(match_status); |
| } |
| return {status, result}; |
| } |
| |
| static MatchResult ConditionMatches(const AllOfCondition& all_of, |
| const SystemState& system_state, |
| absl::Time start_time) { |
| bool matches_all = true; |
| absl::Status status = absl::OkStatus(); |
| Condition::MatchList result; |
| for (const safepower_agent_proto::Condition& condition : |
| all_of.conditions()) { |
| auto [match_status, match_list] = |
| ConditionMatches(condition, system_state, start_time); |
| result.insert(result.end(), match_list.begin(), match_list.end()); |
| if (match_list.empty()) { |
| // Make sure we don't return early if there are any other conditions. It |
| // is possible that the other conditions will return an error, which will |
| // override the return value. |
| matches_all = false; |
| } |
| status.Update(match_status); |
| } |
| if (!matches_all && status.ok()) { |
| result.clear(); |
| } |
| return {status, result}; |
| } |
| |
| static absl::StatusOr<bool> ConditionMatches( |
| const StateMatchedCondition& condition, const SystemState& system_state) { |
| auto node_state_found = system_state.node_state().find( |
| condition.system_component().node_entity_tag()); |
| if (node_state_found == system_state.node_state().end()) { |
| return false; |
| } |
| auto component_state_found = node_state_found->second.component_state().find( |
| condition.system_component().name()); |
| if (component_state_found == |
| node_state_found->second.component_state().end()) { |
| return false; |
| } |
| |
| if (condition.component_state().state_case() == |
| ComponentState::STATE_NOT_SET) { |
| return absl::InvalidArgumentError("Component state not set"); |
| } |
| |
| ASSIGN_OR_RETURN(bool matches, Compare(condition.comparison_type(), |
| component_state_found->second, |
| condition.component_state())); |
| if (matches && condition.abort()) { |
| return absl::AbortedError( |
| absl::StrCat("Abort condition matched: ", condition)); |
| } |
| return matches; |
| } |
| |
| static absl::StatusOr<bool> ConditionMatches(const TimeoutCondition& condition, |
| absl::Time start_time) { |
| bool matches = false; |
| switch (condition.time_case()) { |
| case TimeoutCondition::kTimeoutMs: |
| matches = start_time + absl::Milliseconds(condition.timeout_ms()) <= |
| DaemonContext::Get().now(); |
| break; |
| case TimeoutCondition::kDeadline: |
| matches = ConvertTime(condition.deadline()) <= DaemonContext::Get().now(); |
| break; |
| default: |
| return absl::InvalidArgumentError( |
| absl::StrCat("Unknown timeout type: ", condition)); |
| } |
| if (matches && condition.abort()) { |
| return absl::AbortedError( |
| absl::StrCat("Abort condition matched: ", condition)); |
| } |
| return matches; |
| } |
| |
| static MatchResult GetMatchResult( |
| const safepower_agent_proto::Condition& condition, |
| absl::StatusOr<bool> matches) { |
| if (!matches.ok()) { |
| return {matches.status(), {&condition}}; |
| } |
| if (*matches) { |
| return {absl::OkStatus(), {&condition}}; |
| } |
| return {absl::OkStatus(), {}}; |
| } |
| |
| static MatchResult ConditionMatches( |
| const safepower_agent_proto::Condition& condition, |
| const safepower_agent_proto::SystemState& system_state, |
| absl::Time start_time) { |
| switch (condition.condition_type_case()) { |
| case safepower_agent_proto::Condition::kAnyOf: |
| return ConditionMatches(condition.any_of(), system_state, start_time); |
| case safepower_agent_proto::Condition::kAllOf: |
| return ConditionMatches(condition.all_of(), system_state, start_time); |
| case safepower_agent_proto::Condition::kStateCondition: |
| return GetMatchResult( |
| condition, |
| ConditionMatches(condition.state_condition(), system_state)); |
| case safepower_agent_proto::Condition::kTimeout: |
| return GetMatchResult(condition, |
| ConditionMatches(condition.timeout(), start_time)); |
| default: |
| return {absl::InvalidArgumentError( |
| absl::StrCat("Unknown condition type: ", condition)), |
| {&condition}}; |
| } |
| } |
| |
| Condition::Condition(const safepower_agent_proto::Condition& condition, |
| absl::Time original_start_time) |
| : listener_(*this), |
| condition_(condition), |
| start_time_(original_start_time) {} |
| |
| Condition::~Condition() { |
| for (const std::string& job_name : pending_jobs_) { |
| absl::Status status = DaemonContext::Get().scheduler().CancelCall(job_name); |
| if (!status.ok() && !absl::IsNotFound(status)) { |
| LOG(ERROR) << "Failed to cancel job: " << status; |
| } |
| } |
| } |
| |
| absl::Status Condition::WaitForMatch( |
| absl::string_view unique_id, |
| absl::Nonnull<std::shared_ptr<StateUpdater<SystemState>>> state_updater, |
| absl::AnyInvocable<void(absl::Status, Condition::MatchList) &&> callback) { |
| complete_job_name_ = absl::StrCat(unique_id, ".complete"); |
| state_updater_ = std::move(state_updater); |
| if (callback == nullptr) { |
| return absl::InvalidArgumentError("Callback must not be null"); |
| } |
| callback_ = std::move(callback); |
| RETURN_IF_ERROR(RegisterTimeouts(condition_, unique_id)); |
| listener_.Listen(state_updater_); |
| return absl::OkStatus(); |
| } |
| |
| std::tuple<absl::Status, Condition::MatchList> Condition::CheckState( |
| const SystemState& state) { |
| return ConditionMatches(condition_, state, start_time_); |
| } |
| |
| absl::Status Condition::TriggerCallback(absl::Status match_status, |
| MatchList match_list) { |
| if (callback_ == nullptr) { |
| return absl::FailedPreconditionError("Condition is not active"); |
| } |
| return DaemonContext::Get().scheduler().DelayCall( |
| [callback = std::move(callback_), match_status = std::move(match_status), |
| match_list = std::move(match_list)]() mutable { |
| std::move(callback)(std::move(match_status), std::move(match_list)); |
| }, |
| absl::ZeroDuration(), complete_job_name_); |
| } |
| |
| void Condition::Listener::UpdateState(const SystemState& current_state, |
| const SystemState& update) { |
| auto [match_status, match_list] = |
| condition_.CheckState(MergeState(current_state, update)); |
| if (!match_status.ok() || !match_list.empty()) { |
| absl::Status status = condition_.TriggerCallback(std::move(match_status), |
| std::move(match_list)); |
| if (!status.ok()) { |
| LOG(ERROR) << "Failed to trigger callback: " << status; |
| } |
| } |
| } |
| |
| absl::Status Condition::RegisterTimeouts( |
| const safepower_agent_proto::Condition& condition, |
| absl::string_view job_name_prefix) { |
| switch (condition.condition_type_case()) { |
| case safepower_agent_proto::Condition::kAnyOf: |
| RETURN_IF_ERROR( |
| RegisterTimeouts(condition.any_of().conditions(), job_name_prefix)); |
| break; |
| case safepower_agent_proto::Condition::kAllOf: |
| RETURN_IF_ERROR( |
| RegisterTimeouts(condition.all_of().conditions(), job_name_prefix)); |
| break; |
| case safepower_agent_proto::Condition::kStateCondition: |
| break; |
| case safepower_agent_proto::Condition::kTimeout: { |
| absl::Time deadline; |
| switch (condition.timeout().time_case()) { |
| case TimeoutCondition::kTimeoutMs: |
| deadline = start_time_ + |
| absl::Milliseconds(condition.timeout().timeout_ms()); |
| break; |
| case TimeoutCondition::kDeadline: |
| deadline = ConvertTime(condition.timeout().deadline()); |
| break; |
| default: |
| return absl::InvalidArgumentError( |
| absl::StrCat("Unknown timeout type: ", condition)); |
| } |
| pending_jobs_.push_back(absl::StrCat(job_name_prefix, ".timer")); |
| RETURN_IF_ERROR(DaemonContext::Get().scheduler().DelayCall( |
| [this] { |
| // Push empty update to trigger a check. |
| state_updater_->UpdateState({}); |
| }, |
| deadline - DaemonContext::Get().now(), pending_jobs_.back())); |
| break; |
| } |
| default: |
| return absl::InvalidArgumentError( |
| absl::StrCat("Unknown condition type: ", condition)); |
| } |
| return absl::OkStatus(); |
| } |
| |
| absl::Status Condition::RegisterTimeouts( |
| const google::protobuf::RepeatedPtrField<safepower_agent_proto::Condition>& |
| conditions, |
| absl::string_view job_name_prefix) { |
| int index = 0; |
| for (const safepower_agent_proto::Condition& condition : conditions) { |
| RETURN_IF_ERROR(RegisterTimeouts( |
| condition, absl::StrCat(job_name_prefix, ".", index++))); |
| } |
| return absl::OkStatus(); |
| } |
| |
| } // namespace safepower_agent |