Skip to content

Commit

Permalink
Refactor AssistantBench output directories (#242)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Maxime Gasse <[email protected]>
  • Loading branch information
ThibaultLSDC and gasse authored Nov 15, 2024
1 parent 8e56811 commit a02a541
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 31 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
VALID_AB_TASK_IDS = []
TEST_AB_TASK_IDS = []


# register a toy easy task for testing implementation
gym_id = f"assistantbench.imp.0"
register_task(
gym_id,
task.AssistantBenchTask,
task_kwargs={
"task_id": f"imp.0",
"output_file_path": "./assistantbench-predictions-imp.jsonl",
},
default_task_kwargs={
"save_predictions": False, # can be overriden
},
)
TOY_AB_TASK_IDS.append(gym_id)
Expand All @@ -24,7 +27,12 @@
register_task(
gym_id,
task.AssistantBenchTask,
task_kwargs={"task_id": f"validation.{task_id}"},
task_kwargs={
"task_id": f"validation.{task_id}",
},
default_task_kwargs={
"save_predictions": False, # can be overriden
},
)
VALID_AB_TASK_IDS.append(gym_id)

Expand All @@ -36,7 +44,9 @@
task.AssistantBenchTask,
task_kwargs={
"task_id": f"test.{task_id}",
"output_file_path": "./assistantbench-predictions-test.jsonl",
},
default_task_kwargs={
"save_predictions": True, # can be overriden
},
)
TEST_AB_TASK_IDS.append(gym_id)
Expand Down
54 changes: 47 additions & 7 deletions browsergym/assistantbench/src/browsergym/assistantbench/task.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import os
from typing import Dict, Tuple

from datasets import load_dataset
Expand All @@ -11,6 +12,18 @@

logger = logging.getLogger(__name__)

_DEFAULT_OUTPUT_FILE = None


def set_default_output_file(output_file: str):
global _DEFAULT_OUTPUT_FILE
_DEFAULT_OUTPUT_FILE = output_file


def get_default_output_file():
return _DEFAULT_OUTPUT_FILE


# Load dataset

DATA_DATASET = "AssistantBench/AssistantBench"
Expand Down Expand Up @@ -60,12 +73,15 @@ def get_task_id(cls) -> str:
"""
raise NotImplementedError

def __init__(self, seed: int, task_id: str, output_file_path: str = None) -> None:
def __init__(
self, seed: int, task_id: str, output_file: str = None, save_predictions: bool = False
) -> None:
"""
Args:
seed (int): Random seed for task initialization.
task_id (str): Unique identifier for the task (for the BrowserGym environment).
output_file_path (str, optional): Path to the output file for saving results, needed for test set.
output_file (str, optional): Path to the output file for saving results, needed for test set.
save_predictions (bool, optional): Save predictions to the output file (yes/no).
"""
super().__init__(seed)
self.locale = "en-US"
Expand All @@ -76,11 +92,31 @@ def __init__(self, seed: int, task_id: str, output_file_path: str = None) -> Non
self.goal = tasks[str(self.task_id)]
self.gold = gold_answers[str(self.task_id)]
self.ab_task_id = ids[self.task_id]
self.output_file_path = output_file_path
self.save_predictions = save_predictions

self.output_file = output_file

# set output_file using the global default value, if not provided in constructor
if not self.output_file:
self.output_file = get_default_output_file()
# use env variable in last resort
if not self.output_file:
self.output_file = os.getenv("ASSISTANTBENCH_OUTPUT_FILE", None)

if self.save_predictions and self.output_file:
logger.info(f"Task prediction will be written to output file {self.output_file}")

def setup(self, page: Page) -> Tuple[str, dict]:
logger.info(f"Navigating to start url: {self.start_url}")
page.goto(self.start_url, timeout=10000)
if self.save_predictions and self.output_file:
# create an empty task entry in the output file (will raise an Exception if the entry is already there)
add_prediction_to_jsonl(
file_path=self.output_file,
task_id=self.ab_task_id,
prediction="",
override_if_exists=False,
)
return self.goal, {}

def teardown(self) -> None:
Expand All @@ -93,10 +129,14 @@ def validate(self, page: Page, chat_messages: list[dict]) -> Tuple[float, bool,
if chat_messages and chat_messages[-1]["role"] == "assistant":
done = True
prediction = chat_messages[-1]["message"]
accuracy, has_ans = question_scorer(prediction, self.gold)
if self.output_file_path:
if self.save_predictions and self.output_file:
# update the task entry in the output file
add_prediction_to_jsonl(
self.output_file_path, self.ab_task_id, prediction, True
) # save answer to file
file_path=self.output_file,
task_id=self.ab_task_id,
prediction=prediction,
override_if_exists=True,
)
accuracy, has_ans = question_scorer(prediction, self.gold)

return accuracy, done, msg, info
64 changes: 45 additions & 19 deletions browsergym/assistantbench/src/browsergym/assistantbench/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import json
import logging
import os
import pathlib
import time

logger = logging.getLogger(__name__)

Expand All @@ -9,39 +11,63 @@ def add_prediction_to_jsonl(
file_path: str, task_id: str, prediction: object, override_if_exists: bool
) -> None:
"""
WARNING: this is not multiprocessing-safe.
Multiprocessing-safe file write.
"""
lock_file_path = pathlib.Path(file_path).with_suffix(".lock")
lock_max_wait = 10 # 10 seconds

# Acquire lock (atomic file creation)
start_time = time.time()
while True:
try:
fd = os.open(lock_file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY)
with os.fdopen(fd, "w") as f:
f.write("lock")
break
except FileExistsError:
# give up if max wait time reached
seconds_waited = time.time() - start_time
if seconds_waited >= lock_max_wait:
raise RuntimeError(
f"Lock file could not be acquired after {seconds_waited} seconds ({lock_file_path})"
)
# wait for lock release
logger.info(f"Waiting for lock file to be released: {lock_file_path}")
time.sleep(1) # 1 sec

logger.info(f"Lock file acquired: {lock_file_path}")

# Check if the file exists, if not, create it
if not os.path.exists(file_path):
with open(file_path, "w") as f:
pass # Create an empty file

# Load existing data
# Load existing data, if any
data = []
with open(file_path, "r") as f:
for line in f:
if line.strip(): # Ensure no empty lines
data.append(json.loads(line))
if os.path.exists(file_path):
with open(file_path, "r") as f:
data.extend([json.loads(line) for line in f if line.strip()]) # Skip empty lines

# Check if task_id already exists
existing_record = next((entry for entry in data if entry["id"] == task_id), None)

if existing_record:
if not override_if_exists:
logger.warning(
f"Task ID '{task_id}' already exists. Not overriding as 'override_if_exists' is set to False."
)
return
else:
logger.warning(
f"Task ID '{task_id}' already exists. Overriding as 'override_if_exists' is set to True."
)
existing_record["answer"] = prediction
else:
# Add or update the record
if not existing_record:
# Add new record
data.append({"id": task_id, "answer": prediction})
elif override_if_exists:
# Update existing record
existing_record["answer"] = prediction
else:
raise ValueError(
f"Prediction for task ID {repr(task_id)} already exists in file {file_path}."
)

# Write updated data back to the file
# Write data back to the file
with open(file_path, "w") as f:
for entry in data:
f.write(json.dumps(entry) + "\n")

# Release lock (remove file)
os.remove(lock_file_path)
logger.info(f"Lock file released: {lock_file_path}")
24 changes: 22 additions & 2 deletions browsergym/experiments/src/browsergym/experiments/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,15 @@ class EnvArgs(DataClassJsonMixin):
storage_state: Optional[str | Path | dict] = None
task_kwargs: Optional[dict] = None # use default value from BrowserGym

def make_env(self, action_mapping, exp_dir):
def make_env(self, action_mapping, exp_dir, exp_task_kwargs: dict = {}):
"""
Instantiates the BrowserGym environment corresponding to the arguments (with some tweaks).
Args:
action_mapping: overrides the action mapping of the environment.
exp_dir: will set some environment parameters (e.g., record_video_dir) with respect to the directory where the experiment is running.
exp_task_kwargs: use with caution! Will override task parameters to experiment-specific values. Useful to set different server configs for different experiments, or output file paths within the experiment's folder (e.g., assistantbench).
"""
extra_kwargs = {}
if self.record_video:
extra_kwargs["record_video_dir"] = exp_dir
Expand All @@ -57,6 +65,15 @@ def make_env(self, action_mapping, exp_dir):
extra_kwargs["pw_context_kwargs"] = {"storage_state": self.storage_state}
if self.task_kwargs is not None:
extra_kwargs["task_kwargs"] = self.task_kwargs
if exp_task_kwargs:
extra_kwargs["task_kwargs"] = extra_kwargs.get("task_kwargs", {}) | exp_task_kwargs

# assistantbench hack, write the task output (agent prediction) to a file in the experiment's directory
# TODO: find a better way to deal with this
if self.task_name.startswith("assistantbench.test"):
extra_kwargs["task_kwargs"] = extra_kwargs.get("task_kwargs", {}) | {
"output_file": exp_dir / "assistantbench-prediction.json"
}

return gym.make(
_get_env_name(self.task_name),
Expand Down Expand Up @@ -214,9 +231,12 @@ def run(self):
logger.info(f"Running experiment {self.exp_name} in:\n {self.exp_dir}")
agent = self.agent_args.make_agent()
logger.debug(f"Agent created.")

env = self.env_args.make_env(
action_mapping=agent.action_set.to_python_code, exp_dir=self.exp_dir
action_mapping=agent.action_set.to_python_code,
exp_dir=self.exp_dir,
)

logger.debug(f"Environment created.")

step_info = StepInfo(step=0)
Expand Down

0 comments on commit a02a541

Please sign in to comment.