Skip to content

Commit

Permalink
fix: ensure that options modules from sections are unimported
Browse files Browse the repository at this point in the history
fix: ensure that unimports for sections are properly scoped
  • Loading branch information
zachdaniel committed Jun 10, 2024
1 parent e20d7ee commit d59b26d
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 52 deletions.
2 changes: 2 additions & 0 deletions lib/spark/code_helpers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ defmodule Spark.CodeHelpers do
)
when is_list(fn_args) do
fn_name = generate_unique_function_name(quoted_fn, key)

arity =
case fn_args do
[{:when, _, args_with_clause}] ->
Expand All @@ -214,6 +215,7 @@ defmodule Spark.CodeHelpers do
other ->
Enum.count(other)
end

function = generate_captured_function_caller(fn_name, arity, caller)

function_defs =
Expand Down
106 changes: 63 additions & 43 deletions lib/spark/dsl/extension.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1009,51 +1009,58 @@ defmodule Spark.Dsl.Extension do
[]
end

entity_imports ++
section_imports ++
opts_import ++
configured_imports ++
patch_module_imports ++
unimports ++
other_extension_unimports ++
[
quote generated: true do
unquote(body[:do])
all_the_code =
entity_imports ++
section_imports ++
opts_import ++
configured_imports ++
patch_module_imports ++
unimports ++
other_extension_unimports ++
[
quote generated: true do
unquote(body[:do])

current_config =
Process.get(
{__MODULE__, :spark, unquote(section_path)},
%{entities: [], opts: []}
)
current_config =
Process.get(
{__MODULE__, :spark, unquote(section_path)},
%{entities: [], opts: []}
)

opts =
case Spark.Options.validate(
Keyword.new(current_config.opts),
Map.get(unquote(Macro.escape(section)), :schema, [])
) do
{:ok, opts} ->
opts

{:error, error} ->
raise Spark.Error.DslError,
module: __MODULE__,
message: error,
path: unquote(section_path)
end
opts =
case Spark.Options.validate(
Keyword.new(current_config.opts),
Map.get(unquote(Macro.escape(section)), :schema, [])
) do
{:ok, opts} ->
opts

{:error, error} ->
raise Spark.Error.DslError,
module: __MODULE__,
message: error,
path: unquote(section_path)
end

Process.put(
{__MODULE__, :spark, unquote(section_path)},
%{
entities: current_config.entities,
opts: opts
}
)
end
] ++
configured_unimports ++
patch_module_unimports ++
other_extension_reimports ++
opts_unimport ++ entity_unimports ++ section_unimports
Process.put(
{__MODULE__, :spark, unquote(section_path)},
%{
entities: current_config.entities,
opts: opts
}
)
end
] ++
configured_unimports ++
patch_module_unimports ++
other_extension_reimports ++
opts_unimport ++ entity_unimports ++ section_unimports

quote do
with do
unquote(all_the_code)
end
end
end
end
end
Expand Down Expand Up @@ -1187,6 +1194,19 @@ defmodule Spark.Dsl.Extension do
[mod | nested_mod_name] ++ [Macro.camelize(to_string(nested_section.name))]
)

opts_unimport =
if section.schema == [] do
[]
else
[
quote generated: true do
import unquote(opts_mod_name), only: []
end
]
end

unimports = opts_unimport ++ unimports

Spark.Dsl.Extension.async_compile(agent, fn ->
{:module, module, _, _} =
Module.create(
Expand All @@ -1201,7 +1221,7 @@ defmodule Spark.Dsl.Extension do
unquote(agent),
unquote(extension),
unquote(Macro.escape(nested_section)),
unquote(unimports),
unquote(Macro.escape(unimports)),
unquote(path ++ [section.name]),
unquote(mod)
)
Expand Down
18 changes: 9 additions & 9 deletions test/code_helpers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@ defmodule Spark.CodeHelpersTest do
__ENV__
)

{:module, module, _, _} =
Module.create(
Foo,
quote do
unquote(funs)
{:module, module, _, _} =
Module.create(
Foo,
quote do
unquote(funs)

def fun, do: unquote(code)
end,
Macro.Env.location(__ENV__)
)
def fun, do: unquote(code)
end,
Macro.Env.location(__ENV__)
)

assert :erlang.fun_info(module.fun())[:arity] == 2
end
Expand Down

0 comments on commit d59b26d

Please sign in to comment.