// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <flasher/logging.hpp>
#include <flasher/ops.hpp>
#include <stdplus/exception.hpp>

#include <algorithm>
#include <format>
#include <stdexcept>

namespace flasher
{
namespace ops
{

static void readAtExact(Device& dev, std::span<std::byte> buf, size_t off)
{
    dev.readAtExact(buf, off);
    LOG(LogLevel::Info, " RD@{}#{}", off, buf.size());
}

void automatic(Device& dev, size_t dev_offset, Reader& reader,
               size_t reader_offset, Mutate& mut, size_t max_size,
               std::optional<size_t> stride_size, bool noread)
{
    if (dev.getEraseSize() == 0)
    {
        return write(dev, dev_offset, reader, reader_offset, mut, max_size,
                     stride_size, noread);
    }

    if (dev_offset > dev.getSize())
    {
        throw std::invalid_argument(std::format(
            "Device smaller than offset, {} < {}", dev.getSize(), dev_offset));
    }
    size_t stride = stride_size ? *stride_size : dev.recommendedStride();
    if (stride == 0)
    {
        throw std::invalid_argument("Stride cannot be 0");
    }
    stride = dev.eraseAlignEnd(stride);
    std::vector<std::byte> reader_buf_v(stride * 3 + dev.getEraseSize()),
        dev_buf_v(stride * 3);
    const std::span<std::byte> reader_buf(reader_buf_v), dev_buf(dev_buf_v);
    std::span<std::byte> reader_data, dev_data;
    auto start = dev.eraseAlignStart(dev_offset), erase_offset = start;
    if (start != dev_offset)
    {
        dev_data = dev_buf.subspan(0, dev_offset - start);
        reader_data = reader_buf.subspan(0, dev_data.size());
        readAtExact(dev, dev_data, start);
        std::memcpy(reader_data.data(), dev_data.data(), dev_data.size());
        max_size = std::max(max_size, max_size + dev_data.size());
    }
    bool eof = false;
    const auto next_dev_buf = [&]() {
        auto next_size = std::min(
            {stride, max_size - dev_data.size(), dev.getSize() - dev_offset});
        if (next_size == 0)
        {
            throw std::runtime_error("Device not large enough for reader");
        }
        return dev_buf.subspan(dev_data.size(), next_size);
    };
    const auto too_little_dev = [&]() {
        return dev_data.size() < stride * 2 && max_size > dev_data.size();
    };
    while (max_size > 0)
    {
        while (reader_data.size() < stride * 2 &&
               max_size > reader_data.size() && !eof)
        {
            try
            {
                auto new_reader_data = reader.readAt(
                    reader_buf.subspan(
                        reader_data.size(),
                        std::min(stride, max_size - reader_data.size())),
                    reader_offset);
                LOG(LogLevel::Info, " RF@{}#{}", dev_offset,
                    new_reader_data.size());
                mut.forward(new_reader_data, reader_offset);
                reader_data = reader_buf.subspan(0, reader_data.size() +
                                                        new_reader_data.size());
                reader_offset += new_reader_data.size();
            }
            catch (const stdplus::exception::Eof&)
            {
                eof = true;
                max_size = reader_data.size();
            }
            if (reader_data.size() == max_size)
            {
                max_size = dev.eraseAlignEnd(reader_data.size() -
                                             dev_data.size() + dev_offset) -
                           dev_offset + dev_data.size();
                if (max_size > reader_data.size())
                {
                    readAtExact(
                        dev,
                        reader_buf.subspan(reader_data.size(),
                                           max_size - reader_data.size()),
                        dev_offset + reader_data.size() - dev_data.size());
                    reader_data = reader_buf.subspan(0, max_size);
                }
            }
        }
        if (noread)
        {
            if (too_little_dev())
            {
                for (size_t i = dev_data.size(); i < reader_data.size(); ++i)
                {
                    dev_buf[i] = std::byte{0xff} ^ reader_data[i];
                }
                dev_offset += reader_data.size() - dev_data.size();
                dev_data = dev_buf.subspan(0, reader_data.size());
            }
        }
        else
        {
            while (too_little_dev())
            {
                auto new_dev_data = dev.readAt(next_dev_buf(), dev_offset);
                LOG(LogLevel::Info, " RD@{}#{}", dev_offset,
                    new_dev_data.size());
                dev_data =
                    dev_buf.subspan(0, dev_data.size() + new_dev_data.size());
                dev_offset += new_dev_data.size();
            }
        }
        if (dev_data.size() - (dev_offset - erase_offset) < stride &&
            dev_offset > erase_offset)
        {
            auto erase_buf_offset =
                dev_data.size() - (dev_offset - erase_offset);
            auto to_erase = dev_data.subspan(erase_buf_offset);
            to_erase = to_erase.subspan(
                0,
                std::min(to_erase.size() - to_erase.size() % dev.getEraseSize(),
                         stride));
            if (dev.needsErase(to_erase, reader_data.subspan(erase_buf_offset,
                                                             to_erase.size())))
            {
                dev.eraseBlocks(erase_offset / dev.getEraseSize(),
                                to_erase.size() / dev.getEraseSize());
                LOG(LogLevel::Info, " ED@{}#{}", erase_offset, to_erase.size());
                dev.mockErase(to_erase);
            }
            erase_offset += to_erase.size();
        }
        auto dev_erased = dev_data.subspan(
            0, std::min(stride, dev_data.size() - (dev_offset - erase_offset)));
        auto start = dev.shrinkWritePre(dev_erased, reader_data);
        auto consumed = start;
        if (start == 0)
        {
            auto size = dev.shrinkWritePost(dev_erased, reader_data);
            consumed = std::min(dev_erased.size(), reader_data.size());
            if (size > 0)
            {
                auto written = dev.writeAt(reader_data.subspan(0, size),
                                           dev_offset - dev_data.size())
                                   .size();
                LOG(LogLevel::Info, " WD@{}#{}", dev_offset - dev_data.size(),
                    written);
                if (written != size)
                {
                    consumed = written;
                }
            }
        }
        std::memmove(reader_data.data(), reader_data.data() + consumed,
                     reader_data.size() - consumed);
        reader_data = reader_data.subspan(0, reader_data.size() - consumed);
        std::memmove(dev_data.data(), dev_data.data() + consumed,
                     dev_data.size() - consumed);
        dev_data = dev_data.subspan(0, dev_data.size() - consumed);
        max_size -= consumed;
    }
    LOG(LogLevel::Info, "\n");
}

} // namespace ops
} // namespace flasher
