diff --git a/grammar.cc b/grammar.cc index 62aa8e9..1c19295 100644 --- a/grammar.cc +++ b/grammar.cc @@ -465,22 +465,39 @@ upsert_stmt::upsert_stmt(prod *p, struct scope *s, table *v) shared_ptr statement_factory(struct scope *s) { + const auto check_d42 = [s](const std::string& stmt) { + return !s->excluded_stmts.count(stmt) && d42() == 1; + }; + + const auto check_d6 = [s](const int value, const std::string& stmt) { + return !s->excluded_stmts.count(stmt) && d6() > value; + }; + try { s->new_stmt(); - if (d42() == 1) + if (check_d42("merge")) { return make_shared((struct prod *)0, s); - if (d42() == 1) + } + + if (check_d42("insert")) { return make_shared((struct prod *)0, s); - else if (d42() == 1) + } + else if (check_d42("delete")) { return make_shared((struct prod *)0, s); - else if (d42() == 1) { + } + else if (check_d42("upsert")) { return make_shared((struct prod *)0, s); - } else if (d42() == 1) + } + else if (check_d42("update")) { return make_shared((struct prod *)0, s); - else if (d6() > 4) + } + else if (check_d6(4, "select_for_update")) { return make_shared((struct prod *)0, s); - else if (d6() > 5) + } + else if (check_d6(5, "cte")) { return make_shared((struct prod *)0, s); + } + return make_shared((struct prod *)0, s); } catch (runtime_error &e) { return statement_factory(s); diff --git a/relmodel.hh b/relmodel.hh index 713ed59..6c6a9c7 100644 --- a/relmodel.hh +++ b/relmodel.hh @@ -6,6 +6,7 @@ #include #include #include +#include #include #include #include @@ -80,6 +81,7 @@ struct scope { struct scope *parent; vector tables; // available to table_ref productions vector refs; // available to column_ref productions + std::unordered_set excluded_stmts; struct schema *schema; shared_ptr > stmt_seq; // sequence for stmt-unique identifiers scope(struct scope *parent = 0) : parent(parent) { diff --git a/sqlsmith.cc b/sqlsmith.cc index 182c7d6..8dbe18f 100644 --- a/sqlsmith.cc +++ b/sqlsmith.cc @@ -2,6 +2,7 @@ #include #include +#include #ifndef HAVE_BOOST_REGEX #include @@ -62,8 +63,8 @@ int main(int argc, char *argv[]) cerr << PACKAGE_NAME " " GITREV << endl; map options; - regex optregex("--(help|log-to|verbose|target|sqlite|monetdb|version|dump-all-graphs|dump-all-queries|seed|dry-run|max-queries|rng-state|exclude-catalog)(?:=((?:.|\n)*))?"); - + regex optregex("--(help|log-to|verbose|target|sqlite|monetdb|version|dump-all-graphs|dump-all-queries|seed|dry-run|max-queries|rng-state|exclude-catalog|exclude-stmts)(?:=((?:.|\n)*))?"); + for(char **opt = argv+1 ;opt < argv+argc; opt++) { smatch match; string s(*opt); @@ -75,26 +76,30 @@ int main(int argc, char *argv[]) } } + const std::string option_exclude_stmts = "exclude-stmts"; + if (options.count("help")) { cerr << - " --target=connstr postgres database to send queries to" << endl << + " --target=connstr postgres database to send queries to" << endl << #ifdef HAVE_LIBSQLITE3 - " --sqlite=URI SQLite database to send queries to" << endl << + " --sqlite=URI SQLite database to send queries to" << endl << #endif #ifdef HAVE_MONETDB - " --monetdb=connstr MonetDB database to send queries to" <fill_scope(scope); + if (options.count(option_exclude_stmts)) { + const auto& excluded_stmts = options[option_exclude_stmts]; + boost::regex re("[a-zA-Z_]+"); + boost::sregex_iterator it(excluded_stmts.begin(), excluded_stmts.end(), re); + boost::sregex_iterator itEnd; + for(; it != itEnd; ++it) { + std::string stmt = it->str(); + std::cout << "Excluding: " << stmt << "\n"; + + std::transform(stmt.begin(), stmt.end(), stmt.begin(), ::tolower); + scope.excluded_stmts.emplace(stmt); + } + } + if (options.count("rng-state")) { istringstream(options["rng-state"]) >> smith::rng; } else { @@ -147,7 +166,7 @@ int main(int argc, char *argv[]) loggers.push_back(l); signal(SIGINT, cerr_log_handler); } - + if (options.count("dump-all-graphs")) loggers.push_back(make_shared()); @@ -166,11 +185,12 @@ int main(int argc, char *argv[]) if (options.count("max-queries") && (queries_generated >= stol(options["max-queries"]))) return 0; + } } shared_ptr dut; - + if (options.count("sqlite")) { #ifdef HAVE_LIBSQLITE3 dut = make_shared(options["sqlite"]); @@ -180,7 +200,7 @@ int main(int argc, char *argv[]) #endif } else if(options.count("monetdb")) { -#ifdef HAVE_MONETDB +#ifdef HAVE_MONETDB dut = make_shared(options["monetdb"]); #else cerr << "Sorry, " PACKAGE_NAME " was compiled without MonetDB support." << endl; @@ -201,13 +221,13 @@ int main(int argc, char *argv[]) global_cerr_logger->report(); return 0; } - + /* Invoke top-level production to generate AST */ shared_ptr gen = statement_factory(&scope); for (auto l : loggers) l->generated(*gen); - + /* Generate SQL from AST */ ostringstream s; gen->out(s);