Skip to content

Commit

Permalink
feat: Implement partial update with PATCH and add tests
Browse files Browse the repository at this point in the history
This commit allows partial updates of datapoints using the PATCH method. It also includes tests to ensure the correct functionality of this feature. A new class has been added to handle the partial update functionality.
  • Loading branch information
basarbyz committed Jun 27, 2024
1 parent 8eac4b7 commit 63861b0
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 23 deletions.
136 changes: 113 additions & 23 deletions backend/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import logging
import time


__version__ = "0.2.0"
app = FastAPI()
# enable CORS for the frontend
Expand All @@ -36,6 +35,7 @@
logging.basicConfig(level=settings.LOG_LEVEL.upper(),
format='%(asctime)s %(name)s %(levelname)s: %(message)s')


# Pydantic model
class Datapoint(BaseModel):
object_id: Optional[str] = Field(None, min_length=1, max_length=255)
Expand All @@ -47,6 +47,14 @@ class Datapoint(BaseModel):
description: Optional[str] = ""
matchDatapoint: Optional[bool] = False


class DatapointPartialUpdate(BaseModel):
entity_id: Optional[str] = Field(None, min_length=1, max_length=255)
entity_type: Optional[str] = Field(None, min_length=1, max_length=255)
attribute_name: Optional[str] = Field(None, min_length=1, max_length=255)
description: Optional[str] = ""


@app.on_event("startup")
async def startup():
"""
Expand Down Expand Up @@ -124,7 +132,7 @@ async def get_datapoints(conn: asyncpg.Connection = Depends(get_connection)):
If the datapoint is not found, an error will be raised.",
)
async def get_datapoint(
object_id: str, conn: asyncpg.Connection = Depends(get_connection)
object_id: str, conn: asyncpg.Connection = Depends(get_connection)
):
"""
Get a specific datapoint from the gateway. This is to allow the frontend to display a specific datapoint in the database.
Expand All @@ -145,6 +153,7 @@ async def get_datapoint(
raise HTTPException(status_code=404, detail="Device not found!")
return row


@app.post(
"/data",
response_model=Datapoint,
Expand All @@ -156,7 +165,7 @@ async def get_datapoint(
database that a new datapoint has been added as well as whether the topic needs to be subscribed to.",
)
async def add_datapoint(
datapoint: Datapoint, conn: asyncpg.Connection = Depends(get_connection)
datapoint: Datapoint, conn: asyncpg.Connection = Depends(get_connection)
):
"""
Add a new datapoint to the gateway. This is to allow to add new datapoints to the gateway via the frontend.
Expand All @@ -175,7 +184,7 @@ async def add_datapoint(
"""
datapoint.object_id = str(uuid4())
if datapoint.matchDatapoint and (
datapoint.entity_id is None or datapoint.attribute_name is None
datapoint.entity_id is None or datapoint.attribute_name is None
):
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -226,10 +235,9 @@ async def add_datapoint(
if not subscribed:
stream_name = "manage_topics"
await app.state.notifier.xadd(
stream_name,
{'subscribe': datapoint.topic},
)

stream_name,
{'subscribe': datapoint.topic},
)

return {**datapoint.dict(), "subscribe": subscribed is None}

Expand Down Expand Up @@ -333,14 +341,86 @@ async def update_datapoint(
raise HTTPException(status_code=500, detail="Internal Server Error")


@app.patch(
"/data/{object_id}",
response_model=Datapoint,
summary="Partially update a specific datapoint in the gateway",
description="Partially update a specific datapoint in the gateway. This allows the frontend to update specific fields of a datapoint.",
)
async def partial_update_datapoint(
object_id: str,
datapoint_update: DatapointPartialUpdate,
conn: asyncpg.Connection = Depends(get_connection),
):
existing_datapoint = await conn.fetchrow(
"""SELECT * FROM datapoints WHERE object_id=$1""", object_id
)
if existing_datapoint is None:
raise HTTPException(status_code=404, detail="Datapoint not found!")

update_data = datapoint_update.dict(exclude_unset=True)

if 'entity_id' in update_data and 'attribute_name' not in update_data and existing_datapoint['attribute_name'] is None:
raise HTTPException(
status_code=400,
detail="attribute_name must be set if entity_id is provided!",
)

if 'attribute_name' in update_data and 'entity_id' not in update_data and existing_datapoint['entity_id'] is None:
raise HTTPException(
status_code=400,
detail="entity_id must be set if attribute_name is provided!",
)

if not update_data:
raise HTTPException(
status_code=400,
detail="No valid fields provided for update.",
)

try:
async with conn.transaction():
# Dynamically build the SQL query to update only provided fields
set_clauses = ", ".join([f"{key} = ${i + 1}" for i, key in enumerate(update_data.keys())])
values = list(update_data.values()) + [object_id]
query = f"UPDATE datapoints SET {set_clauses} WHERE object_id = ${len(values)}"
await conn.execute(query, *values)

# Retrieve updated datapoint
updated_datapoint = await conn.fetchrow(
"""SELECT * FROM datapoints WHERE object_id=$1""", object_id
)

# Update the datapoint in Redis
await app.state.redis.hset(
updated_datapoint['topic'],
object_id,
json.dumps(
{
"object_id": object_id,
"jsonpath": updated_datapoint['jsonpath'],
"entity_id": updated_datapoint['entity_id'],
"entity_type": updated_datapoint['entity_type'],
"attribute_name": updated_datapoint['attribute_name'],
"description": updated_datapoint['description'],
}
),
)

return updated_datapoint

except Exception as e:
logging.error(str(e))
raise HTTPException(status_code=500, detail="Internal Server Error!")

@app.delete(
"/data/{object_id}",
status_code=204,
summary="Delete a specific datapoint from the gateway",
description="Delete a specific datapoint from the gateway. This is to allow the frontend to delete a datapoint from the gateway.",
)
async def delete_datapoint(
object_id: str, conn: asyncpg.Connection = Depends(get_connection)
object_id: str, conn: asyncpg.Connection = Depends(get_connection)
):
"""
Delete a specific datapoint from the gateway. This is to allow the frontend to delete a datapoint from the gateway and unsubscribe from the topic if it is the last subscriber.
Expand Down Expand Up @@ -376,9 +456,9 @@ async def delete_datapoint(
# await app.state.notifier.publish("unsubscribe", datapoint["topic"])
stream_name = "manage_topics"
await app.state.notifier.xadd(
stream_name,
{'unsubscribe': datapoint["topic"]},
)
stream_name,
{'unsubscribe': datapoint["topic"]},
)
return None
except Exception as e:
logging.error(str(e))
Expand Down Expand Up @@ -427,7 +507,7 @@ async def delete_all_datapoints(conn: asyncpg.Connection = Depends(get_connectio
description="Get the match status of a specific datapoint. This is to allow the frontend to check whether a datapoint is matched to an existing entity/attribute pair in the Context Broker.",
)
async def get_match_status(
object_id: str, conn: asyncpg.Connection = Depends(get_connection)
object_id: str, conn: asyncpg.Connection = Depends(get_connection)
):
"""
Get the match status of a specific datapoint. This is to allow the frontend to check whether a datapoint is matched to an existing entity/attribute pair in the Context Broker.
Expand Down Expand Up @@ -455,11 +535,12 @@ async def get_match_status(
)
return response.status == 200


@app.get("/system/status",
response_model=dict,
summary="Get the status of the system",
description="Get the status of the system. This is to allow the frontend to check whether the system is running properly.",
)
response_model=dict,
summary="Get the status of the system",
description="Get the status of the system. This is to allow the frontend to check whether the system is running properly.",
)
async def get_status():
checks = {
"orion": await check_orion(),
Expand All @@ -475,28 +556,32 @@ async def get_status():
}
return system_status


@app.get("/system/version",
response_model=dict,
summary="Get the version of the system and the dependencies",
description="Get the version of the system. This is to allow the frontend to check the version of the system and its dependencies."
)
)
async def get_version_info():
"""
Return version information for the application and its dependencies.
"""
dependencies = ["fastapi", "aiohttp", "asyncpg", "pydantic", "redis", "uvicorn"]

def get_dependency_version(package: str):
"""
Get the version of a package.
"""
return importlib.metadata.version(package)

version_results = [get_dependency_version(dep) for dep in dependencies]
version_info = {
"application_version": __version__,
"dependencies": dict(zip(dependencies, version_results))
}
return version_info


async def check_orion():
"""
Check whether the Orion Context Broker is running properly.
Expand All @@ -506,14 +591,16 @@ async def check_orion():
async with aiohttp.ClientSession() as session:
response = await session.get(f"{ORION_URL}/version")
status = response.status == 200
latency = (time.time() - start_time)*1000
latency = (time.time() - start_time) * 1000
return {"status": status, "latency": latency, "latency_unit": "ms",
"message": None if status else "Failed to connect"}
except Exception as e:
latency = time.time() - start_time
logging.error(f"Error checking Orion: {e}")
return {"status": False, "latency": latency,
"latency_unit": "ms", "message": str(e)}


async def check_postgres():
"""
Check whether the PostgreSQL database is running properly.
Expand All @@ -522,30 +609,33 @@ async def check_postgres():
try:
async with app.state.pool.acquire() as connection:
await connection.execute("SELECT 1")
latency = (time.time() - start_time)*1000
latency = (time.time() - start_time) * 1000
return {"status": True, "latency": latency,
"latency_unit": "ms", "message": None}
except Exception as e:
latency = (time.time() - start_time)*1000
latency = (time.time() - start_time) * 1000
logging.error(f"Error checking PostgreSQL: {e}")
return {"status": False, "latency": latency,
"latency_unit": "ms", "message": str(e)}


async def check_redis():
"""
Check whether the Redis cache is running properly.
"""
start_time = time.time()
try:
await app.state.redis.ping()
latency = (time.time() - start_time)*1000
latency = (time.time() - start_time) * 1000
return {"status": True, "latency": latency,
"latency_unit": "ms", "message": None}
except Exception as e:
latency = (time.time() - start_time)*1000
latency = (time.time() - start_time) * 1000
logging.error(f"Error checking Redis: {e}")
return {"status": False, "latency": latency,
"latency_unit": "ms", "message": str(e)}


if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000, reload=True,
log_level=settings.LOG_LEVEL.lower())
Loading

0 comments on commit 63861b0

Please sign in to comment.