diff --git a/README.md b/README.md
index 7c8a82a704..595da21956 100644
--- a/README.md
+++ b/README.md
@@ -273,6 +273,7 @@ P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
| 29 |[dmc2gym](https://github.com/denisyarats/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/dmc2gym/dmc2gym_cheetah.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/dmc2gym)
[env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/dmc2gym.html)
[环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/dmc2gym_zh.html) |
| 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs)
环境指南 |
| 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym-pybullet-drones/gym-pybullet-drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs)
环境指南 |
+| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs)
环境指南 |
![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
diff --git a/ding/envs/env_manager/base_env_manager.py b/ding/envs/env_manager/base_env_manager.py
index e54eafe926..47dfdd5732 100644
--- a/ding/envs/env_manager/base_env_manager.py
+++ b/ding/envs/env_manager/base_env_manager.py
@@ -203,7 +203,9 @@ def done(self) -> bool:
@property
def method_name_list(self) -> list:
- return ['reset', 'step', 'seed', 'close', 'enable_save_replay', 'render']
+ return [
+ 'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure'
+ ]
def env_state_done(self, env_id: int) -> bool:
return self._env_states[env_id] == EnvState.DONE
@@ -418,6 +420,19 @@ def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
replay_path = [replay_path] * self.env_num
self._env_replay_path = replay_path
+ def enable_save_figure(self, env_id: int, figure_path: Union[List[str], str]) -> None:
+ """
+ Overview:
+ Set each env's replay save path.
+ Arguments:
+ - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
+ Or one path for all environments.
+ """
+ if isinstance(figure_path, str):
+ self._env[env_id].enable_save_figure(figure_path)
+ else:
+ raise TypeError("invalid figure_path arguments type: {}".format(type(figure_path)))
+
def close(self) -> None:
"""
Overview:
@@ -431,6 +446,9 @@ def close(self) -> None:
self._env_states[i] = EnvState.VOID
self._closed = True
+ def reward_shaping(self, env_id: int, transitions: List[dict]) -> List[dict]:
+ return self._envs[env_id].reward_shaping(transitions)
+
@property
def closed(self) -> bool:
return self._closed
diff --git a/ding/worker/collector/episode_serial_collector.py b/ding/worker/collector/episode_serial_collector.py
index 9fcf733779..6d89462bf0 100644
--- a/ding/worker/collector/episode_serial_collector.py
+++ b/ding/worker/collector/episode_serial_collector.py
@@ -21,7 +21,9 @@ class EpisodeSerialCollector(ISerialCollector):
envstep
"""
- config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False)
+ config = dict(
+ deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False, reward_shaping=False
+ )
def __init__(
self,
@@ -251,6 +253,8 @@ def collect(self,
# prepare data
if timestep.done:
transitions = to_tensor_transitions(self._traj_buffer[env_id])
+ if self._cfg.reward_shaping:
+ self._env.reward_shaping(env_id, transitions)
if self._cfg.get_train_sample:
train_sample = self._policy.get_train_sample(transitions)
return_data.extend(train_sample)
diff --git a/ding/worker/collector/interaction_serial_evaluator.py b/ding/worker/collector/interaction_serial_evaluator.py
index 3e4645260a..32222c0804 100644
--- a/ding/worker/collector/interaction_serial_evaluator.py
+++ b/ding/worker/collector/interaction_serial_evaluator.py
@@ -241,6 +241,9 @@ def eval(
continue
if t.done:
# Env reset is done by env_manager automatically.
+ if 'figure_path' in self._cfg:
+ if self._cfg.figure_path is not None:
+ self._env.enable_save_figure(env_id, self._cfg.figure_path)
self._policy.reset([env_id])
reward = t.info['final_eval_reward']
if 'episode_info' in t.info:
diff --git a/dizoo/beergame/beergame.png b/dizoo/beergame/beergame.png
new file mode 100644
index 0000000000..ec76d6e5ca
Binary files /dev/null and b/dizoo/beergame/beergame.png differ
diff --git a/dizoo/beergame/config/beergame_onppo_config.py b/dizoo/beergame/config/beergame_onppo_config.py
new file mode 100644
index 0000000000..7ad87a23fe
--- /dev/null
+++ b/dizoo/beergame/config/beergame_onppo_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+beergame_ppo_config = dict(
+ exp_name='beergame_ppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=200,
+ role=0, # 0-3 : retailer, warehouse, distributor, manufacturer
+ agent_type='bs',
+ # type of co-player, 'bs'- base stock, 'Strm'- use Sterman formula to model typical human behavior
+ demandDistribution=0
+ # distribution of demand, default=0, '0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data'
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=50, # statedim * multPerdInpt= 5 * 10
+ action_shape=5, # the quantity relative to the arriving order
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=320,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=True,
+ ),
+ collect=dict(
+ n_episode=8,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ collector=dict(
+ get_train_sample=True,
+ reward_shaping=True, # whether use total reward to reshape reward
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+beergame_ppo_config = EasyDict(beergame_ppo_config)
+main_config = beergame_ppo_config
+beergame_ppo_create_config = dict(
+ env=dict(
+ type='beergame',
+ import_names=['dizoo.beergame.envs.beergame_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+ collector=dict(type='episode', ),
+)
+beergame_ppo_create_config = EasyDict(beergame_ppo_create_config)
+create_config = beergame_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c beergame_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/dizoo/beergame/entry/beergame_eval.py b/dizoo/beergame/entry/beergame_eval.py
new file mode 100644
index 0000000000..5299107e78
--- /dev/null
+++ b/dizoo/beergame/entry/beergame_eval.py
@@ -0,0 +1,42 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import InteractionSerialEvaluator
+from ding.envs import BaseEnvManager
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from dizoo.beergame.config.beergame_onppo_config import beergame_ppo_config, beergame_ppo_create_config
+from ding.envs import get_vec_env_setting
+from functools import partial
+
+
+def main(cfg, seed=0):
+ env_fn = None
+ cfg, create_cfg = beergame_ppo_config, beergame_ppo_create_config
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ cfg.env.manager.auto_reset = False
+ evaluator_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ model = VAC(**cfg.policy.model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ policy = PPOPolicy(cfg.policy, model=model)
+ # set the path to save figure
+ cfg.policy.eval.evaluator.figure_path = './'
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ # load model
+ model.load_state_dict(torch.load('model path', map_location='cpu')["model"])
+ evaluator.eval(None, -1, -1)
+
+
+if __name__ == "__main__":
+ beergame_ppo_config.exp_name = 'beergame_evaluate'
+ main(beergame_ppo_config)
\ No newline at end of file
diff --git a/dizoo/beergame/envs/BGAgent.py b/dizoo/beergame/envs/BGAgent.py
new file mode 100644
index 0000000000..7c4088c1d0
--- /dev/null
+++ b/dizoo/beergame/envs/BGAgent.py
@@ -0,0 +1,152 @@
+# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL.
+import argparse
+import numpy as np
+
+
+# Here we want to define the agent class for the BeerGame
+class Agent(object):
+ # initializes the agents with initial values for IL, OO and saves self.agentNum for recognizing the agents.
+ def __init__(
+ self, agentNum: int, IL: int, AO: int, AS: int, c_h: float, c_p: float, eta: int, compuType: str,
+ config: argparse.Namespace
+ ) -> None:
+ self.agentNum = agentNum
+ self.IL = IL # Inventory level of each agent - changes during the game
+ self.OO = 0 # Open order of each agent - changes during the game
+ self.ASInitial = AS # the initial arriving shipment.
+ self.ILInitial = IL # IL at which we start each game with this number
+ self.AOInitial = AO # OO at which we start each game with this number
+ self.config = config # an instance of config is stored inside the class
+ self.curState = [] # this function gets the current state of the game
+ self.nextState = []
+ self.curReward = 0 # the reward observed at the current step
+ self.cumReward = 0 # cumulative reward; reset at the begining of each episode
+ self.totRew = 0 # it is reward of all players obtained for the current player.
+ self.c_h = c_h # holding cost
+ self.c_p = c_p # backorder cost
+ self.eta = eta # the total cost regulazer
+ self.AS = np.zeros((1, 1)) # arriced shipment
+ self.AO = np.zeros((1, 1)) # arrived order
+ self.action = 0 # the action at time t
+ self.compType = compuType
+ # self.compTypeTrain = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists
+ # self.compTypeTest = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists
+ self.alpha_b = self.config.alpha_b[self.agentNum] # parameters for the formula
+ self.betta_b = self.config.betta_b[self.agentNum] # parameters for the formula
+ if self.config.demandDistribution == 0:
+ self.a_b = np.mean((self.config.demandUp, self.config.demandLow)) # parameters for the formula
+ self.b_b = np.mean((self.config.demandUp, self.config.demandLow)) * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ elif self.config.demandDistribution == 1 or self.config.demandDistribution == 3 or self.config.demandDistribution == 4:
+ self.a_b = self.config.demandMu # parameters for the formula
+ self.b_b = self.config.demandMu * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ elif self.config.demandDistribution == 2:
+ self.a_b = 8 # parameters for the formula
+ self.b_b = (3 / 4.) * 8 * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ elif self.config.demandDistribution == 3:
+ self.a_b = 10 # parameters for the formula
+ self.b_b = 7 * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ else:
+ raise Exception('The demand distribution is not defined or it is not a valid type.!')
+
+ self.hist = [] # this is used for plotting - keeps the history for only one game
+ self.hist2 = [] # this is used for animation usage
+ self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017
+ self.T = 0
+ self.bsBaseStock = 0
+ self.init_bsBaseStock = 0
+ self.nextObservation = []
+
+ if self.compType == 'srdqn':
+ # sets the initial input of the network
+ self.currentState = np.stack(
+ [self.curState for _ in range(self.config.multPerdInpt)], axis=0
+ ) # multPerdInpt observations stacked. each row is an observation
+
+ # reset player information
+ def resetPlayer(self, T: int):
+ self.IL = self.ILInitial
+ self.OO = 0
+ self.AS = np.squeeze(
+ np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10))
+ ) # arriced shipment
+ self.AO = np.squeeze(
+ np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10))
+ ) # arrived order
+ if self.agentNum != 0:
+ for i in range(self.config.leadRecOrderUp_aux[self.agentNum - 1]):
+ self.AO[i] = self.AOInitial[self.agentNum - 1]
+ for i in range(self.config.leadRecItemUp[self.agentNum]):
+ self.AS[i] = self.ASInitial
+ self.curReward = 0 # the reward observed at the current step
+ self.cumReward = 0 # cumulative reward; reset at the begining of each episode
+ self.action = []
+ self.hist = []
+ self.hist2 = []
+ self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017
+ self.T = T
+ self.curObservation = self.getCurState(1) # this function gets the current state of the game
+ self.nextObservation = []
+ if self.compType == 'srdqn':
+ self.currentState = np.stack([self.curObservation for _ in range(self.config.multPerdInpt)], axis=0)
+
+ # updates the IL and OO at time t, after recieving "rec" number of items
+ def recieveItems(self, time: int) -> None:
+ self.IL = self.IL + self.AS[time] # inverntory level update
+ self.OO = self.OO - self.AS[time] # invertory in transient update
+
+ # find action Value associated with the action list
+ def actionValue(self, curTime: int) -> int:
+ if self.config.fixedAction:
+ a = self.config.actionList[np.argmax(self.action)]
+ else:
+ # "d + x" rule
+ if self.compType == 'srdqn':
+ a = max(0, self.config.actionList[np.argmax(self.action)] * self.config.action_step + self.AO[curTime])
+ elif self.compType == 'rnd':
+ a = max(0, self.config.actionList[np.argmax(self.action)] + self.AO[curTime])
+ else:
+ a = max(0, self.config.actionListOpt[np.argmax(self.action)])
+
+ return a
+
+ # getReward returns the reward at the current state
+ def getReward(self) -> None:
+ # cost (holding + backorder) for one time unit
+ self.curReward = (self.c_p * max(0, -self.IL) + self.c_h * max(0, self.IL)) / 200. # self.config.Ttest #
+ self.curReward = -self.curReward
+ # make reward negative, because it is the cost
+
+ # sum total reward of each agent
+ self.cumReward = self.config.gamma * self.cumReward + self.curReward
+
+ # This function returns a np.array of the current state of the agent
+ def getCurState(self, t: int) -> np.ndarray:
+ if self.config.ifUseASAO:
+ if self.config.if_use_AS_t_plus_1:
+ curState = np.array(
+ [-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t], self.AO[t]]
+ )
+ else:
+ curState = np.array(
+ [-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t - 1], self.AO[t]]
+ )
+ else:
+ curState = np.array([-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO])
+
+ if self.config.ifUseActionInD:
+ a = self.config.actionList[np.argmax(self.action)]
+ curState = np.concatenate((curState, np.array([a])))
+
+ return curState
diff --git a/dizoo/beergame/envs/__init__.py b/dizoo/beergame/envs/__init__.py
new file mode 100644
index 0000000000..d4ffbfd452
--- /dev/null
+++ b/dizoo/beergame/envs/__init__.py
@@ -0,0 +1,2 @@
+from .clBeergame import clBeerGame
+from .beergame_core import BeerGame
diff --git a/dizoo/beergame/envs/beergame_core.py b/dizoo/beergame/envs/beergame_core.py
new file mode 100644
index 0000000000..2f0ac61910
--- /dev/null
+++ b/dizoo/beergame/envs/beergame_core.py
@@ -0,0 +1,112 @@
+from __future__ import print_function
+from dizoo.beergame.envs import clBeerGame
+from torch import Tensor
+import numpy as np
+import random
+from .utils import get_config, update_config
+import gym
+import os
+from typing import Optional
+
+
+class BeerGame():
+
+ def __init__(self, role: int, agent_type: str, demandDistribution: int) -> None:
+ self._cfg, unparsed = get_config()
+ self._role = role
+ # prepare loggers and directories
+ # prepare_dirs_and_logger(self._cfg)
+ self._cfg = update_config(self._cfg)
+
+ # set agent type
+ if agent_type == 'bs':
+ self._cfg.agentTypes = ["bs", "bs", "bs", "bs"]
+ elif agent_type == 'Strm':
+ self._cfg.agentTypes = ["Strm", "Strm", "Strm", "Strm"]
+ self._cfg.agentTypes[role] = "srdqn"
+
+ self._cfg.demandDistribution = demandDistribution
+
+ # load demands:0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data
+ if self._cfg.observation_data:
+ adsr = 'data/demandTr-obs-'
+ elif self._cfg.demandDistribution == 3:
+ if self._cfg.scaled:
+ adsr = 'data/basket_data/scaled'
+ else:
+ adsr = 'data/basket_data'
+ direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy')
+ self._demandTr = np.load(direc)
+ print("loaded training set=", direc)
+ elif self._cfg.demandDistribution == 4:
+ if self._cfg.scaled:
+ adsr = 'data/forecast_data/scaled'
+ else:
+ adsr = 'data/forecast_data'
+ direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy')
+ self._demandTr = np.load(direc)
+ print("loaded training set=", direc)
+ else:
+ if self._cfg.demandDistribution == 0: # uniform
+ self._demandTr = np.random.randint(0, self._cfg.demandUp, size=[self._cfg.demandSize, self._cfg.TUp])
+ elif self._cfg.demandDistribution == 1: # normal distribution
+ self._demandTr = np.round(
+ np.random.normal(
+ self._cfg.demandMu, self._cfg.demandSigma, size=[self._cfg.demandSize, self._cfg.TUp]
+ )
+ ).astype(int)
+ elif self._cfg.demandDistribution == 2: # the sequence of 4,4,4,4,8,...
+ self._demandTr = np.concatenate(
+ (4 * np.ones((self._cfg.demandSize, 4)), 8 * np.ones((self._cfg.demandSize, 98))), axis=1
+ ).astype(int)
+
+ # initilize an instance of Beergame
+ self._env = clBeerGame(self._cfg)
+ self.observation_space = gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._cfg.stateDim * self._cfg.multPerdInpt, ),
+ dtype=np.float32
+ ) # state_space = state_dim * m (considering the reward delay)
+ self.action_space = gym.spaces.Discrete(self._cfg.actionListLen) # length of action list
+ self.reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+
+ # get the length of the demand.
+ self._demand_len = np.shape(self._demandTr)[0]
+
+ def reset(self):
+ self._env.resetGame(demand=self._demandTr[random.randint(0, self._demand_len - 1)])
+ obs = [i for item in self._env.players[self._role].currentState for i in item]
+ return obs
+
+ def seed(self, seed: int) -> None:
+ self._seed = seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ pass
+
+ def step(self, action: np.ndarray):
+ self._env.handelAction(action)
+ self._env.next()
+ newstate = np.append(
+ self._env.players[self._role].currentState[1:, :], [self._env.players[self._role].nextObservation], axis=0
+ )
+ self._env.players[self._role].currentState = newstate
+ obs = [i for item in newstate for i in item]
+ rew = self._env.players[self._role].curReward
+ done = (self._env.curTime == self._env.T)
+ info = {}
+ return obs, rew, done, info
+
+ def reward_shaping(self, reward: Tensor) -> Tensor:
+ self._totRew, self._cumReward = self._env.distTotReward(self._role)
+ reward += (self._cfg.distCoeff / 3) * ((self._totRew - self._cumReward) / (self._env.T))
+ return reward
+
+ def enable_save_figure(self, figure_path: Optional[str] = None) -> None:
+ self._cfg.ifSaveFigure = True
+ if figure_path is None:
+ figure_path = './'
+ self._cfg.figure_dir = figure_path
+ self._env.doTestMid(self._demandTr[random.randint(0, self._demand_len - 1)])
diff --git a/dizoo/beergame/envs/beergame_env.py b/dizoo/beergame/envs/beergame_env.py
new file mode 100644
index 0000000000..c8e76de9fe
--- /dev/null
+++ b/dizoo/beergame/envs/beergame_env.py
@@ -0,0 +1,84 @@
+import numpy as np
+from dizoo.beergame.envs.beergame_core import BeerGame
+from typing import Union, List, Optional
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+import copy
+
+
+@ENV_REGISTRY.register('beergame')
+class BeerGameEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = BeerGame(self._cfg.role, self._cfg.agent_type, self._cfg.demandDistribution)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = self._env.reward_space
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._final_eval_reward = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray) and action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ obs, rew, done, info = self._env.step(action)
+ self._final_eval_reward += rew
+ if done:
+ info['final_eval_reward'] = self._final_eval_reward
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def reward_shaping(self, transitions: List[dict]) -> List[dict]:
+ new_transitions = copy.deepcopy(transitions)
+ for trans in new_transitions:
+ trans['reward'] = self._env.reward_shaping(trans['reward'])
+ return new_transitions
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ def enable_save_figure(self, figure_path: Optional[str] = None) -> None:
+ self._env.enable_save_figure(figure_path)
+
+ @property
+ def observation_space(self) -> int:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> int:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> int:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Beergame Env"
diff --git a/dizoo/beergame/envs/clBeergame.py b/dizoo/beergame/envs/clBeergame.py
new file mode 100644
index 0000000000..9a237fbb68
--- /dev/null
+++ b/dizoo/beergame/envs/clBeergame.py
@@ -0,0 +1,439 @@
+# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL.
+import numpy as np
+from random import randint
+from .BGAgent import Agent
+from matplotlib import rc
+rc('text', usetex=True)
+from .plotting import plotting, savePlot
+import matplotlib.pyplot as plt
+import os
+import time
+from time import gmtime, strftime
+
+
+class clBeerGame(object):
+
+ def __init__(self, config):
+ self.config = config
+ self.curGame = 0 # The number associated with the current game (counter of the game)
+ self.curTime = 0
+ self.totIterPlayed = 0 # total iterations of the game, played so far in this and previous games
+ self.players = self.createAgent() # create the agents
+ self.T = 0
+ self.demand = []
+ self.ifOptimalSolExist = self.config.ifOptimalSolExist
+ self.getOptimalSol()
+ self.totRew = 0 # it is reward of all players obtained for the current player.
+ self.resultTest = []
+ self.runnerMidlResults = [] # stores the results to use in runner comparisons
+ self.runnerFinlResults = [] # stores the results to use in runner comparisons
+ self.middleTestResult = [
+ ] # stores the whole middle results of bs, Strm, and random to avoid doing same tests multiple of times.
+ self.runNumber = 0 # the runNumber which is used when use runner
+ self.strNum = 0 # the runNumber which is used when use runner
+
+ # createAgent : Create agent objects (agentNum,IL,OO,c_h,c_p,type,config)
+ def createAgent(self):
+ agentTypes = self.config.agentTypes
+ return [
+ Agent(
+ i, self.config.ILInit[i], self.config.AOInit, self.config.ASInit[i], self.config.c_h[i],
+ self.config.c_p[i], self.config.eta[i], agentTypes[i], self.config
+ ) for i in range(self.config.NoAgent)
+ ]
+
+ # planHorizon : Find a random planning horizon
+ def planHorizon(self):
+ # TLow: minimum number for the planning horizon # TUp: maximum number for the planning horizon
+ # output: The planning horizon which is chosen randomly.
+ return randint(self.config.TLow, self.config.TUp)
+
+ # this function resets the game for start of the new game
+ def resetGame(self, demand: np.ndarray):
+ self.demand = demand
+ self.curTime = 0
+ self.curGame += 1
+ self.totIterPlayed += self.T
+ self.T = self.planHorizon()
+ # reset the required information of player for each episode
+ for k in range(0, self.config.NoAgent):
+ self.players[k].resetPlayer(self.T)
+
+ # update OO when there are initial IL,AO,AS
+ self.update_OO()
+
+ # correction on cost at time T according to the cost of the other players
+ def getTotRew(self):
+ totRew = 0
+ for i in range(self.config.NoAgent):
+ # sum all rewards for the agents and make correction
+ totRew += self.players[i].cumReward
+
+ for i in range(self.config.NoAgent):
+ self.players[i].curReward += self.players[i].eta * (totRew - self.players[i].cumReward) # /(self.T)
+
+ # make correction to the rewards in the experience replay for all iterations of current game
+ def distTotReward(self, role: int):
+ totRew = 0
+ optRew = 0.1 # why?
+ for i in range(self.config.NoAgent):
+ # sum all rewards for the agents and make correction
+ totRew += self.players[i].cumReward
+ totRew += optRew
+
+ return totRew, self.players[role].cumReward
+
+ def getAction(self, k: int, action: np.ndarray, playType="train"):
+ if playType == "train":
+ if self.players[k].compType == "srdqn":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ self.players[k].action[action] = 1
+ elif self.players[k].compType == "Strm":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt)\
+ - max(0, round(self.players[k].AO[self.curTime] + \
+ self.players[k].alpha_b*(self.players[k].IL - self.players[k].a_b) + \
+ self.players[k].betta_b*(self.players[k].OO - self.players[k].b_b)))))] = 1
+ elif self.players[k].compType == "rnd":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ a = np.random.randint(self.config.actionListLen)
+ self.players[k].action[a] = 1
+ elif self.players[k].compType == "bs":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+ if self.config.demandDistribution == 2:
+ if self.curTime and self.config.use_initial_BS <= 4:
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \
+ max(0, (self.players[k].int_bslBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1
+ else:
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \
+ max(0, (self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1
+ else:
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \
+ max(0, (self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1
+ elif playType == "test":
+ if self.players[k].compTypeTest == "srdqn":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ self.players[k].action = self.players[k].brain.getDNNAction(self.playType)
+ elif self.players[k].compTypeTest == "Strm":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,round(self.players[k].AO[self.curTime] +\
+ self.players[k].alpha_b*(self.players[k].IL - self.players[k].a_b) +\
+ self.players[k].betta_b*(self.players[k].OO - self.players[k].b_b)))))] = 1
+ elif self.players[k].compTypeTest == "rnd":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ a = np.random.randint(self.config.actionListLen)
+ self.players[k].action[a] = 1
+ elif self.players[k].compTypeTest == "bs":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+
+ if self.config.demandDistribution == 2:
+ if self.curTime and self.config.use_initial_BS <= 4:
+ self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,(self.players[k].int_bslBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1
+ else:
+ self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,(self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1
+ else:
+ self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,(self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1
+ else:
+ # not a valid player is defined.
+ raise Exception('The player type is not defined or it is not a valid type.!')
+
+ def next(self):
+ # get a random leadtime
+ leadTimeIn = randint(
+ self.config.leadRecItemLow[self.config.NoAgent - 1], self.config.leadRecItemUp[self.config.NoAgent - 1]
+ )
+ # handle the most upstream recieved shipment
+ self.players[self.config.NoAgent - 1].AS[self.curTime +
+ leadTimeIn] += self.players[self.config.NoAgent -
+ 1].actionValue(self.curTime)
+
+ for k in range(self.config.NoAgent - 1, -1, -1): # [3,2,1,0]
+
+ # get current IL and Backorder
+ current_IL = max(0, self.players[k].IL)
+ current_backorder = max(0, -self.players[k].IL)
+
+ # TODO: We have get the AS and AO from the UI and update our AS and AO, so that code update the corresponding variables
+
+ # increase IL and decrease OO based on the action, for the next period
+ self.players[k].recieveItems(self.curTime)
+
+ # observe the reward
+ possible_shipment = min(
+ current_IL + self.players[k].AS[self.curTime], current_backorder + self.players[k].AO[self.curTime]
+ )
+
+ # plan arrivals of the items to the downstream agent
+ if self.players[k].agentNum > 0:
+ leadTimeIn = randint(self.config.leadRecItemLow[k - 1], self.config.leadRecItemUp[k - 1])
+ self.players[k - 1].AS[self.curTime + leadTimeIn] += possible_shipment
+
+ # update IL
+ self.players[k].IL -= self.players[k].AO[self.curTime]
+ # observe the reward
+ self.players[k].getReward()
+ self.players[k].hist[-1][-2] = self.players[k].curReward
+ self.players[k].hist2[-1][-2] = self.players[k].curReward
+
+ # update next observation
+ self.players[k].nextObservation = self.players[k].getCurState(self.curTime + 1)
+
+ if self.config.ifUseTotalReward:
+ # correction on cost at time T
+ if self.curTime == self.T:
+ self.getTotRew()
+
+ self.curTime += 1
+
+ def handelAction(self, action: np.ndarray, playType="train"):
+ # get random lead time
+ leadTime = randint(self.config.leadRecOrderLow[0], self.config.leadRecOrderUp[0])
+ # set AO
+ self.players[0].AO[self.curTime] += self.demand[self.curTime]
+ for k in range(0, self.config.NoAgent):
+ self.getAction(k, action, playType)
+
+ self.players[k].srdqnBaseStock += [self.players[k].actionValue( \
+ self.curTime) + self.players[k].IL + self.players[k].OO]
+
+ # update hist for the plots
+ self.players[k].hist += [[self.curTime, self.players[k].IL, self.players[k].OO,\
+ self.players[k].actionValue(self.curTime), self.players[k].curReward, self.players[k].srdqnBaseStock[-1]]]
+
+ if self.players[k].compType == "srdqn":
+ self.players[k].hist2 += [[self.curTime, self.players[k].IL, self.players[k].OO, self.players[k].AO[self.curTime], self.players[k].AS[self.curTime], \
+ self.players[k].actionValue(self.curTime), self.players[k].curReward, \
+ self.config.actionList[np.argmax(self.players[k].action)]]]
+
+ else:
+ self.players[k].hist2 += [[self.curTime, self.players[k].IL, self.players[k].OO, self.players[k].AO[self.curTime], self.players[k].AS[self.curTime], \
+ self.players[k].actionValue(self.curTime), self.players[k].curReward, 0]]
+
+ # updates OO and AO at time t+1
+ self.players[k].OO += self.players[k].actionValue(self.curTime) # open order level update
+ leadTime = randint(self.config.leadRecOrderLow[k], self.config.leadRecOrderUp[k])
+ if self.players[k].agentNum < self.config.NoAgent - 1:
+ self.players[k + 1].AO[self.curTime + leadTime] += self.players[k].actionValue(
+ self.curTime
+ ) # open order level update
+
+ # check the Shang and Song (2003) condition, and if it works, obtains the base stock policy values for each agent
+ def getOptimalSol(self):
+ # if self.config.NoAgent !=1:
+ if self.config.NoAgent != 1 and 1 == 2:
+ # check the Shang and Song (2003) condition.
+ for k in range(self.config.NoAgent - 1):
+ if not (self.players[k].c_h == self.players[k + 1].c_h and self.players[k + 1].c_p == 0):
+ self.ifOptimalSolExist = False
+
+ # if the Shang and Song (2003) condition satisfied, it runs the algorithm
+ if self.ifOptimalSolExist == True:
+ calculations = np.zeros((7, self.config.NoAgent))
+ for k in range(self.config.NoAgent):
+ # DL_high
+ calculations[0][k] = ((self.config.leadRecItemLow + self.config.leadRecItemUp + 2) / 2 \
+ + (self.config.leadRecOrderLow + self.config.leadRecOrderUp + 2) / 2) * \
+ (self.config.demandUp - self.config.demandLow - 1)
+ if k > 0:
+ calculations[0][k] += calculations[0][k - 1]
+ # probability_high
+ nominator_ch = 0
+ low_denominator_ch = 0
+ for j in range(k, self.config.NoAgent):
+ if j < self.config.NoAgent - 1:
+ nominator_ch += self.players[j + 1].c_h
+ low_denominator_ch += self.players[j].c_h
+ if k == 0:
+ high_denominator_ch = low_denominator_ch
+ calculations[2][k] = (self.players[0].c_p +
+ nominator_ch) / (self.players[0].c_p + low_denominator_ch + 0.0)
+ # probability_low
+ calculations[3][k] = (self.players[0].c_p +
+ nominator_ch) / (self.players[0].c_p + high_denominator_ch + 0.0)
+ # S_high
+ calculations[4] = np.round(np.multiply(calculations[0], calculations[2]))
+ # S_low
+ calculations[5] = np.round(np.multiply(calculations[0], calculations[3]))
+ # S_avg
+ calculations[6] = np.round(np.mean(calculations[4:6], axis=0))
+ # S', set the base stock values into each agent.
+ for k in range(self.config.NoAgent):
+ if k == 0:
+ self.players[k].bsBaseStock = calculations[6][k]
+
+ else:
+ self.players[k].bsBaseStock = calculations[6][k] - calculations[6][k - 1]
+ if self.players[k].bsBaseStock < 0:
+ self.players[k].bsBaseStock = 0
+ elif self.config.NoAgent == 1:
+ if self.config.demandDistribution == 0:
+ self.players[0].bsBaseStock = np.ceil(
+ self.config.c_h[0] / (self.config.c_h[0] + self.config.c_p[0] + 0.0)
+ ) * ((self.config.demandUp - self.config.demandLow - 1) / 2) * self.config.leadRecItemUp
+ elif 1 == 1:
+ f = self.config.f
+ f_init = self.config.f_init
+ for k in range(self.config.NoAgent):
+ self.players[k].bsBaseStock = f[k]
+ self.players[k].int_bslBaseStock = f_init[k]
+
+ def update_OO(self):
+ for k in range(0, self.config.NoAgent):
+ if k < self.config.NoAgent - 1:
+ self.players[k].OO = sum(self.players[k + 1].AO) + sum(self.players[k].AS)
+ else:
+ self.players[k].OO = sum(self.players[k].AS)
+
+ def doTestMid(self, demandTs):
+ self.resultTest = []
+ m = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
+ self.doTest(m, demandTs)
+ print("---------------------------------------------------------------------------------------")
+ resultSummary = np.array(self.resultTest).mean(axis=0).tolist()
+
+ result_srdqn = ', '.join(map("{:.2f}".format, resultSummary[0]))
+ result_rand = ', '.join(map("{:.2f}".format, resultSummary[1]))
+ result_strm = ', '.join(map("{:.2f}".format, resultSummary[2]))
+ if self.ifOptimalSolExist:
+ result_bs = ', '.join(map("{:.2f}".format, resultSummary[3]))
+ print(
+ 'SUMMARY; {0:s}; ITER= {1:d}; OURPOLICY= [{2:s}]; SUM = {3:2.4f}; Rand= [{4:s}]; SUM = {5:2.4f}; STRM= [{6:s}]; SUM = {7:2.4f}; BS= [{8:s}]; SUM = {9:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), self.curGame, result_srdqn, sum(resultSummary[0]),
+ result_rand, sum(resultSummary[1]), result_strm, sum(resultSummary[2]), result_bs,
+ sum(resultSummary[3])
+ )
+ )
+
+ else:
+ print(
+ 'SUMMARY; {0:s}; ITER= {1:d}; OURPOLICY= [{2:s}]; SUM = {3:2.4f}; Rand= [{4:s}]; SUM = {5:2.4f}; STRM= [{6:s}]; SUM = {7:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), self.curGame, result_srdqn, sum(resultSummary[0]),
+ result_rand, sum(resultSummary[1]), result_strm, sum(resultSummary[2])
+ )
+ )
+
+ print("=======================================================================================")
+
+ def doTest(self, m, demand):
+ import matplotlib.pyplot as plt
+ if self.config.ifSaveFigure:
+ plt.figure(self.curGame, figsize=(12, 8), dpi=80, facecolor='w', edgecolor='k')
+
+ # self.demand = demand
+ # use dnn to get output.
+ Rsltdnn, plt = self.tester(self.config.agentTypes, plt, 'b', 'OurPolicy', m)
+ baseStockdata = self.players[0].srdqnBaseStock
+ # # use random to get output.
+ RsltRnd, plt = self.tester(["rnd", "rnd", "rnd", "rnd"], plt, 'y', 'RAND', m)
+
+ # use formual to get output.
+ RsltStrm, plt = self.tester(["Strm", "Strm", "Strm", "Strm"], plt, 'g', 'Strm', m)
+
+ # use optimal strategy to get output, if it works.
+ if self.ifOptimalSolExist:
+ if self.config.agentTypes == ["srdqn", "Strm", "Strm", "Strm"]:
+ Rsltbs, plt = self.tester(["bs", "Strm", "Strm", "Strm"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["Strm", "srdqn", "Strm", "Strm"]:
+ Rsltbs, plt = self.tester(["Strm", "bs", "Strm", "Strm"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["Strm", "Strm", "srdqn", "Strm"]:
+ Rsltbs, plt = self.tester(["Strm", "Strm", "bs", "Strm"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["Strm", "Strm", "Strm", "srdqn"]:
+ Rsltbs, plt = self.tester(["Strm", "Strm", "Strm", "bs"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["srdqn", "rnd", "rnd", "rnd"]:
+ Rsltbs, plt = self.tester(["bs", "rnd", "rnd", "rnd"], plt, 'r', 'RND-BS', m)
+ elif self.config.agentTypes == ["rnd", "srdqn", "rnd", "rnd"]:
+ Rsltbs, plt = self.tester(["rnd", "bs", "rnd", "rnd"], plt, 'r', 'RND-BS', m)
+ elif self.config.agentTypes == ["rnd", "rnd", "srdqn", "rnd"]:
+ Rsltbs, plt = self.tester(["rnd", "rnd", "bs", "rnd"], plt, 'r', 'RND-BS', m)
+ elif self.config.agentTypes == ["rnd", "rnd", "rnd", "srdqn"]:
+ Rsltbs, plt = self.tester(["rnd", "rnd", "rnd", "bs"], plt, 'r', 'RND-BS', m)
+ else:
+ Rsltbs, plt = self.tester(["bs", "bs", "bs", "bs"], plt, 'r', 'BS', m)
+ # hold the results of the optimal solution
+ self.middleTestResult += [[RsltRnd, RsltStrm, Rsltbs]]
+ else:
+ self.middleTestResult += [[RsltRnd, RsltStrm]]
+
+ else:
+ # return the obtained results into their lists
+ RsltRnd = self.middleTestResult[m][0]
+ RsltStrm = self.middleTestResult[m][1]
+ if self.ifOptimalSolExist:
+ Rsltbs = self.middleTestResult[m][2]
+
+ # save the figure
+ if self.config.ifSaveFigure:
+ savePlot(self.players, self.curGame, Rsltdnn, RsltStrm, Rsltbs, RsltRnd, self.config, m)
+ plt.close()
+
+ result_srdqn = ', '.join(map("{:.2f}".format, Rsltdnn))
+ result_rand = ', '.join(map("{:.2f}".format, RsltRnd))
+ result_strm = ', '.join(map("{:.2f}".format, RsltStrm))
+ if self.ifOptimalSolExist:
+ result_bs = ', '.join(map("{:.2f}".format, Rsltbs))
+ print(
+ 'output; {0:s}; Iter= {1:s}; SRDQN= [{2:s}]; sum = {3:2.4f}; Rand= [{4:s}]; sum = {5:2.4f}; Strm= [{6:s}]; sum = {7:2.4f}; BS= [{8:s}]; sum = {9:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), str(str(self.curGame) + "-" + str(m)), result_srdqn,
+ sum(Rsltdnn), result_rand, sum(RsltRnd), result_strm, sum(RsltStrm), result_bs, sum(Rsltbs)
+ )
+ )
+ self.resultTest += [[Rsltdnn, RsltRnd, RsltStrm, Rsltbs]]
+
+ else:
+ print(
+ 'output; {0:s}; Iter= {1:s}; SRDQN= [{2:s}]; sum = {3:2.4f}; Rand= [{4:s}]; sum = {5:2.4f}; Strm= [{6:s}]; sum = {7:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), str(str(self.curGame) + "-" + str(m)), result_srdqn,
+ sum(Rsltdnn), result_rand, sum(RsltRnd), result_strm, sum(RsltStrm)
+ )
+ )
+
+ self.resultTest += [[Rsltdnn, RsltRnd, RsltStrm]]
+
+ return sum(Rsltdnn)
+
+ def tester(self, testType, plt, colori, labeli, m):
+
+ # set computation type for test
+ for k in range(0, self.config.NoAgent):
+ # self.players[k].compTypeTest = testType[k]
+ self.players[k].compType = testType[k]
+ # run the episode to get the results.
+ if labeli != 'OurPolicy':
+ result = self.playGame(self.demand)
+ else:
+ result = [-1 * self.players[i].cumReward for i in range(0, self.config.NoAgent)]
+ # add the results into the figure
+ if self.config.ifSaveFigure:
+ plt = plotting(plt, [np.array(self.players[i].hist) for i in range(0, self.config.NoAgent)], colori, labeli)
+ if self.config.ifsaveHistInterval and ((self.curGame == 0) or (self.curGame == 1) or (self.curGame == 2) or (self.curGame == 3) or ((self.curGame - 1) % self.config.saveHistInterval == 0)\
+ or ((self.curGame) % self.config.saveHistInterval == 0) or ((self.curGame) % self.config.saveHistInterval == 1) \
+ or ((self.curGame) % self.config.saveHistInterval == 2)) :
+ for k in range(0, self.config.NoAgent):
+ name = labeli + "-" + str(self.curGame) + "-" + "player" + "-" + str(k) + "-" + str(m)
+ np.save(os.path.join(self.config.model_dir, name), np.array(self.players[k].hist2))
+
+ # save the figure of base stocks
+ # if self.config.ifSaveFigure and (self.curGame in range(self.config.saveFigInt[0],self.config.saveFigInt[1])):
+ # for k in range(self.config.NoAgent):
+ # if self.players[k].compTypeTest == 'dnn':
+ # plotBaseStock(self.players[k].srdqnBaseStock, 'b', 'base stock of agent '+ str(self.players[k].agentNum), self.curGame, self.config, m)
+
+ return result, plt
+
+ def playGame(self, demand):
+ self.resetGame(demand)
+
+ # run the game
+ while self.curTime < self.T:
+ self.handelAction(np.array(0)) # action won't be used.
+ self.next()
+ return [-1 * self.players[i].cumReward for i in range(0, self.config.NoAgent)]
diff --git a/dizoo/beergame/envs/plotting.py b/dizoo/beergame/envs/plotting.py
new file mode 100644
index 0000000000..57776c9641
--- /dev/null
+++ b/dizoo/beergame/envs/plotting.py
@@ -0,0 +1,72 @@
+# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL.
+import os
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from pylab import *
+
+
+# plotting
+def plotting(plt, data, colori, pltLabel):
+ # plt.hold(True)
+
+ for i in range(np.shape(data)[0]):
+ plt.subplot(4, 5, 5 * i + 1)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[1, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('IL')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 2)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[2, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('OO')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 3)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[3, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('a')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 4)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[5, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('OUTL')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 5)
+ plt.plot(np.transpose(data[i])[0, :], -1 * np.transpose(data[i])[4, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('r')
+ plt.grid(True)
+
+ return plt
+
+
+def savePlot(players, curGame, Rsltdnn, RsltFrmu, RsltOptm, RsltRnd, config, m):
+ #add title to plot
+ if config.if_titled_figure:
+ plt.suptitle(
+ "sum OurPolicy=" + str(round(sum(Rsltdnn), 2)) + "; sum Strm=" + str(round(sum(RsltFrmu), 2)) +
+ "; sum BS=" + str(round(sum(RsltOptm), 2)) + "; sum Rnd=" + str(round(sum(RsltRnd), 2)) + "\n" +
+ "Ag OurPolicy=" + str([round(Rsltdnn[i], 2) for i in range(config.NoAgent)]) + "; Ag Strm=" +
+ str([round(RsltFrmu[i], 2) for i in range(config.NoAgent)]) + "; Ag BS=" +
+ str([round(RsltOptm[i], 2) for i in range(config.NoAgent)]) + "; Ag Rnd=" +
+ str([round(RsltRnd[i], 2) for i in range(config.NoAgent)]),
+ fontsize=12
+ )
+
+ #insert legend to the figure
+ legend = plt.legend(bbox_to_anchor=(-1.4, -.165, 1., -.102), shadow=True, ncol=4)
+
+ # configures spaces between subplots
+ plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5)
+ # save the figure
+ path = os.path.join(config.figure_dir, 'saved_figures/')
+ if not os.path.exists(path):
+ os.mkdir(path)
+ plt.savefig(path + str(curGame) + '-' + str(m) + '.png', format='png')
+ print("figure" + str(curGame) + ".png saved in folder \"saved_figures\"")
+ plt.close(curGame)
diff --git a/dizoo/beergame/envs/utils.py b/dizoo/beergame/envs/utils.py
new file mode 100644
index 0000000000..2a6cf6f83d
--- /dev/null
+++ b/dizoo/beergame/envs/utils.py
@@ -0,0 +1,355 @@
+import argparse
+import os
+import numpy as np
+
+
+def str2bool(v):
+ return v.lower() in ('true', '1')
+
+
+arg_lists = []
+parser = argparse.ArgumentParser()
+
+
+def add_argument_group(name):
+ arg = parser.add_argument_group(name)
+ arg_lists.append(arg)
+ return arg
+
+
+# crm
+game_arg = add_argument_group('BeerGame')
+game_arg.add_argument('--task', type=str, default='bg')
+game_arg.add_argument(
+ '--fixedAction',
+ type=str2bool,
+ default='False',
+ help='if you want to have actions in [0,actionMax] set it to True. with False it will set it [actionLow, actionUp]'
+)
+game_arg.add_argument(
+ '--observation_data',
+ type=str2bool,
+ default=False,
+ help='if it is True, then it uses the data that is generated by based on few real world observation'
+)
+game_arg.add_argument('--data_id', type=int, default=22, help='the default item id for the basket dataset')
+game_arg.add_argument('--TLow', type=int, default=100, help='duration of one GAME (lower bound)')
+game_arg.add_argument('--TUp', type=int, default=100, help='duration of one GAME (upper bound)')
+game_arg.add_argument(
+ '--demandDistribution',
+ type=int,
+ default=0,
+ help='0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data'
+)
+game_arg.add_argument(
+ '--scaled', type=str2bool, default=False, help='if true it uses the (if) existing scaled parameters'
+)
+game_arg.add_argument('--demandSize', type=int, default=6100, help='the size of demand dataset')
+game_arg.add_argument('--demandLow', type=int, default=0, help='the lower bound of random demand')
+game_arg.add_argument('--demandUp', type=int, default=3, help='the upper bound of random demand')
+game_arg.add_argument('--demandMu', type=float, default=10, help='the mu of the normal distribution for demand ')
+game_arg.add_argument('--demandSigma', type=float, default=2, help='the sigma of the normal distribution for demand ')
+game_arg.add_argument('--actionMax', type=int, default=2, help='it works when fixedAction is True')
+game_arg.add_argument(
+ '--actionUp', type=int, default=2, help='bounds on my decision (upper bound), it works when fixedAction is True'
+)
+game_arg.add_argument(
+ '--actionLow', type=int, default=-2, help='bounds on my decision (lower bound), it works when fixedAction is True'
+)
+game_arg.add_argument(
+ '--action_step', type=int, default=1, help='The obtained action value by dnn is multiplied by this value'
+)
+game_arg.add_argument('--actionList', type=list, default=[], help='The list of the available actions')
+game_arg.add_argument('--actionListLen', type=int, default=0, help='the length of the action list')
+game_arg.add_argument(
+ '--actionListOpt', type=int, default=0, help='the action list which is used in optimal and sterman'
+)
+game_arg.add_argument('--actionListLenOpt', type=int, default=0, help='the length of the actionlistopt')
+game_arg.add_argument('--agentTypes', type=list, default=['dnn', 'dnn', 'dnn', 'dnn'], help='the player types')
+game_arg.add_argument(
+ '--agent_type1', type=str, default='dnn', help='the player types for agent 1, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument(
+ '--agent_type2', type=str, default='dnn', help='the player types for agent 2, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument(
+ '--agent_type3', type=str, default='dnn', help='the player types for agent 3, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument(
+ '--agent_type4', type=str, default='dnn', help='the player types for agent 4, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument('--NoAgent', type=int, default=4, help='number of agents, currently it should be in {1,2,3,4}')
+game_arg.add_argument('--cp1', type=float, default=2.0, help='shortage cost of player 1')
+game_arg.add_argument('--cp2', type=float, default=0.0, help='shortage cost of player 2')
+game_arg.add_argument('--cp3', type=float, default=0.0, help='shortage cost of player 3')
+game_arg.add_argument('--cp4', type=float, default=0.0, help='shortage cost of player 4')
+game_arg.add_argument('--ch1', type=float, default=2.0, help='holding cost of player 1')
+game_arg.add_argument('--ch2', type=float, default=2.0, help='holding cost of player 2')
+game_arg.add_argument('--ch3', type=float, default=2.0, help='holding cost of player 3')
+game_arg.add_argument('--ch4', type=float, default=2.0, help='holding cost of player 4')
+game_arg.add_argument('--alpha_b1', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 1')
+game_arg.add_argument('--alpha_b2', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 2')
+game_arg.add_argument('--alpha_b3', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 3')
+game_arg.add_argument('--alpha_b4', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 4')
+game_arg.add_argument('--betta_b1', type=float, default=-0.2, help='beta of Sterman formula parameter for player 1')
+game_arg.add_argument('--betta_b2', type=float, default=-0.2, help='beta of Sterman formula parameter for player 2')
+game_arg.add_argument('--betta_b3', type=float, default=-0.2, help='beta of Sterman formula parameter for player 3')
+game_arg.add_argument('--betta_b4', type=float, default=-0.2, help='beta of Sterman formula parameter for player 4')
+game_arg.add_argument('--eta', type=list, default=[0, 4, 4, 4], help='the total cost regulazer')
+game_arg.add_argument('--distCoeff', type=int, default=20, help='the total cost regulazer')
+game_arg.add_argument(
+ '--ifUseTotalReward',
+ type=str2bool,
+ default='False',
+ help='if you want to have the total rewards in the experience replay, set it to true.'
+)
+game_arg.add_argument(
+ '--ifUsedistTotReward',
+ type=str2bool,
+ default='True',
+ help='If use correction to the rewards in the experience replay for all iterations of current game'
+)
+game_arg.add_argument(
+ '--ifUseASAO',
+ type=str2bool,
+ default='True',
+ help='if use AS and AO, i.e., received shipment and received orders in the input of DNN'
+)
+game_arg.add_argument('--ifUseActionInD', type=str2bool, default='False', help='if use action in the input of DNN')
+game_arg.add_argument(
+ '--stateDim', type=int, default=5, help='Number of elements in the state desciptor - Depends on ifUseASAO'
+)
+game_arg.add_argument('--iftl', type=str2bool, default=False, help='if apply transfer learning')
+game_arg.add_argument(
+ '--ifTransferFromSmallerActionSpace',
+ type=str2bool,
+ default=False,
+ help='if want to transfer knowledge from a network with different action space size.'
+)
+game_arg.add_argument(
+ '--baseActionSize',
+ type=int,
+ default=5,
+ help='if ifTransferFromSmallerActionSpace is true, this determines the size of action space of saved network'
+)
+game_arg.add_argument(
+ '--tlBaseBrain',
+ type=int,
+ default=3,
+ help='the gameConfig of the base network for re-training with transfer-learning'
+)
+game_arg.add_argument('--baseDemandDistribution', type=int, default=0, help='same as the demandDistribution')
+game_arg.add_argument(
+ '--MultiAgent', type=str2bool, default=False, help='if run multi-agent RL model, not fully operational'
+)
+game_arg.add_argument(
+ '--MultiAgentRun',
+ type=list,
+ default=[True, True, True, True],
+ help='In the multi-RL setting, it determines which agent should get training.'
+)
+game_arg.add_argument(
+ '--if_use_AS_t_plus_1', type=str2bool, default='False', help='if use AS[t+1], not AS[t] in the input of DNN'
+)
+game_arg.add_argument(
+ '--ifSinglePathExist',
+ type=str2bool,
+ default=False,
+ help='If true it uses the predefined path in pre_model_dir and does not merge it with demandDistribution.'
+)
+game_arg.add_argument('--gamma', type=float, default=.99, help='discount factor for reward')
+game_arg.add_argument(
+ '--multPerdInpt', type=int, default=10, help='Number of history records which we feed into network'
+)
+
+# parameters of the leadtimes
+leadtimes_arg = add_argument_group('leadtimes')
+leadtimes_arg.add_argument(
+ '--leadRecItemLow', type=list, default=[2, 2, 2, 4], help='the min lead time for receiving items'
+)
+leadtimes_arg.add_argument(
+ '--leadRecItemUp', type=list, default=[2, 2, 2, 4], help='the max lead time for receiving items'
+)
+leadtimes_arg.add_argument(
+ '--leadRecOrderLow', type=int, default=[2, 2, 2, 0], help='the min lead time for receiving orders'
+)
+leadtimes_arg.add_argument(
+ '--leadRecOrderUp', type=int, default=[2, 2, 2, 0], help='the max lead time for receiving orders'
+)
+leadtimes_arg.add_argument('--ILInit', type=list, default=[0, 0, 0, 0], help='')
+leadtimes_arg.add_argument('--AOInit', type=list, default=[0, 0, 0, 0], help='')
+leadtimes_arg.add_argument('--ASInit', type=list, default=[0, 0, 0, 0], help='the initial shipment of each agent')
+leadtimes_arg.add_argument('--leadRecItem1', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecItem2', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecItem3', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecItem4', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecOrder1', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--leadRecOrder2', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--leadRecOrder3', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--leadRecOrder4', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--ILInit1', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--ILInit2', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--ILInit3', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--ILInit4', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--AOInit1', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--AOInit2', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--AOInit3', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--AOInit4', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--ASInit1', type=int, default=0, help='the initial arriving shipment of the agent')
+leadtimes_arg.add_argument('--ASInit2', type=int, default=0, help='the initial arriving shipment of the agent')
+leadtimes_arg.add_argument('--ASInit3', type=int, default=0, help='the initial arriving shipment of the agent')
+leadtimes_arg.add_argument('--ASInit4', type=int, default=0, help='the initial arriving shipment of the agent')
+
+# test
+test_arg = add_argument_group('testing')
+test_arg.add_argument(
+ '--testRepeatMid',
+ type=int,
+ default=50,
+ help='it is number of episodes which is going to be used for testing in the middle of training'
+)
+test_arg.add_argument('--testInterval', type=int, default=100, help='every xx games compute "test error"')
+test_arg.add_argument(
+ '--ifSaveFigure', type=str2bool, default=True, help='if is it True, save the figures in each testing.'
+)
+test_arg.add_argument(
+ '--if_titled_figure',
+ type=str2bool,
+ default='True',
+ help='if is it True, save the figures with details in the title.'
+)
+test_arg.add_argument(
+ '--ifsaveHistInterval', type=str2bool, default=False, help='if every xx games save details of the episode'
+)
+test_arg.add_argument('--saveHistInterval', type=int, default=50000, help='every xx games save details of the play')
+test_arg.add_argument('--Ttest', type=int, default=100, help='it defines the number of periods in the test cases')
+test_arg.add_argument(
+ '--ifOptimalSolExist',
+ type=str2bool,
+ default=True,
+ help='if the instance has optimal base stock policy, set it to True, otherwise it should be False.'
+)
+test_arg.add_argument('--f1', type=float, default=8, help='base stock policy decision of player 1')
+test_arg.add_argument('--f2', type=float, default=8, help='base stock policy decision of player 2')
+test_arg.add_argument('--f3', type=float, default=0, help='base stock policy decision of player 3')
+test_arg.add_argument('--f4', type=float, default=0, help='base stock policy decision of player 4')
+test_arg.add_argument(
+ '--f_init',
+ type=list,
+ default=[32, 32, 32, 24],
+ help='base stock policy decision for 4 time-steps on the C(4,8) demand distribution'
+)
+test_arg.add_argument('--use_initial_BS', type=str2bool, default=False, help='If use f_init set it to True')
+
+# reporting
+reporting_arg = add_argument_group('reporting')
+reporting_arg.add_argument('--Rsltdnn', type=list, default=[], help='the result of dnn play tests will be saved here')
+reporting_arg.add_argument(
+ '--RsltRnd', type=list, default=[], help='the result of random play tests will be saved here'
+)
+reporting_arg.add_argument(
+ '--RsltStrm', type=list, default=[], help='the result of heuristic fomula play tests will be saved here'
+)
+reporting_arg.add_argument(
+ '--Rsltbs', type=list, default=[], help='the result of optimal play tests will be saved here'
+)
+reporting_arg.add_argument(
+ '--ifSaveHist',
+ type=str2bool,
+ default='False',
+ help=
+ 'if it is true, saves history, prediction, and the randBatch in each period, WARNING: just make it True in small runs, it saves huge amount of files.'
+)
+
+
+# buildActionList: actions for the beer game problem
+def buildActionList(config):
+ aDiv = 1 # difference in the action list
+ if config.fixedAction:
+ actions = list(
+ range(0, config.actionMax + 1, aDiv)
+ ) # If you put the second argument =11, creates an actionlist from 0..xx
+ else:
+ actions = list(range(config.actionLow, config.actionUp + 1, aDiv))
+ return actions
+
+
+# specify the dimension of the state of the game
+def getStateDim(config):
+ if config.ifUseASAO:
+ stateDim = 5
+ else:
+ stateDim = 3
+
+ if config.ifUseActionInD:
+ stateDim += 1
+
+ return stateDim
+
+
+def set_optimal(config):
+ if config.demandDistribution == 0:
+ if config.cp1 == 2 and config.ch1 == 2 and config.ch2 == 2 and config.ch3 == 2 and config.ch4 == 2:
+ config.f1 = 8.
+ config.f2 = 8.
+ config.f3 = 0.
+ config.f4 = 0.
+
+
+def get_config():
+ config, unparsed = parser.parse_known_args()
+ config = update_config(config)
+
+ return config, unparsed
+
+
+def fill_leadtime_initial_values(config):
+ config.leadRecItemLow = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4]
+ config.leadRecItemUp = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4]
+ config.leadRecOrderLow = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4]
+ config.leadRecOrderUp = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4]
+ config.ILInit = [config.ILInit1, config.ILInit2, config.ILInit3, config.ILInit4]
+ config.AOInit = [config.AOInit1, config.AOInit2, config.AOInit3, config.AOInit4]
+ config.ASInit = [config.ASInit1, config.ASInit2, config.ASInit3, config.ASInit4]
+
+
+def get_auxuliary_leadtime_initial_values(config):
+ config.leadRecOrderUp_aux = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4]
+ config.leadRecItemUp_aux = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4]
+
+
+def fix_lead_time_manufacturer(config):
+ if config.leadRecOrder4 > 0:
+ config.leadRecItem4 += config.leadRecOrder4
+ config.leadRecOrder4 = 0
+
+
+def set_sterman_parameters(config):
+ config.alpha_b = [config.alpha_b1, config.alpha_b2, config.alpha_b3, config.alpha_b4]
+ config.betta_b = [config.betta_b1, config.betta_b2, config.betta_b3, config.betta_b4]
+
+
+def update_config(config):
+ config.actionList = buildActionList(config) # The list of the available actions
+ config.actionListLen = len(config.actionList) # the length of the action list
+
+ set_optimal(config)
+ config.f = [config.f1, config.f2, config.f3, config.f4] # [6.4, 2.88, 2.08, 0.8]
+
+ config.actionListLen = len(config.actionList)
+ if config.demandDistribution == 0:
+ config.actionListOpt = list(range(0, int(max(config.actionUp * 30 + 1, 3 * sum(config.f))), 1))
+ else:
+ config.actionListOpt = list(range(0, int(max(config.actionUp * 30 + 1, 7 * sum(config.f))), 1))
+ config.actionListLenOpt = len(config.actionListOpt)
+
+ config.c_h = [config.ch1, config.ch2, config.ch3, config.ch4]
+ config.c_p = [config.cp1, config.cp2, config.cp3, config.cp4]
+
+ config.stateDim = getStateDim(config) # Number of elements in the state description - Depends on ifUseASAO
+ get_auxuliary_leadtime_initial_values(config)
+ fix_lead_time_manufacturer(config)
+ fill_leadtime_initial_values(config)
+ set_sterman_parameters(config)
+
+ return config