Skip to content

Commit

Permalink
[ TEST ] add torch input and output test data for mixed precision
Browse files Browse the repository at this point in the history
This PR add torch mixed precsion golden data generation and input and
output for test.

. some fixes to test.

Resolves:

**Self evaluation:**
1. Build test:	 [X]Passed [ ]Failed [ ]Skipped
2. Run test:	 [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: jijoong.moon <[email protected]>
  • Loading branch information
jijoongmoon committed May 24, 2024
1 parent 7bd4989 commit 8030518
Show file tree
Hide file tree
Showing 23 changed files with 159 additions and 24 deletions.
1 change: 1 addition & 0 deletions debian/nntrainer-dev.install
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
/usr/include/nntrainer/blas_interface.h
/usr/include/nntrainer/var_grad.h
/usr/include/nntrainer/weight.h
/usr/include/nntrainer/blas_avx.h
# todo: update dataset headers
/usr/include/nntrainer/databuffer.h
/usr/include/nntrainer/databuffer_factory.h
Expand Down
2 changes: 1 addition & 1 deletion meson_options.txt
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ option('enable-fp16', type: 'boolean', value: false)
option('enable-cublas', type: 'boolean', value: false)
option('enable-openmp', type: 'boolean', value: true)
option('enable-neon', type: 'boolean', value: false)
option('enable-avx', type: 'boolean', value: false)
option('enable-avx', type: 'boolean', value: true)
option('enable-opencl', type: 'boolean', value: false)

# ml-api dependency (to enable, install capi-inference from github.com/nnstreamer/api )
Expand Down
9 changes: 8 additions & 1 deletion nntrainer/graph/network_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,15 @@ void NetworkGraph::addLayer(std::shared_ptr<LayerNode> layer) {

InPlace
NetworkGraph::canExecuteInPlace(const std::shared_ptr<LayerNode> &lnode) {
if (!lnode->supportInPlace())

if (!lnode->supportInPlace()) {
return InPlace::NONE;
}

if (lnode->getType() == InputLayer::type &&
!istrequal(getTensorType()[2], "FP32")) {
return InPlace::NONE;
}

/** layers which behave as a no-op - flatten */
auto no_op = [](const std::shared_ptr<LayerNode> &lnode) {
Expand Down
7 changes: 5 additions & 2 deletions nntrainer/layers/input_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ namespace nntrainer {
static constexpr size_t SINGLE_INOUT_IDX = 0;

InputLayer::InputLayer() :
Layer(), input_props(props::Normalization(), props::Standardization()) {}
Layer(),
input_props(props::Normalization(), props::Standardization()),
is_inplace(true) {}

void InputLayer::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, input_props);
Expand Down Expand Up @@ -82,8 +84,9 @@ void InputLayer::finalize(InitLayerContext &context) {
* activation data type is not fp32, then it does not support in-place
* operation.
*/
if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32)
if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32) {
is_inplace = false;
}
}

} /* namespace nntrainer */
3 changes: 3 additions & 0 deletions nntrainer/layers/layer_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,9 @@ void LayerNode::read(std::ifstream &file, bool opt_var) {
/// @note shared weights are only be read at the first acecss
if (run_context->isGradientLastAccess(i)) {
run_context->getWeight(i).read(file);
if (run_context->isMixedPrecision(i) && getTrainable()) {
run_context->getWeightFP32(i).copyData(run_context->getWeight(i));
}
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions nntrainer/layers/loss/loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ void LossLayer::finalize(InitLayerContext &context) {
nntrainer::TensorDataTypeInfo>::from_string("FP32"));

context.setOutputDimensions(output_dim);

is_inplace = true;
if (context.getActivationDataType() != ml::train::TensorDim::DataType::FP32)
is_inplace = false;
}

void LossLayer::updateLoss(RunLayerContext &context, const Tensor &l) {
Expand Down
4 changes: 4 additions & 0 deletions nntrainer/layers/loss/loss_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ class LossLayer : public Layer {
*/
virtual bool supportBackwarding() const override { return true; }

bool supportInPlace() const override {return is_inplace;}

/**
* @copydoc Layer::requireLabel()
*/
Expand All @@ -69,6 +71,8 @@ class LossLayer : public Layer {

Tensor
l; /**< loss tensor to store intermediate value to calculate loss value */

bool is_inplace;
};

} // namespace nntrainer
Expand Down
4 changes: 2 additions & 2 deletions nntrainer/layers/loss/mse_loss_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,11 @@ void MSELossLayer::calcDerivative(RunLayerContext &context) {
if (ret_derivative.empty())
ret_derivative = context.getOutgoingDerivative(SINGLE_INOUT_IDX)
.clone(ml::train::TensorDim::DataType::FP32);

Tensor empty_tensor1;
Tensor &y = context.getInput(SINGLE_INOUT_IDX).getDataType() ==
ml::train::TensorDim::DataType::FP32
? context.getInput(SINGLE_INOUT_IDX)
: empty_tensor;
: empty_tensor1;

if (y.empty())
y = context.getInput(SINGLE_INOUT_IDX)
Expand Down
1 change: 1 addition & 0 deletions nntrainer/layers/loss/mse_loss_layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class MSELossLayer : public LossLayer {
const std::string getType() const override { return MSELossLayer::type; };

inline static const std::string type = "mse";

};
} // namespace nntrainer

Expand Down
4 changes: 3 additions & 1 deletion nntrainer/optimizers/adam.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,8 +79,10 @@ void Adam::applyGradient(RunOptimizerContext &context) {
? context.getGradient()
: empty_tensor;

if (x_grad.empty())
if (x_grad.empty()) {
x_grad = context.getGradient().clone(ml::train::TensorDim::DataType::FP32);
context.applyLossScale(x_grad);
}

auto &beta1 = std::get<PropsB1>(adam_props).get();
auto &beta2 = std::get<PropsB2>(adam_props).get();
Expand Down
13 changes: 13 additions & 0 deletions nntrainer/optimizers/optimizer_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,4 +49,17 @@ void RunOptimizerContext::applyGradient(double lr) const {
void RunOptimizerContext::applyGradient(double lr, Tensor &updated_grad) const {
weight->applyGradient(lr, updated_grad);
}

/**
* @brief Apply loss scale to gradient (full precision)
*/
void RunOptimizerContext::applyLossScale(Tensor &fp32_grad) {
if (!weight->isMixedPrecision())
return;
if (fp32_grad.getDataType() != ml::train::TensorDim::DataType::FP32)
throw std::invalid_argument(
"gradient should be fullprecsion to maintain accuracy");
float loss_scale = weight->getLossScale();
fp32_grad.divide_i(loss_scale);
}
} // namespace nntrainer
5 changes: 5 additions & 0 deletions nntrainer/optimizers/optimizer_context.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ class RunOptimizerContext {
*/
double getLearningRate() const { return learning_rate; }

/**
* @brief Apply loss scale to gradient (full precision)
*/
void applyLossScale(Tensor &fp32_grad);

private:
Weight *weight; /**< weights for the optimizer */
size_t iteration; /**< iteration number */
Expand Down
5 changes: 4 additions & 1 deletion nntrainer/tensor/blas_interface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -842,7 +842,10 @@ void scopy(const unsigned int N, const float *X, const int incX, float *Y,
#ifdef BLAS_NUM_THREADS
openblas_set_num_threads(BLAS_NUM_THREADS);
#endif
cblas_scopy(N, X, incX, Y, incY);
// cblas_scopy(N, (float*)(X), incX, (float*)(Y), incY);
// replace cblas scopy with raw temporary.
for (unsigned int i = 0; i < N; ++i)
Y[i * incY] = X[i * incX];
#else
scopy_raw(N, X, incX, Y, incY);
#endif
Expand Down
1 change: 0 additions & 1 deletion nntrainer/tensor/weight.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,6 @@ Weight::Weight(Tensor *v, Tensor *g, Tensor *v32, const WeightRegularizer reg,
void Weight::applyGradient(double lr, Tensor &updated_grad) {
if (isMixedPrecision() &&
updated_grad.getDataType() == ml::train::TensorDim::DataType::FP32) {
updated_grad.divide(loss_scale);
var32->add_i(updated_grad, -lr);
quantizeWeight();
return;
Expand Down
9 changes: 8 additions & 1 deletion nntrainer/tensor/weight.h
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class Weight : public Var_Grad {
* @return false otherwise
*/
bool isMixedPrecision() const {
return var->getDataType() != ml::train::TensorDim::DataType::FP32;
return ((var->getDataType() != ml::train::TensorDim::DataType::FP32));
}

/**
Expand Down Expand Up @@ -356,6 +356,13 @@ class Weight : public Var_Grad {
*/
void setLossScale(float scale) { loss_scale = scale; };


/**
* @brief get loss scale
*
*/
const float getLossScale() { return loss_scale; };

private:
static constexpr float epsilon = 1e-6; /**< epsilon for zero comparison */
static constexpr float epsilon_decay =
Expand Down
Binary file modified packaging/unittest_layers.tar.gz
Binary file not shown.
Binary file modified packaging/unittest_models_v3.tar.gz
Binary file not shown.
23 changes: 23 additions & 0 deletions test/include/nntrainer_test_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,29 @@ float mse(Ta *A, Tb *B, uint32_t size) {
return mse;
}

/**
* @brief calculate mean squared errer
*
* @param A const prediction data
* @param B const reference data
* @param size data size
* @return mean squared errer value
*/
template <typename Ta = float, typename Tb = float>
float mse(const Ta *A, const Tb *B, uint32_t size) {
float pred;
float ref;
float mse_error = 0;
for (uint32_t i = 0; i < size; i++) {
pred = A[i];
ref = B[i];
float diff = pred - ref;
mse_error += pow(diff, 2);
}
float mse = mse_error / size;
return mse;
}

/**
* @brief A helper struct for performing static_cast operations on types.
*
Expand Down
33 changes: 25 additions & 8 deletions test/input_gen/recorder_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,32 @@


def _get_writer(file):
def write_fn(items, type = "int32"):
def write_fn(items, type = 'float32'):
if not isinstance(items, (list, tuple)):
items = [items]

for item in items:
print(item.numel(), " -0-----")
print(item)
np.array([item.numel()], dtype="int32").tofile(file)
a=np.array(item.detach().cpu(),dtype=type)
np.array([item.numel()], dtype='int32').tofile(file)
a=np.array(item.detach().cpu(), dtype=type)
a.tofile(file)
print(a.dtype)

return items

return write_fn

def _get_writer_mixed(file):
def write_fn(items, num_type = 'int32', type = 'float32'):
if not isinstance(items, (list, tuple)):
items = [items]

for item in items:
print(item.numel(), " -0-----")
print(item)
np.array([item.numel()], dtype=num_type).tofile(file)
a=np.array(item.detach().cpu(), dtype=type)
a.tofile(file)
print(a.dtype)

Expand Down Expand Up @@ -110,9 +127,9 @@ def record_iteration_with_amp(write_fn):
else:
inputs = _rand_like(input_dims, dtype=input_dtype if input_dtype is not None else float)
labels = _rand_like(label_dims, dtype=float)
write_fn(inputs[0])
write_fn(labels[0])
write_fn(list(t for _, t in params_translated(model_)),'float16')
write_fn(inputs,'int32', 'float32')
write_fn(labels, 'int32', 'float32')
write_fn(list(t for _, t in params_translated(model_)),'int16','float16')

output = model_(inputs[0], labels[0])

Expand All @@ -136,14 +153,14 @@ def record_iteration_with_amp(write_fn):
#

scaler.update()
write_fn(output)
write_fn(output,'int32','float32')

with open(file_name, "wb") as f:
# write number of iterations
print("iteration : ", iteration)
np.array([iteration], dtype="int32").tofile(f)

write_fn = _get_writer(f)
write_fn = _get_writer_mixed(f)
for _ in range(iteration):
record_iteration_with_amp(write_fn)

Expand Down
1 change: 1 addition & 0 deletions test/nntrainer_test_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ void sizeCheckedReadTensor(nntrainer::Tensor &t, std::ifstream &file,
nntrainer::checkedRead(file, (char *)&sz, sizeof(unsigned));
} else if (t.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
// This needs to be fixed. sz is always unsinged int type.
nntrainer::checkedRead(file, (char *)&sz, sizeof(_FP16));
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
Expand Down
48 changes: 45 additions & 3 deletions test/unittest/models/models_test_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,41 @@ static sharedConstTensors toSharedTensors(const std::vector<Tensor> &ts) {
static void verify(const nntrainer::Tensor &actual,
const nntrainer::Tensor &expected,
const std::string &error_msg) {
bool equal = false;

if (actual.getDataType() == ml::train::TensorDim::DataType::FP32 &&
expected.getDataType() == ml::train::TensorDim::DataType::FP32) {
equal = (actual == expected);
if (!equal) {
float mseError = mse<float>(actual.getData<float>(),
expected.getData<float>(), actual.size());
if (mseError > 10 - 4) {
equal = false;
} else {
equal = true;
}
}
}

#ifdef ENABLE_FP16
if (!equal) {
if (actual.getDataType() == ml::train::TensorDim::DataType::FP16 &&
expected.getDataType() == ml::train::TensorDim::DataType::FP16) {
float mseError = mse<_FP16>(actual.getData<_FP16>(),
expected.getData<_FP16>(), actual.size());
if (mseError > 10 - 2) {
equal = false;
} else {
equal = true;
}
}
}
#endif

if (!equal) {
nntrainer::Tensor diff = actual.subtract(expected);
const float *diff_data = diff.getData();

if (actual != expected) {
std::cout
<< "============================================================\n";
std::cout << "\033[1;33m" << error_msg << "\033[0m\n";
Expand All @@ -60,8 +93,6 @@ static void verify(const nntrainer::Tensor &actual,
<< " - " << expected;

if (actual.getDim() == expected.getDim()) {
nntrainer::Tensor diff = actual.subtract(expected);
const float *diff_data = diff.getData();
std::cout << "\033[1;33mdifference\033[0m " << diff;
std::cout << "number of data: " << diff.size() << std::endl;
std::cout << "\033[4;33mMAX DIFF: "
Expand Down Expand Up @@ -119,6 +150,12 @@ class IterationForGolden {
}

Tensor &t = rc.getWeight(i);

if (t.getDataType() != ml::train::TensorDim::DataType::FP32) {
Tensor &t32 = rc.getWeightFP32(i);
weights32.push_back(t32);
}

weights.push_back(t);
expected_weights.push_back(t.clone());
}
Expand Down Expand Up @@ -158,6 +195,10 @@ class IterationForGolden {
} else {
for (unsigned int i = 0; i < weights.size(); ++i) {
weights.at(i).fill(expected_weights.at(i));
if (iteration == 0 &&
weights.at(i).getDataType() != ml::train::TensorDim::DataType::FP32)
weights32.at(i).fill(
weights.at(i).clone(ml::train::TensorDim::DataType::FP32));
}
}

Expand All @@ -174,6 +215,7 @@ class IterationForGolden {
std::vector<Tensor> inputs;
std::vector<Tensor> labels;
std::vector<Tensor> weights;
std::vector<Tensor> weights32;
std::vector<Tensor> expected_weights;
std::vector<Tensor> expected_outputs;
};
Expand Down
2 changes: 1 addition & 1 deletion test/unittest/models/unittest_models_mixed_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ using namespace nntrainer;
static std::unique_ptr<NeuralNetwork> fc_mixed_training() {
std::unique_ptr<NeuralNetwork> nn(new NeuralNetwork());
nn->setProperty(
{"batch_size=1", "model_tensor_type=FP16-FP16", "loss_scale=128"});
{"batch_size=1", "model_tensor_type=FP16-FP16", "loss_scale=65536"});

auto graph = makeGraph({
{"input", {"name=in", "input_shape=1:1:3"}},
Expand Down
Loading

0 comments on commit 8030518

Please sign in to comment.