# Recurrent Neural Networks for Language Model

## Imports

In [1]:
from trax.fastmath import numpy as np
import random
import trax

2022-07-29 06:27:08.825369: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory


## RNN Calculation Viz

In [2]:
w_hh = np.full((3,2), 1); w_hh, w_hh.shape



(DeviceArray([[1, 1],
              [1, 1],
              [1, 1]], dtype=int32, weak_type=True),
 (3, 2))

In [3]:
w_hx = np.full((3,3), 9); w_hx, w_hx.shape

(DeviceArray([[9, 9, 9],
              [9, 9, 9],
              [9, 9, 9]], dtype=int32, weak_type=True),
 (3, 3))

In [4]:
w_h1 = np.concatenate((w_hh, w_hx), axis=1); w_h1, w_h1.shape

(DeviceArray([[1, 1, 9, 9, 9],
              [1, 1, 9, 9, 9],
              [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),
 (3, 5))

In [5]:
w_h2 = np.hstack((w_hh, w_hx)); w_h2, w_h2.shape

(DeviceArray([[1, 1, 9, 9, 9],
              [1, 1, 9, 9, 9],
              [1, 1, 9, 9, 9]], dtype=int32, weak_type=True),
 (3, 5))

In [6]:
h_t_prev = np.full((2,1),1); h_t_prev, h_t_prev.shape

(DeviceArray([[1],
              [1]], dtype=int32, weak_type=True),
 (2, 1))

In [7]:
x_t = np.full((3,1), 9); x_t, x_t.shape

(DeviceArray([[9],
              [9],
              [9]], dtype=int32, weak_type=True),
 (3, 1))

In [8]:
ax_1 = np.concatenate((h_t_prev, x_t), axis=0); ax_1

DeviceArray([[1],
             [1],
             [9],
             [9],
             [9]], dtype=int32, weak_type=True)

In [9]:
ax_2 = np.vstack((h_t_prev, x_t)); ax_2,ax_2.shape

(DeviceArray([[1],
              [1],
              [9],
              [9],
              [9]], dtype=int32, weak_type=True),
 (5, 1))

In [10]:
w_h1@ax_1

DeviceArray([[245],
             [245],
             [245]], dtype=int32, weak_type=True)

In [11]:
w_hh@h_t_prev + w_hx@x_t

DeviceArray([[245],
             [245],
             [245]], dtype=int32, weak_type=True)

## Vanilla RNNs, GRUs and the scan function

*RNN*


\begin{equation}
h^{<t>}=g(W_{h}[h^{<t-1>},x^{<t>}] + b_h)
\label{eq: htRNN}
\end{equation}
    
\begin{equation}
\hat{y}^{<t>}=g(W_{yh}h^{<t>} + b_y)
\label{eq: ytRNN}
\end{equation}
    
![RNN](RNN.png)
    
*GRU*
  
\begin{equation}
\Gamma_r=\sigma{(W_r[h^{<t-1>}, x^{<t>}]+b_r)}
\end{equation}

\begin{equation}
\Gamma_u=\sigma{(W_u[h^{<t-1>}, x^{<t>}]+b_u)}
\end{equation}

\begin{equation}
c^{<t>}=\tanh{(W_h[\Gamma_r*h^{<t-1>},x^{<t>}]+b_h)}
\end{equation}

\begin{equation}
h^{<t>}=\Gamma_u*c^{<t>}+(1-\Gamma_u)*h^{<t-1>}
\end{equation}
    
![GRU](GRU.png)

In [12]:
def sigmoid(x):
    return 1.0/(1.0+np.exp(-x))

def forward_VRNN(inputs, weights):
    x_t, h_t_prev = inputs
    stack = np.vstack([h_t_prev, x_t])
    w_h, _, _,w_y, b_h, _,_,b_y = weights
    h_t = sigmoid(w_h@stack + b_h)
    # print(h_t.shape)
    y_t = sigmoid(w_y@h_t+b_y)
    return y_t, h_t

def forward_GRU(inputs, weights):
    x_t, h_t_prev = inputs
    stack = np.vstack([h_t_prev, x_t])
    w_r, w_u, w_h,w_y, b_r, b_u,b_h,b_y = weights
    T_r = sigmoid(w_r@stack + b_r)
    T_u = sigmoid(w_u@stack + b_u)
    c_t = np.tanh(w_h@np.vstack((T_r*h_t_prev, x_t)) + b_h)
    h_t = T_u*c_t+(1-T_u)*h_t_prev
    y_t = sigmoid(w_y@h_t+b_y)
    return y_t, h_t

def scan(fn, elems, weights, h_0=None):
    h_t = h_0
    ys = []
    for x in elems:
        y, h_t = fn([x, h_t], weights)
        ys.append(y)
    return np.array(ys), h_t

In [13]:
random.seed(10)
emb = 128
T = 256                 # Number of variables in sequence
h_dim = 16              # Hidden state dimensions
h_0 = np.zeros((h_dim, 1));h_0.shape

random_key = trax.fastmath.random.get_prng(seed=0)
w1 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, emb+h_dim)); w1
w2 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, emb+h_dim)); w2
w3 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, emb+h_dim)); w3

w4 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, h_dim)); w4

b1 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, 1)); b1
b2 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, 1)); b2
b3 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, 1)); b3

b4 = trax.fastmath.random.normal(key = random_key, 
                                 shape = (h_dim, 1)); b4
X = trax.fastmath.random.normal(key = random_key, 
                                 shape = (T, emb, 1)); X
weights = [w1, w2, w3, w4, b1, b2, b3, b4]; weights
w1.shape, h_0.shape, X.shape, X[0].shape

((16, 144), (16, 1), (256, 128, 1), (128, 1))

In [14]:
inputs = X[0], h_0
weights = weights
forward_VRNN(inputs, weights)
forward_GRU(inputs, weights)

(DeviceArray([[0.98875064],
              [0.90948516],
              [0.83860266],
              [0.605238  ],
              [0.8174015 ],
              [0.8074013 ],
              [0.9864478 ],
              [0.19836329],
              [0.23806544],
              [0.03616815],
              [0.9230415 ],
              [0.897435  ],
              [0.81876415],
              [0.06574202],
              [0.14956403],
              [0.820773  ]], dtype=float32),
 DeviceArray([[ 9.9999964e-01],
              [ 1.0000000e+00],
              [ 9.9392438e-01],
              [-1.2257343e-04],
              [-7.2951213e-02],
              [-2.5626263e-02],
              [-2.4690714e-09],
              [-2.0006059e-01],
              [ 7.9664338e-01],
              [-4.0450820e-07],
              [-1.1183748e-03],
              [-1.2833535e-09],
              [-7.8878832e-09],
              [ 9.9999976e-01],
              [ 9.5781153e-01],
              [-7.0657270e-06]], dtype=float32))

In [15]:
np.array([1,2,3])*np.array([2,2,2])

DeviceArray([2, 4, 6], dtype=int32)

In [16]:
%time scan(fn=forward_VRNN, elems=X, weights=weights, h_0=h_0)

CPU times: user 2.12 s, sys: 43.6 ms, total: 2.17 s
Wall time: 2.16 s


(DeviceArray([[[0.9844214 ],
               [0.86315817],
               [0.808999  ],
               ...,
               [0.05173129],
               [0.21562287],
               [0.7522172 ]],
 
              [[0.9830031 ],
               [0.8508622 ],
               [0.7086585 ],
               ...,
               [0.02949953],
               [0.27303317],
               [0.9107685 ]],
 
              [[0.5616716 ],
               [0.42216417],
               [0.02124172],
               ...,
               [0.49089074],
               [0.6445939 ],
               [0.0898608 ]],
 
              ...,
 
              [[0.96719635],
               [0.3858212 ],
               [0.06641304],
               ...,
               [0.89807695],
               [0.8543829 ],
               [0.7933321 ]],
 
              [[0.9930194 ],
               [0.75615996],
               [0.3213525 ],
               ...,
               [0.46565428],
               [0.5535333 ],
               [0.7318801 

In [17]:
%time scan(fn=forward_GRU, elems=X, weights=weights, h_0=h_0)

CPU times: user 386 ms, sys: 0 ns, total: 386 ms
Wall time: 384 ms


(DeviceArray([[[0.98875064],
               [0.90948516],
               [0.83860266],
               ...,
               [0.06574202],
               [0.14956403],
               [0.820773  ]],
 
              [[0.99207675],
               [0.8274236 ],
               [0.7019363 ],
               ...,
               [0.08724681],
               [0.2603313 ],
               [0.9303203 ]],
 
              [[0.96868426],
               [0.65789413],
               [0.04706543],
               ...,
               [0.02912108],
               [0.76775664],
               [0.1878842 ]],
 
              ...,
 
              [[0.99862254],
               [0.68557423],
               [0.68682545],
               ...,
               [0.01584747],
               [0.6382734 ],
               [0.27868038]],
 
              [[0.9979925 ],
               [0.67977065],
               [0.7547641 ],
               ...,
               [0.01567412],
               [0.54346865],
               [0.24606723

## Perplexity Calculations

The perplexity is a metric that measures how well a probability model predicts a sample and it is commonly used to evaluate language models. It is defined as: 

$$P(W) = \sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}$$

As an implementation hack, you would usually take the log of that formula (so the computation is less prone to underflow problems)

After taking the logarithm of $P(W)$ you have:

$$log P(W) = {\log\left(\sqrt[N]{\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}}\right)}$$


$$ = \log\left(\left(\prod_{i=1}^{N} \frac{1}{P(w_i| w_1,...,w_{n-1})}\right)^{\frac{1}{N}}\right)$$

$$ = \log\left(\left({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\right)^{-\frac{1}{N}}\right)$$

$$ = -\frac{1}{N}{\log\left({\prod_{i=1}^{N}{P(w_i| w_1,...,w_{n-1})}}\right)} $$

$$ = -\frac{1}{N}{{\sum_{i=1}^{N}{\log P(w_i| w_1,...,w_{n-1})}}} $$
