diff --git a/scala/private/paths.bzl b/scala/private/paths.bzl new file mode 100644 index 000000000..09b916d4f --- /dev/null +++ b/scala/private/paths.bzl @@ -0,0 +1,12 @@ +java_extension = ".java" + +scala_extension = ".scala" + +srcjar_extension = ".srcjar" + +def get_files_with_extension(ctx, extension): + return [ + f + for f in ctx.files.srcs + if f.basename.endswith(extension) + ] diff --git a/scala/private/phases/phase_compile.bzl b/scala/private/phases/phase_compile.bzl index e98946899..1e730a574 100644 --- a/scala/private/phases/phase_compile.bzl +++ b/scala/private/phases/phase_compile.bzl @@ -9,6 +9,13 @@ load( "@io_bazel_rules_scala//scala/private:coverage_replacements_provider.bzl", _coverage_replacements_provider = "coverage_replacements_provider", ) +load( + "@io_bazel_rules_scala//scala/private:paths.bzl", + _get_files_with_extension = "get_files_with_extension", + _java_extension = "java_extension", + _scala_extension = "scala_extension", + _srcjar_extension = "srcjar_extension", +) load( "@io_bazel_rules_scala//scala/private:rule_impls.bzl", _compile_scala = "compile_scala", @@ -16,12 +23,6 @@ load( ) load(":resources.bzl", _resource_paths = "paths") -_java_extension = ".java" - -_scala_extension = ".scala" - -_srcjar_extension = ".srcjar" - _empty_coverage_struct = struct( external = struct( replacements = {}, @@ -197,29 +198,17 @@ def _compile_or_empty( merged_provider = scala_compilation_provider, ) else: - in_srcjars = [ - f - for f in ctx.files.srcs - if f.basename.endswith(_srcjar_extension) - ] + java_srcs = _get_files_with_extension(ctx, _java_extension) + scala_srcs = _get_files_with_extension(ctx, _scala_extension) + in_srcjars = _get_files_with_extension(ctx, _srcjar_extension) all_srcjars = depset(in_srcjars, transitive = [srcjars]) - java_srcs = [ - f - for f in ctx.files.srcs - if f.basename.endswith(_java_extension) - ] - # We are not able to verify whether dependencies are used when compiling java sources # Thus we disable unused dependency checking when java sources are found if len(java_srcs) != 0: unused_dependency_checker_mode = "off" - sources = [ - f - for f in ctx.files.srcs - if f.basename.endswith(_scala_extension) - ] + java_srcs + sources = scala_srcs + java_srcs _compile_scala( ctx, ctx.label, @@ -258,7 +247,7 @@ def _compile_or_empty( # so set ijar == jar ijar = ctx.outputs.jar - source_jar = _pack_source_jar(ctx) + source_jar = _pack_source_jar(ctx, scala_srcs, in_srcjars) scala_compilation_provider = _create_scala_compilation_provider(ctx, ijar, source_jar, deps_providers) # compile the java now @@ -339,31 +328,16 @@ def _create_scala_compilation_provider(ctx, ijar, source_jar, deps_providers): runtime_deps = runtime_deps, ) -def _pack_source_jar(ctx): - # collect .scala sources and pack a source jar for Scala - scala_sources = [ - f - for f in ctx.files.srcs - if f.basename.endswith(_scala_extension) - ] - - # collect .srcjar files and pack them with the scala sources - bundled_source_jars = [ - f - for f in ctx.files.srcs - if f.basename.endswith(_srcjar_extension) - ] - scala_source_jar = java_common.pack_sources( +def _pack_source_jar(ctx, scala_srcs, in_srcjars): + return java_common.pack_sources( ctx.actions, output_jar = ctx.outputs.jar, - sources = scala_sources, - source_jars = bundled_source_jars, + sources = scala_srcs, + source_jars = in_srcjars, java_toolchain = find_java_toolchain(ctx, ctx.attr._java_toolchain), host_javabase = find_java_runtime_toolchain(ctx, ctx.attr._host_javabase), ) - return scala_source_jar - def _jacoco_offline_instrument(ctx, input_jar): if not ctx.configuration.coverage_enabled or not hasattr(ctx.attr, "_code_coverage_instrumentation_worker"): return _empty_coverage_struct diff --git a/scala/private/phases/phase_scalafmt.bzl b/scala/private/phases/phase_scalafmt.bzl index 5e8284c35..a42410980 100644 --- a/scala/private/phases/phase_scalafmt.bzl +++ b/scala/private/phases/phase_scalafmt.bzl @@ -3,6 +3,11 @@ # # Outputs to format the scala files when it is explicitly specified # +load( + "@io_bazel_rules_scala//scala/private:paths.bzl", + _scala_extension = "scala_extension", +) + def phase_scalafmt(ctx, p): if ctx.attr.format: manifest, files = _build_format(ctx) @@ -17,7 +22,7 @@ def _build_format(ctx): manifest_content = [] for src in ctx.files.srcs: # only format scala source files, not generated files - if src.path.endswith(".scala") and src.is_source: + if src.path.endswith(_scala_extension) and src.is_source: file = ctx.actions.declare_file("{}.fmt.output".format(src.short_path)) files.append(file) ctx.actions.run(