// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "hoth_updater_cli.hpp"

#include "message_util.hpp"

#include <boost/endian/arithmetic.hpp>
#include <sdbusplus/bus.hpp>
#include <sdbusplus/message.hpp>
#include <stdplus/print.hpp>
#include <stdplus/raw.hpp>

#include <chrono>
#include <exception>
#include <format>
#include <fstream>
#include <functional>
#include <iostream>
#include <iterator>
#include <span>
#include <string_view>
#include <thread>
#include <variant>
#include <vector>

using namespace std::chrono_literals;

namespace google::hoth::tools
{

namespace
{

// Normal update time:
// root@mvbbj12-nfd01:/tmp# time ./hoth_updater  --hoth_id=<HOTH_ID> update
// --image=<IMAGE_FILE_NAME>
// --address=0x01ef0000 real    0m8.952s user    0m0.290s sys     0m0.030s
constexpr auto kCallTimeout = 15s;
constexpr auto kRetryDelay = 30s;
constexpr auto kAttemptLimit = 3;

std::string getHothService(std::string_view hoth_id)
{
    std::string service = "xyz.openbmc_project.Control.Hoth";
    if (!hoth_id.empty() && hoth_id != "bmc")
    {
        service += ".";
        service += hoth_id;
    }
    return service;
}

std::string getHothObject(std::string_view hoth_id)
{
    std::string object = "/xyz/openbmc_project/Control/Hoth";
    if (!hoth_id.empty() && hoth_id != "bmc")
    {
        object += "/";
        object += hoth_id;
    }
    return object;
}

sdbusplus::message::message hothMessage(sdbusplus::bus::bus& bus,
                                        std::string_view hoth_id,
                                        const char* method)
{
    std::string service = getHothService(hoth_id);
    std::string object = getHothObject(hoth_id);
    return bus.new_method_call(service.c_str(), object.c_str(),
                               "xyz.openbmc_project.Control.Hoth", method);
}

sdbusplus::message::message hothPropertyMessage(sdbusplus::bus::bus& bus,
                                                std::string_view hoth_id)
{
    std::string service = getHothService(hoth_id);
    std::string object = getHothObject(hoth_id);
    return bus.new_method_call(service.c_str(), object.c_str(),
                               "org.freedesktop.DBus.Properties", "Get");
}

std::vector<uint8_t> sendHostCommand(
    sdbusplus::bus::bus& bus, std::string_view hoth_id,
    const std::span<const uint8_t> command,
    std::optional<sdbusplus::SdBusDuration> timeout = std::nullopt)
{
    sdbusplus::message::message msg =
        hothMessage(bus, hoth_id, "SendHostCommand");
    msg.append(command);
    sdbusplus::message::message resp =
        bus.call(msg, timeout.value_or(kCallTimeout));
    std::vector<uint8_t> result;
    resp.read(result);
    return result;
}

template <typename T>
std::optional<T> getHothStateProperty(
    sdbusplus::bus::bus& bus, std::string_view hoth_id,
    std::string_view property,
    std::optional<sdbusplus::SdBusDuration> timeout = std::nullopt)
{
    try
    {
        sdbusplus::message::message msg = hothPropertyMessage(bus, hoth_id);
        msg.append("xyz.openbmc_project.Control.Hoth.State", property);
        std::variant<T> value{};
        sdbusplus::message::message resp =
            bus.call(msg, timeout.value_or(kCallTimeout));
        resp.read(value);
        return std::get<T>(value);
    }
    catch (const std::exception& ex)
    {
        return std::nullopt;
    }
}

template <typename T>
inline std::string optionalToString(const std::optional<T>& value)
{
    return value.has_value() ? std::format("{}", *value) : "n/a";
}

} // namespace

void HothUpdaterCLI::updateFirmware(sdbusplus::bus::bus& bus,
                                    std::string_view hoth_id,
                                    const std::span<const uint8_t> image)
{
    sdbusplus::message::message msg =
        hothMessage(bus, hoth_id, "UpdateFirmware");
    msg.append(image);
    bus.call(msg, kCallTimeout);
}

void HothUpdaterCLI::spiWrite(sdbusplus::bus::bus& bus,
                              std::string_view hoth_id,
                              const std::span<const uint8_t> image,
                              std::optional<uint32_t> address)
{
    for (int attempt = 0; attempt < kAttemptLimit; attempt++)
    {
        try
        {
            sdbusplus::message::message msg =
                hothMessage(bus, hoth_id, "SpiWrite");
            if (address)
            {
                msg.append(*address); // u
                msg.append(image);    // ay
            }
            bus.call(msg, kCallTimeout);
            return;
        }
        catch (const std::exception& ex)
        {
          std::cout << "Exception caught: " << ex.what() << '\n';
          std::cout << "Will retry in " << kRetryDelay.count() << " seconds"
                    << '\n';
          std::this_thread::sleep_for(kRetryDelay);
        }
    }
    throw std::runtime_error("Retry attempt limit exhausted.");
}

FirmwareUpdateStatus
    HothUpdaterCLI::getFirmwareUpdateStatus(sdbusplus::bus::bus& bus,
                                            std::string_view hoth_id)
{
    sdbusplus::message::message msg =
        hothMessage(bus, hoth_id, "GetFirmwareUpdateStatus");
    sdbusplus::message::message reply = bus.call(msg, kCallTimeout);
    std::string rsp;
    reply.read(rsp);
    return sdbusplus::xyz::openbmc_project::Control::server::Hoth::
        convertFirmwareUpdateStatusFromString(rsp);
}

std::vector<uint8_t>
    HothUpdaterCLI::readFileIntoByteArray(std::string_view filename)
{
    std::ifstream image_file;
    image_file.exceptions(std::ios::failbit);
    image_file.open(filename.data(), std::ios::binary);
    std::vector<uint8_t> image(std::istreambuf_iterator<char>(image_file), {});
    return image;
}

void HothUpdaterCLI::doUpdate(const Args& args)
{
    sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
    auto end_time = std::chrono::steady_clock::now() + 5min;

    std::vector<uint8_t> image = readFileIntoByteArray(args.imageFilename);

    if (args.updateMethod == "spi")
    {
        spiWrite(bus, args.hothId, image, args.address);
    }
    else if (args.updateMethod == "update_firmware")
    {
        updateFirmware(bus, args.hothId, image);
    }
    else if (args.updateMethod == "payload_update")
    {
        throw std::runtime_error("Payload update is not supported yet");
    }

    FirmwareUpdateStatus status = getFirmwareUpdateStatus(bus, args.hothId);
    while (status == FirmwareUpdateStatus::InProgress)
    {
        if (std::chrono::steady_clock::now() > end_time)
        {
            throw std::runtime_error("Timed out updating firmware");
        }
        std::this_thread::sleep_for(1s);
        status = getFirmwareUpdateStatus(bus, args.hothId);
    }
    if (status != FirmwareUpdateStatus::Done)
    {
        throw std::runtime_error("Update failed");
    }
}

using namespace boost::endian;

HothVersionStringsRsp HothUpdaterCLI::getHothVersion(sdbusplus::bus::bus& bus,
                                                     std::string_view hoth_id)
{
    const std::vector<uint8_t> versionStringsCommand = {0x03, 0xfb, 0x02, 0x00,
                                                        0x00, 0x00, 0x00, 0x00};
    std::vector<uint8_t> resp_bytes =
        sendHostCommand(bus, hoth_id, versionStringsCommand);

    auto response = stdplus::raw::copyFrom<HothVersionStringsRsp>(resp_bytes);

    if (response.header.result != 0)
    {
        throw std::runtime_error("Failed to get versions");
    }

    return response;
}

void HothUpdaterCLI::doFirmwareVersion(const Args& args)
{
    sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
    auto response = getHothVersion(bus, args.hothId);

    if (args.ro)
    {
        std::cout << response.version.version_string_ro;
    }
    else
    {
        std::cout << response.version.version_string_rw;
    }
}

std::vector<std::string> splitString(const std::string& s, const char delim)
{
    size_t idx = 0;
    std::vector<std::string> ret;
    while (idx < s.size())
    {
        std::size_t delim_pos = s.find_first_of(delim, idx);
        if (delim_pos == std::string::npos)
        {
            break;
        }
        ret.emplace_back(s.substr(idx, delim_pos - idx));
        idx = delim_pos + 1;
    }
    if (idx < s.size())
    {
        ret.emplace_back(s.substr(idx));
    }
    return ret;
}

HothActivationStatistics
    HothUpdaterCLI::getHothActivationStatistics(sdbusplus::bus::bus& bus,
                                                std::string_view hoth_id)
{
    HothActivationStatistics status;

    // Try query the statistics properties on dBus.
    // Each individual property could fail. Ignore errors and try populate as
    // much as possible.
    status.rw_failure_code = getHothStateProperty<uint32_t>(
        bus, hoth_id, "FirmwareUpdateFailureCode");
    status.rw_failed_minor = getHothStateProperty<uint32_t>(
        bus, hoth_id, "FirmwareUpdateFailedMinor");
    status.ro_failure_code = getHothStateProperty<uint32_t>(
        bus, hoth_id, "BootloaderUpdateFailureCode");
    status.reset_flags =
        getHothStateProperty<uint32_t>(bus, hoth_id, "ResetFlags");
    status.uptime_us = getHothStateProperty<uint64_t>(bus, hoth_id, "UpTime");

    return status;
}

void HothUpdaterCLI::doActivationCheck(const Args& args)
{
    sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
    stdplus::print(stdout, "installed_version: \"{}\"\n",
                   args.expectedRwVersion);

    std::vector<std::string> hoth_ids = splitString(args.hothId, ',');
    if (hoth_ids.empty())
    {
        hoth_ids.push_back("");
    }

    for (const std::string& hoth_id : hoth_ids)
    {
        auto response = getHothVersion(bus, hoth_id);

        // Strip out the hoth family in the version string, because it's not
        // guaranteed to be returned completely.
        std::string_view version_stripped(response.version.version_string_rw);
        size_t split_pos = version_stripped.find_first_of('/');
        if (split_pos != std::string::npos)
        {
            version_stripped = version_stripped.substr(0, split_pos);
        }

        stdplus::print(stdout,
                       "activated_versions {{\n"
                       "    key: \"{}\"\n"
                       "    value: \"{}\"\n"
                       "}}\n",
                       (hoth_id.empty() ? "active" : hoth_id),
                       version_stripped);

        if (args.expectedRwVersion != version_stripped)
        {
            HothActivationStatistics actv_status =
                getHothActivationStatistics(bus, hoth_id);
            stdplus::print(stdout,
                           "notes: \"Running RW version {}"
                           " does not match expected version {}, Status: rw={} "
                           "(ver={}) ro={} rst_flags={}, uptime={}us\"\n",
                           version_stripped, args.expectedRwVersion,
                           optionalToString(actv_status.rw_failure_code),
                           optionalToString(actv_status.rw_failed_minor),
                           optionalToString(actv_status.ro_failure_code),
                           optionalToString(actv_status.reset_flags),
                           optionalToString(actv_status.uptime_us));
            throw std::runtime_error("Activation check failed");
        }
    }
}

void setupCLIApp(CLI::App& app, HothUpdaterCLI& cli, Args& args)
{
    app.require_subcommand(1);
    app.add_option("--hoth_id", args.hothId,
                   "Hoth IDs, comma-delimited for activation-check");
    auto* update = app.add_subcommand("update", "Update Hoth image");
    update->add_option("--image", args.imageFilename, "Firmware image path")
        ->required()
        ->check(CLI::ExistingFile);
    update
        ->add_option("--method", args.updateMethod,
                     "Update method, can be spi|update_firmware|payload_update")
        ->required();
    update->add_option("--address", args.address, "SPI address for data")
        ->check(CLI::NonNegativeNumber);
    update->callback([&args, &cli] { cli.doUpdate(args); });

    auto* version =
        app.add_subcommand("firmware-version", "Get firmware version");

    version->add_flag("--ro,!--rw", args.ro, "Select ro/rw partition")
        ->required();
    version->callback([&args, &cli] { cli.doFirmwareVersion(args); });

    auto* activation_check = app.add_subcommand(
        "activation-check", "Generate package activation check textproto");
    activation_check->add_flag("--rw", args.expectedRwVersion,
                               "Expected RW version");
    // RO version check not implemented yet.
    activation_check->callback([&args, &cli] { cli.doActivationCheck(args); });
}

} // namespace google::hoth::tools
