diff --git a/guppylang/checker/core.py b/guppylang/checker/core.py index c2ec79a0..d919a14b 100644 --- a/guppylang/checker/core.py +++ b/guppylang/checker/core.py @@ -30,6 +30,7 @@ nat_type_def, none_type_def, sized_iter_type_def, + string_type_def, tuple_type_def, ) from guppylang.tys.param import Parameter @@ -237,6 +238,7 @@ def default() -> "Globals": nat_type_def, int_type_def, float_type_def, + string_type_def, list_type_def, array_type_def, sized_iter_type_def, diff --git a/guppylang/checker/expr_checker.py b/guppylang/checker/expr_checker.py index d6753e30..3f1dcc7f 100644 --- a/guppylang/checker/expr_checker.py +++ b/guppylang/checker/expr_checker.py @@ -117,6 +117,7 @@ is_sized_iter_type, list_type, nat_type, + string_type, ) from guppylang.tys.param import ConstParam, TypeParam from guppylang.tys.subst import Inst, Subst @@ -1179,6 +1180,8 @@ def python_value_to_guppy_type( match v: case bool(): return bool_type() + case str(): + return string_type() # Only resolve `int` to `nat` if the user specifically asked for it case int(n) if type_hint == nat_type() and n >= 0: return nat_type() diff --git a/guppylang/compiler/expr_compiler.py b/guppylang/compiler/expr_compiler.py index e724cdec..a28cec4c 100644 --- a/guppylang/compiler/expr_compiler.py +++ b/guppylang/compiler/expr_compiler.py @@ -9,6 +9,7 @@ import hugr.std.float import hugr.std.int import hugr.std.logic +import hugr.std.prelude from hugr import Hugr, Wire, ops from hugr import tys as ht from hugr import val as hv @@ -596,6 +597,8 @@ def python_value_to_hugr(v: Any, exp_ty: Type) -> hv.Value | None: match v: case bool(): return hv.bool_value(v) + case str(): + return hugr.std.prelude.StringVal(v) case int(): return hugr.std.int.IntVal(v, width=NumericType.INT_WIDTH) case float(): diff --git a/guppylang/std/builtins.py b/guppylang/std/builtins.py index 85a9be7d..fbc09b06 100644 --- a/guppylang/std/builtins.py +++ b/guppylang/std/builtins.py @@ -46,6 +46,7 @@ list_type_def, nat_type_def, sized_iter_type_def, + string_type_def, ) guppy.init_module(import_builtins=False) @@ -121,6 +122,12 @@ def __or__(self: bool, other: bool) -> bool: ... def __xor__(self: bool, other: bool) -> bool: ... +@guppy.extend_type(string_type_def) +class String: + @guppy.custom(checker=UnsupportedChecker(), higher_order_value=False) + def __new__(x): ... + + @guppy.extend_type(nat_type_def) class Nat: @guppy.custom(NoopCompiler()) @@ -890,10 +897,6 @@ def sorted(x): ... def staticmethod(x): ... -@guppy.custom(checker=UnsupportedChecker(), higher_order_value=False) -def str(x): ... - - @guppy.custom(checker=UnsupportedChecker(), higher_order_value=False) def sum(x): ... diff --git a/guppylang/tys/builtin.py b/guppylang/tys/builtin.py index 740869dc..370aaff7 100644 --- a/guppylang/tys/builtin.py +++ b/guppylang/tys/builtin.py @@ -168,6 +168,14 @@ def _sized_iter_to_hugr(args: Sequence[Argument]) -> ht.Type: float_type_def = _NumericTypeDef( DefId.fresh(), "float", None, NumericType(NumericType.Kind.Float) ) +string_type_def = OpaqueTypeDef( + id=DefId.fresh(), + name="str", + defined_at=None, + params=[], + always_linear=False, + to_hugr=lambda _: hugr.std.PRELUDE.get_type("string").instantiate([]), +) list_type_def = _ListTypeDef( id=DefId.fresh(), name="list", @@ -216,6 +224,10 @@ def float_type() -> NumericType: return NumericType(NumericType.Kind.Float) +def string_type() -> OpaqueType: + return OpaqueType([], string_type_def) + + def list_type(element_ty: Type) -> OpaqueType: return OpaqueType([TypeArg(element_ty)], list_type_def) @@ -236,6 +248,10 @@ def is_bool_type(ty: Type) -> bool: return isinstance(ty, OpaqueType) and ty.defn == bool_type_def +def is_string_type(ty: Type) -> bool: + return isinstance(ty, OpaqueType) and ty.defn == string_type_def + + def is_list_type(ty: Type) -> bool: return isinstance(ty, OpaqueType) and ty.defn == list_type_def diff --git a/tests/error/misc_errors/unsupported_const.err b/tests/error/misc_errors/unsupported_const.err index baf3536d..20810ec4 100644 --- a/tests/error/misc_errors/unsupported_const.err +++ b/tests/error/misc_errors/unsupported_const.err @@ -1,8 +1,8 @@ -Error: Unsupported constant (at $FILE:7:8) +Error: Unsupported constant (at $FILE:6:9) | -5 | @compile_guppy -6 | def foo() -> None: -7 | x = "foo" - | ^^^^^ Type `str` is not supported +4 | @compile_guppy +5 | def foo() -> None: +6 | x = -2j + | ^^ Type `complex` is not supported Guppy compilation failed due to 1 previous error diff --git a/tests/error/misc_errors/unsupported_const.py b/tests/error/misc_errors/unsupported_const.py index 1959d05b..928f4948 100644 --- a/tests/error/misc_errors/unsupported_const.py +++ b/tests/error/misc_errors/unsupported_const.py @@ -1,7 +1,6 @@ -from guppylang.std.builtins import array from tests.util import compile_guppy @compile_guppy def foo() -> None: - x = "foo" + x = -2j diff --git a/tests/integration/test_py.py b/tests/integration/test_py.py index ac9dce41..95a4bd0f 100644 --- a/tests/integration/test_py.py +++ b/tests/integration/test_py.py @@ -117,6 +117,14 @@ def foo() -> None: validate(foo) +def test_strings(validate): + @compile_guppy + def foo() -> None: + x: str = py("a" + "b") + + validate(foo) + + @pytest.mark.skipif(not tket2_installed, reason="Tket2 is not installed") @pytest.mark.skip("Fails because of extensions in types #343") def test_pytket_single_qubit(validate): diff --git a/tests/integration/test_strings.py b/tests/integration/test_strings.py new file mode 100644 index 00000000..b584bc52 --- /dev/null +++ b/tests/integration/test_strings.py @@ -0,0 +1,35 @@ +from guppylang.decorator import guppy +from guppylang.module import GuppyModule +from tests.util import compile_guppy + +import pytest + +def test_basic_type(validate): + @compile_guppy + def foo(x: str) -> str: + return x + + validate(foo) + + +def test_basic_value(validate): + @compile_guppy + def foo() -> str: + x = "Hello World" + return x + + validate(foo) + + +def test_struct(validate): + module = GuppyModule("module") + + @guppy.struct(module) + class StringStruct: + x: str + + @guppy(module) + def main(s: StringStruct) -> None: + StringStruct("Lorem Ipsum") + + validate(module.compile()) \ No newline at end of file