blob: 405e779f2856abb480dd009e0165d0e9d4146b80 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <flasher/logging.hpp>
#include <flasher/ops.hpp>
#include <stdplus/exception.hpp>
#include <algorithm>
#include <format>
#include <stdexcept>
namespace flasher
{
namespace ops
{
static void readAtExact(Device& dev, std::span<std::byte> buf, size_t off)
{
dev.readAtExact(buf, off);
LOG(LogLevel::Info, " RD@{}#{}", off, buf.size());
}
void automatic(Device& dev, size_t dev_offset, Reader& reader,
size_t reader_offset, Mutate& mut, size_t max_size,
std::optional<size_t> stride_size, bool noread)
{
if (dev.getEraseSize() == 0)
{
return write(dev, dev_offset, reader, reader_offset, mut, max_size,
stride_size, noread);
}
if (dev_offset > dev.getSize())
{
throw std::invalid_argument(std::format(
"Device smaller than offset, {} < {}", dev.getSize(), dev_offset));
}
size_t stride = stride_size ? *stride_size : dev.recommendedStride();
if (stride == 0)
{
throw std::invalid_argument("Stride cannot be 0");
}
stride = dev.eraseAlignEnd(stride);
std::vector<std::byte> reader_buf_v(stride * 3 + dev.getEraseSize()),
dev_buf_v(stride * 3);
const std::span<std::byte> reader_buf(reader_buf_v), dev_buf(dev_buf_v);
std::span<std::byte> reader_data, dev_data;
auto start = dev.eraseAlignStart(dev_offset), erase_offset = start;
if (start != dev_offset)
{
dev_data = dev_buf.subspan(0, dev_offset - start);
reader_data = reader_buf.subspan(0, dev_data.size());
readAtExact(dev, dev_data, start);
std::memcpy(reader_data.data(), dev_data.data(), dev_data.size());
max_size = std::max(max_size, max_size + dev_data.size());
}
bool eof = false;
const auto next_dev_buf = [&]() {
auto next_size = std::min(
{stride, max_size - dev_data.size(), dev.getSize() - dev_offset});
if (next_size == 0)
{
throw std::runtime_error("Device not large enough for reader");
}
return dev_buf.subspan(dev_data.size(), next_size);
};
const auto too_little_dev = [&]() {
return dev_data.size() < stride * 2 && max_size > dev_data.size();
};
while (max_size > 0)
{
while (reader_data.size() < stride * 2 &&
max_size > reader_data.size() && !eof)
{
try
{
auto new_reader_data = reader.readAt(
reader_buf.subspan(
reader_data.size(),
std::min(stride, max_size - reader_data.size())),
reader_offset);
LOG(LogLevel::Info, " RF@{}#{}", dev_offset,
new_reader_data.size());
mut.forward(new_reader_data, reader_offset);
reader_data = reader_buf.subspan(0, reader_data.size() +
new_reader_data.size());
reader_offset += new_reader_data.size();
}
catch (const stdplus::exception::Eof&)
{
eof = true;
max_size = reader_data.size();
}
if (reader_data.size() == max_size)
{
max_size = dev.eraseAlignEnd(reader_data.size() -
dev_data.size() + dev_offset) -
dev_offset + dev_data.size();
if (max_size > reader_data.size())
{
readAtExact(
dev,
reader_buf.subspan(reader_data.size(),
max_size - reader_data.size()),
dev_offset + reader_data.size() - dev_data.size());
reader_data = reader_buf.subspan(0, max_size);
}
}
}
if (noread)
{
if (too_little_dev())
{
for (size_t i = dev_data.size(); i < reader_data.size(); ++i)
{
dev_buf[i] = std::byte{0xff} ^ reader_data[i];
}
dev_offset += reader_data.size() - dev_data.size();
dev_data = dev_buf.subspan(0, reader_data.size());
}
}
else
{
while (too_little_dev())
{
auto new_dev_data = dev.readAt(next_dev_buf(), dev_offset);
LOG(LogLevel::Info, " RD@{}#{}", dev_offset,
new_dev_data.size());
dev_data =
dev_buf.subspan(0, dev_data.size() + new_dev_data.size());
dev_offset += new_dev_data.size();
}
}
if (dev_data.size() - (dev_offset - erase_offset) < stride &&
dev_offset > erase_offset)
{
auto erase_buf_offset =
dev_data.size() - (dev_offset - erase_offset);
auto to_erase = dev_data.subspan(erase_buf_offset);
to_erase = to_erase.subspan(
0,
std::min(to_erase.size() - to_erase.size() % dev.getEraseSize(),
stride));
if (dev.needsErase(to_erase, reader_data.subspan(erase_buf_offset,
to_erase.size())))
{
dev.eraseBlocks(erase_offset / dev.getEraseSize(),
to_erase.size() / dev.getEraseSize());
LOG(LogLevel::Info, " ED@{}#{}", erase_offset, to_erase.size());
dev.mockErase(to_erase);
}
erase_offset += to_erase.size();
}
auto dev_erased = dev_data.subspan(
0, std::min(stride, dev_data.size() - (dev_offset - erase_offset)));
auto start = dev.shrinkWritePre(dev_erased, reader_data);
auto consumed = start;
if (start == 0)
{
auto size = dev.shrinkWritePost(dev_erased, reader_data);
consumed = std::min(dev_erased.size(), reader_data.size());
if (size > 0)
{
auto written = dev.writeAt(reader_data.subspan(0, size),
dev_offset - dev_data.size())
.size();
LOG(LogLevel::Info, " WD@{}#{}", dev_offset - dev_data.size(),
written);
if (written != size)
{
consumed = written;
}
}
}
std::memmove(reader_data.data(), reader_data.data() + consumed,
reader_data.size() - consumed);
reader_data = reader_data.subspan(0, reader_data.size() - consumed);
std::memmove(dev_data.data(), dev_data.data() + consumed,
dev_data.size() - consumed);
dev_data = dev_data.subspan(0, dev_data.size() - consumed);
max_size -= consumed;
}
LOG(LogLevel::Info, "\n");
}
} // namespace ops
} // namespace flasher