Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: more gradient optimizer tests #1217

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions src/TensorFlowNET.Core/Tensors/tensor_util.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
/*****************************************************************************
/*****************************************************************************
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -135,6 +135,23 @@ T[] ExpandArrayToSize<T>(IList<T> src)
TF_DataType.TF_QINT32
};

private static TOut[,] ConvertArray2D<TIn, TOut>(TIn[,] inputArray, Func<TIn, TOut> converter)
{
var rows = inputArray.GetLength(0);
var cols = inputArray.GetLength(1);
var outputArray = new TOut[rows, cols];

for (var i = 0; i < rows; i++)
{
for (var j = 0; j < cols; j++)
{
outputArray[i, j] = converter(inputArray[i, j]);
}
}

return outputArray;
}

/// <summary>
/// Create a TensorProto, invoked in graph mode
/// </summary>
Expand All @@ -157,19 +174,21 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
else if(origin_dtype != dtype)
{
var new_system_dtype = dtype.as_system_dtype();
if (values is long[] long_values)
{
if (dtype == TF_DataType.TF_INT32)
values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray();
}
else if (values is double[] double_values)

values = values switch
{
if (dtype == TF_DataType.TF_FLOAT)
values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray();
}
else
values = Convert.ChangeType(values, new_system_dtype);

long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
long[] longValues => values,
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
float[] floatValues => values,
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
float[,] float2DValues => values,
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
double[] doubleValues => values,
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
double[,] double2DValues => values,
_ => Convert.ChangeType(values, new_system_dtype),
};
dtype = values.GetDataType();
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.VisualStudio.TestTools.UnitTesting;
using System;
using System.Linq;
using Tensorflow;
using Tensorflow.NumPy;
using static Tensorflow.Binding;
Expand Down Expand Up @@ -67,6 +68,51 @@ public void TestBasic()
TestBasic<double>();
}

private void TestMinimizeResourceVariable<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

// train.GradientDescentOptimizer is V1 only API.
tf.Graph().as_default();
using (var sess = self.cached_session())
{
var var0 = tf.Variable(new[,] { { 1.0f, 2.0f } }, dtype: dtype);
var var1 = tf.Variable(new[] { 3.0 }, dtype: dtype);
var x = tf.constant(new[,] { { 4.0f }, { 5.0f } }, dtype: dtype);

var pred = math_ops.matmul(var0, x) + var1;
var loss = pred * pred;
var sgd_op = tf.train.GradientDescentOptimizer(1.0f).minimize(loss);

var global_variables = tf.global_variables_initializer();
sess.run(global_variables);

sess.run(new[] { var0, var1 });
// Fetch params to validate initial values
self.assertAllCloseAccordingToType<T>(new[,] { { 1.0, 2.0 } }, self.evaluate<T[,]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd
sgd_op.run();
// Validate updated params
var np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0;
var np_grad = 2 * np_pred;
self.assertAllCloseAccordingToType(
new[,] { { 1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0 } },
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Wanglongzhi2001 Hm, now it calculates but the test doesn't pass. However, the code corresponds to TensorFlow original test. I have to check math there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There just was a small typo, but I didn't have time to debug it :-D

self.evaluate<T[,]>(var0));
self.assertAllCloseAccordingToType(
new[] { 3.0 - np_grad },
self.evaluate<T[]>(var1));
}
}

[TestMethod]
public void TestMinimizeResourceVariable()
{
//TODO: add np.half
TestMinimizeResourceVariable<float>();
TestMinimizeResourceVariable<double>();
}

private void TestTensorLearningRate<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();
Expand Down Expand Up @@ -115,5 +161,72 @@ public void TestTensorLearningRate()
TestTensorLearningRate<float>();
TestTensorLearningRate<double>();
}

public void TestGradWrtRef<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

var graph = tf.Graph().as_default();
using (var sess = self.cached_session())
{
var opt = tf.train.GradientDescentOptimizer(3.0f);
var values = new[] { 1.0, 3.0 };
var vars_ = values.Select(
v => tf.Variable(new[] { v }, dtype: dtype) as IVariableV1
).ToList();
var grads_and_vars = opt.compute_gradients(tf.add(vars_[0], vars_[1]), vars_);
sess.run(tf.global_variables_initializer());
foreach (var (grad, _) in grads_and_vars)
self.assertAllCloseAccordingToType(new[] { 1.0 }, self.evaluate<T[]>(grad));

}
}

[TestMethod]
public void TestGradWrtRef()
{
TestGradWrtRef<float>();
TestGradWrtRef<double>();
}

public void TestWithGlobalStep<T>() where T : struct
{
var dtype = GetTypeForNumericType<T>();

tf.Graph().as_default();
using (var sess = self.cached_session())
{
var global_step = tf.Variable(0, trainable: false);
var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype);
var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype);
var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype);
var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype);
var grads_and_vars = new[] {
Tuple.Create(grads0, var0 as IVariableV1),
Tuple.Create(grads1, var1 as IVariableV1)
};
var sgd_op = tf.train.GradientDescentOptimizer(3.0f)
.apply_gradients(grads_and_vars, global_step: global_step);

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@AsakusaRinne why does apply_gradiens take System.Tuple while zip produces (T1, T2). Would it be better to replace or extend apply_gradients interface to support valuetuple as well?

sess.run(tf.global_variables_initializer());
// Fetch params to validate initial values
self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate<T[]>(var1));
// Run 1 step of sgd
sgd_op.run();
// Validate updated params and global_step
self.assertAllCloseAccordingToType(new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, self.evaluate<T[]>(var0));
self.assertAllCloseAccordingToType(new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate<T[]>(var1));
Assert.AreEqual(1, self.evaluate<int>(global_step));
}

}

[TestMethod]
public void TestWithGlobalStep()
{
TestWithGlobalStep<float>();
TestWithGlobalStep<double>();
}
}
}
45 changes: 36 additions & 9 deletions test/Tensorflow.UnitTest/PythonTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ public int Compare(object? x, object? y)
return 1;
}

var a = (double)x;
var b = (double)y;
var a = Convert.ToDouble(x);
var b = Convert.ToDouble(y);

double delta = Math.Abs(a - b);
if (delta < _epsilon)
Expand All @@ -187,6 +187,19 @@ public int Compare(object? x, object? y)
}
}

public void assertAllCloseAccordingToType<T>(
double[,] expected,
T[,] given,
double eps = 1e-6,
float float_eps = 1e-6f)
{
Assert.AreEqual(expected.GetLength(0), given.GetLength(0));
Assert.AreEqual(expected.GetLength(1), given.GetLength(1));

var flattenGiven = given.Cast<T>().ToArray();
assertAllCloseAccordingToType(expected, flattenGiven, eps, float_eps);
}

public void assertAllCloseAccordingToType<T>(
ICollection expected,
ICollection<T> given,
Expand Down Expand Up @@ -267,21 +280,35 @@ public T evaluate<T>(Tensor tensor)
{
var sess = tf.get_default_session();
var ndarray = tensor.eval(sess);
if (typeof(T) == typeof(double)
|| typeof(T) == typeof(float)
|| typeof(T) == typeof(int))

if (typeof(T) == typeof(int))
{
int i = ndarray;
result = i;
}
else if (typeof(T) == typeof(float))
{
float f = ndarray;
result = f;
}
else if (typeof(T) == typeof(double))
{
result = Convert.ChangeType(ndarray, typeof(T));
double d = ndarray;
result = d;
}
else if (typeof(T) == typeof(double[]))
else if (
typeof(T) == typeof(double[])
|| typeof(T) == typeof(double[,]))
{
result = ndarray.ToMultiDimArray<double>();
}
else if (typeof(T) == typeof(float[]))
else if (typeof(T) == typeof(float[])
|| typeof(T) == typeof(float[,]))
{
result = ndarray.ToMultiDimArray<float>();
}
else if (typeof(T) == typeof(int[]))
else if (typeof(T) == typeof(int[])
|| typeof(T) == typeof(int[,]))
{
result = ndarray.ToMultiDimArray<int>();
}
Expand Down
Loading