Skip to content

Commit

Permalink
Benchmarks (#173)
Browse files Browse the repository at this point in the history
  • Loading branch information
gasse authored Oct 17, 2024
1 parent 89eccd5 commit 3a33a69
Show file tree
Hide file tree
Showing 13 changed files with 3,195 additions and 8 deletions.
1 change: 1 addition & 0 deletions browsergym/experiments/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
browsergym-core==0.8.1
tiktoken>=0.4
dataclasses-json
9 changes: 5 additions & 4 deletions browsergym/experiments/src/bgym/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from browsergym.core.action.base import AbstractActionSet
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet
from browsergym.experiments.agent import Agent, AgentInfo
from browsergym.experiments.benchmark import Benchmark, HighLevelActionSetArgs
from browsergym.experiments.loop import (
AbstractAgentArgs,
EnvArgs,
ExpArgs,
AbstractAgentArgs,
ExpResult,
StepInfo,
StepTimestamps,
)
from browsergym.core.action.base import AbstractActionSet
from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.core.action.python import PythonActionSet
376 changes: 376 additions & 0 deletions browsergym/experiments/src/browsergym/experiments/benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,376 @@
import fnmatch
import io
import logging
import pkgutil
from dataclasses import dataclass, field
from typing import Literal, Optional

import numpy as np
import pandas as pd
from dataclasses_json import DataClassJsonMixin, config

from browsergym.core.action.highlevel import HighLevelActionSet
from browsergym.experiments.loop import SEED_MAX, EnvArgs

logger = logging.getLogger(__name__)


@dataclass
class HighLevelActionSetArgs(DataClassJsonMixin):
subsets: list[HighLevelActionSet.ActionSubset]
# custom_actions: list[callable] | None # non-serializable argument, not supported
multiaction: bool
strict: bool
retry_with_force: bool
demo_mode: Literal["off", "default", "all_blue", "only_visible_elements"]

def make_action_set(self):
return HighLevelActionSet(
subsets=self.subsets,
custom_actions=None,
multiaction=self.multiaction,
strict=self.strict,
retry_with_force=self.retry_with_force,
demo_mode=self.demo_mode,
)


@dataclass
class Benchmark(DataClassJsonMixin):
name: str
high_level_action_set_args: HighLevelActionSetArgs
env_args_list: list[EnvArgs]
task_metadata: Optional[pd.DataFrame] = field(
default_factory=lambda: None,
metadata=config(
encoder=lambda df: df.to_dict(orient="records") if df is not None else None,
decoder=lambda items: pd.DataFrame(items) if items is not None else None,
),
)

def __post_init__(self):
# if no metadata is present, generate a dataframe with single "task_name" column
if self.task_metadata is None:
unique_task_names = list(set([env_args.task_name for env_args in self.env_args_list]))
self.task_metadata = pd.DataFrame(
[{"task_name": task_name} for task_name in unique_task_names]
)
# make sure all tasks in env_args are in the metadata
metadata_tasks = list(self.task_metadata["task_name"])
assert all([env_args.task_name in metadata_tasks for env_args in self.env_args_list])

def subset_from_split(self, split: Literal["train", "eval", "test"]):
split_column = "split"

# check for a split column in metadata
if not split_column in self.task_metadata.columns:
raise NotImplementedError(
f"This benchmark does not provide train/eval/test {split_column} (missing split column in task metadata)"
)

# recover the target split
sub_benchmark = self.subset_from_regexp(split_column, regexp=f"^{split}$")

# check that the split exists (non-empty task list)
if not sub_benchmark.env_args_list:
raise ValueError(f"The {split} split for this benchmark is empty.")

return sub_benchmark

def subset_from_glob(self, column, glob):
subset = self.subset_from_regexp(column, regexp=fnmatch.translate(glob))
subset.name = f"{self.name}[{column}={glob}]"
return subset

def subset_from_regexp(self, column, regexp):
# extract the filtered task_name subset
task_name_subset = task_list_from_metadata(self.task_metadata, {column: regexp})

# return the sub benchmark
return Benchmark(
name=f"{self.name}[{column}=/{regexp}/]",
high_level_action_set_args=self.high_level_action_set_args,
env_args_list=[
env_args
for env_args in self.env_args_list
if env_args.task_name in task_name_subset
],
task_metadata=self.task_metadata,
)


def task_metadata(benchmark_name: str):
return task_metadata_from_csv(
io.StringIO(
pkgutil.get_data(__name__, f"task_metadata/{benchmark_name}.csv").decode("utf-8")
)
)


def task_metadata_from_csv(filepath):
return pd.read_csv(filepath).fillna("")


def task_list_from_metadata(metadata: pd.DataFrame, filter: dict[str, str] = {}):
df = metadata
# filter the desired columns (AND filter)
for col_name, regex in filter.items():
col_filter = df[col_name].astype(str).str.contains(regex, regex=True)
df = df[col_filter]
# return only the task names
return list(df["task_name"])


# These are mean as the default highlevel action set to fairly evaluate agents on each benchmark.
# They are mostly arbitrary, the important thing is to evaluate different agents using the same action set for fairness.
DEFAULT_HIGHLEVEL_ACTION_SET_ARGS = {
# loosely from https://github.com/Farama-Foundation/miniwob-plusplus/blob/1bab0dffe34e92cc1049fe9443542029bf7e44a9/miniwob/action.py#L122
"miniwob": HighLevelActionSetArgs(
subsets=["bid", "coord"],
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
# loosely from https://github.com/Farama-Foundation/miniwob-plusplus/blob/1bab0dffe34e92cc1049fe9443542029bf7e44a9/miniwob/action.py#L160
"miniwob_bid": HighLevelActionSetArgs(
subsets=["bid"],
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
# loosely from https://github.com/Farama-Foundation/miniwob-plusplus/blob/1bab0dffe34e92cc1049fe9443542029bf7e44a9/miniwob/action.py#L173
"miniwob_coord": HighLevelActionSetArgs(
subsets=["coord"],
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
"workarena_l1": HighLevelActionSetArgs(
subsets=["chat", "bid"], # no need for infeasible action
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
"workarena": HighLevelActionSetArgs(
subsets=["chat", "infeas", "bid"],
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
# from https://arxiv.org/abs/2307.13854
"webarena": HighLevelActionSetArgs(
subsets=["chat", "infeas", "bid", "tab", "nav"],
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
# from https://arxiv.org/abs/2401.13649
"visualwebarena": HighLevelActionSetArgs(
subsets=["chat", "infeas", "bid", "tab", "nav"],
multiaction=False,
strict=False,
retry_with_force=False,
demo_mode="off",
),
}

# all benchmarks are callables designed for lazy loading, i.e. `bench = BENCHMARKS["miniwob_all"]()`
BENCHMARKS = {
"miniwob_all": lambda: Benchmark(
name="miniwob_all",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(metadata=task_metadata("miniwob")),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_webgum": lambda: Benchmark(
name="miniwob_webgum",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
metadata=task_metadata("miniwob"), filter={"webgum_subset": "True"}
),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_tiny_test": lambda: Benchmark(
name="miniwob_tiny_test",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=["miniwob.click-dialog", "miniwob.click-checkboxes"],
max_steps=5,
n_repeats=2,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_train": lambda: Benchmark(
name="miniwob_train",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
metadata=task_metadata("miniwob"),
filter={"miniwob_category": "original|nodelay|debug|additional"},
),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"miniwob_test": lambda: Benchmark(
name="miniwob_test",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["miniwob"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
metadata=task_metadata("miniwob"), filter={"miniwob_category": "hidden test"}
),
max_steps=10,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("miniwob"),
),
"webarena": lambda: Benchmark(
name="webarena",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["webarena"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(metadata=task_metadata("webarena")),
max_steps=15,
n_repeats=1,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("webarena"),
),
"visualwebarena": lambda: Benchmark(
name="visualwebarena",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["visualwebarena"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(metadata=task_metadata("visualwebarena")),
max_steps=15,
n_repeats=1,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("visualwebarena"),
),
"workarena_l1": lambda: Benchmark(
name="workarena_l1",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena_l1"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
metadata=task_metadata("workarena"), filter={"level": "l1"}
),
max_steps=15,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("workarena"),
),
"workarena_l1_sort": lambda: Benchmark(
name="workarena_l1_sort",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena_l1"],
env_args_list=_make_env_args_list_from_repeat_tasks(
task_list=task_list_from_metadata(
metadata=task_metadata("workarena"), filter={"level": "l1", "category": "list-sort"}
),
max_steps=15,
n_repeats=10,
seeds_rng=np.random.RandomState(42),
),
task_metadata=task_metadata("workarena"),
),
"workarena_l2_agent_curriculum_eval": lambda: Benchmark(
name="workarena_l2_agent_curriculum_eval",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena"],
env_args_list=_make_env_args_list_from_workarena_curriculum(
level="l2",
task_category_filter=None,
meta_seed=42, # meta seed for evaluation curriculum
max_steps=50,
curriculum_type="agent",
),
task_metadata=task_metadata("workarena"),
),
"workarena_l3_agent_curriculum_eval": lambda: Benchmark(
name="workarena_l3_agent_curriculum_eval",
high_level_action_set_args=DEFAULT_HIGHLEVEL_ACTION_SET_ARGS["workarena"],
env_args_list=_make_env_args_list_from_workarena_curriculum(
level="l3",
task_category_filter=None,
meta_seed=42, # meta seed for evaluation curriculum
max_steps=50,
curriculum_type="agent",
),
task_metadata=task_metadata("workarena"),
),
}


def _make_env_args_list_from_workarena_curriculum(
level: Literal["l2", "l3"],
task_category_filter: str,
meta_seed: int,
max_steps: int,
curriculum_type: Literal["human", "agent"],
):
"""
Returns a WorkArena predefined task curriculum (e.g., task and seed combination).
"""
assert level in ("l2", "l3")
assert curriculum_type in ("human", "agent")

env_args_list = []

from browsergym.workarena import get_all_tasks_agents

all_task_tuples = get_all_tasks_agents(
filter=f"{level}.{task_category_filter}" if task_category_filter else level,
meta_seed=meta_seed,
is_agent_curriculum=(curriculum_type == "agent"),
)

for task, seed in all_task_tuples:
task_name = task.get_task_id()
env_args_list.append(EnvArgs(task_name=task_name, task_seed=seed, max_steps=max_steps))

return env_args_list


def _make_env_args_list_from_repeat_tasks(
task_list: list[str], max_steps: int, n_repeats: int, seeds_rng: np.random.RandomState
):
"""
Generates a list of `len(task_list)` time `n_repeats` environments arguments, using randomly generated seeds.
"""
env_args_list = []
for task in task_list:
for seed in seeds_rng.randint(low=0, high=SEED_MAX, size=n_repeats):
env_args_list.append(
EnvArgs(
task_name=task,
task_seed=int(seed),
max_steps=max_steps,
headless=True,
record_video=False,
wait_for_user_message=False,
viewport=None,
slow_mo=None,
storage_state=None,
task_kwargs=None,
)
)

return env_args_list
Loading

0 comments on commit 3a33a69

Please sign in to comment.