// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "hoth_update_unittest.hpp"

#include <string_view>
#include <vector>

#include <gmock/gmock.h>

using ::testing::_;
using ::testing::ContainerEq;
using ::testing::NotNull;
using ::testing::Return;

using namespace std::literals;

namespace ipmi_hoth
{

using Cb = internal::DbusUpdate::Cb;
using FirmwareUpdateStatus = internal::DbusUpdate::FirmwareUpdateStatus;

class HothUpdateStatTest : public HothUpdateTest
{
  protected:
    blobs::BlobMeta meta_;
    // Initialize expected_meta_ with empty members
    blobs::BlobMeta expected_meta_;
};
class HothUpdateSessionStatTest : public HothUpdateStatTest
{};

struct MockCancel : stdplus::Cancelable
{
    MOCK_METHOD(void, cancel, (), (noexcept, override));
};

const auto static test_str = "Hello,\0 world!"s;
const std::vector<uint8_t> static test_buf(test_str.begin(), test_str.end());

TEST_F(HothUpdateStatTest, InvalidStatIsRejected)
{
    // Verify the hoth update handler checks for a valid session.

    EXPECT_FALSE(hvn.stat(legacyPath, &meta_));
}

TEST_F(HothUpdateStatTest, StatBeforeCommitReturnsInitialState)
{
    // Verify stat returns initial state before commit

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));

    EXPECT_TRUE(hvn.stat(legacyPath, &meta_));

    expected_meta_.size = 0;
    expected_meta_.blobState = blobs::StateFlags::open_write;
    EXPECT_EQ(meta_, expected_meta_);
}

// Rest of the tests will be using session stat, since
// stat with blob ID calls session ID after the initial checks

TEST_F(HothUpdateSessionStatTest, InvalidSessionStatIsRejected)
{
    // Verify the hoth update handler checks for a valid session.

    EXPECT_FALSE(hvn.stat(0, &meta_));
}

TEST_F(HothUpdateSessionStatTest, SessionStatBeforeCommitReturnsInitialState)
{
    // Verify the session stat before commit returns initial state
    // without any D-Bus call

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));

    EXPECT_TRUE(hvn.stat(session_, &meta_));

    expected_meta_.size = 0;
    expected_meta_.blobState = blobs::StateFlags::open_write;
    EXPECT_EQ(meta_, expected_meta_);
}

TEST_F(HothUpdateSessionStatTest, SessionStatAfterWriteMetadataLengthMatches)
{
    // Verify that after writes, the length returned matches
    // without any D-bus call

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    EXPECT_TRUE(hvn.stat(session_, &meta_));

    // We wrote one byte to the last index, making the length the buffer size.
    expected_meta_.size = test_buf.size();
    expected_meta_.blobState = blobs::StateFlags::open_write;
    EXPECT_EQ(meta_, expected_meta_);
}

TEST_F(HothUpdateSessionStatTest, SessionStatAfterErrorCommitReturnsStatus)
{
    // Verify that after commit errors out early, the session stat
    // returns the initial status without any D-Bus call

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));
    EXPECT_FALSE(hvn.commit(session_, std::vector<uint8_t>({1, 2, 3})));

    EXPECT_TRUE(hvn.stat(session_, &meta_));

    expected_meta_.size = test_buf.size();
    expected_meta_.blobState = blobs::StateFlags::open_write;
    EXPECT_EQ(meta_, expected_meta_);
}

TEST_F(HothUpdateSessionStatTest, SessionStatErrorReturnsCommitError)
{
    // Verify that mocking GetFirmwareUpdateStatus ouptut to Error after
    // a successful commit makes session stat return 'commit_error' status

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    Cb cb;
    EXPECT_CALL(dbus,
                UpdateFirmware(std::string_view(""), ContainerEq(test_buf), _))
        .WillOnce([&](std::string_view, const std::vector<uint8_t>&, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.commit(session_, std::vector<uint8_t>()));
    cb(FirmwareUpdateStatus::InProgress);

    EXPECT_CALL(dbus, GetFirmwareUpdateStatus(std::string_view(""), _))
        .WillOnce([&](std::string_view, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    struct blobs::BlobMeta meta;
    EXPECT_TRUE(hvn.stat(session_, &meta));

    EXPECT_EQ(meta.size, test_buf.size());
    EXPECT_EQ(meta.metadata.size(), 0);
    EXPECT_EQ(meta.blobState,
              blobs::StateFlags::open_write | blobs::StateFlags::committing);

    cb(FirmwareUpdateStatus::Error);
    EXPECT_TRUE(hvn.stat(session_, &meta));
    EXPECT_EQ(meta.blobState,
              blobs::StateFlags::open_write | blobs::StateFlags::commit_error);
}

TEST_F(HothUpdateSessionStatTest, SessionStatInProgressReturnsCommitting)
{
    // Verify that mocking GetFirmwareUpdateStatus ouptut to InProgress after
    // a successful commit makes session stat return 'committing' status

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    Cb cb;
    EXPECT_CALL(dbus,
                UpdateFirmware(std::string_view(""), ContainerEq(test_buf), _))
        .WillOnce([&](std::string_view, const std::vector<uint8_t>&, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.commit(session_, std::vector<uint8_t>()));
    cb(FirmwareUpdateStatus::InProgress);

    EXPECT_CALL(dbus, GetFirmwareUpdateStatus(std::string_view(""), _))
        .Times(2)
        .WillRepeatedly([&](std::string_view, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    cb(FirmwareUpdateStatus::InProgress);

    expected_meta_.size = test_buf.size();
    expected_meta_.blobState =
        blobs::StateFlags::open_write | blobs::StateFlags::committing;
    EXPECT_EQ(meta_, expected_meta_);

    // Check that repeated stats after callback trigger another call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
}

TEST_F(HothUpdateSessionStatTest, SessionStatDoneReturnsCommitted)
{
    // Verify that mocking GetFirmwareUpdateStatus ouptut to Done after
    // a successful commit makes session stat return 'committed' status

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    Cb cb;
    EXPECT_CALL(dbus,
                UpdateFirmware(std::string_view(""), ContainerEq(test_buf), _))
        .WillOnce([&](std::string_view, const std::vector<uint8_t>&, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.commit(session_, std::vector<uint8_t>()));
    cb(FirmwareUpdateStatus::InProgress);

    EXPECT_CALL(dbus, GetFirmwareUpdateStatus(std::string_view(""), _))
        .WillOnce([&](std::string_view, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.stat(session_, &meta_));

    expected_meta_.size = test_buf.size();
    expected_meta_.blobState =
        blobs::StateFlags::open_write | blobs::StateFlags::committing;
    EXPECT_EQ(meta_, expected_meta_);

    cb(FirmwareUpdateStatus::Done);
    expected_meta_.blobState =
        blobs::StateFlags::open_write | blobs::StateFlags::committed;
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
}

TEST_F(HothUpdateSessionStatTest, MultipleInProgressReturnsCommitting)
{
    // Verify that repeated session stat calls while committing
    // result in multiple D-Bus call

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    testing::StrictMock<MockCancel> c;
    Cb cb;
    EXPECT_CALL(dbus,
                UpdateFirmware(std::string_view(""), ContainerEq(test_buf), _))
        .WillOnce([&](std::string_view, const std::vector<uint8_t>&, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(&c);
        });
    EXPECT_TRUE(hvn.commit(session_, std::vector<uint8_t>()));

    // If commit is still outstanding we should not issue a new command
    expected_meta_.size = test_buf.size();
    expected_meta_.blobState =
        blobs::StateFlags::open_write | blobs::StateFlags::committing;
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);

    EXPECT_CALL(c, cancel());
    cb(FirmwareUpdateStatus::InProgress);
    testing::Mock::VerifyAndClearExpectations(&c);

    EXPECT_CALL(dbus, GetFirmwareUpdateStatus(std::string_view(""), _))
        .WillOnce([&](std::string_view, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(&c);
        });
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);

    EXPECT_CALL(c, cancel());
    cb(FirmwareUpdateStatus::InProgress);
}

TEST_F(HothUpdateSessionStatTest, IdempotentDoneReturnsCommitted)
{
    // Verify that repeated session stat calls while status is committed
    // result in one D-Bus call

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    Cb cb;
    EXPECT_CALL(dbus,
                UpdateFirmware(std::string_view(""), ContainerEq(test_buf), _))
        .WillOnce([&](std::string_view, const std::vector<uint8_t>&, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.commit(session_, std::vector<uint8_t>()));
    cb(FirmwareUpdateStatus::Done);

    expected_meta_.size = test_buf.size();
    expected_meta_.blobState =
        blobs::StateFlags::open_write | blobs::StateFlags::committed;
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
}

TEST_F(HothUpdateSessionStatTest, IdempotentErrorReturnsCommitError)
{
    // Verify that repeated session stat calls while status is commit_error
    // result in one D-Bus call

    EXPECT_CALL(dbus, pingHothd(std::string_view(""))).WillOnce(Return(true));

    EXPECT_TRUE(hvn.open(session_, hvn.requiredFlags(), legacyPath));
    // session, offset, data
    EXPECT_TRUE(hvn.write(session_, 0, test_buf));

    Cb cb;
    EXPECT_CALL(dbus,
                UpdateFirmware(std::string_view(""), ContainerEq(test_buf), _))
        .WillOnce([&](std::string_view, const std::vector<uint8_t>&, Cb&& icb) {
            cb = std::move(icb);
            return stdplus::Cancel(std::nullopt);
        });
    EXPECT_TRUE(hvn.commit(session_, std::vector<uint8_t>()));
    cb(FirmwareUpdateStatus::Error);

    expected_meta_.size = test_buf.size();
    expected_meta_.blobState =
        blobs::StateFlags::open_write | blobs::StateFlags::commit_error;
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
    // Shouldn't trigger a new call
    EXPECT_TRUE(hvn.stat(session_, &meta_));
    EXPECT_EQ(meta_, expected_meta_);
}

} // namespace ipmi_hoth
