NVMe: Fix callback race condition and handle nested refresh calls Move callbacks to local variables before execution to avoid iterator invalidation if callbacks modify the callback list. Reset 'refreshing' flag before executing callbacks to allow nested refresh() calls. Added test_nvme_nested_refresh.cpp to verify the fix. Tested: Verified with new unit test (test_nvme_nested_refresh). Google-Bug-Id: 485660452 Change-Id: Ie542cd04f80ef58cf320078eec98822cb7208060 Signed-off-by: Guangzong Chen <guangzong@google.com>
diff --git a/src/NVMeCache.hpp b/src/NVMeCache.hpp index bf8e538..3e2cac2 100644 --- a/src/NVMeCache.hpp +++ b/src/NVMeCache.hpp
@@ -320,33 +320,38 @@ self->executingTask.reset(); self->enqueue(time, task); - for (auto&& func : task->partialRefreshCBs) + auto partialRefreshCBs = std::move(task->partialRefreshCBs); + task->partialRefreshCBs.clear(); + for (auto&& func : partialRefreshCBs) { if (func) { func({}); } } - task->partialRefreshCBs.clear(); if (complete) { - for (auto&& func : task->refreshCBs) - { - if (func) - { - func({}); - } - } - for (auto&& func : task->completeCBs) - { - if (func) - { - func({}); - } - } - task->refreshCBs.clear(); task->refreshing = false; + auto refreshCBs = std::move(task->refreshCBs); + auto completeCBs = std::move(task->completeCBs); + task->refreshCBs.clear(); + task->completeCBs.clear(); + + for (auto&& func : refreshCBs) + { + if (func) + { + func({}); + } + } + for (auto&& func : completeCBs) + { + if (func) + { + func({}); + } + } } }); });
diff --git a/tests/meson.build b/tests/meson.build index a0a3a3e..44c5993 100644 --- a/tests/meson.build +++ b/tests/meson.build
@@ -38,6 +38,17 @@ ) ) +test( + 'test_nvme_nested_refresh', + executable( + 'test_nvme_nested_refresh', + 'test_nvme_nested_refresh.cpp', + cpp_args: ['-UBOOST_ASIO_NO_DEPRECATED', '-UBOOST_ASIO_DISABLE_THREADS', '-UBOOST_ASIO_HAS_IO_URING','-DBUILDDIR='+ meson.current_build_dir(), '-DNVME_UNIT_TEST=1'], + dependencies: [ut_deps_list], + include_directories: '../src', + ) +) + # enable the nvme unit test only for CI docker because the test requires dbus if get_option('nvme').enabled() and run_command('/usr/bin/bash', '-c', '[ ! -f /.dockerenv ]').returncode() == 1 @@ -97,4 +108,3 @@ ) endif -
diff --git a/tests/test_nvme_nested_refresh.cpp b/tests/test_nvme_nested_refresh.cpp new file mode 100644 index 0000000..142c9e9 --- /dev/null +++ b/tests/test_nvme_nested_refresh.cpp
@@ -0,0 +1,90 @@ +#include "NVMeCache.hpp" + +#include <boost/asio.hpp> +#include <boost/asio/spawn.hpp> + +#include <gmock/gmock.h> +#include <gtest/gtest.h> + +class TestNVMeNestedRefresh : public testing::Test +{ + public: + boost::asio::io_context io; + std::shared_ptr<Scheduler<std::chrono::steady_clock>> scheduler; + + TestNVMeNestedRefresh() + { + scheduler = std::make_shared<Scheduler<std::chrono::steady_clock>>(io); + } +}; + +class MetricMock : public MetricBase<std::chrono::steady_clock> +{ + public: + explicit MetricMock( + std::shared_ptr<Scheduler<std::chrono::steady_clock>> scheduler) : + MetricBase<std::chrono::steady_clock>( + std::move(scheduler), std::chrono::steady_clock::duration::max()) + {} + + MOCK_METHOD( + void, readDevice, + (std::function<void(std::error_code ec, size_t size, bool complete)> && + cb), + (noexcept, override)); + + MOCK_METHOD((std::string_view), getIdentifier, (), + (const, noexcept, override)); + MOCK_METHOD(bool, isCacheValid, (), (const, noexcept, override)); + MOCK_METHOD((std::tuple<std::chrono::time_point<ClockType>, + std::chrono::time_point<ClockType>, + std::span<const uint8_t>>), + getCache, (), (const, noexcept, override)); +}; + +TEST_F(TestNVMeNestedRefresh, NestedRefresh) +{ + auto mtx = std::make_shared<MetricMock>(scheduler); + + EXPECT_CALL(*mtx, readDevice) + .WillRepeatedly( + [](std::function<void(std::error_code ec, size_t size, + bool complete)>&& cb) { cb({}, 0, true); }); + + scheduler->start(); + scheduler->dequeue(); + + bool cb1Called = false; + bool cb2Called = false; + + mtx->refresh([&](std::error_code ec) { + EXPECT_FALSE(ec); + cb1Called = true; + + // Issue 1: isRefreshing() should be false when the callback is + // triggered for completion In the bugged version, it is still true. + EXPECT_FALSE(mtx->isRefreshing()); + + // Issue 2: Nested refresh + mtx->refresh([&](std::error_code ec) { + EXPECT_FALSE(ec); + cb2Called = true; + }); + }); + + io.run_one(); // Run the first refresh + io.poll(); // Run any remaining handlers + + EXPECT_TRUE(cb1Called); + + // In the bugged version, cb2 will NEVER be called because it was cleared + // and task->refreshing was reset to false by the first refresh completion + // logic. + EXPECT_TRUE(cb2Called); +} + +int main(int argc, char** argv) +{ + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +}