diff --git a/Sources/NIOTransportServices/StateManagedChannel.swift b/Sources/NIOTransportServices/StateManagedChannel.swift index 6fea1ea..72c7b0e 100644 --- a/Sources/NIOTransportServices/StateManagedChannel.swift +++ b/Sources/NIOTransportServices/StateManagedChannel.swift @@ -39,10 +39,14 @@ internal enum ChannelState { case active(ActiveSubstate) case inactive - fileprivate mutating func register() throws { + /// Unlike every other one of these methods, this one has a side-effect. This is because + /// it's impossible to correctly be in the reigstered state without verifying that + /// registration has occurred. + fileprivate mutating func register(eventLoop: NIOTSEventLoop, channel: Channel) throws { guard case .idle = self else { throw NIOTSErrors.InvalidChannelStateTransition() } + try eventLoop.register(channel) self = .registered } @@ -139,8 +143,7 @@ extension StateManagedChannel { public func register0(promise: EventLoopPromise?) { // TODO: does this need to do anything more than this? do { - try self.state.register() - try self.tsEventLoop.register(self) + try self.state.register(eventLoop: self.tsEventLoop, channel: self) self.pipeline.fireChannelRegistered() promise?.succeed(result: ()) } catch { @@ -151,8 +154,7 @@ extension StateManagedChannel { public func registerAlreadyConfigured0(promise: EventLoopPromise?) { do { - try self.state.register() - try self.tsEventLoop.register(self) + try self.state.register(eventLoop: self.tsEventLoop, channel: self) self.pipeline.fireChannelRegistered() try self.state.beginActivating() promise?.succeed(result: ()) diff --git a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift index d14b24d..1d1bb59 100644 --- a/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSConnectionChannelTests.swift @@ -593,4 +593,19 @@ class NIOTSConnectionChannelTests: XCTestCase { XCTAssertNoThrow(try channel.close().wait()) } + + func testConnectingChannelsOnShutdownEventLoopsFails() throws { + let temporaryGroup = NIOTSEventLoopGroup() + XCTAssertNoThrow(try temporaryGroup.syncShutdownGracefully()) + + let bootstrap = NIOTSConnectionBootstrap(group: temporaryGroup) + + do { + _ = try bootstrap.connect(host: "localhost", port: 12345).wait() + } catch EventLoopError.shutdown { + // Expected + } catch { + XCTFail("Unexpected error: \(error)") + } + } } diff --git a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift index 96de901..852b46c 100644 --- a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift @@ -184,4 +184,19 @@ class NIOTSListenerChannelTests: XCTestCase { XCTAssertNoThrow(try childChannel.close().wait()) XCTAssertNoThrow(try channel.closeFuture.wait()) } + + func testBindingChannelsOnShutdownEventLoopsFails() throws { + let temporaryGroup = NIOTSEventLoopGroup() + XCTAssertNoThrow(try temporaryGroup.syncShutdownGracefully()) + + let bootstrap = NIOTSListenerBootstrap(group: temporaryGroup) + + do { + _ = try bootstrap.bind(host: "localhost", port: 0).wait() + } catch EventLoopError.shutdown { + // Expected + } catch { + XCTFail("Unexpected error: \(error)") + } + } }