Skip to content

Commit f2c7b53

Browse files
Find inherited TryParse and BindAsync (dotnet#36688)
1 parent 8539422 commit f2c7b53

File tree

2 files changed

+133
-6
lines changed

2 files changed

+133
-6
lines changed

src/Http/Http.Extensions/test/ParameterBindingMethodCacheTests.cs

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvaria
7777
[Theory]
7878
[InlineData(typeof(TryParseStringRecord))]
7979
[InlineData(typeof(TryParseStringStruct))]
80+
[InlineData(typeof(TryParseInheritClassWithFormatProvider))]
8081
public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvariantCultureCustomType(Type type)
8182
{
8283
var methodFound = new ParameterBindingMethodCache().FindTryParseMethod(@type);
@@ -94,6 +95,24 @@ public void FindTryParseStringMethod_ReturnsTheExpectedTryParseMethodWithInvaria
9495
Assert.True(((call.Arguments[1] as ConstantExpression)!.Value as CultureInfo)!.Equals(CultureInfo.InvariantCulture));
9596
}
9697

98+
[Theory]
99+
[InlineData(typeof(TryParseNoFormatProviderRecord))]
100+
[InlineData(typeof(TryParseNoFormatProviderStruct))]
101+
[InlineData(typeof(TryParseInheritClass))]
102+
public void FindTryParseMethod_WithNoFormatProvider(Type type)
103+
{
104+
var methodFound = new ParameterBindingMethodCache().FindTryParseMethod(@type);
105+
Assert.NotNull(methodFound);
106+
107+
var call = methodFound!(Expression.Variable(type, "parsedValue")) as MethodCallExpression;
108+
Assert.NotNull(call);
109+
var parameters = call!.Method.GetParameters();
110+
111+
Assert.Equal(2, parameters.Length);
112+
Assert.Equal(typeof(string), parameters[0].ParameterType);
113+
Assert.True(parameters[1].IsOut);
114+
}
115+
97116
public static IEnumerable<object[]> TryParseStringParameterInfoData
98117
{
99118
get
@@ -249,6 +268,14 @@ public static IEnumerable<object[]> BindAsyncParameterInfoData
249268
new[]
250269
{
251270
GetFirstParameter((BindAsyncSingleArgStruct arg) => BindAsyncSingleArgStructMethod(arg)),
271+
},
272+
new[]
273+
{
274+
GetFirstParameter((InheritBindAsync arg) => InheritBindAsyncMethod(arg))
275+
},
276+
new[]
277+
{
278+
GetFirstParameter((InheritBindAsyncWithParameterInfo arg) => InheritBindAsyncWithParameterInfoMethod(arg))
252279
}
253280
};
254281
}
@@ -285,6 +312,7 @@ public void FindBindAsyncMethod_FindsNonNullableReturningBindAsyncMethodGivenNul
285312
[InlineData(typeof(InvalidTooFewArgsTryParseClass))]
286313
[InlineData(typeof(InvalidNonStaticTryParseStruct))]
287314
[InlineData(typeof(InvalidNonStaticTryParseClass))]
315+
[InlineData(typeof(TryParseWrongTypeInheritClass))]
288316
public void FindTryParseMethod_ThrowsIfInvalidTryParseOnType(Type type)
289317
{
290318
var ex = Assert.Throws<InvalidOperationException>(
@@ -308,6 +336,8 @@ public void FindTryParseMethod_IgnoresInvalidTryParseIfGoodOneFound(Type type)
308336
[InlineData(typeof(InvalidWrongReturnBindAsyncClass))]
309337
[InlineData(typeof(InvalidWrongParamBindAsyncStruct))]
310338
[InlineData(typeof(InvalidWrongParamBindAsyncClass))]
339+
[InlineData(typeof(BindAsyncWrongTypeInherit))]
340+
[InlineData(typeof(BindAsyncWithParameterInfoWrongTypeInherit))]
311341
public void FindBindAsyncMethod_ThrowsIfInvalidBindAsyncOnType(Type type)
312342
{
313343
var cache = new ParameterBindingMethodCache();
@@ -350,6 +380,8 @@ private static void NullableReturningBindAsyncStructMethod(NullableReturningBind
350380

351381
private static void BindAsyncSingleArgRecordMethod(BindAsyncSingleArgRecord arg) { }
352382
private static void BindAsyncSingleArgStructMethod(BindAsyncSingleArgStruct arg) { }
383+
private static void InheritBindAsyncMethod(InheritBindAsync arg) { }
384+
private static void InheritBindAsyncWithParameterInfoMethod(InheritBindAsyncWithParameterInfo args) { }
353385

354386
private static ParameterInfo GetFirstParameter<T>(Expression<Action<T>> expr)
355387
{
@@ -538,6 +570,67 @@ public bool TryParse(string? value, IFormatProvider formatProvider, out InvalidN
538570
}
539571
}
540572

573+
private record TryParseNoFormatProviderRecord(int Value)
574+
{
575+
public static bool TryParse(string? value, out TryParseNoFormatProviderRecord? result)
576+
{
577+
if (!int.TryParse(value, out var val))
578+
{
579+
result = null;
580+
return false;
581+
}
582+
583+
result = new TryParseNoFormatProviderRecord(val);
584+
return true;
585+
}
586+
}
587+
588+
private record struct TryParseNoFormatProviderStruct(int Value)
589+
{
590+
public static bool TryParse(string? value, out TryParseNoFormatProviderStruct result)
591+
{
592+
if (!int.TryParse(value, out var val))
593+
{
594+
result = default;
595+
return false;
596+
}
597+
598+
result = new TryParseNoFormatProviderStruct(val);
599+
return true;
600+
}
601+
}
602+
603+
private class BaseTryParseClass<T>
604+
{
605+
public static bool TryParse(string? value, out T? result)
606+
{
607+
result = default(T);
608+
return false;
609+
}
610+
}
611+
612+
private class TryParseInheritClass : BaseTryParseClass<TryParseInheritClass>
613+
{
614+
}
615+
616+
// using wrong T on purpose
617+
private class TryParseWrongTypeInheritClass : BaseTryParseClass<TryParseInheritClass>
618+
{
619+
}
620+
621+
private class BaseTryParseClassWithFormatProvider<T>
622+
{
623+
public static bool TryParse(string? value, IFormatProvider formatProvider, out T? result)
624+
{
625+
result = default(T);
626+
return false;
627+
}
628+
}
629+
630+
private class TryParseInheritClassWithFormatProvider : BaseTryParseClassWithFormatProvider<TryParseInheritClassWithFormatProvider>
631+
{
632+
}
633+
541634
private record BindAsyncRecord(int Value)
542635
{
543636
public static ValueTask<BindAsyncRecord?> BindAsync(HttpContext context, ParameterInfo parameter)
@@ -644,6 +737,40 @@ public static ValueTask<BindAsyncClassWithGoodAndBad> BindAsync(ParameterInfo pa
644737
throw new NotImplementedException();
645738
}
646739

740+
private class BaseBindAsync<T>
741+
{
742+
public static ValueTask<T?> BindAsync(HttpContext context)
743+
{
744+
return new(default(T));
745+
}
746+
}
747+
748+
private class InheritBindAsync : BaseBindAsync<InheritBindAsync>
749+
{
750+
}
751+
752+
// Using wrong T on purpose
753+
private class BindAsyncWrongTypeInherit : BaseBindAsync<InheritBindAsync>
754+
{
755+
}
756+
757+
private class BaseBindAsyncWithParameterInfo<T>
758+
{
759+
public static ValueTask<T?> BindAsync(HttpContext context, ParameterInfo parameter)
760+
{
761+
return new(default(T));
762+
}
763+
}
764+
765+
private class InheritBindAsyncWithParameterInfo : BaseBindAsyncWithParameterInfo<InheritBindAsyncWithParameterInfo>
766+
{
767+
}
768+
769+
// Using wrong T on purpose
770+
private class BindAsyncWithParameterInfoWrongTypeInherit : BaseBindAsyncWithParameterInfo<InheritBindAsync>
771+
{
772+
}
773+
647774
private class MockParameterInfo : ParameterInfo
648775
{
649776
public MockParameterInfo(Type type, string name)

src/Shared/ParameterBindingMethodCache.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) =>
106106
expression);
107107
}
108108

109-
methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(string), typeof(IFormatProvider), type.MakeByRefType() });
109+
methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(string), typeof(IFormatProvider), type.MakeByRefType() });
110110

111111
if (methodInfo is not null && methodInfo.ReturnType == typeof(bool))
112112
{
@@ -117,14 +117,14 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) =>
117117
expression);
118118
}
119119

120-
methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static, new[] { typeof(string), type.MakeByRefType() });
120+
methodInfo = type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(string), type.MakeByRefType() });
121121

122122
if (methodInfo is not null && methodInfo.ReturnType == typeof(bool))
123123
{
124124
return (expression) => Expression.Call(methodInfo, TempSourceStringExpr, expression);
125125
}
126126

127-
if (type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance) is MethodInfo invalidMethod)
127+
if (type.GetMethod("TryParse", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance | BindingFlags.FlattenHierarchy) is MethodInfo invalidMethod)
128128
{
129129
var stringBuilder = new StringBuilder();
130130
stringBuilder.AppendLine(CultureInfo.InvariantCulture, $"TryParse method found on {TypeNameHelper.GetTypeDisplayName(type, fullName: false)} with incorrect format. Must be a static method with format");
@@ -149,11 +149,11 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) =>
149149
{
150150
var hasParameterInfo = true;
151151
// There should only be one BindAsync method with these parameters since C# does not allow overloading on return type.
152-
var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext), typeof(ParameterInfo) });
152+
var methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(HttpContext), typeof(ParameterInfo) });
153153
if (methodInfo is null)
154154
{
155155
hasParameterInfo = false;
156-
methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static, new[] { typeof(HttpContext) });
156+
methodInfo = nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.FlattenHierarchy, new[] { typeof(HttpContext) });
157157
}
158158

159159
// We're looking for a method with the following signatures:
@@ -207,7 +207,7 @@ public bool HasBindAsyncMethod(ParameterInfo parameter) =>
207207
}
208208
}
209209

210-
if (nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance) is MethodInfo invalidBindMethod)
210+
if (nonNullableParameterType.GetMethod("BindAsync", BindingFlags.Public | BindingFlags.Static | BindingFlags.Instance | BindingFlags.FlattenHierarchy) is MethodInfo invalidBindMethod)
211211
{
212212
var stringBuilder = new StringBuilder();
213213
stringBuilder.AppendLine(CultureInfo.InvariantCulture, $"BindAsync method found on {TypeNameHelper.GetTypeDisplayName(nonNullableParameterType, fullName: false)} with incorrect format. Must be a static method with format");

0 commit comments

Comments
 (0)