diff --git a/Sources/NIOTransportServices/NIOTSErrors.swift b/Sources/NIOTransportServices/NIOTSErrors.swift index 3671f13..0d29784 100644 --- a/Sources/NIOTransportServices/NIOTSErrors.swift +++ b/Sources/NIOTransportServices/NIOTSErrors.swift @@ -57,5 +57,15 @@ public enum NIOTSErrors { /// `UnableToResolveEndpoint` is thrown when an attempt is made to resolve a local endpoint, but /// insufficient information is available to create it. public struct UnableToResolveEndpoint: NIOTSError { } + + /// `BindTimeout` is thrown when a timeout set for a `NWListenerBootstrap.bind` call has been exceeded + /// without successfully binding the address. + public struct BindTimeout: NIOTSError { + public var timeout: TimeAmount + + public init(timeout: TimeAmount) { + self.timeout = timeout + } + } } #endif diff --git a/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift b/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift index ef4eeae..d6717cb 100644 --- a/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift +++ b/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift @@ -31,6 +31,7 @@ public final class NIOTSListenerBootstrap { private var childQoS: DispatchQoS? private var tcpOptions: NWProtocolTCP.Options = .init() private var tlsOptions: NWProtocolTLS.Options? + private var bindTimeout: TimeAmount? /// Create a `NIOTSListenerBootstrap` for the `EventLoopGroup` `group`. /// @@ -140,6 +141,14 @@ public final class NIOTSListenerBootstrap { return self } + /// Specifies a timeout to apply to a bind attempt. + // + /// - parameters: + /// - timeout: The timeout that will apply to the bind attempt. + public func bindTimeout(_ timeout: TimeAmount) -> Self { + self.bindTimeout = timeout + return self + } /// Specifies a QoS to use for the server channel, instead of the default QoS for the /// event loop. @@ -186,18 +195,16 @@ public final class NIOTSListenerBootstrap { /// - host: The host to bind on. /// - port: The port to bind on. public func bind(host: String, port: Int) -> EventLoopFuture { - return self.bind0 { channel in - let p: EventLoopPromise = channel.eventLoop.makePromise() + return self.bind0 { (channel, promise) in do { // NWListener does not actually resolve hostname-based NWEndpoints // for use with requiredLocalEndpoint, so we fall back to // SocketAddress for this. let address = try SocketAddress.makeAddressResolvingHost(host, port: port) - channel.bind(to: address, promise: p) + channel.bind(to: address, promise: promise) } catch { - p.fail(error) + promise.fail(error) } - return p.futureResult } } @@ -206,8 +213,8 @@ public final class NIOTSListenerBootstrap { /// - parameters: /// - address: The `SocketAddress` to bind on. public func bind(to address: SocketAddress) -> EventLoopFuture { - return self.bind0 { channel in - channel.bind(to: address) + return self.bind0 { (channel, promise) in + channel.bind(to: address, promise: promise) } } @@ -216,15 +223,13 @@ public final class NIOTSListenerBootstrap { /// - parameters: /// - unixDomainSocketPath: The _Unix domain socket_ path to bind to. `unixDomainSocketPath` must not exist, it will be created by the system. public func bind(unixDomainSocketPath: String) -> EventLoopFuture { - return self.bind0 { channel in - let p: EventLoopPromise = channel.eventLoop.makePromise() + return self.bind0 { (channel, promise) in do { let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) - channel.bind(to: address, promise: p) + channel.bind(to: address, promise: promise) } catch { - p.fail(error) + promise.fail(error) } - return p.futureResult } } @@ -233,12 +238,12 @@ public final class NIOTSListenerBootstrap { /// - parameters: /// - endpoint: The `NWEndpoint` to bind this channel to. public func bind(endpoint: NWEndpoint) -> EventLoopFuture { - return self.bind0 { channel in - channel.triggerUserOutboundEvent(NIOTSNetworkEvents.BindToNWEndpoint(endpoint: endpoint)) + return self.bind0 { (channel, promise) in + channel.triggerUserOutboundEvent(NIOTSNetworkEvents.BindToNWEndpoint(endpoint: endpoint), promise: promise) } } - private func bind0(_ binder: @escaping (Channel) -> EventLoopFuture) -> EventLoopFuture { + private func bind0(_ binder: @escaping (Channel, EventLoopPromise) -> Void) -> EventLoopFuture { let eventLoop = self.group.next() as! NIOTSEventLoop let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) } let childChannelInit = self.childChannelInit @@ -263,7 +268,20 @@ public final class NIOTSListenerBootstrap { }.flatMap { serverChannel.register() }.flatMap { - binder(serverChannel) + let bindPromise = eventLoop.makePromise(of: Void.self) + binder(serverChannel, bindPromise) + + if let bindTimeout = self.bindTimeout { + let cancelTask = eventLoop.scheduleTask(in: bindTimeout) { + bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout)) + serverChannel.close(promise: nil) + } + + bindPromise.futureResult.whenComplete { (_: Result) in + cancelTask.cancel() + } + } + return bindPromise.futureResult }.map { serverChannel as Channel }.flatMapError { error in diff --git a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift index a2ab1ee..cb00e36 100644 --- a/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSListenerChannelTests.swift @@ -259,5 +259,21 @@ class NIOTSListenerChannelTests: XCTestCase { XCTAssertEqual(promisedChannel.remoteAddress, connection.localAddress) XCTAssertEqual(promisedChannel.localAddress, connection.remoteAddress) } + + func testBindTimeout() throws { + // Testing the bind timeout is damn fiddly, because I don't know a reliable way to force it + // to happen. The best approach I can think of is to set the timeout to "now". + // If you see this test fail, verify that it isn't a simple timing issue first. + let listener = NIOTSListenerBootstrap(group: self.group) + .bindTimeout(.nanoseconds(0)) + + do { + let channel = try listener.bind(host: "localhost", port: 0).wait() + XCTAssertNoThrow(try channel.close().wait()) + XCTFail("Did not throw") + } catch { + XCTAssertEqual(error as? NIOTSErrors.BindTimeout, NIOTSErrors.BindTimeout(timeout: .nanoseconds(0)), "unexpected error: \(error)") + } + } } #endif