-
Notifications
You must be signed in to change notification settings - Fork 61
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #34 from qshi95/lemon
Lemon
- Loading branch information
Showing
322 changed files
with
150,107 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
# LEMON | ||
|
||
This repository contains the code and pre-trained models for our EMNLP2022 Findings paper [LEMON: Language-Based Environment Manipulation via Execution-guided Pre-training](https://arxiv.org/pdf/2201.08081.pdf) | ||
|
||
Data | ||
------- | ||
The data is in the [release](https://github.com/qshi95/LEMON/releases/tag/data). Please unzip it and put it in the lemon_data folder. | ||
|
||
Pre-training | ||
------- | ||
Run the following command to preprocess the data: | ||
```bash | ||
bash preprocess_pretrain.bat | ||
``` | ||
|
||
Then run the following command to pre-train the model: | ||
```bash | ||
bash pretrain.sh | ||
``` | ||
|
||
Fine-tuning | ||
------- | ||
|
||
Run the following command to preprocess the data: | ||
```bash | ||
bash preprocess_finetune.bat | ||
``` | ||
|
||
Then run the following command to fine-tune the model: | ||
```bash | ||
bash finetune.sh | ||
``` |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
import sys | ||
sys.path.append('../executor/') | ||
from strongsup.rlong.executor import RLongExecutor | ||
from strongsup.rlong.predicate import RLongPredicate | ||
from strongsup.rlong.state import RLongAlchemyState | ||
from itertools import permutations | ||
from random import choices, choice, sample | ||
import math | ||
import argparse | ||
import multiprocessing | ||
from multiprocessing import Pool | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--max_number", type=int, default=100000, help="max number each dataset.") | ||
parser.add_argument("--corpus_file", type=str, default='../corpus/pretraining_corpus_alchemy.txt', help="corpus file") | ||
parser.add_argument("--dataset_prefix", type=str, default='alchemy', help="dataset name") | ||
args = parser.parse_args() | ||
|
||
fw = open(args.corpus_file, 'w') | ||
|
||
def random_sampling(candidate_list, n, weights=None): | ||
|
||
result_list = [] | ||
for _ in range(n): | ||
result = choices(candidate_list, k=1, weights=weights)[0] | ||
result_list.append(result) | ||
return result_list | ||
|
||
def prepare_lf(lf): | ||
if isinstance(lf, str): | ||
lf = lf.split() | ||
if not all(isinstance(x, RLongPredicate) for x in lf): | ||
lf = [x if isinstance(x, RLongPredicate) else RLongPredicate(x) | ||
for x in lf] | ||
return lf | ||
|
||
def postpreprocess_alchemy(states): | ||
return ' | '.join(states.strip().split()) | ||
|
||
|
||
def uni_executor(state, lf, dataset_prefix): | ||
if dataset_prefix == 'alchemy': | ||
state = RLongAlchemyState.from_raw_string(state) | ||
|
||
lf = prepare_lf(lf) | ||
executor = RLongExecutor(state, debug=False) | ||
|
||
# Direct execution | ||
denotation_direct = executor.execute(lf) | ||
# Token-by-token execution | ||
denotation = None | ||
for x in lf: | ||
denotation = executor.execute_predicate(x, denotation) | ||
assert denotation_direct == denotation | ||
|
||
# return denotation.world_state | ||
return denotation | ||
|
||
def alchemy_state_generator(): | ||
|
||
colors = ['b', 'g', 'o', 'p', 'r', 'y'] | ||
num_positions = 7 | ||
objects = [] | ||
for i in range(num_positions): | ||
amt = choice([0,1,2,3,4]) | ||
color = choice(colors) | ||
beaker = None | ||
for _ in range(amt): | ||
if beaker is None: | ||
beaker = [] | ||
beaker.append(color) | ||
if beaker is None: | ||
beaker = '_' | ||
objects.append(''.join(beaker)) | ||
|
||
states = ['{}:{}'.format(str(i+1), item) for i,item in enumerate(objects)] | ||
states = ' '.join(states) | ||
return states | ||
|
||
def single_alchemy_lf_generator(states, executor, lf): | ||
|
||
colors = ['b', 'g', 'o', 'p', 'r', 'y'] | ||
object_list = ['{} PColor {} index'.format(color, ind) for color in colors for ind in range(1,8)] # 这里可能 会有问题 如果唯一的对象再制定了index 可能会累赘 先这样 | ||
object_list += ['{} PColor'.format(color) for color in colors] | ||
object_list += ['all-objects {} index'.format(ind) for ind in range(1,8)] | ||
object_list += ['{} H1'.format(item) for item in [1,2,3,4,-1]] | ||
object_list += ['{} H2'.format(item) for item in [1,2,3,4,-1]] | ||
|
||
func = random_sampling(['APour', 'ADrain', 'AMix'], 1)[0] | ||
|
||
if func == 'ADrain': | ||
valid_objects = [] | ||
for item in list(set(object_list)): # shuffle because only choose one | ||
try: | ||
result = executor(states, lf + ' ' + item, 'alchemy') | ||
if len(result.execution_stack[0]) == 1: | ||
if str(result.execution_stack[0]).split(':')[1] != '_': | ||
valid_objects.append(item) | ||
break | ||
except: | ||
pass | ||
|
||
assert len(valid_objects) <= 1 | ||
object = random_sampling(valid_objects, 1)[0] | ||
stack = executor(states, lf + ' ' + object, 'alchemy').execution_stack[0] | ||
assert len(stack) == 1 | ||
cur_len = len(str(stack[0]).split(':')[1]) | ||
number_list = [] | ||
if cur_len == 4: | ||
number_list.extend(['X1/2', 'X1/4', '4']) | ||
elif cur_len == 3: | ||
number_list.extend(['X1/3', 'X2/3', '3']) | ||
elif cur_len == 2: | ||
number_list.extend(['X1/2', '2']) | ||
elif cur_len == 1: | ||
number_list.extend(['1']) | ||
number = random_sampling(number_list, 1)[0] | ||
if lf and func == lf.split()[-1]: | ||
lf += ' ' + object + ' ' + number + ' ' + '-1 H0' | ||
else: | ||
lf += ' ' + object + ' ' + number + ' ' + func | ||
assert executor(states, lf, 'alchemy') | ||
|
||
elif func == 'APour': | ||
valid_objects = [] | ||
for item1 in list(set(object_list)): | ||
for item2 in list(set(object_list)): | ||
try: | ||
result = executor(states, lf + ' ' + item1 + ' ' + item2 + ' ' + func, 'alchemy') | ||
valid_objects.append((item1, item2)) | ||
break | ||
except: | ||
pass | ||
if len(valid_objects) > 0: | ||
break | ||
|
||
assert len(valid_objects) <= 1 | ||
object = random_sampling(valid_objects, 1)[0] | ||
if lf and func == lf.split()[-1]: | ||
lf += ' ' + object[0] + ' ' + object[1] + ' ' + '-1 H0' | ||
else: | ||
lf += ' ' + object[0] + ' ' + object[1] + ' ' + func | ||
assert executor(states, lf, 'alchemy') | ||
|
||
elif func == 'AMix': | ||
valid_objects = [] | ||
for item in list(set(object_list)): | ||
try: | ||
result = executor(states, lf + ' ' + item + ' ' + func, 'alchemy') | ||
valid_objects.append(item) | ||
break | ||
except: | ||
pass | ||
|
||
assert len(valid_objects) <= 1 | ||
object = random_sampling(valid_objects, 1)[0] | ||
lf += ' ' + object + ' ' + func | ||
assert executor(states, lf, 'alchemy') | ||
|
||
return lf | ||
|
||
|
||
def lf_generator(states, executor, max_step, dataset_prefix): | ||
|
||
if dataset_prefix == 'alchemy': | ||
func = single_alchemy_lf_generator | ||
|
||
count = 0 | ||
lf = '' | ||
for _ in range(10): | ||
try: | ||
lf = func(states, executor, lf) | ||
except: | ||
continue | ||
|
||
count += 1 | ||
if count >= max_step: | ||
break | ||
|
||
return lf | ||
|
||
def corpus_generation(inputs): | ||
|
||
executor, max_step, total_number, dataset_prefix = inputs | ||
|
||
if dataset_prefix == 'alchemy': | ||
state_generator = alchemy_state_generator | ||
state_preprocesser = postpreprocess_alchemy | ||
|
||
count = 0 | ||
while True: | ||
states = state_generator() | ||
lf = lf_generator(states, executor, max_step, dataset_prefix) | ||
if lf.strip(): | ||
result = executor(states, lf.strip(), dataset_prefix) | ||
if len(result.command_history) == max_step: | ||
initial_state = state_preprocesser(states) | ||
final_state = state_preprocesser(str(result.world_state)) | ||
item_row = '\t'.join([lf.strip(), initial_state, final_state]) | ||
fw.write(item_row) | ||
fw.write('\n') | ||
count += 1 | ||
if count % 10000 == 0: | ||
print('Finish generating {} cases'.format(count)) | ||
if count >= total_number: | ||
break | ||
|
||
|
||
if __name__ == '__main__': | ||
|
||
cores = multiprocessing.cpu_count() | ||
print("Using {} cores".format(cores)) | ||
pool = Pool(cores) | ||
|
||
|
||
for i in range(1,6): | ||
res = pool.map(corpus_generation, zip([uni_executor]*cores, [i]*cores, [int(args.max_number // 5 // cores)]*cores, [args.dataset_prefix]*cores)) | ||
|
||
pool.close() | ||
pool.join() |
63 changes: 63 additions & 0 deletions
63
LEMON/corpus_generation/corpus_generation_split_newformat.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import json | ||
import sys | ||
import copy | ||
from itertools import combinations, permutations | ||
import math | ||
import argparse | ||
from random import shuffle | ||
from remove_same import big_file_remove_same | ||
import os | ||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--dataset_prefix", type=str, default='alchemy', help="dataset prefix") | ||
parser.add_argument("--root_path", type=str, default='../corpus/', help="dataset prefix") | ||
|
||
args = parser.parse_args() | ||
|
||
args.corpus_file = os.path.join(args.root_path, '{}/pretraining_corpus_{}.txt'.format(args.dataset_prefix, args.dataset_prefix)) | ||
args.remove_same_file = os.path.join(args.root_path, '{}/temp.txt'.format(args.dataset_prefix)) | ||
args.train_source_file = os.path.join(args.root_path, '{}/train.src'.format(args.dataset_prefix)) | ||
args.train_target_file = os.path.join(args.root_path, '{}/train.tgt'.format(args.dataset_prefix)) | ||
args.dev_source_file = os.path.join(args.root_path, '{}/dev.src'.format(args.dataset_prefix)) | ||
args.dev_target_file = os.path.join(args.root_path, '{}/dev.tgt'.format(args.dataset_prefix)) | ||
|
||
big_file_remove_same(args.corpus_file, args.remove_same_file) | ||
|
||
with open(args.remove_same_file, 'r') as f: | ||
total_data_list = f.readlines() | ||
|
||
print(len(total_data_list)) | ||
shuffle(total_data_list) | ||
|
||
train_data_list = total_data_list[:-20000] | ||
dev_data_list = total_data_list[-20000:] | ||
|
||
fw_train_src = open(args.train_source_file, 'w') | ||
fw_train_tgt = open(args.train_target_file, 'w') | ||
fw_dev_src = open(args.dev_source_file, 'w') | ||
fw_dev_tgt = open(args.dev_target_file, 'w') | ||
|
||
for item in train_data_list: | ||
try: | ||
action, prev_state, current_state = item.split('\t') | ||
except: | ||
continue | ||
src_row = ' SEP '.join([prev_state.strip(), action.strip()]) | ||
tgt_row = current_state.strip() | ||
fw_train_src.write(src_row) | ||
fw_train_src.write('\n') | ||
fw_train_tgt.write(tgt_row) | ||
fw_train_tgt.write('\n') | ||
|
||
for item in dev_data_list: | ||
try: | ||
action, prev_state, current_state = item.split('\t') | ||
except: | ||
continue | ||
src_row = ' SEP '.join([prev_state.strip(), action.strip()]) | ||
tgt_row = current_state.strip() | ||
fw_dev_src.write(src_row) | ||
fw_dev_src.write('\n') | ||
fw_dev_tgt.write(tgt_row) | ||
fw_dev_tgt.write('\n') | ||
|
Oops, something went wrong.