diff --git a/contracts/exchange-libs/CHANGELOG.json b/contracts/exchange-libs/CHANGELOG.json index f8ce90f4df..3fdc4345ba 100644 --- a/contracts/exchange-libs/CHANGELOG.json +++ b/contracts/exchange-libs/CHANGELOG.json @@ -57,6 +57,14 @@ { "note": "Add reference functions for `LibMath` and `LibFillResults`", "pr": "TODO" + }, + { + "note": "Bring in revamped `LibMath` tests from the `contracts-exchange` package.", + "pr": "TODO" + }, + { + "note": "Remove unecessary zero-denominator checks in `LibMath`.", + "pr": "TODO" } ] }, diff --git a/contracts/exchange-libs/contracts/src/LibMath.sol b/contracts/exchange-libs/contracts/src/LibMath.sol index 6cc082b052..bc145499af 100644 --- a/contracts/exchange-libs/contracts/src/LibMath.sol +++ b/contracts/exchange-libs/contracts/src/LibMath.sol @@ -41,10 +41,6 @@ contract LibMath is pure returns (uint256 partialAmount) { - if (denominator == 0) { - LibRichErrors._rrevert(LibMathRichErrors.DivisionByZeroError()); - } - if (_isRoundingErrorFloor( numerator, denominator, @@ -79,10 +75,6 @@ contract LibMath is pure returns (uint256 partialAmount) { - if (denominator == 0) { - LibRichErrors._rrevert(LibMathRichErrors.DivisionByZeroError()); - } - if (_isRoundingErrorCeil( numerator, denominator, @@ -122,10 +114,6 @@ contract LibMath is pure returns (uint256 partialAmount) { - if (denominator == 0) { - LibRichErrors._rrevert(LibMathRichErrors.DivisionByZeroError()); - } - partialAmount = _safeDiv( _safeMul(numerator, target), denominator @@ -147,10 +135,6 @@ contract LibMath is pure returns (uint256 partialAmount) { - if (denominator == 0) { - LibRichErrors._rrevert(LibMathRichErrors.DivisionByZeroError()); - } - // _safeDiv computes `floor(a / b)`. We use the identity (a, b integer): // ceil(a / b) = floor((a + b - 1) / b) // To implement `ceil(a / b)` using _safeDiv. diff --git a/contracts/exchange-libs/contracts/test/TestLibs.sol b/contracts/exchange-libs/contracts/test/TestLibs.sol index 8f2925af83..43beae3161 100644 --- a/contracts/exchange-libs/contracts/test/TestLibs.sol +++ b/contracts/exchange-libs/contracts/test/TestLibs.sol @@ -73,6 +73,40 @@ contract TestLibs is return partialAmount; } + function safeGetPartialAmountFloor( + uint256 numerator, + uint256 denominator, + uint256 target + ) + public + pure + returns (uint256 partialAmount) + { + partialAmount = _safeGetPartialAmountFloor( + numerator, + denominator, + target + ); + return partialAmount; + } + + function safeGetPartialAmountCeil( + uint256 numerator, + uint256 denominator, + uint256 target + ) + public + pure + returns (uint256 partialAmount) + { + partialAmount = _safeGetPartialAmountCeil( + numerator, + denominator, + target + ); + return partialAmount; + } + function isRoundingErrorFloor( uint256 numerator, uint256 denominator, diff --git a/contracts/exchange-libs/src/reference_functions.ts b/contracts/exchange-libs/src/reference_functions.ts index 6a41601166..8585f0a64c 100644 --- a/contracts/exchange-libs/src/reference_functions.ts +++ b/contracts/exchange-libs/src/reference_functions.ts @@ -21,8 +21,11 @@ export function isRoundingErrorFloor( if (numerator.eq(0) || target.eq(0)) { return false; } - const remainder = numerator.multipliedBy(target).mod(denominator); - return safeMul(new BigNumber(1000), remainder).gte(safeMul(numerator, target)); + const remainder = numerator.times(target).mod(denominator); + // Need to do this separately because solidity evaluates RHS of the comparison expression first. + const rhs = safeMul(numerator, target); + const lhs = safeMul(new BigNumber(1000), remainder); + return lhs.gte(rhs); } export function isRoundingErrorCeil( @@ -36,9 +39,12 @@ export function isRoundingErrorCeil( if (numerator.eq(0) || target.eq(0)) { return false; } - let remainder = numerator.multipliedBy(target).mod(denominator); + let remainder = numerator.times(target).mod(denominator); remainder = safeSub(denominator, remainder).mod(denominator); - return safeMul(new BigNumber(1000), remainder).gte(safeMul(numerator, target)); + // Need to do this separately because solidity evaluates RHS of the comparison expression first. + const rhs = safeMul(numerator, target); + const lhs = safeMul(new BigNumber(1000), remainder); + return lhs.gte(rhs); } export function safeGetPartialAmountFloor( @@ -46,9 +52,6 @@ export function safeGetPartialAmountFloor( denominator: BigNumber, target: BigNumber, ): BigNumber { - if (denominator.eq(0)) { - throw new LibMathRevertErrors.DivisionByZeroError(); - } if (isRoundingErrorFloor(numerator, denominator, target)) { throw new LibMathRevertErrors.RoundingError(numerator, denominator, target); } @@ -63,9 +66,6 @@ export function safeGetPartialAmountCeil( denominator: BigNumber, target: BigNumber, ): BigNumber { - if (denominator.eq(0)) { - throw new LibMathRevertErrors.DivisionByZeroError(); - } if (isRoundingErrorCeil(numerator, denominator, target)) { throw new LibMathRevertErrors.RoundingError(numerator, denominator, target); } @@ -83,9 +83,6 @@ export function getPartialAmountFloor( denominator: BigNumber, target: BigNumber, ): BigNumber { - if (denominator.eq(0)) { - throw new LibMathRevertErrors.DivisionByZeroError(); - } return safeDiv( safeMul(numerator, target), denominator, @@ -97,9 +94,6 @@ export function getPartialAmountCeil( denominator: BigNumber, target: BigNumber, ): BigNumber { - if (denominator.eq(0)) { - throw new LibMathRevertErrors.DivisionByZeroError(); - } return safeDiv( safeAdd( safeMul(numerator, target), diff --git a/contracts/exchange-libs/test/lib_math.ts b/contracts/exchange-libs/test/lib_math.ts new file mode 100644 index 0000000000..92da9baecc --- /dev/null +++ b/contracts/exchange-libs/test/lib_math.ts @@ -0,0 +1,109 @@ +import { + blockchainTests, + describe, + testCombinatoriallyWithReferenceFunc, + uint256Values, +} from '@0x/contracts-test-utils'; +import { BigNumber } from '@0x/utils'; +import * as _ from 'lodash'; + +import { artifacts, ReferenceFunctions, TestLibsContract } from '../src'; + +const CHAIN_ID = 1337; + +blockchainTests('LibMath', env => { + let libsContract: TestLibsContract; + + before(async () => { + libsContract = await TestLibsContract.deployFrom0xArtifactAsync( + artifacts.TestLibs, + env.provider, + env.txDefaults, + new BigNumber(CHAIN_ID), + ); + }); + + // Wrap a reference function with identical arguments in a promise. + function createAsyncReferenceFunction( + ref: (...args: any[]) => T, + ): (...args: any[]) => Promise { + return async (...args: any[]): Promise => { + return ref(...args); + }; + } + + function createContractTestFunction( + name: string, + ): (...args: any[]) => Promise { + return async (...args: any[]): Promise => { + const method = (libsContract as any)[name] as { callAsync: (...args: any[]) => Promise }; + return method.callAsync(...args); + }; + } + + describe('getPartialAmountFloor', () => { + describe.optional('combinatorial tests', () => { + testCombinatoriallyWithReferenceFunc( + 'getPartialAmountFloor', + createAsyncReferenceFunction(ReferenceFunctions.getPartialAmountFloor), + createContractTestFunction('getPartialAmountFloor'), + [uint256Values, uint256Values, uint256Values], + ); + }); + }); + + describe('getPartialAmountCeil', () => { + describe.optional('combinatorial tests', () => { + testCombinatoriallyWithReferenceFunc( + 'getPartialAmountCeil', + createAsyncReferenceFunction(ReferenceFunctions.getPartialAmountCeil), + createContractTestFunction('getPartialAmountCeil'), + [uint256Values, uint256Values, uint256Values], + ); + }); + }); + + describe('safeGetPartialAmountFloor', () => { + describe.optional('combinatorial tests', () => { + testCombinatoriallyWithReferenceFunc( + 'safeGetPartialAmountFloor', + createAsyncReferenceFunction(ReferenceFunctions.safeGetPartialAmountFloor), + createContractTestFunction('safeGetPartialAmountFloor'), + [uint256Values, uint256Values, uint256Values], + ); + }); + }); + + describe('safeGetPartialAmountCeil', () => { + describe.optional('combinatorial tests', () => { + testCombinatoriallyWithReferenceFunc( + 'safeGetPartialAmountCeil', + createAsyncReferenceFunction(ReferenceFunctions.safeGetPartialAmountCeil), + createContractTestFunction('safeGetPartialAmountCeil'), + [uint256Values, uint256Values, uint256Values], + ); + }); + }); + + describe('isRoundingErrorFloor', () => { + describe.optional('combinatorial tests', () => { + testCombinatoriallyWithReferenceFunc( + 'isRoundingErrorFloor', + createAsyncReferenceFunction(ReferenceFunctions.isRoundingErrorFloor), + createContractTestFunction('isRoundingErrorFloor'), + [uint256Values, uint256Values, uint256Values], + ); + }); + }); + + describe('isRoundingErrorCeil', () => { + describe.optional('combinatorial tests', () => { + testCombinatoriallyWithReferenceFunc( + 'isRoundingErrorCeil', + createAsyncReferenceFunction(ReferenceFunctions.isRoundingErrorCeil), + createContractTestFunction('isRoundingErrorCeil'), + [uint256Values, uint256Values, uint256Values], + ); + }); + }); +});