Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update sepsis.py #1

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 4 additions & 37 deletions POMDPy/examples/sepsis/sepsis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import random
import numpy as np
import sys
sys.path.append('/next/u/hjnam/locf/env/sepsisSimDiabetes')
sys.path.append('/next/u/hjnam/POMDPy')
sys.path.append('../../locf/env/sepsisSimDiabetes')
sys.path.append('../..//POMDPy')
import pickle
from pomdpy.discrete_pomdp import DiscreteActionPool, DiscreteObservationPool
from pomdpy.discrete_pomdp import DiscreteAction
Expand Down Expand Up @@ -190,15 +190,11 @@ def generate_particles(self, prev_belief, action, obs, n_particles, prev_particl

'''
Can use a plug-in empirical simulator
Inp: action (dtype int), state (dtype int)
Input: action (dtype int), state (dtype int)
1) # of samples
2) exact noise level
'''
def empirical_simulate(self, state, action):
# return self.sim.step(action, state)
# rew = 0
# if action > 7:
# rew += self.cost
action = action % 8
if (state, action) in self.empi_model.keys():
p = self.empi_model[(state, action)]
Expand All @@ -214,18 +210,6 @@ def empirical_simulate(self, state, action):
rew = temp.env.calculateReward()
return BoxState(int(state), is_terminal=bool(rew != 0), r=rew), True

# def make_next_state(self, state, action):
# if state.terminal:
# return state.copy(), False
# if type(action) is not int:
# action = action.bin_number
# if type(state) is not int:
# state = state.position
# # this should be an imagined step in the learned simulator
# _, rew, done, info = self.empirical_simulate(action, state)
# next_pos = info['true_state']
# return BoxState(int(next_pos), is_terminal=done, r=rew), True

'''
In the real env, observation = state \cup NA
but always return TRUE
Expand All @@ -251,10 +235,6 @@ def make_next_position(self, state, action):
if type(state) is not int:
state = state.position
return self.empirical_simulate(state, action)
# # should be through the learned simulator
# _, _, _, info = self.empirical_simulate(action, state)
# next_position = info['true_state']
# return int(next_position)

def get_all_observations(self):
obs = {}
Expand Down Expand Up @@ -292,7 +272,7 @@ def generate_step(self, state, action, _true=False, is_mdp=False):
else:
result.next_state, is_legal = self.make_next_position(state, action)

### for true runs #####
### for true/eval runs #####
# result.next_state, is_legal = self.take_real_state(state, action)
########################

Expand All @@ -302,19 +282,6 @@ def generate_step(self, state, action, _true=False, is_mdp=False):
result.is_terminal = result.next_state.terminal
return result, is_legal


'''
def mdp_generate_step(self, state, action):
if type(action) is int:
action = BoxAction(action)
result = StepResult()
result.next_state, is_legal = self.make_next_position(state, action)
result.action = action.copy()
result.observation = self.make_observation(action, result.next_state, always_obs=True)
result.reward = self.make_reward(state, action, result.next_state, is_legal, always_obs=True)
result.is_terminal = result.next_state.terminal
return result, is_legal
'''

def reset_for_simulation(self):
pass
Expand Down
13 changes: 11 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
# acno_mdp

Credits:
- DVRL encoder & a2c is slighlty adapted from https://github.com/maximilianigl/DVRL (Igl et al)
- Monte-Carlo Tree Search and POMDP tree search code is from https://github.com/pemami4911/POMDPy.
@ARTICLE{emami2015pomdpy,
author = {Emami, Patrick and Hamlet, Alan J., and Crane, Carl},
title = {POMDPy: An Extensible Framework for Implementing POMDPs in Python},
year = {2015},
}
- Sepsis simulator code is from https://github.com/clinicalml/gumbel-max-scm.

Conda packages and versions used for generating the reported results are shared in conda.yml (Note not all the packages are needed).

To run the known observation belief encoder
Expand All @@ -21,5 +32,3 @@ python pomcp.py --init_idx 256 --cost -0.1 --is_mdp 0

2. MCTS
python pomcp.py --init_idx 256 --cost -0.05 --is_mdp 1


2 changes: 1 addition & 1 deletion known_obs/code/known_pf_model_alg2.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __init__(self,
if self.halved_acts: # know only 2 actions (dx, dy) affect the agent dynamics
enc_actions -= 1

self.action_encoding=128
self.action_encoding=action_encoding
self.action_encoder = nn.Sequential(
nn.Linear(enc_actions, action_encoding),
nn.ReLU()
Expand Down