Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enable parametrization of Service instances within a single Golem.run_service() call #514

Merged
merged 12 commits into from
Jul 5, 2021
25 changes: 18 additions & 7 deletions examples/simple-service-poc/simple_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,17 @@
TEXT_COLOR_YELLOW,
)

NUM_INSTANCES = 1
STARTING_TIMEOUT = timedelta(minutes=4)


class SimpleService(Service):
SIMPLE_SERVICE = "/golem/run/simple_service.py"
SIMPLE_SERVICE_CTL = "/golem/run/simulate_observations_ctl.py"

def __init__(self, *args, instance_name, **kwargs):
super().__init__(*args, **kwargs)
self.name = instance_name

@staticmethod
async def get_payload():
return await vm.repo(
Expand Down Expand Up @@ -89,7 +92,7 @@ async def shutdown(self):
print(f" --- {self._ctx.provider_name} COST: {await self._ctx.get_cost()}")


async def main(subnet_tag, driver=None, network=None):
async def main(subnet_tag, driver=None, network=None, num_instances=1):
async with Golem(
budget=1.0,
subnet_tag=subnet_tag,
Expand All @@ -107,27 +110,29 @@ async def main(subnet_tag, driver=None, network=None):
commissioning_time = datetime.now()

print(
f"{TEXT_COLOR_YELLOW}starting {pluralize(NUM_INSTANCES, 'instance')}{TEXT_COLOR_DEFAULT}"
f"{TEXT_COLOR_YELLOW}starting {pluralize(num_instances, 'instance')}{TEXT_COLOR_DEFAULT}"
)

# start the service

cluster = await golem.run_service(
SimpleService,
num_instances=NUM_INSTANCES,
instance_params=[
{"instance_name": f"simple-service-{i+1}"} for i in range(num_instances)
],
expiration=datetime.now(timezone.utc) + timedelta(minutes=120),
)

# helper functions to display / filter instances

def instances():
return [(s.provider_name, s.state.value) for s in cluster.instances]
return [f"{s.name}: {s.state.value} on {s.provider_name}" for s in cluster.instances]

def still_running():
return any([s for s in cluster.instances if s.is_available])

def still_starting():
return len(cluster.instances) < NUM_INSTANCES or any(
return len(cluster.instances) < num_instances or any(
[s for s in cluster.instances if s.state == ServiceState.starting]
)

Expand Down Expand Up @@ -168,6 +173,7 @@ def still_starting():
parser = build_parser(
"A very simple / POC example of a service running on Golem, utilizing the VM runtime"
)
parser.add_argument("--num-instances", type=int, default=1)
now = datetime.now().strftime("%Y-%m-%d_%H.%M.%S")
parser.set_defaults(log_file=f"simple-service-yapapi-{now}.log")
args = parser.parse_args()
Expand All @@ -184,7 +190,12 @@ def still_starting():

loop = asyncio.get_event_loop()
task = loop.create_task(
main(subnet_tag=args.subnet_tag, driver=args.driver, network=args.network)
main(
subnet_tag=args.subnet_tag,
driver=args.driver,
network=args.network,
num_instances=args.num_instances,
)
)

try:
Expand Down
79 changes: 79 additions & 0 deletions tests/services/test_cluster.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import itertools
import sys
import pytest
from unittest.mock import Mock, patch, call
from yapapi.services import Cluster, Service, ServiceError


class _TestService(Service):
pass


def _get_cluster():
return Cluster(engine=Mock(), service_class=_TestService, payload=Mock())


@pytest.mark.parametrize(
"kwargs, calls, error",
[
(
{"num_instances": 1},
[call({})],
None,
),
(
{"num_instances": 3},
[call({}) for _ in range(3)],
None,
),
(
{"instance_params": [{}]},
[call({})],
None,
),
(
{"instance_params": [{"n": 1}, {"n": 2}]},
[call({"n": 1}), call({"n": 2})],
None,
),
(
# num_instances takes precedence
{"num_instances": 2, "instance_params": [{} for _ in range(3)]},
[call({}), call({})],
None,
),
(
# num_instances takes precedence
{"num_instances": 3, "instance_params": ({"n": i} for i in itertools.count(1))},
[call({"n": 1}), call({"n": 2}), call({"n": 3})],
None,
),
(
# num_instances takes precedence
{"num_instances": 4, "instance_params": [{} for _ in range(3)]},
[call({}) for _ in range(3)],
"`instance_params` iterable depleted after 3 spawned instances.",
),
(
{"num_instances": 0},
[],
None,
),
],
)
@pytest.mark.asyncio
@pytest.mark.skipif(sys.version_info < (3, 8), reason="AsyncMock requires python 3.8+")
async def test_spawn_instances(kwargs, calls, error):
with patch("yapapi.services.Cluster.spawn_instance") as spawn_instance:
cluster = _get_cluster()
try:
cluster.spawn_instances(**kwargs)
except ServiceError as e:
if error is not None:
assert str(e) == error
else:
assert False, e
else:
assert error is None, f"Expected ServiceError: {error}"

assert spawn_instance.mock_calls == calls
21 changes: 16 additions & 5 deletions yapapi/golem.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,27 @@ async def worker(context: WorkContext, tasks: AsyncIterable[Task]):
async def run_service(
self,
service_class: Type[Service],
num_instances: int = 1,
num_instances: Optional[int] = None,
instance_params: Optional[Iterable[Dict]] = None,
payload: Optional[Payload] = None,
expiration: Optional[datetime] = None,
) -> Cluster:
"""Run a number of instances of a service represented by a given `Service` subclass.

:param service_class: a subclass of `Service` that represents the service to be run
:param num_instances: optional number of service instances to run, defaults to a single
instance
:param num_instances: optional number of service instances to run. Defaults to a single
instance, unless `instance_params` is given, in which case, the Cluster will be created
with as many instances as there are elements in the `instance_params` iterable.
if `num_instances` is set to < 1, the `Cluster` will still be created but no instances
will be spawned within it.
:param instance_params: optional list of dictionaries of keyword arguments that will be passed
to consecutive, spawned instances. The number of elements in the iterable determines the
number of instances spawned, unless `num_instances` is given, in which case the latter takes
precedence.
In other words, if both `num_instances` and `instance_params` are provided,
the Cluster will be created with the number of instances determined by `num_instances`
and if there are too few elements in the `instance_params` iterable, it will results in
an error.
:param payload: optional runtime definition for the service; if not provided, the
payload specified by the `get_payload()` method of `service_class` is used
:param expiration: optional expiration datetime for the service
Expand Down Expand Up @@ -200,9 +212,8 @@ async def main():
engine=self,
service_class=service_class,
payload=payload,
num_instances=num_instances,
expiration=expiration,
)
await self._stack.enter_async_context(cluster)
cluster.spawn_instances()
cluster.spawn_instances(num_instances=num_instances, instance_params=instance_params)
return cluster
74 changes: 54 additions & 20 deletions yapapi/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import statemachine # type: ignore
import sys
from types import TracebackType
from typing import Any, AsyncContextManager, List, Optional, Set, Tuple, Type, Union
from typing import Any, AsyncContextManager, List, Optional, Set, Tuple, Type, Union, Iterable, Dict

if sys.version_info >= (3, 7):
from contextlib import AsyncExitStack
Expand Down Expand Up @@ -38,6 +38,10 @@
cluster_ids = itertools.count(1)


class ServiceError(Exception):
pass


class ServiceState(statemachine.StateMachine):
"""State machine describing the state and lifecycle of a Service instance."""

Expand Down Expand Up @@ -238,15 +242,13 @@ def __init__(
engine: "_Engine",
service_class: Type[Service],
payload: Payload,
num_instances: int = 1,
expiration: Optional[datetime] = None,
):
"""Initialize this Cluster.

:param engine: an engine for running service instance
:param service_class: service specification
:param payload: definition of service runtime for this Cluster
:param num_instances: number of instances to spawn in this Cluster
:param expiration: a date before which all agreements related to running services
in this Cluster should be terminated
"""
Expand All @@ -256,7 +258,6 @@ def __init__(
self._engine = engine
self._service_class = service_class
self._payload = payload
self._num_instances = num_instances
self._expiration = expiration or datetime.now(timezone.utc) + DEFAULT_SERVICE_EXPIRATION
self._task_ids = itertools.count(1)
self._stack = AsyncExitStack()
Expand All @@ -269,7 +270,7 @@ def __init__(

def __repr__(self):
return (
f"Cluster {self.id}: {self._num_instances}x[Service: {self._service_class.__name__}, "
f"Cluster {self.id}: {len(self.__instances)}x[Service: {self._service_class.__name__}, "
f"Payload: {self._payload}]"
)

Expand Down Expand Up @@ -398,10 +399,9 @@ def _change_state(
instance.service._exc_info = event
return instance.state != prev_state

async def _run_instance(self, ctx: WorkContext):

async def _run_instance(self, ctx: WorkContext, params: Dict):
loop = asyncio.get_event_loop()
instance = ServiceInstance(service=self._service_class(self, ctx))
instance = ServiceInstance(service=self._service_class(self, ctx, **params)) # type: ignore
self.__instances.append(instance)

logger.info("%s commissioned", instance.service)
Expand Down Expand Up @@ -494,7 +494,7 @@ def change_state(event: Union[ControlSignal, ExcInfo] = (None, None, None)) -> N

logger.info("%s decomissioned", instance.service)

async def spawn_instance(self) -> None:
async def spawn_instance(self, params: Dict) -> None:
"""Spawn a new service instance within this Cluster."""

logger.debug("spawning instance within %s", self)
Expand All @@ -518,7 +518,7 @@ async def _worker(
)

try:
instance_batches = self._run_instance(work_context)
instance_batches = self._run_instance(work_context, params)
try:
await self._engine.process_batches(agreement.id, activity, instance_batches)
except StopAsyncIteration:
Expand Down Expand Up @@ -557,21 +557,55 @@ def stop_instance(self, service: Service):
instance = self.__get_service_instance(service)
instance.control_queue.put_nowait(ControlSignal.stop)

def spawn_instances(self, num_instances: Optional[int] = None) -> None:
def spawn_instances(
self,
num_instances: Optional[int] = None,
instance_params: Optional[Iterable[Dict]] = None,
) -> None:
"""Spawn new instances within this Cluster.

:param num_instances: number of instances to commission.
if not given, spawns the number that the Cluster has been initialized with.
:param num_instances: optional number of service instances to run. Defaults to a single
instance, unless `instance_params` is given, in which case, the Cluster will spawn
as many instances as there are elements in the `instance_params` iterable.
if `num_instances` is not None and < 1, the method will immediately return and log a warning.
:param instance_params: optional list of dictionaries of keyword arguments that will be passed
to consecutive, spawned instances. The number of elements in the iterable determines the
number of instances spawned, unless `num_instances` is given, in which case the latter takes
precedence.
In other words, if both `num_instances` and `instance_params` are provided,
the number of instances spawned will be equal to `num_instances` and if there are
too few elements in the `instance_params` iterable, it will results in an error.

"""
if num_instances:
self._num_instances += num_instances
else:
num_instances = self._num_instances
# just a sanity check
if num_instances is not None and num_instances < 1:
logger.warning(
"Trying to spawn less than one instance. num_instances: %s", num_instances
)
return

# if the parameters iterable was not given, assume a default of a single instance
if not num_instances and not instance_params:
shadeofblue marked this conversation as resolved.
Show resolved Hide resolved
num_instances = 1

# convert the parameters iterable to an iterator
# if not provided, make a default iterator consisting of empty dictionaries
instance_params = iter(instance_params or (dict() for _ in range(num_instances))) # type: ignore

loop = asyncio.get_event_loop()
for i in range(num_instances):
task = loop.create_task(self.spawn_instance())
self._instance_tasks.add(task)
spawned_instances = 0
while not num_instances or spawned_instances < num_instances:
try:
params = next(instance_params)
task = loop.create_task(self.spawn_instance(params))
self._instance_tasks.add(task)
spawned_instances += 1
except StopIteration:
if num_instances and spawned_instances < num_instances:
raise ServiceError(
f"`instance_params` iterable depleted after {spawned_instances} spawned instances."
)
break

def stop(self):
"""Signal the whole cluster to stop."""
Expand Down