diff --git a/include/dbng.hpp b/include/dbng.hpp index ce1487c4..3f0c5193 100644 --- a/include/dbng.hpp +++ b/include/dbng.hpp @@ -54,6 +54,16 @@ class dbng { return db_.update(t, std::forward(args)...); } + template + int replace(const T &t, Args &&...args) { + return db_.replace(t, std::forward(args)...); + } + + template + int replace(const std::vector &t, Args &&...args) { + return db_.replace(t, std::forward(args)...); + } + template bool delete_records(Args &&...where_conditon) { return db_.template delete_records( diff --git a/include/mysql.hpp b/include/mysql.hpp index 387dab94..a8275e73 100644 --- a/include/mysql.hpp +++ b/include/mysql.hpp @@ -100,23 +100,35 @@ class mysql { template int insert(const T &t, bool get_insert_id = false, Args &&...args) { - return insert_impl(false, t, get_insert_id, std::forward(args)...); + return insert_impl(OptType::insert, t, get_insert_id, + std::forward(args)...); } template int insert(const std::vector &t, bool get_insert_id = false, Args &&...args) { - return insert_impl(false, t, get_insert_id, std::forward(args)...); + return insert_impl(OptType::insert, t, get_insert_id, + std::forward(args)...); + } + + template + int replace(const T &t, Args &&...args) { + return insert_impl(OptType::replace, t, false, std::forward(args)...); + } + + template + int replace(const std::vector &t, Args &&...args) { + return insert_impl(OptType::replace, t, false, std::forward(args)...); } template int update(const T &t, Args &&...args) { - return insert_impl(true, t, false, std::forward(args)...); + return update_impl(t, std::forward(args)...); } template int update(const std::vector &t, Args &&...args) { - return insert_impl(true, t, false, std::forward(args)...); + return update_impl(t, std::forward(args)...); } template @@ -617,18 +629,27 @@ class mysql { } template - int stmt_execute(bool update, const T &t) { - reset_error(); + int stmt_execute(const T &t, OptType type, bool condition) { std::vector param_binds; - - iguana::for_each(t, [&t, ¶m_binds, update, this](auto item, auto i) { - if (is_auto_key(iguana::get_name(), iguana::get_name(i).data()) && - !update) { + iguana::for_each(t, [&t, ¶m_binds, type, this](auto item, auto i) { + if (type == OptType::insert && + is_auto_key(iguana::get_name(i).data())) { return; } set_param_bind(param_binds, t.*item); }); + if (condition && type == OptType::update) { + iguana::for_each(t, [&t, ¶m_binds, this](auto item, auto i) { + std::string field_name = "`"; + field_name += iguana::get_name(i).data(); + field_name += "`"; + if (is_conflict_key(field_name)) { + set_param_bind(param_binds, t.*item); + } + }); + } + if (mysql_stmt_bind_param(stmt_, ¶m_binds[0])) { set_last_error(mysql_stmt_error(stmt_)); return INT_MIN; @@ -648,9 +669,38 @@ class mysql { } template - int insert_impl(bool update, const T &t, bool get_insert_id = false, + int insert_impl(OptType type, const T &t, bool get_insert_id, Args &&...args) { - std::string sql = generate_insert_sql(update); + std::string sql = generate_insert_sql(type == OptType::insert); + return insert_or_update_impl(t, sql, type, get_insert_id); + } + + template + int insert_impl(OptType type, const std::vector &v, bool get_insert_id, + Args &&...args) { + std::string sql = generate_insert_sql(type == OptType::insert); + return insert_or_update_impl(v, sql, type, get_insert_id); + } + + template + int update_impl(const T &t, Args &&...args) { + bool condition = true; + std::string sql = + generate_update_sql(condition, std::forward(args)...); + return insert_or_update_impl(t, sql, OptType::update, false, condition); + } + + template + int update_impl(const std::vector &v, Args &&...args) { + bool condition = true; + std::string sql = + generate_update_sql(condition, std::forward(args)...); + return insert_or_update_impl(v, sql, OptType::update, false, condition); + } + + template + int insert_or_update_impl(const T &t, const std::string &sql, OptType type, + bool get_insert_id = false, bool condition = true) { #ifdef ORMPP_ENABLE_LOG std::cout << sql << std::endl; #endif @@ -667,7 +717,7 @@ class mysql { auto guard = guard_statment(stmt_); - if (stmt_execute(update, t) == INT_MIN) { + if (stmt_execute(t, type, condition) == INT_MIN) { set_last_error(mysql_stmt_error(stmt_)); return INT_MIN; } @@ -675,10 +725,10 @@ class mysql { return get_insert_id ? stmt_->mysql->insert_id : 1; } - template - int insert_impl(bool update, const std::vector &t, - bool get_insert_id = false, Args &&...args) { - std::string sql = generate_insert_sql(update); + template + int insert_or_update_impl(const std::vector &v, const std::string &sql, + OptType type, bool get_insert_id = false, + bool condition = true) { #ifdef ORMPP_ENABLE_LOG std::cout << sql << std::endl; #endif @@ -699,14 +749,14 @@ class mysql { return INT_MIN; } - for (auto &item : t) { - if (stmt_execute(update, item) == INT_MIN) { + for (auto &item : v) { + if (stmt_execute(item, type, condition) == INT_MIN) { rollback(); return INT_MIN; } } - return commit() ? (get_insert_id ? stmt_->mysql->insert_id : (int)t.size()) + return commit() ? get_insert_id ? stmt_->mysql->insert_id : (int)v.size() : INT_MIN; } diff --git a/include/postgresql.hpp b/include/postgresql.hpp index 59888d85..98c53603 100644 --- a/include/postgresql.hpp +++ b/include/postgresql.hpp @@ -75,22 +75,32 @@ class postgresql { template constexpr int insert(const T &t, Args &&...args) { - return insert_impl(false, t, std::forward(args)...); + return insert_impl(OptType::insert, t, std::forward(args)...); } template constexpr int insert(const std::vector &v, Args &&...args) { - return insert_impl(false, v, std::forward(args)...); + return insert_impl(OptType::insert, v, std::forward(args)...); + } + + template + constexpr int replace(const T &t, Args &&...args) { + return insert_impl(OptType::replace, t, std::forward(args)...); + } + + template + constexpr int replace(const std::vector &v, Args &&...args) { + return insert_impl(OptType::replace, v, std::forward(args)...); } template constexpr int update(const T &t, Args &&...args) { - return insert_impl(true, t, std::forward(args)...); + return update_impl(t, std::forward(args)...); } template constexpr int update(const std::vector &v, Args &&...args) { - return insert_impl(true, v, std::forward(args)...); + return update_impl(v, std::forward(args)...); } template @@ -359,16 +369,24 @@ class postgresql { } template - constexpr int stmt_execute(bool update, const T &t) { + constexpr int stmt_execute(const T &t, OptType type, bool condition) { std::vector> param_values; - iguana::for_each(t, [&t, ¶m_values, update, this](auto item, auto i) { - if (is_auto_key(iguana::get_name(), iguana::get_name(i).data()) && - !update) { + iguana::for_each(t, [&t, ¶m_values, type, this](auto item, auto i) { + if (type == OptType::insert && + is_auto_key(iguana::get_name(i).data())) { return; } set_param_values(param_values, t.*item); }); + if (condition && type == OptType::update) { + iguana::for_each(t, [&t, ¶m_values, this](auto item, auto i) { + if (is_conflict_key(iguana::get_name(i).data())) { + set_param_values(param_values, t.*item); + } + }); + } + if (param_values.empty()) return INT_MIN; @@ -385,26 +403,51 @@ class postgresql { } template - constexpr int insert_impl(bool update, const T &t, Args &&...args) { + constexpr int insert_impl(OptType type, const T &t, Args &&...args) { + std::string sql = generate_insert_sql(type == OptType::insert, + std::forward(args)...); + return insert_or_update_impl(t, sql, type); + } + + template + constexpr int insert_impl(OptType type, const std::vector &v, + Args &&...args) { + std::string sql = generate_insert_sql(type == OptType::insert, + std::forward(args)...); + return insert_or_update_impl(v, sql, type); + } + + template + int update_impl(const T &t, Args &&...args) { + bool condition = true; + std::string sql = + generate_update_sql(condition, std::forward(args)...); + return insert_or_update_impl(t, sql, OptType::update, condition); + } + + template + int update_impl(const std::vector &v, Args &&...args) { + bool condition = true; std::string sql = - update ? generate_update_sql(std::forward(args)...) - : generate_insert_sql(update); + generate_update_sql(condition, std::forward(args)...); + return insert_or_update_impl(v, sql, OptType::update, condition); + } + + template + int insert_or_update_impl(const T &t, const std::string &sql, OptType type, + bool condition = true) { #ifdef ORMPP_ENABLE_LOG std::cout << sql << std::endl; #endif if (!prepare(sql)) { return INT_MIN; } - - return stmt_execute(update, t); + return stmt_execute(t, type, condition); } - template - constexpr int insert_impl(bool update, const std::vector &v, - Args &&...args) { - std::string sql = - update ? generate_update_sql(std::forward(args)...) - : generate_insert_sql(update); + template + int insert_or_update_impl(const std::vector &v, const std::string &sql, + OptType type, bool condition = true) { #ifdef ORMPP_ENABLE_LOG std::cout << sql << std::endl; #endif @@ -417,7 +460,7 @@ class postgresql { } for (auto &item : v) { - if (stmt_execute(update, item) == INT_MIN) { + if (stmt_execute(item, type, condition) == INT_MIN) { rollback(); return INT_MIN; } @@ -429,45 +472,6 @@ class postgresql { return (int)v.size(); } - template - inline std::string generate_update_sql(Args &&...args) { - constexpr auto SIZE = iguana::get_value(); - std::string sql = "insert into "; - auto name = get_name(); - append(sql, name.data()); - int index = 0; - std::string set; - std::string fields = "("; - std::string values = "values("; - for (auto i = 0; i < SIZE; ++i) { - std::string field_name = iguana::get_name(i).data(); - std::string value = "$" + std::to_string(++index); - set += field_name + "=" + value; - fields += field_name; - values += value; - if (i < SIZE - 1) { - fields += ","; - values += ","; - set += ","; - } - else { - fields += ")"; - values += ")"; - set += ";"; - } - } - std::string conflict = "on conflict("; - if constexpr (sizeof...(Args) > 0) { - append(conflict, args...); - } - else { - conflict += get_conflict_key(iguana::get_name()); - } - conflict += ")"; - append(sql, fields, values, conflict, "do update set", set); - return sql; - } - template constexpr void set_param_values(std::vector> ¶m_values, T &&value) { diff --git a/include/sqlite.hpp b/include/sqlite.hpp index dd85035e..0771f0e8 100644 --- a/include/sqlite.hpp +++ b/include/sqlite.hpp @@ -73,23 +73,35 @@ class sqlite { template int insert(const T &t, bool get_insert_id = false, Args &&...args) { - return insert_impl(false, t, get_insert_id, std::forward(args)...); + return insert_impl(OptType::insert, t, get_insert_id, + std::forward(args)...); } template int insert(const std::vector &t, bool get_insert_id = false, Args &&...args) { - return insert_impl(false, t, get_insert_id, std::forward(args)...); + return insert_impl(OptType::insert, t, get_insert_id, + std::forward(args)...); + } + + template + int replace(const T &t, Args &&...args) { + return insert_impl(OptType::replace, t, false, std::forward(args)...); + } + + template + int replace(const std::vector &t, Args &&...args) { + return insert_impl(OptType::replace, t, false, std::forward(args)...); } template int update(const T &t, Args &&...args) { - return insert_impl(true, t, false, std::forward(args)...); + return update_impl(t, std::forward(args)...); } template int update(const std::vector &t, Args &&...args) { - return insert_impl(true, t, false, std::forward(args)...); + return update_impl(t, std::forward(args)...); } template @@ -357,20 +369,31 @@ class sqlite { } template - constexpr int stmt_execute(bool update, const T &t) { + constexpr int stmt_execute(const T &t, OptType type, bool condition) { int index = 0; bool bind_ok = true; - iguana::for_each(t, [&t, &bind_ok, &index, update, this](auto item, - auto i) { - if ((is_auto_key(iguana::get_name(), iguana::get_name(i).data()) || - !bind_ok) && - !update) { + iguana::for_each(t, [&t, &bind_ok, &index, type, this](auto item, auto i) { + if ((type == OptType::insert && + is_auto_key(iguana::get_name(i).data())) || + !bind_ok) { return; } bind_ok = set_param_bind(t.*item, index + 1); index++; }); + if (condition && type == OptType::update) { + iguana::for_each(t, [&t, &bind_ok, &index, this](auto item, auto i) { + if (!bind_ok) { + return; + } + if (is_conflict_key(iguana::get_name(i).data())) { + bind_ok = set_param_bind(t.*item, index + 1); + index++; + } + }); + } + if (!bind_ok) { set_last_error(sqlite3_errmsg(handle_)); return INT_MIN; @@ -463,9 +486,38 @@ class sqlite { } template - int insert_impl(bool update, const T &t, bool get_insert_id = false, + int insert_impl(OptType type, const T &t, bool get_insert_id, + Args &&...args) { + std::string sql = generate_insert_sql(type == OptType::insert); + return insert_or_update_impl(t, sql, type, get_insert_id); + } + + template + int insert_impl(OptType type, const std::vector &v, bool get_insert_id, Args &&...args) { - std::string sql = generate_insert_sql(update); + std::string sql = generate_insert_sql(type == OptType::insert); + return insert_or_update_impl(v, sql, type, get_insert_id); + } + + template + int update_impl(const T &t, Args &&...args) { + bool condition = true; + std::string sql = + generate_update_sql(condition, std::forward(args)...); + return insert_or_update_impl(t, sql, OptType::update, false, condition); + } + + template + int update_impl(const std::vector &v, Args &&...args) { + bool condition = true; + std::string sql = + generate_update_sql(condition, std::forward(args)...); + return insert_or_update_impl(v, sql, OptType::update, false, condition); + } + + template + int insert_or_update_impl(const T &t, const std::string &sql, OptType type, + bool get_insert_id = false, bool condition = true) { #ifdef ORMPP_ENABLE_LOG std::cout << sql << std::endl; #endif @@ -477,7 +529,7 @@ class sqlite { auto guard = guard_statment(stmt_); - if (stmt_execute(update, t) == INT_MIN) { + if (stmt_execute(t, type, condition) == INT_MIN) { set_last_error(sqlite3_errmsg(handle_)); return INT_MIN; } @@ -485,10 +537,10 @@ class sqlite { return get_insert_id ? sqlite3_last_insert_rowid(handle_) : 1; } - template - int insert_impl(bool update, const std::vector &v, - bool get_insert_id = false, Args &&...args) { - std::string sql = generate_insert_sql(update); + template + int insert_or_update_impl(const std::vector &v, const std::string &sql, + OptType type, bool get_insert_id = false, + bool condition = true) { #ifdef ORMPP_ENABLE_LOG std::cout << sql << std::endl; #endif @@ -505,7 +557,7 @@ class sqlite { } for (auto &item : v) { - if (stmt_execute(update, item) == INT_MIN) { + if (stmt_execute(item, type, condition) == INT_MIN) { rollback(); return INT_MIN; } @@ -517,8 +569,8 @@ class sqlite { } } - return commit() ? (get_insert_id ? sqlite3_last_insert_rowid(handle_) - : (int)v.size()) + return commit() ? get_insert_id ? sqlite3_last_insert_rowid(handle_) + : (int)v.size() : INT_MIN; } diff --git a/include/utility.hpp b/include/utility.hpp index 91201965..3850dde9 100644 --- a/include/utility.hpp +++ b/include/utility.hpp @@ -19,9 +19,10 @@ inline int add_auto_key_field(std::string_view key, std::string_view value) { return 0; } -inline auto is_auto_key(std::string_view key, std::string_view value) { - auto it = g_ormpp_auto_key_map.find(key); - return it == g_ormpp_auto_key_map.end() ? false : it->second == value; +template +inline auto is_auto_key(std::string_view field_name) { + auto it = g_ormpp_auto_key_map.find(iguana::get_name()); + return it == g_ormpp_auto_key_map.end() ? false : it->second == field_name; } #define REGISTER_AUTO_KEY(STRUCT_NAME, KEY) \ @@ -115,6 +116,7 @@ inline auto sort_tuple(const std::tuple &tp) { } } +enum class OptType { insert, update, replace }; enum class DBType { mysql, sqlite, postgresql, unknown }; template @@ -199,9 +201,83 @@ inline std::string get_fields() { return fields; } +template >> +inline std::vector get_conflict_keys() { + static std::vector res; + if (!res.empty()) { + return res; + } + std::stringstream s(get_conflict_key(iguana::get_name()).data()); + while (s.good()) { + std::string str; + getline(s, str, ','); + if (str.front() == ' ') { + str.erase(0); + } + if (str.back() == ' ') { + str.pop_back(); + } +#ifdef ORMPP_ENABLE_MYSQL + str.insert(0, "`"); + str.append("`"); +#endif + res.emplace_back(str); + } + return res; +} + template -inline std::string generate_insert_sql(bool replace) { - std::string sql = replace ? "replace into " : "insert into "; +inline auto is_conflict_key(std::string_view field_name) { + for (const auto &it : get_conflict_keys()) { + if (it == field_name) { + return true; + } + } + return false; +} + +template +inline std::string generate_insert_sql(bool insert, Args &&...args) { +#ifdef ORMPP_ENABLE_PG + if (!insert) { + constexpr auto SIZE = iguana::get_value(); + std::string sql = "insert into "; + auto name = get_name(); + append(sql, name.data()); + int index = 0; + std::string set; + std::string fields = "("; + std::string values = "values("; + for (auto i = 0; i < SIZE; ++i) { + std::string field_name = iguana::get_name(i).data(); + std::string value = "$" + std::to_string(++index); + append(set, field_name, "=", value); + fields += field_name; + values += value; + if (i < SIZE - 1) { + fields += ","; + values += ","; + set += ","; + } + else { + fields += ")"; + values += ")"; + set += ";"; + } + } + std::string conflict = "on conflict("; + if constexpr (sizeof...(Args) > 0) { + append(conflict, args...); + } + else { + conflict += get_conflict_key(iguana::get_name()); + } + conflict += ")"; + append(sql, fields, values, conflict, "do update set", set); + return sql; + } +#endif + std::string sql = insert ? "insert into " : "replace into "; constexpr auto SIZE = iguana::get_value(); auto name = get_name(); append(sql, name.data()); @@ -211,7 +287,7 @@ inline std::string generate_insert_sql(bool replace) { std::string values = "values("; for (size_t i = 0; i < SIZE; ++i) { std::string field_name = iguana::get_name(i).data(); - if (is_auto_key(iguana::get_name(), field_name)) { + if (insert && is_auto_key(field_name)) { continue; } #ifdef ORMPP_ENABLE_PG @@ -245,6 +321,50 @@ inline std::string generate_insert_sql(bool replace) { return sql; } +template +inline std::string generate_update_sql(bool &condition, Args &&...args) { + constexpr auto SIZE = iguana::get_value(); + std::string sql = "update "; + auto name = get_name(); + append(sql, name.data()); + append(sql, "set"); + + int index = 0; + std::string fields; + for (size_t i = 0; i < SIZE; ++i) { + std::string field_name = iguana::get_name(i).data(); +#ifdef ORMPP_ENABLE_MYSQL + fields += "`" + field_name + "`"; +#else + fields += field_name; +#endif +#ifdef ORMPP_ENABLE_PG + append(fields, " =", "$" + std::to_string(++index)); +#else + fields += " = ?"; +#endif + if (i < SIZE - 1) { + fields += ", "; + } + } + std::string conflict = "where 1=1"; + if constexpr (sizeof...(Args) > 0) { + append(conflict, " and", args...); + condition = false; + } + else { + for (const auto &it : get_conflict_keys()) { +#ifdef ORMPP_ENABLE_PG + append(conflict, " and", it, "=", "$" + std::to_string(++index)); +#else + append(conflict, " and", it, "= ?"); +#endif + } + } + append(sql, fields, conflict); + return sql; +} + inline bool is_empty(const std::string &t) { return t.empty(); } template diff --git a/tests/test_ormpp.cpp b/tests/test_ormpp.cpp index 7f192b69..7071597d 100644 --- a/tests/test_ormpp.cpp +++ b/tests/test_ormpp.cpp @@ -494,6 +494,101 @@ TEST_CASE("insert query") { } } +TEST_CASE("update replace") { +#ifdef ORMPP_ENABLE_MYSQL + dbng mysql; + if (mysql.connect(ip, username, password, db)) { + mysql.execute("drop table if exists person"); + mysql.create_datatable(ormpp_auto_key{"id"}); + mysql.insert({"purecpp", 100}); + auto vec = mysql.query(); + CHECK(vec.size() == 1); + vec.front().name = "update"; + vec.front().age = 200; + mysql.update(vec.front()); + vec = mysql.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "update"); + CHECK(vec.front().age == 200); + mysql.update({"purecpp", 100, 1}, "id=1"); + vec = mysql.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "purecpp"); + CHECK(vec.front().age == 100); + vec.front().name = "update"; + vec.front().age = 200; + mysql.replace(vec.front()); + vec = mysql.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "update"); + CHECK(vec.front().age == 200); + } +#endif +#ifdef ORMPP_ENABLE_PG + dbng postgres; + if (postgres.connect(ip, username, password, db)) { + postgres.execute("drop table if exists person"); + postgres.create_datatable(ormpp_auto_key{"id"}); + postgres.insert({"purecpp", 100}); + auto vec = postgres.query(); + CHECK(vec.size() == 1); + vec.front().name = "update"; + vec.front().age = 200; + postgres.update(vec.front()); + vec = postgres.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "update"); + CHECK(vec.front().age == 200); + postgres.update({"purecpp", 100, 1}, "id=1"); + vec = postgres.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "purecpp"); + CHECK(vec.front().age == 100); + vec.front().name = "update"; + vec.front().age = 200; + postgres.replace(vec.front()); + vec = postgres.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "update"); + CHECK(vec.front().age == 200); + postgres.replace({"purecpp", 100, 1}, "id"); + vec = postgres.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "purecpp"); + CHECK(vec.front().age == 100); + } +#endif +#ifdef ORMPP_ENABLE_SQLITE3 + dbng sqlite; + if (sqlite.connect(db)) { + sqlite.execute("drop table if exists person"); + sqlite.create_datatable(ormpp_auto_key{"id"}); + sqlite.insert({"purecpp", 100}); + auto vec = sqlite.query(); + CHECK(vec.size() == 1); + vec.front().name = "update"; + vec.front().age = 200; + sqlite.update(vec.front()); + vec = sqlite.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "update"); + CHECK(vec.front().age == 200); + sqlite.update({"purecpp", 100, 1}, "id=1"); + vec = sqlite.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "purecpp"); + CHECK(vec.front().age == 100); + vec.front().name = "update"; + vec.front().age = 200; + sqlite.replace(vec.front()); + vec = sqlite.query(); + CHECK(vec.size() == 1); + CHECK(vec.front().name == "update"); + CHECK(vec.front().age == 200); + } +#endif +} + TEST_CASE("update") { ormpp_key key{"code"}; ormpp_not_null not_null{{"code", "age"}};