| #include "tlbmc/expression/expression.h" |
| |
| #include <algorithm> |
| #include <cctype> |
| #include <cstddef> |
| #include <memory> |
| #include <string> |
| #include <system_error> // NOLINT: system_error is commonly used in BMC |
| #include <utility> |
| #include <vector> |
| |
| #include "absl/container/flat_hash_map.h" |
| #include "absl/container/flat_hash_set.h" |
| #include "absl/status/status.h" |
| #include "absl/status/statusor.h" |
| #include "absl/strings/charconv.h" |
| #include "absl/strings/str_cat.h" |
| #include "absl/strings/string_view.h" |
| #include "g3/macros.h" |
| |
| namespace milotic_tlbmc { |
| namespace expression::internal { |
| |
| // Represents a constant value. |
| class Constant final : public Expression { |
| public: |
| explicit Constant(double value) : value_(value) {} |
| absl::StatusOr<double> Evaluate( |
| const absl::flat_hash_map<std::string, double>& /*variable_maps*/) |
| const override { |
| return value_; |
| } |
| |
| void GetRequiredVariables( |
| absl::flat_hash_set<std::string>& /*variable_names*/) const override {} |
| |
| private: |
| const double value_; |
| }; |
| |
| // Represents a variable. |
| class Variable final : public Expression { |
| public: |
| explicit Variable(absl::string_view name) : name_(name) {} |
| absl::StatusOr<double> Evaluate( |
| const absl::flat_hash_map<std::string, double>& variable_maps) |
| const override { |
| if (auto it = variable_maps.find(name_); it != variable_maps.end()) { |
| return it->second; |
| } |
| return absl::NotFoundError(absl::StrCat("Variable not found: ", name_)); |
| } |
| |
| void GetRequiredVariables( |
| absl::flat_hash_set<std::string>& variable_names) const override { |
| variable_names.insert(name_); |
| } |
| |
| private: |
| const std::string name_; |
| }; |
| |
| // Represents a binary operation. |
| class BinaryOperation final : public Expression { |
| public: |
| enum class Operator { kAdd, kSubtract, kMultiply, kDivide }; |
| |
| BinaryOperation(Operator op, std::unique_ptr<Expression> lhs, |
| std::unique_ptr<Expression> rhs) |
| : operator_(op), lhs_(std::move(lhs)), rhs_(std::move(rhs)) {} |
| |
| absl::StatusOr<double> Evaluate( |
| const absl::flat_hash_map<std::string, double>& variable_maps) |
| const override { |
| ECCLESIA_ASSIGN_OR_RETURN(double lhs_val, lhs_->Evaluate(variable_maps)); |
| ECCLESIA_ASSIGN_OR_RETURN(double rhs_val, rhs_->Evaluate(variable_maps)); |
| |
| switch (operator_) { |
| case Operator::kAdd: |
| return lhs_val + rhs_val; |
| case Operator::kSubtract: |
| return lhs_val - rhs_val; |
| case Operator::kMultiply: |
| return lhs_val * rhs_val; |
| case Operator::kDivide: |
| if (rhs_val == 0) { |
| return absl::InvalidArgumentError("Division by zero."); |
| } |
| return lhs_val / rhs_val; |
| } |
| return absl::InternalError("Unreachable."); |
| } |
| |
| void GetRequiredVariables( |
| absl::flat_hash_set<std::string>& variable_names) const override { |
| lhs_->GetRequiredVariables(variable_names); |
| rhs_->GetRequiredVariables(variable_names); |
| } |
| |
| private: |
| const Operator operator_; |
| const std::unique_ptr<Expression> lhs_; |
| const std::unique_ptr<Expression> rhs_; |
| }; |
| |
| // Represents the Maximum function. |
| class Max final : public Expression { |
| public: |
| explicit Max(std::vector<std::unique_ptr<Expression>> operands) |
| : operands_(std::move(operands)) {} |
| |
| absl::StatusOr<double> Evaluate( |
| const absl::flat_hash_map<std::string, double>& variable_maps) |
| const override { |
| // operands are guaranteed to be non-empty by the parser. |
| ECCLESIA_ASSIGN_OR_RETURN(double max_val, |
| operands_[0]->Evaluate(variable_maps)); |
| for (const auto& operand : operands_) { |
| ECCLESIA_ASSIGN_OR_RETURN(double val, operand->Evaluate(variable_maps)); |
| max_val = std::max(max_val, val); |
| } |
| return max_val; |
| } |
| |
| void GetRequiredVariables( |
| absl::flat_hash_set<std::string>& variable_names) const override { |
| for (const auto& operand : operands_) { |
| operand->GetRequiredVariables(variable_names); |
| } |
| } |
| |
| private: |
| const std::vector<std::unique_ptr<Expression>> operands_; |
| }; |
| |
| // Consumes whitespace until the first non-whitespace character. |
| absl::string_view ConsumeWhitespace(absl::string_view s) { |
| while (!s.empty() && std::isspace(s.front()) != 0) { |
| s.remove_prefix(1); |
| } |
| return s; |
| } |
| |
| // Returns the first character in the string, or '\0' if the string is empty. |
| char Peek(absl::string_view s) { |
| s = ConsumeWhitespace(s); |
| return s.empty() ? '\0' : s.front(); |
| } |
| |
| // Consumes the given character if it is present. Returns true if the character |
| // is found, false otherwise. |
| bool Consume(absl::string_view& s, char c) { |
| s = ConsumeWhitespace(s); |
| if (!s.empty() && s.front() == c) { |
| s.remove_prefix(1); |
| return true; |
| } |
| return false; |
| } |
| |
| // Parses a primary, which is a number or a variable or an primary operator. |
| absl::StatusOr<std::unique_ptr<Expression>> ParsePrimary(absl::string_view& s) { |
| s = ConsumeWhitespace(s); |
| if (s.empty()) { |
| return absl::InvalidArgumentError("Unexpected end of expression."); |
| } |
| |
| // Parse a number. |
| if ((std::isdigit(s.front()) != 0) || s.front() == '.' || s.front() == '-') { |
| double value; |
| auto [ptr, ec] = absl::from_chars(s.data(), s.data() + s.size(), value); |
| if (ec == std::errc()) { |
| s.remove_prefix(ptr - s.data()); |
| return std::make_unique<Constant>(value); |
| } |
| return absl::InvalidArgumentError( |
| absl::StrCat("Failed to parse number: ", ec)); |
| } |
| |
| // Parse a variable name or a operator name. |
| if (std::isalpha(s.front()) != 0) { |
| size_t len = 0; |
| while (len < s.size() && (std::isalnum(s[len]) != 0 || s[len] == '_')) { |
| len++; |
| } |
| absl::string_view name = s.substr(0, len); |
| s.remove_prefix(len); |
| if (name == "Maximum") { |
| if (!Consume(s, '(')) { |
| return absl::InvalidArgumentError("Expected '(' after Maximum."); |
| } |
| std::vector<std::unique_ptr<Expression>> args; |
| if (Peek(s) != ')') { |
| while (true) { |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> arg, |
| ParseExpression(s)); |
| args.push_back(std::move(arg)); |
| if (!Consume(s, ',')) break; |
| } |
| } |
| if (!Consume(s, ')')) { |
| return absl::InvalidArgumentError( |
| "Expected ')' after Maximum arguments."); |
| } |
| if (args.empty()) { |
| return absl::InvalidArgumentError( |
| "Maximum requires at least one argument."); |
| } |
| return std::make_unique<Max>(std::move(args)); |
| } |
| return std::make_unique<Variable>(name); |
| } |
| |
| if (Consume(s, '(')) { |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> expr, |
| ParseExpression(s)); |
| if (!Consume(s, ')')) { |
| return absl::InvalidArgumentError("Mismatched parentheses."); |
| } |
| return expr; |
| } |
| |
| return absl::InvalidArgumentError( |
| absl::StrCat("Unexpected token: ", s.substr(0, 1))); |
| } |
| |
| // Parses a factor, which is a primary followed by a sequence of multiplication |
| // and division operators. |
| absl::StatusOr<std::unique_ptr<Expression>> ParseFactor(absl::string_view& s) { |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> lhs, ParsePrimary(s)); |
| |
| while (true) { |
| if (Peek(s) == '*') { |
| Consume(s, '*'); |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs, |
| ParsePrimary(s)); |
| lhs = std::make_unique<BinaryOperation>( |
| BinaryOperation::Operator::kMultiply, std::move(lhs), |
| std::move(rhs)); |
| } else if (Peek(s) == '/') { |
| Consume(s, '/'); |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs, |
| ParsePrimary(s)); |
| lhs = std::make_unique<BinaryOperation>( |
| BinaryOperation::Operator::kDivide, std::move(lhs), std::move(rhs)); |
| } else { |
| break; |
| } |
| } |
| return lhs; |
| } |
| |
| // Parses a term, which is a factor followed by a sequence of binary |
| // operations. |
| absl::StatusOr<std::unique_ptr<Expression>> ParseTerm(absl::string_view& s) { |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> lhs, ParseFactor(s)); |
| |
| while (true) { |
| if (Peek(s) == '+') { |
| Consume(s, '+'); |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs, |
| ParseFactor(s)); |
| lhs = std::make_unique<BinaryOperation>(BinaryOperation::Operator::kAdd, |
| std::move(lhs), std::move(rhs)); |
| } else if (Peek(s) == '-') { |
| Consume(s, '-'); |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> rhs, |
| ParseFactor(s)); |
| lhs = std::make_unique<BinaryOperation>( |
| BinaryOperation::Operator::kSubtract, std::move(lhs), |
| std::move(rhs)); |
| } else { |
| break; |
| } |
| } |
| return lhs; |
| } |
| |
| absl::StatusOr<std::unique_ptr<Expression>> ParseExpression( |
| absl::string_view& s) { |
| return ParseTerm(s); |
| } |
| |
| } // namespace expression::internal |
| |
| absl::StatusOr<std::unique_ptr<Expression>> Parse( |
| absl::string_view expr) { |
| ECCLESIA_ASSIGN_OR_RETURN(std::unique_ptr<Expression> result, |
| expression::internal::ParseExpression(expr)); |
| |
| // Remove any trailing whitespace. |
| absl::string_view left_over = expression::internal::ConsumeWhitespace(expr); |
| if (!left_over.empty()) { |
| return absl::InvalidArgumentError( |
| absl::StrCat("Unexpected trailing characters: ", left_over)); |
| } |
| return result; |
| } |
| |
| } // namespace milotic_tlbmc |