{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "b706722e-bc0f-4e48-890d-76ad514ee6aa",
   "metadata": {},
   "source": [
    "# Recurrent Neural Networks for Language Model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "0c5876fc-0371-4862-905b-a36ae2b4ccc9",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "5885d14a-7f6c-43e8-94f5-41f1cff350f0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "2022-07-29 06:27:08.825369: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory\n"
     ]
    }
   ],
   "source": [
    "from trax.fastmath import numpy as np\n",
    "import random\n",
    "import trax"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7be0f358-f94d-431e-b046-db3496384ea3",
   "metadata": {},
   "source": [
    "## RNN Calculation Viz"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f4cf75f0-6576-4975-9e60-6011267f25dd",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[1, 1],\n",
       "              [1, 1],\n",
       "              [1, 1]], dtype=int32, weak_type=True),\n",
       " (3, 2))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_hh = np.full((3,2), 1); w_hh, w_hh.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "7cbad359-d636-464c-bd36-b66e8e0b0e86",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[9, 9, 9],\n",
       "              [9, 9, 9],\n",
       "              [9, 9, 9]], dtype=int32, weak_type=True),\n",
       " (3, 3))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_hx = np.full((3,3), 9); w_hx, w_hx.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "6f5710a6-c7b1-4085-80e3-fa9c5985e140",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[1, 1, 9, 9, 9],\n",
       "              [1, 1, 9, 9, 9],\n",
       "              [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),\n",
       " (3, 5))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_h1 = np.concatenate((w_hh, w_hx), axis=1); w_h1, w_h1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "f4bdde3e-6935-431e-8d23-2998d007a3a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[1, 1, 9, 9, 9],\n",
       "              [1, 1, 9, 9, 9],\n",
       "              [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),\n",
       " (3, 5))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_h2 = np.hstack((w_hh, w_hx)); w_h2, w_h2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f47956c6-98c4-47b3-8dd8-5f8d1c318f25",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[1],\n",
       "              [1]], dtype=int32, weak_type=True),\n",
       " (2, 1))"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "h_t_prev = np.full((2,1),1); h_t_prev, h_t_prev.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "de1e84c6-9ad4-4b61-9180-aff7b06b0456",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[9],\n",
       "              [9],\n",
       "              [9]], dtype=int32, weak_type=True),\n",
       " (3, 1))"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x_t = np.full((3,1), 9); x_t, x_t.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "ef122a63-e80a-4e65-a616-e6336a1d6854",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[1],\n",
       "             [1],\n",
       "             [9],\n",
       "             [9],\n",
       "             [9]], dtype=int32, weak_type=True)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ax_1 = np.concatenate((h_t_prev, x_t), axis=0); ax_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "ca8358ba-9322-4e01-b485-236a00b314e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[1],\n",
       "              [1],\n",
       "              [9],\n",
       "              [9],\n",
       "              [9]], dtype=int32, weak_type=True),\n",
       " (5, 1))"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "ax_2 = np.vstack((h_t_prev, x_t)); ax_2,ax_2.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "00aaa273-8a0d-4af4-b92c-3d1d5622eaf1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[245],\n",
       "             [245],\n",
       "             [245]], dtype=int32, weak_type=True)"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_h1@ax_1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "0cfdf3c0-4f55-4d2d-a01c-216012140bfc",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([[245],\n",
       "             [245],\n",
       "             [245]], dtype=int32, weak_type=True)"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "w_hh@h_t_prev + w_hx@x_t"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5b2983ba-31d1-421a-815f-effe78493c46",
   "metadata": {},
   "source": [
    "## Vanilla RNNs, GRUs and the scan function"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fb91abe6-f917-4175-b566-ac658306d857",
   "metadata": {},
   "source": [
    "*RNN*\n",
    "\n",
    "\n",
    "\\begin{equation}\n",
    "h^{<t>}=g(W_{h}[h^{<t-1>},x^{<t>}] + b_h)\n",
    "\\label{eq: htRNN}\n",
    "\\end{equation}\n",
    "    \n",
    "\\begin{equation}\n",
    "\\hat{y}^{<t>}=g(W_{yh}h^{<t>} + b_y)\n",
    "\\label{eq: ytRNN}\n",
    "\\end{equation}\n",
    "    \n",
    "![RNN](RNN.png)\n",
    "    \n",
    "*GRU*\n",
    "  \n",
    "\\begin{equation}\n",
    "\\Gamma_r=\\sigma{(W_r[h^{<t-1>}, x^{<t>}]+b_r)}\n",
    "\\end{equation}\n",
    "\n",
    "\\begin{equation}\n",
    "\\Gamma_u=\\sigma{(W_u[h^{<t-1>}, x^{<t>}]+b_u)}\n",
    "\\end{equation}\n",
    "\n",
    "\\begin{equation}\n",
    "c^{<t>}=\\tanh{(W_h[\\Gamma_r*h^{<t-1>},x^{<t>}]+b_h)}\n",
    "\\end{equation}\n",
    "\n",
    "\\begin{equation}\n",
    "h^{<t>}=\\Gamma_u*c^{<t>}+(1-\\Gamma_u)*h^{<t-1>}\n",
    "\\end{equation}\n",
    "    \n",
    "![GRU](GRU.png)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "e9d81e45-4d3c-451d-825f-e817f5e2a684",
   "metadata": {},
   "outputs": [],
   "source": [
    "def sigmoid(x):\n",
    "    return 1.0/(1.0+np.exp(-x))\n",
    "\n",
    "def forward_VRNN(inputs, weights):\n",
    "    x_t, h_t_prev = inputs\n",
    "    stack = np.vstack([h_t_prev, x_t])\n",
    "    w_h, _, _,w_y, b_h, _,_,b_y = weights\n",
    "    h_t = sigmoid(w_h@stack + b_h)\n",
    "    # print(h_t.shape)\n",
    "    y_t = sigmoid(w_y@h_t+b_y)\n",
    "    return y_t, h_t\n",
    "\n",
    "def forward_GRU(inputs, weights):\n",
    "    x_t, h_t_prev = inputs\n",
    "    stack = np.vstack([h_t_prev, x_t])\n",
    "    w_r, w_u, w_h,w_y, b_r, b_u,b_h,b_y = weights\n",
    "    T_r = sigmoid(w_r@stack + b_r)\n",
    "    T_u = sigmoid(w_u@stack + b_u)\n",
    "    c_t = np.tanh(w_h@np.vstack((T_r*h_t_prev, x_t)) + b_h)\n",
    "    h_t = T_u*c_t+(1-T_u)*h_t_prev\n",
    "    y_t = sigmoid(w_y@h_t+b_y)\n",
    "    return y_t, h_t\n",
    "\n",
    "def scan(fn, elems, weights, h_0=None):\n",
    "    h_t = h_0\n",
    "    ys = []\n",
    "    for x in elems:\n",
    "        y, h_t = fn([x, h_t], weights)\n",
    "        ys.append(y)\n",
    "    return np.array(ys), h_t"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "940369ac-90af-49ba-b9ce-5abb207dfc32",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "((16, 144), (16, 1), (256, 128, 1), (128, 1))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "random.seed(10)\n",
    "emb = 128\n",
    "T = 256                 # Number of variables in sequence\n",
    "h_dim = 16              # Hidden state dimensions\n",
    "h_0 = np.zeros((h_dim, 1));h_0.shape\n",
    "\n",
    "random_key = trax.fastmath.random.get_prng(seed=0)\n",
    "w1 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, emb+h_dim)); w1\n",
    "w2 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, emb+h_dim)); w2\n",
    "w3 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, emb+h_dim)); w3\n",
    "\n",
    "w4 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, h_dim)); w4\n",
    "\n",
    "b1 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, 1)); b1\n",
    "b2 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, 1)); b2\n",
    "b3 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, 1)); b3\n",
    "\n",
    "b4 = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (h_dim, 1)); b4\n",
    "X = trax.fastmath.random.normal(key = random_key, \n",
    "                                 shape = (T, emb, 1)); X\n",
    "weights = [w1, w2, w3, w4, b1, b2, b3, b4]; weights\n",
    "w1.shape, h_0.shape, X.shape, X[0].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "4f89f820-7305-4da7-b21a-bc89be6769ba",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[0.98875064],\n",
       "              [0.90948516],\n",
       "              [0.83860266],\n",
       "              [0.605238  ],\n",
       "              [0.8174015 ],\n",
       "              [0.8074013 ],\n",
       "              [0.9864478 ],\n",
       "              [0.19836329],\n",
       "              [0.23806544],\n",
       "              [0.03616815],\n",
       "              [0.9230415 ],\n",
       "              [0.897435  ],\n",
       "              [0.81876415],\n",
       "              [0.06574202],\n",
       "              [0.14956403],\n",
       "              [0.820773  ]], dtype=float32),\n",
       " DeviceArray([[ 9.9999964e-01],\n",
       "              [ 1.0000000e+00],\n",
       "              [ 9.9392438e-01],\n",
       "              [-1.2257343e-04],\n",
       "              [-7.2951213e-02],\n",
       "              [-2.5626263e-02],\n",
       "              [-2.4690714e-09],\n",
       "              [-2.0006059e-01],\n",
       "              [ 7.9664338e-01],\n",
       "              [-4.0450820e-07],\n",
       "              [-1.1183748e-03],\n",
       "              [-1.2833535e-09],\n",
       "              [-7.8878832e-09],\n",
       "              [ 9.9999976e-01],\n",
       "              [ 9.5781153e-01],\n",
       "              [-7.0657270e-06]], dtype=float32))"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "inputs = X[0], h_0\n",
    "weights = weights\n",
    "forward_VRNN(inputs, weights)\n",
    "forward_GRU(inputs, weights)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "3cf41b53-417a-44a0-8565-3919f4fda206",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DeviceArray([2, 4, 6], dtype=int32)"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([1,2,3])*np.array([2,2,2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "36e06206-6707-46e5-af34-0635bf13fd45",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 2.12 s, sys: 43.6 ms, total: 2.17 s\n",
      "Wall time: 2.16 s\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[[0.9844214 ],\n",
       "               [0.86315817],\n",
       "               [0.808999  ],\n",
       "               ...,\n",
       "               [0.05173129],\n",
       "               [0.21562287],\n",
       "               [0.7522172 ]],\n",
       " \n",
       "              [[0.9830031 ],\n",
       "               [0.8508622 ],\n",
       "               [0.7086585 ],\n",
       "               ...,\n",
       "               [0.02949953],\n",
       "               [0.27303317],\n",
       "               [0.9107685 ]],\n",
       " \n",
       "              [[0.5616716 ],\n",
       "               [0.42216417],\n",
       "               [0.02124172],\n",
       "               ...,\n",
       "               [0.49089074],\n",
       "               [0.6445939 ],\n",
       "               [0.0898608 ]],\n",
       " \n",
       "              ...,\n",
       " \n",
       "              [[0.96719635],\n",
       "               [0.3858212 ],\n",
       "               [0.06641304],\n",
       "               ...,\n",
       "               [0.89807695],\n",
       "               [0.8543829 ],\n",
       "               [0.7933321 ]],\n",
       " \n",
       "              [[0.9930194 ],\n",
       "               [0.75615996],\n",
       "               [0.3213525 ],\n",
       "               ...,\n",
       "               [0.46565428],\n",
       "               [0.5535333 ],\n",
       "               [0.7318801 ]],\n",
       " \n",
       "              [[0.81508607],\n",
       "               [0.3935613 ],\n",
       "               [0.07889035],\n",
       "               ...,\n",
       "               [0.2907602 ],\n",
       "               [0.2841384 ],\n",
       "               [0.30343485]]], dtype=float32),\n",
       " DeviceArray([[9.9991798e-01],\n",
       "              [9.7653759e-01],\n",
       "              [2.6515589e-05],\n",
       "              [8.7920368e-01],\n",
       "              [1.0000000e+00],\n",
       "              [9.4612682e-01],\n",
       "              [9.9921525e-01],\n",
       "              [1.0000000e+00],\n",
       "              [1.0000000e+00],\n",
       "              [1.1614484e-02],\n",
       "              [9.9999189e-01],\n",
       "              [9.9806911e-01],\n",
       "              [2.3502709e-02],\n",
       "              [1.6757324e-08],\n",
       "              [8.7104484e-09],\n",
       "              [2.2113952e-10]], dtype=float32))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time scan(fn=forward_VRNN, elems=X, weights=weights, h_0=h_0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "c82f964e-5bfe-47e2-8492-4f2ca8c0ba59",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 386 ms, sys: 0 ns, total: 386 ms\n",
      "Wall time: 384 ms\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(DeviceArray([[[0.98875064],\n",
       "               [0.90948516],\n",
       "               [0.83860266],\n",
       "               ...,\n",
       "               [0.06574202],\n",
       "               [0.14956403],\n",
       "               [0.820773  ]],\n",
       " \n",
       "              [[0.99207675],\n",
       "               [0.8274236 ],\n",
       "               [0.7019363 ],\n",
       "               ...,\n",
       "               [0.08724681],\n",
       "               [0.2603313 ],\n",
       "               [0.9303203 ]],\n",
       " \n",
       "              [[0.96868426],\n",
       "               [0.65789413],\n",
       "               [0.04706543],\n",
       "               ...,\n",
       "               [0.02912108],\n",
       "               [0.76775664],\n",
       "               [0.1878842 ]],\n",
       " \n",
       "              ...,\n",
       " \n",
       "              [[0.99862254],\n",
       "               [0.68557423],\n",
       "               [0.68682545],\n",
       "               ...,\n",
       "               [0.01584747],\n",
       "               [0.6382734 ],\n",
       "               [0.27868038]],\n",
       " \n",
       "              [[0.9979925 ],\n",
       "               [0.67977065],\n",
       "               [0.7547641 ],\n",
       "               ...,\n",
       "               [0.01567412],\n",
       "               [0.54346865],\n",
       "               [0.24606723]],\n",
       " \n",
       "              [[0.9998665 ],\n",
       "               [0.38725543],\n",
       "               [0.9630767 ],\n",
       "               ...,\n",
       "               [0.00373541],\n",
       "               [0.6321914 ],\n",
       "               [0.9655915 ]]], dtype=float32),\n",
       " DeviceArray([[ 0.99998415],\n",
       "              [ 0.99973965],\n",
       "              [ 0.93683696],\n",
       "              [-0.44143844],\n",
       "              [ 1.        ],\n",
       "              [ 0.99857235],\n",
       "              [ 0.99990374],\n",
       "              [ 1.        ],\n",
       "              [ 1.        ],\n",
       "              [ 0.85161185],\n",
       "              [ 0.99997693],\n",
       "              [ 0.9956449 ],\n",
       "              [ 0.37345803],\n",
       "              [ 0.99998754],\n",
       "              [ 0.98146665],\n",
       "              [ 0.9999966 ]], dtype=float32))"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%time scan(fn=forward_GRU, elems=X, weights=weights, h_0=h_0)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c0d8a1bc-316f-4a14-8f97-f6b498204bfa",
   "metadata": {},
   "source": [
    "## Perplexity Calculations\n",
    "\n",
    "The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as: \n",
    "\n",
    "$$P(W) = \\sqrt[N]{\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}}$$\n",
    "\n",
    "As an implementation hack, you would usually take the log of that formula (so the computation is less prone to underflow problems)\n",
    "\n",
    "After taking the logarithm of $P(W)$ you have:\n",
    "\n",
    "$$log P(W) = {\\log\\left(\\sqrt[N]{\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}}\\right)}$$\n",
    "\n",
    "\n",
    "$$ = \\log\\left(\\left(\\prod_{i=1}^{N} \\frac{1}{P(w_i| w_1,...,w_{n-1})}\\right)^{\\frac{1}{N}}\\right)$$\n",
    "\n",
    "$$ = \\log\\left(\\left({\\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\\right)^{-\\frac{1}{N}}\\right)$$\n",
    "\n",
    "$$ = -\\frac{1}{N}{\\log\\left({\\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\\right)} $$\n",
    "\n",
    "$$ = -\\frac{1}{N}{{\\sum_{i=1}^{N}{\\log P(w_i| w_1,...,w_{n-1})}}} $$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d16f656-6d8a-4cbd-adda-e9bf5b02c855",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ea5fd9b6-b740-425a-8102-f625ec8ebfe2",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:aiking] *",
   "language": "python",
   "name": "conda-env-aiking-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}