diff --git a/POMDPy/examples/sepsis/sepsis.py b/POMDPy/examples/sepsis/sepsis.py index d90f337..05b53cb 100644 --- a/POMDPy/examples/sepsis/sepsis.py +++ b/POMDPy/examples/sepsis/sepsis.py @@ -2,8 +2,8 @@ import random import numpy as np import sys -sys.path.append('/next/u/hjnam/locf/env/sepsisSimDiabetes') -sys.path.append('/next/u/hjnam/POMDPy') +sys.path.append('../../locf/env/sepsisSimDiabetes') +sys.path.append('../..//POMDPy') import pickle from pomdpy.discrete_pomdp import DiscreteActionPool, DiscreteObservationPool from pomdpy.discrete_pomdp import DiscreteAction @@ -190,15 +190,11 @@ def generate_particles(self, prev_belief, action, obs, n_particles, prev_particl ''' Can use a plug-in empirical simulator - Inp: action (dtype int), state (dtype int) + Input: action (dtype int), state (dtype int) 1) # of samples 2) exact noise level ''' def empirical_simulate(self, state, action): - # return self.sim.step(action, state) - # rew = 0 - # if action > 7: - # rew += self.cost action = action % 8 if (state, action) in self.empi_model.keys(): p = self.empi_model[(state, action)] @@ -214,18 +210,6 @@ def empirical_simulate(self, state, action): rew = temp.env.calculateReward() return BoxState(int(state), is_terminal=bool(rew != 0), r=rew), True - # def make_next_state(self, state, action): - # if state.terminal: - # return state.copy(), False - # if type(action) is not int: - # action = action.bin_number - # if type(state) is not int: - # state = state.position - # # this should be an imagined step in the learned simulator - # _, rew, done, info = self.empirical_simulate(action, state) - # next_pos = info['true_state'] - # return BoxState(int(next_pos), is_terminal=done, r=rew), True - ''' In the real env, observation = state \cup NA but always return TRUE @@ -251,10 +235,6 @@ def make_next_position(self, state, action): if type(state) is not int: state = state.position return self.empirical_simulate(state, action) - # # should be through the learned simulator - # _, _, _, info = self.empirical_simulate(action, state) - # next_position = info['true_state'] - # return int(next_position) def get_all_observations(self): obs = {} @@ -292,7 +272,7 @@ def generate_step(self, state, action, _true=False, is_mdp=False): else: result.next_state, is_legal = self.make_next_position(state, action) - ### for true runs ##### + ### for true/eval runs ##### # result.next_state, is_legal = self.take_real_state(state, action) ######################## @@ -302,19 +282,6 @@ def generate_step(self, state, action, _true=False, is_mdp=False): result.is_terminal = result.next_state.terminal return result, is_legal - - ''' - def mdp_generate_step(self, state, action): - if type(action) is int: - action = BoxAction(action) - result = StepResult() - result.next_state, is_legal = self.make_next_position(state, action) - result.action = action.copy() - result.observation = self.make_observation(action, result.next_state, always_obs=True) - result.reward = self.make_reward(state, action, result.next_state, is_legal, always_obs=True) - result.is_terminal = result.next_state.terminal - return result, is_legal - ''' def reset_for_simulation(self): pass diff --git a/README.md b/README.md index e1b9946..c031f55 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,15 @@ # acno_mdp + +Credits: +- DVRL encoder & a2c is slighlty adapted from https://github.com/maximilianigl/DVRL (Igl et al) +- Monte-Carlo Tree Search and POMDP tree search code is from https://github.com/pemami4911/POMDPy. +@ARTICLE{emami2015pomdpy, + author = {Emami, Patrick and Hamlet, Alan J., and Crane, Carl}, + title = {POMDPy: An Extensible Framework for Implementing POMDPs in Python}, + year = {2015}, +} +- Sepsis simulator code is from https://github.com/clinicalml/gumbel-max-scm. + Conda packages and versions used for generating the reported results are shared in conda.yml (Note not all the packages are needed). To run the known observation belief encoder @@ -21,5 +32,3 @@ python pomcp.py --init_idx 256 --cost -0.1 --is_mdp 0 2. MCTS python pomcp.py --init_idx 256 --cost -0.05 --is_mdp 1 - - diff --git a/known_obs/code/known_pf_model_alg2.py b/known_obs/code/known_pf_model_alg2.py index c00cd45..7b9f972 100644 --- a/known_obs/code/known_pf_model_alg2.py +++ b/known_obs/code/known_pf_model_alg2.py @@ -89,7 +89,7 @@ def __init__(self, if self.halved_acts: # know only 2 actions (dx, dy) affect the agent dynamics enc_actions -= 1 - self.action_encoding=128 + self.action_encoding=action_encoding self.action_encoder = nn.Sequential( nn.Linear(enc_actions, action_encoding), nn.ReLU()