diff --git a/include/SQLiteCpp/Column.h b/include/SQLiteCpp/Column.h index 0163d791..bc349f98 100644 --- a/include/SQLiteCpp/Column.h +++ b/include/SQLiteCpp/Column.h @@ -14,8 +14,11 @@ #include #include +#include #include // For INT_MAX +// Forward declarations to avoid inclusion of in a header +struct sqlite3_stmt; namespace SQLite { @@ -26,7 +29,6 @@ extern const int TEXT; ///< SQLITE_TEXT extern const int BLOB; ///< SQLITE_BLOB extern const int Null; ///< SQLITE_NULL - /** * @brief Encapsulation of a Column in a row of the result pointed by the prepared Statement. * @@ -52,7 +54,7 @@ class Column * @param[in] aStmtPtr Shared pointer to the prepared SQLite Statement Object. * @param[in] aIndex Index of the column in the row of result, starting at 0 */ - Column(Statement::Ptr& aStmtPtr, int aIndex) noexcept; + explicit Column(const Statement::TStatementPtr& aStmtPtr, int aIndex); // default destructor: the finalization will be done by the destructor of the last shared pointer // default copy constructor and assignment operator are perfectly suited : @@ -250,8 +252,8 @@ class Column } private: - Statement::Ptr mStmtPtr; ///< Shared Pointer to the prepared SQLite Statement Object - int mIndex; ///< Index of the column in the row of result, starting at 0 + Statement::TStatementPtr mStmtPtr; ///< Shared Pointer to the prepared SQLite Statement Object + int mIndex; ///< Index of the column in the row of result, starting at 0 }; /** @@ -281,7 +283,7 @@ T Statement::getColumns() template T Statement::getColumns(const std::integer_sequence) { - return T{Column(mStmtPtr, Is)...}; + return T{Column(mpPreparedStatement, Is)...}; } #endif diff --git a/include/SQLiteCpp/Statement.h b/include/SQLiteCpp/Statement.h index 15d53184..c81e215d 100644 --- a/include/SQLiteCpp/Statement.h +++ b/include/SQLiteCpp/Statement.h @@ -15,7 +15,7 @@ #include #include -#include // For INT_MAX +#include // Forward declarations to avoid inclusion of in a header struct sqlite3; @@ -51,8 +51,6 @@ extern const int OK; ///< SQLITE_OK */ class Statement { - friend class Column; // For access to Statement::Ptr inner class - public: /** * @brief Compile and register the SQL query for the provided SQLite Database Connection @@ -62,7 +60,7 @@ class Statement * * Exception is thrown in case of error, then the Statement object is NOT constructed. */ - Statement(Database& aDatabase, const char* apQuery); + Statement(const Database& aDatabase, const char* apQuery); /** * @brief Compile and register the SQL query for the provided SQLite Database Connection @@ -72,7 +70,7 @@ class Statement * * Exception is thrown in case of error, then the Statement object is NOT constructed. */ - Statement(Database &aDatabase, const std::string& aQuery) : + Statement(const Database& aDatabase, const std::string& aQuery) : Statement(aDatabase, aQuery.c_str()) {} @@ -82,6 +80,7 @@ class Statement * @param[in] aStatement Statement to move */ Statement(Statement&& aStatement) noexcept; + Statement& operator=(Statement&& aStatement) noexcept = default; // Statement is non-copyable Statement(const Statement&) = delete; @@ -123,39 +122,20 @@ class Statement // => if you know what you are doing, use bindNoCopy() instead of bind() SQLITECPP_PURE_FUNC - int getIndex(const char * const apName); + int getIndex(const char * const apName) const; /** * @brief Bind an int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const int aIndex, const int aValue); + void bind(const int aIndex, const int32_t aValue); /** * @brief Bind a 32bits unsigned int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const int aIndex, const unsigned aValue); - -#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW) - /** - * @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) - */ - void bind(const int aIndex, const long aValue) - { - bind(aIndex, static_cast(aValue)); - } -#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS) - /** - * @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) - */ - void bind(const int aIndex, const long aValue) - { - bind(aIndex, static_cast(aValue)); - } -#endif - + void bind(const int aIndex, const uint32_t aValue); /** * @brief Bind a 64bits int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const int aIndex, const long long aValue); + void bind(const int aIndex, const int64_t aValue); /** * @brief Bind a double (64bits float) value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ @@ -210,39 +190,21 @@ class Statement /** * @brief Bind an int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const char* apName, const int aValue) + void bind(const char* apName, const int32_t aValue) { bind(getIndex(apName), aValue); } /** * @brief Bind a 32bits unsigned int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const char* apName, const unsigned aValue) + void bind(const char* apName, const uint32_t aValue) { bind(getIndex(apName), aValue); } - -#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW) - /** - * @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) - */ - void bind(const char* apName, const long aValue) - { - bind(apName, static_cast(aValue)); - } -#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS) - /** - * @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) - */ - void bind(const char* apName, const long aValue) - { - bind(apName, static_cast(aValue)); - } -#endif /** * @brief Bind a 64bits int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const char* apName, const long long aValue) + void bind(const char* apName, const int64_t aValue) { bind(getIndex(apName), aValue); } @@ -325,46 +287,28 @@ class Statement /** * @brief Bind an int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const std::string& aName, const int aValue) + void bind(const std::string& aName, const int32_t aValue) { bind(aName.c_str(), aValue); } /** * @brief Bind a 32bits unsigned int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const std::string& aName, const unsigned aValue) + void bind(const std::string& aName, const uint32_t aValue) { bind(aName.c_str(), aValue); } - -#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW) - /** - * @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) - */ - void bind(const std::string& aName, const long aValue) - { - bind(aName.c_str(), static_cast(aValue)); - } -#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS) - /** - * @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) - */ - void bind(const std::string& aName, const long aValue) - { - bind(aName.c_str(), static_cast(aValue)); - } -#endif /** * @brief Bind a 64bits int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const std::string& aName, const long long aValue) + void bind(const std::string& aName, const int64_t aValue) { bind(aName.c_str(), aValue); } /** * @brief Bind a double (64bits float) value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1) */ - void bind(const std::string& aName, const double aValue) + void bind(const std::string& aName, const double aValue) { bind(aName.c_str(), aValue); } @@ -519,7 +463,7 @@ class Statement * Thus, you should instead extract immediately its data (getInt(), getText()...) * and use or copy this data for any later usage. */ - Column getColumn(const int aIndex); + Column getColumn(const int aIndex) const; /** * @brief Return a copy of the column data specified by its column name (less efficient than using an index) @@ -550,7 +494,7 @@ class Statement * * Throw an exception if the specified name is not one of the aliased name of the columns in the result. */ - Column getColumn(const char* apName); + Column getColumn(const char* apName) const; #if __cplusplus >= 201402L || (defined(_MSC_VER) && _MSC_VER >= 1900) // c++14: Visual Studio 2015 /** @@ -673,7 +617,7 @@ class Statement } // Return a UTF-8 string containing the SQL text of prepared statement with bound parameters expanded. - std::string getExpandedSQL(); + std::string getExpandedSQL() const; /// Return the number of columns in the result set returned by the prepared statement int getColumnCount() const @@ -701,52 +645,8 @@ class Statement /// Return UTF-8 encoded English language explanation of the most recent failed API call (if any). const char* getErrorMsg() const noexcept; -private: - /** - * @brief Shared pointer to the sqlite3_stmt SQLite Statement Object. - * - * Manage the finalization of the sqlite3_stmt with a reference counter. - * - * This is a internal class, not part of the API (hence full documentation is in the cpp). - */ - // TODO Convert this whole custom pointer to a C++11 std::shared_ptr with a custom deleter - class Ptr - { - public: - // Prepare the statement and initialize its reference counter - Ptr(sqlite3* apSQLite, std::string& aQuery); - // Copy constructor increments the ref counter - Ptr(const Ptr& aPtr); - - // Move constructor - Ptr(Ptr&& aPtr); - - // Decrement the ref counter and finalize the sqlite3_stmt when it reaches 0 - ~Ptr(); - - /// Inline cast operator returning the pointer to SQLite Database Connection Handle - operator sqlite3*() const - { - return mpSQLite; - } - - /// Inline cast operator returning the pointer to SQLite Statement Object - operator sqlite3_stmt*() const - { - return mpStmt; - } - - private: - /// @{ Unused/forbidden copy/assignment operator - Ptr& operator=(const Ptr& aPtr); - /// @} - - private: - sqlite3* mpSQLite; //!< Pointer to SQLite Database Connection Handle - sqlite3_stmt* mpStmt; //!< Pointer to SQLite Statement Object - unsigned int* mpRefCount; //!< Pointer to the heap allocated reference counter of the sqlite3_stmt - //!< (to share it with Column objects) - }; + /// Shared pointer to SQLite Prepared Statement Object + using TStatementPtr = std::shared_ptr; private: /** @@ -758,7 +658,7 @@ class Statement { if (SQLite::OK != aRet) { - throw SQLite::Exception(mStmtPtr, aRet); + throw SQLite::Exception(mpSQLite, aRet); } } @@ -784,17 +684,30 @@ class Statement } } -private: - /// Map of columns index by name (mutable so getColumnIndex can be const) - typedef std::map TColumnNames; + /** + * @brief Prepare statement object. + * + * @return Shared pointer to prepared statement object + */ + TStatementPtr prepareStatement(); -private: - std::string mQuery; //!< UTF-8 SQL Query - Ptr mStmtPtr; //!< Shared Pointer to the prepared SQLite Statement Object - int mColumnCount; //!< Number of columns in the result of the prepared statement - mutable TColumnNames mColumnNames; //!< Map of columns index by name (mutable so getColumnIndex can be const) - bool mbHasRow; //!< true when a row has been fetched with executeStep() - bool mbDone; //!< true when the last executeStep() had no more row to fetch + /** + * @brief Return a prepared statement object. + * + * Throw an exception if the statement object was not prepared. + * @return raw pointer to Prepared Statement Object + */ + sqlite3_stmt* getPreparedStatement() const; + + std::string mQuery; //!< UTF-8 SQL Query + sqlite3* mpSQLite; //!< Pointer to SQLite Database Connection Handle + TStatementPtr mpPreparedStatement; //!< Shared Pointer to the prepared SQLite Statement Object + int mColumnCount{0}; //!< Number of columns in the result of the prepared statement + bool mbHasRow{false}; //!< true when a row has been fetched with executeStep() + bool mbDone{false}; //!< true when the last executeStep() had no more row to fetch + + /// Map of columns index by name (mutable so getColumnIndex can be const) + mutable std::map mColumnNames{}; }; diff --git a/src/Column.cpp b/src/Column.cpp index 1a8005f2..f5dc0d98 100644 --- a/src/Column.cpp +++ b/src/Column.cpp @@ -26,30 +26,34 @@ const int Null = SQLITE_NULL; // Encapsulation of a Column in a row of the result pointed by the prepared Statement. -Column::Column(Statement::Ptr& aStmtPtr, int aIndex) noexcept : +Column::Column(const Statement::TStatementPtr& aStmtPtr, int aIndex) : mStmtPtr(aStmtPtr), mIndex(aIndex) { + if (!aStmtPtr) + { + throw SQLite::Exception("Statement was destroyed"); + } } // Return the named assigned to this result column (potentially aliased) const char* Column::getName() const noexcept { - return sqlite3_column_name(mStmtPtr, mIndex); + return sqlite3_column_name(mStmtPtr.get(), mIndex); } #ifdef SQLITE_ENABLE_COLUMN_METADATA // Return the name of the table column that is the origin of this result column const char* Column::getOriginName() const noexcept { - return sqlite3_column_origin_name(mStmtPtr, mIndex); + return sqlite3_column_origin_name(mStmtPtr.get(), mIndex); } #endif // Return the integer value of the column specified by its index starting at 0 int Column::getInt() const noexcept { - return sqlite3_column_int(mStmtPtr, mIndex); + return sqlite3_column_int(mStmtPtr.get(), mIndex); } // Return the unsigned integer value of the column specified by its index starting at 0 @@ -61,26 +65,26 @@ unsigned Column::getUInt() const noexcept // Return the 64bits integer value of the column specified by its index starting at 0 long long Column::getInt64() const noexcept { - return sqlite3_column_int64(mStmtPtr, mIndex); + return sqlite3_column_int64(mStmtPtr.get(), mIndex); } // Return the double value of the column specified by its index starting at 0 double Column::getDouble() const noexcept { - return sqlite3_column_double(mStmtPtr, mIndex); + return sqlite3_column_double(mStmtPtr.get(), mIndex); } // Return a pointer to the text value (NULL terminated string) of the column specified by its index starting at 0 const char* Column::getText(const char* apDefaultValue /* = "" */) const noexcept { - const char* pText = reinterpret_cast(sqlite3_column_text(mStmtPtr, mIndex)); + auto pText = reinterpret_cast(sqlite3_column_text(mStmtPtr.get(), mIndex)); return (pText?pText:apDefaultValue); } // Return a pointer to the blob value (*not* NULL terminated) of the column specified by its index starting at 0 const void* Column::getBlob() const noexcept { - return sqlite3_column_blob(mStmtPtr, mIndex); + return sqlite3_column_blob(mStmtPtr.get(), mIndex); } // Return a std::string to a TEXT or BLOB column @@ -88,23 +92,23 @@ std::string Column::getString() const { // Note: using sqlite3_column_blob and not sqlite3_column_text // - no need for sqlite3_column_text to add a \0 on the end, as we're getting the bytes length directly - const char *data = static_cast(sqlite3_column_blob(mStmtPtr, mIndex)); + auto data = static_cast(sqlite3_column_blob(mStmtPtr.get(), mIndex)); // SQLite docs: "The safest policy is to invokeā€¦ sqlite3_column_blob() followed by sqlite3_column_bytes()" // Note: std::string is ok to pass nullptr as first arg, if length is 0 - return std::string(data, sqlite3_column_bytes(mStmtPtr, mIndex)); + return std::string(data, sqlite3_column_bytes(mStmtPtr.get(), mIndex)); } // Return the type of the value of the column int Column::getType() const noexcept { - return sqlite3_column_type(mStmtPtr, mIndex); + return sqlite3_column_type(mStmtPtr.get(), mIndex); } // Return the number of bytes used by the text value of the column int Column::getBytes() const noexcept { - return sqlite3_column_bytes(mStmtPtr, mIndex); + return sqlite3_column_bytes(mStmtPtr.get(), mIndex); } // Standard std::ostream inserter diff --git a/src/Statement.cpp b/src/Statement.cpp index 0b9b3e5b..64e16d2d 100644 --- a/src/Statement.cpp +++ b/src/Statement.cpp @@ -20,23 +20,24 @@ namespace SQLite { -Statement::Statement(Database &aDatabase, const char* apQuery) : +Statement::Statement(const Database& aDatabase, const char* apQuery) : mQuery(apQuery), - mStmtPtr(aDatabase.getHandle(), mQuery), // prepare the SQL query, and ref count (needs Database friendship) - mColumnCount(0), - mbHasRow(false), - mbDone(false) + mpSQLite(aDatabase.getHandle()), + mpPreparedStatement(prepareStatement()) // prepare the SQL query (needs Database friendship) { - mColumnCount = sqlite3_column_count(mStmtPtr); + mColumnCount = sqlite3_column_count(mpPreparedStatement.get()); } Statement::Statement(Statement&& aStatement) noexcept : mQuery(std::move(aStatement.mQuery)), - mStmtPtr(std::move(aStatement.mStmtPtr)), + mpSQLite(aStatement.mpSQLite), + mpPreparedStatement(std::move(aStatement.mpPreparedStatement)), mColumnCount(aStatement.mColumnCount), mbHasRow(aStatement.mbHasRow), - mbDone(aStatement.mbDone) + mbDone(aStatement.mbDone), + mColumnNames(std::move(aStatement.mColumnNames)) { + aStatement.mpSQLite = nullptr; aStatement.mColumnCount = 0; aStatement.mbHasRow = false; aStatement.mbDone = false; @@ -53,53 +54,53 @@ int Statement::tryReset() noexcept { mbHasRow = false; mbDone = false; - return sqlite3_reset(mStmtPtr); + return sqlite3_reset(mpPreparedStatement.get()); } // Clears away all the bindings of a prepared statement (can be associated with #reset() above). void Statement::clearBindings() { - const int ret = sqlite3_clear_bindings(mStmtPtr); + const int ret = sqlite3_clear_bindings(getPreparedStatement()); check(ret); } -int Statement::getIndex(const char * const apName) +int Statement::getIndex(const char * const apName) const { - return sqlite3_bind_parameter_index(mStmtPtr, apName); + return sqlite3_bind_parameter_index(getPreparedStatement(), apName); } -// Bind an int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement -void Statement::bind(const int aIndex, const int aValue) +// Bind an 32bits int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement +void Statement::bind(const int aIndex, const int32_t aValue) { - const int ret = sqlite3_bind_int(mStmtPtr, aIndex, aValue); + const int ret = sqlite3_bind_int(getPreparedStatement(), aIndex, aValue); check(ret); } // Bind a 32bits unsigned int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement -void Statement::bind(const int aIndex, const unsigned aValue) +void Statement::bind(const int aIndex, const uint32_t aValue) { - const int ret = sqlite3_bind_int64(mStmtPtr, aIndex, aValue); + const int ret = sqlite3_bind_int64(getPreparedStatement(), aIndex, aValue); check(ret); } // Bind a 64bits int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement -void Statement::bind(const int aIndex, const long long aValue) +void Statement::bind(const int aIndex, const int64_t aValue) { - const int ret = sqlite3_bind_int64(mStmtPtr, aIndex, aValue); + const int ret = sqlite3_bind_int64(getPreparedStatement(), aIndex, aValue); check(ret); } // Bind a double (64bits float) value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bind(const int aIndex, const double aValue) { - const int ret = sqlite3_bind_double(mStmtPtr, aIndex, aValue); + const int ret = sqlite3_bind_double(getPreparedStatement(), aIndex, aValue); check(ret); } // Bind a string value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bind(const int aIndex, const std::string& aValue) { - const int ret = sqlite3_bind_text(mStmtPtr, aIndex, aValue.c_str(), + const int ret = sqlite3_bind_text(getPreparedStatement(), aIndex, aValue.c_str(), static_cast(aValue.size()), SQLITE_TRANSIENT); check(ret); } @@ -107,21 +108,21 @@ void Statement::bind(const int aIndex, const std::string& aValue) // Bind a text value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bind(const int aIndex, const char* apValue) { - const int ret = sqlite3_bind_text(mStmtPtr, aIndex, apValue, -1, SQLITE_TRANSIENT); + const int ret = sqlite3_bind_text(getPreparedStatement(), aIndex, apValue, -1, SQLITE_TRANSIENT); check(ret); } // Bind a binary blob value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bind(const int aIndex, const void* apValue, const int aSize) { - const int ret = sqlite3_bind_blob(mStmtPtr, aIndex, apValue, aSize, SQLITE_TRANSIENT); + const int ret = sqlite3_bind_blob(getPreparedStatement(), aIndex, apValue, aSize, SQLITE_TRANSIENT); check(ret); } // Bind a string value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bindNoCopy(const int aIndex, const std::string& aValue) { - const int ret = sqlite3_bind_text(mStmtPtr, aIndex, aValue.c_str(), + const int ret = sqlite3_bind_text(getPreparedStatement(), aIndex, aValue.c_str(), static_cast(aValue.size()), SQLITE_STATIC); check(ret); } @@ -129,21 +130,21 @@ void Statement::bindNoCopy(const int aIndex, const std::string& aValue) // Bind a text value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bindNoCopy(const int aIndex, const char* apValue) { - const int ret = sqlite3_bind_text(mStmtPtr, aIndex, apValue, -1, SQLITE_STATIC); + const int ret = sqlite3_bind_text(getPreparedStatement(), aIndex, apValue, -1, SQLITE_STATIC); check(ret); } // Bind a binary blob value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bindNoCopy(const int aIndex, const void* apValue, const int aSize) { - const int ret = sqlite3_bind_blob(mStmtPtr, aIndex, apValue, aSize, SQLITE_STATIC); + const int ret = sqlite3_bind_blob(getPreparedStatement(), aIndex, apValue, aSize, SQLITE_STATIC); check(ret); } // Bind a NULL value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement void Statement::bind(const int aIndex) { - const int ret = sqlite3_bind_null(mStmtPtr, aIndex); + const int ret = sqlite3_bind_null(getPreparedStatement(), aIndex); check(ret); } @@ -154,9 +155,9 @@ bool Statement::executeStep() const int ret = tryExecuteStep(); if ((SQLITE_ROW != ret) && (SQLITE_DONE != ret)) // on row or no (more) row ready, else it's a problem { - if (ret == sqlite3_errcode(mStmtPtr)) + if (ret == sqlite3_errcode(mpSQLite)) { - throw SQLite::Exception(mStmtPtr, ret); + throw SQLite::Exception(mpSQLite, ret); } else { @@ -177,9 +178,9 @@ int Statement::exec() { throw SQLite::Exception("exec() does not expect results. Use executeStep."); } - else if (ret == sqlite3_errcode(mStmtPtr)) + else if (ret == sqlite3_errcode(mpSQLite)) { - throw SQLite::Exception(mStmtPtr, ret); + throw SQLite::Exception(mpSQLite, ret); } else { @@ -188,59 +189,50 @@ int Statement::exec() } // Return the number of rows modified by those SQL statements (INSERT, UPDATE or DELETE) - return sqlite3_changes(mStmtPtr); + return sqlite3_changes(mpSQLite); } int Statement::tryExecuteStep() noexcept { - if (false == mbDone) + if (mbDone) { - const int ret = sqlite3_step(mStmtPtr); - if (SQLITE_ROW == ret) // one row is ready : call getColumn(N) to access it - { - mbHasRow = true; - } - else if (SQLITE_DONE == ret) // no (more) row ready : the query has finished executing - { - mbHasRow = false; - mbDone = true; - } - else - { - mbHasRow = false; - mbDone = false; - } + return SQLITE_MISUSE; // Statement needs to be reseted ! + } - return ret; + const int ret = sqlite3_step(mpPreparedStatement.get()); + if (SQLITE_ROW == ret) // one row is ready : call getColumn(N) to access it + { + mbHasRow = true; } else { - // Statement needs to be reseted ! - return SQLITE_MISUSE; + mbHasRow = false; + mbDone = SQLITE_DONE == ret; // check if the query has finished executing } + return ret; } // Return a copy of the column data specified by its index starting at 0 // (use the Column copy-constructor) -Column Statement::getColumn(const int aIndex) +Column Statement::getColumn(const int aIndex) const { checkRow(); checkIndex(aIndex); // Share the Statement Object handle with the new Column created - return Column(mStmtPtr, aIndex); + return Column(mpPreparedStatement, aIndex); } // Return a copy of the column data specified by its column name starting at 0 // (use the Column copy-constructor) -Column Statement::getColumn(const char* apName) +Column Statement::getColumn(const char* apName) const { checkRow(); const int index = getColumnIndex(apName); // Share the Statement Object handle with the new Column created - return Column(mStmtPtr, index); + return Column(mpPreparedStatement, index); } // Test if the column is NULL @@ -248,21 +240,21 @@ bool Statement::isColumnNull(const int aIndex) const { checkRow(); checkIndex(aIndex); - return (SQLITE_NULL == sqlite3_column_type(mStmtPtr, aIndex)); + return (SQLITE_NULL == sqlite3_column_type(getPreparedStatement(), aIndex)); } bool Statement::isColumnNull(const char* apName) const { checkRow(); const int index = getColumnIndex(apName); - return (SQLITE_NULL == sqlite3_column_type(mStmtPtr, index)); + return (SQLITE_NULL == sqlite3_column_type(getPreparedStatement(), index)); } // Return the named assigned to the specified result column (potentially aliased) const char* Statement::getColumnName(const int aIndex) const { checkIndex(aIndex); - return sqlite3_column_name(mStmtPtr, aIndex); + return sqlite3_column_name(getPreparedStatement(), aIndex); } #ifdef SQLITE_ENABLE_COLUMN_METADATA @@ -270,7 +262,7 @@ const char* Statement::getColumnName(const int aIndex) const const char* Statement::getColumnOriginName(const int aIndex) const { checkIndex(aIndex); - return sqlite3_column_origin_name(mStmtPtr, aIndex); + return sqlite3_column_origin_name(getPreparedStatement(), aIndex); } #endif @@ -282,24 +274,24 @@ int Statement::getColumnIndex(const char* apName) const { for (int i = 0; i < mColumnCount; ++i) { - const char* pName = sqlite3_column_name(mStmtPtr, i); + const char* pName = sqlite3_column_name(getPreparedStatement(), i); mColumnNames[pName] = i; } } - const TColumnNames::const_iterator iIndex = mColumnNames.find(apName); + const auto iIndex = mColumnNames.find(apName); if (iIndex == mColumnNames.end()) { throw SQLite::Exception("Unknown column name."); } - return (*iIndex).second; + return iIndex->second; } const char * Statement::getColumnDeclaredType(const int aIndex) const { checkIndex(aIndex); - const char * result = sqlite3_column_decltype(mStmtPtr, aIndex); + const char * result = sqlite3_column_decltype(getPreparedStatement(), aIndex); if (!result) { throw SQLite::Exception("Could not determine declared column type."); @@ -313,119 +305,64 @@ const char * Statement::getColumnDeclaredType(const int aIndex) const // Get number of rows modified by last INSERT, UPDATE or DELETE statement (not DROP table). int Statement::getChanges() const noexcept { - return sqlite3_changes(mStmtPtr); + return sqlite3_changes(mpSQLite); } int Statement::getBindParameterCount() const noexcept { - return sqlite3_bind_parameter_count(mStmtPtr); + return sqlite3_bind_parameter_count(mpPreparedStatement.get()); } // Return the numeric result code for the most recent failed API call (if any). int Statement::getErrorCode() const noexcept { - return sqlite3_errcode(mStmtPtr); + return sqlite3_errcode(mpSQLite); } // Return the extended numeric result code for the most recent failed API call (if any). int Statement::getExtendedErrorCode() const noexcept { - return sqlite3_extended_errcode(mStmtPtr); + return sqlite3_extended_errcode(mpSQLite); } // Return UTF-8 encoded English language explanation of the most recent failed API call (if any). const char* Statement::getErrorMsg() const noexcept { - return sqlite3_errmsg(mStmtPtr); + return sqlite3_errmsg(mpSQLite); } // Return a UTF-8 string containing the SQL text of prepared statement with bound parameters expanded. -std::string Statement::getExpandedSQL() { - char* expanded = sqlite3_expanded_sql(mStmtPtr); +std::string Statement::getExpandedSQL() const { + char* expanded = sqlite3_expanded_sql(getPreparedStatement()); std::string expandedString(expanded); sqlite3_free(expanded); return expandedString; } -//////////////////////////////////////////////////////////////////////////////// -// Internal class : shared pointer to the sqlite3_stmt SQLite Statement Object -//////////////////////////////////////////////////////////////////////////////// - -/** - * @brief Prepare the statement and initialize its reference counter - * - * @param[in] apSQLite The sqlite3 database connexion - * @param[in] aQuery The SQL query string to prepare - */ -Statement::Ptr::Ptr(sqlite3* apSQLite, std::string& aQuery) : - mpSQLite(apSQLite), - mpStmt(NULL), - mpRefCount(NULL) +// Prepare SQLite statement object and return shared pointer to this object +Statement::TStatementPtr Statement::prepareStatement() { - const int ret = sqlite3_prepare_v2(apSQLite, aQuery.c_str(), static_cast(aQuery.size()), &mpStmt, NULL); + sqlite3_stmt* statement; + const int ret = sqlite3_prepare_v2(mpSQLite, mQuery.c_str(), static_cast(mQuery.size()), &statement, nullptr); if (SQLITE_OK != ret) { - throw SQLite::Exception(apSQLite, ret); + throw SQLite::Exception(mpSQLite, ret); } - // Initialize the reference counter of the sqlite3_stmt : - // used to share the mStmtPtr between Statement and Column objects; - // This is needed to enable Column objects to live longer than the Statement objet it refers to. - mpRefCount = new unsigned int(1); // NOLINT(readability/casting) -} - -/** - * @brief Copy constructor increments the ref counter - * - * @param[in] aPtr Pointer to copy - */ -Statement::Ptr::Ptr(const Statement::Ptr& aPtr) : - mpSQLite(aPtr.mpSQLite), - mpStmt(aPtr.mpStmt), - mpRefCount(aPtr.mpRefCount) -{ - assert(mpRefCount); - assert(0 != *mpRefCount); - - // Increment the reference counter of the sqlite3_stmt, - // asking not to finalize the sqlite3_stmt during the lifetime of the new objet - ++(*mpRefCount); -} - -Statement::Ptr::Ptr(Ptr&& aPtr) : - mpSQLite(aPtr.mpSQLite), - mpStmt(aPtr.mpStmt), - mpRefCount(aPtr.mpRefCount) -{ - aPtr.mpSQLite = nullptr; - aPtr.mpStmt = nullptr; - aPtr.mpRefCount = nullptr; + return Statement::TStatementPtr(statement, [](sqlite3_stmt* stmt) + { + sqlite3_finalize(stmt); + }); } -/** - * @brief Decrement the ref counter and finalize the sqlite3_stmt when it reaches 0 - */ -Statement::Ptr::~Ptr() +// Return prepered statement object or throw +sqlite3_stmt* Statement::getPreparedStatement() const { - if (mpRefCount) + sqlite3_stmt* ret = mpPreparedStatement.get(); + if (ret) { - assert(0 != *mpRefCount); - - // Decrement and check the reference counter of the sqlite3_stmt - --(*mpRefCount); - if (0 == *mpRefCount) - { - // If count reaches zero, finalize the sqlite3_stmt, as no Statement nor Column objet use it anymore. - // No need to check the return code, as it is the same as the last statement evaluation. - sqlite3_finalize(mpStmt); - - // and delete the reference counter - delete mpRefCount; - mpRefCount = nullptr; - mpStmt = nullptr; - } - // else, the finalization will be done later, by the last object + return ret; } + throw SQLite::Exception("Statement was not prepared."); } - } // namespace SQLite diff --git a/tests/Column_test.cpp b/tests/Column_test.cpp index bf31c23c..adc26bd4 100644 --- a/tests/Column_test.cpp +++ b/tests/Column_test.cpp @@ -241,3 +241,39 @@ TEST(Column, stream) std::string content = ss.str(); EXPECT_EQ(content, str); } + +TEST(Column, shared_ptr) +{ + // Create a new database + SQLite::Database db(":memory:", SQLite::OPEN_READWRITE|SQLite::OPEN_CREATE); + EXPECT_EQ(0, db.exec("CREATE TABLE test (id INTEGER PRIMARY KEY, msg TEXT)")); + EXPECT_EQ(1, db.exec(R"(INSERT INTO test VALUES (42, "fortytwo"))")); + const char* query_str = "SELECT id, msg FROM test"; + + std::unique_ptr query{ new SQLite::Statement(db, query_str) }; + query->executeStep(); + + auto column0 = query->getColumn(0); + auto column1 = query->getColumn(1); + query.reset(); + + EXPECT_EQ(42, column0.getInt()); + EXPECT_STREQ("fortytwo", column1.getText()); + + query.reset(new SQLite::Statement(db, query_str)); + query->executeStep(); + column0 = query->getColumn(0); + EXPECT_EQ(true, column0.isInteger()); + query->executeStep(); // query is done + + // Undefined behavior + // auto x = column0.getInt(); + + query.reset(); + + // Undefined behavior + // auto x = column0.getInt(); + // bool isInt = column0.isInteger(); + + EXPECT_STREQ("id", column0.getName()); +} diff --git a/tests/Statement_test.cpp b/tests/Statement_test.cpp index fbaec4b3..8a79a18f 100644 --- a/tests/Statement_test.cpp +++ b/tests/Statement_test.cpp @@ -119,7 +119,6 @@ TEST(Statement, moveConstructor) EXPECT_EQ(2, query.getColumnCount()); SQLite::Statement moved = std::move(query); EXPECT_TRUE(query.getQuery().empty()); - EXPECT_EQ(0, query.getColumnCount()); EXPECT_FALSE(moved.getQuery().empty()); EXPECT_EQ(2, moved.getColumnCount()); // Execute @@ -128,6 +127,16 @@ TEST(Statement, moveConstructor) EXPECT_FALSE(moved.isDone()); EXPECT_FALSE(query.hasRow()); EXPECT_FALSE(query.isDone()); + + // Const statement lookup + const auto const_query = std::move(moved); + auto index = const_query.getColumnIndex("value"); + EXPECT_EQ(1, index); + EXPECT_NO_THROW(const_query.getColumn(index)); + + // Moved statements should throw + EXPECT_THROW(query.getColumnIndex("value"), SQLite::Exception); + EXPECT_THROW(query.getColumn(index), SQLite::Exception); } #endif