Skip to content
This repository has been archived by the owner on Oct 11, 2023. It is now read-only.

Commit

Permalink
Merge pull request #225 from facebookresearch/bugz7
Browse files Browse the repository at this point in the history
Generative model with start features
  • Loading branch information
jaseweston authored May 7, 2021
2 parents 1bce452 + 0202de5 commit 70f4aca
Show file tree
Hide file tree
Showing 7 changed files with 349 additions and 3 deletions.
5 changes: 3 additions & 2 deletions light/world/souls/base_soul.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def build_context(self, partner_name=None, quest_txt=None):
txt += quest_txt
txt += "\n"
return txt

def build_dialog_context(self, quest_txt=None):
# Initial context.
txt = self.build_context(quest_txt)
Expand All @@ -212,7 +212,8 @@ def build_dialog_context(self, quest_txt=None):
# reset conversation when unsafe utterances are in the history
txt = ""
dtxt = dtxt.lstrip(" ")
return txt + dtxt
final = txt + dtxt
return final

@classmethod
def load_generic_act_model(cls, generic_act_model_file):
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from light.world.souls.models.generative_heuristic_model_soul import GenerativeHeuristicModelSoul

class GenerativeHeuristicModelWithStartFeatureSoul(GenerativeHeuristicModelSoul):

def add_startswith_tokens(self, context, dialogue_txt):
# extract partner name
partner_name = ""
if self.target_node._last_interaction_partner_id != None:
partner = self.world.oo_graph.get_node(
self.target_node._last_interaction_partner_id
)
if partner is not None:
partner_name = partner.get_prefix_view()
if len(dialogue_txt) < 3:
feature = "START " + partner_name
else:
feature = "CONTINUE " + partner_name
final = context + dialogue_txt + '\n' + feature
#print(final)
return final

def build_dialog_context(self, quest_txt=None):
# Initial context.
txt = self.build_context(quest_txt)
# Dialogue/interaction context.
dtxt = ""
agent = self.target_node
agent_id = agent.node_id
turn_id = None
for d in agent._last_interaction_history:
current_turn_id = d[0][0]
if turn_id == None or turn_id == current_turn_id:
dtxt += " " + d[1]
else:
dtxt = dtxt.lstrip(" ")
dtxt += "\n" + d[1]
turn_id = current_turn_id
is_safe = d[0][2]
if not is_safe:
# reset conversation when unsafe utterances are in the history
dtxt = self.build_context(quest_txt)
dtxt = dtxt.lstrip(" ")

# Add starting context features, can help the model.
final = self.add_startswith_tokens(txt, dtxt)

return final

7 changes: 7 additions & 0 deletions projects/dialog_and_act_trainer/SETUP.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ py convert.py -t fromfile -ffdp /checkpoint/light/projects/dialog_and_act_train

etc. (for all the files in /checkpoint/light/projects/dialog_and_act_trainer/raw/ .. )

#OR
#
# we make that final data we want with convo starter tokens (to try to preent starting convo as if in middle ):

py convert_withstarter.py -t fromfile -ffdp /checkpoint/light/projects/dialog_and_act_trainer/raw/light_train.txt -of /checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_train.txt


AND

# for the type picker (dialog, act, emote) task:
Expand Down
1 change: 1 addition & 0 deletions projects/dialog_and_act_trainer/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ def clean(msg):
txt = msg["text"]
res = []
app = ""

for t in txt.split("\n"):
if t.startswith("_") and "_object_desc" not in t:
if (
Expand Down
161 changes: 161 additions & 0 deletions projects/dialog_and_act_trainer/convert_withstarter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
#!/usr/bin/env python3

# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Convert a dataset into the ParlAI text format.
## Examples
```shell
parlai convert_data_to_parlai_format -t babi:task1k:1 --outfile /tmp/dump
```
"""

from parlai.core.params import ParlaiParser
from parlai.agents.repeat_label.repeat_label import RepeatLabelAgent
from parlai.core.worlds import create_task
from parlai.utils.misc import msg_to_str, TimeLogger
import parlai.utils.logging as logging
from parlai.core.script import ParlaiScript, register_script
import random
import tempfile
import copy


def clean(msg):
msg = copy.deepcopy(msg)
txt = msg["text"]
res = []
app = ""
convo_has_started = False
for t in txt.split("\n"):
if "_self_say " in t:
convo_has_started = True
if "_partner_say " in t:
convo_has_started = True
if "_partner_name " in t:
partner_name = t.replace("_partner_name ", "")

first_convo_line = True
for t in txt.split("\n"):
if t.startswith("_") and "_object_desc" not in t:
if (
t.startswith("_self_act")
or t.startswith("_self_emote")
or t.startswith("_partner_act")
or t.startswith("_partner_emote")
):
t = t.replace("_self_act ", "")
t = t.replace("_partner_act ", "")
t = t.replace("_self_emote ", "")
t = t.replace("_partner_emote ", "")
app = app + " *" + t + "*"
else:
#if ("_partner_say " in t) or ("_self_say" in t):
# #import pdb; pdb.set_trace()
if ("_partner_say " in t) or ("_self_say" in t):
first_convo_line = False
t = t.replace("_self_say ", "")
t = t.replace("_partner_say ", "")
res.append(t + app)
app = ""

if not convo_has_started:
res.append("START " + partner_name)
else:
res.append("CONTINUE " + partner_name)

msg.force_set("text", "\n".join(res))
#print(res)
return msg


def dump_data(opt):
# create repeat label agent and assign it to the specified task
agent = RepeatLabelAgent(opt)
world = create_task(opt, agent)
opt.log()
ignorefields = opt.get("ignore_fields", "")
if opt["outfile"] is None:
outfile = tempfile.mkstemp(
prefix="{}_{}_".format(opt["task"], opt["datatype"]), suffix=".txt"
)[1]
else:
outfile = opt["outfile"]

if opt["num_examples"] == -1:
num_examples = world.num_examples()
else:
num_examples = opt["num_examples"]
log_timer = TimeLogger()

logging.debug("starting to convert...")
logging.info(f"saving output to {outfile}")
fw = open(outfile, "w")
for _ in range(num_examples):
world.parley()
acts = world.get_acts()
value = acts[0].get("labels", acts[0].pop("eval_labels", None))
acts[0].force_set("labels", value)

msg = clean(acts[0])

txt = msg_to_str(msg, ignore_fields=ignorefields)
fw.write(txt + "\n")
if acts[0].get("episode_done", False):
fw.write("\n")

if log_timer.time() > opt["log_every_n_secs"]:
text, _log = log_timer.log(world.total_parleys, world.num_examples())
logging.info(text)

if world.epoch_done():
logging.info("epoch done")
break
fw.close()


def setup_args():
# Get command line arguments
parser = ParlaiParser(description="Dump a task to a standardized format")
parser.add_argument(
"-n",
"--num-examples",
default=-1,
type=int,
help="Total number of exs to convert, -1 to convert all examples",
)
parser.add_argument(
"-of",
"--outfile",
default=None,
type=str,
help="Output file where to save, by default will be created in tmp",
)
parser.add_argument(
"-if",
"--ignore-fields",
default="id",
type=str,
help="Ignore these fields from the message (returned with .act() )",
)
parser.add_argument("-ltim", "--log-every-n-secs", type=float, default=2)
parser.set_defaults(datatype="train:stream")
return parser


@register_script("convert_to_parlai", hidden=True)
class ConvertDataToParlaiFormat(ParlaiScript):
@classmethod
def setup_args(cls):
return setup_args()

def run(self):
return dump_data(self.opt)


if __name__ == "__main__":
random.seed(42)
ConvertDataToParlaiFormat.main()
105 changes: 105 additions & 0 deletions projects/dialog_and_act_trainer/train/gen_withstarter_sweep.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.


from parlai_internal.projects.param_sweep_utils.param_sweep import run_grid

SWEEP_NAME = "gen_withstarter"
name_keys = {}

grid = {
"-t": [
'"fromfile:fromfile_datapath=/checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_train.txt,fromfile:fromfile_datapath=/checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_wild_train.txt"'
],
"-et": [
'"fromfile:fromfile_datapath=/checkpoint/light/projects/dialog_and_act_trainer/light_withstarter_wild_valid.txt"',
],
"--max-train-time": [(60 * 60) * 36],
"--model": ["parlai_internal.projects.meena.unlikely.boringul:UnlikelihoodAgent"],
"--eval-skip-generation": [False],
"--validation-max-exs": [200],
"-vp": [1000],
"-vmt": ["boring_fails"],
"-vmm": ["min"],
"--save-after-valid": [True],
"-vtim": [(60 * 60) * 30],
"-stim": [(60) * 60 * 1],
"--inference": ["beam"],
"--beam_min_length": [20],
"--beam_context_block_ngram": [-1],
"--beam_block_ngram": [-1],
"--beam_size": [1],
"--ul-type": ["std0"],
"--seq-ul-nc": [4],
"--seq-ul-nl": [4],
"--seq-ul-ratio": [0.25],
"--train-boring-repeats": [True],
"--train-context-repeats": [False],
"--train-label-repeats": [False],
"--attention-dropout": [0.00],
"--batchsize": ["64"],
"--embedding-size": [2560],
"--ffn-size": [10240],
"--variant": ["prelayernorm"],
"--n-heads": [32],
"--n-positions": [128],
"--n-encoder-layers": [2],
"--n-decoder-layers": [24],
"--history-add-global-end-token": [
"end", # hack to get newline delimiter
],
"--dict-tokenizer": ["bytelevelbpe"],
"--dict-file": [
"/checkpoint/parlai/zoo/meena/20200319_meenav0data_tall_2.7B_adamoptimizer/20200319_13.3ppl_200kupdates/model.dict"
],
"--dropout": [0.1],
"--fp16": [True],
"--init-model": [
# emily's good good
"/checkpoint/parlai/zoo/q_function/generative2.7B_bst_0331/model",
],
"--label-truncate": [128],
"--lr-scheduler": ["reduceonplateau"],
"--lr-scheduler-patience": [3],
"--optimizer": ["adam"],
"--relu-dropout": [0.0],
"--activation": ["gelu"],
"--model-parallel": ["true"],
"--text-truncate": [128],
"--truncate": [128],
"--warmup_updates": [200],
"--fp16-impl": ["mem_efficient"],
"--update-freq": [4],
"--gradient-clip": [0.1],
"--skip-generation": [True],
"--log_every_n_secs": [10],
"-lr": [7e-6],
}

if __name__ == "__main__":
run_grid(
grid,
name_keys,
SWEEP_NAME,
partition="learnfair",
# partition='dev',
jobtime="24:00:00",
PARLAI_PATH="/private/home/jase/src/ParlAI/",
gpus=8,
nodes=1,
volta=True,
volta32=True,
saveroot="/checkpoint/light/projects/dialog_and_act_trainer/models/"
+ SWEEP_NAME,
include_job_id=True,
create_model_file=True,
hashname=False,
# fixedname='model',
requeue=True,
mem_gb=400,
data_parallel=True,
copy_env=False,
)
Loading

0 comments on commit 70f4aca

Please sign in to comment.