Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Separate training and runtime attention #174

Merged
merged 12 commits into from
Dec 12, 2016
Merged

Conversation

cifkao
Copy link
Member

@cifkao cifkao commented Dec 2, 2016

Should fix the issue mentioned here:

The problem is that the same Attention object is used for constructing both the training and runtime parts of the graph, and therefore attentions_in_time contains tensors from both parts. Because I'm using attentions_in_time for visualization, I was getting images like this:

tensorboard

Copy link
Contributor

@jlibovicky jlibovicky left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am sorry, but this is not a good solution for this issue. The reason is that if we would like to use the source sentence coverage (i.e., cumulative attention over time) to compute next attention distribution (as in this paper), we would get into troubles. The correct way will be instantiate two attention object for each way of running the decoder and collect the attentions_in_time only from one of them.

I admit, it will probably require a bigger intervention, but it is definitely worth it.

@cifkao
Copy link
Member Author

cifkao commented Dec 3, 2016

The correct way will be instantiate two attention object for each way of running the decoder and collect the attentions_in_time only from one of them.

I'm not sure if I understand this correctly. Did you mean to say "one attention object for each way of running the decoder"? That will probably be a cleaner solution, but it won't make the TF computation any different, will it? Both Attention objects will be doing the same thing, except that each one will have a different attentions_in_time.

@cifkao
Copy link
Member Author

cifkao commented Dec 3, 2016

Is this what you had in mind? It still seems to me that with regard to CoverageAttention, there is really no difference from my first commit.

@jlibovicky jlibovicky force-pushed the fix-attentions-in-time branch 2 times, most recently from 7d017dc to c04b76e Compare December 8, 2016 13:40
@jlibovicky jlibovicky dismissed their stale review December 8, 2016 13:50

I made the last commits that need a reivew.

@jlibovicky
Copy link
Contributor

After the recent changes, this PR:

  • introduces the notion of Attentive a class that has a function that is a common ancestor of all attentive encoders (note Python 3 allows multiple inheritance), creating an attention object is centralized there
  • decoders don't take the pre-made attention_object from an encoder, rather call a get_attention_object method that instantiates one for them

The benefit of having the attention object like this are:

  • it solves the bug from collecting attentions from both training and runtime run of a RNN decoder into single collection
  • allows doing attention over the same encoder by multiple decoders.

@jlibovicky jlibovicky self-assigned this Dec 8, 2016
@jlibovicky
Copy link
Contributor

@cifkao, I can't add you among the reviewers, but have a look at the code and tell me whether it is what you wanted to do.

@@ -24,7 +24,8 @@ def __init__(self, dimension, output_shape, data_id):
self.encoded = tf.tanh(tf.matmul(self.flat, project_w) + project_b)

self.attention_tensor = None
self.attention_object = None
self.attention_object_train = None
self.attention_object_runtime = None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this should inherit from Attentive too.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether we've implemented attention in this encoder. It was not used in our submission for the multimodal task, so it might be safe to ignore this for now.

@cifkao
Copy link
Member Author

cifkao commented Dec 8, 2016

I think you forgot about image_encoder (or we should revert my changes to it). Otherwise, LGTM.

"""
self.scope = scope
self.attentions_in_time = []
self.attention_states = attention_states
self.input_weights = input_weights

with tf.variable_scope(scope):
with tf.variable_scope(scope, reuse=runtime_mode):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tohle může bejt potenciálně nebezpečný - když ve scope nebude string ale tf.VariableScope objekt, kterej už bude mít nastavený reuse na not runtime_mode, tak si nejsem jistej, jak se to zachová. Pro jistotu bych před a za přidal nějaký rozumný asserty.

Napadá mě situace, kdy runtime_mode bude False a z nějakého důvodu vlezem podruhý do týhle funkce se stejnou scope. Pak to v lepším případě spadne, nebo v horším případě vytvoří novou sadu proměnných, nikomu to nic neřekne, a nebude se to učit správně.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A je tam vůbec nutný tohelto reuse? Když se řekne reuse o úroveň výš, tak se to přenáší i na ty vnitřní scopy, ne?
Jinak mi přijde, že by to nikdy nemělo vyrobit nový proměnný, ale vždycky spadnout.
Jak by měly vypadat ty asserty, nic rozumného mě nenapadá.


assert hasattr(self, "name")
assert hasattr(self, "_padding")
assert hasattr(self, "_attention_tensor")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ač je to asi OK, nepřijde mi jako nejlepší věc diktovat z předka místo v __init__ metodě potomků, ze kterýho se bude volat super.. Navíc mi přijde, že by to asi správně mělo bejt opačně, že by se super měl volat jako první command (á la Java). Nejde to tomu vrazit jako argument konstruktoru, když to potřebuje?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Javový řešení by bylo udělat tady abstraktní proměnný, co by se musely oddědit, pak by super mohlo být první volání v tom konstruktoru. Ale přišlo mi, že by se tím jenom zbytečně prodloužil už tak dlouhej kód těch enkodérů.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jestli myslíš, že je to tak správnější, tak já to tak s radostí přepíšu, ale sám nevím.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ne ze bych moc rozumel vicenasobne dedicnosti v Pythonu, ale tohle by mohl byt problem, kdybychom chteli mit vic takovych trid, ktere by do enkoderu pridavaly nejakou funkcnost dedenim.

Nebylo by lepsi, aby Attentive nedelal v konstruktoru nic a inicializoval se az zavolanim nejake metody?

Jeste je moznost neresit to vubec dedicnosti.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mělo by to být v konstruktoru kvůli tomu, jak se tvoří tensorflow graf z konfiguráku - neni tam moc prostoru pro volání dalších metod..
Ale mohl by to volat dekodér, kterej už dostává hotový enkodéry, to je fakt. To by pak i znamenalo, že pokud attention nepoužijeme, tak by se tahle část grafu v enkodéru ani zbytečně nevytvářela.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jako že Decoder._collect_attention_objects by ještě na těch enkodérech volalo něco jako if encoder.supports_attention: encoder.create_attention_graph(), který by vracelo Attentive nebo Attention objekt.

@jlibovicky šlo by to takhle? Změnila by se tam ta hierarchie, že místo encoder is_a attentive by to bylo encoder has_a attention

Taky to můžeme nechat na pak, aby se na tenhle pull request zbytečně nenabaloval další a další kód.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Myslel jsem to tak, ze enkoder by zavolal super, ale to by nic neudelalo, a teprv dal v konstruktoru, az by to bylo pripravene, by zavolal self._init_attention().

Attention objekt se uz tak vytvari az volanim z dekoderu, jestli se nepletu.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A co by se dělalo v tom self._init_attention()? Teď se v tom konstruktoru jenom kontrolují ty hasattr. Ty když se vyhodí pryč, tak se super() může zavolat úplně na začátku konstruiktoru enkodéru. Důvod proč jsem to udělal takhle je, aby to spadlo už při inicializaci enkodéru, pokud by tam nevznikly správné tenzory a ne až při volání té metody get_attention_object, protože to se děje až v dekodéru a to by mohlo být matoucí potom při ladění konfiguráků.

def get_attention_object(self, runtime: bool=False):
# pylint: disable=no-member
if self._attention_type and self._attention_tensor is None:
raise Exception("Can't get attention: missing attention tensor.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dal bych ValueError páč to má blbou hodnotu. Nebo vytvořit v tomhle modulu nějakej specifičtější error.

name),
dropout_placeholder=self.dropout_placeholder,
input_weights=att_in_weights)
super(CNNEncoder, self).__init__(attention_type)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tohle mě tu dráždí.. Zaprvý z důvodů, který jsem popsal nahoře, zadruhý, protože stačí super().__init__(attention_type)

dropout_placeholder=self.dropout_placeholder,
input_weights=weight_tensor,
max_fertility=attention_fertility)
super(FactoredEncoder, self).__init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

def feed_dict(self, dataset, train=False):
factors = {data_id: dataset.get_series(data_id) for data_id in self.data_ids}
factors = {data_id: dataset.get_series(
data_id) for data_id in self.data_ids}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

já to radši škubu až před for, který se pak zarovná pod data_id, ale proti gustu.. :-)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Já taky, za tohole může autotpep8, to pouštim na soubory, kde pekelně moc dlouhých řádek.

@@ -24,7 +24,8 @@ def __init__(self, dimension, output_shape, data_id):
self.encoded = tf.tanh(tf.matmul(self.flat, project_w) + project_b)

self.attention_tensor = None
self.attention_object = None
self.attention_object_train = None
self.attention_object_runtime = None

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know whether we've implemented attention in this encoder. It was not used in our submission for the multimodal task, so it might be safe to ignore this for now.

input_weights=self.padding,
max_fertility=attention_fertility) if attention_type else None
super(SentenceEncoder, self).__init__(
attention_type, attention_fertility=attention_fertility)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

použít super().__init__, promyslet jak by to šlo líp bez toho

@jindrahelcl
Copy link
Member

jindrahelcl commented Dec 9, 2016 via email

@jindrahelcl
Copy link
Member

jindrahelcl commented Dec 9, 2016 via email

@jindrahelcl
Copy link
Member

jindrahelcl commented Dec 9, 2016 via email

@jindrahelcl
Copy link
Member

jindrahelcl commented Dec 9, 2016 via email

@jlibovicky
Copy link
Contributor

To by mohla, bylo by to asi 1. nepřehlednější, 2. neodstraňuje se tím to, že by to padalo při konstrukci dekodéru. Zatím se mi asi líbí Javovský nápad: attentive bude mít astraktní properties a enkodéry je budou přetěžovat.

@jlibovicky jlibovicky force-pushed the fix-attentions-in-time branch from 0d2a8d8 to b318c4f Compare December 12, 2016 09:49
@jlibovicky
Copy link
Contributor

Teď je tu varianta Attentive s abstraktními properties. Komentáře, připomínky?

@property
def _attention_tensor(self):
"""Tensor over which the attention is done."""
raise NotImplementedError(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Proč nepoužít @abstractmethod?


# pylint: disable=too-few-public-methods
class Attentive(object):
def __init__(self, attention_type, **kwargs):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chtělo by to ještě doplnit docstringy, ale už nechci zdržovat merge.

encoder_state = tf.concat(
1, [encoder_state, backward_encoder_state])

self.encoded = encoder_state

self.attention_tensor = \
self.__attention_tensor = \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

dvě podtržítka? Neni to moc?
plus, backslash jde vyrefakorovat do konce řádky za reshape(

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dvě podrtžítka, protože private a ne protected.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok


@property
def _attention_mask(self):
return self.__attention_weights
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mask, weights, jedno podtřžítko dvě podtržítka, to je nějakěj bordel teda

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Z dvěma podržítkama je to private field, s jedním je to property odděděná od attentive.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Zvážil bych ještě používání cached_property, pak by tam k tomu nemusly být zvlášť ty attributes. Vůbec, mám pocit, že zvaedení těhle chached properties by mohlo dost zpřehlednit a zmodulárnět kód. Vyhlásím k tomu diskutovací issue.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Já jsem asi pro.

self.attention_tensor = tf.concat(2, outputs_bidi_tup)
self.attention_tensor = self._dropout(self.attention_tensor)
self.__attention_tensor = tf.concat(2, outputs_bidi_tup)
self.__attention_tensor = self._dropout(self._attention_tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

@@ -112,16 +119,16 @@ def _create_input_placeholders(self):
self.inputs = tf.placeholder(tf.int32, shape=[None, self.max_input_len],
name="encoder_input")

self.padding = tf.placeholder(
self.__input_weights = tf.placeholder(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weights přejmenuju na mask.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok :-)

@@ -91,7 +91,7 @@ def __init__(self,
dtype=tf.float32)

self.__attention_tensor = tf.concat(2, outputs_bidi_tup)
self.__attention_tensor = self._dropout(self._attention_tensor)
self.__attention_tensor = self._dropout(self.__attention_tensor)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a to před tim prošly testy?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Co by neprošlo - akorát to šlo oklikou přes tu property místo toho, aby se volal rovnou ten atribut (měl jsem kliku, že to dělalo to samý, ale bylo to sémanticky blbě).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

jo no jo.. hm..

@jlibovicky jlibovicky force-pushed the fix-attentions-in-time branch from d4af86f to cb69c61 Compare December 12, 2016 14:23
@jlibovicky jlibovicky merged commit 52a851d into master Dec 12, 2016
@jlibovicky jlibovicky deleted the fix-attentions-in-time branch December 12, 2016 14:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants