#include "tlbmc/expression/expression.h"

#include <algorithm>
#include <cctype>
#include <cstddef>
#include <memory>
#include <string>
#include <system_error>  // NOLINT: system_error is commonly used in BMC
#include <utility>
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/charconv.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "g3/macros.h"

namespace milotic_tlbmc {
namespace expression::internal {

// Represents a constant value.
class Constant final : public Expression {
 public:
  explicit Constant(double value) : value_(value) {}
  absl::StatusOr<double> Evaluate(
      const absl::flat_hash_map<std::string, double>& /*variable_maps*/)
      const override {
    return value_;
  }

  void GetRequiredVariables(
      absl::flat_hash_set<std::string>& /*variable_names*/) const override {}

 private:
  const double value_;
};

// Represents a variable.
class Variable final : public Expression {
 public:
  explicit Variable(absl::string_view name) : name_(name) {}
  absl::StatusOr<double> Evaluate(
      const absl::flat_hash_map<std::string, double>& variable_maps)
      const override {
    if (auto it = variable_maps.find(name_); it != variable_maps.end()) {
      return it->second;
    }
    return absl::NotFoundError(absl::StrCat("Variable not found: ", name_));
  }

  void GetRequiredVariables(
      absl::flat_hash_set<std::string>& variable_names) const override {
    variable_names.insert(name_);
  }

 private:
  const std::string name_;
};

// Represents a binary operation.
class BinaryOperation final : public Expression {
 public:
  enum class Operator { kAdd, kSubtract, kMultiply, kDivide };

  BinaryOperation(Operator op, std::unique_ptr<Expression> lhs,
                  std::unique_ptr<Expression> rhs)
      : operator_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}

  absl::StatusOr<double> Evaluate(
      const absl::flat_hash_map<std::string, double>& variable_maps)
      const override {
    ECCLESIA_ASSIGN_OR_RETURN(double lhs_val, lhs_->Evaluate(variable_maps));
    ECCLESIA_ASSIGN_OR_RETURN(double rhs_val, rhs_->Evaluate(variable_maps));

    switch (operator_) {
      case Operator::kAdd:
        return lhs_val + rhs_val;
      case Operator::kSubtract:
        return lhs_val - rhs_val;
      case Operator::kMultiply:
        return lhs_val * rhs_val;
      case Operator::kDivide:
        if (rhs_val == 0) {
          return absl::InvalidArgumentError("Division by zero.");
        }
        return lhs_val / rhs_val;
    }
    return absl::InternalError("Unreachable.");
  }

  void GetRequiredVariables(
      absl::flat_hash_set<std::string>& variable_names) const override {
    lhs_->GetRequiredVariables(variable_names);
    rhs_->GetRequiredVariables(variable_names);
  }

 private:
  const Operator operator_;
  const std::unique_ptr<Expression> lhs_;
  const std::unique_ptr<Expression> rhs_;
};

// Represents the Maximum function.
class Max final : public Expression {
 public:
  explicit Max(std::vector<std::unique_ptr<Expression>> operands)
      : operands_(std::move(operands)) {}

  absl::StatusOr<double> Evaluate(
      const absl::flat_hash_map<std::string, double>& variable_maps)
      const override {
    // operands are guaranteed to be non-empty by the parser.
    ECCLESIA_ASSIGN_OR_RETURN(double max_val,
                              operands_[0]->Evaluate(variable_maps));
    for (const auto& operand : operands_) {
      ECCLESIA_ASSIGN_OR_RETURN(double val, operand->Evaluate(variable_maps));
      max_val = std::max(max_val, val);
    }
    return max_val;
  }

  void GetRequiredVariables(
      absl::flat_hash_set<std::string>& variable_names) const override {
    for (const auto& operand : operands_) {
      operand->GetRequiredVariables(variable_names);
    }
  }

 private:
  const std::vector<std::unique_ptr<Expression>> operands_;
};

// Consumes whitespace until the first non-whitespace character.
absl::string_view ConsumeWhitespace(absl::string_view s) {
  while (!s.empty() && std::isspace(s.front()) != 0) {
    s.remove_prefix(1);
  }
  return s;
}

// Returns the first character in the string, or '\0' if the string is empty.
char Peek(absl::string_view s) {
  s = ConsumeWhitespace(s);
  return s.empty() ? '\0' : s.front();
}

// Consumes the given character if it is present. Returns true if the character
// is found, false otherwise.
bool Consume(absl::string_view& s, char c) {
  s = ConsumeWhitespace(s);
  if (!s.empty() && s.front() == c) {
    s.remove_prefix(1);
    return true;
  }
  return false;
}

// Parses a primary, which is a number or a variable or an primary operator.
absl::StatusOr<std::unique_ptr<Expression>> ParsePrimary(absl::string_view& s) {
  s = ConsumeWhitespace(s);
  if (s.empty()) {
    return absl::InvalidArgumentError("Unexpected end of expression.");
  }

  // Parse a number.
  if ((std::isdigit(s.front()) != 0) || s.front() == '.' || s.front() == '-') {
    double value;
    auto [ptr, ec] = absl::from_chars(s.data(), s.data() + s.size(), value);
    if (ec == std::errc()) {
      s.remove_prefix(ptr - s.data());
      return std::make_unique<Constant>(value);
    }
    return absl::InvalidArgumentError(
          absl::StrCat("Failed to parse number: ", ec));
  }

  // Parse a variable name or a operator name.
  if (std::isalpha(s.front()) != 0) {
    size_t len = 0;
    while (len < s.size() && (std::isalnum(s[len]) != 0 || s[len] == '_')) {
      len++;
    }
    absl::string_view name = s.substr(0, len);
    s.remove_prefix(len);
    if (name == "Maximum") {
      if (!Consume(s, '(')) {
        return absl::InvalidArgumentError("Expected '(' after Maximum.");
      }
      std::vector<std::unique_ptr<Expression>> args;
      if (Peek(s) != ')') {
        while (true) {
          ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> arg,
                                    ParseExpression(s));
          args.push_back(std::move(arg));
          if (!Consume(s, ',')) break;
        }
      }
      if (!Consume(s, ')')) {
        return absl::InvalidArgumentError(
            "Expected ')' after Maximum arguments.");
      }
      if (args.empty()) {
        return absl::InvalidArgumentError(
            "Maximum requires at least one argument.");
      }
      return std::make_unique<Max>(std::move(args));
    }
    return std::make_unique<Variable>(name);
  }

  if (Consume(s, '(')) {
    ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> expr,
                              ParseExpression(s));
    if (!Consume(s, ')')) {
      return absl::InvalidArgumentError("Mismatched parentheses.");
    }
    return expr;
  }

  return absl::InvalidArgumentError(
      absl::StrCat("Unexpected token: ", s.substr(0, 1)));
}

// Parses a factor, which is a primary followed by a sequence of multiplication
// and division operators.
absl::StatusOr<std::unique_ptr<Expression>> ParseFactor(absl::string_view& s) {
  ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> lhs, ParsePrimary(s));

  while (true) {
    if (Peek(s) == '*') {
      Consume(s, '*');
      ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
                                ParsePrimary(s));
      lhs = std::make_unique<BinaryOperation>(
          BinaryOperation::Operator::kMultiply, std::move(lhs),
          std::move(rhs));
    } else if (Peek(s) == '/') {
      Consume(s, '/');
      ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
                                ParsePrimary(s));
      lhs = std::make_unique<BinaryOperation>(
          BinaryOperation::Operator::kDivide, std::move(lhs), std::move(rhs));
    } else {
      break;
    }
  }
  return lhs;
}

// Parses a term, which is a factor followed by a sequence of binary
// operations.
absl::StatusOr<std::unique_ptr<Expression>> ParseTerm(absl::string_view& s) {
  ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> lhs, ParseFactor(s));

  while (true) {
    if (Peek(s) == '+') {
      Consume(s, '+');
      ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
                                ParseFactor(s));
      lhs = std::make_unique<BinaryOperation>(BinaryOperation::Operator::kAdd,
                                              std::move(lhs), std::move(rhs));
    } else if (Peek(s) == '-') {
      Consume(s, '-');
      ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
                                ParseFactor(s));
      lhs = std::make_unique<BinaryOperation>(
          BinaryOperation::Operator::kSubtract, std::move(lhs),
          std::move(rhs));
    } else {
      break;
    }
  }
  return lhs;
}

absl::StatusOr<std::unique_ptr<Expression>> ParseExpression(
    absl::string_view& s) {
  return ParseTerm(s);
}

}  // namespace expression::internal

absl::StatusOr<std::unique_ptr<Expression>> Parse(
    absl::string_view expr) {
  ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> result,
                            expression::internal::ParseExpression(expr));

  // Remove any trailing whitespace.
  absl::string_view left_over = expression::internal::ConsumeWhitespace(expr);
  if (!left_over.empty()) {
    return absl::InvalidArgumentError(
        absl::StrCat("Unexpected trailing characters: ", left_over));
  }
  return result;
}

}  // namespace milotic_tlbmc
