Use internal functions in Refundable modifiers

This commit is contained in:
Amir Bandeali 2019-09-01 17:43:18 -07:00
parent fe01a150f0
commit 406a78a11a
5 changed files with 47 additions and 21 deletions

View File

@ -25,7 +25,7 @@ import "./LibRichErrors.sol";
contract ReentrancyGuard {
// Locked state of mutex.
bool private locked = false;
bool private _locked = false;
/// @dev Functions with this modifer cannot be reentered. The mutex will be locked
/// before function execution and unlocked after.
@ -39,19 +39,19 @@ contract ReentrancyGuard {
internal
{
// Ensure mutex is unlocked.
if (locked) {
if (_locked) {
LibRichErrors.rrevert(
LibReentrancyGuardRichErrors.IllegalReentrancyError()
);
}
// Lock mutex.
locked = true;
_locked = true;
}
function _unlockMutex()
internal
{
// Unlock mutex.
locked = false;
_locked = false;
}
}

View File

@ -22,27 +22,33 @@ pragma solidity ^0.5.9;
contract Refundable {
// This bool is used by the refund modifier to allow for lazily evaluated refunds.
bool internal shouldNotRefund;
bool internal _shouldNotRefund;
modifier refundFinalBalance {
_;
if (!shouldNotRefund) {
_refundNonzeroBalance();
}
_refundNonZeroBalanceIfEnabled();
}
modifier disableRefundUntilEnd {
if (shouldNotRefund) {
if (_areRefundsDisabled()) {
_;
} else {
shouldNotRefund = true;
_disableRefund();
_;
shouldNotRefund = false;
_refundNonzeroBalance();
_enableRefund();
_refundNonZeroBalance();
}
}
function _refundNonzeroBalance()
function _refundNonZeroBalanceIfEnabled()
internal
{
if (!_areRefundsDisabled()) {
_refundNonZeroBalance();
}
}
function _refundNonZeroBalance()
internal
{
uint256 balance = address(this).balance;
@ -50,4 +56,24 @@ contract Refundable {
msg.sender.transfer(balance);
}
}
function _disableRefund()
internal
{
_shouldNotRefund = true;
}
function _enableRefund()
internal
{
_shouldNotRefund = false;
}
function _areRefundsDisabled()
internal
view
returns (bool)
{
return _shouldNotRefund;
}
}

View File

@ -24,17 +24,17 @@ import "../src/Refundable.sol";
contract TestRefundable is
Refundable
{
function refundNonzeroBalanceExternal()
function refundNonZeroBalanceExternal()
external
payable
{
_refundNonzeroBalance();
_refundNonZeroBalance();
}
function setShouldNotRefund(bool shouldNotRefundNew)
external
{
shouldNotRefund = shouldNotRefundNew;
_shouldNotRefund = shouldNotRefundNew;
}
function getShouldNotRefund()
@ -42,7 +42,7 @@ contract TestRefundable is
view
returns (bool)
{
return shouldNotRefund;
return _shouldNotRefund;
}
function refundFinalBalanceFunction()

View File

@ -37,12 +37,12 @@ contract TestRefundableReceiver {
/// @dev This function tests the behavior of the `refundNonzeroBalance` function by checking whether or
/// not the `callCounter` state variable changes after the `refundNonzeroBalance` is called.
/// @param testRefundable The TestRefundable that should be tested against.
function testRefundNonzeroBalance(TestRefundable testRefundable)
function testRefundNonZeroBalance(TestRefundable testRefundable)
external
payable
{
// Call `refundNonzeroBalance()` and forward all of the eth sent to the contract.
testRefundable.refundNonzeroBalanceExternal.value(msg.value)();
testRefundable.refundNonZeroBalanceExternal.value(msg.value)();
// If the value sent was nonzero, a check that a refund was received will be executed. Otherwise, the fallback
// function contains a check that will fail in the event that a value of zero was sent to the contract.

View File

@ -35,14 +35,14 @@ blockchainTests('Refundable', env => {
blockchainTests.resets('refundNonzeroBalance', () => {
it('should not send a refund when no value is sent', async () => {
// Send 100 wei to the refundable contract that should be refunded.
await receiver.testRefundNonzeroBalance.awaitTransactionSuccessAsync(refundable.address, {
await receiver.testRefundNonZeroBalance.awaitTransactionSuccessAsync(refundable.address, {
value: constants.ZERO_AMOUNT,
});
});
it('should send a full refund when nonzero value is sent', async () => {
// Send 100 wei to the refundable contract that should be refunded.
await receiver.testRefundNonzeroBalance.awaitTransactionSuccessAsync(refundable.address, {
await receiver.testRefundNonZeroBalance.awaitTransactionSuccessAsync(refundable.address, {
value: ONE_HUNDRED,
});
});