diff --git a/core/src/main/java/com/google/bitcoin/protocols/channels/PaymentChannelServerState.java b/core/src/main/java/com/google/bitcoin/protocols/channels/PaymentChannelServerState.java index 5287940e..f672cf76 100644 --- a/core/src/main/java/com/google/bitcoin/protocols/channels/PaymentChannelServerState.java +++ b/core/src/main/java/com/google/bitcoin/protocols/channels/PaymentChannelServerState.java @@ -343,9 +343,9 @@ public class PaymentChannelServerState { storedServerChannel = null; StoredPaymentChannelServerStates channels = (StoredPaymentChannelServerStates) wallet.getExtensions().get(StoredPaymentChannelServerStates.EXTENSION_ID); - channels.closeChannel(temp); // Calls this method again for us - checkState(state.compareTo(State.CLOSING) >= 0); - return closedFuture; + channels.closeChannel(temp); // May call this method again for us (if it wasn't the original caller) + if (state.compareTo(State.CLOSING) >= 0) + return closedFuture; } if (state.ordinal() < State.READY.ordinal()) { diff --git a/core/src/main/java/com/google/bitcoin/protocols/channels/StoredPaymentChannelServerStates.java b/core/src/main/java/com/google/bitcoin/protocols/channels/StoredPaymentChannelServerStates.java index 7add1d8d..be3344d1 100644 --- a/core/src/main/java/com/google/bitcoin/protocols/channels/StoredPaymentChannelServerStates.java +++ b/core/src/main/java/com/google/bitcoin/protocols/channels/StoredPaymentChannelServerStates.java @@ -18,9 +18,12 @@ package com.google.bitcoin.protocols.channels; import java.io.*; import java.util.*; +import java.util.concurrent.locks.ReentrantLock; import com.google.bitcoin.core.*; +import com.google.bitcoin.utils.Locks; import com.google.common.annotations.VisibleForTesting; +import net.jcip.annotations.GuardedBy; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkNotNull; @@ -32,12 +35,14 @@ import static com.google.common.base.Preconditions.checkNotNull; public class StoredPaymentChannelServerStates implements WalletExtension { static final String EXTENSION_ID = StoredPaymentChannelServerStates.class.getName(); - @VisibleForTesting final Map mapChannels = new HashMap(); + @GuardedBy("lock") @VisibleForTesting final Map mapChannels = new HashMap(); private final Wallet wallet; private final TransactionBroadcaster broadcaster; private final Timer channelTimeoutHandler = new Timer(); + private final ReentrantLock lock = Locks.lock("StoredPaymentChannelServerStates"); + /** * The offset between the refund transaction's lock time and the time channels will be automatically closed. * This defines a window during which we must get the last payment transaction verified, ie it should allow time for @@ -64,19 +69,25 @@ public class StoredPaymentChannelServerStates implements WalletExtension { *

Removes the given channel from this set of {@link StoredServerChannel}s and notifies the wallet of a change to * this wallet extension.

*/ - public synchronized void closeChannel(StoredServerChannel channel) { + public void closeChannel(StoredServerChannel channel) { + lock.lock(); + try { + if (mapChannels.remove(channel.contract.getHash()) == null) + return; + } finally { + lock.unlock(); + } synchronized (channel) { - if (channel.connectedHandler != null) - channel.connectedHandler.close(); // connectedHandler will be reset to null in connectionClosed + if (channel.connectedHandler != null) // connectedHandler will be reset to null in connectionClosed + channel.connectedHandler.close(); // Closes the actual connection, not the channel try {//TODO add event listener to PaymentChannelServerStateManager - channel.getState(wallet, broadcaster).close(); // Closes the actual connection, not the channel + channel.getState(wallet, broadcaster).close(); } catch (ValueOutOfRangeException e) { - e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. + e.printStackTrace(); } catch (VerificationException e) { - e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. + e.printStackTrace(); } channel.state = null; - mapChannels.remove(channel.contract.getHash()); } wallet.addOrUpdateExtension(this); } @@ -84,8 +95,13 @@ public class StoredPaymentChannelServerStates implements WalletExtension { /** * Gets the {@link StoredServerChannel} with the given channel id (ie contract transaction hash). */ - public synchronized StoredServerChannel getChannel(Sha256Hash id) { - return mapChannels.get(id); + public StoredServerChannel getChannel(Sha256Hash id) { + lock.lock(); + try { + return mapChannels.get(id); + } finally { + lock.unlock(); + } } /** @@ -95,16 +111,21 @@ public class StoredPaymentChannelServerStates implements WalletExtension { *

Because there must be only one, canonical {@link StoredServerChannel} per channel, this method throws if the * channel is already present in the set of channels.

*/ - public synchronized void putChannel(final StoredServerChannel channel) { - checkArgument(mapChannels.put(channel.contract.getHash(), checkNotNull(channel)) == null); - channelTimeoutHandler.schedule(new TimerTask() { - @Override - public void run() { - closeChannel(channel); - } - // Add the difference between real time and Utils.now() so that test-cases can use a mock clock. - }, new Date((channel.refundTransactionUnlockTimeSecs + CHANNEL_EXPIRE_OFFSET)*1000L - + (System.currentTimeMillis() - Utils.now().getTime()))); + public void putChannel(final StoredServerChannel channel) { + lock.lock(); + try { + checkArgument(mapChannels.put(channel.contract.getHash(), checkNotNull(channel)) == null); + channelTimeoutHandler.schedule(new TimerTask() { + @Override + public void run() { + closeChannel(channel); + } + // Add the difference between real time and Utils.now() so that test-cases can use a mock clock. + }, new Date((channel.refundTransactionUnlockTimeSecs + CHANNEL_EXPIRE_OFFSET)*1000L + + (System.currentTimeMillis() - Utils.now().getTime()))); + } finally { + lock.unlock(); + } } @Override @@ -118,7 +139,8 @@ public class StoredPaymentChannelServerStates implements WalletExtension { } @Override - public synchronized byte[] serializeWalletExtension() { + public byte[] serializeWalletExtension() { + lock.lock(); try { ByteArrayOutputStream out = new ByteArrayOutputStream(); ObjectOutputStream oos = new ObjectOutputStream(out); @@ -128,17 +150,24 @@ public class StoredPaymentChannelServerStates implements WalletExtension { return out.toByteArray(); } catch (IOException e) { throw new RuntimeException(e); + } finally { + lock.unlock(); } } @Override - public synchronized void deserializeWalletExtension(Wallet containingWallet, byte[] data) throws Exception { - checkArgument(containingWallet == wallet); - ByteArrayInputStream inStream = new ByteArrayInputStream(data); - ObjectInputStream ois = new ObjectInputStream(inStream); - while (inStream.available() > 0) { - StoredServerChannel channel = (StoredServerChannel)ois.readObject(); - putChannel(channel); + public void deserializeWalletExtension(Wallet containingWallet, byte[] data) throws Exception { + lock.lock(); + try { + checkArgument(containingWallet == wallet); + ByteArrayInputStream inStream = new ByteArrayInputStream(data); + ObjectInputStream ois = new ObjectInputStream(inStream); + while (inStream.available() > 0) { + StoredServerChannel channel = (StoredServerChannel)ois.readObject(); + putChannel(channel); + } + } finally { + lock.unlock(); } } } diff --git a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java index 400d2f11..1b43003a 100644 --- a/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java +++ b/core/src/main/java/com/google/bitcoin/protocols/niowrapper/ProtobufServer.java @@ -25,7 +25,9 @@ import java.nio.channels.ServerSocketChannel; import java.nio.channels.SocketChannel; import java.nio.channels.spi.SelectorProvider; import java.util.Iterator; +import java.util.concurrent.locks.ReentrantLock; +import com.google.bitcoin.utils.Locks; import com.google.common.annotations.VisibleForTesting; import org.slf4j.LoggerFactory; @@ -48,6 +50,7 @@ public class ProtobufServer { private static final int BUFFER_SIZE_UPPER_BOUND = 65536; private class ConnectionHandler extends MessageWriteTarget { + private final ReentrantLock lock = Locks.lock("protobufServerConnectionHandler"); private final ByteBuffer dbuf; private final SocketChannel channel; private final ProtobufParser parser; @@ -66,13 +69,16 @@ public class ProtobufServer { } @Override - synchronized void writeBytes(byte[] message) { + void writeBytes(byte[] message) { + lock.lock(); try { if (channel.write(ByteBuffer.wrap(message)) != message.length) throw new IOException("Couldn't write all of message to socket"); } catch (IOException e) { log.error("Error writing message to connection, closing connection", e); closeConnection(); + } finally { + lock.unlock(); } } @@ -86,10 +92,17 @@ public class ProtobufServer { connectionClosed(); } - private synchronized void connectionClosed() { - if (!closeCalled) + private void connectionClosed() { + boolean callClosed = false; + lock.lock(); + try { + callClosed = !closeCalled; + closeCalled = true; + } finally { + lock.unlock(); + } + if (callClosed) parser.connectionClosed(); - closeCalled = true; } }