diff --git a/trt_dev/RMSNorm/RMSNorm.ipynb b/trt_dev/RMSNorm/RMSNorm.ipynb new file mode 100644 index 0000000..fd93900 --- /dev/null +++ b/trt_dev/RMSNorm/RMSNorm.ipynb @@ -0,0 +1,394 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0e7c0a0e", + "metadata": {}, + "source": [ + "# RMSNorm with TensorRT" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "19e3269d", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "from cuda import cudart\n", + "import torch\n", + "from torch import Tensor, nn\n", + "import tensorrt as trt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "07811b68", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "PyTorch version: 2.1.0a0+4136153\n", + "TensorRT version: 8.6.1\n" + ] + } + ], + "source": [ + "print(\"PyTorch version: \" + torch.__version__)\n", + "print(\"TensorRT version: \" + trt.__version__)" + ] + }, + { + "cell_type": "markdown", + "id": "6ee4707c", + "metadata": {}, + "source": [ + "## Generate input and data shape" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "9499ca7d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputH0 : (1, 2, 2)\n", + "[[[0. 1.]\n", + " [2. 3.]]]\n" + ] + } + ], + "source": [ + "# Input tensor shape NCHW\n", + "nIn, hIn, wIn = 1, 2, 2\n", + "\n", + "# Output tensor shape C\n", + "cOut = 2\n", + "\n", + "# Input tensor\n", + "data = np.arange(hIn * wIn, dtype=np.float32).reshape(nIn, hIn, wIn)\n", + "\n", + "# fully connected weight\n", + "weight = np.ones(cOut * hIn * wIn, dtype=np.float32).reshape(cOut, hIn * wIn)\n", + "\n", + "# fully connected bias\n", + "bias = np.zeros(cOut, dtype=np.float32)\n", + "\n", + "print(\"inputH0 :\", data.shape)\n", + "print(data)" + ] + }, + { + "cell_type": "markdown", + "id": "18128183", + "metadata": {}, + "source": [ + "## 1. RMSNorm by PyTorch " + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "9318f587", + "metadata": {}, + "outputs": [], + "source": [ + "class RMSNorm(nn.Module):\n", + " def __init__(self, dim: int, eps: float = 1e-6):\n", + " super().__init__()\n", + " self.eps = eps\n", + " # The gamma parameter\n", + " self.weight = nn.Parameter(torch.ones(dim))\n", + "\n", + " def _norm(self, x: torch.Tensor):\n", + " # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)\n", + " # rsqrt: 1 / sqrt(x)\n", + " return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n", + "\n", + " def forward(self, x: torch.Tensor):\n", + " # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)\n", + " return self.weight * self._norm(x.float()).type_as(x)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "9f31c1c4", + "metadata": {}, + "outputs": [], + "source": [ + "def test_torch(nIn, hIn, wIn, cOut, raw_data, weight, bias):\n", + " data = torch.tensor(raw_data).reshape(-1)\n", + " \n", + " model = RMSNorm(1)\n", + "\n", + " output = model(data)\n", + "\n", + " return output" + ] + }, + { + "cell_type": "markdown", + "id": "a6e87667", + "metadata": {}, + "source": [ + "## PyTorch Testing" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "b5446246", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "RMSNorm_output_torch : torch.Size([4])\n", + "tensor([0.0000, 0.5345, 1.0690, 1.6036], grad_fn=)\n" + ] + } + ], + "source": [ + "torch_output = test_torch(nIn, hIn, wIn, cOut, data, weight, bias)\n", + "print(\"RMSNorm_output_torch :\", torch_output.shape)\n", + "print(torch_output)" + ] + }, + { + "cell_type": "markdown", + "id": "a4ec993e", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "6b7fa9b4", + "metadata": {}, + "source": [ + "## 2. RMSNorm with TensorRT" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6431574b", + "metadata": {}, + "outputs": [], + "source": [ + "def trt_create(nIn, hIn, cOut, weight, bias):\n", + " # Config TensorRT Logger, Builder, Network\n", + " logger = trt.Logger(trt.Logger.ERROR)\n", + " builder = trt.Builder(logger)\n", + "\n", + " network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n", + " config = builder.create_builder_config()\n", + "\n", + " # input\n", + " inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (nIn, -1, hIn))\n", + "\n", + " # dynamic shape optimization\n", + " profile = builder.create_optimization_profile();\n", + " profile.set_shape(\"inputT0\", (nIn, 1, hIn), (nIn, 2, hIn), (nIn, 3, hIn)) \n", + " config.add_optimization_profile(profile)\n", + "\n", + " # RMSNorm Layer: 1) Square: X^2 -> 2) Sum: sum of all x^2 -> 3) Mean: 1/N -> 4) Root: sqrt(X) -> 5) Division: 1/X\n", + " print(\"inputT0.shape :\")\n", + " print(inputT0.shape)\n", + " # 1) Square: X^2\n", + " RMSNorm_Square_layer = network.add_elementwise(inputT0, inputT0, op=trt.ElementWiseOperation.PROD)\n", + " print(\"RMSNorm_Square_layer.get_output(0).shape :\")\n", + " print(RMSNorm_Square_layer.get_output(0).shape)\n", + " # 2) Sum: sum of all X^2\n", + " RMSNorm_Sum_layer = network.add_reduce(RMSNorm_Square_layer.get_output(0), op=trt.ReduceOperation.SUM, axes=1, keep_dims=True)\n", + " print(\"RMSNorm_Sum_layer.get_output(0).shape :\")\n", + " print(RMSNorm_Sum_layer.get_output(0).shape)\n", + " # 3) Mean: 1/N\n", + " RMSNorm_Mean_layer = network.add_reduce(RMSNorm_Sum_layer.get_output(0), op=trt.ReduceOperation.AVG, axes=7, keep_dims=True)\n", + " print(\"RMSNorm_Mean_layer.get_output(0).shape :\")\n", + " print(RMSNorm_Mean_layer.get_output(0).shape)\n", + " # 4) Root: sqrt(X)\n", + " RMSNorm_Sqrt_layer = network.add_unary(RMSNorm_Mean_layer.get_output(0), op=trt.UnaryOperation.SQRT)\n", + " print(\"RMSNorm_Sqrt_layer.get_output(0).shape :\")\n", + " print(RMSNorm_Sqrt_layer.get_output(0).shape)\n", + " # 5) Division: 1/X\n", + " RMSNorm_Div_layer = network.add_elementwise(inputT0, RMSNorm_Sqrt_layer.get_output(0), op=trt.ElementWiseOperation.DIV)\n", + " print(\"RMSNorm_Div_layer.get_output(0).shape :\")\n", + " print(RMSNorm_Div_layer.get_output(0).shape)\n", + " # output\n", + " network.mark_output(RMSNorm_Div_layer.get_output(0))\n", + "\n", + " engineString = builder.build_serialized_network(network, config)\n", + " \n", + " return engineString" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "4807937f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "inputT0.shape :\n", + "(1, -1, 2)\n", + "RMSNorm_Square_layer.get_output(0).shape :\n", + "(1, -1, 2)\n", + "RMSNorm_Sum_layer.get_output(0).shape :\n", + "(1, -1, 2)\n", + "RMSNorm_Mean_layer.get_output(0).shape :\n", + "(1, 1, 1)\n", + "RMSNorm_Sqrt_layer.get_output(0).shape :\n", + "(1, 1, 1)\n", + "RMSNorm_Div_layer.get_output(0).shape :\n", + "(1, -1, 2)\n" + ] + } + ], + "source": [ + "trt_engineStr = trt_create(nIn, hIn, cOut, weight, bias)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "3580ecec", + "metadata": {}, + "outputs": [], + "source": [ + "def trt_inference(nIn, hIn, cOut, engineString, raw_data):\n", + " print(engineString)\n", + " print(\"Runtime\")\n", + " logger = trt.Logger(trt.Logger.ERROR)\n", + " engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)\n", + " context = engine.create_execution_context()\n", + "\n", + " # dynamic shape configure\n", + " print(\"Set input shape\")\n", + " context.set_input_shape(\"inputT0\", (nIn, 2, hIn))\n", + " context.set_binding_shape(0, (nIn, 2, hIn))\n", + " origin_inputshape = context.get_binding_shape(0)\n", + "\n", + " print(\"Set input shape completed\")\n", + "\n", + " data = np.array(raw_data)\n", + "\n", + " _, stream = cudart.cudaStreamCreate()\n", + " print(\"Reshaping\")\n", + "\n", + " inputH0 = np.ascontiguousarray(data.reshape(-1))\n", + " outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))\n", + " print(\"Reshaped\")\n", + "\n", + " # initialize input and output data\n", + " _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)\n", + " _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)\n", + "\n", + " # move input to device\n", + " cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)\n", + "\n", + " # execute\n", + " print(\"execute\")\n", + " context.execute_async_v2([int(inputD0), int(outputD0)], stream)\n", + "\n", + " # move output back to host\n", + " cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)\n", + "\n", + " # wait for everything\n", + " cudart.cudaStreamSynchronize(stream)\n", + "\n", + " cudart.cudaStreamDestroy(stream)\n", + " cudart.cudaFree(inputD0)\n", + " cudart.cudaFree(outputD0)\n", + "\n", + " return outputH0" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "a2a806c0", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "Runtime\n", + "Set input shape\n", + "Set input shape completed\n", + "Reshaping\n", + "Reshaped\n", + "execute\n", + "output_trt : (4,)\n", + "[0. 0.5345225 1.069045 1.6035675]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_1410/3070280412.py:11: DeprecationWarning: Use set_input_shape instead.\n", + " context.set_binding_shape(0, (nIn, 2, hIn))\n", + "/tmp/ipykernel_1410/3070280412.py:12: DeprecationWarning: Use get_tensor_shape instead.\n", + " origin_inputshape = context.get_binding_shape(0)\n", + "/tmp/ipykernel_1410/3070280412.py:22: DeprecationWarning: Use get_tensor_shape instead.\n", + " outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))\n", + "/tmp/ipykernel_1410/3070280412.py:22: DeprecationWarning: Use get_tensor_dtype instead.\n", + " outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))\n" + ] + } + ], + "source": [ + "trt_output = trt_inference(nIn, hIn, cOut, trt_engineStr, data)\n", + "trt_output = trt_output.reshape(-1)\n", + "print(\"output_trt :\", trt_output.shape)\n", + "print(trt_output)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9d007208", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/trt_dev/RMSNorm/trt_RMSNorm.py b/trt_dev/RMSNorm/trt_RMSNorm.py new file mode 100644 index 0000000..763bcc5 --- /dev/null +++ b/trt_dev/RMSNorm/trt_RMSNorm.py @@ -0,0 +1,136 @@ +import numpy as np +from cuda import cudart +import torch +from torch import Tensor, nn +import tensorrt as trt + +# RMSNorm by PyTorch +class RMSNorm(nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + super().__init__() + self.eps = eps + # The gamma parameter + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x: torch.Tensor): + # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim) + # rsqrt: 1 / sqrt(x) + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x: torch.Tensor): + # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim) + return self.weight * self._norm(x.float()).type_as(x) + +def test_torch(nIn, hIn, wIn, cOut, raw_data, weight, bias): + data = torch.tensor(raw_data).reshape(-1) + + model = RMSNorm(1) + + output = model(data) + + return output + + +def test_trt(nIn, hIn, wIn, cOut, raw_data, weight, bias): + data = np.array(raw_data) + + logger = trt.Logger(trt.Logger.ERROR) + builder = trt.Builder(logger) + + network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)) + config = builder.create_builder_config() + + # input + inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (nIn, -1, hIn)) + + # dynamic shape optimization + profile = builder.create_optimization_profile(); + profile.set_shape("inputT0", (nIn, 1, hIn), (nIn, 2, hIn), (nIn, 3, hIn)) + config.add_optimization_profile(profile) + + # RMSNorm Layer: 1) Square: X^2 -> 2) Sum: sum of all x^2 -> 3) Mean: 1/N -> 4) Root: sqrt(X) -> 5) Division: 1/X + + # 1) Square: X^2 + RMSNorm_Square_layer = network.add_elementwise(inputT0, inputT0, op=trt.ElementWiseOperation.PROD) + + # 2) Sum: sum of all X^2 + RMSNorm_Sum_layer = network.add_reduce(RMSNorm_Square_layer.get_output(0), op=trt.ReduceOperation.SUM, axes=1, keep_dims=True) + + # 3) Mean: 1/N + RMSNorm_Mean_layer = network.add_reduce(RMSNorm_Sum_layer.get_output(0), op=trt.ReduceOperation.AVG, axes=7, keep_dims=True) + + # 4) Root: sqrt(X) + RMSNorm_Sqrt_layer = network.add_unary(RMSNorm_Mean_layer.get_output(0), op=trt.UnaryOperation.SQRT) + + # 5) Division: 1/X + RMSNorm_Div_layer = network.add_elementwise(inputT0, RMSNorm_Sqrt_layer.get_output(0), op=trt.ElementWiseOperation.DIV) + + # output + network.mark_output(RMSNorm_Div_layer.get_output(0)) + + engineString = builder.build_serialized_network(network, config) + + print("Runtime") + engine = trt.Runtime(logger).deserialize_cuda_engine(engineString) + context = engine.create_execution_context() + + # dynamic shape configure + print("Set input shape") + context.set_input_shape("inputT0", (nIn, 2, hIn)) + context.set_binding_shape(0, (nIn, 2, hIn)) + + _, stream = cudart.cudaStreamCreate() + + inputH0 = np.ascontiguousarray(data.reshape(-1)) + outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1))) + + # initialize input and output data + _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream) + _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream) + + # move input to device + cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream) + + # execute + print("execute") + context.execute_async_v2([int(inputD0), int(outputD0)], stream) + + # move output back to host + cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream) + + # wait for everything + cudart.cudaStreamSynchronize(stream) + + cudart.cudaStreamDestroy(stream) + cudart.cudaFree(inputD0) + cudart.cudaFree(outputD0) + + return outputH0 + + +if __name__ == "__main__": + # Input tensor shape NCHW + nIn, hIn, wIn = 1, 2, 2 + + # Output tensor shape C + cOut = 2 + + # Input tensor + data = np.arange(hIn * wIn, dtype=np.float32).reshape(nIn, hIn, wIn) + + # fully connected weight + weight = np.ones(cOut * hIn * wIn, dtype=np.float32).reshape(cOut, hIn * wIn) + + # fully connected bias + bias = np.zeros(cOut, dtype=np.float32) + + print("inputH0 :", data.shape) + print(data) + + output_trt = test_trt(nIn, hIn, wIn, cOut, data, weight, bias).reshape(-1) + print("output_trt :", output_trt.shape) + print(output_trt) + + output_torch = test_torch(nIn, hIn, wIn, cOut, data, weight, bias) + print("output_torch :", output_torch.shape) + print(output_torch) \ No newline at end of file