Skip to content

从TFGAN的LOSS引申开来,阅读源码后的一些理解

DingfengShi edited this page Feb 15, 2018 · 3 revisions
  • TFGAN的API定义了各种各样的LOSS函数,包括了现今比较火的各种各样LOSS,比如WGAN,WGAN-GP的LOSS等等
  • 但是,很多情况下,我们需要根据需要修改或者定义自己的LOSS

最基本的实现很简单,按如下定义:

def your_loss(gan_model, *kargs):
    loss=.....
    return loss
  • 其中,gan_model就是tfgan的gan_model类。

然鹅

按照谷歌一贯的个性,官方实现的源码都不会像想像中这么简单,化简单为复杂(划掉)
可以发现,官方实现的loss function的输入参数都不是gan_model,而是各种各样五花八门,为什么会这样呢?

namedtuple

这里先了解一下python的一种类:namedtuple
namedtuple相当于C语言里的结构体的声明,其参数主要就两个:typename,field_names,分别表示类型名和数据成员名
举个官方的栗子:

    >>> Point = namedtuple('Point', ['x', 'y'])
    >>> p = Point(11, y=22)             # instantiate with positional args or keywords
    >>> p[0] + p[1]                     # indexable like a plain tuple
    33
    >>> x, y = p                        # unpack like a regular tuple
    >>> x, y
    (11, 22)
    >>> p.x + p.y                       # fields also accessible by name
    33
    >>> d = p._asdict()                 # convert to a dictionary
    >>> d['x']
    11
    >>> Point(**d)                      # convert from a dictionary
    Point(x=11, y=22)
    >>> p._replace(x=100)               # _replace() is like str.replace() but targets named fields
    Point(x=100, y=22)

可以看到,定义了一种坐标点的结构体point,里面按顺序存储了x,y两个成员参数
namedtuple的返回值相当于返回了一种结构体类型,可以对其像构造函数一样构造。

实际上,gan_model就是一种namedtuple,其参数如下:

    collections.namedtuple('GANModel', (
        'generator_inputs',
        'generated_data',
        'generator_variables',
        'generator_scope',
        'generator_fn',
        'real_data',
        'discriminator_real_outputs',
        'discriminator_gen_outputs',
        'discriminator_variables',
        'discriminator_scope',
        'discriminator_fn',
    ))

通过源码可以发现,官方定义的Loss function都不是直接用于传入到gan_model的参数当中的,而是经过一个_arg_to_gan_model(loss_fn)包装后才传入的,这个函数返回值就是一个新的loss function,而这个新的函数跟一开始定义的your_loss一样,也是gan_model
下面贴出源码,不想看可以直接跳到后面结论:

def _args_to_gan_model(loss_fn):
  """Converts a loss taking individual args to one taking a GANModel namedtuple.

  The new function has the same name as the original one.

  Args:
    loss_fn: A python function taking a `GANModel` object and returning a loss
      Tensor calculated from that object. The shape of the loss depends on
      `reduction`.

  Returns:
    A new function that takes a GANModel namedtuples and returns the same loss.
  """
  # Match arguments in `loss_fn` to elements of `namedtuple`.
  # TODO(joelshor): Properly handle `varargs` and `keywords`.
  argspec = tf_inspect.getargspec(loss_fn)
  defaults = argspec.defaults or []

  required_args = set(argspec.args[:-len(defaults)])
  args_with_defaults = argspec.args[-len(defaults):]
  default_args_dict = dict(zip(args_with_defaults, defaults))

  def new_loss_fn(gan_model, **kwargs):  # pylint:disable=missing-docstring
    def _asdict(namedtuple):
      """Returns a namedtuple as a dictionary.

      This is required because `_asdict()` in Python 3.x.x is broken in classes
      that inherit from `collections.namedtuple`. See
      https://bugs.python.org/issue24931 for more details.

      Args:
        namedtuple: An object that inherits from `collections.namedtuple`.

      Returns:
        A dictionary version of the tuple.
      """
      return {k: getattr(namedtuple, k) for k in namedtuple._fields}
    gan_model_dict = _asdict(gan_model)

    # Make sure non-tuple required args are supplied.
    args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))
    required_args_not_from_tuple = required_args - args_from_tuple
    for arg in required_args_not_from_tuple:
      if arg not in kwargs:
        raise ValueError('`%s` must be supplied to %s loss function.' % (
            arg, loss_fn.__name__))

    # Make sure tuple args aren't also supplied as keyword args.
    ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys()))
    if ambiguous_args:
      raise ValueError(
          'The following args are present in both the tuple and keyword args '
          'for %s: %s' % (loss_fn.__name__, ambiguous_args))

    # Add required args to arg dictionary.
    required_args_from_tuple = required_args.intersection(args_from_tuple)
    for arg in required_args_from_tuple:
      assert arg not in kwargs
      kwargs[arg] = gan_model_dict[arg]

    # Add arguments that have defaults.
    for arg in default_args_dict:
      val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None
      val_from_kwargs = kwargs[arg] if arg in kwargs else None
      assert not (val_from_tuple is not None and val_from_kwargs is not None)
      kwargs[arg] = (val_from_tuple if val_from_tuple is not None else
                     val_from_kwargs if val_from_kwargs is not None else
                     default_args_dict[arg])

    return loss_fn(**kwargs)

  new_docstring = """The gan_model version of %s.""" % loss_fn.__name__
  new_loss_fn.__docstring__ = new_docstring
  new_loss_fn.__name__ = loss_fn.__name__
  new_loss_fn.__module__ = loss_fn.__module__
  return new_loss_fn

结论

  • 可以自己声明一个loss,传入参数可以自己设定,但是要经过_arg_to_gan_model包装后才能传入estimator
  • 对于自己声明的loss,其参数分为两种,就是Python一样,便是无初值和有初值的参数,对于无初值的参数,如果参数名和上面提到的GANModel的namedtuple里面的参数重名的话,会自动帮你把GANModel里该参数的值传入,无需自己输入。
  • 不能有模糊参数,即有初值和无初值的参数不能存在同名项
  • 额外的参数可以通过**kwargs传入到loss function中
  • 这个结论可以推广到tensorflow很多实现中,用来理解某些规定的函数需要怎么去定义
  • (实际上实现loss最简单的方法还是一开始那种)但是如果想要利用用上tensorflow已经定义好的loss去定义自己的loss,上文提到的内容能给修改提供便利
Clone this wiki locally