Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow existing uses of _jacobian in function names, with warning #1471

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,76 @@
independent lognormal distribution on the scales, see: \
https://mc-stan.org/docs/reference-manual/deprecations.html#lkj_cov-distribution"

let functions_block_contains_jac_pe (stmts : untyped_statement list) =
(* tracking if 'jacobian' is a variable in scope *)
let jacobian_scope_id = ref 0 in
let is_jacobian_in_scope () = !jacobian_scope_id > 0 in
let current_scope_id = ref 1 in
let found_jacobian () =
if not (is_jacobian_in_scope ()) then jacobian_scope_id := !current_scope_id

Check warning on line 50 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L50

Added line #L50 was not covered by tests
in
let push_scope () = current_scope_id := !current_scope_id + 1 in
let pop_scope () =
current_scope_id := !current_scope_id - 1;
(* if the scope we just left was the one defining jacobian, reset it *)
if !jacobian_scope_id > !current_scope_id then jacobian_scope_id := 0 in

Check warning on line 56 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L56

Added line #L56 was not covered by tests
(* walk over the tree, looking for usages of jacobian+= where
there is no variable called jacobian already in scope *)
let rec f (s : untyped_statement) =
match s.stmt with
| FunDef {body; funname; _}
when String.is_suffix funname.name ~suffix:"_jacobian" ->
push_scope ();
let res = f body in
pop_scope ();
res
| Block stmts | Profile (_, stmts) ->

Check warning on line 67 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L67

Added line #L67 was not covered by tests
push_scope ();
let res = List.exists ~f stmts in
pop_scope ();
res
| For {loop_body; _} | While (_, loop_body) | ForEach (_, _, loop_body) ->

Check warning on line 72 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L72

Added line #L72 was not covered by tests
push_scope ();
let res = f loop_body in
pop_scope ();
res
| IfThenElse (_, s1, s2_opt) ->

Check warning on line 77 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L74-L77

Added lines #L74 - L77 were not covered by tests
push_scope ();
let res1 = f s1 in
pop_scope ();
push_scope ();
let res2 = match s2_opt with Some s2 -> f s2 | None -> false in

Check warning on line 82 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L79-L82

Added lines #L79 - L82 were not covered by tests
pop_scope ();
res1 || res2
| JacobianPE _ -> true

Check warning on line 85 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L84-L85

Added lines #L84 - L85 were not covered by tests
| Assignment
{ assign_lhs= LValue {lval= LVariable {name; _}; _}
; assign_op= OperatorAssign Plus
; _ }
when String.equal name "jacobian" ->
not (is_jacobian_in_scope ())
| VarDecl {variables; _} ->
if
List.exists
~f:(fun {identifier; _} -> String.equal identifier.name "jacobian")
variables
then found_jacobian ();

Check warning on line 97 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L97

Added line #L97 was not covered by tests
false
| _ -> false in
let res = List.exists ~f stmts in
(* sanity check that pushes and pops are balanced *)
if !current_scope_id <> 1 then
Common.ICE.internal_compiler_error

Check warning on line 103 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L103

Added line #L103 was not covered by tests
[%message
"functions_block_contains_jac_pe: scope tracking failed"
(!current_scope_id : int)
(!jacobian_scope_id : int)
(stmts : untyped_statement list)];
res

let set_jacobian_compatibility_mode stmts =
Fun_kind.jacobian_compat_mode := not (functions_block_contains_jac_pe stmts)

let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) :
(Location_span.t * string) list =
Expand Down Expand Up @@ -89,6 +159,22 @@
, "Functions do not need to be declared before definition; all user \
defined function names are always in scope regardless of \
definition order." ) ]
| FunDef {funname; body; _}
when !Fun_kind.jacobian_compat_mode
&& String.is_suffix funname.name ~suffix:"_jacobian" ->
let acc =
( funname.id_loc
, "Functions that end in _jacobian will change meaning in Stan 2.39. \
They will be used for the encapsulating usages of 'jacobian +=', \
and therefore not available to be called in all the same places as \
this function is now. To avoid any issues, please rename this \
function to not end in _jacobian." )
:: acc in
fold_statement collect_deprecated_expr
(collect_deprecated_stmt fundefs)
collect_deprecated_lval
(fun l _ -> l)

Check warning on line 176 in src/frontend/Deprecation_analysis.ml

View check run for this annotation

Codecov / codecov/patch

src/frontend/Deprecation_analysis.ml#L176

Added line #L176 was not covered by tests
acc body.stmt
| Tilde {distribution; _} when String.equal distribution.name "lkj_cov" ->
let acc = (distribution.id_loc, lkj_cov_message) :: acc in
fold_statement collect_deprecated_expr
Expand Down
5 changes: 5 additions & 0 deletions src/frontend/Deprecation_analysis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ val rename_deprecated : (string * (int * int)) String.Map.t -> string -> string
val stan_lib_deprecations : (string * (int * int)) String.Map.t
val collect_warnings : typed_program -> Warnings.t list
val remove_unneeded_forward_decls : typed_program -> typed_program

val set_jacobian_compatibility_mode : untyped_statement list -> unit
(** Pre-Stan 2.39, we need to know if _jacobian functions are
FnPlain or not. We use the presence of any jacobian+= statements
as our condition. If none are present, we assume this is old code. *)
14 changes: 6 additions & 8 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ let verify_fn_target_plus_equals cf loc id =
let verify_fn_jacobian_plus_equals cf loc id =
if
String.is_suffix id.name ~suffix:"_jacobian"
&& (not !Fun_kind.jacobian_compat_mode)
&& not (in_jacobian_function cf || cf.current_block = TParam)
then Semantic_error.jacobian_plusequals_not_allowed loc |> error

Expand Down Expand Up @@ -913,13 +914,6 @@ let check_expression_of_scalar_or_type cf tenv t e name =

(* -- Statements ------------------------------------------------- *)
(* non returning functions *)
let verify_nrfn_target loc cf id =
if
String.is_suffix id.name ~suffix:"_lp"
&& not
(in_lp_function cf || cf.current_block = Model
|| cf.current_block = TParam)
then Semantic_error.target_plusequals_outside_model_or_logprob loc |> error

let check_nrfn loc tenv id es =
match Env.find tenv id.name with
Expand Down Expand Up @@ -960,7 +954,9 @@ let check_nrfn loc tenv id es =
let check_nr_fn_app loc cf tenv id es =
let tes = List.map ~f:(check_expression cf tenv) es in
verify_identifier id;
verify_nrfn_target loc cf id;
verify_fn_target_plus_equals cf loc id;
verify_fn_jacobian_plus_equals cf loc id;
verify_fn_rng cf loc id;
check_nrfn loc tenv id tes

(* target plus-equals / jacobian plus-equals *)
Expand Down Expand Up @@ -1894,6 +1890,8 @@ let add_userdefined_functions tenv stmts_opt =
match stmts_opt with
| None -> tenv
| Some {stmts; _} ->
(* TODO(2.39): Remove this workaround *)
Deprecation_analysis.set_jacobian_compatibility_mode stmts;
let f tenv (s : Ast.untyped_statement) =
match s with
| {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} ->
Expand Down
7 changes: 6 additions & 1 deletion src/middle/Fun_kind.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ type 'e t =
| UserDefined of string * bool suffix
[@@deriving compare, sexp, hash, map, fold]

(** If true, we assume _jacobian functions are
"plain" functions for the purposes of typechecking and warnings
*)
let jacobian_compat_mode = ref false

let suffix_from_name fname =
let is_suffix suffix = Core.String.is_suffix ~suffix fname in
if is_suffix "_rng" then FnRng
else if is_suffix "_lp" then FnTarget
else if is_suffix "_jacobian" then FnJacobian
else if is_suffix "_jacobian" && not !jacobian_compat_mode then FnJacobian
else if is_suffix "_lupdf" then FnLpdf true
else if is_suffix "_lupmf" then FnLpmf true
else if is_suffix "_lpdf" then FnLpdf false
Expand Down
9 changes: 9 additions & 0 deletions test/integration/bad/err-jacobian-plusequals-scope4.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
functions {
// void return type to check function statement, rather than expression
void foo_jacobian() {
jacobian += 1;
}
}
transformed data {
foo_jacobian();
}
9 changes: 9 additions & 0 deletions test/integration/bad/err_void_rng_check.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
functions {
void foo_rng(real x){
print(normal_rng(0,x));
}
}

model {
foo_rng(1.0);
}
24 changes: 24 additions & 0 deletions test/integration/bad/stanc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -1083,6 +1083,18 @@ Semantic error in 'err-jacobian-plusequals-scope3.stan', line 14, column 11 to c
15: }
-------------------------------------------------

The jacobian adjustment can only be applied in the transformed parameters block or in functions ending with _jacobian
[exit 1]
$ ../../../../install/default/bin/stanc err-jacobian-plusequals-scope4.stan
Semantic error in 'err-jacobian-plusequals-scope4.stan', line 8, column 2 to column 17:
-------------------------------------------------
6: }
7: transformed data {
8: foo_jacobian();
^
9: }
-------------------------------------------------

The jacobian adjustment can only be applied in the transformed parameters block or in functions ending with _jacobian
[exit 1]
$ ../../../../install/default/bin/stanc err-minus-types.stan
Expand Down Expand Up @@ -1245,6 +1257,18 @@ Syntax error in 'err-transformed-params.stan', line 4, column 0 to column 11, pa
-------------------------------------------------

"transformed parameters {", "model {" or "generated quantities {" expected after end of parameters block.
[exit 1]
$ ../../../../install/default/bin/stanc err_void_rng_check.stan
Semantic error in 'err_void_rng_check.stan', line 8, column 4 to column 17:
-------------------------------------------------
6:
7: model {
8: foo_rng(1.0);
^
9: }
-------------------------------------------------

Random number generators are only allowed in transformed data block, generated quantities block or user-defined functions with names ending in _rng.
[exit 1]
$ ../../../../install/default/bin/stanc expect_statement_seq_close_brace.stan
Syntax error in 'expect_statement_seq_close_brace.stan', line 6, column 0 to column 0, parsing error:
Expand Down
4 changes: 4 additions & 0 deletions test/integration/cli-args/warn-pedantic/stanc.expected
Original file line number Diff line number Diff line change
Expand Up @@ -546,6 +546,10 @@ Warning in 'jacobian_warning_user.stan', line 5, column 2: Left-hand side of
using jacobian += in the transformed parameters block.
[exit 0]
$ ../../../../../install/default/bin/stanc --warn-pedantic lp_fun.stan
Warning in 'lp_fun.stan', line 10, column 2: Using _lp functions in
transformed parameters is deprecated and will be disallowed in Stan 2.39.
Use an _jacobian function instead, as this allows change of variable
adjustments which are conditionally enabled by the algorithms.
Warning: The parameter y has 2 priors.
[exit 0]
$ ../../../../../install/default/bin/stanc --warn-pedantic missing-prior-false-alarm.stan
Expand Down
Loading