diff --git a/src/Http/Routing/perf/Microbenchmarks/EndpointRoutingShortCircuitBenchmark.cs b/src/Http/Routing/perf/Microbenchmarks/EndpointRoutingShortCircuitBenchmark.cs new file mode 100644 index 000000000000..9fc01b45acde --- /dev/null +++ b/src/Http/Routing/perf/Microbenchmarks/EndpointRoutingShortCircuitBenchmark.cs @@ -0,0 +1,113 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Diagnostics; +using BenchmarkDotNet.Attributes; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; +using Microsoft.AspNetCore.Routing.Matching; +using Microsoft.AspNetCore.Routing.ShortCircuit; +using Microsoft.Extensions.Logging.Abstractions; +using Microsoft.Extensions.Options; +using Microsoft.Extensions.Primitives; + +namespace Microsoft.AspNetCore.Routing; + +public class EndpointRoutingShortCircuitBenchmark +{ + private EndpointRoutingMiddleware _normalEndpointMiddleware; + private EndpointRoutingMiddleware _shortCircuitEndpointMiddleware; + + [GlobalSetup] + public void Setup() + { + var normalEndpoint = new Endpoint(context => Task.CompletedTask, new EndpointMetadataCollection(), "normal"); + + _normalEndpointMiddleware = new EndpointRoutingMiddleware( + new BenchmarkMatcherFactory(normalEndpoint), + NullLogger.Instance, + new BenchmarkEndpointRouteBuilder(), + new BenchmarkEndpointDataSource(), + new DiagnosticListener("benchmark"), + Options.Create(new RouteOptions()), + context => Task.CompletedTask); + + var shortCircuitEndpoint = new Endpoint(context => Task.CompletedTask, new EndpointMetadataCollection(new ShortCircuitMetadata(200)), "shortcircuit"); + + _shortCircuitEndpointMiddleware = new EndpointRoutingMiddleware( + new BenchmarkMatcherFactory(shortCircuitEndpoint), + NullLogger.Instance, + new BenchmarkEndpointRouteBuilder(), + new BenchmarkEndpointDataSource(), + new DiagnosticListener("benchmark"), + Options.Create(new RouteOptions()), + context => Task.CompletedTask); + + } + + [Benchmark] + public async Task NormalEndpoint() + { + var context = new DefaultHttpContext(); + await _normalEndpointMiddleware.Invoke(context); + } + + [Benchmark] + public async Task ShortCircuitEndpoint() + { + var context = new DefaultHttpContext(); + await _shortCircuitEndpointMiddleware.Invoke(context); + } +} + +internal class BenchmarkMatcherFactory : MatcherFactory +{ + private readonly Endpoint _endpoint; + + public BenchmarkMatcherFactory(Endpoint endpoint) + { + _endpoint = endpoint; + } + + public override Matcher CreateMatcher(EndpointDataSource dataSource) + { + return new BenchmarkMatcher(_endpoint); + } + + internal class BenchmarkMatcher : Matcher + { + private Endpoint _endpoint; + + public BenchmarkMatcher(Endpoint endpoint) + { + _endpoint = endpoint; + } + + public override Task MatchAsync(HttpContext httpContext) + { + httpContext.SetEndpoint(_endpoint); + return Task.CompletedTask; + } + } +} + +internal class BenchmarkEndpointRouteBuilder : IEndpointRouteBuilder +{ + public IServiceProvider ServiceProvider => throw new NotImplementedException(); + + public ICollection DataSources => new List(); + + public IApplicationBuilder CreateApplicationBuilder() + { + throw new NotImplementedException(); + } +} +internal class BenchmarkEndpointDataSource : EndpointDataSource +{ + public override IReadOnlyList Endpoints => throw new NotImplementedException(); + + public override IChangeToken GetChangeToken() + { + throw new NotImplementedException(); + } +} diff --git a/src/Http/Routing/src/EndpointMiddleware.cs b/src/Http/Routing/src/EndpointMiddleware.cs index cf91b97f7426..3658b5084be6 100644 --- a/src/Http/Routing/src/EndpointMiddleware.cs +++ b/src/Http/Routing/src/EndpointMiddleware.cs @@ -33,6 +33,7 @@ public Task Invoke(HttpContext httpContext) var endpoint = httpContext.GetEndpoint(); if (endpoint is not null) { + // This check should be kept in sync with the one in EndpointRoutingMiddleware if (!_routeOptions.SuppressCheckForUnhandledSecurityMetadata) { if (endpoint.Metadata.GetMetadata() is not null && diff --git a/src/Http/Routing/src/EndpointRoutingMiddleware.cs b/src/Http/Routing/src/EndpointRoutingMiddleware.cs index 709936791705..ebd02ff64939 100644 --- a/src/Http/Routing/src/EndpointRoutingMiddleware.cs +++ b/src/Http/Routing/src/EndpointRoutingMiddleware.cs @@ -4,9 +4,13 @@ using System.Diagnostics; using System.Diagnostics.CodeAnalysis; using System.Runtime.CompilerServices; +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Cors.Infrastructure; using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing.Matching; +using Microsoft.AspNetCore.Routing.ShortCircuit; using Microsoft.Extensions.Logging; +using Microsoft.Extensions.Options; namespace Microsoft.AspNetCore.Routing; @@ -19,7 +23,7 @@ internal sealed partial class EndpointRoutingMiddleware private readonly EndpointDataSource _endpointDataSource; private readonly DiagnosticListener _diagnosticListener; private readonly RequestDelegate _next; - + private readonly RouteOptions _routeOptions; private Task? _initializationTask; public EndpointRoutingMiddleware( @@ -28,6 +32,7 @@ public EndpointRoutingMiddleware( IEndpointRouteBuilder endpointRouteBuilder, EndpointDataSource rootCompositeEndpointDataSource, DiagnosticListener diagnosticListener, + IOptions routeOptions, RequestDelegate next) { ArgumentNullException.ThrowIfNull(endpointRouteBuilder); @@ -36,6 +41,7 @@ public EndpointRoutingMiddleware( _logger = logger ?? throw new ArgumentNullException(nameof(logger)); _diagnosticListener = diagnosticListener ?? throw new ArgumentNullException(nameof(diagnosticListener)); _next = next ?? throw new ArgumentNullException(nameof(next)); + _routeOptions = routeOptions.Value; // rootCompositeEndpointDataSource is a constructor parameter only so it always gets disposed by DI. This ensures that any // disposable EndpointDataSources also get disposed. _endpointDataSource is a component of rootCompositeEndpointDataSource. @@ -102,6 +108,12 @@ private Task SetRoutingAndContinue(HttpContext httpContext) } Log.MatchSuccess(_logger, endpoint); + + var shortCircuitMetadata = endpoint.Metadata.GetMetadata(); + if (shortCircuitMetadata is not null) + { + return ExecuteShortCircuit(shortCircuitMetadata, endpoint, httpContext); + } } return _next(httpContext); @@ -115,6 +127,75 @@ static void Write(DiagnosticListener diagnosticListener, HttpContext httpContext } } + private Task ExecuteShortCircuit(ShortCircuitMetadata shortCircuitMetadata, Endpoint endpoint, HttpContext httpContext) + { + // This check should be kept in sync with the one in EndpointMiddleware + if (!_routeOptions.SuppressCheckForUnhandledSecurityMetadata) + { + if (endpoint.Metadata.GetMetadata() is not null) + { + ThrowCannotShortCircuitAnAuthRouteException(endpoint); + } + + if (endpoint.Metadata.GetMetadata() is not null) + { + ThrowCannotShortCircuitACorsRouteException(endpoint); + } + } + + if (shortCircuitMetadata.StatusCode.HasValue) + { + httpContext.Response.StatusCode = shortCircuitMetadata.StatusCode.Value; + } + + if (endpoint.RequestDelegate is not null) + { + if (!_logger.IsEnabled(LogLevel.Information)) + { + // Avoid the AwaitRequestTask state machine allocation if logging is disabled. + return endpoint.RequestDelegate(httpContext); + } + + Log.ExecutingEndpoint(_logger, endpoint); + + try + { + var requestTask = endpoint.RequestDelegate(httpContext); + if (!requestTask.IsCompletedSuccessfully) + { + return AwaitRequestTask(endpoint, requestTask, _logger); + } + } + catch + { + Log.ExecutedEndpoint(_logger, endpoint); + throw; + } + + Log.ExecutedEndpoint(_logger, endpoint); + + return Task.CompletedTask; + + static async Task AwaitRequestTask(Endpoint endpoint, Task requestTask, ILogger logger) + { + try + { + await requestTask; + } + finally + { + Log.ExecutedEndpoint(logger, endpoint); + } + } + + } + else + { + Log.ShortCircuitedEndpoint(_logger, endpoint); + } + return Task.CompletedTask; + } + // Initialization is async to avoid blocking threads while reflection and things // of that nature take place. // @@ -165,6 +246,18 @@ private Task InitializeCoreAsync() } } + private static void ThrowCannotShortCircuitAnAuthRouteException(Endpoint endpoint) + { + throw new InvalidOperationException($"Endpoint {endpoint.DisplayName} contains authorization metadata, " + + "but this endpoint is marked with short circuit and it will execute on Routing Middleware."); + } + + private static void ThrowCannotShortCircuitACorsRouteException(Endpoint endpoint) + { + throw new InvalidOperationException($"Endpoint {endpoint.DisplayName} contains CORS metadata, " + + "but this endpoint is marked with short circuit and it will execute on Routing Middleware."); + } + private static partial class Log { public static void MatchSuccess(ILogger logger, Endpoint endpoint) @@ -181,5 +274,14 @@ public static void MatchSkipped(ILogger logger, Endpoint endpoint) [LoggerMessage(3, LogLevel.Debug, "Endpoint '{EndpointName}' already set, skipping route matching.", EventName = "MatchingSkipped")] private static partial void MatchingSkipped(ILogger logger, string? endpointName); + + [LoggerMessage(4, LogLevel.Information, "The endpoint '{EndpointName}' is being executed without running additional middleware.", EventName = "ExecutingEndpoint")] + public static partial void ExecutingEndpoint(ILogger logger, Endpoint endpointName); + + [LoggerMessage(5, LogLevel.Information, "The endpoint '{EndpointName}' has been executed without running additional middleware.", EventName = "ExecutedEndpoint")] + public static partial void ExecutedEndpoint(ILogger logger, Endpoint endpointName); + + [LoggerMessage(6, LogLevel.Information, "The endpoint '{EndpointName}' is being short circuited without running additional middleware or producing a response.", EventName = "ShortCircuitedEndpoint")] + public static partial void ShortCircuitedEndpoint(ILogger logger, Endpoint endpointName); } } diff --git a/src/Http/Routing/src/PublicAPI.Unshipped.txt b/src/Http/Routing/src/PublicAPI.Unshipped.txt index fff1ced433fd..d73a26d7c26d 100644 --- a/src/Http/Routing/src/PublicAPI.Unshipped.txt +++ b/src/Http/Routing/src/PublicAPI.Unshipped.txt @@ -1,4 +1,8 @@ #nullable enable Microsoft.AspNetCore.Routing.RouteHandlerServices static Microsoft.AspNetCore.Routing.RouteHandlerServices.Map(Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! endpoints, string! pattern, System.Delegate! handler, System.Collections.Generic.IEnumerable! httpMethods, System.Func! populateMetadata, System.Func! createRequestDelegate) -> Microsoft.AspNetCore.Builder.RouteHandlerBuilder! +Microsoft.AspNetCore.Builder.RouteShortCircuitEndpointConventionBuilderExtensions +Microsoft.AspNetCore.Routing.RouteShortCircuitEndpointRouteBuilderExtensions +static Microsoft.AspNetCore.Builder.RouteShortCircuitEndpointConventionBuilderExtensions.ShortCircuit(this Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! builder, int? statusCode = null) -> Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! +static Microsoft.AspNetCore.Routing.RouteShortCircuitEndpointRouteBuilderExtensions.MapShortCircuit(this Microsoft.AspNetCore.Routing.IEndpointRouteBuilder! builder, int statusCode, params string![]! routePrefixes) -> Microsoft.AspNetCore.Builder.IEndpointConventionBuilder! static Microsoft.Extensions.DependencyInjection.RoutingServiceCollectionExtensions.AddRoutingCore(this Microsoft.Extensions.DependencyInjection.IServiceCollection! services) -> Microsoft.Extensions.DependencyInjection.IServiceCollection! diff --git a/src/Http/Routing/src/ShortCircuit/RouteShortCircuitEndpointConventionBuilderExtensions.cs b/src/Http/Routing/src/ShortCircuit/RouteShortCircuitEndpointConventionBuilderExtensions.cs new file mode 100644 index 000000000000..15a867bd1278 --- /dev/null +++ b/src/Http/Routing/src/ShortCircuit/RouteShortCircuitEndpointConventionBuilderExtensions.cs @@ -0,0 +1,39 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Routing.ShortCircuit; + +namespace Microsoft.AspNetCore.Builder; + +/// +/// Short circuit extension methods for . +/// +public static class RouteShortCircuitEndpointConventionBuilderExtensions +{ + private static readonly ShortCircuitMetadata _200ShortCircuitMetadata = new ShortCircuitMetadata(200); + private static readonly ShortCircuitMetadata _401ShortCircuitMetadata = new ShortCircuitMetadata(401); + private static readonly ShortCircuitMetadata _404ShortCircuitMetadata = new ShortCircuitMetadata(404); + private static readonly ShortCircuitMetadata _nullShortCircuitMetadata = new ShortCircuitMetadata(null); + + /// + /// Short circuit the endpoint(s). + /// The execution of the endpoint will happen in UseRouting middleware instead of UseEndpoint. + /// + /// The endpoint convention builder. + /// The status code to set in the response. + /// The original convention builder parameter. + public static IEndpointConventionBuilder ShortCircuit(this IEndpointConventionBuilder builder, int? statusCode = null) + { + var metadata = statusCode switch + { + 200 => _200ShortCircuitMetadata, + 401 => _401ShortCircuitMetadata, + 404 => _404ShortCircuitMetadata, + null => _nullShortCircuitMetadata, + _ => new ShortCircuitMetadata(statusCode) + }; + + builder.Add(b => b.Metadata.Add(metadata)); + return builder; + } +} diff --git a/src/Http/Routing/src/ShortCircuit/RouteShortCircuitEndpointRouteBuilderExtensions.cs b/src/Http/Routing/src/ShortCircuit/RouteShortCircuitEndpointRouteBuilderExtensions.cs new file mode 100644 index 000000000000..3d0e3f6bc4a6 --- /dev/null +++ b/src/Http/Routing/src/ShortCircuit/RouteShortCircuitEndpointRouteBuilderExtensions.cs @@ -0,0 +1,68 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Http; + +namespace Microsoft.AspNetCore.Routing; + +/// +/// Provides extension methods for to add short circuited endpoints. +/// +public static class RouteShortCircuitEndpointRouteBuilderExtensions +{ + private static readonly RequestDelegate _shortCircuitDelegate = (context) => Task.CompletedTask; + /// + /// Adds a to the that matches HTTP GET requests + /// for the specified prefixes. + /// + ///The to add the route to. + /// The status code to set in the response. + /// An array of route prefixes to be short circuited. + /// A that can be used to further customize the endpoint. + public static IEndpointConventionBuilder MapShortCircuit(this IEndpointRouteBuilder builder, int statusCode, params string[] routePrefixes) + { + var group = builder.MapGroup(""); + foreach (var routePrefix in routePrefixes) + { + string route; + if (routePrefix.EndsWith("/", StringComparison.OrdinalIgnoreCase)) + { + route = $"{routePrefix}{{**catchall}}"; + } + else + { + route = $"{routePrefix}/{{**catchall}}"; + } + group.Map(route, _shortCircuitDelegate) + .ShortCircuit(statusCode) + .Add(endpoint => + { + endpoint.DisplayName = $"ShortCircuit {endpoint.DisplayName}"; + ((RouteEndpointBuilder)endpoint).Order = int.MaxValue; + }); + } + + return new EndpointConventionBuilder(group); + } + + private sealed class EndpointConventionBuilder : IEndpointConventionBuilder + { + private readonly IEndpointConventionBuilder _endpointConventionBuilder; + + public EndpointConventionBuilder(IEndpointConventionBuilder endpointConventionBuilder) + { + _endpointConventionBuilder = endpointConventionBuilder; + } + + public void Add(Action convention) + { + _endpointConventionBuilder.Add(convention); + } + + public void Finally(Action finalConvention) + { + _endpointConventionBuilder.Finally(finalConvention); + } + } +} diff --git a/src/Http/Routing/src/ShortCircuit/ShortCircuitMetadata.cs b/src/Http/Routing/src/ShortCircuit/ShortCircuitMetadata.cs new file mode 100644 index 000000000000..f75345a8ac23 --- /dev/null +++ b/src/Http/Routing/src/ShortCircuit/ShortCircuitMetadata.cs @@ -0,0 +1,14 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +namespace Microsoft.AspNetCore.Routing.ShortCircuit; + +internal sealed class ShortCircuitMetadata +{ + public int? StatusCode { get; } + + public ShortCircuitMetadata(int? statusCode) + { + StatusCode = statusCode; + } +} diff --git a/src/Http/Routing/test/FunctionalTests/ShortCircuitTests.cs b/src/Http/Routing/test/FunctionalTests/ShortCircuitTests.cs new file mode 100644 index 000000000000..d65cd8f8d153 --- /dev/null +++ b/src/Http/Routing/test/FunctionalTests/ShortCircuitTests.cs @@ -0,0 +1,98 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System.Net; +using Microsoft.AspNetCore.Builder; +using Microsoft.AspNetCore.Hosting; +using Microsoft.AspNetCore.TestHost; +using Microsoft.Extensions.DependencyInjection; +using Microsoft.Extensions.Hosting; + +namespace Microsoft.AspNetCore.Routing.FunctionalTests; + +public class ShortCircuitTests +{ + [Fact] + public async Task ShortCircuitTest() + { + using var host = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .Configure(app => + { + app.UseRouting(); + app.Use((context, next) => + { + context.Response.Headers["NotSet"] = "No!"; + return next(context); + }); + app.UseEndpoints(b => + { + b.Map("/shortcircuit", context => + { + context.Response.Headers["Set"] = "Yes!"; + return Task.CompletedTask; + }) + .ShortCircuit(); + }); + }) + .UseTestServer(); + }) + .ConfigureServices(services => + { + services.AddRouting(); + }) + .Build(); + + using var server = host.GetTestServer(); + + await host.StartAsync(); + + var response = await server.CreateRequest("/shortcircuit").SendAsync("GET"); + + Assert.True(response.Headers.Contains("Set")); + Assert.False(response.Headers.Contains("NotSet")); + } + + [Fact] + public async Task MapShortCircuitTest() + { + using var host = new HostBuilder() + .ConfigureWebHost(webHostBuilder => + { + webHostBuilder + .Configure(app => + { + app.UseRouting(); + app.Use((context, next) => + { + context.Response.Headers["NotSet"] = "No!"; + return next(context); + }); + app.UseEndpoints(b => + { + b.MapShortCircuit((int)HttpStatusCode.NotFound, "/shortcircuit"); + }); + }) + .UseTestServer(); + }) + .ConfigureServices(services => + { + services.AddRouting(); + }) + .Build(); + + using var server = host.GetTestServer(); + + await host.StartAsync(); + + var response1 = await server.CreateRequest("/shortcircuit").SendAsync("GET"); + Assert.Equal(HttpStatusCode.NotFound, response1.StatusCode); + Assert.False(response1.Headers.Contains("NotSet")); + + var response2 = await server.CreateRequest("/shortcircuit/whatever").SendAsync("GET"); + Assert.Equal(HttpStatusCode.NotFound, response2.StatusCode); + Assert.False(response2.Headers.Contains("NotSet")); + } +} diff --git a/src/Http/Routing/test/UnitTests/EndpointRoutingMiddlewareTest.cs b/src/Http/Routing/test/UnitTests/EndpointRoutingMiddlewareTest.cs index ba1506102c63..03a35f4689e3 100644 --- a/src/Http/Routing/test/UnitTests/EndpointRoutingMiddlewareTest.cs +++ b/src/Http/Routing/test/UnitTests/EndpointRoutingMiddlewareTest.cs @@ -10,6 +10,7 @@ using Microsoft.Extensions.Logging; using Microsoft.Extensions.Logging.Abstractions; using Microsoft.Extensions.Logging.Testing; +using Microsoft.Extensions.Options; using Moq; namespace Microsoft.AspNetCore.Routing; @@ -155,6 +156,74 @@ public async Task Invoke_InitializationFailure_AllowsReinitialization() .Verify(f => f.CreateMatcher(It.IsAny()), Times.Exactly(2)); } + [Fact] + public async Task ShortCircuitWithoutStatusCode() + { + // Arrange + var httpContext = CreateHttpContext(); + + var middleware = CreateMiddleware( + matcherFactory: new ShortCircuitMatcherFactory(null, false, false), + next: context => + { + // should not be reached + throw new Exception(); + }); + + // Act + await middleware.Invoke(httpContext); + + // Assert + Assert.True((bool)httpContext.Items["ShortCircuit"]); + Assert.Equal(200, httpContext.Response.StatusCode); + } + + [Fact] + public async Task ShortCircuitWithStatusCode() + { + // Arrange + var httpContext = CreateHttpContext(); + + var middleware = CreateMiddleware( + matcherFactory: new ShortCircuitMatcherFactory(404, false, false), + next: context => + { + // should not be reached + throw new Exception(); + }); + + // Act + await middleware.Invoke(httpContext); + + // Assert + Assert.True((bool)httpContext.Items["ShortCircuit"]); + Assert.Equal(404, httpContext.Response.StatusCode); + } + + [InlineData(404, true, true)] + [InlineData(404, false, true)] + [InlineData(404, true, false)] + [InlineData(null, true, true)] + [InlineData(null, false, true)] + [InlineData(null, true, false)] + [Theory] + public async Task ThrowIfSecurityMetadataPresent(int? statusCode, bool hasAuthMetadata, bool hasCorsMetadata) + { + // Arrange + var httpContext = CreateHttpContext(); + + var middleware = CreateMiddleware( + matcherFactory: new ShortCircuitMatcherFactory(statusCode, hasAuthMetadata, hasCorsMetadata), + next: context => + { + // should not be reached + throw new Exception(); + }); + + // Act + await Assert.ThrowsAsync(() => middleware.Invoke(httpContext)); + } + private HttpContext CreateHttpContext() { var httpContext = new DefaultHttpContext @@ -182,6 +251,7 @@ private EndpointRoutingMiddleware CreateMiddleware( new DefaultEndpointRouteBuilder(Mock.Of()), new DefaultEndpointDataSource(), listener, + Options.Create(new RouteOptions()), next); return middleware; diff --git a/src/Http/Routing/test/UnitTests/TestConstants.cs b/src/Http/Routing/test/UnitTests/TestConstants.cs index 3aba3f9d71e7..61a07bd86392 100644 --- a/src/Http/Routing/test/UnitTests/TestConstants.cs +++ b/src/Http/Routing/test/UnitTests/TestConstants.cs @@ -1,4 +1,4 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. using Microsoft.AspNetCore.Http; @@ -8,4 +8,9 @@ namespace Microsoft.AspNetCore.Routing; public static class TestConstants { internal static readonly RequestDelegate EmptyRequestDelegate = (context) => Task.CompletedTask; + internal static readonly RequestDelegate ShortCircuitRequestDelegate = (context) => + { + context.Items["ShortCircuit"] = true; + return Task.CompletedTask; + }; } diff --git a/src/Http/Routing/test/UnitTests/TestObjects/TestMatcherFactory.cs b/src/Http/Routing/test/UnitTests/TestObjects/TestMatcherFactory.cs index ece8696ebad2..240462a3b115 100644 --- a/src/Http/Routing/test/UnitTests/TestObjects/TestMatcherFactory.cs +++ b/src/Http/Routing/test/UnitTests/TestObjects/TestMatcherFactory.cs @@ -1,7 +1,11 @@ -// Licensed to the .NET Foundation under one or more agreements. +// Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. +using Microsoft.AspNetCore.Authorization; +using Microsoft.AspNetCore.Cors.Infrastructure; +using Microsoft.AspNetCore.Http; using Microsoft.AspNetCore.Routing.Matching; +using Microsoft.AspNetCore.Routing.ShortCircuit; namespace Microsoft.AspNetCore.Routing.TestObjects; @@ -19,3 +23,63 @@ public override Matcher CreateMatcher(EndpointDataSource dataSource) return new TestMatcher(_isHandled); } } + +internal class ShortCircuitMatcherFactory : MatcherFactory +{ + private readonly int? _statusCode; + private readonly bool _hasAuthMetadata; + private readonly bool _hasCorsMetadata; + + public ShortCircuitMatcherFactory(int? statusCode, bool hasAuthMetadata, bool hasCorsMetadata) + { + _statusCode = statusCode; + _hasAuthMetadata = hasAuthMetadata; + _hasCorsMetadata = hasCorsMetadata; + } + + public override Matcher CreateMatcher(EndpointDataSource dataSource) + { + return new ShortCircuitMatcher(_statusCode, _hasAuthMetadata, _hasCorsMetadata); + } +} + +internal class ShortCircuitMatcher : Matcher +{ + private readonly int? _statusCode; + private readonly bool _hasAuthMetadata; + private readonly bool _hasCorsMetadata; + + public ShortCircuitMatcher(int? statusCode, bool hasAuthMetadata, bool hasCorsMetadata) + { + _statusCode = statusCode; + _hasAuthMetadata = hasAuthMetadata; + _hasCorsMetadata = hasCorsMetadata; + } + + public override Task MatchAsync(HttpContext httpContext) + { + var metadataList = new List + { + new ShortCircuitMetadata(_statusCode) + }; + + if (_hasAuthMetadata) + { + metadataList.Add(new AuthorizeAttribute()); + } + + if (_hasCorsMetadata) + { + metadataList.Add(new CorsMetadata()); + } + + var metadata = new EndpointMetadataCollection(metadataList); + httpContext.SetEndpoint(new Endpoint(TestConstants.ShortCircuitRequestDelegate, metadata, "Short Circuit Endpoint")); + + return Task.CompletedTask; + } +} + +internal class CorsMetadata : ICorsMetadata +{ +}