diff --git a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters index 2df7b1de94..d7db2a17cd 100644 --- a/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters +++ b/src/AppInstallerCLITests/AppInstallerCLITests.vcxproj.filters @@ -27,7 +27,16 @@ {52b26063-ccff-40ae-a420-243327dcca25} - + + + {b35e1a8b-1961-46d5-98e5-adc8e52c54a9} + + + {d055e02a-59c9-4ae4-9405-579dac6ff59b} + + + {13d4d227-0f04-4e57-a663-c3c535438ab3} + @@ -59,146 +68,146 @@ Source Files - - Source Files - - - Source Files - Source Files - + Source Files - + Source Files - - Source Files + + Source Files\Repository - - Source Files + + Source Files\CLI - - Source Files + + Source Files\CLI - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\CLI - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\CLI - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - Source Files + Source Files\Common - Source Files + Source Files\CLI - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\Repository - Source Files + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Common - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Repository - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - + Source Files - - Source Files + + Source Files\Common - - Source Files + + Source Files\Common - - Source Files + + Source Files\CLI + + + Source Files\CLI + + + Source Files\Common @@ -479,7 +488,7 @@ TestData - + TestData @@ -506,7 +515,7 @@ TestData - + TestData diff --git a/src/AppInstallerCLITests/CompositeSource.cpp b/src/AppInstallerCLITests/CompositeSource.cpp index e1eb0865a3..273fd27571 100644 --- a/src/AppInstallerCLITests/CompositeSource.cpp +++ b/src/AppInstallerCLITests/CompositeSource.cpp @@ -50,11 +50,11 @@ struct ComponentTestSource : public TestSource // A helper to create the sources used by the majority of tests in this file. struct CompositeTestSetup { - CompositeTestSetup() : Composite("*Tests") + CompositeTestSetup(CompositeSearchBehavior behavior = CompositeSearchBehavior::Installed) : Composite("*Tests") { Installed = std::make_shared("InstalledTestSource1"); Available = std::make_shared("AvailableTestSource1"); - Composite.SetInstalledSource(Source{ Installed }); + Composite.SetInstalledSource(Source{ Installed }, behavior); Composite.AddAvailableSource(Source{ Available }); } @@ -942,3 +942,23 @@ TEST_CASE("CompositeSource_TrackingFound_NotInstalled", "[CompositeSource]") REQUIRE(result.Matches.empty()); } + +TEST_CASE("CompositeSource_NullInstalledVersion", "[CompositeSource]") +{ + CompositeTestSetup setup; + setup.Installed->Everything.Matches.emplace_back(MakeAvailable(), Criteria()); + + // We are mostly testing to see if a null installed version causes an AV or not + SearchResult result = setup.Search(); + REQUIRE(result.Matches.size() == 0); +} + +TEST_CASE("CompositeSource_NullAvailableVersion", "[CompositeSource]") +{ + CompositeTestSetup setup{ CompositeSearchBehavior::AvailablePackages }; + setup.Available->Everything.Matches.emplace_back(MakeInstalled(), Criteria()); + + // We are mostly testing to see if a null available version causes an AV or not + SearchResult result = setup.Search(); + REQUIRE(result.Matches.size() == 1); +} diff --git a/src/AppInstallerCLITests/SQLiteIndex.cpp b/src/AppInstallerCLITests/SQLiteIndex.cpp index f63c56bb4e..63fbf627ad 100644 --- a/src/AppInstallerCLITests/SQLiteIndex.cpp +++ b/src/AppInstallerCLITests/SQLiteIndex.cpp @@ -3153,4 +3153,31 @@ TEST_CASE("SQLiteIndex_ManifestArpVersion_ValidateManifestAgainstIndex", "[sqlit // Add different version should result in failure. manifest.Version = "10.1"; REQUIRE_THROWS(ValidateManifestArpVersion(&index, manifest)); -} \ No newline at end of file +} + +TEST_CASE("SQLiteIndex_CheckConsistency_FindEmbeddedNull", "[sqliteindex]") +{ + TempFile tempFile{ "repolibtest_tempdb"s, ".db"s }; + INFO("Using temporary file named: " << tempFile.GetPath()); + + SQLiteIndex index = CreateTestIndex(tempFile, Schema::Version::Latest()); + + Manifest manifest; + manifest.Id = "Foo"; + manifest.Version = "10.0"; + manifest.Installers.push_back({}); + manifest.Installers[0].InstallerType = InstallerTypeEnum::Exe; + manifest.Installers[0].AppsAndFeaturesEntries.push_back({}); + manifest.Installers[0].AppsAndFeaturesEntries[0].DisplayVersion = "1.0"; + manifest.Installers[0].AppsAndFeaturesEntries.push_back({}); + manifest.Installers[0].AppsAndFeaturesEntries[1].DisplayVersion = "1.1"; + + index.AddManifest(manifest, "path"); + + // Inject a null character using SQL without binding since we block it + Connection connection = Connection::Create(tempFile, Connection::OpenDisposition::ReadWrite); + Statement update = Statement::Create(connection, "Update versions set version = '10.0'||char(0)||'After Null' where version = '10.0'"); + update.Execute(); + + REQUIRE(!index.CheckConsistency(true)); +} diff --git a/src/AppInstallerCLITests/SQLiteWrapper.cpp b/src/AppInstallerCLITests/SQLiteWrapper.cpp index a003ef975d..d27b20fd24 100644 --- a/src/AppInstallerCLITests/SQLiteWrapper.cpp +++ b/src/AppInstallerCLITests/SQLiteWrapper.cpp @@ -2,6 +2,7 @@ // Licensed under the MIT License. #include "pch.h" #include "TestCommon.h" +#include #include #include @@ -286,6 +287,19 @@ TEST_CASE("SQLiteWrapper_EscapeStringForLike", "[sqlitewrapper]") REQUIRE(expected == output); } +TEST_CASE("SQLiteWrapper_BindWithEmbeddedNull", "[sqlitewrapper]") +{ + Connection connection = Connection::Create(SQLITE_MEMORY_DB_CONNECTION_TARGET, Connection::OpenDisposition::Create); + + CreateSimpleTestTable(connection); + + int firstVal = 1; + std::string secondVal = "test"; + secondVal[1] = '\0'; + + REQUIRE_THROWS_HR(InsertIntoSimpleTestTable(connection, firstVal, secondVal), APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL); +} + TEST_CASE("SQLBuilder_SimpleSelectBind", "[sqlbuilder]") { Connection connection = Connection::Create(SQLITE_MEMORY_DB_CONNECTION_TARGET, Connection::OpenDisposition::Create); diff --git a/src/AppInstallerCLITests/Strings.cpp b/src/AppInstallerCLITests/Strings.cpp index 547e5fd66c..86183ee224 100644 --- a/src/AppInstallerCLITests/Strings.cpp +++ b/src/AppInstallerCLITests/Strings.cpp @@ -99,6 +99,10 @@ TEST_CASE("NormalizedString", "[strings]") // Ligature fi => f + i std::string_view input2 = u8"\xFB01"; REQUIRE(NormalizedString(input2) == u8"fi"); + + // Embedded null + std::string_view input3{ "Test\0Case", 9 }; + REQUIRE(NormalizedString(input3) == "Test Case"); } TEST_CASE("Trim", "[strings]") @@ -203,4 +207,12 @@ TEST_CASE("SplitIntoWords", "[strings]") // 私のテスト = "My test" according to an online translator // Split as "私" "の" "テスト" REQUIRE(SplitIntoWords("\xe7\xa7\x81\xe3\x81\xae\xe3\x83\x86\xe3\x82\xb9\xe3\x83\x88") == std::vector{ "\xe7\xa7\x81", "\xe3\x81\xae", "\xe3\x83\x86\xe3\x82\xb9\xe3\x83\x88" }); -} \ No newline at end of file +} + +TEST_CASE("ReplaceEmbeddedNullCharacters", "[strings]") +{ + std::string test = "Test Parts"; + test[4] = '\0'; + ReplaceEmbeddedNullCharacters(test); + REQUIRE(test == "Test Parts"); +} diff --git a/src/AppInstallerCommonCore/AppInstallerStrings.cpp b/src/AppInstallerCommonCore/AppInstallerStrings.cpp index 9d1e6cd9e4..d09fc0e03c 100644 --- a/src/AppInstallerCommonCore/AppInstallerStrings.cpp +++ b/src/AppInstallerCommonCore/AppInstallerStrings.cpp @@ -333,6 +333,17 @@ namespace AppInstaller::Utility } } + void ReplaceEmbeddedNullCharacters(std::string& s, char c) + { + for (size_t i = 0; i < s.length(); ++i) + { + if (s[i] == '\0') + { + s[i] = c; + } + } + } + std::string ToLower(std::string_view in) { std::string result(in); diff --git a/src/AppInstallerCommonCore/Errors.cpp b/src/AppInstallerCommonCore/Errors.cpp index f41de288e4..3f36a8199e 100644 --- a/src/AppInstallerCommonCore/Errors.cpp +++ b/src/AppInstallerCommonCore/Errors.cpp @@ -220,6 +220,8 @@ namespace AppInstaller return "Organization policies are preventing installation. Contact your admin."; case APPINSTALLER_CLI_ERROR_INSTALL_DEPENDENCIES: return "Failed to install package dependencies."; + case APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL: + return "Embedded null characters are disallowed for SQLite"; default: return "Unknown Error Code"; } diff --git a/src/AppInstallerCommonCore/Public/AppInstallerErrors.h b/src/AppInstallerCommonCore/Public/AppInstallerErrors.h index 8284fcc00d..0435b00399 100644 --- a/src/AppInstallerCommonCore/Public/AppInstallerErrors.h +++ b/src/AppInstallerCommonCore/Public/AppInstallerErrors.h @@ -102,6 +102,7 @@ #define APPINSTALLER_CLI_ERROR_PORTABLE_UNINSTALL_FAILED ((HRESULT)0x8A150057) #define APPINSTALLER_CLI_ERROR_ARP_VERSION_VALIDATION_FAILED ((HRESULT)0x8A150058) #define APPINSTALLER_CLI_ERROR_UNSUPPORTED_ARGUMENT ((HRESULT)0x8A150059) +#define APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL ((HRESULT)0x8A15005A) // Install errors. #define APPINSTALLER_CLI_ERROR_INSTALL_PACKAGE_IN_USE ((HRESULT)0x8A150101) diff --git a/src/AppInstallerCommonCore/Public/AppInstallerStrings.h b/src/AppInstallerCommonCore/Public/AppInstallerStrings.h index d8e4f0216d..140c5b6800 100644 --- a/src/AppInstallerCommonCore/Public/AppInstallerStrings.h +++ b/src/AppInstallerCommonCore/Public/AppInstallerStrings.h @@ -24,22 +24,26 @@ namespace AppInstaller::Utility // Normalizes a UTF16 string to the given form. std::wstring Normalize(std::wstring_view input, NORM_FORM form = NORM_FORM::NormalizationKC); + // Replaces any embedded null character found within the string. + void ReplaceEmbeddedNullCharacters(std::string& s, char c = ' '); + // Type to hold and force a normalized UTF8 string. - template + // Also enables further normalization by replacing embedded null characters with spaces. + template struct NormalizedUTF8 : public std::string { NormalizedUTF8() = default; template - NormalizedUTF8(const char(&s)[Size]) : std::string(Normalize(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }, Form)) {} + NormalizedUTF8(const char(&s)[Size]) { AssignValue(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }); } - NormalizedUTF8(std::string_view sv) : std::string(Normalize(sv, Form)) {} + NormalizedUTF8(std::string_view sv) { AssignValue(sv); } - NormalizedUTF8(std::string& s) : std::string(Normalize(s, Form)) {} - NormalizedUTF8(const std::string& s) : std::string(Normalize(s, Form)) {} - NormalizedUTF8(std::string&& s) : std::string(Normalize(s, Form)) {} + NormalizedUTF8(std::string& s) { AssignValue(s); } + NormalizedUTF8(const std::string& s) { AssignValue(s); } + NormalizedUTF8(std::string&& s) { AssignValue(s); } - NormalizedUTF8(std::wstring_view sv) : std::string(ConvertToUTF8(Normalize(sv, Form))) {} + NormalizedUTF8(std::wstring_view sv) { AssignValue(sv); } NormalizedUTF8(const NormalizedUTF8& other) = default; NormalizedUTF8& operator=(const NormalizedUTF8& other) = default; @@ -50,30 +54,51 @@ namespace AppInstaller::Utility template NormalizedUTF8& operator=(const char(&s)[Size]) { - assign(Normalize(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }, Form)); + AssignValue(std::string_view{ s, (s[Size - 1] == '\0' ? Size - 1 : Size) }); return *this; } NormalizedUTF8& operator=(std::string_view sv) { - assign(Normalize(sv, Form)); + AssignValue(sv); return *this; } NormalizedUTF8& operator=(const std::string& s) { - assign(Normalize(s, Form)); + AssignValue(s); return *this; } NormalizedUTF8& operator=(std::string&& s) { - assign(Normalize(s, Form)); + AssignValue(s); return *this; } + + private: + void AssignValue(std::string_view sv) + { + assign(Normalize(sv, Form)); + + if constexpr (ConvertEmbeddedNullToSpace) + { + ReplaceEmbeddedNullCharacters(*this); + } + } + + void AssignValue(std::wstring_view sv) + { + assign(ConvertToUTF8(Normalize(sv, Form))); + + if constexpr (ConvertEmbeddedNullToSpace) + { + ReplaceEmbeddedNullCharacters(*this); + } + } }; - using NormalizedString = NormalizedUTF8<>; + using NormalizedString = NormalizedUTF8; // Compares the two UTF8 strings in a case insensitive manner. // Use this if one of the values is a known value, and thus ToLower is sufficient. diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp index 5b97c3c423..3fcdedb69b 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/Interface_1_0.cpp @@ -393,53 +393,35 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0 { bool result = true; - // Check the manifest table references to it's 1:1 tables - if (result || log) - { - result = ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = ManifestTable::CheckConsistency(connection, log) && result; +#define AICLI_CHECK_CONSISTENCY(_check_) \ + if (result || log) \ + { \ + result = _check_ && result; \ } - // Check the pathpaths table for consistency - if (result || log) - { - result = PathPartTable::CheckConsistency(connection, log) && result; - } + // Check the manifest table references to it's 1:1 tables + AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(ManifestTable::CheckConsistency(connection, log)); + + // Check the 1:1 tables' consistency + AICLI_CHECK_CONSISTENCY(IdTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(NameTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(MonikerTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(VersionTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(ChannelTable::CheckConsistency(connection, log)); + + // Check the pathparts table for consistency + AICLI_CHECK_CONSISTENCY(PathPartTable::CheckConsistency(connection, log)); // Check the 1:N map tables for consistency - if (result || log) - { - result = TagsTable::CheckConsistency(connection, log) && result; - } + AICLI_CHECK_CONSISTENCY(TagsTable::CheckConsistency(connection, log)); + AICLI_CHECK_CONSISTENCY(CommandsTable::CheckConsistency(connection, log)); - if (result || log) - { - result = CommandsTable::CheckConsistency(connection, log) && result; - } +#undef AICLI_CHECK_CONSISTENCY return result; } diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp index 4980da4351..6dcaca1c6b 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToManyTable.cpp @@ -355,6 +355,13 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0 result = result && secondaryResult; } + if (!result && !log) + { + return result; + } + + result = OneToOneTableCheckConsistency(connection, tableName, valueName, log) && result; + return result; } diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp index ba5d1347a8..be3f4bd239 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.cpp @@ -176,5 +176,34 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0 builder.Execute(connection); } + + bool OneToOneTableCheckConsistency(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, bool log) + { + // Build a select statement to find values that contain an embedded null character + // Such as: + // Select count(*) from table where instr(value,char(0))>0 + SQLite::Builder::StatementBuilder builder; + builder. + Select({ SQLite::RowIDName, valueName }). + From(tableName). + WhereValueContainsEmbeddedNullCharacter(valueName); + + SQLite::Statement select = builder.Prepare(connection); + bool result = true; + + while (select.Step()) + { + result = false; + + if (!log) + { + break; + } + + AICLI_LOG(Repo, Info, << " [INVALID] value in table [" << tableName << "] at row [" << select.GetColumn(0) << "] contains an embedded null character and starts with [" << select.GetColumn(1) << "]"); + } + + return result; + } } } diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h index 51e0360118..02594c758b 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/OneToOneTable.h @@ -37,6 +37,9 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0 // Removes the given row by its rowid if it is no longer referenced. void OneToOneTableDeleteById(SQLite::Connection& connection, std::string_view tableName, SQLite::rowid_t id); + + // Checks the consistency of the table. + bool OneToOneTableCheckConsistency(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, bool log); } // A table that represents a value that is 1:1 with a primary entry. @@ -132,5 +135,11 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0 { return details::OneToOneTableIsEmpty(connection, TableInfo::TableName()); } + + // Checks the consistency of the table. + static bool CheckConsistency(const SQLite::Connection& connection, bool log) + { + return details::OneToOneTableCheckConsistency(connection, TableInfo::TableName(), TableInfo::ValueName(), log); + } }; } diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp index 5cea7d41c7..d6185e8efb 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_0/PathPartTable.cpp @@ -373,27 +373,59 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_0 // Select l.rowid, l.parent from pathparts as l left outer join pathparts as r on l.parent = r.rowid where l.parent is not null and r.pathpart is null constexpr std::string_view s_left = "left"sv; constexpr std::string_view s_right = "right"sv; - - SQLite::Builder::StatementBuilder builder; - builder. - Select({ QCol(s_left, SQLite::RowIDName), QCol(s_left, s_PathPartTable_ParentValue_Name) }). - From(s_PathPartTable_Table_Name).As(s_left). - LeftOuterJoin(s_PathPartTable_Table_Name).As(s_right).On(QCol(s_left, s_PathPartTable_ParentValue_Name), QCol(s_right, SQLite::RowIDName)). - Where(QCol(s_left, s_PathPartTable_ParentValue_Name)).IsNotNull().And(QCol(s_right, s_PathPartTable_PartValue_Name)).IsNull(); - - SQLite::Statement select = builder.Prepare(connection); bool result = true; - while (select.Step()) { - result = false; + SQLite::Builder::StatementBuilder builder; + builder. + Select({ QCol(s_left, SQLite::RowIDName), QCol(s_left, s_PathPartTable_ParentValue_Name) }). + From(s_PathPartTable_Table_Name).As(s_left). + LeftOuterJoin(s_PathPartTable_Table_Name).As(s_right).On(QCol(s_left, s_PathPartTable_ParentValue_Name), QCol(s_right, SQLite::RowIDName)). + Where(QCol(s_left, s_PathPartTable_ParentValue_Name)).IsNotNull().And(QCol(s_right, s_PathPartTable_PartValue_Name)).IsNull(); + + SQLite::Statement select = builder.Prepare(connection); - if (!log) + while (select.Step()) { - break; + result = false; + + if (!log) + { + break; + } + + AICLI_LOG(Repo, Info, << " [INVALID] pathparts [" << select.GetColumn(0) << "] refers to " << s_PathPartTable_ParentValue_Name << " [" << select.GetColumn(1) << "]"); } + } - AICLI_LOG(Repo, Info, << " [INVALID] pathparts [" << select.GetColumn(0) << "] refers to " << s_PathPartTable_ParentValue_Name << " [" << select.GetColumn(1) << "]"); + if (!result && !log) + { + return result; + } + + { + // Build a select statement to find values that contain an embedded null character + // Such as: + // Select count(*) from table where instr(value,char(0))>0 + SQLite::Builder::StatementBuilder builder; + builder. + Select({ SQLite::RowIDName, s_PathPartTable_PartValue_Name }). + From(s_PathPartTable_Table_Name). + WhereValueContainsEmbeddedNullCharacter(s_PathPartTable_PartValue_Name); + + SQLite::Statement select = builder.Prepare(connection); + + while (select.Step()) + { + result = false; + + if (!log) + { + break; + } + + AICLI_LOG(Repo, Info, << " [INVALID] value in table [" << s_PathPartTable_Table_Name << "] at row [" << select.GetColumn(0) << "] contains an embedded null character and starts with [" << select.GetColumn(1) << "]"); + } } return result; diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp index 381e5ef29c..d4c7668098 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.cpp @@ -1,560 +1,560 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -#include "pch.h" -#include "DependenciesTable.h" -#include "SQLiteStatementBuilder.h" -#include "winget\DependenciesGraph.h" -#include "Microsoft/Schema/1_0/OneToOneTable.h" -#include "Microsoft/Schema/1_0/IdTable.h" -#include "Microsoft/Schema/1_0/ManifestTable.h" -#include "Microsoft/Schema/1_0/VersionTable.h" -#include "Microsoft/Schema/1_0/Interface.h" -#include "Microsoft/Schema/1_0/ChannelTable.h" - -namespace AppInstaller::Repository::Microsoft::Schema::V1_4 -{ - using namespace AppInstaller; - using namespace std::string_view_literals; - using namespace SQLite::Builder; - using namespace Schema::V1_0; - using QCol = SQLite::Builder::QualifiedColumn; - - static constexpr std::string_view s_DependenciesTable_Table_Name = "dependencies"sv; - static constexpr std::string_view s_DependenciesTable_Index_Name = "dependencies_pkindex"sv; - static constexpr std::string_view s_DependenciesTable_Manifest_Column_Name = "manifest"sv; - static constexpr std::string_view s_DependenciesTable_MinVersion_Column_Name = "min_version"sv; - static constexpr std::string_view s_DependenciesTable_PackageId_Column_Name = "package_id"; - - namespace - { - struct DependencyTableRow - { - SQLite::rowid_t m_packageRowId; - SQLite::rowid_t m_manifestRowId; - - // Ideally this should be version row id, the version string is more needed than row id, - // this prevents converting back and forth between version row id and version string. - std::optional m_version; - - bool operator <(const DependencyTableRow& rhs) const - { - auto lhsVersion = m_version.has_value() ? m_version.value() : ""; - auto rhsVersion = rhs.m_version.has_value() ? rhs.m_version.value() : ""; - return std::tie(m_packageRowId, m_manifestRowId, lhsVersion) < std::tie(rhs.m_packageRowId, rhs.m_manifestRowId, rhsVersion); - } - }; - - void ThrowOnMissingPackageNodes(std::vector& missingPackageNodes) - { - if (!missingPackageNodes.empty()) - { - std::string missingPackages{ missingPackageNodes.begin()->Id }; - std::for_each( - missingPackageNodes.begin() + 1, - missingPackageNodes.end(), - [&](auto& dep) { missingPackages.append(", " + dep.Id); }); - THROW_HR_MSG(APPINSTALLER_CLI_ERROR_MISSING_PACKAGE, "Missing packages: %hs", missingPackages.c_str()); - } - } - - std::set GetAndLinkDependencies( - SQLite::Connection& connection, - const Manifest::Manifest& manifest, - SQLite::rowid_t manifestRowId, - Manifest::DependencyType dependencyType) - { - std::set dependencies; - std::vector missingPackageNodes; - - for (const auto& installer : manifest.Installers) - { - installer.Dependencies.ApplyToType(dependencyType, [&](Manifest::Dependency dependency) - { - auto packageRowId = IdTable::SelectIdByValue(connection, dependency.Id); - std::optional version; - - if (!packageRowId.has_value()) - { - missingPackageNodes.emplace_back(dependency); - return; - } - - if (dependency.MinVersion.has_value()) - { - version = dependency.MinVersion.value().ToString(); - } - - dependencies.emplace(DependencyTableRow{ packageRowId.value(), manifestRowId, version }); - }); - } - - ThrowOnMissingPackageNodes(missingPackageNodes); - - return dependencies; - } - - bool RemoveDependenciesByRowIds(SQLite::Connection& connection, std::vector dependencyTableRows) - { - using namespace SQLite::Builder; - bool tableUpdated = false; - if (dependencyTableRows.empty()) - { - return tableUpdated; - } - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_rowid"); - - SQLite::Builder::StatementBuilder builder; - builder - .DeleteFrom(s_DependenciesTable_Table_Name) - .Where(s_DependenciesTable_PackageId_Column_Name).Equals(Unbound) - .And(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound); - - - SQLite::Statement deleteStmt = builder.Prepare(connection); - for (auto row : dependencyTableRows) - { - deleteStmt.Reset(); - deleteStmt.Bind(1, row.m_packageRowId); - deleteStmt.Bind(2, row.m_manifestRowId); - deleteStmt.Execute(true); - tableUpdated = true; - } - - savepoint.Commit(); - return tableUpdated; - } - - bool InsertManifestDependencies( - SQLite::Connection& connection, - std::set& dependenciesTableRows) - { - using namespace SQLite::Builder; - using namespace Schema::V1_0; - bool tableUpdated = false; - - StatementBuilder insertBuilder; - insertBuilder.InsertInto(s_DependenciesTable_Table_Name) - .Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_MinVersion_Column_Name, s_DependenciesTable_PackageId_Column_Name }) - .Values(Unbound, Unbound, Unbound); - SQLite::Statement insert = insertBuilder.Prepare(connection); - - - for (const auto& dep : dependenciesTableRows) - { - insert.Reset(); - insert.Bind(1, dep.m_manifestRowId); - - if (dep.m_version.has_value()) - { - insert.Bind(2, VersionTable::EnsureExists(connection, dep.m_version.value())); - } - else - { - insert.Bind(2, nullptr); - } - - insert.Bind(3, dep.m_packageRowId); - - insert.Execute(true); - tableUpdated = true; - } - - return tableUpdated; - } - } - - bool DependenciesTable::Exists(const SQLite::Connection& connection) - { - using namespace SQLite; - - Builder::StatementBuilder builder; - builder.Select(Builder::RowCount).From(Builder::Schema::MainTable). - Where(Builder::Schema::TypeColumn).Equals(Builder::Schema::Type_Table).And(Builder::Schema::NameColumn).Equals(s_DependenciesTable_Table_Name); - - Statement statement = builder.Prepare(connection); - THROW_HR_IF(E_UNEXPECTED, !statement.Step()); - return statement.GetColumn(0) != 0; - } - - std::string_view DependenciesTable::TableName() - { - return s_DependenciesTable_Table_Name; - } - - void DependenciesTable::Create(SQLite::Connection& connection) - { - using namespace SQLite::Builder; - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createDependencyTable_v1_4"); - constexpr std::string_view dependencyIndexByVersionId = "dependencies_version_id_index"; - constexpr std::string_view dependencyIndexByPackageId = "dependencies_package_id_index"; - - StatementBuilder createTableBuilder; - createTableBuilder.CreateTable(TableName()).BeginColumns(); - createTableBuilder.Column(IntegerPrimaryKey()); - - std::array notNullableDependenciesColumns - { - DependenciesTableColumnInfo{ s_DependenciesTable_Manifest_Column_Name }, - DependenciesTableColumnInfo{ s_DependenciesTable_PackageId_Column_Name } - }; - - std::array nullableDependenciesColumns - { - DependenciesTableColumnInfo{ s_DependenciesTable_MinVersion_Column_Name } - }; - - // Add dependencies column tables not null columns. - for (const DependenciesTableColumnInfo& value : notNullableDependenciesColumns) - { - createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId).NotNull()); - } - - // Add dependencies column tables null columns. - for (const DependenciesTableColumnInfo& value : nullableDependenciesColumns) - { - createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId)); - } - - createTableBuilder.EndColumns(); - - createTableBuilder.Execute(connection); - - // Primary key index by package rowid and manifest rowid. - StatementBuilder createPKIndexBuilder; - createPKIndexBuilder.CreateUniqueIndex(s_DependenciesTable_Index_Name).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_PackageId_Column_Name }); - createPKIndexBuilder.Execute(connection); - - // Index of dependency by Manifest id. - StatementBuilder createIndexByManifestIdBuilder; - createIndexByManifestIdBuilder.CreateIndex(dependencyIndexByVersionId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_MinVersion_Column_Name }); - createIndexByManifestIdBuilder.Execute(connection); - - // Index of dependency by package id. - StatementBuilder createIndexByPackageIdBuilder; - createIndexByPackageIdBuilder.CreateIndex(dependencyIndexByPackageId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_PackageId_Column_Name }); - createIndexByPackageIdBuilder.Execute(connection); - - savepoint.Commit(); - } - - void DependenciesTable::AddDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId) - { - if (!Exists(connection)) - { - return; - } - - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "add_dependencies_v1_4"); - - auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package); - if (!dependencies.size()) - { - return; - } - - InsertManifestDependencies(connection, dependencies); - - savepoint.Commit(); - } - - bool DependenciesTable::UpdateDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId) - { - if (!Exists(connection)) - { - return false; - } - - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "update_dependencies_v1_4"); - - const auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package); - auto existingDependencies = GetDependenciesByManifestRowId(connection, manifestRowId); - - // Get dependencies to add. - std::set toAddDependencies; - std::copy_if( - dependencies.begin(), - dependencies.end(), - std::inserter(toAddDependencies, toAddDependencies.begin()), - [&](DependencyTableRow dep) - { - Utility::NormalizedString version = dep.m_version.has_value() ? dep.m_version.value() : ""; - return existingDependencies.find(std::make_pair(dep.m_packageRowId, version)) == existingDependencies.end(); - } - ); - - // Get dependencies to remove. - std::vector toRemoveDependencies; - std::for_each( - existingDependencies.begin(), - existingDependencies.end(), - [&](std::pair row) - { - if (dependencies.find(DependencyTableRow{row.first, manifestRowId, row.second}) == dependencies.end()) - { - toRemoveDependencies.emplace_back(DependencyTableRow{ row.first, manifestRowId }); - } - } - ); - - bool tableUpdated = InsertManifestDependencies(connection, toAddDependencies); - tableUpdated = RemoveDependenciesByRowIds(connection, toRemoveDependencies) || tableUpdated; - savepoint.Commit(); - - return tableUpdated; - } - - void DependenciesTable::RemoveDependencies(SQLite::Connection& connection, SQLite::rowid_t manifestRowId) - { - if (!Exists(connection)) - { - return; - } - - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_manifest_v1_4"); - - SQLite::Builder::StatementBuilder builder; - builder.DeleteFrom(s_DependenciesTable_Table_Name).Where(s_DependenciesTable_Manifest_Column_Name).Equals(manifestRowId); - - builder.Execute(connection); - savepoint.Commit(); - } - - std::vector> DependenciesTable::GetDependentsById(const SQLite::Connection& connection, Manifest::string_t packageId) - { - constexpr std::string_view depTableAlias = "dep"; - constexpr std::string_view minVersionAlias = "minV"; - constexpr std::string_view packageIdAlias = "pId"; - - - StatementBuilder builder; - // Find all manifest that depend on this package. - // SELECT [dep].[manifest], [pId].[id], [minV].[version] FROM [dependencies] AS [dep] - // JOIN [versions] AS [minV] ON [dep].[min_version] = [minV].[rowid] - // JOIN [ids] AS [pId] ON [pId].[rowid] = [dep].[package_id] - // WHERE [pId].[id] = ? - builder.Select() - .Column(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name)) - .Column(QCol(packageIdAlias, IdTable::ValueName())) - .Column(QCol(minVersionAlias, VersionTable::ValueName())) - .From({ s_DependenciesTable_Table_Name }).As(depTableAlias) - .Join({ VersionTable::TableName() }).As(minVersionAlias) - .On(QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name), QCol(minVersionAlias, SQLite::RowIDName)) - .Join({ IdTable::TableName() }).As(packageIdAlias) - .On(QCol(packageIdAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name)) - .Where(QCol(packageIdAlias, IdTable::ValueName())).Equals(Unbound); - - SQLite::Statement stmt = builder.Prepare(connection); - stmt.Bind(1, std::string{ packageId }); - - std::vector> resultSet; - - while (stmt.Step()) - { - resultSet.emplace_back( - std::make_pair(stmt.GetColumn(0), Utility::NormalizedString(stmt.GetColumn(2)))); - } - - return resultSet; - } - - std::set> DependenciesTable::GetDependenciesByManifestRowId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId) - { - SQLite::Builder::StatementBuilder builder; - - constexpr std::string_view depTableAlias = "dep"; - constexpr std::string_view minVersionAlias = "minV"; - - std::set> resultSet; - - // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep] - // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version] - // WHERE [dep].[manifest] = ? - builder.Select() - .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name)) - .Column(QCol(minVersionAlias, VersionTable::ValueName())) - .From({ s_DependenciesTable_Table_Name }).As(depTableAlias) - .Join({ VersionTable::TableName() }).As(minVersionAlias) - .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name)) - .Where(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name)).Equals(Unbound); - - SQLite::Statement select = builder.Prepare(connection); - - select.Bind(1, manifestRowId); - while (select.Step()) - { - Utility::NormalizedString version = ""; - if (!select.GetColumnIsNull(1)) - { - version = select.GetColumn(1); - } - resultSet.emplace(std::make_pair(select.GetColumn(0), version)); - } - - return resultSet; - } - - void DependenciesTable::PrepareForPackaging(SQLite::Connection& connection) - { - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "prepareForPacking_V1_4"); - - StatementBuilder dropIndexBuilder; - dropIndexBuilder.DropIndex({ s_DependenciesTable_Index_Name }); - dropIndexBuilder.Execute(connection); - - StatementBuilder dropTableBuilder; - dropTableBuilder.DropTable({ s_DependenciesTable_Table_Name }); - dropTableBuilder.Execute(connection); - - savepoint.Commit(); - } - - bool DependenciesTable::DependenciesTableCheckConsistency(const SQLite::Connection& connection, bool log) - { - StatementBuilder builder; - - if (!Exists(connection)) - { - return true; - } - - builder.Select(QCol(s_DependenciesTable_Table_Name, SQLite::RowIDName)) - .From(s_DependenciesTable_Table_Name) - .LeftOuterJoin(IdTable::TableName()) - .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_PackageId_Column_Name), QCol(IdTable::TableName(), SQLite::RowIDName)) - .LeftOuterJoin(ManifestTable::TableName()) - .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_Manifest_Column_Name), QCol(ManifestTable::TableName(), SQLite::RowIDName)) - .LeftOuterJoin(VersionTable::TableName()) - .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name), QCol(VersionTable::TableName(), SQLite::RowIDName)) - .Where(QCol(ManifestTable::TableName(), SQLite::RowIDName)).IsNull() - .Or(QCol(VersionTable::TableName(), SQLite::RowIDName)).IsNull().And(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name)).IsNotNull() - .Or(QCol(IdTable::TableName(), SQLite::RowIDName)).IsNull(); - - SQLite::Statement select = builder.Prepare(connection); - - bool result = true; - - while (select.Step()) - { - result = false; - - if (!log) - { - break; - } - - AICLI_LOG(Repo, Info, << " [INVALID] rowid [" << select.GetColumn(0) << "]"); - } - - return result; - } - - bool DependenciesTable::IsValueReferenced(const SQLite::Connection& connection, std::string_view tableName, SQLite::rowid_t valueRowId) - { - if (!Exists(connection)) - { - return false; - } - - StatementBuilder builder; - - if (tableName != V1_0::VersionTable::TableName()) - { - return false; - } - - std::array columns = { s_DependenciesTable_MinVersion_Column_Name }; - bool referenced = false; - - for(auto column: columns) - { - builder.Select(SQLite::RowIDName).From(s_DependenciesTable_Table_Name).Where(column).Equals(Unbound).Limit(1); - - SQLite::Statement select = builder.Prepare(connection); - - select.Bind(1, valueRowId); - if (select.Step()) - { - referenced = true; - break; - } - } - - return referenced; - } - - std::vector DependenciesTable::GetDependenciesMinVersionsRowIdByManifestId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId) - { - if (!Exists(connection)) - { - return {}; - } - - StatementBuilder builder; - - std::vector result; - - // Find all versions for manifest row. - // SELECT [min_version] FROM [dependencies] - // WHERE [manifest] = ? - builder.Select() - .Column(s_DependenciesTable_MinVersion_Column_Name) - .From({ s_DependenciesTable_Table_Name }) - .Where(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound); - - auto select = builder.Prepare(connection); - - select.Bind(1, manifestRowId); - - while (select.Step()) - { - if (!select.GetColumnIsNull(0)) - { - result.emplace_back(select.GetColumn(0)); - } - } - - return result; - } - +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" +#include "DependenciesTable.h" +#include "SQLiteStatementBuilder.h" +#include "winget\DependenciesGraph.h" +#include "Microsoft/Schema/1_0/OneToOneTable.h" +#include "Microsoft/Schema/1_0/IdTable.h" +#include "Microsoft/Schema/1_0/ManifestTable.h" +#include "Microsoft/Schema/1_0/VersionTable.h" +#include "Microsoft/Schema/1_0/Interface.h" +#include "Microsoft/Schema/1_0/ChannelTable.h" + +namespace AppInstaller::Repository::Microsoft::Schema::V1_4 +{ + using namespace AppInstaller; + using namespace std::string_view_literals; + using namespace SQLite::Builder; + using namespace Schema::V1_0; + using QCol = SQLite::Builder::QualifiedColumn; + + static constexpr std::string_view s_DependenciesTable_Table_Name = "dependencies"sv; + static constexpr std::string_view s_DependenciesTable_Index_Name = "dependencies_pkindex"sv; + static constexpr std::string_view s_DependenciesTable_Manifest_Column_Name = "manifest"sv; + static constexpr std::string_view s_DependenciesTable_MinVersion_Column_Name = "min_version"sv; + static constexpr std::string_view s_DependenciesTable_PackageId_Column_Name = "package_id"; + + namespace + { + struct DependencyTableRow + { + SQLite::rowid_t m_packageRowId; + SQLite::rowid_t m_manifestRowId; + + // Ideally this should be version row id, the version string is more needed than row id, + // this prevents converting back and forth between version row id and version string. + std::optional m_version; + + bool operator <(const DependencyTableRow& rhs) const + { + auto lhsVersion = m_version.has_value() ? m_version.value() : ""; + auto rhsVersion = rhs.m_version.has_value() ? rhs.m_version.value() : ""; + return std::tie(m_packageRowId, m_manifestRowId, lhsVersion) < std::tie(rhs.m_packageRowId, rhs.m_manifestRowId, rhsVersion); + } + }; + + void ThrowOnMissingPackageNodes(std::vector& missingPackageNodes) + { + if (!missingPackageNodes.empty()) + { + std::string missingPackages{ missingPackageNodes.begin()->Id }; + std::for_each( + missingPackageNodes.begin() + 1, + missingPackageNodes.end(), + [&](auto& dep) { missingPackages.append(", " + dep.Id); }); + THROW_HR_MSG(APPINSTALLER_CLI_ERROR_MISSING_PACKAGE, "Missing packages: %hs", missingPackages.c_str()); + } + } + + std::set GetAndLinkDependencies( + SQLite::Connection& connection, + const Manifest::Manifest& manifest, + SQLite::rowid_t manifestRowId, + Manifest::DependencyType dependencyType) + { + std::set dependencies; + std::vector missingPackageNodes; + + for (const auto& installer : manifest.Installers) + { + installer.Dependencies.ApplyToType(dependencyType, [&](Manifest::Dependency dependency) + { + auto packageRowId = IdTable::SelectIdByValue(connection, dependency.Id); + std::optional version; + + if (!packageRowId.has_value()) + { + missingPackageNodes.emplace_back(dependency); + return; + } + + if (dependency.MinVersion.has_value()) + { + version = dependency.MinVersion.value().ToString(); + } + + dependencies.emplace(DependencyTableRow{ packageRowId.value(), manifestRowId, version }); + }); + } + + ThrowOnMissingPackageNodes(missingPackageNodes); + + return dependencies; + } + + bool RemoveDependenciesByRowIds(SQLite::Connection& connection, std::vector dependencyTableRows) + { + using namespace SQLite::Builder; + bool tableUpdated = false; + if (dependencyTableRows.empty()) + { + return tableUpdated; + } + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_rowid"); + + SQLite::Builder::StatementBuilder builder; + builder + .DeleteFrom(s_DependenciesTable_Table_Name) + .Where(s_DependenciesTable_PackageId_Column_Name).Equals(Unbound) + .And(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound); + + + SQLite::Statement deleteStmt = builder.Prepare(connection); + for (auto row : dependencyTableRows) + { + deleteStmt.Reset(); + deleteStmt.Bind(1, row.m_packageRowId); + deleteStmt.Bind(2, row.m_manifestRowId); + deleteStmt.Execute(true); + tableUpdated = true; + } + + savepoint.Commit(); + return tableUpdated; + } + + bool InsertManifestDependencies( + SQLite::Connection& connection, + std::set& dependenciesTableRows) + { + using namespace SQLite::Builder; + using namespace Schema::V1_0; + bool tableUpdated = false; + + StatementBuilder insertBuilder; + insertBuilder.InsertInto(s_DependenciesTable_Table_Name) + .Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_MinVersion_Column_Name, s_DependenciesTable_PackageId_Column_Name }) + .Values(Unbound, Unbound, Unbound); + SQLite::Statement insert = insertBuilder.Prepare(connection); + + + for (const auto& dep : dependenciesTableRows) + { + insert.Reset(); + insert.Bind(1, dep.m_manifestRowId); + + if (dep.m_version.has_value()) + { + insert.Bind(2, VersionTable::EnsureExists(connection, dep.m_version.value())); + } + else + { + insert.Bind(2, nullptr); + } + + insert.Bind(3, dep.m_packageRowId); + + insert.Execute(true); + tableUpdated = true; + } + + return tableUpdated; + } + } + + bool DependenciesTable::Exists(const SQLite::Connection& connection) + { + using namespace SQLite; + + Builder::StatementBuilder builder; + builder.Select(Builder::RowCount).From(Builder::Schema::MainTable). + Where(Builder::Schema::TypeColumn).Equals(Builder::Schema::Type_Table).And(Builder::Schema::NameColumn).Equals(s_DependenciesTable_Table_Name); + + Statement statement = builder.Prepare(connection); + THROW_HR_IF(E_UNEXPECTED, !statement.Step()); + return statement.GetColumn(0) != 0; + } + + std::string_view DependenciesTable::TableName() + { + return s_DependenciesTable_Table_Name; + } + + void DependenciesTable::Create(SQLite::Connection& connection) + { + using namespace SQLite::Builder; + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createDependencyTable_v1_4"); + constexpr std::string_view dependencyIndexByVersionId = "dependencies_version_id_index"; + constexpr std::string_view dependencyIndexByPackageId = "dependencies_package_id_index"; + + StatementBuilder createTableBuilder; + createTableBuilder.CreateTable(TableName()).BeginColumns(); + createTableBuilder.Column(IntegerPrimaryKey()); + + std::array notNullableDependenciesColumns + { + DependenciesTableColumnInfo{ s_DependenciesTable_Manifest_Column_Name }, + DependenciesTableColumnInfo{ s_DependenciesTable_PackageId_Column_Name } + }; + + std::array nullableDependenciesColumns + { + DependenciesTableColumnInfo{ s_DependenciesTable_MinVersion_Column_Name } + }; + + // Add dependencies column tables not null columns. + for (const DependenciesTableColumnInfo& value : notNullableDependenciesColumns) + { + createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId).NotNull()); + } + + // Add dependencies column tables null columns. + for (const DependenciesTableColumnInfo& value : nullableDependenciesColumns) + { + createTableBuilder.Column(ColumnBuilder(value.Name, Type::RowId)); + } + + createTableBuilder.EndColumns(); + + createTableBuilder.Execute(connection); + + // Primary key index by package rowid and manifest rowid. + StatementBuilder createPKIndexBuilder; + createPKIndexBuilder.CreateUniqueIndex(s_DependenciesTable_Index_Name).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_Manifest_Column_Name, s_DependenciesTable_PackageId_Column_Name }); + createPKIndexBuilder.Execute(connection); + + // Index of dependency by Manifest id. + StatementBuilder createIndexByManifestIdBuilder; + createIndexByManifestIdBuilder.CreateIndex(dependencyIndexByVersionId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_MinVersion_Column_Name }); + createIndexByManifestIdBuilder.Execute(connection); + + // Index of dependency by package id. + StatementBuilder createIndexByPackageIdBuilder; + createIndexByPackageIdBuilder.CreateIndex(dependencyIndexByPackageId).On(s_DependenciesTable_Table_Name).Columns({ s_DependenciesTable_PackageId_Column_Name }); + createIndexByPackageIdBuilder.Execute(connection); + + savepoint.Commit(); + } + + void DependenciesTable::AddDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId) + { + if (!Exists(connection)) + { + return; + } + + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "add_dependencies_v1_4"); + + auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package); + if (!dependencies.size()) + { + return; + } + + InsertManifestDependencies(connection, dependencies); + + savepoint.Commit(); + } + + bool DependenciesTable::UpdateDependencies(SQLite::Connection& connection, const Manifest::Manifest& manifest, SQLite::rowid_t manifestRowId) + { + if (!Exists(connection)) + { + return false; + } + + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "update_dependencies_v1_4"); + + const auto dependencies = GetAndLinkDependencies(connection, manifest, manifestRowId, Manifest::DependencyType::Package); + auto existingDependencies = GetDependenciesByManifestRowId(connection, manifestRowId); + + // Get dependencies to add. + std::set toAddDependencies; + std::copy_if( + dependencies.begin(), + dependencies.end(), + std::inserter(toAddDependencies, toAddDependencies.begin()), + [&](DependencyTableRow dep) + { + Utility::NormalizedString version = dep.m_version.has_value() ? dep.m_version.value() : ""; + return existingDependencies.find(std::make_pair(dep.m_packageRowId, version)) == existingDependencies.end(); + } + ); + + // Get dependencies to remove. + std::vector toRemoveDependencies; + std::for_each( + existingDependencies.begin(), + existingDependencies.end(), + [&](std::pair row) + { + if (dependencies.find(DependencyTableRow{row.first, manifestRowId, row.second}) == dependencies.end()) + { + toRemoveDependencies.emplace_back(DependencyTableRow{ row.first, manifestRowId }); + } + } + ); + + bool tableUpdated = InsertManifestDependencies(connection, toAddDependencies); + tableUpdated = RemoveDependenciesByRowIds(connection, toRemoveDependencies) || tableUpdated; + savepoint.Commit(); + + return tableUpdated; + } + + void DependenciesTable::RemoveDependencies(SQLite::Connection& connection, SQLite::rowid_t manifestRowId) + { + if (!Exists(connection)) + { + return; + } + + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, std::string{ s_DependenciesTable_Table_Name } + "remove_dependencies_by_manifest_v1_4"); + + SQLite::Builder::StatementBuilder builder; + builder.DeleteFrom(s_DependenciesTable_Table_Name).Where(s_DependenciesTable_Manifest_Column_Name).Equals(manifestRowId); + + builder.Execute(connection); + savepoint.Commit(); + } + + std::vector> DependenciesTable::GetDependentsById(const SQLite::Connection& connection, Manifest::string_t packageId) + { + constexpr std::string_view depTableAlias = "dep"; + constexpr std::string_view minVersionAlias = "minV"; + constexpr std::string_view packageIdAlias = "pId"; + + + StatementBuilder builder; + // Find all manifest that depend on this package. + // SELECT [dep].[manifest], [pId].[id], [minV].[version] FROM [dependencies] AS [dep] + // JOIN [versions] AS [minV] ON [dep].[min_version] = [minV].[rowid] + // JOIN [ids] AS [pId] ON [pId].[rowid] = [dep].[package_id] + // WHERE [pId].[id] = ? + builder.Select() + .Column(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name)) + .Column(QCol(packageIdAlias, IdTable::ValueName())) + .Column(QCol(minVersionAlias, VersionTable::ValueName())) + .From({ s_DependenciesTable_Table_Name }).As(depTableAlias) + .Join({ VersionTable::TableName() }).As(minVersionAlias) + .On(QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name), QCol(minVersionAlias, SQLite::RowIDName)) + .Join({ IdTable::TableName() }).As(packageIdAlias) + .On(QCol(packageIdAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name)) + .Where(QCol(packageIdAlias, IdTable::ValueName())).Equals(Unbound); + + SQLite::Statement stmt = builder.Prepare(connection); + stmt.Bind(1, std::string{ packageId }); + + std::vector> resultSet; + + while (stmt.Step()) + { + resultSet.emplace_back( + std::make_pair(stmt.GetColumn(0), Utility::NormalizedString(stmt.GetColumn(2)))); + } + + return resultSet; + } + + std::set> DependenciesTable::GetDependenciesByManifestRowId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId) + { + SQLite::Builder::StatementBuilder builder; + + constexpr std::string_view depTableAlias = "dep"; + constexpr std::string_view minVersionAlias = "minV"; + + std::set> resultSet; + + // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep] + // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version] + // WHERE [dep].[manifest] = ? + builder.Select() + .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name)) + .Column(QCol(minVersionAlias, VersionTable::ValueName())) + .From({ s_DependenciesTable_Table_Name }).As(depTableAlias) + .Join({ VersionTable::TableName() }).As(minVersionAlias) + .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name)) + .Where(QCol(depTableAlias, s_DependenciesTable_Manifest_Column_Name)).Equals(Unbound); + + SQLite::Statement select = builder.Prepare(connection); + + select.Bind(1, manifestRowId); + while (select.Step()) + { + Utility::NormalizedString version = ""; + if (!select.GetColumnIsNull(1)) + { + version = select.GetColumn(1); + } + resultSet.emplace(std::make_pair(select.GetColumn(0), version)); + } + + return resultSet; + } + + void DependenciesTable::PrepareForPackaging(SQLite::Connection& connection) + { + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "prepareForPacking_V1_4"); + + StatementBuilder dropIndexBuilder; + dropIndexBuilder.DropIndex({ s_DependenciesTable_Index_Name }); + dropIndexBuilder.Execute(connection); + + StatementBuilder dropTableBuilder; + dropTableBuilder.DropTable({ s_DependenciesTable_Table_Name }); + dropTableBuilder.Execute(connection); + + savepoint.Commit(); + } + + bool DependenciesTable::CheckConsistency(const SQLite::Connection& connection, bool log) + { + StatementBuilder builder; + + if (!Exists(connection)) + { + return true; + } + + builder.Select(QCol(s_DependenciesTable_Table_Name, SQLite::RowIDName)) + .From(s_DependenciesTable_Table_Name) + .LeftOuterJoin(IdTable::TableName()) + .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_PackageId_Column_Name), QCol(IdTable::TableName(), SQLite::RowIDName)) + .LeftOuterJoin(ManifestTable::TableName()) + .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_Manifest_Column_Name), QCol(ManifestTable::TableName(), SQLite::RowIDName)) + .LeftOuterJoin(VersionTable::TableName()) + .On(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name), QCol(VersionTable::TableName(), SQLite::RowIDName)) + .Where(QCol(ManifestTable::TableName(), SQLite::RowIDName)).IsNull() + .Or(QCol(VersionTable::TableName(), SQLite::RowIDName)).IsNull().And(QCol(s_DependenciesTable_Table_Name, s_DependenciesTable_MinVersion_Column_Name)).IsNotNull() + .Or(QCol(IdTable::TableName(), SQLite::RowIDName)).IsNull(); + + SQLite::Statement select = builder.Prepare(connection); + + bool result = true; + + while (select.Step()) + { + result = false; + + if (!log) + { + break; + } + + AICLI_LOG(Repo, Info, << " [INVALID] row in [" << s_DependenciesTable_Table_Name << "] rowid [" << select.GetColumn(0) << "]"); + } + + return result; + } + + bool DependenciesTable::IsValueReferenced(const SQLite::Connection& connection, std::string_view tableName, SQLite::rowid_t valueRowId) + { + if (!Exists(connection)) + { + return false; + } + + StatementBuilder builder; + + if (tableName != V1_0::VersionTable::TableName()) + { + return false; + } + + std::array columns = { s_DependenciesTable_MinVersion_Column_Name }; + bool referenced = false; + + for(auto column: columns) + { + builder.Select(SQLite::RowIDName).From(s_DependenciesTable_Table_Name).Where(column).Equals(Unbound).Limit(1); + + SQLite::Statement select = builder.Prepare(connection); + + select.Bind(1, valueRowId); + if (select.Step()) + { + referenced = true; + break; + } + } + + return referenced; + } + + std::vector DependenciesTable::GetDependenciesMinVersionsRowIdByManifestId(const SQLite::Connection& connection, SQLite::rowid_t manifestRowId) + { + if (!Exists(connection)) + { + return {}; + } + + StatementBuilder builder; + + std::vector result; + + // Find all versions for manifest row. + // SELECT [min_version] FROM [dependencies] + // WHERE [manifest] = ? + builder.Select() + .Column(s_DependenciesTable_MinVersion_Column_Name) + .From({ s_DependenciesTable_Table_Name }) + .Where(s_DependenciesTable_Manifest_Column_Name).Equals(Unbound); + + auto select = builder.Prepare(connection); + + select.Bind(1, manifestRowId); + + while (select.Step()) + { + if (!select.GetColumnIsNull(0)) + { + result.emplace_back(select.GetColumn(0)); + } + } + + return result; + } + std::vector> DependenciesTable::GetAllDependenciesWithMinVersions(const SQLite::Connection& connection) { - if (!Exists(connection)) - { - return {}; - } - + if (!Exists(connection)) + { + return {}; + } + std::vector> result; - constexpr std::string_view depTableAlias = "dep"; - constexpr std::string_view minVersionAlias = "minV"; - - StatementBuilder builder; - - // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep] - // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version] - builder.Select() - .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name)) - .Column(QCol(minVersionAlias, VersionTable::ValueName())) - .From({ s_DependenciesTable_Table_Name }).As(depTableAlias) - .Join({ VersionTable::TableName() }).As(minVersionAlias) - .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name)); - - SQLite::Statement select = builder.Prepare(connection); - - while (select.Step()) - { - Utility::NormalizedString version = ""; - if (!select.GetColumnIsNull(1)) - { - version = select.GetColumn(1); - } - - if (!version.empty()) - { - result.emplace_back(std::make_pair(select.GetColumn(0), version)); - } - } - + constexpr std::string_view depTableAlias = "dep"; + constexpr std::string_view minVersionAlias = "minV"; + + StatementBuilder builder; + + // SELECT [dep].[package_id], [minV].[version] FROM [dependencies] AS [dep] + // JOIN [versions] AS [minV] ON [minV].[rowid] = [dep].[min_version] + builder.Select() + .Column(QCol(depTableAlias, s_DependenciesTable_PackageId_Column_Name)) + .Column(QCol(minVersionAlias, VersionTable::ValueName())) + .From({ s_DependenciesTable_Table_Name }).As(depTableAlias) + .Join({ VersionTable::TableName() }).As(minVersionAlias) + .On(QCol(minVersionAlias, SQLite::RowIDName), QCol(depTableAlias, s_DependenciesTable_MinVersion_Column_Name)); + + SQLite::Statement select = builder.Prepare(connection); + + while (select.Step()) + { + Utility::NormalizedString version = ""; + if (!select.GetColumnIsNull(1)) + { + version = select.GetColumn(1); + } + + if (!version.empty()) + { + result.emplace_back(std::make_pair(select.GetColumn(0), version)); + } + } + return result; - } + } } \ No newline at end of file diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h index 703b58ccc1..5a37d6a4d9 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/DependenciesTable.h @@ -41,7 +41,7 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_4 static std::vector> GetDependentsById(const SQLite::Connection& connection, AppInstaller::Manifest::string_t packageId); // Check dependencies table consistency. - static bool DependenciesTableCheckConsistency(const SQLite::Connection& connection, bool log); + static bool CheckConsistency(const SQLite::Connection& connection, bool log); // Checks if the row id is present in the column denoted by the value supplied. static bool IsValueReferenced(const SQLite::Connection& connection, std::string_view valueName, SQLite::rowid_t valueRowId); diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp index 9c3bf3a686..439c10e861 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_4/Interface_1_4.cpp @@ -117,7 +117,7 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_4 // If the v1.3 index was consistent, or if full logging of inconsistency was requested, check the v1.4 data. if (result || log) { - result = DependenciesTable::DependenciesTableCheckConsistency(connection, log) && result; + result = DependenciesTable::CheckConsistency(connection, log) && result; } if (result || log) diff --git a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp index 38844df57c..a3d6703265 100644 --- a/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp +++ b/src/AppInstallerRepositoryCore/Microsoft/Schema/1_5/Interface_1_5.cpp @@ -1,202 +1,202 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. -#include "pch.h" -#include "Microsoft/Schema/1_5/Interface.h" -#include "Microsoft/Schema/1_5/ArpVersionVirtualTable.h" -#include "Microsoft/Schema/1_0/ManifestTable.h" -#include "Microsoft/Schema/1_0/VersionTable.h" - -namespace AppInstaller::Repository::Microsoft::Schema::V1_5 -{ - Interface::Interface(Utility::NormalizationVersion normVersion) : V1_4::Interface(normVersion) - { - } - - Schema::Version Interface::GetVersion() const - { - return { 1, 5 }; - } - - void Interface::CreateTables(SQLite::Connection& connection, CreateOptions options) - { - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createtables_v1_5"); - - V1_4::Interface::CreateTables(connection, options); - - V1_0::ManifestTable::AddColumn(connection, { ArpMinVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::Int64}); - V1_0::ManifestTable::AddColumn(connection, { ArpMaxVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::Int64 }); - - savepoint.Commit(); - } - - SQLite::rowid_t Interface::AddManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath) - { - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "addmanifest_v1_5"); - - SQLite::rowid_t manifestId = V1_4::Interface::AddManifest(connection, manifest, relativePath); - - auto arpVersionRange = manifest.GetArpVersionRange(); - Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString(); - Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString(); - SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion); - SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion); - V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId); - V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId); - - savepoint.Commit(); - - return manifestId; - } - - std::pair Interface::UpdateManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath) - { - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "updatemanifest_v1_5"); - - auto [indexModified, manifestId] = V1_4::Interface::UpdateManifest(connection, manifest, relativePath); - - auto [oldMinVersionId, oldMaxVersionId] = - V1_0::ManifestTable::GetIdsById(connection, manifestId); - - auto arpVersionRange = manifest.GetArpVersionRange(); - Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString(); - Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString(); - - SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion); - SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion); - - // For cleaning up the old entries after update if applicable - bool cleanOldMinVersionId = false; - bool cleanOldMaxVersionId = false; - - if (arpMinVersionId != oldMinVersionId) - { - V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId); - cleanOldMinVersionId = true; - indexModified = true; - } - - if (arpMaxVersionId != oldMaxVersionId) - { - V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId); - cleanOldMaxVersionId = true; - indexModified = true; - } - - if (cleanOldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMinVersionId)) - { - V1_0::VersionTable::DeleteById(connection, oldMinVersionId); - } - - if (cleanOldMaxVersionId && oldMaxVersionId != oldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMaxVersionId)) - { - V1_0::VersionTable::DeleteById(connection, oldMaxVersionId); - } - - savepoint.Commit(); - - return { indexModified, manifestId }; - } - - void Interface::RemoveManifestById(SQLite::Connection& connection, SQLite::rowid_t manifestId) - { - // Get the old arp version ids of the values from the manifest table - auto [arpMinVersionId, arpMaxVersionId] = - V1_0::ManifestTable::GetIdsById(connection, manifestId); - - SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "RemoveManifestById_v1_5"); - - // Removes the manifest. - V1_4::Interface::RemoveManifestById(connection, manifestId); - - // Remove the versions that are not needed. - if (NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMinVersionId)) - { - V1_0::VersionTable::DeleteById(connection, arpMinVersionId); - } - - if (arpMaxVersionId != arpMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMaxVersionId)) - { - V1_0::VersionTable::DeleteById(connection, arpMaxVersionId); - } - - savepoint.Commit(); - } - - bool Interface::NotNeeded(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, SQLite::rowid_t id) const - { - bool result = V1_4::Interface::NotNeeded(connection, tableName, valueName, id); - - if (result && tableName == V1_0::VersionTable::TableName()) - { - if (valueName != V1_0::VersionTable::ValueName()) - { - result = !V1_0::ManifestTable::IsValueReferenced(connection, V1_0::VersionTable::ValueName(), id) && result; - } - if (valueName != ArpMinVersionVirtualTable::ManifestColumnName()) - { - result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMinVersionVirtualTable::ManifestColumnName(), id) && result; - } - if (valueName != ArpMaxVersionVirtualTable::ManifestColumnName()) - { - result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMaxVersionVirtualTable::ManifestColumnName(), id) && result; - } - } - - return result; - } - - bool Interface::CheckConsistency(const SQLite::Connection& connection, bool log) const - { - bool result = V1_4::Interface::CheckConsistency(connection, log); - - // If the v1.4 index was consistent, or if full logging of inconsistency was requested, check the v1.5 data. - if (result || log) - { - result = V1_0::ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = V1_0::ManifestTable::CheckConsistency(connection, log) && result; - } - - if (result || log) - { - result = ValidateArpVersionConsistency(connection, log) && result; - } - - return result; - } - - std::optional Interface::GetPropertyByManifestIdInternal(const SQLite::Connection& connection, SQLite::rowid_t manifestId, PackageVersionProperty property) const - { - switch (property) - { - case AppInstaller::Repository::PackageVersionProperty::ArpMinVersion: - return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId)); - case AppInstaller::Repository::PackageVersionProperty::ArpMaxVersion: - return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId)); - default: - return V1_4::Interface::GetPropertyByManifestIdInternal(connection, manifestId, property); - } - } - +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. +#include "pch.h" +#include "Microsoft/Schema/1_5/Interface.h" +#include "Microsoft/Schema/1_5/ArpVersionVirtualTable.h" +#include "Microsoft/Schema/1_0/ManifestTable.h" +#include "Microsoft/Schema/1_0/VersionTable.h" + +namespace AppInstaller::Repository::Microsoft::Schema::V1_5 +{ + Interface::Interface(Utility::NormalizationVersion normVersion) : V1_4::Interface(normVersion) + { + } + + Schema::Version Interface::GetVersion() const + { + return { 1, 5 }; + } + + void Interface::CreateTables(SQLite::Connection& connection, CreateOptions options) + { + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "createtables_v1_5"); + + V1_4::Interface::CreateTables(connection, options); + + V1_0::ManifestTable::AddColumn(connection, { ArpMinVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::RowId }); + V1_0::ManifestTable::AddColumn(connection, { ArpMaxVersionVirtualTable::ManifestColumnName(), SQLite::Builder::Type::RowId }); + + savepoint.Commit(); + } + + SQLite::rowid_t Interface::AddManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath) + { + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "addmanifest_v1_5"); + + SQLite::rowid_t manifestId = V1_4::Interface::AddManifest(connection, manifest, relativePath); + + auto arpVersionRange = manifest.GetArpVersionRange(); + Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString(); + Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString(); + SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion); + SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion); + V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId); + V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId); + + savepoint.Commit(); + + return manifestId; + } + + std::pair Interface::UpdateManifest(SQLite::Connection& connection, const Manifest::Manifest& manifest, const std::optional& relativePath) + { + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "updatemanifest_v1_5"); + + auto [indexModified, manifestId] = V1_4::Interface::UpdateManifest(connection, manifest, relativePath); + + auto [oldMinVersionId, oldMaxVersionId] = + V1_0::ManifestTable::GetIdsById(connection, manifestId); + + auto arpVersionRange = manifest.GetArpVersionRange(); + Manifest::string_t arpMinVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMinVersion().ToString(); + Manifest::string_t arpMaxVersion = arpVersionRange.IsEmpty() ? "" : arpVersionRange.GetMaxVersion().ToString(); + + SQLite::rowid_t arpMinVersionId = V1_0::VersionTable::EnsureExists(connection, arpMinVersion); + SQLite::rowid_t arpMaxVersionId = V1_0::VersionTable::EnsureExists(connection, arpMaxVersion); + + // For cleaning up the old entries after update if applicable + bool cleanOldMinVersionId = false; + bool cleanOldMaxVersionId = false; + + if (arpMinVersionId != oldMinVersionId) + { + V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMinVersionId); + cleanOldMinVersionId = true; + indexModified = true; + } + + if (arpMaxVersionId != oldMaxVersionId) + { + V1_0::ManifestTable::UpdateValueIdById(connection, manifestId, arpMaxVersionId); + cleanOldMaxVersionId = true; + indexModified = true; + } + + if (cleanOldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMinVersionId)) + { + V1_0::VersionTable::DeleteById(connection, oldMinVersionId); + } + + if (cleanOldMaxVersionId && oldMaxVersionId != oldMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), oldMaxVersionId)) + { + V1_0::VersionTable::DeleteById(connection, oldMaxVersionId); + } + + savepoint.Commit(); + + return { indexModified, manifestId }; + } + + void Interface::RemoveManifestById(SQLite::Connection& connection, SQLite::rowid_t manifestId) + { + // Get the old arp version ids of the values from the manifest table + auto [arpMinVersionId, arpMaxVersionId] = + V1_0::ManifestTable::GetIdsById(connection, manifestId); + + SQLite::Savepoint savepoint = SQLite::Savepoint::Create(connection, "RemoveManifestById_v1_5"); + + // Removes the manifest. + V1_4::Interface::RemoveManifestById(connection, manifestId); + + // Remove the versions that are not needed. + if (NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMinVersionId)) + { + V1_0::VersionTable::DeleteById(connection, arpMinVersionId); + } + + if (arpMaxVersionId != arpMinVersionId && NotNeeded(connection, V1_0::VersionTable::TableName(), V1_0::VersionTable::ValueName(), arpMaxVersionId)) + { + V1_0::VersionTable::DeleteById(connection, arpMaxVersionId); + } + + savepoint.Commit(); + } + + bool Interface::NotNeeded(const SQLite::Connection& connection, std::string_view tableName, std::string_view valueName, SQLite::rowid_t id) const + { + bool result = V1_4::Interface::NotNeeded(connection, tableName, valueName, id); + + if (result && tableName == V1_0::VersionTable::TableName()) + { + if (valueName != V1_0::VersionTable::ValueName()) + { + result = !V1_0::ManifestTable::IsValueReferenced(connection, V1_0::VersionTable::ValueName(), id) && result; + } + if (valueName != ArpMinVersionVirtualTable::ManifestColumnName()) + { + result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMinVersionVirtualTable::ManifestColumnName(), id) && result; + } + if (valueName != ArpMaxVersionVirtualTable::ManifestColumnName()) + { + result = !V1_0::ManifestTable::IsValueReferenced(connection, ArpMaxVersionVirtualTable::ManifestColumnName(), id) && result; + } + } + + return result; + } + + bool Interface::CheckConsistency(const SQLite::Connection& connection, bool log) const + { + bool result = V1_4::Interface::CheckConsistency(connection, log); + + // If the v1.4 index was consistent, or if full logging of inconsistency was requested, check the v1.5 data. + if (result || log) + { + result = V1_0::ManifestTable::CheckConsistency(connection, log) && result; + } + + if (result || log) + { + result = V1_0::ManifestTable::CheckConsistency(connection, log) && result; + } + + if (result || log) + { + result = ValidateArpVersionConsistency(connection, log) && result; + } + + return result; + } + + std::optional Interface::GetPropertyByManifestIdInternal(const SQLite::Connection& connection, SQLite::rowid_t manifestId, PackageVersionProperty property) const + { + switch (property) + { + case AppInstaller::Repository::PackageVersionProperty::ArpMinVersion: + return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId)); + case AppInstaller::Repository::PackageVersionProperty::ArpMaxVersion: + return std::get<0>(V1_0::ManifestTable::GetValuesById(connection, manifestId)); + default: + return V1_4::Interface::GetPropertyByManifestIdInternal(connection, manifestId, property); + } + } + bool Interface::ValidateArpVersionConsistency(const SQLite::Connection& connection, bool log) const { - try - { - bool result = true; - - // Search everything - SearchRequest request; - auto searchResult = Search(connection, request); - for (auto const& match : searchResult.Matches) - { - // Get arp version ranges for each package to check - std::vector ranges; - auto versionKeys = GetVersionKeysById(connection, match.first); - for (auto const& versionKey : versionKeys) - { + try + { + bool result = true; + + // Search everything + SearchRequest request; + auto searchResult = Search(connection, request); + for (auto const& match : searchResult.Matches) + { + // Get arp version ranges for each package to check + std::vector ranges; + auto versionKeys = GetVersionKeysById(connection, match.first); + for (auto const& versionKey : versionKeys) + { auto manifestRowId = GetManifestIdByKey(connection, match.first, versionKey.GetVersion().ToString(), versionKey.GetChannel().ToString()); if (manifestRowId) { @@ -210,28 +210,28 @@ namespace AppInstaller::Repository::Microsoft::Schema::V1_5 { ranges.emplace_back(Utility::VersionRange{ Utility::Version{ std::move(arpMinVersion) }, Utility::Version{ std::move(arpMaxVersion) } }); } - } - } - - // Check overlap - if (Utility::HasOverlapInVersionRanges(ranges)) - { - AICLI_LOG(Repo, Error, << "Overlapped Arp version ranges found for package. PackageRowId: " << match.first); - result = false; - - if (!log) - { - break; - } - } - } - - return result; - } - catch (...) - { - AICLI_LOG(Repo, Error, << "ValidateArpVersionConsistency() encountered internal error. Returning false."); - return false; + } + } + + // Check overlap + if (Utility::HasOverlapInVersionRanges(ranges)) + { + AICLI_LOG(Repo, Error, << "Overlapped Arp version ranges found for package. PackageRowId: " << match.first); + result = false; + + if (!log) + { + break; + } + } + } + + return result; + } + catch (...) + { + AICLI_LOG(Repo, Error, << "ValidateArpVersionConsistency() encountered internal error. Returning false."); + return false; } - } + } } \ No newline at end of file diff --git a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp index 13e1a0b70d..e431375fb5 100644 --- a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp +++ b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.cpp @@ -307,6 +307,20 @@ namespace AppInstaller::Repository::SQLite::Builder return *this; } + StatementBuilder& StatementBuilder::WhereValueContainsEmbeddedNullCharacter(std::string_view column) + { + OutputColumns(m_stream, " WHERE instr(", column); + m_stream << ",char(0))>0"; + return *this; + } + + StatementBuilder& StatementBuilder::WhereValueContainsEmbeddedNullCharacter(const QualifiedColumn& column) + { + OutputColumns(m_stream, " WHERE instr(", column); + m_stream << ",char(0))>0"; + return *this; + } + StatementBuilder& StatementBuilder::Equals(details::unbound_t) { AppendOpAndBinder(Op::Equals); diff --git a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h index e82a8f4ea9..3bcb632d6e 100644 --- a/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h +++ b/src/AppInstallerRepositoryCore/SQLiteStatementBuilder.h @@ -217,6 +217,11 @@ namespace AppInstaller::Repository::SQLite::Builder StatementBuilder& Where(std::string_view column); StatementBuilder& Where(const QualifiedColumn& column); + // A full filter clause looking for an embedded null character. + // Is extremely specific to consistency checks, and so a more detailed construct is not required. + StatementBuilder& WhereValueContainsEmbeddedNullCharacter(std::string_view column); + StatementBuilder& WhereValueContainsEmbeddedNullCharacter(const QualifiedColumn& column); + // Indicate the operation of the filter clause. template StatementBuilder& Equals(const ValueType& value) diff --git a/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp b/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp index 5e72b13f6d..a19d720734 100644 --- a/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp +++ b/src/AppInstallerRepositoryCore/SQLiteWrapper.cpp @@ -50,8 +50,14 @@ namespace AppInstaller::Repository::SQLite THROW_IF_SQLITE_FAILED(sqlite3_bind_null(stmt, index)); } + void ThrowIfContainsEmbeddedNullCharacter(std::string_view v) + { + THROW_HR_IF(APPINSTALLER_CLI_ERROR_BIND_WITH_EMBEDDED_NULL, v.find('\0') != std::string_view::npos); + } + void ParameterSpecificsImpl::Bind(sqlite3_stmt* stmt, int index, const std::string& v) { + ThrowIfContainsEmbeddedNullCharacter(v); THROW_IF_SQLITE_FAILED(sqlite3_bind_text64(stmt, index, v.c_str(), v.size(), SQLITE_TRANSIENT, SQLITE_UTF8)); } @@ -70,6 +76,7 @@ namespace AppInstaller::Repository::SQLite } else { + ThrowIfContainsEmbeddedNullCharacter(v); THROW_IF_SQLITE_FAILED(sqlite3_bind_text64(stmt, index, v.data(), v.size(), SQLITE_TRANSIENT, SQLITE_UTF8)); } }