diff --git a/ff/ff_derive/src/lib.rs b/ff/ff_derive/src/lib.rs index 7b19d96..dceb508 100644 --- a/ff/ff_derive/src/lib.rs +++ b/ff/ff_derive/src/lib.rs @@ -47,7 +47,7 @@ pub fn prime_field(input: proc_macro::TokenStream) -> proc_macro::TokenStream { let mut gen = proc_macro2::TokenStream::new(); let (constants_impl, sqrt_impl) = - prime_field_constants_and_sqrt(&ast.ident, &repr_ident, modulus, limbs, generator); + prime_field_constants_and_sqrt(&ast.ident, &repr_ident, &modulus, limbs, generator); gen.extend(constants_impl); gen.extend(prime_field_repr_impl(&repr_ident, limbs)); @@ -383,7 +383,7 @@ fn test_exp() { fn prime_field_constants_and_sqrt( name: &syn::Ident, repr: &syn::Ident, - modulus: BigUint, + modulus: &BigUint, limbs: usize, generator: BigUint, ) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) { @@ -396,28 +396,26 @@ fn prime_field_constants_and_sqrt( let repr_shave_bits = (64 * limbs as u32) - biguint_num_bits(modulus.clone()); // Compute R = 2**(64 * limbs) mod m - let r = (BigUint::one() << (limbs * 64)) % &modulus; + let r = (BigUint::one() << (limbs * 64)) % modulus; // modulus - 1 = 2^s * t let mut s: u32 = 0; - let mut t = &modulus - BigUint::from_str("1").unwrap(); + let mut t = modulus - BigUint::from_str("1").unwrap(); while t.is_even() { t = t >> 1; s += 1; } // Compute 2^s root of unity given the generator - let root_of_unity = biguint_to_u64_vec( - (exp(generator.clone(), &t, &modulus) * &r) % &modulus, - limbs, - ); - let generator = biguint_to_u64_vec((generator.clone() * &r) % &modulus, limbs); + let root_of_unity = + biguint_to_u64_vec((exp(generator.clone(), &t, &modulus) * &r) % modulus, limbs); + let generator = biguint_to_u64_vec((generator.clone() * &r) % modulus, limbs); - let sqrt_impl = if (&modulus % BigUint::from_str("4").unwrap()) + let sqrt_impl = if (modulus % BigUint::from_str("4").unwrap()) == BigUint::from_str("3").unwrap() { let mod_plus_1_over_4 = - biguint_to_u64_vec((&modulus + BigUint::from_str("1").unwrap()) >> 2, limbs); + biguint_to_u64_vec((modulus + BigUint::from_str("1").unwrap()) >> 2, limbs); quote! { impl ::ff::SqrtField for #name { @@ -436,7 +434,7 @@ fn prime_field_constants_and_sqrt( } } } - } else if (&modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { + } else if (modulus % BigUint::from_str("16").unwrap()) == BigUint::from_str("1").unwrap() { let t_minus_1_over_2 = biguint_to_u64_vec((&t - BigUint::one()) >> 1, limbs); quote! { @@ -490,10 +488,10 @@ fn prime_field_constants_and_sqrt( }; // Compute R^2 mod m - let r2 = biguint_to_u64_vec((&r * &r) % &modulus, limbs); + let r2 = biguint_to_u64_vec((&r * &r) % modulus, limbs); let r = biguint_to_u64_vec(r, limbs); - let modulus = biguint_to_real_u64_vec(modulus, limbs); + let modulus = biguint_to_real_u64_vec(modulus.clone(), limbs); // Compute -m^-1 mod 2**64 by exponentiating by totient(2**64) - 1 let mut inv = 1u64;