{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "21847ae1-b8c4-4fa1-8b05-e5d045b0e71d",
   "metadata": {},
   "source": [
    "# Word Embeddings"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81e5c8bf-485a-4ce4-9787-942ae301a6d0",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "id": "c955e907-9fdc-4871-a34a-cd4abf93479b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np \n",
    "import scipy as sp \n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "\n",
    "import seaborn as sns\n",
    "from collections import Counter\n",
    "import nltk\n",
    "import re\n",
    "import emoji\n",
    "\n",
    "from fastcore.all import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "id": "e624b09a-b35e-4f73-8ff2-548cad7cdd69",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to\n",
      "[nltk_data]     /home/rahul.saraf/nltk_data...\n",
      "[nltk_data]   Package punkt is already up-to-date!\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "nltk.download('punkt')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b8c160c1-4e49-4098-b5bf-ee0057fdb8da",
   "metadata": {},
   "source": [
    "## Define / Get Corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "id": "7e86c638-ae2b-4adb-bc48-c22351bd10b1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Who ❤️ \"word embeddings\" in 2020? I do!!!'"
      ]
     },
     "execution_count": 28,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Define a corpus\n",
    "corpus = 'Who ❤️ \"word embeddings\" in 2020? I do!!!'; corpus"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "1fb68e40-f18f-430f-9f6e-94b9f58e907a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# em = 'Hey 😷😷😷'\n",
    "# em_split_emoji = emoji.get_emoji_regexp().split(em)\n",
    "# em_split_emoji"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "f63b0033-2f7b-4208-844a-cf950aa47d74",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['who', '❤️', 'word', 'embeddings', 'in', '.', 'i', 'do', '.']"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def tokenize(corpus):\n",
    "    data = re.sub(r'[,!?;-]+', '.', corpus); data\n",
    "    tokens = nltk.word_tokenize(data); tokens\n",
    "    return [token.lower() for token in tokens \n",
    "                     if token.isalpha() \n",
    "                     or token == \".\" \n",
    "                     or token in emoji.get_emoji_unicode_dict('en').values()]\n",
    "\n",
    "tokenize(corpus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "2d34931c-0fe3-4f80-beec-5e96ab1accad",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['i', 'am', 'happy', 'because', 'i', 'am', 'learning']"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenize('I am happy because I am learning')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "id": "1a961543-5e47-4179-90c6-3b84bcfe9350",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(['who', '❤️', 'embeddings', 'in'], 'word')"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def sliding_window(tokens, C=2):\n",
    "    for i in range(C, len(tokens)-C):\n",
    "        center_word = tokens[i]\n",
    "        context_words = tokens[i-C:i] + tokens[i+1:i+C+1]\n",
    "        yield context_words, center_word\n",
    "    return\n",
    "        # print(tokens[i-C:i], tokens[i], tokens[i+1:i+C+1])\n",
    "        \n",
    "g = sliding_window(tokenize(corpus))\n",
    "next(g)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "40e48f48-c879-4c57-811d-fd0208a66e43",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['who', '❤️', 'embeddings', 'in'] word\n",
      "['❤️', 'word', 'in', '.'] embeddings\n",
      "['word', 'embeddings', '.', 'i'] in\n",
      "['embeddings', 'in', 'i', 'do'] .\n",
      "['in', '.', 'do', '.'] i\n"
     ]
    }
   ],
   "source": [
    "for context_words, center_word in sliding_window(tokenize(corpus)):\n",
    "    print(context_words, center_word)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "id": "b4982306-c716-46fa-a3e8-df845282bb33",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>.</th>\n",
       "      <th>do</th>\n",
       "      <th>embeddings</th>\n",
       "      <th>i</th>\n",
       "      <th>in</th>\n",
       "      <th>who</th>\n",
       "      <th>word</th>\n",
       "      <th>❤️</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>1</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   .  do  embeddings  i  in  who  word  ❤️\n",
       "0  0   0           0  0   0    1     0   0\n",
       "1  0   0           0  0   0    0     0   1\n",
       "2  0   0           0  0   0    0     1   0\n",
       "3  0   0           1  0   0    0     0   0\n",
       "4  0   0           0  0   1    0     0   0\n",
       "5  1   0           0  0   0    0     0   0\n",
       "6  0   0           0  1   0    0     0   0\n",
       "7  0   1           0  0   0    0     0   0\n",
       "8  1   0           0  0   0    0     0   0"
      ]
     },
     "execution_count": 81,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokens = tokenize(corpus)\n",
    "ind2word = dict(enumerate(tokens))\n",
    "word2ind = {v:k for k,v in ind2word.items()}\n",
    "one_hot = pd.get_dummies(pd.Series(ind2word, name='vocab')); one_hot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 92,
   "id": "c7aa73a2-5b27-45d4-8e09-64b507de4067",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(array([0.25, 0.25, 0.  , 0.25, 0.25, 0.  , 0.  , 0.  , 0.  ]),\n",
       " array([0, 0, 1, 0, 0, 0, 0, 0, 0], dtype=uint8))"
      ]
     },
     "execution_count": 92,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "g = sliding_window(tokenize(corpus))\n",
    "context_words, center_word = next(g)\n",
    "one_hot[context_words].mean(axis=1).values, one_hot[center_word].values"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "23e35871-0421-48c5-ae3b-7d42a3173118",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}