Skip to content

Commit

Permalink
Merge pull request #66 from alextercete/master
Browse files Browse the repository at this point in the history
Avoid NullReferenceException with async methods
  • Loading branch information
kzu committed Dec 13, 2013
2 parents af32e6d + ad78301 commit b24199e
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
33 changes: 33 additions & 0 deletions Source/EmptyDefaultValueProvider.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
#if !NET3x
using System.Threading.Tasks;
#endif

namespace Moq
{
Expand Down Expand Up @@ -86,6 +89,13 @@ private static object GetReferenceTypeDefault(Type valueType)
{
return new object[0].AsQueryable();
}
#if !NET3x
else if (valueType == typeof(Task))
{
// Task<T> inherits from Task, so just return Task<bool>
return GetCompletedTaskWithResult(false);
}
#endif
else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(IEnumerable<>))
{
var genericListType = typeof(List<>).MakeGenericType(valueType.GetGenericArguments()[0]);
Expand All @@ -101,6 +111,15 @@ private static object GetReferenceTypeDefault(Type valueType)
.MakeGenericMethod(genericType)
.Invoke(null, new[] { Activator.CreateInstance(genericListType) });
}
#if !NET3x
else if (valueType.IsGenericType && valueType.GetGenericTypeDefinition() == typeof(Task<>))
{
var genericType = valueType.GetGenericArguments()[0];

return GetCompletedTaskWithResult(
genericType.IsValueType ? GetValueTypeDefault(genericType) : GetReferenceTypeDefault(genericType));
}
#endif

return null;
}
Expand All @@ -115,5 +134,19 @@ private static object GetValueTypeDefault(Type valueType)

return Activator.CreateInstance(valueType);
}

#if !NET3x
private static Task GetCompletedTaskWithResult(object result)
{
var type = result.GetType();
var tcs = Activator.CreateInstance(typeof (TaskCompletionSource<>).MakeGenericType(type));

var setResultMethod = tcs.GetType().GetMethod("SetResult");
var taskProperty = tcs.GetType().GetProperty("Task");

setResultMethod.Invoke(tcs, new[] {result});
return (Task) taskProperty.GetValue(tcs, null);
}
#endif
}
}
4 changes: 3 additions & 1 deletion Source/Language/IReturnsExtensions.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
#if !NET3x
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
Expand Down Expand Up @@ -34,3 +35,4 @@ public static IReturnsResult<TMock> ThrowsAsync<TMock, TResult>(this IReturns<TM
}
}
}
#endif
45 changes: 45 additions & 0 deletions UnitTests/EmptyDefaultValueProviderFixture.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using System.Collections.Generic;
using System.Linq;
#if !NET3x
using System.Threading.Tasks;
#endif
using Xunit;

namespace Moq.Tests
Expand Down Expand Up @@ -107,6 +110,43 @@ public void ProvideEmptyQueryableObjects()
Assert.Equal(0, ((IQueryable)value).Cast<object>().Count());
}

#if !NET3x
[Fact]
public void ProvidesDefaultTask()
{
var provider = new EmptyDefaultValueProvider();

var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskValue").GetGetMethod());

Assert.NotNull(value);
Assert.True(((Task)value).IsCompleted);
}

[Fact]
public void ProvidesDefaultGenericTask()
{
var provider = new EmptyDefaultValueProvider();

var value = provider.ProvideDefault(typeof(IFoo).GetProperty("GenericTaskValue").GetGetMethod());

Assert.NotNull(value);
Assert.True(((Task)value).IsCompleted);
Assert.Equal(default(int), ((Task<int>)value).Result);
}

[Fact]
public void ProvidesDefaultTaskOfGenericTask()
{
var provider = new EmptyDefaultValueProvider();

var value = provider.ProvideDefault(typeof(IFoo).GetProperty("TaskOfGenericTaskValue").GetGetMethod());

Assert.NotNull(value);
Assert.True(((Task)value).IsCompleted);
Assert.Equal(default(int), ((Task<Task<int>>) value).Result.Result);
}
#endif

public interface IFoo
{
object Object { get; set; }
Expand All @@ -120,6 +160,11 @@ public interface IFoo
IBar[] Bars { get; set; }
IQueryable<int> Queryable { get; }
IQueryable QueryableObjects { get; }
#if !NET3x
Task TaskValue { get; set; }
Task<int> GenericTaskValue { get; set; }
Task<Task<int>> TaskOfGenericTaskValue { get; set; }
#endif
}

public interface IBar { }
Expand Down

0 comments on commit b24199e

Please sign in to comment.