Skip to content

[feat] support alpn negotiation in ssl context #247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 3, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 80 additions & 9 deletions src/main/java/org/jruby/ext/openssl/SSLContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@

import javax.net.ssl.KeyManager;
import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLParameters;
import javax.net.ssl.SSLSessionContext;
import javax.net.ssl.TrustManager;
import javax.net.ssl.X509ExtendedKeyManager;
Expand All @@ -53,13 +54,15 @@
import org.jruby.Ruby;
import org.jruby.RubyArray;
import org.jruby.RubyClass;
import org.jruby.RubyString;
import org.jruby.RubyFixnum;
import org.jruby.RubyHash;
import org.jruby.RubyInteger;
import org.jruby.RubyModule;
import org.jruby.RubyNumeric;
import org.jruby.RubyObject;
import org.jruby.RubySymbol;
import org.jruby.RubyProc;
import org.jruby.anno.JRubyMethod;
import org.jruby.common.IRubyWarnings.ID;
import org.jruby.runtime.Arity;
Expand Down Expand Up @@ -242,6 +245,8 @@ public static void createSSLContext(final Ruby runtime, final RubyModule SSL) {
SSLContext.addReadWriteAttribute(context, "tmp_dh_callback");
SSLContext.addReadWriteAttribute(context, "servername_cb");
SSLContext.addReadWriteAttribute(context, "renegotiation_cb");
SSLContext.addReadWriteAttribute(context, "alpn_protocols");
SSLContext.addReadWriteAttribute(context, "alpn_select_cb");

SSLContext.defineAlias("ssl_timeout", "timeout");
SSLContext.defineAlias("ssl_timeout=", "timeout=");
Expand Down Expand Up @@ -451,6 +456,29 @@ public IRubyObject setup(final ThreadContext context) {
// SSL_CTX_set_tlsext_servername_callback(ctx, ssl_servername_cb);
}

final String[] alpnProtocols;

value = getInstanceVariable("@alpn_protocols");
if ( value != null && ! value.isNil() ) {
IRubyObject[] alpn_protocols = ((RubyArray) value).toJavaArrayMaybeUnsafe();
String[] protocols = new String[alpn_protocols.length];
for(int i = 0; i < protocols.length; i++) {
protocols[i] = alpn_protocols[i].convertToString().asJavaString();
}
alpnProtocols = protocols;
} else {
alpnProtocols = null;
}

final RubyProc alpnSelectCb;
value = getInstanceVariable("@alpn_select_cb");
if ( value != null && ! value.isNil() ) {
alpnSelectCb = (RubyProc) value;
} else {
alpnSelectCb = null;
}


// NOTE: no API under javax.net to support session get/new/remove callbacks
/*
val = ossl_sslctx_get_sess_id_ctx(self);
Expand All @@ -477,7 +505,8 @@ public IRubyObject setup(final ThreadContext context) {
*/

try {
internalContext = createInternalContext(context, cert, key, store, clientCert, extraChainCert, verifyMode, timeout);
internalContext = createInternalContext(context, cert, key, store, clientCert, extraChainCert,
verifyMode, timeout, alpnProtocols, alpnSelectCb);
}
catch (GeneralSecurityException e) {
throw newSSLError(runtime, e);
Expand Down Expand Up @@ -505,7 +534,7 @@ private RubyArray matchedCiphersWithCache(final ThreadContext context) {
private RubyArray matchedCiphers(final ThreadContext context) {
final Ruby runtime = context.runtime;
try {
final String[] supported = getSupportedCipherSuites(protocol);
final String[] supported = getSupportedCipherSuites(runtime, protocol);
final Collection<CipherStrings.Def> cipherDefs =
CipherStrings.matchingCiphers(this.ciphers, supported, false);

Expand Down Expand Up @@ -688,14 +717,48 @@ private static class CipherListCache {
}
}

private static String[] getSupportedCipherSuites(final String protocol)
void setApplicationProtocolsOrSelector(final SSLEngine engine) {
setApplicationProtocolSelector(engine);
setApplicationProtocols(engine);
}

private void setApplicationProtocolSelector(final SSLEngine engine) {
final RubyProc alpn_select_cb = internalContext.alpnSelectCallback;
if (alpn_select_cb != null) {
engine.setHandshakeApplicationProtocolSelector((_engine, protocols) -> {
final Ruby runtime = getRuntime();
IRubyObject[] rubyProtocols = new IRubyObject[protocols.size()];
int i = 0; for (String protocol : protocols) {
rubyProtocols[i++] = runtime.newString(protocol);
}

IRubyObject[] args = new IRubyObject[] { RubyArray.newArray(runtime, rubyProtocols) };
IRubyObject selected_protocol = alpn_select_cb.call(runtime.getCurrentContext(), args);
if (selected_protocol != null && !selected_protocol.isNil()) {
return ((RubyString) selected_protocol).asJavaString();
}
return null; // callback returned nil - none of the advertised names are acceptable
});
}
}

private void setApplicationProtocols(final SSLEngine engine) {
final String[] alpn_protocols = internalContext.alpnProtocols;
if (alpn_protocols != null) {
SSLParameters params = engine.getSSLParameters();
params.setApplicationProtocols(alpn_protocols);
engine.setSSLParameters(params);
}
}

private static String[] getSupportedCipherSuites(Ruby runtime, final String protocol)
throws GeneralSecurityException {
return dummySSLEngine(protocol).getSupportedCipherSuites();
return dummySSLEngine(runtime, protocol).getSupportedCipherSuites();
}

private static SSLEngine dummySSLEngine(final String protocol) throws GeneralSecurityException {
private static SSLEngine dummySSLEngine(Ruby runtime, final String protocol) throws GeneralSecurityException {
javax.net.ssl.SSLContext sslContext = SecurityHelper.getSSLContext(protocol);
sslContext.init(null, null, null);
sslContext.init(null, null, OpenSSL.getSecureRandom(runtime));
return sslContext.createSSLEngine();
}

Expand Down Expand Up @@ -899,8 +962,9 @@ static RubyClass _SSLContext(final Ruby runtime) {
private InternalContext createInternalContext(ThreadContext context,
final X509Cert xCert, final PKey pKey, final Store store,
final List<X509AuxCertificate> clientCert, final List<X509AuxCertificate> extraChainCert,
final int verifyMode, final int timeout) throws NoSuchAlgorithmException, KeyManagementException {
InternalContext internalContext = new InternalContext(xCert, pKey, store, clientCert, extraChainCert, verifyMode, timeout);
final int verifyMode, final int timeout,
final String[] alpnProtocols, final RubyProc alpnSelectCb) throws NoSuchAlgorithmException, KeyManagementException {
InternalContext internalContext = new InternalContext(xCert, pKey, store, clientCert, extraChainCert, verifyMode, timeout, alpnProtocols, alpnSelectCb);
internalContext.initSSLContext(context);
return internalContext;
}
Expand All @@ -917,7 +981,9 @@ private class InternalContext {
final List<X509AuxCertificate> clientCert,
final List<X509AuxCertificate> extraChainCert,
final int verifyMode,
final int timeout) throws NoSuchAlgorithmException {
final int timeout,
final String[] alpnProtocols,
final RubyProc alpnSelectCallback) throws NoSuchAlgorithmException {

if ( pKey != null && xCert != null ) {
this.privateKey = pKey.getPrivateKey();
Expand All @@ -935,6 +1001,8 @@ private class InternalContext {
this.extraChainCert = extraChainCert;
this.verifyMode = verifyMode;
this.timeout = timeout;
this.alpnProtocols = alpnProtocols;
this.alpnSelectCallback = alpnSelectCallback;

// initialize SSL context :

Expand Down Expand Up @@ -982,6 +1050,9 @@ void initSSLContext(final ThreadContext context) throws KeyManagementException {

private final int timeout;

private final String[] alpnProtocols;
private final RubyProc alpnSelectCallback;

private final javax.net.ssl.SSLContext sslContext;

// part of ssl_verify_cert_chain
Expand Down
13 changes: 11 additions & 2 deletions src/main/java/org/jruby/ext/openssl/SSLSocket.java
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,9 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean serv
dummy = ByteBuffer.allocate(0);
this.engine = engine;
copySessionSetupIfSet(context);

sslContext.setApplicationProtocolsOrSelector(engine);

return engine;
}

Expand All @@ -238,6 +241,12 @@ private SSLEngine ossl_ssl_setup(final ThreadContext context, final boolean serv
@JRubyMethod(name = "context")
public final SSLContext context() { return this.sslContext; }

@JRubyMethod(name = "alpn_protocol")
public IRubyObject alpn_protocol(final ThreadContext context) {
final String protocol = engine.getApplicationProtocol();
return protocol == null ? context.nil : RubyString.newString(context.runtime, protocol);
}

@JRubyMethod(name = "sync")
public IRubyObject sync(final ThreadContext context) {
final CallSite[] sites = getMetaClass().getExtraCallSites();
Expand Down Expand Up @@ -283,7 +292,7 @@ private IRubyObject connectImpl(final ThreadContext context, final boolean block

try {
if ( ! initialHandshake ) {
SSLEngine engine = ossl_ssl_setup(context, true);
SSLEngine engine = ossl_ssl_setup(context, false);
engine.setUseClientMode(true);
engine.beginHandshake();
handshakeStatus = engine.getHandshakeStatus();
Expand Down Expand Up @@ -343,7 +352,7 @@ private IRubyObject acceptImpl(final ThreadContext context, final boolean blocki

try {
if ( ! initialHandshake ) {
final SSLEngine engine = ossl_ssl_setup(context, false);
final SSLEngine engine = ossl_ssl_setup(context, true);
engine.setUseClientMode(false);
final IRubyObject verify_mode = verify_mode(context);
if ( verify_mode != context.nil ) {
Expand Down
21 changes: 21 additions & 0 deletions src/test/ruby/ssl/test_session.rb
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,27 @@ def test_session
end
end

def test_alpn_protocol_selection_ary
advertised = ["h2", "http/1.1"]
ctx_proc = Proc.new { |ctx|
ctx.alpn_select_cb = -> (protocols) {
assert_equal Array, protocols.class
assert_equal advertised, protocols
protocols.first
}
}
start_server0(PORT, OpenSSL::SSL::VERIFY_NONE, true, ctx_proc: ctx_proc) do |server, port|
sock = TCPSocket.new("127.0.0.1", port)
ctx = OpenSSL::SSL::SSLContext.new("TLSv1_2")
ctx.alpn_protocols = advertised
ssl = OpenSSL::SSL::SSLSocket.new(sock, ctx)
ssl.sync_close = true
ssl.connect
assert_equal("h2", ssl.alpn_protocol)
ssl.puts "abc"; assert_equal "abc\n", ssl.gets
end
end

def test_exposes_session_error
OpenSSL::SSL::Session::SessionError
end
Expand Down