From 406a78a11af9d720df6bd95797c196e1399a6da1 Mon Sep 17 00:00:00 2001 From: Amir Bandeali Date: Sun, 1 Sep 2019 17:43:18 -0700 Subject: [PATCH] Use internal functions in Refundable modifiers --- .../utils/contracts/src/ReentrancyGuard.sol | 8 ++-- contracts/utils/contracts/src/Refundable.sol | 44 +++++++++++++++---- .../utils/contracts/test/TestRefundable.sol | 8 ++-- .../contracts/test/TestRefundableReceiver.sol | 4 +- contracts/utils/test/refundable.ts | 4 +- 5 files changed, 47 insertions(+), 21 deletions(-) diff --git a/contracts/utils/contracts/src/ReentrancyGuard.sol b/contracts/utils/contracts/src/ReentrancyGuard.sol index 78dc68e0ea..273483e3c5 100644 --- a/contracts/utils/contracts/src/ReentrancyGuard.sol +++ b/contracts/utils/contracts/src/ReentrancyGuard.sol @@ -25,7 +25,7 @@ import "./LibRichErrors.sol"; contract ReentrancyGuard { // Locked state of mutex. - bool private locked = false; + bool private _locked = false; /// @dev Functions with this modifer cannot be reentered. The mutex will be locked /// before function execution and unlocked after. @@ -39,19 +39,19 @@ contract ReentrancyGuard { internal { // Ensure mutex is unlocked. - if (locked) { + if (_locked) { LibRichErrors.rrevert( LibReentrancyGuardRichErrors.IllegalReentrancyError() ); } // Lock mutex. - locked = true; + _locked = true; } function _unlockMutex() internal { // Unlock mutex. - locked = false; + _locked = false; } } diff --git a/contracts/utils/contracts/src/Refundable.sol b/contracts/utils/contracts/src/Refundable.sol index 091c62aed6..175773aee1 100644 --- a/contracts/utils/contracts/src/Refundable.sol +++ b/contracts/utils/contracts/src/Refundable.sol @@ -22,27 +22,33 @@ pragma solidity ^0.5.9; contract Refundable { // This bool is used by the refund modifier to allow for lazily evaluated refunds. - bool internal shouldNotRefund; + bool internal _shouldNotRefund; modifier refundFinalBalance { _; - if (!shouldNotRefund) { - _refundNonzeroBalance(); - } + _refundNonZeroBalanceIfEnabled(); } modifier disableRefundUntilEnd { - if (shouldNotRefund) { + if (_areRefundsDisabled()) { _; } else { - shouldNotRefund = true; + _disableRefund(); _; - shouldNotRefund = false; - _refundNonzeroBalance(); + _enableRefund(); + _refundNonZeroBalance(); } } - function _refundNonzeroBalance() + function _refundNonZeroBalanceIfEnabled() + internal + { + if (!_areRefundsDisabled()) { + _refundNonZeroBalance(); + } + } + + function _refundNonZeroBalance() internal { uint256 balance = address(this).balance; @@ -50,4 +56,24 @@ contract Refundable { msg.sender.transfer(balance); } } + + function _disableRefund() + internal + { + _shouldNotRefund = true; + } + + function _enableRefund() + internal + { + _shouldNotRefund = false; + } + + function _areRefundsDisabled() + internal + view + returns (bool) + { + return _shouldNotRefund; + } } diff --git a/contracts/utils/contracts/test/TestRefundable.sol b/contracts/utils/contracts/test/TestRefundable.sol index a86cd67cab..86df80185f 100644 --- a/contracts/utils/contracts/test/TestRefundable.sol +++ b/contracts/utils/contracts/test/TestRefundable.sol @@ -24,17 +24,17 @@ import "../src/Refundable.sol"; contract TestRefundable is Refundable { - function refundNonzeroBalanceExternal() + function refundNonZeroBalanceExternal() external payable { - _refundNonzeroBalance(); + _refundNonZeroBalance(); } function setShouldNotRefund(bool shouldNotRefundNew) external { - shouldNotRefund = shouldNotRefundNew; + _shouldNotRefund = shouldNotRefundNew; } function getShouldNotRefund() @@ -42,7 +42,7 @@ contract TestRefundable is view returns (bool) { - return shouldNotRefund; + return _shouldNotRefund; } function refundFinalBalanceFunction() diff --git a/contracts/utils/contracts/test/TestRefundableReceiver.sol b/contracts/utils/contracts/test/TestRefundableReceiver.sol index 046fe69833..37cc4d7ae5 100644 --- a/contracts/utils/contracts/test/TestRefundableReceiver.sol +++ b/contracts/utils/contracts/test/TestRefundableReceiver.sol @@ -37,12 +37,12 @@ contract TestRefundableReceiver { /// @dev This function tests the behavior of the `refundNonzeroBalance` function by checking whether or /// not the `callCounter` state variable changes after the `refundNonzeroBalance` is called. /// @param testRefundable The TestRefundable that should be tested against. - function testRefundNonzeroBalance(TestRefundable testRefundable) + function testRefundNonZeroBalance(TestRefundable testRefundable) external payable { // Call `refundNonzeroBalance()` and forward all of the eth sent to the contract. - testRefundable.refundNonzeroBalanceExternal.value(msg.value)(); + testRefundable.refundNonZeroBalanceExternal.value(msg.value)(); // If the value sent was nonzero, a check that a refund was received will be executed. Otherwise, the fallback // function contains a check that will fail in the event that a value of zero was sent to the contract. diff --git a/contracts/utils/test/refundable.ts b/contracts/utils/test/refundable.ts index c295e26c7d..8411716231 100644 --- a/contracts/utils/test/refundable.ts +++ b/contracts/utils/test/refundable.ts @@ -35,14 +35,14 @@ blockchainTests('Refundable', env => { blockchainTests.resets('refundNonzeroBalance', () => { it('should not send a refund when no value is sent', async () => { // Send 100 wei to the refundable contract that should be refunded. - await receiver.testRefundNonzeroBalance.awaitTransactionSuccessAsync(refundable.address, { + await receiver.testRefundNonZeroBalance.awaitTransactionSuccessAsync(refundable.address, { value: constants.ZERO_AMOUNT, }); }); it('should send a full refund when nonzero value is sent', async () => { // Send 100 wei to the refundable contract that should be refunded. - await receiver.testRefundNonzeroBalance.awaitTransactionSuccessAsync(refundable.address, { + await receiver.testRefundNonZeroBalance.awaitTransactionSuccessAsync(refundable.address, { value: ONE_HUNDRED, }); });