From 63861b023cbee497d14154358a4f403ef68160ae Mon Sep 17 00:00:00 2001 From: BasarBayaz Date: Thu, 27 Jun 2024 08:44:33 +0000 Subject: [PATCH] feat: Implement partial update with PATCH and add tests This commit allows partial updates of datapoints using the PATCH method. It also includes tests to ensure the correct functionality of this feature. A new class has been added to handle the partial update functionality. --- backend/api/main.py | 136 ++++++++++++++++++++++++++++++++++++-------- tests/test_crud.py | 113 ++++++++++++++++++++++++++++++++++++ 2 files changed, 226 insertions(+), 23 deletions(-) diff --git a/backend/api/main.py b/backend/api/main.py index 1146088..0908f19 100644 --- a/backend/api/main.py +++ b/backend/api/main.py @@ -13,7 +13,6 @@ import logging import time - __version__ = "0.2.0" app = FastAPI() # enable CORS for the frontend @@ -36,6 +35,7 @@ logging.basicConfig(level=settings.LOG_LEVEL.upper(), format='%(asctime)s %(name)s %(levelname)s: %(message)s') + # Pydantic model class Datapoint(BaseModel): object_id: Optional[str] = Field(None, min_length=1, max_length=255) @@ -47,6 +47,14 @@ class Datapoint(BaseModel): description: Optional[str] = "" matchDatapoint: Optional[bool] = False + +class DatapointPartialUpdate(BaseModel): + entity_id: Optional[str] = Field(None, min_length=1, max_length=255) + entity_type: Optional[str] = Field(None, min_length=1, max_length=255) + attribute_name: Optional[str] = Field(None, min_length=1, max_length=255) + description: Optional[str] = "" + + @app.on_event("startup") async def startup(): """ @@ -124,7 +132,7 @@ async def get_datapoints(conn: asyncpg.Connection = Depends(get_connection)): If the datapoint is not found, an error will be raised.", ) async def get_datapoint( - object_id: str, conn: asyncpg.Connection = Depends(get_connection) + object_id: str, conn: asyncpg.Connection = Depends(get_connection) ): """ Get a specific datapoint from the gateway. This is to allow the frontend to display a specific datapoint in the database. @@ -145,6 +153,7 @@ async def get_datapoint( raise HTTPException(status_code=404, detail="Device not found!") return row + @app.post( "/data", response_model=Datapoint, @@ -156,7 +165,7 @@ async def get_datapoint( database that a new datapoint has been added as well as whether the topic needs to be subscribed to.", ) async def add_datapoint( - datapoint: Datapoint, conn: asyncpg.Connection = Depends(get_connection) + datapoint: Datapoint, conn: asyncpg.Connection = Depends(get_connection) ): """ Add a new datapoint to the gateway. This is to allow to add new datapoints to the gateway via the frontend. @@ -175,7 +184,7 @@ async def add_datapoint( """ datapoint.object_id = str(uuid4()) if datapoint.matchDatapoint and ( - datapoint.entity_id is None or datapoint.attribute_name is None + datapoint.entity_id is None or datapoint.attribute_name is None ): raise HTTPException( status_code=400, @@ -226,10 +235,9 @@ async def add_datapoint( if not subscribed: stream_name = "manage_topics" await app.state.notifier.xadd( - stream_name, - {'subscribe': datapoint.topic}, - ) - + stream_name, + {'subscribe': datapoint.topic}, + ) return {**datapoint.dict(), "subscribe": subscribed is None} @@ -333,6 +341,78 @@ async def update_datapoint( raise HTTPException(status_code=500, detail="Internal Server Error") +@app.patch( + "/data/{object_id}", + response_model=Datapoint, + summary="Partially update a specific datapoint in the gateway", + description="Partially update a specific datapoint in the gateway. This allows the frontend to update specific fields of a datapoint.", +) +async def partial_update_datapoint( + object_id: str, + datapoint_update: DatapointPartialUpdate, + conn: asyncpg.Connection = Depends(get_connection), +): + existing_datapoint = await conn.fetchrow( + """SELECT * FROM datapoints WHERE object_id=$1""", object_id + ) + if existing_datapoint is None: + raise HTTPException(status_code=404, detail="Datapoint not found!") + + update_data = datapoint_update.dict(exclude_unset=True) + + if 'entity_id' in update_data and 'attribute_name' not in update_data and existing_datapoint['attribute_name'] is None: + raise HTTPException( + status_code=400, + detail="attribute_name must be set if entity_id is provided!", + ) + + if 'attribute_name' in update_data and 'entity_id' not in update_data and existing_datapoint['entity_id'] is None: + raise HTTPException( + status_code=400, + detail="entity_id must be set if attribute_name is provided!", + ) + + if not update_data: + raise HTTPException( + status_code=400, + detail="No valid fields provided for update.", + ) + + try: + async with conn.transaction(): + # Dynamically build the SQL query to update only provided fields + set_clauses = ", ".join([f"{key} = ${i + 1}" for i, key in enumerate(update_data.keys())]) + values = list(update_data.values()) + [object_id] + query = f"UPDATE datapoints SET {set_clauses} WHERE object_id = ${len(values)}" + await conn.execute(query, *values) + + # Retrieve updated datapoint + updated_datapoint = await conn.fetchrow( + """SELECT * FROM datapoints WHERE object_id=$1""", object_id + ) + + # Update the datapoint in Redis + await app.state.redis.hset( + updated_datapoint['topic'], + object_id, + json.dumps( + { + "object_id": object_id, + "jsonpath": updated_datapoint['jsonpath'], + "entity_id": updated_datapoint['entity_id'], + "entity_type": updated_datapoint['entity_type'], + "attribute_name": updated_datapoint['attribute_name'], + "description": updated_datapoint['description'], + } + ), + ) + + return updated_datapoint + + except Exception as e: + logging.error(str(e)) + raise HTTPException(status_code=500, detail="Internal Server Error!") + @app.delete( "/data/{object_id}", status_code=204, @@ -340,7 +420,7 @@ async def update_datapoint( description="Delete a specific datapoint from the gateway. This is to allow the frontend to delete a datapoint from the gateway.", ) async def delete_datapoint( - object_id: str, conn: asyncpg.Connection = Depends(get_connection) + object_id: str, conn: asyncpg.Connection = Depends(get_connection) ): """ Delete a specific datapoint from the gateway. This is to allow the frontend to delete a datapoint from the gateway and unsubscribe from the topic if it is the last subscriber. @@ -376,9 +456,9 @@ async def delete_datapoint( # await app.state.notifier.publish("unsubscribe", datapoint["topic"]) stream_name = "manage_topics" await app.state.notifier.xadd( - stream_name, - {'unsubscribe': datapoint["topic"]}, - ) + stream_name, + {'unsubscribe': datapoint["topic"]}, + ) return None except Exception as e: logging.error(str(e)) @@ -427,7 +507,7 @@ async def delete_all_datapoints(conn: asyncpg.Connection = Depends(get_connectio description="Get the match status of a specific datapoint. This is to allow the frontend to check whether a datapoint is matched to an existing entity/attribute pair in the Context Broker.", ) async def get_match_status( - object_id: str, conn: asyncpg.Connection = Depends(get_connection) + object_id: str, conn: asyncpg.Connection = Depends(get_connection) ): """ Get the match status of a specific datapoint. This is to allow the frontend to check whether a datapoint is matched to an existing entity/attribute pair in the Context Broker. @@ -455,11 +535,12 @@ async def get_match_status( ) return response.status == 200 + @app.get("/system/status", - response_model=dict, - summary="Get the status of the system", - description="Get the status of the system. This is to allow the frontend to check whether the system is running properly.", -) + response_model=dict, + summary="Get the status of the system", + description="Get the status of the system. This is to allow the frontend to check whether the system is running properly.", + ) async def get_status(): checks = { "orion": await check_orion(), @@ -475,21 +556,24 @@ async def get_status(): } return system_status + @app.get("/system/version", response_model=dict, summary="Get the version of the system and the dependencies", description="Get the version of the system. This is to allow the frontend to check the version of the system and its dependencies." -) + ) async def get_version_info(): """ Return version information for the application and its dependencies. """ dependencies = ["fastapi", "aiohttp", "asyncpg", "pydantic", "redis", "uvicorn"] + def get_dependency_version(package: str): """ Get the version of a package. """ return importlib.metadata.version(package) + version_results = [get_dependency_version(dep) for dep in dependencies] version_info = { "application_version": __version__, @@ -497,6 +581,7 @@ def get_dependency_version(package: str): } return version_info + async def check_orion(): """ Check whether the Orion Context Broker is running properly. @@ -506,7 +591,7 @@ async def check_orion(): async with aiohttp.ClientSession() as session: response = await session.get(f"{ORION_URL}/version") status = response.status == 200 - latency = (time.time() - start_time)*1000 + latency = (time.time() - start_time) * 1000 return {"status": status, "latency": latency, "latency_unit": "ms", "message": None if status else "Failed to connect"} except Exception as e: @@ -514,6 +599,8 @@ async def check_orion(): logging.error(f"Error checking Orion: {e}") return {"status": False, "latency": latency, "latency_unit": "ms", "message": str(e)} + + async def check_postgres(): """ Check whether the PostgreSQL database is running properly. @@ -522,14 +609,16 @@ async def check_postgres(): try: async with app.state.pool.acquire() as connection: await connection.execute("SELECT 1") - latency = (time.time() - start_time)*1000 + latency = (time.time() - start_time) * 1000 return {"status": True, "latency": latency, "latency_unit": "ms", "message": None} except Exception as e: - latency = (time.time() - start_time)*1000 + latency = (time.time() - start_time) * 1000 logging.error(f"Error checking PostgreSQL: {e}") return {"status": False, "latency": latency, "latency_unit": "ms", "message": str(e)} + + async def check_redis(): """ Check whether the Redis cache is running properly. @@ -537,15 +626,16 @@ async def check_redis(): start_time = time.time() try: await app.state.redis.ping() - latency = (time.time() - start_time)*1000 + latency = (time.time() - start_time) * 1000 return {"status": True, "latency": latency, "latency_unit": "ms", "message": None} except Exception as e: - latency = (time.time() - start_time)*1000 + latency = (time.time() - start_time) * 1000 logging.error(f"Error checking Redis: {e}") return {"status": False, "latency": latency, "latency_unit": "ms", "message": str(e)} + if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000, reload=True, log_level=settings.LOG_LEVEL.lower()) diff --git a/tests/test_crud.py b/tests/test_crud.py index 833f7e4..0a77758 100644 --- a/tests/test_crud.py +++ b/tests/test_crud.py @@ -183,6 +183,119 @@ def test_update_put(self): json.loads(response.text).pop("matchDatapoint") ) + def test_partial_update_patch(self): + headers = { + 'Accept': 'application/json' + } + object_id = self.unmatched_object_id + + # Perform initial GET to fetch the existing datapoint + response = requests.request("GET", settings.GATEWAY_URL + "/data/" + object_id) + datapoint_basis = Datapoint( + **json.loads(response.text) + ) + + # Test case 1: Successful partial update of entity_id and attribute_name + update_data = { + "entity_id": "NewEntityID", + "attribute_name": "NewAttributeName" + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(update_data)) + self.assertTrue(response.ok) + + # Verify the update + response = requests.request("GET", settings.GATEWAY_URL + "/data/" + object_id) + updated_datapoint = Datapoint( + **json.loads(response.text) + ) + self.assertEqual(updated_datapoint.entity_id, update_data["entity_id"]) + self.assertEqual(updated_datapoint.attribute_name, update_data["attribute_name"]) + + # Test case 2: Attempt to update forbidden fields jsonpath and topic + forbidden_update_data = { + "jsonpath": "$..new_jsonpath", + "topic": "new/topic" + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(forbidden_update_data)) + self.assertFalse(response.ok) + self.assertEqual(response.status_code, 400) + + # Verify that jsonpath and topic are unchanged + response = requests.request("GET", settings.GATEWAY_URL + "/data/" + object_id) + updated_datapoint = Datapoint( + **json.loads(response.text) + ) + self.assertEqual(updated_datapoint.jsonpath, datapoint_basis.jsonpath) + self.assertEqual(updated_datapoint.topic, datapoint_basis.topic) + + # Test case 3: Provide only entity_id without attribute_name + invalid_update_data1 = { + "entity_id": "AnotherNewEntityID" + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(invalid_update_data1)) + # Check if attribute_name is set + self.assertTrue("attribute_name" in response.json()) + self.assertEqual(response.status_code, 200) + + # Test case 4: Provide only attribute_name without entity_id + invalid_update_data2 = { + "attribute_name": "AnotherNewAttributeName" + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(invalid_update_data2)) + # Check if attribute_name is set + self.assertTrue("entity_id" in response.json()) + self.assertEqual(response.status_code, 200) + + # Test case 5: Attempt to update matchDatapoint field + invalid_update_data3 = { + "matchDatapoint": True + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(invalid_update_data3)) + self.assertFalse(response.ok) + self.assertEqual(response.status_code, 400) + + # Test case 6: Update entity_type field + update_entity_type_data = { + "entity_type": "NewEntityType" + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(update_entity_type_data)) + self.assertTrue(response.ok) + + # Verify the entity_type update + response = requests.request("GET", settings.GATEWAY_URL + "/data/" + object_id) + updated_datapoint = Datapoint( + **json.loads(response.text) + ) + self.assertEqual(updated_datapoint.entity_type, update_entity_type_data["entity_type"]) + + # Test case 7: Update description field + update_description_data = { + "description": "Updated description" + } + response = requests.request("PATCH", settings.GATEWAY_URL + "/data/" + object_id, + headers=headers, + data=json.dumps(update_description_data)) + self.assertTrue(response.ok) + + # Verify the description update + response = requests.request("GET", settings.GATEWAY_URL + "/data/" + object_id) + updated_datapoint = Datapoint( + **json.loads(response.text) + ) + self.assertEqual(updated_datapoint.description, update_description_data["description"]) + def test_delete(self): object_id = self.unmatched_object_id response = requests.request("DELETE", settings.GATEWAY_URL + "/data/" + object_id)