-
Notifications
You must be signed in to change notification settings - Fork 298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Optimize for NVIDIA cards using TensorRT #52
Comments
I'm not familiar with TensorRT at all, but it seems very powerful for Pascal architectures and beyond. If anyone has experience, please share your knowledge on how useful this is. http://docs.nvidia.com/deeplearning/sdk/tensorrt-developer-guide/index.html |
I ported the current network to caffe for some timing tests. With the heads still on the CPU for the OpenCL implementation I'm obtaining 926 n/s. Using Caffe with CUDNN I'm able todo forward inferencing in minibatches of 8 at 2.5ms which corresponds to 3200 n/s, assuming CPU can keep up.
Caffe model here: chessnet-mini.zip next up TensorRT! |
Note that minibatches aren't a free speedup. If you do them for a single instance, the search become less efficient. If you want to batch over multiple running games, the implementation becomes trickier. |
One important thing to look at with TensorRT is the licensing:
|
Hmmm this seems counter intuitive to me. Why does the search become less efficient? Assuming multiple threads can traverse the tree and add positions to some GPU queue.
Yeah this makes me sad :( Did you find any speedup comparing caffe to the current OpenCL implementation? |
The search algorithm is serial. There is a single best node to explore in the UCT algorithm, so with a naive implementation they would all end up expanding the same node or blocking on each other. Virtual losses improve upon this, by lowering the score of the move, so the next-best moves are investigated by other threads (assuming that eventually we might want to get to those too - which isn't always true, and in some cases we'd really have wanted more info on the best move instead). Now, with blocking, not only will you incur the efficiency penalty from virtual loss, you will incur an additional penalty because the first move being evaluated will block until the last move is queued. Thus, on each search decision, you will not have info on the last batch_size - 1 UCT evaluations (whereas with simple threading it will be from at best 0 to worst case =threads), which coincidentally happened to be the most critical ones for deciding where to search now. For self-players, this problem disappears if you play batch_size games at the same time, because they will each behave as if they had a single-threaded search. |
I haven't tested recently (I have a caffe backend for Leela, but not Leela Zero) but for batch_size = 1 cuDNN was only a small bit faster. For small networks the overhead of OS, GPU drivers etc becomes rather big and dominates. This is especially true on Windows 10, it seems. For large networks and bigger batches I would think cuDNN should have a sizeable advantage. The current winograd in/out transforms aren't optimized, and despite the automated tuner I think cuBLAS (which backs cuDNN) is so optimized that it should beat CLBlast easily. You might very well win enough performance, especially with small networks, that batching search losses are offset by performance gains.
I've been told once you can actually ask NVIDIA to get permissions to redistribute, but that won't solve the underlying licensing incompatibility for distribution. |
How do people feel about Keras as a higher-level NN framework? |
There is a BSD licensed header-only library from nvidia |
Well, indeed it's serial but only with a perfect NN would there be a true single best node. But you know all this very well.
I also thought about this, it would be near optimal during training. But during play we'd also would like to use a multithreaded solution. How about the following: Given a state s and it's sorted legal actions A, we could impose a virtual loss iff the first sorted n move probabilities are approximately equal, where n is an integer > 1. Would that make sense @gcp ? |
Virtual loss already does that naturally, I think. If the scores are not close, the virtual loss won't be enough to discourage the search from the node. So bad moves still aren't searched. Not sure theorizing (or trusting my intuition!) about this is worth a lot IMHO. This is the kind of stuff that needs testing. And maybe the results will be a bit unclear: leela-zero/leela-zero#526 (We ended up taking the multi-GPU support from that effort) The results might even differ depending on the OS. Windows 10 has very high overhead for GPU calls, so it would benefit more from batching them than say Linux.
Well multi-threading you can already do, the question is about whether you want to block the threads until you have a batch or not. |
Ah of course... But then I wonder why it's set to 3? That would imply two threads never going down the same node, unless #threads > #actions for a certain node. Why is it so high? edit: Maybe I should study the theory again. |
A single loss isn't necessarily enough to disqualify a node if it has a much higher score than the next best one. |
@brianprichardson, I don't see a need for it tbh? Why do you ask? Also it seems a bit off topic here. @vondele thanks, that might be interesting! I'll first run some experiments just to see if this is a worthy effort at all. I think we should first put the entire net on the GPU with the OpenCL implementation #51. And then experiment with a queue. @gcp I see, I made the wrong assumption with respect to the virtual loss and what it's added to without looking at the code. This renders my whole discussion about it invalid, sorry. I was wondering what GPU you use yourself? Just curious. |
@Error323 I mentioned Keras after Caffe came up. I like Keras as it runs with numerous underlying learning environments. Perhaps I misunderstand the direction of this issue. I find Keras easier to understand than more crafted (and optimized) NN frameworks. Of course, as long as a large number of people can easily contribute to the distributed phase, it really does not matter what is under the covers. |
@brianprichardson ah I see, I agree, Keras is nice and clean for creating a model. But here I was talking about the inferencing part in c++ indeed. |
@gcp, @glinscott I did try TensorRT with int8 using giexec, results are quite amazing:
Certainly the OpenCL implementation is severly underestimated and I expect it to be in the order of 2500 nps with batchsize 8 and heads on the GPU. Another thing that was interesting to me is that the INT8 version didn't seem to help at all for small batch sizes, given this network. I expect INT8 to become more usefull as the network increases in complexity. Used the following commands for TensorRT: /usr/src/tensorrt/bin/giexec --deploy=/home/fhuizing/chessnet-mini.prototxt --batch=8 --output=score --output=pol_out --avgRuns=1000
/usr/src/tensorrt/bin/giexec --deploy=/home/fhuizing/chessnet-mini.prototxt --batch=8 --int8 --output=score --output=pol_out --avgRuns=1000 Edit: From this I conclude that we should support the possibility of TensorRT for NVIDIA users who'd want to compile stuff themselves. I think most of the people willing to help out in the trainingphase are programmers anyway. Edit2: Note that the results for caffe and tensorrt were obtained using random weights. |
I have a pile of different ones, though mostly NVIDIA now after I got a bit tired of the amount of bugs in AMD's OpenCL implementation. The amazing thing from your measurement is that TensorRT FLT32 outperforms cuDNN by a factor 8. |
TensorRT does a lot of optimizations and fusions to the network when creating the plan - final network. The caffe network is just as is. That's probably the main cause. Ideally, we'd write something for OpenCL that also applies these optimizations like TensorRT. But that's a lot of work probably. |
In Leela Zero there is now a fused output-transform-batchnorm-relu-input-transform kernel. But this is like a 10% speedup. I'm not sure where the factor 8 can come from, but I suspect it's the mainline Caffe not selecting the best convolution kernel (which NVIDIA's branch will do fine). |
Just compiled NVidia's caffe with optimal settings afaik, but it didn't show an improvement:
|
Ok then TensorRT is just magic. |
But seriously, an order of magnitude is no joke. It'll probably be even more substantial for Go. What do you think @gcp? I'll try to implement something, do you happen to have a conversion of your weights.txt to caffe? Do you agree with this? |
I don't know, I'm puzzled at where the factor 8 comes from. Depending on TensorRT (or cuDNN, for that matter) makes the distribution of the whole thing impossible, and while it's nice for people who download and compile everything themselves, or when running on the cloud, it does limit the wider usability. That's why I'd be interested in understanding the difference. We won't ever equal cuBLAS in OpenCL so cuDNN should always be a bit ahead. But where does the additional x8 speedup come from?
No, I used to do weight exports with a rather hacked up version of (old) Leela that was intertwined with NVIDIA-caffe. I don't have import and I rewrote everything in TensorFlow partly to get a clean solution. Perhaps measuring cuDNN/forward performance in TensorFlow is just as easy? |
Agreed, this would be very useful to know. Let me first see if I can make a version with TensorRT if the outcomes are correct whilest still maintaining the speedup.
When doing Tensorflow training on a batch size of 8 it computes ~1400 pos/s. This is forward and backward. Given caffe's timed benchmarks a forward pass is slightly slower than backwards. So we should multiply by 1.8 and this gives ~2500 nps. |
Bah so far I'm failing in two cases. First of all I'm unable to convert a tensorflow model to uff file for importing into TensorRT. The Merge operator is not supported, officially TensorRT only supports up to tensorflow 1.3. Secondly I wanted to obtain a frozen model in tensorflow and perform inference and timing benchmarks on that model. But this fails because of my lack of knowledge and I was hoping that someone could help here. I've added the following function to def save_frozen_model(self, filename):
# We use a built-in TF helper to export variables to constants
output_graph_def = tf.graph_util.convert_variables_to_constants(
self.session, # The session is used to retrieve the weights
tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
['policy_head', 'value_head'] # The output node names are used to select the usefull nodes
)
# Finally we serialize and dump the output graph to the filesystem
with tf.gfile.GFile(filename, "wb") as f:
f.write(output_graph_def.SerializeToString())
print("%d ops in the final graph." % len(output_graph_def.node)) This works fine, it generates a frozen model with the following output:
But for inferencing, it just fails hard whatever I try. I added a new script called import tensorflow as tf
import parse
import numpy as np
import argparse
def load_graph(frozen_graph_filename):
# We load the protobuf file from the disk and parse it to retrieve the
# unserialized graph_def
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# Then, we import the graph_def into a new Graph and returns it
with tf.Graph().as_default() as graph:
# The name var will prefix every op/nodes in your graph
# Since we load everything in a new graph, this is not needed
tf.import_graph_def(graph_def)
return graph
if __name__ == '__main__':
# Let's allow the user to pass the filename as an argument
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", default="results/frozen_model.pb", type=str, help="Frozen model file to import")
args = parser.parse_args()
# We use our "load_graph" function
graph = load_graph(args.frozen_model_filename)
# We can verify that we can access the list of operations in the graph
for op in graph.get_operations():
print(op.name)
# prefix/Placeholder/inputs_placeholder
# ...
# prefix/Accuracy/predictions
pos, probs, winner = parse.generate_fake_pos()
pos = np.reshape(pos, (-1, 120, 8, 8))
print(pos.shape)
# We access the input and output nodes
training = graph.get_tensor_by_name('import/Placeholder_3:0')
x = graph.get_tensor_by_name('import/input_conv:0')
y = graph.get_tensor_by_name('import/policy_head:0')
z = graph.get_tensor_by_name('import/value_head:0')
# We launch a Session
with tf.Session(graph=graph) as sess:
y_out, z_out = sess.run([y, z], feed_dict={
x: pos, training: False
})
print(y_out)
print(z_out) The error in question is:
Do you any of you guys reading this understand what I'm doing wrong? @gcp @glinscott ? |
Oh, I figured the tensorflow part out. The real input is the |
tensorflow/tensorflow#16253 it seems tensorflow is integrating TensorRT together with NVIDIA. |
Yeah sorry that was way to vague. @gcp let me explain properly. I think the tensornames are at the wrong places for the graph constructed by diff --git a/training/tf/tfprocess.py b/training/tf/tfprocess.py
index 24130b8..83f33b6 100644
--- a/training/tf/tfprocess.py
+++ b/training/tf/tfprocess.py
@@ -63,7 +63,7 @@ class TFProcess:
self.x = next_train_batch[0] # tf.placeholder(tf.float32, [None, 120, 8 * 8])
self.y_ = next_train_batch[1] # tf.placeholder(tf.float32, [None, 1924])
self.z_ = next_train_batch[2] # tf.placeholder(tf.float32, [None, 1])
- self.training = tf.placeholder(tf.bool)
+ self.training = tf.placeholder(tf.bool, name='input_training')
self.batch_norm_count = 0
self.y_conv, self.z_conv = self.construct_net(self.x)
@@ -283,6 +283,19 @@ class TFProcess:
#v = self.session.run(output, feed_dict={self.x:data, self.training:False})
#print('input_conv', v)
+ def save_frozen_model(self, filename):
+ # We use a built-in TF helper to export variables to constants
+ output_graph_def = tf.graph_util.convert_variables_to_constants(
+ self.session, # The session is used to retrieve the weights
+ tf.get_default_graph().as_graph_def(), # The graph_def is used to retrieve the nodes
+ ['policy_head', 'value_head'] # The output node names are used to select the usefull nodes
+ )
+
+ # Finally we serialize and dump the output graph to the filesystem
+ with tf.gfile.GFile(filename, "wb") as f:
+ f.write(output_graph_def.SerializeToString())
+ print("%d ops in the final graph." % len(output_graph_def.node))
+
def get_batchnorm_key(self):
result = "bn" + str(self.batch_norm_count)
self.batch_norm_count += 1
@@ -356,13 +369,12 @@ class TFProcess:
# NCHW format
# batch, 120 input channels, 8 x 8
- x_planes = tf.reshape(planes, [-1, 120, 8, 8])
+ x_planes = tf.reshape(planes, [-1, 120, 8, 8], name='input_planes')
# Input convolution
flow = self.conv_block(x_planes, filter_size=3,
input_channels=120,
- output_channels=RESIDUAL_FILTERS,
- name='input_conv')
+ output_channels=RESIDUAL_FILTERS)
# Residual tower
for _ in range(0, RESIDUAL_BLOCKS):
flow = self.residual_block(flow, RESIDUAL_FILTERS) In doing so, one can perform inference on a frozen model: import tensorflow as tf
import parse
import numpy as np
import argparse
def load_graph(frozen_graph_filename):
with tf.gfile.GFile(frozen_graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
return graph
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--frozen_model_filename", type=str, help="Frozen model file to import")
args = parser.parse_args()
graph = load_graph(args.frozen_model_filename)
pos, probs, winner = parse.generate_fake_pos()
pos = np.reshape(pos, (-1, 120, 8, 8))
print(pos.shape)
# We access the input and output nodes
training = graph.get_tensor_by_name('import/input_training:0')
x = graph.get_tensor_by_name('import/input_planes:0')
y = graph.get_tensor_by_name('import/policy_head:0')
z = graph.get_tensor_by_name('import/value_head:0')
# We launch a Session
with tf.Session(graph=graph) as sess:
y_out, z_out = sess.run([y, z], feed_dict={
x: pos, training: False
})
print(y_out.shape)
print(z_out.shape) |
@Error323 Have you managed to convert frozen tensorflow model to uff format eventually? |
I also cannot convert my resnet50 model trained from keras implementation to uff: unsupported operation: Merge. |
https://developer.nvidia.com/tensorrt should be able to give significant performance gains when doing inference.
The text was updated successfully, but these errors were encountered: