Skip to content
This repository has been archived by the owner on Oct 19, 2022. It is now read-only.

Commit

Permalink
Update __init__.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SireInsectus committed Oct 3, 2022
1 parent 6f165ee commit 9c6c5b4
Showing 1 changed file with 26 additions and 2 deletions.
28 changes: 26 additions & 2 deletions src/dbacademy_gems/dbgems/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@

dbgems_module = sys.modules[globals()['__name__']]


# noinspection PyGlobalUndefined
def __init_globals():
# noinspection PyUnresolvedReferences
import dbruntime

global __is_initialized
Expand Down Expand Up @@ -46,17 +48,20 @@ def __init_globals():

dbgems_module.dbutils = dbutils


def deprecation_logging_enabled():
status = spark.conf.get("dbacademy.deprecation.logging", None)
return status is not None and status.lower() == "enabled"


def print_warning(title: str, message: str, length: int = 100):
title_len = length - len(title) - 3
print(f"""* {title.upper()} {("*"*title_len)}""")
for line in message.split("\n"):
print(f"* {line}")
print("*"*length)


def deprecated(reason=None):
def decorator(inner_function):
def wrapper(*args, **kwargs):
Expand All @@ -75,21 +80,26 @@ def wrapper(*args, **kwargs):
return wrapper
return decorator


@deprecated(reason="Use dbgems.dbutils instead.")
def get_dbutils(): # -> dbruntime.dbutils.DBUtils:
return dbgems_module.dbutils


@deprecated(reason="Use dbgems.spark instead.")
def get_spark_session() -> pyspark.sql.SparkSession:
return dbgems_module.spark


@deprecated(reason="Use dbgems.sc instead.")
def get_session_context() -> pyspark.context.SparkContext:
return dbgems_module.sc


def sql(query):
return spark.sql(query)


def get_parameter(name, default_value=""):
from py4j.protocol import Py4JJavaError
try:
Expand All @@ -102,6 +112,7 @@ def get_parameter(name, default_value=""):
else:
return default_value


def get_cloud():
with open("/databricks/common/conf/deploy.conf") as f:
for line in f:
Expand Down Expand Up @@ -147,6 +158,7 @@ def get_workspace_id() -> str:
# noinspection PyUnresolvedReferences
return dbutils.entry_point.getDbutils().notebook().getContext().workspaceId().getOrElse(None)


def get_notebook_path() -> str:
# noinspection PyUnresolvedReferences
return dbutils.entry_point.getDbutils().notebook().getContext().notebookPath().getOrElse(None)
Expand All @@ -159,6 +171,7 @@ def get_notebook_name() -> str:
def get_notebook_dir(offset=-1) -> str:
return "/".join(get_notebook_path().split("/")[:offset])


def get_notebooks_api_endpoint() -> str:
# noinspection PyUnresolvedReferences
return dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiUrl().getOrElse(None)
Expand All @@ -168,12 +181,14 @@ def get_notebooks_api_token() -> str:
# noinspection PyUnresolvedReferences
return dbutils.notebook.entry_point.getDbutils().notebook().getContext().apiToken().getOrElse(None)


def jprint(value: dict, indent: int = 4):
assert type(value) == dict or type(value) == list, f"Expected value to be of type \"dict\" or \"list\", found \"{type(value)}\"."

import json
print(json.dumps(value, indent=indent))


@deprecated(reason="Use dbacademy.dbrest.clusters.get_current_spark_version() instead.")
def get_current_spark_version(client=None):
if includes_dbrest:
Expand All @@ -187,6 +202,7 @@ def get_current_spark_version(client=None):
else:
raise Exception(f"Cannot use rest API with-out including dbacademy.dbrest")


@deprecated(reason="Use dbacademy.dbrest.clusters.get_current_instance_pool_id() instead.")
def get_current_instance_pool_id(client=None):
if includes_dbrest:
Expand Down Expand Up @@ -215,10 +231,12 @@ def get_current_node_type_id(client=None):
else:
raise Exception(f"Cannot use rest API with-out including dbacademy.dbrest")


def sort_semantic_versions(versions: List[str]) -> List[str]:
versions.sort(key=lambda v: (int(v.split(".")[0]) * 10000) + (int(v.split(".")[1]) * 100) + int(v.split(".")[2]))
return versions


def lookup_all_module_versions(module: str, github_org: str = "databricks-academy") -> List[str]:
import requests

Expand All @@ -230,6 +248,7 @@ def lookup_all_module_versions(module: str, github_org: str = "databricks-academ
versions = [t.get("name")[1:] for t in response.json()]
return sort_semantic_versions(versions)


def lookup_current_module_version(module: str, dist_version: str = "0.0.0", default: str = "v0.0.0") -> str:
import json, pkg_resources

Expand All @@ -245,10 +264,12 @@ def lookup_current_module_version(module: str, dist_version: str = "0.0.0", defa

return requested_revision


def is_curriculum_workspace() -> bool:
host_name = get_browser_host_name(default_value="unknown")
return host_name.startswith("curriculum-") and host_name.endswith(".cloud.databricks.com")


def validate_dependencies(module: str, curriculum_workspaces_only=True) -> bool:
# Don't do anything unless this is in one of the Curriculum Workspaces
testable = curriculum_workspaces_only is False or is_curriculum_workspace()
Expand All @@ -271,11 +292,11 @@ def validate_dependencies(module: str, curriculum_workspaces_only=True) -> bool:
return True # They match, all done!

print_warning(title=f"Outdated Dependency",
message=f"You are using version {current_version} but the latest version is {versions[-1]}.\n"+
message=f"You are using version \"{current_version}\" but the latest version is \"{versions[-1]}\".\n" +
f"Please update your dependencies on the module \"{module}\" at your earliest convenience.")
else:
print_warning(title=f"Invalid Dependency",
message=f"You are using the branch or commit hash {current_version} but the latest version is {versions[-1]}.\n"+
message=f"You are using the branch or commit hash \"{current_version}\" but the latest version is \"{versions[-1]}\".\n" +
f"Please update your dependencies on the module \"{module}\" at your earliest convenience.")
except Exception as e:
if testable:
Expand All @@ -285,6 +306,7 @@ def validate_dependencies(module: str, curriculum_workspaces_only=True) -> bool:

return False


# noinspection PyUnresolvedReferences
def proof_of_life(expected_get_username,
expected_get_tag,
Expand Down Expand Up @@ -367,6 +389,7 @@ def proof_of_life(expected_get_username,

print("All tests passed!")


def display_html(html) -> None:
import inspect
caller_frame = inspect.currentframe().f_back
Expand All @@ -378,6 +401,7 @@ def display_html(html) -> None:
caller_frame = caller_frame.f_back
raise ValueError("displayHTML not found in any caller frames.")


def display(html) -> None:
import inspect
caller_frame = inspect.currentframe().f_back
Expand Down

0 comments on commit 9c6c5b4

Please sign in to comment.