Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow existing uses of _jacobian in function names, with warning #1471

Merged
merged 3 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions src/frontend/Deprecation_analysis.ml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,76 @@ let lkj_cov_message =
independent lognormal distribution on the scales, see: \
https://mc-stan.org/docs/reference-manual/deprecations.html#lkj_cov-distribution"

let functions_block_contains_jac_pe (stmts : untyped_statement list) =
(* tracking if 'jacobian' is a variable in scope *)
let jacobian_scope_id = ref 0 in
let is_jacobian_in_scope () = !jacobian_scope_id > 0 in
let current_scope_id = ref 1 in
let found_jacobian () =
if not (is_jacobian_in_scope ()) then jacobian_scope_id := !current_scope_id
in
let push_scope () = current_scope_id := !current_scope_id + 1 in
let pop_scope () =
current_scope_id := !current_scope_id - 1;
(* if the scope we just left was the one defining jacobian, reset it *)
if !jacobian_scope_id > !current_scope_id then jacobian_scope_id := 0 in
(* walk over the tree, looking for usages of jacobian+= where
there is no variable called jacobian already in scope *)
let rec f (s : untyped_statement) =
match s.stmt with
| FunDef {body; funname; _}
when String.is_suffix funname.name ~suffix:"_jacobian" ->
push_scope ();
let res = f body in
pop_scope ();
res
| Block stmts | Profile (_, stmts) ->
push_scope ();
let res = List.exists ~f stmts in
pop_scope ();
res
| For {loop_body; _} | While (_, loop_body) | ForEach (_, _, loop_body) ->
push_scope ();
let res = f loop_body in
pop_scope ();
res
| IfThenElse (_, s1, s2_opt) ->
push_scope ();
let res1 = f s1 in
pop_scope ();
push_scope ();
let res2 = match s2_opt with Some s2 -> f s2 | None -> false in
pop_scope ();
res1 || res2
| JacobianPE _ -> true
| Assignment
{ assign_lhs= LValue {lval= LVariable {name; _}; _}
; assign_op= OperatorAssign Plus
; _ }
when String.equal name "jacobian" ->
not (is_jacobian_in_scope ())
| VarDecl {variables; _} ->
if
List.exists
~f:(fun {identifier; _} -> String.equal identifier.name "jacobian")
variables
then found_jacobian ();
false
| _ -> false in
let res = List.exists ~f stmts in
(* sanity check that pushes and pops are balanced *)
if !current_scope_id <> 1 then
Common.ICE.internal_compiler_error
[%message
"functions_block_contains_jac_pe: scope tracking failed"
(!current_scope_id : int)
(!jacobian_scope_id : int)
(stmts : untyped_statement list)];
res

let set_jacobian_compatibility_mode stmts =
Fun_kind.jacobian_compat_mode := not (functions_block_contains_jac_pe stmts)

let rec collect_deprecated_expr (acc : (Location_span.t * string) list)
({expr; emeta} : (typed_expr_meta, fun_kind) expr_with) :
(Location_span.t * string) list =
Expand Down Expand Up @@ -89,6 +159,22 @@ let rec collect_deprecated_stmt fundefs (acc : (Location_span.t * string) list)
, "Functions do not need to be declared before definition; all user \
defined function names are always in scope regardless of \
definition order." ) ]
| FunDef {funname; body; _}
when !Fun_kind.jacobian_compat_mode
&& String.is_suffix funname.name ~suffix:"_jacobian" ->
let acc =
( funname.id_loc
, "Functions that end in _jacobian will change meaning in Stan 2.39. \
They will be used for the encapsulating usages of 'jacobian +=', \
and therefore not available to be called in all the same places as \
this function is now. To avoid any issues, please rename this \
function to not end in _jacobian." )
:: acc in
fold_statement collect_deprecated_expr
(collect_deprecated_stmt fundefs)
collect_deprecated_lval
(fun l _ -> l)
acc body.stmt
| Tilde {distribution; _} when String.equal distribution.name "lkj_cov" ->
let acc = (distribution.id_loc, lkj_cov_message) :: acc in
fold_statement collect_deprecated_expr
Expand Down
5 changes: 5 additions & 0 deletions src/frontend/Deprecation_analysis.mli
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,8 @@ val rename_deprecated : (string * (int * int)) String.Map.t -> string -> string
val stan_lib_deprecations : (string * (int * int)) String.Map.t
val collect_warnings : typed_program -> Warnings.t list
val remove_unneeded_forward_decls : typed_program -> typed_program

val set_jacobian_compatibility_mode : untyped_statement list -> unit
(** Pre-Stan 2.39, we need to know if _jacobian functions are
FnPlain or not. We use the presence of any jacobian+= statements
as our condition. If none are present, we assume this is old code. *)
3 changes: 3 additions & 0 deletions src/frontend/Typechecker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -454,6 +454,7 @@ let verify_fn_target_plus_equals cf loc id =
let verify_fn_jacobian_plus_equals cf loc id =
if
String.is_suffix id.name ~suffix:"_jacobian"
&& (not !Fun_kind.jacobian_compat_mode)
&& not (in_jacobian_function cf || cf.current_block = TParam)
then Semantic_error.jacobian_plusequals_not_allowed loc |> error

Expand Down Expand Up @@ -1894,6 +1895,8 @@ let add_userdefined_functions tenv stmts_opt =
match stmts_opt with
| None -> tenv
| Some {stmts; _} ->
(* TODO(2.39): Remove this workaround *)
Deprecation_analysis.set_jacobian_compatibility_mode stmts;
let f tenv (s : Ast.untyped_statement) =
match s with
| {stmt= FunDef {returntype; funname; arguments; body}; smeta= {loc}} ->
Expand Down
7 changes: 6 additions & 1 deletion src/middle/Fun_kind.ml
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ type 'e t =
| UserDefined of string * bool suffix
[@@deriving compare, sexp, hash, map, fold]

(** If true, we assume _jacobian functions are
"plain" functions for the purposes of typechecking and warnings
*)
let jacobian_compat_mode = ref false

let suffix_from_name fname =
let is_suffix suffix = Core.String.is_suffix ~suffix fname in
if is_suffix "_rng" then FnRng
else if is_suffix "_lp" then FnTarget
else if is_suffix "_jacobian" then FnJacobian
else if is_suffix "_jacobian" && not !jacobian_compat_mode then FnJacobian
else if is_suffix "_lupdf" then FnLpdf true
else if is_suffix "_lupmf" then FnLpmf true
else if is_suffix "_lpdf" then FnLpdf false
Expand Down
122 changes: 101 additions & 21 deletions test/integration/good/code-gen/cpp.expected
Original file line number Diff line number Diff line change
Expand Up @@ -3931,18 +3931,32 @@ namespace deprecated_jacobian_usage_model_namespace {
using stan::model::model_base_crtp;
using namespace stan::math;
stan::math::profile_map profiles__;
static constexpr std::array<const char*, 7> locations_array__ =
static constexpr std::array<const char*, 12> locations_array__ =
{" (found before start of program)",
" (in 'deprecated_jacobian_usage.stan', line 11, column 2 to column 20)",
" (in 'deprecated_jacobian_usage.stan', line 12, column 2 to column 21)",
" (in 'deprecated_jacobian_usage.stan', line 19, column 2 to column 20)",
" (in 'deprecated_jacobian_usage.stan', line 24, column 2 to column 28)",
" (in 'deprecated_jacobian_usage.stan', line 20, column 2 to column 21)",
" (in 'deprecated_jacobian_usage.stan', line 5, column 4 to column 22)",
" (in 'deprecated_jacobian_usage.stan', line 6, column 4 to column 18)",
" (in 'deprecated_jacobian_usage.stan', line 7, column 4 to column 20)",
" (in 'deprecated_jacobian_usage.stan', line 4, column 19 to line 8, column 3)"};
" (in 'deprecated_jacobian_usage.stan', line 4, column 19 to line 8, column 3)",
" (in 'deprecated_jacobian_usage.stan', line 11, column 4 to column 13)",
" (in 'deprecated_jacobian_usage.stan', line 10, column 28 to line 12, column 3)",
" (in 'deprecated_jacobian_usage.stan', line 15, column 4 to column 27)",
" (in 'deprecated_jacobian_usage.stan', line 14, column 19 to line 16, column 3)"};
template <typename T0__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
std::is_floating_point<T0__>>>* = nullptr>
stan::promote_args_t<T0__> foo(const T0__& x, std::ostream* pstream__);
template <typename T0__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
std::is_floating_point<T0__>>>* = nullptr>
stan::promote_args_t<T0__>
bar_jacobian(const T0__& x, std::ostream* pstream__);
template <typename T0__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
std::is_floating_point<T0__>>>* = nullptr>
stan::promote_args_t<T0__> bar(const T0__& x, std::ostream* pstream__);
// real foo(real)
template <typename T0__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
Expand All @@ -3960,16 +3974,61 @@ stan::promote_args_t<T0__> foo(const T0__& x, std::ostream* pstream__) {
(void) DUMMY_VAR__;
try {
local_scalar_t__ jacobian = DUMMY_VAR__;
current_statement__ = 3;
jacobian = 0;
current_statement__ = 4;
jacobian = (jacobian + x);
jacobian = 0;
current_statement__ = 5;
jacobian = (jacobian + x);
current_statement__ = 6;
return jacobian;
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
// real bar_jacobian(real)
template <typename T0__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
std::is_floating_point<T0__>>>*>
stan::promote_args_t<T0__>
bar_jacobian(const T0__& x, std::ostream* pstream__) {
using local_scalar_t__ = stan::promote_args_t<T0__>;
int current_statement__ = 0;
// suppress unused var warning
(void) current_statement__;
static constexpr bool propto__ = true;
// suppress unused var warning
(void) propto__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// suppress unused var warning
(void) DUMMY_VAR__;
try {
current_statement__ = 8;
return x;
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
// real bar(real)
template <typename T0__,
stan::require_all_t<stan::math::disjunction<stan::is_autodiff<T0__>,
std::is_floating_point<T0__>>>*>
stan::promote_args_t<T0__> bar(const T0__& x, std::ostream* pstream__) {
using local_scalar_t__ = stan::promote_args_t<T0__>;
int current_statement__ = 0;
// suppress unused var warning
(void) current_statement__;
static constexpr bool propto__ = true;
// suppress unused var warning
(void) propto__;
local_scalar_t__ DUMMY_VAR__(std::numeric_limits<double>::quiet_NaN());
// suppress unused var warning
(void) DUMMY_VAR__;
try {
current_statement__ = 10;
return bar_jacobian(x, pstream__);
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
}
class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_jacobian_usage_model> {
private:

Expand Down Expand Up @@ -4029,7 +4088,7 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
local_scalar_t__ jacobian = DUMMY_VAR__;
current_statement__ = 1;
jacobian = 1;
current_statement__ = 2;
current_statement__ = 3;
jacobian = (jacobian + foo(static_cast<double>(1), pstream__));
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
Expand Down Expand Up @@ -4064,7 +4123,7 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
local_scalar_t__ jacobian = DUMMY_VAR__;
current_statement__ = 1;
jacobian = 1;
current_statement__ = 2;
current_statement__ = 3;
jacobian = (jacobian + foo(static_cast<double>(1), pstream__));
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
Expand Down Expand Up @@ -4115,14 +4174,18 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
}
current_statement__ = 1;
jacobian = 1;
current_statement__ = 2;
current_statement__ = 3;
jacobian = (jacobian + foo(static_cast<double>(1), pstream__));
if (emit_transformed_parameters__) {
out__.write(jacobian);
}
if (stan::math::logical_negation(emit_generated_quantities__)) {
return ;
}
double b_jacobian = std::numeric_limits<double>::quiet_NaN();
current_statement__ = 2;
b_jacobian = bar(static_cast<double>(10), pstream__);
out__.write(b_jacobian);
} catch (const std::exception& e) {
stan::lang::rethrow_located(e, locations_array__[current_statement__]);
}
Expand Down Expand Up @@ -4166,7 +4229,11 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
names__.reserve(names__.size() + temp.size());
names__.insert(names__.end(), temp.begin(), temp.end());
}
if (emit_generated_quantities__) {}
if (emit_generated_quantities__) {
std::vector<std::string> temp{"b_jacobian"};
names__.reserve(names__.size() + temp.size());
names__.insert(names__.end(), temp.begin(), temp.end());
}
}
inline void
get_dims(std::vector<std::vector<size_t>>& dimss__, const bool
Expand All @@ -4178,7 +4245,11 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
dimss__.reserve(dimss__.size() + temp.size());
dimss__.insert(dimss__.end(), temp.begin(), temp.end());
}
if (emit_generated_quantities__) {}
if (emit_generated_quantities__) {
std::vector<std::vector<size_t>> temp{std::vector<size_t>{}};
dimss__.reserve(dimss__.size() + temp.size());
dimss__.insert(dimss__.end(), temp.begin(), temp.end());
}
}
inline void
constrained_param_names(std::vector<std::string>& param_names__, bool
Expand All @@ -4187,7 +4258,9 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
if (emit_transformed_parameters__) {
param_names__.emplace_back(std::string() + "jacobian");
}
if (emit_generated_quantities__) {}
if (emit_generated_quantities__) {
param_names__.emplace_back(std::string() + "b_jacobian");
}
}
inline void
unconstrained_param_names(std::vector<std::string>& param_names__, bool
Expand All @@ -4196,13 +4269,15 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
if (emit_transformed_parameters__) {
param_names__.emplace_back(std::string() + "jacobian");
}
if (emit_generated_quantities__) {}
if (emit_generated_quantities__) {
param_names__.emplace_back(std::string() + "b_jacobian");
}
}
inline std::string get_constrained_sizedtypes() const {
return std::string("[{\"name\":\"jacobian\",\"type\":{\"name\":\"real\"},\"block\":\"transformed_parameters\"}]");
return std::string("[{\"name\":\"jacobian\",\"type\":{\"name\":\"real\"},\"block\":\"transformed_parameters\"},{\"name\":\"b_jacobian\",\"type\":{\"name\":\"real\"},\"block\":\"generated_quantities\"}]");
}
inline std::string get_unconstrained_sizedtypes() const {
return std::string("[{\"name\":\"jacobian\",\"type\":{\"name\":\"real\"},\"block\":\"transformed_parameters\"}]");
return std::string("[{\"name\":\"jacobian\",\"type\":{\"name\":\"real\"},\"block\":\"transformed_parameters\"},{\"name\":\"b_jacobian\",\"type\":{\"name\":\"real\"},\"block\":\"generated_quantities\"}]");
}
// Begin method overload boilerplate
template <typename RNG> inline void
Expand All @@ -4213,7 +4288,7 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
pstream = nullptr) const {
const size_t num_params__ = 0;
const size_t num_transformed = emit_transformed_parameters * (1);
const size_t num_gen_quantities = emit_generated_quantities * (0);
const size_t num_gen_quantities = emit_generated_quantities * (1);
const size_t num_to_write = num_params__ + num_transformed +
num_gen_quantities;
std::vector<int> params_i;
Expand All @@ -4230,7 +4305,7 @@ class deprecated_jacobian_usage_model final : public model_base_crtp<deprecated_
pstream = nullptr) const {
const size_t num_params__ = 0;
const size_t num_transformed = emit_transformed_parameters * (1);
const size_t num_gen_quantities = emit_generated_quantities * (0);
const size_t num_gen_quantities = emit_generated_quantities * (1);
const size_t num_to_write = num_params__ + num_transformed +
num_gen_quantities;
vars = std::vector<double>(num_to_write,
Expand Down Expand Up @@ -4312,15 +4387,20 @@ Warning in 'deprecated_jacobian_usage.stan', line 6, column 4: Variable name
Warning in 'deprecated_jacobian_usage.stan', line 7, column 11: Variable name
'jacobian' will be a reserved word starting in Stan 2.38. Please rename
it!
Warning in 'deprecated_jacobian_usage.stan', line 11, column 7: Variable name
Warning in 'deprecated_jacobian_usage.stan', line 19, column 7: Variable name
'jacobian' will be a reserved word starting in Stan 2.38. Please rename
it!
Warning in 'deprecated_jacobian_usage.stan', line 11, column 7: Variable name
Warning in 'deprecated_jacobian_usage.stan', line 19, column 7: Variable name
'jacobian' will be a reserved word starting in Stan 2.38. Please rename
it!
Warning in 'deprecated_jacobian_usage.stan', line 12, column 2: Variable name
Warning in 'deprecated_jacobian_usage.stan', line 20, column 2: Variable name
'jacobian' will be a reserved word starting in Stan 2.38. Please rename
it!
Warning in 'deprecated_jacobian_usage.stan', line 10, column 7: Functions
that end in _jacobian will change meaning in Stan 2.39. They will be used
for the encapsulating usages of 'jacobian +=', and therefore not
available to be called in all the same places as this function is now. To
avoid any issues, please rename this function to not end in _jacobian.
[exit 0]
$ ../../../../../install/default/bin/stanc --print-cpp double-reject.stan
// Code generated by %%NAME%% %%VERSION%%
Expand Down
Loading