-
Notifications
You must be signed in to change notification settings - Fork 80
/
util.bzl
352 lines (293 loc) · 11.9 KB
/
util.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
load("@bazel_tools//tools/cpp:lib_cc_configure.bzl", "get_cpu_value")
load("@bazel_skylib//lib:paths.bzl", "paths")
load("@bazel_skylib//lib:versions.bzl", "versions")
def fail_on_err(return_value, prefix = None):
"""Fail if the given return value indicates an error.
Args:
return_value: Pair; If the second element is not `None` this indicates an error.
prefix: optional, String; A prefix for the error message contained in `return_value`.
Returns:
The first element of `return_value` if no error was indicated.
"""
result, err = return_value
if err:
if prefix:
msg = prefix + err
else:
msg = err
fail(msg)
return result
def is_supported_platform(repository_ctx):
return repository_ctx.which("nix-build") != None
def _is_executable(repository_ctx, path):
stat_exe = repository_ctx.which("stat")
if stat_exe == None:
return False
# A hack to detect if stat in Nix shell is BSD stat as BSD stat does not
# support --version flag
is_bsd_stat = repository_ctx.execute([stat_exe, "--version"]).return_code != 0
if is_bsd_stat:
stat_args = ["-f", "%Lp", path]
else:
stat_args = ["-c", "%a", path]
arguments = [stat_exe] + stat_args
exec_result = repository_ctx.execute(arguments)
stdout = exec_result.stdout.strip()
mode = int(stdout, 8)
return mode & 0o100 != 0
def external_repository_root(label):
"""Get path to repository root from label."""
return "/".join([
component
for component in [label.workspace_root, label.package, label.name]
if component
])
def cp(repository_ctx, src, dest = None):
"""Copy the given file into the external repository root.
Args:
repository_ctx: The repository context of the current repository rule.
src: The source file. Must be a Label if dest is None.
dest: Optional, The target path within the current repository root.
By default the relative path to the repository root is preserved.
Returns:
The dest value
"""
if dest == None:
if type(src) != "Label":
fail("src must be a Label if dest is not specified explicitly.")
dest = external_repository_root(src)
src_path = repository_ctx.path(src)
dest_path = repository_ctx.path(dest)
executable = _is_executable(repository_ctx, src_path)
# Copy the file
repository_ctx.file(
dest_path,
repository_ctx.read(src_path),
executable = executable,
legacy_utf8 = False,
)
return dest
def execute_or_fail(repository_ctx, arguments, failure_message = "", *args, **kwargs):
"""Call repository_ctx.execute() and fail if non-zero return code."""
result = repository_ctx.execute(arguments, *args, **kwargs)
if result.return_code:
outputs = dict(
failure_message = failure_message,
command = " ".join([repr(str(a)) for a in arguments]),
return_code = result.return_code,
stderr = ' > '.join(('\n'+result.stderr).splitlines(True)),
)
fail("""
{failure_message}
Command: {command}
Return code: {return_code}
Error output: {stderr}
""".format(**outputs))
return result
def label_string(label):
"""Convert the given (optional) Label to a string."""
if not label:
return "None"
else:
return '"%s"' % label
def executable_path(repository_ctx, exe_name, extra_msg = ""):
"""Try to find the executable, fail with an error."""
path = repository_ctx.which(exe_name)
if path == None:
fail("Could not find the `{}` executable in PATH.{}\n"
.format(exe_name, " " + extra_msg if extra_msg else ""))
return path
def find_children(repository_ctx, target_dir):
find_args = [
executable_path(repository_ctx, "find"),
"-L",
target_dir,
"-maxdepth",
"1",
# otherwise the directory is printed as well
"-mindepth",
"1",
# filenames can contain \n
"-print0",
]
exec_result = execute_or_fail(repository_ctx, find_args)
return exec_result.stdout.rstrip("\000").split("\000")
def default_constraints(repository_ctx):
"""Calculate the default CPU and OS constraints based on the host platform.
Args:
repository_ctx: The repository context of the current repository rule.
Returns:
A list containing the cpu and os constraints.
"""
cpu_value = get_cpu_value(repository_ctx)
cpu = {
"darwin": "@platforms//cpu:x86_64",
"darwin_arm64": "@platforms//cpu:arm64",
"aarch64": "@platforms//cpu:arm64",
}.get(cpu_value, "@platforms//cpu:x86_64")
os = {
"darwin": "@platforms//os:osx",
"darwin_arm64": "@platforms//os:osx",
}.get(cpu_value, "@platforms//os:linux")
return [cpu, os]
def ensure_constraints_pure(default_constraints, target_constraints = [], exec_constraints = []):
"""Build exec and target constraints for repository rules.
If these are user-provided, then they are passed through.
Otherwise, use the provided default constraints.
In either case, exec_constraints always contain the support_nix constraint, so the toolchain can be rejected on non-Nix environments.
Args:
target_constraints: optional, User provided target_constraints.
exec_constraints: optional, User provided exec_constraints.
default_constraints: Fall-back constraints.
Returns:
exec_constraints, The generated list of exec constraints
target_constraints, The generated list of target constraints
"""
if not target_constraints and not exec_constraints:
target_constraints = default_constraints
exec_constraints = target_constraints
else:
target_constraints = list(target_constraints)
exec_constraints = list(exec_constraints)
exec_constraints.append("@rules_nixpkgs_core//constraints:support_nix")
return exec_constraints, target_constraints
def ensure_constraints(repository_ctx):
"""Build exec and target constraints for repository rules.
If these are user-provided, then they are passed through.
Otherwise we build for the current CPU on the current OS, one of darwin-x86_64, darwin-arm64, or the default linux-x86_64.
In either case, exec_constraints always contain the support_nix constraint, so the toolchain can be rejected on non-Nix environments.
Args:
repository_ctx: The repository context of the current repository rule.
Returns:
exec_constraints, The generated list of exec constraints
target_constraints, The generated list of target constraints
"""
return ensure_constraints_pure(
default_constraints = default_constraints(repository_ctx),
target_constraints = repository_ctx.attr.target_constraints,
exec_constraints = repository_ctx.attr.exec_constraints,
)
def parse_expand_location(string):
"""Parse a string that might contain location expansion commands.
Generates a list of pairs of command and argument.
The command can have the following values:
- `string`: argument is a string, append it to the result.
- `location`: argument is a label, append its location to the result.
Attrs:
string: string, The string to parse.
Returns:
(result, error):
result: The generated list of pairs of command and argument.
error: string or None, This is set if an error occurred.
"""
result = []
offset = 0
len_string = len(string)
# Step through occurrences of `$`. This is bounded by the length of the string.
for _ in range(len_string):
# Find the position of the next `$`.
position = string.find("$", offset)
if position == -1:
position = len_string
# Append the in-between literal string.
if offset < position:
result.append(("string", string[offset:position]))
# Terminate at the end of the string.
if position == len_string:
break
# Parse the `$` command.
if string[position:].startswith("$$"):
# Insert verbatim '$'.
result.append(("string", "$"))
offset = position + 2
elif string[position:].startswith("$("):
# Expand a location command.
group_start = position + 2
group_end = string.find(")", group_start)
if group_end == -1:
return (None, "Unbalanced parentheses in location expansion for '{}'.".format(string[position:]))
group = string[group_start:group_end]
command = None
if group.startswith("location "):
label_str = group[len("location "):]
command = ("location", label_str)
else:
return (None, "Unrecognized location expansion '$({})'.".format(group))
result.append(command)
offset = group_end + 1
else:
return (None, "Unescaped '$' in location expansion at position {} of input.".format(position))
return (result, None)
def resolve_label(label_str, labels):
"""Find the label that corresponds to the given string.
Attr:
label_str: string, String representation of a label.
labels: dict from String to path: Known label-string to path mappings.
Returns:
(path, error):
path: path, The path to the resolved label
error: string or None, This is set if an error occurred.
"""
label_candidates = [
(lbl_str, path)
for (lbl_str, path) in labels.items()
if Label(lbl_str).relative(label_str) == Label(lbl_str)
]
if len(label_candidates) == 0:
return (None, "Unknown label '{}' in location expansion.".format(label_str))
elif len(label_candidates) > 1:
return (None, "Ambiguous label '{}' in location expansion. Candidates: {}".format(
label_str,
", ".join([str(lbl) for (lbl, _) in label_candidates]),
))
return (label_candidates[0][1], None)
def expand_location(repository_ctx, string, labels, attr = None):
"""Expand `$(location label)` to a path.
Raises an error on unexpected occurrences of `$`.
Use `$$` to insert a verbatim `$`.
Attrs:
repository_ctx: The repository rule context.
string: string, Replace instances of `$(location )` in this string.
labels: dict from string to path: Known label-string to path mappings.
attr: string, The rule attribute to use for error reporting.
Returns:
The string with all instances of `$(location )` replaced by paths.
"""
(parsed, error) = parse_expand_location(string)
if error != None:
fail(error, attr)
result = ""
for (command, argument) in parsed:
if command == "string":
result += argument
elif command == "location":
(path, error) = resolve_label(argument, labels)
if error != None:
fail(error, attr)
result += paths.join(".", paths.relativize(
str(repository_ctx.path(path)),
str(repository_ctx.path(".")),
))
else:
fail("Internal error: Unknown location expansion command '{}'.".format(command), attr)
return result
def is_bazel_version_at_least(threshold):
""" Check if current bazel version is higer or equals to a threshold.
Args:
threshold: string: minimum desired version of Bazel
Returns:
threshold_met, from_source_version: bool, bool: tuple where
first item states if the threshold was met, the second indicates
if obtained bazel version is empty string (indicating from source build)
"""
threshold_met = False
from_source_version = False
bazel_version = versions.get()
if not bazel_version:
from_source_version = True
else:
threshold_met = versions.is_at_least(threshold, bazel_version)
return (
threshold_met,
from_source_version,
)