diff --git a/.gitignore b/.gitignore index 81e4ff0c..1525c9ee 100644 --- a/.gitignore +++ b/.gitignore @@ -5,7 +5,8 @@ __pycache__/ # C extensions *.so - +checkpoints +logs # Distribution / packaging .Python build/ diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py index 91b645c2..a34f52ac 100644 --- a/gia/eval/evaluator.py +++ b/gia/eval/evaluator.py @@ -5,9 +5,9 @@ class Evaluator: - def __init__(self, args: Arguments, task: str) -> None: + def __init__(self, args: Arguments, task_name: str) -> None: self.args = args - self.task = task + self.task_name = task_name @torch.no_grad() def evaluate(self, model: GiaModel) -> float: diff --git a/gia/eval/language_modeling/language_modeling_evaluator.py b/gia/eval/language_modeling/language_modeling_evaluator.py index 72353d62..f92a647f 100644 --- a/gia/eval/language_modeling/language_modeling_evaluator.py +++ b/gia/eval/language_modeling/language_modeling_evaluator.py @@ -9,11 +9,10 @@ class LanguageModelingEvaluator(Evaluator): - def _evaluate(self, model: GiaModel): - model.eval() + def _evaluate(self, model: GiaModel) -> float: losses = [] processor = GiaProcessor() - dataset = load_dataset("gia-project/gia-dataset", self.task, split="test") + dataset = load_dataset("gia-project/gia-dataset", self.task_name, split=self.args.test_split) dataset = dataset.map(lambda batch: processor(**batch), remove_columns=dataset.column_names, batched=True) dataloader = DataLoader(dataset, batch_size=self.args.batch_size, collate_fn=GiaDataCollator(), shuffle=True) for step, batch in enumerate(dataloader): diff --git a/gia/eval/mappings.py b/gia/eval/mappings.py deleted file mode 100644 index e7ba9d34..00000000 --- a/gia/eval/mappings.py +++ /dev/null @@ -1,11 +0,0 @@ -TASK_TO_ENV_MAPPING = { - "mujoco-ant": "Ant-v4", - "mujoco-halfcheetah": "HalfCheetah-v4", - "mujoco-hopper": "Hopper-v4", - "mujoco-doublependulum": "InvertedDoublePendulum-v4", - "mujoco-pendulum": "InvertedPendulum-v4", - "mujoco-reacher": "Reacher-v4", - "mujoco-swimmer": "Swimmer-v4", - "mujoco-walker": "Walker2d-v4", - # Atari etc... -} diff --git a/gia/eval/rl/__init__.py b/gia/eval/rl/__init__.py index e9557034..f9837536 100644 --- a/gia/eval/rl/__init__.py +++ b/gia/eval/rl/__init__.py @@ -1,5 +1,5 @@ from .envs import make -from .gym_evaluator import GymEvaluator +from .rl_evaluator import RLEvaluator -__all__ = ["make", "GymEvaluator"] +__all__ = ["make", "RLEvaluator"] diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py index 38a5a735..ab02c36e 100644 --- a/gia/eval/rl/envs/core.py +++ b/gia/eval/rl/envs/core.py @@ -14,7 +14,7 @@ ) -TASK_TO_ENV_MAPPING = { +TASK_NAME_TO_ENV_ID = { "atari-alien": "ALE/Alien-v5", "atari-amidar": "ALE/Amidar-v5", "atari-assault": "ALE/Assault-v5", @@ -181,7 +181,7 @@ def get_task_names() -> List[str]: Returns: list: List of environment ids """ - return list(TASK_TO_ENV_MAPPING.keys()) + return list(TASK_NAME_TO_ENV_ID.keys()) class AtariDictObservationWrapper(ObservationWrapper): @@ -200,7 +200,7 @@ def make_atari(task_name: str, **kwargs) -> Env: kwargs = {"frameskip": 1, "repeat_action_probability": 0.0, **kwargs} if task_name == "atari-montezumarevenge": kwargs["max_episode_steps"] = 18_000 - env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs) + env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs) env = gym.wrappers.RecordEpisodeStatistics(env) env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) @@ -248,7 +248,7 @@ def reward(self, reward): def make_babyai(task_name: str, **kwargs) -> Env: - env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs) + env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs) env = BabyAIDictObservationWrapper(env) env = FloatRewardWrapper(env) return env @@ -266,13 +266,13 @@ def observation(self, observation): def make_metaworld(task_name: str, **kwargs) -> Env: import metaworld # noqa - env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs) + env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs) env = ContinuousObservationDictWrapper(env) return env def make_mujoco(task_name: str, **kwargs) -> Env: - env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs) + env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs) env = ContinuousObservationDictWrapper(env) return env diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py index ed5d31ce..dd76b2e7 100644 --- a/gia/eval/rl/gia_agent.py +++ b/gia/eval/rl/gia_agent.py @@ -1,100 +1,213 @@ +from typing import List, Tuple + import numpy as np import torch from datasets import load_dataset +from gymnasium import spaces +from torch import Tensor from gia.datasets import GiaDataCollator, Prompter from gia.model.gia_model import GiaModel from gia.processing import GiaProcessor +from .envs.core import get_task_names, make + class GiaAgent: + r""" + An RL agent that uses Gia to generate actions. + + Warning: + The agent caches past key values from the model. This means that when you call `get_action` multiple times in + succession, the agent will generate actions based on all the previous actions passed to `get_action`. If you + want to reset the agent to generate actions based on the initial prompt, you need to call the `reset` method. + + Args: + pretrained_model_name_or_path (`str` or `os.PathLike`): + Can be either: + + - A string, the *model id* of a pretrained model configuration hosted inside a model repo on + huggingface.co. + - A path to a *directory* containing a configuration file saved using the + [`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method, + e.g., `./my_model_directory/`. + - A path or url to a saved configuration JSON *file*, e.g., + `./my_model_directory/configuration.json`. + task_name (`str`): + The environment id. Check the available task names with `GiaAgent.get_available_task_names()`. + collator (`GiaDataCollator`): + The GiaDataCollator to use for collating processed observations. + observation_space (`spaces.Space`): + The observation space. + action_space (`spaces.Space`): + The action space. + prompter (`Prompter`, *optional*, defaults to None): + The Prompter to use for generating prompts. When None, the agent will not use prompts. Defaults to None. + deterministic (`bool`, *optional*, defaults to False): + Whether to use deterministic action generation. Defaults to False. + """ + def __init__( self, - task: str, model: GiaModel, - obs_space, - action_space, - use_separator: bool = True, - ): + processor: GiaProcessor, + task_name: str, + num_envs: int = 1, + use_prompt: bool = True, + p_prompt: float = 0.25, + p_end: float = 0.1, + min_prompt_len: int = 1, + max_prompt_len: int = 1024, + deterministic: bool = False, + ) -> None: + self.processor = processor self.model = model - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - - self._num_obs_tokens = obs_space.shape[1] - self._num_act_tokens = action_space.shape[1] - self._tokens_per_step = self._num_obs_tokens + self._num_act_tokens + int(use_separator) - self._int_per_seq = (self.model.config.seq_len // self._tokens_per_step) - 1 - - dataset = load_dataset("gia-project/gia-dataset", task, split="test") - self.prompter = Prompter( - dataset, - min_prompt_len=self._int_per_seq, - max_prompt_len=self._int_per_seq, - ) - self._max_kv_size = self._int_per_seq * self._tokens_per_step + if use_prompt: + dataset = load_dataset("gia-project/gia-dataset", task_name, split="test", writer_batch_size=1) + self.prompter = Prompter(dataset, p_prompt, p_end, min_prompt_len, max_prompt_len) + else: + self.prompter = None - self.processor = GiaProcessor() self.collator = GiaDataCollator() + self.num_envs = num_envs + + # Get observation and action space + env = make(task_name) + self.observation_space = env.observation_space + self.action_space = env.action_space + + self.deterministic = deterministic + self.device = next(self.model.parameters()).device + self._max_length = self.model.config.max_position_embeddings - 100 # TODO: check this + + if not isinstance(self.observation_space, spaces.Dict): + raise TypeError("Unsupported observation space") + + if isinstance(self.action_space, spaces.Box): + self._num_act_tokens = self.action_space.shape[0] + elif isinstance(self.action_space, spaces.Discrete): + self._num_act_tokens = 1 + else: + raise TypeError("Unsupported action space") + + @staticmethod + def get_available_task_names() -> List[str]: + """ + Returns the available task names. + + Returns: + List[str]: The available task names. + """ + return get_task_names() + + def _truncate_past_key_values( + self, past_key_values: Tuple[Tuple[Tensor, Tensor], ...] + ) -> Tuple[Tuple[Tensor, Tensor], ...]: + return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) + def reset(self) -> None: - prompt = self.prompter.generate_prompts(1) - prompt_observations = np.array([prompt["continuous_observations"][0]]) - prompt_actions = np.array([prompt["continuous_actions"][0]]) + if self.prompter is not None: + prompts = self.prompter.generate_prompts(self.num_envs) + processed_prompt = self.processor( + **prompts, + padding=False, + truncation="max_length", + truncation_side="left", + max_length=self._max_length, + ) + processed_prompt = self.collator( + [ + {key: processed_prompt[key][ep_idx] for key in processed_prompt.keys()} + for ep_idx in range(self.num_envs) + ] + ) + for key in processed_prompt.keys(): + processed_prompt[key] = processed_prompt[key].to(self.device) + output = self.model(**processed_prompt, use_cache=True) + self._past_key_values = self._truncate_past_key_values(output.past_key_values) + else: + self._past_key_values = None - processed_prompt = self.processor( - continuous_observations=prompt_observations, - continuous_actions=prompt_actions, - padding=False, - truncation="max_length", - truncation_side="left", - max_length=self.model.config.seq_len - self._num_act_tokens, - ) + def get_action(self, observations: np.ndarray) -> np.ndarray: + """ + Predicts the next action given the current observation. - processed_prompt = self.collator([{key: processed_prompt[key][0] for key in processed_prompt.keys()}]) - for key in processed_prompt.keys(): - processed_prompt[key] = processed_prompt[key].to(self.device) - output = self.model(**processed_prompt, use_cache=True) - self._past_key_values = output.past_key_values + Args: + observations (np.ndarray): The current observation - def get_action(self, obs) -> np.ndarray: + Returns: + np.ndarray: The next action + """ + # Turn into episode + keys = observations[0].keys() + dict_observations = {} + for key in keys: + values = [obs[key] for obs in observations] + if isinstance(values[0], np.ndarray): + dict_observations[key] = np.expand_dims(np.stack(values), axis=1) + elif isinstance(values[0], str): + dict_observations[key] = [[value] for value in values] + else: + raise TypeError(f"Unsupported type for {key}") + + # Process observations processed = self.processor( - continuous_observations=[obs], - continuous_actions=[], + **dict_observations, padding=False, truncation="max_length", truncation_side="left", - max_length=self.model.config.seq_len - self._num_act_tokens, # ensure not to not overflow + max_length=self._max_length, # ensure not to not overflow + ) + + # Process and move to device + num_envs = len(processed["input_types"]) + processed = self.collator( + [{key: processed[key][ep_idx] for key in processed.keys()} for ep_idx in range(num_envs)] ) - processed = self.collator([{key: processed[key][0] for key in processed.keys()}]) for key in processed.keys(): processed[key] = processed[key].to(self.device) - action_tokens = [] - for i in range(self._num_act_tokens): + # Generate action tokens + action_tokens = [] + for _ in range(self._num_act_tokens): output = self.model(**processed, use_cache=True, past_key_values=self._past_key_values) - self._past_key_values = output.past_key_values + self._past_key_values = self._truncate_past_key_values(output.past_key_values) action_logits = output.logits[:, -1] - action_token = torch.argmax(action_logits, dim=1) - action_tokens.append(action_token) - processed["input_ids"] = action_token[None, :] - if i == 0: # only needs to be done once - processed["loss_mask"] = torch.ones(1, 1, dtype=torch.bool, device=self.device) - processed["input_types"] = torch.zeros(1, 1, dtype=torch.int64, device=self.device) - processed["local_positions"] = -torch.ones(1, 1, dtype=torch.int64, device=self.device) + if self.deterministic: + action_token = torch.argmax(action_logits, dim=-1) + else: + action_token = torch.multinomial(torch.softmax(action_logits, dim=-1), num_samples=1).squeeze(-1) + action_tokens.append(action_token.tolist()) + + processed = { + "input_ids": action_token[:, None], + "loss_mask": torch.ones(num_envs, 1, dtype=torch.bool, device=self.device), + "input_types": torch.zeros(num_envs, 1, dtype=torch.int64, device=self.device), + "local_positions": -torch.ones(num_envs, 1, dtype=torch.int64, device=self.device), + } - # to ensure the KV cache includes the last action token + # To ensure the KV cache includes the last action token output = self.model(**processed, use_cache=True, past_key_values=self._past_key_values) - self._past_key_values = output.past_key_values - if self._past_key_values[0][0].shape[2] > self._max_kv_size: - # remove one step of tokens, to ensure context < 1024 - self._past_key_values = [ - (k[:, :, self._tokens_per_step :], v[:, :, self._tokens_per_step :]) - for (k, v) in self._past_key_values - ] - action_tokens = torch.stack(action_tokens, dim=-1).cpu().tolist() - - # Decode the action tokens - action = np.array(self.processor.decode_continuous(action_tokens)) - # TODO: Clamp action to be in domain of action space? - return action + self._past_key_values = self._truncate_past_key_values(output.past_key_values) + + # Transpose action_tokens to be (num_envs, num_act_tokens) + action_tokens = np.array(action_tokens, dtype=self.action_space.dtype).T.tolist() + + if isinstance(self.action_space, spaces.Box): + # Decode the action tokens + actions = np.array(self.processor.decode_continuous(action_tokens), dtype=self.action_space.dtype) + + # Clip the action if necessary + actions = np.clip(actions, self.action_space.low, self.action_space.high) + + elif isinstance(self.action_space, spaces.Discrete): + # Decode the action tokens + actions = np.array(self.processor.decode_discrete(action_tokens), dtype=self.action_space.dtype) + actions = actions.squeeze(axis=1) + + # Clip the action if necessary (decoded actions are between 0 and num_bins) + actions = np.clip(actions, a_min=0, a_max=self.action_space.n - 1) + return actions diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py deleted file mode 100644 index f8531ee9..00000000 --- a/gia/eval/rl/gym_evaluator.py +++ /dev/null @@ -1,13 +0,0 @@ -import gym -from gym.vector.vector_env import VectorEnv - -from gia.eval.mappings import TASK_TO_ENV_MAPPING -from gia.eval.rl.rl_evaluator import RLEvaluator - - -class GymEvaluator(RLEvaluator): - def _build_env(self) -> VectorEnv: - NUM_ENVS = 1 - env_name = TASK_TO_ENV_MAPPING[self.task] - env = gym.vector.make(env_name, NUM_ENVS) - return env diff --git a/gia/eval/rl/rl_evaluator.py b/gia/eval/rl/rl_evaluator.py index c5cc423f..13600d3b 100644 --- a/gia/eval/rl/rl_evaluator.py +++ b/gia/eval/rl/rl_evaluator.py @@ -1,19 +1,18 @@ import numpy as np -from gym.vector.vector_env import VectorEnv from tqdm import tqdm from gia import GiaModel from gia.eval.evaluator import Evaluator +from gia.eval.rl import make from gia.eval.rl.gia_agent import GiaAgent +from gia.processing import GiaProcessor class RLEvaluator(Evaluator): - def _build_env(self) -> VectorEnv: # TODO: maybe just a gym.Env ? - raise NotImplementedError - def _evaluate(self, model: GiaModel) -> float: - env = self._build_env() - gia_agent = GiaAgent(self.task, model, env.observation_space, env.action_space) + env = make(self.task_name) + processor = GiaProcessor() # Ideally, model.config + gia_agent = GiaAgent(model, processor, self.task_name, num_envs=1) returns = [] # due to how to KV cache is used, we only can evaluate one env instance at a time @@ -25,11 +24,11 @@ def _evaluate(self, model: GiaModel) -> float: while not done: # Compute the output of the model - action = gia_agent.get_action(obs) + action = gia_agent.get_action([obs])[0] obs, reward, terminated, truncated, info = env.step(action) done = terminated or truncated - accum_rewards.append(reward[0]) + accum_rewards.append(reward) returns.append(sum(accum_rewards)) env.close() diff --git a/gia/eval/utils.py b/gia/eval/utils.py index 17a1dcb9..0dc1f229 100644 --- a/gia/eval/utils.py +++ b/gia/eval/utils.py @@ -1,7 +1,8 @@ import subprocess +from gia.eval.evaluator import Evaluator from gia.eval.language_modeling.language_modeling_evaluator import LanguageModelingEvaluator -from gia.eval.rl import GymEvaluator +from gia.eval.rl import RLEvaluator def is_slurm_available() -> bool: @@ -13,19 +14,32 @@ def is_slurm_available() -> bool: return False -# TODO: A nice use case for structural pattern matching?! EVALUATORS = { - "mujoco": GymEvaluator, - "atari": GymEvaluator, + "mujoco": RLEvaluator, + "atari": RLEvaluator, "oscar": LanguageModelingEvaluator, "ok-vqa": LanguageModelingEvaluator, "conceptual-captions": LanguageModelingEvaluator, } -def get_evaluator(task): - if "-" in task: - domain = task.split("-")[0] # TODO: this will have problems for ok-vqa, etc.. - return EVALUATORS[domain] +def get_evaluator(task_name: str) -> Evaluator: + """ + Get the evaluator for a given task. - return EVALUATORS[task] + Args: + task_name (`str`): + The task name. + + Raises: + `ValueError`: If the task name is unknown. + + Returns: + evaluator (`Evaluator`): + The evaluator for the task. + """ + for domain in EVALUATORS.keys(): + if task_name.startswith(domain): + return EVALUATORS[domain] + else: + raise ValueError(f"Unknown task {task_name}") diff --git a/gia/processing/processing.py b/gia/processing/processing.py index a5a077c2..0c57a993 100644 --- a/gia/processing/processing.py +++ b/gia/processing/processing.py @@ -246,6 +246,7 @@ def __init__( ] ] self.local_positions_adder = LocalPositionsAdder(local_positions_groups) + self.use_separator = use_separator if use_separator: separator = { "input_ids": [token_shift + nb_bins], diff --git a/scripts/eval_checkpoint.py b/scripts/eval_checkpoint.py new file mode 100644 index 00000000..de2f6233 --- /dev/null +++ b/scripts/eval_checkpoint.py @@ -0,0 +1,19 @@ +from gia import GiaConfig, GiaModel +from gia.eval.rl import make +from gia.eval.rl.gia_agent import GiaAgent +from gia.processing import GiaProcessor + + +config = GiaConfig(num_layers=4, num_heads=12, hidden_size=384) +model = GiaModel(config) + + +task_name = "metaworld-assembly" +env = make(task_name, num_envs=1) +processor = GiaProcessor() +gia_agent = GiaAgent(model, processor, task_name, use_prompt=False) + +gia_agent.reset() +obs, info = env.reset() +action = gia_agent.get_action(obs) +print(action) diff --git a/tests/eval/rl/envs/test_core.py b/tests/eval/rl/envs/test_core.py index fec0940a..c008d693 100644 --- a/tests/eval/rl/envs/test_core.py +++ b/tests/eval/rl/envs/test_core.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from gia.eval.rl.envs.core import make +from gia.eval.rl import make OBS_KEYS = {"discrete_observations", "continuous_observations", "image_observations", "text_observations"} diff --git a/tests/eval/rl/test_gia_agent.py b/tests/eval/rl/test_gia_agent.py new file mode 100644 index 00000000..12b87872 --- /dev/null +++ b/tests/eval/rl/test_gia_agent.py @@ -0,0 +1,19 @@ +import pytest + +from gia import GiaConfig, GiaModel +from gia.eval.rl.gia_agent import GiaAgent +from gia.processing import GiaProcessor + + +@pytest.mark.parametrize("task_name", ["atari-alien", "babyai-action-obj-door", "metaworld-assembly", "mujoco-ant"]) +def test_gia_agent(task_name): + num_envs = 2 + config = GiaConfig(seq_len=128, hidden_size=32, nul_layers=4, num_heads=4) + model = GiaModel(config) + processor = GiaProcessor() + agent = GiaAgent(model, processor, task_name, num_envs, use_prompt=False) + agent.reset() + observations = [agent.observation_space.sample() for _ in range(num_envs)] + actions = agent.get_action(observations) + for action in actions: + assert agent.action_space.contains(action) diff --git a/tests/eval/test_mujoco_evaluator.py b/tests/eval/test_mujoco_evaluator.py index 90cc8fab..0a2d558d 100644 --- a/tests/eval/test_mujoco_evaluator.py +++ b/tests/eval/test_mujoco_evaluator.py @@ -1,6 +1,6 @@ from gia import GiaConfig, GiaModel from gia.config.arguments import Arguments -from gia.eval.rl.gym_evaluator import GymEvaluator +from gia.eval.rl.rl_evaluator import RLEvaluator def test_mujoco_evaluator(): @@ -9,5 +9,5 @@ def test_mujoco_evaluator(): args = Arguments(output_dir="tmp", n_episodes=2, task_names="mujoco-doublependulum") - evaluator = GymEvaluator(args, "mujoco-doublependulum") + evaluator = RLEvaluator(args, "mujoco-doublependulum") evaluator.evaluate(model)