Skip to content

Commit

Permalink
cve_feed: Correct the retry and sleep time (#1410)
Browse files Browse the repository at this point in the history
### Summary

The [current
code](https://github.com/cartography-cncf/cartography/blob/4d53bce6d9f3f6703b709b70071cbcc36820a5bb/cartography/intel/cve/feed.py#L73-L113)
uses of variable `sleep_time` for both retries and sleep between
requests. Furthermore, the `sleep_time` was not reset in the next
request, meaning that if the first request increases it to 16 seconds,
the second request will continue to take that up to every bigger number
(the `retries` count is reset though which makes it worse).

Two changes here:
- The sleep between requests is set to be a dedicated variable
`sleep_between_requests`.
- I decided to rewrite the request retry logic more properly using
HttpAdapter's retry policy.

Unfortunately, I tried pretty hard to see if we can keep
`test_call_cves_api_with_error` which tests the retry logic but it's not
very feasible. Also it doesn't make a lot of sense to test something
managed by the Session object itself. So I had to drop it.

### Testing

I think existing integ test can confirm the code is still working. The
retry change itself will need some manual review to see if it makes
sense.

---------

Signed-off-by: Khanh Le Do <kledo@lyft.com>
kledo-lyft authored Dec 18, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 0d02954 commit eddab71
Showing 3 changed files with 98 additions and 99 deletions.
68 changes: 50 additions & 18 deletions cartography/intel/cve/__init__.py
Original file line number Diff line number Diff line change
@@ -2,6 +2,9 @@
from datetime import datetime

import neo4j
from requests import Session
from requests.adapters import HTTPAdapter
from urllib3 import Retry

from cartography.config import Config
from cartography.intel.cve import feed
@@ -13,28 +16,34 @@
stat_handler = get_stats_client(__name__)


@timeit
def start_cve_ingestion(
neo4j_session: neo4j.Session, config: Config,
) -> None:
"""
Perform ingestion of CVE data from NIST APIs.
:param neo4j_session: Neo4J session for database interface
:param config: A cartography.config object
:return: None
"""
if not config.cve_enabled:
return
cve_api_key: str | None = config.cve_api_key if config.cve_api_key else None
def _retryable_session() -> Session:
session = Session()
retry_policy = Retry(
total=8,
connect=1,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["GET"],
)
session.mount("https://", HTTPAdapter(max_retries=retry_policy))
logger.info(f"Configured session with retry policy: {retry_policy}")
return session

# sync CVE year archives, if not yet synced

def _sync_year_archives(
http_session: Session,
neo4j_session: neo4j.Session,
config: Config,
cve_api_key: str | None,
) -> None:
existing_years = feed.get_cve_sync_metadata(neo4j_session)
current_year = datetime.now().year
logger.info(f"Syncing CVE data for year archives. Existing years: {existing_years}. Current year: {current_year}")
for year in range(2002, current_year + 1):
if year in existing_years:
continue
logger.info(f"Syncing CVE data for year {year}")
cves = feed.get_published_cves_per_year(config.nist_cve_url, str(year), cve_api_key)
cves = feed.get_published_cves_per_year(http_session, config.nist_cve_url, str(year), cve_api_key)
feed_metadata = feed.transform_cve_feed(cves)
feed.load_cve_feed(neo4j_session, [feed_metadata], config.update_tag)
published_cves = feed.transform_cves(cves)
@@ -48,10 +57,16 @@ def start_cve_ingestion(
stat_handler=stat_handler,
)

# sync modified data

def _sync_modified_data(
http_session: Session,
neo4j_session: neo4j.Session,
config: Config,
cve_api_key: str | None,
) -> None:
logger.info("Syncing CVE data for modified data")
last_modified_date = feed.get_last_modified_cve_date(neo4j_session)
cves = feed.get_modified_cves(config.nist_cve_url, last_modified_date, cve_api_key)
cves = feed.get_modified_cves(http_session, config.nist_cve_url, last_modified_date, cve_api_key)
feed_metadata = feed.transform_cve_feed(cves)
feed.load_cve_feed(neo4j_session, [feed_metadata], config.update_tag)
modified_cves = feed.transform_cves(cves)
@@ -65,4 +80,21 @@ def start_cve_ingestion(
stat_handler=stat_handler,
)

# CVEs are never deleted, so we don't need to run a cleanup job

@timeit
def start_cve_ingestion(
neo4j_session: neo4j.Session, config: Config,
) -> None:
"""
Perform ingestion of CVE data from NIST APIs.
:param neo4j_session: Neo4J session for database interface
:param config: A cartography.config object
:return: None
"""
if not config.cve_enabled:
return
cve_api_key: str | None = config.cve_api_key if config.cve_api_key else None
with _retryable_session() as http_session:
_sync_year_archives(http_session, neo4j_session=neo4j_session, config=config, cve_api_key=cve_api_key)
_sync_modified_data(http_session, neo4j_session=neo4j_session, config=config, cve_api_key=cve_api_key)
# CVEs are never deleted, so we don't need to run a cleanup job
66 changes: 23 additions & 43 deletions cartography/intel/cve/feed.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from typing import Optional

import neo4j
import requests
from requests import Session

from cartography.client.core.tx import load
from cartography.client.core.tx import read_list_of_values_tx
@@ -22,7 +22,6 @@

logger = logging.getLogger(__name__)

MAX_RETRIES = 8
# Connect and read timeouts of 120 seconds each; see https://requests.readthedocs.io/en/master/user/advanced/#timeouts
CONNECT_AND_READ_TIMEOUT = (30, 120)
CVE_FEED_ID = "NIST_NVD"
@@ -68,53 +67,36 @@ def _map_cve_dict(cve_dict: Dict[Any, Any], data: Dict[Any, Any]) -> None:
cve_dict["startIndex"] = data["startIndex"]


def _call_cves_api(url: str, api_key: str | None, params: Dict[str, Any]) -> Dict[Any, Any]:
totalResults = 0
sleep_time = DEFAULT_SLEEP_TIME
retries = 0
def _call_cves_api(http_session: Session, url: str, api_key: str | None, params: Dict[str, Any]) -> Dict[Any, Any]:
total_results = 0
params["startIndex"] = 0
params["resultsPerPage"] = RESULTS_PER_PAGE
headers = {}
headers["Content-Type"] = "application/json"
headers = {"Content-Type": "application/json"}
if api_key:
sleep_between_requests = DEFAULT_SLEEP_TIME
headers["apiKey"] = api_key
else:
sleep_time = DELAYED_SLEEP_TIME # Sleep for 6 seconds between each request to avoid rate limiting
sleep_between_requests = DELAYED_SLEEP_TIME
logger.warning(
f"No NIST NVD API key provided. Increasing sleep time to {sleep_time}.",
f"No NIST NVD API key provided. Increasing sleep time to {sleep_between_requests}.",
)
results: Dict[Any, Any] = dict()

with requests.Session() as session:
while params["resultsPerPage"] > 0 or params["startIndex"] < totalResults:
logger.info(f"Calling NIST NVD API at {url} with params {params}")
try:
res = session.get(
url, params=params, headers=headers, timeout=CONNECT_AND_READ_TIMEOUT,
)
res.raise_for_status()
data = res.json()
except requests.exceptions.HTTPError:
logger.error(
f"Failed to get CVE data from NIST NVD API {res.status_code} : {res.text}",
)
retries += 1
if retries >= MAX_RETRIES:
raise
# Exponential backoff
sleep_time *= 2
time.sleep(sleep_time)
continue
_map_cve_dict(results, data)
totalResults = data["totalResults"]
params["resultsPerPage"] = data["resultsPerPage"]
params["startIndex"] += data["resultsPerPage"]
retries = 0
time.sleep(sleep_time)
while params["resultsPerPage"] > 0 or params["startIndex"] < total_results:
logger.info(f"Calling NIST NVD API at {url} with params {params}")
res = http_session.get(url, params=params, headers=headers, timeout=CONNECT_AND_READ_TIMEOUT)
res.raise_for_status()
data = res.json()
_map_cve_dict(results, data)
total_results = data["totalResults"]
params["resultsPerPage"] = data["resultsPerPage"]
params["startIndex"] += data["resultsPerPage"]
time.sleep(sleep_between_requests)
return results


def get_cves_in_batches(
http_session: Session,
nist_cve_url: str,
start_date: datetime,
end_date: datetime,
@@ -147,7 +129,7 @@ def get_cves_in_batches(
logger.info(
f"Querying CVE data between {current_start_date} and {current_end_date}",
)
batch_cves = _call_cves_api(nist_cve_url, api_key, params)
batch_cves = _call_cves_api(http_session, nist_cve_url, api_key, params)
_map_cve_dict(cves, batch_cves)
current_start_date = current_end_date
new_end_date = current_start_date + batch_size
@@ -158,9 +140,8 @@ def get_cves_in_batches(


def get_modified_cves(
nist_cve_url: str, last_modified_date: str, api_key: str | None,
http_session: Session, nist_cve_url: str, last_modified_date: str, api_key: str | None,
) -> Dict[Any, Any]:
cves = dict()
end_date = datetime.now(tz=timezone.utc)
start_date = datetime.strptime(last_modified_date, "%Y-%m-%dT%H:%M:%S").replace(
tzinfo=timezone.utc,
@@ -170,15 +151,14 @@ def get_modified_cves(
"end": "lastModEndDate",
}
cves = get_cves_in_batches(
nist_cve_url, start_date, end_date, date_param_names, api_key,
http_session, nist_cve_url, start_date, end_date, date_param_names, api_key,
)
return cves


def get_published_cves_per_year(
nist_cve_url: str, year: str, api_key: str | None,
http_session: Session, nist_cve_url: str, year: str, api_key: str | None,
) -> Dict[Any, Any]:
cves = {}
start_of_year = datetime.strptime(f"{year}-01-01", "%Y-%m-%d")
next_year = int(year) + 1
end_of_next_year = datetime.strptime(f"{next_year}-01-01", "%Y-%m-%d")
@@ -187,7 +167,7 @@ def get_published_cves_per_year(
"end": "pubEndDate",
}
cves = get_cves_in_batches(
nist_cve_url, start_of_year, end_of_next_year, date_param_names, api_key,
http_session, nist_cve_url, start_of_year, end_of_next_year, date_param_names, api_key,
)
return cves

63 changes: 25 additions & 38 deletions tests/unit/cartography/intel/cve/test_feed.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import Mock
from unittest.mock import patch

import pytest
import requests
from requests import Session

from cartography.intel.cve.feed import _call_cves_api
from cartography.intel.cve.feed import _map_cve_dict
@@ -20,14 +21,16 @@


@pytest.fixture
def mock_get():
with patch("cartography.intel.cve.feed.requests.Session") as mock_session:
session_mock = mock_session.return_value.__enter__.return_value
yield session_mock.get
def mock_session():
return MagicMock(spec=Session)


def test_call_cves_api(mock_get):
# Arrange
@pytest.fixture(autouse=True)
def mock_time_sleep(mocker):
return mocker.patch("time.sleep")


def _mock_good_responses() -> list[Mock]:
mock_response_1 = Mock()
mock_response_1.status_code = 200
mock_response_1.json.return_value = {
@@ -93,8 +96,12 @@ def test_call_cves_api(mock_get):
"timestamp": "2024-01-10T19:30:07.520",
"vulnerabilities": [],
}
return [mock_response_1, mock_response_2, mock_response_3]

mock_get.side_effect = [mock_response_1, mock_response_2, mock_response_3]

def test_call_cves_api(mock_session):
# Arrange
mock_session.get.side_effect = _mock_good_responses()
params = {"start": "2024-01-10T00:00:00Z", "end": "2024-01-10T23:59:59Z"}
expected_result = {
"resultsPerPage": 0,
@@ -138,35 +145,15 @@ def test_call_cves_api(mock_get):
}

# Act
result = _call_cves_api(NIST_CVE_URL, API_KEY, params)
result = _call_cves_api(mock_session, NIST_CVE_URL, API_KEY, params)

# Assert
assert mock_get.call_count == 3
assert mock_session.get.call_count == 3
assert result == expected_result


@patch("cartography.intel.cve.feed.DEFAULT_SLEEP_TIME", 0)
def test_call_cves_api_with_error(mock_get: Mock):
# Arrange
mock_response = Mock()
mock_response.status_code = 404
mock_response.message = "Data error"
mock_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock_response,
)
mock_get.return_value = mock_response
params = {"start": "2024-01-10T00:00:00Z", "end": "2024-01-10T23:59:59Z"}

# Act
try:
_call_cves_api(NIST_CVE_URL, API_KEY, params)
except requests.exceptions.HTTPError as err:
assert err.response == mock_response
assert mock_get.call_count == 8


@patch("cartography.intel.cve.feed._call_cves_api")
def test_get_cves_in_batches(mock_call_cves_api: Mock):
def test_get_cves_in_batches(mock_call_cves_api: Mock, mock_session: Session):
"""
Ensure that we get the correct number of CVEs in batches of 120 days
"""
@@ -185,15 +172,15 @@ def test_get_cves_in_batches(mock_call_cves_api: Mock):
_map_cve_dict(excepted_cves, GET_CVE_API_DATA_BATCH_2)
# Act
cves = get_cves_in_batches(
NIST_CVE_URL, start_date, end_date, date_param_names, API_KEY,
mock_session, NIST_CVE_URL, start_date, end_date, date_param_names, API_KEY,
)
# Assert
assert mock_call_cves_api.call_count == 2
assert cves == excepted_cves


@patch("cartography.intel.cve.feed._call_cves_api")
def test_get_modified_cves(mock_call_cves_api: Mock):
def test_get_modified_cves(mock_call_cves_api: Mock, mock_session: Session):
# Arrange
mock_call_cves_api.side_effect = [GET_CVE_API_DATA]
last_modified_date = datetime.now(tz=timezone.utc) + timedelta(days=-1)
@@ -204,14 +191,14 @@ def test_get_modified_cves(mock_call_cves_api: Mock):
"lastModEndDate": current_date_iso8601,
}
# Act
cves = get_modified_cves(NIST_CVE_URL, last_modified_date_iso8601, API_KEY)
cves = get_modified_cves(mock_session, NIST_CVE_URL, last_modified_date_iso8601, API_KEY)
# Assert
mock_call_cves_api.assert_called_once_with(NIST_CVE_URL, API_KEY, expected_params)
mock_call_cves_api.assert_called_once_with(mock_session, NIST_CVE_URL, API_KEY, expected_params)
assert cves == GET_CVE_API_DATA


@patch("cartography.intel.cve.feed._call_cves_api")
def test_get_published_cves_per_year(mock_call_cves_api: Mock):
def test_get_published_cves_per_year(mock_call_cves_api: Mock, mock_session: Session):
# Arrange
no_cves = {
"resultsPerPage": 0,
@@ -226,7 +213,7 @@ def test_get_published_cves_per_year(mock_call_cves_api: Mock):
_map_cve_dict(expected_cves, no_cves)
mock_call_cves_api.side_effect = [GET_CVE_API_DATA, no_cves, no_cves, no_cves]
# Act
cves = get_published_cves_per_year(NIST_CVE_URL, "2024", API_KEY)
cves = get_published_cves_per_year(mock_session, NIST_CVE_URL, "2024", API_KEY)
# Assert
mock_call_cves_api.call_count == 4
assert mock_call_cves_api.call_count == 4
assert cves == expected_cves

0 comments on commit eddab71

Please sign in to comment.