Skip to content

Commit

Permalink
Update tensor_util.cs
Browse files Browse the repository at this point in the history
  • Loading branch information
novikov-alexander authored Jun 14, 2024
1 parent b21a58a commit 483ac82
Showing 1 changed file with 6 additions and 1 deletion.
7 changes: 6 additions & 1 deletion src/TensorFlowNET.Core/Tensors/tensor_util.cs
Original file line number Diff line number Diff line change
Expand Up @@ -178,10 +178,15 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
values = values switch
{
long[] longValues when dtype == TF_DataType.TF_INT32 => longValues.Select(x => (int)x).ToArray(),
long[] longValues => values,
float[] floatValues when dtype == TF_DataType.TF_DOUBLE => floatValues.Select(x => (double)x).ToArray(),
float[] floatValues => values,
float[,] float2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(float2DValues, Convert.ToDouble),
float[,] float2DValues => values,
double[] doubleValues when dtype == TF_DataType.TF_FLOAT => doubleValues.Select(x => (float)x).ToArray(),
double[,] double2DValues when dtype == TF_DataType.TF_DOUBLE => ConvertArray2D(double2DValues, Convert.ToSingle),
double[] doubleValues => values,
double[,] double2DValues when dtype == TF_DataType.TF_FLOAT => ConvertArray2D(double2DValues, Convert.ToSingle),
double[,] double2DValues => values,
_ => Convert.ChangeType(values, new_system_dtype),
};
dtype = values.GetDataType();
Expand Down

0 comments on commit 483ac82

Please sign in to comment.