blob: e4546d7de3d3d4870db1ed9afdc3a371e06d7625 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <flasher/logging.hpp>
#include <flasher/ops.hpp>
#include <stdplus/exception.hpp>
#include <algorithm>
#include <cstddef>
#include <format>
#include <stdexcept>
#include <vector>
namespace flasher
{
namespace ops
{
void write(Device& dev, size_t dev_offset, Reader& reader, size_t reader_offset,
Mutate& mut, size_t max_size, std::optional<size_t> stride_size,
bool noread)
{
if (dev_offset > dev.getSize())
{
throw std::invalid_argument(std::format(
"Device smaller than offset, {} < {}", dev.getSize(), dev_offset));
}
size_t stride = stride_size ? *stride_size : dev.recommendedStride();
if (stride == 0)
{
throw std::invalid_argument("Stride cannot be 0");
}
std::vector<std::byte> reader_buf_v(stride * 2), dev_buf_v(stride * 2);
const std::span<std::byte> reader_buf(reader_buf_v), dev_buf(dev_buf_v);
std::span<std::byte> reader_data, dev_data;
bool eof = false;
const auto next_dev_buf = [&]() {
auto next_size = std::min(
{stride, max_size - dev_data.size(), dev.getSize() - dev_offset});
if (next_size == 0)
{
throw std::runtime_error("Device not large enough for reader");
}
return dev_buf.subspan(dev_data.size(), next_size);
};
const auto too_little_dev = [&]() {
return dev_data.size() < stride && max_size > dev_data.size();
};
while (max_size > 0)
{
while (reader_data.size() < stride && max_size > reader_data.size() &&
!eof)
{
try
{
auto new_reader_data = reader.readAt(
reader_buf.subspan(
reader_data.size(),
std::min(stride, max_size - reader_data.size())),
reader_offset);
LOG(LogLevel::Info, " RF@{}#{}", dev_offset,
new_reader_data.size());
mut.forward(new_reader_data, reader_offset);
reader_data = reader_buf.subspan(0, reader_data.size() +
new_reader_data.size());
reader_offset += new_reader_data.size();
}
catch (const stdplus::exception::Eof&)
{
eof = true;
max_size = reader_data.size();
}
}
if (noread)
{
if (too_little_dev())
{
auto erased = next_dev_buf();
dev.mockErase(erased);
dev_data = dev_buf.subspan(0, dev_data.size() + erased.size());
dev_offset += erased.size();
}
}
else
{
while (too_little_dev())
{
auto new_dev_data = dev.readAt(next_dev_buf(), dev_offset);
LOG(LogLevel::Info, " RD@{}#{}", dev_offset,
new_dev_data.size());
dev_data =
dev_buf.subspan(0, dev_data.size() + new_dev_data.size());
dev_offset += new_dev_data.size();
}
}
auto start = dev.shrinkWritePre(dev_data, reader_data);
auto consumed = start;
if (start == 0)
{
auto size = dev.shrinkWritePost(dev_data, reader_data);
consumed = std::min(dev_data.size(), reader_data.size());
if (size > 0)
{
auto written = dev.writeAt(reader_data.subspan(0, size),
dev_offset - dev_data.size())
.size();
LOG(LogLevel::Info, " WD@{}#{}", dev_offset - dev_data.size(),
written);
if (written != size)
{
consumed = written;
}
}
}
std::memmove(reader_data.data(), reader_data.data() + consumed,
reader_data.size() - consumed);
reader_data = reader_data.subspan(0, reader_data.size() - consumed);
std::memmove(dev_data.data(), dev_data.data() + consumed,
dev_data.size() - consumed);
dev_data = dev_data.subspan(0, dev_data.size() - consumed);
max_size -= consumed;
}
LOG(LogLevel::Info, "\n");
}
} // namespace ops
} // namespace flasher