Skip to content

Commit

Permalink
upgrade to pydantic 2 (#744)
Browse files Browse the repository at this point in the history
This PR upgrades langserve to pydantic 2.

* Added a failing unit test that has 2 known failures (that need to be
fixed in langchain-core)
* Deprecation warnings will be resolved separately.
  • Loading branch information
eyurtsev authored Sep 9, 2024
1 parent 41a9d79 commit 5aedbf7
Show file tree
Hide file tree
Showing 18 changed files with 1,490 additions and 1,455 deletions.
94 changes: 0 additions & 94 deletions .github/workflows/_pydantic_compatibility.yml

This file was deleted.

7 changes: 0 additions & 7 deletions .github/workflows/langserve_ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,6 @@ jobs:
with:
working-directory: .
secrets: inherit

pydantic-compatibility:
uses:
./.github/workflows/_pydantic_compatibility.yml
with:
working-directory: .
secrets: inherit
test:
timeout-minutes: 10
runs-on: ubuntu-latest
Expand Down
53 changes: 53 additions & 0 deletions langserve/_pydantic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any, Dict, Type, cast

from pydantic import BaseModel, ConfigDict, RootModel
from pydantic.json_schema import (
DEFAULT_REF_TEMPLATE,
GenerateJsonSchema,
JsonSchemaMode,
)


def _create_root_model(name: str, type_: Any) -> Type[RootModel]:
"""Create a base class."""

def schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
) -> Dict[str, Any]:
# Complains about schema not being defined in superclass
schema_ = super(cls, cls).schema( # type: ignore[misc]
by_alias=by_alias, ref_template=ref_template
)
schema_["title"] = name
return schema_

def model_json_schema(
cls: Type[BaseModel],
by_alias: bool = True,
ref_template: str = DEFAULT_REF_TEMPLATE,
schema_generator: type[GenerateJsonSchema] = GenerateJsonSchema,
mode: JsonSchemaMode = "validation",
) -> Dict[str, Any]:
# Complains about model_json_schema not being defined in superclass
schema_ = super(cls, cls).model_json_schema( # type: ignore[misc]
by_alias=by_alias,
ref_template=ref_template,
schema_generator=schema_generator,
mode=mode,
)
schema_["title"] = name
return schema_

base_class_attributes = {
"__annotations__": {"root": type_},
"model_config": ConfigDict(arbitrary_types_allowed=True),
"schema": classmethod(schema),
"model_json_schema": classmethod(model_json_schema),
# Should replace __module__ with caller based on stack frame.
"__module__": "langserve._pydantic",
}

custom_root_type = type(name, (RootModel,), base_class_attributes)
return cast(Type[RootModel], custom_root_type)
56 changes: 26 additions & 30 deletions langserve/api_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,15 @@
from langsmith import client as ls_client
from langsmith.schemas import FeedbackIngestToken
from langsmith.utils import tracing_is_enabled
from pydantic import BaseModel, Field, RootModel, ValidationError, create_model
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from typing_extensions import TypedDict

from langserve._pydantic import _create_root_model
from langserve.callbacks import AsyncEventAggregatorCallback, CallbackEventDict
from langserve.lzstring import LZString
from langserve.playground import serve_playground
from langserve.pydantic_v1 import BaseModel, Field, ValidationError, create_model
from langserve.schema import (
BatchResponseMetadata,
CustomUserType,
Expand Down Expand Up @@ -256,10 +257,12 @@ def _update_config_with_defaults(
}
metadata.update(hosted_metadata)

non_overridable_default_config = RunnableConfig(
run_name=run_name,
metadata=metadata,
)
non_overridable_default_config: RunnableConfig = {
"metadata": metadata,
}

if run_name:
non_overridable_default_config["run_name"] = run_name

# merge_configs is last-writer-wins, so we specifically pass in the
# overridable configs first, then the user provided configs, then
Expand All @@ -280,8 +283,8 @@ def _update_config_with_defaults(

def _unpack_input(validated_model: BaseModel) -> Any:
"""Unpack the decoded input from the validated model."""
if hasattr(validated_model, "__root__"):
model = validated_model.__root__
if isinstance(validated_model, RootModel):
model = validated_model.root
else:
model = validated_model

Expand All @@ -305,7 +308,7 @@ def _rename_pydantic_model(model: Type[BaseModel], prefix: str) -> Type[BaseMode
"""Rename the given pydantic model to the given name."""
return create_model(
prefix + model.__name__,
__config__=model.__config__,
__config__=model.model_config,
**{
fieldname: (
_rename_pydantic_model(field.annotation, prefix)
Expand All @@ -314,10 +317,10 @@ def _rename_pydantic_model(model: Type[BaseModel], prefix: str) -> Type[BaseMode
Field(
field.default,
title=fieldname,
description=field.field_info.description,
description=field.description,
),
)
for fieldname, field in model.__fields__.items()
for fieldname, field in model.model_fields.items()
},
)

Expand All @@ -334,7 +337,7 @@ def _resolve_model(
if isclass(type_) and issubclass(type_, BaseModel):
model = type_
else:
model = create_model(default_name, __root__=(type_, ...))
model = _create_root_model(default_name, type_)

hash_ = model.schema_json()

Expand Down Expand Up @@ -367,11 +370,7 @@ def _add_namespace_to_model(namespace: str, model: Type[BaseModel]) -> Type[Base
A new model with name prepended with the given namespace.
"""
model_with_unique_name = _rename_pydantic_model(model, namespace)
if "run_id" in model_with_unique_name.__annotations__:
# Help resolve reference by providing namespace references
model_with_unique_name.update_forward_refs(uuid=uuid)
else:
model_with_unique_name.update_forward_refs()
model_with_unique_name.model_rebuild()
return model_with_unique_name


Expand Down Expand Up @@ -404,7 +403,7 @@ def _with_validation_error_translation() -> Generator[None, None, None]:
try:
yield
except ValidationError as e:
raise RequestValidationError(e.errors(), body=e.model)
raise RequestValidationError(e.errors())


def _json_encode_response(model: BaseModel) -> JSONResponse:
Expand All @@ -424,39 +423,36 @@ def _json_encode_response(model: BaseModel) -> JSONResponse:

if isinstance(model, InvokeBaseResponse):
# Invoke Response
# Collapse '__root__' from output field if it exists. This is done
# Collapse 'root' from output field if it exists. This is done
# automatically by fastapi when annotating request and response with
# We need to do this manually since we're using vanilla JSONResponse
if isinstance(obj["output"], dict) and "__root__" in obj["output"]:
obj["output"] = obj["output"]["__root__"]
if isinstance(obj["output"], dict) and "root" in obj["output"]:
obj["output"] = obj["output"]["root"]

if "callback_events" in obj:
for idx, callback_event in enumerate(obj["callback_events"]):
if isinstance(callback_event, dict) and "__root__" in callback_event:
obj["callback_events"][idx] = callback_event["__root__"]
if isinstance(callback_event, dict) and "root" in callback_event:
obj["callback_events"][idx] = callback_event["root"]
elif isinstance(model, BatchBaseResponse):
if not isinstance(obj["output"], list):
raise AssertionError("Expected output to be a list")

# Collapse '__root__' from output field if it exists. This is done
# Collapse 'root' from output field if it exists. This is done
# automatically by fastapi when annotating request and response with
# We need to do this manually since we're using vanilla JSONResponse
outputs = obj["output"]
for idx, output in enumerate(outputs):
if isinstance(output, dict) and "__root__" in output:
outputs[idx] = output["__root__"]
if isinstance(output, dict) and "root" in output:
outputs[idx] = output["root"]

if "callback_events" in obj:
if not isinstance(obj["callback_events"], list):
raise AssertionError("Expected callback_events to be a list")

for callback_events in obj["callback_events"]:
for idx, callback_event in enumerate(callback_events):
if (
isinstance(callback_event, dict)
and "__root__" in callback_event
):
callback_events[idx] = callback_event["__root__"]
if isinstance(callback_event, dict) and "root" in callback_event:
callback_events[idx] = callback_event["root"]
else:
raise AssertionError(
f"Expected a InvokeBaseResponse or BatchBaseResponse got: {type(model)}"
Expand Down
3 changes: 1 addition & 2 deletions langserve/playground.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@

from fastapi.responses import Response
from langchain_core.runnables import Runnable

from langserve.pydantic_v1 import BaseModel
from pydantic import BaseModel


class PlaygroundTemplate(Template):
Expand Down
33 changes: 0 additions & 33 deletions langserve/pydantic_v1.py

This file was deleted.

9 changes: 5 additions & 4 deletions langserve/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
from typing import Dict, List, Optional, Union
from uuid import UUID

from pydantic import BaseModel # Floats between v1 and v2

from langserve.pydantic_v1 import BaseModel as BaseModelV1
from langserve.pydantic_v1 import Field
from pydantic import (
BaseModel,
Field,
)
from pydantic import BaseModel as BaseModelV1


class CustomUserType(BaseModelV1):
Expand Down
Loading

0 comments on commit 5aedbf7

Please sign in to comment.