{
"cells": [
{
"cell_type": "markdown",
"id": "b706722e-bc0f-4e48-890d-76ad514ee6aa",
"metadata": {},
"source": [
"# Recurrent Neural Networks for Language Model"
]
},
{
"cell_type": "markdown",
"id": "0c5876fc-0371-4862-905b-a36ae2b4ccc9",
"metadata": {},
"source": [
"## Imports"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "5885d14a-7f6c-43e8-94f5-41f1cff350f0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2022-07-29 06:27:08.825369: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
]
}
],
"source": [
"from trax.fastmath import numpy as np\n",
"import random\n",
"import trax"
]
},
{
"cell_type": "markdown",
"id": "7be0f358-f94d-431e-b046-db3496384ea3",
"metadata": {},
"source": [
"## RNN Calculation Viz"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "f4cf75f0-6576-4975-9e60-6011267f25dd",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
]
},
{
"data": {
"text/plain": [
"(DeviceArray([[1, 1],\n",
" [1, 1],\n",
" [1, 1]], dtype=int32, weak_type=True),\n",
" (3, 2))"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w_hh = np.full((3,2), 1); w_hh, w_hh.shape"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "7cbad359-d636-464c-bd36-b66e8e0b0e86",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[9, 9, 9],\n",
" [9, 9, 9],\n",
" [9, 9, 9]], dtype=int32, weak_type=True),\n",
" (3, 3))"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w_hx = np.full((3,3), 9); w_hx, w_hx.shape"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "6f5710a6-c7b1-4085-80e3-fa9c5985e140",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[1, 1, 9, 9, 9],\n",
" [1, 1, 9, 9, 9],\n",
" [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),\n",
" (3, 5))"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w_h1 = np.concatenate((w_hh, w_hx), axis=1); w_h1, w_h1.shape"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "f4bdde3e-6935-431e-8d23-2998d007a3a2",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[1, 1, 9, 9, 9],\n",
" [1, 1, 9, 9, 9],\n",
" [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),\n",
" (3, 5))"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w_h2 = np.hstack((w_hh, w_hx)); w_h2, w_h2.shape"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "f47956c6-98c4-47b3-8dd8-5f8d1c318f25",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[1],\n",
" [1]], dtype=int32, weak_type=True),\n",
" (2, 1))"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"h_t_prev = np.full((2,1),1); h_t_prev, h_t_prev.shape"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "de1e84c6-9ad4-4b61-9180-aff7b06b0456",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[9],\n",
" [9],\n",
" [9]], dtype=int32, weak_type=True),\n",
" (3, 1))"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"x_t = np.full((3,1), 9); x_t, x_t.shape"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "ef122a63-e80a-4e65-a616-e6336a1d6854",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[1],\n",
" [1],\n",
" [9],\n",
" [9],\n",
" [9]], dtype=int32, weak_type=True)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ax_1 = np.concatenate((h_t_prev, x_t), axis=0); ax_1"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "ca8358ba-9322-4e01-b485-236a00b314e4",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[1],\n",
" [1],\n",
" [9],\n",
" [9],\n",
" [9]], dtype=int32, weak_type=True),\n",
" (5, 1))"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ax_2 = np.vstack((h_t_prev, x_t)); ax_2,ax_2.shape"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "00aaa273-8a0d-4af4-b92c-3d1d5622eaf1",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[245],\n",
" [245],\n",
" [245]], dtype=int32, weak_type=True)"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w_h1@ax_1"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "0cfdf3c0-4f55-4d2d-a01c-216012140bfc",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([[245],\n",
" [245],\n",
" [245]], dtype=int32, weak_type=True)"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"w_hh@h_t_prev + w_hx@x_t"
]
},
{
"cell_type": "markdown",
"id": "5b2983ba-31d1-421a-815f-effe78493c46",
"metadata": {},
"source": [
"## Vanilla RNNs, GRUs and the scan function"
]
},
{
"cell_type": "markdown",
"id": "fb91abe6-f917-4175-b566-ac658306d857",
"metadata": {},
"source": [
"*RNN*\n",
"\n",
"\n",
"\\begin{equation}\n",
"h^{<t>}=g(W_{h}[h^{<t-1>},x^{<t>}] + b_h)\n",
"\\label{eq: htRNN}\n",
"\\end{equation}\n",
" \n",
"\\begin{equation}\n",
"\\hat{y}^{<t>}=g(W_{yh}h^{<t>} + b_y)\n",
"\\label{eq: ytRNN}\n",
"\\end{equation}\n",
" \n",
"\n",
" \n",
"*GRU*\n",
" \n",
"\\begin{equation}\n",
"\\Gamma_r=\\sigma{(W_r[h^{<t-1>}, x^{<t>}]+b_r)}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"\\Gamma_u=\\sigma{(W_u[h^{<t-1>}, x^{<t>}]+b_u)}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"c^{<t>}=\\tanh{(W_h[\\Gamma_r*h^{<t-1>},x^{<t>}]+b_h)}\n",
"\\end{equation}\n",
"\n",
"\\begin{equation}\n",
"h^{<t>}=\\Gamma_u*c^{<t>}+(1-\\Gamma_u)*h^{<t-1>}\n",
"\\end{equation}\n",
" \n",
""
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "e9d81e45-4d3c-451d-825f-e817f5e2a684",
"metadata": {},
"outputs": [],
"source": [
"def sigmoid(x):\n",
" return 1.0/(1.0+np.exp(-x))\n",
"\n",
"def forward_VRNN(inputs, weights):\n",
" x_t, h_t_prev = inputs\n",
" stack = np.vstack([h_t_prev, x_t])\n",
" w_h, _, _,w_y, b_h, _,_,b_y = weights\n",
" h_t = sigmoid(w_h@stack + b_h)\n",
" # print(h_t.shape)\n",
" y_t = sigmoid(w_y@h_t+b_y)\n",
" return y_t, h_t\n",
"\n",
"def forward_GRU(inputs, weights):\n",
" x_t, h_t_prev = inputs\n",
" stack = np.vstack([h_t_prev, x_t])\n",
" w_r, w_u, w_h,w_y, b_r, b_u,b_h,b_y = weights\n",
" T_r = sigmoid(w_r@stack + b_r)\n",
" T_u = sigmoid(w_u@stack + b_u)\n",
" c_t = np.tanh(w_h@np.vstack((T_r*h_t_prev, x_t)) + b_h)\n",
" h_t = T_u*c_t+(1-T_u)*h_t_prev\n",
" y_t = sigmoid(w_y@h_t+b_y)\n",
" return y_t, h_t\n",
"\n",
"def scan(fn, elems, weights, h_0=None):\n",
" h_t = h_0\n",
" ys = []\n",
" for x in elems:\n",
" y, h_t = fn([x, h_t], weights)\n",
" ys.append(y)\n",
" return np.array(ys), h_t"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "940369ac-90af-49ba-b9ce-5abb207dfc32",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"((16, 144), (16, 1), (256, 128, 1), (128, 1))"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"random.seed(10)\n",
"emb = 128\n",
"T = 256 # Number of variables in sequence\n",
"h_dim = 16 # Hidden state dimensions\n",
"h_0 = np.zeros((h_dim, 1));h_0.shape\n",
"\n",
"random_key = trax.fastmath.random.get_prng(seed=0)\n",
"w1 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, emb+h_dim)); w1\n",
"w2 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, emb+h_dim)); w2\n",
"w3 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, emb+h_dim)); w3\n",
"\n",
"w4 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, h_dim)); w4\n",
"\n",
"b1 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, 1)); b1\n",
"b2 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, 1)); b2\n",
"b3 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, 1)); b3\n",
"\n",
"b4 = trax.fastmath.random.normal(key = random_key, \n",
" shape = (h_dim, 1)); b4\n",
"X = trax.fastmath.random.normal(key = random_key, \n",
" shape = (T, emb, 1)); X\n",
"weights = [w1, w2, w3, w4, b1, b2, b3, b4]; weights\n",
"w1.shape, h_0.shape, X.shape, X[0].shape"
]
},
{
"cell_type": "code",
"execution_count": 14,
"id": "4f89f820-7305-4da7-b21a-bc89be6769ba",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(DeviceArray([[0.98875064],\n",
" [0.90948516],\n",
" [0.83860266],\n",
" [0.605238 ],\n",
" [0.8174015 ],\n",
" [0.8074013 ],\n",
" [0.9864478 ],\n",
" [0.19836329],\n",
" [0.23806544],\n",
" [0.03616815],\n",
" [0.9230415 ],\n",
" [0.897435 ],\n",
" [0.81876415],\n",
" [0.06574202],\n",
" [0.14956403],\n",
" [0.820773 ]], dtype=float32),\n",
" DeviceArray([[ 9.9999964e-01],\n",
" [ 1.0000000e+00],\n",
" [ 9.9392438e-01],\n",
" [-1.2257343e-04],\n",
" [-7.2951213e-02],\n",
" [-2.5626263e-02],\n",
" [-2.4690714e-09],\n",
" [-2.0006059e-01],\n",
" [ 7.9664338e-01],\n",
" [-4.0450820e-07],\n",
" [-1.1183748e-03],\n",
" [-1.2833535e-09],\n",
" [-7.8878832e-09],\n",
" [ 9.9999976e-01],\n",
" [ 9.5781153e-01],\n",
" [-7.0657270e-06]], dtype=float32))"
]
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"inputs = X[0], h_0\n",
"weights = weights\n",
"forward_VRNN(inputs, weights)\n",
"forward_GRU(inputs, weights)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "3cf41b53-417a-44a0-8565-3919f4fda206",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"DeviceArray([2, 4, 6], dtype=int32)"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"np.array([1,2,3])*np.array([2,2,2])"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "36e06206-6707-46e5-af34-0635bf13fd45",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 2.12 s, sys: 43.6 ms, total: 2.17 s\n",
"Wall time: 2.16 s\n"
]
},
{
"data": {
"text/plain": [
"(DeviceArray([[[0.9844214 ],\n",
" [0.86315817],\n",
" [0.808999 ],\n",
" ...,\n",
" [0.05173129],\n",
" [0.21562287],\n",
" [0.7522172 ]],\n",
" \n",
" [[0.9830031 ],\n",
" [0.8508622 ],\n",
" [0.7086585 ],\n",
" ...,\n",
" [0.02949953],\n",
" [0.27303317],\n",
" [0.9107685 ]],\n",
" \n",
" [[0.5616716 ],\n",
" [0.42216417],\n",
" [0.02124172],\n",
" ...,\n",
" [0.49089074],\n",
" [0.6445939 ],\n",
" [0.0898608 ]],\n",
" \n",
" ...,\n",
" \n",
" [[0.96719635],\n",
" [0.3858212 ],\n",
" [0.06641304],\n",
" ...,\n",
" [0.89807695],\n",
" [0.8543829 ],\n",
" [0.7933321 ]],\n",
" \n",
" [[0.9930194 ],\n",
" [0.75615996],\n",
" [0.3213525 ],\n",
" ...,\n",
" [0.46565428],\n",
" [0.5535333 ],\n",
" [0.7318801 ]],\n",
" \n",
" [[0.81508607],\n",
" [0.3935613 ],\n",
" [0.07889035],\n",
" ...,\n",
" [0.2907602 ],\n",
" [0.2841384 ],\n",
" [0.30343485]]], dtype=float32),\n",
" DeviceArray([[9.9991798e-01],\n",
" [9.7653759e-01],\n",
" [2.6515589e-05],\n",
" [8.7920368e-01],\n",
" [1.0000000e+00],\n",
" [9.4612682e-01],\n",
" [9.9921525e-01],\n",
" [1.0000000e+00],\n",
" [1.0000000e+00],\n",
" [1.1614484e-02],\n",
" [9.9999189e-01],\n",
" [9.9806911e-01],\n",
" [2.3502709e-02],\n",
" [1.6757324e-08],\n",
" [8.7104484e-09],\n",
" [2.2113952e-10]], dtype=float32))"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time scan(fn=forward_VRNN, elems=X, weights=weights, h_0=h_0)"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c82f964e-5bfe-47e2-8492-4f2ca8c0ba59",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 386 ms, sys: 0 ns, total: 386 ms\n",
"Wall time: 384 ms\n"
]
},
{
"data": {
"text/plain": [
"(DeviceArray([[[0.98875064],\n",
" [0.90948516],\n",
" [0.83860266],\n",
" ...,\n",
" [0.06574202],\n",
" [0.14956403],\n",
" [0.820773 ]],\n",
" \n",
" [[0.99207675],\n",
" [0.8274236 ],\n",
" [0.7019363 ],\n",
" ...,\n",
" [0.08724681],\n",
" [0.2603313 ],\n",
" [0.9303203 ]],\n",
" \n",
" [[0.96868426],\n",
" [0.65789413],\n",
" [0.04706543],\n",
" ...,\n",
" [0.02912108],\n",
" [0.76775664],\n",
" [0.1878842 ]],\n",
" \n",
" ...,\n",
" \n",
" [[0.99862254],\n",
" [0.68557423],\n",
" [0.68682545],\n",
" ...,\n",
" [0.01584747],\n",
" [0.6382734 ],\n",
" [0.27868038]],\n",
" \n",
" [[0.9979925 ],\n",
" [0.67977065],\n",
" [0.7547641 ],\n",
" ...,\n",
" [0.01567412],\n",
" [0.54346865],\n",
" [0.24606723]],\n",
" \n",
" [[0.9998665 ],\n",
" [0.38725543],\n",
" [0.9630767 ],\n",
" ...,\n",
" [0.00373541],\n",
" [0.6321914 ],\n",
" [0.9655915 ]]], dtype=float32),\n",
" DeviceArray([[ 0.99998415],\n",
" [ 0.99973965],\n",
" [ 0.93683696],\n",
" [-0.44143844],\n",
" [ 1. ],\n",
" [ 0.99857235],\n",
" [ 0.99990374],\n",
" [ 1. ],\n",
" [ 1. ],\n",
" [ 0.85161185],\n",
" [ 0.99997693],\n",
" [ 0.9956449 ],\n",
" [ 0.37345803],\n",
" [ 0.99998754],\n",
" [ 0.98146665],\n",
" [ 0.9999966 ]], dtype=float32))"
]
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"%time scan(fn=forward_GRU, elems=X, weights=weights, h_0=h_0)"
]
},
{
"cell_type": "markdown",
"id": "c0d8a1bc-316f-4a14-8f97-f6b498204bfa",
"metadata": {},
"source": [
"## Perplexity Calculations\n",
"\n",
"The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as: \n",
"\n",
"$$P(W) = \\sqrt[N]{\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}}$$\n",
"\n",
"As an implementation hack, you would usually take the log of that formula (so the computation is less prone to underflow problems)\n",
"\n",
"After taking the logarithm of $P(W)$ you have:\n",
"\n",
"$$log P(W) = {\\log\\left(\\sqrt[N]{\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}}\\right)}$$\n",
"\n",
"\n",
"$$ = \\log\\left(\\left(\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}\\right)^{\\frac{1}{N}}\\right)$$\n",
"\n",
"$$ = \\log\\left(\\left({\\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\\right)^{-\\frac{1}{N}}\\right)$$\n",
"\n",
"$$ = -\\frac{1}{N}{\\log\\left({\\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\\right)} $$\n",
"\n",
"$$ = -\\frac{1}{N}{{\\sum_{i=1}^{N}{\\log P(w_i| w_1,...,w_{n-1})}}} $$\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "1d16f656-6d8a-4cbd-adda-e9bf5b02c855",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"id": "ea5fd9b6-b740-425a-8102-f625ec8ebfe2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python [conda env:aiking] *",
"language": "python",
"name": "conda-env-aiking-py"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
}
},
"nbformat": 4,
"nbformat_minor": 5
}