Strict concurrency for NIOTransportServices and tests (#228)

This commit is contained in:
Gus Cairo 2025-04-02 10:54:00 +01:00 committed by GitHub
parent a9b23220e4
commit 92bb536b7e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 552 additions and 340 deletions

View File

@ -15,13 +15,31 @@
import PackageDescription
let strictConcurrencyDevelopment = false
let strictConcurrencySettings: [SwiftSetting] = {
var initialSettings: [SwiftSetting] = []
initialSettings.append(contentsOf: [
.enableUpcomingFeature("StrictConcurrency"),
.enableUpcomingFeature("InferSendableFromCaptures"),
])
if strictConcurrencyDevelopment {
// -warnings-as-errors here is a workaround so that IDE-based development can
// get tripped up on -require-explicit-sendable.
initialSettings.append(.unsafeFlags(["-require-explicit-sendable", "-warnings-as-errors"]))
}
return initialSettings
}()
let package = Package(
name: "swift-nio-transport-services",
products: [
.library(name: "NIOTransportServices", targets: ["NIOTransportServices"])
],
dependencies: [
.package(url: "https://github.com/apple/swift-nio.git", from: "2.62.0"),
.package(url: "https://github.com/apple/swift-nio.git", from: "2.81.0"),
.package(url: "https://github.com/apple/swift-atomics.git", from: "1.0.2"),
],
targets: [
@ -33,7 +51,8 @@ let package = Package(
.product(name: "NIOFoundationCompat", package: "swift-nio"),
.product(name: "NIOTLS", package: "swift-nio"),
.product(name: "Atomics", package: "swift-atomics"),
]
],
swiftSettings: strictConcurrencySettings
),
.executableTarget(
name: "NIOTSHTTPClient",
@ -58,7 +77,8 @@ let package = Package(
.product(name: "NIOCore", package: "swift-nio"),
.product(name: "NIOEmbedded", package: "swift-nio"),
.product(name: "Atomics", package: "swift-atomics"),
]
],
swiftSettings: strictConcurrencySettings
),
]
)

View File

@ -20,11 +20,11 @@ internal class AcceptHandler<ChildChannel: Channel>: ChannelInboundHandler {
typealias InboundIn = ChildChannel
typealias InboundOut = ChildChannel
private let childChannelInitializer: ((Channel) -> EventLoopFuture<Void>)?
private let childChannelInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?
private let childChannelOptions: ChannelOptions.Storage
init(
childChannelInitializer: ((Channel) -> EventLoopFuture<Void>)?,
childChannelInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?,
childChannelOptions: ChannelOptions.Storage
) {
self.childChannelInitializer = childChannelInitializer
@ -35,11 +35,12 @@ internal class AcceptHandler<ChildChannel: Channel>: ChannelInboundHandler {
let newChannel = self.unwrapInboundIn(data)
let childLoop = newChannel.eventLoop
let ctxEventLoop = context.eventLoop
let childInitializer = self.childChannelInitializer ?? { _ in childLoop.makeSucceededFuture(()) }
let childInitializer = self.childChannelInitializer ?? { @Sendable _ in childLoop.makeSucceededFuture(()) }
let childChannelOptions = self.childChannelOptions
@inline(__always)
@Sendable @inline(__always)
func setupChildChannel() -> EventLoopFuture<Void> {
self.childChannelOptions.applyAllChannelOptions(to: newChannel).flatMap { () -> EventLoopFuture<Void> in
childChannelOptions.applyAllChannelOptions(to: newChannel).flatMap { () -> EventLoopFuture<Void> in
childLoop.assertInEventLoop()
return childInitializer(newChannel)
}
@ -48,8 +49,8 @@ internal class AcceptHandler<ChildChannel: Channel>: ChannelInboundHandler {
@inline(__always)
func fireThroughPipeline(_ future: EventLoopFuture<Void>) {
ctxEventLoop.assertInEventLoop()
future.flatMap { (_) -> EventLoopFuture<Void> in
ctxEventLoop.assertInEventLoop()
assert(ctxEventLoop === context.eventLoop)
future.assumeIsolated().flatMap { (_) -> EventLoopFuture<Void> in
guard context.channel.isActive else {
return newChannel.close().flatMapThrowing {
throw ChannelError.ioOnClosedChannel
@ -75,4 +76,7 @@ internal class AcceptHandler<ChildChannel: Channel>: ChannelInboundHandler {
}
}
}
@available(*, unavailable)
extension AcceptHandler: Sendable {}
#endif

View File

@ -41,7 +41,7 @@ import Network
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
public final class NIOTSDatagramBootstrap {
private let group: EventLoopGroup
private var channelInitializer: ((Channel) -> EventLoopFuture<Void>)?
private var channelInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)?
private var connectTimeout: TimeAmount = TimeAmount.seconds(10)
private var channelOptions = ChannelOptions.Storage()
private var qos: DispatchQoS?
@ -79,7 +79,8 @@ public final class NIOTSDatagramBootstrap {
///
/// - parameters:
/// - handler: A closure that initializes the provided `Channel`.
public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
@preconcurrency
public func channelInitializer(_ handler: @Sendable @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
self.channelInitializer = handler
return self
}
@ -180,17 +181,18 @@ public final class NIOTSDatagramBootstrap {
}
}
private func connect0(_ binder: @escaping (Channel, EventLoopPromise<Void>) -> Void) -> EventLoopFuture<Channel> {
private func connect0(
_ binder: @Sendable @escaping (Channel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<Channel> {
let conn: Channel = NIOTSDatagramChannel(
eventLoop: self.group.next() as! NIOTSEventLoop,
qos: self.qos,
udpOptions: self.udpOptions,
tlsOptions: self.tlsOptions
)
let initializer = self.channelInitializer ?? { _ in conn.eventLoop.makeSucceededFuture(()) }
let channelOptions = self.channelOptions
let initializer = self.channelInitializer ?? { @Sendable _ in conn.eventLoop.makeSucceededFuture(()) }
return conn.eventLoop.submit {
return conn.eventLoop.submit { [channelOptions, connectTimeout] in
channelOptions.applyAllChannelOptions(to: conn).flatMap {
initializer(conn)
}.flatMap {
@ -199,8 +201,8 @@ public final class NIOTSDatagramBootstrap {
}.flatMap {
let connectPromise: EventLoopPromise<Void> = conn.eventLoop.makePromise()
binder(conn, connectPromise)
let cancelTask = conn.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
let cancelTask = conn.eventLoop.scheduleTask(in: connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(connectTimeout))
conn.close(promise: nil)
}
@ -215,4 +217,7 @@ public final class NIOTSDatagramBootstrap {
}.flatMap { $0 }
}
}
@available(*, unavailable)
extension NIOTSDatagramBootstrap: Sendable {}
#endif

View File

@ -243,4 +243,7 @@ extension NIOTSDatagramChannel {
SynchronousOptions(channel: self)
}
}
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
extension NIOTSDatagramChannel: @unchecked Sendable {}
#endif

View File

@ -57,8 +57,8 @@ import Network
public final class NIOTSDatagramListenerBootstrap {
private let group: EventLoopGroup
private let childGroup: EventLoopGroup
private var serverChannelInit: ((Channel) -> EventLoopFuture<Void>)?
private var childChannelInit: ((Channel) -> EventLoopFuture<Void>)?
private var serverChannelInit: (@Sendable (Channel) -> EventLoopFuture<Void>)?
private var childChannelInit: (@Sendable (Channel) -> EventLoopFuture<Void>)?
private var serverChannelOptions = ChannelOptions.Storage()
private var childChannelOptions = ChannelOptions.Storage()
private var serverQoS: DispatchQoS?
@ -154,7 +154,9 @@ public final class NIOTSDatagramListenerBootstrap {
///
/// - parameters:
/// - initializer: A closure that initializes the provided `Channel`.
public func serverChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
@preconcurrency
public func serverChannelInitializer(_ initializer: @Sendable @escaping (Channel) -> EventLoopFuture<Void>) -> Self
{
self.serverChannelInit = initializer
return self
}
@ -167,7 +169,8 @@ public final class NIOTSDatagramListenerBootstrap {
///
/// - parameters:
/// - initializer: A closure that initializes the provided `Channel`.
public func childChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
@preconcurrency
public func childChannelInitializer(_ initializer: @Sendable @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
self.childChannelInit = initializer
return self
}
@ -305,10 +308,13 @@ public final class NIOTSDatagramListenerBootstrap {
private func bind0(
existingNWListener: NWListener? = nil,
shouldRegister: Bool,
_ binder: @escaping (NIOTSDatagramListenerChannel, EventLoopPromise<Void>) -> Void
_ binder: @Sendable @escaping (NIOTSDatagramListenerChannel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<Channel> {
let eventLoop = self.group.next() as! NIOTSEventLoop
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
let serverChannelInit =
self.serverChannelInit ?? {
@Sendable _ in eventLoop.makeSucceededFuture(())
}
let childChannelInit = self.childChannelInit
let serverChannelOptions = self.serverChannelOptions
let childChannelOptions = self.childChannelOptions
@ -339,17 +345,19 @@ public final class NIOTSDatagramListenerBootstrap {
)
}
return eventLoop.submit {
return eventLoop.submit { [bindTimeout] in
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
serverChannelInit(serverChannel)
}.flatMap {
eventLoop.assertInEventLoop()
return serverChannel.pipeline.addHandler(
AcceptHandler<NIOTSDatagramChannel>(
childChannelInitializer: childChannelInit,
childChannelOptions: childChannelOptions
return eventLoop.makeCompletedFuture {
try serverChannel.pipeline.syncOperations.addHandler(
AcceptHandler<NIOTSDatagramChannel>(
childChannelInitializer: childChannelInit,
childChannelOptions: childChannelOptions
)
)
)
}
}.flatMap {
if shouldRegister {
return serverChannel.register()
@ -360,7 +368,7 @@ public final class NIOTSDatagramListenerBootstrap {
let bindPromise = eventLoop.makePromise(of: Void.self)
binder(serverChannel, bindPromise)
if let bindTimeout = self.bindTimeout {
if let bindTimeout = bindTimeout {
let cancelTask = eventLoop.scheduleTask(in: bindTimeout) {
bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout))
serverChannel.close(promise: nil)
@ -382,4 +390,7 @@ public final class NIOTSDatagramListenerBootstrap {
}
}
}
@available(*, unavailable)
extension NIOTSDatagramListenerBootstrap: Sendable {}
#endif

View File

@ -135,7 +135,7 @@ internal final class NIOTSDatagramListenerChannel: StateManagedListenerChannel<N
tlsOptions: self.childTLSOptions
)
self.pipeline.fireChannelRead(NIOAny(newChannel))
self.pipeline.fireChannelRead(newChannel)
self.pipeline.fireChannelReadComplete()
}

View File

@ -13,11 +13,11 @@
//===----------------------------------------------------------------------===//
#if canImport(Network)
import NIOCore
@preconcurrency import Network
import Network
/// Options that can be set explicitly and only on bootstraps provided by `NIOTransportServices`.
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
public struct NIOTSChannelOptions {
public struct NIOTSChannelOptions: Sendable {
/// See: ``Types/NIOTSWaitForActivityOption``.
public static let waitForActivity = NIOTSChannelOptions.Types.NIOTSWaitForActivityOption()
@ -32,7 +32,7 @@ public struct NIOTSChannelOptions {
/// See: ``Types/NIOTSMetadataOption``
public static let metadata = {
(definition: NWProtocolDefinition) -> NIOTSChannelOptions.Types.NIOTSMetadataOption in
@Sendable (definition: NWProtocolDefinition) -> NIOTSChannelOptions.Types.NIOTSMetadataOption in
.init(definition: definition)
}
@ -66,7 +66,7 @@ public struct NIOTSChannelOptions {
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
extension NIOTSChannelOptions {
/// A namespace for ``NIOTSChannelOptions`` datastructures.
public enum Types {
public enum Types: Sendable {
/// ``NIOTSWaitForActivityOption`` controls whether the `Channel` should wait for connection changes
/// during the connection process if the connection attempt fails. If Network.framework believes that
/// a connection may succeed in future, it may transition into the `.waiting` state. By default, this option

View File

@ -41,13 +41,15 @@ import Network
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
public final class NIOTSConnectionBootstrap {
private let group: EventLoopGroup
private var _channelInitializer: ((Channel) -> EventLoopFuture<Void>)
private var channelInitializer: ((Channel) -> EventLoopFuture<Void>) {
private var _channelInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>)
private var channelInitializer: (@Sendable (Channel) -> EventLoopFuture<Void>) {
if let protocolHandlers = self.protocolHandlers {
let channelInitializer = self._channelInitializer
return { channel in
channelInitializer(channel).flatMap {
channel.pipeline.addHandlers(protocolHandlers(), position: .first)
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandlers(protocolHandlers(), position: .first)
}
}
}
} else {
@ -59,7 +61,7 @@ public final class NIOTSConnectionBootstrap {
private var qos: DispatchQoS?
private var tcpOptions: NWProtocolTCP.Options = .init()
private var tlsOptions: NWProtocolTLS.Options?
private var protocolHandlers: (() -> [ChannelHandler])? = nil
private var protocolHandlers: (@Sendable () -> [ChannelHandler])? = nil
/// Create a `NIOTSConnectionBootstrap` on the `EventLoopGroup` `group`.
///
@ -111,7 +113,8 @@ public final class NIOTSConnectionBootstrap {
///
/// - parameters:
/// - handler: A closure that initializes the provided `Channel`.
public func channelInitializer(_ handler: @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
@preconcurrency
public func channelInitializer(_ handler: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self {
self._channelInitializer = handler
return self
}
@ -229,7 +232,10 @@ public final class NIOTSConnectionBootstrap {
private func connect(
existingNWConnection: NWConnection? = nil,
shouldRegister: Bool,
_ connectAction: @escaping (NIOTSConnectionChannel, EventLoopPromise<Void>) -> Void
_ connectAction: @Sendable @escaping (
NIOTSConnectionChannel,
EventLoopPromise<Void>
) -> Void
) -> EventLoopFuture<Channel> {
let conn: NIOTSConnectionChannel
if let newConnection = existingNWConnection {
@ -250,7 +256,7 @@ public final class NIOTSConnectionBootstrap {
let initializer = self.channelInitializer
let channelOptions = self.channelOptions
return conn.eventLoop.flatSubmit {
return conn.eventLoop.flatSubmit { [connectTimeout] in
channelOptions.applyAllChannelOptions(to: conn).flatMap {
initializer(conn)
}.flatMap {
@ -263,8 +269,8 @@ public final class NIOTSConnectionBootstrap {
}.flatMap {
let connectPromise: EventLoopPromise<Void> = conn.eventLoop.makePromise()
connectAction(conn, connectPromise)
let cancelTask = conn.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
let cancelTask = conn.eventLoop.scheduleTask(in: connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(connectTimeout))
conn.close(promise: nil)
}
@ -285,7 +291,8 @@ public final class NIOTSConnectionBootstrap {
/// Per bootstrap, you can only set the `protocolHandlers` once. Typically, `protocolHandlers` are used for the TLS
/// implementation. Most notably, `NIOClientTCPBootstrap`, NIO's "universal bootstrap" abstraction, uses
/// `protocolHandlers` to add the required `ChannelHandler`s for many TLS implementations.
public func protocolHandlers(_ handlers: @escaping () -> [ChannelHandler]) -> Self {
@preconcurrency
public func protocolHandlers(_ handlers: @Sendable @escaping () -> [ChannelHandler]) -> Self {
precondition(self.protocolHandlers == nil, "protocol handlers can only be set once")
self.protocolHandlers = handlers
return self
@ -419,10 +426,10 @@ extension NIOTSConnectionBootstrap {
}
@available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *)
private func connect0<ChannelInitializerResult>(
private func connect0<ChannelInitializerResult: Sendable>(
existingNWConnection: NWConnection? = nil,
channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
registration: @escaping (NIOTSConnectionChannel, EventLoopPromise<Void>) -> Void
registration: @Sendable @escaping (NIOTSConnectionChannel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<ChannelInitializerResult> {
let connectionChannel: NIOTSConnectionChannel
if let newConnection = existingNWConnection {
@ -440,20 +447,20 @@ extension NIOTSConnectionBootstrap {
tlsOptions: self.tlsOptions
)
}
let channelInitializer = { (channel: Channel) -> EventLoopFuture<ChannelInitializerResult> in
let initializer = self.channelInitializer
return initializer(channel).flatMap { channelInitializer(channel) }
let initializer = self.channelInitializer
let channelInitializer = { @Sendable (channel: Channel) -> EventLoopFuture<ChannelInitializerResult> in
initializer(channel).flatMap { channelInitializer(channel) }
}
let channelOptions = self.channelOptions
return connectionChannel.eventLoop.flatSubmit {
return connectionChannel.eventLoop.flatSubmit { [connectTimeout] in
channelOptions.applyAllChannelOptions(to: connectionChannel).flatMap {
channelInitializer(connectionChannel)
}.flatMap { result -> EventLoopFuture<ChannelInitializerResult> in
let connectPromise: EventLoopPromise<Void> = connectionChannel.eventLoop.makePromise()
registration(connectionChannel, connectPromise)
let cancelTask = connectionChannel.eventLoop.scheduleTask(in: self.connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(self.connectTimeout))
let cancelTask = connectionChannel.eventLoop.scheduleTask(in: connectTimeout) {
connectPromise.fail(ChannelError.connectTimeout(connectTimeout))
connectionChannel.close(promise: nil)
}

View File

@ -406,7 +406,7 @@ extension NIOTSConnectionChannel {
// APIs.
var buffer = self.allocator.buffer(capacity: content.count)
buffer.writeBytes(content)
self.pipeline.fireChannelRead(NIOAny(buffer))
self.pipeline.fireChannelRead(buffer)
self.pipeline.fireChannelReadComplete()
}
@ -568,4 +568,7 @@ extension Channel {
}
}
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
extension NIOTSConnectionChannel: @unchecked Sendable {}
#endif

View File

@ -13,7 +13,11 @@
//===----------------------------------------------------------------------===//
#if canImport(Network)
import Dispatch
#if swift(<6.1)
@preconcurrency import class Dispatch.DispatchSource
#else
import class Dispatch.DispatchSource
#endif
import Foundation
import Network
@ -28,11 +32,17 @@ import NIOConcurrencyHelpers
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
public protocol QoSEventLoop: EventLoop {
/// Submit a given task to be executed by the `EventLoop` at a given `qos`.
func execute(qos: DispatchQoS, _ task: @escaping () -> Void)
@preconcurrency
func execute(qos: DispatchQoS, _ task: @escaping @Sendable () -> Void)
/// Schedule a `task` that is executed by this `NIOTSEventLoop` after the given amount of time at the
/// given `qos`.
func scheduleTask<T>(in time: TimeAmount, qos: DispatchQoS, _ task: @escaping () throws -> T) -> Scheduled<T>
@preconcurrency
func scheduleTask<T>(
in time: TimeAmount,
qos: DispatchQoS,
_ task: @escaping @Sendable () throws -> T
) -> Scheduled<T>
}
/// The lifecycle state of a given event loop.
@ -49,8 +59,9 @@ private enum LifecycleState {
case closed
}
// It's okay for NIOTSEventLoop to be unchecked Sendable, since the state is isolated to the EL.
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
internal class NIOTSEventLoop: QoSEventLoop {
internal final class NIOTSEventLoop: QoSEventLoop, @unchecked Sendable {
private let loop: DispatchQueue
private let taskQueue: DispatchQueue
private let inQueueKey: DispatchSpecificKey<UUID>
@ -114,23 +125,27 @@ internal class NIOTSEventLoop: QoSEventLoop {
loop.setSpecific(key: inQueueKey, value: self.loopID)
}
public func execute(_ task: @escaping () -> Void) {
@preconcurrency
public func execute(_ task: @escaping @Sendable () -> Void) {
self.execute(qos: self.defaultQoS, task)
}
public func execute(qos: DispatchQoS, _ task: @escaping () -> Void) {
@preconcurrency
public func execute(qos: DispatchQoS, _ task: @escaping @Sendable () -> Void) {
// Ideally we'd not accept new work while closed. Sadly, that's not possible with the current APIs for this.
self.taskQueue.async(qos: qos, execute: task)
}
public func scheduleTask<T>(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled<T> {
@preconcurrency
public func scheduleTask<T>(deadline: NIODeadline, _ task: @escaping @Sendable () throws -> T) -> Scheduled<T> {
self.scheduleTask(deadline: deadline, qos: self.defaultQoS, task)
}
@preconcurrency
public func scheduleTask<T>(
deadline: NIODeadline,
qos: DispatchQoS,
_ task: @escaping () throws -> T
_ task: @escaping @Sendable () throws -> T
) -> Scheduled<T> {
let p: EventLoopPromise<T> = self.makePromise()
@ -143,11 +158,12 @@ internal class NIOTSEventLoop: QoSEventLoop {
p.fail(EventLoopError.shutdown)
return
}
do {
p.succeed(try task())
} catch {
p.fail(error)
}
p.assumeIsolated().completeWith(
Result {
try task()
}
)
}
timerSource.resume()
@ -165,16 +181,25 @@ internal class NIOTSEventLoop: QoSEventLoop {
)
}
public func scheduleTask<T>(in time: TimeAmount, _ task: @escaping () throws -> T) -> Scheduled<T> {
@preconcurrency
public func scheduleTask<T>(
in time: TimeAmount,
_ task: @escaping @Sendable () throws -> T
) -> Scheduled<T> {
self.scheduleTask(in: time, qos: self.defaultQoS, task)
}
public func scheduleTask<T>(in time: TimeAmount, qos: DispatchQoS, _ task: @escaping () throws -> T) -> Scheduled<T>
{
@preconcurrency
public func scheduleTask<T>(
in time: TimeAmount,
qos: DispatchQoS,
_ task: @escaping @Sendable () throws -> T
) -> Scheduled<T> {
self.scheduleTask(deadline: NIODeadline.now() + time, qos: qos, task)
}
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
@preconcurrency
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping @Sendable (Error?) -> Void) {
guard self.canBeShutDownIndividually else {
// The loops cannot be shut down by individually. They need to be shut down as a group and
// `NIOTSEventLoopGroup` calls `closeGently` not this method.

View File

@ -82,7 +82,8 @@ public final class NIOTSEventLoopGroup: EventLoopGroup {
}
/// Shuts down all of the event loops, rendering them unable to perform further work.
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
@preconcurrency
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping @Sendable (Error?) -> Void) {
guard self.canBeShutDown else {
queue.async {
callback(EventLoopError.unsupportedOperation)
@ -91,19 +92,19 @@ public final class NIOTSEventLoopGroup: EventLoopGroup {
}
let g = DispatchGroup()
let q = DispatchQueue(label: "nio.transportservices.shutdowngracefullyqueue", target: queue)
var error: Error? = nil
let error: NIOLockedValueBox<Error?> = .init(nil)
for loop in self.eventLoops {
g.enter()
loop.closeGently().recover { err in
q.sync { error = err }
q.sync { error.withLockedValue({ $0 = err }) }
}.whenComplete { (_: Result<Void, Error>) in
g.leave()
}
}
g.notify(queue: q) {
callback(error)
callback(error.withLockedValue({ $0 }))
}
}
@ -145,4 +146,7 @@ public struct NIOTSClientTLSProvider: NIOClientTLSProvider {
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
extension NIOTSEventLoopGroup: @unchecked Sendable {}
@available(*, unavailable)
extension NIOTSClientTLSProvider: Sendable {}
#endif

View File

@ -57,8 +57,8 @@ import Network
public final class NIOTSListenerBootstrap {
private let group: EventLoopGroup
private let childGroup: EventLoopGroup
private var serverChannelInit: ((Channel) -> EventLoopFuture<Void>)?
private var childChannelInit: ((Channel) -> EventLoopFuture<Void>)?
private var serverChannelInit: (@Sendable (Channel) -> EventLoopFuture<Void>)?
private var childChannelInit: (@Sendable (Channel) -> EventLoopFuture<Void>)?
private var serverChannelOptions = ChannelOptions.Storage()
private var childChannelOptions = ChannelOptions.Storage()
private var serverQoS: DispatchQoS?
@ -157,7 +157,9 @@ public final class NIOTSListenerBootstrap {
///
/// - parameters:
/// - initializer: A closure that initializes the provided `Channel`.
public func serverChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
@preconcurrency
public func serverChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self
{
self.serverChannelInit = initializer
return self
}
@ -170,7 +172,8 @@ public final class NIOTSListenerBootstrap {
///
/// - parameters:
/// - initializer: A closure that initializes the provided `Channel`.
public func childChannelInitializer(_ initializer: @escaping (Channel) -> EventLoopFuture<Void>) -> Self {
@preconcurrency
public func childChannelInitializer(_ initializer: @escaping @Sendable (Channel) -> EventLoopFuture<Void>) -> Self {
self.childChannelInit = initializer
return self
}
@ -318,10 +321,10 @@ public final class NIOTSListenerBootstrap {
private func bind0(
existingNWListener: NWListener? = nil,
shouldRegister: Bool,
_ binder: @escaping (NIOTSListenerChannel, EventLoopPromise<Void>) -> Void
_ binder: @escaping @Sendable (NIOTSListenerChannel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<Channel> {
let eventLoop = self.group.next() as! NIOTSEventLoop
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
let serverChannelInit = self.serverChannelInit ?? { @Sendable _ in eventLoop.makeSucceededFuture(()) }
let childChannelInit = self.childChannelInit
let serverChannelOptions = self.serverChannelOptions
let childChannelOptions = self.childChannelOptions
@ -352,17 +355,19 @@ public final class NIOTSListenerBootstrap {
)
}
return eventLoop.submit {
return eventLoop.submit { [bindTimeout] in
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
serverChannelInit(serverChannel)
}.flatMap {
eventLoop.assertInEventLoop()
return serverChannel.pipeline.addHandler(
AcceptHandler<NIOTSConnectionChannel>(
childChannelInitializer: childChannelInit,
childChannelOptions: childChannelOptions
return eventLoop.makeCompletedFuture {
try serverChannel.pipeline.syncOperations.addHandler(
AcceptHandler<NIOTSConnectionChannel>(
childChannelInitializer: childChannelInit,
childChannelOptions: childChannelOptions
)
)
)
}
}.flatMap {
if shouldRegister {
return serverChannel.register()
@ -373,7 +378,7 @@ public final class NIOTSListenerBootstrap {
let bindPromise = eventLoop.makePromise(of: Void.self)
binder(serverChannel, bindPromise)
if let bindTimeout = self.bindTimeout {
if let bindTimeout = bindTimeout {
let cancelTask = eventLoop.scheduleTask(in: bindTimeout) {
bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout))
serverChannel.close(promise: nil)
@ -537,7 +542,7 @@ extension NIOTSListenerBootstrap {
existingNWListener: NWListener? = nil,
serverBackPressureStrategy: NIOAsyncSequenceProducerBackPressureStrategies.HighLowWatermark?,
childChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture<ChannelInitializerResult>,
registration: @escaping (NIOTSListenerChannel, EventLoopPromise<Void>) -> Void
registration: @escaping @Sendable (NIOTSListenerChannel, EventLoopPromise<Void>) -> Void
) -> EventLoopFuture<NIOAsyncChannel<ChannelInitializerResult, Never>> {
let eventLoop = self.group.next() as! NIOTSEventLoop
let serverChannelInit = self.serverChannelInit ?? { _ in eventLoop.makeSucceededFuture(()) }
@ -571,7 +576,7 @@ extension NIOTSListenerBootstrap {
)
}
return eventLoop.submit {
return eventLoop.submit { [bindTimeout] in
serverChannelOptions.applyAllChannelOptions(to: serverChannel).flatMap {
serverChannelInit(serverChannel)
}.flatMap { (_) -> EventLoopFuture<NIOAsyncChannel<ChannelInitializerResult, Never>> in
@ -585,7 +590,7 @@ extension NIOTSListenerBootstrap {
)
let asyncChannel = try NIOAsyncChannel<ChannelInitializerResult, Never>
._wrapAsyncChannelWithTransformations(
synchronouslyWrapping: serverChannel,
wrappingChannelSynchronously: serverChannel,
backPressureStrategy: serverBackPressureStrategy,
channelReadTransformation: { channel -> EventLoopFuture<(ChannelInitializerResult)> in
// The channelReadTransformation is run on the EL of the server channel
@ -600,7 +605,7 @@ extension NIOTSListenerBootstrap {
let bindPromise = eventLoop.makePromise(of: Void.self)
registration(serverChannel, bindPromise)
if let bindTimeout = self.bindTimeout {
if let bindTimeout = bindTimeout {
let cancelTask = eventLoop.scheduleTask(in: bindTimeout) {
bindPromise.fail(NIOTSErrors.BindTimeout(timeout: bindTimeout))
serverChannel.close(promise: nil)
@ -627,4 +632,6 @@ extension NIOTSListenerBootstrap {
}
}
@available(*, unavailable)
extension NIOTSListenerBootstrap: Sendable {}
#endif

View File

@ -137,7 +137,7 @@ internal final class NIOTSListenerChannel: StateManagedListenerChannel<NIOTSConn
tlsOptions: self.childTLSOptions
)
self.pipeline.fireChannelRead(NIOAny(newChannel))
self.pipeline.fireChannelRead(newChannel)
self.pipeline.fireChannelReadComplete()
}

View File

@ -13,7 +13,7 @@
//===----------------------------------------------------------------------===//
#if canImport(Network)
@preconcurrency import Network
import Network
import NIOCore
/// A tag protocol that can be used to cover all network events emitted by `NIOTransportServices`.
@ -23,7 +23,7 @@ import NIOCore
public protocol NIOTSNetworkEvent: Equatable, _NIOPreconcurrencySendable {}
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
public enum NIOTSNetworkEvents {
public enum NIOTSNetworkEvents: Sendable {
/// ``BetterPathAvailable`` is fired whenever the OS has informed NIO that there is a better
/// path available to the endpoint that this `Channel` is currently connected to,
/// e.g. the current connection is using an expensive cellular connection and

View File

@ -533,4 +533,9 @@ extension StateManagedListenerChannel {
}
}
// We inherit from StateManagedListenerChannel in NIOTSDatagramListenerChannel, so we can't mark
// it as Sendable safely.
@available(OSX 10.14, iOS 12.0, tvOS 12.0, watchOS 6.0, *)
extension StateManagedListenerChannel: @unchecked Sendable {}
#endif

View File

@ -182,7 +182,7 @@ extension StateManagedNWConnectionChannel {
preconditionFailure("nwconnection cannot be nil while channel is active")
}
func completionCallback(promise: EventLoopPromise<Void>?, sentBytes: Int) -> ((NWError?) -> Void) {
func completionCallback(promise: EventLoopPromise<Void>?, sentBytes: Int) -> (@Sendable (NWError?) -> Void) {
{ error in
if let error = error {
promise?.fail(error)
@ -314,7 +314,7 @@ extension StateManagedNWConnectionChannel {
return
}
func completionCallback(for promise: EventLoopPromise<Void>?) -> ((NWError?) -> Void) {
func completionCallback(for promise: EventLoopPromise<Void>?) -> (@Sendable (NWError?) -> Void) {
{ error in
if let error = error {
promise?.fail(error)
@ -432,7 +432,7 @@ extension StateManagedNWConnectionChannel {
// APIs.
var buffer = self.allocator.buffer(capacity: content.count)
buffer.writeBytes(content)
self.pipeline.fireChannelRead(NIOAny(buffer))
self.pipeline.fireChannelRead(buffer)
self.pipeline.fireChannelReadComplete()
}

View File

@ -42,7 +42,7 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
func testEmptyWritePromise() {
let emptyWrite = self.allocator.buffer(capacity: 0)
let emptyWritePromise = self.eventLoop.makePromise(of: Void.self)
self.channel.write(NIOAny(emptyWrite), promise: emptyWritePromise)
self.channel.write(emptyWrite, promise: emptyWritePromise)
self.channel.flush()
XCTAssertNoThrow(
try emptyWritePromise.futureResult.wait()
@ -53,7 +53,7 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
}
func testEmptyWritesNoWriteThrough() {
class OutboundTestHandler: ChannelOutboundHandler {
final class OutboundTestHandler: ChannelOutboundHandler, Sendable {
typealias OutboundIn = ByteBuffer
typealias OutboundOut = ByteBuffer
@ -75,9 +75,9 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
let emptyWrite = self.allocator.buffer(capacity: 0)
let thenEmptyWrite = self.allocator.buffer(capacity: 0)
let thenEmptyWritePromise = self.eventLoop.makePromise(of: Void.self)
self.channel.write(NIOAny(emptyWrite), promise: nil)
self.channel.write(emptyWrite, promise: nil)
self.channel.write(
NIOAny(thenEmptyWrite),
thenEmptyWrite,
promise: thenEmptyWritePromise
)
self.channel.flush()
@ -98,20 +98,20 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
case thenEmptyWrite
}
var checkOrder = CheckOrder.noWrite
someWritePromise.futureResult.whenSuccess {
someWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .noWrite)
checkOrder = .someWrite
}
thenEmptyWritePromise.futureResult.whenSuccess {
thenEmptyWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .someWrite)
checkOrder = .thenEmptyWrite
}
self.channel.write(
NIOAny(someWrite),
someWrite,
promise: someWritePromise
)
self.channel.write(
NIOAny(thenEmptyWrite),
thenEmptyWrite,
promise: thenEmptyWritePromise
)
self.channel.flush()
@ -136,20 +136,20 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
case thenEmptyWrite
}
var checkOrder = CheckOrder.noWrite
emptyWritePromise.futureResult.whenSuccess {
emptyWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .noWrite)
checkOrder = .emptyWrite
}
thenEmptyWritePromise.futureResult.whenSuccess {
thenEmptyWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .emptyWrite)
checkOrder = .thenEmptyWrite
}
self.channel.write(
NIOAny(emptyWrite),
emptyWrite,
promise: emptyWritePromise
)
self.channel.write(
NIOAny(thenEmptyWrite),
thenEmptyWrite,
promise: thenEmptyWritePromise
)
self.channel.flush()
@ -174,21 +174,21 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
case thenEmptyWrite
}
var checkOrder = CheckOrder.noWrite
emptyWritePromise.futureResult.whenSuccess {
emptyWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .noWrite)
checkOrder = .emptyWrite
}
thenSomeWritePromise.futureResult.whenSuccess {
thenSomeWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .emptyWrite)
checkOrder = .thenSomeWrite
}
thenEmptyWritePromise.futureResult.whenSuccess {
thenEmptyWritePromise.futureResult.assumeIsolated().whenSuccess {
XCTAssertEqual(checkOrder, .thenSomeWrite)
checkOrder = .thenEmptyWrite
}
self.channel.write(NIOAny(emptyWrite), promise: emptyWritePromise)
self.channel.write(NIOAny(thenSomeWrite), promise: thenSomeWritePromise)
self.channel.write(NIOAny(thenEmptyWrite), promise: thenEmptyWritePromise)
self.channel.write(emptyWrite, promise: emptyWritePromise)
self.channel.write(thenSomeWrite, promise: thenSomeWritePromise)
self.channel.write(thenEmptyWrite, promise: thenEmptyWritePromise)
self.channel.flush()
XCTAssertNoThrow(try thenEmptyWritePromise.futureResult.wait())
XCTAssertNoThrow(
@ -205,9 +205,9 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
let thenEmptyWrite = self.allocator.buffer(capacity: 0)
let thenSomeWrite = self.allocator.bufferFor(string: "then some")
let thenSomeWritePromise = self.eventLoop.makePromise(of: Void.self)
self.channel.write(NIOAny(someWrite), promise: nil)
self.channel.write(NIOAny(thenEmptyWrite), promise: nil)
self.channel.write(NIOAny(thenSomeWrite), promise: thenSomeWritePromise)
self.channel.write(someWrite, promise: nil)
self.channel.write(thenEmptyWrite, promise: nil)
self.channel.write(thenSomeWrite, promise: thenSomeWritePromise)
self.channel.flush()
XCTAssertNoThrow(try thenSomeWritePromise.futureResult.wait())
var someWriteOutput: ByteBuffer?
@ -228,7 +228,7 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
func testSomeWriteAndFlushThenSomeWriteAndFlush() {
let someWrite = self.allocator.bufferFor(string: "non empty")
var someWritePromise: EventLoopPromise<Void>! = self.eventLoop.makePromise()
self.channel.write(NIOAny(someWrite), promise: someWritePromise)
self.channel.write(someWrite, promise: someWritePromise)
self.channel.flush()
XCTAssertNoThrow(
try someWritePromise.futureResult.wait()
@ -239,7 +239,7 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
someWritePromise = nil
let thenSomeWrite = self.allocator.bufferFor(string: "then some")
var thenSomeWritePromise: EventLoopPromise<Void>! = self.eventLoop.makePromise()
self.channel.write(NIOAny(thenSomeWrite), promise: thenSomeWritePromise)
self.channel.write(thenSomeWrite, promise: thenSomeWritePromise)
self.channel.flush()
XCTAssertNoThrow(
try thenSomeWritePromise.futureResult.wait()
@ -253,7 +253,7 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
func testEmptyWriteAndFlushThenEmptyWriteAndFlush() {
let emptyWrite = self.allocator.buffer(capacity: 0)
var emptyWritePromise: EventLoopPromise<Void>! = self.eventLoop.makePromise()
self.channel.write(NIOAny(emptyWrite), promise: emptyWritePromise)
self.channel.write(emptyWrite, promise: emptyWritePromise)
self.channel.flush()
XCTAssertNoThrow(
try emptyWritePromise.futureResult.wait()
@ -264,7 +264,7 @@ class NIOFilterEmptyWritesHandlerTests: XCTestCase {
emptyWritePromise = nil
let thenEmptyWrite = self.allocator.buffer(capacity: 0)
var thenEmptyWritePromise: EventLoopPromise<Void>! = self.eventLoop.makePromise()
self.channel.write(NIOAny(thenEmptyWrite), promise: thenEmptyWritePromise)
self.channel.write(thenEmptyWrite, promise: thenEmptyWritePromise)
self.channel.flush()
XCTAssertNoThrow(
try thenEmptyWritePromise.futureResult.wait()

View File

@ -85,12 +85,12 @@ private final class TLSUserEventHandler: ChannelInboundHandler, RemovableChannel
let alpn = String(string.dropFirst(15))
context.writeAndFlush(.init(ByteBuffer(string: "alpn:\(alpn)")), promise: nil)
context.fireUserInboundEventTriggered(TLSUserEvent.handshakeCompleted(negotiatedProtocol: alpn))
context.pipeline.removeHandler(self, promise: nil)
context.pipeline.syncOperations.removeHandler(self, promise: nil)
} else if string.hasPrefix("alpn:") {
context.fireUserInboundEventTriggered(
TLSUserEvent.handshakeCompleted(negotiatedProtocol: String(string.dropFirst(5)))
)
context.pipeline.removeHandler(self, promise: nil)
context.pipeline.syncOperations.removeHandler(self, promise: nil)
} else {
context.fireChannelRead(data)
}
@ -182,7 +182,9 @@ final class AsyncChannelBootstrapTests: XCTestCase {
func testServerClientBootstrap_withAsyncChannel_andHostPort() async throws {
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
Task {
try! await eventLoopGroup.shutdownGracefully()
}
}
let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup)
@ -240,7 +242,9 @@ final class AsyncChannelBootstrapTests: XCTestCase {
func testAsyncChannelProtocolNegotiation() async throws {
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
Task {
try! await eventLoopGroup.shutdownGracefully()
}
}
let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup)
@ -251,7 +255,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
try Self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
}
}
@ -323,7 +327,9 @@ final class AsyncChannelBootstrapTests: XCTestCase {
func testAsyncChannelNestedProtocolNegotiation() async throws {
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
Task {
try! await eventLoopGroup.shutdownGracefully()
}
}
let channel = try await NIOTSListenerBootstrap(group: eventLoopGroup)
@ -334,7 +340,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureNestedProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
try Self.configureNestedProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
}
}
@ -459,7 +465,9 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
let eventLoopGroup = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! eventLoopGroup.syncShutdownGracefully()
Task {
try! await eventLoopGroup.shutdownGracefully()
}
}
let channels = NIOLockedValueBox<[Channel]>([Channel]())
@ -478,7 +486,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
port: 0
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
try Self.configureProtocolNegotiationHandlers(channel: channel).protocolNegotiationResult
}
}
@ -585,7 +593,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN)
try Self.configureProtocolNegotiationHandlers(channel: channel, proposedALPN: proposedALPN)
.protocolNegotiationResult
}
}
@ -602,7 +610,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
to: .init(ipAddress: "127.0.0.1", port: port)
) { channel in
channel.eventLoop.makeCompletedFuture {
try self.configureNestedProtocolNegotiationHandlers(
try Self.configureNestedProtocolNegotiationHandlers(
channel: channel,
proposedOuterALPN: proposedOuterALPN,
proposedInnerALPN: proposedInnerALPN
@ -612,18 +620,18 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
@discardableResult
private func configureProtocolNegotiationHandlers(
private static func configureProtocolNegotiationHandlers(
channel: Channel,
proposedALPN: TLSUserEventHandler.ALPN? = nil
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
try channel.pipeline.syncOperations.addHandler(ByteToMessageHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(MessageToByteHandler(LineDelimiterCoder()))
try channel.pipeline.syncOperations.addHandler(TLSUserEventHandler(proposedALPN: proposedALPN))
return try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return try Self.addTypedApplicationProtocolNegotiationHandler(to: channel)
}
@discardableResult
private func configureNestedProtocolNegotiationHandlers(
private static func configureNestedProtocolNegotiationHandlers(
channel: Channel,
proposedOuterALPN: TLSUserEventHandler.ALPN? = nil,
proposedInnerALPN: TLSUserEventHandler.ALPN? = nil
@ -642,7 +650,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
try channel.pipeline.syncOperations.addHandler(
TLSUserEventHandler(proposedALPN: proposedInnerALPN)
)
let negotiationFuture = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
let negotiationFuture = try Self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return negotiationFuture.protocolNegotiationResult
}
@ -651,7 +659,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
try channel.pipeline.syncOperations.addHandler(
TLSUserEventHandler(proposedALPN: proposedInnerALPN)
)
let negotiationHandler = try self.addTypedApplicationProtocolNegotiationHandler(to: channel)
let negotiationHandler = try Self.addTypedApplicationProtocolNegotiationHandler(to: channel)
return negotiationHandler.protocolNegotiationResult
}
@ -667,7 +675,7 @@ final class AsyncChannelBootstrapTests: XCTestCase {
}
@discardableResult
private func addTypedApplicationProtocolNegotiationHandler(
private static func addTypedApplicationProtocolNegotiationHandler(
to channel: Channel
) throws -> NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {
let negotiationHandler = NIOTypedApplicationProtocolNegotiationHandler<NegotiationResult> {

View File

@ -24,58 +24,41 @@ import Foundation
@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6, *)
final class NIOTSBootstrapTests: XCTestCase {
var groupBag: [NIOTSEventLoopGroup]? = nil // protected by `self.lock`
let lock = NIOLock()
override func setUp() {
self.lock.withLock {
XCTAssertNil(self.groupBag)
self.groupBag = []
}
}
override func tearDown() {
XCTAssertNoThrow(
try self.lock.withLock {
guard let groupBag = self.groupBag else {
XCTFail()
return
}
for group in groupBag {
func testBootstrapsTolerateFuturesFromDifferentEventLoopsReturnedInInitializers() throws {
let groupBag: NIOLockedValueBox<[NIOTSEventLoopGroup]> = .init([])
defer {
try! groupBag.withLockedValue {
for group in $0 {
XCTAssertNoThrow(try group.syncShutdownGracefully())
}
self.groupBag = nil
}
)
}
func freshEventLoop() -> EventLoop {
let group: NIOTSEventLoopGroup = .init(loopCount: 1, defaultQoS: .default)
self.lock.withLock {
self.groupBag!.append(group)
}
return group.next()
}
func testBootstrapsTolerateFuturesFromDifferentEventLoopsReturnedInInitializers() throws {
let childChannelDone = self.freshEventLoop().makePromise(of: Void.self)
let serverChannelDone = self.freshEventLoop().makePromise(of: Void.self)
@Sendable func freshEventLoop() -> EventLoop {
let group: NIOTSEventLoopGroup = .init(loopCount: 1, defaultQoS: .default)
groupBag.withLockedValue {
$0.append(group)
}
return group.next()
}
let childChannelDone = freshEventLoop().makePromise(of: Void.self)
let serverChannelDone = freshEventLoop().makePromise(of: Void.self)
let serverChannel = try assertNoThrowWithValue(
NIOTSListenerBootstrap(group: self.freshEventLoop())
NIOTSListenerBootstrap(group: freshEventLoop())
.childChannelInitializer { channel in
channel.eventLoop.preconditionInEventLoop()
defer {
childChannelDone.succeed(())
}
return self.freshEventLoop().makeSucceededFuture(())
return freshEventLoop().makeSucceededFuture(())
}
.serverChannelInitializer { channel in
channel.eventLoop.preconditionInEventLoop()
defer {
serverChannelDone.succeed(())
}
return self.freshEventLoop().makeSucceededFuture(())
return freshEventLoop().makeSucceededFuture(())
}
.bind(host: "127.0.0.1", port: 0)
.wait()
@ -85,10 +68,10 @@ final class NIOTSBootstrapTests: XCTestCase {
}
let client = try assertNoThrowWithValue(
NIOTSConnectionBootstrap(group: self.freshEventLoop())
NIOTSConnectionBootstrap(group: freshEventLoop())
.channelInitializer { channel in
channel.eventLoop.preconditionInEventLoop()
return self.freshEventLoop().makeSucceededFuture(())
return freshEventLoop().makeSucceededFuture(())
}
.connect(to: serverChannel.localAddress!)
.wait()
@ -140,7 +123,9 @@ final class NIOTSBootstrapTests: XCTestCase {
return try NIOTSListenerBootstrap(group: group)
.childChannelInitializer { channel in
XCTAssertEqual(0, numberOfConnections.loadThenWrappingIncrement(ordering: .relaxed))
return channel.pipeline.addHandler(TellMeIfConnectionIsTLSHandler(isTLS: isTLS))
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(TellMeIfConnectionIsTLSHandler(isTLS: isTLS))
}
}
.bind(host: "127.0.0.1", port: 0)
.wait()
@ -175,8 +160,7 @@ final class NIOTSBootstrapTests: XCTestCase {
)
.enableTLS()
var buffer = server1.allocator.buffer(capacity: 2)
buffer.writeString("NO")
let buffer = server1.allocator.buffer(string: "NO")
var maybeClient1: Channel? = nil
XCTAssertNoThrow(maybeClient1 = try bootstrap.connect(to: server1.localAddress!).wait())

View File

@ -19,6 +19,7 @@ import NIOCore
import NIOFoundationCompat
import NIOTransportServices
import Foundation
import NIOConcurrencyHelpers
@available(macOS 10.14, iOS 12.0, tvOS 12.0, watchOS 6, *)
final class ConnectRecordingHandler: ChannelOutboundHandler {
@ -73,13 +74,14 @@ final class DisableWaitingAfterConnect: ChannelOutboundHandler {
typealias OutboundOut = Any
func connect(context: ChannelHandlerContext, to address: SocketAddress, promise: EventLoopPromise<Void>?) {
do {
try context.channel.syncOptions?.setOption(NIOTSChannelOptions.waitForActivity, value: false)
} catch {
promise?.fail(error)
return
}
let f = context.channel.setOption(NIOTSChannelOptions.waitForActivity, value: false).flatMap {
context.connect(to: address)
}
if let promise = promise {
f.cascade(to: promise)
}
context.connect(to: address).cascade(to: promise)
}
}
@ -112,7 +114,7 @@ final class PromiseOnActiveHandler: ChannelInboundHandler {
}
@available(OSX 10.14, iOS 12.0, tvOS 12.0, *)
final class EventWaiter<Event>: ChannelInboundHandler {
final class EventWaiter<Event: Sendable>: ChannelInboundHandler {
typealias InboundIn = Any
typealias InboundOut = Any
@ -145,7 +147,6 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
func testConnectingToSocketAddressTraversesPipeline() throws {
let connectRecordingHandler = ConnectRecordingHandler()
let listener = try NIOTSListenerBootstrap(group: self.group)
.bind(host: "localhost", port: 0).wait()
defer {
@ -153,9 +154,14 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
let connectBootstrap = NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in channel.pipeline.addHandler(connectRecordingHandler) }
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [])
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let handler = ConnectRecordingHandler()
try channel.pipeline.syncOperations.addHandler(handler)
XCTAssertEqual(handler.connectTargets, [])
XCTAssertEqual(handler.endpointTargets, [])
}
}
let connection = try connectBootstrap.connect(to: listener.localAddress!).wait()
defer {
@ -163,13 +169,13 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
try connection.eventLoop.submit {
XCTAssertEqual(connectRecordingHandler.connectTargets, [listener.localAddress!])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [])
let handler = try connection.pipeline.syncOperations.handler(type: ConnectRecordingHandler.self)
XCTAssertEqual(handler.connectTargets, [listener.localAddress!])
XCTAssertEqual(handler.endpointTargets, [])
}.wait()
}
func testConnectingToHostPortSkipsPipeline() throws {
let connectRecordingHandler = ConnectRecordingHandler()
let listener = try NIOTSListenerBootstrap(group: self.group)
.bind(host: "localhost", port: 0).wait()
defer {
@ -177,9 +183,14 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
let connectBootstrap = NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in channel.pipeline.addHandler(connectRecordingHandler) }
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [])
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let connectRecordingHandler = ConnectRecordingHandler()
try channel.pipeline.syncOperations.addHandler(connectRecordingHandler)
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [])
}
}
let connection = try connectBootstrap.connect(host: "localhost", port: Int(listener.localAddress!.port!)).wait()
defer {
@ -187,6 +198,9 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
try connection.eventLoop.submit {
let connectRecordingHandler = try connection.pipeline.syncOperations.handler(
type: ConnectRecordingHandler.self
)
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(
connectRecordingHandler.endpointTargets,
@ -201,7 +215,6 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
func testConnectingToEndpointSkipsPipeline() throws {
let connectRecordingHandler = ConnectRecordingHandler()
let listener = try NIOTSListenerBootstrap(group: self.group)
.bind(host: "localhost", port: 0).wait()
defer {
@ -209,9 +222,14 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
let connectBootstrap = NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in channel.pipeline.addHandler(connectRecordingHandler) }
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [])
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let connectRecordingHandler = ConnectRecordingHandler()
try channel.pipeline.syncOperations.addHandler(connectRecordingHandler)
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [])
}
}
let target = NWEndpoint.hostPort(
host: "localhost",
@ -224,6 +242,9 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
try connection.eventLoop.submit {
let connectRecordingHandler = try connection.pipeline.syncOperations.handler(
type: ConnectRecordingHandler.self
)
XCTAssertEqual(connectRecordingHandler.connectTargets, [])
XCTAssertEqual(connectRecordingHandler.endpointTargets, [target])
}.wait()
@ -231,7 +252,11 @@ class NIOTSConnectionChannelTests: XCTestCase {
func testZeroLengthWritesHaveSatisfiedPromises() throws {
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(FailOnReadHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(FailOnReadHandler())
}
}
.bind(host: "localhost", port: 0).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -320,13 +345,16 @@ class NIOTSConnectionChannelTests: XCTestCase {
XCTAssertNoThrow(try listener.close().wait())
}
var writabilities = [Bool]()
let handler = WritabilityChangedHandler { newValue in
writabilities.append(newValue)
}
let writabilities: NIOLockedValueBox<[Bool]> = .init([])
let connection = try NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in channel.pipeline.addHandler(handler) }
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let handler = WritabilityChangedHandler { newValue in
writabilities.withLockedValue { $0.append(newValue) }
}
try channel.pipeline.syncOperations.addHandler(handler)
}
}
.connect(to: listener.localAddress!)
.wait()
@ -337,8 +365,6 @@ class NIOTSConnectionChannelTests: XCTestCase {
value: ChannelOptions.Types.WriteBufferWaterMark(low: 2, high: 2048)
).wait()
)
var buffer = connection.allocator.buffer(capacity: 2048)
buffer.writeBytes(repeatElement(UInt8(4), count: 2048))
// We're going to issue the following pattern of writes:
// a: 1 byte
@ -355,58 +381,61 @@ class NIOTSConnectionChannelTests: XCTestCase {
// until after the promise for d has fired: by the time the promise for e has fired it will be writable
// again.
try connection.eventLoop.submit {
var buffer = connection.allocator.buffer(capacity: 2048)
buffer.writeBytes(repeatElement(UInt8(4), count: 2048))
// Pre writing.
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
// Write a. After this write, we are still writable. When this write
// succeeds, we'll still be not writable.
connection.write(buffer.getSlice(at: 0, length: 1)).whenComplete { (_: Result<Void, Error>) in
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
}
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
// Write b. After this write we are still writable. When this write
// succeeds we'll still be not writable.
connection.write(buffer.getSlice(at: 0, length: 1)).whenComplete { (_: Result<Void, Error>) in
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
}
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
// Write c. After this write we are still writable (2047 bytes written).
// When this write succeeds we'll still be not writable (2 bytes outstanding).
connection.write(buffer.getSlice(at: 0, length: 2045)).whenComplete { (_: Result<Void, Error>) in
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
}
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
// Write d. After this write we are still writable (2048 bytes written).
// When this write succeeds we'll become writable, but critically the promise fires before
// the state change, so we'll *appear* to be unwritable.
connection.write(buffer.getSlice(at: 0, length: 1)).whenComplete { (_: Result<Void, Error>) in
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
}
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
// Write e. After this write we are now not writable (2049 bytes written).
// When this write succeeds we'll have already been writable, thanks to the previous
// write.
connection.write(buffer.getSlice(at: 0, length: 1)).whenComplete { (_: Result<Void, Error>) in
XCTAssertEqual(writabilities, [false, true])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false, true])
XCTAssertTrue(connection.isWritable)
// We close after this succeeds.
connection.close(promise: nil)
}
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
}.wait()
@ -415,7 +444,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
XCTAssertNoThrow(try connection.closeFuture.wait())
// Ok, check that the writability changes worked.
XCTAssertEqual(writabilities, [false, true])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false, true])
}
func testWritabilityChangesAfterChangingWatermarks() throws {
@ -425,23 +454,22 @@ class NIOTSConnectionChannelTests: XCTestCase {
XCTAssertNoThrow(try listener.close().wait())
}
var writabilities = [Bool]()
let handler = WritabilityChangedHandler { newValue in
writabilities.append(newValue)
}
let writabilities: NIOLockedValueBox<[Bool]> = .init([])
let connection = try NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in channel.pipeline.addHandler(handler) }
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let handler = WritabilityChangedHandler { newValue in
writabilities.withLockedValue({ $0.append(newValue) })
}
try channel.pipeline.syncOperations.addHandler(handler)
}
}
.connect(to: listener.localAddress!)
.wait()
defer {
XCTAssertNoThrow(try connection.close().wait())
}
// We're going to allocate a buffer.
var buffer = connection.allocator.buffer(capacity: 256)
buffer.writeBytes(repeatElement(UInt8(4), count: 256))
// We're going to issue a 256-byte write. This write will not cause any change in channel writability
// state.
//
@ -462,13 +490,17 @@ class NIOTSConnectionChannelTests: XCTestCase {
//
// Then we're going to set the high watermark to 1024, and the low to 256. This will change nothing.
try connection.eventLoop.submit {
// We're going to allocate a buffer.
var buffer = connection.allocator.buffer(capacity: 256)
buffer.writeBytes(repeatElement(UInt8(4), count: 256))
// Pre changes.
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
// Write. No writability change.
connection.write(buffer, promise: nil)
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
}.wait()
@ -477,7 +509,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
value: ChannelOptions.Types.WriteBufferWaterMark(low: 128, high: 256)
).flatMap {
// High to 256, low to 128. No writability change.
XCTAssertEqual(writabilities, [])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [])
XCTAssertTrue(connection.isWritable)
return connection.setOption(
@ -486,7 +518,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
)
}.flatMap {
// High to 255, low to 127. Channel becomes not writable.
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
return connection.setOption(
@ -495,7 +527,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
)
}.flatMap {
// High back to 256, low to 128. No writability change.
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
return connection.setOption(
@ -504,7 +536,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
)
}.flatMap {
// High to 1024, low to 128. No writability change.
XCTAssertEqual(writabilities, [false])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false])
XCTAssertFalse(connection.isWritable)
return connection.setOption(
@ -513,7 +545,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
)
}.flatMap {
// Low to 257, channel becomes writable again.
XCTAssertEqual(writabilities, [false, true])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false, true])
XCTAssertTrue(connection.isWritable)
return connection.setOption(
@ -522,7 +554,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
)
}.map {
// Low back to 256, no writability change.
XCTAssertEqual(writabilities, [false, true])
XCTAssertEqual(writabilities.withLockedValue({ $0 }), [false, true])
XCTAssertTrue(connection.isWritable)
}.wait()
}
@ -608,7 +640,9 @@ class NIOTSConnectionChannelTests: XCTestCase {
func testEarlyExitCanBeSetInWaitingState() throws {
let connectFuture = NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in
channel.pipeline.addHandler(DisableWaitingAfterConnect())
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(DisableWaitingAfterConnect())
}
}.connect(to: try SocketAddress(unixDomainSocketPath: "/this/path/definitely/doesnt/exist"))
do {
@ -712,7 +746,11 @@ class NIOTSConnectionChannelTests: XCTestCase {
let channel = try NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in
channel.pipeline.addHandler(PromiseOnActiveHandler(activePromise))
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
PromiseOnActiveHandler(activePromise)
)
}
}.connect(to: listener.localAddress!).wait()
XCTAssertNoThrow(try activePromise.futureResult.wait())
@ -772,10 +810,13 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
let testCompletePromise = self.group.next().makePromise(of: Void.self)
let testHandler = TestHandler(testCompletePromise: testCompletePromise)
let listener = try assertNoThrowWithValue(
NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(EchoHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
}
}
.bind(host: "localhost", port: 0).wait()
)
defer {
@ -783,7 +824,13 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
let connectBootstrap = NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in channel.pipeline.addHandler(testHandler) }
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
TestHandler(testCompletePromise: testCompletePromise)
)
}
}
let connection = try assertNoThrowWithValue(connectBootstrap.connect(to: listener.localAddress!).wait())
defer {
@ -804,7 +851,8 @@ class NIOTSConnectionChannelTests: XCTestCase {
// We expect 2.
XCTAssertNoThrow(
try connection.eventLoop.submit {
XCTAssertEqual(testHandler.readCount, 2)
let handler = try connection.pipeline.syncOperations.handler(type: TestHandler.self)
XCTAssertEqual(handler.readCount, 2)
}.wait()
)
}
@ -839,11 +887,15 @@ class NIOTSConnectionChannelTests: XCTestCase {
func testConnectingInvolvesWaiting() throws {
let loop = self.group.next()
let eventPromise = loop.makePromise(of: NIOTSNetworkEvents.WaitingForConnectivity.self)
let eventRecordingHandler = EventWaiter<NIOTSNetworkEvents.WaitingForConnectivity>(eventPromise)
// 5s is the worst-case test time: normally it'll be faster as we don't wait for this.
let connectBootstrap = NIOTSConnectionBootstrap(group: loop)
.channelInitializer { channel in channel.pipeline.addHandler(eventRecordingHandler) }
.channelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let eventRecordingHandler = EventWaiter<NIOTSNetworkEvents.WaitingForConnectivity>(eventPromise)
try channel.pipeline.syncOperations.addHandler(eventRecordingHandler)
}
}
.connectTimeout(.seconds(5))
// We choose 443 here to avoid triggering Private Relay, which can do all kinds of weird stuff to this test.
@ -868,7 +920,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
}
func testSyncOptionsAreSupported() throws {
func testSyncOptions(_ channel: Channel) {
@Sendable func testSyncOptions(_ channel: Channel) {
if let sync = channel.syncOptions {
do {
let autoRead = try sync.getOption(ChannelOptions.autoRead)
@ -919,6 +971,7 @@ class NIOTSConnectionChannelTests: XCTestCase {
func channelActive(context: ChannelHandlerContext) {
listenerChannel
.close()
.assumeIsolated()
.whenSuccess { _ in
_ = context.channel.write(ByteBuffer(data: Data()))
}
@ -943,12 +996,14 @@ class NIOTSConnectionChannelTests: XCTestCase {
let testCompletePromise = self.group.next().makePromise(of: Error.self)
let connection = try NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in
channel.pipeline.addHandler(
ForwardErrorHandler(
testCompletePromise: testCompletePromise,
listenerChannel: listener
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
ForwardErrorHandler(
testCompletePromise: testCompletePromise,
listenerChannel: listener
)
)
)
}
}
.connect(to: listener.localAddress!)
.wait()

View File

@ -20,7 +20,7 @@ import NIOTransportServices
import Foundation
extension Channel {
func wait<T>(for type: T.Type, count: Int) throws -> [T] {
func wait<T: Sendable>(for type: T.Type, count: Int) throws -> [T] {
try self.pipeline.context(name: "ByteReadRecorder").flatMap { context in
if let future = (context.handler as? ReadRecorder<T>)?.notifyForDatagrams(count) {
return future
@ -53,7 +53,7 @@ extension Channel {
}
}
final class ReadRecorder<DataType>: ChannelInboundHandler {
final class ReadRecorder<DataType: Sendable>: ChannelInboundHandler {
typealias InboundIn = DataType
typealias InboundOut = DataType
@ -120,22 +120,35 @@ final class NIOTSDatagramConnectionChannelTests: XCTestCase {
group: NIOTSEventLoopGroup,
host: String = "127.0.0.1",
port: Int = 0,
onConnect: @escaping (Channel) -> Void
onConnect: @escaping @Sendable (Channel) -> Void
) throws -> Channel {
try NIOTSDatagramListenerBootstrap(group: group)
.childChannelInitializer { childChannel in
onConnect(childChannel)
return childChannel.pipeline.addHandler(ReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
return childChannel.eventLoop.makeCompletedFuture {
try childChannel.pipeline.syncOperations.addHandler(
ReadRecorder<ByteBuffer>(),
name: "ByteReadRecorder"
)
}
}
.bind(host: host, port: port)
.wait()
}
private func buildClientChannel(group: NIOTSEventLoopGroup, host: String = "127.0.0.1", port: Int) throws -> Channel
{
private func buildClientChannel(
group: NIOTSEventLoopGroup,
host: String = "127.0.0.1",
port: Int
) throws -> Channel {
try NIOTSDatagramBootstrap(group: group)
.channelInitializer { channel in
channel.pipeline.addHandler(ReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
ReadRecorder<ByteBuffer>(),
name: "ByteReadRecorder"
)
}
}
.connect(host: host, port: port)
.wait()
@ -169,7 +182,7 @@ final class NIOTSDatagramConnectionChannelTests: XCTestCase {
}
func testSyncOptionsAreSupported() throws {
func testSyncOptions(_ channel: Channel) {
@Sendable func testSyncOptions(_ channel: Channel) {
if let sync = channel.syncOptions {
do {
let endpointReuse = try sync.getOption(NIOTSChannelOptions.allowLocalEndpointReuse)
@ -192,7 +205,12 @@ final class NIOTSDatagramConnectionChannelTests: XCTestCase {
.childChannelInitializer { channel in
testSyncOptions(channel)
promise.succeed(channel)
return channel.pipeline.addHandler(ReadRecorder<ByteBuffer>(), name: "ByteReadRecorder")
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
ReadRecorder<ByteBuffer>(),
name: "ByteReadRecorder"
)
}
}
.bind(host: "localhost", port: 0)
.wait()

View File

@ -18,6 +18,7 @@ import NIOCore
import NIOTransportServices
import Foundation
import Network
import NIOConcurrencyHelpers
func assertNoThrowWithValue<T>(
_ body: @autoclosure () throws -> T,
@ -56,44 +57,29 @@ final class ReadExpecter: ChannelInboundHandler {
struct DidNotReadError: Error {}
private var readPromise: EventLoopPromise<Void>?
private let readPromise: EventLoopPromise<Void>
private var cumulationBuffer: ByteBuffer?
private let expectedRead: ByteBuffer
var readFuture: EventLoopFuture<Void>? {
self.readPromise?.futureResult
}
init(expecting: ByteBuffer) {
init(expecting: ByteBuffer, readPromise: EventLoopPromise<Void>) {
self.readPromise = readPromise
self.cumulationBuffer = nil
self.expectedRead = expecting
}
func handlerAdded(context: ChannelHandlerContext) {
self.readPromise = context.eventLoop.makePromise()
}
func handlerRemoved(context: ChannelHandlerContext) {
if let promise = self.readPromise {
promise.fail(DidNotReadError())
}
self.readPromise.fail(DidNotReadError())
}
func channelRead(context: ChannelHandlerContext, data: NIOAny) {
var bytes = self.unwrapInboundIn(data)
if self.cumulationBuffer == nil {
self.cumulationBuffer = bytes
} else {
self.cumulationBuffer!.writeBuffer(&bytes)
}
self.cumulationBuffer.setOrWriteBuffer(&bytes)
self.maybeFulfillPromise()
}
private func maybeFulfillPromise() {
if let promise = self.readPromise, self.cumulationBuffer! == self.expectedRead {
promise.succeed(())
self.readPromise = nil
}
guard self.cumulationBuffer == self.expectedRead else { return }
self.readPromise.succeed(())
}
}
@ -180,9 +166,12 @@ final class WaitForActiveHandler: ChannelInboundHandler {
extension Channel {
/// Expect that the given bytes will be received.
func expectRead(_ bytes: ByteBuffer) -> EventLoopFuture<Void> {
let expecter = ReadExpecter(expecting: bytes)
return self.pipeline.addHandler(expecter).flatMap {
expecter.readFuture!
let readPromise = self.eventLoop.makePromise(of: Void.self)
return self.eventLoop.submit {
let expecter = ReadExpecter(expecting: bytes, readPromise: readPromise)
try self.pipeline.syncOperations.addHandler(expecter)
}.flatMap {
readPromise.futureResult
}
}
}
@ -209,7 +198,11 @@ class NIOTSEndToEndTests: XCTestCase {
func testSimpleListener() throws {
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(EchoHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
}
}
.bind(host: "localhost", port: 0).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -232,7 +225,11 @@ class NIOTSEndToEndTests: XCTestCase {
on: NWEndpoint.Port(rawValue: 0)!
)
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(EchoHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
}
}
.withNWListener(nwListenerTest).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -259,7 +256,11 @@ class NIOTSEndToEndTests: XCTestCase {
func testMultipleConnectionsOneListener() throws {
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(EchoHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
}
}
.bind(host: "localhost", port: 0).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -282,7 +283,11 @@ class NIOTSEndToEndTests: XCTestCase {
func testBasicConnectionTeardown() throws {
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(CloseOnActiveHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(CloseOnActiveHandler())
}
}
.bind(host: "localhost", port: 0).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -304,12 +309,12 @@ class NIOTSEndToEndTests: XCTestCase {
// This test is a little bit dicey, but we need 20 futures in this list.
let closeFutureSyncQueue = DispatchQueue(label: "closeFutureSyncQueue")
let closeFutureGroup = DispatchGroup()
var closeFutures: [EventLoopFuture<Void>] = []
let closeFutures: NIOLockedValueBox<[EventLoopFuture<Void>]> = .init([])
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in
closeFutureSyncQueue.sync {
closeFutures.append(channel.closeFuture)
closeFutures.withLockedValue { $0.append(channel.closeFuture) }
}
closeFutureGroup.leave()
return channel.eventLoop.makeSucceededFuture(())
@ -320,7 +325,9 @@ class NIOTSEndToEndTests: XCTestCase {
}
let bootstrap = NIOTSConnectionBootstrap(group: self.group).channelInitializer { channel in
channel.pipeline.addHandler(CloseOnActiveHandler())
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(CloseOnActiveHandler())
}
}
for _ in (0..<10) {
@ -330,7 +337,7 @@ class NIOTSEndToEndTests: XCTestCase {
closeFutureGroup.enter()
bootstrap.connect(to: listener.localAddress!).whenSuccess { channel in
closeFutureSyncQueue.sync {
closeFutures.append(channel.closeFuture)
closeFutures.withLockedValue { $0.append(channel.closeFuture) }
}
closeFutureGroup.leave()
}
@ -338,7 +345,9 @@ class NIOTSEndToEndTests: XCTestCase {
closeFutureGroup.wait()
let allClosed = closeFutureSyncQueue.sync {
EventLoopFuture<Void>.andAllComplete(closeFutures, on: self.group.next())
closeFutures.withLockedValue {
EventLoopFuture<Void>.andAllComplete($0, on: self.group.next())
}
}
XCTAssertNoThrow(try allClosed.wait())
}
@ -347,10 +356,12 @@ class NIOTSEndToEndTests: XCTestCase {
let serverSideConnectionPromise: EventLoopPromise<Channel> = self.group.next().makePromise()
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in
channel.pipeline.addHandlers([
WaitForActiveHandler(serverSideConnectionPromise),
EchoHandler(),
])
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandlers([
WaitForActiveHandler(serverSideConnectionPromise),
EchoHandler(),
])
}
}
.bind(host: "localhost", port: 0).wait()
defer {
@ -373,8 +384,9 @@ class NIOTSEndToEndTests: XCTestCase {
let halfClosedPromise: EventLoopPromise<Void> = self.group.next().makePromise()
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in
channel.pipeline.addHandler(EchoHandler()).flatMap { _ in
channel.pipeline.addHandler(HalfCloseHandler(halfClosedPromise))
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
try channel.pipeline.syncOperations.addHandler(HalfCloseHandler(halfClosedPromise))
}
}
.childChannelOption(ChannelOptions.allowRemoteHalfClosure, value: true)
@ -403,8 +415,9 @@ class NIOTSEndToEndTests: XCTestCase {
func testDisabledHalfClosureCausesFullClosure() throws {
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in
channel.pipeline.addHandler(EchoHandler()).flatMap { _ in
channel.pipeline.addHandler(FailOnHalfCloseHandler())
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
try channel.pipeline.syncOperations.addHandler(FailOnHalfCloseHandler())
}
}
.bind(host: "localhost", port: 0).wait()
@ -483,7 +496,11 @@ class NIOTSEndToEndTests: XCTestCase {
let udsPath = "/tmp/\(UUID().uuidString)_testBasicUnixSockets.sock"
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(EchoHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
}
}
.bind(unixDomainSocketPath: udsPath).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -513,7 +530,11 @@ class NIOTSEndToEndTests: XCTestCase {
let serviceEndpoint = NWEndpoint.service(name: name, type: "_niots._tcp", domain: "local", interface: nil)
let listener = try NIOTSListenerBootstrap(group: self.group)
.childChannelInitializer { channel in channel.pipeline.addHandler(EchoHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(EchoHandler())
}
}
.bind(endpoint: serviceEndpoint).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -541,7 +562,11 @@ class NIOTSEndToEndTests: XCTestCase {
let listener = try NIOTSListenerBootstrap(group: self.group)
.serverChannelOption(ChannelOptions.socket(SOL_SOCKET, SO_REUSEADDR), value: 0)
.serverChannelOption(ChannelOptions.socket(SOL_SOCKET, SO_REUSEPORT), value: 0)
.childChannelInitializer { channel in channel.pipeline.addHandler(CloseOnActiveHandler()) }
.childChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(CloseOnActiveHandler())
}
}
.bind(host: "localhost", port: 0).wait()
let address = listener.localAddress!
@ -587,11 +612,11 @@ class NIOTSEndToEndTests: XCTestCase {
let testCompletePromise = self.group.next().makePromise(of: Bool.self)
let connection = try NIOTSConnectionBootstrap(group: self.group)
.channelInitializer { channel in
channel.pipeline.addHandler(
ViabilityHandler(
testCompletePromise: testCompletePromise
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
ViabilityHandler(testCompletePromise: testCompletePromise)
)
)
}
}
.connect(to: listener.localAddress!)
.wait()

View File

@ -147,7 +147,9 @@ class NIOTSEventLoopTest: XCTestCase {
func testIndividualLoopsCannotBeShutDownWhenPartOfGroup() async throws {
let group = NIOTSEventLoopGroup(loopCount: 3)
defer {
try! group.syncShutdownGracefully()
Task {
try! await group.shutdownGracefully()
}
}
for loop in group.makeIterator() {

View File

@ -55,13 +55,18 @@ class NIOTSListenerChannelTests: XCTestCase {
}
func testBindingToSocketAddressTraversesPipeline() throws {
let bindRecordingHandler = BindRecordingHandler()
let target = try SocketAddress.makeAddressResolvingHost("localhost", port: 0)
let bindBootstrap = NIOTSListenerBootstrap(group: self.group)
.serverChannelInitializer { channel in channel.pipeline.addHandler(bindRecordingHandler) }
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
.serverChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let bindRecordingHandler = BindRecordingHandler()
try channel.pipeline.syncOperations.addHandler(
bindRecordingHandler
)
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
}
}
let listener = try bindBootstrap.bind(to: target).wait()
defer {
@ -69,18 +74,22 @@ class NIOTSListenerChannelTests: XCTestCase {
}
try self.group.next().submit {
XCTAssertEqual(bindRecordingHandler.bindTargets, [target])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
let handler = try listener.pipeline.syncOperations.handler(type: BindRecordingHandler.self)
XCTAssertEqual(handler.bindTargets, [target])
XCTAssertEqual(handler.endpointTargets, [])
}.wait()
}
func testConnectingToHostPortTraversesPipeline() throws {
let bindRecordingHandler = BindRecordingHandler()
let bindBootstrap = NIOTSListenerBootstrap(group: self.group)
.serverChannelInitializer { channel in channel.pipeline.addHandler(bindRecordingHandler) }
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
.serverChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let bindRecordingHandler = BindRecordingHandler()
try channel.pipeline.syncOperations.addHandler(bindRecordingHandler)
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
}
}
let listener = try bindBootstrap.bind(host: "localhost", port: 0).wait()
defer {
@ -88,22 +97,28 @@ class NIOTSListenerChannelTests: XCTestCase {
}
try self.group.next().submit {
let handler = try listener.pipeline.syncOperations.handler(
type: BindRecordingHandler.self
)
XCTAssertEqual(
bindRecordingHandler.bindTargets,
handler.bindTargets,
[try SocketAddress.makeAddressResolvingHost("localhost", port: 0)]
)
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
XCTAssertEqual(handler.endpointTargets, [])
}.wait()
}
func testConnectingToEndpointSkipsPipeline() throws {
let endpoint = NWEndpoint.hostPort(host: .ipv4(.loopback), port: .any)
let bindRecordingHandler = BindRecordingHandler()
let bindBootstrap = NIOTSListenerBootstrap(group: self.group)
.serverChannelInitializer { channel in channel.pipeline.addHandler(bindRecordingHandler) }
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
.serverChannelInitializer { channel in
channel.eventLoop.makeCompletedFuture {
let bindRecordingHandler = BindRecordingHandler()
try channel.pipeline.syncOperations.addHandler(bindRecordingHandler)
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [])
}
}
let listener = try bindBootstrap.bind(endpoint: endpoint).wait()
defer {
@ -111,8 +126,11 @@ class NIOTSListenerChannelTests: XCTestCase {
}
try self.group.next().submit {
XCTAssertEqual(bindRecordingHandler.bindTargets, [])
XCTAssertEqual(bindRecordingHandler.endpointTargets, [endpoint])
let handler = try listener.pipeline.syncOperations.handler(
type: BindRecordingHandler.self
)
XCTAssertEqual(handler.bindTargets, [])
XCTAssertEqual(handler.endpointTargets, [endpoint])
}.wait()
}
@ -169,7 +187,11 @@ class NIOTSListenerChannelTests: XCTestCase {
let listener = try NIOTSListenerBootstrap(group: self.group, childGroup: childGroup)
.childChannelInitializer { channel in
childChannelPromise.succeed(channel)
return channel.pipeline.addHandler(PromiseOnActiveHandler(activePromise))
return channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(
PromiseOnActiveHandler(activePromise)
)
}
}.bind(host: "localhost", port: 0).wait()
defer {
XCTAssertNoThrow(try listener.close().wait())
@ -243,7 +265,9 @@ class NIOTSListenerChannelTests: XCTestCase {
let channelPromise = self.group.next().makePromise(of: Channel.self)
let listener = try NIOTSListenerBootstrap(group: self.group)
.serverChannelInitializer { channel in
channel.pipeline.addHandler(ChannelReceiver(channelPromise))
channel.eventLoop.makeCompletedFuture {
try channel.pipeline.syncOperations.addHandler(ChannelReceiver(channelPromise))
}
}
.bind(host: "localhost", port: 0).wait()
defer {
@ -258,7 +282,9 @@ class NIOTSListenerChannelTests: XCTestCase {
// We must wait for channel active here, or the socket addresses won't be set.
let promisedChannel = try channelPromise.futureResult.flatMap { (channel) -> EventLoopFuture<Channel> in
let promiseChannelActive = channel.eventLoop.makePromise(of: Channel.self)
_ = channel.pipeline.addHandler(WaitForActiveHandler(promiseChannelActive))
try? channel.pipeline.syncOperations.addHandler(
WaitForActiveHandler(promiseChannelActive)
)
return promiseChannelActive.futureResult
}.wait()
@ -314,7 +340,7 @@ class NIOTSListenerChannelTests: XCTestCase {
}
func testSyncOptionsAreSupported() throws {
func testSyncOptions(_ channel: Channel) {
@Sendable func testSyncOptions(_ channel: Channel) {
if let sync = channel.syncOptions {
do {
let endpointReuse = try sync.getOption(NIOTSChannelOptions.allowLocalEndpointReuse)