diff --git a/Package.swift b/Package.swift index a583dfe..9a31d54 100644 --- a/Package.swift +++ b/Package.swift @@ -23,7 +23,7 @@ let package = Package( .executable(name: "NIOTSHTTPServer", targets: ["NIOTSHTTPServer"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.11.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.15.0"), ], targets: [ .target(name: "NIOTransportServices", diff --git a/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift b/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift index 2d28b39..c85c34f 100644 --- a/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift +++ b/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift @@ -28,6 +28,7 @@ public final class NIOTSConnectionBootstrap { private var qos: DispatchQoS? private var tcpOptions: NWProtocolTCP.Options = .init() private var tlsOptions: NWProtocolTLS.Options? + private var protocolHandlers: Optional<() -> [ChannelHandler]> = nil /// Create a `NIOTSConnectionBootstrap` on the `EventLoopGroup` `group`. /// @@ -193,8 +194,23 @@ public final class NIOTSConnectionBootstrap { } } } + + /// Sets the protocol handlers that will be added to the front of the `ChannelPipeline` right after the + /// `channelInitializer` has been called. + /// + /// Per bootstrap, you can only set the `protocolHandlers` once. Typically, `protocolHandlers` are used for the TLS + /// implementation. Most notably, `NIOClientTCPBootstrap`, NIO's "universal bootstrap" abstraction, uses + /// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations. + public func protocolHandlers(_ handlers: @escaping () -> [ChannelHandler]) -> Self { + precondition(self.protocolHandlers == nil, "protocol handlers can only be set once") + self.protocolHandlers = handlers + return self + } } +@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) +extension NIOTSConnectionBootstrap: NIOClientTCPBootstrapProtocol {} + // This is a backport of ChannelOptions.Storage from SwiftNIO because the initializer wasn't public, so we couldn't actually build it. // When https://github.com/apple/swift-nio/pull/988 is in a shipped release, we can remove this and simply bump our lowest supported version of SwiftNIO. @available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) diff --git a/Sources/NIOTransportServices/NIOTSEventLoopGroup.swift b/Sources/NIOTransportServices/NIOTSEventLoopGroup.swift index 66fc4c6..9806700 100644 --- a/Sources/NIOTransportServices/NIOTSEventLoopGroup.swift +++ b/Sources/NIOTransportServices/NIOTSEventLoopGroup.swift @@ -85,4 +85,34 @@ public final class NIOTSEventLoopGroup: EventLoopGroup { return EventLoopIterator(self.eventLoops) } } + +/// A TLS provider to bootstrap TLS-enabled connections with `NIOClientTCPBootstrap`. +/// +/// Example: +/// +/// // Creating the "universal bootstrap" with the `NIOTSClientTLSProvider`. +/// let tlsProvider = NIOTSClientTLSProvider() +/// let bootstrap = NIOClientTCPBootstrap(NIOTSConnectionBootstrap(group: group), tls: tlsProvider) +/// +/// // Bootstrapping a connection using the "universal bootstrapping mechanism" +/// let connection = bootstrap.enableTLS() +/// .connect(host: "example.com", port: 443) +/// .wait() +@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) +public struct NIOTSClientTLSProvider: NIOClientTLSProvider { + public typealias Bootstrap = NIOTSConnectionBootstrap + + let tlsOptions: NWProtocolTLS.Options + + /// Construct the TLS provider. + public init(tlsOptions: NWProtocolTLS.Options = NWProtocolTLS.Options()) { + self.tlsOptions = tlsOptions + } + + /// Enable TLS on the bootstrap. This is not a function you will typically call as a user, it is called by + /// `NIOClientTCPBootstrap`. + public func enableTLS(_ bootstrap: NIOTSConnectionBootstrap) -> NIOTSConnectionBootstrap { + return bootstrap.tlsOptions(self.tlsOptions) + } +} #endif diff --git a/Tests/NIOTransportServicesTests/NIOTSBootstrapTests.swift b/Tests/NIOTransportServicesTests/NIOTSBootstrapTests.swift index 9e904d1..90bfdf5 100644 --- a/Tests/NIOTransportServicesTests/NIOTSBootstrapTests.swift +++ b/Tests/NIOTransportServicesTests/NIOTSBootstrapTests.swift @@ -92,6 +92,99 @@ final class NIOTSBootstrapTests: XCTestCase { XCTAssertNoThrow(try childChannelDone.futureResult.wait()) XCTAssertNoThrow(try serverChannelDone.futureResult.wait()) } + + func testUniveralBootstrapWorks() { + final class TellMeIfConnectionIsTLSHandler: ChannelInboundHandler { + typealias InboundIn = ByteBuffer + typealias OutboundOut = ByteBuffer + + private let isTLS: EventLoopPromise + private var buffer: ByteBuffer? + + init(isTLS: EventLoopPromise) { + self.isTLS = isTLS + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buffer = self.unwrapInboundIn(data) + + if self.buffer == nil { + self.buffer = buffer + } else { + self.buffer!.writeBuffer(&buffer) + } + + switch self.buffer!.readBytes(length: 2) { + case .some([0x16, 0x03]): // TLS ClientHello always starts with 0x16, 0x03 + self.isTLS.succeed(true) + context.channel.close(promise: nil) + case .some(_): + self.isTLS.succeed(false) + context.channel.close(promise: nil) + case .none: + // not enough data + () + } + } + } + let group = NIOTSEventLoopGroup() + func makeServer(isTLS: EventLoopPromise) throws -> Channel { + let numberOfConnections = NIOAtomic.makeAtomic(value: 0) + return try NIOTSListenerBootstrap(group: group) + .childChannelInitializer { channel in + XCTAssertEqual(0, numberOfConnections.add(1)) + return channel.pipeline.addHandler(TellMeIfConnectionIsTLSHandler(isTLS: isTLS)) + } + .bind(host: "127.0.0.1", port: 0) + .wait() + } + + let isTLSConnection1 = group.next().makePromise(of: Bool.self) + let isTLSConnection2 = group.next().makePromise(of: Bool.self) + + var maybeServer1: Channel? = nil + var maybeServer2: Channel? = nil + + XCTAssertNoThrow(maybeServer1 = try makeServer(isTLS: isTLSConnection1)) + XCTAssertNoThrow(maybeServer2 = try makeServer(isTLS: isTLSConnection2)) + + guard let server1 = maybeServer1, let server2 = maybeServer2 else { + XCTFail("can't make servers") + return + } + defer { + XCTAssertNoThrow(try server1.close().wait()) + XCTAssertNoThrow(try server2.close().wait()) + } + + let tlsOptions = NWProtocolTLS.Options() + let bootstrap = NIOClientTCPBootstrap(NIOTSConnectionBootstrap(group: group), + tls: NIOTSClientTLSProvider(tlsOptions: tlsOptions)) + let tlsBootstrap = NIOClientTCPBootstrap(NIOTSConnectionBootstrap(group: group), + tls: NIOTSClientTLSProvider()) + .enableTLS() + + var buffer = server1.allocator.buffer(capacity: 2) + buffer.writeString("NO") + + var maybeClient1: Channel? = nil + XCTAssertNoThrow(maybeClient1 = try bootstrap.connect(to: server1.localAddress!).wait()) + guard let client1 = maybeClient1 else { + XCTFail("can't connect to server1") + return + } + XCTAssertNoThrow(try client1.writeAndFlush(buffer).wait()) + + // The TLS connection won't actually succeed but it takes Network.framework a while to tell us, we don't + // actually care because we're only interested in the first 2 bytes which we're waiting for below. + tlsBootstrap.connect(to: server2.localAddress!).whenSuccess { channel in + XCTFail("TLS connection succeeded but really shouldn't have: \(channel)") + channel.writeAndFlush(buffer, promise: nil) + } + + XCTAssertNoThrow(XCTAssertFalse(try isTLSConnection1.futureResult.wait())) + XCTAssertNoThrow(XCTAssertTrue(try isTLSConnection2.futureResult.wait())) + } } extension Channel {