Skip to content

Commit

Permalink
print out embeddings for illustrative learning
Browse files Browse the repository at this point in the history
  • Loading branch information
henrythe9th committed Jan 13, 2025
1 parent b524afe commit 9e88891
Showing 1 changed file with 136 additions and 10 deletions.
146 changes: 136 additions & 10 deletions ch02/01_main-chapter-code/ch02.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1709,12 +1709,29 @@
"execution_count": 47,
"id": "0b9e344d-03a6-4f2c-b723-67b6a20c5041",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-0.4105, 0.1759, -1.1781, ..., 0.2422, 1.8863, -0.7444],\n",
" [ 0.2254, 0.2377, 0.7024, ..., -0.1327, -0.5048, 1.5622],\n",
" [-0.7796, 0.2877, -0.3258, ..., -0.7638, 0.0910, 0.5602],\n",
" ...,\n",
" [ 0.5587, 0.1098, 0.3590, ..., -0.0373, 1.4842, 0.6784],\n",
" [ 0.1251, -0.8082, -2.4348, ..., 1.0868, 0.6819, -0.4818],\n",
" [ 0.2434, 0.0151, 0.3175, ..., 0.3685, 0.3924, 0.7938]],\n",
" requires_grad=True)\n"
]
}
],
"source": [
"vocab_size = 50257\n",
"output_dim = 256\n",
"\n",
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)"
"token_embedding_layer = torch.nn.Embedding(vocab_size, output_dim)\n",
"print(token_embedding_layer.weight)"
]
},
{
Expand Down Expand Up @@ -1782,13 +1799,46 @@
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 4, 256])\n"
"torch.Size([8, 4, 256])\n",
"tensor([[[ 1.1678, -0.3909, -0.1767, ..., 0.1349, 0.7360, 0.1036],\n",
" [-0.3847, 1.1523, -1.9311, ..., 0.3216, -0.0090, 0.1268],\n",
" [-0.1908, 0.7605, -2.0030, ..., 0.3729, 0.4844, -1.4316],\n",
" [ 0.2240, -0.6976, 0.1289, ..., -1.7083, 1.2737, 1.1384]],\n",
"\n",
" [[-0.7294, 0.3682, 1.0276, ..., 0.4348, -2.7288, -2.8317],\n",
" [ 1.7466, -0.3932, -0.8052, ..., -0.8955, -0.7593, 0.0674],\n",
" [ 1.0403, -0.8981, -0.4447, ..., 0.3280, 0.1324, 0.9342],\n",
" [-1.7905, -0.3360, 1.4520, ..., 2.7658, -0.5026, 0.0061]],\n",
"\n",
" [[-1.3703, 0.1330, -0.7088, ..., 0.6933, 1.6739, -0.2127],\n",
" [-1.1414, 1.9915, -0.2790, ..., -0.1318, 1.3707, 1.4974],\n",
" [-0.0174, 0.1929, 0.7441, ..., 1.8306, 0.1374, 0.5985],\n",
" [ 0.7680, -0.7084, 0.5994, ..., 0.1701, 1.6316, -0.3582]],\n",
"\n",
" ...,\n",
"\n",
" [[ 1.3656, -1.8724, 0.3464, ..., 0.2038, -0.4580, 0.1668],\n",
" [-0.7478, 0.2314, -1.5561, ..., -0.7006, -1.4071, 1.9468],\n",
" [-1.1293, -0.5649, -0.6807, ..., -0.1769, -1.0311, -1.3673],\n",
" [ 0.1264, -0.0062, 0.3948, ..., -1.4024, -0.7518, 0.2758]],\n",
"\n",
" [[-0.2547, -0.1153, -0.2032, ..., -0.0501, 0.3482, -0.3122],\n",
" [-0.3852, -0.5523, -0.0681, ..., -0.8820, -0.3880, -0.0975],\n",
" [-1.3012, 0.0518, -0.0203, ..., 0.3899, 1.5529, -0.1066],\n",
" [-0.1291, -0.8259, 0.2547, ..., 0.3598, 0.1151, 0.1591]],\n",
"\n",
" [[-1.3012, 0.0518, -0.0203, ..., 0.3899, 1.5529, -0.1066],\n",
" [ 0.3554, 0.8244, -0.2780, ..., 0.2390, 1.1231, 0.0898],\n",
" [ 0.7745, -0.4805, 0.6817, ..., -0.1028, 1.7001, -1.6556],\n",
" [-0.0286, 1.3302, -2.2592, ..., 0.5255, 0.2511, 1.1557]]],\n",
" grad_fn=<EmbeddingBackward0>)\n"
]
}
],
"source": [
"token_embeddings = token_embedding_layer(inputs)\n",
"print(token_embeddings.shape)"
"print(token_embeddings.shape)\n",
"print(token_embeddings)"
]
},
{
Expand All @@ -1804,10 +1854,24 @@
"execution_count": 51,
"id": "cc048e20-7ac8-417e-81f5-8fe6f9a4fe07",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Parameter containing:\n",
"tensor([[-0.6580, -1.2299, 0.0347, ..., -1.4686, 0.6399, -0.5408],\n",
" [ 1.0752, 2.7793, -0.9888, ..., -0.2912, 1.2106, -0.4063],\n",
" [-0.5240, 0.3168, 0.0776, ..., 0.8753, 1.3297, 0.7725],\n",
" [ 0.9148, -0.6863, 0.9455, ..., 0.0252, 0.1518, 0.2224]],\n",
" requires_grad=True)\n"
]
}
],
"source": [
"context_length = max_length\n",
"pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)"
"pos_embedding_layer = torch.nn.Embedding(context_length, output_dim)\n",
"print(pos_embedding_layer.weight)"
]
},
{
Expand All @@ -1820,13 +1884,19 @@
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([4, 256])\n"
"torch.Size([4, 256])\n",
"tensor([[ 2.2182, -1.3569, -0.1239, ..., 0.7485, -2.0677, -0.8651],\n",
" [-0.3519, -2.5110, -1.3511, ..., 0.1736, 0.6988, 1.8490],\n",
" [ 0.8965, 0.3749, 0.0378, ..., -1.1351, 1.5106, 2.1837],\n",
" [ 0.7844, -0.9181, 1.3495, ..., 0.8581, 0.7744, 0.9833]],\n",
" grad_fn=<EmbeddingBackward0>)\n"
]
}
],
"source": [
"pos_embeddings = pos_embedding_layer(torch.arange(max_length))\n",
"print(pos_embeddings.shape)"
"print(pos_embeddings.shape)\n",
"print(pos_embeddings)"
]
},
{
Expand All @@ -1847,13 +1917,69 @@
"name": "stdout",
"output_type": "stream",
"text": [
"torch.Size([8, 4, 256])\n"
"torch.Size([8, 4, 256])\n",
"tensor([[[ 3.3860e+00, -1.7478e+00, -3.0061e-01, ..., 8.8339e-01,\n",
" -1.3317e+00, -7.6146e-01],\n",
" [-7.3660e-01, -1.3587e+00, -3.2822e+00, ..., 4.9528e-01,\n",
" 6.8979e-01, 1.9758e+00],\n",
" [ 7.0569e-01, 1.1354e+00, -1.9652e+00, ..., -7.6219e-01,\n",
" 1.9950e+00, 7.5211e-01],\n",
" [ 1.0084e+00, -1.6157e+00, 1.4784e+00, ..., -8.5024e-01,\n",
" 2.0481e+00, 2.1217e+00]],\n",
"\n",
" [[ 1.4887e+00, -9.8865e-01, 9.0368e-01, ..., 1.1833e+00,\n",
" -4.7965e+00, -3.6968e+00],\n",
" [ 1.3947e+00, -2.9042e+00, -2.1563e+00, ..., -7.2184e-01,\n",
" -6.0412e-02, 1.9164e+00],\n",
" [ 1.9368e+00, -5.2321e-01, -4.0685e-01, ..., -8.0718e-01,\n",
" 1.6429e+00, 3.1179e+00],\n",
" [-1.0061e+00, -1.2540e+00, 2.8015e+00, ..., 3.6239e+00,\n",
" 2.7181e-01, 9.8939e-01]],\n",
"\n",
" [[ 8.4783e-01, -1.2239e+00, -8.3269e-01, ..., 1.4419e+00,\n",
" -3.9378e-01, -1.0778e+00],\n",
" [-1.4933e+00, -5.1949e-01, -1.6301e+00, ..., 4.1825e-02,\n",
" 2.0696e+00, 3.3464e+00],\n",
" [ 8.7912e-01, 5.6780e-01, 7.8196e-01, ..., 6.9545e-01,\n",
" 1.6479e+00, 2.7822e+00],\n",
" [ 1.5524e+00, -1.6265e+00, 1.9489e+00, ..., 1.0282e+00,\n",
" 2.4060e+00, 6.2512e-01]],\n",
"\n",
" ...,\n",
"\n",
" [[ 3.5838e+00, -3.2293e+00, 2.2252e-01, ..., 9.5236e-01,\n",
" -2.5257e+00, -6.9832e-01],\n",
" [-1.0997e+00, -2.2796e+00, -2.9072e+00, ..., -5.2700e-01,\n",
" -7.0828e-01, 3.7958e+00],\n",
" [-2.3279e-01, -1.9005e-01, -6.4289e-01, ..., -1.3121e+00,\n",
" 4.7947e-01, 8.1640e-01],\n",
" [ 9.1078e-01, -9.2422e-01, 1.7443e+00, ..., -5.4439e-01,\n",
" 2.2631e-02, 1.2591e+00]],\n",
"\n",
" [[ 1.9635e+00, -1.4722e+00, -3.2714e-01, ..., 6.9841e-01,\n",
" -1.7195e+00, -1.1773e+00],\n",
" [-7.3718e-01, -3.0633e+00, -1.4192e+00, ..., -7.0835e-01,\n",
" 3.1082e-01, 1.7515e+00],\n",
" [-4.0476e-01, 4.2667e-01, 1.7501e-02, ..., -7.4529e-01,\n",
" 3.0635e+00, 2.0771e+00],\n",
" [ 6.5526e-01, -1.7440e+00, 1.6042e+00, ..., 1.2179e+00,\n",
" 8.8957e-01, 1.1425e+00]],\n",
"\n",
" [[ 9.1692e-01, -1.3051e+00, -1.4425e-01, ..., 1.1384e+00,\n",
" -5.1479e-01, -9.7169e-01],\n",
" [ 3.4773e-03, -1.6866e+00, -1.6290e+00, ..., 4.1264e-01,\n",
" 1.8220e+00, 1.9388e+00],\n",
" [ 1.6710e+00, -1.0559e-01, 7.1958e-01, ..., -1.2380e+00,\n",
" 3.2107e+00, 5.2809e-01],\n",
" [ 7.5577e-01, 4.1219e-01, -9.0969e-01, ..., 1.3836e+00,\n",
" 1.0255e+00, 2.1390e+00]]], grad_fn=<AddBackward0>)\n"
]
}
],
"source": [
"input_embeddings = token_embeddings + pos_embeddings\n",
"print(input_embeddings.shape)"
"print(input_embeddings.shape)\n",
"print(input_embeddings)"
]
},
{
Expand Down

0 comments on commit 9e88891

Please sign in to comment.