Add `enable-secureboot-enforcement` command to hoth_updater_cli.

This change introduces a new subcommand to `hoth_updater_cli` that allows enabling secure boot enforcement on the Hoth device.

TESTED=todo

Google-Bug-Id:442614164

Change-Id: I07c6a16d167d8bb2698caabc4b929c581828f843
Signed-off-by: Christian Kungler ckungler@google.com
diff --git a/tools/hoth_updater_cli.cpp b/tools/hoth_updater_cli.cpp
index 01bd1b1..f8541cf 100644
--- a/tools/hoth_updater_cli.cpp
+++ b/tools/hoth_updater_cli.cpp
@@ -92,21 +92,6 @@
                                "org.freedesktop.DBus.Properties", "Get");
 }
 
-std::vector<uint8_t> sendHostCommand(
-    sdbusplus::bus::bus& bus, std::string_view hoth_id,
-    const std::span<const uint8_t> command,
-    std::optional<sdbusplus::SdBusDuration> timeout = std::nullopt)
-{
-    sdbusplus::message::message msg =
-        hothMessage(bus, hoth_id, "SendHostCommand");
-    msg.append(command);
-    sdbusplus::message::message resp =
-        bus.call(msg, timeout.value_or(kCallTimeout));
-    std::vector<uint8_t> result;
-    resp.read(result);
-    return result;
-}
-
 template <typename T>
 std::optional<T> getHothStateProperty(
     sdbusplus::bus::bus& bus, std::string_view hoth_id,
@@ -137,6 +122,20 @@
 
 } // namespace
 
+std::vector<uint8_t> HothUpdaterCLI::sendHostCommand(
+    sdbusplus::bus::bus& bus, std::string_view hoth_id,
+    const std::span<const uint8_t> command)
+{
+    sdbusplus::message::message msg =
+        hothMessage(bus, hoth_id, "SendHostCommand");
+    msg.append(command);
+    sdbusplus::message::message resp =
+        bus.call(msg, kCallTimeout);
+    std::vector<uint8_t> result;
+    resp.read(result);
+    return result;
+}
+
 void HothUpdaterCLI::updateFirmware(sdbusplus::bus::bus& bus,
                                     std::string_view hoth_id,
                                     const std::span<const uint8_t> image)
@@ -203,6 +202,11 @@
 void HothUpdaterCLI::doUpdate(const Args& args)
 {
     sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
+    doUpdateLogic(bus, args);
+}
+
+void HothUpdaterCLI::doUpdateLogic(sdbusplus::bus::bus& bus, const Args& args)
+{
     auto end_time = std::chrono::steady_clock::now() + 5min;
 
     std::vector<uint8_t> image = readFileIntoByteArray(args.imageFilename);
@@ -257,6 +261,12 @@
 void HothUpdaterCLI::doFirmwareVersion(const Args& args)
 {
     sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
+    doFirmwareVersionLogic(bus, args);
+}
+
+void HothUpdaterCLI::doFirmwareVersionLogic(sdbusplus::bus::bus& bus,
+                                            const Args& args)
+{
     auto response = getHothVersion(bus, args.hothId);
 
     if (args.ro)
@@ -315,6 +325,12 @@
 void HothUpdaterCLI::doActivationCheck(const Args& args)
 {
     sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
+    doActivationCheckLogic(bus, args);
+}
+
+void HothUpdaterCLI::doActivationCheckLogic(sdbusplus::bus::bus& bus,
+                                            const Args& args)
+{
     stdplus::print(stdout, "installed_version: \"{}\"\n",
                    args.expectedRwVersion);
 
@@ -364,6 +380,57 @@
     }
 }
 
+// from libhoth command HOTH_PRV_CMD_HOTH_SET_SECURE_BOOT_ENFORCEMENT
+constexpr uint16_t kSetSecureBootEnforcementCmd = 0x3E0D;
+
+struct SetSecureBootEnforcementReq {
+  uint8_t enabled = 1;
+  uint8_t reserved[3] = {0};
+};
+
+void HothUpdaterCLI::doEnableSecurebootEnforcement(const Args& args)
+{
+    sdbusplus::bus::bus bus = sdbusplus::bus::new_default();
+    doEnableSecurebootEnforcementLogic(bus, args);
+}
+
+void HothUpdaterCLI::doEnableSecurebootEnforcementLogic(
+    sdbusplus::bus::bus& bus, const Args& args)
+{
+    google::hoth::internal::ReqHeader header;
+    SetSecureBootEnforcementReq payload;
+    google::hoth::internal::populateReqHeader(kSetSecureBootEnforcementCmd, 0,
+                                              &payload, sizeof(payload),
+                                              &header);
+
+    auto header_span = stdplus::raw::asSpan<uint8_t>(header);
+    auto payload_span = stdplus::raw::asSpan<uint8_t>(payload);
+    std::vector<uint8_t> command;
+    command.insert(command.end(), header_span.begin(), header_span.end());
+    command.insert(command.end(), payload_span.begin(), payload_span.end());
+
+    std::vector<std::string> hoth_ids = splitString(args.hothId, ',');
+    if (hoth_ids.empty())
+    {
+        hoth_ids.push_back("");
+    }
+
+    for (const std::string& hoth_id : hoth_ids)
+    {
+        std::vector<uint8_t> resp_bytes =
+            sendHostCommand(bus, hoth_id, command);
+
+        auto response = stdplus::raw::copyFrom<
+            google::hoth::internal::RspHeader>(resp_bytes);
+
+        if (response.result != 0)
+        {
+            throw std::runtime_error(
+                "Failed to enable secure boot enforcement");
+        }
+    }
+}
+
 void setupCLIApp(CLI::App& app, HothUpdaterCLI& cli, Args& args)
 {
     app.require_subcommand(1);
@@ -394,6 +461,11 @@
                                "Expected RW version");
     // RO version check not implemented yet.
     activation_check->callback([&args, &cli] { cli.doActivationCheck(args); });
+
+    auto* enable_secureboot = app.add_subcommand(
+        "enable-secureboot-enforcement", "Enable secure boot enforcement");
+    enable_secureboot->callback(
+        [&args, &cli] { cli.doEnableSecurebootEnforcement(args); });
 }
 
 } // namespace google::hoth::tools
diff --git a/tools/hoth_updater_cli.hpp b/tools/hoth_updater_cli.hpp
index 440d28a..f520749 100644
--- a/tools/hoth_updater_cli.hpp
+++ b/tools/hoth_updater_cli.hpp
@@ -80,6 +80,12 @@
     void doUpdate(const Args& args);
     void doActivationCheck(const Args& args);
     void doFirmwareVersion(const Args& args);
+    void doEnableSecurebootEnforcement(const Args& args);
+    void doUpdateLogic(sdbusplus::bus::bus& bus, const Args& args);
+    void doActivationCheckLogic(sdbusplus::bus::bus& bus, const Args& args);
+    void doFirmwareVersionLogic(sdbusplus::bus::bus& bus, const Args& args);
+    void doEnableSecurebootEnforcementLogic(sdbusplus::bus::bus& bus,
+                                            const Args& args);
     virtual void spiWrite(sdbusplus::bus::bus& bus, std::string_view hoth_id,
                           std::span<const uint8_t> image,
                           std::optional<uint32_t> address);
@@ -93,6 +99,9 @@
     virtual HothActivationStatistics
         getHothActivationStatistics(sdbusplus::bus::bus& bus,
                                     std::string_view hoth_id);
+    virtual std::vector<uint8_t>
+        sendHostCommand(sdbusplus::bus::bus& bus, std::string_view hoth_id,
+                        std::span<const uint8_t> command);
 };
 
 void setupCLIApp(CLI::App& app, HothUpdaterCLI& cli, Args& args);
diff --git a/tools/meson.build b/tools/meson.build
index c23a6f8..24d9a84 100644
--- a/tools/meson.build
+++ b/tools/meson.build
@@ -10,7 +10,7 @@
   ],
   include_directories: [hothd_headers, hothtools_headers],
   implicit_include_directories: false,
-  dependencies: libhothd_deps)
+  dependencies: [libhothd_dep])
 
 libhothtools_dep = declare_dependency(
   dependencies: libhothd_deps,
diff --git a/tools/test/hoth_updater_cli_test.cpp b/tools/test/hoth_updater_cli_test.cpp
index 253f8ca..f9598b7 100644
--- a/tools/test/hoth_updater_cli_test.cpp
+++ b/tools/test/hoth_updater_cli_test.cpp
@@ -43,6 +43,9 @@
                 (std::string_view));
     MOCK_METHOD(FirmwareUpdateStatus, getFirmwareUpdateStatus,
                 (sdbusplus::bus::bus & bus, std::string_view hoth_id));
+    MOCK_METHOD(std::vector<uint8_t>, sendHostCommand,
+                (sdbusplus::bus::bus & bus, std::string_view hoth_id,
+                 const std::span<const uint8_t> command));
 };
 
 TEST(PayloadUpdateCLITest, splitStringTest)
@@ -85,9 +88,10 @@
                                  }};
     Args args;
     args.expectedRwVersion = "2.3.456";
+    sdbusplus::bus::bus bus{nullptr};
     EXPECT_CALL(cli, getHothVersion(_, _)).Times(1).WillOnce(Return(rsp));
     EXPECT_CALL(cli, getHothActivationStatistics(_, _)).Times(0);
-    cli.doActivationCheck(args);
+    cli.doActivationCheckLogic(bus, args);
 }
 
 // Test doActivationCheck when version mismatch.
@@ -116,11 +120,12 @@
     };
     Args args;
     args.expectedRwVersion = "3.4.567";
+    sdbusplus::bus::bus bus{nullptr};
     EXPECT_CALL(cli, getHothVersion(_, _)).Times(1).WillOnce(Return(rsp));
     EXPECT_CALL(cli, getHothActivationStatistics(_, _))
         .Times(1)
         .WillOnce(Return(stats));
-    EXPECT_THROW(cli.doActivationCheck(args), std::runtime_error);
+    EXPECT_THROW(cli.doActivationCheckLogic(bus, args), std::runtime_error);
 }
 
 // Test for making sure the doUpdate loop works correctly
@@ -131,6 +136,7 @@
     args.imageFilename = "dummy";
 
     std::vector<uint8_t> dummy_image = {0, 1, 2, 3};
+    sdbusplus::bus::bus bus{nullptr};
     EXPECT_CALL(cli, readFileIntoByteArray(_))
         .Times(1)
         .WillOnce(Return(dummy_image));
@@ -143,7 +149,34 @@
         .WillOnce(Return(FirmwareUpdateStatus::InProgress))
         .WillOnce(Return(FirmwareUpdateStatus::Done));
 
-    cli.doUpdate(args);
+    cli.doUpdateLogic(bus, args);
+}
+
+// Test for enabling secureboot enforcement successfully
+TEST_F(HothUpdaterCLITest, EnableSecurebootEnforcementSuccess)
+{
+    Args args;
+    args.hothId = "inst0";
+    const std::vector<uint8_t> successResponse = {0x03, 0xfc, 0x00, 0x00,
+                                                  0x00, 0x00, 0x00, 0x00};
+    sdbusplus::bus::bus bus{nullptr};
+    EXPECT_CALL(cli, sendHostCommand(_, args.hothId, _))
+        .WillOnce(Return(successResponse));
+    EXPECT_NO_THROW(cli.doEnableSecurebootEnforcementLogic(bus, args));
+}
+
+// Test for enabling secureboot enforcement failure
+TEST_F(HothUpdaterCLITest, EnableSecurebootEnforcementFailure)
+{
+    Args args;
+    args.hothId = "inst0";
+    const std::vector<uint8_t> failureResponse = {0x03, 0xfc, 0x01, 0x00,
+                                                  0x00, 0x00, 0x00, 0x00};
+    sdbusplus::bus::bus bus{nullptr};
+    EXPECT_CALL(cli, sendHostCommand(_, args.hothId, _))
+        .WillOnce(Return(failureResponse));
+    EXPECT_THROW(cli.doEnableSecurebootEnforcementLogic(bus, args),
+                 std::runtime_error);
 }
 
 } // namespace