From 1761d4eafa64d59c16df340a70f7b64dc764f360 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 6 Dec 2018 12:22:58 +0000 Subject: [PATCH] Make Channel.isActive thread-safe. (#16) Motivation: Sadly I overlooked the fact that Channel.isActive is supposed to be safe to call from multiple threads: the implementation here was not. Modifications: Store the active state into an Atomic. Result: It will be thread-safe to ask if a channel is active. --- .../NIOTSConnectionChannel.swift | 3 ++ .../NIOTSListenerChannel.swift | 4 ++ .../StateManagedChannel.swift | 13 ++++--- .../NIOTSConnectionChannelTests.swift | 37 +++++++++++++++++++ 4 files changed, 51 insertions(+), 6 deletions(-) diff --git a/Sources/NIOTransportServices/NIOTSConnectionChannel.swift b/Sources/NIOTransportServices/NIOTSConnectionChannel.swift index bcacba1..b261dae 100644 --- a/Sources/NIOTransportServices/NIOTSConnectionChannel.swift +++ b/Sources/NIOTransportServices/NIOTSConnectionChannel.swift @@ -176,6 +176,9 @@ internal final class NIOTSConnectionChannel { /// The state of this connection channel. internal var state: ChannelState = .idle + /// The active state, used for safely reporting the channel state across threads. + internal var isActive0: Atomic = Atomic(value: false) + /// The kinds of channel activation this channel supports internal let supportedActivationType: ActivationType = .connect diff --git a/Sources/NIOTransportServices/NIOTSListenerChannel.swift b/Sources/NIOTransportServices/NIOTSListenerChannel.swift index 1cb9472..be78969 100644 --- a/Sources/NIOTransportServices/NIOTSListenerChannel.swift +++ b/Sources/NIOTransportServices/NIOTSListenerChannel.swift @@ -16,6 +16,7 @@ import Foundation import NIO import NIOFoundationCompat +import NIOConcurrencyHelpers import Dispatch import Network @@ -60,6 +61,9 @@ internal final class NIOTSListenerChannel { /// The kinds of channel activation this channel supports internal let supportedActivationType: ActivationType = .bind + /// The active state, used for safely reporting the channel state across threads. + internal var isActive0: Atomic = Atomic(value: false) + /// Whether a call to NWListener.receive has been made, but the completion /// handler has not yet been invoked. private var outstandingRead: Bool = false diff --git a/Sources/NIOTransportServices/StateManagedChannel.swift b/Sources/NIOTransportServices/StateManagedChannel.swift index 401d391..6fea1ea 100644 --- a/Sources/NIOTransportServices/StateManagedChannel.swift +++ b/Sources/NIOTransportServices/StateManagedChannel.swift @@ -16,6 +16,7 @@ import Foundation import NIO import NIOFoundationCompat +import NIOConcurrencyHelpers import Dispatch import Network @@ -93,6 +94,8 @@ internal protocol StateManagedChannel: Channel, ChannelCore { var state: ChannelState { get set } + var isActive0: Atomic { get set } + var tsEventLoop: NIOTSEventLoop { get } var closePromise: EventLoopPromise { get } @@ -119,12 +122,7 @@ extension StateManagedChannel { /// Whether this channel is currently active. public var isActive: Bool { - switch self.state { - case .active: - return true - case .idle, .registered, .activating, .inactive: - return false - } + return self.isActive0.load() } /// Whether this channel is currently closed. This is not necessary for the public @@ -196,6 +194,8 @@ extension StateManagedChannel { return } + self.isActive0.store(false) + self.doClose0(error: error) switch oldState { @@ -240,6 +240,7 @@ extension StateManagedChannel { return } + self.isActive0.store(true) promise?.succeed(result: ()) self.pipeline.fireChannelActive() self.readIfNeeded0() diff --git a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift index 3a5a589..d14b24d 100644 --- a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift @@ -80,6 +80,22 @@ final class DisableWaitingAfterConnect: ChannelOutboundHandler { } +final class PromiseOnActiveHandler: ChannelInboundHandler { + typealias InboundIn = Any + typealias InboudOut = Any + + private let promise: EventLoopPromise + + init(_ promise: EventLoopPromise) { + self.promise = promise + } + + func channelActive(ctx: ChannelHandlerContext) { + self.promise.succeed(result: ()) + } +} + + class NIOTSConnectionChannelTests: XCTestCase { private var group: NIOTSEventLoopGroup! @@ -556,4 +572,25 @@ class NIOTSConnectionChannelTests: XCTestCase { let conn = try connectFuture.wait() XCTAssertNoThrow(try conn.close().wait()) } + + func testCanSafelyInvokeActiveFromMultipleThreads() throws { + // This test exists to trigger TSAN violations if we screw things up. + let listener = try NIOTSListenerBootstrap(group: self.group) + .bind(host: "localhost", port: 0).wait() + defer { + XCTAssertNoThrow(try listener.close().wait()) + } + + let activePromise: EventLoopPromise = self.group.next().newPromise() + + let channel = try NIOTSConnectionBootstrap(group: self.group) + .channelInitializer { channel in + channel.pipeline.add(handler: PromiseOnActiveHandler(activePromise)) + }.connect(to: listener.localAddress!).wait() + + XCTAssertNoThrow(try activePromise.futureResult.wait()) + XCTAssertTrue(channel.isActive) + + XCTAssertNoThrow(try channel.close().wait()) + } }