From 08500ee71275e7b6127e7b512c0763b058acf714 Mon Sep 17 00:00:00 2001 From: Jack Grigg Date: Wed, 22 Apr 2020 10:45:51 +1200 Subject: [PATCH] ff: PrimeField: BitAnd + Shr --- bellman/src/groth16/tests/dummy_engine.rs | 19 +++- ff/ff_derive/src/lib.rs | 73 +++++++++++-- ff/src/lib.rs | 4 +- pairing/src/bls12_381/fq.rs | 73 +++++++++++++ pairing/src/bls12_381/fr.rs | 61 +++++++++++ zcash_primitives/src/jubjub/fs.rs | 127 +++++++++++++++++++++- zcash_primitives/src/pedersen_hash.rs | 8 +- 7 files changed, 348 insertions(+), 17 deletions(-) diff --git a/bellman/src/groth16/tests/dummy_engine.rs b/bellman/src/groth16/tests/dummy_engine.rs index 91f6993..86e7b18 100644 --- a/bellman/src/groth16/tests/dummy_engine.rs +++ b/bellman/src/groth16/tests/dummy_engine.rs @@ -8,7 +8,7 @@ use rand_core::RngCore; use std::cmp::Ordering; use std::fmt; use std::num::Wrapping; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::ops::{Add, AddAssign, BitAnd, Mul, MulAssign, Neg, Shr, Sub, SubAssign}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; const MODULUS_R: Wrapping = Wrapping(64513); @@ -151,6 +151,23 @@ impl MulAssign for Fr { } } +impl BitAnd for Fr { + type Output = u64; + + fn bitand(self, rhs: u64) -> u64 { + (self.0).0 as u64 & rhs + } +} + +impl Shr for Fr { + type Output = Fr; + + fn shr(mut self, rhs: u32) -> Fr { + self.0 = Wrapping((self.0).0 >> rhs); + self + } +} + impl Field for Fr { fn random(rng: &mut R) -> Self { Fr(Wrapping(rng.next_u32()) % MODULUS_R) diff --git a/ff/ff_derive/src/lib.rs b/ff/ff_derive/src/lib.rs index 17ad399..f5439c5 100644 --- a/ff/ff_derive/src/lib.rs +++ b/ff/ff_derive/src/lib.rs @@ -785,14 +785,19 @@ fn prime_field_impl( proc_macro2::Punct::new('&', proc_macro2::Spacing::Alone), ); - // (self.0).0[0], (self.0).0[1], ..., 0, 0, 0, 0, ... - let mut into_repr_params = proc_macro2::TokenStream::new(); - into_repr_params.append_separated( - (0..limbs) - .map(|i| quote! { (self.0).0[#i] }) - .chain((0..limbs).map(|_| quote! {0})), - proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), - ); + fn mont_reduce_params(a: proc_macro2::TokenStream, limbs: usize) -> proc_macro2::TokenStream { + // a.0[0], a.0[1], ..., 0, 0, 0, 0, ... + let mut mont_reduce_params = proc_macro2::TokenStream::new(); + mont_reduce_params.append_separated( + (0..limbs) + .map(|i| quote! { (#a.0).0[#i] }) + .chain((0..limbs).map(|_| quote! {0})), + proc_macro2::Punct::new(',', proc_macro2::Spacing::Alone), + ); + mont_reduce_params + } + + let mont_reduce_self_params = mont_reduce_params(quote! {self}, limbs); let top_limb_index = limbs - 1; @@ -1006,6 +1011,56 @@ fn prime_field_impl( } } + impl ::core::ops::BitAnd for #name { + type Output = u64; + + #[inline(always)] + fn bitand(mut self, rhs: u64) -> u64 { + self.mont_reduce( + #mont_reduce_self_params + ); + + (self.0).0[0] & rhs + } + } + + impl ::core::ops::Shr for #name { + type Output = #name; + + #[inline(always)] + fn shr(mut self, mut n: u32) -> #name { + if n as usize >= 64 * #limbs { + return Self::from(0); + } + + // Convert from Montgomery to native representation. + self.mont_reduce( + #mont_reduce_self_params + ); + + while n >= 64 { + let mut t = 0; + for i in (self.0).0.iter_mut().rev() { + ::core::mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in (self.0).0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + + // Convert back to Montgomery representation + self * #name(R2) + } + } + impl ::ff::PrimeField for #name { type Repr = #repr; @@ -1023,7 +1078,7 @@ fn prime_field_impl( fn into_repr(&self) -> #repr { let mut r = *self; r.mont_reduce( - #into_repr_params + #mont_reduce_self_params ); r.0 diff --git a/ff/src/lib.rs b/ff/src/lib.rs index b2f4d3a..78fbc5c 100644 --- a/ff/src/lib.rs +++ b/ff/src/lib.rs @@ -257,7 +257,9 @@ impl fmt::Display for PrimeFieldDecodingError { } /// This represents an element of a prime field. -pub trait PrimeField: Field + From { +pub trait PrimeField: + Field + From + BitAnd + Shr +{ /// The prime field can be converted back and forth into this biginteger /// representation. type Repr: PrimeFieldRepr + From; diff --git a/pairing/src/bls12_381/fq.rs b/pairing/src/bls12_381/fq.rs index 8e5d660..f9caf5e 100644 --- a/pairing/src/bls12_381/fq.rs +++ b/pairing/src/bls12_381/fq.rs @@ -1931,6 +1931,79 @@ fn test_fq_mul_assign() { } } +#[test] +fn test_fq_shr() { + let mut a = Fq::from_repr(FqRepr([ + 0xaa5cdd6172847ffd, + 0x43242c06aed55287, + 0x9ddd5b312f3dd104, + 0xc5541fd48046b7e7, + 0x16080cf4071e0b05, + 0x1225f2901aea514e, + ])) + .unwrap(); + a = a >> 0; + assert_eq!( + a.into_repr(), + FqRepr([ + 0xaa5cdd6172847ffd, + 0x43242c06aed55287, + 0x9ddd5b312f3dd104, + 0xc5541fd48046b7e7, + 0x16080cf4071e0b05, + 0x1225f2901aea514e, + ]) + ); + a = a >> 1; + assert_eq!( + a.into_repr(), + FqRepr([ + 0xd52e6eb0b9423ffe, + 0x21921603576aa943, + 0xceeead98979ee882, + 0xe2aa0fea40235bf3, + 0x0b04067a038f0582, + 0x0912f9480d7528a7, + ]) + ); + a = a >> 50; + assert_eq!( + a.into_repr(), + FqRepr([ + 0x8580d5daaa50f54b, + 0xab6625e7ba208864, + 0x83fa9008d6fcf3bb, + 0x019e80e3c160b8aa, + 0xbe52035d4a29c2c1, + 0x0000000000000244, + ]) + ); + a = a >> 130; + assert_eq!( + a.into_repr(), + FqRepr([ + 0xa0fea40235bf3cee, + 0x4067a038f0582e2a, + 0x2f9480d7528a70b0, + 0x0000000000000091, + 0x0000000000000000, + 0x0000000000000000, + ]) + ); + a = a >> 64; + assert_eq!( + a.into_repr(), + FqRepr([ + 0x4067a038f0582e2a, + 0x2f9480d7528a70b0, + 0x0000000000000091, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ]) + ); +} + #[test] fn test_fq_squaring() { let a = Fq(FqRepr([ diff --git a/pairing/src/bls12_381/fr.rs b/pairing/src/bls12_381/fr.rs index cc02d22..6bfb175 100644 --- a/pairing/src/bls12_381/fr.rs +++ b/pairing/src/bls12_381/fr.rs @@ -669,6 +669,67 @@ fn test_fr_mul_assign() { } } +#[test] +fn test_fr_shr() { + let mut a = Fr::from_repr(FrRepr([ + 0xb33fbaec482a283f, + 0x997de0d3a88cb3df, + 0x9af62d2a9a0e5525, + 0x36003ab08de70da1, + ])) + .unwrap(); + a = a >> 0; + assert_eq!( + a.into_repr(), + FrRepr([ + 0xb33fbaec482a283f, + 0x997de0d3a88cb3df, + 0x9af62d2a9a0e5525, + 0x36003ab08de70da1, + ]) + ); + a = a >> 1; + assert_eq!( + a.into_repr(), + FrRepr([ + 0xd99fdd762415141f, + 0xccbef069d44659ef, + 0xcd7b16954d072a92, + 0x1b001d5846f386d0, + ]) + ); + a = a >> 50; + assert_eq!( + a.into_repr(), + FrRepr([ + 0xbc1a7511967bf667, + 0xc5a55341caa4b32f, + 0x075611bce1b4335e, + 0x00000000000006c0, + ]) + ); + a = a >> 130; + assert_eq!( + a.into_repr(), + FrRepr([ + 0x01d5846f386d0cd7, + 0x00000000000001b0, + 0x0000000000000000, + 0x0000000000000000, + ]) + ); + a = a >> 64; + assert_eq!( + a.into_repr(), + FrRepr([ + 0x00000000000001b0, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ]) + ); +} + #[test] fn test_fr_squaring() { let a = Fr(FrRepr([ diff --git a/zcash_primitives/src/jubjub/fs.rs b/zcash_primitives/src/jubjub/fs.rs index e4bea13..81d7089 100644 --- a/zcash_primitives/src/jubjub/fs.rs +++ b/zcash_primitives/src/jubjub/fs.rs @@ -3,7 +3,8 @@ use ff::{ PrimeFieldRepr, SqrtField, }; use rand_core::RngCore; -use std::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign}; +use std::mem; +use std::ops::{Add, AddAssign, BitAnd, Mul, MulAssign, Neg, Shr, Sub, SubAssign}; use subtle::{Choice, ConditionallySelectable, ConstantTimeEq, CtOption}; use super::ToUniform; @@ -452,6 +453,69 @@ impl MulAssign for Fs { } } +impl BitAnd for Fs { + type Output = u64; + + #[inline(always)] + fn bitand(mut self, rhs: u64) -> u64 { + self.mont_reduce( + (self.0).0[0], + (self.0).0[1], + (self.0).0[2], + (self.0).0[3], + 0, + 0, + 0, + 0, + ); + (self.0).0[0] & rhs + } +} + +impl Shr for Fs { + type Output = Self; + + #[inline(always)] + fn shr(mut self, mut n: u32) -> Self { + if n as usize >= 64 * 4 { + return Self::from(0); + } + + // Convert from Montgomery to native representation. + self.mont_reduce( + (self.0).0[0], + (self.0).0[1], + (self.0).0[2], + (self.0).0[3], + 0, + 0, + 0, + 0, + ); + + while n >= 64 { + let mut t = 0; + for i in (self.0).0.iter_mut().rev() { + mem::swap(&mut t, i); + } + n -= 64; + } + + if n > 0 { + let mut t = 0; + for i in (self.0).0.iter_mut().rev() { + let t2 = *i << (64 - n); + *i >>= n; + *i |= t; + t = t2; + } + } + + // Convert back to Montgomery representation + self * Fs(R2) + } +} + impl PrimeField for Fs { type Repr = FsRepr; @@ -1400,6 +1464,67 @@ fn test_fs_mul_assign() { } } +#[test] +fn test_fs_shr() { + let mut a = Fs::from_repr(FsRepr([ + 0xb33fbaec482a283f, + 0x997de0d3a88cb3df, + 0x9af62d2a9a0e5525, + 0x06003ab08de70da1, + ])) + .unwrap(); + a = a >> 0; + assert_eq!( + a.into_repr(), + FsRepr([ + 0xb33fbaec482a283f, + 0x997de0d3a88cb3df, + 0x9af62d2a9a0e5525, + 0x06003ab08de70da1, + ]) + ); + a = a >> 1; + assert_eq!( + a.into_repr(), + FsRepr([ + 0xd99fdd762415141f, + 0xccbef069d44659ef, + 0xcd7b16954d072a92, + 0x03001d5846f386d0, + ]) + ); + a = a >> 50; + assert_eq!( + a.into_repr(), + FsRepr([ + 0xbc1a7511967bf667, + 0xc5a55341caa4b32f, + 0x075611bce1b4335e, + 0x00000000000000c0, + ]) + ); + a = a >> 130; + assert_eq!( + a.into_repr(), + FsRepr([ + 0x01d5846f386d0cd7, + 0x0000000000000030, + 0x0000000000000000, + 0x0000000000000000, + ]) + ); + a = a >> 64; + assert_eq!( + a.into_repr(), + FsRepr([ + 0x0000000000000030, + 0x0000000000000000, + 0x0000000000000000, + 0x0000000000000000, + ]) + ); +} + #[test] fn test_fr_squaring() { let a = Fs(FsRepr([ diff --git a/zcash_primitives/src/pedersen_hash.rs b/zcash_primitives/src/pedersen_hash.rs index ea182a5..afd8a73 100644 --- a/zcash_primitives/src/pedersen_hash.rs +++ b/zcash_primitives/src/pedersen_hash.rs @@ -1,7 +1,7 @@ //! Implementation of the Pedersen hash function used in Sapling. use crate::jubjub::*; -use ff::{Field, PrimeField, PrimeFieldRepr}; +use ff::Field; use std::ops::{AddAssign, Neg}; #[derive(Copy, Clone)] @@ -88,16 +88,14 @@ where let window = JubjubBls12::pedersen_hash_exp_window_size(); let window_mask = (1 << window) - 1; - let mut acc = acc.into_repr(); - let mut tmp = edwards::Point::zero(); while !acc.is_zero() { - let i = (acc.as_ref()[0] & window_mask) as usize; + let i = (acc & window_mask) as usize; tmp = tmp.add(&table[0][i], params); - acc.shr(window); + acc = acc >> window; table = &table[1..]; }