Skip to content

Commit

Permalink
fix: use @field_serializer for pydantic Urls (#363)
Browse files Browse the repository at this point in the history
* fix: use @field_serializer for pydantic `Url`s

ensures `model_dump` preserves python `datetime`s, which are handled by pymongo.

reverts mode='json' option to model_dump, to preserve datetimes.

fixes #349

* fix: preserve exclude_unset

* style: remove `print` for debugging

* fix: remove stale comment
  • Loading branch information
dwinston authored Nov 6, 2023
1 parent dc668ea commit ebe41c9
Show file tree
Hide file tree
Showing 22 changed files with 106 additions and 159 deletions.
4 changes: 1 addition & 3 deletions components/nmdc_runtime/workflow_execution_activity/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,7 @@ def insert_into_keys(
workflow: Workflow, data_objects: list[DataObject]
) -> dict[str, Any]:
"""Insert data object url into correct workflow input field."""
workflow_dict = workflow.model_dump(
mode="json",
)
workflow_dict = workflow.model_dump()
for key in workflow_dict["inputs"]:
for do in data_objects:
if workflow_dict["inputs"][key] == str(do.data_object_type):
Expand Down
6 changes: 2 additions & 4 deletions nmdc_runtime/api/endpoints/objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def create_object(
"""
id_supplied = supplied_object_id(
mdb, client_site, object_in.model_dump(mode="json", exclude_unset=True)
mdb, client_site, object_in.model_dump(exclude_unset=True)
)
drs_id = local_part(
id_supplied if id_supplied is not None else generate_one_id(mdb, S3_ID_NS)
Expand Down Expand Up @@ -255,9 +255,7 @@ def update_object(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"client authorized for different site_id than {object_mgr_site}",
)
doc_object_patched = merge(
doc, object_patch.model_dump(mode="json", exclude_unset=True)
)
doc_object_patched = merge(doc, object_patch.model_dump(exclude_unset=True))
mdb.operations.replace_one({"id": object_id}, doc_object_patched)
return doc_object_patched

Expand Down
4 changes: 2 additions & 2 deletions nmdc_runtime/api/endpoints/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,13 +61,13 @@ def update_operation(
detail=f"client authorized for different site_id than {site_id_op}",
)
op_patch_metadata = merge(
op_patch.model_dump(mode="json", exclude_unset=True).get("metadata", {}),
op_patch.model_dump(exclude_unset=True).get("metadata", {}),
pick(["site_id", "job", "model"], doc_op.get("metadata", {})),
)
doc_op_patched = merge(
doc_op,
assoc(
op_patch.model_dump(mode="json", exclude_unset=True),
op_patch.model_dump(exclude_unset=True),
"metadata",
op_patch_metadata,
),
Expand Down
10 changes: 5 additions & 5 deletions nmdc_runtime/api/endpoints/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,9 @@ def run_query(
id=qid,
saved_at=saved_at,
)
mdb.queries.insert_one(query.model_dump(mode="json", exclude_unset=True))
mdb.queries.insert_one(query.model_dump(exclude_unset=True))
cmd_response = _run_query(query, mdb)
return unmongo(cmd_response.model_dump(mode="json", exclude_unset=True))
return unmongo(cmd_response.model_dump(exclude_unset=True))


@router.get("/queries/{query_id}", response_model=Query)
Expand Down Expand Up @@ -107,7 +107,7 @@ def rerun_query(
check_can_delete(user)

cmd_response = _run_query(query, mdb)
return unmongo(cmd_response.model_dump(mode="json", exclude_unset=True))
return unmongo(cmd_response.model_dump(exclude_unset=True))


def _run_query(query, mdb) -> CommandResponse:
Expand All @@ -131,12 +131,12 @@ def _run_query(query, mdb) -> CommandResponse:
detail="Failed to back up to-be-deleted documents. operation aborted.",
)

q_response = mdb.command(query.cmd.model_dump(mode="json", exclude_unset=True))
q_response = mdb.command(query.cmd.model_dump(exclude_unset=True))
cmd_response: CommandResponse = command_response_for(q_type)(**q_response)
query_run = (
QueryRun(qid=query.id, ran_at=ran_at, result=cmd_response)
if cmd_response.ok
else QueryRun(qid=query.id, ran_at=ran_at, error=cmd_response)
)
mdb.query_runs.insert_one(query_run.model_dump(mode="json", exclude_unset=True))
mdb.query_runs.insert_one(query_run.model_dump(exclude_unset=True))
return cmd_response
6 changes: 1 addition & 5 deletions nmdc_runtime/api/endpoints/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,5 @@ def post_run_event(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Supplied run_event.run.id does not match run_id given in request URL.",
)
mdb.run_events.insert_one(
run_event.model_dump(
mode="json",
)
)
mdb.run_events.insert_one(run_event.model_dump())
return _get_run_summary(run_event.run.id, mdb)
4 changes: 1 addition & 3 deletions nmdc_runtime/api/endpoints/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def data_objects(
req: DataObjectListRequest = Depends(),
mdb: MongoDatabase = Depends(get_mongo_db),
):
filter_ = list_request_filter_to_mongo_filter(
req.model_dump(mode="json", exclude_unset=True)
)
filter_ = list_request_filter_to_mongo_filter(req.model_dump(exclude_unset=True))
max_page_size = filter_.pop("max_page_size", None)
page_token = filter_.pop("page_token", None)
req = ListRequest(
Expand Down
12 changes: 2 additions & 10 deletions nmdc_runtime/api/endpoints/sites.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,7 @@ def create_site(
status_code=status.HTTP_409_CONFLICT,
detail=f"site with supplied id {site.id} already exists",
)
mdb.sites.insert_one(
site.model_dump(
mode="json",
)
)
mdb.sites.insert_one(site.model_dump())
refresh_minter_requesters_from_sites()
rv = mdb.users.update_one(
{"username": user.username},
Expand Down Expand Up @@ -169,11 +165,7 @@ def put_object_in_site(
},
}
)
mdb.operations.insert_one(
op.model_dump(
mode="json",
)
)
mdb.operations.insert_one(op.model_dump())
return op


Expand Down
22 changes: 5 additions & 17 deletions nmdc_runtime/api/endpoints/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,7 @@ async def login_for_access_token(
detail="Incorrect username or password",
headers={"WWW-Authenticate": "Bearer"},
)
access_token_expires = timedelta(
**ACCESS_TOKEN_EXPIRES.model_dump(
mode="json",
)
)
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump())
access_token = create_access_token(
data={"sub": f"user:{user.username}"}, expires_delta=access_token_expires
)
Expand All @@ -54,21 +50,15 @@ async def login_for_access_token(
headers={"WWW-Authenticate": "Bearer"},
)
# TODO make below an absolute time
access_token_expires = timedelta(
**ACCESS_TOKEN_EXPIRES.model_dump(
mode="json",
)
)
access_token_expires = timedelta(**ACCESS_TOKEN_EXPIRES.model_dump())
access_token = create_access_token(
data={"sub": f"client:{form_data.client_id}"},
expires_delta=access_token_expires,
)
return {
"access_token": access_token,
"token_type": "bearer",
"expires": ACCESS_TOKEN_EXPIRES.model_dump(
mode="json",
),
"expires": ACCESS_TOKEN_EXPIRES.model_dump(),
}


Expand All @@ -94,10 +84,8 @@ def create_user(
check_can_create_user(requester)
mdb.users.insert_one(
UserInDB(
**user_in.model_dump(
mode="json",
),
**user_in.model_dump(),
hashed_password=get_password_hash(user_in.password),
).model_dump(mode="json", exclude_unset=True)
).model_dump(exclude_unset=True)
)
return mdb.users.find_one({"username": user_in.username})
18 changes: 6 additions & 12 deletions nmdc_runtime/api/endpoints/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -454,11 +454,11 @@ def _create_object(
mdb: MongoDatabase, object_in: DrsObjectIn, mgr_site, drs_id, self_uri
):
drs_obj = DrsObject(
**object_in.model_dump(exclude_unset=True, mode="json"),
**object_in.model_dump(exclude_unset=True),
id=drs_id,
self_uri=self_uri,
)
doc = drs_obj.model_dump(exclude_unset=True, mode="json")
doc = drs_obj.model_dump(exclude_unset=True)
doc["_mgr_site"] = mgr_site # manager site
try:
mdb.objects.insert_one(doc)
Expand Down Expand Up @@ -519,22 +519,16 @@ def _claim_job(job_id: str, mdb: MongoDatabase, site: Site):
"workflow": job.workflow,
"config": job.config,
}
).model_dump(mode="json", exclude_unset=True),
).model_dump(exclude_unset=True),
"site_id": site.id,
"model": dotted_path_for(JobOperationMetadata),
},
}
)
mdb.operations.insert_one(
op.model_dump(
mode="json",
)
)
mdb.jobs.replace_one(
{"id": job.id}, job.model_dump(mode="json", exclude_unset=True)
)
mdb.operations.insert_one(op.model_dump())
mdb.jobs.replace_one({"id": job.id}, job.model_dump(exclude_unset=True))

return op.model_dump(mode="json", exclude_unset=True)
return op.model_dump(exclude_unset=True)


@lru_cache
Expand Down
10 changes: 3 additions & 7 deletions nmdc_runtime/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,7 @@ def ensure_initial_resources_on_boot():
collection_boot = import_module(f"nmdc_runtime.api.boot.{collection_name}")

for model in collection_boot.construct():
doc = model.model_dump(
mode="json",
)
doc = model.model_dump()
mdb[collection_name].replace_one({"id": doc["id"]}, doc, upsert=True)

username = os.getenv("API_ADMIN_USER")
Expand All @@ -248,7 +246,7 @@ def ensure_initial_resources_on_boot():
username=username,
hashed_password=get_password_hash(os.getenv("API_ADMIN_PASS")),
site_admin=[os.getenv("API_SITE_ID")],
).model_dump(mode="json", exclude_unset=True),
).model_dump(exclude_unset=True),
upsert=True,
)
mdb.users.create_index("username")
Expand All @@ -269,9 +267,7 @@ def ensure_initial_resources_on_boot():
),
)
],
).model_dump(
mode="json",
),
).model_dump(),
upsert=True,
)

Expand Down
23 changes: 23 additions & 0 deletions nmdc_runtime/api/models/object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BaseModel,
AnyUrl,
HttpUrl,
field_serializer,
)
from typing_extensions import Annotated

Expand All @@ -31,6 +32,10 @@ class AccessURL(BaseModel):
headers: Optional[Dict[str, str]] = None
url: AnyUrl

@field_serializer("url")
def serialize_url(self, url: AnyUrl, _info):
return str(url)


class AccessMethod(BaseModel):
access_id: Optional[Annotated[str, StringConstraints(min_length=1)]] = None
Expand Down Expand Up @@ -78,6 +83,12 @@ def no_contents_means_single_blob(cls, values):
raise ValueError("no contents means no further nesting, so id required")
return values

@field_serializer("drs_uri")
def serialize_url(self, drs_uri: Optional[List[AnyUrl]], _info):
if drs_uri is not None and len(drs_uri) > 0:
return [str(u) for u in drs_uri]
return drs_uri


ContentsObject.update_forward_refs()

Expand Down Expand Up @@ -127,6 +138,10 @@ class DrsObject(DrsObjectIn):
id: DrsId
self_uri: AnyUrl

@field_serializer("self_uri")
def serialize_url(self, self_uri: AnyUrl, _info):
return str(self_uri)


Seconds = Annotated[int, Field(strict=True, gt=0)]

Expand All @@ -135,6 +150,10 @@ class ObjectPresignedUrl(BaseModel):
url: HttpUrl
expires_in: Seconds = 300

@field_serializer("url")
def serialize_url(self, url: HttpUrl, _info):
return str(url)


class DrsObjectOutBase(DrsObjectBase):
checksums: List[Checksum]
Expand All @@ -145,6 +164,10 @@ class DrsObjectOutBase(DrsObjectBase):
updated_time: Optional[datetime.datetime] = None
version: Optional[str] = None

@field_serializer("self_uri")
def serialize_url(self, slf_uri: AnyUrl, _info):
return str(self_uri)


class DrsObjectBlobOut(DrsObjectOutBase):
access_methods: List[AccessMethod]
Expand Down
6 changes: 5 additions & 1 deletion nmdc_runtime/api/models/operation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import datetime
from typing import Generic, TypeVar, Optional, List, Any, Union

from pydantic import StringConstraints, BaseModel, HttpUrl
from pydantic import StringConstraints, BaseModel, HttpUrl, field_serializer

from nmdc_runtime.api.models.util import ResultT
from typing_extensions import Annotated
Expand Down Expand Up @@ -59,3 +59,7 @@ class ObjectPutMetadata(Metadata):
site_id: str
url: HttpUrl
expires_in_seconds: int

@field_serializer("url")
def serialize_url(self, url: HttpUrl, _info):
return str(url)
18 changes: 4 additions & 14 deletions nmdc_runtime/api/models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,7 @@ def _add_run_requested_event(run_spec: RunUserSpec, mdb: MongoDatabase, user: Us
time=now(as_str=True),
inputs=run_spec.inputs,
)
mdb.run_events.insert_one(
event.model_dump(
mode="json",
)
)
mdb.run_events.insert_one(event.model_dump())
return run_id


Expand All @@ -117,9 +113,7 @@ def _add_run_started_event(run_id: str, mdb: MongoDatabase):
job=requested.job,
type=RunEventType.STARTED,
time=now(as_str=True),
).model_dump(
mode="json",
)
).model_dump()
)
return run_id

Expand All @@ -140,9 +134,7 @@ def _add_run_fail_event(run_id: str, mdb: MongoDatabase):
job=requested.job,
type=RunEventType.FAIL,
time=now(as_str=True),
).model_dump(
mode="json",
)
).model_dump()
)
return run_id

Expand All @@ -164,8 +156,6 @@ def _add_run_complete_event(run_id: str, mdb: MongoDatabase, outputs: List[str])
type=RunEventType.COMPLETE,
time=now(as_str=True),
outputs=outputs,
).model_dump(
mode="json",
)
).model_dump()
)
return run_id
Loading

0 comments on commit ebe41c9

Please sign in to comment.