diff --git a/tests/nvd/cpe_match/test_api.py b/tests/nvd/cpe_match/test_api.py index c5d9a39b7..6decde737 100644 --- a/tests/nvd/cpe_match/test_api.py +++ b/tests/nvd/cpe_match/test_api.py @@ -399,6 +399,63 @@ async def test_cpe_matches_request_results(self): with self.assertRaises(Exception): cpe_match = await anext(it) + async def test_cpe_match_caching(self): + match_criteria_id = uuid4() + cpe_name_id = uuid4() + + responses = create_cpe_match_responses( + match_criteria_id=match_criteria_id, + cpe_name_id=cpe_name_id, + results_per_response=3, + ) + self.http_client.get.side_effect = responses + response_matches = [ + [ + match_string["match_string"]["matches"] + for match_string in response.json.return_value["match_strings"] + ] + for response in responses + ] + + # Make matches of first match_string identical in each response + response_matches[1][0][0]["cpe_name"] = response_matches[0][0][0][ + "cpe_name" + ] + response_matches[1][0][0]["cpe_name_id"] = response_matches[0][0][0][ + "cpe_name_id" + ] + # Make matches of second match_string only have the same cpe_name_id + response_matches[1][1][0]["cpe_name_id"] = response_matches[0][1][0][ + "cpe_name_id" + ] + # Leave matches of third match_string different from each other + + it = aiter(self.api.cpe_matches(request_results=10)) + received = [item async for item in it] + + # First matches in each response of three items must be identical objects + self.assertIs(received[0].matches[0], received[3].matches[0]) + + # Second matches in each response of three items must only have same cpe_name_id + self.assertIsNot(received[1].matches[0], received[4].matches[0]) + self.assertEqual( + received[1].matches[0].cpe_name_id, + received[4].matches[0].cpe_name_id, + ) + self.assertNotEqual( + received[1].matches[0].cpe_name, received[4].matches[0].cpe_name + ) + + # Third matches in each response of three items must be different + self.assertIsNot(received[2].matches[0], received[5].matches[0]) + self.assertNotEqual( + received[2].matches[0].cpe_name_id, + received[5].matches[0].cpe_name_id, + ) + self.assertNotEqual( + received[2].matches[0].cpe_name, received[5].matches[0].cpe_name + ) + async def test_context_manager(self): async with self.api: pass