diff --git a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters
index 2df7b1de94..d7db2a17cd 100644
--- a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters
+++ b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters
@@ -27,7 +27,16 @@
{52b26063-ccff-40ae-a420-243327dcca25}
-
+
+
+ {b35e1a8b-1961-46d5-98e5-adc8e52c54a9}
+
+
+ {d055e02a-59c9-4ae4-9405-579dac6ff59b}
+
+
+ {13d4d227-0f04-4e57-a663-c3c535438ab3}
+
@@ -59,146 +68,146 @@
Source Files
-
- Source Files
-
-
- Source Files
-
Source Files
-
+
Source Files
-
+
Source Files
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\CLI
-
- Source Files
+
+ Source Files\CLI
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\CLI
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\CLI
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
- Source Files
+ Source Files\Common
- Source Files
+ Source Files\CLI
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Repository
- Source Files
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Repository
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
+
Source Files
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\Common
-
- Source Files
+
+ Source Files\CLI
+
+
+ Source Files\CLI
+
+
+ Source Files\Common
@@ -479,7 +488,7 @@
TestData
-
+
TestData
@@ -506,7 +515,7 @@
TestData
-
+
TestData
diff --git a/src/AppInstallerCLITests/CompositeSource.cpp b/src/AppInstallerCLITests/CompositeSource.cpp
index e1eb0865a3..273fd27571 100644
--- a/src/AppInstallerCLITests/CompositeSource.cpp
+++ b/src/AppInstallerCLITests/CompositeSource.cpp
@@ -50,11 +50,11 @@ struct ComponentTestSource : public TestSource
// A helper to create the sources used by the majority of tests in this file.
struct CompositeTestSetup
{
- CompositeTestSetup() : Composite("*Tests")
+ CompositeTestSetup(CompositeSearchBehavior behavior = CompositeSearchBehavior::Installed) : Composite("*Tests")
{
Installed = std::make_shared("InstalledTestSource1");
Available = std::make_shared("AvailableTestSource1");
- Composite.SetInstalledSource(Source{ Installed });
+ Composite.SetInstalledSource(Source{ Installed }, behavior);
Composite.AddAvailableSource(Source{ Available });
}
@@ -942,3 +942,23 @@ TEST_CASE("CompositeSource_TrackingFound_NotInstalled", "[CompositeSource]")
REQUIRE(result.Matches.empty());
}
+
+TEST_CASE("CompositeSource_NullInstalledVersion", "[CompositeSource]")
+{
+ CompositeTestSetup setup;
+ setup.Installed->Everything.Matches.emplace_back(MakeAvailable(), Criteria());
+
+ // We are mostly testing to see if a null installed version causes an AV or not
+ SearchResult result = setup.Search();
+ REQUIRE(result.Matches.size() == 0);
+}
+
+TEST_CASE("CompositeSource_NullAvailableVersion", "[CompositeSource]")
+{
+ CompositeTestSetup setup{ CompositeSearchBehavior::AvailablePackages };
+ setup.Available->Everything.Matches.emplace_back(MakeInstalled(), Criteria());
+
+ // We are mostly testing to see if a null available version causes an AV or not
+ SearchResult result = setup.Search();
+ REQUIRE(result.Matches.size() == 1);
+}
diff --git a/src/AppInstallerCLITests/SQLiteIndex.cpp b/src/AppInstallerCLITests/SQLiteIndex.cpp
index f63c56bb4e..63fbf627ad 100644
--- a/src/AppInstallerCLITests/SQLiteIndex.cpp
+++ b/src/AppInstallerCLITests/SQLiteIndex.cpp
@@ -3153,4 +3153,31 @@ TEST_CASE("SQLiteIndex_ManifestArpVersion_ValidateManifestAgainstIndex", "[sqlit
// Add different version should result in failure.
manifest.Version = "10.1";
REQUIRE_THROWS(ValidateManifestArpVersion(&index, manifest));
-}
\ No newline at end of file
+}
+
+TEST_CASE("SQLiteIndex_CheckConsistency_FindEmbeddedNull", "[sqliteindex]")
+{
+ TempFile tempFile{ "repolibtest_tempdb"s, ".db"s };
+ INFO("Using temporary file named: " << tempFile.GetPath());
+
+ SQLiteIndex index = CreateTestIndex(tempFile, Schema::Version::Latest());
+
+ Manifest manifest;
+ manifest.Id = "Foo";
+ manifest.Version = "10.0";
+ manifest.Installers.push_back({});
+ manifest.Installers[0].InstallerType = InstallerTypeEnum::Exe;
+ manifest.Installers[0].AppsAndFeaturesEntries.push_back({});
+ manifest.Installers[0].AppsAndFeaturesEntries[0].DisplayVersion = "1.0";
+ manifest.Installers[0].AppsAndFeaturesEntries.push_back({});
+ manifest.Installers[0].AppsAndFeaturesEntries[1].DisplayVersion = "1.1";
+
+ index.AddManifest(manifest, "path");
+
+ // Inject a null character using SQL without binding since we block it
+ Connection connection = Connection::Create(tempFile, Connection::OpenDisposition::ReadWrite);
+ Statement update = Statement::Create(connection, "Update versions set version = '10.0'||char(0)||'After Null' where version = '10.0'");
+ update.Execute();
+
+ REQUIRE(!index.CheckConsistency(true));
+}
diff --git a/src/AppInstallerCLITests/SQLiteWrapper.cpp b/src/AppInstallerCLITests/SQLiteWrapper.cpp
index a003ef975d..d27b20fd24 100644
--- a/src/AppInstallerCLITests/SQLiteWrapper.cpp
+++ b/src/AppInstallerCLITests/SQLiteWrapper.cpp
@@ -2,6 +2,7 @@
// Licensed under the MIT License.
#include "pch.h"
#include "TestCommon.h"
+#include
#include
#include
@@ -286,6 +287,19 @@ TEST_CASE("SQLiteWrapper_EscapeStringForLike", "[sqlitewrapper]")
REQUIRE(expected == output);
}
+TEST_CASE("SQLiteWrapper_BindWithEmbeddedNull", "[sqlitewrapper]")
+{
+ Connection connection = Connection::Create(SQLITE_MEMORY_DB_CONNECTION_TARGET, Connection::OpenDisposition::Create);
+
+ CreateSimpleTestTable(connection);
+
+ int firstVal = 1;
+ std::string secondVal = "test";
+ secondVal[1] = '\0';
+
+ REQUIRE_THROWS_HR(InsertIntoSimpleTestTable(connection, firstVal, secondVal), APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL);
+}
+
TEST_CASE("SQLBuilder_SimpleSelectBind", "[sqlbuilder]")
{
Connection connection = Connection::Create(SQLITE_MEMORY_DB_CONNECTION_TARGET, Connection::OpenDisposition::Create);
diff --git a/src/AppInstallerCLITests/Strings.cpp b/src/AppInstallerCLITests/Strings.cpp
index 547e5fd66c..86183ee224 100644
--- a/src/AppInstallerCLITests/Strings.cpp
+++ b/src/AppInstallerCLITests/Strings.cpp
@@ -99,6 +99,10 @@ TEST_CASE("NormalizedString", "[strings]")
// Ligature fi => f + i
std::string_view input2 = u8"\xFB01";
REQUIRE(NormalizedString(input2) == u8"fi");
+
+ // Embedded null
+ std::string_view input3{ "Test\0Case", 9 };
+ REQUIRE(NormalizedString(input3) == "Test Case");
}
TEST_CASE("Trim", "[strings]")
@@ -203,4 +207,12 @@ TEST_CASE("SplitIntoWords", "[strings]")
// 私のテスト = "My test" according to an online translator
// Split as "私" "の" "テスト"
REQUIRE(SplitIntoWords("\xe7\xa7\x81\xe3\x81\xae\xe3\x83\x86\xe3\x82\xb9\xe3\x83\x88") == std::vector{ "\xe7\xa7\x81", "\xe3\x81\xae", "\xe3\x83\x86\xe3\x82\xb9\xe3\x83\x88" });
-}
\ No newline at end of file
+}
+
+TEST_CASE("ReplaceEmbeddedNullCharacters", "[strings]")
+{
+ std::string test = "Test Parts";
+ test[4] = '\0';
+ ReplaceEmbeddedNullCharacters(test);
+ REQUIRE(test == "Test Parts");
+}
diff --git a/src/AppInstallerCommonCore/AppInstallerStrings.cpp b/src/AppInstallerCommonCore/AppInstallerStrings.cpp
index 9d1e6cd9e4..d09fc0e03c 100644
--- a/src/AppInstallerCommonCore/AppInstallerStrings.cpp
+++ b/src/AppInstallerCommonCore/AppInstallerStrings.cpp
@@ -333,6 +333,17 @@ namespace AppInstaller::Utility
}
}
+ void ReplaceEmbeddedNullCharacters(std::string& s, char c)
+ {
+ for (size_t i = 0; i < s.length(); ++i)
+ {
+ if (s[i] == '\0')
+ {
+ s[i] = c;
+ }
+ }
+ }
+
std::string ToLower(std::string_view in)
{
std::string result(in);
diff --git a/src/AppInstallerCommonCore/Errors.cpp b/src/AppInstallerCommonCore/Errors.cpp
index f41de288e4..3f36a8199e 100644
--- a/src/AppInstallerCommonCore/Errors.cpp
+++ b/src/AppInstallerCommonCore/Errors.cpp
@@ -220,6 +220,8 @@ namespace AppInstaller
return "Organization policies are preventing installation. Contact your admin.";
case APPINSTALLER_CLI_ERROR_INSTALL_DEPENDENCIES:
return "Failed to install package dependencies.";
+ case APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL:
+ return "Embedded null characters are disallowed for SQLite";
default:
return "Unknown Error Code";
}
diff --git a/src/AppInstallerCommonCore/Public/AppInstallerErrors.h b/src/AppInstallerCommonCore/Public/AppInstallerErrors.h
index 8284fcc00d..0435b00399 100644
--- a/src/AppInstallerCommonCore/Public/AppInstallerErrors.h
+++ b/src/AppInstallerCommonCore/Public/AppInstallerErrors.h
@@ -102,6 +102,7 @@
#define APPINSTALLER_CLI_ERROR_PORTABLE_UNINSTALL_FAILED ((HRESULT)0x8A150057)
#define APPINSTALLER_CLI_ERROR_ARP_VERSION_VALIDATION_FAILED ((HRESULT)0x8A150058)
#define APPINSTALLER_CLI_ERROR_UNSUPPORTED_ARGUMENT ((HRESULT)0x8A150059)
+#define APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL ((HRESULT)0x8A15005A)
// Install errors.
#define APPINSTALLER_CLI_ERROR_INSTALL_PACKAGE_IN_USE ((HRESULT)0x8A150101)
diff --git a/src/AppInstallerCommonCore/Public/AppInstallerStrings.h b/src/AppInstallerCommonCore/Public/AppInstallerStrings.h
index d8e4f0216d..140c5b6800 100644
--- a/src/AppInstallerCommonCore/Public/AppInstallerStrings.h
+++ b/src/AppInstallerCommonCore/Public/AppInstallerStrings.h
@@ -24,22 +24,26 @@ namespace AppInstaller::Utility
// Normalizes a UTF16 string to the given form.
std::wstring Normalize(std::wstring_view input, NORM_FORM form = NORM_FORM::NormalizationKC);
+ // Replaces any embedded null character found within the string.
+ void ReplaceEmbeddedNullCharacters(std::string& s, char c = ' ');
+
// Type to hold and force a normalized UTF8 string.
- template
+ // Also enables further normalization by replacing embedded null characters with spaces.
+ template
struct NormalizedUTF8 : public std::string
{
NormalizedUTF8() = default;
template
- NormalizedUTF8(const char(&s)[Size]) : std::string(Normalize(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }, Form)) {}
+ NormalizedUTF8(const char(&s)[Size]) { AssignValue(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }); }
- NormalizedUTF8(std::string_view sv) : std::string(Normalize(sv, Form)) {}
+ NormalizedUTF8(std::string_view sv) { AssignValue(sv); }
- NormalizedUTF8(std::string& s) : std::string(Normalize(s, Form)) {}
- NormalizedUTF8(const std::string& s) : std::string(Normalize(s, Form)) {}
- NormalizedUTF8(std::string&& s) : std::string(Normalize(s, Form)) {}
+ NormalizedUTF8(std::string& s) { AssignValue(s); }
+ NormalizedUTF8(const std::string& s) { AssignValue(s); }
+ NormalizedUTF8(std::string&& s) { AssignValue(s); }
- NormalizedUTF8(std::wstring_view sv) : std::string(ConvertToUTF8(Normalize(sv, Form))) {}
+ NormalizedUTF8(std::wstring_view sv) { AssignValue(sv); }
NormalizedUTF8(const NormalizedUTF8& other) = default;
NormalizedUTF8& operator=(const NormalizedUTF8& other) = default;
@@ -50,30 +54,51 @@ namespace AppInstaller::Utility
template
NormalizedUTF8& operator=(const char(&s)[Size])
{
- assign(Normalize(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }, Form));
+ AssignValue(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) });
return *this;
}
NormalizedUTF8& operator=(std::string_view sv)
{
- assign(Normalize(sv, Form));
+ AssignValue(sv);
return *this;
}
NormalizedUTF8& operator=(const std::string& s)
{
- assign(Normalize(s, Form));
+ AssignValue(s);
return *this;
}
NormalizedUTF8& operator=(std::string&& s)
{
- assign(Normalize(s, Form));
+ AssignValue(s);
return *this;
}
+
+ private:
+ void AssignValue(std::string_view sv)
+ {
+ assign(Normalize(sv, Form));
+
+ if constexpr (ConvertEmbeddedNullToSpace)
+ {
+ ReplaceEmbeddedNullCharacters(*this);
+ }
+ }
+
+ void AssignValue(std::wstring_view sv)
+ {
+ assign(ConvertToUTF8(Normalize(sv, Form)));
+
+ if constexpr (ConvertEmbeddedNullToSpace)
+ {
+ ReplaceEmbeddedNullCharacters(*this);
+ }
+ }
};
- using NormalizedString = NormalizedUTF8<>;
+ using NormalizedString = NormalizedUTF8;
// Compares the two UTF8 strings in a case insensitive manner.
// Use this if one of the values is a known value, and thus ToLower is sufficient.
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp
index 5b97c3c423..3fcdedb69b 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp
@@ -393,53 +393,35 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0
{
bool result = true;
- // Check the manifest table references to it's 1:1 tables
- if (result || log)
- {
- result = ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = ManifestTable::CheckConsistency(connection, log) && result;
+#define AICLI_CHECK_CONSISTENCY(_check_) \
+ if (result || log) \
+ { \
+ result = _check_ && result; \
}
- // Check the pathpaths table for consistency
- if (result || log)
- {
- result = PathPartTable::CheckConsistency(connection, log) && result;
- }
+ // Check the manifest table references to it's 1:1 tables
+ AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log));
+
+ // Check the 1:1 tables' consistency
+ AICLI_CHECK_CONSISTENCY(IdTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(NameTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(MonikerTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(VersionTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(ChannelTable::CheckConsistency(connection, log));
+
+ // Check the pathparts table for consistency
+ AICLI_CHECK_CONSISTENCY(PathPartTable::CheckConsistency(connection, log));
// Check the 1:N map tables for consistency
- if (result || log)
- {
- result = TagsTable::CheckConsistency(connection, log) && result;
- }
+ AICLI_CHECK_CONSISTENCY(TagsTable::CheckConsistency(connection, log));
+ AICLI_CHECK_CONSISTENCY(CommandsTable::CheckConsistency(connection, log));
- if (result || log)
- {
- result = CommandsTable::CheckConsistency(connection, log) && result;
- }
+#undef AICLI_CHECK_CONSISTENCY
return result;
}
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp
index 4980da4351..6dcaca1c6b 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp
@@ -355,6 +355,13 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0
result = result && secondaryResult;
}
+ if (!result && !log)
+ {
+ return result;
+ }
+
+ result = OneToOneTableCheckConsistency(connection, tableName, valueName, log) && result;
+
return result;
}
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp
index ba5d1347a8..be3f4bd239 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp
@@ -176,5 +176,34 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0
builder.Execute(connection);
}
+
+ bool OneToOneTableCheckConsistency(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, bool log)
+ {
+ // Build a select statement to find values that contain an embedded null character
+ // Such as:
+ // Select count(*) from table where instr(value,char(0))>0
+ SQLite::Builder::StatementBuilder builder;
+ builder.
+ Select({ SQLite::RowIDName, valueName }).
+ From(tableName).
+ WhereValueContainsEmbeddedNullCharacter(valueName);
+
+ SQLite::Statement select = builder.Prepare(connection);
+ bool result = true;
+
+ while (select.Step())
+ {
+ result = false;
+
+ if (!log)
+ {
+ break;
+ }
+
+ AICLI_LOG(Repo, Info, << " [INVALID] value in table [" << tableName << "] at row [" << select.GetColumn(0) << "] contains an embedded null character and starts with [" << select.GetColumn(1) << "]");
+ }
+
+ return result;
+ }
}
}
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h
index 51e0360118..02594c758b 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h
@@ -37,6 +37,9 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0
// Removes the given row by its rowid if it is no longer referenced.
void OneToOneTableDeleteById(SQLite::Connection& connection, std::string_view tableName, SQLite::rowid_t id);
+
+ // Checks the consistency of the table.
+ bool OneToOneTableCheckConsistency(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, bool log);
}
// A table that represents a value that is 1:1 with a primary entry.
@@ -132,5 +135,11 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0
{
return details::OneToOneTableIsEmpty(connection, TableInfo::TableName());
}
+
+ // Checks the consistency of the table.
+ static bool CheckConsistency(const SQLite::Connection& connection, bool log)
+ {
+ return details::OneToOneTableCheckConsistency(connection, TableInfo::TableName(), TableInfo::ValueName(), log);
+ }
};
}
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp
index 5cea7d41c7..d6185e8efb 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp
@@ -373,27 +373,59 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0
// Select l.rowid, l.parent from pathparts as l left outer join pathparts as r on l.parent = r.rowid where l.parent is not null and r.pathpart is null
constexpr std::string_view s_left = "left"sv;
constexpr std::string_view s_right = "right"sv;
-
- SQLite::Builder::StatementBuilder builder;
- builder.
- Select({ QCol(s_left, SQLite::RowIDName), QCol(s_left, s_PathPartTable_ParentValue_Name) }).
- From(s_PathPartTable_Table_Name).As(s_left).
- LeftOuterJoin(s_PathPartTable_Table_Name).As(s_right).On(QCol(s_left, s_PathPartTable_ParentValue_Name), QCol(s_right, SQLite::RowIDName)).
- Where(QCol(s_left, s_PathPartTable_ParentValue_Name)).IsNotNull().And(QCol(s_right, s_PathPartTable_PartValue_Name)).IsNull();
-
- SQLite::Statement select = builder.Prepare(connection);
bool result = true;
- while (select.Step())
{
- result = false;
+ SQLite::Builder::StatementBuilder builder;
+ builder.
+ Select({ QCol(s_left, SQLite::RowIDName), QCol(s_left, s_PathPartTable_ParentValue_Name) }).
+ From(s_PathPartTable_Table_Name).As(s_left).
+ LeftOuterJoin(s_PathPartTable_Table_Name).As(s_right).On(QCol(s_left, s_PathPartTable_ParentValue_Name), QCol(s_right, SQLite::RowIDName)).
+ Where(QCol(s_left, s_PathPartTable_ParentValue_Name)).IsNotNull().And(QCol(s_right, s_PathPartTable_PartValue_Name)).IsNull();
+
+ SQLite::Statement select = builder.Prepare(connection);
- if (!log)
+ while (select.Step())
{
- break;
+ result = false;
+
+ if (!log)
+ {
+ break;
+ }
+
+ AICLI_LOG(Repo, Info, << " [INVALID] pathparts [" << select.GetColumn(0) << "] refers to " << s_PathPartTable_ParentValue_Name << " [" << select.GetColumn(1) << "]");
}
+ }
- AICLI_LOG(Repo, Info, << " [INVALID] pathparts [" << select.GetColumn(0) << "] refers to " << s_PathPartTable_ParentValue_Name << " [" << select.GetColumn(1) << "]");
+ if (!result && !log)
+ {
+ return result;
+ }
+
+ {
+ // Build a select statement to find values that contain an embedded null character
+ // Such as:
+ // Select count(*) from table where instr(value,char(0))>0
+ SQLite::Builder::StatementBuilder builder;
+ builder.
+ Select({ SQLite::RowIDName, s_PathPartTable_PartValue_Name }).
+ From(s_PathPartTable_Table_Name).
+ WhereValueContainsEmbeddedNullCharacter(s_PathPartTable_PartValue_Name);
+
+ SQLite::Statement select = builder.Prepare(connection);
+
+ while (select.Step())
+ {
+ result = false;
+
+ if (!log)
+ {
+ break;
+ }
+
+ AICLI_LOG(Repo, Info, << " [INVALID] value in table [" << s_PathPartTable_Table_Name << "] at row [" << select.GetColumn(0) << "] contains an embedded null character and starts with [" << select.GetColumn(1) << "]");
+ }
}
return result;
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp
index 381e5ef29c..d4c7668098 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp
@@ -1,560 +1,560 @@
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT License.
-#include "pch.h"
-#include "DependenciesTable.h"
-#include "SQLiteStatementBuilder.h"
-#include "winget\DependenciesGraph.h"
-#include "Microsoft/Schema/1_0/OneToOneTable.h"
-#include "Microsoft/Schema/1_0/IdTable.h"
-#include "Microsoft/Schema/1_0/ManifestTable.h"
-#include "Microsoft/Schema/1_0/VersionTable.h"
-#include "Microsoft/Schema/1_0/Interface.h"
-#include "Microsoft/Schema/1_0/ChannelTable.h"
-
-namespace AppInstaller::Repository::Microsoft::Schema::V1_4
-{
- using namespace AppInstaller;
- using namespace std::string_view_literals;
- using namespace SQLite::Builder;
- using namespace Schema::V1_0;
- using QCol = SQLite::Builder::QualifiedColumn;
-
- static constexpr std::string_view s_DependenciesTable_Table_Name = "dependencies"sv;
- static constexpr std::string_view s_DependenciesTable_Index_Name = "dependencies_pkindex"sv;
- static constexpr std::string_view s_DependenciesTable_Manifest_Column_Name = "manifest"sv;
- static constexpr std::string_view s_DependenciesTable_MinVersion_Column_Name = "min_version"sv;
- static constexpr std::string_view s_DependenciesTable_PackageId_Column_Name = "package_id";
-
- namespace
- {
- struct DependencyTableRow
- {
- SQLite::rowid_t m_packageRowId;
- SQLite::rowid_t m_manifestRowId;
-
- // Ideally this should be version row id, the version string is more needed than row id,
- // this prevents converting back and forth between version row id and version string.
- std::optional m_version;
-
- bool operator <(const DependencyTableRow& rhs) const
- {
- auto lhsVersion = m_version.has_value() ? m_version.value() : "";
- auto rhsVersion = rhs.m_version.has_value() ? rhs.m_version.value() : "";
- return std::tie(m_packageRowId, m_manifestRowId, lhsVersion) < std::tie(rhs.m_packageRowId, rhs.m_manifestRowId, rhsVersion);
- }
- };
-
- void ThrowOnMissingPackageNodes(std::vector& missingPackageNodes)
- {
- if (!missingPackageNodes.empty())
- {
- std::string missingPackages{ missingPackageNodes.begin()->Id };
- std::for_each(
- missingPackageNodes.begin() + 1,
- missingPackageNodes.end(),
- [&](auto& dep) { missingPackages.append(", " + dep.Id); });
- THROW_HR_MSG(APPINSTALLER_CLI_ERROR_MISSING_PACKAGE, "Missing packages: %hs", missingPackages.c_str());
- }
- }
-
- std::set GetAndLinkDependencies(
- SQLite::Connection& connection,
- const Manifest::Manifest& manifest,
- SQLite::rowid_t manifestRowId,
- Manifest::DependencyType dependencyType)
- {
- std::set dependencies;
- std::vector missingPackageNodes;
-
- for (const auto& installer : manifest.Installers)
- {
- installer.Dependencies.ApplyToType(dependencyType, [&](Manifest::Dependency dependency)
- {
- auto packageRowId = IdTable::SelectIdByValue(connection, dependency.Id);
- std::optional version;
-
- if (!packageRowId.has_value())
- {
- missingPackageNodes.emplace_back(dependency);
- return;
- }
-
- if (dependency.MinVersion.has_value())
- {
- version = dependency.MinVersion.value().ToString();
- }
-
- dependencies.emplace(DependencyTableRow{ packageRowId.value(), manifestRowId, version });
- });
- }
-
- ThrowOnMissingPackageNodes(missingPackageNodes);
-
- return dependencies;
- }
-
- bool RemoveDependenciesByRowIds(SQLite::Connection& connection, std::vector dependencyTableRows)
- {
- using namespace SQLite::Builder;
- bool tableUpdated = false;
- if (dependencyTableRows.empty())
- {
- return tableUpdated;
- }
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_rowid");
-
- SQLite::Builder::StatementBuilder builder;
- builder
- .DeleteFrom(s_DependenciesTable_Table_Name)
- .Where(s_DependenciesTable_PackageId_Column_Name).Equals(Unbound)
- .And(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound);
-
-
- SQLite::Statement deleteStmt = builder.Prepare(connection);
- for (auto row : dependencyTableRows)
- {
- deleteStmt.Reset();
- deleteStmt.Bind(1, row.m_packageRowId);
- deleteStmt.Bind(2, row.m_manifestRowId);
- deleteStmt.Execute(true);
- tableUpdated = true;
- }
-
- savepoint.Commit();
- return tableUpdated;
- }
-
- bool InsertManifestDependencies(
- SQLite::Connection& connection,
- std::set& dependenciesTableRows)
- {
- using namespace SQLite::Builder;
- using namespace Schema::V1_0;
- bool tableUpdated = false;
-
- StatementBuilder insertBuilder;
- insertBuilder.InsertInto(s_DependenciesTable_Table_Name)
- .Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_MinVersion_Column_Name, s_DependenciesTable_PackageId_Column_Name })
- .Values(Unbound, Unbound, Unbound);
- SQLite::Statement insert = insertBuilder.Prepare(connection);
-
-
- for (const auto& dep : dependenciesTableRows)
- {
- insert.Reset();
- insert.Bind(1, dep.m_manifestRowId);
-
- if (dep.m_version.has_value())
- {
- insert.Bind(2, VersionTable::EnsureExists(connection, dep.m_version.value()));
- }
- else
- {
- insert.Bind(2, nullptr);
- }
-
- insert.Bind(3, dep.m_packageRowId);
-
- insert.Execute(true);
- tableUpdated = true;
- }
-
- return tableUpdated;
- }
- }
-
- bool DependenciesTable::Exists(const SQLite::Connection& connection)
- {
- using namespace SQLite;
-
- Builder::StatementBuilder builder;
- builder.Select(Builder::RowCount).From(Builder::Schema::MainTable).
- Where(Builder::Schema::TypeColumn).Equals(Builder::Schema::Type_Table).And(Builder::Schema::NameColumn).Equals(s_DependenciesTable_Table_Name);
-
- Statement statement = builder.Prepare(connection);
- THROW_HR_IF(E_UNEXPECTED, !statement.Step());
- return statement.GetColumn(0) != 0;
- }
-
- std::string_view DependenciesTable::TableName()
- {
- return s_DependenciesTable_Table_Name;
- }
-
- void DependenciesTable::Create(SQLite::Connection& connection)
- {
- using namespace SQLite::Builder;
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createDependencyTable_v1_4");
- constexpr std::string_view dependencyIndexByVersionId = "dependencies_version_id_index";
- constexpr std::string_view dependencyIndexByPackageId = "dependencies_package_id_index";
-
- StatementBuilder createTableBuilder;
- createTableBuilder.CreateTable(TableName()).BeginColumns();
- createTableBuilder.Column(IntegerPrimaryKey());
-
- std::array notNullableDependenciesColumns
- {
- DependenciesTableColumnInfo{ s_DependenciesTable_Manifest_Column_Name },
- DependenciesTableColumnInfo{ s_DependenciesTable_PackageId_Column_Name }
- };
-
- std::array nullableDependenciesColumns
- {
- DependenciesTableColumnInfo{ s_DependenciesTable_MinVersion_Column_Name }
- };
-
- // Add dependencies column tables not null columns.
- for (const DependenciesTableColumnInfo& value : notNullableDependenciesColumns)
- {
- createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId).NotNull());
- }
-
- // Add dependencies column tables null columns.
- for (const DependenciesTableColumnInfo& value : nullableDependenciesColumns)
- {
- createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId));
- }
-
- createTableBuilder.EndColumns();
-
- createTableBuilder.Execute(connection);
-
- // Primary key index by package rowid and manifest rowid.
- StatementBuilder createPKIndexBuilder;
- createPKIndexBuilder.CreateUniqueIndex(s_DependenciesTable_Index_Name).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_PackageId_Column_Name });
- createPKIndexBuilder.Execute(connection);
-
- // Index of dependency by Manifest id.
- StatementBuilder createIndexByManifestIdBuilder;
- createIndexByManifestIdBuilder.CreateIndex(dependencyIndexByVersionId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_MinVersion_Column_Name });
- createIndexByManifestIdBuilder.Execute(connection);
-
- // Index of dependency by package id.
- StatementBuilder createIndexByPackageIdBuilder;
- createIndexByPackageIdBuilder.CreateIndex(dependencyIndexByPackageId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_PackageId_Column_Name });
- createIndexByPackageIdBuilder.Execute(connection);
-
- savepoint.Commit();
- }
-
- void DependenciesTable::AddDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId)
- {
- if (!Exists(connection))
- {
- return;
- }
-
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "add_dependencies_v1_4");
-
- auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package);
- if (!dependencies.size())
- {
- return;
- }
-
- InsertManifestDependencies(connection, dependencies);
-
- savepoint.Commit();
- }
-
- bool DependenciesTable::UpdateDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId)
- {
- if (!Exists(connection))
- {
- return false;
- }
-
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "update_dependencies_v1_4");
-
- const auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package);
- auto existingDependencies = GetDependenciesByManifestRowId(connection, manifestRowId);
-
- // Get dependencies to add.
- std::set toAddDependencies;
- std::copy_if(
- dependencies.begin(),
- dependencies.end(),
- std::inserter(toAddDependencies, toAddDependencies.begin()),
- [&](DependencyTableRow dep)
- {
- Utility::NormalizedString version = dep.m_version.has_value() ? dep.m_version.value() : "";
- return existingDependencies.find(std::make_pair(dep.m_packageRowId, version)) == existingDependencies.end();
- }
- );
-
- // Get dependencies to remove.
- std::vector toRemoveDependencies;
- std::for_each(
- existingDependencies.begin(),
- existingDependencies.end(),
- [&](std::pair row)
- {
- if (dependencies.find(DependencyTableRow{row.first, manifestRowId, row.second}) == dependencies.end())
- {
- toRemoveDependencies.emplace_back(DependencyTableRow{ row.first, manifestRowId });
- }
- }
- );
-
- bool tableUpdated = InsertManifestDependencies(connection, toAddDependencies);
- tableUpdated = RemoveDependenciesByRowIds(connection, toRemoveDependencies) || tableUpdated;
- savepoint.Commit();
-
- return tableUpdated;
- }
-
- void DependenciesTable::RemoveDependencies(SQLite::Connection& connection, SQLite::rowid_t manifestRowId)
- {
- if (!Exists(connection))
- {
- return;
- }
-
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_manifest_v1_4");
-
- SQLite::Builder::StatementBuilder builder;
- builder.DeleteFrom(s_DependenciesTable_Table_Name).Where(s_DependenciesTable_Manifest_Column_Name).Equals(manifestRowId);
-
- builder.Execute(connection);
- savepoint.Commit();
- }
-
- std::vector> DependenciesTable::GetDependentsById(const SQLite::Connection& connection, Manifest::string_t packageId)
- {
- constexpr std::string_view depTableAlias = "dep";
- constexpr std::string_view minVersionAlias = "minV";
- constexpr std::string_view packageIdAlias = "pId";
-
-
- StatementBuilder builder;
- // Find all manifest that depend on this package.
- // SELECT [dep].[manifest], [pId].[id], [minV].[version] FROM [dependencies] AS [dep]
- // JOIN [versions] AS [minV] ON [dep].[min_version] = [minV].[rowid]
- // JOIN [ids] AS [pId] ON [pId].[rowid] = [dep].[package_id]
- // WHERE [pId].[id] = ?
- builder.Select()
- .Column(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name))
- .Column(QCol(packageIdAlias, IdTable::ValueName()))
- .Column(QCol(minVersionAlias, VersionTable::ValueName()))
- .From({ s_DependenciesTable_Table_Name }).As(depTableAlias)
- .Join({ VersionTable::TableName() }).As(minVersionAlias)
- .On(QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name), QCol(minVersionAlias, SQLite::RowIDName))
- .Join({ IdTable::TableName() }).As(packageIdAlias)
- .On(QCol(packageIdAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name))
- .Where(QCol(packageIdAlias, IdTable::ValueName())).Equals(Unbound);
-
- SQLite::Statement stmt = builder.Prepare(connection);
- stmt.Bind(1, std::string{ packageId });
-
- std::vector> resultSet;
-
- while (stmt.Step())
- {
- resultSet.emplace_back(
- std::make_pair(stmt.GetColumn(0), Utility::NormalizedString(stmt.GetColumn(2))));
- }
-
- return resultSet;
- }
-
- std::set> DependenciesTable::GetDependenciesByManifestRowId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId)
- {
- SQLite::Builder::StatementBuilder builder;
-
- constexpr std::string_view depTableAlias = "dep";
- constexpr std::string_view minVersionAlias = "minV";
-
- std::set> resultSet;
-
- // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep]
- // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version]
- // WHERE [dep].[manifest] = ?
- builder.Select()
- .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name))
- .Column(QCol(minVersionAlias, VersionTable::ValueName()))
- .From({ s_DependenciesTable_Table_Name }).As(depTableAlias)
- .Join({ VersionTable::TableName() }).As(minVersionAlias)
- .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name))
- .Where(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name)).Equals(Unbound);
-
- SQLite::Statement select = builder.Prepare(connection);
-
- select.Bind(1, manifestRowId);
- while (select.Step())
- {
- Utility::NormalizedString version = "";
- if (!select.GetColumnIsNull(1))
- {
- version = select.GetColumn(1);
- }
- resultSet.emplace(std::make_pair(select.GetColumn(0), version));
- }
-
- return resultSet;
- }
-
- void DependenciesTable::PrepareForPackaging(SQLite::Connection& connection)
- {
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "prepareForPacking_V1_4");
-
- StatementBuilder dropIndexBuilder;
- dropIndexBuilder.DropIndex({ s_DependenciesTable_Index_Name });
- dropIndexBuilder.Execute(connection);
-
- StatementBuilder dropTableBuilder;
- dropTableBuilder.DropTable({ s_DependenciesTable_Table_Name });
- dropTableBuilder.Execute(connection);
-
- savepoint.Commit();
- }
-
- bool DependenciesTable::DependenciesTableCheckConsistency(const SQLite::Connection& connection, bool log)
- {
- StatementBuilder builder;
-
- if (!Exists(connection))
- {
- return true;
- }
-
- builder.Select(QCol(s_DependenciesTable_Table_Name, SQLite::RowIDName))
- .From(s_DependenciesTable_Table_Name)
- .LeftOuterJoin(IdTable::TableName())
- .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_PackageId_Column_Name), QCol(IdTable::TableName(), SQLite::RowIDName))
- .LeftOuterJoin(ManifestTable::TableName())
- .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_Manifest_Column_Name), QCol(ManifestTable::TableName(), SQLite::RowIDName))
- .LeftOuterJoin(VersionTable::TableName())
- .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name), QCol(VersionTable::TableName(), SQLite::RowIDName))
- .Where(QCol(ManifestTable::TableName(), SQLite::RowIDName)).IsNull()
- .Or(QCol(VersionTable::TableName(), SQLite::RowIDName)).IsNull().And(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name)).IsNotNull()
- .Or(QCol(IdTable::TableName(), SQLite::RowIDName)).IsNull();
-
- SQLite::Statement select = builder.Prepare(connection);
-
- bool result = true;
-
- while (select.Step())
- {
- result = false;
-
- if (!log)
- {
- break;
- }
-
- AICLI_LOG(Repo, Info, << " [INVALID] rowid [" << select.GetColumn(0) << "]");
- }
-
- return result;
- }
-
- bool DependenciesTable::IsValueReferenced(const SQLite::Connection& connection, std::string_view tableName, SQLite::rowid_t valueRowId)
- {
- if (!Exists(connection))
- {
- return false;
- }
-
- StatementBuilder builder;
-
- if (tableName != V1_0::VersionTable::TableName())
- {
- return false;
- }
-
- std::array columns = { s_DependenciesTable_MinVersion_Column_Name };
- bool referenced = false;
-
- for(auto column: columns)
- {
- builder.Select(SQLite::RowIDName).From(s_DependenciesTable_Table_Name).Where(column).Equals(Unbound).Limit(1);
-
- SQLite::Statement select = builder.Prepare(connection);
-
- select.Bind(1, valueRowId);
- if (select.Step())
- {
- referenced = true;
- break;
- }
- }
-
- return referenced;
- }
-
- std::vector DependenciesTable::GetDependenciesMinVersionsRowIdByManifestId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId)
- {
- if (!Exists(connection))
- {
- return {};
- }
-
- StatementBuilder builder;
-
- std::vector result;
-
- // Find all versions for manifest row.
- // SELECT [min_version] FROM [dependencies]
- // WHERE [manifest] = ?
- builder.Select()
- .Column(s_DependenciesTable_MinVersion_Column_Name)
- .From({ s_DependenciesTable_Table_Name })
- .Where(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound);
-
- auto select = builder.Prepare(connection);
-
- select.Bind(1, manifestRowId);
-
- while (select.Step())
- {
- if (!select.GetColumnIsNull(0))
- {
- result.emplace_back(select.GetColumn(0));
- }
- }
-
- return result;
- }
-
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+#include "pch.h"
+#include "DependenciesTable.h"
+#include "SQLiteStatementBuilder.h"
+#include "winget\DependenciesGraph.h"
+#include "Microsoft/Schema/1_0/OneToOneTable.h"
+#include "Microsoft/Schema/1_0/IdTable.h"
+#include "Microsoft/Schema/1_0/ManifestTable.h"
+#include "Microsoft/Schema/1_0/VersionTable.h"
+#include "Microsoft/Schema/1_0/Interface.h"
+#include "Microsoft/Schema/1_0/ChannelTable.h"
+
+namespace AppInstaller::Repository::Microsoft::Schema::V1_4
+{
+ using namespace AppInstaller;
+ using namespace std::string_view_literals;
+ using namespace SQLite::Builder;
+ using namespace Schema::V1_0;
+ using QCol = SQLite::Builder::QualifiedColumn;
+
+ static constexpr std::string_view s_DependenciesTable_Table_Name = "dependencies"sv;
+ static constexpr std::string_view s_DependenciesTable_Index_Name = "dependencies_pkindex"sv;
+ static constexpr std::string_view s_DependenciesTable_Manifest_Column_Name = "manifest"sv;
+ static constexpr std::string_view s_DependenciesTable_MinVersion_Column_Name = "min_version"sv;
+ static constexpr std::string_view s_DependenciesTable_PackageId_Column_Name = "package_id";
+
+ namespace
+ {
+ struct DependencyTableRow
+ {
+ SQLite::rowid_t m_packageRowId;
+ SQLite::rowid_t m_manifestRowId;
+
+ // Ideally this should be version row id, the version string is more needed than row id,
+ // this prevents converting back and forth between version row id and version string.
+ std::optional m_version;
+
+ bool operator <(const DependencyTableRow& rhs) const
+ {
+ auto lhsVersion = m_version.has_value() ? m_version.value() : "";
+ auto rhsVersion = rhs.m_version.has_value() ? rhs.m_version.value() : "";
+ return std::tie(m_packageRowId, m_manifestRowId, lhsVersion) < std::tie(rhs.m_packageRowId, rhs.m_manifestRowId, rhsVersion);
+ }
+ };
+
+ void ThrowOnMissingPackageNodes(std::vector& missingPackageNodes)
+ {
+ if (!missingPackageNodes.empty())
+ {
+ std::string missingPackages{ missingPackageNodes.begin()->Id };
+ std::for_each(
+ missingPackageNodes.begin() + 1,
+ missingPackageNodes.end(),
+ [&](auto& dep) { missingPackages.append(", " + dep.Id); });
+ THROW_HR_MSG(APPINSTALLER_CLI_ERROR_MISSING_PACKAGE, "Missing packages: %hs", missingPackages.c_str());
+ }
+ }
+
+ std::set GetAndLinkDependencies(
+ SQLite::Connection& connection,
+ const Manifest::Manifest& manifest,
+ SQLite::rowid_t manifestRowId,
+ Manifest::DependencyType dependencyType)
+ {
+ std::set dependencies;
+ std::vector missingPackageNodes;
+
+ for (const auto& installer : manifest.Installers)
+ {
+ installer.Dependencies.ApplyToType(dependencyType, [&](Manifest::Dependency dependency)
+ {
+ auto packageRowId = IdTable::SelectIdByValue(connection, dependency.Id);
+ std::optional version;
+
+ if (!packageRowId.has_value())
+ {
+ missingPackageNodes.emplace_back(dependency);
+ return;
+ }
+
+ if (dependency.MinVersion.has_value())
+ {
+ version = dependency.MinVersion.value().ToString();
+ }
+
+ dependencies.emplace(DependencyTableRow{ packageRowId.value(), manifestRowId, version });
+ });
+ }
+
+ ThrowOnMissingPackageNodes(missingPackageNodes);
+
+ return dependencies;
+ }
+
+ bool RemoveDependenciesByRowIds(SQLite::Connection& connection, std::vector dependencyTableRows)
+ {
+ using namespace SQLite::Builder;
+ bool tableUpdated = false;
+ if (dependencyTableRows.empty())
+ {
+ return tableUpdated;
+ }
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_rowid");
+
+ SQLite::Builder::StatementBuilder builder;
+ builder
+ .DeleteFrom(s_DependenciesTable_Table_Name)
+ .Where(s_DependenciesTable_PackageId_Column_Name).Equals(Unbound)
+ .And(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound);
+
+
+ SQLite::Statement deleteStmt = builder.Prepare(connection);
+ for (auto row : dependencyTableRows)
+ {
+ deleteStmt.Reset();
+ deleteStmt.Bind(1, row.m_packageRowId);
+ deleteStmt.Bind(2, row.m_manifestRowId);
+ deleteStmt.Execute(true);
+ tableUpdated = true;
+ }
+
+ savepoint.Commit();
+ return tableUpdated;
+ }
+
+ bool InsertManifestDependencies(
+ SQLite::Connection& connection,
+ std::set& dependenciesTableRows)
+ {
+ using namespace SQLite::Builder;
+ using namespace Schema::V1_0;
+ bool tableUpdated = false;
+
+ StatementBuilder insertBuilder;
+ insertBuilder.InsertInto(s_DependenciesTable_Table_Name)
+ .Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_MinVersion_Column_Name, s_DependenciesTable_PackageId_Column_Name })
+ .Values(Unbound, Unbound, Unbound);
+ SQLite::Statement insert = insertBuilder.Prepare(connection);
+
+
+ for (const auto& dep : dependenciesTableRows)
+ {
+ insert.Reset();
+ insert.Bind(1, dep.m_manifestRowId);
+
+ if (dep.m_version.has_value())
+ {
+ insert.Bind(2, VersionTable::EnsureExists(connection, dep.m_version.value()));
+ }
+ else
+ {
+ insert.Bind(2, nullptr);
+ }
+
+ insert.Bind(3, dep.m_packageRowId);
+
+ insert.Execute(true);
+ tableUpdated = true;
+ }
+
+ return tableUpdated;
+ }
+ }
+
+ bool DependenciesTable::Exists(const SQLite::Connection& connection)
+ {
+ using namespace SQLite;
+
+ Builder::StatementBuilder builder;
+ builder.Select(Builder::RowCount).From(Builder::Schema::MainTable).
+ Where(Builder::Schema::TypeColumn).Equals(Builder::Schema::Type_Table).And(Builder::Schema::NameColumn).Equals(s_DependenciesTable_Table_Name);
+
+ Statement statement = builder.Prepare(connection);
+ THROW_HR_IF(E_UNEXPECTED, !statement.Step());
+ return statement.GetColumn(0) != 0;
+ }
+
+ std::string_view DependenciesTable::TableName()
+ {
+ return s_DependenciesTable_Table_Name;
+ }
+
+ void DependenciesTable::Create(SQLite::Connection& connection)
+ {
+ using namespace SQLite::Builder;
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createDependencyTable_v1_4");
+ constexpr std::string_view dependencyIndexByVersionId = "dependencies_version_id_index";
+ constexpr std::string_view dependencyIndexByPackageId = "dependencies_package_id_index";
+
+ StatementBuilder createTableBuilder;
+ createTableBuilder.CreateTable(TableName()).BeginColumns();
+ createTableBuilder.Column(IntegerPrimaryKey());
+
+ std::array notNullableDependenciesColumns
+ {
+ DependenciesTableColumnInfo{ s_DependenciesTable_Manifest_Column_Name },
+ DependenciesTableColumnInfo{ s_DependenciesTable_PackageId_Column_Name }
+ };
+
+ std::array nullableDependenciesColumns
+ {
+ DependenciesTableColumnInfo{ s_DependenciesTable_MinVersion_Column_Name }
+ };
+
+ // Add dependencies column tables not null columns.
+ for (const DependenciesTableColumnInfo& value : notNullableDependenciesColumns)
+ {
+ createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId).NotNull());
+ }
+
+ // Add dependencies column tables null columns.
+ for (const DependenciesTableColumnInfo& value : nullableDependenciesColumns)
+ {
+ createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId));
+ }
+
+ createTableBuilder.EndColumns();
+
+ createTableBuilder.Execute(connection);
+
+ // Primary key index by package rowid and manifest rowid.
+ StatementBuilder createPKIndexBuilder;
+ createPKIndexBuilder.CreateUniqueIndex(s_DependenciesTable_Index_Name).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_PackageId_Column_Name });
+ createPKIndexBuilder.Execute(connection);
+
+ // Index of dependency by Manifest id.
+ StatementBuilder createIndexByManifestIdBuilder;
+ createIndexByManifestIdBuilder.CreateIndex(dependencyIndexByVersionId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_MinVersion_Column_Name });
+ createIndexByManifestIdBuilder.Execute(connection);
+
+ // Index of dependency by package id.
+ StatementBuilder createIndexByPackageIdBuilder;
+ createIndexByPackageIdBuilder.CreateIndex(dependencyIndexByPackageId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_PackageId_Column_Name });
+ createIndexByPackageIdBuilder.Execute(connection);
+
+ savepoint.Commit();
+ }
+
+ void DependenciesTable::AddDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId)
+ {
+ if (!Exists(connection))
+ {
+ return;
+ }
+
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "add_dependencies_v1_4");
+
+ auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package);
+ if (!dependencies.size())
+ {
+ return;
+ }
+
+ InsertManifestDependencies(connection, dependencies);
+
+ savepoint.Commit();
+ }
+
+ bool DependenciesTable::UpdateDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId)
+ {
+ if (!Exists(connection))
+ {
+ return false;
+ }
+
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "update_dependencies_v1_4");
+
+ const auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package);
+ auto existingDependencies = GetDependenciesByManifestRowId(connection, manifestRowId);
+
+ // Get dependencies to add.
+ std::set toAddDependencies;
+ std::copy_if(
+ dependencies.begin(),
+ dependencies.end(),
+ std::inserter(toAddDependencies, toAddDependencies.begin()),
+ [&](DependencyTableRow dep)
+ {
+ Utility::NormalizedString version = dep.m_version.has_value() ? dep.m_version.value() : "";
+ return existingDependencies.find(std::make_pair(dep.m_packageRowId, version)) == existingDependencies.end();
+ }
+ );
+
+ // Get dependencies to remove.
+ std::vector toRemoveDependencies;
+ std::for_each(
+ existingDependencies.begin(),
+ existingDependencies.end(),
+ [&](std::pair row)
+ {
+ if (dependencies.find(DependencyTableRow{row.first, manifestRowId, row.second}) == dependencies.end())
+ {
+ toRemoveDependencies.emplace_back(DependencyTableRow{ row.first, manifestRowId });
+ }
+ }
+ );
+
+ bool tableUpdated = InsertManifestDependencies(connection, toAddDependencies);
+ tableUpdated = RemoveDependenciesByRowIds(connection, toRemoveDependencies) || tableUpdated;
+ savepoint.Commit();
+
+ return tableUpdated;
+ }
+
+ void DependenciesTable::RemoveDependencies(SQLite::Connection& connection, SQLite::rowid_t manifestRowId)
+ {
+ if (!Exists(connection))
+ {
+ return;
+ }
+
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_manifest_v1_4");
+
+ SQLite::Builder::StatementBuilder builder;
+ builder.DeleteFrom(s_DependenciesTable_Table_Name).Where(s_DependenciesTable_Manifest_Column_Name).Equals(manifestRowId);
+
+ builder.Execute(connection);
+ savepoint.Commit();
+ }
+
+ std::vector> DependenciesTable::GetDependentsById(const SQLite::Connection& connection, Manifest::string_t packageId)
+ {
+ constexpr std::string_view depTableAlias = "dep";
+ constexpr std::string_view minVersionAlias = "minV";
+ constexpr std::string_view packageIdAlias = "pId";
+
+
+ StatementBuilder builder;
+ // Find all manifest that depend on this package.
+ // SELECT [dep].[manifest], [pId].[id], [minV].[version] FROM [dependencies] AS [dep]
+ // JOIN [versions] AS [minV] ON [dep].[min_version] = [minV].[rowid]
+ // JOIN [ids] AS [pId] ON [pId].[rowid] = [dep].[package_id]
+ // WHERE [pId].[id] = ?
+ builder.Select()
+ .Column(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name))
+ .Column(QCol(packageIdAlias, IdTable::ValueName()))
+ .Column(QCol(minVersionAlias, VersionTable::ValueName()))
+ .From({ s_DependenciesTable_Table_Name }).As(depTableAlias)
+ .Join({ VersionTable::TableName() }).As(minVersionAlias)
+ .On(QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name), QCol(minVersionAlias, SQLite::RowIDName))
+ .Join({ IdTable::TableName() }).As(packageIdAlias)
+ .On(QCol(packageIdAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name))
+ .Where(QCol(packageIdAlias, IdTable::ValueName())).Equals(Unbound);
+
+ SQLite::Statement stmt = builder.Prepare(connection);
+ stmt.Bind(1, std::string{ packageId });
+
+ std::vector> resultSet;
+
+ while (stmt.Step())
+ {
+ resultSet.emplace_back(
+ std::make_pair(stmt.GetColumn(0), Utility::NormalizedString(stmt.GetColumn(2))));
+ }
+
+ return resultSet;
+ }
+
+ std::set> DependenciesTable::GetDependenciesByManifestRowId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId)
+ {
+ SQLite::Builder::StatementBuilder builder;
+
+ constexpr std::string_view depTableAlias = "dep";
+ constexpr std::string_view minVersionAlias = "minV";
+
+ std::set> resultSet;
+
+ // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep]
+ // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version]
+ // WHERE [dep].[manifest] = ?
+ builder.Select()
+ .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name))
+ .Column(QCol(minVersionAlias, VersionTable::ValueName()))
+ .From({ s_DependenciesTable_Table_Name }).As(depTableAlias)
+ .Join({ VersionTable::TableName() }).As(minVersionAlias)
+ .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name))
+ .Where(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name)).Equals(Unbound);
+
+ SQLite::Statement select = builder.Prepare(connection);
+
+ select.Bind(1, manifestRowId);
+ while (select.Step())
+ {
+ Utility::NormalizedString version = "";
+ if (!select.GetColumnIsNull(1))
+ {
+ version = select.GetColumn(1);
+ }
+ resultSet.emplace(std::make_pair(select.GetColumn(0), version));
+ }
+
+ return resultSet;
+ }
+
+ void DependenciesTable::PrepareForPackaging(SQLite::Connection& connection)
+ {
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "prepareForPacking_V1_4");
+
+ StatementBuilder dropIndexBuilder;
+ dropIndexBuilder.DropIndex({ s_DependenciesTable_Index_Name });
+ dropIndexBuilder.Execute(connection);
+
+ StatementBuilder dropTableBuilder;
+ dropTableBuilder.DropTable({ s_DependenciesTable_Table_Name });
+ dropTableBuilder.Execute(connection);
+
+ savepoint.Commit();
+ }
+
+ bool DependenciesTable::CheckConsistency(const SQLite::Connection& connection, bool log)
+ {
+ StatementBuilder builder;
+
+ if (!Exists(connection))
+ {
+ return true;
+ }
+
+ builder.Select(QCol(s_DependenciesTable_Table_Name, SQLite::RowIDName))
+ .From(s_DependenciesTable_Table_Name)
+ .LeftOuterJoin(IdTable::TableName())
+ .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_PackageId_Column_Name), QCol(IdTable::TableName(), SQLite::RowIDName))
+ .LeftOuterJoin(ManifestTable::TableName())
+ .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_Manifest_Column_Name), QCol(ManifestTable::TableName(), SQLite::RowIDName))
+ .LeftOuterJoin(VersionTable::TableName())
+ .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name), QCol(VersionTable::TableName(), SQLite::RowIDName))
+ .Where(QCol(ManifestTable::TableName(), SQLite::RowIDName)).IsNull()
+ .Or(QCol(VersionTable::TableName(), SQLite::RowIDName)).IsNull().And(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name)).IsNotNull()
+ .Or(QCol(IdTable::TableName(), SQLite::RowIDName)).IsNull();
+
+ SQLite::Statement select = builder.Prepare(connection);
+
+ bool result = true;
+
+ while (select.Step())
+ {
+ result = false;
+
+ if (!log)
+ {
+ break;
+ }
+
+ AICLI_LOG(Repo, Info, << " [INVALID] row in [" << s_DependenciesTable_Table_Name << "] rowid [" << select.GetColumn(0) << "]");
+ }
+
+ return result;
+ }
+
+ bool DependenciesTable::IsValueReferenced(const SQLite::Connection& connection, std::string_view tableName, SQLite::rowid_t valueRowId)
+ {
+ if (!Exists(connection))
+ {
+ return false;
+ }
+
+ StatementBuilder builder;
+
+ if (tableName != V1_0::VersionTable::TableName())
+ {
+ return false;
+ }
+
+ std::array columns = { s_DependenciesTable_MinVersion_Column_Name };
+ bool referenced = false;
+
+ for(auto column: columns)
+ {
+ builder.Select(SQLite::RowIDName).From(s_DependenciesTable_Table_Name).Where(column).Equals(Unbound).Limit(1);
+
+ SQLite::Statement select = builder.Prepare(connection);
+
+ select.Bind(1, valueRowId);
+ if (select.Step())
+ {
+ referenced = true;
+ break;
+ }
+ }
+
+ return referenced;
+ }
+
+ std::vector DependenciesTable::GetDependenciesMinVersionsRowIdByManifestId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId)
+ {
+ if (!Exists(connection))
+ {
+ return {};
+ }
+
+ StatementBuilder builder;
+
+ std::vector result;
+
+ // Find all versions for manifest row.
+ // SELECT [min_version] FROM [dependencies]
+ // WHERE [manifest] = ?
+ builder.Select()
+ .Column(s_DependenciesTable_MinVersion_Column_Name)
+ .From({ s_DependenciesTable_Table_Name })
+ .Where(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound);
+
+ auto select = builder.Prepare(connection);
+
+ select.Bind(1, manifestRowId);
+
+ while (select.Step())
+ {
+ if (!select.GetColumnIsNull(0))
+ {
+ result.emplace_back(select.GetColumn(0));
+ }
+ }
+
+ return result;
+ }
+
std::vector> DependenciesTable::GetAllDependenciesWithMinVersions(const SQLite::Connection& connection)
{
- if (!Exists(connection))
- {
- return {};
- }
-
+ if (!Exists(connection))
+ {
+ return {};
+ }
+
std::vector> result;
- constexpr std::string_view depTableAlias = "dep";
- constexpr std::string_view minVersionAlias = "minV";
-
- StatementBuilder builder;
-
- // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep]
- // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version]
- builder.Select()
- .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name))
- .Column(QCol(minVersionAlias, VersionTable::ValueName()))
- .From({ s_DependenciesTable_Table_Name }).As(depTableAlias)
- .Join({ VersionTable::TableName() }).As(minVersionAlias)
- .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name));
-
- SQLite::Statement select = builder.Prepare(connection);
-
- while (select.Step())
- {
- Utility::NormalizedString version = "";
- if (!select.GetColumnIsNull(1))
- {
- version = select.GetColumn(1);
- }
-
- if (!version.empty())
- {
- result.emplace_back(std::make_pair(select.GetColumn(0), version));
- }
- }
-
+ constexpr std::string_view depTableAlias = "dep";
+ constexpr std::string_view minVersionAlias = "minV";
+
+ StatementBuilder builder;
+
+ // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep]
+ // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version]
+ builder.Select()
+ .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name))
+ .Column(QCol(minVersionAlias, VersionTable::ValueName()))
+ .From({ s_DependenciesTable_Table_Name }).As(depTableAlias)
+ .Join({ VersionTable::TableName() }).As(minVersionAlias)
+ .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name));
+
+ SQLite::Statement select = builder.Prepare(connection);
+
+ while (select.Step())
+ {
+ Utility::NormalizedString version = "";
+ if (!select.GetColumnIsNull(1))
+ {
+ version = select.GetColumn(1);
+ }
+
+ if (!version.empty())
+ {
+ result.emplace_back(std::make_pair(select.GetColumn(0), version));
+ }
+ }
+
return result;
- }
+ }
}
\ No newline at end of file
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h
index 703b58ccc1..5a37d6a4d9 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h
@@ -41,7 +41,7 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_4
static std::vector> GetDependentsById(const SQLite::Connection& connection, AppInstaller::Manifest::string_t packageId);
// Check dependencies table consistency.
- static bool DependenciesTableCheckConsistency(const SQLite::Connection& connection, bool log);
+ static bool CheckConsistency(const SQLite::Connection& connection, bool log);
// Checks if the row id is present in the column denoted by the value supplied.
static bool IsValueReferenced(const SQLite::Connection& connection, std::string_view valueName, SQLite::rowid_t valueRowId);
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp
index 9c3bf3a686..439c10e861 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp
@@ -117,7 +117,7 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_4
// If the v1.3 index was consistent, or if full logging of inconsistency was requested, check the v1.4 data.
if (result || log)
{
- result = DependenciesTable::DependenciesTableCheckConsistency(connection, log) && result;
+ result = DependenciesTable::CheckConsistency(connection, log) && result;
}
if (result || log)
diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp
index 38844df57c..a3d6703265 100644
--- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp
+++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp
@@ -1,202 +1,202 @@
-// Copyright (c) Microsoft Corporation.
-// Licensed under the MIT License.
-#include "pch.h"
-#include "Microsoft/Schema/1_5/Interface.h"
-#include "Microsoft/Schema/1_5/ArpVersionVirtualTable.h"
-#include "Microsoft/Schema/1_0/ManifestTable.h"
-#include "Microsoft/Schema/1_0/VersionTable.h"
-
-namespace AppInstaller::Repository::Microsoft::Schema::V1_5
-{
- Interface::Interface(Utility::NormalizationVersion normVersion) : V1_4::Interface(normVersion)
- {
- }
-
- Schema::Version Interface::GetVersion() const
- {
- return { 1, 5 };
- }
-
- void Interface::CreateTables(SQLite::Connection& connection, CreateOptions options)
- {
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createtables_v1_5");
-
- V1_4::Interface::CreateTables(connection, options);
-
- V1_0::ManifestTable::AddColumn(connection, { ArpMinVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::Int64});
- V1_0::ManifestTable::AddColumn(connection, { ArpMaxVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::Int64 });
-
- savepoint.Commit();
- }
-
- SQLite::rowid_t Interface::AddManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath)
- {
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "addmanifest_v1_5");
-
- SQLite::rowid_t manifestId = V1_4::Interface::AddManifest(connection, manifest, relativePath);
-
- auto arpVersionRange = manifest.GetArpVersionRange();
- Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString();
- Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString();
- SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion);
- SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion);
- V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId);
- V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId);
-
- savepoint.Commit();
-
- return manifestId;
- }
-
- std::pair Interface::UpdateManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath)
- {
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "updatemanifest_v1_5");
-
- auto [indexModified, manifestId] = V1_4::Interface::UpdateManifest(connection, manifest, relativePath);
-
- auto [oldMinVersionId, oldMaxVersionId] =
- V1_0::ManifestTable::GetIdsById(connection, manifestId);
-
- auto arpVersionRange = manifest.GetArpVersionRange();
- Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString();
- Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString();
-
- SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion);
- SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion);
-
- // For cleaning up the old entries after update if applicable
- bool cleanOldMinVersionId = false;
- bool cleanOldMaxVersionId = false;
-
- if (arpMinVersionId != oldMinVersionId)
- {
- V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId);
- cleanOldMinVersionId = true;
- indexModified = true;
- }
-
- if (arpMaxVersionId != oldMaxVersionId)
- {
- V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId);
- cleanOldMaxVersionId = true;
- indexModified = true;
- }
-
- if (cleanOldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMinVersionId))
- {
- V1_0::VersionTable::DeleteById(connection, oldMinVersionId);
- }
-
- if (cleanOldMaxVersionId && oldMaxVersionId != oldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMaxVersionId))
- {
- V1_0::VersionTable::DeleteById(connection, oldMaxVersionId);
- }
-
- savepoint.Commit();
-
- return { indexModified, manifestId };
- }
-
- void Interface::RemoveManifestById(SQLite::Connection& connection, SQLite::rowid_t manifestId)
- {
- // Get the old arp version ids of the values from the manifest table
- auto [arpMinVersionId, arpMaxVersionId] =
- V1_0::ManifestTable::GetIdsById(connection, manifestId);
-
- SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "RemoveManifestById_v1_5");
-
- // Removes the manifest.
- V1_4::Interface::RemoveManifestById(connection, manifestId);
-
- // Remove the versions that are not needed.
- if (NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMinVersionId))
- {
- V1_0::VersionTable::DeleteById(connection, arpMinVersionId);
- }
-
- if (arpMaxVersionId != arpMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMaxVersionId))
- {
- V1_0::VersionTable::DeleteById(connection, arpMaxVersionId);
- }
-
- savepoint.Commit();
- }
-
- bool Interface::NotNeeded(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, SQLite::rowid_t id) const
- {
- bool result = V1_4::Interface::NotNeeded(connection, tableName, valueName, id);
-
- if (result && tableName == V1_0::VersionTable::TableName())
- {
- if (valueName != V1_0::VersionTable::ValueName())
- {
- result = !V1_0::ManifestTable::IsValueReferenced(connection, V1_0::VersionTable::ValueName(), id) && result;
- }
- if (valueName != ArpMinVersionVirtualTable::ManifestColumnName())
- {
- result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMinVersionVirtualTable::ManifestColumnName(), id) && result;
- }
- if (valueName != ArpMaxVersionVirtualTable::ManifestColumnName())
- {
- result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMaxVersionVirtualTable::ManifestColumnName(), id) && result;
- }
- }
-
- return result;
- }
-
- bool Interface::CheckConsistency(const SQLite::Connection& connection, bool log) const
- {
- bool result = V1_4::Interface::CheckConsistency(connection, log);
-
- // If the v1.4 index was consistent, or if full logging of inconsistency was requested, check the v1.5 data.
- if (result || log)
- {
- result = V1_0::ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = V1_0::ManifestTable::CheckConsistency(connection, log) && result;
- }
-
- if (result || log)
- {
- result = ValidateArpVersionConsistency(connection, log) && result;
- }
-
- return result;
- }
-
- std::optional Interface::GetPropertyByManifestIdInternal(const SQLite::Connection& connection, SQLite::rowid_t manifestId, PackageVersionProperty property) const
- {
- switch (property)
- {
- case AppInstaller::Repository::PackageVersionProperty::ArpMinVersion:
- return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId));
- case AppInstaller::Repository::PackageVersionProperty::ArpMaxVersion:
- return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId));
- default:
- return V1_4::Interface::GetPropertyByManifestIdInternal(connection, manifestId, property);
- }
- }
-
+// Copyright (c) Microsoft Corporation.
+// Licensed under the MIT License.
+#include "pch.h"
+#include "Microsoft/Schema/1_5/Interface.h"
+#include "Microsoft/Schema/1_5/ArpVersionVirtualTable.h"
+#include "Microsoft/Schema/1_0/ManifestTable.h"
+#include "Microsoft/Schema/1_0/VersionTable.h"
+
+namespace AppInstaller::Repository::Microsoft::Schema::V1_5
+{
+ Interface::Interface(Utility::NormalizationVersion normVersion) : V1_4::Interface(normVersion)
+ {
+ }
+
+ Schema::Version Interface::GetVersion() const
+ {
+ return { 1, 5 };
+ }
+
+ void Interface::CreateTables(SQLite::Connection& connection, CreateOptions options)
+ {
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createtables_v1_5");
+
+ V1_4::Interface::CreateTables(connection, options);
+
+ V1_0::ManifestTable::AddColumn(connection, { ArpMinVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::RowId });
+ V1_0::ManifestTable::AddColumn(connection, { ArpMaxVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::RowId });
+
+ savepoint.Commit();
+ }
+
+ SQLite::rowid_t Interface::AddManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath)
+ {
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "addmanifest_v1_5");
+
+ SQLite::rowid_t manifestId = V1_4::Interface::AddManifest(connection, manifest, relativePath);
+
+ auto arpVersionRange = manifest.GetArpVersionRange();
+ Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString();
+ Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString();
+ SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion);
+ SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion);
+ V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId);
+ V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId);
+
+ savepoint.Commit();
+
+ return manifestId;
+ }
+
+ std::pair Interface::UpdateManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath)
+ {
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "updatemanifest_v1_5");
+
+ auto [indexModified, manifestId] = V1_4::Interface::UpdateManifest(connection, manifest, relativePath);
+
+ auto [oldMinVersionId, oldMaxVersionId] =
+ V1_0::ManifestTable::GetIdsById(connection, manifestId);
+
+ auto arpVersionRange = manifest.GetArpVersionRange();
+ Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString();
+ Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString();
+
+ SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion);
+ SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion);
+
+ // For cleaning up the old entries after update if applicable
+ bool cleanOldMinVersionId = false;
+ bool cleanOldMaxVersionId = false;
+
+ if (arpMinVersionId != oldMinVersionId)
+ {
+ V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId);
+ cleanOldMinVersionId = true;
+ indexModified = true;
+ }
+
+ if (arpMaxVersionId != oldMaxVersionId)
+ {
+ V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId);
+ cleanOldMaxVersionId = true;
+ indexModified = true;
+ }
+
+ if (cleanOldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMinVersionId))
+ {
+ V1_0::VersionTable::DeleteById(connection, oldMinVersionId);
+ }
+
+ if (cleanOldMaxVersionId && oldMaxVersionId != oldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMaxVersionId))
+ {
+ V1_0::VersionTable::DeleteById(connection, oldMaxVersionId);
+ }
+
+ savepoint.Commit();
+
+ return { indexModified, manifestId };
+ }
+
+ void Interface::RemoveManifestById(SQLite::Connection& connection, SQLite::rowid_t manifestId)
+ {
+ // Get the old arp version ids of the values from the manifest table
+ auto [arpMinVersionId, arpMaxVersionId] =
+ V1_0::ManifestTable::GetIdsById(connection, manifestId);
+
+ SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "RemoveManifestById_v1_5");
+
+ // Removes the manifest.
+ V1_4::Interface::RemoveManifestById(connection, manifestId);
+
+ // Remove the versions that are not needed.
+ if (NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMinVersionId))
+ {
+ V1_0::VersionTable::DeleteById(connection, arpMinVersionId);
+ }
+
+ if (arpMaxVersionId != arpMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMaxVersionId))
+ {
+ V1_0::VersionTable::DeleteById(connection, arpMaxVersionId);
+ }
+
+ savepoint.Commit();
+ }
+
+ bool Interface::NotNeeded(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, SQLite::rowid_t id) const
+ {
+ bool result = V1_4::Interface::NotNeeded(connection, tableName, valueName, id);
+
+ if (result && tableName == V1_0::VersionTable::TableName())
+ {
+ if (valueName != V1_0::VersionTable::ValueName())
+ {
+ result = !V1_0::ManifestTable::IsValueReferenced(connection, V1_0::VersionTable::ValueName(), id) && result;
+ }
+ if (valueName != ArpMinVersionVirtualTable::ManifestColumnName())
+ {
+ result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMinVersionVirtualTable::ManifestColumnName(), id) && result;
+ }
+ if (valueName != ArpMaxVersionVirtualTable::ManifestColumnName())
+ {
+ result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMaxVersionVirtualTable::ManifestColumnName(), id) && result;
+ }
+ }
+
+ return result;
+ }
+
+ bool Interface::CheckConsistency(const SQLite::Connection& connection, bool log) const
+ {
+ bool result = V1_4::Interface::CheckConsistency(connection, log);
+
+ // If the v1.4 index was consistent, or if full logging of inconsistency was requested, check the v1.5 data.
+ if (result || log)
+ {
+ result = V1_0::ManifestTable::CheckConsistency(connection, log) && result;
+ }
+
+ if (result || log)
+ {
+ result = V1_0::ManifestTable::CheckConsistency(connection, log) && result;
+ }
+
+ if (result || log)
+ {
+ result = ValidateArpVersionConsistency(connection, log) && result;
+ }
+
+ return result;
+ }
+
+ std::optional Interface::GetPropertyByManifestIdInternal(const SQLite::Connection& connection, SQLite::rowid_t manifestId, PackageVersionProperty property) const
+ {
+ switch (property)
+ {
+ case AppInstaller::Repository::PackageVersionProperty::ArpMinVersion:
+ return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId));
+ case AppInstaller::Repository::PackageVersionProperty::ArpMaxVersion:
+ return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId));
+ default:
+ return V1_4::Interface::GetPropertyByManifestIdInternal(connection, manifestId, property);
+ }
+ }
+
bool Interface::ValidateArpVersionConsistency(const SQLite::Connection& connection, bool log) const
{
- try
- {
- bool result = true;
-
- // Search everything
- SearchRequest request;
- auto searchResult = Search(connection, request);
- for (auto const& match : searchResult.Matches)
- {
- // Get arp version ranges for each package to check
- std::vector ranges;
- auto versionKeys = GetVersionKeysById(connection, match.first);
- for (auto const& versionKey : versionKeys)
- {
+ try
+ {
+ bool result = true;
+
+ // Search everything
+ SearchRequest request;
+ auto searchResult = Search(connection, request);
+ for (auto const& match : searchResult.Matches)
+ {
+ // Get arp version ranges for each package to check
+ std::vector ranges;
+ auto versionKeys = GetVersionKeysById(connection, match.first);
+ for (auto const& versionKey : versionKeys)
+ {
auto manifestRowId = GetManifestIdByKey(connection, match.first, versionKey.GetVersion().ToString(), versionKey.GetChannel().ToString());
if (manifestRowId)
{
@@ -210,28 +210,28 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_5
{
ranges.emplace_back(Utility::VersionRange{ Utility::Version{ std::move(arpMinVersion) }, Utility::Version{ std::move(arpMaxVersion) } });
}
- }
- }
-
- // Check overlap
- if (Utility::HasOverlapInVersionRanges(ranges))
- {
- AICLI_LOG(Repo, Error, << "Overlapped Arp version ranges found for package. PackageRowId: " << match.first);
- result = false;
-
- if (!log)
- {
- break;
- }
- }
- }
-
- return result;
- }
- catch (...)
- {
- AICLI_LOG(Repo, Error, << "ValidateArpVersionConsistency() encountered internal error. Returning false.");
- return false;
+ }
+ }
+
+ // Check overlap
+ if (Utility::HasOverlapInVersionRanges(ranges))
+ {
+ AICLI_LOG(Repo, Error, << "Overlapped Arp version ranges found for package. PackageRowId: " << match.first);
+ result = false;
+
+ if (!log)
+ {
+ break;
+ }
+ }
+ }
+
+ return result;
+ }
+ catch (...)
+ {
+ AICLI_LOG(Repo, Error, << "ValidateArpVersionConsistency() encountered internal error. Returning false.");
+ return false;
}
- }
+ }
}
\ No newline at end of file
diff --git a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp
index 13e1a0b70d..e431375fb5 100644
--- a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp
+++ b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp
@@ -307,6 +307,20 @@ namespace AppInstaller::Repository::SQLite::Builder
return *this;
}
+ StatementBuilder& StatementBuilder::WhereValueContainsEmbeddedNullCharacter(std::string_view column)
+ {
+ OutputColumns(m_stream, " WHERE instr(", column);
+ m_stream << ",char(0))>0";
+ return *this;
+ }
+
+ StatementBuilder& StatementBuilder::WhereValueContainsEmbeddedNullCharacter(const QualifiedColumn& column)
+ {
+ OutputColumns(m_stream, " WHERE instr(", column);
+ m_stream << ",char(0))>0";
+ return *this;
+ }
+
StatementBuilder& StatementBuilder::Equals(details::unbound_t)
{
AppendOpAndBinder(Op::Equals);
diff --git a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h
index e82a8f4ea9..3bcb632d6e 100644
--- a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h
+++ b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h
@@ -217,6 +217,11 @@ namespace AppInstaller::Repository::SQLite::Builder
StatementBuilder& Where(std::string_view column);
StatementBuilder& Where(const QualifiedColumn& column);
+ // A full filter clause looking for an embedded null character.
+ // Is extremely specific to consistency checks, and so a more detailed construct is not required.
+ StatementBuilder& WhereValueContainsEmbeddedNullCharacter(std::string_view column);
+ StatementBuilder& WhereValueContainsEmbeddedNullCharacter(const QualifiedColumn& column);
+
// Indicate the operation of the filter clause.
template
StatementBuilder& Equals(const ValueType& value)
diff --git a/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp b/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp
index 5e72b13f6d..a19d720734 100644
--- a/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp
+++ b/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp
@@ -50,8 +50,14 @@ namespace AppInstaller::Repository::SQLite
THROW_IF_SQLITE_FAILED(sqlite3_bind_null(stmt, index));
}
+ void ThrowIfContainsEmbeddedNullCharacter(std::string_view v)
+ {
+ THROW_HR_IF(APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL, v.find('\0') != std::string_view::npos);
+ }
+
void ParameterSpecificsImpl::Bind(sqlite3_stmt* stmt, int index, const std::string& v)
{
+ ThrowIfContainsEmbeddedNullCharacter(v);
THROW_IF_SQLITE_FAILED(sqlite3_bind_text64(stmt, index, v.c_str(), v.size(), SQLITE_TRANSIENT, SQLITE_UTF8));
}
@@ -70,6 +76,7 @@ namespace AppInstaller::Repository::SQLite
}
else
{
+ ThrowIfContainsEmbeddedNullCharacter(v);
THROW_IF_SQLITE_FAILED(sqlite3_bind_text64(stmt, index, v.data(), v.size(), SQLITE_TRANSIENT, SQLITE_UTF8));
}
}