Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented RMSNorm Layer #1

Merged
merged 1 commit into from
Oct 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
394 changes: 394 additions & 0 deletions trt_dev/RMSNorm/RMSNorm.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,394 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "0e7c0a0e",
"metadata": {},
"source": [
"# RMSNorm with TensorRT"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "19e3269d",
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"from cuda import cudart\n",
"import torch\n",
"from torch import Tensor, nn\n",
"import tensorrt as trt"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "07811b68",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"PyTorch version: 2.1.0a0+4136153\n",
"TensorRT version: 8.6.1\n"
]
}
],
"source": [
"print(\"PyTorch version: \" + torch.__version__)\n",
"print(\"TensorRT version: \" + trt.__version__)"
]
},
{
"cell_type": "markdown",
"id": "6ee4707c",
"metadata": {},
"source": [
"## Generate input and data shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9499ca7d",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"inputH0 : (1, 2, 2)\n",
"[[[0. 1.]\n",
" [2. 3.]]]\n"
]
}
],
"source": [
"# Input tensor shape NCHW\n",
"nIn, hIn, wIn = 1, 2, 2\n",
"\n",
"# Output tensor shape C\n",
"cOut = 2\n",
"\n",
"# Input tensor\n",
"data = np.arange(hIn * wIn, dtype=np.float32).reshape(nIn, hIn, wIn)\n",
"\n",
"# fully connected weight\n",
"weight = np.ones(cOut * hIn * wIn, dtype=np.float32).reshape(cOut, hIn * wIn)\n",
"\n",
"# fully connected bias\n",
"bias = np.zeros(cOut, dtype=np.float32)\n",
"\n",
"print(\"inputH0 :\", data.shape)\n",
"print(data)"
]
},
{
"cell_type": "markdown",
"id": "18128183",
"metadata": {},
"source": [
"## 1. RMSNorm by PyTorch "
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "9318f587",
"metadata": {},
"outputs": [],
"source": [
"class RMSNorm(nn.Module):\n",
" def __init__(self, dim: int, eps: float = 1e-6):\n",
" super().__init__()\n",
" self.eps = eps\n",
" # The gamma parameter\n",
" self.weight = nn.Parameter(torch.ones(dim))\n",
"\n",
" def _norm(self, x: torch.Tensor):\n",
" # (B, Seq_Len, Dim) * (B, Seq_Len, 1) = (B, Seq_Len, Dim)\n",
" # rsqrt: 1 / sqrt(x)\n",
" return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" # (Dim) * (B, Seq_Len, Dim) = (B, Seq_Len, Dim)\n",
" return self.weight * self._norm(x.float()).type_as(x)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "9f31c1c4",
"metadata": {},
"outputs": [],
"source": [
"def test_torch(nIn, hIn, wIn, cOut, raw_data, weight, bias):\n",
" data = torch.tensor(raw_data).reshape(-1)\n",
" \n",
" model = RMSNorm(1)\n",
"\n",
" output = model(data)\n",
"\n",
" return output"
]
},
{
"cell_type": "markdown",
"id": "a6e87667",
"metadata": {},
"source": [
"## PyTorch Testing"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "b5446246",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSNorm_output_torch : torch.Size([4])\n",
"tensor([0.0000, 0.5345, 1.0690, 1.6036], grad_fn=<MulBackward0>)\n"
]
}
],
"source": [
"torch_output = test_torch(nIn, hIn, wIn, cOut, data, weight, bias)\n",
"print(\"RMSNorm_output_torch :\", torch_output.shape)\n",
"print(torch_output)"
]
},
{
"cell_type": "markdown",
"id": "a4ec993e",
"metadata": {},
"source": [
"---"
]
},
{
"cell_type": "markdown",
"id": "6b7fa9b4",
"metadata": {},
"source": [
"## 2. RMSNorm with TensorRT"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6431574b",
"metadata": {},
"outputs": [],
"source": [
"def trt_create(nIn, hIn, cOut, weight, bias):\n",
" # Config TensorRT Logger, Builder, Network\n",
" logger = trt.Logger(trt.Logger.ERROR)\n",
" builder = trt.Builder(logger)\n",
"\n",
" network = builder.create_network(1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH))\n",
" config = builder.create_builder_config()\n",
"\n",
" # input\n",
" inputT0 = network.add_input('inputT0', trt.DataType.FLOAT, (nIn, -1, hIn))\n",
"\n",
" # dynamic shape optimization\n",
" profile = builder.create_optimization_profile();\n",
" profile.set_shape(\"inputT0\", (nIn, 1, hIn), (nIn, 2, hIn), (nIn, 3, hIn)) \n",
" config.add_optimization_profile(profile)\n",
"\n",
" # RMSNorm Layer: 1) Square: X^2 -> 2) Sum: sum of all x^2 -> 3) Mean: 1/N -> 4) Root: sqrt(X) -> 5) Division: 1/X\n",
" print(\"inputT0.shape :\")\n",
" print(inputT0.shape)\n",
" # 1) Square: X^2\n",
" RMSNorm_Square_layer = network.add_elementwise(inputT0, inputT0, op=trt.ElementWiseOperation.PROD)\n",
" print(\"RMSNorm_Square_layer.get_output(0).shape :\")\n",
" print(RMSNorm_Square_layer.get_output(0).shape)\n",
" # 2) Sum: sum of all X^2\n",
" RMSNorm_Sum_layer = network.add_reduce(RMSNorm_Square_layer.get_output(0), op=trt.ReduceOperation.SUM, axes=1, keep_dims=True)\n",
" print(\"RMSNorm_Sum_layer.get_output(0).shape :\")\n",
" print(RMSNorm_Sum_layer.get_output(0).shape)\n",
" # 3) Mean: 1/N\n",
" RMSNorm_Mean_layer = network.add_reduce(RMSNorm_Sum_layer.get_output(0), op=trt.ReduceOperation.AVG, axes=7, keep_dims=True)\n",
" print(\"RMSNorm_Mean_layer.get_output(0).shape :\")\n",
" print(RMSNorm_Mean_layer.get_output(0).shape)\n",
" # 4) Root: sqrt(X)\n",
" RMSNorm_Sqrt_layer = network.add_unary(RMSNorm_Mean_layer.get_output(0), op=trt.UnaryOperation.SQRT)\n",
" print(\"RMSNorm_Sqrt_layer.get_output(0).shape :\")\n",
" print(RMSNorm_Sqrt_layer.get_output(0).shape)\n",
" # 5) Division: 1/X\n",
" RMSNorm_Div_layer = network.add_elementwise(inputT0, RMSNorm_Sqrt_layer.get_output(0), op=trt.ElementWiseOperation.DIV)\n",
" print(\"RMSNorm_Div_layer.get_output(0).shape :\")\n",
" print(RMSNorm_Div_layer.get_output(0).shape)\n",
" # output\n",
" network.mark_output(RMSNorm_Div_layer.get_output(0))\n",
"\n",
" engineString = builder.build_serialized_network(network, config)\n",
" \n",
" return engineString"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "4807937f",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"inputT0.shape :\n",
"(1, -1, 2)\n",
"RMSNorm_Square_layer.get_output(0).shape :\n",
"(1, -1, 2)\n",
"RMSNorm_Sum_layer.get_output(0).shape :\n",
"(1, -1, 2)\n",
"RMSNorm_Mean_layer.get_output(0).shape :\n",
"(1, 1, 1)\n",
"RMSNorm_Sqrt_layer.get_output(0).shape :\n",
"(1, 1, 1)\n",
"RMSNorm_Div_layer.get_output(0).shape :\n",
"(1, -1, 2)\n"
]
}
],
"source": [
"trt_engineStr = trt_create(nIn, hIn, cOut, weight, bias)"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "3580ecec",
"metadata": {},
"outputs": [],
"source": [
"def trt_inference(nIn, hIn, cOut, engineString, raw_data):\n",
" print(engineString)\n",
" print(\"Runtime\")\n",
" logger = trt.Logger(trt.Logger.ERROR)\n",
" engine = trt.Runtime(logger).deserialize_cuda_engine(engineString)\n",
" context = engine.create_execution_context()\n",
"\n",
" # dynamic shape configure\n",
" print(\"Set input shape\")\n",
" context.set_input_shape(\"inputT0\", (nIn, 2, hIn))\n",
" context.set_binding_shape(0, (nIn, 2, hIn))\n",
" origin_inputshape = context.get_binding_shape(0)\n",
"\n",
" print(\"Set input shape completed\")\n",
"\n",
" data = np.array(raw_data)\n",
"\n",
" _, stream = cudart.cudaStreamCreate()\n",
" print(\"Reshaping\")\n",
"\n",
" inputH0 = np.ascontiguousarray(data.reshape(-1))\n",
" outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))\n",
" print(\"Reshaped\")\n",
"\n",
" # initialize input and output data\n",
" _, inputD0 = cudart.cudaMallocAsync(inputH0.nbytes, stream)\n",
" _, outputD0 = cudart.cudaMallocAsync(outputH0.nbytes, stream)\n",
"\n",
" # move input to device\n",
" cudart.cudaMemcpyAsync(inputD0, inputH0.ctypes.data, inputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyHostToDevice, stream)\n",
"\n",
" # execute\n",
" print(\"execute\")\n",
" context.execute_async_v2([int(inputD0), int(outputD0)], stream)\n",
"\n",
" # move output back to host\n",
" cudart.cudaMemcpyAsync(outputH0.ctypes.data, outputD0, outputH0.nbytes, cudart.cudaMemcpyKind.cudaMemcpyDeviceToHost, stream)\n",
"\n",
" # wait for everything\n",
" cudart.cudaStreamSynchronize(stream)\n",
"\n",
" cudart.cudaStreamDestroy(stream)\n",
" cudart.cudaFree(inputD0)\n",
" cudart.cudaFree(outputD0)\n",
"\n",
" return outputH0"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "a2a806c0",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"<tensorrt.tensorrt.IHostMemory object at 0x7f4381fa7370>\n",
"Runtime\n",
"Set input shape\n",
"Set input shape completed\n",
"Reshaping\n",
"Reshaped\n",
"execute\n",
"output_trt : (4,)\n",
"[0. 0.5345225 1.069045 1.6035675]\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_1410/3070280412.py:11: DeprecationWarning: Use set_input_shape instead.\n",
" context.set_binding_shape(0, (nIn, 2, hIn))\n",
"/tmp/ipykernel_1410/3070280412.py:12: DeprecationWarning: Use get_tensor_shape instead.\n",
" origin_inputshape = context.get_binding_shape(0)\n",
"/tmp/ipykernel_1410/3070280412.py:22: DeprecationWarning: Use get_tensor_shape instead.\n",
" outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))\n",
"/tmp/ipykernel_1410/3070280412.py:22: DeprecationWarning: Use get_tensor_dtype instead.\n",
" outputH0 = np.empty(context.get_binding_shape(1), dtype=trt.nptype(engine.get_binding_dtype(1)))\n"
]
}
],
"source": [
"trt_output = trt_inference(nIn, hIn, cOut, trt_engineStr, data)\n",
"trt_output = trt_output.reshape(-1)\n",
"print(\"output_trt :\", trt_output.shape)\n",
"print(trt_output)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9d007208",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading