Skip to content

Commit

Permalink
Generic GiaAgent (#98)
Browse files Browse the repository at this point in the history
* genereic agent

* babyai wrapper

* Gia agent with test

* improve docstring and type hint

* generic make

* metaworld gymnasium

* new metaworld style

* update generate dataset

* try to fix generate dataset

* get_task_names

* get_task_names for test

* some modif in giaagent

* new make function

* use ale-v5 for atari

* general gia agent

* tmp fix

* lr scheduler

* lr scheduler

* update test_core

* fix schedulmer

* drop sample_factory dep

* rm babyai test

* remove unused import

* reorder import

* TASK_NAME_TO_ENV_ID

* fix get_evaluator

* remove gym evaluator

* use make in rl evaluator

* task -> task_name

* rlevaluator

* fix case

* test_split in LanguageModelingEvaluator

* gia agent for images

* fix rl evaluator _evaluate

* fix RLEvaluator

* fix rl evaluator
  • Loading branch information
qgallouedec authored Aug 15, 2023
1 parent 7d23690 commit dc7ba8c
Show file tree
Hide file tree
Showing 15 changed files with 264 additions and 123 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ __pycache__/

# C extensions
*.so

checkpoints
logs
# Distribution / packaging
.Python
build/
Expand Down
4 changes: 2 additions & 2 deletions gia/eval/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@


class Evaluator:
def __init__(self, args: Arguments, task: str) -> None:
def __init__(self, args: Arguments, task_name: str) -> None:
self.args = args
self.task = task
self.task_name = task_name

@torch.no_grad()
def evaluate(self, model: GiaModel) -> float:
Expand Down
5 changes: 2 additions & 3 deletions gia/eval/language_modeling/language_modeling_evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@


class LanguageModelingEvaluator(Evaluator):
def _evaluate(self, model: GiaModel):
model.eval()
def _evaluate(self, model: GiaModel) -> float:
losses = []
processor = GiaProcessor()
dataset = load_dataset("gia-project/gia-dataset", self.task, split="test")
dataset = load_dataset("gia-project/gia-dataset", self.task_name, split=self.args.test_split)
dataset = dataset.map(lambda batch: processor(**batch), remove_columns=dataset.column_names, batched=True)
dataloader = DataLoader(dataset, batch_size=self.args.batch_size, collate_fn=GiaDataCollator(), shuffle=True)
for step, batch in enumerate(dataloader):
Expand Down
11 changes: 0 additions & 11 deletions gia/eval/mappings.py

This file was deleted.

4 changes: 2 additions & 2 deletions gia/eval/rl/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from .envs import make
from .gym_evaluator import GymEvaluator
from .rl_evaluator import RLEvaluator


__all__ = ["make", "GymEvaluator"]
__all__ = ["make", "RLEvaluator"]
12 changes: 6 additions & 6 deletions gia/eval/rl/envs/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)


TASK_TO_ENV_MAPPING = {
TASK_NAME_TO_ENV_ID = {
"atari-alien": "ALE/Alien-v5",
"atari-amidar": "ALE/Amidar-v5",
"atari-assault": "ALE/Assault-v5",
Expand Down Expand Up @@ -181,7 +181,7 @@ def get_task_names() -> List[str]:
Returns:
list: List of environment ids
"""
return list(TASK_TO_ENV_MAPPING.keys())
return list(TASK_NAME_TO_ENV_ID.keys())


class AtariDictObservationWrapper(ObservationWrapper):
Expand All @@ -200,7 +200,7 @@ def make_atari(task_name: str, **kwargs) -> Env:
kwargs = {"frameskip": 1, "repeat_action_probability": 0.0, **kwargs}
if task_name == "atari-montezumarevenge":
kwargs["max_episode_steps"] = 18_000
env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs)
env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs)
env = gym.wrappers.RecordEpisodeStatistics(env)
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)
Expand Down Expand Up @@ -248,7 +248,7 @@ def reward(self, reward):


def make_babyai(task_name: str, **kwargs) -> Env:
env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs)
env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs)
env = BabyAIDictObservationWrapper(env)
env = FloatRewardWrapper(env)
return env
Expand All @@ -266,13 +266,13 @@ def observation(self, observation):
def make_metaworld(task_name: str, **kwargs) -> Env:
import metaworld # noqa

env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs)
env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs)
env = ContinuousObservationDictWrapper(env)
return env


def make_mujoco(task_name: str, **kwargs) -> Env:
env = gym.make(TASK_TO_ENV_MAPPING[task_name], **kwargs)
env = gym.make(TASK_NAME_TO_ENV_ID[task_name], **kwargs)
env = ContinuousObservationDictWrapper(env)
return env

Expand Down
243 changes: 178 additions & 65 deletions gia/eval/rl/gia_agent.py
Original file line number Diff line number Diff line change
@@ -1,100 +1,213 @@
from typing import List, Tuple

import numpy as np
import torch
from datasets import load_dataset
from gymnasium import spaces
from torch import Tensor

from gia.datasets import GiaDataCollator, Prompter
from gia.model.gia_model import GiaModel
from gia.processing import GiaProcessor

from .envs.core import get_task_names, make


class GiaAgent:
r"""
An RL agent that uses Gia to generate actions.
Warning:
The agent caches past key values from the model. This means that when you call `get_action` multiple times in
succession, the agent will generate actions based on all the previous actions passed to `get_action`. If you
want to reset the agent to generate actions based on the initial prompt, you need to call the `reset` method.
Args:
pretrained_model_name_or_path (`str` or `os.PathLike`):
Can be either:
- A string, the *model id* of a pretrained model configuration hosted inside a model repo on
huggingface.co.
- A path to a *directory* containing a configuration file saved using the
[`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method,
e.g., `./my_model_directory/`.
- A path or url to a saved configuration JSON *file*, e.g.,
`./my_model_directory/configuration.json`.
task_name (`str`):
The environment id. Check the available task names with `GiaAgent.get_available_task_names()`.
collator (`GiaDataCollator`):
The GiaDataCollator to use for collating processed observations.
observation_space (`spaces.Space`):
The observation space.
action_space (`spaces.Space`):
The action space.
prompter (`Prompter`, *optional*, defaults to None):
The Prompter to use for generating prompts. When None, the agent will not use prompts. Defaults to None.
deterministic (`bool`, *optional*, defaults to False):
Whether to use deterministic action generation. Defaults to False.
"""

def __init__(
self,
task: str,
model: GiaModel,
obs_space,
action_space,
use_separator: bool = True,
):
processor: GiaProcessor,
task_name: str,
num_envs: int = 1,
use_prompt: bool = True,
p_prompt: float = 0.25,
p_end: float = 0.1,
min_prompt_len: int = 1,
max_prompt_len: int = 1024,
deterministic: bool = False,
) -> None:
self.processor = processor
self.model = model
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

self._num_obs_tokens = obs_space.shape[1]
self._num_act_tokens = action_space.shape[1]
self._tokens_per_step = self._num_obs_tokens + self._num_act_tokens + int(use_separator)
self._int_per_seq = (self.model.config.seq_len // self._tokens_per_step) - 1

dataset = load_dataset("gia-project/gia-dataset", task, split="test")
self.prompter = Prompter(
dataset,
min_prompt_len=self._int_per_seq,
max_prompt_len=self._int_per_seq,
)

self._max_kv_size = self._int_per_seq * self._tokens_per_step
if use_prompt:
dataset = load_dataset("gia-project/gia-dataset", task_name, split="test", writer_batch_size=1)
self.prompter = Prompter(dataset, p_prompt, p_end, min_prompt_len, max_prompt_len)
else:
self.prompter = None

self.processor = GiaProcessor()
self.collator = GiaDataCollator()

self.num_envs = num_envs

# Get observation and action space
env = make(task_name)
self.observation_space = env.observation_space
self.action_space = env.action_space

self.deterministic = deterministic
self.device = next(self.model.parameters()).device
self._max_length = self.model.config.max_position_embeddings - 100 # TODO: check this

if not isinstance(self.observation_space, spaces.Dict):
raise TypeError("Unsupported observation space")

if isinstance(self.action_space, spaces.Box):
self._num_act_tokens = self.action_space.shape[0]
elif isinstance(self.action_space, spaces.Discrete):
self._num_act_tokens = 1
else:
raise TypeError("Unsupported action space")

@staticmethod
def get_available_task_names() -> List[str]:
"""
Returns the available task names.
Returns:
List[str]: The available task names.
"""
return get_task_names()

def _truncate_past_key_values(
self, past_key_values: Tuple[Tuple[Tensor, Tensor], ...]
) -> Tuple[Tuple[Tensor, Tensor], ...]:
return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)

def reset(self) -> None:
prompt = self.prompter.generate_prompts(1)
prompt_observations = np.array([prompt["continuous_observations"][0]])
prompt_actions = np.array([prompt["continuous_actions"][0]])
if self.prompter is not None:
prompts = self.prompter.generate_prompts(self.num_envs)
processed_prompt = self.processor(
**prompts,
padding=False,
truncation="max_length",
truncation_side="left",
max_length=self._max_length,
)
processed_prompt = self.collator(
[
{key: processed_prompt[key][ep_idx] for key in processed_prompt.keys()}
for ep_idx in range(self.num_envs)
]
)
for key in processed_prompt.keys():
processed_prompt[key] = processed_prompt[key].to(self.device)
output = self.model(**processed_prompt, use_cache=True)
self._past_key_values = self._truncate_past_key_values(output.past_key_values)
else:
self._past_key_values = None

processed_prompt = self.processor(
continuous_observations=prompt_observations,
continuous_actions=prompt_actions,
padding=False,
truncation="max_length",
truncation_side="left",
max_length=self.model.config.seq_len - self._num_act_tokens,
)
def get_action(self, observations: np.ndarray) -> np.ndarray:
"""
Predicts the next action given the current observation.
processed_prompt = self.collator([{key: processed_prompt[key][0] for key in processed_prompt.keys()}])
for key in processed_prompt.keys():
processed_prompt[key] = processed_prompt[key].to(self.device)
output = self.model(**processed_prompt, use_cache=True)
self._past_key_values = output.past_key_values
Args:
observations (np.ndarray): The current observation
def get_action(self, obs) -> np.ndarray:
Returns:
np.ndarray: The next action
"""
# Turn into episode
keys = observations[0].keys()
dict_observations = {}
for key in keys:
values = [obs[key] for obs in observations]
if isinstance(values[0], np.ndarray):
dict_observations[key] = np.expand_dims(np.stack(values), axis=1)
elif isinstance(values[0], str):
dict_observations[key] = [[value] for value in values]
else:
raise TypeError(f"Unsupported type for {key}")

# Process observations
processed = self.processor(
continuous_observations=[obs],
continuous_actions=[],
**dict_observations,
padding=False,
truncation="max_length",
truncation_side="left",
max_length=self.model.config.seq_len - self._num_act_tokens, # ensure not to not overflow
max_length=self._max_length, # ensure not to not overflow
)

# Process and move to device
num_envs = len(processed["input_types"])
processed = self.collator(
[{key: processed[key][ep_idx] for key in processed.keys()} for ep_idx in range(num_envs)]
)
processed = self.collator([{key: processed[key][0] for key in processed.keys()}])
for key in processed.keys():
processed[key] = processed[key].to(self.device)
action_tokens = []

for i in range(self._num_act_tokens):
# Generate action tokens
action_tokens = []
for _ in range(self._num_act_tokens):
output = self.model(**processed, use_cache=True, past_key_values=self._past_key_values)
self._past_key_values = output.past_key_values
self._past_key_values = self._truncate_past_key_values(output.past_key_values)
action_logits = output.logits[:, -1]
action_token = torch.argmax(action_logits, dim=1)
action_tokens.append(action_token)

processed["input_ids"] = action_token[None, :]
if i == 0: # only needs to be done once
processed["loss_mask"] = torch.ones(1, 1, dtype=torch.bool, device=self.device)
processed["input_types"] = torch.zeros(1, 1, dtype=torch.int64, device=self.device)
processed["local_positions"] = -torch.ones(1, 1, dtype=torch.int64, device=self.device)
if self.deterministic:
action_token = torch.argmax(action_logits, dim=-1)
else:
action_token = torch.multinomial(torch.softmax(action_logits, dim=-1), num_samples=1).squeeze(-1)
action_tokens.append(action_token.tolist())

processed = {
"input_ids": action_token[:, None],
"loss_mask": torch.ones(num_envs, 1, dtype=torch.bool, device=self.device),
"input_types": torch.zeros(num_envs, 1, dtype=torch.int64, device=self.device),
"local_positions": -torch.ones(num_envs, 1, dtype=torch.int64, device=self.device),
}

# to ensure the KV cache includes the last action token
# To ensure the KV cache includes the last action token
output = self.model(**processed, use_cache=True, past_key_values=self._past_key_values)
self._past_key_values = output.past_key_values
if self._past_key_values[0][0].shape[2] > self._max_kv_size:
# remove one step of tokens, to ensure context < 1024
self._past_key_values = [
(k[:, :, self._tokens_per_step :], v[:, :, self._tokens_per_step :])
for (k, v) in self._past_key_values
]
action_tokens = torch.stack(action_tokens, dim=-1).cpu().tolist()

# Decode the action tokens
action = np.array(self.processor.decode_continuous(action_tokens))
# TODO: Clamp action to be in domain of action space?
return action
self._past_key_values = self._truncate_past_key_values(output.past_key_values)

# Transpose action_tokens to be (num_envs, num_act_tokens)
action_tokens = np.array(action_tokens, dtype=self.action_space.dtype).T.tolist()

if isinstance(self.action_space, spaces.Box):
# Decode the action tokens
actions = np.array(self.processor.decode_continuous(action_tokens), dtype=self.action_space.dtype)

# Clip the action if necessary
actions = np.clip(actions, self.action_space.low, self.action_space.high)

elif isinstance(self.action_space, spaces.Discrete):
# Decode the action tokens
actions = np.array(self.processor.decode_discrete(action_tokens), dtype=self.action_space.dtype)
actions = actions.squeeze(axis=1)

# Clip the action if necessary (decoded actions are between 0 and num_bins)
actions = np.clip(actions, a_min=0, a_max=self.action_space.n - 1)
return actions
Loading

0 comments on commit dc7ba8c

Please sign in to comment.