-
Notifications
You must be signed in to change notification settings - Fork 171
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Machine Learning Provider Abstraction Layer #70
Comments
Thanks for bringing this up, we are very interested in this question. Providing unified APIs and data loading procedures is one of the areas where we can add value compared to what is already out there in terms of deep learning libraries. Data loading/processing is on of Spark's main strength. Let us know about your suggestions, our current plan is to provide interfaces that can be implemented by various backends. For the network, the interface would look like this: trait NetInterface {
def forward(rowIt: Iterator[Row]): Array[Row]
def forwardBackward(rowIt: Iterator[Row])
def getWeights(): WeightCollection
def setWeights(weights: WeightCollection)
def outputSchema(): StructType
} For the Solver: trait Solver {
def step(rowIt: Iterator[Row])
} Data would be loaded in a unified way from Spark DataFrames. We are working on this in the javacpp+dataframes branch, see for example this file. |
Awesome! This may be a fair bit less complicated than anticipated from our perspective. I am interested in trying out that trait with another ml library. Please suggest which test(s) to run that would best validate the usability of your NetInterface with the OtherMlLibrary framework. |
Thanks, the least complicated approaches are often the best. I can sketch how we plan to implement the interface for TensorFlow. Assume you have a TensorFlow graph definition like this (in Python): import tensorflow as tf
sess = tf.InteractiveSession()
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.nn.softmax(tf.matmul(x, W) + b)
tf.initialize_all_variables().run() You can then serialize the graph and the weights in the following way:
SparkNet would provide a TensorFlowNet class which implements the Net trait and as a constructor takes the protocol buffer definition that is generated by g.as_graph_def(). Furthermore, there would be a procedure for loading the weights saved by tf.train.Saver into a WeightCollection object and an implementation of setWeights that loads the weights into the Network. If you are interested in pursuing this, you can start from the JavaCPP TensorFlow implementation and implement the TensorFlowNet as well as the TensorFlowSolver class. This is high priority for us, but before we get to it we would like to improve a few other things first. |
OK i will first dig a bit into the javacpp-presets as a background and then 2016-02-16 21:52 GMT-08:00 Philipp Moritz [email protected]:
|
Great, any progress on this will be very helpful for the project, and don't hesitate to ask questions if you run into problems. We have a bunch of experience with JavaCPP by now and might be able to help you. To get started, you can both try to run the ExampleTrainer.java from the TensorFlow preset, and also our Cifar training app in the SparkNet javacpp+dataframes branch. It is almost ready to merge now, we just haven't gotten around to create the AMI yet. |
I am a coder for a team looking to consider using SparkNet with another ML library besides caffe. The intent of this Issue is to capture discussions on a ML Provider Abstraction Layer (MLPAL?) that would permit pluggable use of Caffe vs SomeOtherMLLibrary.
To the core committers: do you already have thoughts and/or a Roadmap for this? In any case our thoughts will start appearing here.
The text was updated successfully, but these errors were encountered: