Skip to content

Commit c695ea0

Browse files
JamesNKkhellang
authored andcommitted
Check type name for known value before getting type
1 parent f3803dd commit c695ea0

File tree

6 files changed

+140
-45
lines changed

6 files changed

+140
-45
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Diagnostics.CodeAnalysis;
6+
7+
namespace Microsoft.AspNetCore.DataProtection.Internal;
8+
9+
internal sealed class DefaultTypeNameResolver : ITypeNameResolver
10+
{
11+
public static readonly DefaultTypeNameResolver Instance = new();
12+
13+
private DefaultTypeNameResolver()
14+
{
15+
}
16+
17+
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used to resolve statically known types that are referenced by DataProtection assembly.")]
18+
public bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type)
19+
{
20+
try
21+
{
22+
// Some exceptions are thrown regardless of the value of throwOnError.
23+
// For example, if the type is found but cannot be loaded,
24+
// a System.TypeLoadException is thrown even if throwOnError is false.
25+
type = Type.GetType(typeName, throwOnError: false);
26+
return type != null;
27+
}
28+
catch
29+
{
30+
type = null;
31+
return false;
32+
}
33+
}
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Diagnostics.CodeAnalysis;
6+
7+
namespace Microsoft.AspNetCore.DataProtection.Internal;
8+
9+
internal interface ITypeNameResolver
10+
{
11+
bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type);
12+
}

src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public sealed class XmlKeyManager : IKeyManager, IInternalXmlKeyManager
4949
private const string RevokeAllKeysValue = "*";
5050

5151
private readonly IActivator _activator;
52+
private readonly ITypeNameResolver _typeNameResolver;
5253
private readonly AlgorithmConfiguration _authenticatedEncryptorConfiguration;
5354
private readonly IKeyEscrowSink? _keyEscrowSink;
5455
private readonly IInternalXmlKeyManager _internalKeyManager;
@@ -112,6 +113,7 @@ internal XmlKeyManager(
112113
var escrowSinks = keyManagementOptions.Value.KeyEscrowSinks;
113114
_keyEscrowSink = escrowSinks.Count > 0 ? new AggregateKeyEscrowSink(escrowSinks) : null;
114115
_activator = activator;
116+
_typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;
115117
TriggerAndResetCacheExpirationToken(suppressLogging: true);
116118
_internalKeyManager = _internalKeyManager ?? this;
117119
_encryptorFactories = keyManagementOptions.Value.AuthenticatedEncryptorFactories;
@@ -469,21 +471,20 @@ private IAuthenticatedEncryptorDescriptorDeserializer CreateDeserializer(string
469471
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
470472
? forwardedTypeName
471473
: descriptorDeserializerTypeName;
472-
var type = Type.GetType(resolvedTypeName, throwOnError: false);
473474

474-
if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer))
475+
if (TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(AuthenticatedEncryptorDescriptorDeserializer)))
475476
{
476477
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
477478
}
478-
else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
479+
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer)))
479480
{
480481
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
481482
}
482-
else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
483+
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer)))
483484
{
484485
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
485486
}
486-
else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer))
487+
else if (TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer)))
487488
{
488489
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
489490
}

src/DataProtection/DataProtection/src/TypeExtensions.cs

+12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Diagnostics.CodeAnalysis;
6+
using Microsoft.AspNetCore.DataProtection.Internal;
67

78
namespace Microsoft.AspNetCore.DataProtection;
89

@@ -39,4 +40,15 @@ public static Type GetTypeWithTrimFriendlyErrorMessage(string typeName)
3940
throw new InvalidOperationException($"Unable to load type '{typeName}'. If the app is published with trimming then this type may have been trimmed. Ensure the type's assembly is excluded from trimming.", ex);
4041
}
4142
}
43+
44+
public static bool MatchType(string resolvedTypeName, ITypeNameResolver typeNameResolver, Type matchType)
45+
{
46+
// Before attempting to resolve the name to a type, check if it starts with the full name of the type.
47+
if (matchType.FullName != null && resolvedTypeName.StartsWith(matchType.FullName, StringComparison.Ordinal))
48+
{
49+
return typeNameResolver.TryResolveType(resolvedTypeName, out var resolvedType) && resolvedType == matchType;
50+
}
51+
52+
return false;
53+
}
4254
}

src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs

+8-34
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,6 @@ namespace Microsoft.AspNetCore.DataProtection.XmlEncryption;
1616

1717
internal static unsafe class XmlEncryptionExtensions
1818
{
19-
// Used for testing edge case assembly loading errors
20-
internal static Func<string, Type?> _getType = GetType;
21-
2219
public static XElement DecryptElement(this XElement element, IActivator activator)
2320
{
2421
// If no decryption necessary, return original element.
@@ -72,54 +69,31 @@ public static XElement DecryptElement(this XElement element, IActivator activato
7269

7370
private static IXmlDecryptor CreateDecryptor(IActivator activator, string decryptorTypeName)
7471
{
75-
if (!TryGetDecryptorType(decryptorTypeName, out var type))
76-
{
77-
return activator.CreateInstance<IXmlDecryptor>(decryptorTypeName);
78-
}
72+
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
73+
? forwardedTypeName
74+
: decryptorTypeName;
75+
var typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;
7976

80-
if (type == typeof(DpapiNGXmlDecryptor))
77+
if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(DpapiNGXmlDecryptor)))
8178
{
8279
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
8380
}
84-
else if (type == typeof(DpapiXmlDecryptor))
81+
else if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(DpapiXmlDecryptor)))
8582
{
8683
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
8784
}
88-
else if (type == typeof(EncryptedXmlDecryptor))
85+
else if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(EncryptedXmlDecryptor)))
8986
{
9087
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
9188
}
92-
else if (type == typeof(NullXmlDecryptor))
89+
else if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(NullXmlDecryptor)))
9390
{
9491
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
9592
}
9693

9794
return activator.CreateInstance<IXmlDecryptor>(decryptorTypeName);
9895
}
9996

100-
private static bool TryGetDecryptorType(string decryptorTypeName, [NotNullWhen(true)] out Type? type)
101-
{
102-
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
103-
? forwardedTypeName
104-
: decryptorTypeName;
105-
try
106-
{
107-
// Some exceptions are thrown regardless of the value of throwOnError.
108-
// For example, if the type is found but cannot be loaded,
109-
// a System.TypeLoadException is thrown even if throwOnError is false.
110-
type = _getType(resolvedTypeName);
111-
return type is not null;
112-
}
113-
catch
114-
{
115-
type = default;
116-
return false;
117-
}
118-
}
119-
120-
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
121-
private static Type? GetType(string typeName) => Type.GetType(typeName, throwOnError: false);
122-
12397
public static XElement? EncryptIfNecessary(this IXmlEncryptor encryptor, XElement element)
12498
{
12599
// If no encryption is necessary, return null.

src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs

+68-6
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,9 @@ public void DecryptElement_RootNodeRequiresDecryption_Success()
5050
}
5151

5252
[Fact]
53-
public void DecryptElement_MissingDecryptorType_Success()
53+
public void DecryptElement_CustomType_TypeNameResolverNotCalled()
5454
{
5555
// Arrange
56-
// We want to simulate an error loading the specified decryptor type, i.e.
57-
// Could not load file or assembly 'Azure.Extensions.AspNetCore.DataProtection.Keys,
58-
// Version=1.2.2.0, Culture=neutral, PublicKeyToken=92742159e12e44c8' or one of its dependencies.
59-
XmlEncryptionExtensions._getType = _ => throw new TypeLoadException();
60-
6156
var decryptorTypeName = typeof(MyXmlDecryptor).AssemblyQualifiedName;
6257

6358
var original = XElement.Parse(@$"
@@ -67,6 +62,7 @@ public void DecryptElement_MissingDecryptorType_Success()
6762

6863
var mockActivator = new Mock<IActivator>();
6964
mockActivator.ReturnDecryptedElementGivenDecryptorTypeNameAndInput(decryptorTypeName, "<node />", "<newNode />");
65+
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
7066

7167
var serviceCollection = new ServiceCollection();
7268
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
@@ -78,6 +74,72 @@ public void DecryptElement_MissingDecryptorType_Success()
7874

7975
// Assert
8076
XmlAssert.Equal("<newNode />", retVal);
77+
Type resolvedType;
78+
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Never());
79+
}
80+
81+
[Fact]
82+
public void DecryptElement_KnownType_TypeNameResolverCalled()
83+
{
84+
// Arrange
85+
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;
86+
87+
var original = XElement.Parse(@$"
88+
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
89+
<node>
90+
<value />
91+
</node>
92+
</x:encryptedSecret>");
93+
94+
var mockActivator = new Mock<IActivator>();
95+
mockActivator.Setup(o => o.CreateInstance(typeof(NullXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
96+
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
97+
var resolvedType = typeof(NullXmlDecryptor);
98+
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny<string>(), out resolvedType)).Returns(true);
99+
100+
var serviceCollection = new ServiceCollection();
101+
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
102+
var services = serviceCollection.BuildServiceProvider();
103+
var activator = services.GetActivator();
104+
105+
// Act
106+
var retVal = original.DecryptElement(activator);
107+
108+
// Assert
109+
XmlAssert.Equal("<value />", retVal);
110+
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
111+
}
112+
113+
[Fact]
114+
public void DecryptElement_KnownType_UnableToResolveType_Success()
115+
{
116+
// Arrange
117+
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;
118+
119+
var original = XElement.Parse(@$"
120+
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
121+
<node>
122+
<value />
123+
</node>
124+
</x:encryptedSecret>");
125+
126+
var mockActivator = new Mock<IActivator>();
127+
mockActivator.Setup(o => o.CreateInstance(typeof(IXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
128+
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
129+
var resolvedType = typeof(NullXmlDecryptor);
130+
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny<string>(), out resolvedType)).Returns(false);
131+
132+
var serviceCollection = new ServiceCollection();
133+
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
134+
var services = serviceCollection.BuildServiceProvider();
135+
var activator = services.GetActivator();
136+
137+
// Act
138+
var retVal = original.DecryptElement(activator);
139+
140+
// Assert
141+
XmlAssert.Equal("<value />", retVal);
142+
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
81143
}
82144

83145
[Fact]

0 commit comments

Comments
 (0)