Skip to content

Commit

Permalink
DRY when handling Scala version values (#20264)
Browse files Browse the repository at this point in the history
Refactor into using a new `ScalaVersion` data class when handling Scala versions.
  • Loading branch information
alonsodomin authored Dec 8, 2023
1 parent ee40ca3 commit 5f0bc44
Show file tree
Hide file tree
Showing 15 changed files with 199 additions and 116 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ async def resolve_scalapb_runtime_for_resolve(
scalapb: ScalaPBSubsystem,
) -> ScalaPBRuntimeForResolve:
scala_version = scala_subsystem.version_for_resolve(request.resolve_name)
# TODO: Does not handle Scala 3 suffix which is just `_3` nor X.Y.Z versions.
scala_binary_version, _, _ = scala_version.rpartition(".")
scala_binary_version = scala_version.binary
version = scalapb.version

addresses = find_jvm_artifacts_or_raise(
Expand Down
32 changes: 12 additions & 20 deletions src/python/pants/backend/codegen/protobuf/scala/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@
ProtobufSourceTarget,
)
from pants.backend.scala.target_types import ScalaSourceField
from pants.backend.scala.util_rules.versions import (
ScalaArtifactsForVersionRequest,
ScalaArtifactsForVersionResult,
ScalaVersion,
)
from pants.core.goals.generate_lockfiles import GenerateToolLockfileSentinel
from pants.core.util_rules import distdir
from pants.core.util_rules.external_tool import DownloadedExternalTool, ExternalToolRequest
Expand Down Expand Up @@ -46,7 +51,7 @@
from pants.jvm.dependency_inference import artifact_mapper
from pants.jvm.goals import lockfile
from pants.jvm.jdk_rules import InternalJdk, JvmProcess
from pants.jvm.resolve.common import ArtifactRequirements, Coordinate, GatherJvmCoordinatesRequest
from pants.jvm.resolve.common import ArtifactRequirements, GatherJvmCoordinatesRequest
from pants.jvm.resolve.coursier_fetch import ToolClasspath, ToolClasspathRequest
from pants.jvm.resolve.jvm_tool import GenerateJvmLockfileFromTool, GenerateJvmToolLockfileSentinel
from pants.jvm.target_types import PrefixedJvmJdkField, PrefixedJvmResolveField
Expand Down Expand Up @@ -234,7 +239,7 @@ async def materialize_jvm_plugins(
return MaterializedJvmPlugins(merged_plugins_digest, materialized_plugins)


SHIM_SCALA_VERSION = "2.13.7"
SHIM_SCALA_VERSION = ScalaVersion.parse("2.13.7")


# TODO(13879): Consolidate compilation of wrapper binaries to common rules.
Expand All @@ -253,30 +258,17 @@ async def setup_scalapb_shim_classfiles(

scalapb_shim_source = FileContent("ScalaPBShim.scala", scalapb_shim_content)

lockfile_request = await Get(GenerateJvmLockfileFromTool, ScalapbcToolLockfileSentinel())
lockfile_request, scala_artifacts = await MultiGet(
Get(GenerateJvmLockfileFromTool, ScalapbcToolLockfileSentinel()),
Get(ScalaArtifactsForVersionResult, ScalaArtifactsForVersionRequest(SHIM_SCALA_VERSION)),
)
tool_classpath, shim_classpath, source_digest = await MultiGet(
Get(
ToolClasspath,
ToolClasspathRequest(
prefix="__toolcp",
artifact_requirements=ArtifactRequirements.from_coordinates(
[
Coordinate(
group="org.scala-lang",
artifact="scala-compiler",
version=SHIM_SCALA_VERSION,
),
Coordinate(
group="org.scala-lang",
artifact="scala-library",
version=SHIM_SCALA_VERSION,
),
Coordinate(
group="org.scala-lang",
artifact="scala-reflect",
version=SHIM_SCALA_VERSION,
),
]
scala_artifacts.all_coordinates
),
),
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def resolve_scrooge_thrift_java_runtime_for_resolve(
scala_subsystem: ScalaSubsystem,
) -> ScroogeThriftJavaRuntimeForResolve:
scala_version = scala_subsystem.version_for_resolve(request.resolve_name)
scala_binary_version, _, _ = scala_version.rpartition(".")
scala_binary_version = scala_version.binary
addresses = find_jvm_artifacts_or_raise(
required_coordinates=[
UnversionedCoordinate(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def resolve_scrooge_thrift_scala_runtime_for_resolve(
scala_subsystem: ScalaSubsystem,
) -> ScroogeThriftScalaRuntimeForResolve:
scala_version = scala_subsystem.version_for_resolve(request.resolve_name)
scala_binary_version, _, _ = scala_version.rpartition(".")
scala_binary_version = scala_version.binary
addresses = find_jvm_artifacts_or_raise(
required_coordinates=[
UnversionedCoordinate(
Expand Down
13 changes: 4 additions & 9 deletions src/python/pants/backend/scala/bsp/rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from pants.backend.scala.util_rules.versions import (
ScalaArtifactsForVersionRequest,
ScalaArtifactsForVersionResult,
ScalaVersion,
)
from pants.base.build_root import BuildRoot
from pants.bsp.protocol import BSPHandlerMapping
Expand Down Expand Up @@ -158,7 +159,7 @@ async def collect_thirdparty_modules(
)


async def _materialize_scala_runtime_jars(scala_version: str) -> Snapshot:
async def _materialize_scala_runtime_jars(scala_version: ScalaVersion) -> Snapshot:
scala_artifacts = await Get(
ScalaArtifactsForVersionResult, ScalaArtifactsForVersionRequest(scala_version)
)
Expand Down Expand Up @@ -283,17 +284,11 @@ async def bsp_resolve_scala_metadata(
java_version=f"1.{jdk.jre_major_version}",
)

scala_version_parts = scala_version.split(".")
scala_binary_version = (
".".join(scala_version_parts[0:2])
if int(scala_version_parts[0]) < 3
else scala_version_parts[0]
)
return BSPBuildTargetsMetadataResult(
metadata=ScalaBuildTarget(
scala_organization="org.scala-lang",
scala_version=scala_version,
scala_binary_version=scala_binary_version,
scala_version=str(scala_version),
scala_binary_version=scala_version.binary,
platform=ScalaPlatform.JVM,
jars=scala_jar_uris,
jvm_build_target=jvm_build_target,
Expand Down
3 changes: 2 additions & 1 deletion src/python/pants/backend/scala/compile/scalac.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from pants.backend.scala.util_rules.versions import (
ScalaArtifactsForVersionRequest,
ScalaArtifactsForVersionResult,
ScalaVersion,
)
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.fs import EMPTY_DIGEST, Digest, MergeDigests
Expand Down Expand Up @@ -58,7 +59,7 @@ class CompileScalaSourceRequest(ClasspathEntryRequest):

@dataclass(frozen=True)
class ScalaLibraryRequest:
version: str
version: ScalaVersion


# TODO: This code is duplicated in the scalac and BSP rules.
Expand Down
45 changes: 19 additions & 26 deletions src/python/pants/backend/scala/dependency_inference/scala_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

from pants.backend.scala.subsystems.scala import ScalaSubsystem
from pants.backend.scala.subsystems.scalac import Scalac
from pants.backend.scala.util_rules.versions import (
ScalaArtifactsForVersionRequest,
ScalaArtifactsForVersionResult,
ScalaVersion,
)
from pants.core.goals.generate_lockfiles import DEFAULT_TOOL_LOCKFILE, GenerateToolLockfileSentinel
from pants.core.util_rules.source_files import SourceFiles, SourceFilesRequest
from pants.engine.fs import (
Expand All @@ -30,7 +35,7 @@
from pants.jvm.compile import ClasspathEntry
from pants.jvm.jdk_rules import InternalJdk, JvmProcess
from pants.jvm.jdk_rules import rules as jdk_rules
from pants.jvm.resolve.common import ArtifactRequirements, Coordinate
from pants.jvm.resolve.common import ArtifactRequirements
from pants.jvm.resolve.coursier_fetch import ToolClasspath, ToolClasspathRequest
from pants.jvm.resolve.jvm_tool import GenerateJvmLockfileFromTool, GenerateJvmToolLockfileSentinel
from pants.jvm.subsystems import JvmSubsystem
Expand All @@ -43,8 +48,8 @@
logger = logging.getLogger(__name__)


_PARSER_SCALA_VERSION = "2.13.8"
_PARSER_SCALA_BINARY_VERSION = _PARSER_SCALA_VERSION.rpartition(".")[0]
_PARSER_SCALA_VERSION = ScalaVersion.parse("2.13.8")
_PARSER_SCALA_BINARY_VERSION = _PARSER_SCALA_VERSION.binary


class ScalaParserToolLockfileSentinel(GenerateJvmToolLockfileSentinel):
Expand Down Expand Up @@ -194,7 +199,7 @@ class ScalaParserCompiledClassfiles(ClasspathEntry):
@dataclass(frozen=True)
class AnalyzeScalaSourceRequest:
source_files: SourceFiles
scala_version: str
scala_version: ScalaVersion
source3: bool


Expand Down Expand Up @@ -274,7 +279,7 @@ async def analyze_scala_source_dependencies(
"org.pantsbuild.backend.scala.dependency_inference.ScalaParser",
analysis_output_path,
source_path,
request.scala_version,
str(request.scala_version),
str(request.source3),
],
input_digest=prefixed_source_files_digest,
Expand Down Expand Up @@ -319,8 +324,9 @@ async def setup_scala_parser_classfiles(jdk: InternalJdk) -> ScalaParserCompiled

parser_source = FileContent("ScalaParser.scala", parser_source_content)

parser_lockfile_request = await Get(
GenerateJvmLockfileFromTool, ScalaParserToolLockfileSentinel()
parser_lockfile_request, scala_artifacts = await MultiGet(
Get(GenerateJvmLockfileFromTool, ScalaParserToolLockfileSentinel()),
Get(ScalaArtifactsForVersionResult, ScalaArtifactsForVersionRequest(_PARSER_SCALA_VERSION)),
)

tool_classpath, parser_classpath, source_digest = await MultiGet(
Expand All @@ -329,23 +335,7 @@ async def setup_scala_parser_classfiles(jdk: InternalJdk) -> ScalaParserCompiled
ToolClasspathRequest(
prefix="__toolcp",
artifact_requirements=ArtifactRequirements.from_coordinates(
[
Coordinate(
group="org.scala-lang",
artifact="scala-compiler",
version=_PARSER_SCALA_VERSION,
),
Coordinate(
group="org.scala-lang",
artifact="scala-library",
version=_PARSER_SCALA_VERSION,
),
Coordinate(
group="org.scala-lang",
artifact="scala-reflect",
version=_PARSER_SCALA_VERSION,
),
]
scala_artifacts.all_coordinates
),
),
),
Expand Down Expand Up @@ -397,15 +387,18 @@ async def setup_scala_parser_classfiles(jdk: InternalJdk) -> ScalaParserCompiled


@rule
def generate_scala_parser_lockfile_request(
async def generate_scala_parser_lockfile_request(
_: ScalaParserToolLockfileSentinel,
) -> GenerateJvmLockfileFromTool:
scala_artifacts = await Get(
ScalaArtifactsForVersionResult, ScalaArtifactsForVersionRequest(_PARSER_SCALA_VERSION)
)
return GenerateJvmLockfileFromTool(
artifact_inputs=FrozenOrderedSet(
{
f"org.scalameta:scalameta_{_PARSER_SCALA_BINARY_VERSION}:4.8.7",
f"io.circe:circe-generic_{_PARSER_SCALA_BINARY_VERSION}:0.14.1",
f"org.scala-lang:scala-library:{_PARSER_SCALA_VERSION}",
scala_artifacts.library_coordinate.to_coord_str(),
}
),
artifact_option_name="n/a",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
ScalaSourceDependencyAnalysis,
)
from pants.backend.scala.target_types import ScalaSourceField, ScalaSourceTarget
from pants.backend.scala.util_rules import versions
from pants.build_graph.address import Address
from pants.core.util_rules import source_files
from pants.core.util_rules.source_files import SourceFilesRequest
Expand All @@ -37,6 +38,7 @@ def rule_runner() -> RuleRunner:
*target_types.rules(),
*jvm_util_rules.rules(),
*process.rules(),
*versions.rules(),
QueryRule(AnalyzeScalaSourceRequest, (SourceFilesRequest,)),
QueryRule(ScalaSourceDependencyAnalysis, (AnalyzeScalaSourceRequest,)),
],
Expand Down
4 changes: 2 additions & 2 deletions src/python/pants/backend/scala/resolve/lockfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,9 @@ async def validate_scala_runtime_is_present_in_resolve(
artifact.coordinate.group == SCALA_LIBRARY_GROUP
and artifact.coordinate.artifact == scala_artifacts.library_coordinate.artifact
):
if artifact.coordinate.version != scala_version:
if artifact.coordinate.version != str(scala_version):
raise ConflictingScalaLibraryVersionInResolveError(
request.resolve_name, scala_version, artifact.coordinate
request.resolve_name, str(scala_version), artifact.coordinate
)

# This does not `break` so the loop can validate the entire set of requirements to ensure no conflicting
Expand Down
11 changes: 4 additions & 7 deletions src/python/pants/backend/scala/subsystems/scala.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,12 @@

from __future__ import annotations

import logging

from pants.backend.scala.util_rules.versions import ScalaVersion
from pants.option.option_types import BoolOption, DictOption
from pants.option.subsystem import Subsystem
from pants.util.strutil import softwrap

DEFAULT_SCALA_VERSION = "2.13.6"

_logger = logging.getLogger(__name__)
DEFAULT_SCALA_VERSION = ScalaVersion.parse("2.13.6")


class ScalaSubsystem(Subsystem):
Expand Down Expand Up @@ -42,8 +39,8 @@ class ScalaSubsystem(Subsystem):
advanced=True,
)

def version_for_resolve(self, resolve: str) -> str:
def version_for_resolve(self, resolve: str) -> ScalaVersion:
version = self._version_for_resolve.get(resolve)
if version:
return version
return ScalaVersion.parse(version)
return DEFAULT_SCALA_VERSION
Loading

0 comments on commit 5f0bc44

Please sign in to comment.