Skip to content

Commit

Permalink
Merge pull request #206 from yilei/push_up_to_494901390_2
Browse files Browse the repository at this point in the history
Push up to 494901390
  • Loading branch information
yilei authored Dec 15, 2022
2 parents 83adb26 + 8da67e5 commit 1f1c14f
Show file tree
Hide file tree
Showing 9 changed files with 86 additions and 31 deletions.
8 changes: 7 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,13 @@ The format is based on [Keep a Changelog](https://keepachangelog.com).

### Changed

* If no log dir is specified `logging.find_log_dir()` now falls back to `tempfile.gettempdir()` instead of `/tmp/`.
* (logging) If no log dir is specified `logging.find_log_dir()` now falls back
to `tempfile.gettempdir()` instead of `/tmp/`.

### Fixed

* (flags) Additional kwargs (e.g. `short_name=`) to `DEFINE_multi_enum_class`
are now correctly passed to the underlying `Flag` object.

## 1.3.0 (2022-10-11)

Expand Down
3 changes: 3 additions & 0 deletions absl/flags/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ mark_flags_as_required = _validators.mark_flags_as_required
mark_flags_as_mutual_exclusive = _validators.mark_flags_as_mutual_exclusive
mark_bool_flags_as_mutual_exclusive = _validators.mark_bool_flags_as_mutual_exclusive

# Flag modifiers.
set_default = _defines.set_default

# Key flag related functions.
declare_key_flag = _defines.declare_key_flag
adopt_module_key_flags = _defines.adopt_module_key_flags
Expand Down
10 changes: 8 additions & 2 deletions absl/flags/_defines.py
Original file line number Diff line number Diff line change
Expand Up @@ -859,11 +859,17 @@ def DEFINE_multi_enum_class( # pylint: disable=invalid-name,redefined-builtin
"""
return DEFINE_flag(
_flag.MultiEnumClassFlag(
name, default, help, enum_class, case_sensitive=case_sensitive),
name,
default,
help,
enum_class,
case_sensitive=case_sensitive,
**args,
),
flag_values,
module_name,
required=required,
**args)
)


def DEFINE_alias( # pylint: disable=invalid-name
Expand Down
3 changes: 2 additions & 1 deletion absl/flags/_flag.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import functools
from absl.flags import _argument_parser
import enum

from typing import Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence
from typing import Callable, Text, TypeVar, Generic, Iterable, Type, List, Optional, Any, Union, Sequence

_T = TypeVar('_T')
_ET = TypeVar('_ET', bound=enum.Enum)
Expand All @@ -44,6 +44,7 @@ class Flag(Generic[_T]):
using_default_value = ... # type: bool
allow_overwrite = ... # type: bool
allow_using_method_names = ... # type: bool
validators = ... # type: List[Callable[[Any], bool]]

def __init__(self,
parser: _argument_parser.ArgumentParser[_T],
Expand Down
4 changes: 3 additions & 1 deletion absl/flags/_flagvalues.py
Original file line number Diff line number Diff line change
Expand Up @@ -411,7 +411,9 @@ def __setitem__(self, name, flag):
"""Registers a new flag variable."""
fl = self._flags()
if not isinstance(flag, _flag.Flag):
raise _exceptions.IllegalFlagValueError(flag)
raise _exceptions.IllegalFlagValueError(
f'Expect Flag instances, found type {type(flag)}. '
"Maybe you didn't mean to use FlagValue.__setitem__?")
if not isinstance(name, str):
raise _exceptions.Error('Flag name must be a string')
if not name:
Expand Down
11 changes: 11 additions & 0 deletions absl/flags/tests/flags_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,17 @@ def test_bad_multi_enum_flags(self):

class MultiEnumClassFlagsTest(absltest.TestCase):

def test_short_name(self):
fv = flags.FlagValues()
flags.DEFINE_multi_enum_class(
'fruit',
None,
Fruit,
'Enum option that can occur multiple times',
flag_values=fv,
short_name='me')
self.assertEqual(fv['fruit'].short_name, 'me')

def test_define_results_in_registered_flag_with_none(self):
fv = flags.FlagValues()
enum_defaults = None
Expand Down
21 changes: 13 additions & 8 deletions absl/logging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
import socket
import struct
import sys
import tempfile
import threading
import tempfile
import time
Expand Down Expand Up @@ -707,22 +708,26 @@ def find_log_dir(log_dir=None):
FileNotFoundError: raised in Python 3 when it cannot find a log directory.
OSError: raised in Python 2 when it cannot find a log directory.
"""
# Get a possible log dir.
# Get a list of possible log dirs (will try to use them in order).
# NOTE: Google's internal implementation has a special handling for Google
# machines, which uses a list of directories. Hence the following uses `dirs`
# instead of a single directory.
if log_dir:
# log_dir was explicitly specified as an arg, so use it and it alone.
log_dir_candidate = log_dir
dirs = [log_dir]
elif FLAGS['log_dir'].value:
# log_dir flag was provided, so use it and it alone (this mimics the
# behavior of the same flag in logging.cc).
log_dir_candidate = FLAGS['log_dir'].value
dirs = [FLAGS['log_dir'].value]
else:
log_dir_candidate = tempfile.gettempdir()
dirs = [tempfile.gettempdir()]

# Test if log dir candidate is usable.
if os.path.isdir(log_dir_candidate) and os.access(log_dir_candidate, os.W_OK):
return log_dir_candidate
# Find the first usable log dir.
for d in dirs:
if os.path.isdir(d) and os.access(d, os.W_OK):
return d
raise FileNotFoundError(
"Can't find a writable directory for logs, tried %s" % log_dir_candidate)
"Can't find a writable directory for logs, tried %s" % dirs)


def get_absl_log_prefix(record):
Expand Down
7 changes: 5 additions & 2 deletions absl/testing/absltest.py
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,10 @@ def open_bytes(self, mode='rb'):
# currently `Any` to avoid [bad-return-type] errors in the open_* methods.
@contextlib.contextmanager
def _open(
self, mode: str, encoding: str = 'utf8', errors: str = 'strict'
self,
mode: str,
encoding: Optional[str] = 'utf8',
errors: Optional[str] = 'strict',
) -> Iterator[Any]:
with io.open(
self.full_path, mode=mode, encoding=encoding, errors=errors) as fp:
Expand Down Expand Up @@ -638,7 +641,7 @@ def test_foo(self):
self.assertTrue(os.path.exists(expected_paths[1]))
self.assertEqual('foo', out_log.read_text())
See also: :meth:`create_tempdir` for creating temporary files.
See also: :meth:`create_tempfile` for creating temporary files.
Args:
name: Optional name of the directory. If not given, a unique
Expand Down
50 changes: 34 additions & 16 deletions absl/testing/flagsaver.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Decorator and context manager for saving and restoring flag values.
There are many ways to save and restore. Always use the most convenient method
Expand Down Expand Up @@ -61,11 +60,26 @@ def some_func():

import functools
import inspect
from typing import overload, Any, Callable, Mapping, Tuple, TypeVar

from absl import flags

FLAGS = flags.FLAGS

# The type of pre/post wrapped functions.
_CallableT = TypeVar('_CallableT', bound=Callable)


@overload
def flagsaver(*args: Tuple[flags.FlagHolder, Any],
**kwargs: Any) -> '_FlagOverrider':
...


@overload
def flagsaver(func: _CallableT) -> _CallableT:
...


def flagsaver(*args, **kwargs):
"""The main flagsaver interface. See module doc for usage."""
Expand Down Expand Up @@ -94,26 +108,29 @@ def flagsaver(*args, **kwargs):
return _FlagOverrider(**kwargs)


def save_flag_values(flag_values=FLAGS):
def save_flag_values(
flag_values: flags.FlagValues = FLAGS) -> Mapping[str, Mapping[str, Any]]:
"""Returns copy of flag values as a dict.
Args:
flag_values: FlagValues, the FlagValues instance with which the flag will
be saved. This should almost never need to be overridden.
flag_values: FlagValues, the FlagValues instance with which the flag will be
saved. This should almost never need to be overridden.
Returns:
Dictionary mapping keys to values. Keys are flag names, values are
corresponding ``__dict__`` members. E.g. ``{'key': value_dict, ...}``.
"""
return {name: _copy_flag_dict(flag_values[name]) for name in flag_values}


def restore_flag_values(saved_flag_values, flag_values=FLAGS):
def restore_flag_values(saved_flag_values: Mapping[str, Mapping[str, Any]],
flag_values: flags.FlagValues = FLAGS):
"""Restores flag values based on the dictionary of flag values.
Args:
saved_flag_values: {'flag_name': value_dict, ...}
flag_values: FlagValues, the FlagValues instance from which the flag will
be restored. This should almost never need to be overridden.
flag_values: FlagValues, the FlagValues instance from which the flag will be
restored. This should almost never need to be overridden.
"""
new_flag_names = list(flag_values)
for name in new_flag_names:
Expand All @@ -127,23 +144,24 @@ def restore_flag_values(saved_flag_values, flag_values=FLAGS):
flag_values[name].__dict__ = saved


def _wrap(func, overrides):
def _wrap(func: _CallableT, overrides: Mapping[str, Any]) -> _CallableT:
"""Creates a wrapper function that saves/restores flag values.
Args:
func: function object - This will be called between saving flags and
restoring flags.
overrides: {str: object} - Flag names mapped to their values. These flags
will be set after saving the original flag state.
func: This will be called between saving flags and restoring flags.
overrides: Flag names mapped to their values. These flags will be set after
saving the original flag state.
Returns:
return value from func()
A wrapped version of func.
"""

@functools.wraps(func)
def _flagsaver_wrapper(*args, **kwargs):
"""Wrapper function that saves and restores flags."""
with _FlagOverrider(**overrides):
return func(*args, **kwargs)

return _flagsaver_wrapper


Expand All @@ -154,11 +172,11 @@ class _FlagOverrider(object):
completes.
"""

def __init__(self, **overrides):
def __init__(self, **overrides: Any):
self._overrides = overrides
self._saved_flag_values = None

def __call__(self, func):
def __call__(self, func: _CallableT) -> _CallableT:
if inspect.isclass(func):
raise TypeError('flagsaver cannot be applied to a class.')
return _wrap(func, self._overrides)
Expand All @@ -176,7 +194,7 @@ def __exit__(self, exc_type, exc_value, traceback):
restore_flag_values(self._saved_flag_values, FLAGS)


def _copy_flag_dict(flag):
def _copy_flag_dict(flag: flags.Flag) -> Mapping[str, Any]:
"""Returns a copy of the flag object's ``__dict__``.
It's mostly a shallow copy of the ``__dict__``, except it also does a shallow
Expand Down

0 comments on commit 1f1c14f

Please sign in to comment.