-
Notifications
You must be signed in to change notification settings - Fork 221
/
launch_kubernetes.py
414 lines (375 loc) · 16.2 KB
/
launch_kubernetes.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
#!/opt/conda/bin/python
"""Launch on kubernetes."""
import argparse
import os
import sys
from typing import Dict, List
import urllib3
import yaml
from jinja2 import Environment, FileSystemLoader, select_autoescape
from kubernetes import client, config
from kubernetes.client.rest import ApiException
urllib3.disable_warnings()
KERNEL_POD_TEMPLATE_PATH = "/kernel-pod.yaml.j2"
def generate_kernel_pod_yaml(keywords):
"""Return the kubernetes pod spec as a yaml string.
- load jinja2 template from this file directory.
- substitute template variables with keywords items.
"""
j_env = Environment(
loader=FileSystemLoader(os.path.dirname(__file__)),
trim_blocks=True,
lstrip_blocks=True,
autoescape=select_autoescape(
disabled_extensions=(
"j2",
"yaml",
),
default_for_string=True,
default=True,
),
)
# jinja2 template substitutes template variables with None though keywords doesn't
# contain corresponding item. Therefore, no need to check if any are left unsubstituted.
# Kubernetes API server will validate the pod spec instead.
k8s_yaml = j_env.get_template(KERNEL_POD_TEMPLATE_PATH).render(**keywords)
return k8s_yaml
def extend_pod_env(pod_def: dict) -> dict:
"""Extends the pod_def.spec.containers[0].env stanza with current environment."""
env_stanza = pod_def["spec"]["containers"][0].get("env") or []
# Walk current set of template env entries and replace those found in the current
# env with their values (and record those items). Then add all others from the env
# that were not already.
processed_entries: List[str] = []
for item in env_stanza:
item_name = item.get("name")
if item_name in os.environ:
item["value"] = os.environ[item_name]
processed_entries.append(item_name)
for name, value in os.environ.items():
if name not in processed_entries:
env_stanza.append({"name": name, "value": value})
pod_def["spec"]["containers"][0]["env"] = env_stanza
return pod_def
# a popular reason that lasts many APIs but is not constantized in the client lib
K8S_ALREADY_EXIST_REASON = "AlreadyExists"
def _parse_k8s_exception(exc: ApiException) -> str:
"""Parse the exception and return the error message from kubernetes api
Args:
exc (Exception): Exception object
Returns:
str: Error message from kubernetes api
"""
# more exception can be parsed, but at the time of implementation we only need this one
msg = f'"reason":{K8S_ALREADY_EXIST_REASON}'
if exc.status == 409 and exc.reason == "Conflict" and msg in exc.body:
return K8S_ALREADY_EXIST_REASON
return ""
def launch_kubernetes_kernel(
kernel_id,
port_range,
response_addr,
public_key,
spark_context_init_mode,
pod_template_file,
spark_opts_out,
kernel_class_name,
):
"""Launches a containerized kernel as a kubernetes pod."""
if os.getenv("KUBERNETES_SERVICE_HOST"):
config.load_incluster_config()
else:
config.load_kube_config()
# Capture keywords and their values.
keywords = {}
# Factory values...
# Since jupyter lower cases the kernel directory as the kernel-name, we need to capture its case-sensitive
# value since this is used to locate the kernel launch script within the image.
# Ensure these key/value pairs are reflected in the environment. We'll add these to the container's env
# stanza after the pod template is generated.
if port_range:
os.environ["PORT_RANGE"] = port_range
if public_key:
os.environ["PUBLIC_KEY"] = public_key
if response_addr:
os.environ["RESPONSE_ADDRESS"] = response_addr
if kernel_id:
os.environ["KERNEL_ID"] = kernel_id
if spark_context_init_mode:
os.environ["KERNEL_SPARK_CONTEXT_INIT_MODE"] = spark_context_init_mode
if kernel_class_name:
os.environ["KERNEL_CLASS_NAME"] = kernel_class_name
os.environ["KERNEL_NAME"] = os.path.basename(
os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
)
# Walk env variables looking for names prefixed with KERNEL_. When found, set corresponding keyword value
# with name in lower case.
for name, value in os.environ.items():
if name.startswith("KERNEL_"):
keywords[name.lower()] = yaml.safe_load(value)
# Substitute all template variable (wrapped with {{ }}) and generate `yaml` string.
k8s_yaml = generate_kernel_pod_yaml(keywords)
# For each k8s object (kind), call the appropriate API method. Too bad there isn't a method
# that can take a set of objects.
#
# Creation for additional kinds of k8s objects can be added below. Refer to
# https://github.com/kubernetes-client/python for API signatures. Other examples can be found in
# https://github.com/jupyter-server/enterprise_gateway/tree/main/enterprise_gateway/services/processproxies/k8s.py
#
pod_template = None
pod_created = None
kernel_namespace = keywords["kernel_namespace"]
k8s_objs = yaml.safe_load_all(k8s_yaml)
for k8s_obj in k8s_objs:
if k8s_obj.get("kind"):
if k8s_obj["kind"] == "Pod":
# print("{}".format(k8s_obj)) # useful for debug
pod_template = extend_pod_env(k8s_obj)
if pod_template_file is None:
try:
pod_created = client.CoreV1Api(client.ApiClient()).create_namespaced_pod(
body=k8s_obj, namespace=kernel_namespace
)
except ApiException as exc:
if _parse_k8s_exception(exc) == K8S_ALREADY_EXIST_REASON:
pod_created = (
client.CoreV1Api(client.ApiClient())
.list_namespaced_pod(
namespace=kernel_namespace,
label_selector=f"kernel_id={kernel_id}",
watch=False,
)
.items[0]
)
else:
raise exc
elif k8s_obj["kind"] == "Secret":
if pod_template_file is None:
client.CoreV1Api(client.ApiClient()).create_namespaced_secret(
body=k8s_obj, namespace=kernel_namespace
)
elif k8s_obj["kind"] == "PersistentVolumeClaim":
if pod_template_file is None:
try:
client.CoreV1Api(
client.ApiClient()
).create_namespaced_persistent_volume_claim(
body=k8s_obj, namespace=kernel_namespace
)
except ApiException as exc:
if _parse_k8s_exception(exc) == K8S_ALREADY_EXIST_REASON:
pass
else:
raise exc
elif k8s_obj["kind"] == "PersistentVolume":
if pod_template_file is None:
client.CoreV1Api(client.ApiClient()).create_persistent_volume(body=k8s_obj)
elif k8s_obj["kind"] == "Service":
if pod_template_file is None and pod_created is not None:
# Create dependency between pod and service, useful to delete service when kernel stops
k8s_obj["metadata"]["ownerReferences"] = [
{
"apiVersion": "v1",
"kind": "pod",
"name": str(pod_created.metadata.name),
"uid": str(pod_created.metadata.uid),
}
]
client.CoreV1Api(client.ApiClient()).create_namespaced_service(
body=k8s_obj, namespace=kernel_namespace
)
elif k8s_obj["kind"] == "ConfigMap":
if pod_template_file is None and pod_created is not None:
# Create dependency between pod and configmap, useful to delete service when kernel stops
k8s_obj["metadata"]["ownerReferences"] = [
{
"apiVersion": "v1",
"kind": "pod",
"name": str(pod_created.metadata.name),
"uid": str(pod_created.metadata.uid),
}
]
client.CoreV1Api(client.ApiClient()).create_namespaced_config_map(
body=k8s_obj, namespace=kernel_namespace
)
else:
sys.exit(
f"ERROR - Unhandled Kubernetes object kind '{k8s_obj['kind']}' found in yaml file - "
f"kernel launch terminating!"
)
else:
print("ERROR processing Kubernetes yaml file - kernel launch terminating!")
print(k8s_yaml)
sys.exit(
f"ERROR - Unknown Kubernetes object '{k8s_obj}' found in yaml file - kernel launch terminating!"
)
if pod_template_file:
# TODO - construct other --conf options for things like mounts, resources, etc.
# write yaml to file...
with open(pod_template_file, "w") as stream:
yaml.dump(pod_template, stream)
# Build up additional spark options. Note the trailing space to accommodate concatenation
additional_spark_opts = (
f"--conf spark.kubernetes.driver.podTemplateFile={pod_template_file} "
f"--conf spark.kubernetes.executor.podTemplateFile={pod_template_file} "
)
additional_spark_opts += _get_spark_resources(pod_template)
if spark_opts_out:
with open(spark_opts_out, "w+") as soo_fd:
soo_fd.write(additional_spark_opts)
else: # If no spark_opts_out was specified, print to stdout in case this is an old caller
print(additional_spark_opts)
def _get_spark_resources(pod_template: Dict) -> str:
# Gather up resources for cpu/memory requests/limits. Since gpus require a "discovery script"
# we'll leave that alone for now:
# https://spark.apache.org/docs/latest/running-on-kubernetes.html#resource-allocation-and-configuration-overview
#
# The config value names below are pulled from:
# https://spark.apache.org/docs/latest/running-on-kubernetes.html#container-spec
spark_resources = ""
containers = pod_template.get("spec", {}).get("containers", [])
if containers:
# We're just dealing with single-container pods at this time.
resources = containers[0].get("resources", {})
if resources:
requests = resources.get("requests", {})
if requests:
cpu_request = requests.get("cpu")
if cpu_request:
spark_resources += (
f"--conf spark.driver.cores={cpu_request} "
f"--conf spark.executor.cores={cpu_request} "
)
memory_request = requests.get("memory")
if memory_request:
spark_resources += (
f"--conf spark.driver.memory={memory_request} "
f"--conf spark.executor.memory={memory_request} "
)
limits = resources.get("limits", {})
if limits:
cpu_limit = limits.get("cpu")
if cpu_limit:
spark_resources += (
f"--conf spark.kubernetes.driver.limit.cores={cpu_limit} "
f"--conf spark.kubernetes.executor.limit.cores={cpu_limit} "
)
memory_limit = limits.get("memory")
if memory_limit:
spark_resources += (
f"--conf spark.driver.memory={memory_limit} "
f"--conf spark.executor.memory={memory_limit} "
)
return spark_resources
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--kernel-id",
dest="kernel_id",
nargs="?",
help="Indicates the id associated with the launched kernel.",
)
parser.add_argument(
"--port-range",
dest="port_range",
nargs="?",
metavar="<lowerPort>..<upperPort>",
help="Port range to impose for kernel ports",
)
parser.add_argument(
"--response-address",
dest="response_address",
nargs="?",
metavar="<ip>:<port>",
help="Connection address (<ip>:<port>) for returning connection file",
)
parser.add_argument(
"--public-key",
dest="public_key",
nargs="?",
help="Public key used to encrypt connection information",
)
parser.add_argument(
"--spark-context-initialization-mode",
dest="spark_context_init_mode",
nargs="?",
help="Indicates whether or how a spark context should be created",
)
parser.add_argument(
"--pod-template",
dest="pod_template_file",
nargs="?",
metavar="template filename",
help="When present, yaml is written to file, no launch performed.",
)
parser.add_argument(
"--spark-opts-out",
dest="spark_opts_out",
nargs="?",
metavar="additional spark options filename",
help="When present, additional spark options are written to file, "
"no launch performed, requires --pod-template.",
)
parser.add_argument(
"--kernel-class-name",
dest="kernel_class_name",
nargs="?",
help="Indicates the name of the kernel class to use. Must be a subclass of 'ipykernel.kernelbase.Kernel'.",
)
# The following arguments are deprecated and will be used only if their mirroring arguments have no value.
# This means that the default value for --spark-context-initialization-mode (none) will need to come from
# the mirrored args' default until deprecated item has been removed.
parser.add_argument(
"--RemoteProcessProxy.kernel-id",
dest="rpp_kernel_id",
nargs="?",
help="Indicates the id associated with the launched kernel. (deprecated)",
)
parser.add_argument(
"--RemoteProcessProxy.port-range",
dest="rpp_port_range",
nargs="?",
metavar="<lowerPort>..<upperPort>",
help="Port range to impose for kernel ports (deprecated)",
)
parser.add_argument(
"--RemoteProcessProxy.response-address",
dest="rpp_response_address",
nargs="?",
metavar="<ip>:<port>",
help="Connection address (<ip>:<port>) for returning connection file (deprecated)",
)
parser.add_argument(
"--RemoteProcessProxy.public-key",
dest="rpp_public_key",
nargs="?",
help="Public key used to encrypt connection information (deprecated)",
)
parser.add_argument(
"--RemoteProcessProxy.spark-context-initialization-mode",
dest="rpp_spark_context_init_mode",
nargs="?",
help="Indicates whether or how a spark context should be created (deprecated)",
default="none",
)
arguments = vars(parser.parse_args())
kernel_id = arguments["kernel_id"] or arguments["rpp_kernel_id"]
port_range = arguments["port_range"] or arguments["rpp_port_range"]
response_addr = arguments["response_address"] or arguments["rpp_response_address"]
public_key = arguments["public_key"] or arguments["rpp_public_key"]
spark_context_init_mode = (
arguments["spark_context_init_mode"] or arguments["rpp_spark_context_init_mode"]
)
pod_template_file = arguments["pod_template_file"]
spark_opts_out = arguments["spark_opts_out"]
kernel_class_name = arguments["kernel_class_name"]
launch_kubernetes_kernel(
kernel_id,
port_range,
response_addr,
public_key,
spark_context_init_mode,
pod_template_file,
spark_opts_out,
kernel_class_name,
)