#include "action_context.h"

#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>

#include "google/protobuf/timestamp.pb.h"
#include "condition.h"
#include "convert_proto.h"
#include "daemon_context.h"
#include "safepower_agent.pb.h"
#include "state_persistence.pb.h"
#include "state_updater.h"
#include "absl/base/nullability.h"
#include "absl/functional/any_invocable.h"
#include "absl/functional/bind_front.h"
#include "absl/log/check.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
#include "absl/time/time.h"
#include "bmc/status_macros.h"

namespace safepower_agent {

using ::safepower_agent_persistence_proto::SavedActions;
using ::safepower_agent_proto::ActionState;
using ::safepower_agent_proto::ActionStateChange;
using ::safepower_agent_proto::ActionStateLog;

ActionStateLog ActionContext::NewInitialState() {
  ActionStateLog initial_state;
  initial_state.set_epoch_ms(DaemonContext::Get().epoch_ms());
  initial_state.set_current_state(safepower_agent_proto::ACTION_STATE_INIT);
  SetTimestampToNow(*initial_state.add_history()->mutable_changed_at());
  return initial_state;
}

static absl::Time GetStartTime(
    const safepower_agent_proto::ActionStateLog& state) {
  if (state.history_size() == 0) [[unlikely]] {
    LOG(DFATAL) << "No start time in initial state";
    return DaemonContext::Get().now();
  }
  return ConvertTime(state.history(0).changed_at());
}

ActionContext::ActionContext(CreationToken /*token*/,
                             ActionContextManager& manager,
                             std::string action_id,
                             safepower_agent_proto::StartActionRequest request,
                             Action action_impl, ActionStateLog initial_state)

    : manager_(manager),
      request_(std::move(request)),
      action_impl_(std::move(action_impl)),
      precondition_(
          request_.has_precondition()
              ? std::make_optional<Condition>(request_.precondition(),
                                              GetStartTime(initial_state))
              : std::nullopt),
      validation_(request_.has_validation()
                      ? std::make_optional<Condition>(
                            request_.validation(), GetStartTime(initial_state))
                      : std::nullopt),
      action_state_updater_(
          std::make_shared<StateUpdater<ActionStateLog>>(initial_state)),
      action_id_(std::move(action_id)) {}

absl::Status ActionContext::StartCheckingCondition(
    Condition& condition,
    absl::AnyInvocable<void(absl::Status, Condition::MatchList)> callback) {
  absl::Status status = condition.WaitForMatch(execution_task_name(),
                                               manager_.system_state_updater(),
                                               std::move(callback));
  if (!status.ok()) {
    LOG(ERROR) << "Failed to wait for condition for " << action_id_ << ": "
               << status;
    ActionStateChange change;
    SetStatus(*change.mutable_status(), status);
    SetState(ActionState::ACTION_STATE_ERROR, std::move(change));
    Finish(std::move(action_impl_));
    return status;
  }
  return absl::OkStatus();
}

absl::Status ActionContext::Activate() {
  absl::MutexLock lock(&mutex_);
  LOG(INFO) << "Activate " << action_id_ << " in state "
            << safepower_agent_proto::ActionState_Name(
                   action_state_updater_->state().current_state());
  switch (action_state_updater_->state().current_state()) {
    default:
      return absl::InvalidArgumentError(
          absl::StrCat("Action state is unknown:",
                       safepower_agent_proto::ActionState_Name(
                           action_state_updater_->state().current_state())));
    case safepower_agent_proto::ACTION_STATE_INIT: {
      RETURN_IF_ERROR(EnterStateInit());
      break;
    }
    case safepower_agent_proto::ACTION_STATE_CHECKING_PRECONDITION:
      if (precondition_.has_value()) {
        return EnterStateCheckingPrecondition();
      }

      LOG(WARNING)
          << action_id_
          << " has no precondition but is in checking precondition state";
      SetState(ActionState::ACTION_STATE_RUNNING_ACTION);

      break;
    case safepower_agent_proto::ACTION_STATE_RUNNING_ACTION:
      LOG(INFO) << action_id_
                << " resumed in running action state. Assuming the action was "
                   "already run successfully.";
      if (validation_.has_value()) {
        SetState(ActionState::ACTION_STATE_VALIDATING_FINAL_STATE);
      } else {
        SetState(ActionState::ACTION_STATE_SUCCESS);
      }
      break;
    case safepower_agent_proto::ACTION_STATE_VALIDATING_FINAL_STATE: {
      if (validation_.has_value()) {
        return EnterStateValidatingFinalState();
      }

      ActionStateChange change;
      SetStatus(*change.mutable_status(),
                absl::InvalidArgumentError(
                    "No validation specified in validating state"));
      LOG(WARNING) << action_id_
                   << " has no validation but is in validating state";
      SetState(ActionState::ACTION_STATE_ERROR, std::move(change));
      break;
    }
    case safepower_agent_proto::ACTION_STATE_ERROR:
    case safepower_agent_proto::ACTION_STATE_SUCCESS:
      break;
  }

  return absl::OkStatus();
}

ActionContext::~ActionContext() {
  // Always try to cancel. This will fail if the action is already finished or
  // has not started yet, which is fine.
  do {
    absl::Status cancel_status =
        DaemonContext::Get().scheduler().CancelCall(execution_task_name());
    if (!cancel_status.ok() && !absl::IsNotFound(cancel_status)) {
      LOG(DFATAL) << "Failed to cancel action " << action_id_ << ": "
                  << cancel_status;
      continue;  // Try again if cancel failed. This is not expected.
    }
  } while (false);
  if (!mutex_.TryLock()) {
    LOG(DFATAL) << "Action " << action_id_ << " is still running";
    return;
  }
  Action action_impl = std::move(action_impl_);
  mutex_.Unlock();
  Finish(std::move(action_impl));
}

void ActionContext::Finish(Action action_impl) {
  if (action_impl != nullptr) {
    manager_.FinishAction(request_.action(), std::move(action_impl));
  }
}

static void AssignConditions(ActionStateChange& change,
                             const Condition::MatchList& matches) {
  for (const safepower_agent_proto::Condition*  condition :
       matches) {
    *change.add_matching_conditions() = *condition;
  }
}

void ActionContext::SetState(safepower_agent_proto::ActionState new_state,
                             ActionStateChange change_info) {
  using safepower_agent_proto::ActionStateChange;
  LOG(INFO) << "SetState " << action_id_ << " "
            << safepower_agent_proto::ActionState_Name(new_state) << ": "
            << change_info;
  ActionStateLog state_change;
  state_change.set_epoch_ms(DaemonContext::Get().epoch_ms());
  state_change.set_current_state(new_state);
  ActionStateChange* history_item = state_change.add_history();
  *history_item = std::move(change_info);
  SetTimestampToNow(*history_item->mutable_changed_at());
  history_item->set_previous_state(
      action_state_updater_->state().current_state());

  bool is_final = false;
  absl::Status status;
  switch (new_state) {
    default:
      LOG(DFATAL) << "Unexpected state: "
                  << safepower_agent_proto::ActionState_Name(new_state);
      break;
    case safepower_agent_proto::ACTION_STATE_CHECKING_PRECONDITION:
      status = EnterStateCheckingPrecondition();
      break;
    case safepower_agent_proto::ACTION_STATE_RUNNING_ACTION:
      status = EnterStateRunningAction();
      break;
    case safepower_agent_proto::ACTION_STATE_VALIDATING_FINAL_STATE:
      status = EnterStateValidatingFinalState();
      break;
    case safepower_agent_proto::ACTION_STATE_SUCCESS:
    case safepower_agent_proto::ACTION_STATE_ERROR:
      status = DaemonContext::Get().scheduler().DelayCall(
          [this, action_impl = std::move(action_impl_)]() mutable {
            Finish(std::move(action_impl));
          },
          absl::ZeroDuration(), execution_task_name());
      is_final = true;
      break;
  }
  if (!status.ok()) {
    LOG(ERROR) << "Failed to enter state "
               << safepower_agent_proto::ActionState_Name(new_state) << ": "
               << status;
    ActionStateChange error_change;
    SetStatus(*error_change.mutable_status(), status);
    if (action_state_updater_->state().current_state() ==
            ActionState::ACTION_STATE_ERROR &&
        new_state == ActionState::ACTION_STATE_ERROR) {
      LOG(DFATAL) << "State change failed too many times";
      return;
    }
    SetState(ActionState::ACTION_STATE_ERROR, std::move(error_change));
    return;
  }
  action_state_updater_->UpdateState(state_change, is_final);

  SavedActions saved_actions;
  auto [it, inserted] =
      saved_actions.mutable_actions()->insert({action_id_, {}});
  CHECK(inserted);
  *it->second.mutable_action_state_log() = std::move(state_change);
  absl::Status write_status =
      DaemonContext::Get().persistent_storage_manager().WriteSavedActionsChange(
          saved_actions);
  if (!write_status.ok()) {
    LOG(ERROR) << "Failed to persist action state: " << write_status;
  }
}

absl::Status ActionContext::EnterStateInit() {
  SavedActions saved_actions;
  auto [it, inserted] =
      saved_actions.mutable_actions()->insert({action_id_, {}});
  CHECK(inserted);
  *it->second.mutable_original_request() = request_;
  *it->second.mutable_action_state_log() = action_state_updater_->state();
  RETURN_IF_ERROR(
      DaemonContext::Get().persistent_storage_manager().WriteSavedActionsChange(
          saved_actions));
  return DaemonContext::Get().scheduler().DelayCall(
      std::bind(&ActionContext::NextStateInit, this), absl::ZeroDuration(),
      execution_task_name());
}

void ActionContext::NextStateInit() {
  absl::MutexLock lock(&mutex_);
  if (precondition_.has_value()) {
    SetState(ActionState::ACTION_STATE_CHECKING_PRECONDITION);
  } else {
    SetState(ActionState::ACTION_STATE_RUNNING_ACTION);
  }
}

absl::Status ActionContext::EnterStateCheckingPrecondition() {
  return StartCheckingCondition(
      *precondition_,
      absl::bind_front(&ActionContext::NextStatePreconditionMatched, this));
}

void ActionContext::NextStatePreconditionMatched(absl::Status status,
                                                 Condition::MatchList matches) {
  absl::MutexLock lock(&mutex_);
  ActionStateChange change;
  AssignConditions(change, matches);
  if (!status.ok()) {
    LOG(ERROR) << "Precondition failed: " << status;
    SetStatus(*change.mutable_status(), status);
    SetState(safepower_agent_proto::ACTION_STATE_ERROR, std::move(change));
    return;
  }
  SetState(ActionState::ACTION_STATE_RUNNING_ACTION, std::move(change));
}

absl::Status ActionContext::EnterStateRunningAction() {
  if (validation_.has_value()) {
    auto [match_status, match_list] =
        validation_->CheckState(manager_.system_state_updater()->state());
    if (!match_status.ok() || !match_list.empty()) {
      return DaemonContext::Get().scheduler().DelayCall(
          std::bind(&ActionContext::NextStateValidationCompleted, this,
                    match_status, std::move(match_list)),
          absl::ZeroDuration(), execution_task_name());
    }
  }

  return DaemonContext::Get().scheduler().DelayCall(
      std::bind(&ActionContext::RunAction, this), absl::ZeroDuration(),
      execution_task_name());
}

void ActionContext::RunAction() {
  absl::MutexLock lock(&mutex_);
  action_impl_(request_.action(), [this](absl::Status status) {
    // NextStateActionRan must be called asynchronously to avoid deadlocks.
    absl::Status delay_status = DaemonContext::Get().scheduler().DelayCall(
        std::bind(&ActionContext::NextStateActionRan, this, status),
        absl::ZeroDuration(), execution_task_name());
    if (!delay_status.ok()) {
      LOG(DFATAL) << "Failed to delay next state: " << delay_status;
      NextStateActionRan(status);
    }
  });
}

void ActionContext::NextStateActionRan(absl::Status status) {
  absl::MutexLock lock(&mutex_);
  if (!status.ok()) {
    LOG(ERROR) << "Action failed: " << status;
    ActionStateChange change;
    SetStatus(*change.mutable_status(), status);
    SetState(ActionState::ACTION_STATE_ERROR, std::move(change));
    return;
  }
  if (validation_.has_value()) {
    SetState(ActionState::ACTION_STATE_VALIDATING_FINAL_STATE);
  } else {
    SetState(ActionState::ACTION_STATE_SUCCESS);
  }
}

absl::Status ActionContext::EnterStateValidatingFinalState() {
  return StartCheckingCondition(
      *validation_,
      absl::bind_front(&ActionContext::NextStateValidationCompleted, this));
}

void ActionContext::NextStateValidationCompleted(absl::Status status,
                                                 Condition::MatchList matches) {
  absl::MutexLock lock(&mutex_);
  ActionStateChange change;
  AssignConditions(change, matches);
  if (!status.ok()) {
    LOG(ERROR) << "Validation failed: " << status;
    SetStatus(*change.mutable_status(), status);
    SetState(ActionState::ACTION_STATE_ERROR, std::move(change));
    return;
  }
  SetState(ActionState::ACTION_STATE_SUCCESS, std::move(change));
}

absl::Status ActionContextManager::RegisterAction(
    const safepower_agent_proto::Action& action,
    ActionContext::Action action_impl) {
  absl::MutexLock lock(&actions_mutex_);
  auto [_, inserted] = actions_.try_emplace(action, std::move(action_impl));
  if (!inserted) {
    return absl::AlreadyExistsError(
        absl::StrCat("Action already registered: ", action));
  }
  return absl::OkStatus();
}

absl::StatusOr< std::unique_ptr<ActionContext>>
ActionContextManager::ReloadActionContext(
    std::string action_id,
    safepower_agent_persistence_proto::SavedAction saved_action) {
  Action action_impl = nullptr;
  if (!ActionContext::IsFinalState(
          saved_action.action_state_log().current_state())) {
    ASSIGN_OR_RETURN(action_impl,
                     ReserveAction(saved_action.original_request().action()));
  }

  auto action_context = std::make_unique<ActionContext>(
      ActionContext::CreationToken{}, *this, std::move(action_id),
      std::move(*saved_action.mutable_original_request()),
      std::move(action_impl),
      std::move(*saved_action.mutable_action_state_log()));
  RETURN_IF_ERROR(action_context->Activate());
  return action_context;
}

absl::Status ActionContextManager::LoadSavedActions() {
  ASSIGN_OR_RETURN(
      SavedActions saved_actions,
      DaemonContext::Get().persistent_storage_manager().ReadSavedActions());
  absl::MutexLock lock(&actions_mutex_);

  for (auto& [action_id, saved_action] : *saved_actions.mutable_actions()) {
    absl::StatusOr<std::unique_ptr<ActionContext>> action_context_or =
        ReloadActionContext(action_id, std::move(saved_action));
    if (!action_context_or.ok()) {
      LOG(ERROR) << "Failed to reload action: " << action_context_or.status();
      continue;
    }
    std::unique_ptr<ActionContext> action_context =
        *std::move(action_context_or);
    absl::string_view action_id_ref = action_context->action_id();
    auto [it, inserted] =
        running_actions_.try_emplace(action_id_ref, std::move(action_context));
    if (!inserted) {
      return absl::AlreadyExistsError(
          absl::StrCat("Duplicate action ID: ", action_context->action_id()));
    }
  }
  return absl::OkStatus();
}

absl::StatusOr<ActionContext::Action> ActionContextManager::ReserveAction(
    const safepower_agent_proto::Action& action) {
  auto found_action_impl = actions_.find(action);
  if (found_action_impl == actions_.end()) {
    return absl::NotFoundError(absl::StrCat("Action not found: ", action));
  }
  if (found_action_impl->second == nullptr) {
    return absl::FailedPreconditionError(
        absl::StrCat("Action already started: ", action));
  }
  return std::move(found_action_impl->second);
}

std::string ActionContextManager::NextActionId() {
  return absl::StrFormat("%u-%u", DaemonContext::Get().epoch_ms(),
                         next_action_id_++);
}

absl::StatusOr<ActionContext* > ActionContextManager::StartAction(
    safepower_agent_proto::StartActionRequest request) {
  // Make sure this is never destroyed with the lock held.
  std::unique_ptr<ActionContext> action_context;
  absl::MutexLock lock(&actions_mutex_);
  ASSIGN_OR_RETURN(ActionContext::Action action_impl,
                   ReserveAction(request.action()));
  action_context = std::make_unique<ActionContext>(
      ActionContext::CreationToken{}, *this, NextActionId(), std::move(request),
      std::move(action_impl));
  RETURN_IF_ERROR(action_context->Activate());

  absl::string_view action_id = action_context->action_id();
  auto [it, inserted] =
      running_actions_.try_emplace(action_id, std::move(action_context));
  if (!inserted) {
    return absl::AlreadyExistsError(
        absl::StrCat("Duplicate action ID: ", action_context->action_id()));
  }
  return it->second.get();
}

void ActionContextManager::FinishAction(
    const safepower_agent_proto::Action& action,
    ActionContext::Action action_impl) {
  absl::MutexLock lock(&actions_mutex_);
  auto found_action_impl = actions_.find(action);
  if (found_action_impl == actions_.end()) {
    LOG(DFATAL) << "Action not found: " << action;
    return;
  }
  if (found_action_impl->second != nullptr) {
    LOG(DFATAL) << "Action not started: " << action;
    return;
  }
  found_action_impl->second = std::move(action_impl);
}

ActionContext*  ActionContextManager::GetActionContext(
    absl::string_view action_id) {
  absl::MutexLock lock(&actions_mutex_);
  auto found_action_context = running_actions_.find(action_id);
  if (found_action_context == running_actions_.end()) {
    return nullptr;
  }
  return found_action_context->second.get();
}

void ActionContextManager::GetSupportedActions(
    safepower_agent_proto::GetSupportedActionsResponse& response) const {
  absl::MutexLock lock(&actions_mutex_);
  for (const auto& [action, _] : actions_) {
    response.add_actions()->CopyFrom(action);
  }
}

}  // namespace safepower_agent
