Skip to content

Commit

Permalink
Avoid NullReferenceException with async methods
Browse files Browse the repository at this point in the history
When using async methods and loose Mock Behavior, calls to Task methods
give NullReferenceException, since default(Task) is null. Calls to
Task.Wait() and Task<T>.Result should follow the loose behavior: do
nothing and return the default value for T, respectively.

Related to #64.
  • Loading branch information
Alex Tercete committed Dec 13, 2013
1 parent af32e6d commit 7bf05bd
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 0 deletions.
25 changes: 25 additions & 0 deletions Source/EmptyDefaultValueProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Threading.Tasks;

namespace Moq
{
Expand Down Expand Up @@ -86,6 +87,11 @@ private static object GetReferenceTypeDefault(Type valueType)
{
return new object[0].AsQueryable();
}
else if (valueType == typeof(Task))
{
// Task<T> inherits from Task, so just return Task<bool>
return GetCompletedTaskWithResult(false);
}
else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
var genericListType = typeof(List<>).MakeGenericType(valueType.GetGenericArguments()[0]);
Expand All @@ -101,6 +107,13 @@ private static object GetReferenceTypeDefault(Type valueType)
.MakeGenericMethod(genericType)
.Invoke(null, new[] { Activator.CreateInstance(genericListType) });
}
else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(Task<>))
{
var genericType = valueType.GetGenericArguments()[0];

return GetCompletedTaskWithResult(
genericType.IsValueType ? GetValueTypeDefault(genericType) : GetReferenceTypeDefault(genericType));
}

return null;
}
Expand All @@ -115,5 +128,17 @@ private static object GetValueTypeDefault(Type valueType)

return Activator.CreateInstance(valueType);
}

private static Task GetCompletedTaskWithResult(object result)
{
var type = result.GetType();

This comment has been minimized.

Copy link
@danielcweber

danielcweber Dec 15, 2013

Will throw when result == null - which is whenever GetCompletedTaskWithResult is called to get a completed Task<T> where T is a reference type.

This comment has been minimized.

Copy link
@alextercete

alextercete Dec 15, 2013

Of course it will. Thanks for pointing that out, I'll fix this shortly.

var tcs = Activator.CreateInstance(typeof (TaskCompletionSource<>).MakeGenericType(type));

var setResultMethod = tcs.GetType().GetMethod("SetResult");
var taskProperty = tcs.GetType().GetProperty("Task");

setResultMethod.Invoke(tcs, new[] {result});
return (Task) taskProperty.GetValue(tcs, null);
}
}
}
39 changes: 39 additions & 0 deletions UnitTests/EmptyDefaultValueProviderFixture.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading.Tasks;
using Xunit;

namespace Moq.Tests
Expand Down Expand Up @@ -107,6 +108,41 @@ public void ProvideEmptyQueryableObjects()
Assert.Equal(0, ((IQueryable)value).Cast<object>().Count());
}

[Fact]
public void ProvidesDefaultTask()
{
var provider = new EmptyDefaultValueProvider();

var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskValue").GetGetMethod());

Assert.NotNull(value);
Assert.True(((Task)value).IsCompleted);
}

[Fact]
public void ProvidesDefaultGenericTask()
{
var provider = new EmptyDefaultValueProvider();

var value = provider.ProvideDefault(typeof(IFoo).GetProperty("GenericTaskValue").GetGetMethod());

Assert.NotNull(value);
Assert.True(((Task)value).IsCompleted);
Assert.Equal(default(int), ((Task<int>)value).Result);
}

[Fact]
public void ProvidesDefaultTaskOfGenericTask()
{
var provider = new EmptyDefaultValueProvider();

var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskOfGenericTaskValue").GetGetMethod());

Assert.NotNull(value);
Assert.True(((Task)value).IsCompleted);
Assert.Equal(default(int), ((Task<Task<int>>) value).Result.Result);
}

public interface IFoo
{
object Object { get; set; }
Expand All @@ -120,6 +156,9 @@ public interface IFoo
IBar[] Bars { get; set; }
IQueryable<int> Queryable { get; }
IQueryable QueryableObjects { get; }
Task TaskValue { get; set; }
Task<int> GenericTaskValue { get; set; }
Task<Task<int>> TaskOfGenericTaskValue { get; set; }
}

public interface IBar { }
Expand Down

0 comments on commit 7bf05bd

Please sign in to comment.