blob: 43361b46c4189d8e0cb1c5be84fe8f3741edb465 [file] [log] [blame]
#include "tlbmc/expression/expression.h"
#include <algorithm>
#include <cctype>
#include <cstddef>
#include <memory>
#include <string>
#include <system_error> // NOLINT: system_error is commonly used in BMC
#include <utility>
#include <vector>
#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/charconv.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/string_view.h"
#include "g3/macros.h"
namespace milotic_tlbmc {
namespace expression::internal {
// Represents a constant value.
class Constant final : public Expression {
public:
explicit Constant(double value) : value_(value) {}
absl::StatusOr<double> Evaluate(
const absl::flat_hash_map<std::string, double>& /*variable_maps*/)
const override {
return value_;
}
void GetRequiredVariables(
absl::flat_hash_set<std::string>& /*variable_names*/) const override {}
private:
const double value_;
};
// Represents a variable.
class Variable final : public Expression {
public:
explicit Variable(absl::string_view name) : name_(name) {}
absl::StatusOr<double> Evaluate(
const absl::flat_hash_map<std::string, double>& variable_maps)
const override {
if (auto it = variable_maps.find(name_); it != variable_maps.end()) {
return it->second;
}
return absl::NotFoundError(absl::StrCat("Variable not found: ", name_));
}
void GetRequiredVariables(
absl::flat_hash_set<std::string>& variable_names) const override {
variable_names.insert(name_);
}
private:
const std::string name_;
};
// Represents a binary operation.
class BinaryOperation final : public Expression {
public:
enum class Operator { kAdd, kSubtract, kMultiply, kDivide };
BinaryOperation(Operator op, std::unique_ptr<Expression> lhs,
std::unique_ptr<Expression> rhs)
: operator_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {}
absl::StatusOr<double> Evaluate(
const absl::flat_hash_map<std::string, double>& variable_maps)
const override {
ECCLESIA_ASSIGN_OR_RETURN(double lhs_val, lhs_->Evaluate(variable_maps));
ECCLESIA_ASSIGN_OR_RETURN(double rhs_val, rhs_->Evaluate(variable_maps));
switch (operator_) {
case Operator::kAdd:
return lhs_val + rhs_val;
case Operator::kSubtract:
return lhs_val - rhs_val;
case Operator::kMultiply:
return lhs_val * rhs_val;
case Operator::kDivide:
if (rhs_val == 0) {
return absl::InvalidArgumentError("Division by zero.");
}
return lhs_val / rhs_val;
}
return absl::InternalError("Unreachable.");
}
void GetRequiredVariables(
absl::flat_hash_set<std::string>& variable_names) const override {
lhs_->GetRequiredVariables(variable_names);
rhs_->GetRequiredVariables(variable_names);
}
private:
const Operator operator_;
const std::unique_ptr<Expression> lhs_;
const std::unique_ptr<Expression> rhs_;
};
// Represents the Maximum function.
class Max final : public Expression {
public:
explicit Max(std::vector<std::unique_ptr<Expression>> operands)
: operands_(std::move(operands)) {}
absl::StatusOr<double> Evaluate(
const absl::flat_hash_map<std::string, double>& variable_maps)
const override {
// operands are guaranteed to be non-empty by the parser.
ECCLESIA_ASSIGN_OR_RETURN(double max_val,
operands_[0]->Evaluate(variable_maps));
for (const auto& operand : operands_) {
ECCLESIA_ASSIGN_OR_RETURN(double val, operand->Evaluate(variable_maps));
max_val = std::max(max_val, val);
}
return max_val;
}
void GetRequiredVariables(
absl::flat_hash_set<std::string>& variable_names) const override {
for (const auto& operand : operands_) {
operand->GetRequiredVariables(variable_names);
}
}
private:
const std::vector<std::unique_ptr<Expression>> operands_;
};
// Consumes whitespace until the first non-whitespace character.
absl::string_view ConsumeWhitespace(absl::string_view s) {
while (!s.empty() && std::isspace(s.front()) != 0) {
s.remove_prefix(1);
}
return s;
}
// Returns the first character in the string, or '\0' if the string is empty.
char Peek(absl::string_view s) {
s = ConsumeWhitespace(s);
return s.empty() ? '\0' : s.front();
}
// Consumes the given character if it is present. Returns true if the character
// is found, false otherwise.
bool Consume(absl::string_view& s, char c) {
s = ConsumeWhitespace(s);
if (!s.empty() && s.front() == c) {
s.remove_prefix(1);
return true;
}
return false;
}
// Parses a primary, which is a number or a variable or an primary operator.
absl::StatusOr<std::unique_ptr<Expression>> ParsePrimary(absl::string_view& s) {
s = ConsumeWhitespace(s);
if (s.empty()) {
return absl::InvalidArgumentError("Unexpected end of expression.");
}
// Parse a number.
if ((std::isdigit(s.front()) != 0) || s.front() == '.' || s.front() == '-') {
double value;
auto [ptr, ec] = absl::from_chars(s.data(), s.data() + s.size(), value);
if (ec == std::errc()) {
s.remove_prefix(ptr - s.data());
return std::make_unique<Constant>(value);
}
return absl::InvalidArgumentError(
absl::StrCat("Failed to parse number: ", ec));
}
// Parse a variable name or a operator name.
if (std::isalpha(s.front()) != 0) {
size_t len = 0;
while (len < s.size() && (std::isalnum(s[len]) != 0 || s[len] == '_')) {
len++;
}
absl::string_view name = s.substr(0, len);
s.remove_prefix(len);
if (name == "Maximum") {
if (!Consume(s, '(')) {
return absl::InvalidArgumentError("Expected '(' after Maximum.");
}
std::vector<std::unique_ptr<Expression>> args;
if (Peek(s) != ')') {
while (true) {
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> arg,
ParseExpression(s));
args.push_back(std::move(arg));
if (!Consume(s, ',')) break;
}
}
if (!Consume(s, ')')) {
return absl::InvalidArgumentError(
"Expected ')' after Maximum arguments.");
}
if (args.empty()) {
return absl::InvalidArgumentError(
"Maximum requires at least one argument.");
}
return std::make_unique<Max>(std::move(args));
}
return std::make_unique<Variable>(name);
}
if (Consume(s, '(')) {
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> expr,
ParseExpression(s));
if (!Consume(s, ')')) {
return absl::InvalidArgumentError("Mismatched parentheses.");
}
return expr;
}
return absl::InvalidArgumentError(
absl::StrCat("Unexpected token: ", s.substr(0, 1)));
}
// Parses a factor, which is a primary followed by a sequence of multiplication
// and division operators.
absl::StatusOr<std::unique_ptr<Expression>> ParseFactor(absl::string_view& s) {
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> lhs, ParsePrimary(s));
while (true) {
if (Peek(s) == '*') {
Consume(s, '*');
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
ParsePrimary(s));
lhs = std::make_unique<BinaryOperation>(
BinaryOperation::Operator::kMultiply, std::move(lhs),
std::move(rhs));
} else if (Peek(s) == '/') {
Consume(s, '/');
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
ParsePrimary(s));
lhs = std::make_unique<BinaryOperation>(
BinaryOperation::Operator::kDivide, std::move(lhs), std::move(rhs));
} else {
break;
}
}
return lhs;
}
// Parses a term, which is a factor followed by a sequence of binary
// operations.
absl::StatusOr<std::unique_ptr<Expression>> ParseTerm(absl::string_view& s) {
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> lhs, ParseFactor(s));
while (true) {
if (Peek(s) == '+') {
Consume(s, '+');
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
ParseFactor(s));
lhs = std::make_unique<BinaryOperation>(BinaryOperation::Operator::kAdd,
std::move(lhs), std::move(rhs));
} else if (Peek(s) == '-') {
Consume(s, '-');
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs,
ParseFactor(s));
lhs = std::make_unique<BinaryOperation>(
BinaryOperation::Operator::kSubtract, std::move(lhs),
std::move(rhs));
} else {
break;
}
}
return lhs;
}
absl::StatusOr<std::unique_ptr<Expression>> ParseExpression(
absl::string_view& s) {
return ParseTerm(s);
}
} // namespace expression::internal
absl::StatusOr<std::unique_ptr<Expression>> Parse(
absl::string_view expr) {
ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> result,
expression::internal::ParseExpression(expr));
// Remove any trailing whitespace.
absl::string_view left_over = expression::internal::ConsumeWhitespace(expr);
if (!left_over.empty()) {
return absl::InvalidArgumentError(
absl::StrCat("Unexpected trailing characters: ", left_over));
}
return result;
}
} // namespace milotic_tlbmc