diff --git a/cartography/intel/cve/__init__.py b/cartography/intel/cve/__init__.py index 7ef1ab626..ca40ac49b 100644 --- a/cartography/intel/cve/__init__.py +++ b/cartography/intel/cve/__init__.py @@ -2,6 +2,9 @@ from datetime import datetime import neo4j +from requests import Session +from requests.adapters import HTTPAdapter +from urllib3 import Retry from cartography.config import Config from cartography.intel.cve import feed @@ -13,28 +16,34 @@ stat_handler = get_stats_client(__name__) -@timeit -def start_cve_ingestion( - neo4j_session: neo4j.Session, config: Config, -) -> None: - """ - Perform ingestion of CVE data from NIST APIs. - :param neo4j_session: Neo4J session for database interface - :param config: A cartography.config object - :return: None - """ - if not config.cve_enabled: - return - cve_api_key: str | None = config.cve_api_key if config.cve_api_key else None +def _retryable_session() -> Session: + session = Session() + retry_policy = Retry( + total=8, + connect=1, + backoff_factor=1, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["GET"], + ) + session.mount("https://", HTTPAdapter(max_retries=retry_policy)) + logger.info(f"Configured session with retry policy: {retry_policy}") + return session - # sync CVE year archives, if not yet synced + +def _sync_year_archives( + http_session: Session, + neo4j_session: neo4j.Session, + config: Config, + cve_api_key: str | None, +) -> None: existing_years = feed.get_cve_sync_metadata(neo4j_session) current_year = datetime.now().year + logger.info(f"Syncing CVE data for year archives. Existing years: {existing_years}. Current year: {current_year}") for year in range(2002, current_year + 1): if year in existing_years: continue logger.info(f"Syncing CVE data for year {year}") - cves = feed.get_published_cves_per_year(config.nist_cve_url, str(year), cve_api_key) + cves = feed.get_published_cves_per_year(http_session, config.nist_cve_url, str(year), cve_api_key) feed_metadata = feed.transform_cve_feed(cves) feed.load_cve_feed(neo4j_session, [feed_metadata], config.update_tag) published_cves = feed.transform_cves(cves) @@ -48,10 +57,16 @@ def start_cve_ingestion( stat_handler=stat_handler, ) - # sync modified data + +def _sync_modified_data( + http_session: Session, + neo4j_session: neo4j.Session, + config: Config, + cve_api_key: str | None, +) -> None: logger.info("Syncing CVE data for modified data") last_modified_date = feed.get_last_modified_cve_date(neo4j_session) - cves = feed.get_modified_cves(config.nist_cve_url, last_modified_date, cve_api_key) + cves = feed.get_modified_cves(http_session, config.nist_cve_url, last_modified_date, cve_api_key) feed_metadata = feed.transform_cve_feed(cves) feed.load_cve_feed(neo4j_session, [feed_metadata], config.update_tag) modified_cves = feed.transform_cves(cves) @@ -65,4 +80,21 @@ def start_cve_ingestion( stat_handler=stat_handler, ) - # CVEs are never deleted, so we don't need to run a cleanup job + +@timeit +def start_cve_ingestion( + neo4j_session: neo4j.Session, config: Config, +) -> None: + """ + Perform ingestion of CVE data from NIST APIs. + :param neo4j_session: Neo4J session for database interface + :param config: A cartography.config object + :return: None + """ + if not config.cve_enabled: + return + cve_api_key: str | None = config.cve_api_key if config.cve_api_key else None + with _retryable_session() as http_session: + _sync_year_archives(http_session, neo4j_session=neo4j_session, config=config, cve_api_key=cve_api_key) + _sync_modified_data(http_session, neo4j_session=neo4j_session, config=config, cve_api_key=cve_api_key) + # CVEs are never deleted, so we don't need to run a cleanup job diff --git a/cartography/intel/cve/feed.py b/cartography/intel/cve/feed.py index 316bc7037..79a123dbd 100644 --- a/cartography/intel/cve/feed.py +++ b/cartography/intel/cve/feed.py @@ -11,7 +11,7 @@ from typing import Optional import neo4j -import requests +from requests import Session from cartography.client.core.tx import load from cartography.client.core.tx import read_list_of_values_tx @@ -22,7 +22,6 @@ logger = logging.getLogger(__name__) -MAX_RETRIES = 8 # Connect and read timeouts of 120 seconds each; see https://requests.readthedocs.io/en/master/user/advanced/#timeouts CONNECT_AND_READ_TIMEOUT = (30, 120) CVE_FEED_ID = "NIST_NVD" @@ -68,53 +67,36 @@ def _map_cve_dict(cve_dict: Dict[Any, Any], data: Dict[Any, Any]) -> None: cve_dict["startIndex"] = data["startIndex"] -def _call_cves_api(url: str, api_key: str | None, params: Dict[str, Any]) -> Dict[Any, Any]: - totalResults = 0 - sleep_time = DEFAULT_SLEEP_TIME - retries = 0 +def _call_cves_api(http_session: Session, url: str, api_key: str | None, params: Dict[str, Any]) -> Dict[Any, Any]: + total_results = 0 params["startIndex"] = 0 params["resultsPerPage"] = RESULTS_PER_PAGE - headers = {} - headers["Content-Type"] = "application/json" + headers = {"Content-Type": "application/json"} if api_key: + sleep_between_requests = DEFAULT_SLEEP_TIME headers["apiKey"] = api_key else: - sleep_time = DELAYED_SLEEP_TIME # Sleep for 6 seconds between each request to avoid rate limiting + sleep_between_requests = DELAYED_SLEEP_TIME logger.warning( - f"No NIST NVD API key provided. Increasing sleep time to {sleep_time}.", + f"No NIST NVD API key provided. Increasing sleep time to {sleep_between_requests}.", ) results: Dict[Any, Any] = dict() - with requests.Session() as session: - while params["resultsPerPage"] > 0 or params["startIndex"] < totalResults: - logger.info(f"Calling NIST NVD API at {url} with params {params}") - try: - res = session.get( - url, params=params, headers=headers, timeout=CONNECT_AND_READ_TIMEOUT, - ) - res.raise_for_status() - data = res.json() - except requests.exceptions.HTTPError: - logger.error( - f"Failed to get CVE data from NIST NVD API {res.status_code} : {res.text}", - ) - retries += 1 - if retries >= MAX_RETRIES: - raise - # Exponential backoff - sleep_time *= 2 - time.sleep(sleep_time) - continue - _map_cve_dict(results, data) - totalResults = data["totalResults"] - params["resultsPerPage"] = data["resultsPerPage"] - params["startIndex"] += data["resultsPerPage"] - retries = 0 - time.sleep(sleep_time) + while params["resultsPerPage"] > 0 or params["startIndex"] < total_results: + logger.info(f"Calling NIST NVD API at {url} with params {params}") + res = http_session.get(url, params=params, headers=headers, timeout=CONNECT_AND_READ_TIMEOUT) + res.raise_for_status() + data = res.json() + _map_cve_dict(results, data) + total_results = data["totalResults"] + params["resultsPerPage"] = data["resultsPerPage"] + params["startIndex"] += data["resultsPerPage"] + time.sleep(sleep_between_requests) return results def get_cves_in_batches( + http_session: Session, nist_cve_url: str, start_date: datetime, end_date: datetime, @@ -147,7 +129,7 @@ def get_cves_in_batches( logger.info( f"Querying CVE data between {current_start_date} and {current_end_date}", ) - batch_cves = _call_cves_api(nist_cve_url, api_key, params) + batch_cves = _call_cves_api(http_session, nist_cve_url, api_key, params) _map_cve_dict(cves, batch_cves) current_start_date = current_end_date new_end_date = current_start_date + batch_size @@ -158,9 +140,8 @@ def get_cves_in_batches( def get_modified_cves( - nist_cve_url: str, last_modified_date: str, api_key: str | None, + http_session: Session, nist_cve_url: str, last_modified_date: str, api_key: str | None, ) -> Dict[Any, Any]: - cves = dict() end_date = datetime.now(tz=timezone.utc) start_date = datetime.strptime(last_modified_date, "%Y-%m-%dT%H:%M:%S").replace( tzinfo=timezone.utc, @@ -170,15 +151,14 @@ def get_modified_cves( "end": "lastModEndDate", } cves = get_cves_in_batches( - nist_cve_url, start_date, end_date, date_param_names, api_key, + http_session, nist_cve_url, start_date, end_date, date_param_names, api_key, ) return cves def get_published_cves_per_year( - nist_cve_url: str, year: str, api_key: str | None, + http_session: Session, nist_cve_url: str, year: str, api_key: str | None, ) -> Dict[Any, Any]: - cves = {} start_of_year = datetime.strptime(f"{year}-01-01", "%Y-%m-%d") next_year = int(year) + 1 end_of_next_year = datetime.strptime(f"{next_year}-01-01", "%Y-%m-%d") @@ -187,7 +167,7 @@ def get_published_cves_per_year( "end": "pubEndDate", } cves = get_cves_in_batches( - nist_cve_url, start_of_year, end_of_next_year, date_param_names, api_key, + http_session, nist_cve_url, start_of_year, end_of_next_year, date_param_names, api_key, ) return cves diff --git a/tests/unit/cartography/intel/cve/test_feed.py b/tests/unit/cartography/intel/cve/test_feed.py index 8af9c448b..758769f6e 100644 --- a/tests/unit/cartography/intel/cve/test_feed.py +++ b/tests/unit/cartography/intel/cve/test_feed.py @@ -1,11 +1,12 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +from unittest.mock import MagicMock from unittest.mock import Mock from unittest.mock import patch import pytest -import requests +from requests import Session from cartography.intel.cve.feed import _call_cves_api from cartography.intel.cve.feed import _map_cve_dict @@ -20,14 +21,16 @@ @pytest.fixture -def mock_get(): - with patch("cartography.intel.cve.feed.requests.Session") as mock_session: - session_mock = mock_session.return_value.__enter__.return_value - yield session_mock.get +def mock_session(): + return MagicMock(spec=Session) -def test_call_cves_api(mock_get): - # Arrange +@pytest.fixture(autouse=True) +def mock_time_sleep(mocker): + return mocker.patch("time.sleep") + + +def _mock_good_responses() -> list[Mock]: mock_response_1 = Mock() mock_response_1.status_code = 200 mock_response_1.json.return_value = { @@ -93,8 +96,12 @@ def test_call_cves_api(mock_get): "timestamp": "2024-01-10T19:30:07.520", "vulnerabilities": [], } + return [mock_response_1, mock_response_2, mock_response_3] - mock_get.side_effect = [mock_response_1, mock_response_2, mock_response_3] + +def test_call_cves_api(mock_session): + # Arrange + mock_session.get.side_effect = _mock_good_responses() params = {"start": "2024-01-10T00:00:00Z", "end": "2024-01-10T23:59:59Z"} expected_result = { "resultsPerPage": 0, @@ -138,35 +145,15 @@ def test_call_cves_api(mock_get): } # Act - result = _call_cves_api(NIST_CVE_URL, API_KEY, params) + result = _call_cves_api(mock_session, NIST_CVE_URL, API_KEY, params) # Assert - assert mock_get.call_count == 3 + assert mock_session.get.call_count == 3 assert result == expected_result -@patch("cartography.intel.cve.feed.DEFAULT_SLEEP_TIME", 0) -def test_call_cves_api_with_error(mock_get: Mock): - # Arrange - mock_response = Mock() - mock_response.status_code = 404 - mock_response.message = "Data error" - mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError( - response=mock_response, - ) - mock_get.return_value = mock_response - params = {"start": "2024-01-10T00:00:00Z", "end": "2024-01-10T23:59:59Z"} - - # Act - try: - _call_cves_api(NIST_CVE_URL, API_KEY, params) - except requests.exceptions.HTTPError as err: - assert err.response == mock_response - assert mock_get.call_count == 8 - - @patch("cartography.intel.cve.feed._call_cves_api") -def test_get_cves_in_batches(mock_call_cves_api: Mock): +def test_get_cves_in_batches(mock_call_cves_api: Mock, mock_session: Session): """ Ensure that we get the correct number of CVEs in batches of 120 days """ @@ -185,7 +172,7 @@ def test_get_cves_in_batches(mock_call_cves_api: Mock): _map_cve_dict(excepted_cves, GET_CVE_API_DATA_BATCH_2) # Act cves = get_cves_in_batches( - NIST_CVE_URL, start_date, end_date, date_param_names, API_KEY, + mock_session, NIST_CVE_URL, start_date, end_date, date_param_names, API_KEY, ) # Assert assert mock_call_cves_api.call_count == 2 @@ -193,7 +180,7 @@ def test_get_cves_in_batches(mock_call_cves_api: Mock): @patch("cartography.intel.cve.feed._call_cves_api") -def test_get_modified_cves(mock_call_cves_api: Mock): +def test_get_modified_cves(mock_call_cves_api: Mock, mock_session: Session): # Arrange mock_call_cves_api.side_effect = [GET_CVE_API_DATA] last_modified_date = datetime.now(tz=timezone.utc) + timedelta(days=-1) @@ -204,14 +191,14 @@ def test_get_modified_cves(mock_call_cves_api: Mock): "lastModEndDate": current_date_iso8601, } # Act - cves = get_modified_cves(NIST_CVE_URL, last_modified_date_iso8601, API_KEY) + cves = get_modified_cves(mock_session, NIST_CVE_URL, last_modified_date_iso8601, API_KEY) # Assert - mock_call_cves_api.assert_called_once_with(NIST_CVE_URL, API_KEY, expected_params) + mock_call_cves_api.assert_called_once_with(mock_session, NIST_CVE_URL, API_KEY, expected_params) assert cves == GET_CVE_API_DATA @patch("cartography.intel.cve.feed._call_cves_api") -def test_get_published_cves_per_year(mock_call_cves_api: Mock): +def test_get_published_cves_per_year(mock_call_cves_api: Mock, mock_session: Session): # Arrange no_cves = { "resultsPerPage": 0, @@ -226,7 +213,7 @@ def test_get_published_cves_per_year(mock_call_cves_api: Mock): _map_cve_dict(expected_cves, no_cves) mock_call_cves_api.side_effect = [GET_CVE_API_DATA, no_cves, no_cves, no_cves] # Act - cves = get_published_cves_per_year(NIST_CVE_URL, "2024", API_KEY) + cves = get_published_cves_per_year(mock_session, NIST_CVE_URL, "2024", API_KEY) # Assert - mock_call_cves_api.call_count == 4 + assert mock_call_cves_api.call_count == 4 assert cves == expected_cves