diff --git a/Sources/NIOTransportServices/NIOTSConnectionChannel.swift b/Sources/NIOTransportServices/NIOTSConnectionChannel.swift index 6b540c5..f8eadad 100644 --- a/Sources/NIOTransportServices/NIOTSConnectionChannel.swift +++ b/Sources/NIOTransportServices/NIOTSConnectionChannel.swift @@ -67,6 +67,18 @@ private struct ConnectionChannelOptions { private typealias PendingWrite = (data: ByteBuffer, promise: EventLoopPromise?) +internal struct AddressCache { + // deliberately lets because they must always be updated together (so forcing `init` is useful). + let local: Optional + let remote: Optional + + init(local: SocketAddress?, remote: SocketAddress?) { + self.local = local + self.remote = remote + } +} + + /// A structure that manages backpressure signaling on this channel. @available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) private struct BackpressureManager { @@ -211,6 +223,12 @@ internal final class NIOTSConnectionChannel { /// Whether to use peer-to-peer connectivity when connecting to Bonjour services. private var enablePeerToPeer = false + /// The cache of the local and remote socket addresses. Must be accessed using _addressCacheLock. + private var _addressCache = AddressCache(local: nil, remote: nil) + + /// A lock that guards the _addressCache. + private let _addressCacheLock = Lock() + /// Create a `NIOTSConnectionChannel` on a given `NIOTSEventLoop`. /// /// Note that `NIOTSConnectionChannel` objects cannot be created on arbitrary loops types. @@ -257,19 +275,15 @@ extension NIOTSConnectionChannel: Channel { /// The local address for this channel. public var localAddress: SocketAddress? { - if self.eventLoop.inEventLoop { - return try? self.localAddress0() - } else { - return self.connectionQueue.sync { try? self.localAddress0() } + return self._addressCacheLock.withLock { + return self._addressCache.local } } /// The remote address for this channel. public var remoteAddress: SocketAddress? { - if self.eventLoop.inEventLoop { - return try? self.remoteAddress0() - } else { - return self.connectionQueue.sync { try? self.remoteAddress0() } + return self._addressCacheLock.withLock { + return self._addressCache.remote } } @@ -748,6 +762,15 @@ extension NIOTSConnectionChannel { private func connectionComplete0() { let promise = self.connectPromise self.connectPromise = nil + + // Before becoming active, update the cached addresses. + let localAddress = try? self.localAddress0() + let remoteAddress = try? self.remoteAddress0() + + self._addressCacheLock.withLock { + self._addressCache = AddressCache(local: localAddress, remote: remoteAddress) + } + self.becomeActive0(promise: promise) if let metadata = self.nwConnection?.metadata(definition: NWProtocolTLS.definition) as? NWProtocolTLS.Metadata { diff --git a/Sources/NIOTransportServices/NIOTSListenerChannel.swift b/Sources/NIOTransportServices/NIOTSListenerChannel.swift index 50af01f..b7001e8 100644 --- a/Sources/NIOTransportServices/NIOTSListenerChannel.swift +++ b/Sources/NIOTransportServices/NIOTSListenerChannel.swift @@ -95,6 +95,12 @@ internal final class NIOTSListenerChannel { /// The TLS options to use for child channels. private let childTLSOptions: NWProtocolTLS.Options? + /// The cache of the local and remote socket addresses. Must be accessed using _addressCacheLock. + private var _addressCache = AddressCache(local: nil, remote: nil) + + /// A lock that guards the _addressCache. + private let _addressCacheLock = Lock() + /// Create a `NIOTSListenerChannel` on a given `NIOTSEventLoop`. /// @@ -133,19 +139,15 @@ extension NIOTSListenerChannel: Channel { /// The local address for this channel. public var localAddress: SocketAddress? { - if self.eventLoop.inEventLoop { - return try? self.localAddress0() - } else { - return self.connectionQueue.sync { try? self.localAddress0() } + return self._addressCacheLock.withLock { + return self._addressCache.local } } /// The remote address for this channel. public var remoteAddress: SocketAddress? { - if self.eventLoop.inEventLoop { - return try? self.remoteAddress0() - } else { - return self.connectionQueue.sync { try? self.remoteAddress0() } + return self._addressCacheLock.withLock { + return self._addressCache.remote } } @@ -456,6 +458,14 @@ extension NIOTSListenerChannel { private func bindComplete0() { let promise = self.bindPromise self.bindPromise = nil + + // Before becoming active, update the cached addresses. Remote is always nil. + let localAddress = try? self.localAddress0() + + self._addressCacheLock.withLock { + self._addressCache = AddressCache(local: localAddress, remote: nil) + } + self.becomeActive0(promise: promise) } } diff --git a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift index 4af147c..5602c42 100644 --- a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift @@ -710,8 +710,32 @@ class NIOTSConnectionChannelTests: XCTestCase { XCTAssertNoThrow(try connection.eventLoop.submit { XCTAssertEqual(testHandler.readCount, 2) }.wait()) + } - + func testLoadingAddressesInMultipleQueues() throws { + let listener = try NIOTSListenerBootstrap(group: self.group) + .bind(host: "localhost", port: 0).wait() + defer { + XCTAssertNoThrow(try listener.close().wait()) + } + + let ourSyncQueue = DispatchQueue(label: "ourSyncQueue") + + let workFuture = NIOTSConnectionBootstrap(group: self.group).connect(to: listener.localAddress!).map { channel -> Channel in + XCTAssertTrue(channel.eventLoop.inEventLoop) + + ourSyncQueue.sync { + XCTAssertFalse(channel.eventLoop.inEventLoop) + + // These will crash before we apply our fix. + XCTAssertNotNil(channel.localAddress) + XCTAssertNotNil(channel.remoteAddress) + } + + return channel + }.flatMap { $0.close() } + + XCTAssertNoThrow(try workFuture.wait()) } } #endif diff --git a/Tests/NIOTransportServicesTests/NIOTSEndToEndTests.swift b/Tests/NIOTransportServicesTests/NIOTSEndToEndTests.swift index ab8840e..337d9bb 100644 --- a/Tests/NIOTransportServicesTests/NIOTSEndToEndTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSEndToEndTests.swift @@ -160,6 +160,27 @@ final class FailOnHalfCloseHandler: ChannelInboundHandler { } +final class WaitForActiveHandler: ChannelInboundHandler { + typealias InboundIn = Any + + private let activePromise: EventLoopPromise + + init(_ promise: EventLoopPromise) { + self.activePromise = promise + } + + func handlerAdded(context: ChannelHandlerContext) { + if context.channel.isActive { + self.activePromise.succeed(context.channel) + } + } + + func channelActive(context: ChannelHandlerContext) { + self.activePromise.succeed(context.channel) + } +} + + extension Channel { /// Expect that the given bytes will be received. func expectRead(_ bytes: ByteBuffer) -> EventLoopFuture { @@ -298,8 +319,10 @@ class NIOTSEndToEndTests: XCTestCase { let serverSideConnectionPromise: EventLoopPromise = self.group.next().makePromise() let listener = try NIOTSListenerBootstrap(group: self.group) .childChannelInitializer { channel in - serverSideConnectionPromise.succeed(channel) - return channel.pipeline.addHandler(EchoHandler()) + return channel.pipeline.addHandlers([ + WaitForActiveHandler(serverSideConnectionPromise), + EchoHandler() + ]) } .bind(host: "localhost", port: 0).wait() defer { diff --git a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift index cb00e36..b875201 100644 --- a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift @@ -237,6 +237,7 @@ class NIOTSListenerChannelTests: XCTestCase { func channelRead(context: ChannelHandlerContext, data: NIOAny) { let channel = self.unwrapInboundIn(data) self.promise.succeed(channel) + context.fireChannelRead(data) } } @@ -255,7 +256,13 @@ class NIOTSListenerChannelTests: XCTestCase { XCTAssertNoThrow(try connection.close().wait()) } - let promisedChannel = try channelPromise.futureResult.wait() + // We must wait for channel active here, or the socket addresses won't be set. + let promisedChannel = try channelPromise.futureResult.flatMap { (channel) -> EventLoopFuture in + let promiseChannelActive = channel.eventLoop.makePromise(of: Channel.self) + _ = channel.pipeline.addHandler(WaitForActiveHandler(promiseChannelActive)) + return promiseChannelActive.futureResult + }.wait() + XCTAssertEqual(promisedChannel.remoteAddress, connection.localAddress) XCTAssertEqual(promisedChannel.localAddress, connection.remoteAddress) } @@ -275,5 +282,31 @@ class NIOTSListenerChannelTests: XCTestCase { XCTAssertEqual(error as? NIOTSErrors.BindTimeout, NIOTSErrors.BindTimeout(timeout: .nanoseconds(0)), "unexpected error: \(error)") } } + + func testLoadingAddressesInMultipleQueues() throws { + let listener = try NIOTSListenerBootstrap(group: self.group) + .bind(host: "localhost", port: 0).wait() + defer { + XCTAssertNoThrow(try listener.close().wait()) + } + + let ourSyncQueue = DispatchQueue(label: "ourSyncQueue") + + let workFuture = NIOTSConnectionBootstrap(group: self.group).connect(to: listener.localAddress!).map { channel -> Channel in + XCTAssertTrue(listener.eventLoop.inEventLoop) + + ourSyncQueue.sync { + XCTAssertFalse(listener.eventLoop.inEventLoop) + + // These will crash before we apply our fix. + XCTAssertNotNil(listener.localAddress) + XCTAssertNil(listener.remoteAddress) + } + + return channel + }.flatMap { $0.close() } + + XCTAssertNoThrow(try workFuture.wait()) + } } #endif