Skip to content

Add option to interpret request headers as Latin1 #18255

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ public partial class KestrelServerOptions
{
internal System.Security.Cryptography.X509Certificates.X509Certificate2 DefaultCertificate { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
internal bool IsDevCertLoaded { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
internal bool Latin1RequestHeaders { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } [System.Runtime.CompilerServices.CompilerGeneratedAttribute]set { } }
internal System.Collections.Generic.List<Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions> ListenOptions { [System.Runtime.CompilerServices.CompilerGeneratedAttribute]get { throw null; } }
internal void ApplyDefaultCert(Microsoft.AspNetCore.Server.Kestrel.Https.HttpsConnectionAdapterOptions httpsOptions) { }
internal void ApplyEndpointDefaults(Microsoft.AspNetCore.Server.Kestrel.Core.ListenOptions listenOptions) { }
Expand Down Expand Up @@ -433,6 +434,7 @@ public ConfigurationReader(Microsoft.Extensions.Configuration.IConfiguration con
public System.Collections.Generic.IDictionary<string, Microsoft.AspNetCore.Server.Kestrel.Core.Internal.CertificateConfig> Certificates { get { throw null; } }
public Microsoft.AspNetCore.Server.Kestrel.Core.Internal.EndpointDefaults EndpointDefaults { get { throw null; } }
public System.Collections.Generic.IEnumerable<Microsoft.AspNetCore.Server.Kestrel.Core.Internal.EndpointConfig> Endpoints { get { throw null; } }
public bool Latin1RequestHeaders { get { throw null; } }
}
internal partial class HttpConnectionContext
{
Expand Down Expand Up @@ -879,7 +881,7 @@ public void ThrowRequestTargetRejected(System.Span<byte> target) { }
}
internal sealed partial class HttpRequestHeaders : Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpHeaders
{
public HttpRequestHeaders(bool reuseHeaderValues = true) { }
public HttpRequestHeaders(bool reuseHeaderValues = true, bool useLatin1 = false) { }
public bool HasConnection { get { throw null; } }
public bool HasTransferEncoding { get { throw null; } }
public Microsoft.Extensions.Primitives.StringValues HeaderAccept { get { throw null; } set { } }
Expand Down Expand Up @@ -1614,7 +1616,6 @@ internal static partial class HttpUtilities
public const string Http2Version = "HTTP/2";
public const string HttpsUriScheme = "https://";
public const string HttpUriScheme = "http://";
public static string GetAsciiOrUTF8StringNonNullCharacters(this System.Span<byte> span) { throw null; }
public static string GetAsciiStringEscaped(this System.Span<byte> span, int maxChars) { throw null; }
public static string GetAsciiStringNonNullCharacters(this System.Span<byte> span) { throw null; }
public static string GetHeaderName(this System.Span<byte> span) { throw null; }
Expand All @@ -1624,6 +1625,7 @@ internal static partial class HttpUtilities
public static Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpMethod GetKnownMethod(string value) { throw null; }
[System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]internal unsafe static Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpVersion GetKnownVersion(byte* location, int length) { throw null; }
[System.Runtime.CompilerServices.MethodImpl(System.Runtime.CompilerServices.MethodImplOptions.AggressiveInlining)]public static bool GetKnownVersion(this System.Span<byte> span, out Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpVersion knownVersion, out byte length) { throw null; }
public static string GetRequestHeaderStringNonNullCharacters(this System.Span<byte> span, bool useLatin1) { throw null; }
public static bool IsHostHeaderValid(string hostText) { throw null; }
public static string MethodToString(Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpMethod method) { throw null; }
public static string SchemeToString(Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.HttpScheme scheme) { throw null; }
Expand Down Expand Up @@ -1683,6 +1685,7 @@ public StringUtilities() { }
public static bool BytesOrdinalEqualsStringAndAscii(string previousValue, System.Span<byte> newValue) { throw null; }
public static string ConcatAsHexSuffix(string str, char separator, uint number) { throw null; }
public unsafe static bool TryGetAsciiString(byte* input, char* output, int count) { throw null; }
public unsafe static bool TryGetLatin1String(byte* input, char* output, int count) { throw null; }
}
internal partial class TimeoutControl : Microsoft.AspNetCore.Server.Kestrel.Core.Features.IConnectionTimeoutFeature, Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Infrastructure.ITimeoutControl
{
Expand Down
17 changes: 16 additions & 1 deletion src/Servers/Kestrel/Core/src/Internal/ConfigurationReader.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) .NET Foundation. All rights reserved.
// Copyright (c) .NET Foundation. All rights reserved.
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
Expand All @@ -15,11 +15,13 @@ internal class ConfigurationReader
private const string EndpointDefaultsKey = "EndpointDefaults";
private const string EndpointsKey = "Endpoints";
private const string UrlKey = "Url";
private const string Latin1RequestHeadersKey = "Latin1RequestHeaders";

private IConfiguration _configuration;
private IDictionary<string, CertificateConfig> _certificates;
private IList<EndpointConfig> _endpoints;
private EndpointDefaults _endpointDefaults;
private bool? _latin1RequestHeaders;

public ConfigurationReader(IConfiguration configuration)
{
Expand Down Expand Up @@ -65,6 +67,19 @@ public IEnumerable<EndpointConfig> Endpoints
}
}

public bool Latin1RequestHeaders
{
get
{
if (_latin1RequestHeaders is null)
{
_latin1RequestHeaders = _configuration.GetValue<bool>(Latin1RequestHeadersKey);
}

return _latin1RequestHeaders.Value;
}
}

private void ReadCertificates()
{
_certificates = new Dictionary<string, CertificateConfig>(0);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6105,7 +6105,7 @@ public unsafe void Append(Span<byte> name, Span<byte> value)
}

// We didn't have a previous matching header value, or have already added a header, so get the string for this value.
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters();
var valueStr = value.GetRequestHeaderStringNonNullCharacters(_useLatin1);
if ((_bits & flag) == 0)
{
// We didn't already have a header set, so add a new one.
Expand All @@ -6123,7 +6123,7 @@ public unsafe void Append(Span<byte> name, Span<byte> value)
// The header was not one of the "known" headers.
// Convert value to string first, because passing two spans causes 8 bytes stack zeroing in
// this method with rep stosd, which is slower than necessary.
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters();
var valueStr = value.GetRequestHeaderStringNonNullCharacters(_useLatin1);
AppendUnknownHeaders(name, valueStr);
}
}
Expand Down
8 changes: 6 additions & 2 deletions src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ public HttpProtocol(HttpConnectionContext context)
_context = context;

ServerOptions = ServiceContext.ServerOptions;
HttpRequestHeaders = new HttpRequestHeaders(reuseHeaderValues: !ServerOptions.DisableStringReuse);

HttpRequestHeaders = new HttpRequestHeaders(
reuseHeaderValues: !ServerOptions.DisableStringReuse,
useLatin1: ServerOptions.Latin1RequestHeaders);

HttpResponseControl = this;
}

Expand Down Expand Up @@ -513,7 +517,7 @@ public void OnTrailer(Span<byte> name, Span<byte> value)
}

string key = name.GetHeaderName();
var valueStr = value.GetAsciiOrUTF8StringNonNullCharacters();
var valueStr = value.GetRequestHeaderStringNonNullCharacters(ServerOptions.Latin1RequestHeaders);
RequestTrailers.Append(key, valueStr);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ namespace Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http
internal sealed partial class HttpRequestHeaders : HttpHeaders
{
private readonly bool _reuseHeaderValues;
private readonly bool _useLatin1;
private long _previousBits = 0;

public HttpRequestHeaders(bool reuseHeaderValues = true)
public HttpRequestHeaders(bool reuseHeaderValues = true, bool useLatin1 = false)
{
_reuseHeaderValues = reuseHeaderValues;
_useLatin1 = useLatin1;
}

public void OnHeadersComplete()
Expand Down Expand Up @@ -80,7 +82,7 @@ private void AppendContentLength(Span<byte> value)
parsed < 0 ||
consumed != value.Length)
{
BadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetAsciiOrUTF8StringNonNullCharacters());
BadHttpRequestException.Throw(RequestRejectionReason.InvalidContentLength, value.GetRequestHeaderStringNonNullCharacters(_useLatin1));
}

_contentLength = parsed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ public static unsafe string GetAsciiStringNonNullCharacters(this Span<byte> span
fixed (char* output = asciiString)
fixed (byte* buffer = span)
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// StringUtilities.TryGetAsciiString returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
{
Expand All @@ -130,7 +130,7 @@ public static unsafe string GetAsciiStringNonNullCharacters(this Span<byte> span
return asciiString;
}

public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
private static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte> span)
{
if (span.IsEmpty)
{
Expand All @@ -142,7 +142,7 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte
fixed (char* output = resultString)
fixed (byte* buffer = span)
{
// This version if AsciiUtilities returns null if there are any null (0 byte) characters
// StringUtilities.TryGetAsciiString returns null if there are any null (0 byte) characters
// in the string
if (!StringUtilities.TryGetAsciiString(buffer, output, span.Length))
{
Expand All @@ -162,9 +162,36 @@ public static unsafe string GetAsciiOrUTF8StringNonNullCharacters(this Span<byte
}
}
}

return resultString;
}

private static unsafe string GetLatin1StringNonNullCharacters(this Span<byte> span)
{
if (span.IsEmpty)
{
return string.Empty;
}

var resultString = new string('\0', span.Length);

fixed (char* output = resultString)
fixed (byte* buffer = span)
{
// This returns false if there are any null (0 byte) characters in the string.
if (!StringUtilities.TryGetLatin1String(buffer, output, span.Length))
{
// null characters are considered invalid
throw new InvalidOperationException();
}
}

return resultString;
}

public static string GetRequestHeaderStringNonNullCharacters(this Span<byte> span, bool useLatin1) =>
useLatin1 ? GetLatin1StringNonNullCharacters(span) : GetAsciiOrUTF8StringNonNullCharacters(span);

public static string GetAsciiStringEscaped(this Span<byte> span, int maxChars)
{
var sb = new StringBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// Licensed under the Apache License, Version 2.0. See License.txt in the project root for license information.

using System;
using System.Buffers.Binary;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
Expand All @@ -17,6 +16,9 @@ internal class StringUtilities
[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public static unsafe bool TryGetAsciiString(byte* input, char* output, int count)
{
Debug.Assert(input != null);
Debug.Assert(output != null);

// Calculate end position
var end = input + count;
// Start as valid
Expand Down Expand Up @@ -115,6 +117,111 @@ out Unsafe.AsRef<Vector<short>>(output),
return isValid;
}

[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public static unsafe bool TryGetLatin1String(byte* input, char* output, int count)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using unsafe code here instead of Span and Span? To match TryGetAsciiString?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I'm trying to keep this as close to TryGetAsciiString as possible since it's conceptually quite similar.

{
Debug.Assert(input != null);
Debug.Assert(output != null);

// Calculate end position
var end = input + count;
// Start as valid
var isValid = true;

do
{
// If Vector not-accelerated or remaining less than vector size
if (!Vector.IsHardwareAccelerated || input > end - Vector<sbyte>.Count)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clause should be end - input <= Vector<sbyte>.Count.

{
if (IntPtr.Size == 8) // Use Intrinsic switch for branch elimination
{
// 64-bit: Loop longs by default
while (input <= end - sizeof(long))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This clause should be end - input >= sizeof(long).

{
isValid &= CheckBytesNotNull(((long*)input)[0]);

output[0] = (char)input[0];
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could use Bmi2.X64.ParallelBitDeposit if supported. Same for the 32-bit path.
See

if (Bmi2.X64.IsSupported)
{
// BMI2 will work regardless of the processor's endianness.
((ulong*)output)[0] = Bmi2.X64.ParallelBitDeposit((ulong)value, 0x00FF00FF_00FF00FFul);
((ulong*)output)[1] = Bmi2.X64.ParallelBitDeposit((ulong)(value >> 32), 0x00FF00FF_00FF00FFul);
}
else
{
output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
output[4] = (char)input[4];
output[5] = (char)input[5];
output[6] = (char)input[6];
output[7] = (char)input[7];
}

output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];
output[4] = (char)input[4];
output[5] = (char)input[5];
output[6] = (char)input[6];
output[7] = (char)input[7];

input += sizeof(long);
output += sizeof(long);
}
if (input <= end - sizeof(int))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end - input >= sizeof(int)

{
isValid &= CheckBytesNotNull(((int*)input)[0]);

output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];

input += sizeof(int);
output += sizeof(int);
}
}
else
{
// 32-bit: Loop ints by default
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need for this branch. Just do loops of stride 8 on all platforms, and ignore the IntPtr.Size value.

while (input <= end - sizeof(int))
{
isValid &= CheckBytesNotNull(((int*)input)[0]);

output[0] = (char)input[0];
output[1] = (char)input[1];
output[2] = (char)input[2];
output[3] = (char)input[3];

input += sizeof(int);
output += sizeof(int);
}
}
if (input <= end - sizeof(short))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end - input >= sizeof(short)

{
isValid &= CheckBytesNotNull(((short*)input)[0]);

output[0] = (char)input[0];
output[1] = (char)input[1];

input += sizeof(short);
output += sizeof(short);
}
if (input < end)
{
isValid &= CheckBytesNotNull(((sbyte*)input)[0]);
output[0] = (char)input[0];
}

return isValid;
}

// do/while as entry condition already checked
do
{
// Use byte/ushort instead of signed equivalents to ensure it doesn't fill based on the high bit.
var vector = Unsafe.AsRef<Vector<byte>>(input);
isValid &= CheckBytesNotNull(vector);
Vector.Widen(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this method is hot, you could add explicit paths for AVX2, SSE2, because

  • Vector operates on the widest op-set available, i.e. AVX, so some iterations need to be done sequentially instead of parallel as SSE would allow
  • Vector-methods may result in somewhat not ideal codegen.

For reference see

public static unsafe bool TryGetAsciiString(byte* input, char* output, int count)
from #17556

Except CheckBytesInAsciiRange both methods are similar. So they could be combined to one, and with some tricks let the JIT generate the proper code for both flavors.
I could do this as part of my PR, if this one gets merged first (also add the BMI-path, comment above).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't expect this method to be hot. This is mostly here so Bing can accept non-UTF8/ASCII encoded headers from legacy clients. This is opt-in via an undocumented IConfiguration flag. Even when opted-in, this method will only be hit to decode non-ASCII request headers.

I'll look into merging #17556 to master soon. I'll convert GetLatin1String to use your TryGetAsciiString optimizations when I do.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is if we don't just end up using Encoding.Latin1 in .NET 5 which we very well might.

vector,
out Unsafe.AsRef<Vector<ushort>>(output),
out Unsafe.AsRef<Vector<ushort>>(output + Vector<ushort>.Count));

input += Vector<byte>.Count;
output += Vector<byte>.Count;
} while (input <= end - Vector<byte>.Count);

// Vector path done, loop back to do non-Vector
// If is a exact multiple of vector size, bail now
} while (input < end);

return isValid;
}

[MethodImpl(MethodImplOptions.AggressiveOptimization)]
public unsafe static bool BytesOrdinalEqualsStringAndAscii(string previousValue, Span<byte> newValue)
{
Expand Down Expand Up @@ -421,7 +528,7 @@ private static bool CheckBytesInAsciiRange(Vector<sbyte> check)
// Validate: bytes != 0 && bytes <= 127
// Subtract 1 from all bytes to move 0 to high bits
// bitwise or with self to catch all > 127 bytes
// mask off high bits and check if 0
// mask off non high bits and check if 0

[MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
private static bool CheckBytesInAsciiRange(long check)
Expand All @@ -444,5 +551,39 @@ private static bool CheckBytesInAsciiRange(short check)

private static bool CheckBytesInAsciiRange(sbyte check)
=> check > 0;

[MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
private static bool CheckBytesNotNull(Vector<byte> check)
{
// Vectorized byte range check, signed byte != null
return !Vector.EqualsAny(check, Vector<byte>.Zero);
}

// Validate: bytes != 0
// Subtract 1 from all bytes to move 0 to high bits
// bitwise and with ~check so high bits are only set for bytes that were originally 0
// mask off non high bits and check if 0

[MethodImpl(MethodImplOptions.AggressiveInlining)] // Needs a push
private static bool CheckBytesNotNull(long check)
{
const long HighBits = unchecked((long)0x8080808080808080L);
return ((check - 0x0101010101010101L) & ~check & HighBits) == 0;
}

private static bool CheckBytesNotNull(int check)
{
const int HighBits = unchecked((int)0x80808080);
return ((check - 0x01010101) & ~check & HighBits) == 0;
}

private static bool CheckBytesNotNull(short check)
{
const short HighBits = unchecked((short)0x8080);
return ((check - 0x0101) & ~check & HighBits) == 0;
}

private static bool CheckBytesNotNull(sbyte check)
=> check != 0;
}
}
Loading