Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve weblinx towards single turn loading #32

Merged
merged 7 commits into from
Apr 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 19 additions & 3 deletions modeling/llama/processing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
from functools import partial
from typing import Callable

Expand Down Expand Up @@ -139,6 +140,8 @@ def build_prompt_records_for_llama_truncated(
max_candidates_tokens=65 * 10,
add_unused_len_to_cands=True,
allow_iterative_reduction=False,
use_tokenizer_template=False,
template_tokenizer=None,
parser=None,
):
"""
Expand Down Expand Up @@ -221,9 +224,22 @@ def build_prompt_records_for_llama_truncated(
# Add the unused length to the candidates
num_html_tokens = len(tokenizer.tokenize(html))
num_utter_tokens = len(tokenizer.tokenize(utterance_context))
num_prev_turns_tokens = len(
tokenizer.tokenize(" ".join(prev_turns_text_list))
)
if use_tokenizer_template:
if template_tokenizer is None:
raise ValueError(
"template_tokenizer must be provided when use_tokenizer_template is True."
)
prev_turns_merged_copy = deepcopy(prev_turns_merged)
if prev_turns_merged[0]['role'] == 'assistant':
# insert a dummy user turn
prev_turns_merged_copy.insert(0, {'role': 'user', 'content': ''})
num_prev_turns_tokens = len(template_tokenizer.apply_chat_template(
[{'role': 'system', 'content': ''}, *prev_turns_merged_copy], tokenize=True
))
else:
num_prev_turns_tokens = len(
tokenizer.tokenize(" ".join(prev_turns_text_list))
)
remain_html_tokens = max_html_tokens - num_html_tokens
remain_utter_tokens = max_utterance_tokens - num_utter_tokens
remain_prev_turns_tokens = max_prev_turns_tokens - num_prev_turns_tokens
Expand Down
2,055 changes: 2,055 additions & 0 deletions tests/demonstrations/candidates_unittest.jsonl

Large diffs are not rendered by default.

51 changes: 51 additions & 0 deletions tests/test_build_prompt_records.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import sys
import unittest

import weblinx as wl
from weblinx.processing import load_candidate_elements

# import llama's processing code to test
from modeling.llama.processing import (
format_candidates,
format_utterances,
format_utterances_truncated,
get_speaker,
multi_attempt_format_prev_turns_truncated,
)


# if needed, run this function to get candidates:
def create_candidates_unittest_jsonl():
import json
from pathlib import Path
from weblinx.processing import load_candidate_elements

demo_name = "aaabtsd" # change if needed
candidate_path = "wl_data/candidates/test_geo.jsonl" # change if needed
save_path = "tests/demonstrations/candidates_unittest.jsonl"
candidate_elements = load_candidate_elements(candidate_path, group_keys=None)
filt_elems = [e for e in candidate_elements if e["demo_name"] == demo_name]

with open(save_path, "w") as f:
for elem in filt_elems:
f.write(json.dumps(elem) + "\n")


class TestBuildPromptRecords(unittest.TestCase):
def setUp(self):
self.demo = wl.Demonstration("aaabtsd", base_dir="./tests/demonstrations")
# load tests/demonstrations/candidates_unittest.jsonl
self.candidates = load_candidate_elements(
"tests/demonstrations/candidates_unittest.jsonl"
)

def test_format_candidates(self):
"""
Tests the format_candidates function to ensure it returns the expected
string representation of a list of candidates, including the candidate
index and the candidate's intent.
"""


if __name__ == "__main__":
unittest.main()
36 changes: 36 additions & 0 deletions tests/test_processing_prompt.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest
from weblinx.processing.prompt import find_turns_with_instructor_chat
import weblinx as wl


class TestProcessingPrompt(unittest.TestCase):
def setUp(self):
self.demo = wl.Demonstration("aaabtsd", base_dir="./tests/demonstrations")

def test_find_turns_with_instructor_chat(self):
"""
Test the find_turns_with_instructor_chat function to ensure it correctly
filters out turns that contain instructor chat. It checks that the output
list contains only turns with instructor chat.
"""
replay = wl.Replay.from_demonstration(self.demo)
turn = replay[15]

result = find_turns_with_instructor_chat(
replay, turn, speaker="instructor", num_prev_turns=5
)

# In this demo, we checked that there are 3 turns with instructor chat
# so it should return a list of 3 turns
self.assertEqual(len(result), 3)

# now, compare this with filter() function which should return the same
start_index = max(0, turn.index - 5)
result_filter = filter(
lambda turn: turn.get("speaker") == "instructor"
and turn.index < start_index,
replay,
)
result_filter = list(result_filter)
self.assertEqual(len(result_filter), 3)
self.assertEqual(result, result_filter)
39 changes: 25 additions & 14 deletions weblinx/processing/dom.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def prune_tree(


def clean_and_prune_tree(
dom_tree, cands_turn, max_depth=1, max_children=5, max_sibling=2
dom_tree, cands_turn=None, candidate_uids=None, max_depth=1, max_children=5, max_sibling=2,
):
"""
This function will clean and prune the tree based on the candidates in the cands_turn. This
Expand All @@ -341,7 +341,12 @@ def clean_and_prune_tree(
The tree to clean and prune.

cands_turn : list
The list of candidates for the turn.
The list of candidates for the turn. If this is None, we are expected to pass in the
`candidate_uids`; otherwise an error will be raised.

candidate_uids : list, optional
The list of candidate uids to keep. If this is None, we are expected to pass in the
`cands_turn`; otherwise an error will be raised.

max_depth : int, optional
The maximum depth to prune the tree. Defaults to 1.
Expand All @@ -360,23 +365,29 @@ def clean_and_prune_tree(
Raises
------
ValueError
If cands_turn is None.
If cands_turn is None and candidate_uids is None. Alternatively, if both
cands_turn and candidate_uids are passed in, an error will be raised.
"""
if cands_turn is None:
if cands_turn is None and candidate_uids is None:
raise ValueError(
"cands_turn cannot be None. The dom_tree cannot be pruned this way."
"cands_turn or candidate_uids must be provided. The dom_tree cannot be pruned this way."
)

if cands_turn is not None:
candidate_uids = [cand["uid"] for cand in cands_turn]
dom_tree = prune_tree(
dom_tree,
set(candidate_uids),
max_depth=max_depth,
max_children=max_children,
max_sibling=max_sibling,
if cands_turn is not None and candidate_uids is not None:
raise ValueError(
"cands_turn and candidate_uids cannot both be provided. Please provide only one."
)
remove_uid_when_not_candidate(dom_tree, candidate_uids=candidate_uids)
if candidate_uids is None:
candidate_uids = [cand["uid"] for cand in cands_turn]

dom_tree = prune_tree(
dom_tree,
set(candidate_uids),
max_depth=max_depth,
max_children=max_children,
max_sibling=max_sibling,
)
remove_uid_when_not_candidate(dom_tree, candidate_uids=candidate_uids)

remove_html_comments(dom_tree)
sanitize_elem_attributes(dom_tree)
Expand Down
11 changes: 8 additions & 3 deletions weblinx/processing/prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,9 +353,13 @@ def find_turns_with_instructor_chat(
This output of this function should be used by format_utterances to display the utterances.
"""
start_index = max(0, turn.index - num_prev_turns)
instructor_chat_turns = replay.filter_turns(
lambda_func = (
lambda turn: turn.get("speaker") == speaker and turn.index < start_index
)
if isinstance(replay, list):
instructor_chat_turns = list(filter(lambda_func, replay))
else:
instructor_chat_turns = replay.filter_turns(lambda_func)
return instructor_chat_turns


Expand Down Expand Up @@ -638,7 +642,7 @@ def select_turns_and_candidates_for_prompts(

remove_turns_without_elements : bool, optional
Whether to remove turns that do not have elements. Defaults to True.

Returns
-------
list
Expand Down Expand Up @@ -672,7 +676,8 @@ def select_turns_and_candidates_for_prompts(
turns = filter_turns(
turns,
lambda turn: not (
turn.intent in ("click", "change", "textinput", "submit") and turn.element is None
turn.intent in ("click", "change", "textinput", "submit")
and turn.element is None
),
)

Expand Down
19 changes: 19 additions & 0 deletions weblinx/processing/truncation.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,25 @@ def convert_elem_dict_to_str(elem_dict: dict, remove_empty=False):

return element_str

def convert_elem_dict_to_str_dmr(elem_dict: dict):
"""
Convert an element dictionary to a string.
"""
elem_dict = deepcopy(elem_dict)

element_str = f"[[tag]] {elem_dict.pop('tag')}\n"
element_str += f"[[xpath]] {elem_dict.pop('xpath')}\n"
element_str += f"[[text]] {elem_dict.pop('text')}\n"
element_str += f"[[bbox]] {elem_dict.pop('bbox')}\n"
element_str += f"[[attributes]] {elem_dict.pop('attributes')}\n"
element_str += f"[[children]] {elem_dict.pop('children')}"

# for other keys, we just add them to the end

for k, v in elem_dict.items():
element_str += f"\n[[{k}]] {v}"

return element_str

def truncate_cands_turn(
cands_turn: list,
Expand Down
Loading