Skip to content

Commit 6167a3b

Browse files
committed
Add IForwardingDecryptorActivator and allow forward behavior to be altered in unit tests
1 parent f3803dd commit 6167a3b

File tree

6 files changed

+89
-89
lines changed

6 files changed

+89
-89
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System.Diagnostics.CodeAnalysis;
5+
using System;
6+
7+
namespace Microsoft.AspNetCore.DataProtection.Internal;
8+
9+
internal interface ITypeForwardingActivator : IActivator
10+
{
11+
bool TryForwardType(string originalTypeName, [NotNullWhen(true)] out Type? type);
12+
}

src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -466,26 +466,24 @@ IAuthenticatedEncryptorDescriptor IInternalXmlKeyManager.DeserializeDescriptorFr
466466
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
467467
private IAuthenticatedEncryptorDescriptorDeserializer CreateDeserializer(string descriptorDeserializerTypeName)
468468
{
469-
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
470-
? forwardedTypeName
471-
: descriptorDeserializerTypeName;
472-
var type = Type.GetType(resolvedTypeName, throwOnError: false);
473-
474-
if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer))
475-
{
476-
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
477-
}
478-
else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
479-
{
480-
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
481-
}
482-
else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
483-
{
484-
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
485-
}
486-
else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer))
469+
if (_activator is ITypeForwardingActivator forwardingActivator && forwardingActivator.TryForwardType(descriptorDeserializerTypeName, out var type))
487470
{
488-
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
471+
if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer))
472+
{
473+
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
474+
}
475+
else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
476+
{
477+
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
478+
}
479+
else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
480+
{
481+
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
482+
}
483+
else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer))
484+
{
485+
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
486+
}
489487
}
490488

491489
return _activator.CreateInstance<IAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);

src/DataProtection/DataProtection/src/TypeForwardingActivator.cs

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@
33

44
using System;
55
using System.Diagnostics.CodeAnalysis;
6+
using Microsoft.AspNetCore.DataProtection.Internal;
67
using Microsoft.Extensions.Logging;
78
using Microsoft.Extensions.Logging.Abstractions;
89

910
namespace Microsoft.AspNetCore.DataProtection;
1011

1112
#pragma warning disable CA1852 // Seal internal types
12-
internal class TypeForwardingActivator : SimpleActivator
13+
internal class TypeForwardingActivator : SimpleActivator, ITypeForwardingActivator
1314
#pragma warning restore CA1852 // Seal internal types
1415
{
1516
private const string OldNamespace = "Microsoft.AspNet.DataProtection";
@@ -30,31 +31,47 @@ public TypeForwardingActivator(IServiceProvider services, ILoggerFactory loggerF
3031
public override object CreateInstance([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type expectedBaseType, string originalTypeName)
3132
=> CreateInstance(expectedBaseType, originalTypeName, out var _);
3233

34+
public bool TryForwardType(string originalTypeName, [NotNullWhen(true)] out Type? type)
35+
=> TryForwardType(originalTypeName, out type, out _);
36+
3337
// for testing
34-
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used with forwarded types that are referenced by DataProtection assembly.")]
3538
internal object CreateInstance([DynamicallyAccessedMembers(DynamicallyAccessedMemberTypes.PublicConstructors)] Type expectedBaseType, string originalTypeName, out bool forwarded)
3639
{
37-
if (TryForwardTypeName(originalTypeName, out var forwardedTypeName))
40+
if (!TryForwardType(originalTypeName, out var forwardedType, out forwarded))
3841
{
39-
var type = Type.GetType(forwardedTypeName, false);
40-
if (type != null)
41-
{
42-
if (_logger.IsEnabled(LogLevel.Debug))
43-
{
44-
_logger.LogDebug("Forwarded activator type request from {FromType} to {ToType}",
45-
originalTypeName,
46-
forwardedTypeName);
47-
}
48-
forwarded = true;
49-
return base.CreateInstance(expectedBaseType, forwardedTypeName);
50-
}
42+
return base.CreateInstance(expectedBaseType, originalTypeName);
43+
}
44+
45+
var forwardedTypeName = forwardedType.AssemblyQualifiedName!;
46+
if (_logger.IsEnabled(LogLevel.Debug))
47+
{
48+
_logger.LogDebug("Forwarded activator type request from {FromType} to {ToType}",
49+
originalTypeName,
50+
forwardedTypeName);
5151
}
52+
return base.CreateInstance(expectedBaseType, forwardedTypeName);
53+
}
5254

53-
forwarded = false;
54-
return base.CreateInstance(expectedBaseType, originalTypeName);
55+
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used with forwarded types that are referenced by DataProtection assembly.")]
56+
private static bool TryForwardType(string originalTypeName, [NotNullWhen(true)] out Type? type, out bool forwarded)
57+
{
58+
forwarded = TryForwardTypeName(originalTypeName, out var resolvedTypeName);
59+
try
60+
{
61+
// Some exceptions are thrown regardless of the value of throwOnError.
62+
// For example, if the type is found but cannot be loaded,
63+
// a System.TypeLoadException is thrown even if throwOnError is false.
64+
type = Type.GetType(resolvedTypeName, throwOnError: false);
65+
return type is not null;
66+
}
67+
catch
68+
{
69+
type = default;
70+
return false;
71+
}
5572
}
5673

57-
internal static bool TryForwardTypeName(string originalTypeName, out string forwardedTypeName)
74+
private static bool TryForwardTypeName(string originalTypeName, out string forwardedTypeName)
5875
{
5976
forwardedTypeName = originalTypeName;
6077

src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs

Lines changed: 17 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ namespace Microsoft.AspNetCore.DataProtection.XmlEncryption;
1616

1717
internal static unsafe class XmlEncryptionExtensions
1818
{
19-
// Used for testing edge case assembly loading errors
20-
internal static Func<string, Type?> _getType = GetType;
21-
2219
public static XElement DecryptElement(this XElement element, IActivator activator)
2320
{
2421
// If no decryption necessary, return original element.
@@ -72,54 +69,29 @@ public static XElement DecryptElement(this XElement element, IActivator activato
7269

7370
private static IXmlDecryptor CreateDecryptor(IActivator activator, string decryptorTypeName)
7471
{
75-
if (!TryGetDecryptorType(decryptorTypeName, out var type))
76-
{
77-
return activator.CreateInstance<IXmlDecryptor>(decryptorTypeName);
78-
}
79-
80-
if (type == typeof(DpapiNGXmlDecryptor))
72+
if (activator is ITypeForwardingActivator typeForwarder && typeForwarder.TryForwardType(decryptorTypeName, out var type))
8173
{
82-
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
83-
}
84-
else if (type == typeof(DpapiXmlDecryptor))
85-
{
86-
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
87-
}
88-
else if (type == typeof(EncryptedXmlDecryptor))
89-
{
90-
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
91-
}
92-
else if (type == typeof(NullXmlDecryptor))
93-
{
94-
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
74+
if (type == typeof(DpapiNGXmlDecryptor))
75+
{
76+
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
77+
}
78+
else if (type == typeof(DpapiXmlDecryptor))
79+
{
80+
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
81+
}
82+
else if (type == typeof(EncryptedXmlDecryptor))
83+
{
84+
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
85+
}
86+
else if (type == typeof(NullXmlDecryptor))
87+
{
88+
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
89+
}
9590
}
9691

9792
return activator.CreateInstance<IXmlDecryptor>(decryptorTypeName);
9893
}
9994

100-
private static bool TryGetDecryptorType(string decryptorTypeName, [NotNullWhen(true)] out Type? type)
101-
{
102-
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
103-
? forwardedTypeName
104-
: decryptorTypeName;
105-
try
106-
{
107-
// Some exceptions are thrown regardless of the value of throwOnError.
108-
// For example, if the type is found but cannot be loaded,
109-
// a System.TypeLoadException is thrown even if throwOnError is false.
110-
type = _getType(resolvedTypeName);
111-
return type is not null;
112-
}
113-
catch
114-
{
115-
type = default;
116-
return false;
117-
}
118-
}
119-
120-
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
121-
private static Type? GetType(string typeName) => Type.GetType(typeName, throwOnError: false);
122-
12395
public static XElement? EncryptIfNecessary(this IXmlEncryptor encryptor, XElement element)
12496
{
12597
// If no encryption is necessary, return null.

src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/MockExtensions.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static void ReturnDescriptorGivenDeserializerTypeNameAndInput(this Mock<I
3838
/// Sets up a mock such that given the name of a decryptor class and the XML node that class's
3939
/// Decrypt method should expect returns the specified XML elmeent.
4040
/// </summary>
41-
public static void ReturnDecryptedElementGivenDecryptorTypeNameAndInput(this Mock<IActivator> mockActivator, string typeName, string expectedInputXml, string outputXml)
41+
public static void ReturnDecryptedElementGivenDecryptorTypeNameAndInput<T>(this Mock<T> mockActivator, string typeName, string expectedInputXml, string outputXml) where T : class, IActivator
4242
{
4343
mockActivator
4444
.Setup(o => o.CreateInstance(typeof(IXmlDecryptor), typeName))

src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,21 +53,22 @@ public void DecryptElement_RootNodeRequiresDecryption_Success()
5353
public void DecryptElement_MissingDecryptorType_Success()
5454
{
5555
// Arrange
56-
// We want to simulate an error loading the specified decryptor type, i.e.
57-
// Could not load file or assembly 'Azure.Extensions.AspNetCore.DataProtection.Keys,
58-
// Version=1.2.2.0, Culture=neutral, PublicKeyToken=92742159e12e44c8' or one of its dependencies.
59-
XmlEncryptionExtensions._getType = _ => throw new TypeLoadException();
60-
6156
var decryptorTypeName = typeof(MyXmlDecryptor).AssemblyQualifiedName;
6257

6358
var original = XElement.Parse(@$"
6459
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
6560
<node />
6661
</x:encryptedSecret>");
6762

68-
var mockActivator = new Mock<IActivator>();
63+
var mockActivator = new Mock<ITypeForwardingActivator>();
6964
mockActivator.ReturnDecryptedElementGivenDecryptorTypeNameAndInput(decryptorTypeName, "<node />", "<newNode />");
7065

66+
// We want to simulate an error loading the specified decryptor type, i.e.
67+
// Could not load file or assembly 'Azure.Extensions.AspNetCore.DataProtection.Keys,
68+
// Version=1.2.2.0, Culture=neutral, PublicKeyToken=92742159e12e44c8' or one of its dependencies.
69+
Type type;
70+
mockActivator.Setup(x => x.TryForwardType(decryptorTypeName, out type)).Returns(false);
71+
7172
var serviceCollection = new ServiceCollection();
7273
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
7374
var services = serviceCollection.BuildServiceProvider();

0 commit comments

Comments
 (0)