diff --git a/src/main/java/org/logicng/explanations/smus/SmusComputation.java b/src/main/java/org/logicng/explanations/smus/SmusComputation.java index 4a66ab9c..81c024a9 100644 --- a/src/main/java/org/logicng/explanations/smus/SmusComputation.java +++ b/src/main/java/org/logicng/explanations/smus/SmusComputation.java @@ -30,8 +30,6 @@ import static org.logicng.handlers.Handler.aborted; import static org.logicng.handlers.Handler.start; -import static org.logicng.handlers.OptimizationHandler.satHandler; -import static org.logicng.solvers.maxsat.OptimizationConfig.OptimizationType.SAT_OPTIMIZATION; import static org.logicng.util.CollectionHelper.difference; import static org.logicng.util.CollectionHelper.nullSafe; @@ -39,7 +37,9 @@ import org.logicng.datastructures.Tristate; import org.logicng.formulas.Formula; import org.logicng.formulas.FormulaFactory; +import org.logicng.formulas.Literal; import org.logicng.formulas.Variable; +import org.logicng.handlers.Handler; import org.logicng.handlers.MaxSATHandler; import org.logicng.handlers.OptimizationHandler; import org.logicng.handlers.SATHandler; @@ -52,15 +52,16 @@ import org.logicng.solvers.functions.OptimizationFunction; import org.logicng.solvers.maxsat.OptimizationConfig; import org.logicng.solvers.maxsat.algorithms.MaxSAT; +import org.logicng.util.FormulaHelper; import java.util.ArrayList; +import java.util.Collection; import java.util.List; import java.util.Map; import java.util.Set; import java.util.SortedSet; import java.util.TreeMap; import java.util.TreeSet; -import java.util.function.Consumer; import java.util.stream.Collectors; /** @@ -137,10 +138,29 @@ public static

List

computeSmus( final List additionalConstraints, final FormulaFactory f, final OptimizationConfig config) { - if (config.getOptimizationType() == SAT_OPTIMIZATION) { - return computeSmusSAT(propositions, additionalConstraints, f, config.getOptimizationHandler()); - } else { - return computeSmusMaxSAT(propositions, additionalConstraints, f, config); + final Handler handler = getHandler(config); + start(handler); + final OptSolver growSolver = OptSolver.create(f, config); + growSolver.addConstraint(nullSafe(additionalConstraints)); + final Map propositionMapping = createPropositionsMapping(propositions, growSolver, f); + final boolean sat = growSolver.sat(propositionMapping.keySet()); + if (sat || growSolver.aborted()) { + return null; + } + final OptSolver hSolver = OptSolver.create(f, config); + while (true) { + final SortedSet h = hSolver.minimize(propositionMapping.keySet()); + if (h == null || aborted(handler)) { + return null; + } + final SortedSet c = grow(growSolver, h, propositionMapping.keySet()); + if (aborted(handler)) { + return null; + } + if (c == null) { + return h.stream().map(propositionMapping::get).collect(Collectors.toList()); + } + hSolver.addConstraint(f.or(c)); } } @@ -194,152 +214,183 @@ public static List computeSmusForFormulas( return smus == null ? null : smus.stream().map(Proposition::formula).collect(Collectors.toList()); } - private static

List

computeSmusSAT( - final List

propositions, - final List additionalConstraints, - final FormulaFactory f, - final OptimizationHandler handler) { - start(handler); - final SATSolver growSolver = MiniSat.miniSat(f); - growSolver.add(nullSafe(additionalConstraints)); - final Map propositionMapping = createPropositionsMapping(propositions, growSolver::add, f); - final boolean sat = growSolver.sat(satHandler(handler), propositionMapping.keySet()) == Tristate.TRUE; - if (sat || aborted(handler)) { - return null; - } - final SATSolver hSolver = MiniSat.miniSat(f); - while (true) { - final SortedSet h = minimumHs(hSolver, propositionMapping.keySet(), handler); - if (h == null || aborted(handler)) { - return null; - } - final SortedSet c = grow(growSolver, h, propositionMapping.keySet(), handler); - if (aborted(handler)) { - return null; - } - if (c == null) { - return h.stream().map(propositionMapping::get).collect(Collectors.toList()); - } - hSolver.add(f.or(c)); - } - } - - private static

List

computeSmusMaxSAT( - final List

propositions, - final List additionalConstraints, - final FormulaFactory f, - final OptimizationConfig config) { - final MaxSATHandler handler = config.getMaxSATHandler(); - start(handler); - final List growSolverConstraints = new ArrayList<>(nullSafe(additionalConstraints)); - final Map propositionMapping = createPropositionsMapping(propositions, growSolverConstraints::add, f); - final boolean sat = sat(growSolverConstraints, propositionMapping.keySet(), handler, f); - if (sat || aborted(handler)) { - return null; - } - final List hSolverConstraints = new ArrayList<>(); - while (true) { - final SortedSet h = minimumHs(hSolverConstraints, propositionMapping.keySet(), config, f); - if (h == null || aborted(handler)) { - return null; - } - final SortedSet c = grow(growSolverConstraints, h, propositionMapping.keySet(), config, f); - if (aborted(handler)) { - return null; - } - if (c == null) { - return h.stream().map(propositionMapping::get).collect(Collectors.toList()); - } - hSolverConstraints.add(f.or(c)); - } + private static Handler getHandler(final OptimizationConfig config) { + return config.getOptimizationType() == OptimizationConfig.OptimizationType.SAT_OPTIMIZATION + ? config.getOptimizationHandler() + : config.getMaxSATHandler(); } private static

Map createPropositionsMapping( - final List

propositions, final Consumer consumer, final FormulaFactory f) { + final List

propositions, final OptSolver solver, final FormulaFactory f) { final Map propositionMapping = new TreeMap<>(); for (final P proposition : propositions) { final Variable selector = f.variable(PROPOSITION_SELECTOR + propositionMapping.size()); propositionMapping.put(selector, proposition); - consumer.accept(f.equivalence(selector, proposition.formula())); + solver.addConstraint(f.equivalence(selector, proposition.formula())); } return propositionMapping; } - private static boolean sat(final List constraints, final Set variables, final MaxSATHandler handler, final FormulaFactory f) { - final SATHandler satHandler = handler == null ? null : handler.satHandler(); - final SATSolver satSolver = MiniSat.miniSat(f); - satSolver.add(constraints); - return satSolver.sat(satHandler, variables) == Tristate.TRUE; + private static SortedSet grow(final OptSolver growSolver, final SortedSet h, final Set variables) { + growSolver.saveState(); + growSolver.addConstraint(h); + final SortedSet maxModel = growSolver.maximize(variables); + if (maxModel == null) { + return null; + } + growSolver.loadState(); + return difference(variables, maxModel, TreeSet::new); } - private static SortedSet minimumHs( - final SATSolver hSolver, final Set variables, final OptimizationHandler handler) { - final Assignment minimumHsModel = hSolver.execute(OptimizationFunction.builder() - .handler(handler) - .literals(variables) - .minimize().build()); - return aborted(handler) ? null : new TreeSet<>(minimumHsModel.positiveVariables()); - } + private abstract static class OptSolver { + protected final FormulaFactory f; + protected final OptimizationConfig config; - private static SortedSet minimumHs( - final List constraints, - final Set variables, - final OptimizationConfig config, - final FormulaFactory f) { - final MaxSATSolver maxSatSolver = config.genMaxSATSolver(f); - constraints.forEach(maxSatSolver::addHardFormula); - for (final Variable v : variables) { - maxSatSolver.addSoftFormula(v.negate(), 1); + OptSolver(final FormulaFactory f, final OptimizationConfig config) { + this.f = f; + this.config = config; } - final MaxSATHandler handler = config.getMaxSATHandler(); - final MaxSAT.MaxSATResult result = maxSatSolver.solve(handler); - if (result == MaxSAT.MaxSATResult.UNDEF || aborted(handler)) { - return null; + + public static OptSolver create(final FormulaFactory f, final OptimizationConfig config) { + if (config.getOptimizationType() == OptimizationConfig.OptimizationType.SAT_OPTIMIZATION) { + return new SatOptSolver(f, config); + } else { + return new MaxSatOptSolver(f, config); + } + } + + abstract void addConstraint(final Formula formula); + + abstract void addConstraint(final Collection formulas); + + abstract boolean sat(final Collection variables); + + abstract void saveState(); + + abstract void loadState(); + + abstract SortedSet maximize(final Collection targetLiterals); + + SortedSet minimize(final Collection targetLiterals) { + return maximize(FormulaHelper.negateLiterals(targetLiterals, TreeSet::new)); } - return new TreeSet<>(maxSatSolver.model().positiveVariables()); + + abstract boolean aborted(); } - private static SortedSet grow( - final SATSolver growSolver, - final SortedSet h, - final Set variables, - final OptimizationHandler handler) { - final SolverState solverState = growSolver.saveState(); - growSolver.add(h); - final Assignment maxModel = growSolver.execute(OptimizationFunction.builder() - .handler(handler) - .literals(variables) - .maximize().build()); - if (maxModel == null || aborted(handler)) { - return null; - } else { - growSolver.loadState(solverState); - return difference(variables, maxModel.positiveVariables(), TreeSet::new); + private static class SatOptSolver extends OptSolver { + private final MiniSat solver; + private SolverState state; + + SatOptSolver(final FormulaFactory f, final OptimizationConfig config) { + super(f, config); + this.solver = MiniSat.miniSat(f); + this.state = null; + } + + @Override + void addConstraint(final Formula formula) { + this.solver.add(formula); + } + + @Override + void addConstraint(final Collection formulas) { + this.solver.add(formulas); + } + + @Override + boolean sat(final Collection variables) { + final SATHandler satHandler = this.config.getOptimizationHandler() == null ? null : this.config.getOptimizationHandler().satHandler(); + return this.solver.sat(satHandler, variables) == Tristate.TRUE; + } + + @Override + void saveState() { + this.state = this.solver.saveState(); + } + + @Override + void loadState() { + if (this.state != null) { + this.solver.loadState(this.state); + this.state = null; + } + } + + @Override + SortedSet maximize(final Collection targetLiterals) { + final OptimizationFunction optFunction = OptimizationFunction.builder() + .handler(this.config.getOptimizationHandler()) + .literals(targetLiterals) + .maximize().build(); + final Assignment model = this.solver.execute(optFunction); + return model == null || aborted() ? null : new TreeSet<>(model.positiveVariables()); + } + + @Override + boolean aborted() { + return Handler.aborted(this.config.getOptimizationHandler()); } } - private static SortedSet grow( - final List constraints, - final SortedSet h, - final Set variables, - final OptimizationConfig config, - final FormulaFactory f) { - final MaxSATSolver maxSatSolver = config.genMaxSATSolver(f); - constraints.forEach(maxSatSolver::addHardFormula); - h.forEach(maxSatSolver::addHardFormula); - for (final Variable v : variables) { - maxSatSolver.addSoftFormula(v, 1); + private static class MaxSatOptSolver extends OptSolver { + private List constraints; + private int saveIdx; + + public MaxSatOptSolver(final FormulaFactory f, final OptimizationConfig config) { + super(f, config); + this.constraints = new ArrayList<>(); + this.saveIdx = -1; } - final MaxSATHandler handler = config.getMaxSATHandler(); - final MaxSAT.MaxSATResult result = maxSatSolver.solve(handler); - if (result == MaxSAT.MaxSATResult.UNDEF || aborted(handler)) { - return null; + + @Override + void addConstraint(final Formula formula) { + this.constraints.add(formula); } - final Assignment maxModel = maxSatSolver.model(); - if (maxModel == null) { - return null; - } else { - return difference(variables, maxModel.positiveVariables(), TreeSet::new); + + @Override + void addConstraint(final Collection formulas) { + this.constraints.addAll(formulas); + } + + @Override + boolean sat(final Collection variables) { + final SATHandler satHandler = this.config.getMaxSATHandler() == null ? null : this.config.getMaxSATHandler().satHandler(); + final SATSolver satSolver = MiniSat.miniSat(this.f); + satSolver.add(this.constraints); + return satSolver.sat(satHandler, variables) == Tristate.TRUE; + } + + @Override + void saveState() { + this.saveIdx = this.constraints.size(); + } + + @Override + void loadState() { + if (this.saveIdx != -1) { + this.constraints = this.constraints.subList(0, this.saveIdx); + this.saveIdx = -1; + } + } + + @Override + SortedSet maximize(final Collection targetLiterals) { + final MaxSATSolver maxSatSolver = this.config.genMaxSATSolver(this.f); + this.constraints.forEach(maxSatSolver::addHardFormula); + for (final Literal lit : targetLiterals) { + maxSatSolver.addSoftFormula(lit, 1); + } + final MaxSATHandler handler = this.config.getMaxSATHandler(); + final MaxSAT.MaxSATResult result = maxSatSolver.solve(handler); + return result == MaxSAT.MaxSATResult.UNDEF || result == MaxSAT.MaxSATResult.UNSATISFIABLE || aborted() + ? null + : new TreeSet<>(maxSatSolver.model().positiveVariables()); + } + + @Override + boolean aborted() { + return Handler.aborted(this.config.getMaxSATHandler()); } } }