diff --git a/CHANGELOG.md b/CHANGELOG.md index 1405a14d..56f832b1 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,7 +8,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com). ### Changed -* If no log dir is specified `logging.find_log_dir()` now falls back to `tempfile.gettempdir()` instead of `/tmp/`. +* (logging) If no log dir is specified `logging.find_log_dir()` now falls back + to `tempfile.gettempdir()` instead of `/tmp/`. + +### Fixed + +* (flags) Additional kwargs (e.g. `short_name=`) to `DEFINE_multi_enum_class` + are now correctly passed to the underlying `Flag` object. ## 1.3.0 (2022-10-11) diff --git a/absl/flags/__init__.pyi b/absl/flags/__init__.pyi index 4eee59e2..7bf6842e 100644 --- a/absl/flags/__init__.pyi +++ b/absl/flags/__init__.pyi @@ -52,6 +52,9 @@ mark_flags_as_required = _validators.mark_flags_as_required mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive +# Flag modifiers. +set_default = _defines.set_default + # Key flag related functions. declare_key_flag = _defines.declare_key_flag adopt_module_key_flags = _defines.adopt_module_key_flags diff --git a/absl/flags/_defines.py b/absl/flags/_defines.py index dce53ea2..61354e94 100644 --- a/absl/flags/_defines.py +++ b/absl/flags/_defines.py @@ -859,11 +859,17 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin """ return DEFINE_flag( _flag.MultiEnumClassFlag( - name, default, help, enum_class, case_sensitive=case_sensitive), + name, + default, + help, + enum_class, + case_sensitive=case_sensitive, + **args, + ), flag_values, module_name, required=required, - **args) + ) def DEFINE_alias( # pylint: disable=invalid-name diff --git a/absl/flags/_flag.pyi b/absl/flags/_flag.pyi index 9b4a3d3a..35066445 100644 --- a/absl/flags/_flag.pyi +++ b/absl/flags/_flag.pyi @@ -20,7 +20,7 @@ import functools from absl.flags import _argument_parser import enum -from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence +from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence _T = TypeVar('_T') _ET = TypeVar('_ET', bound=enum.Enum) @@ -44,6 +44,7 @@ class Flag(Generic[_T]): using_default_value = ... # type: bool allow_overwrite = ... # type: bool allow_using_method_names = ... # type: bool + validators = ... # type: List[Callable[[Any], bool]] def __init__(self, parser: _argument_parser.ArgumentParser[_T], diff --git a/absl/flags/_flagvalues.py b/absl/flags/_flagvalues.py index 937dc6c2..6661b783 100644 --- a/absl/flags/_flagvalues.py +++ b/absl/flags/_flagvalues.py @@ -411,7 +411,9 @@ def __setitem__(self, name, flag): """Registers a new flag variable.""" fl = self._flags() if not isinstance(flag, _flag.Flag): - raise _exceptions.IllegalFlagValueError(flag) + raise _exceptions.IllegalFlagValueError( + f'Expect Flag instances, found type {type(flag)}. ' + "Maybe you didn't mean to use FlagValue.__setitem__?") if not isinstance(name, str): raise _exceptions.Error('Flag name must be a string') if not name: diff --git a/absl/flags/tests/flags_test.py b/absl/flags/tests/flags_test.py index 77ed307e..7cacbc84 100644 --- a/absl/flags/tests/flags_test.py +++ b/absl/flags/tests/flags_test.py @@ -1591,6 +1591,17 @@ def test_bad_multi_enum_flags(self): class MultiEnumClassFlagsTest(absltest.TestCase): + def test_short_name(self): + fv = flags.FlagValues() + flags.DEFINE_multi_enum_class( + 'fruit', + None, + Fruit, + 'Enum option that can occur multiple times', + flag_values=fv, + short_name='me') + self.assertEqual(fv['fruit'].short_name, 'me') + def test_define_results_in_registered_flag_with_none(self): fv = flags.FlagValues() enum_defaults = None diff --git a/absl/logging/__init__.py b/absl/logging/__init__.py index 33276cd1..f4e79675 100644 --- a/absl/logging/__init__.py +++ b/absl/logging/__init__.py @@ -86,6 +86,7 @@ import socket import struct import sys +import tempfile import threading import tempfile import time @@ -707,22 +708,26 @@ def find_log_dir(log_dir=None): FileNotFoundError: raised in Python 3 when it cannot find a log directory. OSError: raised in Python 2 when it cannot find a log directory. """ - # Get a possible log dir. + # Get a list of possible log dirs (will try to use them in order). + # NOTE: Google's internal implementation has a special handling for Google + # machines, which uses a list of directories. Hence the following uses `dirs` + # instead of a single directory. if log_dir: # log_dir was explicitly specified as an arg, so use it and it alone. - log_dir_candidate = log_dir + dirs = [log_dir] elif FLAGS['log_dir'].value: # log_dir flag was provided, so use it and it alone (this mimics the # behavior of the same flag in logging.cc). - log_dir_candidate = FLAGS['log_dir'].value + dirs = [FLAGS['log_dir'].value] else: - log_dir_candidate = tempfile.gettempdir() + dirs = [tempfile.gettempdir()] - # Test if log dir candidate is usable. - if os.path.isdir(log_dir_candidate) and os.access(log_dir_candidate, os.W_OK): - return log_dir_candidate + # Find the first usable log dir. + for d in dirs: + if os.path.isdir(d) and os.access(d, os.W_OK): + return d raise FileNotFoundError( - "Can't find a writable directory for logs, tried %s" % log_dir_candidate) + "Can't find a writable directory for logs, tried %s" % dirs) def get_absl_log_prefix(record): diff --git a/absl/testing/absltest.py b/absl/testing/absltest.py index 9071f8f6..1bbcee74 100644 --- a/absl/testing/absltest.py +++ b/absl/testing/absltest.py @@ -533,7 +533,10 @@ def open_bytes(self, mode='rb'): # currently `Any` to avoid [bad-return-type] errors in the open_* methods. @contextlib.contextmanager def _open( - self, mode: str, encoding: str = 'utf8', errors: str = 'strict' + self, + mode: str, + encoding: Optional[str] = 'utf8', + errors: Optional[str] = 'strict', ) -> Iterator[Any]: with io.open( self.full_path, mode=mode, encoding=encoding, errors=errors) as fp: @@ -638,7 +641,7 @@ def test_foo(self): self.assertTrue(os.path.exists(expected_paths[1])) self.assertEqual('foo', out_log.read_text()) - See also: :meth:`create_tempdir` for creating temporary files. + See also: :meth:`create_tempfile` for creating temporary files. Args: name: Optional name of the directory. If not given, a unique diff --git a/absl/testing/flagsaver.py b/absl/testing/flagsaver.py index 37926d7a..774c698c 100644 --- a/absl/testing/flagsaver.py +++ b/absl/testing/flagsaver.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - """Decorator and context manager for saving and restoring flag values. There are many ways to save and restore. Always use the most convenient method @@ -61,11 +60,26 @@ def some_func(): import functools import inspect +from typing import overload, Any, Callable, Mapping, Tuple, TypeVar from absl import flags FLAGS = flags.FLAGS +# The type of pre/post wrapped functions. +_CallableT = TypeVar('_CallableT', bound=Callable) + + +@overload +def flagsaver(*args: Tuple[flags.FlagHolder, Any], + **kwargs: Any) -> '_FlagOverrider': + ... + + +@overload +def flagsaver(func: _CallableT) -> _CallableT: + ... + def flagsaver(*args, **kwargs): """The main flagsaver interface. See module doc for usage.""" @@ -94,12 +108,14 @@ def flagsaver(*args, **kwargs): return _FlagOverrider(**kwargs) -def save_flag_values(flag_values=FLAGS): +def save_flag_values( + flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]: """Returns copy of flag values as a dict. Args: - flag_values: FlagValues, the FlagValues instance with which the flag will - be saved. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance with which the flag will be + saved. This should almost never need to be overridden. + Returns: Dictionary mapping keys to values. Keys are flag names, values are corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``. @@ -107,13 +123,14 @@ def save_flag_values(flag_values=FLAGS): return {name: _copy_flag_dict(flag_values[name]) for name in flag_values} -def restore_flag_values(saved_flag_values, flag_values=FLAGS): +def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]], + flag_values: flags.FlagValues = FLAGS): """Restores flag values based on the dictionary of flag values. Args: saved_flag_values: {'flag_name': value_dict, ...} - flag_values: FlagValues, the FlagValues instance from which the flag will - be restored. This should almost never need to be overridden. + flag_values: FlagValues, the FlagValues instance from which the flag will be + restored. This should almost never need to be overridden. """ new_flag_names = list(flag_values) for name in new_flag_names: @@ -127,23 +144,24 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS): flag_values[name].__dict__ = saved -def _wrap(func, overrides): +def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT: """Creates a wrapper function that saves/restores flag values. Args: - func: function object - This will be called between saving flags and - restoring flags. - overrides: {str: object} - Flag names mapped to their values. These flags - will be set after saving the original flag state. + func: This will be called between saving flags and restoring flags. + overrides: Flag names mapped to their values. These flags will be set after + saving the original flag state. Returns: - return value from func() + A wrapped version of func. """ + @functools.wraps(func) def _flagsaver_wrapper(*args, **kwargs): """Wrapper function that saves and restores flags.""" with _FlagOverrider(**overrides): return func(*args, **kwargs) + return _flagsaver_wrapper @@ -154,11 +172,11 @@ class _FlagOverrider(object): completes. """ - def __init__(self, **overrides): + def __init__(self, **overrides: Any): self._overrides = overrides self._saved_flag_values = None - def __call__(self, func): + def __call__(self, func: _CallableT) -> _CallableT: if inspect.isclass(func): raise TypeError('flagsaver cannot be applied to a class.') return _wrap(func, self._overrides) @@ -176,7 +194,7 @@ def __exit__(self, exc_type, exc_value, traceback): restore_flag_values(self._saved_flag_values, FLAGS) -def _copy_flag_dict(flag): +def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]: """Returns a copy of the flag object's ``__dict__``. It's mostly a shallow copy of the ``__dict__``, except it also does a shallow