Skip to content
This repository has been archived by the owner on Jul 17, 2024. It is now read-only.

Commit

Permalink
feat: Improve ScoreAnalysis debug information (#105)
Browse files Browse the repository at this point in the history
  • Loading branch information
zepfred authored Jul 5, 2024
1 parent 81bbd40 commit 4396e2d
Show file tree
Hide file tree
Showing 3 changed files with 290 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def find_stub_files(stub_root: str):
test_suite='tests',
python_requires='>=3.10',
install_requires=[
'JPype1>=1.5.0',
'JPype1>=1.5.0'
],
cmdclass={'build_py': FetchDependencies},
package_data={
Expand Down
165 changes: 163 additions & 2 deletions tests/test_solution_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,19 @@
from timefold.solver.config import *
from timefold.solver.score import *

import inspect
import re

from ai.timefold.solver.core.api.score import ScoreExplanation as JavaScoreExplanation
from ai.timefold.solver.core.api.score.analysis import (
ConstraintAnalysis as JavaConstraintAnalysis,
MatchAnalysis as JavaMatchAnalysis,
ScoreAnalysis as JavaScoreAnalysis)
from ai.timefold.solver.core.api.score.constraint import Indictment as JavaIndictment
from ai.timefold.solver.core.api.score.constraint import (ConstraintRef as JavaConstraintRef,
ConstraintMatch as JavaConstraintMatch,
ConstraintMatchTotal as JavaConstraintMatchTotal)

from dataclasses import dataclass, field
from typing import Annotated, List

Expand All @@ -18,8 +31,8 @@ class Entity:
def my_constraints(constraint_factory: ConstraintFactory):
return [
constraint_factory.for_each(Entity)
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('package', 'Maximize Value'),
.reward(SimpleScore.ONE, lambda entity: entity.value)
.as_constraint('package', 'Maximize Value'),
]


Expand Down Expand Up @@ -127,6 +140,27 @@ def assert_score_analysis(problem: Solution, score_analysis: ScoreAnalysis):
assert_constraint_analysis(problem, constraint_analysis)


def assert_score_analysis_summary(score_analysis: ScoreAnalysis):
summary = score_analysis.summary
assert "Explanation of score (3):" in summary
assert "Constraint matches:" in summary
assert "3: constraint (Maximize Value) has 3 matches:" in summary
assert "1: justified with" in summary

summary_str = str(score_analysis)
assert summary == summary_str

match = score_analysis.constraint_analyses[0]
match_summary = match.summary
assert "Explanation of score (3):" in match_summary
assert "Constraint matches:" in match_summary
assert "3: constraint (Maximize Value) has 3 matches:" in match_summary
assert "1: justified with" in match_summary

match_summary_str = str(match)
assert match_summary == match_summary_str


def assert_solution_manager(solution_manager: SolutionManager[Solution]):
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
assert problem.score is None
Expand All @@ -140,6 +174,9 @@ def assert_solution_manager(solution_manager: SolutionManager[Solution]):
score_analysis = solution_manager.analyze(problem)
assert_score_analysis(problem, score_analysis)

score_analysis = solution_manager.analyze(problem)
assert_score_analysis_summary(score_analysis)


def test_solver_manager_score_manager():
with SolverManager.create(SolverFactory.create(solver_config)) as solver_manager:
Expand All @@ -148,3 +185,127 @@ def test_solver_manager_score_manager():

def test_solver_factory_score_manager():
assert_solution_manager(SolutionManager.create(SolverFactory.create(solver_config)))


def test_score_manager_solution_initialization():
solution_manager = SolutionManager.create(SolverFactory.create(solver_config))
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
score_analysis = solution_manager.analyze(problem)
assert score_analysis.is_solution_initialized

second_problem: Solution = Solution([Entity('A', None), Entity('B', None), Entity('C', None)], [1, 2, 3])
second_score_analysis = solution_manager.analyze(second_problem)
assert not second_score_analysis.is_solution_initialized


def test_score_manager_diff():
solution_manager = SolutionManager.create(SolverFactory.create(solver_config))
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
score_analysis = solution_manager.analyze(problem)
second_problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1), Entity('D', 1)], [1, 2, 3])
second_score_analysis = solution_manager.analyze(second_problem)
diff = score_analysis.diff(second_score_analysis)
assert diff.score.score == -1

diff_operation = score_analysis - second_score_analysis
assert diff_operation.score.score == -1

constraint_analyses = score_analysis.constraint_analyses
assert len(constraint_analyses) == 1


def test_score_manager_constraint_analysis_map():
solution_manager = SolutionManager.create(SolverFactory.create(solver_config))
problem: Solution = Solution([Entity('A', 1), Entity('B', 1), Entity('C', 1)], [1, 2, 3])
score_analysis = solution_manager.analyze(problem)
constraints = score_analysis.constraint_analyses
assert len(constraints) == 1

constraint_analysis = score_analysis.constraint_analysis('package', 'Maximize Value')
assert constraint_analysis.constraint_name == 'Maximize Value'

constraint_analysis = score_analysis.constraint_analysis(ConstraintRef('package', 'Maximize Value'))
assert constraint_analysis.constraint_name == 'Maximize Value'
assert constraint_analysis.match_count == 3


def test_score_manager_constraint_ref():
constraint_ref = ConstraintRef.parse_id('package/Maximize Value')

assert constraint_ref.package_name == 'package'
assert constraint_ref.constraint_name == 'Maximize Value'


ignored_java_functions = {
'equals',
'getClass',
'hashCode',
'notify',
'notifyAll',
'toString',
'wait',
'compareTo',
}

ignored_java_functions_per_class = {
'Indictment': {'getJustification'}, # deprecated
'ConstraintRef': {'of', 'packageName', 'constraintName'}, # built-in constructor and properties with @dataclass
'ConstraintAnalysis': {'summarize'}, # using summary instead
'ScoreAnalysis': {'summarize'}, # using summary instead
'ConstraintMatch': {
'getConstraintRef', # built-in constructor and properties with @dataclass
'getConstraintPackage', # deprecated
'getConstraintName', # deprecated
'getConstraintId', # deprecated
'getJustificationList', # deprecated
'getJustification', # built-in constructor and properties with @dataclass
'getScore', # built-in constructor and properties with @dataclass
'getIndictedObjectList', # built-in constructor and properties with @dataclass
},
'ConstraintMatchTotal': {
'getConstraintRef', # built-in constructor and properties with @dataclass
'composeConstraintId', # deprecated
'getConstraintPackage', # deprecated
'getConstraintName', # deprecated
'getConstraintId', # deprecated
'getConstraintMatchCount', # built-in constructor and properties with @dataclass
'getConstraintMatchSet', # built-in constructor and properties with @dataclass
'getConstraintWeight', # built-in constructor and properties with @dataclass
'getScore', # built-in constructor and properties with @dataclass
},
}


def test_has_all_methods():
missing = []
for python_type, java_type in ((ScoreExplanation, JavaScoreExplanation),
(ScoreAnalysis, JavaScoreAnalysis),
(ConstraintAnalysis, JavaConstraintAnalysis),
(ScoreExplanation, JavaScoreExplanation),
(ConstraintMatch, JavaConstraintMatch),
(ConstraintMatchTotal, JavaConstraintMatchTotal),
(ConstraintRef, JavaConstraintRef),
(Indictment, JavaIndictment)):
type_name = python_type.__name__
ignored_java_functions_type = ignored_java_functions_per_class[
type_name] if type_name in ignored_java_functions_per_class else {}

for function_name, function_impl in inspect.getmembers(java_type, inspect.isfunction):
if function_name in ignored_java_functions or function_name in ignored_java_functions_type:
continue

snake_case_name = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', function_name)
snake_case_name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', snake_case_name).lower()
snake_case_name_without_prefix = re.sub('(.)([A-Z][a-z]+)', r'\1_\2',
function_name[3:] if function_name.startswith(
"get") else function_name)
snake_case_name_without_prefix = re.sub('([a-z0-9])([A-Z])', r'\1_\2',
snake_case_name_without_prefix).lower()
if not hasattr(python_type, snake_case_name) and not hasattr(python_type, snake_case_name_without_prefix):
missing.append((java_type, python_type, snake_case_name))

if missing:
assertion_msg = ''
for java_type, python_type, snake_case_name in missing:
assertion_msg += f'{python_type} is missing a method ({snake_case_name}) from java_type ({java_type}).)\n'
raise AssertionError(assertion_msg)
Loading

0 comments on commit 4396e2d

Please sign in to comment.