diff --git a/README.md b/README.md index 7c8a82a704..595da21956 100644 --- a/README.md +++ b/README.md @@ -273,6 +273,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo` | 29 |[dmc2gym](https://github.com/denisyarats/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/dmc2gym/dmc2gym_cheetah.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/dmc2gym)
[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/dmc2gym.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/dmc2gym_zh.html) | | 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs)
环境指南 | | 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym-pybullet-drones/gym-pybullet-drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)
环境指南 | +| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)
环境指南 | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space diff --git a/ding/envs/env_manager/base_env_manager.py b/ding/envs/env_manager/base_env_manager.py index e54eafe926..47dfdd5732 100644 --- a/ding/envs/env_manager/base_env_manager.py +++ b/ding/envs/env_manager/base_env_manager.py @@ -203,7 +203,9 @@ def done(self) -> bool: @property def method_name_list(self) -> list: - return ['reset', 'step', 'seed', 'close', 'enable_save_replay', 'render'] + return [ + 'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure' + ] def env_state_done(self, env_id: int) -> bool: return self._env_states[env_id] == EnvState.DONE @@ -418,6 +420,19 @@ def enable_save_replay(self, replay_path: Union[List[str], str]) -> None: replay_path = [replay_path] * self.env_num self._env_replay_path = replay_path + def enable_save_figure(self, env_id: int, figure_path: Union[List[str], str]) -> None: + """ + Overview: + Set each env's replay save path. + Arguments: + - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \ + Or one path for all environments. + """ + if isinstance(figure_path, str): + self._env[env_id].enable_save_figure(figure_path) + else: + raise TypeError("invalid figure_path arguments type: {}".format(type(figure_path))) + def close(self) -> None: """ Overview: @@ -431,6 +446,9 @@ def close(self) -> None: self._env_states[i] = EnvState.VOID self._closed = True + def reward_shaping(self, env_id: int, transitions: List[dict]) -> List[dict]: + return self._envs[env_id].reward_shaping(transitions) + @property def closed(self) -> bool: return self._closed diff --git a/ding/worker/collector/episode_serial_collector.py b/ding/worker/collector/episode_serial_collector.py index 9fcf733779..6d89462bf0 100644 --- a/ding/worker/collector/episode_serial_collector.py +++ b/ding/worker/collector/episode_serial_collector.py @@ -21,7 +21,9 @@ class EpisodeSerialCollector(ISerialCollector): envstep """ - config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False) + config = dict( + deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False, reward_shaping=False + ) def __init__( self, @@ -251,6 +253,8 @@ def collect(self, # prepare data if timestep.done: transitions = to_tensor_transitions(self._traj_buffer[env_id]) + if self._cfg.reward_shaping: + self._env.reward_shaping(env_id, transitions) if self._cfg.get_train_sample: train_sample = self._policy.get_train_sample(transitions) return_data.extend(train_sample) diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py index 3e4645260a..32222c0804 100644 --- a/ding/worker/collector/interaction_serial_evaluator.py +++ b/ding/worker/collector/interaction_serial_evaluator.py @@ -241,6 +241,9 @@ def eval( continue if t.done: # Env reset is done by env_manager automatically. + if 'figure_path' in self._cfg: + if self._cfg.figure_path is not None: + self._env.enable_save_figure(env_id, self._cfg.figure_path) self._policy.reset([env_id]) reward = t.info['final_eval_reward'] if 'episode_info' in t.info: diff --git a/dizoo/beergame/beergame.png b/dizoo/beergame/beergame.png new file mode 100644 index 0000000000..ec76d6e5ca Binary files /dev/null and b/dizoo/beergame/beergame.png differ diff --git a/dizoo/beergame/config/beergame_onppo_config.py b/dizoo/beergame/config/beergame_onppo_config.py new file mode 100644 index 0000000000..7ad87a23fe --- /dev/null +++ b/dizoo/beergame/config/beergame_onppo_config.py @@ -0,0 +1,70 @@ +from easydict import EasyDict + +beergame_ppo_config = dict( + exp_name='beergame_ppo_seed0', + env=dict( + collector_env_num=8, + evaluator_env_num=8, + n_evaluator_episode=8, + stop_value=200, + role=0, # 0-3 : retailer, warehouse, distributor, manufacturer + agent_type='bs', + # type of co-player, 'bs'- base stock, 'Strm'- use Sterman formula to model typical human behavior + demandDistribution=0 + # distribution of demand, default=0, '0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data' + ), + policy=dict( + cuda=True, + recompute_adv=True, + action_space='discrete', + model=dict( + obs_shape=50, # statedim * multPerdInpt= 5 * 10 + action_shape=5, # the quantity relative to the arriving order + action_space='discrete', + encoder_hidden_size_list=[64, 64, 128], + actor_head_hidden_size=128, + critic_head_hidden_size=128, + ), + learn=dict( + epoch_per_collect=10, + batch_size=320, + learning_rate=3e-4, + entropy_weight=0.001, + adv_norm=True, + value_norm=True, + # for onppo, when we recompute adv, we need the key done in data to split traj, so we must + # use ignore_done=False here, + # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True + # for halfcheetah, the length=1000 + ignore_done=True, + ), + collect=dict( + n_episode=8, + discount_factor=0.99, + gae_lambda=0.95, + collector=dict( + get_train_sample=True, + reward_shaping=True, # whether use total reward to reshape reward + ), + ), + eval=dict(evaluator=dict(eval_freq=500, )), + ), +) +beergame_ppo_config = EasyDict(beergame_ppo_config) +main_config = beergame_ppo_config +beergame_ppo_create_config = dict( + env=dict( + type='beergame', + import_names=['dizoo.beergame.envs.beergame_env'], + ), + env_manager=dict(type='base'), + policy=dict(type='ppo'), + collector=dict(type='episode', ), +) +beergame_ppo_create_config = EasyDict(beergame_ppo_create_config) +create_config = beergame_ppo_create_config + +if __name__ == "__main__": + # or you can enter `ding -m serial_onpolicy -c beergame_onppo_config.py -s 0` + from ding.entry import serial_pipeline_onpolicy + serial_pipeline_onpolicy([main_config, create_config], seed=0) diff --git a/dizoo/beergame/entry/beergame_eval.py b/dizoo/beergame/entry/beergame_eval.py new file mode 100644 index 0000000000..5299107e78 --- /dev/null +++ b/dizoo/beergame/entry/beergame_eval.py @@ -0,0 +1,42 @@ +import os +import torch +from tensorboardX import SummaryWriter + +from ding.config import compile_config +from ding.worker import InteractionSerialEvaluator +from ding.envs import BaseEnvManager +from ding.policy import PPOPolicy +from ding.model import VAC +from ding.utils import set_pkg_seed +from dizoo.beergame.config.beergame_onppo_config import beergame_ppo_config, beergame_ppo_create_config +from ding.envs import get_vec_env_setting +from functools import partial + + +def main(cfg, seed=0): + env_fn = None + cfg, create_cfg = beergame_ppo_config, beergame_ppo_create_config + cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True) + collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num + + env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env) + cfg.env.manager.auto_reset = False + evaluator_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager) + evaluator_env.seed(seed, dynamic_seed=False) + set_pkg_seed(seed, use_cuda=cfg.policy.cuda) + model = VAC(**cfg.policy.model) + tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) + policy = PPOPolicy(cfg.policy, model=model) + # set the path to save figure + cfg.policy.eval.evaluator.figure_path = './' + evaluator = InteractionSerialEvaluator( + cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name + ) + # load model + model.load_state_dict(torch.load('model path', map_location='cpu')["model"]) + evaluator.eval(None, -1, -1) + + +if __name__ == "__main__": + beergame_ppo_config.exp_name = 'beergame_evaluate' + main(beergame_ppo_config) \ No newline at end of file diff --git a/dizoo/beergame/envs/BGAgent.py b/dizoo/beergame/envs/BGAgent.py new file mode 100644 index 0000000000..7c4088c1d0 --- /dev/null +++ b/dizoo/beergame/envs/BGAgent.py @@ -0,0 +1,152 @@ +# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL. +import argparse +import numpy as np + + +# Here we want to define the agent class for the BeerGame +class Agent(object): + # initializes the agents with initial values for IL, OO and saves self.agentNum for recognizing the agents. + def __init__( + self, agentNum: int, IL: int, AO: int, AS: int, c_h: float, c_p: float, eta: int, compuType: str, + config: argparse.Namespace + ) -> None: + self.agentNum = agentNum + self.IL = IL # Inventory level of each agent - changes during the game + self.OO = 0 # Open order of each agent - changes during the game + self.ASInitial = AS # the initial arriving shipment. + self.ILInitial = IL # IL at which we start each game with this number + self.AOInitial = AO # OO at which we start each game with this number + self.config = config # an instance of config is stored inside the class + self.curState = [] # this function gets the current state of the game + self.nextState = [] + self.curReward = 0 # the reward observed at the current step + self.cumReward = 0 # cumulative reward; reset at the begining of each episode + self.totRew = 0 # it is reward of all players obtained for the current player. + self.c_h = c_h # holding cost + self.c_p = c_p # backorder cost + self.eta = eta # the total cost regulazer + self.AS = np.zeros((1, 1)) # arriced shipment + self.AO = np.zeros((1, 1)) # arrived order + self.action = 0 # the action at time t + self.compType = compuType + # self.compTypeTrain = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists + # self.compTypeTest = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists + self.alpha_b = self.config.alpha_b[self.agentNum] # parameters for the formula + self.betta_b = self.config.betta_b[self.agentNum] # parameters for the formula + if self.config.demandDistribution == 0: + self.a_b = np.mean((self.config.demandUp, self.config.demandLow)) # parameters for the formula + self.b_b = np.mean((self.config.demandUp, self.config.demandLow)) * ( + np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + + np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) + ) # parameters for the formula + elif self.config.demandDistribution == 1 or self.config.demandDistribution == 3 or self.config.demandDistribution == 4: + self.a_b = self.config.demandMu # parameters for the formula + self.b_b = self.config.demandMu * ( + np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + + np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) + ) # parameters for the formula + elif self.config.demandDistribution == 2: + self.a_b = 8 # parameters for the formula + self.b_b = (3 / 4.) * 8 * ( + np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + + np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) + ) # parameters for the formula + elif self.config.demandDistribution == 3: + self.a_b = 10 # parameters for the formula + self.b_b = 7 * ( + np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) + + np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum])) + ) # parameters for the formula + else: + raise Exception('The demand distribution is not defined or it is not a valid type.!') + + self.hist = [] # this is used for plotting - keeps the history for only one game + self.hist2 = [] # this is used for animation usage + self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017 + self.T = 0 + self.bsBaseStock = 0 + self.init_bsBaseStock = 0 + self.nextObservation = [] + + if self.compType == 'srdqn': + # sets the initial input of the network + self.currentState = np.stack( + [self.curState for _ in range(self.config.multPerdInpt)], axis=0 + ) # multPerdInpt observations stacked. each row is an observation + + # reset player information + def resetPlayer(self, T: int): + self.IL = self.ILInitial + self.OO = 0 + self.AS = np.squeeze( + np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10)) + ) # arriced shipment + self.AO = np.squeeze( + np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10)) + ) # arrived order + if self.agentNum != 0: + for i in range(self.config.leadRecOrderUp_aux[self.agentNum - 1]): + self.AO[i] = self.AOInitial[self.agentNum - 1] + for i in range(self.config.leadRecItemUp[self.agentNum]): + self.AS[i] = self.ASInitial + self.curReward = 0 # the reward observed at the current step + self.cumReward = 0 # cumulative reward; reset at the begining of each episode + self.action = [] + self.hist = [] + self.hist2 = [] + self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017 + self.T = T + self.curObservation = self.getCurState(1) # this function gets the current state of the game + self.nextObservation = [] + if self.compType == 'srdqn': + self.currentState = np.stack([self.curObservation for _ in range(self.config.multPerdInpt)], axis=0) + + # updates the IL and OO at time t, after recieving "rec" number of items + def recieveItems(self, time: int) -> None: + self.IL = self.IL + self.AS[time] # inverntory level update + self.OO = self.OO - self.AS[time] # invertory in transient update + + # find action Value associated with the action list + def actionValue(self, curTime: int) -> int: + if self.config.fixedAction: + a = self.config.actionList[np.argmax(self.action)] + else: + # "d + x" rule + if self.compType == 'srdqn': + a = max(0, self.config.actionList[np.argmax(self.action)] * self.config.action_step + self.AO[curTime]) + elif self.compType == 'rnd': + a = max(0, self.config.actionList[np.argmax(self.action)] + self.AO[curTime]) + else: + a = max(0, self.config.actionListOpt[np.argmax(self.action)]) + + return a + + # getReward returns the reward at the current state + def getReward(self) -> None: + # cost (holding + backorder) for one time unit + self.curReward = (self.c_p * max(0, -self.IL) + self.c_h * max(0, self.IL)) / 200. # self.config.Ttest # + self.curReward = -self.curReward + # make reward negative, because it is the cost + + # sum total reward of each agent + self.cumReward = self.config.gamma * self.cumReward + self.curReward + + # This function returns a np.array of the current state of the agent + def getCurState(self, t: int) -> np.ndarray: + if self.config.ifUseASAO: + if self.config.if_use_AS_t_plus_1: + curState = np.array( + [-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t], self.AO[t]] + ) + else: + curState = np.array( + [-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t - 1], self.AO[t]] + ) + else: + curState = np.array([-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO]) + + if self.config.ifUseActionInD: + a = self.config.actionList[np.argmax(self.action)] + curState = np.concatenate((curState, np.array([a]))) + + return curState diff --git a/dizoo/beergame/envs/__init__.py b/dizoo/beergame/envs/__init__.py new file mode 100644 index 0000000000..d4ffbfd452 --- /dev/null +++ b/dizoo/beergame/envs/__init__.py @@ -0,0 +1,2 @@ +from .clBeergame import clBeerGame +from .beergame_core import BeerGame diff --git a/dizoo/beergame/envs/beergame_core.py b/dizoo/beergame/envs/beergame_core.py new file mode 100644 index 0000000000..2f0ac61910 --- /dev/null +++ b/dizoo/beergame/envs/beergame_core.py @@ -0,0 +1,112 @@ +from __future__ import print_function +from dizoo.beergame.envs import clBeerGame +from torch import Tensor +import numpy as np +import random +from .utils import get_config, update_config +import gym +import os +from typing import Optional + + +class BeerGame(): + + def __init__(self, role: int, agent_type: str, demandDistribution: int) -> None: + self._cfg, unparsed = get_config() + self._role = role + # prepare loggers and directories + # prepare_dirs_and_logger(self._cfg) + self._cfg = update_config(self._cfg) + + # set agent type + if agent_type == 'bs': + self._cfg.agentTypes = ["bs", "bs", "bs", "bs"] + elif agent_type == 'Strm': + self._cfg.agentTypes = ["Strm", "Strm", "Strm", "Strm"] + self._cfg.agentTypes[role] = "srdqn" + + self._cfg.demandDistribution = demandDistribution + + # load demands:0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data + if self._cfg.observation_data: + adsr = 'data/demandTr-obs-' + elif self._cfg.demandDistribution == 3: + if self._cfg.scaled: + adsr = 'data/basket_data/scaled' + else: + adsr = 'data/basket_data' + direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy') + self._demandTr = np.load(direc) + print("loaded training set=", direc) + elif self._cfg.demandDistribution == 4: + if self._cfg.scaled: + adsr = 'data/forecast_data/scaled' + else: + adsr = 'data/forecast_data' + direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy') + self._demandTr = np.load(direc) + print("loaded training set=", direc) + else: + if self._cfg.demandDistribution == 0: # uniform + self._demandTr = np.random.randint(0, self._cfg.demandUp, size=[self._cfg.demandSize, self._cfg.TUp]) + elif self._cfg.demandDistribution == 1: # normal distribution + self._demandTr = np.round( + np.random.normal( + self._cfg.demandMu, self._cfg.demandSigma, size=[self._cfg.demandSize, self._cfg.TUp] + ) + ).astype(int) + elif self._cfg.demandDistribution == 2: # the sequence of 4,4,4,4,8,... + self._demandTr = np.concatenate( + (4 * np.ones((self._cfg.demandSize, 4)), 8 * np.ones((self._cfg.demandSize, 98))), axis=1 + ).astype(int) + + # initilize an instance of Beergame + self._env = clBeerGame(self._cfg) + self.observation_space = gym.spaces.Box( + low=float("-inf"), + high=float("inf"), + shape=(self._cfg.stateDim * self._cfg.multPerdInpt, ), + dtype=np.float32 + ) # state_space = state_dim * m (considering the reward delay) + self.action_space = gym.spaces.Discrete(self._cfg.actionListLen) # length of action list + self.reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32) + + # get the length of the demand. + self._demand_len = np.shape(self._demandTr)[0] + + def reset(self): + self._env.resetGame(demand=self._demandTr[random.randint(0, self._demand_len - 1)]) + obs = [i for item in self._env.players[self._role].currentState for i in item] + return obs + + def seed(self, seed: int) -> None: + self._seed = seed + np.random.seed(self._seed) + + def close(self) -> None: + pass + + def step(self, action: np.ndarray): + self._env.handelAction(action) + self._env.next() + newstate = np.append( + self._env.players[self._role].currentState[1:, :], [self._env.players[self._role].nextObservation], axis=0 + ) + self._env.players[self._role].currentState = newstate + obs = [i for item in newstate for i in item] + rew = self._env.players[self._role].curReward + done = (self._env.curTime == self._env.T) + info = {} + return obs, rew, done, info + + def reward_shaping(self, reward: Tensor) -> Tensor: + self._totRew, self._cumReward = self._env.distTotReward(self._role) + reward += (self._cfg.distCoeff / 3) * ((self._totRew - self._cumReward) / (self._env.T)) + return reward + + def enable_save_figure(self, figure_path: Optional[str] = None) -> None: + self._cfg.ifSaveFigure = True + if figure_path is None: + figure_path = './' + self._cfg.figure_dir = figure_path + self._env.doTestMid(self._demandTr[random.randint(0, self._demand_len - 1)]) diff --git a/dizoo/beergame/envs/beergame_env.py b/dizoo/beergame/envs/beergame_env.py new file mode 100644 index 0000000000..c8e76de9fe --- /dev/null +++ b/dizoo/beergame/envs/beergame_env.py @@ -0,0 +1,84 @@ +import numpy as np +from dizoo.beergame.envs.beergame_core import BeerGame +from typing import Union, List, Optional + +from ding.envs import BaseEnv, BaseEnvTimestep +from ding.utils import ENV_REGISTRY +from ding.torch_utils import to_ndarray +import copy + + +@ENV_REGISTRY.register('beergame') +class BeerGameEnv(BaseEnv): + + def __init__(self, cfg: dict) -> None: + self._cfg = cfg + self._init_flag = False + + def reset(self) -> np.ndarray: + if not self._init_flag: + self._env = BeerGame(self._cfg.role, self._cfg.agent_type, self._cfg.demandDistribution) + self._observation_space = self._env.observation_space + self._action_space = self._env.action_space + self._reward_space = self._env.reward_space + self._init_flag = True + if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed: + np_seed = 100 * np.random.randint(1, 1000) + self._env.seed(self._seed + np_seed) + elif hasattr(self, '_seed'): + self._env.seed(self._seed) + self._final_eval_reward = 0 + obs = self._env.reset() + obs = to_ndarray(obs).astype(np.float32) + return obs + + def close(self) -> None: + if self._init_flag: + self._env.close() + self._init_flag = False + + def seed(self, seed: int, dynamic_seed: bool = True) -> None: + self._seed = seed + self._dynamic_seed = dynamic_seed + np.random.seed(self._seed) + + def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep: + if isinstance(action, np.ndarray) and action.shape == (1, ): + action = action.squeeze() # 0-dim array + obs, rew, done, info = self._env.step(action) + self._final_eval_reward += rew + if done: + info['final_eval_reward'] = self._final_eval_reward + obs = to_ndarray(obs).astype(np.float32) + rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,) + return BaseEnvTimestep(obs, rew, done, info) + + def reward_shaping(self, transitions: List[dict]) -> List[dict]: + new_transitions = copy.deepcopy(transitions) + for trans in new_transitions: + trans['reward'] = self._env.reward_shaping(trans['reward']) + return new_transitions + + def random_action(self) -> np.ndarray: + random_action = self.action_space.sample() + if isinstance(random_action, int): + random_action = to_ndarray([random_action], dtype=np.int64) + return random_action + + def enable_save_figure(self, figure_path: Optional[str] = None) -> None: + self._env.enable_save_figure(figure_path) + + @property + def observation_space(self) -> int: + return self._observation_space + + @property + def action_space(self) -> int: + return self._action_space + + @property + def reward_space(self) -> int: + return self._reward_space + + def __repr__(self) -> str: + return "DI-engine Beergame Env" diff --git a/dizoo/beergame/envs/clBeergame.py b/dizoo/beergame/envs/clBeergame.py new file mode 100644 index 0000000000..9a237fbb68 --- /dev/null +++ b/dizoo/beergame/envs/clBeergame.py @@ -0,0 +1,439 @@ +# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL. +import numpy as np +from random import randint +from .BGAgent import Agent +from matplotlib import rc +rc('text', usetex=True) +from .plotting import plotting, savePlot +import matplotlib.pyplot as plt +import os +import time +from time import gmtime, strftime + + +class clBeerGame(object): + + def __init__(self, config): + self.config = config + self.curGame = 0 # The number associated with the current game (counter of the game) + self.curTime = 0 + self.totIterPlayed = 0 # total iterations of the game, played so far in this and previous games + self.players = self.createAgent() # create the agents + self.T = 0 + self.demand = [] + self.ifOptimalSolExist = self.config.ifOptimalSolExist + self.getOptimalSol() + self.totRew = 0 # it is reward of all players obtained for the current player. + self.resultTest = [] + self.runnerMidlResults = [] # stores the results to use in runner comparisons + self.runnerFinlResults = [] # stores the results to use in runner comparisons + self.middleTestResult = [ + ] # stores the whole middle results of bs, Strm, and random to avoid doing same tests multiple of times. + self.runNumber = 0 # the runNumber which is used when use runner + self.strNum = 0 # the runNumber which is used when use runner + + # createAgent : Create agent objects (agentNum,IL,OO,c_h,c_p,type,config) + def createAgent(self): + agentTypes = self.config.agentTypes + return [ + Agent( + i, self.config.ILInit[i], self.config.AOInit, self.config.ASInit[i], self.config.c_h[i], + self.config.c_p[i], self.config.eta[i], agentTypes[i], self.config + ) for i in range(self.config.NoAgent) + ] + + # planHorizon : Find a random planning horizon + def planHorizon(self): + # TLow: minimum number for the planning horizon # TUp: maximum number for the planning horizon + # output: The planning horizon which is chosen randomly. + return randint(self.config.TLow, self.config.TUp) + + # this function resets the game for start of the new game + def resetGame(self, demand: np.ndarray): + self.demand = demand + self.curTime = 0 + self.curGame += 1 + self.totIterPlayed += self.T + self.T = self.planHorizon() + # reset the required information of player for each episode + for k in range(0, self.config.NoAgent): + self.players[k].resetPlayer(self.T) + + # update OO when there are initial IL,AO,AS + self.update_OO() + + # correction on cost at time T according to the cost of the other players + def getTotRew(self): + totRew = 0 + for i in range(self.config.NoAgent): + # sum all rewards for the agents and make correction + totRew += self.players[i].cumReward + + for i in range(self.config.NoAgent): + self.players[i].curReward += self.players[i].eta * (totRew - self.players[i].cumReward) # /(self.T) + + # make correction to the rewards in the experience replay for all iterations of current game + def distTotReward(self, role: int): + totRew = 0 + optRew = 0.1 # why? + for i in range(self.config.NoAgent): + # sum all rewards for the agents and make correction + totRew += self.players[i].cumReward + totRew += optRew + + return totRew, self.players[role].cumReward + + def getAction(self, k: int, action: np.ndarray, playType="train"): + if playType == "train": + if self.players[k].compType == "srdqn": + self.players[k].action = np.zeros(self.config.actionListLen) + self.players[k].action[action] = 1 + elif self.players[k].compType == "Strm": + self.players[k].action = np.zeros(self.config.actionListLenOpt) + self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt)\ + - max(0, round(self.players[k].AO[self.curTime] + \ + self.players[k].alpha_b*(self.players[k].IL - self.players[k].a_b) + \ + self.players[k].betta_b*(self.players[k].OO - self.players[k].b_b)))))] = 1 + elif self.players[k].compType == "rnd": + self.players[k].action = np.zeros(self.config.actionListLen) + a = np.random.randint(self.config.actionListLen) + self.players[k].action[a] = 1 + elif self.players[k].compType == "bs": + self.players[k].action = np.zeros(self.config.actionListLenOpt) + if self.config.demandDistribution == 2: + if self.curTime and self.config.use_initial_BS <= 4: + self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \ + max(0, (self.players[k].int_bslBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1 + else: + self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \ + max(0, (self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1 + else: + self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \ + max(0, (self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1 + elif playType == "test": + if self.players[k].compTypeTest == "srdqn": + self.players[k].action = np.zeros(self.config.actionListLen) + self.players[k].action = self.players[k].brain.getDNNAction(self.playType) + elif self.players[k].compTypeTest == "Strm": + self.players[k].action = np.zeros(self.config.actionListLenOpt) + + self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt)-\ + max(0,round(self.players[k].AO[self.curTime] +\ + self.players[k].alpha_b*(self.players[k].IL - self.players[k].a_b) +\ + self.players[k].betta_b*(self.players[k].OO - self.players[k].b_b)))))] = 1 + elif self.players[k].compTypeTest == "rnd": + self.players[k].action = np.zeros(self.config.actionListLen) + a = np.random.randint(self.config.actionListLen) + self.players[k].action[a] = 1 + elif self.players[k].compTypeTest == "bs": + self.players[k].action = np.zeros(self.config.actionListLenOpt) + + if self.config.demandDistribution == 2: + if self.curTime and self.config.use_initial_BS <= 4: + self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\ + max(0,(self.players[k].int_bslBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1 + else: + self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\ + max(0,(self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1 + else: + self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\ + max(0,(self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1 + else: + # not a valid player is defined. + raise Exception('The player type is not defined or it is not a valid type.!') + + def next(self): + # get a random leadtime + leadTimeIn = randint( + self.config.leadRecItemLow[self.config.NoAgent - 1], self.config.leadRecItemUp[self.config.NoAgent - 1] + ) + # handle the most upstream recieved shipment + self.players[self.config.NoAgent - 1].AS[self.curTime + + leadTimeIn] += self.players[self.config.NoAgent - + 1].actionValue(self.curTime) + + for k in range(self.config.NoAgent - 1, -1, -1): # [3,2,1,0] + + # get current IL and Backorder + current_IL = max(0, self.players[k].IL) + current_backorder = max(0, -self.players[k].IL) + + # TODO: We have get the AS and AO from the UI and update our AS and AO, so that code update the corresponding variables + + # increase IL and decrease OO based on the action, for the next period + self.players[k].recieveItems(self.curTime) + + # observe the reward + possible_shipment = min( + current_IL + self.players[k].AS[self.curTime], current_backorder + self.players[k].AO[self.curTime] + ) + + # plan arrivals of the items to the downstream agent + if self.players[k].agentNum > 0: + leadTimeIn = randint(self.config.leadRecItemLow[k - 1], self.config.leadRecItemUp[k - 1]) + self.players[k - 1].AS[self.curTime + leadTimeIn] += possible_shipment + + # update IL + self.players[k].IL -= self.players[k].AO[self.curTime] + # observe the reward + self.players[k].getReward() + self.players[k].hist[-1][-2] = self.players[k].curReward + self.players[k].hist2[-1][-2] = self.players[k].curReward + + # update next observation + self.players[k].nextObservation = self.players[k].getCurState(self.curTime + 1) + + if self.config.ifUseTotalReward: + # correction on cost at time T + if self.curTime == self.T: + self.getTotRew() + + self.curTime += 1 + + def handelAction(self, action: np.ndarray, playType="train"): + # get random lead time + leadTime = randint(self.config.leadRecOrderLow[0], self.config.leadRecOrderUp[0]) + # set AO + self.players[0].AO[self.curTime] += self.demand[self.curTime] + for k in range(0, self.config.NoAgent): + self.getAction(k, action, playType) + + self.players[k].srdqnBaseStock += [self.players[k].actionValue( \ + self.curTime) + self.players[k].IL + self.players[k].OO] + + # update hist for the plots + self.players[k].hist += [[self.curTime, self.players[k].IL, self.players[k].OO,\ + self.players[k].actionValue(self.curTime), self.players[k].curReward, self.players[k].srdqnBaseStock[-1]]] + + if self.players[k].compType == "srdqn": + self.players[k].hist2 += [[self.curTime, self.players[k].IL, self.players[k].OO, self.players[k].AO[self.curTime], self.players[k].AS[self.curTime], \ + self.players[k].actionValue(self.curTime), self.players[k].curReward, \ + self.config.actionList[np.argmax(self.players[k].action)]]] + + else: + self.players[k].hist2 += [[self.curTime, self.players[k].IL, self.players[k].OO, self.players[k].AO[self.curTime], self.players[k].AS[self.curTime], \ + self.players[k].actionValue(self.curTime), self.players[k].curReward, 0]] + + # updates OO and AO at time t+1 + self.players[k].OO += self.players[k].actionValue(self.curTime) # open order level update + leadTime = randint(self.config.leadRecOrderLow[k], self.config.leadRecOrderUp[k]) + if self.players[k].agentNum < self.config.NoAgent - 1: + self.players[k + 1].AO[self.curTime + leadTime] += self.players[k].actionValue( + self.curTime + ) # open order level update + + # check the Shang and Song (2003) condition, and if it works, obtains the base stock policy values for each agent + def getOptimalSol(self): + # if self.config.NoAgent !=1: + if self.config.NoAgent != 1 and 1 == 2: + # check the Shang and Song (2003) condition. + for k in range(self.config.NoAgent - 1): + if not (self.players[k].c_h == self.players[k + 1].c_h and self.players[k + 1].c_p == 0): + self.ifOptimalSolExist = False + + # if the Shang and Song (2003) condition satisfied, it runs the algorithm + if self.ifOptimalSolExist == True: + calculations = np.zeros((7, self.config.NoAgent)) + for k in range(self.config.NoAgent): + # DL_high + calculations[0][k] = ((self.config.leadRecItemLow + self.config.leadRecItemUp + 2) / 2 \ + + (self.config.leadRecOrderLow + self.config.leadRecOrderUp + 2) / 2) * \ + (self.config.demandUp - self.config.demandLow - 1) + if k > 0: + calculations[0][k] += calculations[0][k - 1] + # probability_high + nominator_ch = 0 + low_denominator_ch = 0 + for j in range(k, self.config.NoAgent): + if j < self.config.NoAgent - 1: + nominator_ch += self.players[j + 1].c_h + low_denominator_ch += self.players[j].c_h + if k == 0: + high_denominator_ch = low_denominator_ch + calculations[2][k] = (self.players[0].c_p + + nominator_ch) / (self.players[0].c_p + low_denominator_ch + 0.0) + # probability_low + calculations[3][k] = (self.players[0].c_p + + nominator_ch) / (self.players[0].c_p + high_denominator_ch + 0.0) + # S_high + calculations[4] = np.round(np.multiply(calculations[0], calculations[2])) + # S_low + calculations[5] = np.round(np.multiply(calculations[0], calculations[3])) + # S_avg + calculations[6] = np.round(np.mean(calculations[4:6], axis=0)) + # S', set the base stock values into each agent. + for k in range(self.config.NoAgent): + if k == 0: + self.players[k].bsBaseStock = calculations[6][k] + + else: + self.players[k].bsBaseStock = calculations[6][k] - calculations[6][k - 1] + if self.players[k].bsBaseStock < 0: + self.players[k].bsBaseStock = 0 + elif self.config.NoAgent == 1: + if self.config.demandDistribution == 0: + self.players[0].bsBaseStock = np.ceil( + self.config.c_h[0] / (self.config.c_h[0] + self.config.c_p[0] + 0.0) + ) * ((self.config.demandUp - self.config.demandLow - 1) / 2) * self.config.leadRecItemUp + elif 1 == 1: + f = self.config.f + f_init = self.config.f_init + for k in range(self.config.NoAgent): + self.players[k].bsBaseStock = f[k] + self.players[k].int_bslBaseStock = f_init[k] + + def update_OO(self): + for k in range(0, self.config.NoAgent): + if k < self.config.NoAgent - 1: + self.players[k].OO = sum(self.players[k + 1].AO) + sum(self.players[k].AS) + else: + self.players[k].OO = sum(self.players[k].AS) + + def doTestMid(self, demandTs): + self.resultTest = [] + m = strftime("%Y-%m-%d-%H-%M-%S", gmtime()) + self.doTest(m, demandTs) + print("---------------------------------------------------------------------------------------") + resultSummary = np.array(self.resultTest).mean(axis=0).tolist() + + result_srdqn = ', '.join(map("{:.2f}".format, resultSummary[0])) + result_rand = ', '.join(map("{:.2f}".format, resultSummary[1])) + result_strm = ', '.join(map("{:.2f}".format, resultSummary[2])) + if self.ifOptimalSolExist: + result_bs = ', '.join(map("{:.2f}".format, resultSummary[3])) + print( + 'SUMMARY; {0:s}; ITER= {1:d}; OURPOLICY= [{2:s}]; SUM = {3:2.4f}; Rand= [{4:s}]; SUM = {5:2.4f}; STRM= [{6:s}]; SUM = {7:2.4f}; BS= [{8:s}]; SUM = {9:2.4f}' + .format( + strftime("%Y-%m-%d %H:%M:%S", gmtime()), self.curGame, result_srdqn, sum(resultSummary[0]), + result_rand, sum(resultSummary[1]), result_strm, sum(resultSummary[2]), result_bs, + sum(resultSummary[3]) + ) + ) + + else: + print( + 'SUMMARY; {0:s}; ITER= {1:d}; OURPOLICY= [{2:s}]; SUM = {3:2.4f}; Rand= [{4:s}]; SUM = {5:2.4f}; STRM= [{6:s}]; SUM = {7:2.4f}' + .format( + strftime("%Y-%m-%d %H:%M:%S", gmtime()), self.curGame, result_srdqn, sum(resultSummary[0]), + result_rand, sum(resultSummary[1]), result_strm, sum(resultSummary[2]) + ) + ) + + print("=======================================================================================") + + def doTest(self, m, demand): + import matplotlib.pyplot as plt + if self.config.ifSaveFigure: + plt.figure(self.curGame, figsize=(12, 8), dpi=80, facecolor='w', edgecolor='k') + + # self.demand = demand + # use dnn to get output. + Rsltdnn, plt = self.tester(self.config.agentTypes, plt, 'b', 'OurPolicy', m) + baseStockdata = self.players[0].srdqnBaseStock + # # use random to get output. + RsltRnd, plt = self.tester(["rnd", "rnd", "rnd", "rnd"], plt, 'y', 'RAND', m) + + # use formual to get output. + RsltStrm, plt = self.tester(["Strm", "Strm", "Strm", "Strm"], plt, 'g', 'Strm', m) + + # use optimal strategy to get output, if it works. + if self.ifOptimalSolExist: + if self.config.agentTypes == ["srdqn", "Strm", "Strm", "Strm"]: + Rsltbs, plt = self.tester(["bs", "Strm", "Strm", "Strm"], plt, 'r', 'Strm-BS', m) + elif self.config.agentTypes == ["Strm", "srdqn", "Strm", "Strm"]: + Rsltbs, plt = self.tester(["Strm", "bs", "Strm", "Strm"], plt, 'r', 'Strm-BS', m) + elif self.config.agentTypes == ["Strm", "Strm", "srdqn", "Strm"]: + Rsltbs, plt = self.tester(["Strm", "Strm", "bs", "Strm"], plt, 'r', 'Strm-BS', m) + elif self.config.agentTypes == ["Strm", "Strm", "Strm", "srdqn"]: + Rsltbs, plt = self.tester(["Strm", "Strm", "Strm", "bs"], plt, 'r', 'Strm-BS', m) + elif self.config.agentTypes == ["srdqn", "rnd", "rnd", "rnd"]: + Rsltbs, plt = self.tester(["bs", "rnd", "rnd", "rnd"], plt, 'r', 'RND-BS', m) + elif self.config.agentTypes == ["rnd", "srdqn", "rnd", "rnd"]: + Rsltbs, plt = self.tester(["rnd", "bs", "rnd", "rnd"], plt, 'r', 'RND-BS', m) + elif self.config.agentTypes == ["rnd", "rnd", "srdqn", "rnd"]: + Rsltbs, plt = self.tester(["rnd", "rnd", "bs", "rnd"], plt, 'r', 'RND-BS', m) + elif self.config.agentTypes == ["rnd", "rnd", "rnd", "srdqn"]: + Rsltbs, plt = self.tester(["rnd", "rnd", "rnd", "bs"], plt, 'r', 'RND-BS', m) + else: + Rsltbs, plt = self.tester(["bs", "bs", "bs", "bs"], plt, 'r', 'BS', m) + # hold the results of the optimal solution + self.middleTestResult += [[RsltRnd, RsltStrm, Rsltbs]] + else: + self.middleTestResult += [[RsltRnd, RsltStrm]] + + else: + # return the obtained results into their lists + RsltRnd = self.middleTestResult[m][0] + RsltStrm = self.middleTestResult[m][1] + if self.ifOptimalSolExist: + Rsltbs = self.middleTestResult[m][2] + + # save the figure + if self.config.ifSaveFigure: + savePlot(self.players, self.curGame, Rsltdnn, RsltStrm, Rsltbs, RsltRnd, self.config, m) + plt.close() + + result_srdqn = ', '.join(map("{:.2f}".format, Rsltdnn)) + result_rand = ', '.join(map("{:.2f}".format, RsltRnd)) + result_strm = ', '.join(map("{:.2f}".format, RsltStrm)) + if self.ifOptimalSolExist: + result_bs = ', '.join(map("{:.2f}".format, Rsltbs)) + print( + 'output; {0:s}; Iter= {1:s}; SRDQN= [{2:s}]; sum = {3:2.4f}; Rand= [{4:s}]; sum = {5:2.4f}; Strm= [{6:s}]; sum = {7:2.4f}; BS= [{8:s}]; sum = {9:2.4f}' + .format( + strftime("%Y-%m-%d %H:%M:%S", gmtime()), str(str(self.curGame) + "-" + str(m)), result_srdqn, + sum(Rsltdnn), result_rand, sum(RsltRnd), result_strm, sum(RsltStrm), result_bs, sum(Rsltbs) + ) + ) + self.resultTest += [[Rsltdnn, RsltRnd, RsltStrm, Rsltbs]] + + else: + print( + 'output; {0:s}; Iter= {1:s}; SRDQN= [{2:s}]; sum = {3:2.4f}; Rand= [{4:s}]; sum = {5:2.4f}; Strm= [{6:s}]; sum = {7:2.4f}' + .format( + strftime("%Y-%m-%d %H:%M:%S", gmtime()), str(str(self.curGame) + "-" + str(m)), result_srdqn, + sum(Rsltdnn), result_rand, sum(RsltRnd), result_strm, sum(RsltStrm) + ) + ) + + self.resultTest += [[Rsltdnn, RsltRnd, RsltStrm]] + + return sum(Rsltdnn) + + def tester(self, testType, plt, colori, labeli, m): + + # set computation type for test + for k in range(0, self.config.NoAgent): + # self.players[k].compTypeTest = testType[k] + self.players[k].compType = testType[k] + # run the episode to get the results. + if labeli != 'OurPolicy': + result = self.playGame(self.demand) + else: + result = [-1 * self.players[i].cumReward for i in range(0, self.config.NoAgent)] + # add the results into the figure + if self.config.ifSaveFigure: + plt = plotting(plt, [np.array(self.players[i].hist) for i in range(0, self.config.NoAgent)], colori, labeli) + if self.config.ifsaveHistInterval and ((self.curGame == 0) or (self.curGame == 1) or (self.curGame == 2) or (self.curGame == 3) or ((self.curGame - 1) % self.config.saveHistInterval == 0)\ + or ((self.curGame) % self.config.saveHistInterval == 0) or ((self.curGame) % self.config.saveHistInterval == 1) \ + or ((self.curGame) % self.config.saveHistInterval == 2)) : + for k in range(0, self.config.NoAgent): + name = labeli + "-" + str(self.curGame) + "-" + "player" + "-" + str(k) + "-" + str(m) + np.save(os.path.join(self.config.model_dir, name), np.array(self.players[k].hist2)) + + # save the figure of base stocks + # if self.config.ifSaveFigure and (self.curGame in range(self.config.saveFigInt[0],self.config.saveFigInt[1])): + # for k in range(self.config.NoAgent): + # if self.players[k].compTypeTest == 'dnn': + # plotBaseStock(self.players[k].srdqnBaseStock, 'b', 'base stock of agent '+ str(self.players[k].agentNum), self.curGame, self.config, m) + + return result, plt + + def playGame(self, demand): + self.resetGame(demand) + + # run the game + while self.curTime < self.T: + self.handelAction(np.array(0)) # action won't be used. + self.next() + return [-1 * self.players[i].cumReward for i in range(0, self.config.NoAgent)] diff --git a/dizoo/beergame/envs/plotting.py b/dizoo/beergame/envs/plotting.py new file mode 100644 index 0000000000..57776c9641 --- /dev/null +++ b/dizoo/beergame/envs/plotting.py @@ -0,0 +1,72 @@ +# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL. +import os +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from pylab import * + + +# plotting +def plotting(plt, data, colori, pltLabel): + # plt.hold(True) + + for i in range(np.shape(data)[0]): + plt.subplot(4, 5, 5 * i + 1) + plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[1, :], colori, label=pltLabel) + plt.xlabel('Time') + plt.ylabel('IL') + plt.grid(True) + + plt.subplot(4, 5, 5 * i + 2) + plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[2, :], colori, label=pltLabel) + plt.xlabel('Time') + plt.ylabel('OO') + plt.grid(True) + + plt.subplot(4, 5, 5 * i + 3) + plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[3, :], colori, label=pltLabel) + plt.xlabel('Time') + plt.ylabel('a') + plt.grid(True) + + plt.subplot(4, 5, 5 * i + 4) + plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[5, :], colori, label=pltLabel) + plt.xlabel('Time') + plt.ylabel('OUTL') + plt.grid(True) + + plt.subplot(4, 5, 5 * i + 5) + plt.plot(np.transpose(data[i])[0, :], -1 * np.transpose(data[i])[4, :], colori, label=pltLabel) + plt.xlabel('Time') + plt.ylabel('r') + plt.grid(True) + + return plt + + +def savePlot(players, curGame, Rsltdnn, RsltFrmu, RsltOptm, RsltRnd, config, m): + #add title to plot + if config.if_titled_figure: + plt.suptitle( + "sum OurPolicy=" + str(round(sum(Rsltdnn), 2)) + "; sum Strm=" + str(round(sum(RsltFrmu), 2)) + + "; sum BS=" + str(round(sum(RsltOptm), 2)) + "; sum Rnd=" + str(round(sum(RsltRnd), 2)) + "\n" + + "Ag OurPolicy=" + str([round(Rsltdnn[i], 2) for i in range(config.NoAgent)]) + "; Ag Strm=" + + str([round(RsltFrmu[i], 2) for i in range(config.NoAgent)]) + "; Ag BS=" + + str([round(RsltOptm[i], 2) for i in range(config.NoAgent)]) + "; Ag Rnd=" + + str([round(RsltRnd[i], 2) for i in range(config.NoAgent)]), + fontsize=12 + ) + + #insert legend to the figure + legend = plt.legend(bbox_to_anchor=(-1.4, -.165, 1., -.102), shadow=True, ncol=4) + + # configures spaces between subplots + plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5) + # save the figure + path = os.path.join(config.figure_dir, 'saved_figures/') + if not os.path.exists(path): + os.mkdir(path) + plt.savefig(path + str(curGame) + '-' + str(m) + '.png', format='png') + print("figure" + str(curGame) + ".png saved in folder \"saved_figures\"") + plt.close(curGame) diff --git a/dizoo/beergame/envs/utils.py b/dizoo/beergame/envs/utils.py new file mode 100644 index 0000000000..2a6cf6f83d --- /dev/null +++ b/dizoo/beergame/envs/utils.py @@ -0,0 +1,355 @@ +import argparse +import os +import numpy as np + + +def str2bool(v): + return v.lower() in ('true', '1') + + +arg_lists = [] +parser = argparse.ArgumentParser() + + +def add_argument_group(name): + arg = parser.add_argument_group(name) + arg_lists.append(arg) + return arg + + +# crm +game_arg = add_argument_group('BeerGame') +game_arg.add_argument('--task', type=str, default='bg') +game_arg.add_argument( + '--fixedAction', + type=str2bool, + default='False', + help='if you want to have actions in [0,actionMax] set it to True. with False it will set it [actionLow, actionUp]' +) +game_arg.add_argument( + '--observation_data', + type=str2bool, + default=False, + help='if it is True, then it uses the data that is generated by based on few real world observation' +) +game_arg.add_argument('--data_id', type=int, default=22, help='the default item id for the basket dataset') +game_arg.add_argument('--TLow', type=int, default=100, help='duration of one GAME (lower bound)') +game_arg.add_argument('--TUp', type=int, default=100, help='duration of one GAME (upper bound)') +game_arg.add_argument( + '--demandDistribution', + type=int, + default=0, + help='0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data' +) +game_arg.add_argument( + '--scaled', type=str2bool, default=False, help='if true it uses the (if) existing scaled parameters' +) +game_arg.add_argument('--demandSize', type=int, default=6100, help='the size of demand dataset') +game_arg.add_argument('--demandLow', type=int, default=0, help='the lower bound of random demand') +game_arg.add_argument('--demandUp', type=int, default=3, help='the upper bound of random demand') +game_arg.add_argument('--demandMu', type=float, default=10, help='the mu of the normal distribution for demand ') +game_arg.add_argument('--demandSigma', type=float, default=2, help='the sigma of the normal distribution for demand ') +game_arg.add_argument('--actionMax', type=int, default=2, help='it works when fixedAction is True') +game_arg.add_argument( + '--actionUp', type=int, default=2, help='bounds on my decision (upper bound), it works when fixedAction is True' +) +game_arg.add_argument( + '--actionLow', type=int, default=-2, help='bounds on my decision (lower bound), it works when fixedAction is True' +) +game_arg.add_argument( + '--action_step', type=int, default=1, help='The obtained action value by dnn is multiplied by this value' +) +game_arg.add_argument('--actionList', type=list, default=[], help='The list of the available actions') +game_arg.add_argument('--actionListLen', type=int, default=0, help='the length of the action list') +game_arg.add_argument( + '--actionListOpt', type=int, default=0, help='the action list which is used in optimal and sterman' +) +game_arg.add_argument('--actionListLenOpt', type=int, default=0, help='the length of the actionlistopt') +game_arg.add_argument('--agentTypes', type=list, default=['dnn', 'dnn', 'dnn', 'dnn'], help='the player types') +game_arg.add_argument( + '--agent_type1', type=str, default='dnn', help='the player types for agent 1, it can be dnn, Strm, bs, rnd' +) +game_arg.add_argument( + '--agent_type2', type=str, default='dnn', help='the player types for agent 2, it can be dnn, Strm, bs, rnd' +) +game_arg.add_argument( + '--agent_type3', type=str, default='dnn', help='the player types for agent 3, it can be dnn, Strm, bs, rnd' +) +game_arg.add_argument( + '--agent_type4', type=str, default='dnn', help='the player types for agent 4, it can be dnn, Strm, bs, rnd' +) +game_arg.add_argument('--NoAgent', type=int, default=4, help='number of agents, currently it should be in {1,2,3,4}') +game_arg.add_argument('--cp1', type=float, default=2.0, help='shortage cost of player 1') +game_arg.add_argument('--cp2', type=float, default=0.0, help='shortage cost of player 2') +game_arg.add_argument('--cp3', type=float, default=0.0, help='shortage cost of player 3') +game_arg.add_argument('--cp4', type=float, default=0.0, help='shortage cost of player 4') +game_arg.add_argument('--ch1', type=float, default=2.0, help='holding cost of player 1') +game_arg.add_argument('--ch2', type=float, default=2.0, help='holding cost of player 2') +game_arg.add_argument('--ch3', type=float, default=2.0, help='holding cost of player 3') +game_arg.add_argument('--ch4', type=float, default=2.0, help='holding cost of player 4') +game_arg.add_argument('--alpha_b1', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 1') +game_arg.add_argument('--alpha_b2', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 2') +game_arg.add_argument('--alpha_b3', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 3') +game_arg.add_argument('--alpha_b4', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 4') +game_arg.add_argument('--betta_b1', type=float, default=-0.2, help='beta of Sterman formula parameter for player 1') +game_arg.add_argument('--betta_b2', type=float, default=-0.2, help='beta of Sterman formula parameter for player 2') +game_arg.add_argument('--betta_b3', type=float, default=-0.2, help='beta of Sterman formula parameter for player 3') +game_arg.add_argument('--betta_b4', type=float, default=-0.2, help='beta of Sterman formula parameter for player 4') +game_arg.add_argument('--eta', type=list, default=[0, 4, 4, 4], help='the total cost regulazer') +game_arg.add_argument('--distCoeff', type=int, default=20, help='the total cost regulazer') +game_arg.add_argument( + '--ifUseTotalReward', + type=str2bool, + default='False', + help='if you want to have the total rewards in the experience replay, set it to true.' +) +game_arg.add_argument( + '--ifUsedistTotReward', + type=str2bool, + default='True', + help='If use correction to the rewards in the experience replay for all iterations of current game' +) +game_arg.add_argument( + '--ifUseASAO', + type=str2bool, + default='True', + help='if use AS and AO, i.e., received shipment and received orders in the input of DNN' +) +game_arg.add_argument('--ifUseActionInD', type=str2bool, default='False', help='if use action in the input of DNN') +game_arg.add_argument( + '--stateDim', type=int, default=5, help='Number of elements in the state desciptor - Depends on ifUseASAO' +) +game_arg.add_argument('--iftl', type=str2bool, default=False, help='if apply transfer learning') +game_arg.add_argument( + '--ifTransferFromSmallerActionSpace', + type=str2bool, + default=False, + help='if want to transfer knowledge from a network with different action space size.' +) +game_arg.add_argument( + '--baseActionSize', + type=int, + default=5, + help='if ifTransferFromSmallerActionSpace is true, this determines the size of action space of saved network' +) +game_arg.add_argument( + '--tlBaseBrain', + type=int, + default=3, + help='the gameConfig of the base network for re-training with transfer-learning' +) +game_arg.add_argument('--baseDemandDistribution', type=int, default=0, help='same as the demandDistribution') +game_arg.add_argument( + '--MultiAgent', type=str2bool, default=False, help='if run multi-agent RL model, not fully operational' +) +game_arg.add_argument( + '--MultiAgentRun', + type=list, + default=[True, True, True, True], + help='In the multi-RL setting, it determines which agent should get training.' +) +game_arg.add_argument( + '--if_use_AS_t_plus_1', type=str2bool, default='False', help='if use AS[t+1], not AS[t] in the input of DNN' +) +game_arg.add_argument( + '--ifSinglePathExist', + type=str2bool, + default=False, + help='If true it uses the predefined path in pre_model_dir and does not merge it with demandDistribution.' +) +game_arg.add_argument('--gamma', type=float, default=.99, help='discount factor for reward') +game_arg.add_argument( + '--multPerdInpt', type=int, default=10, help='Number of history records which we feed into network' +) + +# parameters of the leadtimes +leadtimes_arg = add_argument_group('leadtimes') +leadtimes_arg.add_argument( + '--leadRecItemLow', type=list, default=[2, 2, 2, 4], help='the min lead time for receiving items' +) +leadtimes_arg.add_argument( + '--leadRecItemUp', type=list, default=[2, 2, 2, 4], help='the max lead time for receiving items' +) +leadtimes_arg.add_argument( + '--leadRecOrderLow', type=int, default=[2, 2, 2, 0], help='the min lead time for receiving orders' +) +leadtimes_arg.add_argument( + '--leadRecOrderUp', type=int, default=[2, 2, 2, 0], help='the max lead time for receiving orders' +) +leadtimes_arg.add_argument('--ILInit', type=list, default=[0, 0, 0, 0], help='') +leadtimes_arg.add_argument('--AOInit', type=list, default=[0, 0, 0, 0], help='') +leadtimes_arg.add_argument('--ASInit', type=list, default=[0, 0, 0, 0], help='the initial shipment of each agent') +leadtimes_arg.add_argument('--leadRecItem1', type=int, default=2, help='the min lead time for receiving items') +leadtimes_arg.add_argument('--leadRecItem2', type=int, default=2, help='the min lead time for receiving items') +leadtimes_arg.add_argument('--leadRecItem3', type=int, default=2, help='the min lead time for receiving items') +leadtimes_arg.add_argument('--leadRecItem4', type=int, default=2, help='the min lead time for receiving items') +leadtimes_arg.add_argument('--leadRecOrder1', type=int, default=2, help='the min lead time for receiving order') +leadtimes_arg.add_argument('--leadRecOrder2', type=int, default=2, help='the min lead time for receiving order') +leadtimes_arg.add_argument('--leadRecOrder3', type=int, default=2, help='the min lead time for receiving order') +leadtimes_arg.add_argument('--leadRecOrder4', type=int, default=2, help='the min lead time for receiving order') +leadtimes_arg.add_argument('--ILInit1', type=int, default=0, help='the initial inventory level of the agent') +leadtimes_arg.add_argument('--ILInit2', type=int, default=0, help='the initial inventory level of the agent') +leadtimes_arg.add_argument('--ILInit3', type=int, default=0, help='the initial inventory level of the agent') +leadtimes_arg.add_argument('--ILInit4', type=int, default=0, help='the initial inventory level of the agent') +leadtimes_arg.add_argument('--AOInit1', type=int, default=0, help='the initial arriving order of the agent') +leadtimes_arg.add_argument('--AOInit2', type=int, default=0, help='the initial arriving order of the agent') +leadtimes_arg.add_argument('--AOInit3', type=int, default=0, help='the initial arriving order of the agent') +leadtimes_arg.add_argument('--AOInit4', type=int, default=0, help='the initial arriving order of the agent') +leadtimes_arg.add_argument('--ASInit1', type=int, default=0, help='the initial arriving shipment of the agent') +leadtimes_arg.add_argument('--ASInit2', type=int, default=0, help='the initial arriving shipment of the agent') +leadtimes_arg.add_argument('--ASInit3', type=int, default=0, help='the initial arriving shipment of the agent') +leadtimes_arg.add_argument('--ASInit4', type=int, default=0, help='the initial arriving shipment of the agent') + +# test +test_arg = add_argument_group('testing') +test_arg.add_argument( + '--testRepeatMid', + type=int, + default=50, + help='it is number of episodes which is going to be used for testing in the middle of training' +) +test_arg.add_argument('--testInterval', type=int, default=100, help='every xx games compute "test error"') +test_arg.add_argument( + '--ifSaveFigure', type=str2bool, default=True, help='if is it True, save the figures in each testing.' +) +test_arg.add_argument( + '--if_titled_figure', + type=str2bool, + default='True', + help='if is it True, save the figures with details in the title.' +) +test_arg.add_argument( + '--ifsaveHistInterval', type=str2bool, default=False, help='if every xx games save details of the episode' +) +test_arg.add_argument('--saveHistInterval', type=int, default=50000, help='every xx games save details of the play') +test_arg.add_argument('--Ttest', type=int, default=100, help='it defines the number of periods in the test cases') +test_arg.add_argument( + '--ifOptimalSolExist', + type=str2bool, + default=True, + help='if the instance has optimal base stock policy, set it to True, otherwise it should be False.' +) +test_arg.add_argument('--f1', type=float, default=8, help='base stock policy decision of player 1') +test_arg.add_argument('--f2', type=float, default=8, help='base stock policy decision of player 2') +test_arg.add_argument('--f3', type=float, default=0, help='base stock policy decision of player 3') +test_arg.add_argument('--f4', type=float, default=0, help='base stock policy decision of player 4') +test_arg.add_argument( + '--f_init', + type=list, + default=[32, 32, 32, 24], + help='base stock policy decision for 4 time-steps on the C(4,8) demand distribution' +) +test_arg.add_argument('--use_initial_BS', type=str2bool, default=False, help='If use f_init set it to True') + +# reporting +reporting_arg = add_argument_group('reporting') +reporting_arg.add_argument('--Rsltdnn', type=list, default=[], help='the result of dnn play tests will be saved here') +reporting_arg.add_argument( + '--RsltRnd', type=list, default=[], help='the result of random play tests will be saved here' +) +reporting_arg.add_argument( + '--RsltStrm', type=list, default=[], help='the result of heuristic fomula play tests will be saved here' +) +reporting_arg.add_argument( + '--Rsltbs', type=list, default=[], help='the result of optimal play tests will be saved here' +) +reporting_arg.add_argument( + '--ifSaveHist', + type=str2bool, + default='False', + help= + 'if it is true, saves history, prediction, and the randBatch in each period, WARNING: just make it True in small runs, it saves huge amount of files.' +) + + +# buildActionList: actions for the beer game problem +def buildActionList(config): + aDiv = 1 # difference in the action list + if config.fixedAction: + actions = list( + range(0, config.actionMax + 1, aDiv) + ) # If you put the second argument =11, creates an actionlist from 0..xx + else: + actions = list(range(config.actionLow, config.actionUp + 1, aDiv)) + return actions + + +# specify the dimension of the state of the game +def getStateDim(config): + if config.ifUseASAO: + stateDim = 5 + else: + stateDim = 3 + + if config.ifUseActionInD: + stateDim += 1 + + return stateDim + + +def set_optimal(config): + if config.demandDistribution == 0: + if config.cp1 == 2 and config.ch1 == 2 and config.ch2 == 2 and config.ch3 == 2 and config.ch4 == 2: + config.f1 = 8. + config.f2 = 8. + config.f3 = 0. + config.f4 = 0. + + +def get_config(): + config, unparsed = parser.parse_known_args() + config = update_config(config) + + return config, unparsed + + +def fill_leadtime_initial_values(config): + config.leadRecItemLow = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4] + config.leadRecItemUp = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4] + config.leadRecOrderLow = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4] + config.leadRecOrderUp = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4] + config.ILInit = [config.ILInit1, config.ILInit2, config.ILInit3, config.ILInit4] + config.AOInit = [config.AOInit1, config.AOInit2, config.AOInit3, config.AOInit4] + config.ASInit = [config.ASInit1, config.ASInit2, config.ASInit3, config.ASInit4] + + +def get_auxuliary_leadtime_initial_values(config): + config.leadRecOrderUp_aux = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4] + config.leadRecItemUp_aux = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4] + + +def fix_lead_time_manufacturer(config): + if config.leadRecOrder4 > 0: + config.leadRecItem4 += config.leadRecOrder4 + config.leadRecOrder4 = 0 + + +def set_sterman_parameters(config): + config.alpha_b = [config.alpha_b1, config.alpha_b2, config.alpha_b3, config.alpha_b4] + config.betta_b = [config.betta_b1, config.betta_b2, config.betta_b3, config.betta_b4] + + +def update_config(config): + config.actionList = buildActionList(config) # The list of the available actions + config.actionListLen = len(config.actionList) # the length of the action list + + set_optimal(config) + config.f = [config.f1, config.f2, config.f3, config.f4] # [6.4, 2.88, 2.08, 0.8] + + config.actionListLen = len(config.actionList) + if config.demandDistribution == 0: + config.actionListOpt = list(range(0, int(max(config.actionUp * 30 + 1, 3 * sum(config.f))), 1)) + else: + config.actionListOpt = list(range(0, int(max(config.actionUp * 30 + 1, 7 * sum(config.f))), 1)) + config.actionListLenOpt = len(config.actionListOpt) + + config.c_h = [config.ch1, config.ch2, config.ch3, config.ch4] + config.c_p = [config.cp1, config.cp2, config.cp3, config.cp4] + + config.stateDim = getStateDim(config) # Number of elements in the state description - Depends on ifUseASAO + get_auxuliary_leadtime_initial_values(config) + fix_lead_time_manufacturer(config) + fill_leadtime_initial_values(config) + set_sterman_parameters(config) + + return config