diff --git a/src/circuit/uint32.rs b/src/circuit/uint32.rs index 6bb4847..e254132 100644 --- a/src/circuit/uint32.rs +++ b/src/circuit/uint32.rs @@ -199,6 +199,85 @@ impl UInt32 { } } + fn triop( + mut cs: CS, + a: &Self, + b: &Self, + c: &Self, + tri_fn: F, + circuit_fn: U + ) -> Result + where E: Engine, + CS: ConstraintSystem, + F: Fn(u32, u32, u32) -> u32, + U: Fn(&mut CS, usize, &Boolean, &Boolean, &Boolean) -> Result + { + let new_value = match (a.value, b.value, c.value) { + (Some(a), Some(b), Some(c)) => { + Some(tri_fn(a, b, c)) + }, + _ => None + }; + + let bits = a.bits.iter() + .zip(b.bits.iter()) + .zip(c.bits.iter()) + .enumerate() + .map(|(i, ((a, b), c))| circuit_fn(&mut cs, i, a, b, c)) + .collect::>()?; + + Ok(UInt32 { + bits: bits, + value: new_value + }) + } + + /// Compute the `maj` value (a and b) xor (a and c) xor (b and c) + /// during SHA256. + pub fn sha256_maj( + cs: CS, + a: &Self, + b: &Self, + c: &Self + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + Self::triop(cs, a, b, c, |a, b, c| (a & b) ^ (a & c) ^ (b & c), + |cs, i, a, b, c| { + Boolean::sha256_maj( + cs.namespace(|| format!("maj {}", i)), + a, + b, + c + ) + } + ) + } + + /// Compute the `ch` value `(a and b) xor ((not a) and c)` + /// during SHA256. + pub fn sha256_ch( + cs: CS, + a: &Self, + b: &Self, + c: &Self + ) -> Result + where E: Engine, + CS: ConstraintSystem + { + Self::triop(cs, a, b, c, |a, b, c| (a & b) ^ ((!a) & c), + |cs, i, a, b, c| { + Boolean::sha256_ch( + cs.namespace(|| format!("ch {}", i)), + a, + b, + c + ) + } + ) + } + /// XOR this `UInt32` with another `UInt32` pub fn xor( &self, @@ -597,4 +676,86 @@ mod test { } } } + + #[test] + fn test_uint32_sha256_maj() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0653]); + + for _ in 0..1000 { + let mut cs = TestConstraintSystem::::new(); + + let a: u32 = rng.gen(); + let b: u32 = rng.gen(); + let c: u32 = rng.gen(); + + let mut expected = (a & b) ^ (a & c) ^ (b & c); + + let a_bit = UInt32::alloc(cs.namespace(|| "a_bit"), Some(a)).unwrap(); + let b_bit = UInt32::constant(b); + let c_bit = UInt32::alloc(cs.namespace(|| "c_bit"), Some(c)).unwrap(); + + let r = UInt32::sha256_maj(&mut cs, &a_bit, &b_bit, &c_bit).unwrap(); + + assert!(cs.is_satisfied()); + + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + &Boolean::Is(ref b) => { + assert!(b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Not(ref b) => { + assert!(!b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Constant(b) => { + assert!(b == (expected & 1 == 1)); + } + } + + expected >>= 1; + } + } + } + + #[test] + fn test_uint32_sha256_ch() { + let mut rng = XorShiftRng::from_seed([0x5dbe6259, 0x8d313d76, 0x3237db17, 0xe5bc0653]); + + for _ in 0..1000 { + let mut cs = TestConstraintSystem::::new(); + + let a: u32 = rng.gen(); + let b: u32 = rng.gen(); + let c: u32 = rng.gen(); + + let mut expected = (a & b) ^ ((!a) & c); + + let a_bit = UInt32::alloc(cs.namespace(|| "a_bit"), Some(a)).unwrap(); + let b_bit = UInt32::constant(b); + let c_bit = UInt32::alloc(cs.namespace(|| "c_bit"), Some(c)).unwrap(); + + let r = UInt32::sha256_ch(&mut cs, &a_bit, &b_bit, &c_bit).unwrap(); + + assert!(cs.is_satisfied()); + + assert!(r.value == Some(expected)); + + for b in r.bits.iter() { + match b { + &Boolean::Is(ref b) => { + assert!(b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Not(ref b) => { + assert!(!b.get_value().unwrap() == (expected & 1 == 1)); + }, + &Boolean::Constant(b) => { + assert!(b == (expected & 1 == 1)); + } + } + + expected >>= 1; + } + } + } }