Skip to content

Commit

Permalink
Implementation with minimal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
khsrali committed Oct 2, 2024
1 parent 07c30db commit 1a92d2c
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 56 deletions.
11 changes: 2 additions & 9 deletions src/plumpy/process_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@

INTENT_KEY = 'intent'
MESSAGE_KEY = 'message'
FORCE_KILL_KEY = 'force_kill'


class Intent:
Expand Down Expand Up @@ -197,19 +196,17 @@ async def play_process(self, pid: 'PID_TYPE') -> 'ProcessResult':
result = await asyncio.wrap_future(future)
return result

async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_kill: bool = False) -> 'ProcessResult':
async def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> 'ProcessResult':
"""
Kill the process
:param pid: the pid of the process to kill
:param msg: optional kill message
:return: True if killed, False otherwise
"""
breakpoint()
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
message[FORCE_KILL_KEY] = force_kill

# Wait for the communication to go through
kill_future = self._communicator.rpc_send(pid, message)
Expand Down Expand Up @@ -378,7 +375,7 @@ def play_all(self) -> None:
"""
self._communicator.broadcast_send(None, subject=Intent.PLAY)

def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_kill: bool = False) -> kiwipy.Future:
def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None) -> kiwipy.Future:
"""
Kill the process
Expand All @@ -387,11 +384,9 @@ def kill_process(self, pid: 'PID_TYPE', msg: Optional[Any] = None, force_kill: b
:return: a response future from the process to be killed
"""
breakpoint()
message = copy.copy(KILL_MSG)
if msg is not None:
message[MESSAGE_KEY] = msg
message[FORCE_KILL_KEY] = force_kill

return self._communicator.rpc_send(pid, message)

Expand All @@ -410,7 +405,6 @@ def continue_process(
nowait: bool = False,
no_reply: bool = False
) -> Union[None, PID_TYPE, ProcessResult]:
breakpoint()
message = create_continue_body(pid=pid, tag=tag, nowait=nowait)
return self.task_send(message, no_reply=no_reply)

Expand Down Expand Up @@ -485,7 +479,6 @@ def task_send(self, message: Any, no_reply: bool = False) -> Optional[Any]:
:param no_reply: if True, this call will be fire-and-forget, i.e. no return value
:return: the response from the remote side (if no_reply=False)
"""
breakpoint()
return self._communicator.task_send(message, no_reply=no_reply)


Expand Down
3 changes: 3 additions & 0 deletions src/plumpy/process_states.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
'Continue',
'Interruption',
'KillInterruption',
'ForceKillInterruption',
'PauseInterruption',
]

Expand All @@ -50,6 +51,8 @@ class Interruption(Exception):
class KillInterruption(Interruption):
pass

class ForceKillInterruption(Interruption):
pass

class PauseInterruption(Interruption):
pass
Expand Down
64 changes: 17 additions & 47 deletions src/plumpy/processes.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,29 +51,7 @@

__all__ = ['Process', 'ProcessSpec', 'BundleKeys', 'TransitionFailed']


#file_handler = logging.FileHandler(filename='tmp.log')
#stdout_handler = logging.StreamHandler(stream=sys.stdout)
#handlers = [file_handler, stdout_handler]
#
#logging.basicConfig(
# level=logging.DEBUG,
# format='[%(asctime)s] {%(filename)s:%(lineno)d} %(levelname)s - %(message)s',
# handlers=handlers
#)

#file_handler = logging.FileHandler(filename="/Users/alexgo/code/aiida-core/plumpy2.log")
#stdout_handler = logging.StreamHandler(stream=sys.stdout)
#handlers = [file_handler, stdout_handler]
#
#logging.basicConfig(
# level=logging.DEBUG,
# format='[%(asctime)s] {%(pathname)s:%(lineno)d} %(levelname)s - %(message)s',
# handlers=handlers
#)
_LOGGER = logging.getLogger(__name__)


PROCESS_STACK = ContextVar('process stack', default=[])


Expand Down Expand Up @@ -411,8 +389,8 @@ def logger(self) -> logging.Logger:
:return: The logger.
"""
#if self._logger is not None:
# return self._logger
if self._logger is not None:
return self._logger

return _LOGGER

Expand Down Expand Up @@ -930,7 +908,6 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
:param msg: the message
:return: the outcome of processing the message, the return value will be sent back as a response to the sender
"""
breakpoint()
self.logger.debug("Process<%s>: received RPC message with communicator '%s': %r", self.pid, _comm, msg)

intent = msg[process_comms.INTENT_KEY]
Expand All @@ -940,11 +917,7 @@ def message_receive(self, _comm: kiwipy.Communicator, msg: Dict[str, Any]) -> An
if intent == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.KILL:
breakpoint()
# have problems to pass new argument get
# Error: failed to kill Process<699>: Process.kill() got an unexpected keyword argument 'force_kill'
#return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None), force_kill=msg.get(process_comms.FORCE_KILL_KEY, False))
return self._schedule_rpc(self.kill, msg=msg)
return self._schedule_rpc(self.kill, msg=msg.get(process_comms.MESSAGE_KEY, None))
if intent == process_comms.Intent.STATUS:
status_info: Dict[str, Any] = {}
self.get_status_info(status_info)
Expand All @@ -961,7 +934,6 @@ def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any,
:param _comm: the communicator that sent the message
:param msg: the message
"""
breakpoint()
# pylint: disable=unused-argument
self.logger.debug(
"Process<%s>: received broadcast message '%s' with communicator '%s': %r", self.pid, subject, _comm, body
Expand All @@ -973,7 +945,6 @@ def broadcast_receive(self, _comm: kiwipy.Communicator, body: Any, sender: Any,
if subject == process_comms.Intent.PAUSE:
return self._schedule_rpc(self.pause, msg=body)
if subject == process_comms.Intent.KILL:
# TODO deal with this
return self._schedule_rpc(self.kill, msg=body)
return None

Expand Down Expand Up @@ -1096,7 +1067,7 @@ def _create_interrupt_action(self, exception: process_states.Interruption) -> fu
do_pause = functools.partial(self._do_pause, str(exception))
return futures.CancellableAction(do_pause, cookie=exception)

if isinstance(exception, process_states.KillInterruption):
if isinstance(exception, process_states.KillInterruption) or isinstance(exception, process_states.ForceKillInterruption):

def do_kill(_next_state: process_states.State) -> Any:
try:
Expand Down Expand Up @@ -1155,21 +1126,15 @@ def fail(self, exception: Optional[BaseException], trace_back: Optional[Tracebac
"""
self.transition_to(process_states.ProcessState.EXCEPTED, exception, trace_back)

def kill(self, msg: Union[dict, None] = None, force_kill: bool = False) -> Union[bool, asyncio.Future]:
def kill(self, msg: Union[str, None] = None) -> Union[bool, asyncio.Future]:
"""
Kill the process
# PR_COMMENT have not figured out how to integrate force_kill as argument
# so I just pass the dict
:param msg: An optional kill message
"""
breakpoint()
if msg is None:
force_kill = False
if isinstance(msg, str) and '-F' in msg:
force_kill = True
else:
force_kill = msg.get(process_comms.FORCE_KILL_KEY, False)

force_kill = False
if self.state == process_states.ProcessState.KILLED:
# Already killed
return True
Expand All @@ -1178,20 +1143,25 @@ def kill(self, msg: Union[dict, None] = None, force_kill: bool = False) -> Union
# Can't kill
return False

if self._killing:
if self._killing and not force_kill:
# Already killing
return self._killing

if self._stepping and not force_kill:
if force_kill:
# Skip interrupting the state and go straight to killed
interrupt_exception = process_states.ForceKillInterruption(msg)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)

elif self._stepping:
# Ask the step function to pause by setting this flag and giving the
# caller back a future
interrupt_exception = process_states.KillInterruption(msg.get(process_comms.MESSAGE_KEY, None))
interrupt_exception = process_states.KillInterruption(msg) # type: ignore
self._set_interrupt_action_from_exception(interrupt_exception)
self._killing = self._interrupt_action
self._state.interrupt(interrupt_exception)
return cast(futures.CancellableAction, self._interrupt_action)

breakpoint()
self.transition_to(process_states.ProcessState.KILLED, msg)
return True

Expand Down

0 comments on commit 1a92d2c

Please sign in to comment.