Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

New operators linalg_syrk, linalg_gelqf #7741

Merged
merged 1 commit into from
Sep 6, 2017
Merged

New operators linalg_syrk, linalg_gelqf #7741

merged 1 commit into from
Sep 6, 2017

Conversation

mseeger
Copy link
Contributor

@mseeger mseeger commented Sep 5, 2017

  • Works for CPU only for now. We will supply the GPU versions next
  • Added dtype parameters to unit test support code, so we can run numerical tests in float64

@asmushetzel


Examples::

// Single matrix multiply
A = [[1.0, 1.0], [1.0, 1.0]]
B = [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]
gemm2(A, B, transpose_b = 1, alpha = 2.0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apparently some general renaming has happened recently for namespace reasons. All documentation now uses function names without linalg_prefix.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will change the docstrings

@@ -251,42 +253,42 @@ NNVM_REGISTER_OP(_backward_linalg_potri)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 3, 1, potri_backward>);

NNVM_REGISTER_OP(_linalg_trmm)
NNVM_REGISTER_OP(linalg_trmm)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All these operators are now registered with an initial underscore followed by an alias for backward compatibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, I will do this

@@ -420,5 +421,134 @@ NNVM_REGISTER_OP(_backward_linalg_sumlogdiag)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FCompute>("FCompute<cpu>", LaOpBackward<cpu, 2, 2, 2, 1, sumlogdiag_backward>);

NNVM_REGISTER_OP(linalg_syrk)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initial underscore followed by alias (as well as in the other new operators)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, will do this

if ( ndim < dim ) {
return false;
}
CHECK_GE(ndim, dim) << "Shape of input has too few dimensions";
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please leave original code. We may have ndim < dim in the case where the input shape has not yet been defined at all (but may be defined later). Means at this point in time we can't do any inference but not necessary that we can't ever (in a subsequent pass).

// CPU/GPU-versions of LAPACK functions "gelqf", "orglq". Please refer to the
// LAPACK documentation for further details.
// Note:
// - The current implementation works for CPU only. In particular, when called
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are no separate batch mode functions for this. So we should not refer to batch mode functions in this comment or add batch mode functions.

@@ -526,6 +526,8 @@ Composite multiple symbols into a new one by an operator.
linalg_trmm
linalg_trsm
linalg_sumlogdiag
linalg_syrk
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wrong indentation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@piiswrong
Copy link
Contributor

what's the status on this?

@mseeger
Copy link
Contributor Author

mseeger commented Sep 6, 2017

I managed to run all unit tests in continuous-integration/jenkins/pr-head, except for the R ones. They fail for funny reasons, which seem nothing to do with my changes.

I am trying again.

@mseeger
Copy link
Contributor Author

mseeger commented Sep 6, 2017

Unit tests for esoteric APIs (Perl, R, Scala) fail with:

Remote call on mxnet14 failed

I fail to see what this has to do with my code changes. Not interested in these APIs (or should I be?).

@asmushetzel
Copy link
Contributor

The R/Perl-unit tests were failing because of a git-issue. All other ones did pass. I did a full review and IMO all is fine now.

There will be a followup PR by me in a week or so that will bring in the CUDA support for these new operators.

So from my point of view, this can be integrated. Just would like to have confirmation from Matthias that the documentation changes look as expected when formatted.

@mseeger
Copy link
Contributor Author

mseeger commented Sep 6, 2017

Yes I confirmed all docstrings in the tool you sent me, they all look fine.

@piiswrong piiswrong merged commit 541158d into apache:master Sep 6, 2017
@asmushetzel
Copy link
Contributor

Thank you Eric for the incredible turnaround time

@@ -526,6 +526,8 @@ Composite multiple symbols into a new one by an operator.
linalg_trmm
linalg_trsm
linalg_sumlogdiag
linalg_syrk
linalg_gelqf
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume ndarray.md should be updated, too?

@mseeger mseeger deleted the mseeger-linalg-newops branch September 7, 2017 07:42
const Tensor<cpu, 2, DType>& B, DType alpha, \
DType beta, bool tA, Stream<cpu> *s) { \
check_syrk(A, B, alpha, beta, tA); \
cblas_##fname(CblasRowMajor, CblasLower, (tA ? CblasTrans : CblasNoTrans), \
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, MXNet.jl got compilation error when build from MXNet's master:
https://travis-ci.org/dmlc/MXNet.jl/jobs/272939237#L1213

any idea?

src/operator/contrib/./../linalg_impl.h:782:1: note: in expansion of macro ‘LINALG_CPU_SYRK’

 LINALG_CPU_SYRK(ssyrk, float)

 ^

src/operator/contrib/./../linalg_impl.h: In function ‘void linalg_syrk(const mshadow::Tensor<Device, 2, DType>&, const mshadow::Tensor<Device, 2, DType>&, DType, DType, bool, mshadow::Stream<Device>*) [with xpu = mshadow::cpu; DType = double]’:

src/operator/contrib/./../linalg_impl.h:748:61: error: ‘cblas_dsyrk’ was not declared in this scope

                 A.dptr_, A.stride_, beta, B.dptr_, B.stride_); \

                                                             ^

src/operator/contrib/./../linalg_impl.h:783:1: note: in expansion of macro ‘LINALG_CPU_SYRK’

 LINALG_CPU_SYRK(dsyrk, double)

 ^

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hello, {s|d}syrk is a BLAS function, which can be called as cblas_{s|d}syrk (via the cblas interface).
For whatever reason, this does not work in your particular setup. What is quite odd about this is that linalg_impl.h calls a range of other cblas_XXX functions:
cblas_*gemm, cblas_*trmm, cblas_*trsm (where * in s, d).

AFAIK, cblas has all of these AND cblas_*syrk.

Are you getting errors for the other cblas calls as well?

Maybe @asmushetzel has an idea what is going on here?

Copy link
Member

@iblislin iblislin Sep 8, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I finally figured out that MXNet.jl needs a modified cblas.h: https://github.com/dmlc/MXNet.jl/blob/master/deps/cblas.h

I will manage to send a patch for that. @mseeger thanks for your time and explanation!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems that your cblas.h is edited, right? This is not good. Maybe find out what Julia is doing there?
We may rely on further blas functions in the future, so better try and get the correct cblas.h being used

@asmushetzel
Copy link
Contributor

asmushetzel commented Sep 8, 2017

Looks as if we don't have testing Julia-build on Jenkins. That's why such things slip through.
Apparently the Julia-build maintains its ow hand-written cblas-include file.
Chris, can you drive sufficient automatic testing for this build?

@mseeger
Copy link
Contributor Author

mseeger commented Sep 8, 2017

This was also my thought. Why do they edit cblas.h? Just to save some compile time?

cjolivier01 pushed a commit to cjolivier01/mxnet that referenced this pull request Sep 11, 2017
mbaijal pushed a commit to mbaijal/incubator-mxnet that referenced this pull request Sep 19, 2017
mbaijal pushed a commit to mbaijal/incubator-mxnet that referenced this pull request Sep 19, 2017
mbaijal pushed a commit to mbaijal/incubator-mxnet that referenced this pull request Sep 20, 2017
@iblislin iblislin mentioned this pull request Oct 8, 2017
crazy-cat pushed a commit to crazy-cat/incubator-mxnet that referenced this pull request Oct 26, 2017
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants