#include "tlbmc/http/http_client.h"

#include <algorithm>
#include <cerrno>
#include <chrono>  // NOLINT
#include <cstddef>
#include <cstdio>
#include <fstream>
#include <ios>
#include <memory>
#include <string>
#include <string_view>
#include <utility>

#include "absl/functional/any_invocable.h"
#include "absl/log/log.h"
#include "absl/status/status.h"
#include "boost/asio.hpp"  //NOLINT: boost::asio is commonly used in BMC
#include "boost/asio/ip/tcp.hpp"  //NOLINT: boost::asio is commonly used in BMC
#include "resource.pb.h"
#include "openssl/err.h"
#include "openssl/x509.h"

namespace milotic_tlbmc {

// 300 MiB is the maximum response size we will accept to avoid DoS attacks.
// We add 1 MiB to allow for headers and other overhead.
constexpr std::size_t kMaxResponseSize = 301 * 1024 * 1024;

HttpClient::HttpClient(boost::asio::io_context& io_context,
                       const Options& options)
    : executor_(boost::asio::make_strand(io_context)),
      resolver_(executor_),
      ssl_ctx_(boost::asio::ssl::context::sslv23_client),
      options_(options) {
  boost::system::error_code ec;
  // NOLINTNEXTLINE: error code has been used.
  ssl_ctx_.load_verify_file(options.ca_cert_path, ec);
  ca_load_ec_ = ec;
  if (ec) {
    LOG(ERROR) << "Failed to load CA cert from " << options.ca_cert_path << ": "
               << ec.message();
  }
  ssl_ctx_.set_verify_mode(boost::asio::ssl::verify_peer);
  InitStream();
}

bool HttpClient::IsActive() { return is_active_; }

void HttpClient::SendRequest(
    std::string_view host, std::string_view port, std::string_view target,
    bool use_tls,
    absl::AnyInvocable<
        void(boost::beast::error_code,
             boost::beast::http::response<boost::beast::http::string_body>)>
        callback) {
  boost::asio::post(executor_, [this, host = std::string(host),
                                port = std::string(port),
                                target = std::string(target), use_tls,
                                callback = std::move(callback)]() mutable {
    if (is_active_) {
      callback(boost::asio::error::already_started, {});
      return;
    }
    buffer_.clear();
    if (use_tls && ca_load_ec_) {
      callback(ca_load_ec_, {});
      return;
    }
    is_active_ = true;

    job_ = std::make_shared<ScanJob>(ScanJob{
        .target = target,
        .callback = std::move(callback),
    });
    job_->req.version(11);
    job_->req.method(boost::beast::http::verb::get);
    job_->req.target(target);
    job_->req.set(boost::beast::http::field::host, host);

    // If TLS usage has changed, we can't reuse connection.
    if (connection_established_ && use_tls_ != use_tls) {
      connection_established_ = false;
    }
    use_tls_ = use_tls;

    // If the connection is already established and the host and port are the
    // same, we can reuse the connection and send the request immediately.
    if (connection_established_ && host_ == host && port_ == port &&
        stream_->next_layer().socket().is_open()) {
      LOG(INFO) << "Connection already established";
      stream_->next_layer().expires_after(options_.read_write_timeout);
      if (use_tls_) {
        boost::beast::http::async_write(
            *stream_, job_->req,
            [this](boost::beast::error_code ec, std::size_t bytes_transferred) {
              OnWrite(ec, bytes_transferred);
            });
      } else {
        boost::beast::http::async_write(
            stream_->next_layer(), job_->req,
            [this](boost::beast::error_code ec, std::size_t bytes_transferred) {
              OnWrite(ec, bytes_transferred);
            });
      }
      return;
    }

    InitStream();

    connection_established_ = false;
    host_ = host;
    port_ = port;
    use_tls_ = use_tls;
    resolver_.async_resolve(
        host_, port_,
        [this](boost::beast::error_code ec,
               boost::asio::ip::tcp::resolver::results_type results) {
          OnResolve(ec, results);
        });
  });
}

void HttpClient::OnResolve(
    boost::beast::error_code ec,
    boost::asio::ip::tcp::resolver::results_type results) {
  if (ec) {
    LOG(ERROR) << "Resolve failed for " << job_->target << ": " << ec.message();
    job_->callback(std::move(ec), std::move(job_->res));
    is_active_ = false;
    return;
  }

  stream_->next_layer().expires_after(options_.connect_timeout);
  stream_->next_layer().async_connect(
      results,
      [this](boost::beast::error_code ec,
             boost::asio::ip::tcp::resolver::results_type::endpoint_type
                 endpoint) { OnConnect(ec, endpoint); });
}

void HttpClient::OnConnect(
    boost::beast::error_code ec,
    boost::asio::ip::tcp::resolver::results_type::endpoint_type endpoint) {
  if (ec) {
    LOG(ERROR) << "Connect failed for " << job_->target << ": " << ec.message();
    job_->callback(std::move(ec), std::move(job_->res));
    is_active_ = false;
    return;
  }
  connection_established_ = true;

  if (use_tls_) {
    stream_->next_layer().expires_after(options_.handshake_timeout);
    if (options_.verify_hostname) {
      if (!SSL_set_tlsext_host_name(stream_->native_handle(), host_.c_str())) {
        boost::beast::error_code ec{static_cast<int>(::ERR_get_error()),
                                    boost::asio::error::get_ssl_category()};
        job_->callback(ec, {});
        is_active_ = false;
        return;
      }
    }
    stream_->async_handshake(
        boost::asio::ssl::stream_base::client,
        [this](boost::beast::error_code ec) { OnHandshake(ec); });
  } else {
    stream_->next_layer().expires_after(options_.read_write_timeout);
    boost::beast::http::async_write(
        stream_->next_layer(), job_->req,
        [this](boost::beast::error_code ec, std::size_t bytes_transferred) {
          OnWrite(ec, bytes_transferred);
        });
  }
}

void HttpClient::OnHandshake(boost::beast::error_code ec) {
  if (ec) {
    LOG(ERROR) << "Handshake failed for " << job_->target << ": "
               << ec.message();
    job_->callback(std::move(ec), std::move(job_->res));
    is_active_ = false;
    return;
  }

  stream_->next_layer().expires_after(options_.read_write_timeout);
  boost::beast::http::async_write(
      *stream_, job_->req,
      [this](boost::beast::error_code ec, std::size_t bytes_transferred) {
        OnWrite(ec, bytes_transferred);
      });
}

void HttpClient::OnWrite(boost::beast::error_code ec,
                         std::size_t bytes_transferred) {
  if (ec) {
    LOG(ERROR) << "Write failed for " << job_->target << ": " << ec.message();
    job_->callback(std::move(ec), std::move(job_->res));
    is_active_ = false;
    // Reset the connection established flag to false if write fails.
    connection_established_ = false;
    return;
  }

  auto parser = std::make_shared<
      boost::beast::http::response_parser<boost::beast::http::string_body>>();
  parser->body_limit(kMaxResponseSize);

  stream_->next_layer().expires_after(options_.read_write_timeout);
  if (use_tls_) {
    boost::beast::http::async_read(
        *stream_, buffer_, *parser,
        [this, parser](boost::beast::error_code ec,
                       std::size_t bytes_transferred) {
          job_->res = parser->release();
          OnRead(ec, bytes_transferred);
        });
  } else {
    boost::beast::http::async_read(
        stream_->next_layer(), buffer_, *parser,
        [this, parser](boost::beast::error_code ec,
                       std::size_t bytes_transferred) {
          job_->res = parser->release();
          OnRead(ec, bytes_transferred);
        });
  }
}

void HttpClient::OnRead(boost::beast::error_code ec,
                        std::size_t bytes_transferred) {
  if (ec) {
    LOG(ERROR) << "Read failed for " << job_->target << ": " << ec.message();
    // Reset the connection established flag to false if read fails.
    connection_established_ = false;
  }

  job_->callback(std::move(ec), std::move(job_->res));
  is_active_ = false;
}

void HttpClient::InitStream() {
  stream_ =
      std::make_unique<boost::beast::ssl_stream<boost::beast::tcp_stream>>(
          executor_, ssl_ctx_);
  stream_->set_verify_mode(boost::asio::ssl::verify_peer);
  stream_->set_verify_callback([this](bool preverified,
                                      boost::asio::ssl::verify_context& ctx) {
    if (!preverified) {
      int err = X509_STORE_CTX_get_error(ctx.native_handle());
      if (err == X509_V_ERR_INVALID_PURPOSE && !options_.verify_purpose) {
        return true;
      }
      LOG(ERROR) << "Verification failed: "
                 << X509_verify_cert_error_string(err) << " (err code: " << err
                 << ")";
      return false;
    }
    if (options_.verify_hostname) {
      return boost::asio::ssl::host_name_verification(host_)(preverified, ctx);
    }
    return true;
  });
}

void HttpClient::DownloadFile(
    std::string_view host, std::string_view port, std::string_view target,
    bool use_tls, std::string_view destination_path,
    absl::AnyInvocable<
        void(boost::beast::error_code,
             boost::beast::http::response<boost::beast::http::empty_body>)>
        callback,
    absl::AnyInvocable<void(std::size_t bytes_downloaded,
                            std::size_t bytes_written,
                            std::size_t content_length)>
        progress_callback) {
  return SendRequest(
      host, port, target, use_tls,
      [destination_path = std::string(destination_path),
       callback = std::move(callback),
       progress_callback = std::move(progress_callback)](
          boost::beast::error_code ec,
          boost::beast::http::response<boost::beast::http::string_body>
              res) mutable {
        boost::beast::http::response<boost::beast::http::empty_body> empty_res;
        if (!ec) {
          empty_res.result(res.result());
          empty_res.version(res.version());
          std::ofstream f;
          f.rdbuf()->pubsetbuf(nullptr, 0);
          f.open(destination_path, std::ios::binary);
          if (f.is_open()) {
            const char* data = res.body().data();
            std::size_t total_size = res.body().size();
            constexpr std::size_t kChunkSize1Mb = 1 * 1024 * 1024;
            bool write_success = true;
            std::size_t bytes_written = 0;
            for (std::size_t offset = 0; offset < total_size;
                 offset += kChunkSize1Mb) {
              std::size_t write_size =
                  std::min(kChunkSize1Mb, total_size - offset);
              f.write(data + offset, static_cast<std::streamsize>(write_size));
              if (!f.good()) {
                ec = boost::system::error_code(
                    errno, boost::system::generic_category());
                write_success = false;
                break;
              }
              bytes_written += write_size;
              if (progress_callback) {
                progress_callback(total_size, bytes_written, total_size);
              }
            }
            if (write_success) {
              f.flush();
              if (!f.good()) {
                ec = boost::system::error_code(
                    errno, boost::system::generic_category());
              }
            }
            f.close();
          } else {
            ec = boost::system::error_code(errno,
                                           boost::system::generic_category());
          }
        }
        if (callback) {
          callback(ec, std::move(empty_res));
        }
      });
}

}  // namespace milotic_tlbmc
