{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "-Jv7Y4hXwt0j"
   },
   "source": [
    "# Assignment 4:  Question duplicates\n",
    "\n",
    "Welcome to the fourth assignment of course 3. In this assignment you will explore Siamese networks applied to natural language processing. You will further explore the fundamentals of Trax and you will be able to implement a more complicated structure using it. By completing this assignment, you will learn how to implement models with different architectures. \n",
    "\n",
    "## Important Note on Submission to the AutoGrader\n",
    "\n",
    "Before submitting your assignment to the AutoGrader, please make sure you are not doing the following:\n",
    "\n",
    "1. You have not added any _extra_ `print` statement(s) in the assignment.\n",
    "2. You have not added any _extra_ code cell(s) in the assignment.\n",
    "3. You have not changed any of the function parameters.\n",
    "4. You are not using any global variables inside your graded exercises. Unless specifically instructed to do so, please refrain from it and use the local variables instead.\n",
    "5. You are not changing the assignment code where it is not required, like creating _extra_ variables.\n",
    "\n",
    "If you do any of the following, you will get something like, `Grader not found` (or similarly unexpected) error upon submitting your assignment. Before asking for help/debugging the errors in your assignment, check for these first. If this is the case, and you don't remember the changes you have made, you can get a fresh copy of the assignment by following these [instructions](https://www.coursera.org/learn/sequence-models-in-nlp/supplement/6ThZO/how-to-refresh-your-workspace).\n",
    "\n",
    "## Outline\n",
    "\n",
    "- [Overview](#0)\n",
    "- [Part 1: Importing the Data](#1)\n",
    "    - [1.1 Loading in the data](#1.1)\n",
    "    - [1.2 Converting a question to a tensor](#1.2)\n",
    "    - [1.3 Understanding the iterator](#1.3)\n",
    "        - [Exercise 01](#ex01)\n",
    "- [Part 2: Defining the Siamese model](#2)\n",
    "    - [2.1 Understanding Siamese Network](#2.1)\n",
    "        - [Exercise 02](#ex02)\n",
    "    - [2.2 Hard  Negative Mining](#2.2)\n",
    "        - [Exercise 03](#ex03)\n",
    "- [Part 3: Training](#3)\n",
    "    - [3.1 Training the model](#3.1)\n",
    "        - [Exercise 04](#ex04)\n",
    "- [Part 4: Evaluation](#4)\n",
    "    - [4.1 Evaluating your siamese network](#4.1)\n",
    "    - [4.2 Classify](#4.2)\n",
    "        - [Exercise 05](#ex05)\n",
    "- [Part 5: Testing with your own questions](#5)\n",
    "    - [Exercise 06](#ex06)\n",
    "- [On Siamese networks](#6)\n",
    "\n",
    "<a name='0'></a>\n",
    "### Overview\n",
    "In this assignment, concretely you will: \n",
    "\n",
    "- Learn about Siamese networks\n",
    "- Understand how the triplet loss works\n",
    "- Understand how to evaluate accuracy\n",
    "- Use cosine similarity between the model's outputted vectors\n",
    "- Use the data generator to get batches of questions\n",
    "- Predict using your own model\n",
    "\n",
    "By now, you are familiar with trax and know how to make use of classes to define your model. We will start this homework by asking you to preprocess the data the same way you did in the previous assignments. After processing the data you will build a classifier that will allow you to identify whether two questions are the same or not. \n",
    "<img src = \"images/meme.png\" style=\"width:550px;height:300px;\"/>\n",
    "\n",
    "\n",
    "You will process the data first and then pad in a similar way you have done in the previous assignment. Your model will take in the two question embeddings, run them through an LSTM, and then compare the outputs of the two sub networks using cosine similarity. Before taking a deep dive into the model, start by importing the data set.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4sF9Hqzgwt0l"
   },
   "source": [
    "<a name='1'></a>\n",
    "# Part 1: Importing the Data\n",
    "<a name='1.1'></a>\n",
    "### 1.1 Loading in the data\n",
    "\n",
    "You will be using the Quora question answer dataset to build a model that could identify similar questions. This is a useful task because you don't want to have several versions of the same question posted. Several times when teaching I end up responding to similar questions on piazza, or on other community forums. This data set has been labeled for you. Run the cell below to import some of the packages you will be using. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "zdACgs491cs2",
    "outputId": "b31042ef-845b-46b8-c783-185e96b135f7"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "[nltk_data] Downloading package punkt to /home/jovyan/nltk_data...\n",
      "[nltk_data]   Unzipping tokenizers/punkt.zip.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import nltk\n",
    "import trax\n",
    "from trax import layers as tl\n",
    "from trax.supervised import training\n",
    "from trax.fastmath import numpy as fastnp\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import random as rnd\n",
    "\n",
    "import w4_unittest\n",
    "\n",
    "nltk.download('punkt')\n",
    "\n",
    "# set random seeds\n",
    "rnd.seed(34)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "3GYhQRMspitx"
   },
   "source": [
    "**Notice that for this assignment Trax's numpy is referred to as `fastnp`, while regular numpy is referred to as `np`.**\n",
    "\n",
    "You will now load in the data set. We have done some preprocessing for you. If you have taken the deeplearning specialization, this is a slightly different training method than the one you have seen there. If you have not, then don't worry about it, we will explain everything. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 528
    },
    "colab_type": "code",
    "id": "sXWBVGWnpity",
    "outputId": "afa90d4d-fed7-43b8-bcba-48c95d600ad5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of question pairs:  404351\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>id</th>\n",
       "      <th>qid1</th>\n",
       "      <th>qid2</th>\n",
       "      <th>question1</th>\n",
       "      <th>question2</th>\n",
       "      <th>is_duplicate</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "      <td>2</td>\n",
       "      <td>What is the step by step guide to invest in sh...</td>\n",
       "      <td>What is the step by step guide to invest in sh...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>1</td>\n",
       "      <td>3</td>\n",
       "      <td>4</td>\n",
       "      <td>What is the story of Kohinoor (Koh-i-Noor) Dia...</td>\n",
       "      <td>What would happen if the Indian government sto...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>2</td>\n",
       "      <td>5</td>\n",
       "      <td>6</td>\n",
       "      <td>How can I increase the speed of my internet co...</td>\n",
       "      <td>How can Internet speed be increased by hacking...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>3</td>\n",
       "      <td>7</td>\n",
       "      <td>8</td>\n",
       "      <td>Why am I mentally very lonely? How can I solve...</td>\n",
       "      <td>Find the remainder when [math]23^{24}[/math] i...</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>4</td>\n",
       "      <td>9</td>\n",
       "      <td>10</td>\n",
       "      <td>Which one dissolve in water quikly sugar, salt...</td>\n",
       "      <td>Which fish would survive in salt water?</td>\n",
       "      <td>0</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   id  qid1  qid2                                          question1  \\\n",
       "0   0     1     2  What is the step by step guide to invest in sh...   \n",
       "1   1     3     4  What is the story of Kohinoor (Koh-i-Noor) Dia...   \n",
       "2   2     5     6  How can I increase the speed of my internet co...   \n",
       "3   3     7     8  Why am I mentally very lonely? How can I solve...   \n",
       "4   4     9    10  Which one dissolve in water quikly sugar, salt...   \n",
       "\n",
       "                                           question2  is_duplicate  \n",
       "0  What is the step by step guide to invest in sh...             0  \n",
       "1  What would happen if the Indian government sto...             0  \n",
       "2  How can Internet speed be increased by hacking...             0  \n",
       "3  Find the remainder when [math]23^{24}[/math] i...             0  \n",
       "4            Which fish would survive in salt water?             0  "
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data = pd.read_csv(\"data/questions.csv\")\n",
    "N=len(data)\n",
    "print('Number of question pairs: ', N)\n",
    "data.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "gkSQTu7Ypit0"
   },
   "source": [
    "We first split the data into a train and test set. The test set will be used later to evaluate our model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "z00A7vEMpit1",
    "outputId": "c12ae7e8-a959-4f56-aa29-6ad34abc1c81"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set: 300000 Test set: 10240\n"
     ]
    }
   ],
   "source": [
    "N_train = 300000\n",
    "N_test  = 10*1024\n",
    "data_train = data[:N_train]\n",
    "data_test  = data[N_train:N_train+N_test]\n",
    "print(\"Train set:\", len(data_train), \"Test set:\", len(data_test))\n",
    "del(data) # remove to free memory"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FbqIRRyEpit4"
   },
   "source": [
    "As explained in the lectures, we select only the question pairs that are duplicate to train the model. <br>\n",
    "We build two batches as input for the Siamese network and we assume that question $q1_i$ (question $i$ in the first batch) is a duplicate of $q2_i$ (question $i$ in the second batch), but all other questions in the second batch are not duplicates of $q1_i$.  \n",
    "The test set uses the original pairs of questions and the status describing if the questions are duplicates."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "Xi_TwXxxpit4",
    "outputId": "f146046f-9c0d-4d8a-ecf8-8d6a4a5371f7"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "number of duplicate questions:  111486\n",
      "indexes of first ten duplicate questions: [5, 7, 11, 12, 13, 15, 16, 18, 20, 29]\n"
     ]
    }
   ],
   "source": [
    "td_index = (data_train['is_duplicate'] == 1).to_numpy()\n",
    "td_index = [i for i, x in enumerate(td_index) if x] \n",
    "print('number of duplicate questions: ', len(td_index))\n",
    "print('indexes of first ten duplicate questions:', td_index[:10])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "3I9oXSsKpit7",
    "outputId": "6f6bd3a1-219f-4fb3-a524-450c38bf44ba"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Astrology: I am a Capricorn Sun Cap moon and cap rising...what does that say about me?\n",
      "I'm a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me?\n",
      "is_duplicate:  1\n"
     ]
    }
   ],
   "source": [
    "print(data_train['question1'][5])  #  Example of question duplicates (first one in data)\n",
    "print(data_train['question2'][5])\n",
    "print('is_duplicate: ', data_train['is_duplicate'][5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "XHpZO58Dss_v"
   },
   "outputs": [],
   "source": [
    "Q1_train_words = np.array(data_train['question1'][td_index])\n",
    "Q2_train_words = np.array(data_train['question2'][td_index])\n",
    "\n",
    "Q1_test_words = np.array(data_test['question1'])\n",
    "Q2_test_words = np.array(data_test['question2'])\n",
    "y_test  = np.array(data_test['is_duplicate'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "P5vBkxunpiuB"
   },
   "source": [
    "Above, you have seen that you only took the duplicated questions for training our model. <br>You did so on purpose, because the data generator will produce batches $([q1_1, q1_2, q1_3, ...]$, $[q2_1, q2_2,q2_3, ...])$  where $q1_i$ and $q2_k$ are duplicate if and only if $i = k$.\n",
    "\n",
    "<br>Let's print to see what your data looks like."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 170
    },
    "colab_type": "code",
    "id": "joyrS1XEpLWn",
    "outputId": "3257cde7-3164-40d9-910e-fa91eae917a0"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TRAINING QUESTIONS:\n",
      "\n",
      "Question 1:  Astrology: I am a Capricorn Sun Cap moon and cap rising...what does that say about me?\n",
      "Question 2:  I'm a triple Capricorn (Sun, Moon and ascendant in Capricorn) What does this say about me? \n",
      "\n",
      "Question 1:  What would a Trump presidency mean for current international master’s students on an F1 visa?\n",
      "Question 2:  How will a Trump presidency affect the students presently in US or planning to study in US? \n",
      "\n",
      "TESTING QUESTIONS:\n",
      "\n",
      "Question 1:  How do I prepare for interviews for cse?\n",
      "Question 2:  What is the best way to prepare for cse? \n",
      "\n",
      "is_duplicate = 0 \n",
      "\n"
     ]
    }
   ],
   "source": [
    "print('TRAINING QUESTIONS:\\n')\n",
    "print('Question 1: ', Q1_train_words[0])\n",
    "print('Question 2: ', Q2_train_words[0], '\\n')\n",
    "print('Question 1: ', Q1_train_words[5])\n",
    "print('Question 2: ', Q2_train_words[5], '\\n')\n",
    "\n",
    "print('TESTING QUESTIONS:\\n')\n",
    "print('Question 1: ', Q1_test_words[0])\n",
    "print('Question 2: ', Q2_test_words[0], '\\n')\n",
    "print('is_duplicate =', y_test[0], '\\n')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "WC_BZU3XpiuF"
   },
   "source": [
    "You will now encode each word of the selected duplicate pairs with an index. <br> Given a question, you can then just encode it as a list of numbers.  \n",
    "\n",
    "First you tokenize the questions using `nltk.word_tokenize`. <br>\n",
    "You need a python default dictionary which later, during inference, assigns the values $0$ to all Out Of Vocabulary (OOV) words.<br>\n",
    "Then you encode each word of the selected duplicate pairs with an index. Given a question, you can then just encode it as a list of numbers. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "QbCoIgLQpiuF"
   },
   "outputs": [],
   "source": [
    "#create arrays\n",
    "Q1_train = np.empty_like(Q1_train_words)\n",
    "Q2_train = np.empty_like(Q2_train_words)\n",
    "\n",
    "Q1_test = np.empty_like(Q1_test_words)\n",
    "Q2_test = np.empty_like(Q2_test_words)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "m9ZmfpGWpiuI",
    "outputId": "d2995c9a-92b4-4892-d34b-c77b94b27134"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The length of the vocabulary is:  36268\n"
     ]
    }
   ],
   "source": [
    "# Building the vocabulary with the train set         (this might take a minute)\n",
    "from collections import defaultdict\n",
    "\n",
    "vocab = defaultdict(lambda: 0)\n",
    "vocab['<PAD>'] = 1\n",
    "\n",
    "for idx in range(len(Q1_train_words)):\n",
    "    Q1_train[idx] = nltk.word_tokenize(Q1_train_words[idx])\n",
    "    Q2_train[idx] = nltk.word_tokenize(Q2_train_words[idx])\n",
    "    q = Q1_train[idx] + Q2_train[idx]\n",
    "    for word in q:\n",
    "        if word not in vocab:\n",
    "            vocab[word] = len(vocab) + 1\n",
    "print('The length of the vocabulary is: ', len(vocab))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "TTMRF8eZpiuK",
    "outputId": "f81d4dc1-7cf9-4476-a454-467b54fe4dc4"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1\n",
      "2\n",
      "0\n"
     ]
    }
   ],
   "source": [
    "print(vocab['<PAD>'])\n",
    "print(vocab['Astrology'])\n",
    "print(vocab['Astronomy'])  #not in vocabulary, returns 0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "5sDs36m81g6f"
   },
   "outputs": [],
   "source": [
    "for idx in range(len(Q1_test_words)): \n",
    "    Q1_test[idx] = nltk.word_tokenize(Q1_test_words[idx])\n",
    "    Q2_test[idx] = nltk.word_tokenize(Q2_test_words[idx])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "3QgGE9KlpiuP",
    "outputId": "19c3cf93-cf0d-4f8f-da99-e75481f16599"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Train set has reduced to:  111486\n",
      "Test set length:  10240\n"
     ]
    }
   ],
   "source": [
    "print('Train set has reduced to: ', len(Q1_train) ) \n",
    "print('Test set length: ', len(Q1_test) ) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "BDcxEmX31y3d"
   },
   "source": [
    "<a name='1.2'></a>\n",
    "### 1.2 Converting a question to a tensor\n",
    "\n",
    "You will now convert every question to a tensor, or an array of numbers, using your vocabulary built above."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "zOhNa-sapiuS"
   },
   "outputs": [],
   "source": [
    "# Converting questions to array of integers\n",
    "for i in range(len(Q1_train)):\n",
    "    Q1_train[i] = [vocab[word] for word in Q1_train[i]]\n",
    "    Q2_train[i] = [vocab[word] for word in Q2_train[i]]\n",
    "\n",
    "        \n",
    "for i in range(len(Q1_test)):\n",
    "    Q1_test[i] = [vocab[word] for word in Q1_test[i]]\n",
    "    Q2_test[i] = [vocab[word] for word in Q2_test[i]]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "Dpawm38dpiuU",
    "outputId": "ef1aa65b-c89b-46f9-a9cf-f73748f1ee56"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "first question in the train set:\n",
      "\n",
      "Astrology: I am a Capricorn Sun Cap moon and cap rising...what does that say about me? \n",
      "\n",
      "encoded version:\n",
      "[2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21] \n",
      "\n",
      "first question in the test set:\n",
      "\n",
      "How do I prepare for interviews for cse? \n",
      "\n",
      "encoded version:\n",
      "[32, 38, 4, 107, 65, 1015, 65, 11509, 21]\n"
     ]
    }
   ],
   "source": [
    "print('first question in the train set:\\n')\n",
    "print(Q1_train_words[0], '\\n') \n",
    "print('encoded version:')\n",
    "print(Q1_train[0],'\\n')\n",
    "\n",
    "print('first question in the test set:\\n')\n",
    "print(Q1_test_words[0], '\\n')\n",
    "print('encoded version:')\n",
    "print(Q1_test[0]) "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "SuggGPaQpiuY"
   },
   "source": [
    "You will now split your train set into a training/validation set so that you can use it to train and evaluate your Siamese model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 68
    },
    "colab_type": "code",
    "id": "BmhrWPtgpiuY",
    "outputId": "7272fb74-79e6-499a-ce95-d11b9edcd64a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Number of duplicate questions:  111486\n",
      "The length of the training set is:   89188\n",
      "The length of the validation set is:  22298\n"
     ]
    }
   ],
   "source": [
    "# Splitting the data\n",
    "cut_off = int(len(Q1_train)*.8)\n",
    "train_Q1, train_Q2 = Q1_train[:cut_off], Q2_train[:cut_off]\n",
    "val_Q1, val_Q2 = Q1_train[cut_off: ], Q2_train[cut_off:]\n",
    "print('Number of duplicate questions: ', len(Q1_train))\n",
    "print(\"The length of the training set is:  \", len(train_Q1))\n",
    "print(\"The length of the validation set is: \", len(val_Q1))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "iFOR19cX2TQs"
   },
   "source": [
    "<a name='1.3'></a>\n",
    "### 1.3 Understanding the iterator \n",
    "\n",
    "Most of the time in Natural Language Processing, and AI in general we use batches when training our data sets. If you were to use stochastic gradient descent with one example at a time, it will take you forever to build a model. In this example, we show you how you can build a data generator that takes in $Q1$ and $Q2$ and returns a batch of size `batch_size`  in the following format $([q1_1, q1_2, q1_3, ...]$, $[q2_1, q2_2,q2_3, ...])$. The tuple consists of two arrays and each array has `batch_size` questions. Again, $q1_i$ and $q2_i$ are duplicates, but they are not duplicates with any other elements in the batch. \n",
    "\n",
    "<br>\n",
    "\n",
    "The command ```next(data_generator)```returns the next batch. This iterator returns the data in a format that you could directly use in your model when computing the feed-forward of your algorithm. This iterator returns a pair of arrays of questions. \n",
    "\n",
    "<a name='ex01'></a>\n",
    "### Exercise 01\n",
    "\n",
    "**Instructions:**  \n",
    "Implement the data generator below. Here are some things you will need. \n",
    "\n",
    "- While true loop.\n",
    "- if `index >= len_Q1`, set the `idx` to $0$.\n",
    "- The generator should return shuffled batches of data. To achieve this without modifying the actual question lists, a list containing the indexes of the questions is created. This list can be shuffled and used to get random batches everytime the index is reset.\n",
    "- Append elements of $Q1$ and $Q2$ to `input1` and `input2` respectively.\n",
    "- if `len(input1) == batch_size`, determine `max_len` as the longest question in `input1` and `input2`. Ceil `max_len` to a power of $2$ (for computation purposes) using the following command:  `max_len = 2**int(np.ceil(np.log2(max_len)))`.\n",
    "- Pad every question by `vocab['<PAD>']` until you get the length `max_len`.\n",
    "- Use yield to return `input1, input2`. \n",
    "- Don't forget to reset `input1, input2`  to empty arrays at the end (data generator resumes from where it last left)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "ibchgos48MtA"
   },
   "outputs": [],
   "source": [
    "# UNQ_C1 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "# GRADED FUNCTION: data_generator\n",
    "def data_generator(Q1, Q2, batch_size, pad=1, shuffle=True):\n",
    "    \"\"\"Generator function that yields batches of data\n",
    "\n",
    "    Args:\n",
    "        Q1 (list): List of transformed (to tensor) questions.\n",
    "        Q2 (list): List of transformed (to tensor) questions.\n",
    "        batch_size (int): Number of elements per batch.\n",
    "        pad (int, optional): Pad character from the vocab. Defaults to 1.\n",
    "        shuffle (bool, optional): If the batches should be randomnized or not. Defaults to True.\n",
    "    Yields:\n",
    "        tuple: Of the form (input1, input2) with types (numpy.ndarray, numpy.ndarray)\n",
    "        NOTE: input1: inputs to your model [q1a, q2a, q3a, ...] i.e. (q1a,q1b) are duplicates\n",
    "              input2: targets to your model [q1b, q2b,q3b, ...] i.e. (q1a,q2i) i!=a are not duplicates\n",
    "    \"\"\"\n",
    "\n",
    "    input1 = []\n",
    "    input2 = []\n",
    "    idx = 0\n",
    "    len_q = len(Q1)\n",
    "    question_indexes = [*range(len_q)]\n",
    "    \n",
    "    if shuffle:\n",
    "        rnd.shuffle(question_indexes)\n",
    "    \n",
    "    ### START CODE HERE (Replace instances of 'None' with your code) ###\n",
    "    while True:\n",
    "        if idx >= len_q:\n",
    "            # if idx is greater than or equal to len_q, set idx accordingly \n",
    "            # (Hint: look at the instructions above)\n",
    "            idx = 0\n",
    "            # shuffle to get random batches if shuffle is set to True\n",
    "            if shuffle:\n",
    "                rnd.shuffle(question_indexes) \n",
    "        \n",
    "        # get questions at the `question_indexes[idx]` position in Q1 and Q2\n",
    "        q1 = Q1[question_indexes[idx]]\n",
    "        q2 = Q2[question_indexes[idx]]\n",
    "        \n",
    "        # increment idx by 1\n",
    "        idx += 1\n",
    "        # append q1\n",
    "        input1.append(q1)\n",
    "        # append q2\n",
    "        input2.append(q2)\n",
    "        if len(input1) == batch_size:\n",
    "            # determine max_len as the longest question in input1 & input 2\n",
    "            # Hint: use the `max` function. \n",
    "            # take max of input1 & input2 and then max out of the two of them.\n",
    "            max_len = max(max(len(i1), len(i2)) for i1, i2 in zip(input1, input2))\n",
    "            # pad to power-of-2 (Hint: look at the instructions above)\n",
    "            max_len = 2**int(np.ceil(np.log2(max_len)))\n",
    "            b1 = [] \n",
    "            b2 = [] \n",
    "            for q1, q2 in zip(input1, input2):\n",
    "                # add [pad] to q1 until it reaches max_len\n",
    "                q1 = q1+[pad for i in range(max_len - len(q1))]\n",
    "                # add [pad] to q2 until it reaches max_len\n",
    "                q2 = q2+[pad for i in range(max_len - len(q2))]                \n",
    "                # append q1\n",
    "                b1.append(q1)\n",
    "                # append q2\n",
    "                b2.append(q2)\n",
    "            # use b1 and b2\n",
    "            yield np.array(b1), np.array(b2)\n",
    "    ### END CODE HERE ###\n",
    "            # reset the batches\n",
    "            input1, input2 = [], []  # reset the batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204
    },
    "colab_type": "code",
    "id": "ZFZeBPnW8Mlb",
    "outputId": "7a31cd19-55dc-4b97-f288-6c59c6a34b53"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "First questions  :  \n",
      " [[  30   87   78  134 2132 1981   28   78  594   21    1    1    1    1\n",
      "     1    1]\n",
      " [  30   55   78 3541 1460   28   56  253   21    1    1    1    1    1\n",
      "     1    1]] \n",
      "\n",
      "Second questions :  \n",
      " [[  30  156   78  134 2132 9508   21    1    1    1    1    1    1    1\n",
      "     1    1]\n",
      " [  30  156   78 3541 1460  131   56  253   21    1    1    1    1    1\n",
      "     1    1]]\n"
     ]
    }
   ],
   "source": [
    "batch_size = 2\n",
    "res1, res2 = next(data_generator(train_Q1, train_Q2, batch_size))\n",
    "print(\"First questions  : \",'\\n', res1, '\\n')\n",
    "print(\"Second questions : \",'\\n', res2)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "tWJ1L9m2piui"
   },
   "source": [
    "**Note**: The following expected output is valid only if you run the above test cell **_once_** (first time). The output will change on each execution.\n",
    "\n",
    "If you think your implementation is correct and it is not matching the output, make sure to restart the kernel and run all the cells from the top again. \n",
    "\n",
    "**Expected Output:**\n",
    "```CPP\n",
    "First questions  :  \n",
    " [[  30   87   78  134 2132 1981   28   78  594   21    1    1    1    1\n",
    "     1    1]\n",
    " [  30   55   78 3541 1460   28   56  253   21    1    1    1    1    1\n",
    "     1    1]] \n",
    "\n",
    "Second questions :  \n",
    " [[  30  156   78  134 2132 9508   21    1    1    1    1    1    1    1\n",
    "     1    1]\n",
    " [  30  156   78 3541 1460  131   56  253   21    1    1    1    1    1\n",
    "     1    1]]\n",
    "```\n",
    "Now that you have your generator, you can just call it and it will return tensors which correspond to your questions in the Quora data set.<br>Now you can go ahead and start building your neural network. \n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m All tests passed\n"
     ]
    }
   ],
   "source": [
    "# Test your function\n",
    "w4_unittest.test_data_generator(data_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KmZRBoaMwt0w"
   },
   "source": [
    "<a name='2'></a>\n",
    "# Part 2: Defining the Siamese model\n",
    "\n",
    "<a name='2.1'></a>\n",
    "\n",
    "### 2.1 Understanding Siamese Network \n",
    "A Siamese network is a neural network which uses the same weights while working in tandem on two different input vectors to compute comparable output vectors.The Siamese network you are about to implement looks like this:\n",
    "\n",
    "<img src = \"images/siamese.png\" style=\"width:600px;height:300px;\"/>\n",
    "\n",
    "You get the question embedding, run it through an LSTM layer, normalize $v_1$ and $v_2$, and finally use a triplet loss (explained below) to get the corresponding cosine similarity for each pair of questions. As usual, you will start by importing the data set. The triplet loss makes use of a baseline (anchor) input that is compared to a positive (truthy) input and a negative (falsy) input. The distance from the baseline (anchor) input to the positive (truthy) input is minimized, and the distance from the baseline (anchor) input to the negative (falsy) input is maximized. In math equations, you are trying to maximize the following.\n",
    "\n",
    "$$\\mathcal{L}(A, P, N)=\\max \\left(\\|\\mathrm{f}(A)-\\mathrm{f}(P)\\|^{2}-\\|\\mathrm{f}(A)-\\mathrm{f}(N)\\|^{2}+\\alpha, 0\\right)$$\n",
    "\n",
    "$A$ is the anchor input, for example $q1_1$, $P$ the duplicate input, for example, $q2_1$, and $N$ the negative input (the non duplicate question), for example $q2_2$.<br>\n",
    "$\\alpha$ is a margin; you can think about it as a safety net, or by how much you want to push the duplicates from the non duplicates. \n",
    "<br>\n",
    "\n",
    "<a name='ex02'></a>\n",
    "### Exercise 02\n",
    "\n",
    "**Instructions:** Implement the `Siamese` function below. You should be using all the objects explained below. \n",
    "\n",
    "To implement this model, you will be using `trax`. Concretely, you will be using the following functions.\n",
    "\n",
    "\n",
    "- `tl.Serial`: Combinator that applies layers serially (by function composition) allows you set up the overall structure of the feedforward. [docs](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Serial) / [source code](https://github.com/google/trax/blob/1372b903bb66b0daccee19fd0b1fdf44f659330b/trax/layers/combinators.py#L26)\n",
    "    - You can pass in the layers as arguments to `Serial`, separated by commas. \n",
    "    - For example: `tl.Serial(tl.Embeddings(...), tl.Mean(...), tl.Dense(...), tl.LogSoftmax(...))` \n",
    "\n",
    "\n",
    "-  `tl.Embedding`: Maps discrete tokens to vectors. It will have shape (vocabulary length X dimension of output vectors). The dimension of output vectors (also called d_feature) is the number of elements in the word embedding. [docs](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Embedding) / [source code](https://github.com/google/trax/blob/1372b903bb66b0daccee19fd0b1fdf44f659330b/trax/layers/core.py#L113)\n",
    "    - `tl.Embedding(vocab_size, d_feature)`.\n",
    "    - `vocab_size` is the number of unique words in the given vocabulary.\n",
    "    - `d_feature` is the number of elements in the word embedding (some choices for a word embedding size range from 150 to 300, for example).\n",
    "\n",
    "\n",
    "-  `tl.LSTM` The LSTM layer. It leverages another Trax layer called [`LSTMCell`](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.rnn.LSTMCell). The number of units should be specified and should match the number of elements in the word embedding. [docs](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.rnn.LSTM) / [source code](https://github.com/google/trax/blob/1372b903bb66b0daccee19fd0b1fdf44f659330b/trax/layers/rnn.py#L87)\n",
    "    - `tl.LSTM(n_units)` Builds an LSTM layer of n_units.\n",
    "    \n",
    "    \n",
    "- `tl.Mean`: Computes the mean across a desired axis. Mean uses one tensor axis to form groups of values and replaces each group with the mean value of that group. [docs](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.core.Mean) / [source code](https://github.com/google/trax/blob/1372b903bb66b0daccee19fd0b1fdf44f659330b/trax/layers/core.py#L276)\n",
    "    - `tl.Mean(axis=1)` mean over columns.\n",
    "\n",
    "\n",
    "- `tl.Fn` Layer with no weights that applies the function f, which should be specified using a lambda syntax. [docs](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.base.Fn) / [source doce](https://github.com/google/trax/blob/70f5364dcaf6ec11aabbd918e5f5e4b0f5bfb995/trax/layers/base.py#L576)\n",
    "    - $x$ -> This is used for cosine similarity.\n",
    "    - `tl.Fn('Normalize', lambda x: normalize(x))` Returns a layer with no weights that applies the function `f`\n",
    "    \n",
    "    \n",
    "- `tl.parallel`: It is a combinator layer (like `Serial`) that applies a list of layers in parallel to its inputs. [docs](https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.combinators.Parallel) / [source code](https://github.com/google/trax/blob/37aba571a89a8ad86be76a569d0ec4a46bdd8642/trax/layers/combinators.py#L152)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "hww76f8_wt0x"
   },
   "outputs": [],
   "source": [
    "# UNQ_C2 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "# GRADED FUNCTION: Siamese\n",
    "def Siamese(vocab_size=41699, d_model=128, mode='train'):\n",
    "    \"\"\"Returns a Siamese model.\n",
    "\n",
    "    Args:\n",
    "        vocab_size (int, optional): Length of the vocabulary. Defaults to len(vocab).\n",
    "        d_model (int, optional): Depth of the model. Defaults to 128.\n",
    "        mode (str, optional): 'train', 'eval' or 'predict', predict mode is for fast inference. Defaults to 'train'.\n",
    "\n",
    "    Returns:\n",
    "        trax.layers.combinators.Parallel: A Siamese model. \n",
    "    \"\"\"\n",
    "\n",
    "    def normalize(x):  # normalizes the vectors to have L2 norm 1\n",
    "        return x / fastnp.sqrt(fastnp.sum(x * x, axis=-1, keepdims=True))\n",
    "    \n",
    "    ### START CODE HERE (Replace instances of 'None' with your code) ###\n",
    "    q_processor = tl.Serial( # Processor will run on Q1 and Q2. \n",
    "        tl.Embedding(vocab_size=vocab_size, d_feature=d_model), # Embedding layer\n",
    "        tl.LSTM(n_units=d_model), # LSTM layer\n",
    "        tl.Mean(axis=1), # Mean over columns\n",
    "        tl.Fn('Normalize', lambda x: normalize(x)), # Apply normalize function\n",
    "    )  # Returns one vector of shape [batch_size, d_model]. \n",
    "    \n",
    "    ### END CODE HERE ###\n",
    "    \n",
    "    # Run on Q1 and Q2 in parallel.\n",
    "    model = tl.Parallel(q_processor, q_processor)\n",
    "    return model\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "es2gfwZypiul"
   },
   "source": [
    "Setup the Siamese network model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 255
    },
    "colab_type": "code",
    "id": "kvQ_jf52-JAn",
    "outputId": "d409460d-2ffb-4ae6-8745-ddcfa1d892ad"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Parallel_in2_out2[\n",
      "  Serial[\n",
      "    Embedding_41699_128\n",
      "    LSTM_128\n",
      "    Mean\n",
      "    Normalize\n",
      "  ]\n",
      "  Serial[\n",
      "    Embedding_41699_128\n",
      "    LSTM_128\n",
      "    Mean\n",
      "    Normalize\n",
      "  ]\n",
      "]\n"
     ]
    }
   ],
   "source": [
    "# check your model\n",
    "model = Siamese()\n",
    "print(model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "LMK9zqhHpiuo"
   },
   "source": [
    "**Expected output:**  \n",
    "\n",
    "```CPP\n",
    "Parallel_in2_out2[\n",
    "  Serial[\n",
    "    Embedding_41699_128\n",
    "    LSTM_128\n",
    "    Mean\n",
    "    Normalize\n",
    "  ]\n",
    "  Serial[\n",
    "    Embedding_41699_128\n",
    "    LSTM_128\n",
    "    Mean\n",
    "    Normalize\n",
    "  ]\n",
    "]\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m All tests passed\n"
     ]
    }
   ],
   "source": [
    "# Test your function\n",
    "w4_unittest.test_Siamese(Siamese)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "KVo1Gvripiuo"
   },
   "source": [
    "<a name='2.2'></a>\n",
    "\n",
    "### 2.2 Hard  Negative Mining\n",
    "\n",
    "\n",
    "You will now implement the `TripletLoss`.<br>\n",
    "As explained in the lecture, loss is composed of two terms. One term utilizes the mean of all the non duplicates, the second utilizes the *closest negative*. Our loss expression is then:\n",
    " \n",
    "\\begin{align}\n",
    " \\mathcal{Loss_{1}(A,P,N)} &=\\max \\left( -cos(A,P)  + mean_{neg} +\\alpha, 0\\right) \\\\\n",
    " \\mathcal{Loss_{2}(A,P,N)} &=\\max \\left( -cos(A,P)  + closest_{neg} +\\alpha, 0\\right) \\\\\n",
    "\\mathcal{Loss(A,P,N)} &= mean(Loss_1 + Loss_2) \\\\\n",
    "\\end{align}\n",
    "\n",
    "\n",
    "Further, two sets of instructions are provided. The first set provides a brief description of the task. If that set proves insufficient, a more detailed set can be displayed.  \n",
    "\n",
    "<a name='ex03'></a>\n",
    "### Exercise 03\n",
    "\n",
    "**Instructions (Brief):** Here is a list of things you should do: <br>\n",
    "\n",
    "- As this will be run inside trax, use `fastnp.xyz` when using any `xyz` numpy function\n",
    "- Use `fastnp.dot` to calculate the similarity matrix $v_1v_2^T$ of dimension `batch_size` x `batch_size`\n",
    "- Take the score of the duplicates on the diagonal `fastnp.diagonal`\n",
    "- Use the `trax` functions `fastnp.eye` and `fastnp.maximum` for the identity matrix and the maximum."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "GWsX-Wz3piup"
   },
   "source": [
    "<details>  \n",
    "<summary>\n",
    "    <font size=\"3\" color=\"darkgreen\"><b>More Detailed Instructions </b></font>\n",
    "</summary>\n",
    "We'll describe the algorithm using a detailed example. Below, V1, V2 are the output of the normalization blocks in our model. Here we will use a batch_size of 4 and a d_model of 3. As explained in lecture, the inputs, Q1, Q2 are arranged so that corresponding inputs are duplicates while non-corresponding entries are not. The outputs will have the same pattern.\n",
    "<img src = \"images/C3_W4_triploss1.png\" style=\"width:1021px;height:229px;\"/>\n",
    "This testcase arranges the outputs, v1,v2, to highlight different scenarios. Here, the first two outputs V1[0], V2[0] match exactly - so the model is generating the same vector for Q1[0] and Q2[0] inputs. The second outputs differ, circled in orange, we set, V2[1] is set to match V2[2], simulating a model which is generating very poor results. V1[3] and V2[3] match exactly again while V1[4] and V2[4] are set to be exactly wrong - 180 degrees from each other, circled in blue. \n",
    "\n",
    "The first step is to compute the cosine similarity matrix or `score` in the code. As explained in lecture, this is $$V_1 V_2^T$$ This is generated with `fastnp.dot`.\n",
    "<img src = \"images/C3_W4_triploss2.png\" style=\"width:959px;height:236px;\"/>\n",
    "The clever arrangement of inputs creates the data needed for positive *and* negative examples without having to run all pair-wise combinations. Because Q1[n] is a duplicate of only Q2[n], other combinations are explicitly created negative examples or *Hard Negative* examples. The matrix multiplication efficiently produces the cosine similarity of all positive/negative combinations as shown above on the left side of the diagram. 'Positive' are the results of duplicate examples and 'negative' are the results of explicitly created negative examples. The results for our test case are as expected, V1[0]V2[0] match producing '1' while our other 'positive' cases (in green) don't match well, as was arranged. The V2[2] was set to match V1[3] producing a poor match at `score[2,2]` and an undesired 'negative' case of a '1' shown in grey. \n",
    "\n",
    "With the similarity matrix (`score`) we can begin to implement the loss equations. First, we can extract $$cos(A,P)$$ by utilizing `fastnp.diagonal`. The goal is to grab all the green entries in the diagram above. This is `positive` in the code.\n",
    "\n",
    "Next, we will create the *closest_negative*. This is the nonduplicate entry in V2 that is closest (has largest cosine similarity) to an entry in V1. Each row, n, of `score` represents all comparisons of the results of Q1[n] vs Q2[x] within a batch. A specific example in our testcase is row `score[2,:]`. It has the cosine similarity of V1[2] and V2[x]. The *closest_negative*, as was arranged, is V2[2] which has a score of 1. This is the maximum value of the 'negative' entries (blue entries in the diagram).\n",
    "\n",
    "To implement this, we need to pick the maximum entry on a row of `score`, ignoring the 'positive'/green entries. To avoid selecting the 'positive'/green entries, we can make them larger negative numbers. Multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores`. The result is `negative_without_positive`. Now we can use `fastnp.max`, row by row (axis=1), to select the maximum which is `closest_negative`.\n",
    "\n",
    "Next, we'll create *mean_negative*. As the name suggests, this is the mean of all the 'negative'/blue values in `score` on a row by row basis. We can use `fastnp.eye(batch_size)` and a constant, this time to create a mask with zeros on the diagonal. Element-wise multiply this with `score` to get just the 'negative values. This is `negative_zero_on_duplicate` in the code. Compute the mean by using `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)` . This is `mean_negative`.\n",
    "\n",
    "Now, we can compute loss using the two equations above and `fastnp.maximum`. This will form `triplet_loss1` and `triplet_loss2`. \n",
    "\n",
    "`triple_loss` is the `fastnp.mean` of the sum of the two individual losses.\n",
    "\n",
    "Once you have this code matching the expected results, you can clip out the section between `### START CODE HERE` and `### END CODE HERE` it out and insert it into TripletLoss below. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 64,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "oJM8EQiopiuv"
   },
   "outputs": [],
   "source": [
    "# UNQ_C3 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "# GRADED FUNCTION: TripletLossFn\n",
    "def TripletLossFn(v1, v2, margin=0.25):\n",
    "    \"\"\"Custom Loss function.\n",
    "\n",
    "    Args:\n",
    "        v1 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q1.\n",
    "        v2 (numpy.ndarray): Array with dimension (batch_size, model_dimension) associated to Q2.\n",
    "        margin (float, optional): Desired margin. Defaults to 0.25.\n",
    "\n",
    "    Returns:\n",
    "        jax.interpreters.xla.DeviceArray: Triplet Loss.\n",
    "    \"\"\"\n",
    "    ### START CODE HERE (Replace instances of 'None' with your code) ###\n",
    "    \n",
    "    # use fastnp to take the dot product of the two batches (don't forget to transpose the second argument)\n",
    "    scores =  fastnp.dot(v1, v2.T) # pairwise cosine sim\n",
    "    #print('scores', scores)\n",
    "    # calculate new batch size\n",
    "    batch_size = len(scores) # @KEEPTHIS\n",
    "    #print('batch_size', batch_size)\n",
    "    # use fastnp to grab all postive `diagonal` entries in `scores`\n",
    "    positive = fastnp.diag(scores)  # the positive ones (duplicates)\n",
    "    # multiply `fastnp.eye(batch_size)` with 2.0 and subtract it out of `scores`\n",
    "    negative_without_positive = scores - fastnp.eye(batch_size)*2\n",
    "    # take the row by row `max` of `negative_without_positive`. \n",
    "    # Hint: negative_without_positive.max(axis = [?])\n",
    "    closest_negative = negative_without_positive.max(axis=1)\n",
    "    #print('closest_negative', closest_negative)\n",
    "    # subtract `fastnp.eye(batch_size)` out of 1.0 and do element-wise multiplication with `scores`\n",
    "    negative_zero_on_duplicate = scores*(1.0- fastnp.eye(batch_size))\n",
    "    # use `fastnp.sum` on `negative_zero_on_duplicate` for `axis=1` and divide it by `(batch_size - 1)`\n",
    "    mean_negative = fastnp.sum(negative_zero_on_duplicate, axis=1)/ (batch_size -1)\n",
    "    # compute `fastnp.maximum` among 0.0 and `A`\n",
    "    # where A = subtract `positive` from `margin` and add `closest_negative`\n",
    "    # IMPORTANT: DO NOT create an extra variable 'A'\n",
    "    triplet_loss1 = fastnp.maximum(0.0, margin - positive + closest_negative)\n",
    "    # compute `fastnp.maximum` among 0.0 and `B`\n",
    "    # where B = subtract `positive` from `margin` and add `mean_negative`\n",
    "    # IMPORTANT: DO NOT create an extra variable 'B'\n",
    "    triplet_loss2 = fastnp.maximum(0.0, margin - positive  + mean_negative)\n",
    "    # add the two losses together and take the `fastnp.mean` of it\n",
    "    triplet_loss = fastnp.mean(triplet_loss1+triplet_loss2)\n",
    "    \n",
    "    ### END CODE HERE ###\n",
    "    \n",
    "    return triplet_loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 65,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Triplet Loss: 0.5\n"
     ]
    }
   ],
   "source": [
    "v1 = np.array([[ 0.26726124,  0.53452248,  0.80178373],[-0.5178918 , -0.57543534, -0.63297887]])\n",
    "v2 = np.array([[0.26726124, 0.53452248, 0.80178373],[0.5178918 , 0.57543534, 0.63297887]])\n",
    "print(\"Triplet Loss:\", TripletLossFn(v1,v2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**Expected Output:**\n",
    "```CPP\n",
    "Triplet Loss: 0.5\n",
    "```   "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m All tests passed\n"
     ]
    }
   ],
   "source": [
    "# Test your function\n",
    "w4_unittest.test_TripletLossFn(TripletLossFn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "r974ozuHYAom"
   },
   "source": [
    "To make a layer out of a function with no trainable variables, use `tl.Fn`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 67,
   "metadata": {},
   "outputs": [],
   "source": [
    "from functools import partial\n",
    "def TripletLoss(margin=0.25):\n",
    "    triplet_loss_fn = partial(TripletLossFn, margin=margin)\n",
    "    return tl.Fn('TripletLoss', triplet_loss_fn)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lsvjaCQ6wt02"
   },
   "source": [
    "<a name='3'></a>\n",
    "\n",
    "# Part 3: Training\n",
    "\n",
    "Now you are going to train your model. As usual, you have to define the cost function and the optimizer. You also have to feed in the built model. Before, going into the training, we will use a special data set up. We will define the inputs using the data generator we built above. The lambda function acts as a seed to remember the last batch that was given. Run the cell below to get the question pairs inputs. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 68,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 51
    },
    "colab_type": "code",
    "id": "iPk7gh-nzCBg",
    "outputId": "a2e8525d-f89a-4d9d-c0d6-bd7406f0246a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train_Q1.shape  (89188,)\n",
      "val_Q1.shape    (22298,)\n"
     ]
    }
   ],
   "source": [
    "batch_size = 256\n",
    "train_generator = data_generator(train_Q1, train_Q2, batch_size, vocab['<PAD>'])\n",
    "val_generator = data_generator(val_Q1, val_Q2, batch_size, vocab['<PAD>'])\n",
    "print('train_Q1.shape ', train_Q1.shape)\n",
    "print('val_Q1.shape   ', val_Q1.shape)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "IgFMfH5awt07"
   },
   "source": [
    "<a name='3.1'></a>\n",
    "\n",
    "### 3.1 Training the model\n",
    "\n",
    "You will now write a function that takes in your model and trains it. To train your model you have to decide how many times you want to iterate over the entire data set; each iteration is defined as an `epoch`. For each epoch, you have to go over all the data, using your training iterator.\n",
    "\n",
    "<a name='ex04'></a>\n",
    "### Exercise 04\n",
    "\n",
    "**Instructions:** Implement the `train_model` below to train the neural network above. Here is a list of things you should do, as already shown in lecture 7: \n",
    "\n",
    "- Create `TrainTask` and `EvalTask`\n",
    "- Create the training loop `trax.supervised.training.Loop`\n",
    "- Pass in the following depending on the context (train_task or eval_task):\n",
    "    - `labeled_data=generator`\n",
    "    - `metrics=[TripletLoss()]`,\n",
    "    - `loss_layer=TripletLoss()`\n",
    "    - `optimizer=trax.optimizers.Adam` with learning rate of 0.01\n",
    "    - `lr_schedule=trax.lr.warmup_and_rsqrt_decay(400, 0.01)`,\n",
    "    - `output_dir=output_dir`\n",
    "\n",
    "\n",
    "You will be using your triplet loss function with Adam optimizer. Please read the [trax](https://trax-ml.readthedocs.io/en/latest/trax.optimizers.html?highlight=adam#trax.optimizers.adam.Adam) documentation to get a full understanding. \n",
    "\n",
    "This function should return a `training.Loop` object. To read more about this check the [docs](https://trax-ml.readthedocs.io/en/latest/trax.supervised.html?highlight=loop#trax.supervised.training.Loop)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 69,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "_kbtfz4T_m7x"
   },
   "outputs": [],
   "source": [
    "# UNQ_C4 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "# GRADED FUNCTION: train_model\n",
    "def train_model(Siamese, TripletLoss\n",
    "                , train_generator, val_generator, output_dir='model/'):\n",
    "    \"\"\"Training the Siamese Model\n",
    "\n",
    "    Args:\n",
    "        Siamese (function): Function that returns the Siamese model.\n",
    "        TripletLoss (function): Function that defines the TripletLoss loss function.\n",
    "        lr_schedule (function): Trax multifactor schedule function.\n",
    "        train_generator (generator, optional): Training generator. Defaults to train_generator.\n",
    "        val_generator (generator, optional): Validation generator. Defaults to val_generator.\n",
    "        output_dir (str, optional): Path to save model to. Defaults to 'model/'.\n",
    "\n",
    "    Returns:\n",
    "        trax.supervised.training.Loop: Training loop for the model.\n",
    "    \"\"\"\n",
    "    output_dir = os.path.expanduser(output_dir)\n",
    "\n",
    "    ### START CODE HERE (Replace instances of 'None' with your code) ###\n",
    "\n",
    "    train_task = training.TrainTask( \n",
    "        labeled_data=train_generator,      # Use generator (train)\n",
    "        loss_layer=TripletLoss(),        # Use triplet loss. Don't forget to instantiate this object\n",
    "        optimizer=trax.optimizers.Adam(0.01),         # Don't forget to add the learning rate parameter\n",
    "        lr_schedule=trax.lr.warmup_and_rsqrt_decay(400, 0.01) # Use Trax multifactor schedule function\n",
    "    )\n",
    "\n",
    "    eval_task = training.EvalTask(\n",
    "        labeled_data=val_generator,      # Use generator (val)\n",
    "        metrics=[TripletLoss()],         # Use triplet loss. Don't forget to instantiate this object\n",
    "    )\n",
    "    \n",
    "    ### END CODE HERE ###\n",
    "\n",
    "    training_loop = training.Loop(Siamese(),\n",
    "                                  train_task,\n",
    "                                  eval_tasks=[eval_task],\n",
    "                                  output_dir=output_dir)\n",
    "\n",
    "    return training_loop"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 391
    },
    "colab_type": "code",
    "id": "-3KXjmBo_6Xa",
    "outputId": "9d57f731-1534-4218-e744-783359d5cd19"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "Step      1: Total number of trainable weights: 5469056\n",
      "Step      1: Ran 1 train steps in 4.68 secs\n",
      "Step      1: train TripletLoss |  0.49999499\n",
      "Step      1: eval  TripletLoss |  0.49999827\n"
     ]
    }
   ],
   "source": [
    "train_steps = 5\n",
    "training_loop = train_model(Siamese, TripletLoss, train_generator, val_generator)\n",
    "training_loop.run(train_steps)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The model was only trained for 5 steps due to the constraints of this environment. For the rest of the assignment you will be using a pretrained model but now you should understand how the training can be done using Trax."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 71,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m All tests passed\n"
     ]
    }
   ],
   "source": [
    "# Test your function\n",
    "w4_unittest.test_train_model(train_model, Siamese, TripletLoss, data_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "abKPe7d4wt1C"
   },
   "source": [
    "<a name='4'></a>\n",
    "\n",
    "# Part 4:  Evaluation  \n",
    "\n",
    "<a name='4.1'></a>\n",
    "\n",
    "### 4.1 Evaluating your siamese network\n",
    "\n",
    "In this section you will learn how to evaluate a Siamese network. You will first start by loading a pretrained model and then you will use it to predict. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "3OtmlEuOwt1D"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(((array([[-0.1548367 ,  0.5842879 , -0.491473  , ...,  0.20171408,\n",
       "           -1.1991485 ,  0.24797706],\n",
       "          [ 1.3291764 , -0.631177  , -0.8041899 , ..., -0.11724447,\n",
       "            1.1201009 , -1.6056383 ],\n",
       "          [-0.07363441,  0.60029113, -0.53114337, ...,  0.34642273,\n",
       "            1.2527287 ,  0.24330486],\n",
       "          ...,\n",
       "          [-1.3343723 ,  1.8407435 , -0.9178211 , ..., -0.8050336 ,\n",
       "           -0.52903235, -0.6519946 ],\n",
       "          [ 0.6470662 ,  1.1115725 , -1.2664155 , ..., -1.0939023 ,\n",
       "            0.35975415, -0.9333337 ],\n",
       "          [-1.7249486 , -0.26522425, -0.8502843 , ...,  0.0217114 ,\n",
       "           -0.9187999 ,  0.57652336]], dtype=float32),\n",
       "   (((), ((), ())),\n",
       "    ((array([[-0.090519  , -0.17231764,  0.08899491, ...,  0.03802169,\n",
       "               0.04847963,  0.10207134],\n",
       "             [ 0.09968922, -0.12788954, -0.02883548, ..., -0.01054264,\n",
       "               0.09648962, -0.14594312],\n",
       "             [ 0.31670982,  0.2782086 ,  0.03118353, ...,  0.07789616,\n",
       "              -0.10716873,  0.27891394],\n",
       "             ...,\n",
       "             [-0.07981083, -0.1666319 ,  0.22607869, ..., -0.0145308 ,\n",
       "              -0.10772758, -0.07736145],\n",
       "             [-0.21396244, -0.27984133,  0.09139609, ..., -0.06796359,\n",
       "              -0.13749957, -0.02988345],\n",
       "             [ 0.09710262, -0.04201529,  0.23600577, ..., -0.00910869,\n",
       "              -0.07826541, -0.06900913]], dtype=float32),\n",
       "      array([ 0.74866724,  0.23604804,  0.61066806,  0.46742544,  0.84660137,\n",
       "              0.6759479 ,  0.75458455,  0.63004017,  0.7348204 ,  0.7810847 ,\n",
       "              0.58264714,  0.80906653,  0.56512594,  0.46575448,  0.5894925 ,\n",
       "              0.52730995,  0.8338528 ,  0.5772276 ,  0.65321517,  0.597102  ,\n",
       "              0.60202324,  0.5726475 ,  0.21979097,  0.30167374,  0.7590757 ,\n",
       "              0.823802  ,  0.7020812 ,  0.89135915,  0.6727817 ,  0.62519836,\n",
       "              0.52391505,  0.5917081 ,  0.87297624,  0.7066292 ,  0.7613032 ,\n",
       "              0.7474745 ,  0.83107007,  0.6900485 ,  0.77471846,  0.42949966,\n",
       "              0.49603695,  0.8286887 ,  0.8028574 ,  0.8557113 ,  0.73552644,\n",
       "              0.7040203 ,  0.6676977 ,  0.8428011 ,  0.85005385,  0.83686626,\n",
       "              0.8439175 ,  0.9119055 ,  0.8712652 ,  0.76832306,  0.94966036,\n",
       "              0.9400879 ,  0.40259257,  0.8512523 ,  0.70599616,  0.8428879 ,\n",
       "              0.26889625,  0.4928807 ,  0.69416577,  0.8385073 ,  0.63600576,\n",
       "              0.665324  ,  0.94755393,  0.21964471,  0.9927555 ,  0.57788414,\n",
       "              0.4923829 ,  0.7651692 ,  0.8147946 ,  0.42811084,  0.78248423,\n",
       "              0.80160993,  0.0615592 ,  0.42542377,  0.9748414 ,  0.5906    ,\n",
       "              0.8284891 ,  0.5356175 ,  0.8901237 ,  0.5789534 ,  0.53019756,\n",
       "              0.6117212 ,  0.6344657 ,  0.7970178 ,  0.7678922 ,  0.56659377,\n",
       "              0.38351488,  0.46815988,  0.92958826,  0.75536066,  0.7945257 ,\n",
       "              0.81163955, -0.03411549,  0.6552565 ,  0.89075   ,  0.12296801,\n",
       "              0.5060967 ,  0.6040902 ,  0.706429  ,  0.5993115 ,  0.94161844,\n",
       "              0.7323816 ,  0.5769668 ,  0.76575243,  0.82241976,  0.5021452 ,\n",
       "              0.39745304,  0.5555639 ,  0.8681804 ,  0.35063615,  0.79127926,\n",
       "              0.6759815 ,  0.6214066 ,  0.8014635 ,  0.6679272 ,  0.73128015,\n",
       "              0.9667262 ,  0.6471471 ,  0.7636121 ,  0.80299324,  0.56265306,\n",
       "              0.78288954,  0.7186831 ,  0.86952096,  0.93183666,  0.9428032 ,\n",
       "              1.0660955 ,  0.9234194 ,  0.98314553,  0.9036281 ,  0.9164918 ,\n",
       "              0.96618134,  0.9356835 ,  0.9602327 ,  0.9254187 ,  0.97522515,\n",
       "              1.0308713 ,  0.96419185,  0.90409976,  0.8842241 ,  1.0028604 ,\n",
       "              0.8919036 ,  0.90017885,  1.0468265 ,  0.95297223,  0.8282836 ,\n",
       "              0.8075595 ,  0.88657457,  0.99655706,  0.8724078 ,  0.9252944 ,\n",
       "              1.0361383 ,  0.91071635,  0.8898618 ,  0.8949262 ,  0.9023262 ,\n",
       "              0.93774265,  0.99980384,  0.9492172 ,  0.9804592 ,  1.0260873 ,\n",
       "              0.9242303 ,  0.94389015,  1.0269797 ,  0.92412454,  0.8666998 ,\n",
       "              0.9186517 ,  0.91434723,  0.93527025,  0.87238604,  0.92576706,\n",
       "              1.0864905 ,  0.9672682 ,  0.9451934 ,  0.9366875 ,  0.97492445,\n",
       "              1.0422153 ,  0.96439254,  1.0562335 ,  0.9816442 ,  0.9188951 ,\n",
       "              0.96140105,  0.94802064,  0.971322  ,  0.909987  ,  0.93718976,\n",
       "              0.92520225,  0.9487348 ,  0.9173629 ,  0.9022114 ,  0.98291016,\n",
       "              0.9203712 ,  1.0528622 ,  0.92637384,  0.8964494 ,  0.87238085,\n",
       "              0.8651392 ,  0.92889744,  0.96058226,  1.03564   ,  1.0819371 ,\n",
       "              0.9664545 ,  1.003696  ,  0.9388842 ,  0.9729707 ,  1.0196446 ,\n",
       "              0.96996737,  0.9818085 ,  0.8986951 ,  0.915359  ,  0.8793746 ,\n",
       "              1.0020589 ,  0.97678035,  0.9033994 ,  0.89549446,  0.91473526,\n",
       "              0.96713454,  0.9447215 ,  0.94387925,  0.9418829 ,  0.87416893,\n",
       "              0.91230744,  0.98161036,  1.037545  ,  0.96248883,  0.92170656,\n",
       "              0.9302868 ,  0.9691254 ,  1.0190876 ,  1.0284687 ,  0.9252369 ,\n",
       "              0.8912538 ,  0.9601474 ,  0.9675769 ,  0.9491154 ,  0.96084094,\n",
       "              0.9740903 ,  0.91221315,  0.936837  ,  0.92450714,  0.9187654 ,\n",
       "              0.9460794 ,  0.85430056,  0.96632856,  0.877542  ,  0.8769376 ,\n",
       "              0.96478283,  0.93429935,  0.9077798 ,  0.945345  ,  0.8814642 ,\n",
       "              1.0224122 ,  1.1819276 ,  1.1743892 ,  1.1259301 ,  1.1365324 ,\n",
       "              1.0201721 ,  0.77878535,  1.1089861 ,  1.3017315 ,  1.0079747 ,\n",
       "              0.88814735,  1.1703241 ,  1.0726731 ,  1.2037915 ,  1.2322062 ,\n",
       "              1.1280077 ,  1.1659031 ,  0.99659234,  1.0427486 ,  1.0784584 ,\n",
       "              1.0045304 ,  1.1760176 ,  1.1747589 ,  1.1270305 ,  1.1475794 ,\n",
       "              1.1064138 ,  0.94072586,  1.1092551 ,  1.0405165 ,  1.0915463 ,\n",
       "              1.1692449 ,  1.1306821 ,  1.1056664 ,  0.96762   ,  1.044192  ,\n",
       "              1.1142045 ,  1.043103  ,  1.1768122 ,  1.095521  ,  0.9751092 ,\n",
       "              1.1367201 ,  1.1491212 ,  0.9467336 ,  0.8791624 ,  0.97153306,\n",
       "              1.1327726 ,  1.1412548 ,  0.8003143 ,  1.063433  ,  1.0218472 ,\n",
       "              0.97348887,  0.8892381 ,  0.9056184 ,  1.0377039 ,  1.0159701 ,\n",
       "              0.9115247 ,  0.9698118 ,  1.1991931 ,  1.0183407 ,  0.93539894,\n",
       "              1.090577  ,  1.1811533 ,  1.0967308 ,  1.135492  ,  1.049782  ,\n",
       "              1.1525736 ,  1.1287878 ,  0.9547346 ,  1.1560799 ,  1.0942348 ,\n",
       "              1.1362269 ,  1.1183137 ,  1.1266384 ,  0.9379652 ,  1.0124723 ,\n",
       "              1.0144356 ,  1.0998818 ,  1.3319428 ,  1.0870364 ,  0.87949634,\n",
       "              1.1151806 ,  1.0183502 ,  0.7299382 ,  0.94768083,  1.1095407 ,\n",
       "              1.1219616 ,  1.1669122 ,  1.095584  ,  0.94217265,  0.8934846 ,\n",
       "              1.1007344 ,  1.1559246 ,  1.1674374 ,  1.0095923 ,  0.8762718 ,\n",
       "              1.0119897 ,  1.0808157 ,  1.3190006 ,  1.1605103 ,  1.045878  ,\n",
       "              1.1813673 ,  1.1257927 ,  1.1406978 ,  1.1264644 ,  1.1466924 ,\n",
       "              1.1002922 ,  1.170897  ,  1.1333318 ,  1.0016005 ,  1.1115596 ,\n",
       "              1.1753446 ,  1.1535839 ,  1.1837677 ,  1.0424278 ,  1.1577648 ,\n",
       "              0.94917965,  1.2191473 ,  1.153406  ,  0.85698205,  1.1744432 ,\n",
       "              1.1687523 ,  0.83516914,  1.1171178 ,  1.136751  ,  0.93364644,\n",
       "              1.1793485 ,  1.0649722 ,  1.0637076 ,  0.948769  ,  0.95589656,\n",
       "              0.9806558 ,  0.9719903 ,  0.9753443 ,  0.83935803,  0.99856   ,\n",
       "              1.0142294 ,  0.9422684 ,  0.88958335,  0.83948064,  1.0275595 ,\n",
       "              0.985004  ,  0.9091298 ,  0.98124814,  0.9453334 ,  0.970391  ,\n",
       "              0.98279506,  0.959381  ,  1.0146968 ,  1.0164036 ,  1.0244045 ,\n",
       "              0.99479854,  0.8352394 ,  0.9564135 ,  1.013886  ,  0.9127604 ,\n",
       "              1.0300385 ,  1.015432  ,  1.0020359 ,  1.0248142 ,  1.0215341 ,\n",
       "              0.97478676,  0.988786  ,  0.9202841 ,  0.99523634,  0.94407266,\n",
       "              0.88930994,  1.0252285 ,  1.0044427 ,  0.7899437 ,  0.9401787 ,\n",
       "              0.9649993 ,  0.906628  ,  0.96424896,  1.0199634 ,  0.96326023,\n",
       "              0.9143608 ,  1.0259404 ,  0.8494232 ,  0.96625495,  0.8660369 ,\n",
       "              0.8509821 ,  0.9552434 ,  0.79410475,  0.93694836,  0.8247366 ,\n",
       "              0.98595744,  0.96510786,  0.8712254 ,  0.88906306,  0.9786764 ,\n",
       "              1.0831393 ,  0.9558825 ,  0.94343257,  1.0149219 ,  1.0587987 ,\n",
       "              0.88914365,  0.89742774,  0.94791895,  1.0467327 ,  0.8927613 ,\n",
       "              0.98992395,  0.9364004 ,  0.9667834 ,  0.92568415,  0.98102164,\n",
       "              0.95136875,  0.863204  ,  0.92163223,  1.0005571 ,  0.9337726 ,\n",
       "              0.92428803,  0.92654896,  0.95457834,  1.0239131 ,  1.0114415 ,\n",
       "              1.0324125 ,  0.9158145 ,  0.8941093 ,  0.99975413,  0.9539995 ,\n",
       "              1.014431  ,  0.9227971 ,  0.9062155 ,  0.8609443 ,  0.94716686,\n",
       "              0.9145248 ,  1.0938901 ,  0.8660885 ,  1.0017323 ,  0.95409846,\n",
       "              1.0142331 ,  0.98524636,  1.0461007 ,  0.8833331 ,  0.9143298 ,\n",
       "              1.0509038 ,  0.8950872 ,  1.0015632 ,  1.0264977 ,  1.0189135 ,\n",
       "              0.91027725,  0.95289606,  0.99035126,  0.9531608 ,  1.0130571 ,\n",
       "              0.9590856 ,  0.92991644,  0.9554187 ,  1.0178174 ,  0.9741912 ,\n",
       "              0.98893434,  0.95927787,  0.93184817,  1.0099845 ,  0.94556713,\n",
       "              0.90718126,  0.9819065 ], dtype=float32)),),\n",
       "    ()),\n",
       "   (),\n",
       "   ()),\n",
       "  {'__marker_for_cached_weights_': ()}),\n",
       " (((), (((), ((), ())), ((), ()), ()), (), ()),\n",
       "  {'__marker_for_cached_state_': ()}))"
      ]
     },
     "execution_count": 72,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Loading in the saved model\n",
    "model = Siamese()\n",
    "model.init_from_file('model.pkl.gz')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "QDi4MBiKpivF"
   },
   "source": [
    "<a name='4.2'></a>\n",
    "### 4.2 Classify\n",
    "To determine the accuracy of the model, we will utilize the test set that was configured earlier. While in training we used only positive examples, the test data, Q1_test, Q2_test and y_test, is setup as pairs of questions, some of which are duplicates some are not. \n",
    "This routine will run all the test question pairs through the model, compute the cosine simlarity of each pair, threshold it and compare the result to  y_test - the correct response from the data set. The results are accumulated to produce an accuracy.\n",
    "\n",
    "\n",
    "<a name='ex05'></a>\n",
    "### Exercise 05\n",
    "\n",
    "**Instructions**  \n",
    " - Loop through the incoming data in `batch_size` chunks\n",
    " - Use the data generator to load `q1`, `q2` a batch at a time. **Don't forget to set `shuffle=False`!**\n",
    " - copy a `batch_size` chunk of `y` into `y_test`\n",
    " - compute `v1`, `v2` using the model\n",
    " - for each element of the batch\n",
    "        - compute the cos similarity of each pair of entries, `v1[j]`,`v2[j]`\n",
    "        - determine if `d` > threshold\n",
    "        - increment accuracy if that result matches the expected results (`y_test[j]`)\n",
    " - compute the final accuracy and return\n",
    " \n",
    "Due to some limitations of this environment, running classify multiple times may result in the kernel failing. If that happens *Restart Kernal & clear output* and then run from the top. During development, consider using a smaller set of data to reduce the number of calls to model(). "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "colab": {},
    "colab_type": "code",
    "id": "K-h6ZH507fUm"
   },
   "outputs": [],
   "source": [
    "# UNQ_C5 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "# GRADED FUNCTION: classify\n",
    "def classify(test_Q1, test_Q2, y, threshold, model, vocab, data_generator=data_generator, batch_size=64):\n",
    "    \"\"\"Function to test the accuracy of the model.\n",
    "\n",
    "    Args:\n",
    "        test_Q1 (numpy.ndarray): Array of Q1 questions.\n",
    "        test_Q2 (numpy.ndarray): Array of Q2 questions.\n",
    "        y (numpy.ndarray): Array of actual target.\n",
    "        threshold (float): Desired threshold.\n",
    "        model (trax.layers.combinators.Parallel): The Siamese model.\n",
    "        vocab (collections.defaultdict): The vocabulary used.\n",
    "        data_generator (function): Data generator function. Defaults to data_generator.\n",
    "        batch_size (int, optional): Size of the batches. Defaults to 64.\n",
    "\n",
    "    Returns:\n",
    "        float: Accuracy of the model.\n",
    "    \"\"\"    \n",
    "    \n",
    "    \n",
    "    accuracy = 0\n",
    "    ### START CODE HERE (Replace instances of 'None' with your code) ###\n",
    "    for i in range(0, len(test_Q1), batch_size):\n",
    "        # Call the data generator (built in Ex 01) with shuffle= None\n",
    "        # use batch size chuncks of questions as Q1 & Q2 arguments of the data generator. e.g x[i:i + batch_size]\n",
    "        # Hint: use `vocab['<PAD>']` for the `pad` argument of the data generator\n",
    "        q1, q2 = next(data_generator(test_Q1[i:i+batch_size], test_Q2[i:i+batch_size], batch_size, \n",
    "                                     pad=vocab['<PAD>'], shuffle=None))\n",
    "        # use batch size chuncks of actual output targets (same syntax as example above)\n",
    "        y_test = y[i:i+batch_size]\n",
    "        # Call the model    \n",
    "        v1, v2 = model((q1,q2))\n",
    "\n",
    "        for j in range(batch_size):\n",
    "            # take dot product to compute cos similarity of each pair of entries, v1[j], v2[j]\n",
    "            # don't forget to transpose the second argument\n",
    "            d = fastnp.dot(v1[j], v2[j].T)\n",
    "            # is d greater than the threshold?\n",
    "            res = d>threshold\n",
    "            # increment accurancy if y_test is equal `res`\n",
    "            accuracy += (y_test[j]==res)\n",
    "    # compute accuracy using accuracy and total length of test questions\n",
    "    accuracy = accuracy/len(test_Q1)\n",
    "    ### END CODE HERE ###\n",
    "    \n",
    "    return accuracy"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 80,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "colab_type": "code",
    "id": "yeQjHxkfpivH",
    "outputId": "103b8449-896f-403d-f011-583df70afdae"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy 0.69091797\n"
     ]
    }
   ],
   "source": [
    "# this takes around 1 minute\n",
    "accuracy = classify(Q1_test,Q2_test, y_test, 0.7, model, vocab, batch_size = 512) \n",
    "print(\"Accuracy\", accuracy)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "CsokYZwhpivJ"
   },
   "source": [
    "**Expected Result**  \n",
    "Accuracy ~0.69"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 81,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m All tests passed\n"
     ]
    }
   ],
   "source": [
    "# Test your function\n",
    "w4_unittest.test_classify(classify, vocab, data_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "4-STC44Ywt1I"
   },
   "source": [
    "<a name='5'></a>\n",
    "\n",
    "# Part 5: Testing with your own questions\n",
    "\n",
    "In this section you will test the model with your own questions. You will write a function `predict` which takes two questions as input and returns $1$ or $0$ depending on whether the question pair is a duplicate or not.   \n",
    "\n",
    "But first, we build a reverse vocabulary that allows to map encoded questions back to words: "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "21h3Y0FNpivK"
   },
   "source": [
    "Write a function `predict`that takes in two questions, the model, and the vocabulary and returns whether the questions are duplicates ($1$) or not duplicates ($0$) given a similarity threshold. \n",
    "\n",
    "<a name='ex06'></a>\n",
    "### Exercise 06\n",
    "\n",
    "\n",
    "**Instructions:** \n",
    "- Tokenize your question using `nltk.word_tokenize` \n",
    "- Create Q1,Q2 by encoding your questions as a list of numbers using vocab\n",
    "- pad Q1,Q2 with next(data_generator([Q1], [Q2],1,vocab['<PAD>']))\n",
    "- use model() to create v1, v2\n",
    "- compute the cosine similarity (dot product) of v1, v2\n",
    "- compute res by comparing d to the threshold\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 86,
   "metadata": {},
   "outputs": [],
   "source": [
    "# UNQ_C6 (UNIQUE CELL IDENTIFIER, DO NOT EDIT)\n",
    "# GRADED FUNCTION: predict\n",
    "def predict(question1, question2, threshold, model, vocab, data_generator=data_generator, verbose=False):\n",
    "    \"\"\"Function for predicting if two questions are duplicates.\n",
    "\n",
    "    Args:\n",
    "        question1 (str): First question.\n",
    "        question2 (str): Second question.\n",
    "        threshold (float): Desired threshold.\n",
    "        model (trax.layers.combinators.Parallel): The Siamese model.\n",
    "        vocab (collections.defaultdict): The vocabulary used.\n",
    "        data_generator (function): Data generator function. Defaults to data_generator.\n",
    "        verbose (bool, optional): If the results should be printed out. Defaults to False.\n",
    "\n",
    "    Returns:\n",
    "        bool: True if the questions are duplicates, False otherwise.\n",
    "    \"\"\"\n",
    "    ### START CODE HERE (Replace instances of 'None' with your code) ###\n",
    "    # use `nltk` word tokenize function to tokenize\n",
    "    q1 = nltk.word_tokenize(question1)  # tokenize\n",
    "    q2 = nltk.word_tokenize(question2)  # tokenize\n",
    "    Q1, Q2 = [], [] # @KEEPTHIS\n",
    "    for word in q1:  # encode q1\n",
    "        # increment by checking the 'word' index in `vocab`\n",
    "        Q1 += [vocab[word]]\n",
    "        # GRADING COMMENT: Also valid to use\n",
    "        # Q1.extend([vocab[word]])\n",
    "    for word in q2:  # encode q2\n",
    "        # increment by checking the 'word' index in `vocab`\n",
    "        Q2 += [vocab[word]]\n",
    "        # GRADING COMMENT: Also valid to use\n",
    "        # Q2.extend([vocab[word]])\n",
    "    \n",
    "    # GRADING COMMENT: Q1 and Q2 need to be nested inside another list\n",
    "    # Call the data generator (built in Ex 01) using next()\n",
    "    # pass [Q1] & [Q2] as Q1 & Q2 arguments of the data generator. Set batch size as 1\n",
    "    # Hint: use `vocab['<PAD>']` for the `pad` argument of the data generator\n",
    "    Q1, Q2 = next(data_generator([Q1], [Q2], batch_size=1, pad=vocab['<PAD>'], shuffle=None))\n",
    "    # Call the model\n",
    "    v1, v2 = model((Q1, Q2))\n",
    "    # take dot product to compute cos similarity of each pair of entries, v1, v2\n",
    "    # don't forget to transpose the second argument\n",
    "    d = fastnp.dot(v1[0], v2[0].T)\n",
    "    # is d greater than the threshold?\n",
    "    res = d > threshold\n",
    "    \n",
    "    ### END CODE HERE ###\n",
    "    \n",
    "    if(verbose):\n",
    "        print(\"Q1  = \", Q1, \"\\nQ2  = \", Q2)\n",
    "        print(\"d   = \", d)\n",
    "        print(\"res = \", res)\n",
    "\n",
    "    return res"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 87,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "Raojyhw3z7HE",
    "outputId": "b0907aaf-63c0-448d-99b0-012359381a97"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q1  =  [[585  76   4  46  53  21   1   1]] \n",
      "Q2  =  [[ 585   33    4   46   53 7280   21    1]]\n",
      "d   =  0.8811324\n",
      "res =  True\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DeviceArray(True, dtype=bool)"
      ]
     },
     "execution_count": 87,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Feel free to try with your own questions\n",
    "question1 = \"When will I see you?\"\n",
    "question2 = \"When can I see you again?\"\n",
    "# 1 means it is duplicated, 0 otherwise\n",
    "predict(question1 , question2, 0.7, model, vocab, verbose = True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "7OEKCa_hpivP"
   },
   "source": [
    "##### Expected Output\n",
    "If input is:\n",
    "```CPP\n",
    "question1 = \"When will I see you?\"\n",
    "question2 = \"When can I see you again?\"\n",
    "```\n",
    "\n",
    "Output is (d may vary a bit):\n",
    "```CPP\n",
    "Q1  =  [[585  76   4  46  53  21   1   1]] \n",
    "Q2  =  [[ 585   33    4   46   53 7280   21    1]]\n",
    "d   =  0.88113236\n",
    "res =  True\n",
    "True\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 88,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 102
    },
    "colab_type": "code",
    "id": "DZccIQ_lpivQ",
    "outputId": "3ed0af7e-5d44-4eb3-cebe-d6f74abe3e41"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Q1  =  [[  443  1145  3159  1169    78 29017    21     1]] \n",
      "Q2  =  [[  443  1145    60 15302    28    78  7431    21]]\n",
      "d   =  0.4775358\n",
      "res =  False\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "DeviceArray(False, dtype=bool)"
      ]
     },
     "execution_count": 88,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Feel free to try with your own questions\n",
    "question1 = \"Do they enjoy eating the dessert?\"\n",
    "question2 = \"Do they like hiking in the desert?\"\n",
    "# 1 means it is duplicated, 0 otherwise\n",
    "predict(question1 , question2, 0.7, model, vocab, verbose=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "lWrt-yCMpivS"
   },
   "source": [
    "##### Expected output\n",
    "\n",
    "If input is:\n",
    "```CPP\n",
    "question1 = \"Do they enjoy eating the dessert?\"\n",
    "question2 = \"Do they like hiking in the desert?\"\n",
    "```\n",
    "\n",
    "Output  (d may vary a bit):\n",
    "\n",
    "```CPP\n",
    "Q1  =  [[  443  1145  3159  1169    78 29017    21     1]] \n",
    "Q2  =  [[  443  1145    60 15302    28    78  7431    21]]\n",
    "d   =  0.477536\n",
    "res =  False\n",
    "False\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "NAfV3l5Zwt1L"
   },
   "source": [
    "You can see that the Siamese network is capable of catching complicated structures. Concretely it can identify question duplicates although the questions do not have many words in common. \n",
    " "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 89,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[92m All tests passed\n"
     ]
    }
   ],
   "source": [
    "# Test your function\n",
    "w4_unittest.test_predict(predict, vocab, data_generator)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "FsE8tdTLwt1M"
   },
   "source": [
    "<a name='6'></a>\n",
    "\n",
    "###  <span style=\"color:blue\"> On Siamese networks </span>\n",
    "\n",
    "Siamese networks are important and useful. Many times there are several questions that are already asked in quora, or other platforms and you can use Siamese networks to avoid question duplicates. \n",
    "\n",
    "Congratulations, you have now built a powerful system that can recognize question duplicates. In the next course we will use transformers for machine translation, summarization, question answering, and chatbots. \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}