From d54f7a62e0e66dee73eff78ce5c93acb195ce813 Mon Sep 17 00:00:00 2001 From: Alexander Date: Mon, 13 Nov 2023 10:33:14 +0000 Subject: [PATCH 1/4] test: more gradients tests --- .../Training/GradientDescentOptimizerTests.cs | 113 ++++++++++++++++++ test/Tensorflow.UnitTest/PythonTest.cs | 45 +++++-- 2 files changed, 149 insertions(+), 9 deletions(-) diff --git a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs index d766890b..f7062f00 100644 --- a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs +++ b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs @@ -1,5 +1,6 @@ using Microsoft.VisualStudio.TestTools.UnitTesting; using System; +using System.Linq; using Tensorflow; using Tensorflow.NumPy; using static Tensorflow.Binding; @@ -67,6 +68,51 @@ public void TestBasic() TestBasic(); } + private void TestMinimizeResourceVariable() where T : struct + { + var dtype = GetTypeForNumericType(); + + // train.GradientDescentOptimizer is V1 only API. + tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var var0 = tf.Variable(new[,] { { 1.0f, 2.0f } }, dtype: dtype); + var var1 = tf.Variable(new[] { 3.0 }, dtype: dtype); + var x = tf.constant(new[,] { { 4.0f }, { 5.0f } }, dtype: dtype); + + var pred = math_ops.matmul(var0, x) + var1; + var loss = pred * pred; + var sgd_op = tf.train.GradientDescentOptimizer(3.0f).minimize(loss); + + var global_variables = tf.global_variables_initializer(); + sess.run(global_variables); + + sess.run(new[] { var0, var1 }); + // Fetch params to validate initial values + self.assertAllCloseAccordingToType(new[,] { { 1.0, 2.0 } }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0 }, self.evaluate(var1)); + // Run 1 step of sgd + sgd_op.run(); + // Validate updated params + var np_pred = 1.0 * 4.0 + 2.0 * 5.0 + 3.0; + var np_grad = 2 * np_pred; + self.assertAllCloseAccordingToType( + new[,] { { 1.0 - np_grad * 4.0, 2.0 - np_grad * 5.0 } }, + self.evaluate(var0)); + self.assertAllCloseAccordingToType( + new[] { 3.0 - np_grad }, + self.evaluate(var1)); + } + } + + [TestMethod] + public void TestMinimizeResourceVariable() + { + //TODO: add np.half + TestMinimizeResourceVariable(); + TestMinimizeResourceVariable(); + } + private void TestTensorLearningRate() where T : struct { var dtype = GetTypeForNumericType(); @@ -115,5 +161,72 @@ public void TestTensorLearningRate() TestTensorLearningRate(); TestTensorLearningRate(); } + + public void TestGradWrtRef() where T : struct + { + var dtype = GetTypeForNumericType(); + + var graph = tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var opt = tf.train.GradientDescentOptimizer(3.0f); + var values = new[] { 1.0, 3.0 }; + var vars_ = values.Select( + v => tf.Variable(new[] { v }, dtype: dtype) as IVariableV1 + ).ToList(); + var grads_and_vars = opt.compute_gradients(tf.add(vars_[0], vars_[1]), vars_); + sess.run(tf.global_variables_initializer()); + foreach (var (grad, _) in grads_and_vars) + self.assertAllCloseAccordingToType(new[] { 1.0 }, self.evaluate(grad)); + + } + } + + [TestMethod] + public void TestGradWrtRef() + { + TestGradWrtRef(); + TestGradWrtRef(); + } + + public void TestWithGlobalStep() where T : struct + { + var dtype = GetTypeForNumericType(); + + tf.Graph().as_default(); + using (var sess = self.cached_session()) + { + var global_step = tf.Variable(0, trainable: false); + var var0 = tf.Variable(new[] { 1.0, 2.0 }, dtype: dtype); + var var1 = tf.Variable(new[] { 3.0, 4.0 }, dtype: dtype); + var grads0 = tf.constant(new[] { 0.1, 0.1 }, dtype: dtype); + var grads1 = tf.constant(new[] { 0.01, 0.01 }, dtype: dtype); + var grads_and_vars = new[] { + Tuple.Create(grads0, var0 as IVariableV1), + Tuple.Create(grads1, var1 as IVariableV1) + }; + var sgd_op = tf.train.GradientDescentOptimizer(3.0f) + .apply_gradients(grads_and_vars, global_step: global_step); + + sess.run(tf.global_variables_initializer()); + // Fetch params to validate initial values + self.assertAllCloseAccordingToType(new[] { 1.0, 2.0 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0, 4.0 }, self.evaluate(var1)); + // Run 1 step of sgd + sgd_op.run(); + // Validate updated params and global_step + self.assertAllCloseAccordingToType(new[] { 1.0 - 3.0 * 0.1, 2.0 - 3.0 * 0.1 }, self.evaluate(var0)); + self.assertAllCloseAccordingToType(new[] { 3.0 - 3.0 * 0.01, 4.0 - 3.0 * 0.01 }, self.evaluate(var1)); + Assert.AreEqual(1, self.evaluate(global_step)); + } + + } + + [TestMethod] + public void TestWithGlobalStep() + { + TestWithGlobalStep(); + TestWithGlobalStep(); + } } } diff --git a/test/Tensorflow.UnitTest/PythonTest.cs b/test/Tensorflow.UnitTest/PythonTest.cs index dff65293..1ccd39f0 100644 --- a/test/Tensorflow.UnitTest/PythonTest.cs +++ b/test/Tensorflow.UnitTest/PythonTest.cs @@ -175,8 +175,8 @@ public int Compare(object? x, object? y) return 1; } - var a = (double)x; - var b = (double)y; + var a = Convert.ToDouble(x); + var b = Convert.ToDouble(y); double delta = Math.Abs(a - b); if (delta < _epsilon) @@ -187,6 +187,19 @@ public int Compare(object? x, object? y) } } + public void assertAllCloseAccordingToType( + double[,] expected, + T[,] given, + double eps = 1e-6, + float float_eps = 1e-6f) + { + Assert.AreEqual(expected.GetLength(0), given.GetLength(0)); + Assert.AreEqual(expected.GetLength(1), given.GetLength(1)); + + var flattenGiven = given.Cast().ToArray(); + assertAllCloseAccordingToType(expected, flattenGiven, eps, float_eps); + } + public void assertAllCloseAccordingToType( ICollection expected, ICollection given, @@ -267,21 +280,35 @@ public T evaluate(Tensor tensor) { var sess = tf.get_default_session(); var ndarray = tensor.eval(sess); - if (typeof(T) == typeof(double) - || typeof(T) == typeof(float) - || typeof(T) == typeof(int)) + + if (typeof(T) == typeof(int)) + { + int i = ndarray; + result = i; + } + else if (typeof(T) == typeof(float)) + { + float f = ndarray; + result = f; + } + else if (typeof(T) == typeof(double)) { - result = Convert.ChangeType(ndarray, typeof(T)); + double d = ndarray; + result = d; } - else if (typeof(T) == typeof(double[])) + else if ( + typeof(T) == typeof(double[]) + || typeof(T) == typeof(double[,])) { result = ndarray.ToMultiDimArray(); } - else if (typeof(T) == typeof(float[])) + else if (typeof(T) == typeof(float[]) + || typeof(T) == typeof(float[,])) { result = ndarray.ToMultiDimArray(); } - else if (typeof(T) == typeof(int[])) + else if (typeof(T) == typeof(int[]) + || typeof(T) == typeof(int[,])) { result = ndarray.ToMultiDimArray(); } From b3ce158ec3304469bf776bc582b847e685a9df73 Mon Sep 17 00:00:00 2001 From: novikov-alexander <79649566+novikov-alexander@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:40:06 +0300 Subject: [PATCH 2/4] Update tensor_util.cs --- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 40 +++++++++++++------ 1 file changed, 27 insertions(+), 13 deletions(-) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index f688d4d5..f2003c9d 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -1,4 +1,4 @@ -/***************************************************************************** +/***************************************************************************** Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); @@ -135,6 +135,23 @@ T[] ExpandArrayToSize(IList src) TF_DataType.TF_QINT32 }; + private static TOut[,] ConvertArray2D(TIn[,] inputArray, Func converter) + { + var rows = inputArray.GetLength(0); + var cols = inputArray.GetLength(1); + var outputArray = new TOut[rows, cols]; + + for (var i = 0; i < rows; i++) + { + for (var j = 0; j < cols; j++) + { + outputArray[i, j] = converter(inputArray[i, j]); + } + } + + return outputArray; + } + /// /// Create a TensorProto, invoked in graph mode /// @@ -157,19 +174,16 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T else if(origin_dtype != dtype) { var new_system_dtype = dtype.as_system_dtype(); - if (values is long[] long_values) - { - if (dtype == TF_DataType.TF_INT32) - values = long_values.Select(x => (int)Convert.ChangeType(x, new_system_dtype)).ToArray(); - } - else if (values is double[] double_values) + + values = values switch { - if (dtype == TF_DataType.TF_FLOAT) - values = double_values.Select(x => (float)Convert.ChangeType(x, new_system_dtype)).ToArray(); - } - else - values = Convert.ChangeType(values, new_system_dtype); - + long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), + float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), + float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), + double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), + double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle), + _ => Convert.ChangeType(values, new_system_dtype), + }; dtype = values.GetDataType(); } From 18db147eb40a07931e8421bbd63c64ce11edd558 Mon Sep 17 00:00:00 2001 From: novikov-alexander <79649566+novikov-alexander@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:40:37 +0300 Subject: [PATCH 3/4] Update GradientDescentOptimizerTests.cs --- .../Training/GradientDescentOptimizerTests.cs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs index f7062f00..3b53ff9c 100644 --- a/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs +++ b/test/TensorFlowNET.UnitTest/Training/GradientDescentOptimizerTests.cs @@ -1,4 +1,4 @@ -using Microsoft.VisualStudio.TestTools.UnitTesting; +using Microsoft.VisualStudio.TestTools.UnitTesting; using System; using System.Linq; using Tensorflow; @@ -82,7 +82,7 @@ private void TestMinimizeResourceVariable() where T : struct var pred = math_ops.matmul(var0, x) + var1; var loss = pred * pred; - var sgd_op = tf.train.GradientDescentOptimizer(3.0f).minimize(loss); + var sgd_op = tf.train.GradientDescentOptimizer(1.0f).minimize(loss); var global_variables = tf.global_variables_initializer(); sess.run(global_variables); From 483ac82cd2db273c2c0520ce6923f5951638daba Mon Sep 17 00:00:00 2001 From: novikov-alexander <79649566+novikov-alexander@users.noreply.github.com> Date: Fri, 14 Jun 2024 15:02:17 +0300 Subject: [PATCH 4/4] Update tensor_util.cs --- src/TensorFlowNET.Core/Tensors/tensor_util.cs | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/TensorFlowNET.Core/Tensors/tensor_util.cs b/src/TensorFlowNET.Core/Tensors/tensor_util.cs index f2003c9d..873579e4 100644 --- a/src/TensorFlowNET.Core/Tensors/tensor_util.cs +++ b/src/TensorFlowNET.Core/Tensors/tensor_util.cs @@ -178,10 +178,15 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T values = values switch { long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(), + long[] longValues => values, float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(), + float[] floatValues => values, float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble), + float[,] float2DValues => values, double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(), - double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle), + double[] doubleValues => values, + double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle), + double[,] double2DValues => values, _ => Convert.ChangeType(values, new_system_dtype), }; dtype = values.GetDataType();