diff --git a/tests/test_updater_ng.py b/tests/test_updater_ng.py index eec10d73cb..8c0d48eb9a 100644 --- a/tests/test_updater_ng.py +++ b/tests/test_updater_ng.py @@ -16,6 +16,7 @@ from tests import utils from tuf import ngclient +from tuf.ngclient.updater import _is_target_in_pathpattern logger = logging.getLogger(__name__) @@ -148,6 +149,29 @@ def test_refresh_with_only_local_root(self): # Get targetinfo for 'file3.txt' listed in the delegated role1 targetinfo3= self.repository_updater.get_one_valid_targetinfo('file3.txt') + + def test_is_target_in_pathpattern(self): + supported_use_cases = [ + ("foo.tgz", "foo.tgz"), + ("foo.tgz", "*"), + ("targets/foo.tgz", "*"), + ("foo.tgz", "*.tgz"), + ("foo-version-a.tgz", "foo-version-?.tgz"), + ("targets/foo.tgz", "targets/*.tgz"), + ("foo/bar/zoo/k.tgz", "foo/bar/zoo/*"), + ("foo/bar/zoo/k.tgz", "foo/bar/*") + ] + for targetname, pathpattern in supported_use_cases: + self.assertTrue(_is_target_in_pathpattern(targetname, pathpattern)) + + invalid_use_cases = [ + ("targets/foo.tgz", "*.tgz"), + ("*.tgz", "/foo.tgz"), + ("foo-version-?.tgz", "foo-version-alpha.tgz") + ] + for pathpattern, targetname in invalid_use_cases: + self.assertFalse(_is_target_in_pathpattern(pathpattern, targetname)) + if __name__ == '__main__': utils.configure_test_logging(sys.argv) unittest.main() diff --git a/tuf/ngclient/updater.py b/tuf/ngclient/updater.py index 850f46b9cf..242cabbb41 100644 --- a/tuf/ngclient/updater.py +++ b/tuf/ngclient/updater.py @@ -463,6 +463,22 @@ def _preorder_depth_first_walk(self, target_filepath) -> Dict: return {"filepath": target_filepath, "fileinfo": target} +def _is_target_in_pathpattern(targetname, pathpattern): + if pathpattern == "*": + return True + pathpattern_dir = os.path.dirname(pathpattern) + target_dir = os.path.dirname(targetname) + if pathpattern_dir == target_dir: + return fnmatch.fnmatch(targetname, pathpattern) + # Check that pathpattern_dir contains the target_dir and `*` is used at the + # end of the pattern. + # For example, targetname "foo/bar/zoo/k.tgz" and pattpatern "foo/bar/*". + elif target_dir.startswith(pathpattern_dir) and pathpattern[-1] == "*": + return True + # pathpattern dir != targetname dir means target is not in pathpattern + return False + + def _visit_child_role(child_role: Dict, target_filepath: str) -> str: """ @@ -527,8 +543,8 @@ def _visit_child_role(child_role: Dict, target_filepath: str) -> str: # target without a leading path separator - make sure to strip any # leading path separators so that a match is made. # Example: "foo.tgz" should match with "/*.tgz". - if fnmatch.fnmatch( - target_filepath.lstrip(os.sep), child_role_path.lstrip(os.sep) + if _is_target_in_pathpattern( + child_role_path.lstrip(os.sep), target_filepath.lstrip(os.sep), ): logger.debug( "Child role %s is allowed to sign for %s",