Skip to content

Commit

Permalink
Changing the argument type of terminate_fn to AsyncEnv so that the is…
Browse files Browse the repository at this point in the history
…_successful function of TaskEval can be used as terminate_fn.

PiperOrigin-RevId: 715929660
  • Loading branch information
The android_world Authors committed Jan 16, 2025
1 parent 4c966eb commit 93bc588
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 8 deletions.
9 changes: 3 additions & 6 deletions android_world/episode_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@

import dataclasses
from typing import Any, Callable, Optional

from android_env import env_interface
from android_world import constants
from android_world.agents import base_agent
from android_world.env import interface
import termcolor


Expand All @@ -45,9 +44,7 @@ def run_episode(
agent: base_agent.EnvironmentInteractingAgent,
max_n_steps: int = 10,
start_on_home_screen: bool = False,
termination_fn: (
Callable[[env_interface.AndroidEnvInterface], float] | None
) = None,
termination_fn: Callable[[interface.AsyncEnv], float] | None = None,
) -> EpisodeResult:
"""Runs an agent on goal, e.g., "turn off wifi".
Expand Down Expand Up @@ -83,7 +80,7 @@ def run_episode(
print('Completed step {:d}.'.format(step_n + 1))
assert constants.STEP_NUMBER not in result.data
output.append(result.data | {constants.STEP_NUMBER: step_n})
if termination_fn(agent.env.controller):
if termination_fn(agent.env):
print('Environment ends episode.')
return EpisodeResult(
done=True,
Expand Down
4 changes: 2 additions & 2 deletions android_world/task_evals/miniwob/miniwob_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ def get_episode_reward(env: env_interface.AndroidEnvInterface) -> float:
return float(int(reward))


def is_episode_terminated(env: env_interface.AndroidEnvInterface) -> bool:
def is_episode_terminated(env: interface.AsyncEnv) -> bool:
"""Checks if the current episode is terminated."""
return get_episode_reward(env) != 0.0
return get_episode_reward(env.controller.env) != 0.0


class MiniWoBTask(task_eval.TaskEval, abc.ABC):
Expand Down

0 comments on commit 93bc588

Please sign in to comment.