-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* genereic agent * babyai wrapper * Gia agent with test * improve docstring and type hint * generic make * metaworld gymnasium * new metaworld style * update generate dataset * try to fix generate dataset * get_task_names * get_task_names for test * some modif in giaagent * new make function * use ale-v5 for atari * general gia agent * tmp fix * lr scheduler * lr scheduler * update test_core * fix schedulmer * drop sample_factory dep * rm babyai test * remove unused import * reorder import * TASK_NAME_TO_ENV_ID * fix get_evaluator * remove gym evaluator * use make in rl evaluator * task -> task_name * rlevaluator * fix case * test_split in LanguageModelingEvaluator * gia agent for images * fix rl evaluator _evaluate * fix RLEvaluator * fix rl evaluator
- Loading branch information
1 parent
7d23690
commit dc7ba8c
Showing
15 changed files
with
264 additions
and
123 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,8 @@ __pycache__/ | |
|
||
# C extensions | ||
*.so | ||
|
||
checkpoints | ||
logs | ||
# Distribution / packaging | ||
.Python | ||
build/ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
from .envs import make | ||
from .gym_evaluator import GymEvaluator | ||
from .rl_evaluator import RLEvaluator | ||
|
||
|
||
__all__ = ["make", "GymEvaluator"] | ||
__all__ = ["make", "RLEvaluator"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,100 +1,213 @@ | ||
from typing import List, Tuple | ||
|
||
import numpy as np | ||
import torch | ||
from datasets import load_dataset | ||
from gymnasium import spaces | ||
from torch import Tensor | ||
|
||
from gia.datasets import GiaDataCollator, Prompter | ||
from gia.model.gia_model import GiaModel | ||
from gia.processing import GiaProcessor | ||
|
||
from .envs.core import get_task_names, make | ||
|
||
|
||
class GiaAgent: | ||
r""" | ||
An RL agent that uses Gia to generate actions. | ||
Warning: | ||
The agent caches past key values from the model. This means that when you call `get_action` multiple times in | ||
succession, the agent will generate actions based on all the previous actions passed to `get_action`. If you | ||
want to reset the agent to generate actions based on the initial prompt, you need to call the `reset` method. | ||
Args: | ||
pretrained_model_name_or_path (`str` or `os.PathLike`): | ||
Can be either: | ||
- A string, the *model id* of a pretrained model configuration hosted inside a model repo on | ||
huggingface.co. | ||
- A path to a *directory* containing a configuration file saved using the | ||
[`~PretrainedConfig.save_pretrained`] method, or the [`~PreTrainedModel.save_pretrained`] method, | ||
e.g., `./my_model_directory/`. | ||
- A path or url to a saved configuration JSON *file*, e.g., | ||
`./my_model_directory/configuration.json`. | ||
task_name (`str`): | ||
The environment id. Check the available task names with `GiaAgent.get_available_task_names()`. | ||
collator (`GiaDataCollator`): | ||
The GiaDataCollator to use for collating processed observations. | ||
observation_space (`spaces.Space`): | ||
The observation space. | ||
action_space (`spaces.Space`): | ||
The action space. | ||
prompter (`Prompter`, *optional*, defaults to None): | ||
The Prompter to use for generating prompts. When None, the agent will not use prompts. Defaults to None. | ||
deterministic (`bool`, *optional*, defaults to False): | ||
Whether to use deterministic action generation. Defaults to False. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
task: str, | ||
model: GiaModel, | ||
obs_space, | ||
action_space, | ||
use_separator: bool = True, | ||
): | ||
processor: GiaProcessor, | ||
task_name: str, | ||
num_envs: int = 1, | ||
use_prompt: bool = True, | ||
p_prompt: float = 0.25, | ||
p_end: float = 0.1, | ||
min_prompt_len: int = 1, | ||
max_prompt_len: int = 1024, | ||
deterministic: bool = False, | ||
) -> None: | ||
self.processor = processor | ||
self.model = model | ||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | ||
|
||
self._num_obs_tokens = obs_space.shape[1] | ||
self._num_act_tokens = action_space.shape[1] | ||
self._tokens_per_step = self._num_obs_tokens + self._num_act_tokens + int(use_separator) | ||
self._int_per_seq = (self.model.config.seq_len // self._tokens_per_step) - 1 | ||
|
||
dataset = load_dataset("gia-project/gia-dataset", task, split="test") | ||
self.prompter = Prompter( | ||
dataset, | ||
min_prompt_len=self._int_per_seq, | ||
max_prompt_len=self._int_per_seq, | ||
) | ||
|
||
self._max_kv_size = self._int_per_seq * self._tokens_per_step | ||
if use_prompt: | ||
dataset = load_dataset("gia-project/gia-dataset", task_name, split="test", writer_batch_size=1) | ||
self.prompter = Prompter(dataset, p_prompt, p_end, min_prompt_len, max_prompt_len) | ||
else: | ||
self.prompter = None | ||
|
||
self.processor = GiaProcessor() | ||
self.collator = GiaDataCollator() | ||
|
||
self.num_envs = num_envs | ||
|
||
# Get observation and action space | ||
env = make(task_name) | ||
self.observation_space = env.observation_space | ||
self.action_space = env.action_space | ||
|
||
self.deterministic = deterministic | ||
self.device = next(self.model.parameters()).device | ||
self._max_length = self.model.config.max_position_embeddings - 100 # TODO: check this | ||
|
||
if not isinstance(self.observation_space, spaces.Dict): | ||
raise TypeError("Unsupported observation space") | ||
|
||
if isinstance(self.action_space, spaces.Box): | ||
self._num_act_tokens = self.action_space.shape[0] | ||
elif isinstance(self.action_space, spaces.Discrete): | ||
self._num_act_tokens = 1 | ||
else: | ||
raise TypeError("Unsupported action space") | ||
|
||
@staticmethod | ||
def get_available_task_names() -> List[str]: | ||
""" | ||
Returns the available task names. | ||
Returns: | ||
List[str]: The available task names. | ||
""" | ||
return get_task_names() | ||
|
||
def _truncate_past_key_values( | ||
self, past_key_values: Tuple[Tuple[Tensor, Tensor], ...] | ||
) -> Tuple[Tuple[Tensor, Tensor], ...]: | ||
return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values) | ||
|
||
def reset(self) -> None: | ||
prompt = self.prompter.generate_prompts(1) | ||
prompt_observations = np.array([prompt["continuous_observations"][0]]) | ||
prompt_actions = np.array([prompt["continuous_actions"][0]]) | ||
if self.prompter is not None: | ||
prompts = self.prompter.generate_prompts(self.num_envs) | ||
processed_prompt = self.processor( | ||
**prompts, | ||
padding=False, | ||
truncation="max_length", | ||
truncation_side="left", | ||
max_length=self._max_length, | ||
) | ||
processed_prompt = self.collator( | ||
[ | ||
{key: processed_prompt[key][ep_idx] for key in processed_prompt.keys()} | ||
for ep_idx in range(self.num_envs) | ||
] | ||
) | ||
for key in processed_prompt.keys(): | ||
processed_prompt[key] = processed_prompt[key].to(self.device) | ||
output = self.model(**processed_prompt, use_cache=True) | ||
self._past_key_values = self._truncate_past_key_values(output.past_key_values) | ||
else: | ||
self._past_key_values = None | ||
|
||
processed_prompt = self.processor( | ||
continuous_observations=prompt_observations, | ||
continuous_actions=prompt_actions, | ||
padding=False, | ||
truncation="max_length", | ||
truncation_side="left", | ||
max_length=self.model.config.seq_len - self._num_act_tokens, | ||
) | ||
def get_action(self, observations: np.ndarray) -> np.ndarray: | ||
""" | ||
Predicts the next action given the current observation. | ||
processed_prompt = self.collator([{key: processed_prompt[key][0] for key in processed_prompt.keys()}]) | ||
for key in processed_prompt.keys(): | ||
processed_prompt[key] = processed_prompt[key].to(self.device) | ||
output = self.model(**processed_prompt, use_cache=True) | ||
self._past_key_values = output.past_key_values | ||
Args: | ||
observations (np.ndarray): The current observation | ||
def get_action(self, obs) -> np.ndarray: | ||
Returns: | ||
np.ndarray: The next action | ||
""" | ||
# Turn into episode | ||
keys = observations[0].keys() | ||
dict_observations = {} | ||
for key in keys: | ||
values = [obs[key] for obs in observations] | ||
if isinstance(values[0], np.ndarray): | ||
dict_observations[key] = np.expand_dims(np.stack(values), axis=1) | ||
elif isinstance(values[0], str): | ||
dict_observations[key] = [[value] for value in values] | ||
else: | ||
raise TypeError(f"Unsupported type for {key}") | ||
|
||
# Process observations | ||
processed = self.processor( | ||
continuous_observations=[obs], | ||
continuous_actions=[], | ||
**dict_observations, | ||
padding=False, | ||
truncation="max_length", | ||
truncation_side="left", | ||
max_length=self.model.config.seq_len - self._num_act_tokens, # ensure not to not overflow | ||
max_length=self._max_length, # ensure not to not overflow | ||
) | ||
|
||
# Process and move to device | ||
num_envs = len(processed["input_types"]) | ||
processed = self.collator( | ||
[{key: processed[key][ep_idx] for key in processed.keys()} for ep_idx in range(num_envs)] | ||
) | ||
processed = self.collator([{key: processed[key][0] for key in processed.keys()}]) | ||
for key in processed.keys(): | ||
processed[key] = processed[key].to(self.device) | ||
action_tokens = [] | ||
|
||
for i in range(self._num_act_tokens): | ||
# Generate action tokens | ||
action_tokens = [] | ||
for _ in range(self._num_act_tokens): | ||
output = self.model(**processed, use_cache=True, past_key_values=self._past_key_values) | ||
self._past_key_values = output.past_key_values | ||
self._past_key_values = self._truncate_past_key_values(output.past_key_values) | ||
action_logits = output.logits[:, -1] | ||
action_token = torch.argmax(action_logits, dim=1) | ||
action_tokens.append(action_token) | ||
|
||
processed["input_ids"] = action_token[None, :] | ||
if i == 0: # only needs to be done once | ||
processed["loss_mask"] = torch.ones(1, 1, dtype=torch.bool, device=self.device) | ||
processed["input_types"] = torch.zeros(1, 1, dtype=torch.int64, device=self.device) | ||
processed["local_positions"] = -torch.ones(1, 1, dtype=torch.int64, device=self.device) | ||
if self.deterministic: | ||
action_token = torch.argmax(action_logits, dim=-1) | ||
else: | ||
action_token = torch.multinomial(torch.softmax(action_logits, dim=-1), num_samples=1).squeeze(-1) | ||
action_tokens.append(action_token.tolist()) | ||
|
||
processed = { | ||
"input_ids": action_token[:, None], | ||
"loss_mask": torch.ones(num_envs, 1, dtype=torch.bool, device=self.device), | ||
"input_types": torch.zeros(num_envs, 1, dtype=torch.int64, device=self.device), | ||
"local_positions": -torch.ones(num_envs, 1, dtype=torch.int64, device=self.device), | ||
} | ||
|
||
# to ensure the KV cache includes the last action token | ||
# To ensure the KV cache includes the last action token | ||
output = self.model(**processed, use_cache=True, past_key_values=self._past_key_values) | ||
self._past_key_values = output.past_key_values | ||
if self._past_key_values[0][0].shape[2] > self._max_kv_size: | ||
# remove one step of tokens, to ensure context < 1024 | ||
self._past_key_values = [ | ||
(k[:, :, self._tokens_per_step :], v[:, :, self._tokens_per_step :]) | ||
for (k, v) in self._past_key_values | ||
] | ||
action_tokens = torch.stack(action_tokens, dim=-1).cpu().tolist() | ||
|
||
# Decode the action tokens | ||
action = np.array(self.processor.decode_continuous(action_tokens)) | ||
# TODO: Clamp action to be in domain of action space? | ||
return action | ||
self._past_key_values = self._truncate_past_key_values(output.past_key_values) | ||
|
||
# Transpose action_tokens to be (num_envs, num_act_tokens) | ||
action_tokens = np.array(action_tokens, dtype=self.action_space.dtype).T.tolist() | ||
|
||
if isinstance(self.action_space, spaces.Box): | ||
# Decode the action tokens | ||
actions = np.array(self.processor.decode_continuous(action_tokens), dtype=self.action_space.dtype) | ||
|
||
# Clip the action if necessary | ||
actions = np.clip(actions, self.action_space.low, self.action_space.high) | ||
|
||
elif isinstance(self.action_space, spaces.Discrete): | ||
# Decode the action tokens | ||
actions = np.array(self.processor.decode_discrete(action_tokens), dtype=self.action_space.dtype) | ||
actions = actions.squeeze(axis=1) | ||
|
||
# Clip the action if necessary (decoded actions are between 0 and num_bins) | ||
actions = np.clip(actions, a_min=0, a_max=self.action_space.n - 1) | ||
return actions |
Oops, something went wrong.