Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add vpn capability properties #658

Merged
merged 6 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/ssh/ssh.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ async def get_payload():
image_hash="ea233c6774b1621207a48e10b46e3e1f944d881911f499f5cbac546a",
min_mem_gib=0.5,
min_storage_gib=2.0,
capabilities=[vm.VM_CAPS_VPN],
)

async def run(self):
Expand Down
52 changes: 39 additions & 13 deletions tests/props/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ class Foo(props_base.Model):
bar: str = props_base.prop("bar", "cafebiba")
max_baz: int = props_base.constraint("baz", "<=", 100)
min_baz: int = props_base.constraint("baz", ">=", 1)
lst: list = props_base.constraint("lst", "=", default_factory=list)


@dataclass
Expand All @@ -23,7 +24,7 @@ class FooZero(props_base.Model):

def test_constraint_fields():
fields = Foo.constraint_fields()
assert len(fields) == 2
assert len(fields) == 3
assert any(f.name == "max_baz" for f in fields)
assert all(f.name != "bar" for f in fields)

Expand All @@ -41,79 +42,104 @@ def test_constraint_to_str():
assert props_base.constraint_to_str(foo.max_baz, max_baz) == "(baz<=42)"


@pytest.mark.parametrize(
"value, constraint_str",
[
(["one"], "(lst=one)"),
(["one", "two"], "(&(lst=one)\n\t(lst=two))"),
],
)
def test_constraint_to_str_list(value, constraint_str):
foo = Foo(lst=value)
lst = [f for f in foo.constraint_fields() if f.name == "lst"][0]
assert props_base.constraint_to_str(foo.lst, lst) == constraint_str


def test_constraint_model_serialize():
foo = Foo()
constraints = props_base.constraint_model_serialize(foo)
assert constraints == ["(baz<=100)", "(baz>=1)"]
assert constraints == ["(baz<=100)", "(baz>=1)", ""]


@pytest.mark.parametrize(
"model, operator, result, error",
[
(
Foo,
Foo(),
None,
"(&(baz<=100)\n\t(baz>=1))",
False,
),
(
Foo,
Foo(),
"&",
"(&(baz<=100)\n\t(baz>=1))",
False,
),
(
Foo,
Foo(lst=["one"]),
None,
"(&(baz<=100)\n\t(baz>=1)\n\t(lst=one))",
False,
),
(
Foo(lst=["one", "other"]),
None,
"(&(baz<=100)\n\t(baz>=1)\n\t(&(lst=one)\n\t(lst=other)))",
False,
),
(
Foo(),
"|",
"(|(baz<=100)\n\t(baz>=1))",
False,
),
(
Foo,
Foo(),
"!",
None,
True,
),
(
FooToo,
FooToo(),
"!",
"(!(baz=21))",
False,
),
(
FooToo,
FooToo(),
"&",
"(baz=21)",
False,
),
(
FooToo,
FooToo(),
"|",
"(baz=21)",
False,
),
(
FooZero,
FooZero(),
None,
"(&)",
False,
),
(
FooZero,
FooZero(),
"&",
"(&)",
False,
),
(
FooZero,
FooZero(),
"|",
"(|)",
False,
),
],
)
def test_join_str_constraints(model, operator, result, error):
args = [props_base.constraint_model_serialize(model())]
args = [props_base.constraint_model_serialize(model)]
if operator:
args.append(operator)
try:
Expand Down
23 changes: 20 additions & 3 deletions yapapi/payload/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,15 @@
from dataclasses import dataclass, field
from enum import Enum
import logging
from typing import Optional
import sys
from typing import Optional, List
from typing_extensions import Final

if sys.version_info > (3, 8):
from typing import Literal
else:
from typing_extensions import Literal
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Both Final and Literal are in typing since 3.8.
I have no opinion on whether we should always import them from _extensions or base import on sys.version_info, but I'm sure we should have them together.


from srvresolver.srv_resolver import SRVResolver, SRVRecord # type: ignore

from yapapi.payload.package import (
Expand All @@ -22,6 +29,10 @@

logger = logging.getLogger(__name__)

VM_CAPS_VPN: str = "vpn"

VmCaps = Literal["vpn"]


@dataclass
class InfVm(InfBase):
Expand All @@ -47,7 +58,10 @@ class _VmConstraints:
min_mem_gib: float = prop_base.constraint(inf.INF_MEM, operator=">=")
min_storage_gib: float = prop_base.constraint(inf.INF_STORAGE, operator=">=")
min_cpu_threads: int = prop_base.constraint(inf.INF_THREADS, operator=">=")
# cores: int = prop_base.constraint(inf.INF_CORES, operator=">=")

capabilities: List[VmCaps] = prop_base.constraint(
"golem.runtime.capabilities", operator="=", default_factory=list
)

runtime: str = prop_base.constraint(inf.INF_RUNTIME_NAME, operator="=", default=RUNTIME_VM)

Expand Down Expand Up @@ -80,6 +94,7 @@ async def repo(
min_mem_gib: float = 0.5,
min_storage_gib: float = 2.0,
min_cpu_threads: int = 1,
capabilities: Optional[List[VmCaps]] = None,
johny-b marked this conversation as resolved.
Show resolved Hide resolved
) -> Package:
"""
Build a reference to application package.
Expand All @@ -89,13 +104,15 @@ async def repo(
:param min_mem_gib: minimal memory required to execute application code
:param min_storage_gib: minimal disk storage to execute tasks
:param min_cpu_threads: minimal available logical CPU cores
:param capabilities: an optional list of required vm capabilities
:return: the payload definition for the given VM image
"""
capabilities = capabilities or list()
return _VmPackage(
repo_url=resolve_repo_srv(_DEFAULT_REPO_SRV),
image_hash=image_hash,
image_url=image_url,
constraints=_VmConstraints(min_mem_gib, min_storage_gib, min_cpu_threads),
constraints=_VmConstraints(min_mem_gib, min_storage_gib, min_cpu_threads, capabilities),
)


Expand Down
14 changes: 11 additions & 3 deletions yapapi/props/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class ModelFieldType(enum.Enum):
property = "property"


def constraint(key: str, operator: ConstraintOperator = "=", default=MISSING):
def constraint(
key: str, operator: ConstraintOperator = "=", default=MISSING, default_factory=MISSING
):
"""
Return a constraint-type dataclass field for a Model.

Expand All @@ -227,8 +229,9 @@ def constraint(key: str, operator: ConstraintOperator = "=", default=MISSING):
['(baz<=100)']
```
"""
return field(
return field( # type: ignore # the default / default_factory exception is resolved by the `field` function
default=default,
default_factory=default_factory,
metadata={
PROP_KEY: key,
PROP_OPERATOR: operator,
Expand Down Expand Up @@ -271,7 +274,10 @@ def constraint_to_str(value, f: Field) -> str:
:param value: the value of the the constraint field
:param f: the dataclass field for this constraint
"""
return f"({f.metadata[PROP_KEY]}{f.metadata[PROP_OPERATOR]}{value})"
if type(value) == list:
return join_str_constraints([constraint_to_str(v, f) for v in value]) if value else ""
else:
return f"({f.metadata[PROP_KEY]}{f.metadata[PROP_OPERATOR]}{value})"


def constraint_model_serialize(m: Model) -> List[str]:
Expand Down Expand Up @@ -316,6 +322,8 @@ def join_str_constraints(constraints: List[str], operator: ConstraintGroupOperat
(bar<=128))
```
"""
constraints = [c for c in constraints if c]

if operator == "!":
if len(constraints) == 1:
return f"({operator}{constraints[0]})"
Expand Down