Skip to content

Commit

Permalink
Make stack_snapshot usable in another module's extension
Browse files Browse the repository at this point in the history
With bzlmod, when using `stack_snapshot` in a module extension of another module,
rules_haskell does not have visibility of generated repositories (such as `rules_haskell_stack`) and cannot build labels to its targets using the `Label` constructor.

To work around this, an optional `label_builder` parameter can be passed to `stack_snapshot`, which is then used to canonicalize various labels.
For the same reason, the `stack_update` attribute to `_stack_snapshot` is now always present (so we do not have to hardcode its label), and is made into a string so that we only fetch the `stack_update` repository when explicitly reading from it.
  • Loading branch information
ylecornec committed Jun 8, 2023
1 parent f5d8e9f commit 9c321ce
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
35 changes: 25 additions & 10 deletions haskell/cabal.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -1536,8 +1536,7 @@ def _download_packages(repository_ctx, snapshot, pinned):

if stack_unpack:
# Enforce dependency on stack_update.
# Hard-coded label since `stack_update` is `None` in this case.
repository_ctx.read(Label("@rules_haskell_stack_update//:stack_update"))
repository_ctx.read(repository_ctx.path(Label(repository_ctx.attr.stack_update)))
_download_packages_unpinned(repository_ctx, snapshot, stack_unpack)

def _download_packages_unpinned(repository_ctx, snapshot, resolved):
Expand Down Expand Up @@ -1586,6 +1585,9 @@ def _to_string_keyed_label_list_dict(d):
out.setdefault(string_key, []).append(label)
return out

def _is_bzlmod_enabled():
return str(Label("@rules_haskell//:BUILD.bazel")).startswith("@@")

def _label_to_string(label):
if check_bazel_version("6.0.0")[0]:
# `str` serializes the label to its canonical name starting from bazel 6
Expand Down Expand Up @@ -1913,7 +1915,7 @@ def _stack_snapshot_impl(repository_ctx):
# Resolve and fetch packages
if repository_ctx.attr.stack_snapshot_json == None:
# Enforce dependency on stack_update
repository_ctx.read(repository_ctx.attr.stack_update)
repository_ctx.read(repository_ctx.path(Label(repository_ctx.attr.stack_update)))
resolved = _resolve_packages(
repository_ctx,
snapshot,
Expand Down Expand Up @@ -1986,7 +1988,7 @@ packages = {
executables = all_components[name].exe,
sublibs = all_components[name].sublibs,
deps = [
str(Label("@{}//:{}".format(repository_ctx.attr.unmangled_repo_name, dep)))
"@{}//:{}".format(repository_ctx.attr.unmangled_repo_name, dep)
for dep in spec["dependencies"]
if all_components[dep].lib
],
Expand Down Expand Up @@ -2069,8 +2071,13 @@ haskell_library(
for dep in spec["dependencies"]
for exe in all_components[dep].exe
] + tools

setup_deps = [
_label_to_string(Label("@{}//:{}".format(repository_ctx.attr.unmangled_repo_name, name)).relative(label))
_label_to_string(Label("{}{}//:{}".format(
"@@" if _is_bzlmod_enabled() else "@",
repository_ctx.attr.name,
name,
)).relative(label))
for label in repository_ctx.attr.setup_deps.get(name, [])
]
if all_components[name].lib:
Expand Down Expand Up @@ -2223,7 +2230,7 @@ _stack_snapshot = repository_rule(
"components": attr.string_list_dict(),
"components_dependencies": attr.string_dict(),
"stack": attr.label(),
"stack_update": attr.label(),
"stack_update": attr.string(),
"verbose": attr.bool(default = False),
"custom_toolchain_libraries": attr.string_list(default = []),
"enable_custom_toolchain_libraries": attr.bool(default = False),
Expand Down Expand Up @@ -2434,6 +2441,7 @@ def stack_snapshot(
netrc = "",
toolchain_libraries = None,
setup_stack = True,
label_builder = lambda l: Label(l),
**kwargs):
"""Use Stack to download and extract Cabal source distributions.
Expand Down Expand Up @@ -2631,11 +2639,11 @@ def stack_snapshot(
# Allow overriding stack binary at workspace level by `use_stack()`.
# Otherwise this is a no-op.
if native.existing_rule("rules_haskell_stack") or not setup_stack:
stack = Label("@rules_haskell_stack//:stack")
stack = label_builder("@rules_haskell_stack//:stack")

if not stack:
_fetch_stack(name = "rules_haskell_stack")
stack = Label("@rules_haskell_stack//:stack")
stack = label_builder("@rules_haskell_stack//:stack")

# Execute stack update once before executing _stack_snapshot.
# This is to avoid multiple concurrent executions of stack update,
Expand Down Expand Up @@ -2663,12 +2671,19 @@ def stack_snapshot(
custom_toolchain_libraries = toolchain_libraries,
enable_custom_toolchain_libraries = toolchain_libraries != None,
)
canonical_setup_deps = {
k: [
str(label_builder(label)) if label.startswith("@") else label
for label in labels
]
for (k, labels) in setup_deps.items()
}
_stack_snapshot(
name = name,
unmangled_repo_name = name,
stack = stack,
# Dependency for ordered execution, stack update before stack unpack.
stack_update = None if stack_snapshot_json else "@rules_haskell_stack_update//:stack_update",
stack_update = str(label_builder("@rules_haskell_stack_update//:stack_update")),
# TODO Remove _from_string_keyed_label_list_dict once following issue
# is resolved: https://github.com/bazelbuild/bazel/issues/7989.
extra_deps = _from_string_keyed_label_list_dict(extra_deps),
Expand All @@ -2681,7 +2696,7 @@ def stack_snapshot(
packages = packages,
flags = flags,
haddock = haddock,
setup_deps = setup_deps,
setup_deps = canonical_setup_deps,
tools = tools,
components = components,
components_dependencies = components_dependencies,
Expand Down
7 changes: 7 additions & 0 deletions rules_haskell_tests/non_module_deps_2.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ load("@os_info//:os_info.bzl", "is_linux", "is_windows")
load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive")
load("@toolchains_libraries//:toolchain_libraries.bzl", "toolchain_libraries")

label_builder = lambda x: Label(x)

def repositories(*, bzlmod):
# In a separate repo because not all platforms support zlib.
stack_snapshot(
Expand All @@ -13,6 +15,7 @@ def repositories(*, bzlmod):
local_snapshot = "//:stackage_snapshot.yaml",
packages = ["zlib"],
stack_snapshot_json = "//:stackage-zlib-snapshot.json" if not is_windows else None,
label_builder = label_builder,
)

# Vendor data-default-instances-containers and data-default-instances-old-local
Expand Down Expand Up @@ -121,6 +124,7 @@ haskell_library(
"data-default-instances-old-locale": "@data-default-ol//:lib",
"ghc-paths": "@rules_haskell//tools/ghc-paths",
},
label_builder = label_builder,
)

http_archive(
Expand Down Expand Up @@ -196,6 +200,7 @@ haskell_cabal_library(
"quickcheck-io": ["@Cabal//:Cabal"],
},
stack_snapshot_json = "//:stackage-pinning-test_snapshot.json" if not is_windows else None,
label_builder = label_builder,
)

stack_snapshot(
Expand Down Expand Up @@ -296,6 +301,7 @@ haskell_cabal_library(
vendored_packages = {
"ghc-paths": "@rules_haskell//tools/ghc-paths",
},
label_builder = label_builder,
)

stack_snapshot(
Expand All @@ -307,6 +313,7 @@ haskell_cabal_library(
],
stack_snapshot_json = "//tests/asterius/stack_toolchain_libraries:snapshot.json",
toolchain_libraries = toolchain_libraries,
label_builder = label_builder,
) if is_linux else None

def _non_module_deps_2_impl(ctx):
Expand Down

0 comments on commit 9c321ce

Please sign in to comment.