Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PTDT-2863]: Feature schema attributes #1930

Merged
merged 20 commits into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from 18 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions libs/labelbox/src/labelbox/schema/ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,11 @@
from labelbox.schema.tool_building.tool_type_mapping import (
map_tool_type_to_tool_cls,
)
from labelbox.schema.tool_building.types import (
FeatureSchemaAttribute,
FeatureSchemaAttributes,
)
import warnings


class DeleteFeatureFromOntologyResult:
Expand Down Expand Up @@ -73,6 +78,7 @@ class Tool:
classifications: (list)
schema_id: (str)
feature_schema_id: (str)
attributes: (list)
"""
Tim-Kerr marked this conversation as resolved.
Show resolved Hide resolved

class Type(Enum):
Expand All @@ -95,6 +101,13 @@ class Type(Enum):
classifications: List[Classification] = field(default_factory=list)
schema_id: Optional[str] = None
feature_schema_id: Optional[str] = None
attributes: Optional[FeatureSchemaAttributes] = None

def __post_init__(self):
if self.attributes is not None:
warnings.warn(
"Attributes are an experimental feature and may change in the future."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

)

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
Expand All @@ -109,6 +122,12 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
for c in dictionary["classifications"]
],
color=dictionary["color"],
attributes=[
FeatureSchemaAttribute.from_dict(attr)
for attr in dictionary.get("attributes", []) or []
]
if dictionary.get("attributes")
else None,
)

def asdict(self) -> Dict[str, Any]:
Expand All @@ -122,6 +141,9 @@ def asdict(self) -> Dict[str, Any]:
],
"schemaNodeId": self.schema_id,
"featureSchemaId": self.feature_schema_id,
"attributes": [a.asdict() for a in self.attributes]
if self.attributes is not None
else None,
}

def add_classification(self, classification: Classification) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

from lbox.exceptions import InconsistentOntologyException

from labelbox.schema.tool_building.types import FeatureSchemaId
from labelbox.schema.tool_building.types import (
FeatureSchemaId,
vbrodsky marked this conversation as resolved.
Show resolved Hide resolved
FeatureSchemaAttributes,
FeatureSchemaAttribute,
)


@dataclass
Expand Down Expand Up @@ -42,6 +46,7 @@ class Classification:
schema_id: (str)
feature_schema_id: (str)
scope: (str)
attributes: (list)
"""

class Type(Enum):
Expand Down Expand Up @@ -70,6 +75,7 @@ class UIMode(Enum):
ui_mode: Optional[UIMode] = (
None # How this classification should be answered (e.g. hotkeys / autocomplete, etc)
)
attributes: Optional[FeatureSchemaAttributes] = None

def __post_init__(self):
if self.name is None:
Expand All @@ -88,6 +94,10 @@ def __post_init__(self):
else:
if self.instructions is None:
self.instructions = self.name
if self.attributes is not None:
warnings.warn(
"Attributes are an experimental feature and may change in the future."
)

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "Classification":
Expand All @@ -103,6 +113,12 @@ def from_dict(cls, dictionary: Dict[str, Any]) -> "Classification":
schema_id=dictionary.get("schemaNodeId", None),
feature_schema_id=dictionary.get("featureSchemaId", None),
scope=cls.Scope(dictionary.get("scope", cls.Scope.GLOBAL)),
attributes=[
FeatureSchemaAttribute.from_dict(attr)
for attr in dictionary.get("attributes", []) or []
]
if dictionary.get("attributes")
else None,
)

def asdict(self, is_subclass: bool = False) -> Dict[str, Any]:
Expand All @@ -118,6 +134,9 @@ def asdict(self, is_subclass: bool = False) -> Dict[str, Any]:
"options": [o.asdict() for o in self.options],
"schemaNodeId": self.schema_id,
"featureSchemaId": self.feature_schema_id,
"attributes": [a.asdict() for a in self.attributes]
if self.attributes is not None
else None,
}
if (
self.class_type == self.Type.RADIO
Expand Down
31 changes: 29 additions & 2 deletions libs/labelbox/src/labelbox/schema/tool_building/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,33 @@
from typing import Annotated

from typing import Annotated, List
from pydantic import Field


from dataclasses import dataclass

from typing import Any, Dict, List


@dataclass
class FeatureSchemaAttribute:
attributeName: str
attributeValue: str

def asdict(self):
return {
"attributeName": self.attributeName,
"attributeValue": self.attributeValue,
}

@classmethod
def from_dict(cls, dictionary: Dict[str, Any]) -> "FeatureSchemaAttribute":
return cls(
attributeName=dictionary["attributeName"],
attributeValue=dictionary["attributeValue"],
)


FeatureSchemaId = Annotated[str, Field(min_length=25, max_length=25)]
SchemaId = Annotated[str, Field(min_length=25, max_length=25)]
FeatureSchemaAttributes = Annotated[
List[FeatureSchemaAttribute], Field(default_factory=list)
vbrodsky marked this conversation as resolved.
Show resolved Hide resolved
]
71 changes: 71 additions & 0 deletions libs/labelbox/tests/integration/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from labelbox.schema.data_row import DataRowMetadataField
from labelbox.schema.ontology_kind import OntologyKind
from labelbox.schema.user import User
from labelbox.schema.tool_building.types import FeatureSchemaAttribute


@pytest.fixture
Expand Down Expand Up @@ -552,6 +553,76 @@ def point():
)


@pytest.fixture
def auto_ocr_text_value_class():
return Classification(
class_type=Classification.Type.TEXT,
name="Auto OCR Text Value",
instructions="Text value for ocr bboxes",
scope=Classification.Scope.GLOBAL,
required=False,
attributes=[
FeatureSchemaAttribute(
attributeName="auto-ocr-text-value", attributeValue="true"
)
],
)


@pytest.fixture
def auto_ocr_bbox(auto_ocr_text_value_class):
return Tool(
tool=Tool.Type.BBOX,
name="Auto ocr bbox",
color="ff0000",
attributes=[
FeatureSchemaAttribute(
attributeName="auto-ocr", attributeValue="true"
)
],
classifications=[auto_ocr_text_value_class],
)


@pytest.fixture
def requires_connection_classification():
return Classification(
name="Requires connection radio",
instructions="Classification that requires a connection",
class_type=Classification.Type.RADIO,
attributes=[
FeatureSchemaAttribute(
attributeName="requires-connection", attributeValue="true"
)
],
options=[Option(value="A"), Option(value="B")],
)


@pytest.fixture
def requires_connection_classification_feature_schema(
client, requires_connection_classification
):
created_feature_schema = client.upsert_feature_schema(
requires_connection_classification.asdict()
)
yield created_feature_schema
client.delete_unused_feature_schema(
created_feature_schema.normalized["featureSchemaId"]
)


@pytest.fixture
def auto_ocr_bbox_feature_schema(client, auto_ocr_bbox):
created_feature_schema = client.upsert_feature_schema(
auto_ocr_bbox.asdict()
)
yield created_feature_schema
client.delete_unused_feature_schema(
created_feature_schema.normalized["featureSchemaId"]
)


@pytest.fixture
def feature_schema(client, point):
created_feature_schema = client.upsert_feature_schema(point.asdict())
Expand Down
26 changes: 26 additions & 0 deletions libs/labelbox/tests/integration/test_feature_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,29 @@ def test_does_not_include_used_feature_schema(client, feature_schema):
assert feature_schema_id not in unused_feature_schemas

client.delete_unused_ontology(ontology.uid)


def test_upsert_tool_with_attributes(auto_ocr_bbox_feature_schema):
auto_ocr_attributes = auto_ocr_bbox_feature_schema.normalized["attributes"]
auto_ocr_text_value_attributes = auto_ocr_bbox_feature_schema.normalized[
"classifications"
][0]["attributes"]
assert auto_ocr_attributes == [
{"attributeName": "auto-ocr", "attributeValue": "true"}
]
assert auto_ocr_text_value_attributes == [
{"attributeName": "auto-ocr-text-value", "attributeValue": "true"}
]


def test_upsert_classification_with_attributes(
requires_connection_classification_feature_schema,
):
requires_connection_attributes = (
requires_connection_classification_feature_schema.normalized[
"attributes"
]
)
assert requires_connection_attributes == [
{"attributeName": "requires-connection", "attributeValue": "true"}
]
20 changes: 20 additions & 0 deletions libs/labelbox/tests/unit/test_unit_ontology.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"color": "#FF0000",
"tool": "polygon",
"classifications": [],
"attributes": None,
},
{
"schemaNodeId": None,
Expand All @@ -24,6 +25,7 @@
"color": "#FF0000",
"tool": "superpixel",
"classifications": [],
"attributes": None,
},
{
"schemaNodeId": None,
Expand All @@ -32,6 +34,12 @@
"name": "bbox",
"color": "#FF0000",
"tool": "rectangle",
"attributes": [
{
"attributeName": "auto-ocr",
"attributeValue": "true",
}
],
"classifications": [
{
"schemaNodeId": None,
Expand All @@ -56,6 +64,7 @@
"name": "nested nested text",
"type": "text",
"options": [],
"attributes": None,
}
],
},
Expand All @@ -67,6 +76,12 @@
"options": [],
},
],
"attributes": [
{
"attributeName": "requires-connection",
"attributeValue": "true",
}
],
},
{
"schemaNodeId": None,
Expand All @@ -76,6 +91,7 @@
"name": "nested text",
"type": "text",
"options": [],
"attributes": None,
},
],
},
Expand All @@ -87,6 +103,7 @@
"color": "#FF0000",
"tool": "point",
"classifications": [],
"attributes": None,
},
{
"schemaNodeId": None,
Expand All @@ -96,6 +113,7 @@
"color": "#FF0000",
"tool": "line",
"classifications": [],
"attributes": None,
},
{
"schemaNodeId": None,
Expand All @@ -105,6 +123,7 @@
"color": "#FF0000",
"tool": "named-entity",
"classifications": [],
"attributes": None,
},
],
"classifications": [
Expand All @@ -117,6 +136,7 @@
"type": "radio",
"scope": "global",
"uiMode": "searchable",
"attributes": None,
"options": [
{
"schemaNodeId": None,
Expand Down
1 change: 1 addition & 0 deletions libs/labelbox/tests/unit/test_unit_prompt_issue_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def test_as_dict():
"schemaNodeId": None,
"featureSchemaId": None,
"scope": "global",
"attributes": None,
}
],
"color": None,
Expand Down
Loading