Solving MAXSAT and saying a few words about it.
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

#### 240 lines 6.9 KiB Raw Permalink Blame History

 `// vim: ts=4 sw=4 et` `use std::cell::RefCell;` `use std::collections::hash_map::HashMap;` `use std::collections::hash_set::HashSet;` `use std::rc::Rc;` `/*****************************` `* *` `* “Every oven has a vent.” *` `* 〜Graham Northup *` `* *` `*****************************/` `/// Represents a clause with a nonnegative weight. Not to be confused with [crate::clause::Clause].` `struct ClauseInner {` ` weight: usize,` ` satisfied: RefCell,` `}` `#[derive(Clone)]` `struct Clause(Rc);` `pub struct Var {` ` name: String,` ` clauses_pos: Vec,` ` clauses_neg: Vec,` ` assignment: RefCell>,` ` best: RefCell>,` `}` `impl Var {` ` fn get_sat_list(&self) -> Option<&Vec> {` ` match *self.assignment.borrow() {` ` None => None,` ` Some(true) => Some(&self.clauses_pos),` ` Some(false) => Some(&self.clauses_neg),` ` }` ` }` `}` `impl core::hash::Hash for Clause {` ` fn hash(&self, hasher: &mut H) {` ` (&*self.0 as *const _ as *const ()).hash(hasher);` ` }` `}` `impl core::cmp::PartialEq for Clause {` ` fn eq(&self, other: &Self) -> bool {` ` (&*self.0 as *const _ as *const ()) == (&*other.0 as *const _ as *const ())` ` }` `}` `impl core::cmp::Eq for Clause {}` `pub fn from_list(list: Vec<(crate::clause::Clause, usize)>) -> (usize, Vec) {` ` let mut map = HashMap::new();` ` let mut sumweight = 0;` ` for (clause, weight) in list {` ` sumweight += weight;` ` let newclause = Clause(Rc::new(ClauseInner {` ` weight,` ` satisfied: RefCell::new(false),` ` }));` ` for lit in clause.iter() {` ` match map.get(&lit.get_variable()) {` ` None => {` ` let var = RefCell::new(Var {` ` name: lit.get_variable().get_name().clone(),` ` clauses_pos: if !lit.is_neg() {` ` vec![newclause.clone()]` ` } else {` ` vec![]` ` },` ` clauses_neg: if lit.is_neg() {` ` vec![newclause.clone()]` ` } else {` ` vec![]` ` },` ` assignment: RefCell::new(None),` ` best: RefCell::new(None),` ` });` ` map.insert(lit.get_variable(), var);` ` }` ` Some(var) => {` ` if lit.is_neg() {` ` var.borrow_mut().clauses_neg.push(newclause.clone());` ` } else {` ` var.borrow_mut().clauses_pos.push(newclause.clone());` ` }` ` }` ` }` ` }` ` }` ` let mut vars = Vec::new();` ` for (_, var) in map {` ` vars.push(var.into_inner());` ` }` ` (sumweight, vars)` `}` `// Panics if the var isn't assigned.` `fn add_weight(var: &Var) -> (usize, HashSet) {` ` let mut set = HashSet::new();` ` let mut weight = 0;` ` let clauses = var.get_sat_list().unwrap();` ` for clause in clauses {` ` if *clause.0.satisfied.borrow() == false {` ` *clause.0.satisfied.borrow_mut() = true;` ` set.insert(clause.clone());` ` weight += clause.0.weight;` ` }` ` }` ` (weight, set)` `}` `fn remove_weight(clauses: HashSet) {` ` for clause in clauses {` ` *clause.0.satisfied.borrow_mut() = false;` ` }` `}` `fn solve_minsat_inner(vars: &Vec, start: usize, end: usize, bound: usize, f: &mut F)` `where` ` F: FnMut(&Vec, usize),` `{` ` if start == end {` ` f(vars, bound);` ` return;` ` }` ` {` ` let mut rm = vars[start].assignment.borrow_mut();` ` *rm = Some(false);` ` }` ` let (weight, s) = add_weight(&vars[start]);` ` if weight <= bound {` ` solve_minsat_inner(vars, start + 1, end, bound - weight, f);` ` }` ` remove_weight(s);` ` {` ` let mut rm = vars[start].assignment.borrow_mut();` ` *rm = Some(true);` ` }` ` let (weight, s) = add_weight(&vars[start]);` ` if weight <= bound {` ` solve_minsat_inner(vars, start + 1, end, bound - weight, f);` ` }` ` remove_weight(s);` ` {` ` let mut rm = vars[start].assignment.borrow_mut();` ` *rm = None;` ` }` `}` `fn maybe_minimize(vars: &Vec, bound: usize) {` ` let replace;` ` match *vars[0].best.borrow() {` ` Some((weight, _)) => replace = weight < bound,` ` None => replace = true,` ` }` ` if replace {` ` for var in vars {` ` *var.best.borrow_mut() =` ` Some((bound, var.assignment.borrow().clone().unwrap_or(false)));` ` }` ` }` `}` `fn reset(vars: &Vec) {` ` for var in vars {` ` *var.assignment.borrow_mut() = None;` ` *var.best.borrow_mut() = None;` ` }` `}` `pub fn solve_minsat_recursive(vars: &Vec, start: usize, end: usize) -> usize {` ` if end == start + 1 {` ` let mut nb = 0;` ` for clause in &vars[start].clauses_neg {` ` nb += clause.0.weight;` ` }` ` let mut pb = 0;` ` for clause in &vars[start].clauses_pos {` ` pb += clause.0.weight;` ` }` ` return usize::min(nb, pb);` ` }` ` let midpoint = (start + end) / 2;` ` let left = solve_minsat_recursive(vars, start, midpoint);` ` let right = solve_minsat_recursive(vars, midpoint, end);` ` solve_minsat_inner(vars, start, end, left + right, &mut |vars, bound| {` ` maybe_minimize(vars, bound)` ` });` ` let rv = left + right - vars[0].best.borrow().unwrap().0;` ` reset(vars);` ` rv` `}` `#[cfg(test)]` `mod test {` ` use super::*;` ` #[test]` ` // I was going to use "naïve" instead of "simple", but non-ASCII in identifiers is currently` ` // feature-gated.` ` fn simple_solver_ex1() {` ` clause_vars![a, b, c, d];` ` let mut cwl = Vec::new();` ` cwl.push((make_clause!(a | b | -c | d), 1));` ` cwl.push((make_clause!(-a | b | -c | d), 1));` ` cwl.push((make_clause!(-a | b | d), 1));` ` cwl.push((make_clause!(a | c | -d), 1));` ` cwl.push((make_clause!(-a | -b), 1));` ` cwl.push((make_clause!(a | b | c | d), 1));` ` let (bound, vars) = from_list(cwl);` ` solve_minsat_inner(&vars, 0, vars.len(), bound, &mut |vars, bound| {` ` maybe_minimize(vars, bound)` ` });` ` for var in vars {` ` print!("{} = {}, ", var.name, var.best.borrow().unwrap().1);` ` }` ` println!("");` ` }` ` #[test]` ` fn recursive_solver_ex1() {` ` clause_vars![a, b, c, d];` ` let mut cwl = Vec::new();` ` cwl.push((make_clause!(a | b | -c | d), 1));` ` cwl.push((make_clause!(-a | b | -c | d), 1));` ` cwl.push((make_clause!(-a | b | d), 1));` ` cwl.push((make_clause!(a | c | -d), 1));` ` cwl.push((make_clause!(-a | -b), 1));` ` cwl.push((make_clause!(a | b | c | d), 1));` ` let (_, vars) = from_list(cwl);` ` println!("min: {}", solve_minsat_recursive(&vars, 0, vars.len()));` ` }` ```} ```