From 5fd1458c245d5741b3c8ebe55489f590c6ca8f15 Mon Sep 17 00:00:00 2001 From: Franz Busch Date: Wed, 9 Aug 2023 13:13:21 +0100 Subject: [PATCH] Back out SPI(AsyncChannel) changes (#184) # Motivation We want to release a new `NIOTS` version without the SPI changes for now. # Modification This PR backs out the new `NIOAsyncChannel` APIs. # Result No more SPI usage so we can safely release. --- .../NIOTSConnectionBootstrap.swift | 184 +---- .../NIOTSListenerBootstrap.swift | 227 +----- .../NIOTSAsyncBootstrapTests.swift | 697 ------------------ 3 files changed, 2 insertions(+), 1106 deletions(-) delete mode 100644 Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift diff --git a/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift b/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift index ab97123..86a7300 100644 --- a/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift +++ b/Sources/NIOTransportServices/NIOTSConnectionBootstrap.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// #if canImport(Network) -@_spi(AsyncChannel) import NIOCore +import NIOCore import Dispatch import Network @@ -274,188 +274,6 @@ public final class NIOTSConnectionBootstrap { } } -// MARK: Async connect methods with arbitrary payload - -@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -extension NIOTSConnectionBootstrap { - /// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established. - /// - /// - Parameters: - /// - host: The host to connect to. - /// - port: The port to connect to. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func connect( - host: String, - port: Int, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> Output { - let validPortRange = Int(UInt16.min)...Int(UInt16.max) - guard validPortRange.contains(port), let actualPort = NWEndpoint.Port(rawValue: UInt16(port)) else { - throw NIOTSErrors.InvalidPort(port: port) - } - - return try await self.connect( - endpoint: NWEndpoint.hostPort(host: .init(host), port: actualPort), - channelInitializer: channelInitializer - ) - } - - /// Specify the `address` to connect to for the TCP `Channel` that will be established. - /// - /// - Parameters: - /// - address: The address to connect to. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func connect( - to address: SocketAddress, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> Output { - try await self.connect0( - channelInitializer: channelInitializer, - registration: { connectionChannel, promise in - connectionChannel.register().whenComplete { result in - switch result { - case .success: - connectionChannel.connect(to: address, promise: promise) - case .failure(let error): - promise.fail(error) - } - } - } - ).get() - } - - /// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established. - /// - /// - Parameters: - /// - unixDomainSocketPath: The _Unix domain socket_ path to connect to. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func connect( - unixDomainSocketPath: String, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> Output { - let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath) - return try await self.connect( - to: address, - channelInitializer: channelInitializer - ) - } - - /// Specify the `endpoint` to connect to for the TCP `Channel` that will be established. - /// - /// - Parameters: - /// - endpoint: The endpoint to connect to. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func connect( - endpoint: NWEndpoint, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> Output { - try await self.connect0( - channelInitializer: channelInitializer, - registration: { connectionChannel, promise in - connectionChannel.register().whenComplete { result in - switch result { - case .success: - connectionChannel.triggerUserOutboundEvent( - NIOTSNetworkEvents.ConnectToNWEndpoint(endpoint: endpoint), - promise: promise - ) - case .failure(let error): - promise.fail(error) - } - } - } - ).get() - } - - /// Use a pre-existing `NWConnection` to connect a `Channel`. - /// - /// - Parameters: - /// - connection: The `NWConnection` to wrap. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func withExistingNWConnection( - _ connection: NWConnection, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> Output { - try await self.connect0( - existingNWConnection: connection, - channelInitializer: channelInitializer, - registration: { connectionChannel, promise in - connectionChannel.registerAlreadyConfigured0(promise: promise) - } - ).get() - } - - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - private func connect0( - existingNWConnection: NWConnection? = nil, - channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - registration: @escaping (NIOTSConnectionChannel, EventLoopPromise) -> Void - ) -> EventLoopFuture { - let connectionChannel: NIOTSConnectionChannel - if let newConnection = existingNWConnection { - connectionChannel = NIOTSConnectionChannel( - wrapping: newConnection, - on: self.group.next() as! NIOTSEventLoop, - tcpOptions: self.tcpOptions, - tlsOptions: self.tlsOptions - ) - } else { - connectionChannel = NIOTSConnectionChannel( - eventLoop: self.group.next() as! NIOTSEventLoop, - qos: self.qos, - tcpOptions: self.tcpOptions, - tlsOptions: self.tlsOptions - ) - } - let channelInitializer = { (channel: Channel) -> EventLoopFuture in - let initializer = self.channelInitializer - return initializer(channel).flatMap { channelInitializer(channel) } - } - let channelOptions = self.channelOptions - - return connectionChannel.eventLoop.flatSubmit { - return channelOptions.applyAllChannelOptions(to: connectionChannel).flatMap { - channelInitializer(connectionChannel) - }.flatMap { result -> EventLoopFuture in - let connectPromise: EventLoopPromise = connectionChannel.eventLoop.makePromise() - registration(connectionChannel, connectPromise) - let cancelTask = connectionChannel.eventLoop.scheduleTask(in: self.connectTimeout) { - connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout)) - connectionChannel.close(promise: nil) - } - - connectPromise.futureResult.whenComplete { (_: Result) in - cancelTask.cancel() - } - return connectPromise.futureResult.map { result } - }.flatMapErrorThrowing { - connectionChannel.close(promise: nil) - throw $0 - } - } - } -} - @available(*, unavailable) @available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *) extension NIOTSConnectionBootstrap: Sendable {} diff --git a/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift b/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift index 0ea3dbf..62a56f6 100644 --- a/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift +++ b/Sources/NIOTransportServices/NIOTSListenerBootstrap.swift @@ -13,7 +13,7 @@ //===----------------------------------------------------------------------===// #if canImport(Network) -@_spi(AsyncChannel) import NIOCore +import NIOCore import Dispatch import Network @@ -377,229 +377,4 @@ public final class NIOTSListenerBootstrap { } } -// MARK: Async bind methods with arbitrary payload - -@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) -extension NIOTSListenerBootstrap { - /// Bind the `NIOTSListenerChannel` to `host` and `port`. - /// - /// - Parameters: - /// - host: The host to bind on. - /// - port: The port to bind on. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func bind( - host: String, - port: Int, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, - childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> NIOAsyncChannel { - let validPortRange = Int(UInt16.min)...Int(UInt16.max) - guard validPortRange.contains(port) else { - throw NIOTSErrors.InvalidPort(port: port) - } - - return try await self.bind0( - serverBackpressureStrategy: serverBackpressureStrategy, - childChannelInitializer: childChannelInitializer, - registration: { (serverChannel, promise) in - serverChannel.register().whenComplete { result in - switch result { - case .success: - do { - // NWListener does not actually resolve hostname-based NWEndpoints - // for use with requiredLocalEndpoint, so we fall back to - // SocketAddress for this. - let address = try SocketAddress.makeAddressResolvingHost(host, port: port) - serverChannel.bind(to: address, promise: promise) - } catch { - promise.fail(error) - } - case .failure(let error): - promise.fail(error) - } - } - } - ).get() - } - - /// Bind the `NIOTSListenerChannel` to `address`. - /// - /// - Parameters: - /// - address: The `SocketAddress` to bind on. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func bind( - to address: SocketAddress, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, - childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> NIOAsyncChannel { - return try await self.bind0( - serverBackpressureStrategy: serverBackpressureStrategy, - childChannelInitializer: childChannelInitializer, - registration: { (serverChannel, promise) in - serverChannel.register().whenComplete { result in - switch result { - case .success: - serverChannel.bind(to: address, promise: promise) - case .failure(let error): - promise.fail(error) - } - } - } - ).get() - } - - /// Bind the `NIOTSListenerChannel` to a given `NWEndpoint`. - /// - /// - Parameters: - /// - endpoint: The `NWEndpoint` to bind this channel to. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func bind( - endpoint: NWEndpoint, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, - childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> NIOAsyncChannel { - return try await self.bind0( - serverBackpressureStrategy: serverBackpressureStrategy, - childChannelInitializer: childChannelInitializer, - registration: { (serverChannel, promise) in - serverChannel.register().whenComplete { result in - switch result { - case .success: - serverChannel.triggerUserOutboundEvent(NIOTSNetworkEvents.BindToNWEndpoint(endpoint: endpoint), promise: promise) - case .failure(let error): - promise.fail(error) - } - } - } - ).get() - } - - /// Bind the `NIOTSListenerChannel` to an existing `NWListener`. - /// - /// - Parameters: - /// - listener: The NWListener to wrap. - /// - serverBackpressureStrategy: The back pressure strategy used by the server socket channel. - /// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind` - /// method. - /// - Returns: The result of the channel initializer. - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - @_spi(AsyncChannel) - public func withNWListener( - _ listener: NWListener, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil, - childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture - ) async throws -> NIOAsyncChannel { - return try await self.bind0( - existingNWListener: listener, - serverBackpressureStrategy: serverBackpressureStrategy, - childChannelInitializer: childChannelInitializer, - registration: { (serverChannel, promise) in - serverChannel.registerAlreadyConfigured0(promise: promise) - } - ).get() - } - - @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) - private func bind0( - existingNWListener: NWListener? = nil, - serverBackpressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?, - childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, - registration: @escaping (NIOTSListenerChannel, EventLoopPromise) -> Void - ) -> EventLoopFuture> { - let eventLoop = self.group.next() as! NIOTSEventLoop - let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) } - let childChannelInit = self.childChannelInit - let serverChannelOptions = self.serverChannelOptions - let childChannelOptions = self.childChannelOptions - - let serverChannel: NIOTSListenerChannel - if let newListener = existingNWListener { - serverChannel = NIOTSListenerChannel(wrapping: newListener, - on: self.group.next() as! NIOTSEventLoop, - qos: self.serverQoS, - tcpOptions: self.tcpOptions, - tlsOptions: self.tlsOptions, - childLoopGroup: self.childGroup, - childChannelQoS: self.childQoS, - childTCPOptions: self.tcpOptions, - childTLSOptions: self.tlsOptions) - } else { - serverChannel = NIOTSListenerChannel(eventLoop: eventLoop, - qos: self.serverQoS, - tcpOptions: self.tcpOptions, - tlsOptions: self.tlsOptions, - childLoopGroup: self.childGroup, - childChannelQoS: self.childQoS, - childTCPOptions: self.tcpOptions, - childTLSOptions: self.tlsOptions) - } - - return eventLoop.submit { - serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap { - serverChannelInit(serverChannel) - }.flatMap { (_) -> EventLoopFuture> in - do { - try serverChannel.pipeline.syncOperations.addHandler( - AcceptHandler(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions), - name: "AcceptHandler" - ) - let asyncChannel = try NIOAsyncChannel - .wrapAsyncChannelWithTransformations( - synchronouslyWrapping: serverChannel, - backpressureStrategy: serverBackpressureStrategy, - channelReadTransformation: { channel -> EventLoopFuture<(ChannelInitializerResult)> in - // The channelReadTransformation is run on the EL of the server channel - // We have to make sure that we execute child channel initializer on the - // EL of the child channel. - channel.eventLoop.flatSubmit { - childChannelInitializer(channel) - } - } - ) - - let bindPromise = eventLoop.makePromise(of: Void.self) - registration(serverChannel, bindPromise) - - if let bindTimeout = self.bindTimeout { - let cancelTask = eventLoop.scheduleTask(in: bindTimeout) { - bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout)) - serverChannel.close(promise: nil) - } - - bindPromise.futureResult.whenComplete { (_: Result) in - cancelTask.cancel() - } - } - - return bindPromise.futureResult - .map { (_) -> NIOAsyncChannel in asyncChannel - } - } catch { - return eventLoop.makeFailedFuture(error) - } - }.flatMapError { error -> EventLoopFuture> in - serverChannel.close0(error: error, mode: .all, promise: nil) - return eventLoop.makeFailedFuture(error) - } - }.flatMap { - $0 - } - } -} - #endif diff --git a/Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift b/Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift deleted file mode 100644 index fc4dd9e..0000000 --- a/Tests/NIOTransportServicesTests/NIOTSAsyncBootstrapTests.swift +++ /dev/null @@ -1,697 +0,0 @@ -//===----------------------------------------------------------------------===// -// -// This source file is part of the SwiftNIO open source project -// -// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors -// Licensed under Apache License v2.0 -// -// See LICENSE.txt for license information -// See CONTRIBUTORS.txt for the list of SwiftNIO project authors -// -// SPDX-License-Identifier: Apache-2.0 -// -//===----------------------------------------------------------------------===// - -#if canImport(Network) -import NIOConcurrencyHelpers -@_spi(AsyncChannel) import NIOTransportServices -@_spi(AsyncChannel) import NIOCore -import XCTest -@_spi(AsyncChannel) import NIOTLS - -private final class LineDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder { - typealias InboundIn = ByteBuffer - typealias InboundOut = ByteBuffer - - private let newLine = "\n".utf8.first! - - func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState { - let readable = buffer.withUnsafeReadableBytes { $0.firstIndex(of: self.newLine) } - if let readable = readable { - context.fireChannelRead(self.wrapInboundOut(buffer.readSlice(length: readable)!)) - buffer.moveReaderIndex(forwardBy: 1) - return .continue - } - return .needMoreData - } - - func encode(data: ByteBuffer, out: inout ByteBuffer) throws { - out.writeImmutableBuffer(data) - out.writeString("\n") - } -} - -private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannelHandler { - typealias InboundIn = ByteBuffer - typealias InboundOut = ByteBuffer - enum ALPN: String { - case string - case byte - case unknown - } - - private var proposedALPN: ALPN? - - init( - proposedALPN: ALPN? = nil - ) { - self.proposedALPN = proposedALPN - } - - func handlerAdded(context: ChannelHandlerContext) { - guard context.channel.isActive else { - return - } - - if let proposedALPN = self.proposedALPN { - self.proposedALPN = nil - context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil) - } - context.fireChannelActive() - } - - func channelActive(context: ChannelHandlerContext) { - if let proposedALPN = self.proposedALPN { - context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil) - } - context.fireChannelActive() - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let buffer = self.unwrapInboundIn(data) - let string = String(buffer: buffer) - - if string.hasPrefix("negotiate-alpn:") { - let alpn = String(string.dropFirst(15)) - context.writeAndFlush(.init(ByteBuffer(string: "alpn:\(alpn)")), promise: nil) - context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn)) - context.pipeline.removeHandler(self, promise: nil) - } else if string.hasPrefix("alpn:") { - context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5)))) - context.pipeline.removeHandler(self, promise: nil) - } else { - context.fireChannelRead(data) - } - } -} - -private final class ByteBufferToStringHandler: ChannelDuplexHandler { - typealias InboundIn = ByteBuffer - typealias InboundOut = String - typealias OutboundIn = String - typealias OutboundOut = ByteBuffer - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let buffer = self.unwrapInboundIn(data) - context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer))) - } - - func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let buffer = ByteBuffer(string: self.unwrapOutboundIn(data)) - context.write(.init(buffer), promise: promise) - } -} - -private final class ByteBufferToByteHandler: ChannelDuplexHandler { - typealias InboundIn = ByteBuffer - typealias InboundOut = UInt8 - typealias OutboundIn = UInt8 - typealias OutboundOut = ByteBuffer - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - var buffer = self.unwrapInboundIn(data) - let byte = buffer.readInteger(as: UInt8.self)! - context.fireChannelRead(self.wrapInboundOut(byte)) - } - - func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let buffer = ByteBuffer(integer: self.unwrapOutboundIn(data)) - context.write(.init(buffer), promise: promise) - } -} - -private final class AddressedEnvelopingHandler: ChannelDuplexHandler { - typealias InboundIn = AddressedEnvelope - typealias InboundOut = ByteBuffer - typealias OutboundIn = ByteBuffer - typealias OutboundOut = Any - - var remoteAddress: SocketAddress? - - init(remoteAddress: SocketAddress? = nil) { - self.remoteAddress = remoteAddress - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let envelope = self.unwrapInboundIn(data) - self.remoteAddress = envelope.remoteAddress - - context.fireChannelRead(self.wrapInboundOut(envelope.data)) - } - - func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise?) { - let buffer = self.unwrapOutboundIn(data) - if let remoteAddress = self.remoteAddress { - context.write(self.wrapOutboundOut(AddressedEnvelope(remoteAddress: remoteAddress, data: buffer)), promise: promise) - return - } - - context.write(self.wrapOutboundOut(buffer), promise: promise) - } -} - -final class AsyncChannelBootstrapTests: XCTestCase { - enum NegotiationResult { - case string(NIOAsyncChannel) - case byte(NIOAsyncChannel) - } - - struct ProtocolNegotiationError: Error {} - - enum StringOrByte: Hashable { - case string(String) - case byte(UInt8) - } - - func testServerClientBootstrap_withAsyncChannel_andHostPort() async throws { - let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) - defer { - try! eventLoopGroup.syncShutdownGracefully() - } - - let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelOption(ChannelOptions.autoRead, value: true) - .bind( - host: "127.0.0.1", - port: 0 - ) { channel -> EventLoopFuture> in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - inboundType: String.self, - outboundType: String.self - ) - ) - } - } - - try await withThrowingTaskGroup(of: Void.self) { group in - let (stream, continuation) = AsyncStream.makeStream() - var iterator = stream.makeAsyncIterator() - - group.addTask { - try await withThrowingTaskGroup(of: Void.self) { _ in - for try await childChannel in channel.inboundStream { - for try await value in childChannel.inboundStream { - continuation.yield(.string(value)) - } - } - } - } - - let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!) - try await stringChannel.outboundWriter.write("hello") - - await XCTAsyncAssertEqual(await iterator.next(), .string("hello")) - - group.cancelAll() - } - } - - func testAsyncChannelProtocolNegotiation() async throws { - let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) - defer { - try! eventLoopGroup.syncShutdownGracefully() - } - - let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelOption(ChannelOptions.autoRead, value: true) - .bind( - host: "127.0.0.1", - port: 0 - ) { channel in - channel.eventLoop.makeCompletedFuture { - try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult - } - } - - try await withThrowingTaskGroup(of: Void.self) { group in - let (stream, continuation) = AsyncStream.makeStream() - var serverIterator = stream.makeAsyncIterator() - - group.addTask { - try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { - group.addTask { - switch try await negotiationResult.getResult() { - case .string(let channel): - for try await value in channel.inboundStream { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inboundStream { - continuation.yield(.byte(value)) - } - } - } - } - } - } - - let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedALPN: .string - ) - switch try await stringNegotiationResult.getResult() { - case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") - await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) - case .byte: - preconditionFailure() - } - - let byteNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedALPN: .byte - ) - switch try await byteNegotiationResult.getResult() { - case .string: - preconditionFailure() - case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) - await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) - } - - group.cancelAll() - } - } - - func testAsyncChannelNestedProtocolNegotiation() async throws { - let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) - defer { - try! eventLoopGroup.syncShutdownGracefully() - } - - let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .childChannelOption(ChannelOptions.autoRead, value: true) - .bind( - host: "127.0.0.1", - port: 0 - ) { channel in - channel.eventLoop.makeCompletedFuture { - try self.configureNestedProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult - } - } - - try await withThrowingTaskGroup(of: Void.self) { group in - let (stream, continuation) = AsyncStream.makeStream() - var serverIterator = stream.makeAsyncIterator() - - group.addTask { - try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { - group.addTask { - switch try await negotiationResult.getResult() { - case .string(let channel): - for try await value in channel.inboundStream { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inboundStream { - continuation.yield(.byte(value)) - } - } - } - } - } - } - - let stringStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedOuterALPN: .string, - proposedInnerALPN: .string - ) - switch try await stringStringNegotiationResult.getResult() { - case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") - await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) - case .byte: - preconditionFailure() - } - - let byteStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedOuterALPN: .byte, - proposedInnerALPN: .string - ) - switch try await byteStringNegotiationResult.getResult() { - case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") - await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) - case .byte: - preconditionFailure() - } - - let byteByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedOuterALPN: .byte, - proposedInnerALPN: .byte - ) - switch try await byteByteNegotiationResult.getResult() { - case .string: - preconditionFailure() - case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) - await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) - } - - let stringByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedOuterALPN: .string, - proposedInnerALPN: .byte - ) - switch try await stringByteNegotiationResult.getResult() { - case .string: - preconditionFailure() - case .byte(let byteChannel): - // This is the actual content - try await byteChannel.outboundWriter.write(UInt8(8)) - await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8)) - } - - group.cancelAll() - } - } - - func testAsyncChannelProtocolNegotiation_whenFails() async throws { - final class CollectingHandler: ChannelInboundHandler { - typealias InboundIn = Channel - - private let channels: NIOLockedValueBox<[Channel]> - - init(channels: NIOLockedValueBox<[Channel]>) { - self.channels = channels - } - - func channelRead(context: ChannelHandlerContext, data: NIOAny) { - let channel = self.unwrapInboundIn(data) - - self.channels.withLockedValue { $0.append(channel) } - - context.fireChannelRead(data) - } - } - let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3) - defer { - try! eventLoopGroup.syncShutdownGracefully() - } - let channels = NIOLockedValueBox<[Channel]>([Channel]()) - - let channel: NIOAsyncChannel>, Never> = try await NIOTSListenerBootstrap(group: eventLoopGroup) - .serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1) - .serverChannelInitializer { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels)) - } - } - .childChannelOption(ChannelOptions.autoRead, value: true) - .bind( - host: "127.0.0.1", - port: 0 - ) { channel in - channel.eventLoop.makeCompletedFuture { - try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult - } - } - - try await withThrowingTaskGroup(of: Void.self) { group in - let (stream, continuation) = AsyncStream.makeStream() - var serverIterator = stream.makeAsyncIterator() - - group.addTask { - try await withThrowingTaskGroup(of: Void.self) { group in - for try await negotiationResult in channel.inboundStream { - group.addTask { - switch try await negotiationResult.getResult() { - case .string(let channel): - for try await value in channel.inboundStream { - continuation.yield(.string(value)) - } - case .byte(let channel): - for try await value in channel.inboundStream { - continuation.yield(.byte(value)) - } - } - } - } - } - } - - let failedProtocolNegotiation = try await self.makeClientChannelWithProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedALPN: .unknown - ) - await XCTAssertThrowsError( - try await failedProtocolNegotiation.getResult() - ) - - // Let's check that we can still open a new connection - let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation( - eventLoopGroup: eventLoopGroup, - port: channel.channel.localAddress!.port!, - proposedALPN: .string - ) - switch try await stringNegotiationResult.getResult() { - case .string(let stringChannel): - // This is the actual content - try await stringChannel.outboundWriter.write("hello") - await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello")) - case .byte: - preconditionFailure() - } - - let failedInboundChannel = channels.withLockedValue { channels -> Channel in - XCTAssertEqual(channels.count, 2) - return channels[0] - } - - // We are waiting here to make sure the channel got closed - try await failedInboundChannel.closeFuture.get() - - group.cancelAll() - } - } - - // MARK: - Test Helpers - - private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel { - return try await NIOTSConnectionBootstrap(group: eventLoopGroup) - .connect( - to: .init(ipAddress: "127.0.0.1", port: port) - ) { channel in - channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler()) - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - return try NIOAsyncChannel( - synchronouslyWrapping: channel, - configuration: .init( - inboundType: String.self, - outboundType: String.self - ) - ) - } - } - } - - private func makeClientChannelWithProtocolNegotiation( - eventLoopGroup: EventLoopGroup, - port: Int, - proposedALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { - return try await NIOTSConnectionBootstrap(group: eventLoopGroup) - .connect( - to: .init(ipAddress: "127.0.0.1", port: port) - ) { channel in - return channel.eventLoop.makeCompletedFuture { - return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN).protocolNegotiationResult - } - } - } - - private func makeClientChannelWithNestedProtocolNegotiation( - eventLoopGroup: EventLoopGroup, - port: Int, - proposedOuterALPN: TLSUserEventHandler.ALPN, - proposedInnerALPN: TLSUserEventHandler.ALPN - ) async throws -> EventLoopFuture> { - return try await NIOTSConnectionBootstrap(group: eventLoopGroup) - .connect( - to: .init(ipAddress: "127.0.0.1", port: port) - ) { channel in - return channel.eventLoop.makeCompletedFuture { - try self.configureNestedProtocolNegotiationHandlers( - channel: channel, - proposedOuterALPN: proposedOuterALPN, - proposedInnerALPN: proposedInnerALPN - ).protocolNegotiationResult - } - } - } - - @discardableResult - private func configureProtocolNegotiationHandlers( - channel: Channel, - proposedALPN: TLSUserEventHandler.ALPN? = nil - ) throws -> NIOTypedApplicationProtocolNegotiationHandler { - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN)) - return try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - } - - @discardableResult - private func configureNestedProtocolNegotiationHandlers( - channel: Channel, - proposedOuterALPN: TLSUserEventHandler.ALPN? = nil, - proposedInnerALPN: TLSUserEventHandler.ALPN? = nil - ) throws -> NIOTypedApplicationProtocolNegotiationHandler { - try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder())) - try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN)) - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in - switch alpnResult { - case .negotiated(let alpn): - switch alpn { - case "string": - return channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) - let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - - return NIOProtocolNegotiationResult(deferredResult: negotiationFuture.protocolNegotiationResult) - } - case "byte": - return channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN)) - let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel) - - return NIOProtocolNegotiationResult(deferredResult: negotiationHandler.protocolNegotiationResult) - } - default: - return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } - } - case .fallback: - return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } - } - } - try channel.pipeline.syncOperations.addHandler(negotiationHandler) - return negotiationHandler - } - - @discardableResult - private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> NIOTypedApplicationProtocolNegotiationHandler { - let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler { alpnResult, channel in - switch alpnResult { - case .negotiated(let alpn): - switch alpn { - case "string": - return channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler()) - let asyncChannel: NIOAsyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel - ) - - return NIOProtocolNegotiationResult(result: NegotiationResult.string(asyncChannel)) - } - case "byte": - return channel.eventLoop.makeCompletedFuture { - try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler()) - - let asyncChannel: NIOAsyncChannel = try NIOAsyncChannel( - synchronouslyWrapping: channel - ) - - return NIOProtocolNegotiationResult(result: NegotiationResult.byte(asyncChannel)) - } - default: - return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } - } - case .fallback: - return channel.close().flatMapThrowing { throw ProtocolNegotiationError() } - } - } - - try channel.pipeline.syncOperations.addHandler(negotiationHandler) - return negotiationHandler - } -} - -extension AsyncStream { - fileprivate static func makeStream( - of elementType: Element.Type = Element.self, - bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded - ) -> (stream: AsyncStream, continuation: AsyncStream.Continuation) { - var continuation: AsyncStream.Continuation! - let stream = AsyncStream(bufferingPolicy: limit) { continuation = $0 } - return (stream: stream, continuation: continuation!) - } -} - -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -private func XCTAsyncAssertEqual(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows { - let lhsResult = try await lhs() - let rhsResult = try await rhs() - XCTAssertEqual(lhsResult, rhsResult, file: file, line: line) -} - -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -private func XCTAsyncAssertThrowsError( - _ expression: @autoclosure () async throws -> T, - _ message: @autoclosure () -> String = "", - file: StaticString = #filePath, - line: UInt = #line, - _ errorHandler: (_ error: Error) -> Void = { _ in } -) async { - do { - _ = try await expression() - XCTFail(message(), file: file, line: line) - } catch { - errorHandler(error) - } -} - -@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *) -internal func XCTAssertThrowsError( - _ expression: @autoclosure () async throws -> T, - file: StaticString = #filePath, - line: UInt = #line, - verify: (Error) -> Void = { _ in } -) async { - do { - _ = try await expression() - XCTFail("Expression did not throw error", file: file, line: line) - } catch { - verify(error) - } -} -#endif