Skip to content

Commit

Permalink
basic_agent multi-tab update
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse committed Oct 17, 2024
1 parent 34a2182 commit 6375634
Showing 1 changed file with 31 additions and 4 deletions.
35 changes: 31 additions & 4 deletions demo_agent/basic_agent.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import base64
import dataclasses
import numpy as np
import io
import logging

import numpy as np
from PIL import Image

from browsergym.experiments import Agent, AbstractAgentArgs
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet
from browsergym.experiments import AbstractAgentArgs, Agent
from browsergym.utils.obs import flatten_axtree_to_str, flatten_dom_to_str, prune_html

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -40,6 +40,9 @@ def obs_preprocessor(self, obs: dict) -> dict:
"goal_object": obs["goal_object"],
"last_action": obs["last_action"],
"last_action_error": obs["last_action_error"],
"open_pages_urls": obs["open_pages_urls"],
"open_pages_titles": obs["open_pages_titles"],
"active_page_index": obs["active_page_index"],
"axtree_txt": flatten_axtree_to_str(obs["axtree_object"]),
"pruned_html": prune_html(flatten_dom_to_str(obs["dom_object"])),
}
Expand Down Expand Up @@ -68,7 +71,7 @@ def __init__(
self.openai_client = OpenAI()

self.action_set = HighLevelActionSet(
subsets=["chat", "bid", "infeas"], # define a subset of the action space
subsets=["chat", "tab", "nav", "bid", "infeas"], # define a subset of the action space
# subsets=["chat", "bid", "coord", "infeas"] # allow the agent to also use x,y coordinates
strict=False, # less strict on the parsing of the actions
multiaction=False, # does not enable the agent to take multiple actions at once
Expand Down Expand Up @@ -151,6 +154,29 @@ def get_action(self, obs: dict) -> tuple[str, dict]:
# goal_object is directly presented as a list of openai-style messages
user_msgs.extend(obs["goal_object"])

# append url of all open tabs
user_msgs.append(
{
"type": "text",
"text": f"""\
# Currently open tabs
""",
}
)
for page_index, (page_url, page_title) in enumerate(
zip(obs["open_pages_urls"], obs["open_pages_titles"])
):
user_msgs.append(
{
"type": "text",
"text": f"""\
Tab {page_index}{" (active tab)" if page_index == obs["active_page_index"] else ""}
Title: {page_title}
URL: {page_url}
""",
}
)

# append page AXTree (if asked)
if self.use_axtree:
user_msgs.append(
Expand Down Expand Up @@ -234,6 +260,7 @@ def get_action(self, obs: dict) -> tuple[str, dict]:
{
"type": "text",
"text": f"""\
{action}
""",
}
Expand Down Expand Up @@ -261,7 +288,7 @@ def get_action(self, obs: dict) -> tuple[str, dict]:
"text": f"""\
# Next action
You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, the current state of the page before deciding on your next action.
You will now think step by step and produce your next best action. Reflect on your past actions, any resulting error message, and the current state of the page before deciding on your next action.
""",
}
)
Expand Down

0 comments on commit 6375634

Please sign in to comment.