diff --git a/Package.swift b/Package.swift index 0b264bd..de7156b 100644 --- a/Package.swift +++ b/Package.swift @@ -21,7 +21,7 @@ let package = Package( .library(name: "NIOTransportServices", targets: ["NIOTransportServices"]), ], dependencies: [ - .package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"), + .package(url: "https://github.com/apple/swift-nio.git", from: "2.60.0"), .package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"), .package(url: "https://github.com/apple/swift-docc-plugin", from: "1.0.0"), ], diff --git a/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift b/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift index 86a7300..277ea02 100644 --- a/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift +++ b/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift @@ -274,6 +274,183 @@ public final class NIOTSConnectionBootstrap { } } +// MARK: Async connect methods with arbitrary payload + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOTSConnectionBootstrap { + /// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - host: The host to connect to. + /// - port: The port to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect( + host: String, + port: Int, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + let validPortRange = Int(UInt16.min)...Int(UInt16.max) + guard validPortRange.contains(port), let actualPort = NWEndpoint.Port(rawValue: UInt16(port)) else { + throw NIOTSErrors.InvalidPort(port: port) + } + + return try await self.connect( + endpoint: NWEndpoint.hostPort(host: .init(host), port: actualPort), + channelInitializer: channelInitializer + ) + } + + /// Specify the `address` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - address: The address to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect( + to address: SocketAddress, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self.connect0( + channelInitializer: channelInitializer, + registration: { connectionChannel, promise in + connectionChannel.register().whenComplete { result in + switch result { + case .success: + connectionChannel.connect(to: address, promise: promise) + case .failure(let error): + promise.fail(error) + } + } + } + ).get() + } + + /// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established. + /// + /// - Parameters: + /// - unixDomainSocketPath: The _Unix domain socket_ path to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect( + unixDomainSocketPath: String, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) + return try await self.connect( + to: address, + channelInitializer: channelInitializer + ) + } + + /// Specify the `endpoint` to connect to for the TCP `Channel` that will be established. + /// + /// - Parameters: + /// - endpoint: The endpoint to connect to. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func connect( + endpoint: NWEndpoint, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self.connect0( + channelInitializer: channelInitializer, + registration: { connectionChannel, promise in + connectionChannel.register().whenComplete { result in + switch result { + case .success: + connectionChannel.triggerUserOutboundEvent( + NIOTSNetworkEvents.ConnectToNWEndpoint(endpoint: endpoint), + promise: promise + ) + case .failure(let error): + promise.fail(error) + } + } + } + ).get() + } + + /// Use a pre-existing `NWConnection` to connect a `Channel`. + /// + /// - Parameters: + /// - connection: The `NWConnection` to wrap. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withExistingNWConnection( + _ connection: NWConnection, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> Output { + try await self.connect0( + existingNWConnection: connection, + channelInitializer: channelInitializer, + registration: { connectionChannel, promise in + connectionChannel.registerAlreadyConfigured0(promise: promise) + } + ).get() + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private func connect0( + existingNWConnection: NWConnection? = nil, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + registration: @escaping (NIOTSConnectionChannel, EventLoopPromise) -> Void + ) -> EventLoopFuture { + let connectionChannel: NIOTSConnectionChannel + if let newConnection = existingNWConnection { + connectionChannel = NIOTSConnectionChannel( + wrapping: newConnection, + on: self.group.next() as! NIOTSEventLoop, + tcpOptions: self.tcpOptions, + tlsOptions: self.tlsOptions + ) + } else { + connectionChannel = NIOTSConnectionChannel( + eventLoop: self.group.next() as! NIOTSEventLoop, + qos: self.qos, + tcpOptions: self.tcpOptions, + tlsOptions: self.tlsOptions + ) + } + let channelInitializer = { (channel: Channel) -> EventLoopFuture in + let initializer = self.channelInitializer + return initializer(channel).flatMap { channelInitializer(channel) } + } + let channelOptions = self.channelOptions + + return connectionChannel.eventLoop.flatSubmit { + return channelOptions.applyAllChannelOptions(to: connectionChannel).flatMap { + channelInitializer(connectionChannel) + }.flatMap { result -> EventLoopFuture in + let connectPromise: EventLoopPromise = connectionChannel.eventLoop.makePromise() + registration(connectionChannel, connectPromise) + let cancelTask = connectionChannel.eventLoop.scheduleTask(in: self.connectTimeout) { + connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout)) + connectionChannel.close(promise: nil) + } + + connectPromise.futureResult.whenComplete { (_: Result) in + cancelTask.cancel() + } + return connectPromise.futureResult.map { result } + }.flatMapErrorThrowing { + connectionChannel.close(promise: nil) + throw $0 + } + } + } +} + @available(*, unavailable) @available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) extension NIOTSConnectionBootstrap: Sendable {} diff --git a/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift b/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift index 522a8e7..ed5241a 100644 --- a/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift +++ b/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift @@ -377,4 +377,229 @@ public final class NIOTSListenerBootstrap { } } +// MARK: Async bind methods with arbitrary payload + +@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) +extension NIOTSListenerBootstrap { + /// Bind the `NIOTSListenerChannel` to `host` and `port`. + /// + /// - Parameters: + /// - host: The host to bind on. + /// - port: The port to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind( + host: String, + port: Int, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> NIOAsyncChannel { + let validPortRange = Int(UInt16.min)...Int(UInt16.max) + guard validPortRange.contains(port) else { + throw NIOTSErrors.InvalidPort(port: port) + } + + return try await self.bind0( + serverBackPressureStrategy: serverBackPressureStrategy, + childChannelInitializer: childChannelInitializer, + registration: { (serverChannel, promise) in + serverChannel.register().whenComplete { result in + switch result { + case .success: + do { + // NWListener does not actually resolve hostname-based NWEndpoints + // for use with requiredLocalEndpoint, so we fall back to + // SocketAddress for this. + let address = try SocketAddress.makeAddressResolvingHost(host, port: port) + serverChannel.bind(to: address, promise: promise) + } catch { + promise.fail(error) + } + case .failure(let error): + promise.fail(error) + } + } + } + ).get() + } + + /// Bind the `NIOTSListenerChannel` to `address`. + /// + /// - Parameters: + /// - address: The `SocketAddress` to bind on. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind( + to address: SocketAddress, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> NIOAsyncChannel { + return try await self.bind0( + serverBackPressureStrategy: serverBackPressureStrategy, + childChannelInitializer: childChannelInitializer, + registration: { (serverChannel, promise) in + serverChannel.register().whenComplete { result in + switch result { + case .success: + serverChannel.bind(to: address, promise: promise) + case .failure(let error): + promise.fail(error) + } + } + } + ).get() + } + + /// Bind the `NIOTSListenerChannel` to a given `NWEndpoint`. + /// + /// - Parameters: + /// - endpoint: The `NWEndpoint` to bind this channel to. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func bind( + endpoint: NWEndpoint, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> NIOAsyncChannel { + return try await self.bind0( + serverBackPressureStrategy: serverBackPressureStrategy, + childChannelInitializer: childChannelInitializer, + registration: { (serverChannel, promise) in + serverChannel.register().whenComplete { result in + switch result { + case .success: + serverChannel.triggerUserOutboundEvent(NIOTSNetworkEvents.BindToNWEndpoint(endpoint: endpoint), promise: promise) + case .failure(let error): + promise.fail(error) + } + } + } + ).get() + } + + /// Bind the `NIOTSListenerChannel` to an existing `NWListener`. + /// + /// - Parameters: + /// - listener: The NWListener to wrap. + /// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel. + /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` + /// method. + /// - Returns: The result of the channel initializer. + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + public func withNWListener( + _ listener: NWListener, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture + ) async throws -> NIOAsyncChannel { + return try await self.bind0( + existingNWListener: listener, + serverBackPressureStrategy: serverBackPressureStrategy, + childChannelInitializer: childChannelInitializer, + registration: { (serverChannel, promise) in + serverChannel.registerAlreadyConfigured0(promise: promise) + } + ).get() + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private func bind0( + existingNWListener: NWListener? = nil, + serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, + childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + registration: @escaping (NIOTSListenerChannel, EventLoopPromise) -> Void + ) -> EventLoopFuture> { + let eventLoop = self.group.next() as! NIOTSEventLoop + let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) } + let childChannelInit = self.childChannelInit + let serverChannelOptions = self.serverChannelOptions + let childChannelOptions = self.childChannelOptions + + let serverChannel: NIOTSListenerChannel + if let newListener = existingNWListener { + serverChannel = NIOTSListenerChannel( + wrapping: newListener, + on: self.group.next() as! NIOTSEventLoop, + qos: self.serverQoS, + tcpOptions: self.tcpOptions, + tlsOptions: self.tlsOptions, + childLoopGroup: self.childGroup, + childChannelQoS: self.childQoS, + childTCPOptions: self.tcpOptions, + childTLSOptions: self.tlsOptions + ) + } else { + serverChannel = NIOTSListenerChannel( + eventLoop: eventLoop, + qos: self.serverQoS, + tcpOptions: self.tcpOptions, + tlsOptions: self.tlsOptions, + childLoopGroup: self.childGroup, + childChannelQoS: self.childQoS, + childTCPOptions: self.tcpOptions, + childTLSOptions: self.tlsOptions + ) + } + + return eventLoop.submit { + serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap { + serverChannelInit(serverChannel) + }.flatMap { (_) -> EventLoopFuture> in + do { + try serverChannel.pipeline.syncOperations.addHandler( + AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions), + name: "AcceptHandler" + ) + let asyncChannel = try NIOAsyncChannel + ._wrapAsyncChannelWithTransformations( + synchronouslyWrapping: serverChannel, + backPressureStrategy: serverBackPressureStrategy, + channelReadTransformation: { channel -> EventLoopFuture<(ChannelInitializerResult)> in + // The channelReadTransformation is run on the EL of the server channel + // We have to make sure that we execute child channel initializer on the + // EL of the child channel. + channel.eventLoop.flatSubmit { + childChannelInitializer(channel) + } + } + ) + + let bindPromise = eventLoop.makePromise(of: Void.self) + registration(serverChannel, bindPromise) + + if let bindTimeout = self.bindTimeout { + let cancelTask = eventLoop.scheduleTask(in: bindTimeout) { + bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout)) + serverChannel.close(promise: nil) + } + + bindPromise.futureResult.whenComplete { (_: Result) in + cancelTask.cancel() + } + } + + return bindPromise.futureResult + .map { (_) -> NIOAsyncChannel in asyncChannel + } + } catch { + return eventLoop.makeFailedFuture(error) + } + }.flatMapError { error -> EventLoopFuture> in + serverChannel.close0(error: error, mode: .all, promise: nil) + return eventLoop.makeFailedFuture(error) + } + }.flatMap { + $0 + } + } +} + #endif diff --git a/Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift b/Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift new file mode 100644 index 0000000..f127d1b --- /dev/null +++ b/Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift @@ -0,0 +1,697 @@ +//===----------------------------------------------------------------------===// +// +// This source file is part of the SwiftNIO open source project +// +// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors +// Licensed under Apache License v2.0 +// +// See LICENSE.txt for license information +// See CONTRIBUTORS.txt for the list of SwiftNIO project authors +// +// SPDX-License-Identifier: Apache-2.0 +// +//===----------------------------------------------------------------------===// + +#if canImport(Network) +import NIOConcurrencyHelpers +import NIOTransportServices +@_spi(AsyncChannel) import NIOCore +import XCTest +@_spi(AsyncChannel) import NIOTLS + +private final class LineDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + + private let newLine = "\n".utf8.first! + + func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { + let readable = buffer.withUnsafeReadableBytes { $0.firstIndex(of: self.newLine) } + if let readable = readable { + context.fireChannelRead(self.wrapInboundOut(buffer.readSlice(length: readable)!)) + buffer.moveReaderIndex(forwardBy: 1) + return .continue + } + return .needMoreData + } + + func encode(data: ByteBuffer, out: inout ByteBuffer) throws { + out.writeImmutableBuffer(data) + out.writeString("\n") + } +} + +private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannelHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = ByteBuffer + enum ALPN: String { + case string + case byte + case unknown + } + + private var proposedALPN: ALPN? + + init( + proposedALPN: ALPN? = nil + ) { + self.proposedALPN = proposedALPN + } + + func handlerAdded(context: ChannelHandlerContext) { + guard context.channel.isActive else { + return + } + + if let proposedALPN = self.proposedALPN { + self.proposedALPN = nil + context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil) + } + context.fireChannelActive() + } + + func channelActive(context: ChannelHandlerContext) { + if let proposedALPN = self.proposedALPN { + context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil) + } + context.fireChannelActive() + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let buffer = self.unwrapInboundIn(data) + let string = String(buffer: buffer) + + if string.hasPrefix("negotiate-alpn:") { + let alpn = String(string.dropFirst(15)) + context.writeAndFlush(.init(ByteBuffer(string: "alpn:\(alpn)")), promise: nil) + context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn)) + context.pipeline.removeHandler(self, promise: nil) + } else if string.hasPrefix("alpn:") { + context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5)))) + context.pipeline.removeHandler(self, promise: nil) + } else { + context.fireChannelRead(data) + } + } +} + +private final class ByteBufferToStringHandler: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = String + typealias OutboundIn = String + typealias OutboundOut = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let buffer = self.unwrapInboundIn(data) + context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer))) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let buffer = ByteBuffer(string: self.unwrapOutboundIn(data)) + context.write(.init(buffer), promise: promise) + } +} + +private final class ByteBufferToByteHandler: ChannelDuplexHandler { + typealias InboundIn = ByteBuffer + typealias InboundOut = UInt8 + typealias OutboundIn = UInt8 + typealias OutboundOut = ByteBuffer + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + var buffer = self.unwrapInboundIn(data) + let byte = buffer.readInteger(as: UInt8.self)! + context.fireChannelRead(self.wrapInboundOut(byte)) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let buffer = ByteBuffer(integer: self.unwrapOutboundIn(data)) + context.write(.init(buffer), promise: promise) + } +} + +private final class AddressedEnvelopingHandler: ChannelDuplexHandler { + typealias InboundIn = AddressedEnvelope + typealias InboundOut = ByteBuffer + typealias OutboundIn = ByteBuffer + typealias OutboundOut = Any + + var remoteAddress: SocketAddress? + + init(remoteAddress: SocketAddress? = nil) { + self.remoteAddress = remoteAddress + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let envelope = self.unwrapInboundIn(data) + self.remoteAddress = envelope.remoteAddress + + context.fireChannelRead(self.wrapInboundOut(envelope.data)) + } + + func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { + let buffer = self.unwrapOutboundIn(data) + if let remoteAddress = self.remoteAddress { + context.write(self.wrapOutboundOut(AddressedEnvelope(remoteAddress: remoteAddress, data: buffer)), promise: promise) + return + } + + context.write(self.wrapOutboundOut(buffer), promise: promise) + } +} + +final class AsyncChannelBootstrapTests: XCTestCase { + enum NegotiationResult { + case string(NIOAsyncChannel) + case byte(NIOAsyncChannel) + } + + struct ProtocolNegotiationError: Error {} + + enum StringOrByte: Hashable { + case string(String) + case byte(UInt8) + } + + func testServerClientBootstrap_withAsyncChannel_andHostPort() async throws { + let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + + let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(ChannelOptions.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel -> EventLoopFuture> in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel( + synchronouslyWrapping: channel, + configuration: .init( + inboundType: String.self, + outboundType: String.self + ) + ) + } + } + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var iterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { _ in + for try await childChannel in channel.inbound { + for try await value in childChannel.inbound { + continuation.yield(.string(value)) + } + } + } + } + + let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!) + try await stringChannel.outbound.write("hello") + + await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) + + group.cancelAll() + } + } + + func testAsyncChannelProtocolNegotiation() async throws { + let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + + let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(ChannelOptions.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult + } + } + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var serverIterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { group in + for try await negotiationResult in channel.inbound { + group.addTask { + switch try await negotiationResult.get() { + case .string(let channel): + for try await value in channel.inbound { + continuation.yield(.string(value)) + } + case .byte(let channel): + for try await value in channel.inbound { + continuation.yield(.byte(value)) + } + } + } + } + } + } + + let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedALPN: .string + ) + switch try await stringNegotiationResult.get() { + case .string(let stringChannel): + // This is the actual content + try await stringChannel.outbound.write("hello") + await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) + case .byte: + preconditionFailure() + } + + let byteNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedALPN: .byte + ) + switch try await byteNegotiationResult.get() { + case .string: + preconditionFailure() + case .byte(let byteChannel): + // This is the actual content + try await byteChannel.outbound.write(UInt8(8)) + await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) + } + + group.cancelAll() + } + } + + func testAsyncChannelNestedProtocolNegotiation() async throws { + let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + + let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .childChannelOption(ChannelOptions.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try self.configureNestedProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult + } + } + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var serverIterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { group in + for try await negotiationResult in channel.inbound { + group.addTask { + switch try await negotiationResult.get().get() { + case .string(let channel): + for try await value in channel.inbound { + continuation.yield(.string(value)) + } + case .byte(let channel): + for try await value in channel.inbound { + continuation.yield(.byte(value)) + } + } + } + } + } + } + + let stringStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedOuterALPN: .string, + proposedInnerALPN: .string + ) + switch try await stringStringNegotiationResult.get().get() { + case .string(let stringChannel): + // This is the actual content + try await stringChannel.outbound.write("hello") + await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) + case .byte: + preconditionFailure() + } + + let byteStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedOuterALPN: .byte, + proposedInnerALPN: .string + ) + switch try await byteStringNegotiationResult.get().get() { + case .string(let stringChannel): + // This is the actual content + try await stringChannel.outbound.write("hello") + await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) + case .byte: + preconditionFailure() + } + + let byteByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedOuterALPN: .byte, + proposedInnerALPN: .byte + ) + switch try await byteByteNegotiationResult.get().get() { + case .string: + preconditionFailure() + case .byte(let byteChannel): + // This is the actual content + try await byteChannel.outbound.write(UInt8(8)) + await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) + } + + let stringByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedOuterALPN: .string, + proposedInnerALPN: .byte + ) + switch try await stringByteNegotiationResult.get().get() { + case .string: + preconditionFailure() + case .byte(let byteChannel): + // This is the actual content + try await byteChannel.outbound.write(UInt8(8)) + await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) + } + + group.cancelAll() + } + } + + func testAsyncChannelProtocolNegotiation_whenFails() async throws { + final class CollectingHandler: ChannelInboundHandler { + typealias InboundIn = Channel + + private let channels: NIOLockedValueBox<[Channel]> + + init(channels: NIOLockedValueBox<[Channel]>) { + self.channels = channels + } + + func channelRead(context: ChannelHandlerContext, data: NIOAny) { + let channel = self.unwrapInboundIn(data) + + self.channels.withLockedValue { $0.append(channel) } + + context.fireChannelRead(data) + } + } + let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) + defer { + try! eventLoopGroup.syncShutdownGracefully() + } + let channels = NIOLockedValueBox<[Channel]>([Channel]()) + + let channel: NIOAsyncChannel, Never> = try await NIOTSListenerBootstrap(group: eventLoopGroup) + .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) + .serverChannelInitializer { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels)) + } + } + .childChannelOption(ChannelOptions.autoRead, value: true) + .bind( + host: "127.0.0.1", + port: 0 + ) { channel in + channel.eventLoop.makeCompletedFuture { + try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult + } + } + + try await withThrowingTaskGroup(of: Void.self) { group in + let (stream, continuation) = AsyncStream.makeStream() + var serverIterator = stream.makeAsyncIterator() + + group.addTask { + try await withThrowingTaskGroup(of: Void.self) { group in + for try await negotiationResult in channel.inbound { + group.addTask { + switch try await negotiationResult.get() { + case .string(let channel): + for try await value in channel.inbound { + continuation.yield(.string(value)) + } + case .byte(let channel): + for try await value in channel.inbound { + continuation.yield(.byte(value)) + } + } + } + } + } + } + + let failedProtocolNegotiation = try await self.makeClientChannelWithProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedALPN: .unknown + ) + await XCTAssertThrowsError( + try await failedProtocolNegotiation.get() + ) + + // Let's check that we can still open a new connection + let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( + eventLoopGroup: eventLoopGroup, + port: channel.channel.localAddress!.port!, + proposedALPN: .string + ) + switch try await stringNegotiationResult.get() { + case .string(let stringChannel): + // This is the actual content + try await stringChannel.outbound.write("hello") + await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) + case .byte: + preconditionFailure() + } + + let failedInboundChannel = channels.withLockedValue { channels -> Channel in + XCTAssertEqual(channels.count, 2) + return channels[0] + } + + // We are waiting here to make sure the channel got closed + try await failedInboundChannel.closeFuture.get() + + group.cancelAll() + } + } + + // MARK: - Test Helpers + + private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel { + return try await NIOTSConnectionBootstrap(group: eventLoopGroup) + .connect( + to: .init(ipAddress: "127.0.0.1", port: port) + ) { channel in + channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler()) + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + return try NIOAsyncChannel( + synchronouslyWrapping: channel, + configuration: .init( + inboundType: String.self, + outboundType: String.self + ) + ) + } + } + } + + private func makeClientChannelWithProtocolNegotiation( + eventLoopGroup: EventLoopGroup, + port: Int, + proposedALPN: TLSUserEventHandler.ALPN + ) async throws -> EventLoopFuture { + return try await NIOTSConnectionBootstrap(group: eventLoopGroup) + .connect( + to: .init(ipAddress: "127.0.0.1", port: port) + ) { channel in + return channel.eventLoop.makeCompletedFuture { + return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN).protocolNegotiationResult + } + } + } + + private func makeClientChannelWithNestedProtocolNegotiation( + eventLoopGroup: EventLoopGroup, + port: Int, + proposedOuterALPN: TLSUserEventHandler.ALPN, + proposedInnerALPN: TLSUserEventHandler.ALPN + ) async throws -> EventLoopFuture> { + return try await NIOTSConnectionBootstrap(group: eventLoopGroup) + .connect( + to: .init(ipAddress: "127.0.0.1", port: port) + ) { channel in + return channel.eventLoop.makeCompletedFuture { + try self.configureNestedProtocolNegotiationHandlers( + channel: channel, + proposedOuterALPN: proposedOuterALPN, + proposedInnerALPN: proposedInnerALPN + ).protocolNegotiationResult + } + } + } + + @discardableResult + private func configureProtocolNegotiationHandlers( + channel: Channel, + proposedALPN: TLSUserEventHandler.ALPN? = nil + ) throws -> NIOTypedApplicationProtocolNegotiationHandler { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN)) + return try self.addTypedApplicationProtocolNegotiationHandler(to: channel) + } + + @discardableResult + private func configureNestedProtocolNegotiationHandlers( + channel: Channel, + proposedOuterALPN: TLSUserEventHandler.ALPN? = nil, + proposedInnerALPN: TLSUserEventHandler.ALPN? = nil + ) throws -> NIOTypedApplicationProtocolNegotiationHandler> { + try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) + try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN)) + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler> { alpnResult, channel in + switch alpnResult { + case .negotiated(let alpn): + switch alpn { + case "string": + return channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) + let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) + + return negotiationFuture.protocolNegotiationResult + } + case "byte": + return channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) + let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) + + return negotiationHandler.protocolNegotiationResult + } + default: + return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } + } + case .fallback: + return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } + } + } + try channel.pipeline.syncOperations.addHandler(negotiationHandler) + return negotiationHandler + } + + @discardableResult + private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> NIOTypedApplicationProtocolNegotiationHandler { + let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in + switch alpnResult { + case .negotiated(let alpn): + switch alpn { + case "string": + return channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) + let asyncChannel: NIOAsyncChannel = try NIOAsyncChannel( + synchronouslyWrapping: channel + ) + + return NegotiationResult.string(asyncChannel) + } + case "byte": + return channel.eventLoop.makeCompletedFuture { + try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler()) + + let asyncChannel: NIOAsyncChannel = try NIOAsyncChannel( + synchronouslyWrapping: channel + ) + + return NegotiationResult.byte(asyncChannel) + } + default: + return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } + } + case .fallback: + return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } + } + } + + try channel.pipeline.syncOperations.addHandler(negotiationHandler) + return negotiationHandler + } +} + +extension AsyncStream { + fileprivate static func makeStream( + of elementType: Element.Type = Element.self, + bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded + ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { + var continuation: AsyncStream.Continuation! + let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } + return (stream: stream, continuation: continuation!) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private func XCTAsyncAssertEqual(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows { + let lhsResult = try await lhs() + let rhsResult = try await rhs() + XCTAssertEqual(lhsResult, rhsResult, file: file, line: line) +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +private func XCTAsyncAssertThrowsError( + _ expression: @autoclosure () async throws -> T, + _ message: @autoclosure () -> String = "", + file: StaticString = #filePath, + line: UInt = #line, + _ errorHandler: (_ error: Error) -> Void = { _ in } +) async { + do { + _ = try await expression() + XCTFail(message(), file: file, line: line) + } catch { + errorHandler(error) + } +} + +@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) +internal func XCTAssertThrowsError( + _ expression: @autoclosure () async throws -> T, + file: StaticString = #filePath, + line: UInt = #line, + verify: (Error) -> Void = { _ in } +) async { + do { + _ = try await expression() + XCTFail("Expression did not throw error", file: file, line: line) + } catch { + verify(error) + } +} +#endif