Skip to content

Commit

Permalink
Reorder loops in im2col_2d_cl given resource strategy issue. Reenable…
Browse files Browse the repository at this point in the history
… relevant test. Use 5000 MNIST samples rather than full dataset for faster testing
  • Loading branch information
thesps committed Nov 10, 2021
1 parent 7f75add commit fa1ff24
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
13 changes: 5 additions & 8 deletions hls4ml/templates/vivado/nnet_utils/nnet_conv2d_resource.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,20 +154,18 @@ void im2col_2d_cl(
const int col)
{
int index = 0;
for (int channel = CONFIG_T::n_chan; channel--; data++) {
for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) {
#pragma HLS UNROLL
for (int kernel_row = 0; kernel_row < CONFIG_T::filt_height; kernel_row++) {
int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height;
for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) {
int input_row = -CONFIG_T::pad_top + kernel_row * CONFIG_T::dilation_height + row * CONFIG_T::stride_height;
for (int kernel_col = 0; kernel_col < CONFIG_T::filt_width; kernel_col++) {
for (int channel = 0; channel < CONFIG_T::n_chan; channel++) {
if (input_row < 0 || input_row >= CONFIG_T::in_height) {
data_col[index++] = 0;
} else {
int input_col = -CONFIG_T::pad_left + kernel_col * CONFIG_T::dilation_width + col * CONFIG_T::stride_width;
if (input_col >= 0 && input_col < CONFIG_T::in_width) {
//*(data_col++) = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan];
data_col[index++] = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan];
data_col[index++] = data[input_row * CONFIG_T::in_width * CONFIG_T::n_chan + input_col * CONFIG_T::n_chan + channel];
} else {
//*(data_col++) = 0;
data_col[index++] = 0;
}
}
Expand Down Expand Up @@ -209,7 +207,6 @@ void conv_2d_resource_cl(
FiltLoop:
for (int k = 0; k < CONFIG_T::n_filt; k++) {
res[i * CONFIG_T::out_width * CONFIG_T::n_filt + j * CONFIG_T::n_filt + k] = res_col[k];
//res[k * CONFIG_T::out_height * CONFIG_T::out_width + i * CONFIG_T::out_width + j] = res_col[k]; // Transposed order
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions test/pytest/test_cnn_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,9 @@ def mnist_model():
model.load_weights('../../example-models/keras/qkeras_mnist_cnn_weights.h5')
return model

# TODO: add ('io_parallel', 'resource') when it can pass
# https://github.com/fastmachinelearning/hls4ml/issues/375
@pytest.fixture
@pytest.mark.parametrize('settings', [('io_parallel', 'latency'),
('io_parallel', 'resource'),
('io_stream', 'latency'),
('io_stream', 'resource')])
def hls_model(settings):
Expand All @@ -49,10 +48,12 @@ def hls_model(settings):
return hls_model

@pytest.mark.parametrize('settings', [('io_parallel', 'latency'),
('io_parallel', 'resource'),
('io_stream', 'latency'),
('io_stream', 'resource')])
def test_accuracy(mnist_data, mnist_model, hls_model):
x_train, y_train, x_test, y_test = mnist_data
x_test, y_test = x_test[:5000], y_test[:5000]
model = mnist_model
# model under test predictions and accuracy
y_keras = model.predict(x_test)
Expand Down

0 comments on commit fa1ff24

Please sign in to comment.