Skip to content

Commit

Permalink
chore(weave): Pull str limit magic numbers into const #3042
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewtruong authored Nov 21, 2024
1 parent a323429 commit 5c8a0b5
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 25 deletions.
28 changes: 12 additions & 16 deletions tests/integrations/integration_utilities_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@
_truncated_str,
truncate_op_name,
)
from weave.trace_server.constants import MAX_OP_NAME_LENGTH

MAX_RUN_NAME_LENGTH = 128
NON_HASH_LIMIT = 5


def test_truncate_op_name_less_than_limit() -> None:
name = _make_string_of_length(MAX_RUN_NAME_LENGTH - 1)
name = _make_string_of_length(MAX_OP_NAME_LENGTH - 1)
trunc = truncate_op_name(name)
assert trunc == name


def test_truncate_op_name_at_limit() -> None:
name = _make_string_of_length(MAX_RUN_NAME_LENGTH)
name = _make_string_of_length(MAX_OP_NAME_LENGTH)
trunc = truncate_op_name(name)
assert trunc == name

Expand All @@ -29,15 +29,13 @@ def test_truncate_op_name_too_short_for_hash() -> None:
if tail_len <= chars_to_remove:
with pytest.raises(ValueError):
name, trunc = _truncated_str(
tail_len, MAX_RUN_NAME_LENGTH + chars_to_remove
tail_len, MAX_OP_NAME_LENGTH + chars_to_remove
)
else:
name, trunc = _truncated_str(
tail_len, MAX_RUN_NAME_LENGTH + chars_to_remove
)
assert trunc == name[:MAX_RUN_NAME_LENGTH]
name, trunc = _truncated_str(tail_len, MAX_OP_NAME_LENGTH + chars_to_remove)
assert trunc == name[:MAX_OP_NAME_LENGTH]

name, trunc = _truncated_str(NON_HASH_LIMIT + 1, MAX_RUN_NAME_LENGTH + 1)
name, trunc = _truncated_str(NON_HASH_LIMIT + 1, MAX_OP_NAME_LENGTH + 1)
assert (
trunc
== "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa.a_0_a"
Expand All @@ -49,24 +47,22 @@ def test_truncate_op_name_too_short_for_hash() -> None:
if tail_len <= chars_to_remove:
with pytest.raises(ValueError):
name, trunc = _truncated_str(
tail_len, MAX_RUN_NAME_LENGTH + chars_to_remove
tail_len, MAX_OP_NAME_LENGTH + chars_to_remove
)
else:
name, trunc = _truncated_str(
tail_len, MAX_RUN_NAME_LENGTH + chars_to_remove
)
assert trunc == name[:MAX_RUN_NAME_LENGTH]
name, trunc = _truncated_str(tail_len, MAX_OP_NAME_LENGTH + chars_to_remove)
assert trunc == name[:MAX_OP_NAME_LENGTH]


def test_truncate_op_name_with_digest() -> None:
name = _make_string_of_length(MAX_RUN_NAME_LENGTH + 1)
name = _make_string_of_length(MAX_OP_NAME_LENGTH + 1)
trunc = truncate_op_name(name)
assert (
trunc
== "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa_b325_aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"
)

name = _make_string_of_length(MAX_RUN_NAME_LENGTH + 10)
name = _make_string_of_length(MAX_OP_NAME_LENGTH + 10)
trunc = truncate_op_name(name)
assert (
trunc
Expand Down
7 changes: 3 additions & 4 deletions weave/integrations/integration_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@

from weave.trace.refs import OpRef, parse_uri
from weave.trace.weave_client import Call, CallsIter

MAX_RUN_NAME_LENGTH = 128
from weave.trace_server.constants import MAX_OP_NAME_LENGTH


def make_pythonic_function_name(name: str) -> str:
Expand All @@ -17,10 +16,10 @@ def make_pythonic_function_name(name: str) -> str:


def truncate_op_name(name: str) -> str:
if len(name) <= MAX_RUN_NAME_LENGTH:
if len(name) <= MAX_OP_NAME_LENGTH:
return name

trim_amount_needed = len(name) - MAX_RUN_NAME_LENGTH
trim_amount_needed = len(name) - MAX_OP_NAME_LENGTH
parts = name.split(".")
last_part = parts[-1]

Expand Down
5 changes: 3 additions & 2 deletions weave/trace/weave_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
from weave.trace.table import Table
from weave.trace.util import deprecated
from weave.trace.vals import WeaveObject, WeaveTable, make_trace_obj
from weave.trace_server.constants import MAX_OBJECT_NAME_LENGTH
from weave.trace_server.ids import generate_id
from weave.trace_server.interface.feedback_types import RUNNABLE_FEEDBACK_TYPE_PREFIX
from weave.trace_server.trace_server_interface import (
Expand Down Expand Up @@ -1604,8 +1605,8 @@ def sanitize_object_name(name: str) -> str:
res = re.sub(r"([._-]{2,})+", "-", re.sub(r"[^\w._]+", "-", name)).strip("-_")
if not res:
raise ValueError(f"Invalid object name: {name}")
if len(res) > 128:
res = res[:128]
if len(res) > MAX_OBJECT_NAME_LENGTH:
res = res[:MAX_OBJECT_NAME_LENGTH]
return res


Expand Down
4 changes: 4 additions & 0 deletions weave/trace_server/constants.py
Original file line number Diff line number Diff line change
@@ -1 +1,5 @@
COMPLETIONS_CREATE_OP_NAME = "weave.completions_create"

MAX_DISPLAY_NAME_LENGTH = 128
MAX_OP_NAME_LENGTH = 128
MAX_OBJECT_NAME_LENGTH = 128
11 changes: 8 additions & 3 deletions weave/trace_server/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
from typing import Any, Literal, Optional

from weave.trace_server import refs_internal, validation_util
from weave.trace_server.constants import (
MAX_DISPLAY_NAME_LENGTH,
MAX_OBJECT_NAME_LENGTH,
MAX_OP_NAME_LENGTH,
)
from weave.trace_server.errors import InvalidRequest

# Temporary flag to disable database-side validation of object ids.
Expand Down Expand Up @@ -39,14 +44,14 @@ def parent_id_validator(s: Optional[str]) -> Optional[str]:
def display_name_validator(s: Optional[str]) -> Optional[str]:
if s is None:
return None
return validation_util.require_max_str_len(s, 128)
return validation_util.require_max_str_len(s, MAX_DISPLAY_NAME_LENGTH)


def op_name_validator(s: str) -> str:
if refs_internal.string_will_be_interpreted_as_ref(s):
validation_util.require_internal_ref_uri(s, refs_internal.InternalOpRef)
else:
validation_util.require_max_str_len(s, 128)
validation_util.require_max_str_len(s, MAX_OP_NAME_LENGTH)

return s

Expand Down Expand Up @@ -86,7 +91,7 @@ def _validate_object_name_charset(name: str) -> None:
def object_id_validator(s: str) -> str:
if SHOULD_ENFORCE_OBJ_ID_CHARSET:
_validate_object_name_charset(s)
return validation_util.require_max_str_len(s, 128)
return validation_util.require_max_str_len(s, MAX_OBJECT_NAME_LENGTH)


def refs_list_validator(s: list[str]) -> list[str]:
Expand Down

0 comments on commit 5c8a0b5

Please sign in to comment.