Skip to content

Commit

Permalink
Fix equality for ScheduledDiagnostics
Browse files Browse the repository at this point in the history
`ScheduledDiagnostics` didn't implement ==. As a result, they were
considered equal when they were identical (===). This is not the
intended behavior because of RefValues, which are effectively pointers.
This commit implements == for Schedules and ScheduledDiagnostics by
looking at the value of the Ref.
  • Loading branch information
Sbozzolo committed Oct 21, 2024
1 parent 25a369d commit 9bdf8ee
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 1 deletion.
9 changes: 8 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
# NEWS

main
v0.2.9
-------

## Bug fixes

### Acquiring ownership with `compute!` PR [#88](https://github.com/CliMA/ClimaDiagnostics.jl/pull/88).

Prior to this version, `ClimaDiagnostics` would directly store use the output
returned by `compute!` functions the first time they are called. This leads to
problems when the output is a reference to an existing object since multiple
diagnostics would modify the same object. Now, `ClimaDiagnostics` makes a copy
of the return object so that it is no longer necessary to do so in the
`compute!` function.

### Correctly de-duplicate `ScheduledDiagnostics` [#93](https://github.com/CliMA/ClimaDiagnostics.jl/pull/93).

This version fixes a bug where `ScheduledDiagnostics` were not correctly
de-duplicated because `==` was not implemented correctly.

v0.2.8
-------

Expand Down
9 changes: 9 additions & 0 deletions src/ScheduledDiagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -218,4 +218,13 @@ function output_long_name(sd::ScheduledDiagnostic)
return sd.output_long_name
end

function Base.:(==)(sd1::T, sd2::T) where {T <: ScheduledDiagnostic}
# We provide == because we don't want to compare with === because we have
# RefValues
return all(
getproperty(sd, p) == getproperty(sd2, p) for p in propertynames(sd)
)
end


end
17 changes: 17 additions & 0 deletions src/Schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ function Base.show(io::IO, schedule::AbstractSchedule)
print(io, short_name(schedule))
end

function Base.:(==)(schedule1::T, schedule2::T) where {T <: AbstractSchedule}
# The Schedules are the identical when the properties are the same, but for
# Refs, we have to unpack the value because we don't want to compare
# pointers.
for p in propertynames(schedule1)
if getproperty(schedule2, p) isa Base.RefValue
getproperty(schedule1, p)[] == getproperty(schedule2, p)[] ||
return false
else
getproperty(schedule1, p) == getproperty(schedule2, p) ||
return false
end
end
return true
end


"""
DivisorSchedule
Expand Down
1 change: 1 addition & 0 deletions src/clima_diagnostics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ function DiagnosticsHandler(scheduled_diagnostics, Y, p, t; dt = nothing)
scheduled_diagnostics_keys = Int[]

unique_scheduled_diagnostics = unique(scheduled_diagnostics)

if length(unique_scheduled_diagnostics) != length(scheduled_diagnostics)
@warn "Given list of diagnostics contains duplicates, removing them"
end
Expand Down
7 changes: 7 additions & 0 deletions test/integration_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,13 +78,20 @@ function setup_integrator(output_dir; context, more_compute_diagnostics = 0)
variable = simple_var,
output_writer = h5_writer,
)
inst_every3s_diagnostic_another = ClimaDiagnostics.ScheduledDiagnostic(
variable = simple_var,
output_writer = h5_writer,
)
scheduled_diagnostics = [
average_diagnostic,
inst_diagnostic,
inst_diagnostic_h5,
inst_every3s_diagnostic,
inst_every3s_diagnostic_another,
]

@test length(unique(scheduled_diagnostics)) == 4

# Add more weight, useful for stressing allocations
compute_diagnostic = ClimaDiagnostics.ScheduledDiagnostic(
variable = simple_var,
Expand Down
4 changes: 4 additions & 0 deletions test/schedules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ include("TestTools.jl")
scheduled_func = EveryDtSchedule(dt_callback)
@test "$scheduled_func" == "0.2s"

scheduled_func_test1 = EveryDtSchedule(dt_callback; t_last = 0.1)
scheduled_func_test2 = EveryDtSchedule(dt_callback; t_last = 0.1)
@test scheduled_func_test1 == scheduled_func_test2

dt_callback2 = 0.3
t_last2 = 0.1
scheduled_func2 = EveryDtSchedule(dt_callback2; t_last = t_last2)
Expand Down

0 comments on commit 9bdf8ee

Please sign in to comment.