diff --git a/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp3.0.cs b/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp3.0.cs index 06dfb1465f70..e431d6366903 100644 --- a/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp3.0.cs +++ b/src/Servers/Kestrel/Core/ref/Microsoft.AspNetCore.Server.Kestrel.Core.netcoreapp3.0.cs @@ -484,6 +484,7 @@ public void Dispose() { } public System.Threading.Tasks.ValueTask FlushAsync(System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } public System.Memory GetMemory(int sizeHint = 0) { throw null; } public System.Span GetSpan(int sizeHint = 0) { throw null; } + public void Reset() { } public System.Threading.Tasks.ValueTask Write100ContinueAsync() { throw null; } public System.Threading.Tasks.ValueTask WriteChunkAsync(System.ReadOnlySpan buffer, System.Threading.CancellationToken cancellationToken) { throw null; } public System.Threading.Tasks.Task WriteDataAsync(System.ReadOnlySpan buffer, System.Threading.CancellationToken cancellationToken = default(System.Threading.CancellationToken)) { throw null; } @@ -954,6 +955,7 @@ public partial interface IHttpOutputProducer System.Threading.Tasks.ValueTask FlushAsync(System.Threading.CancellationToken cancellationToken); System.Memory GetMemory(int sizeHint = 0); System.Span GetSpan(int sizeHint = 0); + void Reset(); System.Threading.Tasks.ValueTask Write100ContinueAsync(); System.Threading.Tasks.ValueTask WriteChunkAsync(System.ReadOnlySpan data, System.Threading.CancellationToken cancellationToken); System.Threading.Tasks.Task WriteDataAsync(System.ReadOnlySpan data, System.Threading.CancellationToken cancellationToken); @@ -1297,6 +1299,7 @@ public void Dispose() { } public System.Span GetSpan(int sizeHint = 0) { throw null; } void Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.IHttpOutputAborter.Abort(Microsoft.AspNetCore.Connections.ConnectionAbortedException abortReason) { } System.Threading.Tasks.ValueTask Microsoft.AspNetCore.Server.Kestrel.Core.Internal.Http.IHttpOutputProducer.WriteChunkAsync(System.ReadOnlySpan data, System.Threading.CancellationToken cancellationToken) { throw null; } + public void Reset() { } public System.Threading.Tasks.ValueTask Write100ContinueAsync() { throw null; } public System.Threading.Tasks.Task WriteChunkAsync(System.ReadOnlySpan span, System.Threading.CancellationToken cancellationToken) { throw null; } public System.Threading.Tasks.Task WriteDataAsync(System.ReadOnlySpan data, System.Threading.CancellationToken cancellationToken) { throw null; } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs index 5ad99ccabe75..f620a4be1e10 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/Http1OutputProducer.cs @@ -3,6 +3,7 @@ using System; using System.Buffers; +using System.Collections.Generic; using System.Diagnostics; using System.IO.Pipelines; using System.Threading; @@ -26,6 +27,9 @@ public class Http1OutputProducer : IHttpOutputProducer, IHttpOutputAborter, IDis // "0\r\n\r\n" private static ReadOnlySpan EndChunkedResponseBytes => new byte[] { (byte)'0', (byte)'\r', (byte)'\n', (byte)'\r', (byte)'\n' }; + private const int BeginChunkLengthMax = 5; + private const int EndChunkLength = 2; + private readonly string _connectionId; private readonly ConnectionContext _connectionContext; private readonly IKestrelTrace _log; @@ -40,21 +44,28 @@ public class Http1OutputProducer : IHttpOutputProducer, IHttpOutputAborter, IDis private bool _completed; private bool _aborted; private long _unflushedBytes; - private bool _autoChunk; + private readonly PipeWriter _pipeWriter; - private const int MemorySizeThreshold = 1024; - private const int BeginChunkLengthMax = 5; - private const int EndChunkLength = 2; + private IMemoryOwner _fakeMemoryOwner; // Chunked responses need to be treated uniquely when using GetMemory + Advance. // We need to know the size of the data written to the chunk before calling Advance on the // PipeWriter, meaning we internally track how far we have advanced through a current chunk (_advancedBytesForChunk). // Once write or flush is called, we modify the _currentChunkMemory to prepend the size of data written // and append the end terminator. + + private bool _autoChunk; private int _advancedBytesForChunk; private Memory _currentChunkMemory; private bool _currentChunkMemoryUpdated; - private IMemoryOwner _fakeMemoryOwner; + + // Fields needed to store writes before calling either startAsync or Write/FlushAsync + // These should be cleared by the end of the request + private List _completedSegments; + private Memory _currentSegment; + private IMemoryOwner _currentSegmentOwner; + private int _position; + private bool _startCalled; public Http1OutputProducer( PipeWriter pipeWriter, @@ -158,6 +169,10 @@ public Memory GetMemory(int sizeHint = 0) { return GetFakeMemory(sizeHint); } + else if (!_startCalled) + { + return LeasedMemory(sizeHint); + } else if (_autoChunk) { return GetChunkedMemory(sizeHint); @@ -177,6 +192,10 @@ public Span GetSpan(int sizeHint = 0) { return GetFakeMemory(sizeHint).Span; } + else if (!_startCalled) + { + return LeasedMemory(sizeHint).Span; + } else if (_autoChunk) { return GetChunkedMemory(sizeHint).Span; @@ -197,16 +216,23 @@ public void Advance(int bytes) return; } - if (_autoChunk) + if (!_startCalled) { - if (bytes < 0) + if (bytes >= 0) { - throw new ArgumentOutOfRangeException(nameof(bytes)); - } + if (_currentSegment.Length - bytes < _position) + { + throw new ArgumentOutOfRangeException("Can't advance past buffer size."); + } - if (bytes + _advancedBytesForChunk > _currentChunkMemory.Length - BeginChunkLengthMax - EndChunkLength) + _position += bytes; + } + } + else if (_autoChunk) + { + if (_advancedBytesForChunk > _currentChunkMemory.Length - BeginChunkLengthMax - EndChunkLength - bytes) { - throw new InvalidOperationException("Can't advance past buffer size."); + throw new ArgumentOutOfRangeException("Can't advance past buffer size."); } _advancedBytesForChunk += bytes; } @@ -238,6 +264,7 @@ public ValueTask WriteChunkAsync(ReadOnlySpan buffer, Cancell { var writer = new BufferWriter(_pipeWriter); CommitChunkInternal(ref writer, buffer); + _unflushedBytes += writer.BytesCommitted; } } @@ -260,7 +287,6 @@ private void CommitChunkInternal(ref BufferWriter writer, ReadOnlySp } writer.Commit(); - _unflushedBytes += writer.BytesCommitted; } public void WriteResponseHeaders(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders, bool autoChunk) @@ -288,8 +314,52 @@ private void WriteResponseHeadersInternal(ref BufferWriter writer, i writer.Commit(); - _unflushedBytes += writer.BytesCommitted; _autoChunk = autoChunk; + WriteDataWrittenBeforeHeaders(ref writer); + _unflushedBytes += writer.BytesCommitted; + + _startCalled = true; + } + + private void WriteDataWrittenBeforeHeaders(ref BufferWriter writer) + { + if (_completedSegments != null) + { + foreach (var segment in _completedSegments) + { + if (_autoChunk) + { + CommitChunkInternal(ref writer, segment.Span); + } + else + { + writer.Write(segment.Span); + writer.Commit(); + } + segment.Return(); + } + + _completedSegments.Clear(); + } + + if (!_currentSegment.IsEmpty) + { + var segment = _currentSegment.Slice(0, _position); + + if (_autoChunk) + { + CommitChunkInternal(ref writer, segment.Span); + } + else + { + writer.Write(segment.Span); + writer.Commit(); + } + + _position = 0; + + DisposeCurrentSegment(); + } } public void Dispose() @@ -302,10 +372,28 @@ public void Dispose() _fakeMemoryOwner = null; } + // Call dispose on any memory that wasn't written. + if (_completedSegments != null) + { + foreach (var segment in _completedSegments) + { + segment.Return(); + } + } + + DisposeCurrentSegment(); + CompletePipe(); } } + private void DisposeCurrentSegment() + { + _currentSegmentOwner?.Dispose(); + _currentSegmentOwner = null; + _currentSegment = default; + } + private void CompletePipe() { if (!_pipeWriterCompleted) @@ -382,10 +470,21 @@ public ValueTask FirstWriteChunkedAsync(int statusCode, string reas CommitChunkInternal(ref writer, buffer); + _unflushedBytes += writer.BytesCommitted; + return FlushAsync(cancellationToken); } } + public void Reset() + { + Debug.Assert(_currentSegmentOwner == null); + Debug.Assert(_completedSegments == null || _completedSegments.Count == 0); + _autoChunk = false; + _startCalled = false; + _currentChunkMemoryUpdated = false; + } + private ValueTask WriteAsync( ReadOnlySpan buffer, CancellationToken cancellationToken = default) @@ -454,7 +553,7 @@ private Memory GetChunkedMemory(int sizeHint) } var memoryMaxLength = _currentChunkMemory.Length - BeginChunkLengthMax - EndChunkLength; - if (_advancedBytesForChunk >= memoryMaxLength - Math.Min(MemorySizeThreshold, sizeHint)) + if (_advancedBytesForChunk >= memoryMaxLength - sizeHint && _advancedBytesForChunk > 0) { // Chunk is completely written, commit it to the pipe so GetMemory will return a new chunk of memory. var writer = new BufferWriter(_pipeWriter); @@ -506,5 +605,91 @@ private Memory GetFakeMemory(int sizeHint) } return _fakeMemoryOwner.Memory; } + + private Memory LeasedMemory(int sizeHint) + { + EnsureCapacity(sizeHint); + return _currentSegment.Slice(_position); + } + + private void EnsureCapacity(int sizeHint) + { + // Only subtracts _position from the current segment length if it's non-null. + // If _currentSegment is null, it returns 0. + var remainingSize = _currentSegment.Length - _position; + + // If the sizeHint is 0, any capacity will do + // Otherwise, the buffer must have enough space for the entire size hint, or we need to add a segment. + if ((sizeHint == 0 && remainingSize > 0) || (sizeHint > 0 && remainingSize >= sizeHint)) + { + // We have capacity in the current segment + return; + } + + AddSegment(sizeHint); + } + + private void AddSegment(int sizeHint = 0) + { + if (_currentSegment.Length != 0) + { + // We're adding a segment to the list + if (_completedSegments == null) + { + _completedSegments = new List(); + } + + // Position might be less than the segment length if there wasn't enough space to satisfy the sizeHint when + // GetMemory was called. In that case we'll take the current segment and call it "completed", but need to + // ignore any empty space in it. + _completedSegments.Add(new CompletedBuffer(_currentSegmentOwner, _currentSegment, _position)); + } + + if (sizeHint <= _memoryPool.MaxBufferSize) + { + // Get a new buffer using the minimum segment size, unless the size hint is larger than a single segment. + // Also, the size cannot be larger than the MaxBufferSize of the MemoryPool + var owner = _memoryPool.Rent(Math.Min(sizeHint, _memoryPool.MaxBufferSize)); + _currentSegment = owner.Memory; + _currentSegmentOwner = owner; + } + else + { + _currentSegment = new byte[sizeHint]; + _currentSegmentOwner = null; + } + + _position = 0; + } + + + /// + /// Holds a byte[] from the pool and a size value. Basically a Memory but guaranteed to be backed by an ArrayPool byte[], so that we know we can return it. + /// + private readonly struct CompletedBuffer + { + private readonly IMemoryOwner _memoryOwner; + + public Memory Buffer { get; } + public int Length { get; } + + public ReadOnlySpan Span => Buffer.Span.Slice(0, Length); + + public CompletedBuffer(IMemoryOwner owner, Memory buffer, int length) + { + _memoryOwner = owner; + + Buffer = buffer; + Length = length; + } + + public void Return() + { + if (_memoryOwner != null) + { + _memoryOwner.Dispose(); + } + } + } } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs index c575895a7768..9cb304c270a9 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/HttpProtocol.cs @@ -45,7 +45,11 @@ public abstract partial class HttpProtocol : IDefaultHttpContextContainer, IHttp // Keep-alive is default for HTTP/1.1 and HTTP/2; parsing and errors will change its value // volatile, see: https://msdn.microsoft.com/en-us/library/x13ttww7.aspx protected volatile bool _keepAlive = true; - private bool _canWriteResponseBody; + // _canWriteResponseBody is set in CreateResponseHeaders. + // If we are writing with GetMemory/Advance before calling StartAsync, assume we can write and throw away contents if we can't. + private bool _canWriteResponseBody = true; + private bool _hasAdvanced; + private bool _isLeasedMemoryInvalid = true; private bool _autoChunk; protected Exception _applicationException; private BadHttpRequestException _requestRejectedException; @@ -351,6 +355,10 @@ public void Reset() RequestHeaders = HttpRequestHeaders; ResponseHeaders = HttpResponseHeaders; + _isLeasedMemoryInvalid = true; + _hasAdvanced = false; + _canWriteResponseBody = true; + if (_scheme == null) { var tlsFeature = ConnectionFeatures?[typeof(ITlsConnectionFeature)]; @@ -380,6 +388,8 @@ public void Reset() } } + Output?.Reset(); + _requestHeadersParsed = 0; _responseBytesWritten = 0; @@ -921,6 +931,8 @@ private void ProduceStart(bool appCompleted) return; } + _isLeasedMemoryInvalid = true; + _requestProcessingStatus = RequestProcessingStatus.HeadersCommitted; var responseHeaders = CreateResponseHeaders(appCompleted); @@ -1066,7 +1078,7 @@ private HttpResponseHeaders CreateResponseHeaders(bool appCompleted) { _keepAlive = false; } - else if (appCompleted || !_canWriteResponseBody) + else if ((appCompleted || !_canWriteResponseBody) && !_hasAdvanced) // Avoid setting contentLength of 0 if we wrote data before calling CreateResponseHeaders { // Don't set the Content-Length header automatically for HEAD requests, 204 responses, or 304 responses. if (CanAutoSetContentLengthZeroResponseHeader()) @@ -1268,6 +1280,21 @@ public void ReportApplicationError(Exception ex) public void Advance(int bytes) { + if (bytes < 0) + { + throw new ArgumentOutOfRangeException(nameof(bytes)); + } + else if (bytes > 0) + { + _hasAdvanced = true; + } + + if (_isLeasedMemoryInvalid) + { + throw new InvalidOperationException("Invalid ordering of calling StartAsync and Advance. " + + "Call StartAsync before calling GetMemory/GetSpan and Advance."); + } + if (_canWriteResponseBody) { VerifyAndUpdateWrite(bytes); @@ -1276,7 +1303,6 @@ public void Advance(int bytes) else { HandleNonBodyResponseWrite(); - // For HEAD requests, we still use the number of bytes written for logging // how many bytes were written. VerifyAndUpdateWrite(bytes); @@ -1285,27 +1311,16 @@ public void Advance(int bytes) public Memory GetMemory(int sizeHint = 0) { - ThrowIfResponseNotStarted(); - + _isLeasedMemoryInvalid = false; return Output.GetMemory(sizeHint); } public Span GetSpan(int sizeHint = 0) { - ThrowIfResponseNotStarted(); - + _isLeasedMemoryInvalid = false; return Output.GetSpan(sizeHint); } - [StackTraceHidden] - private void ThrowIfResponseNotStarted() - { - if (!HasResponseStarted) - { - throw new InvalidOperationException(CoreStrings.StartAsyncBeforeGetMemory); - } - } - public ValueTask FlushPipeAsync(CancellationToken cancellationToken) { if (!HasResponseStarted) @@ -1338,6 +1353,7 @@ public void Complete(Exception ex) ApplicationAbort(); } } + Output.Complete(); } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs index bc4509e8b347..bcdcd8dd102d 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http/IHttpOutputProducer.cs @@ -25,5 +25,6 @@ public interface IHttpOutputProducer void Complete(); ValueTask FirstWriteAsync(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders, bool autoChunk, ReadOnlySpan data, CancellationToken cancellationToken); ValueTask FirstWriteChunkedAsync(int statusCode, string reasonPhrase, HttpResponseHeaders responseHeaders, bool autoChunk, ReadOnlySpan data, CancellationToken cancellationToken); + void Reset(); } } diff --git a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs index 9de710baf6d2..c454b465b0b5 100644 --- a/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs +++ b/src/Servers/Kestrel/Core/src/Internal/Http2/Http2OutputProducer.cs @@ -279,6 +279,10 @@ public void Complete() // This will noop for now. See: https://github.com/aspnet/AspNetCore/issues/7370 } + public void Reset() + { + } + private async ValueTask ProcessDataWrites() { FlushResult flushResult = default; diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedResponseTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedResponseTests.cs index 919f3827e56b..6c36b041bf48 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedResponseTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ChunkedResponseTests.cs @@ -422,13 +422,14 @@ await connection.Receive( } [Fact] - public async Task ChunksWithGetMemoryBeforeFirstFlushStillFlushes() + public async Task ChunksWithGetMemoryAfterStartAsyncBeforeFirstFlushStillFlushes() { var testContext = new TestServiceContext(LoggerFactory); using (var server = new TestServer(async httpContext => { var response = httpContext.Response; + await response.StartAsync(); var memory = response.BodyWriter.GetMemory(); var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); @@ -466,6 +467,51 @@ await connection.Receive( } } + [Fact] + public async Task ChunksWithGetMemoryBeforeFirstFlushStillFlushes() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + + var memory = response.BodyWriter.GetMemory(); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); + fisrtPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(6); + + memory = response.BodyWriter.GetMemory(); + var secondPartOfResponse = Encoding.ASCII.GetBytes("World!"); + secondPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(6); + + await response.BodyWriter.FlushAsync(); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "c", + "Hello World!", + "0", + "", + ""); + } + + await server.StopAsync(); + } + } + [Fact] public async Task ChunksWithGetMemoryLargeWriteBeforeFirstFlush() { @@ -476,7 +522,6 @@ public async Task ChunksWithGetMemoryLargeWriteBeforeFirstFlush() using (var server = new TestServer(async httpContext => { var response = httpContext.Response; - await response.StartAsync(); var memory = response.BodyWriter.GetMemory(); length.Value = memory.Length; @@ -524,7 +569,7 @@ await connection.Receive( } [Fact] - public async Task ChunksWithGetMemoryWithInitialFlushWorks() + public async Task ChunksWithGetMemoryAndStartAsyncWithInitialFlushWorks() { var length = new IntAsRef(); var semaphore = new SemaphoreSlim(initialCount: 0); @@ -581,6 +626,65 @@ await connection.Receive( } } + [Fact] + public async Task ChunksWithGetMemoryBeforeFlushEdgeCase() + { + var length = 0; + var semaphore = new SemaphoreSlim(initialCount: 0); + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + + await response.StartAsync(); + + var memory = response.BodyWriter.GetMemory(); + length = memory.Length - 1; + semaphore.Release(); + + var fisrtPartOfResponse = Encoding.ASCII.GetBytes(new string('a', length)); + fisrtPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(length); + + var secondMemory = response.BodyWriter.GetMemory(6); + + var secondPartOfResponse = Encoding.ASCII.GetBytes("World!"); + secondPartOfResponse.CopyTo(secondMemory); + response.BodyWriter.Advance(6); + + await response.BodyWriter.FlushAsync(); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + + // Wait for length to be set + await semaphore.WaitAsync(); + + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + length.ToString("x"), + new string('a', length), + "6", + "World!", + "0", + "", + ""); + } + + await server.StopAsync(); + } + } + [Fact] public async Task ChunkGetMemoryMultipleAdvance() { @@ -633,6 +737,7 @@ public async Task ChunkGetSpanMultipleAdvance() using (var server = new TestServer(async httpContext => { var response = httpContext.Response; + await response.StartAsync(); // To avoid using span in an async method void NonAsyncMethod() @@ -647,8 +752,6 @@ void NonAsyncMethod() response.BodyWriter.Advance(6); } - await response.StartAsync(); - NonAsyncMethod(); }, testContext)) { @@ -687,6 +790,7 @@ public async Task ChunkGetMemoryAndWrite() await response.StartAsync(); var memory = response.BodyWriter.GetMemory(4096); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); fisrtPartOfResponse.CopyTo(memory); response.BodyWriter.Advance(6); @@ -717,6 +821,48 @@ await connection.Receive( } } + [Fact] + public async Task ChunkGetMemoryAndWriteWithoutStart() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + + var memory = response.BodyWriter.GetMemory(4096); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); + fisrtPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(6); + + await response.WriteAsync("World!"); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "6", + "Hello ", + "6", + "World!", + "0", + "", + ""); + } + + await server.StopAsync(); + } + } + [Fact] public async Task GetMemoryWithSizeHint() { @@ -758,10 +904,47 @@ await connection.Receive( } } + [Fact] + public async Task GetMemoryWithSizeHintWithoutStartAsync() + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + + var memory = response.BodyWriter.GetMemory(0); + + Assert.Equal(4096, memory.Length); + + memory = response.BodyWriter.GetMemory(1000000); + Assert.Equal(1000000, memory.Length); + await Task.CompletedTask; + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + + await server.StopAsync(); + } + } + [Theory] [InlineData(15)] [InlineData(255)] - public async Task ChunkGetMemoryWithSmallerSizesWork(int writeSize) + public async Task ChunkGetMemoryWithoutStartWithSmallerSizesWork(int writeSize) { var testContext = new TestServiceContext(LoggerFactory); @@ -769,12 +952,53 @@ public async Task ChunkGetMemoryWithSmallerSizesWork(int writeSize) { var response = httpContext.Response; - await response.StartAsync(); var memory = response.BodyWriter.GetMemory(4096); var fisrtPartOfResponse = Encoding.ASCII.GetBytes(new string('a', writeSize)); fisrtPartOfResponse.CopyTo(memory); response.BodyWriter.Advance(writeSize); + await response.BodyWriter.FlushAsync(); + }, testContext)) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host: ", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {testContext.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + writeSize.ToString("X").ToLower(), + new string('a', writeSize), + "0", + "", + ""); + } + + await server.StopAsync(); + } + } + + [Theory] + [InlineData(15)] + [InlineData(255)] + public async Task ChunkGetMemoryWithStartWithSmallerSizesWork(int writeSize) + { + var testContext = new TestServiceContext(LoggerFactory); + + using (var server = new TestServer(async httpContext => + { + var response = httpContext.Response; + + var memory = response.BodyWriter.GetMemory(4096); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes(new string('a', writeSize)); + fisrtPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(writeSize); + await response.BodyWriter.FlushAsync(); }, testContext)) { using (var connection = server.CreateConnection()) @@ -806,7 +1030,7 @@ public async Task ChunkedWithBothPipeAndStreamWorks() using (var server = new TestServer(async httpContext => { var response = httpContext.Response; - await response.StartAsync(); + var memory = response.BodyWriter.GetMemory(4096); var fisrtPartOfResponse = Encoding.ASCII.GetBytes("hello,"); fisrtPartOfResponse.CopyTo(memory); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs index 095d1287092a..5ed364f3b111 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/Http2/Http2StreamTests.cs @@ -2625,6 +2625,52 @@ await ExpectAsync(Http2FrameType.SETTINGS, [Fact] public async Task GetMemoryAdvance_Works() + { + var headers = new[] + { + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(httpContext => + { + var response = httpContext.Response; + var memory = response.BodyWriter.GetMemory(); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes("hello,"); + fisrtPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(6); + + memory = response.BodyWriter.GetMemory(); + var secondPartOfResponse = Encoding.ASCII.GetBytes(" world"); + secondPartOfResponse.CopyTo(memory); + response.BodyWriter.Advance(6); + return Task.CompletedTask; + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 37, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 12, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(2, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + + Assert.True(_helloWorldBytes.AsSpan().SequenceEqual(dataFrame.PayloadSequence.ToArray())); + } + + [Fact] + public async Task GetMemoryAdvance_WithStartAsync_Works() { var headers = new[] { @@ -2682,7 +2728,6 @@ await InitializeConnectionAsync(async httpContext => { var response = httpContext.Response; await response.StartAsync(); - var memory = response.BodyWriter.GetMemory(); Assert.Equal(4096, memory.Length); var fisrtPartOfResponse = Encoding.ASCII.GetBytes(new string('a', memory.Length)); @@ -2884,7 +2929,6 @@ await InitializeConnectionAsync(async httpContext => { var response = httpContext.Response; - await response.StartAsync(); var memory = response.BodyWriter.GetMemory(4096); var fisrtPartOfResponse = Encoding.ASCII.GetBytes("hello,"); @@ -2960,6 +3004,49 @@ await InitializeConnectionAsync(async httpContext => Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); } + [Fact] + public async Task WriteAsync_GetMemoryWithSizeHintAlwaysReturnsSameSizeStartAsync() + { + var headers = new[] +{ + new KeyValuePair(HeaderNames.Method, "GET"), + new KeyValuePair(HeaderNames.Path, "/"), + new KeyValuePair(HeaderNames.Scheme, "http"), + }; + await InitializeConnectionAsync(async httpContext => + { + var response = httpContext.Response; + + var memory = response.BodyWriter.GetMemory(0); + Assert.Equal(4096, memory.Length); + + memory = response.BodyWriter.GetMemory(4096); + Assert.Equal(4096, memory.Length); + + await Task.CompletedTask; + }); + + await StartStreamAsync(1, headers, endStream: true); + + var headersFrame = await ExpectAsync(Http2FrameType.HEADERS, + withLength: 55, + withFlags: (byte)Http2HeadersFrameFlags.END_HEADERS, + withStreamId: 1); + var dataFrame = await ExpectAsync(Http2FrameType.DATA, + withLength: 0, + withFlags: (byte)Http2DataFrameFlags.END_STREAM, + withStreamId: 1); + + await StopConnectionAsync(expectedLastStreamId: 1, ignoreNonGoAwayFrames: false); + + _hpackDecoder.Decode(headersFrame.PayloadSequence, endHeaders: false, handler: this); + + Assert.Equal(3, _decodedHeaders.Count); + Assert.Contains("date", _decodedHeaders.Keys, StringComparer.OrdinalIgnoreCase); + Assert.Equal("200", _decodedHeaders[HeaderNames.Status]); + Assert.Equal("0", _decodedHeaders[HeaderNames.ContentLength]); + } + [Fact] public async Task WriteAsync_BothPipeAndStreamWorks() { @@ -3037,7 +3124,7 @@ await InitializeConnectionAsync(async httpContext => { var response = httpContext.Response; response.ContentLength = 12; - await response.StartAsync(); + await Task.CompletedTask; void NonAsyncMethod() { @@ -3084,11 +3171,10 @@ public async Task ContentLengthWithGetMemoryWorks() new KeyValuePair(HeaderNames.Path, "/"), new KeyValuePair(HeaderNames.Scheme, "http"), }; - await InitializeConnectionAsync(async httpContext => + await InitializeConnectionAsync(httpContext => { var response = httpContext.Response; response.ContentLength = 12; - await response.StartAsync(); var memory = response.BodyWriter.GetMemory(4096); var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); @@ -3098,6 +3184,7 @@ await InitializeConnectionAsync(async httpContext => var secondPartOfResponse = Encoding.ASCII.GetBytes("World!"); secondPartOfResponse.CopyTo(memory.Slice(6)); response.BodyWriter.Advance(6); + return Task.CompletedTask; }); await StartStreamAsync(1, headers, endStream: true); @@ -3174,8 +3261,8 @@ await InitializeConnectionAsync(async httpContext => { var response = httpContext.Response; response.ContentLength = 54; - await response.StartAsync(); var memory = response.BodyWriter.GetMemory(4096); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes("hello,"); fisrtPartOfResponse.CopyTo(memory); response.BodyWriter.Advance(6); diff --git a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs index 55445d144fee..3abfea307cb3 100644 --- a/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs +++ b/src/Servers/Kestrel/test/InMemory.FunctionalTests/ResponseTests.cs @@ -3240,17 +3240,16 @@ await connection.Receive( } [Fact] - public async Task AdvanceWithTooLargeOfAValueThrowInvalidOperationException() + public async Task AdvanceNegativeValueThrowsArgumentOutOfRangeExceptionWithStart() { var testContext = new TestServiceContext(LoggerFactory); - using (var server = new TestServer(async httpContext => + using (var server = new TestServer(httpContext => { var response = httpContext.Response; - await response.StartAsync(); - - Assert.Throws(() => response.BodyWriter.Advance(1)); + Assert.Throws(() => response.BodyWriter.Advance(-1)); + return Task.CompletedTask; }, testContext)) { using (var connection = server.CreateConnection()) @@ -3263,9 +3262,7 @@ await connection.Send( await connection.Receive( "HTTP/1.1 200 OK", $"Date: {testContext.DateHeaderValue}", - "Transfer-Encoding: chunked", - "", - "0", + "Content-Length: 0", "", ""); } @@ -3275,13 +3272,15 @@ await connection.Receive( } [Fact] - public async Task GetMemoryBeforeStartAsyncThrows() + public async Task AdvanceWithTooLargeOfAValueThrowInvalidOperationException() { var testContext = new TestServiceContext(LoggerFactory); using (var server = new TestServer(httpContext => { - Assert.Throws(() => httpContext.Response.BodyWriter.GetMemory()); + var response = httpContext.Response; + + Assert.Throws(() => response.BodyWriter.Advance(1)); return Task.CompletedTask; }, testContext)) { @@ -3295,7 +3294,9 @@ await connection.Send( await connection.Receive( "HTTP/1.1 200 OK", $"Date: {testContext.DateHeaderValue}", - "Content-Length: 0", + "Transfer-Encoding: chunked", + "", + "0", "", ""); } @@ -3305,30 +3306,24 @@ await connection.Receive( } [Fact] - public async Task ContentLengthWithGetSpanWorks() + public async Task ContentLengthWithoutStartAsyncWithGetSpanWorks() { var testContext = new TestServiceContext(LoggerFactory); - using (var server = new TestServer(async httpContext => + using (var server = new TestServer(httpContext => { var response = httpContext.Response; response.ContentLength = 12; - await response.StartAsync(); - void NonAsyncMethod() - { - var span = response.BodyWriter.GetSpan(4096); - var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); - fisrtPartOfResponse.CopyTo(span); - response.BodyWriter.Advance(6); - - var secondPartOfResponse = Encoding.ASCII.GetBytes("World!"); - secondPartOfResponse.CopyTo(span.Slice(6)); - response.BodyWriter.Advance(6); - } - - NonAsyncMethod(); + var span = response.BodyWriter.GetSpan(4096); + var fisrtPartOfResponse = Encoding.ASCII.GetBytes("Hello "); + fisrtPartOfResponse.CopyTo(span); + response.BodyWriter.Advance(6); + var secondPartOfResponse = Encoding.ASCII.GetBytes("World!"); + secondPartOfResponse.CopyTo(span.Slice(6)); + response.BodyWriter.Advance(6); + return Task.CompletedTask; }, testContext)) { using (var connection = server.CreateConnection()) @@ -3359,6 +3354,7 @@ public async Task ContentLengthWithGetMemoryWorks() { var response = httpContext.Response; response.ContentLength = 12; + await response.StartAsync(); var memory = response.BodyWriter.GetMemory(4096); @@ -3461,7 +3457,7 @@ await connection.Receive( } [Fact] - public async Task ResponseBodyPipeCompleteWithoutExceptionDoesNotThrow() + public async Task ResponseBodyWriterCompleteWithoutExceptionDoesNotThrow() { using (var server = new TestServer(async httpContext => { @@ -3488,7 +3484,7 @@ await connection.Receive( } [Fact] - public async Task ResponseBodyPipeCompleteWithoutExceptionWritesDoNotThrow() + public async Task ResponseBodyWriterCompleteWithoutExceptionWritesDoNotThrow() { using (var server = new TestServer(async httpContext => { @@ -3516,6 +3512,217 @@ await connection.Receive( } } + [Fact] + public async Task ResponseAdvanceStateIsResetWithMultipleReqeusts() + { + var secondRequest = false; + using (var server = new TestServer(async httpContext => + { + if (secondRequest) + { + return; + } + + var memory = httpContext.Response.BodyWriter.GetMemory(); + Encoding.ASCII.GetBytes("a").CopyTo(memory); + httpContext.Response.BodyWriter.Advance(1); + await httpContext.Response.BodyWriter.FlushAsync(); + secondRequest = true; + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "1", + "a", + "0", + "", + ""); + + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ResponseStartCalledAndAutoChunkStateIsResetWithMultipleReqeusts() + { + using (var server = new TestServer(async httpContext => + { + var memory = httpContext.Response.BodyWriter.GetMemory(); + Encoding.ASCII.GetBytes("a").CopyTo(memory); + httpContext.Response.BodyWriter.Advance(1); + await httpContext.Response.BodyWriter.FlushAsync(); + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "1", + "a", + "0", + "", + ""); + + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "1", + "a", + "0", + "", + ""); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ResponseStartCalledStateIsResetWithMultipleReqeusts() + { + var flip = false; + using (var server = new TestServer(async httpContext => + { + if (flip) + { + httpContext.Response.ContentLength = 1; + var memory = httpContext.Response.BodyWriter.GetMemory(); + Encoding.ASCII.GetBytes("a").CopyTo(memory); + httpContext.Response.BodyWriter.Advance(1); + await httpContext.Response.BodyWriter.FlushAsync(); + } + else + { + var memory = httpContext.Response.BodyWriter.GetMemory(); + Encoding.ASCII.GetBytes("a").CopyTo(memory); + httpContext.Response.BodyWriter.Advance(1); + await httpContext.Response.BodyWriter.FlushAsync(); + } + flip = !flip; + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + for (var i = 0; i < 3; i++) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "1", + "a", + "0", + "", + ""); + + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 1", + "", + "a"); + } + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ResponseIsLeasedMemoryInvalidStateIsResetWithMultipleReqeusts() + { + var secondRequest = false; + using (var server = new TestServer(httpContext => + { + if (secondRequest) + { + Assert.Throws(() => httpContext.Response.BodyWriter.Advance(1)); + return Task.CompletedTask; + } + + var memory = httpContext.Response.BodyWriter.GetMemory(); + return Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + await server.StopAsync(); + } + } + [Fact] public async Task ResponsePipeWriterCompleteWithException() { @@ -3580,12 +3787,111 @@ await connection.Receive( } [Fact] - public async Task ResponseCompleteGetMemoryAdvanceInLoopDoesNotThrow() + public async Task ResponseCompleteGetMemoryReturnsRentedMemoryWithoutStartAsync() + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.BodyWriter.Complete(); + var memory = httpContext.Response.BodyWriter.GetMemory(); // Shouldn't throw + Assert.Equal(4096, memory.Length); + + await Task.CompletedTask; + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Content-Length: 0", + "", + ""); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ResponseGetMemoryAndStartAsyncMemoryReturnsNewMemory() + { + using (var server = new TestServer(async httpContext => + { + var memory = httpContext.Response.BodyWriter.GetMemory(); + Assert.Equal(4096, memory.Length); + + await httpContext.Response.StartAsync(); + // Original memory is disposed, don't compare against it. + + memory = httpContext.Response.BodyWriter.GetMemory(); + Assert.NotEqual(4096, memory.Length); + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "0", + "", + ""); + } + await server.StopAsync(); + } + } + + + [Fact] + public async Task ResponseGetMemoryAndStartAsyncAdvanceThrows() { using (var server = new TestServer(async httpContext => { + var memory = httpContext.Response.BodyWriter.GetMemory(); + await httpContext.Response.StartAsync(); + Assert.Throws(() => httpContext.Response.BodyWriter.Advance(1)); + + }, new TestServiceContext(LoggerFactory))) + { + using (var connection = server.CreateConnection()) + { + await connection.Send( + "GET / HTTP/1.1", + "Host:", + "", + ""); + await connection.Receive( + "HTTP/1.1 200 OK", + $"Date: {server.Context.DateHeaderValue}", + "Transfer-Encoding: chunked", + "", + "0", + "", + ""); + } + await server.StopAsync(); + } + } + + [Fact] + public async Task ResponseCompleteGetMemoryAdvanceInLoopDoesNotThrow() + { + using (var server = new TestServer(async httpContext => + { + httpContext.Response.BodyWriter.Complete(); for (var i = 0; i < 5; i++) { @@ -3653,12 +3959,12 @@ public async Task ResponseSetBodyToSameValueTwiceGetPipeMultipleTimesDifferentOb { var memoryStream = new MemoryStream(); httpContext.Response.Body = memoryStream; - var bodyPipe1 = httpContext.Response.BodyWriter; + var BodyWriter1 = httpContext.Response.BodyWriter; httpContext.Response.Body = memoryStream; - var bodyPipe2 = httpContext.Response.BodyWriter; + var BodyWriter2 = httpContext.Response.BodyWriter; - Assert.NotEqual(bodyPipe1, bodyPipe2); + Assert.NotEqual(BodyWriter1, BodyWriter2); await Task.CompletedTask; }, new TestServiceContext(LoggerFactory))) { @@ -3715,7 +4021,7 @@ await connection.Receive( } [Fact] - public async Task ResponseSetPipeAndBodyPipeIsWrapped() + public async Task ResponseSetPipeAndBodyWriterIsWrapped() { using (var server = new TestServer(async httpContext => { @@ -3746,7 +4052,7 @@ await connection.Receive( } [Fact] - public async Task ResponseWriteToBodyPipeAndStreamAllBlocksDisposed() + public async Task ResponseWriteToBodyWriterAndStreamAllBlocksDisposed() { using (var server = new TestServer(async httpContext => {