diff --git a/src/circuit/boolean.rs b/src/circuit/boolean.rs index f5d9827..08f407e 100644 --- a/src/circuit/boolean.rs +++ b/src/circuit/boolean.rs @@ -649,6 +649,152 @@ impl Boolean { variable: ch }.into()) } + + /// Computes (a and b) xor (a and c) xor (b and c) + pub fn sha256_maj<'a, E, CS>( + mut cs: CS, + a: &'a Self, + b: &'a Self, + c: &'a Self, + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + let maj_value = match (a.get_value(), b.get_value(), c.get_value()) { + (Some(a), Some(b), Some(c)) => { + // (a and b) xor (a and c) xor (b and c) + Some((a & b) ^ (a & c) ^ (b & c)) + }, + _ => None + }; + + match (a, b, c) { + (&Boolean::Constant(_), + &Boolean::Constant(_), + &Boolean::Constant(_)) => { + // They're all constants, so we can just compute the value. + + return Ok(Boolean::Constant(maj_value.expect("they're all constants"))); + }, + (&Boolean::Constant(false), b, c) => { + // If a is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (b and c) + return Boolean::and( + cs, + b, + c + ); + }, + (a, &Boolean::Constant(false), c) => { + // If b is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and c) + return Boolean::and( + cs, + a, + c + ); + }, + (a, b, &Boolean::Constant(false)) => { + // If c is false, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and b) + return Boolean::and( + cs, + a, + b + ); + }, + (a, b, &Boolean::Constant(true)) => { + // If c is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a and b) xor (a) xor (b) + // equals + // not ((not a) and (not b)) + return Ok(Boolean::and( + cs, + &a.not(), + &b.not() + )?.not()); + }, + (a, &Boolean::Constant(true), c) => { + // If b is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (a) xor (a and c) xor (c) + return Ok(Boolean::and( + cs, + &a.not(), + &c.not() + )?.not()); + }, + (&Boolean::Constant(true), b, c) => { + // If a is true, + // (a and b) xor (a and c) xor (b and c) + // equals + // (b) xor (c) xor (b and c) + return Ok(Boolean::and( + cs, + &b.not(), + &c.not() + )?.not()); + }, + (&Boolean::Is(_), &Boolean::Is(_), &Boolean::Is(_)) | + (&Boolean::Is(_), &Boolean::Is(_), &Boolean::Not(_)) | + (&Boolean::Is(_), &Boolean::Not(_), &Boolean::Is(_)) | + (&Boolean::Is(_), &Boolean::Not(_), &Boolean::Not(_)) | + (&Boolean::Not(_), &Boolean::Is(_), &Boolean::Is(_)) | + (&Boolean::Not(_), &Boolean::Is(_), &Boolean::Not(_)) | + (&Boolean::Not(_), &Boolean::Not(_), &Boolean::Is(_)) | + (&Boolean::Not(_), &Boolean::Not(_), &Boolean::Not(_)) + => {} + } + + let maj = cs.alloc(|| "maj", || { + maj_value.get().map(|v| { + if *v { + E::Fr::one() + } else { + E::Fr::zero() + } + }) + })?; + + // ¬(¬a ∧ ¬b) ∧ ¬(¬a ∧ ¬c) ∧ ¬(¬b ∧ ¬c) + // (1 - ((1 - a) * (1 - b))) * (1 - ((1 - a) * (1 - c))) * (1 - ((1 - b) * (1 - c))) + // (a + b - ab) * (a + c - ac) * (b + c - bc) + // -2abc + ab + ac + bc + // a (-2bc + b + c) + bc + // + // (b) * (c) = (bc) + // (2bc - b - c) * (a) = bc - maj + + let bc = Self::and( + cs.namespace(|| "b and c"), + b, + c + )?; + + cs.enforce( + || "maj computation", + |_| bc.lc(CS::one(), E::Fr::one()) + + &bc.lc(CS::one(), E::Fr::one()) + - &b.lc(CS::one(), E::Fr::one()) + - &c.lc(CS::one(), E::Fr::one()), + |_| a.lc(CS::one(), E::Fr::one()), + |_| bc.lc(CS::one(), E::Fr::one()) - maj + ); + + Ok(AllocatedBit { + value: maj_value, + variable: maj + }.into()) + } } impl From for Boolean { @@ -1349,4 +1495,88 @@ mod test { } } } + + #[test] + fn test_boolean_sha256_maj() { + let variants = [ + OperandType::True, + OperandType::False, + OperandType::AllocatedTrue, + OperandType::AllocatedFalse, + OperandType::NegatedAllocatedTrue, + OperandType::NegatedAllocatedFalse + ]; + + for first_operand in variants.iter().cloned() { + for second_operand in variants.iter().cloned() { + for third_operand in variants.iter().cloned() { + let mut cs = TestConstraintSystem::::new(); + + let a; + let b; + let c; + + // maj = (a and b) xor (a and c) xor (b and c) + let expected = (first_operand.val() & second_operand.val()) ^ + (first_operand.val() & third_operand.val()) ^ + (second_operand.val() & third_operand.val()); + + { + let mut dyn_construct = |operand, name| { + let cs = cs.namespace(|| name); + + match operand { + OperandType::True => Boolean::constant(true), + OperandType::False => Boolean::constant(false), + OperandType::AllocatedTrue => Boolean::from(AllocatedBit::alloc(cs, Some(true)).unwrap()), + OperandType::AllocatedFalse => Boolean::from(AllocatedBit::alloc(cs, Some(false)).unwrap()), + OperandType::NegatedAllocatedTrue => Boolean::from(AllocatedBit::alloc(cs, Some(true)).unwrap()).not(), + OperandType::NegatedAllocatedFalse => Boolean::from(AllocatedBit::alloc(cs, Some(false)).unwrap()).not(), + } + }; + + a = dyn_construct(first_operand, "a"); + b = dyn_construct(second_operand, "b"); + c = dyn_construct(third_operand, "c"); + } + + let maj = Boolean::sha256_maj(&mut cs, &a, &b, &c).unwrap(); + + assert!(cs.is_satisfied()); + + assert_eq!(maj.get_value().unwrap(), expected); + + if first_operand.is_constant() || + second_operand.is_constant() || + third_operand.is_constant() + { + if first_operand.is_constant() && + second_operand.is_constant() && + third_operand.is_constant() + { + assert_eq!(cs.num_constraints(), 0); + } + } + else + { + assert_eq!(cs.get("maj"), { + if expected { + Fr::one() + } else { + Fr::zero() + } + }); + cs.set("maj", { + if expected { + Fr::zero() + } else { + Fr::one() + } + }); + assert_eq!(cs.which_is_unsatisfied().unwrap(), "maj computation"); + } + } + } + } + } }