From ea5ba9bd258973f9153b16793cc9e7d97386a9e2 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Tue, 8 Jan 2019 21:40:57 +0000 Subject: [PATCH] Don't transition to .registered unless registration succeeds. (#19) Motivation: In some cases, such as when an event loop has been shutdown, channel registration may fail. In these cases, we would incorrectly attempt to deregister the channel, which would fail (and in debug builds, assert). Really, we shouldn't transition into .registered until we know that we have, in fact, registered. However, we need to be cautious: we don't want to register unless we believe we're in an acceptable state to register. Modifications: Updated the state enum to perform the registration at the correct part of the state change function. Result: Harder to crash in debug mode --- .../StateManagedChannel.swift | 12 +++++++----- .../NIOTSConnectionChannelTests.swift | 15 +++++++++++++++ .../NIOTSListenerChannelTests.swift | 15 +++++++++++++++ 3 files changed, 37 insertions(+), 5 deletions(-) diff --git a/Sources/NIOTransportServices/StateManagedChannel.swift b/Sources/NIOTransportServices/StateManagedChannel.swift index 6fea1ea..72c7b0e 100644 --- a/Sources/NIOTransportServices/StateManagedChannel.swift +++ b/Sources/NIOTransportServices/StateManagedChannel.swift @@ -39,10 +39,14 @@ internal enum ChannelState { case active(ActiveSubstate) case inactive - fileprivate mutating func register() throws { + /// Unlike every other one of these methods, this one has a side-effect. This is because + /// it's impossible to correctly be in the reigstered state without verifying that + /// registration has occurred. + fileprivate mutating func register(eventLoop: NIOTSEventLoop, channel: Channel) throws { guard case .idle = self else { throw NIOTSErrors.InvalidChannelStateTransition() } + try eventLoop.register(channel) self = .registered } @@ -139,8 +143,7 @@ extension StateManagedChannel { public func register0(promise: EventLoopPromise?) { // TODO: does this need to do anything more than this? do { - try self.state.register() - try self.tsEventLoop.register(self) + try self.state.register(eventLoop: self.tsEventLoop, channel: self) self.pipeline.fireChannelRegistered() promise?.succeed(result: ()) } catch { @@ -151,8 +154,7 @@ extension StateManagedChannel { public func registerAlreadyConfigured0(promise: EventLoopPromise?) { do { - try self.state.register() - try self.tsEventLoop.register(self) + try self.state.register(eventLoop: self.tsEventLoop, channel: self) self.pipeline.fireChannelRegistered() try self.state.beginActivating() promise?.succeed(result: ()) diff --git a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift index d14b24d..1d1bb59 100644 --- a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift @@ -593,4 +593,19 @@ class NIOTSConnectionChannelTests: XCTestCase { XCTAssertNoThrow(try channel.close().wait()) } + + func testConnectingChannelsOnShutdownEventLoopsFails() throws { + let temporaryGroup = NIOTSEventLoopGroup() + XCTAssertNoThrow(try temporaryGroup.syncShutdownGracefully()) + + let bootstrap = NIOTSConnectionBootstrap(group: temporaryGroup) + + do { + _ = try bootstrap.connect(host: "localhost", port: 12345).wait() + } catch EventLoopError.shutdown { + // Expected + } catch { + XCTFail("Unexpected error: \(error)") + } + } } diff --git a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift index 96de901..852b46c 100644 --- a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift @@ -184,4 +184,19 @@ class NIOTSListenerChannelTests: XCTestCase { XCTAssertNoThrow(try childChannel.close().wait()) XCTAssertNoThrow(try channel.closeFuture.wait()) } + + func testBindingChannelsOnShutdownEventLoopsFails() throws { + let temporaryGroup = NIOTSEventLoopGroup() + XCTAssertNoThrow(try temporaryGroup.syncShutdownGracefully()) + + let bootstrap = NIOTSListenerBootstrap(group: temporaryGroup) + + do { + _ = try bootstrap.bind(host: "localhost", port: 0).wait() + } catch EventLoopError.shutdown { + // Expected + } catch { + XCTFail("Unexpected error: \(error)") + } + } }