Add new typed async bootstrap APIs back and drop SPI (#191)

* Revert "Back out SPI(AsyncChannel) changes"

This reverts commit 33d2b2993f.

* Add new typed async bootstrap APIs back and drop SPI

# Motivation
We just merged the removal of the `AsyncChannel` SPI in NIO and can now add back the new APIs in transport services as well.

# Modification
This PR brings back the previous SPI and promotes it to API.

# Result
New typed async bootstraps API for `NIOTransportServices`.

* George review
This commit is contained in:
Franz Busch 2023-10-25 14:09:26 +01:00 committed by GitHub
parent 16ca413e3f
commit ebf8b9c365
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 1100 additions and 1 deletions

View File

@ -21,7 +21,7 @@ let package = Package(
.library(name: "NIOTransportServices", targets: ["NIOTransportServices"]),
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.58.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.60.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
.package(url: "https://github.com/apple/swift-docc-plugin", from: "1.0.0"),
],

View File

@ -274,6 +274,183 @@ public final class NIOTSConnectionBootstrap {
}
}
// MARK: Async connect methods with arbitrary payload
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOTSConnectionBootstrap {
/// Specify the `host` and `port` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - host: The host to connect to.
/// - port: The port to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func connect<Output: Sendable>(
host: String,
port: Int,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
let validPortRange = Int(UInt16.min)...Int(UInt16.max)
guard validPortRange.contains(port), let actualPort = NWEndpoint.Port(rawValue: UInt16(port)) else {
throw NIOTSErrors.InvalidPort(port: port)
}
return try await self.connect(
endpoint: NWEndpoint.hostPort(host: .init(host), port: actualPort),
channelInitializer: channelInitializer
)
}
/// Specify the `address` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - address: The address to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func connect<Output: Sendable>(
to address: SocketAddress,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
try await self.connect0(
channelInitializer: channelInitializer,
registration: { connectionChannel, promise in
connectionChannel.register().whenComplete { result in
switch result {
case .success:
connectionChannel.connect(to: address, promise: promise)
case .failure(let error):
promise.fail(error)
}
}
}
).get()
}
/// Specify the `unixDomainSocket` path to connect to for the UDS `Channel` that will be established.
///
/// - Parameters:
/// - unixDomainSocketPath: The _Unix domain socket_ path to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func connect<Output: Sendable>(
unixDomainSocketPath: String,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
let address = try SocketAddress(unixDomainSocketPath: unixDomainSocketPath)
return try await self.connect(
to: address,
channelInitializer: channelInitializer
)
}
/// Specify the `endpoint` to connect to for the TCP `Channel` that will be established.
///
/// - Parameters:
/// - endpoint: The endpoint to connect to.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func connect<Output: Sendable>(
endpoint: NWEndpoint,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
try await self.connect0(
channelInitializer: channelInitializer,
registration: { connectionChannel, promise in
connectionChannel.register().whenComplete { result in
switch result {
case .success:
connectionChannel.triggerUserOutboundEvent(
NIOTSNetworkEvents.ConnectToNWEndpoint(endpoint: endpoint),
promise: promise
)
case .failure(let error):
promise.fail(error)
}
}
}
).get()
}
/// Use a pre-existing `NWConnection` to connect a `Channel`.
///
/// - Parameters:
/// - connection: The `NWConnection` to wrap.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `connect`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func withExistingNWConnection<Output: Sendable>(
_ connection: NWConnection,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> Output {
try await self.connect0(
existingNWConnection: connection,
channelInitializer: channelInitializer,
registration: { connectionChannel, promise in
connectionChannel.registerAlreadyConfigured0(promise: promise)
}
).get()
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func connect0<ChannelInitializerResult>(
existingNWConnection: NWConnection? = nil,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
registration: @escaping (NIOTSConnectionChannel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<ChannelInitializerResult> {
let connectionChannel: NIOTSConnectionChannel
if let newConnection = existingNWConnection {
connectionChannel = NIOTSConnectionChannel(
wrapping: newConnection,
on: self.group.next() as! NIOTSEventLoop,
tcpOptions: self.tcpOptions,
tlsOptions: self.tlsOptions
)
} else {
connectionChannel = NIOTSConnectionChannel(
eventLoop: self.group.next() as! NIOTSEventLoop,
qos: self.qos,
tcpOptions: self.tcpOptions,
tlsOptions: self.tlsOptions
)
}
let channelInitializer = { (channel: Channel) -> EventLoopFuture<ChannelInitializerResult> in
let initializer = self.channelInitializer
return initializer(channel).flatMap { channelInitializer(channel) }
}
let channelOptions = self.channelOptions
return connectionChannel.eventLoop.flatSubmit {
return channelOptions.applyAllChannelOptions(to: connectionChannel).flatMap {
channelInitializer(connectionChannel)
}.flatMap { result -> EventLoopFuture<ChannelInitializerResult> in
let connectPromise: EventLoopPromise<Void> = connectionChannel.eventLoop.makePromise()
registration(connectionChannel, connectPromise)
let cancelTask = connectionChannel.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
connectionChannel.close(promise: nil)
}
connectPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
cancelTask.cancel()
}
return connectPromise.futureResult.map { result }
}.flatMapErrorThrowing {
connectionChannel.close(promise: nil)
throw $0
}
}
}
}
@available(*, unavailable)
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
extension NIOTSConnectionBootstrap: Sendable {}

View File

@ -377,4 +377,229 @@ public final class NIOTSListenerBootstrap {
}
}
// MARK: Async bind methods with arbitrary payload
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
extension NIOTSListenerBootstrap {
/// Bind the `NIOTSListenerChannel` to `host` and `port`.
///
/// - Parameters:
/// - host: The host to bind on.
/// - port: The port to bind on.
/// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func bind<Output: Sendable>(
host: String,
port: Int,
serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> NIOAsyncChannel<Output, Never> {
let validPortRange = Int(UInt16.min)...Int(UInt16.max)
guard validPortRange.contains(port) else {
throw NIOTSErrors.InvalidPort(port: port)
}
return try await self.bind0(
serverBackPressureStrategy: serverBackPressureStrategy,
childChannelInitializer: childChannelInitializer,
registration: { (serverChannel, promise) in
serverChannel.register().whenComplete { result in
switch result {
case .success:
do {
// NWListener does not actually resolve hostname-based NWEndpoints
// for use with requiredLocalEndpoint, so we fall back to
// SocketAddress for this.
let address = try SocketAddress.makeAddressResolvingHost(host, port: port)
serverChannel.bind(to: address, promise: promise)
} catch {
promise.fail(error)
}
case .failure(let error):
promise.fail(error)
}
}
}
).get()
}
/// Bind the `NIOTSListenerChannel` to `address`.
///
/// - Parameters:
/// - address: The `SocketAddress` to bind on.
/// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func bind<Output: Sendable>(
to address: SocketAddress,
serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> NIOAsyncChannel<Output, Never> {
return try await self.bind0(
serverBackPressureStrategy: serverBackPressureStrategy,
childChannelInitializer: childChannelInitializer,
registration: { (serverChannel, promise) in
serverChannel.register().whenComplete { result in
switch result {
case .success:
serverChannel.bind(to: address, promise: promise)
case .failure(let error):
promise.fail(error)
}
}
}
).get()
}
/// Bind the `NIOTSListenerChannel` to a given `NWEndpoint`.
///
/// - Parameters:
/// - endpoint: The `NWEndpoint` to bind this channel to.
/// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func bind<Output: Sendable>(
endpoint: NWEndpoint,
serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> NIOAsyncChannel<Output, Never> {
return try await self.bind0(
serverBackPressureStrategy: serverBackPressureStrategy,
childChannelInitializer: childChannelInitializer,
registration: { (serverChannel, promise) in
serverChannel.register().whenComplete { result in
switch result {
case .success:
serverChannel.triggerUserOutboundEvent(NIOTSNetworkEvents.BindToNWEndpoint(endpoint: endpoint), promise: promise)
case .failure(let error):
promise.fail(error)
}
}
}
).get()
}
/// Bind the `NIOTSListenerChannel` to an existing `NWListener`.
///
/// - Parameters:
/// - listener: The NWListener to wrap.
/// - serverBackPressureStrategy: The back pressure strategy used by the server socket channel.
/// - channelInitializer: A closure to initialize the channel. The return value of this closure is returned from the `bind`
/// method.
/// - Returns: The result of the channel initializer.
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
public func withNWListener<Output: Sendable>(
_ listener: NWListener,
serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark? = nil,
childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<Output>
) async throws -> NIOAsyncChannel<Output, Never> {
return try await self.bind0(
existingNWListener: listener,
serverBackPressureStrategy: serverBackPressureStrategy,
childChannelInitializer: childChannelInitializer,
registration: { (serverChannel, promise) in
serverChannel.registerAlreadyConfigured0(promise: promise)
}
).get()
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func bind0<ChannelInitializerResult>(
existingNWListener: NWListener? = nil,
serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
registration: @escaping (NIOTSListenerChannel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<NIOAsyncChannel<ChannelInitializerResult, Never>> {
let eventLoop = self.group.next() as! NIOTSEventLoop
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
let childChannelInit = self.childChannelInit
let serverChannelOptions = self.serverChannelOptions
let childChannelOptions = self.childChannelOptions
let serverChannel: NIOTSListenerChannel
if let newListener = existingNWListener {
serverChannel = NIOTSListenerChannel(
wrapping: newListener,
on: self.group.next() as! NIOTSEventLoop,
qos: self.serverQoS,
tcpOptions: self.tcpOptions,
tlsOptions: self.tlsOptions,
childLoopGroup: self.childGroup,
childChannelQoS: self.childQoS,
childTCPOptions: self.tcpOptions,
childTLSOptions: self.tlsOptions
)
} else {
serverChannel = NIOTSListenerChannel(
eventLoop: eventLoop,
qos: self.serverQoS,
tcpOptions: self.tcpOptions,
tlsOptions: self.tlsOptions,
childLoopGroup: self.childGroup,
childChannelQoS: self.childQoS,
childTCPOptions: self.tcpOptions,
childTLSOptions: self.tlsOptions
)
}
return eventLoop.submit {
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
serverChannelInit(serverChannel)
}.flatMap { (_) -> EventLoopFuture<NIOAsyncChannel<ChannelInitializerResult, Never>> in
do {
try serverChannel.pipeline.syncOperations.addHandler(
AcceptHandler<NIOTSConnectionChannel>(childChannelInitializer: childChannelInit, childChannelOptions: childChannelOptions),
name: "AcceptHandler"
)
let asyncChannel = try NIOAsyncChannel<ChannelInitializerResult, Never>
._wrapAsyncChannelWithTransformations(
synchronouslyWrapping: serverChannel,
backPressureStrategy: serverBackPressureStrategy,
channelReadTransformation: { channel -> EventLoopFuture<(ChannelInitializerResult)> in
// The channelReadTransformation is run on the EL of the server channel
// We have to make sure that we execute child channel initializer on the
// EL of the child channel.
channel.eventLoop.flatSubmit {
childChannelInitializer(channel)
}
}
)
let bindPromise = eventLoop.makePromise(of: Void.self)
registration(serverChannel, bindPromise)
if let bindTimeout = self.bindTimeout {
let cancelTask = eventLoop.scheduleTask(in: bindTimeout) {
bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout))
serverChannel.close(promise: nil)
}
bindPromise.futureResult.whenComplete { (_: Result<Void, Error>) in
cancelTask.cancel()
}
}
return bindPromise.futureResult
.map { (_) -> NIOAsyncChannel<ChannelInitializerResult, Never> in asyncChannel
}
} catch {
return eventLoop.makeFailedFuture(error)
}
}.flatMapError { error -> EventLoopFuture<NIOAsyncChannel<ChannelInitializerResult, Never>> in
serverChannel.close0(error: error, mode: .all, promise: nil)
return eventLoop.makeFailedFuture(error)
}
}.flatMap {
$0
}
}
}
#endif

View File

@ -0,0 +1,697 @@
//===----------------------------------------------------------------------===//
//
// This source file is part of the SwiftNIO open source project
//
// Copyright (c) 2023 Apple Inc. and the SwiftNIO project authors
// Licensed under Apache License v2.0
//
// See LICENSE.txt for license information
// See CONTRIBUTORS.txt for the list of SwiftNIO project authors
//
// SPDX-License-Identifier: Apache-2.0
//
//===----------------------------------------------------------------------===//
#if canImport(Network)
import NIOConcurrencyHelpers
import NIOTransportServices
@_spi(AsyncChannel) import NIOCore
import XCTest
@_spi(AsyncChannel) import NIOTLS
private final class LineDelimiterCoder: ByteToMessageDecoder, MessageToByteEncoder {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
private let newLine = "\n".utf8.first!
func decode(context: ChannelHandlerContext, buffer: inout ByteBuffer) throws -> DecodingState {
let readable = buffer.withUnsafeReadableBytes { $0.firstIndex(of: self.newLine) }
if let readable = readable {
context.fireChannelRead(self.wrapInboundOut(buffer.readSlice(length: readable)!))
buffer.moveReaderIndex(forwardBy: 1)
return .continue
}
return .needMoreData
}
func encode(data: ByteBuffer, out: inout ByteBuffer) throws {
out.writeImmutableBuffer(data)
out.writeString("\n")
}
}
private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannelHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = ByteBuffer
enum ALPN: String {
case string
case byte
case unknown
}
private var proposedALPN: ALPN?
init(
proposedALPN: ALPN? = nil
) {
self.proposedALPN = proposedALPN
}
func handlerAdded(context: ChannelHandlerContext) {
guard context.channel.isActive else {
return
}
if let proposedALPN = self.proposedALPN {
self.proposedALPN = nil
context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil)
}
context.fireChannelActive()
}
func channelActive(context: ChannelHandlerContext) {
if let proposedALPN = self.proposedALPN {
context.writeAndFlush(.init(ByteBuffer(string: "negotiate-alpn:\(proposedALPN.rawValue)")), promise: nil)
}
context.fireChannelActive()
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let buffer = self.unwrapInboundIn(data)
let string = String(buffer: buffer)
if string.hasPrefix("negotiate-alpn:") {
let alpn = String(string.dropFirst(15))
context.writeAndFlush(.init(ByteBuffer(string: "alpn:\(alpn)")), promise: nil)
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn))
context.pipeline.removeHandler(self, promise: nil)
} else if string.hasPrefix("alpn:") {
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5))))
context.pipeline.removeHandler(self, promise: nil)
} else {
context.fireChannelRead(data)
}
}
}
private final class ByteBufferToStringHandler: ChannelDuplexHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = String
typealias OutboundIn = String
typealias OutboundOut = ByteBuffer
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let buffer = self.unwrapInboundIn(data)
context.fireChannelRead(self.wrapInboundOut(String(buffer: buffer)))
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let buffer = ByteBuffer(string: self.unwrapOutboundIn(data))
context.write(.init(buffer), promise: promise)
}
}
private final class ByteBufferToByteHandler: ChannelDuplexHandler {
typealias InboundIn = ByteBuffer
typealias InboundOut = UInt8
typealias OutboundIn = UInt8
typealias OutboundOut = ByteBuffer
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var buffer = self.unwrapInboundIn(data)
let byte = buffer.readInteger(as: UInt8.self)!
context.fireChannelRead(self.wrapInboundOut(byte))
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let buffer = ByteBuffer(integer: self.unwrapOutboundIn(data))
context.write(.init(buffer), promise: promise)
}
}
private final class AddressedEnvelopingHandler: ChannelDuplexHandler {
typealias InboundIn = AddressedEnvelope<ByteBuffer>
typealias InboundOut = ByteBuffer
typealias OutboundIn = ByteBuffer
typealias OutboundOut = Any
var remoteAddress: SocketAddress?
init(remoteAddress: SocketAddress? = nil) {
self.remoteAddress = remoteAddress
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let envelope = self.unwrapInboundIn(data)
self.remoteAddress = envelope.remoteAddress
context.fireChannelRead(self.wrapInboundOut(envelope.data))
}
func write(context: ChannelHandlerContext, data: NIOAny, promise: EventLoopPromise<Void>?) {
let buffer = self.unwrapOutboundIn(data)
if let remoteAddress = self.remoteAddress {
context.write(self.wrapOutboundOut(AddressedEnvelope(remoteAddress: remoteAddress, data: buffer)), promise: promise)
return
}
context.write(self.wrapOutboundOut(buffer), promise: promise)
}
}
final class AsyncChannelBootstrapTests: XCTestCase {
enum NegotiationResult {
case string(NIOAsyncChannel<String, String>)
case byte(NIOAsyncChannel<UInt8, UInt8>)
}
struct ProtocolNegotiationError: Error {}
enum StringOrByte: Hashable {
case string(String)
case byte(UInt8)
}
func testServerClientBootstrap_withAsyncChannel_andHostPort() async throws {
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
}
let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.autoRead, value: true)
.bind(
host: "127.0.0.1",
port: 0
) { channel -> EventLoopFuture<NIOAsyncChannel<String, String>> in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
return try NIOAsyncChannel(
synchronouslyWrapping: channel,
configuration: .init(
inboundType: String.self,
outboundType: String.self
)
)
}
}
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var iterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { _ in
for try await childChannel in channel.inbound {
for try await value in childChannel.inbound {
continuation.yield(.string(value))
}
}
}
}
let stringChannel = try await self.makeClientChannel(eventLoopGroup: eventLoopGroup, port: channel.channel.localAddress!.port!)
try await stringChannel.outbound.write("hello")
await XCTAsyncAssertEqual(await iterator.next(), .string("hello"))
group.cancelAll()
}
}
func testAsyncChannelProtocolNegotiation() async throws {
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
}
let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.autoRead, value: true)
.bind(
host: "127.0.0.1",
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
}
}
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var serverIterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await negotiationResult in channel.inbound {
group.addTask {
switch try await negotiationResult.get() {
case .string(let channel):
for try await value in channel.inbound {
continuation.yield(.string(value))
}
case .byte(let channel):
for try await value in channel.inbound {
continuation.yield(.byte(value))
}
}
}
}
}
}
let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .string
)
switch try await stringNegotiationResult.get() {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outbound.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
let byteNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .byte
)
switch try await byteNegotiationResult.get() {
case .string:
preconditionFailure()
case .byte(let byteChannel):
// This is the actual content
try await byteChannel.outbound.write(UInt8(8))
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
}
group.cancelAll()
}
}
func testAsyncChannelNestedProtocolNegotiation() async throws {
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
}
let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.childChannelOption(ChannelOptions.autoRead, value: true)
.bind(
host: "127.0.0.1",
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureNestedProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
}
}
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var serverIterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await negotiationResult in channel.inbound {
group.addTask {
switch try await negotiationResult.get().get() {
case .string(let channel):
for try await value in channel.inbound {
continuation.yield(.string(value))
}
case .byte(let channel):
for try await value in channel.inbound {
continuation.yield(.byte(value))
}
}
}
}
}
}
let stringStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .string,
proposedInnerALPN: .string
)
switch try await stringStringNegotiationResult.get().get() {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outbound.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
let byteStringNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .byte,
proposedInnerALPN: .string
)
switch try await byteStringNegotiationResult.get().get() {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outbound.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
let byteByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .byte,
proposedInnerALPN: .byte
)
switch try await byteByteNegotiationResult.get().get() {
case .string:
preconditionFailure()
case .byte(let byteChannel):
// This is the actual content
try await byteChannel.outbound.write(UInt8(8))
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
}
let stringByteNegotiationResult = try await self.makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedOuterALPN: .string,
proposedInnerALPN: .byte
)
switch try await stringByteNegotiationResult.get().get() {
case .string:
preconditionFailure()
case .byte(let byteChannel):
// This is the actual content
try await byteChannel.outbound.write(UInt8(8))
await XCTAsyncAssertEqual(await serverIterator.next(), .byte(8))
}
group.cancelAll()
}
}
func testAsyncChannelProtocolNegotiation_whenFails() async throws {
final class CollectingHandler: ChannelInboundHandler {
typealias InboundIn = Channel
private let channels: NIOLockedValueBox<[Channel]>
init(channels: NIOLockedValueBox<[Channel]>) {
self.channels = channels
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
let channel = self.unwrapInboundIn(data)
self.channels.withLockedValue { $0.append(channel) }
context.fireChannelRead(data)
}
}
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
}
let channels = NIOLockedValueBox<[Channel]>([Channel]())
let channel: NIOAsyncChannel<EventLoopFuture<NegotiationResult>, Never> = try await NIOTSListenerBootstrap(group: eventLoopGroup)
.serverChannelOption(ChannelOptions.socketOption(.so_reuseaddr), value: 1)
.serverChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(CollectingHandler(channels: channels))
}
}
.childChannelOption(ChannelOptions.autoRead, value: true)
.bind(
host: "127.0.0.1",
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
}
}
try await withThrowingTaskGroup(of: Void.self) { group in
let (stream, continuation) = AsyncStream<StringOrByte>.makeStream()
var serverIterator = stream.makeAsyncIterator()
group.addTask {
try await withThrowingTaskGroup(of: Void.self) { group in
for try await negotiationResult in channel.inbound {
group.addTask {
switch try await negotiationResult.get() {
case .string(let channel):
for try await value in channel.inbound {
continuation.yield(.string(value))
}
case .byte(let channel):
for try await value in channel.inbound {
continuation.yield(.byte(value))
}
}
}
}
}
}
let failedProtocolNegotiation = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .unknown
)
await XCTAssertThrowsError(
try await failedProtocolNegotiation.get()
)
// Let's check that we can still open a new connection
let stringNegotiationResult = try await self.makeClientChannelWithProtocolNegotiation(
eventLoopGroup: eventLoopGroup,
port: channel.channel.localAddress!.port!,
proposedALPN: .string
)
switch try await stringNegotiationResult.get() {
case .string(let stringChannel):
// This is the actual content
try await stringChannel.outbound.write("hello")
await XCTAsyncAssertEqual(await serverIterator.next(), .string("hello"))
case .byte:
preconditionFailure()
}
let failedInboundChannel = channels.withLockedValue { channels -> Channel in
XCTAssertEqual(channels.count, 2)
return channels[0]
}
// We are waiting here to make sure the channel got closed
try await failedInboundChannel.closeFuture.get()
group.cancelAll()
}
}
// MARK: - Test Helpers
private func makeClientChannel(eventLoopGroup: EventLoopGroup, port: Int) async throws -> NIOAsyncChannel<String, String> {
return try await NIOTSConnectionBootstrap(group: eventLoopGroup)
.connect(
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(AddressedEnvelopingHandler())
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
return try NIOAsyncChannel(
synchronouslyWrapping: channel,
configuration: .init(
inboundType: String.self,
outboundType: String.self
)
)
}
}
}
private func makeClientChannelWithProtocolNegotiation(
eventLoopGroup: EventLoopGroup,
port: Int,
proposedALPN: TLSUserEventHandler.ALPN
) async throws -> EventLoopFuture<NegotiationResult> {
return try await NIOTSConnectionBootstrap(group: eventLoopGroup)
.connect(
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
return channel.eventLoop.makeCompletedFuture {
return try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN).protocolNegotiationResult
}
}
}
private func makeClientChannelWithNestedProtocolNegotiation(
eventLoopGroup: EventLoopGroup,
port: Int,
proposedOuterALPN: TLSUserEventHandler.ALPN,
proposedInnerALPN: TLSUserEventHandler.ALPN
) async throws -> EventLoopFuture<EventLoopFuture<NegotiationResult>> {
return try await NIOTSConnectionBootstrap(group: eventLoopGroup)
.connect(
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
return channel.eventLoop.makeCompletedFuture {
try self.configureNestedProtocolNegotiationHandlers(
channel: channel,
proposedOuterALPN: proposedOuterALPN,
proposedInnerALPN: proposedInnerALPN
).protocolNegotiationResult
}
}
}
@discardableResult
private func configureProtocolNegotiationHandlers(
channel: Channel,
proposedALPN: TLSUserEventHandler.ALPN? = nil
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN))
return try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
}
@discardableResult
private func configureNestedProtocolNegotiationHandlers(
channel: Channel,
proposedOuterALPN: TLSUserEventHandler.ALPN? = nil,
proposedInnerALPN: TLSUserEventHandler.ALPN? = nil
) throws -> NIOTypedApplicationProtocolNegotiationHandler<EventLoopFuture<NegotiationResult>> {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedOuterALPN))
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<EventLoopFuture<NegotiationResult>> { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
case "string":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN))
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return negotiationFuture.protocolNegotiationResult
}
case "byte":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedInnerALPN))
let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return negotiationHandler.protocolNegotiationResult
}
default:
return channel.close().flatMapThrowing { throw ProtocolNegotiationError() }
}
case .fallback:
return channel.close().flatMapThrowing { throw ProtocolNegotiationError() }
}
}
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
return negotiationHandler
}
@discardableResult
private func addTypedApplicationProtocolNegotiationHandler(to channel: Channel) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> { alpnResult, channel in
switch alpnResult {
case .negotiated(let alpn):
switch alpn {
case "string":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteBufferToStringHandler())
let asyncChannel: NIOAsyncChannel<String, String> = try NIOAsyncChannel(
synchronouslyWrapping: channel
)
return NegotiationResult.string(asyncChannel)
}
case "byte":
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ByteBufferToByteHandler())
let asyncChannel: NIOAsyncChannel<UInt8, UInt8> = try NIOAsyncChannel(
synchronouslyWrapping: channel
)
return NegotiationResult.byte(asyncChannel)
}
default:
return channel.close().flatMapThrowing { throw ProtocolNegotiationError() }
}
case .fallback:
return channel.close().flatMapThrowing { throw ProtocolNegotiationError() }
}
}
try channel.pipeline.syncOperations.addHandler(negotiationHandler)
return negotiationHandler
}
}
extension AsyncStream {
fileprivate static func makeStream(
of elementType: Element.Type = Element.self,
bufferingPolicy limit: Continuation.BufferingPolicy = .unbounded
) -> (stream: AsyncStream<Element>, continuation: AsyncStream<Element>.Continuation) {
var continuation: AsyncStream<Element>.Continuation!
let stream = AsyncStream<Element>(bufferingPolicy: limit) { continuation = $0 }
return (stream: stream, continuation: continuation!)
}
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
private func XCTAsyncAssertEqual<Element: Equatable>(_ lhs: @autoclosure () async throws -> Element, _ rhs: @autoclosure () async throws -> Element, file: StaticString = #filePath, line: UInt = #line) async rethrows {
let lhsResult = try await lhs()
let rhsResult = try await rhs()
XCTAssertEqual(lhsResult, rhsResult, file: file, line: line)
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
private func XCTAsyncAssertThrowsError<T>(
_ expression: @autoclosure () async throws -> T,
_ message: @autoclosure () -> String = "",
file: StaticString = #filePath,
line: UInt = #line,
_ errorHandler: (_ error: Error) -> Void = { _ in }
) async {
do {
_ = try await expression()
XCTFail(message(), file: file, line: line)
} catch {
errorHandler(error)
}
}
@available(macOS 10.15, iOS 13.0, watchOS 6.0, tvOS 13.0, *)
internal func XCTAssertThrowsError<T>(
_ expression: @autoclosure () async throws -> T,
file: StaticString = #filePath,
line: UInt = #line,
verify: (Error) -> Void = { _ in }
) async {
do {
_ = try await expression()
XCTFail("Expression did not throw error", file: file, line: line)
} catch {
verify(error)
}
}
#endif