Skip to content

Add a ChannelOption for a server to query the authenticated username. #58

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Sources/NIOSSH/Child Channels/ChildChannelOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ public struct SSHChildChannelOptions {

/// - seealso: `SSHChannelTypeOption`.
public static let sshChannelType: SSHChildChannelOptions.Types.SSHChannelTypeOption = .init()

/// - seealso: `UsernameOption`.
public static let username: SSHChildChannelOptions.Types.UsernameOption = .init()
}

extension SSHChildChannelOptions {
Expand Down Expand Up @@ -53,4 +56,11 @@ extension SSHChildChannelOptions.Types {

public init() {}
}

/// `UsernameOption` allows users to query the authenticated username of the channel.
public struct UsernameOption: ChannelOption {
public typealias Value = String?

public init() {}
}
}
5 changes: 5 additions & 0 deletions Sources/NIOSSH/Child Channels/SSHChannelMultiplexer.swift
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ extension SSHChannelMultiplexer {
self.erroredChannels.append(channelID)
}
}

// The username which the server accepted in authorization
var username: String? { delegate?.username }
}

// MARK: Calls from SSH handlers.
Expand Down Expand Up @@ -218,6 +221,8 @@ extension SSHChannelMultiplexer {
protocol SSHMultiplexerDelegate {
var channel: Channel? { get }

var username: String? { get }

func writeFromChildChannel(_: SSHMessage, _: EventLoopPromise<Void>?)

func flushFromChildChannel()
Expand Down
2 changes: 2 additions & 0 deletions Sources/NIOSSH/Child Channels/SSHChildChannel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ extension SSHChildChannel: Channel, ChannelCore {
// This force-unwrap is safe: we set type before we call the initializer, so
// users can only get this after this value is set.
return self.type! as! Option.Value
case _ as SSHChildChannelOptions.Types.UsernameOption:
return multiplexer.username as! Option.Value
case _ as ChannelOptions.Types.AutoReadOption:
return self.autoRead as! Option.Value
case _ as ChannelOptions.Types.AllowRemoteHalfClosureOption:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

protocol AcceptsUserAuthMessages {
var userAuthStateMachine: UserAuthenticationStateMachine { get set }
var connectionAttributes: SSHConnectionStateMachine.Attributes? { get }
}

extension AcceptsUserAuthMessages {
Expand All @@ -39,9 +40,21 @@ extension AcceptsUserAuthMessages {

mutating func receiveUserAuthRequest(_ message: SSHMessage.UserAuthRequestMessage) throws -> SSHConnectionStateMachine.StateMachineInboundProcessResult {
let result = try self.userAuthStateMachine.receiveUserAuthRequest(message)


let attr = connectionAttributes

if let future = result {
return .possibleFutureMessage(future.map(Self.transform(_:)))
return .possibleFutureMessage(future.map{
switch $0 {
case .success:
attr?.username = message.username
return SSHMultiMessage(.userAuthSuccess)
case .failure(let message):
return SSHMultiMessage(.userAuthFailure(message))
case .publicKeyOK(let message):
return SSHMultiMessage(.userAuthPKOK(message))
}
})
} else {
return .noMessage
}
Expand All @@ -65,17 +78,6 @@ extension AcceptsUserAuthMessages {
}
}

private static func transform(_ result: NIOSSHUserAuthenticationResponseMessage) -> SSHMultiMessage {
switch result {
case .success:
return SSHMultiMessage(.userAuthSuccess)
case .failure(let message):
return SSHMultiMessage(.userAuthFailure(message))
case .publicKeyOK(let message):
return SSHMultiMessage(.userAuthPKOK(message))
}
}

private static func transform(_ result: SSHMessage.UserAuthRequestMessage?) -> SSHMultiMessage? {
result.map { SSHMultiMessage(.userAuthRequest($0)) }
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,15 +57,25 @@ struct SSHConnectionStateMachine {
case sentDisconnect(SSHConnectionRole)
}

class Attributes {
var username: String? = nil
}

/// The state of this state machine.
private var state: State


/// Attributes of the connection which can be changed by messages handlers
private let attributes: Attributes

var username: String? { attributes.username }

private static let defaultTransportProtectionSchemes: [NIOSSHTransportProtection.Type] = [
AES256GCMOpenSSHTransportProtection.self, AES128GCMOpenSSHTransportProtection.self,
]

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type] = Self.defaultTransportProtectionSchemes) {
self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes))
self.attributes = Attributes()
self.state = .idle(IdleState(role: role, protectionSchemes: protectionSchemes, attributes:self.attributes))
}

func start() -> SSHMultiMessage? {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@ extension SSHConnectionStateMachine {

internal var sessionIdentifier: ByteBuffer

internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previous: UserAuthenticationState) {
self.role = previous.role
self.serializer = previous.serializer
self.parser = previous.parser
self.remoteVersion = previous.remoteVersion
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.connectionAttributes = previous.connectionAttributes
}

init(_ previous: RekeyingReceivedNewKeysState) {
Expand All @@ -47,6 +50,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = previous.remoteVersion
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.connectionAttributes = previous.connectionAttributes
}

init(_ previous: RekeyingSentNewKeysState) {
Expand All @@ -56,6 +60,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = previous.remoteVersion
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.connectionAttributes = previous.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,13 @@ extension SSHConnectionStateMachine {

internal var protectionSchemes: [NIOSSHTransportProtection.Type]

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type]) {
internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(role: SSHConnectionRole, protectionSchemes: [NIOSSHTransportProtection.Type], attributes: SSHConnectionStateMachine.Attributes) {
self.role = role
self.serializer = SSHPacketSerializer()
self.protectionSchemes = protectionSchemes
self.connectionAttributes = attributes
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,16 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(sentVersionState state: SentVersionState, allocator: ByteBufferAllocator, loop: EventLoop, remoteVersion: String) {
self.role = state.role
self.parser = state.parser
self.serializer = state.serializer
self.remoteVersion = remoteVersion
self.protectionSchemes = state.protectionSchemes
self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: state.role, remoteVersion: remoteVersion, protectionSchemes: state.protectionSchemes, previousSessionIdentifier: nil)
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extension SSHConnectionStateMachine {

internal var sessionIdentifier: ByteBuffer

internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previous: ActiveState, allocator: ByteBufferAllocator, loop: EventLoop) {
self.role = previous.role
self.serializer = previous.serializer
Expand All @@ -42,6 +44,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentifier = previous.sessionIdentifier
self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: previous.role, remoteVersion: previous.remoteVersion, protectionSchemes: previous.protectionSchemes, previousSessionIdentifier: self.sessionIdentifier)
self.connectionAttributes = previous.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ extension SSHConnectionStateMachine {
/// The user auth state machine that drives user authentication.
var userAuthStateMachine: UserAuthenticationStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(keyExchangeState state: KeyExchangeState,
loop: EventLoop) {
self.role = state.role
Expand All @@ -53,6 +55,7 @@ extension SSHConnectionStateMachine {
self.userAuthStateMachine = UserAuthenticationStateMachine(role: self.role,
loop: loop,
sessionID: self.sessionIdentifier)
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previousState: RekeyingState) {
self.role = previousState.role
self.parser = previousState.parser
Expand All @@ -44,6 +46,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previousState: RekeyingState) {
self.role = previousState.role
self.parser = previousState.parser
Expand All @@ -44,6 +46,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var keyExchangeStateMachine: SSHKeyExchangeStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previousState: ReceivedKexInitWhenActiveState) {
self.role = previousState.role
self.parser = previousState.parser
Expand All @@ -43,6 +45,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}

init(_ previousState: SentKexInitWhenActiveState) {
Expand All @@ -53,6 +56,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previousState.protectionSchemes
self.sessionIdentifier = previousState.sessionIdentitifier
self.keyExchangeStateMachine = previousState.keyExchangeStateMachine
self.connectionAttributes = previousState.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ extension SSHConnectionStateMachine {

internal var keyExchangeStateMachine: SSHKeyExchangeStateMachine

internal weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(_ previous: ActiveState, allocator: ByteBufferAllocator, loop: EventLoop) {
self.role = previous.role
self.serializer = previous.serializer
Expand All @@ -42,6 +44,7 @@ extension SSHConnectionStateMachine {
self.protectionSchemes = previous.protectionSchemes
self.sessionIdentitifier = previous.sessionIdentifier
self.keyExchangeStateMachine = SSHKeyExchangeStateMachine(allocator: allocator, loop: loop, role: self.role, remoteVersion: self.remoteVersion, protectionSchemes: self.protectionSchemes, previousSessionIdentifier: previous.sessionIdentifier)
self.connectionAttributes = previous.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ extension SSHConnectionStateMachine {
/// The user auth state machine that drives user authentication.
var userAuthStateMachine: UserAuthenticationStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(keyExchangeState state: KeyExchangeState,
loop: EventLoop) {
self.role = state.role
Expand All @@ -53,6 +55,7 @@ extension SSHConnectionStateMachine {
self.userAuthStateMachine = UserAuthenticationStateMachine(role: self.role,
loop: loop,
sessionID: self.sessionIdentifier)
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ extension SSHConnectionStateMachine {

var protectionSchemes: [NIOSSHTransportProtection.Type]

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

private let allocator: ByteBufferAllocator

init(idleState state: IdleState, allocator: ByteBufferAllocator) {
Expand All @@ -37,6 +39,7 @@ extension SSHConnectionStateMachine {

self.parser = SSHPacketParser(allocator: allocator)
self.allocator = allocator
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ extension SSHConnectionStateMachine {
/// The backing state machine.
var userAuthStateMachine: UserAuthenticationStateMachine

weak var connectionAttributes: SSHConnectionStateMachine.Attributes?

init(sentNewKeysState state: SentNewKeysState) {
self.role = state.role
self.parser = state.parser
Expand All @@ -43,6 +45,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = state.remoteVersion
self.protectionSchemes = state.protectionSchemes
self.sessionIdentifier = state.sessionIdentifier
self.connectionAttributes = state.connectionAttributes
}

init(receivedNewKeysState state: ReceivedNewKeysState) {
Expand All @@ -53,6 +56,7 @@ extension SSHConnectionStateMachine {
self.remoteVersion = state.remoteVersion
self.protectionSchemes = state.protectionSchemes
self.sessionIdentifier = state.sessionIdentifier
self.connectionAttributes = state.connectionAttributes
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions Sources/NIOSSH/NIOSSHHandler.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ public final class NIOSSHHandler {

private var pendingGlobalRequestResponses: CircularBuffer<PendingGlobalRequestResponse?>

// The authenticated username, if there was one.
var username: String? { stateMachine.username }

public init(role: SSHConnectionRole, allocator: ByteBufferAllocator, inboundChildChannelInitializer: ((Channel, SSHChannelType) -> EventLoopFuture<Void>)?) {
self.stateMachine = SSHConnectionStateMachine(role: role)
self.pendingWrite = false
Expand Down
2 changes: 2 additions & 0 deletions Tests/NIOSSHTests/ChildChannelMultiplexerTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import XCTest
/// This reduces the testing surface area somewhat, which greatly helps us to test the
/// implementation of the multiplexer and child channels.
final class DummyDelegate: SSHMultiplexerDelegate {
var username : String? = "dummy"

var _channel: EmbeddedChannel = EmbeddedChannel()

var writes: MarkedCircularBuffer<(SSHMessage, EventLoopPromise<Void>?)> = MarkedCircularBuffer(initialCapacity: 8)
Expand Down