diff --git a/clients/client-python/gravitino/api/expressions/__init__.py b/clients/client-python/gravitino/api/expressions/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/clients/client-python/gravitino/api/expressions/expression.py b/clients/client-python/gravitino/api/expressions/expression.py new file mode 100644 index 00000000000..41669042cd4 --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/expression.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from gravitino.api.expressions.named_reference import NamedReference + + +class Expression(ABC): + """Base class of the public logical expression API.""" + + EMPTY_EXPRESSION: list[Expression] = [] + """ + `EMPTY_EXPRESSION` is only used as an input when the default `children` method builds the result. + """ + + EMPTY_NAMED_REFERENCE: list[NamedReference] = [] + """ + `EMPTY_NAMED_REFERENCE` is only used as an input when the default `references` method builds + the result array to avoid repeatedly allocating an empty array. + """ + + @abstractmethod + def children(self) -> list[Expression]: + """Returns a list of the children of this node. Children should not change.""" + pass + + def references(self) -> list[NamedReference]: + """Returns a list of fields or columns that are referenced by this expression.""" + + ref_set: set[NamedReference] = set() + for child in self.children(): + ref_set.update(child.references()) + return list(ref_set) diff --git a/clients/client-python/gravitino/api/expressions/function_expression.py b/clients/client-python/gravitino/api/expressions/function_expression.py new file mode 100644 index 00000000000..7664cf9bf85 --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/function_expression.py @@ -0,0 +1,92 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations +from abc import abstractmethod + +from gravitino.api.expressions.expression import Expression + + +class FunctionExpression(Expression): + """ + The interface of a function expression. A function expression is an expression that takes a + function name and a list of arguments. + """ + + @staticmethod + def of(function_name: str, *arguments: Expression) -> FuncExpressionImpl: + """ + Creates a new FunctionExpression with the given function name. + If no arguments are provided, it uses an empty expression. + + :param function_name: The name of the function. + :param arguments: The arguments to the function (optional). + :return: The created FunctionExpression. + """ + arguments = list(arguments) if arguments else Expression.EMPTY_EXPRESSION + return FuncExpressionImpl(function_name, arguments) + + @abstractmethod + def function_name(self) -> str: + """Returns the function name.""" + + @abstractmethod + def arguments(self) -> list[Expression]: + """Returns the arguments passed to the function.""" + + def children(self) -> list[Expression]: + """Returns the arguments as children.""" + return self.arguments() + + +class FuncExpressionImpl(FunctionExpression): + """ + A concrete implementation of the FunctionExpression interface. + """ + + _function_name: str + _arguments: list[Expression] + + def __init__(self, function_name: str, arguments: list[Expression]): + super().__init__() + self._function_name = function_name + self._arguments = arguments + + def function_name(self) -> str: + return self._function_name + + def arguments(self) -> list[Expression]: + return self._arguments + + def __str__(self) -> str: + if not self._arguments: + return f"{self._function_name}()" + arguments_str = ", ".join(map(str, self._arguments)) + return f"{self._function_name}({arguments_str})" + + def __eq__(self, other: FuncExpressionImpl) -> bool: + if self is other: + return True + if other is None or self.__class__ is not other.__class__: + return False + return ( + self._function_name == other.function_name() + and self._arguments == other.arguments() + ) + + def __hash__(self) -> int: + return hash((self._function_name, tuple(self._arguments))) diff --git a/clients/client-python/gravitino/api/expressions/literals/__init__.py b/clients/client-python/gravitino/api/expressions/literals/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/literals/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/clients/client-python/gravitino/api/expressions/literals/literal.py b/clients/client-python/gravitino/api/expressions/literals/literal.py new file mode 100644 index 00000000000..676b9ef4ce5 --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/literals/literal.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from abc import abstractmethod +from typing import List, TypeVar, Generic + +from gravitino.api.expressions.expression import Expression +from gravitino.api.types.type import Type + +T = TypeVar("T") + + +class Literal(Generic[T], Expression): + """ + Represents a constant literal value in the public expression API. + """ + + @abstractmethod + def value(self) -> T: + """The literal value.""" + raise NotImplementedError("Subclasses must implement the `value` method.") + + @abstractmethod + def data_type(self) -> Type: + """The data type of the literal.""" + raise NotImplementedError("Subclasses must implement the `data_type` method.") + + def children(self) -> List[Expression]: + return Expression.EMPTY_EXPRESSION diff --git a/clients/client-python/gravitino/api/expressions/literals/literals.py b/clients/client-python/gravitino/api/expressions/literals/literals.py new file mode 100644 index 00000000000..c4d07338bcc --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/literals/literals.py @@ -0,0 +1,137 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import decimal +from typing import TypeVar +from datetime import date, time, datetime + +from gravitino.api.expressions.literals.literal import Literal +from gravitino.api.types.type import Type +from gravitino.api.types.types import Types + +T = TypeVar("T") + + +class LiteralImpl(Literal[T]): + """Creates a literal with the given type value.""" + + _value: T + _data_type: Type + + def __init__(self, value: T, data_type: Type): + self._value = value + self._data_type = data_type + + def value(self) -> T: + return self._value + + def data_type(self) -> Type: + return self._data_type + + def __eq__(self, other: object) -> bool: + if not isinstance(other, LiteralImpl): + return False + return (self._value == other._value) and (self._data_type == other._data_type) + + def __hash__(self): + return hash((self._value, self._data_type)) + + def __str__(self): + return f"LiteralImpl(value={self._value}, data_type={self._data_type})" + + +class Literals: + """The helper class to create literals to pass into Apache Gravitino.""" + + NULL = LiteralImpl(None, Types.NullType.get()) + + @staticmethod + def of(value: T, data_type: Type) -> Literal[T]: + return LiteralImpl(value, data_type) + + @staticmethod + def boolean_literal(value: bool) -> LiteralImpl[bool]: + return LiteralImpl(value, Types.BooleanType.get()) + + @staticmethod + def byte_literal(value: str) -> LiteralImpl[str]: + return LiteralImpl(value, Types.ByteType.get()) + + @staticmethod + def unsigned_byte_literal(value: str) -> LiteralImpl[str]: + return LiteralImpl(value, Types.ByteType.unsigned()) + + @staticmethod + def short_literal(value: int) -> LiteralImpl[int]: + return LiteralImpl(value, Types.ShortType.get()) + + @staticmethod + def unsigned_short_literal(value: int) -> LiteralImpl[int]: + return LiteralImpl(value, Types.ShortType.unsigned()) + + @staticmethod + def integer_literal(value: int) -> LiteralImpl[int]: + return LiteralImpl(value, Types.IntegerType.get()) + + @staticmethod + def unsigned_integer_literal(value: int) -> LiteralImpl[int]: + return LiteralImpl(value, Types.IntegerType.unsigned()) + + @staticmethod + def long_literal(value: int) -> LiteralImpl[int]: + return LiteralImpl(value, Types.LongType.get()) + + @staticmethod + def unsigned_long_literal(value: int) -> LiteralImpl[int]: + return LiteralImpl(value, Types.LongType.unsigned()) + + @staticmethod + def float_literal(value: float) -> LiteralImpl[float]: + return LiteralImpl(value, Types.FloatType.get()) + + @staticmethod + def double_literal(value: float) -> LiteralImpl[float]: + return LiteralImpl(value, Types.DoubleType.get()) + + @staticmethod + def decimal_literal(value: decimal.Decimal) -> LiteralImpl[decimal.Decimal]: + precision: int = len(value.as_tuple().digits) + scale: int = -value.as_tuple().exponent + return LiteralImpl(value, Types.DecimalType.of(max(precision, scale), scale)) + + @staticmethod + def date_literal(value: date) -> Literal[date]: + return LiteralImpl(value, Types.DateType.get()) + + @staticmethod + def time_literal(value: time) -> Literal[time]: + return Literals.of(value, Types.TimeType.get()) + + @staticmethod + def timestamp_literal(value: datetime) -> Literal[datetime]: + return Literals.of(value, Types.TimestampType.without_time_zone()) + + @staticmethod + def timestamp_literal_from_string(value: str) -> Literal[datetime]: + return Literals.timestamp_literal(datetime.fromisoformat(value)) + + @staticmethod + def string_literal(value: str) -> Literal[str]: + return LiteralImpl(value, Types.StringType.get()) + + @staticmethod + def varchar_literal(length: int, value: str) -> Literal[str]: + return LiteralImpl(value, Types.VarCharType.of(length)) diff --git a/clients/client-python/gravitino/api/expressions/named_reference.py b/clients/client-python/gravitino/api/expressions/named_reference.py new file mode 100644 index 00000000000..3b766b4ac23 --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/named_reference.py @@ -0,0 +1,86 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from gravitino.api.expressions.expression import Expression + + +class NamedReference(Expression): + """ + Represents a field or column reference in the public logical expression API. + """ + + @staticmethod + def field(field_name: list[str]) -> FieldReference: + """ + Returns a FieldReference for the given field name(s). The array of field name(s) is + used to reference nested fields. For example, if we have a struct column named "student" with a + data type of StructType{"name": StringType, "age": IntegerType}, we can reference the field + "name" by calling field("student", "name"). + + @param field_name the field name(s) + @return a FieldReference for the given field name(s) + """ + return FieldReference(field_name) + + @staticmethod + def field_from_column(column_name: str) -> FieldReference: + """Returns a FieldReference for the given column name.""" + return FieldReference([column_name]) + + def field_name(self) -> list[str]: + """ + Returns the referenced field name as a list of string parts. + Must be implemented by subclasses. + """ + raise NotImplementedError("Subclasses must implement this method.") + + def children(self) -> list[Expression]: + """Named references do not have children.""" + return Expression.EMPTY_EXPRESSION + + def references(self) -> list[NamedReference]: + """Named references reference themselves.""" + return [self] + + +class FieldReference(NamedReference): + """ + A NamedReference that references a field or column. + """ + + _field_names: list[str] + + def __init__(self, field_names: list[str]): + super().__init__() + self._field_names = field_names + + def field_name(self) -> list[str]: + return self._field_names + + def __eq__(self, other: object) -> bool: + if isinstance(other, FieldReference): + return self._field_names == other._field_names + return False + + def __hash__(self) -> int: + return hash(tuple(self._field_names)) + + def __str__(self) -> str: + """Returns the string representation of the field reference.""" + return ".".join(self._field_names) diff --git a/clients/client-python/gravitino/api/expressions/unparsed_expression.py b/clients/client-python/gravitino/api/expressions/unparsed_expression.py new file mode 100644 index 00000000000..55ca327567f --- /dev/null +++ b/clients/client-python/gravitino/api/expressions/unparsed_expression.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +from gravitino.api.expressions.expression import Expression + + +class UnparsedExpression(Expression): + """ + Represents an expression that is not parsed yet. + The parsed expression is represented by FunctionExpression, literal.py, or NamedReference. + """ + + def unparsed_expression(self) -> str: + """ + Returns the unparsed expression as a string. + """ + raise NotImplementedError("Subclasses must implement this method.") + + def children(self) -> list[Expression]: + """ + Unparsed expressions do not have children. + """ + return Expression.EMPTY_EXPRESSION + + @staticmethod + def of(unparsed_expression: str) -> UnparsedExpressionImpl: + """ + Creates a new UnparsedExpression with the given unparsed expression. + + + :param unparsed_expression: The unparsed expression as a string. + :return: The created UnparsedExpression. + """ + return UnparsedExpressionImpl(unparsed_expression) + + +class UnparsedExpressionImpl(UnparsedExpression): + """ + An implementation of the UnparsedExpression interface. + """ + + def __init__(self, unparsed_expression: str): + super().__init__() + self._unparsed_expression = unparsed_expression + + def unparsed_expression(self) -> str: + return self._unparsed_expression + + def __eq__(self, other: object) -> bool: + if isinstance(other, UnparsedExpressionImpl): + return self._unparsed_expression == other._unparsed_expression + return False + + def __hash__(self) -> int: + return hash(self._unparsed_expression) + + def __str__(self) -> str: + """ + Returns the string representation of the unparsed expression. + """ + return f"UnparsedExpressionImpl{{unparsedExpression='{self._unparsed_expression}'}}" diff --git a/clients/client-python/gravitino/api/types/__init__.py b/clients/client-python/gravitino/api/types/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/clients/client-python/gravitino/api/types/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/clients/client-python/gravitino/api/type.py b/clients/client-python/gravitino/api/types/type.py similarity index 100% rename from clients/client-python/gravitino/api/type.py rename to clients/client-python/gravitino/api/types/type.py diff --git a/clients/client-python/gravitino/api/types.py b/clients/client-python/gravitino/api/types/types.py similarity index 99% rename from clients/client-python/gravitino/api/types.py rename to clients/client-python/gravitino/api/types/types.py index b82ac2b6844..63684211a9a 100644 --- a/clients/client-python/gravitino/api/types.py +++ b/clients/client-python/gravitino/api/types/types.py @@ -123,7 +123,7 @@ def get(cls) -> "ShortType": return cls(True) @classmethod - def unsigned(cls): + def unsigned(cls) -> "ShortType": return cls(False) def name(self) -> Name: diff --git a/clients/client-python/tests/unittests/test_expressions.py b/clients/client-python/tests/unittests/test_expressions.py new file mode 100644 index 00000000000..6054c1fde67 --- /dev/null +++ b/clients/client-python/tests/unittests/test_expressions.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from typing import List +from gravitino.api.expressions.expression import Expression +from gravitino.api.expressions.named_reference import NamedReference + + +class MockExpression(Expression): + """Mock implementation of the Expression class for testing.""" + + def __init__( + self, children: List[Expression] = None, references: List[NamedReference] = None + ): + self._children = children if children else [] + self._references = references if references else [] + + def children(self) -> List[Expression]: + return self._children + + def references(self) -> List[NamedReference]: + if self._references: + return self._references + return super().references() + + +class TestExpression(unittest.TestCase): + def test_empty_expression(self): + expr = MockExpression() + self.assertEqual(expr.children(), []) + self.assertEqual(expr.references(), []) + + def test_expression_with_references(self): + ref = NamedReference.field(["student", "name"]) + child = MockExpression(references=[ref]) + expr = MockExpression(children=[child]) + self.assertEqual(expr.children(), [child]) + self.assertEqual(expr.references(), [ref]) + + def test_multiple_children(self): + ref1 = NamedReference.field(["student", "name"]) + ref2 = NamedReference.field(["student", "age"]) + child1 = MockExpression(references=[ref1]) + child2 = MockExpression(references=[ref2]) + expr = MockExpression(children=[child1, child2]) + self.assertCountEqual(expr.references(), [ref1, ref2]) diff --git a/clients/client-python/tests/unittests/test_function_expression.py b/clients/client-python/tests/unittests/test_function_expression.py new file mode 100644 index 00000000000..deaa2089e23 --- /dev/null +++ b/clients/client-python/tests/unittests/test_function_expression.py @@ -0,0 +1,62 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from gravitino.api.expressions.function_expression import ( + FunctionExpression, + FuncExpressionImpl, +) +from gravitino.api.expressions.expression import Expression + + +class MockExpression(Expression): + """Mock implementation of the Expression class for testing.""" + + def children(self): + return [] + + def references(self): + return [] + + def __str__(self): + return "MockExpression()" + + +class TestFunctionExpression(unittest.TestCase): + def test_function_without_arguments(self): + func = FuncExpressionImpl("SUM", []) + self.assertEqual(func.function_name(), "SUM") + self.assertEqual(func.arguments(), []) + self.assertEqual(str(func), "SUM()") + + def test_function_with_arguments(self): + arg1 = MockExpression() + arg2 = MockExpression() + func = FuncExpressionImpl("SUM", [arg1, arg2]) + self.assertEqual(func.function_name(), "SUM") + self.assertEqual(func.arguments(), [arg1, arg2]) + self.assertEqual(str(func), "SUM(MockExpression(), MockExpression())") + + def test_function_equality(self): + func1 = FuncExpressionImpl("SUM", []) + func2 = FuncExpressionImpl("SUM", []) + self.assertEqual(func1, func2) + self.assertEqual(hash(func1), hash(func2)) + + def test_function_of_static_method(self): + func = FunctionExpression.of("SUM", MockExpression()) + self.assertEqual(func.function_name(), "SUM") diff --git a/clients/client-python/tests/unittests/test_literals.py b/clients/client-python/tests/unittests/test_literals.py new file mode 100644 index 00000000000..d9c96b7bab0 --- /dev/null +++ b/clients/client-python/tests/unittests/test_literals.py @@ -0,0 +1,95 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import unittest +from datetime import date, time, datetime +from decimal import Decimal + +from gravitino.api.expressions.literals.literals import Literals +from gravitino.api.types.types import Types + + +class TestLiterals(unittest.TestCase): + def test_null_literal(self): + null_val = Literals.NULL + self.assertEqual(null_val.value(), None) + self.assertEqual(null_val.data_type(), Types.NullType.get()) + + def test_boolean_literal(self): + bool_val = Literals.boolean_literal(True) + self.assertEqual(bool_val.value(), True) + self.assertEqual(bool_val.data_type(), Types.BooleanType.get()) + + def test_integer_literal(self): + int_val = Literals.integer_literal(42) + self.assertEqual(int_val.value(), 42) + self.assertEqual(int_val.data_type(), Types.IntegerType.get()) + + def test_string_literal(self): + str_val = Literals.string_literal("Hello World") + self.assertEqual(str_val.value(), "Hello World") + self.assertEqual(str_val.data_type(), Types.StringType.get()) + + def test_decimal_literal(self): + decimal_val = Literals.decimal_literal(Decimal("0.00")) + self.assertEqual(decimal_val.value(), Decimal("0.00")) + self.assertEqual(decimal_val.data_type(), Types.DecimalType.of(2, 2)) + + def test_date_literal(self): + date_val = Literals.date_literal(date(2023, 1, 1)) + self.assertEqual(date_val.value(), date(2023, 1, 1)) + self.assertEqual(date_val.data_type(), Types.DateType.get()) + + def test_time_literal(self): + time_val = Literals.time_literal(time(12, 30, 45)) + self.assertEqual(time_val.value(), time(12, 30, 45)) + self.assertEqual(time_val.data_type(), Types.TimeType.get()) + + def test_timestamp_literal(self): + timestamp_val = Literals.timestamp_literal(datetime(2023, 1, 1, 12, 30, 45)) + self.assertEqual(timestamp_val.value(), datetime(2023, 1, 1, 12, 30, 45)) + self.assertEqual( + timestamp_val.data_type(), Types.TimestampType.without_time_zone() + ) + + def test_timestamp_literal_from_string(self): + timestamp_val = Literals.timestamp_literal_from_string("2023-01-01T12:30:45") + self.assertEqual(timestamp_val.value(), datetime(2023, 1, 1, 12, 30, 45)) + self.assertEqual( + timestamp_val.data_type(), Types.TimestampType.without_time_zone() + ) + + def test_varchar_literal(self): + varchar_val = Literals.varchar_literal(10, "Test String") + self.assertEqual(varchar_val.value(), "Test String") + self.assertEqual(varchar_val.data_type(), Types.VarCharType.of(10)) + + def test_equality(self): + int_val1 = Literals.integer_literal(42) + int_val2 = Literals.integer_literal(42) + int_val3 = Literals.integer_literal(10) + self.assertTrue(int_val1 == int_val2) + self.assertFalse(int_val1 == int_val3) + + def test_hash(self): + int_val1 = Literals.integer_literal(42) + int_val2 = Literals.integer_literal(42) + self.assertEqual(hash(int_val1), hash(int_val2)) + + def test_unequal_literals(self): + int_val = Literals.integer_literal(42) + str_val = Literals.string_literal("Hello") + self.assertFalse(int_val == str_val) diff --git a/clients/client-python/tests/unittests/test_named_reference.py b/clients/client-python/tests/unittests/test_named_reference.py new file mode 100644 index 00000000000..a9942aec7fc --- /dev/null +++ b/clients/client-python/tests/unittests/test_named_reference.py @@ -0,0 +1,39 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from gravitino.api.expressions.named_reference import NamedReference, FieldReference + + +class TestNamedReference(unittest.TestCase): + def test_field_reference_creation(self): + field = FieldReference(["student", "name"]) + self.assertEqual(field.field_name(), ["student", "name"]) + self.assertEqual(str(field), "student.name") + + def test_field_reference_equality(self): + field1 = FieldReference(["student", "name"]) + field2 = FieldReference(["student", "name"]) + self.assertEqual(field1, field2) + self.assertEqual(hash(field1), hash(field2)) + + def test_named_reference_static_methods(self): + ref = NamedReference.field(["student", "name"]) + self.assertEqual(ref.field_name(), ["student", "name"]) + + ref2 = NamedReference.field_from_column("student") + self.assertEqual(ref2.field_name(), ["student"]) diff --git a/clients/client-python/tests/unittests/test_types.py b/clients/client-python/tests/unittests/test_types.py index e241b420acc..bf5685c4ac5 100644 --- a/clients/client-python/tests/unittests/test_types.py +++ b/clients/client-python/tests/unittests/test_types.py @@ -17,7 +17,7 @@ import unittest -from gravitino.api.types import Types, Name +from gravitino.api.types.types import Types, Name class TestTypes(unittest.TestCase): diff --git a/clients/client-python/tests/unittests/test_unparsed_expression.py b/clients/client-python/tests/unittests/test_unparsed_expression.py new file mode 100644 index 00000000000..809caf67d48 --- /dev/null +++ b/clients/client-python/tests/unittests/test_unparsed_expression.py @@ -0,0 +1,34 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import unittest +from gravitino.api.expressions.unparsed_expression import UnparsedExpressionImpl + + +class TestUnparsedExpression(unittest.TestCase): + def test_unparsed_expression_creation(self): + expr = UnparsedExpressionImpl("some_expression") + self.assertEqual(expr.unparsed_expression(), "some_expression") + self.assertEqual( + str(expr), "UnparsedExpressionImpl{unparsedExpression='some_expression'}" + ) + + def test_unparsed_expression_equality(self): + expr1 = UnparsedExpressionImpl("some_expression") + expr2 = UnparsedExpressionImpl("some_expression") + self.assertEqual(expr1, expr2) + self.assertEqual(hash(expr1), hash(expr2))