Deduplicate Sapling key agreement logic

This commit is contained in:
Jack Grigg 2019-03-23 17:51:30 +13:00
parent 247f3fb038
commit 2b1583d75f
No known key found for this signature in database
GPG Key ID: 9E8255172BBF9898
3 changed files with 37 additions and 17 deletions

View File

@ -58,7 +58,7 @@ use std::ffi::OsString;
use std::os::windows::ffi::OsStringExt; use std::os::windows::ffi::OsStringExt;
use sapling_crypto::primitives::{ProofGenerationKey, ViewingKey}; use sapling_crypto::primitives::{ProofGenerationKey, ViewingKey};
use zcash_primitives::{sapling::spend_sig, JUBJUB}; use zcash_primitives::{note_encryption::sapling_ka_agree, sapling::spend_sig, JUBJUB};
use zcash_proofs::{ use zcash_proofs::{
load_parameters, load_parameters,
sapling::{CommitmentTreeWitness, SaplingProvingContext, SaplingVerificationContext}, sapling::{CommitmentTreeWitness, SaplingProvingContext, SaplingVerificationContext},
@ -536,15 +536,12 @@ pub extern "system" fn librustzcash_sapling_ka_agree(
Err(_) => return false, Err(_) => return false,
}; };
// Multiply by 8 // Compute key agreement
let p = p.mul_by_cofactor(&JUBJUB); let ka = sapling_ka_agree(&sk, &p);
// Multiply by sk
let p = p.mul(sk, &JUBJUB);
// Produce result // Produce result
let result = unsafe { &mut *result }; let result = unsafe { &mut *result };
p.write(&mut result[..]).expect("length is not 32 bytes"); result.copy_from_slice(&ka);
true true
} }

View File

@ -45,6 +45,14 @@ fn convert_subgroup<E: JubjubEngine, S1, S2>(from: &Point<E, S1>) -> Point<E, S2
} }
} }
impl<E: JubjubEngine> From<&Point<E, Unknown>> for Point<E, Unknown>
{
fn from(p: &Point<E, Unknown>) -> Point<E, Unknown>
{
p.clone()
}
}
impl<E: JubjubEngine> From<Point<E, PrimeOrder>> for Point<E, Unknown> impl<E: JubjubEngine> From<Point<E, PrimeOrder>> for Point<E, Unknown>
{ {
fn from(p: Point<E, PrimeOrder>) -> Point<E, Unknown> fn from(p: Point<E, PrimeOrder>) -> Point<E, Unknown>
@ -53,6 +61,14 @@ impl<E: JubjubEngine> From<Point<E, PrimeOrder>> for Point<E, Unknown>
} }
} }
impl<E: JubjubEngine> From<&Point<E, PrimeOrder>> for Point<E, Unknown>
{
fn from(p: &Point<E, PrimeOrder>) -> Point<E, Unknown>
{
convert_subgroup(p)
}
}
impl<E: JubjubEngine, Subgroup> Clone for Point<E, Subgroup> impl<E: JubjubEngine, Subgroup> Clone for Point<E, Subgroup>
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {

View File

@ -128,14 +128,21 @@ fn generate_esk() -> Fs {
Fs::to_uniform(&buffer[..]) Fs::to_uniform(&buffer[..])
} }
fn sapling_ka_agree(esk: &Fs, pk_d: &edwards::Point<Bls12, PrimeOrder>) -> Vec<u8> { pub fn sapling_ka_agree<'a, P>(esk: &Fs, pk_d: &'a P) -> [u8; 32]
let ka = pk_d where
.mul(esk.into_repr(), &JUBJUB) edwards::Point<Bls12, Unknown>: From<&'a P>,
.double(&JUBJUB) {
.double(&JUBJUB) let p: edwards::Point<Bls12, Unknown> = pk_d.into();
.double(&JUBJUB);
let mut result = Vec::with_capacity(32); // Multiply by 8
ka.write(&mut result).expect("length is not 32 bytes"); let p = p.mul_by_cofactor(&JUBJUB);
// Multiply by esk
let p = p.mul(*esk, &JUBJUB);
// Produce result
let mut result = [0; 32];
p.write(&mut result[..]).expect("length is not 32 bytes");
result result
} }
@ -294,7 +301,7 @@ pub fn try_sapling_note_decryption(
cmu: &Fr, cmu: &Fr,
enc_ciphertext: &[u8], enc_ciphertext: &[u8],
) -> Option<(Note<Bls12>, PaymentAddress<Bls12>, Memo)> { ) -> Option<(Note<Bls12>, PaymentAddress<Bls12>, Memo)> {
let shared_secret = sapling_ka_agree(&ivk, &epk); let shared_secret = sapling_ka_agree(ivk, epk);
let key = kdf_sapling(&shared_secret, &epk); let key = kdf_sapling(&shared_secret, &epk);
let mut plaintext = Vec::with_capacity(564); let mut plaintext = Vec::with_capacity(564);
@ -328,7 +335,7 @@ pub fn try_sapling_compact_note_decryption(
cmu: &Fr, cmu: &Fr,
enc_ciphertext: &[u8], enc_ciphertext: &[u8],
) -> Option<(Note<Bls12>, PaymentAddress<Bls12>)> { ) -> Option<(Note<Bls12>, PaymentAddress<Bls12>)> {
let shared_secret = sapling_ka_agree(&ivk, &epk); let shared_secret = sapling_ka_agree(ivk, epk);
let key = kdf_sapling(&shared_secret, &epk); let key = kdf_sapling(&shared_secret, &epk);
let nonce = [0u8; 12]; let nonce = [0u8; 12];