Skip to content

Commit

Permalink
Use ctx.actions.args() in compile_scala (#1124)
Browse files Browse the repository at this point in the history
  • Loading branch information
simuons authored Feb 23, 2021
1 parent 8c2f294 commit 4130036
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 127 deletions.
159 changes: 43 additions & 116 deletions scala/private/rule_impls.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,6 @@ def expand_location(ctx, flags):
data = []
return [ctx.expand_location(f, data) for f in flags]

def _join_path(args, sep = ","):
return sep.join([f.path for f in args])

# Return the first non-empty arg. If all are empty, return the last.
def first_non_empty(*args):
for arg in args:
Expand Down Expand Up @@ -64,131 +61,62 @@ def compile_scala(
# look for any plugins:
input_plugins = plugins
plugins = _collect_plugin_paths(plugins)
internal_plugin_jars = []
compiler_classpath_jars = cjars
if dependency_info.dependency_mode != "direct":
compiler_classpath_jars = transitive_compile_jars
optional_scalac_args = ""
classpath_resources = []
if (hasattr(ctx.files, "classpath_resources")):
classpath_resources = ctx.files.classpath_resources

optional_scalac_args_map = {}

if dependency_info.use_analyzer:
dep_plugin = ctx.attr._dependency_analyzer_plugin
plugins = depset(transitive = [plugins, dep_plugin.files])
internal_plugin_jars = ctx.files._dependency_analyzer_plugin

if dependency_info.need_indirect_info:
transitive_cjars_list = transitive_compile_jars.to_list()
indirect_jars = _join_path(transitive_cjars_list)
indirect_targets = ",".join([str(labels[j.path]) for j in transitive_cjars_list])

optional_scalac_args_map["IndirectJars"] = indirect_jars
optional_scalac_args_map["IndirectTargets"] = indirect_targets

if dependency_info.unused_deps_mode != "off":
ignored_targets = ",".join([str(d) for d in unused_dependency_checker_ignored_targets])
optional_scalac_args_map["UnusedDepsIgnoredTargets"] = ignored_targets

if dependency_info.need_direct_info:
cjars_list = cjars.to_list()
if dependency_info.need_direct_jars:
direct_jars = _join_path(cjars_list)
optional_scalac_args_map["DirectJars"] = direct_jars
if dependency_info.need_direct_targets:
direct_targets = ",".join([str(labels[j.path]) for j in cjars_list])
optional_scalac_args_map["DirectTargets"] = direct_targets

optional_scalac_args = "\n".join([
"{k}: {v}".format(k = k, v = v)
# We sort the arguments for input stability and reproducibility
for (k, v) in sorted(optional_scalac_args_map.items())
])

plugins_list = plugins.to_list()
plugin_arg = _join_path(plugins_list)

separator = ctx.configuration.host_path_separator
compiler_classpath = _join_path(compiler_classpath_jars.to_list(), separator)
plugins = depset(transitive = [plugins, ctx.attr._dependency_analyzer_plugin.files])

toolchain = ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"]
compiler_classpath_jars = cjars if dependency_info.dependency_mode == "direct" else transitive_compile_jars
classpath_resources = getattr(ctx.files, "classpath_resources", [])
scalacopts = [ctx.expand_location(v, input_plugins) for v in toolchain.scalacopts + in_scalacopts]
resource_paths = _resource_paths(resources, resource_strip_prefix)
enable_diagnostics_report = toolchain.enable_diagnostics_report

scalac_args = """
CurrentTarget: {current_target}
Classpath: {cp}
ClasspathResourceSrcs: {classpath_resource_src}
Files: {files}
JarOutput: {out}
Manifest: {manifest}
Plugins: {plugin_arg}
PrintCompileTime: {print_compile_time}
ExpectJavaOutput: {expect_java_output}
ResourceTargets: {resource_targets}
ResourceSources: {resource_sources}
ResourceJars: {resource_jars}
ScalacOpts: {scala_opts}
SourceJars: {srcjars}
StrictDepsMode: {strict_deps_mode}
UnusedDependencyCheckerMode: {unused_dependency_checker_mode}
DependencyTrackingMethod: {dependency_tracking_method}
StatsfileOutput: {statsfile_output}
EnableDiagnosticsReport: {enable_diagnostics_report}
DiagnosticsFile: {diagnostics_output}
""".format(
current_target = str(target_label),
out = output.path,
manifest = manifest.path,
# Using ':::' as delimiter because ',' can collide with actual scalac options
# https://github.com/bazelbuild/rules_scala/issues/1049
scala_opts = ":::".join(scalacopts),
print_compile_time = print_compile_time,
expect_java_output = expect_java_output,
plugin_arg = plugin_arg,
cp = compiler_classpath,
classpath_resource_src = _join_path(classpath_resources),
files = _join_path(sources),
srcjars = _join_path(all_srcjars.to_list()),
# the resource paths need to be aligned in order
resource_targets = ",".join([p[0] for p in resource_paths]),
resource_sources = ",".join([p[1] for p in resource_paths]),
resource_jars = _join_path(resource_jars),
strict_deps_mode = dependency_info.strict_deps_mode,
unused_dependency_checker_mode = dependency_info.unused_deps_mode,
dependency_tracking_method = dependency_info.dependency_tracking_method,
statsfile_output = statsfile.path,
enable_diagnostics_report = enable_diagnostics_report,
diagnostics_output = diagnosticsfile.path,
)
args = ctx.actions.args()
args.set_param_file_format("multiline")
args.use_param_file(param_file_arg = "@%s", use_always = True)
args.add("--CurrentTarget", target_label)
args.add("--JarOutput", output)
args.add("--Manifest", manifest)
args.add("--PrintCompileTime", print_compile_time)
args.add("--ExpectJavaOutput", expect_java_output)
args.add("--StrictDepsMode", dependency_info.strict_deps_mode)
args.add("--UnusedDependencyCheckerMode", dependency_info.unused_deps_mode)
args.add("--DependencyTrackingMethod", dependency_info.dependency_tracking_method)
args.add("--StatsfileOutput", statsfile)
args.add("--EnableDiagnosticsReport", enable_diagnostics_report)
args.add("--DiagnosticsFile", diagnosticsfile)
args.add_joined("--Classpath", compiler_classpath_jars, join_with = ctx.configuration.host_path_separator)
args.add_joined("--ClasspathResourceSrcs", classpath_resources, join_with = ",")
args.add_joined("--Files", sources, join_with = ",")
args.add_joined("--Plugins", plugins, join_with = ",")
args.add_joined("--ResourceTargets", [p[0] for p in resource_paths], join_with = ",")
args.add_joined("--ResourceSources", [p[1] for p in resource_paths], join_with = ",")
args.add_joined("--ResourceJars", resource_jars, join_with = ",")
args.add_joined("--ScalacOpts", scalacopts, join_with = ":::")
args.add_joined("--SourceJars", all_srcjars, join_with = ",")

argfile = ctx.actions.declare_file(
"%s_scalac_worker_input" % target_label.name,
sibling = output,
)
if dependency_info.need_direct_info:
if dependency_info.need_direct_jars:
args.add_joined("--DirectJars", cjars, join_with = ",")
if dependency_info.need_direct_targets:
args.add_joined("--DirectTargets", [labels[j.path] for j in cjars.to_list()], join_with = ",")

ctx.actions.write(
output = argfile,
content = scalac_args + optional_scalac_args,
)
if dependency_info.need_indirect_info:
args.add_joined("--IndirectJars", transitive_compile_jars, join_with = ",")
args.add_joined("--IndirectTargets", [labels[j.path] for j in transitive_compile_jars.to_list()], join_with = ",")

if dependency_info.unused_deps_mode != "off":
args.add_joined("--UnusedDepsIgnoredTargets", unused_dependency_checker_ignored_targets, join_with = ",")

outs = [output, statsfile, diagnosticsfile]

ins = (
compiler_classpath_jars.to_list() + all_srcjars.to_list() + list(sources) +
plugins_list + internal_plugin_jars + classpath_resources + resources +
resource_jars + [manifest, argfile]
ins = depset(
direct = [manifest] + sources + classpath_resources + resources + resource_jars,
transitive = [compiler_classpath_jars, all_srcjars, plugins],
)

# scalac_jvm_flags passed in on the target override scalac_jvm_flags passed in on the
# toolchain
final_scalac_jvm_flags = first_non_empty(
scalac_jvm_flags,
ctx.toolchains["@io_bazel_rules_scala//scala:toolchain_type"].scalac_jvm_flags,
)
# scalac_jvm_flags passed in on the target override scalac_jvm_flags passed in on the toolchain
final_scalac_jvm_flags = first_non_empty(scalac_jvm_flags, toolchain.scalac_jvm_flags)

ctx.actions.run(
inputs = ins,
Expand All @@ -208,8 +136,7 @@ DiagnosticsFile: {diagnostics_output}
arguments = [
"--jvm_flag=%s" % f
for f in expand_location(ctx, final_scalac_jvm_flags)
] + ["@" + argfile.path],
# diagnostics_file = diagnosticsfile TODO: add diagnostics_file argument whenever this argument is supported: https://github.com/bazelbuild/rules_scala/issues/1215
] + [args],
)

def compile_java(ctx, source_jars, source_files, output, extra_javac_opts, providers_of_dependencies):
Expand Down
16 changes: 5 additions & 11 deletions src/java/io/bazel/rulesscala/scalac/CompileOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,12 @@ private static List<Resource> getResources(Map<String, String> args) {
return resources;
}

private static HashMap<String, String> buildArgMap(List<String> lines) {
HashMap hm = new HashMap();
for (String line : lines) {
String[] lSplit = line.split(": ");
if (lSplit.length > 2) {
throw new RuntimeException("Bad arg, should have at most 1 space/2 spans. arg: " + line);
}
if (lSplit.length > 1) {
hm.put(lSplit[0], lSplit[1]);
}
private static Map<String, String> buildArgMap(List<String> lines) {
Map<String, String> args = new HashMap<>();
for (int i = 0; i < lines.size(); i += 2) {
args.put(lines.get(i).substring(2), lines.get(i + 1));
}
return hm;
return args;
}

protected static String[] getTripleColonList(Map<String, String> m, String k) {
Expand Down

0 comments on commit 4130036

Please sign in to comment.