Skip to content

Commit

Permalink
chore: refactor create_hec_token fixture, add default pytest cancella…
Browse files Browse the repository at this point in the history
…tion in case of exception occurence
  • Loading branch information
mkolasinski-splunk committed Nov 26, 2023
1 parent 9cae1db commit 97cf053
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 52 deletions.
14 changes: 14 additions & 0 deletions pytest_splunk_addon/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@

test_generator = None

EXC_MAP = {
Exception: 1
}

def pytest_configure(config):
"""
Expand Down Expand Up @@ -115,6 +118,7 @@ def pytest_sessionstart(session):
SampleXdistGenerator.tokenized_event_source = session.config.getoption(
"tokenized_event_source"
).lower()
session.__exc_limits = EXC_MAP
if (
SampleXdistGenerator.tokenized_event_source == "store_new"
and session.config.getoption("ingest_events").lower()
Expand Down Expand Up @@ -198,5 +202,15 @@ def init_pytest_splunk_addon_logger():
return logger


def pytest_exception_interact(node, call, report):
session = node.session
type_ = call.excinfo.type

if type_ in session.__exc_limits:
if session.__exc_limits[type_] == 0:
pytest.exit(f"Reached max exception for type: {type_}")
else:
session.__exc_limits[type_] -= 1

init_pytest_splunk_addon_logger()
LOGGER = logging.getLogger("pytest-splunk-addon")
117 changes: 65 additions & 52 deletions pytest_splunk_addon/splunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,13 +107,6 @@ def pytest_addoption(parser):
default="8088",
help="Splunk HTTP event collector port. default is 8088.",
)
group.addoption(
"--splunk-hec-token",
action="store",
dest="splunk_hec_token",
help='Splunk HTTP event collector token.',
required=True
)
group.addoption(
"--splunk-port",
action="store",
Expand Down Expand Up @@ -201,13 +194,6 @@ def pytest_addoption(parser):
default="514",
help="SC4S Port. default is 514",
)
group.addoption(
"--sc4s-version",
action="store",
dest="sc4s_version",
default="latest",
help="SC4S version. default is latest",
)
group.addoption(
"--thread-count",
action="store",
Expand Down Expand Up @@ -549,7 +535,6 @@ def splunk_docker(
os.environ["SPLUNK_APP_ID"] = config["package"]["id"]
except Exception:
os.environ["SPLUNK_APP_ID"] = "TA_package"
os.environ["SPLUNK_HEC_TOKEN"] = request.config.getoption("splunk_hec_token")
os.environ["SPLUNK_USER"] = request.config.getoption("splunk_user")
os.environ["SPLUNK_PASSWORD"] = request.config.getoption("splunk_password")
os.environ["SPLUNK_VERSION"] = request.config.getoption("splunk_version")
Expand Down Expand Up @@ -686,64 +671,93 @@ def splunk_rest_uri(splunk):
@pytest.fixture(scope="session")
def splunk_hec_uri(request, splunk):
"""
Provides a uri to the Splunk hec port
Provides a uri to the Splunk services/collector endpoint
"""
splunk_session = requests.Session()
splunk_session.headers = {
"Authorization": f'Splunk {request.config.getoption("splunk_hec_token")}'
}
uri = f'{request.config.getoption("splunk_hec_scheme")}://{splunk["forwarder_host"]}:{splunk["port_hec"]}/services/collector'
LOGGER.info("Fetched splunk_hec_uri=%s", uri)

return splunk_session, uri
return uri

@pytest.fixture(scope="session")
def splunk_inputs_uri(request, splunk):
"""
Provides a uri to the Splunk data/inputs/all endpoint
Provides a uri to the Splunk services/data/inputs/http endpoint
"""
uri = f'{request.config.getoption("splunk_hec_scheme")}://{splunk["forwarder_host"]}:{splunk["port"]}/services/data/inputs/http'
LOGGER.info("Fetched splunk_inputs_uri=%s", uri)

return uri

def extract_token_from_xml(xml_content):
root = ET.fromstring(xml_content)
elements_with_name_attrib = [element for element in root.iter() if 'name' in element.attrib]
for element in elements_with_name_attrib:
if element.attrib["name"] == "token":
return element.text

@pytest.fixture(scope="session")
def create_hec_token(request, splunk_inputs_uri):
splunk_token_name = "splunk_hec_token_psa"
response = requests.post(
splunk_inputs_uri,
verify=False,
auth=(request.config.getoption("splunk_user"), request.config.getoption("splunk_password")),
data=f"name={splunk_token_name}",
)
"""
Creates an HEC token in Splunk instance. Exports its value to SPLUNK_HEC_TOKEN env variable.
Returns:
requests.Session: A session with headers containing Authorization: Splunk <HEC token>.
"""
# Default splunk HEC token name
splunk_token_name = "splunk_hec_token"
try:
response = _create_new_token(request, splunk_inputs_uri, splunk_token_name)
token_value = _extract_token_from_xml(response.text)

except Exception as e:
logging.error(f"Failed to create HEC token: {e}")
raise

splunk_session = requests.Session()
splunk_session.headers = {"Authorization": f'Splunk {token_value}'}
os.environ["SPLUNK_HEC_TOKEN"] = splunk_session.headers.get("Authorization", "").split(" ")[1]

return splunk_session

def _create_new_token(request, splunk_inputs_uri, splunk_token_name):
try:
response = requests.post(
splunk_inputs_uri,
verify=False,
auth=(request.config.getoption("splunk_user"), request.config.getoption("splunk_password")),
data=f"name={splunk_token_name}",
)
response.raise_for_status()
token_value = extract_token_from_xml(response.text)
return response

except requests.exceptions.HTTPError as e:
if e.response.status_code == 409:
existing_token_response = requests.get(
f"{splunk_inputs_uri}/{splunk_token_name}",
verify=False,
auth=(request.config.getoption("splunk_user"), request.config.getoption("splunk_password")),
)
existing_token_response.raise_for_status()
token_value = extract_token_from_xml(existing_token_response.text)
# Token already exists; attempt to retrieve the existing token
return _get_existing_token(request, splunk_inputs_uri, splunk_token_name)
else:
raise Exception("HEC token creation failed!")
raise Exception(f"HTTP error during token creation: {e}") from e

except Exception as e:
logging.error(f"An error occurred during HEC token creation: {e}")
raise Exception(f"HEC token creation failed: {e}") from e

def _get_existing_token(request, splunk_inputs_uri, splunk_token_name):
try:
response = requests.get(
f"{splunk_inputs_uri}/{splunk_token_name}",
verify=False,
auth=(request.config.getoption("splunk_user"), request.config.getoption("splunk_password")),
)
response.raise_for_status()
return response

except Exception as e:
logging.error(f"Failed to retrieve existing HEC token: {e}")
raise Exception(f"Failed to retrieve existing HEC token: {e}") from e

def _extract_token_from_xml(xml_content):
"""
Extracts token value from a xml formatted content
"""
root = ET.fromstring(xml_content)
elements_with_name_attrib = [element for element in root.iter() if 'name' in element.attrib]
for element in elements_with_name_attrib:
if element.attrib["name"] == "token":
return element.text

splunk_session = requests.Session()
splunk_session.headers = {
"Authorization": f'Splunk {token_value}'
}
os.environ["SPLUNK_HEC_TOKEN"] = splunk_session.headers["Authorization"].split(" ")[1]
return splunk_session

@pytest.fixture(scope="session")
def splunk_web_uri(request, splunk):
Expand Down Expand Up @@ -775,7 +789,6 @@ def splunk_ingest_data(request, splunk_hec_uri, sc4s, uf, splunk_events_cleanup,
"""
if request.config.getoption("ingest_events").lower() in ["n", "no", "false", "f"]:
return
#create_hec_token(splunk_inputs_uri=splunk_inputs_uri)
global PYTEST_XDIST_TESTRUNUID
if (
"PYTEST_XDIST_WORKER" not in os.environ
Expand All @@ -789,7 +802,7 @@ def splunk_ingest_data(request, splunk_hec_uri, sc4s, uf, splunk_events_cleanup,
"uf_username": uf.get("uf_username"),
"uf_password": uf.get("uf_password"),
"session_headers": create_hec_token.headers,
"splunk_hec_uri": splunk_hec_uri[1],
"splunk_hec_uri": splunk_hec_uri,
"sc4s_host": sc4s[0], # for sc4s
"sc4s_port": sc4s[1][514], # for sc4s
}
Expand Down

0 comments on commit 97cf053

Please sign in to comment.