-
Notifications
You must be signed in to change notification settings - Fork 5
从TFGAN的LOSS引申开来,阅读源码后的一些理解
DingfengShi edited this page Feb 15, 2018
·
3 revisions
- TFGAN的API定义了各种各样的LOSS函数,包括了现今比较火的各种各样LOSS,比如WGAN,WGAN-GP的LOSS等等
- 但是,很多情况下,我们需要根据需要修改或者定义自己的LOSS
最基本的实现很简单,按如下定义:
def your_loss(gan_model, *kargs):
loss=.....
return loss
- 其中,gan_model就是tfgan的gan_model类。
按照谷歌一贯的个性,官方实现的源码都不会像想像中这么简单,化简单为复杂(划掉)
可以发现,官方实现的loss function的输入参数都不是gan_model,而是各种各样五花八门,为什么会这样呢?
这里先了解一下python的一种类:namedtuple
namedtuple相当于C语言里的结构体的声明,其参数主要就两个:typename,field_names,分别表示类型名和数据成员名
举个官方的栗子:
>>> Point = namedtuple('Point', ['x', 'y'])
>>> p = Point(11, y=22) # instantiate with positional args or keywords
>>> p[0] + p[1] # indexable like a plain tuple
33
>>> x, y = p # unpack like a regular tuple
>>> x, y
(11, 22)
>>> p.x + p.y # fields also accessible by name
33
>>> d = p._asdict() # convert to a dictionary
>>> d['x']
11
>>> Point(**d) # convert from a dictionary
Point(x=11, y=22)
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
Point(x=100, y=22)
可以看到,定义了一种坐标点的结构体point,里面按顺序存储了x,y两个成员参数
namedtuple的返回值相当于返回了一种结构体类型,可以对其像构造函数一样构造。
实际上,gan_model就是一种namedtuple,其参数如下:
collections.namedtuple('GANModel', (
'generator_inputs',
'generated_data',
'generator_variables',
'generator_scope',
'generator_fn',
'real_data',
'discriminator_real_outputs',
'discriminator_gen_outputs',
'discriminator_variables',
'discriminator_scope',
'discriminator_fn',
))
通过源码可以发现,官方定义的Loss function都不是直接用于传入到gan_model的参数当中的,而是经过一个_arg_to_gan_model(loss_fn)包装后才传入的,这个函数返回值就是一个新的loss function,而这个新的函数跟一开始定义的your_loss一样,也是gan_model
下面贴出源码,不想看可以直接跳到后面结论:
def _args_to_gan_model(loss_fn):
"""Converts a loss taking individual args to one taking a GANModel namedtuple.
The new function has the same name as the original one.
Args:
loss_fn: A python function taking a `GANModel` object and returning a loss
Tensor calculated from that object. The shape of the loss depends on
`reduction`.
Returns:
A new function that takes a GANModel namedtuples and returns the same loss.
"""
# Match arguments in `loss_fn` to elements of `namedtuple`.
# TODO(joelshor): Properly handle `varargs` and `keywords`.
argspec = tf_inspect.getargspec(loss_fn)
defaults = argspec.defaults or []
required_args = set(argspec.args[:-len(defaults)])
args_with_defaults = argspec.args[-len(defaults):]
default_args_dict = dict(zip(args_with_defaults, defaults))
def new_loss_fn(gan_model, **kwargs): # pylint:disable=missing-docstring
def _asdict(namedtuple):
"""Returns a namedtuple as a dictionary.
This is required because `_asdict()` in Python 3.x.x is broken in classes
that inherit from `collections.namedtuple`. See
https://bugs.python.org/issue24931 for more details.
Args:
namedtuple: An object that inherits from `collections.namedtuple`.
Returns:
A dictionary version of the tuple.
"""
return {k: getattr(namedtuple, k) for k in namedtuple._fields}
gan_model_dict = _asdict(gan_model)
# Make sure non-tuple required args are supplied.
args_from_tuple = set(argspec.args).intersection(set(gan_model._fields))
required_args_not_from_tuple = required_args - args_from_tuple
for arg in required_args_not_from_tuple:
if arg not in kwargs:
raise ValueError('`%s` must be supplied to %s loss function.' % (
arg, loss_fn.__name__))
# Make sure tuple args aren't also supplied as keyword args.
ambiguous_args = set(gan_model._fields).intersection(set(kwargs.keys()))
if ambiguous_args:
raise ValueError(
'The following args are present in both the tuple and keyword args '
'for %s: %s' % (loss_fn.__name__, ambiguous_args))
# Add required args to arg dictionary.
required_args_from_tuple = required_args.intersection(args_from_tuple)
for arg in required_args_from_tuple:
assert arg not in kwargs
kwargs[arg] = gan_model_dict[arg]
# Add arguments that have defaults.
for arg in default_args_dict:
val_from_tuple = gan_model_dict[arg] if arg in gan_model_dict else None
val_from_kwargs = kwargs[arg] if arg in kwargs else None
assert not (val_from_tuple is not None and val_from_kwargs is not None)
kwargs[arg] = (val_from_tuple if val_from_tuple is not None else
val_from_kwargs if val_from_kwargs is not None else
default_args_dict[arg])
return loss_fn(**kwargs)
new_docstring = """The gan_model version of %s.""" % loss_fn.__name__
new_loss_fn.__docstring__ = new_docstring
new_loss_fn.__name__ = loss_fn.__name__
new_loss_fn.__module__ = loss_fn.__module__
return new_loss_fn
- 可以自己声明一个loss,传入参数可以自己设定,但是要经过_arg_to_gan_model包装后才能传入estimator
- 对于自己声明的loss,其参数分为两种,就是Python一样,便是无初值和有初值的参数,对于无初值的参数,如果参数名和上面提到的GANModel的namedtuple里面的参数重名的话,会自动帮你把GANModel里该参数的值传入,无需自己输入。
- 不能有模糊参数,即有初值和无初值的参数不能存在同名项
- 额外的参数可以通过**kwargs传入到loss function中
- 这个结论可以推广到tensorflow很多实现中,用来理解某些规定的函数需要怎么去定义
- (实际上实现loss最简单的方法还是一开始那种)但是如果想要利用用上tensorflow已经定义好的loss去定义自己的loss,上文提到的内容能给修改提供便利