Skip to content

Commit

Permalink
Merge pull request #349 from Kacperos155/refactoring-Statement&Column
Browse files Browse the repository at this point in the history
Refactoring of Statement and Column classes
  • Loading branch information
SRombauts authored Mar 29, 2022
2 parents 454a2e2 + c5b3aa8 commit 9158225
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 289 deletions.
12 changes: 7 additions & 5 deletions include/SQLiteCpp/Column.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@
#include <SQLiteCpp/Exception.h>

#include <string>
#include <memory>
#include <climits> // For INT_MAX

// Forward declarations to avoid inclusion of <sqlite3.h> in a header
struct sqlite3_stmt;

namespace SQLite
{
Expand All @@ -26,7 +29,6 @@ extern const int TEXT; ///< SQLITE_TEXT
extern const int BLOB; ///< SQLITE_BLOB
extern const int Null; ///< SQLITE_NULL


/**
* @brief Encapsulation of a Column in a row of the result pointed by the prepared Statement.
*
Expand All @@ -52,7 +54,7 @@ class Column
* @param[in] aStmtPtr Shared pointer to the prepared SQLite Statement Object.
* @param[in] aIndex Index of the column in the row of result, starting at 0
*/
Column(Statement::Ptr& aStmtPtr, int aIndex) noexcept;
explicit Column(const Statement::TStatementPtr& aStmtPtr, int aIndex);

// default destructor: the finalization will be done by the destructor of the last shared pointer
// default copy constructor and assignment operator are perfectly suited :
Expand Down Expand Up @@ -250,8 +252,8 @@ class Column
}

private:
Statement::Ptr mStmtPtr; ///< Shared Pointer to the prepared SQLite Statement Object
int mIndex; ///< Index of the column in the row of result, starting at 0
Statement::TStatementPtr mStmtPtr; ///< Shared Pointer to the prepared SQLite Statement Object
int mIndex; ///< Index of the column in the row of result, starting at 0
};

/**
Expand Down Expand Up @@ -281,7 +283,7 @@ T Statement::getColumns()
template<typename T, const int... Is>
T Statement::getColumns(const std::integer_sequence<int, Is...>)
{
return T{Column(mStmtPtr, Is)...};
return T{Column(mpPreparedStatement, Is)...};
}

#endif
Expand Down
175 changes: 44 additions & 131 deletions include/SQLiteCpp/Statement.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

#include <string>
#include <map>
#include <climits> // For INT_MAX
#include <memory>

// Forward declarations to avoid inclusion of <sqlite3.h> in a header
struct sqlite3;
Expand Down Expand Up @@ -51,8 +51,6 @@ extern const int OK; ///< SQLITE_OK
*/
class Statement
{
friend class Column; // For access to Statement::Ptr inner class

public:
/**
* @brief Compile and register the SQL query for the provided SQLite Database Connection
Expand All @@ -62,7 +60,7 @@ class Statement
*
* Exception is thrown in case of error, then the Statement object is NOT constructed.
*/
Statement(Database& aDatabase, const char* apQuery);
Statement(const Database& aDatabase, const char* apQuery);

/**
* @brief Compile and register the SQL query for the provided SQLite Database Connection
Expand All @@ -72,7 +70,7 @@ class Statement
*
* Exception is thrown in case of error, then the Statement object is NOT constructed.
*/
Statement(Database &aDatabase, const std::string& aQuery) :
Statement(const Database& aDatabase, const std::string& aQuery) :
Statement(aDatabase, aQuery.c_str())
{}

Expand All @@ -82,6 +80,7 @@ class Statement
* @param[in] aStatement Statement to move
*/
Statement(Statement&& aStatement) noexcept;
Statement& operator=(Statement&& aStatement) noexcept = default;

// Statement is non-copyable
Statement(const Statement&) = delete;
Expand Down Expand Up @@ -123,39 +122,20 @@ class Statement
// => if you know what you are doing, use bindNoCopy() instead of bind()

SQLITECPP_PURE_FUNC
int getIndex(const char * const apName);
int getIndex(const char * const apName) const;

/**
* @brief Bind an int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const int aValue);
void bind(const int aIndex, const int32_t aValue);
/**
* @brief Bind a 32bits unsigned int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const unsigned aValue);

#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW)
/**
* @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const long aValue)
{
bind(aIndex, static_cast<int>(aValue));
}
#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS)
/**
* @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const long aValue)
{
bind(aIndex, static_cast<long long>(aValue));
}
#endif

void bind(const int aIndex, const uint32_t aValue);
/**
* @brief Bind a 64bits int value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const int aIndex, const long long aValue);
void bind(const int aIndex, const int64_t aValue);
/**
* @brief Bind a double (64bits float) value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
Expand Down Expand Up @@ -210,39 +190,21 @@ class Statement
/**
* @brief Bind an int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const int aValue)
void bind(const char* apName, const int32_t aValue)
{
bind(getIndex(apName), aValue);
}
/**
* @brief Bind a 32bits unsigned int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const unsigned aValue)
void bind(const char* apName, const uint32_t aValue)
{
bind(getIndex(apName), aValue);
}

#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW)
/**
* @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const long aValue)
{
bind(apName, static_cast<int>(aValue));
}
#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS)
/**
* @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const long aValue)
{
bind(apName, static_cast<long long>(aValue));
}
#endif
/**
* @brief Bind a 64bits int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const char* apName, const long long aValue)
void bind(const char* apName, const int64_t aValue)
{
bind(getIndex(apName), aValue);
}
Expand Down Expand Up @@ -325,46 +287,28 @@ class Statement
/**
* @brief Bind an int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const int aValue)
void bind(const std::string& aName, const int32_t aValue)
{
bind(aName.c_str(), aValue);
}
/**
* @brief Bind a 32bits unsigned int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const unsigned aValue)
void bind(const std::string& aName, const uint32_t aValue)
{
bind(aName.c_str(), aValue);
}

#if (LONG_MAX == INT_MAX) // 4 bytes "long" type means the data model is ILP32 or LLP64 (Win64 Visual C++ and MinGW)
/**
* @brief Bind a 32bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const long aValue)
{
bind(aName.c_str(), static_cast<int>(aValue));
}
#else // 8 bytes "long" type means the data model is LP64 (Most Unix-like, Windows when using Cygwin; z/OS)
/**
* @brief Bind a 64bits long value to a parameter "?", "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const long aValue)
{
bind(aName.c_str(), static_cast<long long>(aValue));
}
#endif
/**
* @brief Bind a 64bits int value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const long long aValue)
void bind(const std::string& aName, const int64_t aValue)
{
bind(aName.c_str(), aValue);
}
/**
* @brief Bind a double (64bits float) value to a named parameter "?NNN", ":VVV", "@VVV" or "$VVV" in the SQL prepared statement (aIndex >= 1)
*/
void bind(const std::string& aName, const double aValue)
void bind(const std::string& aName, const double aValue)
{
bind(aName.c_str(), aValue);
}
Expand Down Expand Up @@ -519,7 +463,7 @@ class Statement
* Thus, you should instead extract immediately its data (getInt(), getText()...)
* and use or copy this data for any later usage.
*/
Column getColumn(const int aIndex);
Column getColumn(const int aIndex) const;

/**
* @brief Return a copy of the column data specified by its column name (less efficient than using an index)
Expand Down Expand Up @@ -550,7 +494,7 @@ class Statement
*
* Throw an exception if the specified name is not one of the aliased name of the columns in the result.
*/
Column getColumn(const char* apName);
Column getColumn(const char* apName) const;

#if __cplusplus >= 201402L || (defined(_MSC_VER) && _MSC_VER >= 1900) // c++14: Visual Studio 2015
/**
Expand Down Expand Up @@ -673,7 +617,7 @@ class Statement
}

// Return a UTF-8 string containing the SQL text of prepared statement with bound parameters expanded.
std::string getExpandedSQL();
std::string getExpandedSQL() const;

/// Return the number of columns in the result set returned by the prepared statement
int getColumnCount() const
Expand Down Expand Up @@ -701,52 +645,8 @@ class Statement
/// Return UTF-8 encoded English language explanation of the most recent failed API call (if any).
const char* getErrorMsg() const noexcept;

private:
/**
* @brief Shared pointer to the sqlite3_stmt SQLite Statement Object.
*
* Manage the finalization of the sqlite3_stmt with a reference counter.
*
* This is a internal class, not part of the API (hence full documentation is in the cpp).
*/
// TODO Convert this whole custom pointer to a C++11 std::shared_ptr with a custom deleter
class Ptr
{
public:
// Prepare the statement and initialize its reference counter
Ptr(sqlite3* apSQLite, std::string& aQuery);
// Copy constructor increments the ref counter
Ptr(const Ptr& aPtr);

// Move constructor
Ptr(Ptr&& aPtr);

// Decrement the ref counter and finalize the sqlite3_stmt when it reaches 0
~Ptr();

/// Inline cast operator returning the pointer to SQLite Database Connection Handle
operator sqlite3*() const
{
return mpSQLite;
}

/// Inline cast operator returning the pointer to SQLite Statement Object
operator sqlite3_stmt*() const
{
return mpStmt;
}

private:
/// @{ Unused/forbidden copy/assignment operator
Ptr& operator=(const Ptr& aPtr);
/// @}

private:
sqlite3* mpSQLite; //!< Pointer to SQLite Database Connection Handle
sqlite3_stmt* mpStmt; //!< Pointer to SQLite Statement Object
unsigned int* mpRefCount; //!< Pointer to the heap allocated reference counter of the sqlite3_stmt
//!< (to share it with Column objects)
};
/// Shared pointer to SQLite Prepared Statement Object
using TStatementPtr = std::shared_ptr<sqlite3_stmt>;

private:
/**
Expand All @@ -758,7 +658,7 @@ class Statement
{
if (SQLite::OK != aRet)
{
throw SQLite::Exception(mStmtPtr, aRet);
throw SQLite::Exception(mpSQLite, aRet);
}
}

Expand All @@ -784,17 +684,30 @@ class Statement
}
}

private:
/// Map of columns index by name (mutable so getColumnIndex can be const)
typedef std::map<std::string, int> TColumnNames;
/**
* @brief Prepare statement object.
*
* @return Shared pointer to prepared statement object
*/
TStatementPtr prepareStatement();

private:
std::string mQuery; //!< UTF-8 SQL Query
Ptr mStmtPtr; //!< Shared Pointer to the prepared SQLite Statement Object
int mColumnCount; //!< Number of columns in the result of the prepared statement
mutable TColumnNames mColumnNames; //!< Map of columns index by name (mutable so getColumnIndex can be const)
bool mbHasRow; //!< true when a row has been fetched with executeStep()
bool mbDone; //!< true when the last executeStep() had no more row to fetch
/**
* @brief Return a prepared statement object.
*
* Throw an exception if the statement object was not prepared.
* @return raw pointer to Prepared Statement Object
*/
sqlite3_stmt* getPreparedStatement() const;

std::string mQuery; //!< UTF-8 SQL Query
sqlite3* mpSQLite; //!< Pointer to SQLite Database Connection Handle
TStatementPtr mpPreparedStatement; //!< Shared Pointer to the prepared SQLite Statement Object
int mColumnCount{0}; //!< Number of columns in the result of the prepared statement
bool mbHasRow{false}; //!< true when a row has been fetched with executeStep()
bool mbDone{false}; //!< true when the last executeStep() had no more row to fetch

/// Map of columns index by name (mutable so getColumnIndex can be const)
mutable std::map<std::string, int> mColumnNames{};
};


Expand Down
Loading

0 comments on commit 9158225

Please sign in to comment.