Use internal functions in Refundable modifiers

This commit is contained in:
Amir Bandeali 2019-09-01 17:43:18 -07:00
parent fe01a150f0
commit 406a78a11a
5 changed files with 47 additions and 21 deletions

View File

@ -25,7 +25,7 @@ import "./LibRichErrors.sol";
contract ReentrancyGuard { contract ReentrancyGuard {
// Locked state of mutex. // Locked state of mutex.
bool private locked = false; bool private _locked = false;
/// @dev Functions with this modifer cannot be reentered. The mutex will be locked /// @dev Functions with this modifer cannot be reentered. The mutex will be locked
/// before function execution and unlocked after. /// before function execution and unlocked after.
@ -39,19 +39,19 @@ contract ReentrancyGuard {
internal internal
{ {
// Ensure mutex is unlocked. // Ensure mutex is unlocked.
if (locked) { if (_locked) {
LibRichErrors.rrevert( LibRichErrors.rrevert(
LibReentrancyGuardRichErrors.IllegalReentrancyError() LibReentrancyGuardRichErrors.IllegalReentrancyError()
); );
} }
// Lock mutex. // Lock mutex.
locked = true; _locked = true;
} }
function _unlockMutex() function _unlockMutex()
internal internal
{ {
// Unlock mutex. // Unlock mutex.
locked = false; _locked = false;
} }
} }

View File

@ -22,27 +22,33 @@ pragma solidity ^0.5.9;
contract Refundable { contract Refundable {
// This bool is used by the refund modifier to allow for lazily evaluated refunds. // This bool is used by the refund modifier to allow for lazily evaluated refunds.
bool internal shouldNotRefund; bool internal _shouldNotRefund;
modifier refundFinalBalance { modifier refundFinalBalance {
_; _;
if (!shouldNotRefund) { _refundNonZeroBalanceIfEnabled();
_refundNonzeroBalance();
}
} }
modifier disableRefundUntilEnd { modifier disableRefundUntilEnd {
if (shouldNotRefund) { if (_areRefundsDisabled()) {
_; _;
} else { } else {
shouldNotRefund = true; _disableRefund();
_; _;
shouldNotRefund = false; _enableRefund();
_refundNonzeroBalance(); _refundNonZeroBalance();
} }
} }
function _refundNonzeroBalance() function _refundNonZeroBalanceIfEnabled()
internal
{
if (!_areRefundsDisabled()) {
_refundNonZeroBalance();
}
}
function _refundNonZeroBalance()
internal internal
{ {
uint256 balance = address(this).balance; uint256 balance = address(this).balance;
@ -50,4 +56,24 @@ contract Refundable {
msg.sender.transfer(balance); msg.sender.transfer(balance);
} }
} }
function _disableRefund()
internal
{
_shouldNotRefund = true;
}
function _enableRefund()
internal
{
_shouldNotRefund = false;
}
function _areRefundsDisabled()
internal
view
returns (bool)
{
return _shouldNotRefund;
}
} }

View File

@ -24,17 +24,17 @@ import "../src/Refundable.sol";
contract TestRefundable is contract TestRefundable is
Refundable Refundable
{ {
function refundNonzeroBalanceExternal() function refundNonZeroBalanceExternal()
external external
payable payable
{ {
_refundNonzeroBalance(); _refundNonZeroBalance();
} }
function setShouldNotRefund(bool shouldNotRefundNew) function setShouldNotRefund(bool shouldNotRefundNew)
external external
{ {
shouldNotRefund = shouldNotRefundNew; _shouldNotRefund = shouldNotRefundNew;
} }
function getShouldNotRefund() function getShouldNotRefund()
@ -42,7 +42,7 @@ contract TestRefundable is
view view
returns (bool) returns (bool)
{ {
return shouldNotRefund; return _shouldNotRefund;
} }
function refundFinalBalanceFunction() function refundFinalBalanceFunction()

View File

@ -37,12 +37,12 @@ contract TestRefundableReceiver {
/// @dev This function tests the behavior of the `refundNonzeroBalance` function by checking whether or /// @dev This function tests the behavior of the `refundNonzeroBalance` function by checking whether or
/// not the `callCounter` state variable changes after the `refundNonzeroBalance` is called. /// not the `callCounter` state variable changes after the `refundNonzeroBalance` is called.
/// @param testRefundable The TestRefundable that should be tested against. /// @param testRefundable The TestRefundable that should be tested against.
function testRefundNonzeroBalance(TestRefundable testRefundable) function testRefundNonZeroBalance(TestRefundable testRefundable)
external external
payable payable
{ {
// Call `refundNonzeroBalance()` and forward all of the eth sent to the contract. // Call `refundNonzeroBalance()` and forward all of the eth sent to the contract.
testRefundable.refundNonzeroBalanceExternal.value(msg.value)(); testRefundable.refundNonZeroBalanceExternal.value(msg.value)();
// If the value sent was nonzero, a check that a refund was received will be executed. Otherwise, the fallback // If the value sent was nonzero, a check that a refund was received will be executed. Otherwise, the fallback
// function contains a check that will fail in the event that a value of zero was sent to the contract. // function contains a check that will fail in the event that a value of zero was sent to the contract.

View File

@ -35,14 +35,14 @@ blockchainTests('Refundable', env => {
blockchainTests.resets('refundNonzeroBalance', () => { blockchainTests.resets('refundNonzeroBalance', () => {
it('should not send a refund when no value is sent', async () => { it('should not send a refund when no value is sent', async () => {
// Send 100 wei to the refundable contract that should be refunded. // Send 100 wei to the refundable contract that should be refunded.
await receiver.testRefundNonzeroBalance.awaitTransactionSuccessAsync(refundable.address, { await receiver.testRefundNonZeroBalance.awaitTransactionSuccessAsync(refundable.address, {
value: constants.ZERO_AMOUNT, value: constants.ZERO_AMOUNT,
}); });
}); });
it('should send a full refund when nonzero value is sent', async () => { it('should send a full refund when nonzero value is sent', async () => {
// Send 100 wei to the refundable contract that should be refunded. // Send 100 wei to the refundable contract that should be refunded.
await receiver.testRefundNonzeroBalance.awaitTransactionSuccessAsync(refundable.address, { await receiver.testRefundNonZeroBalance.awaitTransactionSuccessAsync(refundable.address, {
value: ONE_HUNDRED, value: ONE_HUNDRED,
}); });
}); });