// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "asset.hpp"
#include "ec_util.hpp"
#include "firmware_mtd_updater.hpp"
#include "firmware_spi_updater.hpp"
#include "host_command.hpp"
#include "hoth.hpp"
#include "hoth_state.hpp"
#include "log_collector_util.hpp"
#include "message_file.hpp"
#include "message_intf.hpp"
#include "message_reinit.hpp"
#include "mtd_util.hpp"
#include "payload_update.hpp"
#include "sys.hpp"
#include "sys_interface.hpp"
#include "version.hpp"

#include <bits/getopt_core.h>
#include <stdlib.h>
#include <string.h>
#include <systemd/sd-daemon.h>
#include <unistd.h>

#include <boost/asio.hpp>
#include <boost/asio/io_context.hpp>
#include <sdbusplus/asio/connection.hpp>
#include <sdbusplus/asio/object_server.hpp>
#include <sdbusplus/bus.hpp>
#include <stdplus/print.hpp>

#include <exception>
#include <format>
#include <memory>
#include <string>
#include <string_view>

#ifdef HAVE_USB
#include "libhoth_usb.hpp"
#include "libusb_impl.hpp"
#include "message_hoth_usb.hpp"
#endif

#include <xyz/openbmc_project/Control/Hoth/error.hpp>

const char* HOTH_CONTROL_PATH = "/xyz/openbmc_project/Control/Hoth";
const char* HOTH_CONTROL_BUS = "xyz.openbmc_project.Control.Hoth";
const char* VERSION_BASE_PATH = "/xyz/openbmc_project/software";
const char* ASSET_BASE_PATH = "/xyz/openbmc_project/inventory";

constexpr static const char* kEmService = "xyz.openbmc_project.EntityManager";
constexpr static const char* kConfigPathPrefix =
    "xyz.openbmc_project.Configuration.Hoth";

// NOLINTNEXTLINE(google-build-using-namespace)
using namespace std::string_literals;
using sdbusplus::error::xyz::openbmc_project::control::hoth::InterfaceError;

using Value = std::variant<std::monostate>;
using ObjectType =
    std::unordered_map<std::string, std::unordered_map<std::string, Value>>;
using ManagedObjectType =
    std::vector<std::pair<sdbusplus::message::object_path, ObjectType>>;

static void usage(const char* name)
{
    stdplus::print(stderr, "\
Usage: {} [-m <MAILBOX>]\n\
Bridge commands between a D-Bus interface and Hoth's mailbox\n\n\
  -b                        Blocks the use of legacy payload verify commands.\n\
  -u <UART_CHANNEL_ID>      Collects and streams the device UART logs.\n\n\
  -m <MAILBOX>              Use the file at <MAILBOX> as the mailbox.\n\
                            If omitted, it will be located automatically.\n\n\
  -n <NAME>                 Use the name <NAME> for the hoth on DBus.\n\
                            If omitted, it will use the legacy empty name.\n\n\
  -a <ADDRESS_MODE>         SPI address mode (address can be 3 or 4 bytes) .\n\
                            If omitted, it will use 4 bytes by default.\n\n\
  -r<TYPE>                  Resets target before SPI flash. Possible types:\n\
     needed:        Default if no <TYPE> provided. Put the target into reset\n\
                    if SPS passthrough enabled or status unknown.\n\
     never:         Same as no -r. Never put the target in reset.\n\
     ignore_fail:   Put target in reset if SPS passthrough enabled, but ignore\n\
                    SPS status command failures.\n\
     needed_active: Same as \"needed\" but it tries disabling SPS passthrough\n\
                    when it is enabled.\n",
                   name);
}

static std::string findBoardObjPath(const std::string& assetObj,
                                    const std::string& objPath,
                                    const ObjectType& ifcAndProperties)
{
    for (const auto& [ifc, properties] : ifcAndProperties)
    {
        if (ifc.starts_with(kConfigPathPrefix))
        {
            /* Match the hoth name from the objPath to the hoth name
             * obtained from Configuration.Hoth interface */
            if (objPath.substr(objPath.find_last_of('/') + 1) ==
                assetObj.substr(assetObj.find_last_of('/') + 1))
            {
                return objPath.substr(0, objPath.find_last_of('/'));
            }
        }
    }

    return std::string();
}

uint32_t getUartChannelId(const std::string &uartChannelStr) {
  uint32_t uartChannelId = 0;
  for (char c : uartChannelStr) {
    uartChannelId <<= 8; // or 0x100, shifting 2 hex digits
    uartChannelId += static_cast<uint32_t>(c);
  }
  return uartChannelId;
}

using google::hoth::internal::ResetMode;

ResetMode parseReset(char* str)
{
    if (!str || "needed"s == str)
    {
        return ResetMode::needed;
    }
    if ("never"s == str)
    {
        return ResetMode::never;
    }
    if ("ignore_fail"s == str)
    {
        return ResetMode::ignore;
    }
    if ("needed_active"s == str)
    {
        return ResetMode::needed_active;
    }
    stdplus::print(stderr, "Unknown reset type {}\n", str);
    exit(EXIT_FAILURE);
}

int main(int argc, char* argv[])
{
    bool allow_legacy_verify = true;
    const char *mailbox = nullptr, *name = "", *uartChannel = "";
    uint8_t addressSize = 4;
    int opt;
    ResetMode targetReset = ResetMode::never;
    bool ignoreAddressMode = false;

    while ((opt = getopt(argc, argv, "bu:m:n:a:r::i")) != -1)
    {
        switch (opt)
        {
            case 'b':
                allow_legacy_verify = false;
                break;
            case 'u':
                uartChannel = optarg;
                break;
            case 'm':
                mailbox = optarg;
                break;
            case 'n':
                name = optarg;
                break;
            case 'a':
                addressSize = std::stoi(optarg);
                break;
            case 'r':
                targetReset = parseReset(optarg);
                break;
            case 'i':
                ignoreAddressMode = true;
                break;
            default:
                usage(argv[0]);
                exit(EXIT_FAILURE);
        }
    }

    std::unique_ptr<google::hoth::MessageIntf> msg;
    std::unique_ptr<google::hoth::internal::HostCommandImpl> hostCmd;
    std::unique_ptr<google::hoth::internal::PayloadUpdateImpl> payloadUpdate;
    std::unique_ptr<google::hoth::internal::EcUtilImpl> ecUtil;
    std::unique_ptr<google::hoth::Hoth> hoth;
    std::unique_ptr<google::hoth::HothState> hoth_state;
    std::unique_ptr<google::hoth::RoVersion> swVerRo;
    std::unique_ptr<google::hoth::RwVersion> swVerRw;
    std::unique_ptr<google::hoth::Asset> asset;
    std::unique_ptr<google::hoth::internal::FirmwareUpdater> firmwareUpdater;

    google::hoth::internal::RateLimiter rateLimiter(kRateLimiterMilliSeconds);
    std::unique_ptr<google::hoth::internal::LogCollectorUtil> logCollectorUtil =
        std::make_unique<google::hoth::internal::LogCollectorUtil>(
            rateLimiter, kAsyncWaitTimeInSeconds);

    google::hoth::internal::Mtd* mtd = &google::hoth::internal::mtdImpl;
    google::hoth::internal::Sys* sys = &google::hoth::internal::sys_impl;
    boost::asio::io_context io;

    std::shared_ptr<sdbusplus::asio::connection> systemBus =
        std::make_shared<sdbusplus::asio::connection>(io);
    sdbusplus::asio::object_server objectServer(systemBus, true);
    objectServer.add_manager("/");

    size_t max_retries = 2;

    try
    {
        constexpr std::string_view usbPrefix = "usb:";
        if (!mailbox)
        {
            msg = std::make_unique<google::hoth::MessageFile>(
                google::hoth::internal::mtdImpl.findPartition("hoth-mailbox"));
            max_retries = 1;
        }
        else if (strncmp(mailbox, usbPrefix.data(), usbPrefix.size()) == 0)
        {
#ifdef HAVE_USB
            msg = std::make_unique<google::hoth::MessageReinit<
                google::hoth::MessageHothUSB, std::string_view,
                google::hoth::LibusbIntf*, google::hoth::LibHothUsbIntf*>>(
                mailbox + usbPrefix.size(), &google::hoth::libusb_impl,
                &google::hoth::libhoth_usb);
            mtd = nullptr;
#else
            throw std::logic_error("Missing USB support");
#endif
        }
        else
        {
            msg = std::make_unique<google::hoth::MessageFile>(mailbox);
        }

        hostCmd = std::make_unique<google::hoth::internal::HostCommandImpl>(
            msg.get(), &io, logCollectorUtil.get(), name, max_retries,
            allow_legacy_verify, getUartChannelId(uartChannel));
        payloadUpdate =
            std::make_unique<google::hoth::internal::PayloadUpdateImpl>(
                hostCmd.get());
        ecUtil =
            std::make_unique<google::hoth::internal::EcUtilImpl>(hostCmd.get());
        if (mtd)
        {
            firmwareUpdater =
                std::make_unique<google::hoth::internal::FirmwareMtdUpdater>(
                    mtd, sys);
        }
        else
        {
            firmwareUpdater =
                std::make_unique<google::hoth::internal::FirmwareSpiUpdater>(
                    hostCmd.get(), addressSize, targetReset, ignoreAddressMode);
        }

        // To facilitate debugging Hoth errors, emit a snapshot of its console.
        hostCmd->collectHothLogsAsync(true);
    }
    catch (const std::exception& e)
    {
        stdplus::print(stderr, "Error setting up Hoth interface: {}\n",
                       e.what());
        exit(EXIT_FAILURE);
    }

    try
    {
        std::string obj = HOTH_CONTROL_PATH, svc = HOTH_CONTROL_BUS;
        if (name[0] != '\0')
        {
            obj += "/";
            obj += name;
            svc += ".";
            svc += name;
        }
        hoth = std::make_unique<google::hoth::Hoth>(
            *systemBus, obj.c_str(), hostCmd.get(), payloadUpdate.get(),
            ecUtil.get(), firmwareUpdater.get());
        hoth_state = std::make_unique<google::hoth::HothState>(
            *systemBus, io, obj.c_str(), hostCmd.get(), ecUtil.get());
        systemBus->request_name(svc.c_str());
    }
    catch (const std::exception& e)
    {
        stdplus::print(stderr,
                       "Error registering Hoth control object with D-Bus: {}\n",
                       e.what());
        exit(EXIT_FAILURE);
    }

    try
    {
        std::string ro_obj, rw_obj;

        std::string ver_obj =
            std::format("{}/hoth{}", VERSION_BASE_PATH,
                        name[0] != '\0' ? std::string("_") + name : "");

        ro_obj = ver_obj + "_ro";
        rw_obj = ver_obj + "_rw";

        swVerRo = std::make_unique<google::hoth::RoVersion>(
            *systemBus, ro_obj.c_str(), hostCmd.get());
        swVerRw = std::make_unique<google::hoth::RwVersion>(
            *systemBus, rw_obj.c_str(), hostCmd.get());
    }
    catch (const std::exception& e)
    {
        stdplus::print(
            stderr,
            "Error registering Software Version objects with D-Bus: {}\n",
            e.what());
        exit(EXIT_FAILURE);
    }
    sd_notify(0, "READY=1");

    std::string assetObj =
        std::format("{}/hoth{}", ASSET_BASE_PATH,
                    name[0] != '\0' ? std::string("_") + name : "");
    std::string board_obj;

    ManagedObjectType managedObjs;
    try
    {
        auto method = systemBus->new_method_call(
            kEmService, "/xyz/openbmc_project/inventory",
            "org.freedesktop.DBus.ObjectManager", "GetManagedObjects");
        systemBus->call(method).read(managedObjs);
    }
    catch (const sdbusplus::exception::SdBusError& e)
    {
        stdplus::print(
            stderr, "Failed to get managed objects from entity manager: {}\n",
            e.what());
    }

    /* Find the board object path */
    for (const auto& [path, ifcToProperties] : managedObjs)
    {
        board_obj = findBoardObjPath(assetObj, path.str, ifcToProperties);
        if (!board_obj.empty())
        {
            try
            {
                asset = std::make_unique<google::hoth::Asset>(
                    *systemBus, assetObj, board_obj, hostCmd.get());
            }
            catch (const std::exception& e)
            {
                stdplus::print(
                    stderr, "Error registering Asset Object with D-Bus: {}\n",
                    e.what());
                exit(EXIT_FAILURE);
            }

            break;
        }
    }

    auto ifcAddedMatch = std::make_unique<sdbusplus::bus::match_t>(
        *systemBus,
        sdbusplus::bus::match::rules::interfacesAdded() +
            sdbusplus::bus::match::rules::sender(kEmService),
        [&systemBus, &hostCmd, &asset, &assetObj](sdbusplus::message_t& msg) {
            sdbusplus::message::object_path objPath;
            ObjectType ifcAndProperties;
            try
            {
                msg.read(objPath, ifcAndProperties);
            }
            catch (const std::exception& e)
            {
                stdplus::print(stderr,
                               "Error reading objects from EM service: {}\n",
                               e.what());
                return;
            }

            std::string boardObj =
                findBoardObjPath(assetObj, objPath.str, ifcAndProperties);
            if (!boardObj.empty())
            {
                try
                {
                    if (asset)
                    {
                        asset.reset();
                    }
                    asset = std::make_unique<google::hoth::Asset>(
                        *systemBus, assetObj, boardObj, hostCmd.get());
                }
                catch (const std::exception& e)
                {
                    stdplus::print(
                        stderr,
                        "Error registering Asset Object with D-Bus: {}\n",
                        e.what());
                    exit(EXIT_FAILURE);
                }
            }
        });

    try
    {
        io.run();
    }
    catch (const InterfaceError& e)
    {
        stdplus::print(stderr,
                       "Lost connection with Hoth. Shutting down. e={}\n",
                       e.what());
        return EXIT_FAILURE;
    }
    return 0;
}
