Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Telemetry Headers #15

Merged
merged 6 commits into from
Feb 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions tests/unit-tests/client_test.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@
# -*- coding: utf-8 -*-

import os
import re
import unittest

import pytest

from xata.client import XataClient
from xata.client import XataClient, SDK_VERSION

PATTERNS_UUID4 = re.compile(r"^[\da-f]{8}-([\da-f]{4}-){3}[\da-f]{12}$", re.IGNORECASE)
PATTERNS_SDK_VERSION = re.compile(r"^[0-9]{1,3}.[0-9]{1,3}.[0-9]{1,3}$")

class TestXataClient(unittest.TestCase):

"""
'apiKey': self.api_key,
'location': self.api_key_location,
'workspaceId': self.workspace_id,
'region': self.region,
'dbName': self.db_name,
'branchName': self.branch_name,
"""

class TestXataClient(unittest.TestCase):
def test_init_api_key_with_params(self):
api_key = "param_ABCDEF123456789"

Expand Down Expand Up @@ -73,3 +67,36 @@ def test_init_db_url_invalid_combinations(self):

with pytest.raises(Exception):
XataClient(db_url="db_url", workspace_id="ws_id", db_name="db_name")

def test_sdk_version(self):
db_url = "https://py-sdk-unit-test-12345.eu-west-1.xata.sh/db/testopia-042"
client = XataClient(db_url=db_url)
cfg = client.get_config()

assert "version" in cfg
assert PATTERNS_SDK_VERSION.match(cfg["version"])
assert SDK_VERSION == cfg["version"]

def test_telemetry_headers(self):
api_key = "this-key-42"
client1 = XataClient(api_key=api_key, workspace_id="ws_id")
headers1 = client1.get_headers()

assert len(headers1) == 4
assert "authorization" in headers1
assert headers1["authorization"] == f"Bearer {api_key}"
assert "x-xata-client-id" in headers1
assert PATTERNS_UUID4.match(headers1["x-xata-client-id"])
assert "x-xata-session-id" in headers1
assert PATTERNS_UUID4.match(headers1["x-xata-session-id"])
assert headers1["x-xata-client-id"] != headers1["x-xata-session-id"]
assert "x-xata-agent" in headers1
assert headers1['x-xata-agent'] == f"client=PY_SDK;version={SDK_VERSION};"

api_key = "this-key-42"
client2 = XataClient(api_key=api_key, workspace_id="ws_id")
headers2 = client2.get_headers()

assert headers1["x-xata-client-id"] != headers2["x-xata-client-id"]
assert headers1["x-xata-session-id"] != headers2["x-xata-session-id"]
assert headers1['x-xata-agent'] == headers2['x-xata-agent']
2 changes: 2 additions & 0 deletions xata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# -*- coding: utf-8 -*-

from .client import XataClient

__all__ = ("XataClient",)
19 changes: 18 additions & 1 deletion xata/client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# -*- coding: utf-8 -*-

import importlib.metadata
import json
import os
import uuid
from typing import Literal, Optional
from urllib.parse import urljoin

Expand All @@ -14,6 +17,8 @@
UnauthorizedException,
)

SDK_VERSION = importlib.metadata.version(__package__ or __name__)

PERSONAL_API_KEY_LOCATION = "~/.config/xata/key"
DEFAULT_BASE_URL_DOMAIN = "xata.sh"
DEFAULT_CONTROL_PLANE_DOMAIN = "api.xata.io"
Expand Down Expand Up @@ -106,7 +111,12 @@ def __init__(
self.branch_name = (
self.get_branch_name_if_configured() if branch_name is None else branch_name
)
self.headers = {"authorization": f"Bearer {self.api_key}"}
self.headers = {
"authorization": f"Bearer {self.api_key}",
"x-xata-client-id": str(uuid.uuid4()),
"x-xata-session-id": str(uuid.uuid4()),
"x-xata-agent": f"client=PY_SDK;version={SDK_VERSION};",
}

def get_config(self) -> dict:
"""
Expand All @@ -119,8 +129,15 @@ def get_config(self) -> dict:
"region": self.region,
"dbName": self.db_name,
"branchName": self.branch_name,
"version": SDK_VERSION,
}

def get_headers(self) -> dict:
"""
Get the static headers that are iniatilized on client init.
"""
return self.headers

def get_api_key(self) -> tuple[str, ApiKeyLocation]:
if os.environ.get("XATA_API_KEY") is not None:
return os.environ.get("XATA_API_KEY"), "env"
Expand Down