From 2aa750cb4820ab551566e550d0c4e6ad697d87c8 Mon Sep 17 00:00:00 2001 From: Alexander Schepanovski Date: Mon, 18 Nov 2019 17:16:48 +0100 Subject: [PATCH] perf: switch schema validation library for config --- dvc/config.py | 230 +++++++++++++++++++------------------------------- 1 file changed, 85 insertions(+), 145 deletions(-) diff --git a/dvc/config.py b/dvc/config.py index bd569864cb..b883ef259b 100644 --- a/dvc/config.py +++ b/dvc/config.py @@ -8,12 +8,8 @@ import re import configobj -from schema import And -from schema import Optional -from schema import Regex -from schema import Schema -from schema import SchemaError -from schema import Use +from voluptuous import Schema, Required, Optional, Invalid +from voluptuous import All, Any, Lower, Range, Coerce, Match from dvc.exceptions import DvcException from dvc.exceptions import NotDvcRepoError @@ -50,82 +46,43 @@ def __init__(self, command, cause=None): def supported_cache_type(types): - """Checks if link type config option has a valid value. + """Checks if link type config option consists only of valid values. Args: types (list/string): type(s) of links that dvc should try out. """ + if types is None: + return None if isinstance(types, str): types = [typ.strip() for typ in types.split(",")] - for typ in types: - if typ not in ["reflink", "hardlink", "symlink", "copy"]: - return False - return True - - -def is_bool(val): - """Checks that value is a boolean. - - Args: - val (str): string value verify. - - Returns: - bool: True if value stands for boolean, False otherwise. - """ - return val.lower() in ["true", "false"] - - -def to_bool(val): - """Converts value to boolean. - - Args: - val (str): string to convert to boolean. - - Returns: - bool: True if value.lower() == 'true', False otherwise. - """ - return val.lower() == "true" - - -def is_whole(val): - """Checks that value is a whole integer. - - Args: - val (str): number string to verify. - - Returns: - bool: True if val is a whole number, False otherwise. - """ - return int(val) >= 0 + unsupported = set(types) - {"reflink", "hardlink", "symlink", "copy"} + if unsupported: + raise Invalid( + "Unsupported cache type(s): {}".format(", ".join(unsupported)) + ) -def is_percent(val): - """Checks that value is a percent. + return types - Args: - val (str): number string to verify. - Returns: - bool: True if 0<=value<=100, False otherwise. - """ - return int(val) >= 0 and int(val) <= 100 +# Checks that value is either true or false and converts it to bool +Bool = All( + Lower, + Any("true", "false"), + lambda v: v == "true", + msg="expected true or false", +) +to_bool = Schema(Bool) -class Choices(object): +def Choices(*choices): """Checks that value belongs to the specified set of values Args: *choices: pass allowed values as arguments, or pass a list or tuple as a single argument """ - - def __init__(self, *choices): - if len(choices) == 1 and isinstance(choices[0], (list, tuple)): - choices = choices[0] - self.choices = choices - - def __call__(self, value): - return value in self.choices + return Any(*choices, msg="expected one of {}".format(",".join(choices))) class Config(object): # pylint: disable=too-many-instance-attributes @@ -158,28 +115,22 @@ class Config(object): # pylint: disable=too-many-instance-attributes LEVEL_GLOBAL = 2 LEVEL_SYSTEM = 3 - BOOL_SCHEMA = And(str, is_bool, Use(to_bool)) - SECTION_CORE = "core" SECTION_CORE_LOGLEVEL = "loglevel" - SECTION_CORE_LOGLEVEL_SCHEMA = And( - Use(str.lower), Choices("info", "debug", "warning", "error") + SECTION_CORE_LOGLEVEL_SCHEMA = All( + Lower, Choices("info", "debug", "warning", "error") ) SECTION_CORE_REMOTE = "remote" - SECTION_CORE_INTERACTIVE_SCHEMA = BOOL_SCHEMA SECTION_CORE_INTERACTIVE = "interactive" SECTION_CORE_ANALYTICS = "analytics" - SECTION_CORE_ANALYTICS_SCHEMA = BOOL_SCHEMA SECTION_CORE_CHECKSUM_JOBS = "checksum_jobs" - SECTION_CORE_CHECKSUM_JOBS_SCHEMA = And(Use(int), lambda x: x > 0) SECTION_CACHE = "cache" SECTION_CACHE_DIR = "dir" SECTION_CACHE_TYPE = "type" - SECTION_CACHE_TYPE_SCHEMA = supported_cache_type SECTION_CACHE_PROTECTED = "protected" SECTION_CACHE_SHARED = "shared" - SECTION_CACHE_SHARED_SCHEMA = And(Use(str.lower), Choices("group")) + SECTION_CACHE_SHARED_SCHEMA = All(Lower, Choices("group")) SECTION_CACHE_LOCAL = "local" SECTION_CACHE_S3 = "s3" SECTION_CACHE_GS = "gs" @@ -188,34 +139,26 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_CACHE_AZURE = "azure" SECTION_CACHE_SLOW_LINK_WARNING = "slow_link_warning" SECTION_CACHE_SCHEMA = { - Optional(SECTION_CACHE_LOCAL): str, - Optional(SECTION_CACHE_S3): str, - Optional(SECTION_CACHE_GS): str, - Optional(SECTION_CACHE_HDFS): str, - Optional(SECTION_CACHE_SSH): str, - Optional(SECTION_CACHE_AZURE): str, - Optional(SECTION_CACHE_DIR): str, - Optional(SECTION_CACHE_TYPE, default=None): SECTION_CACHE_TYPE_SCHEMA, - Optional(SECTION_CACHE_PROTECTED, default=False): BOOL_SCHEMA, - Optional(SECTION_CACHE_SHARED): SECTION_CACHE_SHARED_SCHEMA, - Optional(PRIVATE_CWD): str, - Optional(SECTION_CACHE_SLOW_LINK_WARNING, default=True): BOOL_SCHEMA, + SECTION_CACHE_LOCAL: str, + SECTION_CACHE_S3: str, + SECTION_CACHE_GS: str, + SECTION_CACHE_HDFS: str, + SECTION_CACHE_SSH: str, + SECTION_CACHE_AZURE: str, + SECTION_CACHE_DIR: str, + SECTION_CACHE_TYPE: supported_cache_type, + Optional(SECTION_CACHE_PROTECTED, default=False): Bool, + SECTION_CACHE_SHARED: SECTION_CACHE_SHARED_SCHEMA, + PRIVATE_CWD: str, + Optional(SECTION_CACHE_SLOW_LINK_WARNING, default=True): Bool, } SECTION_CORE_SCHEMA = { - Optional(SECTION_CORE_LOGLEVEL): And( - str, Use(str.lower), SECTION_CORE_LOGLEVEL_SCHEMA - ), - Optional(SECTION_CORE_REMOTE, default=""): And(str, Use(str.lower)), - Optional( - SECTION_CORE_INTERACTIVE, default=False - ): SECTION_CORE_INTERACTIVE_SCHEMA, - Optional( - SECTION_CORE_ANALYTICS, default=True - ): SECTION_CORE_ANALYTICS_SCHEMA, - Optional( - SECTION_CORE_CHECKSUM_JOBS, default=None - ): SECTION_CORE_CHECKSUM_JOBS_SCHEMA, + SECTION_CORE_LOGLEVEL: SECTION_CORE_LOGLEVEL_SCHEMA, + SECTION_CORE_REMOTE: Lower, + Optional(SECTION_CORE_INTERACTIVE, default=False): Bool, + Optional(SECTION_CORE_ANALYTICS, default=True): Bool, + SECTION_CORE_CHECKSUM_JOBS: All(Coerce(int), Range(1)), } # backward compatibility @@ -230,15 +173,15 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_AWS_SSE = "sse" SECTION_AWS_ACL = "acl" SECTION_AWS_SCHEMA = { - SECTION_AWS_STORAGEPATH: str, - Optional(SECTION_AWS_REGION): str, - Optional(SECTION_AWS_PROFILE): str, - Optional(SECTION_AWS_CREDENTIALPATH): str, - Optional(SECTION_AWS_ENDPOINT_URL): str, - Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA, - Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, - Optional(SECTION_AWS_SSE): str, - Optional(SECTION_AWS_ACL): str, + Required(SECTION_AWS_STORAGEPATH): str, + SECTION_AWS_REGION: str, + SECTION_AWS_PROFILE: str, + SECTION_AWS_CREDENTIALPATH: str, + SECTION_AWS_ENDPOINT_URL: str, + Optional(SECTION_AWS_LIST_OBJECTS, default=False): Bool, + Optional(SECTION_AWS_USE_SSL, default=True): Bool, + SECTION_AWS_SSE: str, + SECTION_AWS_ACL: str, } # backward compatibility @@ -247,14 +190,14 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_GCP_CREDENTIALPATH = SECTION_AWS_CREDENTIALPATH SECTION_GCP_PROJECTNAME = "projectname" SECTION_GCP_SCHEMA = { - SECTION_GCP_STORAGEPATH: str, - Optional(SECTION_GCP_PROJECTNAME): str, + Required(SECTION_GCP_STORAGEPATH): str, + SECTION_GCP_PROJECTNAME: str, } # backward compatibility SECTION_LOCAL = "local" SECTION_LOCAL_STORAGEPATH = SECTION_AWS_STORAGEPATH - SECTION_LOCAL_SCHEMA = {SECTION_LOCAL_STORAGEPATH: str} + SECTION_LOCAL_SCHEMA = {Required(SECTION_LOCAL_STORAGEPATH): str} SECTION_AZURE_CONNECTION_STRING = "connection_string" # Alibabacloud oss options @@ -274,51 +217,48 @@ class Config(object): # pylint: disable=too-many-instance-attributes SECTION_REMOTE_GSS_AUTH = "gss_auth" SECTION_REMOTE_NO_TRAVERSE = "no_traverse" SECTION_REMOTE_SCHEMA = { - SECTION_REMOTE_URL: str, - Optional(SECTION_AWS_REGION): str, - Optional(SECTION_AWS_PROFILE): str, - Optional(SECTION_AWS_CREDENTIALPATH): str, - Optional(SECTION_AWS_ENDPOINT_URL): str, - Optional(SECTION_AWS_LIST_OBJECTS, default=False): BOOL_SCHEMA, - Optional(SECTION_AWS_USE_SSL, default=True): BOOL_SCHEMA, - Optional(SECTION_AWS_SSE): str, - Optional(SECTION_AWS_ACL): str, - Optional(SECTION_GCP_PROJECTNAME): str, - Optional(SECTION_CACHE_TYPE): SECTION_CACHE_TYPE_SCHEMA, - Optional(SECTION_CACHE_PROTECTED, default=False): BOOL_SCHEMA, - Optional(SECTION_REMOTE_USER): str, - Optional(SECTION_REMOTE_PORT): Use(int), - Optional(SECTION_REMOTE_KEY_FILE): str, - Optional(SECTION_REMOTE_TIMEOUT): Use(int), - Optional(SECTION_REMOTE_PASSWORD): str, - Optional(SECTION_REMOTE_ASK_PASSWORD): BOOL_SCHEMA, - Optional(SECTION_REMOTE_GSS_AUTH): BOOL_SCHEMA, - Optional(SECTION_AZURE_CONNECTION_STRING): str, - Optional(SECTION_OSS_ACCESS_KEY_ID): str, - Optional(SECTION_OSS_ACCESS_KEY_SECRET): str, - Optional(SECTION_OSS_ENDPOINT): str, - Optional(PRIVATE_CWD): str, - Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): BOOL_SCHEMA, + Required(SECTION_REMOTE_URL): str, + SECTION_AWS_REGION: str, + SECTION_AWS_PROFILE: str, + SECTION_AWS_CREDENTIALPATH: str, + SECTION_AWS_ENDPOINT_URL: str, + Optional(SECTION_AWS_LIST_OBJECTS, default=False): Bool, + Optional(SECTION_AWS_USE_SSL, default=True): Bool, + SECTION_AWS_SSE: str, + SECTION_AWS_ACL: str, + SECTION_GCP_PROJECTNAME: str, + SECTION_CACHE_TYPE: supported_cache_type, + Optional(SECTION_CACHE_PROTECTED, default=False): Bool, + SECTION_REMOTE_USER: str, + SECTION_REMOTE_PORT: Coerce(int), + SECTION_REMOTE_KEY_FILE: str, + SECTION_REMOTE_TIMEOUT: Coerce(int), + SECTION_REMOTE_PASSWORD: str, + SECTION_REMOTE_ASK_PASSWORD: Bool, + SECTION_REMOTE_GSS_AUTH: Bool, + SECTION_AZURE_CONNECTION_STRING: str, + SECTION_OSS_ACCESS_KEY_ID: str, + SECTION_OSS_ACCESS_KEY_SECRET: str, + SECTION_OSS_ENDPOINT: str, + PRIVATE_CWD: str, + Optional(SECTION_REMOTE_NO_TRAVERSE, default=True): Bool, } SECTION_STATE = "state" SECTION_STATE_ROW_LIMIT = "row_limit" SECTION_STATE_ROW_CLEANUP_QUOTA = "row_cleanup_quota" SECTION_STATE_SCHEMA = { - Optional(SECTION_STATE_ROW_LIMIT): And(Use(int), is_whole), - Optional(SECTION_STATE_ROW_CLEANUP_QUOTA): And(Use(int), is_percent), + SECTION_STATE_ROW_LIMIT: All(Coerce(int), Range(1)), + SECTION_STATE_ROW_CLEANUP_QUOTA: All(Coerce(int), Range(0, 100)), } SCHEMA = { Optional(SECTION_CORE, default={}): SECTION_CORE_SCHEMA, - Optional(Regex(SECTION_REMOTE_REGEX)): SECTION_REMOTE_SCHEMA, + Match(SECTION_REMOTE_REGEX): SECTION_REMOTE_SCHEMA, Optional(SECTION_CACHE, default={}): SECTION_CACHE_SCHEMA, Optional(SECTION_STATE, default={}): SECTION_STATE_SCHEMA, - # backward compatibility - Optional(SECTION_AWS, default={}): SECTION_AWS_SCHEMA, - Optional(SECTION_GCP, default={}): SECTION_GCP_SCHEMA, - Optional(SECTION_LOCAL, default={}): SECTION_LOCAL_SCHEMA, } + COMPILED_SCHEMA = Schema(SCHEMA) def __init__(self, dvc_dir=None, validate=True): self.dvc_dir = dvc_dir @@ -457,9 +397,9 @@ def load(self): d = self.config.dict() try: - d = Schema(self.SCHEMA).validate(d) - except SchemaError as exc: - raise ConfigError("config format error", cause=exc) + d = self.COMPILED_SCHEMA(d) + except Invalid as exc: + raise ConfigError(str(exc), cause=exc) self.config = configobj.ConfigObj(d, write_empty_values=True) def save(self, config=None):