Skip to content

Commit 07f7f17

Browse files
JamesNKkhellang
authored andcommitted
Check type name for known value before getting type
1 parent 6167a3b commit 07f7f17

File tree

6 files changed

+165
-35
lines changed

6 files changed

+165
-35
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Diagnostics.CodeAnalysis;
6+
7+
namespace Microsoft.AspNetCore.DataProtection.Internal;
8+
9+
internal sealed class DefaultTypeNameResolver : ITypeNameResolver
10+
{
11+
public static readonly DefaultTypeNameResolver Instance = new();
12+
13+
private DefaultTypeNameResolver()
14+
{
15+
}
16+
17+
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType is only used to resolve statically known types that are referenced by DataProtection assembly.")]
18+
public bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type)
19+
{
20+
try
21+
{
22+
// Some exceptions are thrown regardless of the value of throwOnError.
23+
// For example, if the type is found but cannot be loaded,
24+
// a System.TypeLoadException is thrown even if throwOnError is false.
25+
type = Type.GetType(typeName, throwOnError: false);
26+
return type != null;
27+
}
28+
catch
29+
{
30+
type = null;
31+
return false;
32+
}
33+
}
34+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using System;
5+
using System.Diagnostics.CodeAnalysis;
6+
7+
namespace Microsoft.AspNetCore.DataProtection.Internal;
8+
9+
internal interface ITypeNameResolver
10+
{
11+
bool TryResolveType(string typeName, [NotNullWhen(true)] out Type? type);
12+
}

src/DataProtection/DataProtection/src/KeyManagement/XmlKeyManager.cs

+20-17
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ public sealed class XmlKeyManager : IKeyManager, IInternalXmlKeyManager
4949
private const string RevokeAllKeysValue = "*";
5050

5151
private readonly IActivator _activator;
52+
private readonly ITypeNameResolver _typeNameResolver;
5253
private readonly AlgorithmConfiguration _authenticatedEncryptorConfiguration;
5354
private readonly IKeyEscrowSink? _keyEscrowSink;
5455
private readonly IInternalXmlKeyManager _internalKeyManager;
@@ -112,6 +113,7 @@ internal XmlKeyManager(
112113
var escrowSinks = keyManagementOptions.Value.KeyEscrowSinks;
113114
_keyEscrowSink = escrowSinks.Count > 0 ? new AggregateKeyEscrowSink(escrowSinks) : null;
114115
_activator = activator;
116+
_typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;
115117
TriggerAndResetCacheExpirationToken(suppressLogging: true);
116118
_internalKeyManager = _internalKeyManager ?? this;
117119
_encryptorFactories = keyManagementOptions.Value.AuthenticatedEncryptorFactories;
@@ -466,24 +468,25 @@ IAuthenticatedEncryptorDescriptor IInternalXmlKeyManager.DeserializeDescriptorFr
466468
[UnconditionalSuppressMessage("Trimmer", "IL2057", Justification = "Type.GetType result is only useful with types that are referenced by DataProtection assembly.")]
467469
private IAuthenticatedEncryptorDescriptorDeserializer CreateDeserializer(string descriptorDeserializerTypeName)
468470
{
469-
if (_activator is ITypeForwardingActivator forwardingActivator && forwardingActivator.TryForwardType(descriptorDeserializerTypeName, out var type))
471+
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(descriptorDeserializerTypeName, out var forwardedTypeName)
472+
? forwardedTypeName
473+
: descriptorDeserializerTypeName;
474+
475+
if (TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(AuthenticatedEncryptorDescriptorDeserializer)))
470476
{
471-
if (type == typeof(AuthenticatedEncryptorDescriptorDeserializer))
472-
{
473-
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
474-
}
475-
else if (type == typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
476-
{
477-
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
478-
}
479-
else if (type == typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer) && RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
480-
{
481-
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
482-
}
483-
else if (type == typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer))
484-
{
485-
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
486-
}
477+
return _activator.CreateInstance<AuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
478+
}
479+
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(CngCbcAuthenticatedEncryptorDescriptorDeserializer)))
480+
{
481+
return _activator.CreateInstance<CngCbcAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
482+
}
483+
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows) && TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(CngGcmAuthenticatedEncryptorDescriptorDeserializer)))
484+
{
485+
return _activator.CreateInstance<CngGcmAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
486+
}
487+
else if (TypeExtensions.MatchType(resolvedTypeName, _typeNameResolver, typeof(ManagedAuthenticatedEncryptorDescriptorDeserializer)))
488+
{
489+
return _activator.CreateInstance<ManagedAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);
487490
}
488491

489492
return _activator.CreateInstance<IAuthenticatedEncryptorDescriptorDeserializer>(descriptorDeserializerTypeName);

src/DataProtection/DataProtection/src/TypeExtensions.cs

+12
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
using System;
55
using System.Diagnostics.CodeAnalysis;
6+
using Microsoft.AspNetCore.DataProtection.Internal;
67

78
namespace Microsoft.AspNetCore.DataProtection;
89

@@ -39,4 +40,15 @@ public static Type GetTypeWithTrimFriendlyErrorMessage(string typeName)
3940
throw new InvalidOperationException($"Unable to load type '{typeName}'. If the app is published with trimming then this type may have been trimmed. Ensure the type's assembly is excluded from trimming.", ex);
4041
}
4142
}
43+
44+
public static bool MatchType(string resolvedTypeName, ITypeNameResolver typeNameResolver, Type matchType)
45+
{
46+
// Before attempting to resolve the name to a type, check if it starts with the full name of the type.
47+
if (matchType.FullName != null && resolvedTypeName.StartsWith(matchType.FullName, StringComparison.Ordinal))
48+
{
49+
return typeNameResolver.TryResolveType(resolvedTypeName, out var resolvedType) && resolvedType == matchType;
50+
}
51+
52+
return false;
53+
}
4254
}

src/DataProtection/DataProtection/src/XmlEncryption/XmlEncryptionExtensions.cs

+19-17
Original file line numberDiff line numberDiff line change
@@ -69,24 +69,26 @@ public static XElement DecryptElement(this XElement element, IActivator activato
6969

7070
private static IXmlDecryptor CreateDecryptor(IActivator activator, string decryptorTypeName)
7171
{
72-
if (activator is ITypeForwardingActivator typeForwarder && typeForwarder.TryForwardType(decryptorTypeName, out var type))
72+
var resolvedTypeName = TypeForwardingActivator.TryForwardTypeName(decryptorTypeName, out var forwardedTypeName)
73+
? forwardedTypeName
74+
: decryptorTypeName;
75+
var typeNameResolver = activator as ITypeNameResolver ?? DefaultTypeNameResolver.Instance;
76+
77+
if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(DpapiNGXmlDecryptor)))
7378
{
74-
if (type == typeof(DpapiNGXmlDecryptor))
75-
{
76-
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
77-
}
78-
else if (type == typeof(DpapiXmlDecryptor))
79-
{
80-
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
81-
}
82-
else if (type == typeof(EncryptedXmlDecryptor))
83-
{
84-
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
85-
}
86-
else if (type == typeof(NullXmlDecryptor))
87-
{
88-
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
89-
}
79+
return activator.CreateInstance<DpapiNGXmlDecryptor>(decryptorTypeName);
80+
}
81+
else if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(DpapiXmlDecryptor)))
82+
{
83+
return activator.CreateInstance<DpapiXmlDecryptor>(decryptorTypeName);
84+
}
85+
else if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(EncryptedXmlDecryptor)))
86+
{
87+
return activator.CreateInstance<EncryptedXmlDecryptor>(decryptorTypeName);
88+
}
89+
else if (TypeExtensions.MatchType(resolvedTypeName, typeNameResolver, typeof(NullXmlDecryptor)))
90+
{
91+
return activator.CreateInstance<NullXmlDecryptor>(decryptorTypeName);
9092
}
9193

9294
return activator.CreateInstance<IXmlDecryptor>(decryptorTypeName);

src/DataProtection/DataProtection/test/Microsoft.AspNetCore.DataProtection.Tests/XmlEncryption/XmlEncryptionExtensionsTests.cs

+68-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ public void DecryptElement_RootNodeRequiresDecryption_Success()
5050
}
5151

5252
[Fact]
53-
public void DecryptElement_MissingDecryptorType_Success()
53+
public void DecryptElement_CustomType_TypeNameResolverNotCalled()
5454
{
5555
// Arrange
5656
var decryptorTypeName = typeof(MyXmlDecryptor).AssemblyQualifiedName;
@@ -62,6 +62,7 @@ public void DecryptElement_MissingDecryptorType_Success()
6262

6363
var mockActivator = new Mock<ITypeForwardingActivator>();
6464
mockActivator.ReturnDecryptedElementGivenDecryptorTypeNameAndInput(decryptorTypeName, "<node />", "<newNode />");
65+
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
6566

6667
// We want to simulate an error loading the specified decryptor type, i.e.
6768
// Could not load file or assembly 'Azure.Extensions.AspNetCore.DataProtection.Keys,
@@ -79,6 +80,72 @@ public void DecryptElement_MissingDecryptorType_Success()
7980

8081
// Assert
8182
XmlAssert.Equal("<newNode />", retVal);
83+
Type resolvedType;
84+
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Never());
85+
}
86+
87+
[Fact]
88+
public void DecryptElement_KnownType_TypeNameResolverCalled()
89+
{
90+
// Arrange
91+
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;
92+
93+
var original = XElement.Parse(@$"
94+
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
95+
<node>
96+
<value />
97+
</node>
98+
</x:encryptedSecret>");
99+
100+
var mockActivator = new Mock<IActivator>();
101+
mockActivator.Setup(o => o.CreateInstance(typeof(NullXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
102+
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
103+
var resolvedType = typeof(NullXmlDecryptor);
104+
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny<string>(), out resolvedType)).Returns(true);
105+
106+
var serviceCollection = new ServiceCollection();
107+
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
108+
var services = serviceCollection.BuildServiceProvider();
109+
var activator = services.GetActivator();
110+
111+
// Act
112+
var retVal = original.DecryptElement(activator);
113+
114+
// Assert
115+
XmlAssert.Equal("<value />", retVal);
116+
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
117+
}
118+
119+
[Fact]
120+
public void DecryptElement_KnownType_UnableToResolveType_Success()
121+
{
122+
// Arrange
123+
var decryptorTypeName = typeof(NullXmlDecryptor).AssemblyQualifiedName;
124+
125+
var original = XElement.Parse(@$"
126+
<x:encryptedSecret decryptorType='{decryptorTypeName}' xmlns:x='http://schemas.asp.net/2015/03/dataProtection'>
127+
<node>
128+
<value />
129+
</node>
130+
</x:encryptedSecret>");
131+
132+
var mockActivator = new Mock<IActivator>();
133+
mockActivator.Setup(o => o.CreateInstance(typeof(IXmlDecryptor), decryptorTypeName)).Returns(new NullXmlDecryptor());
134+
var mockTypeNameResolver = mockActivator.As<ITypeNameResolver>();
135+
var resolvedType = typeof(NullXmlDecryptor);
136+
mockTypeNameResolver.Setup(mockTypeNameResolver => mockTypeNameResolver.TryResolveType(It.IsAny<string>(), out resolvedType)).Returns(false);
137+
138+
var serviceCollection = new ServiceCollection();
139+
serviceCollection.AddSingleton<IActivator>(mockActivator.Object);
140+
var services = serviceCollection.BuildServiceProvider();
141+
var activator = services.GetActivator();
142+
143+
// Act
144+
var retVal = original.DecryptElement(activator);
145+
146+
// Assert
147+
XmlAssert.Equal("<value />", retVal);
148+
mockTypeNameResolver.Verify(o => o.TryResolveType(It.IsAny<string>(), out resolvedType), Times.Once());
82149
}
83150

84151
[Fact]

0 commit comments

Comments
 (0)