Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

Generative model with start features #225

Merged
merged 2 commits into from
May 7, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions light/world/souls/base_soul.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def build_context(self, partner_name=None, quest_txt=None):
txt += quest_txt
txt += "\n"
return txt

def build_dialog_context(self, quest_txt=None):
# Initial context.
txt = self.build_context(quest_txt)
Expand All @@ -212,7 +212,8 @@ def build_dialog_context(self, quest_txt=None):
# reset conversation when unsafe utterances are in the history
dtxt = self.build_context(quest_txt)
dtxt = dtxt.lstrip(" ")
return txt + dtxt
final = txt + dtxt
return final

@classmethod
def load_generic_act_model(cls, generic_act_model_file):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from light.world.souls.models.generative_heuristic_model_soul import GenerativeHeuristicModelSoul

class GenerativeHeuristicModelWithStartFeatureSoul(GenerativeHeuristicModelSoul):

def add_startswith_tokens(self, context, dialogue_txt):
# extract partner name
partner_name = ""
if self.target_node._last_interaction_partner_id != None:
partner = self.world.oo_graph.get_node(
self.target_node._last_interaction_partner_id
)
if partner is not None:
partner_name = partner.get_prefix_view()
if len(dialogue_txt) < 3:
feature = "START " + partner_name
else:
feature = "CONTINUE " + partner_name
final = context + dialogue_txt + '\n' + feature
#print(final)
return final

def build_dialog_context(self, quest_txt=None):
# Initial context.
txt = self.build_context(quest_txt)
# Dialogue/interaction context.
dtxt = ""
agent = self.target_node
agent_id = agent.node_id
turn_id = None
for d in agent._last_interaction_history:
current_turn_id = d[0][0]
if turn_id == None or turn_id == current_turn_id:
dtxt += " " + d[1]
else:
dtxt = dtxt.lstrip(" ")
dtxt += "\n" + d[1]
turn_id = current_turn_id
is_safe = d[0][2]
if not is_safe:
# reset conversation when unsafe utterances are in the history
dtxt = self.build_context(quest_txt)
dtxt = dtxt.lstrip(" ")

# Add starting context features, can help the model.
final = self.add_startswith_tokens(txt, dtxt)

return final

7 changes: 7 additions & 0 deletions projects/dialog_and_act_trainer/SETUP.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ py convert.py -t fromfile -ffdp /checkpoint/light/projects/dialog_and_act_train

etc. (for all the files in /checkpoint/light/projects/dialog_and_act_trainer/raw/ .. )

#OR
#
# we make that final data we want with convo starter tokens (to try to preent starting convo as if in middle ):

py convert_withstarter.py -t fromfile -ffdp /checkpoint/light/projects/dialog_and_act_trainer/raw/light_train.txt -of /checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_train.txt


AND

# for the type picker (dialog, act, emote) task:
Expand Down
1 change: 1 addition & 0 deletions projects/dialog_and_act_trainer/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def clean(msg):
txt = msg["text"]
res = []
app = ""

for t in txt.split("\n"):
if t.startswith("_") and "_object_desc" not in t:
if (
Expand Down
161 changes: 161 additions & 0 deletions projects/dialog_and_act_trainer/convert_withstarter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Convert a dataset into the ParlAI text format.

## Examples

```shell
parlai convert_data_to_parlai_format -t babi:task1k:1 --outfile /tmp/dump
```
"""

from parlai.core.params import ParlaiParser
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.worlds import create_task
from parlai.utils.misc import msg_to_str, TimeLogger
import parlai.utils.logging as logging
from parlai.core.script import ParlaiScript, register_script
import random
import tempfile
import copy


def clean(msg):
msg = copy.deepcopy(msg)
txt = msg["text"]
res = []
app = ""
convo_has_started = False
for t in txt.split("\n"):
if "_self_say " in t:
convo_has_started = True
if "_partner_say " in t:
convo_has_started = True
if "_partner_name " in t:
partner_name = t.replace("_partner_name ", "")

first_convo_line = True
for t in txt.split("\n"):
if t.startswith("_") and "_object_desc" not in t:
if (
t.startswith("_self_act")
or t.startswith("_self_emote")
or t.startswith("_partner_act")
or t.startswith("_partner_emote")
):
t = t.replace("_self_act ", "")
t = t.replace("_partner_act ", "")
t = t.replace("_self_emote ", "")
t = t.replace("_partner_emote ", "")
app = app + " *" + t + "*"
else:
#if ("_partner_say " in t) or ("_self_say" in t):
# #import pdb; pdb.set_trace()
if ("_partner_say " in t) or ("_self_say" in t):
first_convo_line = False
t = t.replace("_self_say ", "")
t = t.replace("_partner_say ", "")
res.append(t + app)
app = ""

if not convo_has_started:
res.append("START " + partner_name)
else:
res.append("CONTINUE " + partner_name)

msg.force_set("text", "\n".join(res))
#print(res)
return msg


def dump_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
opt.log()
ignorefields = opt.get("ignore_fields", "")
if opt["outfile"] is None:
outfile = tempfile.mkstemp(
prefix="{}_{}_".format(opt["task"], opt["datatype"]), suffix=".txt"
)[1]
else:
outfile = opt["outfile"]

if opt["num_examples"] == -1:
num_examples = world.num_examples()
else:
num_examples = opt["num_examples"]
log_timer = TimeLogger()

logging.debug("starting to convert...")
logging.info(f"saving output to {outfile}")
fw = open(outfile, "w")
for _ in range(num_examples):
world.parley()
acts = world.get_acts()
value = acts[0].get("labels", acts[0].pop("eval_labels", None))
acts[0].force_set("labels", value)

msg = clean(acts[0])

txt = msg_to_str(msg, ignore_fields=ignorefields)
fw.write(txt + "\n")
if acts[0].get("episode_done", False):
fw.write("\n")

if log_timer.time() > opt["log_every_n_secs"]:
text, _log = log_timer.log(world.total_parleys, world.num_examples())
logging.info(text)

if world.epoch_done():
logging.info("epoch done")
break
fw.close()


def setup_args():
# Get command line arguments
parser = ParlaiParser(description="Dump a task to a standardized format")
parser.add_argument(
"-n",
"--num-examples",
default=-1,
type=int,
help="Total number of exs to convert, -1 to convert all examples",
)
parser.add_argument(
"-of",
"--outfile",
default=None,
type=str,
help="Output file where to save, by default will be created in tmp",
)
parser.add_argument(
"-if",
"--ignore-fields",
default="id",
type=str,
help="Ignore these fields from the message (returned with .act() )",
)
parser.add_argument("-ltim", "--log-every-n-secs", type=float, default=2)
parser.set_defaults(datatype="train:stream")
return parser


@register_script("convert_to_parlai", hidden=True)
class ConvertDataToParlaiFormat(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()

def run(self):
return dump_data(self.opt)


if __name__ == "__main__":
random.seed(42)
ConvertDataToParlaiFormat.main()
105 changes: 105 additions & 0 deletions projects/dialog_and_act_trainer/train/gen_withstarter_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.


from parlai_internal.projects.param_sweep_utils.param_sweep import run_grid

SWEEP_NAME = "gen_withstarter"
name_keys = {}

grid = {
"-t": [
'"fromfile:fromfile_datapath=/checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_train.txt,fromfile:fromfile_datapath=/checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_wild_train.txt"'
],
"-et": [
'"fromfile:fromfile_datapath=/checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_wild_valid.txt"',
],
"--max-train-time": [(60 * 60) * 36],
"--model": ["parlai_internal.projects.meena.unlikely.boringul:UnlikelihoodAgent"],
"--eval-skip-generation": [False],
"--validation-max-exs": [200],
"-vp": [1000],
"-vmt": ["boring_fails"],
"-vmm": ["min"],
"--save-after-valid": [True],
"-vtim": [(60 * 60) * 30],
"-stim": [(60) * 60 * 1],
"--inference": ["beam"],
"--beam_min_length": [20],
"--beam_context_block_ngram": [-1],
"--beam_block_ngram": [-1],
"--beam_size": [1],
"--ul-type": ["std0"],
"--seq-ul-nc": [4],
"--seq-ul-nl": [4],
"--seq-ul-ratio": [0.25],
"--train-boring-repeats": [True],
"--train-context-repeats": [False],
"--train-label-repeats": [False],
"--attention-dropout": [0.00],
"--batchsize": ["64"],
"--embedding-size": [2560],
"--ffn-size": [10240],
"--variant": ["prelayernorm"],
"--n-heads": [32],
"--n-positions": [128],
"--n-encoder-layers": [2],
"--n-decoder-layers": [24],
"--history-add-global-end-token": [
"end", # hack to get newline delimiter
],
"--dict-tokenizer": ["bytelevelbpe"],
"--dict-file": [
"/checkpoint/parlai/zoo/meena/20200319_meenav0data_tall_2.7B_adamoptimizer/20200319_13.3ppl_200kupdates/model.dict"
],
"--dropout": [0.1],
"--fp16": [True],
"--init-model": [
# emily's good good
"/checkpoint/parlai/zoo/q_function/generative2.7B_bst_0331/model",
],
"--label-truncate": [128],
"--lr-scheduler": ["reduceonplateau"],
"--lr-scheduler-patience": [3],
"--optimizer": ["adam"],
"--relu-dropout": [0.0],
"--activation": ["gelu"],
"--model-parallel": ["true"],
"--text-truncate": [128],
"--truncate": [128],
"--warmup_updates": [200],
"--fp16-impl": ["mem_efficient"],
"--update-freq": [4],
"--gradient-clip": [0.1],
"--skip-generation": [True],
"--log_every_n_secs": [10],
"-lr": [7e-6],
}

if __name__ == "__main__":
run_grid(
grid,
name_keys,
SWEEP_NAME,
partition="learnfair",
# partition='dev',
jobtime="24:00:00",
PARLAI_PATH="/private/home/jase/src/ParlAI/",
gpus=8,
nodes=1,
volta=True,
volta32=True,
saveroot="/checkpoint/light/projects/dialog_and_act_trainer/models/"
+ SWEEP_NAME,
include_job_id=True,
create_model_file=True,
hashname=False,
# fixedname='model',
requeue=True,
mem_gb=400,
data_parallel=True,
copy_env=False,
)
Loading