#include "pldm_interface.hpp"

#include <signal.h>
#include <sys/socket.h>

#include <stdplus/print.hpp>

#include <cstring>
#include <vector>

constexpr uint8_t MCTP_MSG_TYPE_PLDM = 1;

pldm_requester_rc_t PldmInterface::pldmSendAtNetwork(mctp_eid_t eid,
                                                     int networkId, int mctpFd,
                                                     const uint8_t* pldmReqMsg,
                                                     size_t reqMsgLen)
{
    struct sockaddr_mctp addr;
    memset(&addr, 0, sizeof(addr));
    addr.smctp_family = AF_MCTP;
    addr.smctp_network = networkId;
    addr.smctp_addr.s_addr = eid;
    addr.smctp_type = MCTP_MSG_TYPE_PLDM;
    addr.smctp_tag = MCTP_TAG_OWNER;

    int rc =
        sendto(mctpFd, reinterpret_cast<const uint8_t*>(pldmReqMsg), reqMsgLen,
               0, reinterpret_cast<struct sockaddr*>(&addr), sizeof(addr));

    if (rc == -1)
    {
        perror("Failed on sending MCTP packet");
        stdplus::print(stderr,
                       "Failed to send packet over MCTP socket with rc: "
                       "{}\n",
                       rc);
        return PLDM_REQUESTER_SEND_FAIL;
    }
    return PLDM_REQUESTER_SUCCESS;
}

pldm_requester_rc_t PldmInterface::mctpRecvAtNetwork(mctp_eid_t eid, int mctpFd,
                                                     uint8_t** pldmRespMsg,
                                                     size_t* respMsgLen,
                                                     int networkId)
{
    ssize_t minLen = sizeof(struct pldm_msg_hdr);
    struct sockaddr_mctp addr;
    memset(&addr, 0, sizeof(addr));
    addr.smctp_family = AF_MCTP;
    addr.smctp_addr.s_addr = eid;
    addr.smctp_type = MCTP_MSG_TYPE_PLDM;
    addr.smctp_tag = MCTP_TAG_OWNER;
    addr.smctp_network = networkId;
    socklen_t addrLen = sizeof(addr);
    ssize_t length =
        recvfrom(mctpFd, nullptr, 0, MSG_PEEK | MSG_TRUNC,
                 reinterpret_cast<struct sockaddr*>(&addr), &addrLen);
    if (length == 0)
    {
        stdplus::print(
            stderr, "Failure at MCTP socket: No length received, errno: {}\n",
            strerror(errno));
        return PLDM_REQUESTER_RECV_FAIL;
    }
    if (length < minLen)
    {
        /* read and discard */
        std::vector<uint8_t> buf(length);
        recv(mctpFd, buf.data(), length, 0);

        stdplus::print(stderr,
                       "Failure at MCTP socket: Length less than min"
                       "bytes, length: {} and minimum length required: {}\n",
                       std::to_string(length), std::to_string(minLen));
        return PLDM_REQUESTER_INVALID_RECV_LEN;
    }
    ssize_t bytes = recvfrom(mctpFd, *pldmRespMsg, length, MSG_TRUNC,
                             (struct sockaddr*)&addr, &addrLen);
    if (length != bytes)
    {
        stdplus::print(stderr,
                       "Failure at MCTP socket: Length is not equal"
                       "to the bytes read, with length: {} and"
                       "bytes: {}\n",
                       length, bytes);
        return PLDM_REQUESTER_INVALID_RECV_LEN;
    }
    *respMsgLen = length;
    return PLDM_REQUESTER_SUCCESS;
}

pldm_requester_rc_t PldmInterface::recvAtNetwork(mctp_eid_t eid, int mctpFd,
                                                 uint8_t** pldmRespMsg,
                                                 size_t* respMsgLen,
                                                 int networkId)
{
    pldm_requester_rc_t rc =
        mctpRecvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
    if (rc != PLDM_REQUESTER_SUCCESS)
    {
        stdplus::print(
            stderr,
            "Failed to receive message at MCTP socket with eid: {} and netId: "
            "{}\n",
            eid, networkId);
        return rc;
    }

    struct pldm_msg_hdr* hdr =
        reinterpret_cast<struct pldm_msg_hdr*>(*pldmRespMsg);
    if (hdr->request != PLDM_RESPONSE)
    {
        stdplus::print(
            stderr,
            "Failure at MCTP socket: Header is not a response for eid: {} and "
            "netId: {}\n",
            eid, networkId);
        return PLDM_REQUESTER_NOT_RESP_MSG;
    }

    if (*respMsgLen < (sizeof(struct pldm_msg_hdr) + sizeof(uint8_t)))
    {
        stdplus::print(
            stderr,
            "Failure at MCTP socket: Received message is smaller than expected "
            "for eid: {} and netId: {}\n",
            eid, networkId);
        return PLDM_REQUESTER_RESP_MSG_TOO_SMALL;
    }
    return PLDM_REQUESTER_SUCCESS;
}

pldm_requester_rc_t PldmInterface::pldmRecvAtNetwork(mctp_eid_t eid,
                                                     int networkId, int mctpFd,
                                                     uint8_t instanceId,
                                                     uint8_t** pldmRespMsg,
                                                     size_t* respMsgLen)
{
    pldm_requester_rc_t rc =
        recvAtNetwork(eid, mctpFd, pldmRespMsg, respMsgLen, networkId);
    if (rc != PLDM_REQUESTER_SUCCESS)
    {
        return rc;
    }
    struct pldm_msg_hdr* hdr =
        reinterpret_cast<struct pldm_msg_hdr*>(*pldmRespMsg);
    if (hdr->instance_id != instanceId)
    {
        stdplus::print(stderr,
                       "Failure at MCTP socket receive: Instance id"
                       "mismatch: Found: {}, Expected: {}\n",
                       std::to_string(hdr->instance_id), instanceId);
        return PLDM_REQUESTER_INSTANCE_ID_MISMATCH;
    }
    return PLDM_REQUESTER_SUCCESS;
}
