diff --git a/scripts/enjoy.py b/scripts/enjoy.py index b4d56ba..9f7d95a 100644 --- a/scripts/enjoy.py +++ b/scripts/enjoy.py @@ -9,6 +9,7 @@ NTupleNetworkBasePolicy, NTupleNetworkQLearningPolicy, NTupleNetworkTDPolicy, + NTupleNetworkTDPolicySmall, ) @@ -21,7 +22,7 @@ def parse_args() -> argparse.Namespace: "--algo", default="tdl", help="RL Algorithm", - choices=["ql", "tdl"], + choices=["ql", "tdl", "tdl-small"], ) parser.add_argument( "--env", @@ -62,6 +63,7 @@ def make_policy(algo: str, trained_agent: str) -> NTupleNetworkBasePolicy: algo_policy_map = { "ql": NTupleNetworkQLearningPolicy, "tdl": NTupleNetworkTDPolicy, + "tdl-small": NTupleNetworkTDPolicySmall, } policy = algo_policy_map[algo] return policy.load(trained_agent)