// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <unistd.h>

#include <flasher/file.hpp>
#include <flasher/mutate.hpp>
#include <flasher/ops.hpp>
#include <flashupdate/args.hpp>
#include <flashupdate/flash.hpp>
#include <flashupdate/logging.hpp>

#include <charconv>
#include <filesystem>
#include <format>
#include <memory>
#include <optional>
#include <string>
#include <string_view>

namespace flashupdate
{
namespace flash
{

using stdplus::fd::OpenAccess;
using stdplus::fd::OpenFlag;
using stdplus::fd::OpenFlags;

std::string FlashHelper::readMtdFileText(const std::string& filename)
{
    LOG(LogLevel::Debug, "Reading Mtd File {}\n", filename);

    auto argReadFile = flasher::ModArgs(filename);
    auto readFile = openFile(argReadFile, OpenFlags(OpenAccess::ReadOnly));

    std::filesystem::path file(filename.data());
    auto size = std::filesystem::file_size(file);

    std::vector<std::byte> fileIn(size);
    readFile->readAt(fileIn, 0);

    // Find new line and remove the data after it.
    auto newLinePos =
        std::find(fileIn.begin(), fileIn.end(), static_cast<std::byte>('\n'));
    if (newLinePos == fileIn.end())
    {
        throw std::runtime_error("not able to find newline in the mtd file");
    }

    size = newLinePos - fileIn.begin();
    std::string output(size, ' ');
    std::memcpy(output.data(), fileIn.data(), size);
    return output;
}

std::string FlashHelper::findMtdDevice(const std::string& name)
{
    for (const auto& entry :
         std::filesystem::directory_iterator("/sys/class/mtd/"))
    {
        try
        {
            auto mtdName = this->readMtdFileText(
                std::format("{}/name", entry.path().c_str()));

            if (name == mtdName)
            {
                std::string_view mtd = entry.path().c_str();
                return mtd.substr(mtd.find_last_of('/') + 1).data();
            }
        }
        catch (const std::exception& e)
        {
            LOG(LogLevel::Debug, "failed to check mtd name: err {}\n",
                e.what());
        }
    }
    return std::string();
}

Flash::Flash()
{
    helperPtr = std::make_unique<FlashHelper>();
    helper = helperPtr.get();
}

Flash::Flash(Config config, bool keepMux) : Flash()
{
    this->config = config;
    this->keepMux = keepMux;
}

Flash::~Flash()
{
    cleanup();
}

void bindDriver(const Config& config, bool bindValue)
{
    auto argFile = flasher::ModArgs(std::format("{}/{}", config.flash.driver,
                                                bindValue ? "bind" : "unbind"));
    auto file = openFile(argFile, OpenFlags(OpenAccess::WriteOnly));

    auto fileOut = std::vector<std::byte>(config.flash.deviceId.size());
    memcpy(fileOut.data(), config.flash.deviceId.data(),
           config.flash.deviceId.size());
    file->writeAtExact(fileOut, 0);
}

inline std::optional<std::pair<std::string, uint32_t>>
    Flash::getFlash(bool primary, std::optional<size_t> expectedSize)
{
    return primary ? getFlash(std::nullopt, expectedSize) :
                   // Use the staging index from the metadata instead of the
                   // targeted one.
               getFlash(static_cast<std::optional<uint8_t>>(
                            config.flash.stagingIndex),
                        expectedSize);
}

std::optional<std::pair<std::string, uint32_t>>
    Flash::getFlash(std::optional<uint8_t> secondary,
                    std::optional<size_t> expectedSize)
{
    auto partition = secondary.has_value() ? config.flash.secondary[*secondary]
                                           : config.flash.primary;
    std::string_view location = partition.location;

    auto index = location.find_last_of(',');
    if (index == std::string::npos)
    {
        return std::nullopt;
    }
    auto name = location.substr(index + 1);

    if (!location.starts_with("mtd"))
    {
        // non-mtd device path is expected to be in the format of
        //   fake,type=simple,erase=0,fake.img
        // Last element is the image
        try
        {
            // Create a file if the expectedSize is valid
            if (!std::filesystem::exists(name) && expectedSize)
            {
                LOG(LogLevel::Info, "Creating the stage file.");
                auto argFile = flasher::ModArgs(name);
                auto file = openFile(argFile, OpenFlags(OpenAccess::WriteOnly)
                                                  .set(OpenFlag::Create)
                                                  .set(OpenFlag::Trunc));
                file->truncate(*expectedSize);
            }
            std::filesystem::path path(name);
            uint32_t size = std::filesystem::file_size(path);
            return std::make_pair(partition.location, size);
        }
        catch (std::filesystem::filesystem_error const& e)
        {
            LOG(LogLevel::Error, "failed find the partition: {}", e.what());
            return std::nullopt;
        }
    }

#ifndef DEV_WORKFLOW

    // Check if the driver for the SPI flash is already in use.
    bool spiDriverExists =
        access(std::format("{}/{}", config.flash.driver, config.flash.deviceId)
                   .c_str(),
               F_OK) == 0;

    // Only setup Flash Driver if it is needed.
    // Cleaning up before the driver is ready can cause the kernel to crash.
    // If the flash is not used, the cleanup might happen too soon and cause the
    // issue.
    if (partition.muxSelect)
    {
        std::string gpio =
            std::format("/sys/class/gpio/gpio{}/", *partition.muxSelect);

        LOG(LogLevel::Info, "Select the MUX with {}", gpio);

        // Expose the GPIO if it does not exists
        if (access(gpio.c_str(), F_OK) == -1)
        {
            auto argFile = flasher::ModArgs("/sys/class/gpio/export");
            auto file = openFile(argFile, OpenFlags(OpenAccess::WriteOnly)
                                              .set(OpenFlag::Create)
                                              .set(OpenFlag::Trunc));
            std::string data = std::to_string(*config.flash.primary.muxSelect);
            std::vector<std::byte> file_out(data.size());
            std::memcpy(file_out.data(), data.data(), data.size());

            file->writeAtExact(file_out, 0);
        }

        if (spiDriverExists)
        {
            // Get MUX GPIO Value and reset the spi driver if the MUX is not set
            // before the Driver
            auto argFile = flasher::ModArgs(std::format(
                "/sys/class/gpio/gpio{}/value", *partition.muxSelect));
            auto gpioValue = openFile(argFile, OpenFlags(OpenAccess::ReadOnly));
            std::vector<std::byte> value(1);
            gpioValue->readAt(value, 0);

            // The MUX is not set before the driver so unbind the driver.
            if (value[0] == std::byte('0'))
            {
                LOG(LogLevel::Info, "Reset SPI Driver fist. GPIO was not set");
                bindDriver(config, /*bindValue=*/false);
                spiDriverExists = false;
            }
        }

        LOG(LogLevel::Info,
            "Select the MUX with gpio{} to enable the firmware flash",
            *partition.muxSelect);

        auto argFile = flasher::ModArgs(std::format(
            "/sys/class/gpio/gpio{}/direction", *partition.muxSelect));
        auto file = openFile(argFile, OpenFlags(OpenAccess::WriteOnly)
                                          .set(OpenFlag::Create)
                                          .set(OpenFlag::Trunc));
        std::string data = "high";
        std::vector<std::byte> file_out(data.size());
        std::memcpy(file_out.data(), data.data(), data.size());
        file->writeAtExact(file_out, 0);

        resetGPIOs.emplace(*partition.muxSelect);
    }

    // Bind driver if it doesn't already exist.
    if (!spiDriverExists)
    {
        bindDriver(config, /*bindValue=*/true);
        LOG(LogLevel::Info, "bound {} to {}", config.flash.deviceId,
            config.flash.driver);
    }

    // After successfully setting the mux, setting the flag to prepare for
    // cleanup.
    needMuxReset = true;
#endif

    // Mtd dev path is expected to be in the format of
    //   mtd,bios-primary
    // Last element is the label of the mtd device
    std::string mtd = helper->findMtdDevice(name.data());
    if (mtd.empty())
    {
        LOG(LogLevel::Info, "failed to find the mtd device with label of {}",
            name);
        return std::nullopt;
    }

    std::string sizeStr =
        helper->readMtdFileText(std::format("/sys/class/mtd/{}/size", mtd));
    uint32_t size;
    auto [ptr, ec]{
        std::from_chars(sizeStr.data(), sizeStr.data() + sizeStr.size(), size)};
    if (ec != std::errc())
    {
        throw std::runtime_error(
            std::format("failed to convert string to uint32_t: {}",
                        std::make_error_code(ec).message()));
    }
    if (ptr != sizeStr.data() + sizeStr.size())
    {
        throw std::runtime_error("converted invalid characters");
    }

    LOG(LogLevel::Info, "using {} as the firmware flash with size of {}", name,
        size);

    if (expectedSize && size != expectedSize)
    {
        throw std::runtime_error(
            std::format("Device size does not match expected, Want {}, got {}",
                        *expectedSize, size));
    }
    return std::make_pair(std::format("mtd,/dev/{}", mtd), size);
}

void Flash::cleanup()
{
    if (!needMuxReset || keepMux)
    {
        return;
    }

#ifndef DEV_WORKFLOW
    LOG(LogLevel::Info, "Cleanup the MUX");
    LOG(LogLevel::Info, "unbind {} to {}", config.flash.deviceId,
        config.flash.driver);

    // Check if the driver for the SPI flash is already removed.
    if (access(std::format("{}/{}", config.flash.driver, config.flash.deviceId)
                   .c_str(),
               F_OK) != -1)
    {
        bindDriver(config, /*bindValue=*/false);
    }

    // Switch mux to host
    // It is possible that there are multiple MUX selected.
    for (const auto& gpio : resetGPIOs)
    {
        LOG(LogLevel::Info, "set gpio{} to low", gpio);
        auto argFile = flasher::ModArgs(
            std::format("/sys/class/gpio/gpio{}/direction", gpio));
        auto file = openFile(argFile, OpenFlags(OpenAccess::WriteOnly)
                                          .set(OpenFlag::Create)
                                          .set(OpenFlag::Trunc));
        std::string data = "low";
        std::vector<std::byte> file_out = std::vector<std::byte>(data.size());
        memcpy(file_out.data(), data.data(), data.size());
        file->writeAtExact(file_out, 0);
    }
#endif
}

void Flash::setFlashHelper(FlashHelper* helper)
{
    if (helper == nullptr)
    {
        return;
    }

    this->helper = helper;
}

} // namespace flash
} // namespace flashupdate
