{ "cells": [ { "cell_type": "markdown", "id": "b706722e-bc0f-4e48-890d-76ad514ee6aa", "metadata": {}, "source": [ "# Recurrent Neural Networks for Language Model" ] }, { "cell_type": "markdown", "id": "0c5876fc-0371-4862-905b-a36ae2b4ccc9", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": 1, "id": "5885d14a-7f6c-43e8-94f5-41f1cff350f0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-07-29 06:27:08.825369: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n" ] } ], "source": [ "from trax.fastmath import numpy as np\n", "import random\n", "import trax" ] }, { "cell_type": "markdown", "id": "7be0f358-f94d-431e-b046-db3496384ea3", "metadata": {}, "source": [ "## RNN Calculation Viz" ] }, { "cell_type": "code", "execution_count": 2, "id": "f4cf75f0-6576-4975-9e60-6011267f25dd", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" ] }, { "data": { "text/plain": [ "(DeviceArray([[1, 1],\n", " [1, 1],\n", " [1, 1]], dtype=int32, weak_type=True),\n", " (3, 2))" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w_hh = np.full((3,2), 1); w_hh, w_hh.shape" ] }, { "cell_type": "code", "execution_count": 3, "id": "7cbad359-d636-464c-bd36-b66e8e0b0e86", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[9, 9, 9],\n", " [9, 9, 9],\n", " [9, 9, 9]], dtype=int32, weak_type=True),\n", " (3, 3))" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w_hx = np.full((3,3), 9); w_hx, w_hx.shape" ] }, { "cell_type": "code", "execution_count": 4, "id": "6f5710a6-c7b1-4085-80e3-fa9c5985e140", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[1, 1, 9, 9, 9],\n", " [1, 1, 9, 9, 9],\n", " [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),\n", " (3, 5))" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w_h1 = np.concatenate((w_hh, w_hx), axis=1); w_h1, w_h1.shape" ] }, { "cell_type": "code", "execution_count": 5, "id": "f4bdde3e-6935-431e-8d23-2998d007a3a2", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[1, 1, 9, 9, 9],\n", " [1, 1, 9, 9, 9],\n", " [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),\n", " (3, 5))" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w_h2 = np.hstack((w_hh, w_hx)); w_h2, w_h2.shape" ] }, { "cell_type": "code", "execution_count": 6, "id": "f47956c6-98c4-47b3-8dd8-5f8d1c318f25", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[1],\n", " [1]], dtype=int32, weak_type=True),\n", " (2, 1))" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "h_t_prev = np.full((2,1),1); h_t_prev, h_t_prev.shape" ] }, { "cell_type": "code", "execution_count": 7, "id": "de1e84c6-9ad4-4b61-9180-aff7b06b0456", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[9],\n", " [9],\n", " [9]], dtype=int32, weak_type=True),\n", " (3, 1))" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_t = np.full((3,1), 9); x_t, x_t.shape" ] }, { "cell_type": "code", "execution_count": 8, "id": "ef122a63-e80a-4e65-a616-e6336a1d6854", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[1],\n", " [1],\n", " [9],\n", " [9],\n", " [9]], dtype=int32, weak_type=True)" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ax_1 = np.concatenate((h_t_prev, x_t), axis=0); ax_1" ] }, { "cell_type": "code", "execution_count": 9, "id": "ca8358ba-9322-4e01-b485-236a00b314e4", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[1],\n", " [1],\n", " [9],\n", " [9],\n", " [9]], dtype=int32, weak_type=True),\n", " (5, 1))" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ax_2 = np.vstack((h_t_prev, x_t)); ax_2,ax_2.shape" ] }, { "cell_type": "code", "execution_count": 10, "id": "00aaa273-8a0d-4af4-b92c-3d1d5622eaf1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[245],\n", " [245],\n", " [245]], dtype=int32, weak_type=True)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w_h1@ax_1" ] }, { "cell_type": "code", "execution_count": 11, "id": "0cfdf3c0-4f55-4d2d-a01c-216012140bfc", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([[245],\n", " [245],\n", " [245]], dtype=int32, weak_type=True)" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "w_hh@h_t_prev + w_hx@x_t" ] }, { "cell_type": "markdown", "id": "5b2983ba-31d1-421a-815f-effe78493c46", "metadata": {}, "source": [ "## Vanilla RNNs, GRUs and the scan function" ] }, { "cell_type": "markdown", "id": "fb91abe6-f917-4175-b566-ac658306d857", "metadata": {}, "source": [ "*RNN*\n", "\n", "\n", "\\begin{equation}\n", "h^{<t>}=g(W_{h}[h^{<t-1>},x^{<t>}] + b_h)\n", "\\label{eq: htRNN}\n", "\\end{equation}\n", " \n", "\\begin{equation}\n", "\\hat{y}^{<t>}=g(W_{yh}h^{<t>} + b_y)\n", "\\label{eq: ytRNN}\n", "\\end{equation}\n", " \n", "\n", " \n", "*GRU*\n", " \n", "\\begin{equation}\n", "\\Gamma_r=\\sigma{(W_r[h^{<t-1>}, x^{<t>}]+b_r)}\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "\\Gamma_u=\\sigma{(W_u[h^{<t-1>}, x^{<t>}]+b_u)}\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "c^{<t>}=\\tanh{(W_h[\\Gamma_r*h^{<t-1>},x^{<t>}]+b_h)}\n", "\\end{equation}\n", "\n", "\\begin{equation}\n", "h^{<t>}=\\Gamma_u*c^{<t>}+(1-\\Gamma_u)*h^{<t-1>}\n", "\\end{equation}\n", " \n", "" ] }, { "cell_type": "code", "execution_count": 12, "id": "e9d81e45-4d3c-451d-825f-e817f5e2a684", "metadata": {}, "outputs": [], "source": [ "def sigmoid(x):\n", " return 1.0/(1.0+np.exp(-x))\n", "\n", "def forward_VRNN(inputs, weights):\n", " x_t, h_t_prev = inputs\n", " stack = np.vstack([h_t_prev, x_t])\n", " w_h, _, _,w_y, b_h, _,_,b_y = weights\n", " h_t = sigmoid(w_h@stack + b_h)\n", " # print(h_t.shape)\n", " y_t = sigmoid(w_y@h_t+b_y)\n", " return y_t, h_t\n", "\n", "def forward_GRU(inputs, weights):\n", " x_t, h_t_prev = inputs\n", " stack = np.vstack([h_t_prev, x_t])\n", " w_r, w_u, w_h,w_y, b_r, b_u,b_h,b_y = weights\n", " T_r = sigmoid(w_r@stack + b_r)\n", " T_u = sigmoid(w_u@stack + b_u)\n", " c_t = np.tanh(w_h@np.vstack((T_r*h_t_prev, x_t)) + b_h)\n", " h_t = T_u*c_t+(1-T_u)*h_t_prev\n", " y_t = sigmoid(w_y@h_t+b_y)\n", " return y_t, h_t\n", "\n", "def scan(fn, elems, weights, h_0=None):\n", " h_t = h_0\n", " ys = []\n", " for x in elems:\n", " y, h_t = fn([x, h_t], weights)\n", " ys.append(y)\n", " return np.array(ys), h_t" ] }, { "cell_type": "code", "execution_count": 13, "id": "940369ac-90af-49ba-b9ce-5abb207dfc32", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((16, 144), (16, 1), (256, 128, 1), (128, 1))" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "random.seed(10)\n", "emb = 128\n", "T = 256 # Number of variables in sequence\n", "h_dim = 16 # Hidden state dimensions\n", "h_0 = np.zeros((h_dim, 1));h_0.shape\n", "\n", "random_key = trax.fastmath.random.get_prng(seed=0)\n", "w1 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, emb+h_dim)); w1\n", "w2 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, emb+h_dim)); w2\n", "w3 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, emb+h_dim)); w3\n", "\n", "w4 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, h_dim)); w4\n", "\n", "b1 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, 1)); b1\n", "b2 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, 1)); b2\n", "b3 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, 1)); b3\n", "\n", "b4 = trax.fastmath.random.normal(key = random_key, \n", " shape = (h_dim, 1)); b4\n", "X = trax.fastmath.random.normal(key = random_key, \n", " shape = (T, emb, 1)); X\n", "weights = [w1, w2, w3, w4, b1, b2, b3, b4]; weights\n", "w1.shape, h_0.shape, X.shape, X[0].shape" ] }, { "cell_type": "code", "execution_count": 14, "id": "4f89f820-7305-4da7-b21a-bc89be6769ba", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(DeviceArray([[0.98875064],\n", " [0.90948516],\n", " [0.83860266],\n", " [0.605238 ],\n", " [0.8174015 ],\n", " [0.8074013 ],\n", " [0.9864478 ],\n", " [0.19836329],\n", " [0.23806544],\n", " [0.03616815],\n", " [0.9230415 ],\n", " [0.897435 ],\n", " [0.81876415],\n", " [0.06574202],\n", " [0.14956403],\n", " [0.820773 ]], dtype=float32),\n", " DeviceArray([[ 9.9999964e-01],\n", " [ 1.0000000e+00],\n", " [ 9.9392438e-01],\n", " [-1.2257343e-04],\n", " [-7.2951213e-02],\n", " [-2.5626263e-02],\n", " [-2.4690714e-09],\n", " [-2.0006059e-01],\n", " [ 7.9664338e-01],\n", " [-4.0450820e-07],\n", " [-1.1183748e-03],\n", " [-1.2833535e-09],\n", " [-7.8878832e-09],\n", " [ 9.9999976e-01],\n", " [ 9.5781153e-01],\n", " [-7.0657270e-06]], dtype=float32))" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "inputs = X[0], h_0\n", "weights = weights\n", "forward_VRNN(inputs, weights)\n", "forward_GRU(inputs, weights)" ] }, { "cell_type": "code", "execution_count": 15, "id": "3cf41b53-417a-44a0-8565-3919f4fda206", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "DeviceArray([2, 4, 6], dtype=int32)" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.array([1,2,3])*np.array([2,2,2])" ] }, { "cell_type": "code", "execution_count": 16, "id": "36e06206-6707-46e5-af34-0635bf13fd45", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 2.12 s, sys: 43.6 ms, total: 2.17 s\n", "Wall time: 2.16 s\n" ] }, { "data": { "text/plain": [ "(DeviceArray([[[0.9844214 ],\n", " [0.86315817],\n", " [0.808999 ],\n", " ...,\n", " [0.05173129],\n", " [0.21562287],\n", " [0.7522172 ]],\n", " \n", " [[0.9830031 ],\n", " [0.8508622 ],\n", " [0.7086585 ],\n", " ...,\n", " [0.02949953],\n", " [0.27303317],\n", " [0.9107685 ]],\n", " \n", " [[0.5616716 ],\n", " [0.42216417],\n", " [0.02124172],\n", " ...,\n", " [0.49089074],\n", " [0.6445939 ],\n", " [0.0898608 ]],\n", " \n", " ...,\n", " \n", " [[0.96719635],\n", " [0.3858212 ],\n", " [0.06641304],\n", " ...,\n", " [0.89807695],\n", " [0.8543829 ],\n", " [0.7933321 ]],\n", " \n", " [[0.9930194 ],\n", " [0.75615996],\n", " [0.3213525 ],\n", " ...,\n", " [0.46565428],\n", " [0.5535333 ],\n", " [0.7318801 ]],\n", " \n", " [[0.81508607],\n", " [0.3935613 ],\n", " [0.07889035],\n", " ...,\n", " [0.2907602 ],\n", " [0.2841384 ],\n", " [0.30343485]]], dtype=float32),\n", " DeviceArray([[9.9991798e-01],\n", " [9.7653759e-01],\n", " [2.6515589e-05],\n", " [8.7920368e-01],\n", " [1.0000000e+00],\n", " [9.4612682e-01],\n", " [9.9921525e-01],\n", " [1.0000000e+00],\n", " [1.0000000e+00],\n", " [1.1614484e-02],\n", " [9.9999189e-01],\n", " [9.9806911e-01],\n", " [2.3502709e-02],\n", " [1.6757324e-08],\n", " [8.7104484e-09],\n", " [2.2113952e-10]], dtype=float32))" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time scan(fn=forward_VRNN, elems=X, weights=weights, h_0=h_0)" ] }, { "cell_type": "code", "execution_count": 17, "id": "c82f964e-5bfe-47e2-8492-4f2ca8c0ba59", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "CPU times: user 386 ms, sys: 0 ns, total: 386 ms\n", "Wall time: 384 ms\n" ] }, { "data": { "text/plain": [ "(DeviceArray([[[0.98875064],\n", " [0.90948516],\n", " [0.83860266],\n", " ...,\n", " [0.06574202],\n", " [0.14956403],\n", " [0.820773 ]],\n", " \n", " [[0.99207675],\n", " [0.8274236 ],\n", " [0.7019363 ],\n", " ...,\n", " [0.08724681],\n", " [0.2603313 ],\n", " [0.9303203 ]],\n", " \n", " [[0.96868426],\n", " [0.65789413],\n", " [0.04706543],\n", " ...,\n", " [0.02912108],\n", " [0.76775664],\n", " [0.1878842 ]],\n", " \n", " ...,\n", " \n", " [[0.99862254],\n", " [0.68557423],\n", " [0.68682545],\n", " ...,\n", " [0.01584747],\n", " [0.6382734 ],\n", " [0.27868038]],\n", " \n", " [[0.9979925 ],\n", " [0.67977065],\n", " [0.7547641 ],\n", " ...,\n", " [0.01567412],\n", " [0.54346865],\n", " [0.24606723]],\n", " \n", " [[0.9998665 ],\n", " [0.38725543],\n", " [0.9630767 ],\n", " ...,\n", " [0.00373541],\n", " [0.6321914 ],\n", " [0.9655915 ]]], dtype=float32),\n", " DeviceArray([[ 0.99998415],\n", " [ 0.99973965],\n", " [ 0.93683696],\n", " [-0.44143844],\n", " [ 1. ],\n", " [ 0.99857235],\n", " [ 0.99990374],\n", " [ 1. ],\n", " [ 1. ],\n", " [ 0.85161185],\n", " [ 0.99997693],\n", " [ 0.9956449 ],\n", " [ 0.37345803],\n", " [ 0.99998754],\n", " [ 0.98146665],\n", " [ 0.9999966 ]], dtype=float32))" ] }, "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "%time scan(fn=forward_GRU, elems=X, weights=weights, h_0=h_0)" ] }, { "cell_type": "markdown", "id": "c0d8a1bc-316f-4a14-8f97-f6b498204bfa", "metadata": {}, "source": [ "## Perplexity Calculations\n", "\n", "The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as: \n", "\n", "$$P(W) = \\sqrt[N]{\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}}$$\n", "\n", "As an implementation hack, you would usually take the log of that formula (so the computation is less prone to underflow problems)\n", "\n", "After taking the logarithm of $P(W)$ you have:\n", "\n", "$$log P(W) = {\\log\\left(\\sqrt[N]{\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}}\\right)}$$\n", "\n", "\n", "$$ = \\log\\left(\\left(\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}\\right)^{\\frac{1}{N}}\\right)$$\n", "\n", "$$ = \\log\\left(\\left({\\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\\right)^{-\\frac{1}{N}}\\right)$$\n", "\n", "$$ = -\\frac{1}{N}{\\log\\left({\\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\\right)} $$\n", "\n", "$$ = -\\frac{1}{N}{{\\sum_{i=1}^{N}{\\log P(w_i| w_1,...,w_{n-1})}}} $$\n" ] }, { "cell_type": "code", "execution_count": null, "id": "1d16f656-6d8a-4cbd-adda-e9bf5b02c855", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "ea5fd9b6-b740-425a-8102-f625ec8ebfe2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python [conda env:aiking] *", "language": "python", "name": "conda-env-aiking-py" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.13" } }, "nbformat": 4, "nbformat_minor": 5 }