Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Generate scala code from protobuf using @com_google_protobuf//:protoc #705

Closed
wants to merge 15 commits into from
Closed
116 changes: 113 additions & 3 deletions scala_proto/scala_proto.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -482,6 +482,9 @@ def _gen_proto_srcjar_impl(ctx):
srcjars = srcjarsattr,
)

"""
Deprecated: use scala_proto_gen instead
"""
scala_proto_srcjar = rule(
_gen_proto_srcjar_impl,
attrs = {
Expand Down Expand Up @@ -573,11 +576,12 @@ def scalapb_proto_library(
flags.append("flat_package")
if with_single_line_to_string:
flags.append("single_line_to_string")
scala_proto_srcjar(

_scalapb_proto_gen_with_jvm_deps(
name = srcjar,
flags = flags,
generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator",
deps = deps,
flags = flags,
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin",
visibility = visibility,
)

Expand All @@ -594,3 +598,109 @@ def scalapb_proto_library(
scalac_jvm_flags = scalac_jvm_flags,
visibility = visibility,
)

def _scala_proto_gen_attrs(deps_providers):
return {
"deps": attr.label_list(mandatory = True, providers = deps_providers),
"blacklisted_protos" : attr.label_list(providers = [["proto"]]),
"flags": attr.string_list(default = []),
"plugin": attr.label(executable = True, cfg = "host"),
"_protoc": attr.label(executable = True, cfg = "host", default = "@com_google_protobuf//:protoc")
}

_scala_proto_gen_outputs = {
"srcjar": "lib%{name}.srcjar",
}

def _scalapb_proto_gen_with_jvm_deps_impl(ctx):
jvm_deps = [p for p in ctx.attr.deps if hasattr(p, "proto") == False]

if "java_conversions" in ctx.attr.flags and len(jvm_deps) == 0:
fail("must have at least one jvm dependency if with_java is True (java_conversions is turned on)")

_scala_proto_gen_impl(ctx)

deps_jars = collect_jars(jvm_deps)

srcjarsattr = struct(srcjar = ctx.outputs.srcjar)
scalaattr = struct(
outputs = None,
compile_jars = deps_jars.compile_jars,
transitive_runtime_jars = deps_jars.transitive_runtime_jars,
)
java_provider = create_java_provider(scalaattr, depset())
return struct(
scala = scalaattr,
providers = [java_provider],
srcjars = srcjarsattr,
)

_scalapb_proto_gen_with_jvm_deps = rule(
_scalapb_proto_gen_with_jvm_deps_impl,
attrs = _scala_proto_gen_attrs([["proto"], [JavaInfo]]),
outputs = _scala_proto_gen_outputs,
)

def _strip_root(file, roots):
"""Strip first matching root which comes from proto_library(proto_source_root)
It assumes that proto_source_root are unique.
It should go away once generation is moved to aspects and roots can be handled for each proto_library individualy.
"""
for root in roots:
prefix = root + "/" if file.is_source else file.root.path + "/" + root + "/"
if file.path.startswith(prefix):
return file.path.replace(prefix, "")
return file.short_path

def _scala_proto_gen_impl(ctx):
protos = [p for p in ctx.attr.deps if hasattr(p, "proto")] # because scalapb_proto_library passes JavaInfo as well
descriptors = depset([f for dep in protos for f in dep.proto.transitive_descriptor_sets]).to_list()
sources = depset([f for dep in protos for f in dep.proto.transitive_sources]).to_list()
roots = depset([f for dep in protos for f in dep.proto.transitive_proto_path]).to_list()
inputs = depset([_strip_root(f, roots) for f in _retained_protos(sources, ctx.attr.blacklisted_protos)]).to_list()

srcdotjar = ctx.actions.declare_file("_" + ctx.label.name + "_src.jar")

ctx.actions.run(
inputs = [ctx.executable._protoc, ctx.executable.plugin] + descriptors,
outputs = [srcdotjar],
arguments = [
"--plugin=protoc-gen-scala=" + ctx.executable.plugin.path,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is going to prevent using workers right? So you have to spin up a new JVM for each file, right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes it would prevent using workers. I thought this is what is happening now as well. Now I think my assumptions might be wrong as protoc-bridge is doing some tricks to reuse jvm. I'll dig more and update.

"--scala_out=%s:%s" % (",".join(ctx.attr.flags), srcdotjar.path),
"--descriptor_set_in=" + ":".join([descriptor.path for descriptor in descriptors])]
+ inputs,
executable = ctx.executable._protoc,
mnemonic = "ScalaProtoGen",
use_default_shell_env = True,
)

ctx.actions.run_shell(
command = "cp $1 $2",
inputs = [srcdotjar],
outputs = [ctx.outputs.srcjar],
arguments = [srcdotjar.path, ctx.outputs.srcjar.path])

"""Generates code with scala plugin passed to implicit @com_google_protobuf//:protoc

Example:
scala_proto_gen(
name = "a_proto_scala",
deps = [":a_proto"],
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin")

Args:
deps: List of proto_library rules to generate code for
blacklisted_protos: List of proto_library rules to exclude from protoc inputs
(used for libraries that comes from runtime like any.proto)
flags: list of plugin flags passed to --scala_out
plugin: an executable passed to --plugin=protoc-gen-scala= which implements protoc plugin contract
https://github.com/protocolbuffers/protobuf/blob/master/src/google/protobuf/compiler/plugin.proto

Outputs:
Single srcjar with generated sources for all deps and all the transitives
"""
scala_proto_gen = rule(
_scala_proto_gen_impl,
attrs = _scala_proto_gen_attrs(deps_providers = [["proto"]]),
outputs = _scala_proto_gen_outputs,
)
11 changes: 11 additions & 0 deletions src/scala/scripts/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,14 @@ scala_binary(
":scalapb_generator_lib",
],
)

scala_binary(
name = "scalapb_plugin",
srcs = ["ScalaPBPlugin.scala"],
main_class = "scripts.ScalaPBPlugin",
deps = [
"//external:io_bazel_rules_scala/dependency/proto/scalapb_plugin",
"//external:io_bazel_rules_scala/dependency/com_google_protobuf/protobuf_java",
],
visibility = ["//visibility:public"],
)
10 changes: 10 additions & 0 deletions src/scala/scripts/ScalaPBPlugin.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package scripts

import com.google.protobuf.compiler.PluginProtos.CodeGeneratorRequest.parseFrom
import scalapb.compiler.ProtobufGenerator.handleCodeGeneratorRequest

object ScalaPBPlugin extends App {

handleCodeGeneratorRequest(parseFrom(System.in)).writeTo(System.out)

}
11 changes: 6 additions & 5 deletions test/proto/BUILD
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
load(
"//scala_proto:scala_proto.bzl",
"scalapb_proto_library",
"scala_proto_srcjar"
"scala_proto_srcjar",
"scala_proto_gen",
)

load(
Expand Down Expand Up @@ -96,16 +97,16 @@ scalapb_proto_library(
deps = [":test_service"],
)

scala_proto_srcjar(
scala_proto_gen(
name = "test1_proto_scala",
deps = ["//test/proto2:test"],
generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator")
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin")

scala_proto_srcjar(
scala_proto_gen(
name = "test2_proto_scala_with_blacklisted_test1_proto_scala",
deps = [":test2"],
blacklisted_protos = ["//test/proto2:test"],
generator = "@io_bazel_rules_scala//src/scala/scripts:scalapb_generator")
plugin = "@io_bazel_rules_scala//src/scala/scripts:scalapb_plugin")

scala_library(
name = "lib_scala_should_fail_on_duplicated_sources_unless_duplicates_are_blacklisted",
Expand Down