diff --git a/docs/tutorials/image_captioning.ipynb b/docs/tutorials/image_captioning.ipynb index 39613b254..e67690211 100644 --- a/docs/tutorials/image_captioning.ipynb +++ b/docs/tutorials/image_captioning.ipynb @@ -1,2141 +1,2155 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "K2s1A9eLRPEj" - }, - "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "VRLVEKiTEn04" - }, - "outputs": [], - "source": [ - "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", - "# you may not use this file except in compliance with the License.\n", - "# You may obtain a copy of the License at\n", - "#\n", - "# https://www.apache.org/licenses/LICENSE-2.0\n", - "#\n", - "# Unless required by applicable law or agreed to in writing, software\n", - "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", - "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", - "# See the License for the specific language governing permissions and\n", - "# limitations under the License." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EFwSaNB8jF7s" - }, - "source": [ - "\u003cstyle\u003e\n", - "td {\n", - " text-align: center;\n", - "}\n", - "\n", - "th {\n", - " text-align: center;\n", - "}\n", - "\u003c/style\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Cffg2i257iMS" - }, - "source": [ - "# Image captioning with visual attention\n", - "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://www.tensorflow.org/text/tutorials/image_captioning\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/tf_logo_32px.png\" /\u003eView on TensorFlow.org\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/text/blob/master/docs/tutorials/image_captioning.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca target=\"_blank\" href=\"https://github.com/tensorflow/text/blob/master/docs/tutorials/image_captioning.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView on GitHub\u003c/a\u003e\n", - " \u003c/td\u003e\n", - " \u003ctd\u003e\n", - " \u003ca href=\"https://storage.googleapis.com/tensorflow_docs/text/docs/tutorials/image_captioning.ipynb\"\u003e\u003cimg src=\"https://www.tensorflow.org/images/download_logo_32px.png\" /\u003eDownload notebook\u003c/a\u003e\n", - " \u003c/td\u003e\n", - "\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "QASbY_HGo4Lq" - }, - "source": [ - "Given an image like the example below, your goal is to generate a\n", - "caption such as \"a surfer riding on a wave\".\n", - "\n", - "\u003ctable style=\"text-align: center;\"\u003e\n", - "\u003ctr\u003e\n", - " \u003ctd\u003e\n", - " \u003cimg src=\"https://tensorflow.org/images/surf.jpg\"/\u003e\n", - " \u003c/td\u003e\n", - "\u003c/tr\u003e\n", - "\u003ctr\u003e\n", - " \u003cth\u003eA man surfing, from \u003ca href=https://commons.wikimedia.org/wiki/Surfing#/media/File:Surfing_in_Hawaii.jpg\u003ewikimedia\u003c/a\u003e\u003c/th\u003e\n", - "\u003c/tr\u003e\n", - "\u003c/table\u003e\n", - "\n", - "The model architecture used here is inspired by [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044), but has been updated to use a 2-layer Transformer-decoder. To get the most out of this tutorial you should have some experience with [text generation](https://www.tensorflow.org/text/tutorials/text_generation), [seq2seq models \u0026 attention](https://www.tensorflow.org/text/tutorials/nmt_with_attention), or [transformers](https://www.tensorflow.org/text/tutorials/transformer)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6HbD8n0w7d3F" - }, - "source": [ - "The model architecture built in this tutorial is shown below. Features are extracted from the image, and passed to the cross-attention layers of the Transformer-decoder.\n", - "\n", - "\u003ctable\u003e\n", - "\u003ctr\u003e\n", - " \u003cth\u003eThe model architecture\u003c/th\u003e\n", - "\u003c/tr\u003e\n", - "\u003ctr\u003e\n", - " \u003ctd\u003e\n", - " \u003cimg width=400 src=\"https://tensorflow.org/images/tutorials/transformer/ImageCaptioning.png\"/\u003e\n", - " \u003c/td\u003e\n", - "\u003c/tr\u003e\n", - "\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1IxifZKT6vXQ" - }, - "source": [ - "The transformer decoder is mainly built from attention layers. It uses self-attention to process the sequence being generated, and it uses cross-attention to attend to the image.\n", - "\n", - "By inspecting the attention weights of the cross attention layers you will see what parts of the image the model is looking at as it generates words.\n", - "\n", - "![Prediction](https://tensorflow.org/images/imcap_prediction.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "87us2sLVdwME" - }, - "source": [ - "This notebook is an end-to-end example. When you run the notebook, it downloads a dataset, extracts and caches the image features, and trains a decoder model. It then uses the model to generate captions on new images." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5bwwk4uxRz6A" - }, - "source": [ - "## Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "gc06pTaBbl72" - }, - "outputs": [], - "source": [ - "!apt install --allow-change-held-packages libcudnn8=8.6.0.163-1+cuda11.8" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2R1hQGtZEi8Y" - }, - "outputs": [], - "source": [ - "!pip uninstall -y tensorflow estimator keras" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "5Xbt8BkPv8Ou" - }, - "outputs": [], - "source": [ - "!pip install -U tensorflow_text tensorflow tensorflow_datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7TGZmOuqMia9" - }, - "outputs": [], - "source": [ - "!pip install einops" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nQ6q39Vd-y-7" - }, - "source": [ - "This tutorial uses lots of imports, mostly for loading the dataset(s)." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "cellView": "form", - "id": "U8l4RJ0XRPEm" - }, - "outputs": [], - "source": [ - "#@title\n", - "import concurrent.futures\n", - "import collections\n", - "import dataclasses\n", - "import hashlib\n", - "import itertools\n", - "import json\n", - "import math\n", - "import os\n", - "import pathlib\n", - "import random\n", - "import re\n", - "import string\n", - "import time\n", - "import urllib.request\n", - "\n", - "import einops\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pandas as pd\n", - "from PIL import Image\n", - "import requests\n", - "import tqdm\n", - "\n", - "import tensorflow as tf\n", - "import tensorflow_hub as hub\n", - "import tensorflow_text as text\n", - "import tensorflow_datasets as tfds" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Kl9qGnjWrv80" - }, - "source": [ - "## [Optional] Data handling\n", - "\n", - "This section downloads a captions dataset and prepares it for training. It tokenizes the input text, and caches the results of running all the images through a pretrained feature-extractor model. It's not critical to understand everything in this section.\n", - "\n", - " \u003csection class=\"expandable tfo-display-only-on-site\"\u003e\n", - " \u003cbutton type=\"button\" class=\"button-red button expand-control\"\u003eToggle section\u003c/button\u003e\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "q5e_SigQFiWf" - }, - "source": [ - "### Choose a dataset\n", - "\n", - "This tutorial is set up to give a choice of datasets. Either [Flickr8k](https://www.ijcai.org/Proceedings/15/Papers/593.pdf) or a small slice of the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/) dataset. These two are downloaded and converted from scratch, but it wouldn't be hard to convert the tutorial to use the caption datasets available in [TensorFlow Datasets](https://www.tensorflow.org/datasets): [Coco Captions](https://www.tensorflow.org/datasets/catalog/coco_captions) and the full [Conceptual Captions](https://www.tensorflow.org/datasets/community_catalog/huggingface/conceptual_captions).\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "wqGXX9Dc5c0v" - }, - "source": [ - "#### Flickr8k" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kaNy_l7tGuAZ" - }, - "outputs": [], - "source": [ - "def flickr8k(path='flickr8k'):\n", - " path = pathlib.Path(path)\n", - "\n", - " if len(list(path.rglob('*'))) \u003c 16197:\n", - " tf.keras.utils.get_file(\n", - " origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',\n", - " cache_dir='.',\n", - " cache_subdir=path,\n", - " extract=True)\n", - " tf.keras.utils.get_file(\n", - " origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',\n", - " cache_dir='.',\n", - " cache_subdir=path,\n", - " extract=True)\n", - " \n", - " captions = (path/\"Flickr8k.token.txt\").read_text().splitlines()\n", - " captions = (line.split('\\t') for line in captions)\n", - " captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)\n", - "\n", - " cap_dict = collections.defaultdict(list)\n", - " for fname, cap in captions:\n", - " cap_dict[fname].append(cap)\n", - "\n", - " train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()\n", - " train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]\n", - "\n", - " test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()\n", - " test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]\n", - "\n", - " train_ds = tf.data.experimental.from_list(train_captions)\n", - " test_ds = tf.data.experimental.from_list(test_captions)\n", - "\n", - " return train_ds, test_ds" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zQICBAF4FmSL" - }, - "source": [ - "#### Conceptual Captions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vQwnxXZXRl12" - }, - "outputs": [], - "source": [ - "def conceptual_captions(*, data_dir=\"conceptual_captions\", num_train, num_val):\n", - " def iter_index(index_path):\n", - " with open(index_path) as f:\n", - " for line in f:\n", - " caption, url = line.strip().split('\\t')\n", - " yield caption, url\n", - "\n", - " def download_image_urls(data_dir, urls):\n", - " ex = concurrent.futures.ThreadPoolExecutor(max_workers=100)\n", - " def save_image(url):\n", - " hash = hashlib.sha1(url.encode())\n", - " # Name the files after the hash of the URL.\n", - " file_path = data_dir/f'{hash.hexdigest()}.jpeg'\n", - " if file_path.exists():\n", - " # Only download each file once.\n", - " return file_path\n", - "\n", - " try:\n", - " result = requests.get(url, timeout=5)\n", - " except Exception:\n", - " file_path = None\n", - " else:\n", - " file_path.write_bytes(result.content)\n", - " return file_path\n", - " \n", - " result = []\n", - " out_paths = ex.map(save_image, urls)\n", - " for file_path in tqdm.tqdm(out_paths, total=len(urls)):\n", - " result.append(file_path)\n", - "\n", - " return result\n", - "\n", - " def ds_from_index_file(index_path, data_dir, count):\n", - " data_dir.mkdir(exist_ok=True)\n", - " index = list(itertools.islice(iter_index(index_path), count))\n", - " captions = [caption for caption, url in index]\n", - " urls = [url for caption, url in index]\n", - "\n", - " paths = download_image_urls(data_dir, urls)\n", - "\n", - " new_captions = []\n", - " new_paths = []\n", - " for cap, path in zip(captions, paths):\n", - " if path is None:\n", - " # Download failed, so skip this pair.\n", - " continue\n", - " new_captions.append(cap)\n", - " new_paths.append(path)\n", - " \n", - " new_paths = [str(p) for p in new_paths]\n", - "\n", - " ds = tf.data.Dataset.from_tensor_slices((new_paths, new_captions))\n", - " ds = ds.map(lambda path,cap: (path, cap[tf.newaxis])) # 1 caption per image\n", - " return ds\n", - "\n", - " data_dir = pathlib.Path(data_dir)\n", - " train_index_path = tf.keras.utils.get_file(\n", - " origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',\n", - " cache_subdir=data_dir,\n", - " cache_dir='.')\n", - " \n", - " val_index_path = tf.keras.utils.get_file(\n", - " origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',\n", - " cache_subdir=data_dir,\n", - " cache_dir='.')\n", - " \n", - " train_raw = ds_from_index_file(train_index_path, data_dir=data_dir/'train', count=num_train)\n", - " test_raw = ds_from_index_file(val_index_path, data_dir=data_dir/'val', count=num_val)\n", - "\n", - " return train_raw, test_raw" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rBAagBw5p-TM" - }, - "source": [ - "#### Download the dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WFtTZaobquNr" - }, - "source": [ - "The Flickr8k is a good choice because it contains 5-captions per image, more data for a smaller download." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "EJySPbzJ4Wxw" - }, - "outputs": [], - "source": [ - "choose = 'flickr8k'\n", - "\n", - "if choose == 'flickr8k':\n", - " train_raw, test_raw = flickr8k()\n", - "else:\n", - " train_raw, test_raw = conceptual_captions(num_train=10000, num_val=5000)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-UAc275FHxm8" - }, - "source": [ - "The loaders for both datasets above return `tf.data.Dataset`s containing `(image_path, captions)` pairs. The Flickr8k dataset contains 5 captions per image, while Conceptual Captions has 1:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sAQSps5F8RQI" - }, - "outputs": [], - "source": [ - "train_raw.element_spec" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "xIa0ZaP4tBez" - }, - "outputs": [], - "source": [ - "for ex_path, ex_captions in train_raw.take(1):\n", - " print(ex_path)\n", - " print(ex_captions)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8cSW4u-ORPFQ" - }, - "source": [ - "### Image feature extractor\n", - "\n", - "You will use an image model (pretrained on imagenet) to extract the features from each image. The model was trained as an image classifier, but setting `include_top=False` returns the model without the final classification layer, so you can use the last layer of feature-maps: \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "IlUckK8Zfikv" - }, - "outputs": [], - "source": [ - "IMAGE_SHAPE=(224, 224, 3)\n", - "mobilenet = tf.keras.applications.MobileNetV3Small(\n", - " input_shape=IMAGE_SHAPE,\n", - " include_top=False,\n", - " include_preprocessing=True)\n", - "mobilenet.trainable=False" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Dojkiou9gL3R" - }, - "source": [ - "Here's a function to load an image and resize it for the model:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zXR0217aRPFR" - }, - "outputs": [], - "source": [ - "def load_image(image_path):\n", - " img = tf.io.read_file(image_path)\n", - " img = tf.io.decode_jpeg(img, channels=3)\n", - " img = tf.image.resize(img, IMAGE_SHAPE[:-1])\n", - " return img" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-JyQ7zS6gzZh" - }, - "source": [ - "The model returns a feature map for each image in the input batch:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sY86n2i6wJNm" - }, - "outputs": [], - "source": [ - "test_img_batch = load_image(ex_path)[tf.newaxis, :]\n", - "\n", - "print(test_img_batch.shape)\n", - "print(mobilenet(test_img_batch).shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nyqH3zFwRPFi" - }, - "source": [ - "### Setup the text tokenizer/vectorizer\n", - "\n", - "You will transform the text captions into integer sequences using the [TextVectorization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer, with the following steps:\n", - "\n", - "* Use [adapt](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization#adapt) to iterate over all captions, split the captions into words, and compute a vocabulary of the top words.\n", - "* Tokenize all captions by mapping each word to its index in the vocabulary. All output sequences will be padded to length 50.\n", - "* Create word-to-index and index-to-word mappings to display results." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NroZIzB90hD3" - }, - "outputs": [], - "source": [ - "def standardize(s):\n", - " s = tf.strings.lower(s)\n", - " s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')\n", - " s = tf.strings.join(['[START]', s, '[END]'], separator=' ')\n", - " return s" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "n9SQOXFsyS36" - }, - "outputs": [], - "source": [ - "# Use the top 5000 words for a vocabulary.\n", - "vocabulary_size = 5000\n", - "tokenizer = tf.keras.layers.TextVectorization(\n", - " max_tokens=vocabulary_size,\n", - " standardize=standardize,\n", - " ragged=True)\n", - "# Learn the vocabulary from the caption data." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oJGE34aiRPFo" - }, - "outputs": [], - "source": [ - "tokenizer.adapt(train_raw.map(lambda fp,txt: txt).unbatch().batch(1024))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oRahTDtWhJIf" - }, - "outputs": [], - "source": [ - "tokenizer.get_vocabulary()[:10]" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-2mGxD33JCxN" - }, - "outputs": [], - "source": [ - "t = tokenizer([['a cat in a hat'], ['a robot dog']])\n", - "t" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8Q44tNQVRPFt" - }, - "outputs": [], - "source": [ - "# Create mappings for words to indices and indices to words.\n", - "word_to_index = tf.keras.layers.StringLookup(\n", - " mask_token=\"\",\n", - " vocabulary=tokenizer.get_vocabulary())\n", - "index_to_word = tf.keras.layers.StringLookup(\n", - " mask_token=\"\",\n", - " vocabulary=tokenizer.get_vocabulary(),\n", - " invert=True)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qo-cfCX3LnHs" - }, - "outputs": [], - "source": [ - "w = index_to_word(t)\n", - "w.to_list()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rrUUfGc65vAT" - }, - "outputs": [], - "source": [ - "tf.strings.reduce_join(w, separator=' ', axis=-1).numpy()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "uEWM9xrYcg45" - }, - "source": [ - "### Prepare the datasets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6aX0Z_98S2tN" - }, - "source": [ - "The `train_raw` and `test_raw` datasets contain 1:many `(image, captions)` pairs. \n", - "\n", - "This function will replicate the image so there are 1:1 images to captions:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3_Lqwl9NiGT0" - }, - "outputs": [], - "source": [ - "def match_shapes(images, captions):\n", - " caption_shape = einops.parse_shape(captions, 'b c')\n", - " captions = einops.rearrange(captions, 'b c -\u003e (b c)')\n", - " images = einops.repeat(\n", - " images, 'b ... -\u003e (b c) ...',\n", - " c = caption_shape['c'])\n", - " return images, captions" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CZGUsuGzUfzt" - }, - "outputs": [], - "source": [ - "for ex_paths, ex_captions in train_raw.batch(32).take(1):\n", - " break\n", - "\n", - "print('image paths:', ex_paths.shape)\n", - "print('captions:', ex_captions.shape)\n", - "print()\n", - "\n", - "ex_paths, ex_captions = match_shapes(images=ex_paths, captions=ex_captions)\n", - "\n", - "print('image_paths:', ex_paths.shape)\n", - "print('captions:', ex_captions.shape)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8ENR_-swVhnm" - }, - "source": [ - "To be compatible with keras training the dataset should contain `(inputs, labels)` pairs. For text generation the tokens are both an input and the labels, shifted by one step. This function will convert an `(images, texts)` pair to an `((images, input_tokens), label_tokens)` pair:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2DsgQ_hZT4C2" - }, - "outputs": [], - "source": [ - "def prepare_txt(imgs, txts):\n", - " tokens = tokenizer(txts)\n", - "\n", - " input_tokens = tokens[..., :-1]\n", - " label_tokens = tokens[..., 1:]\n", - " return (imgs, input_tokens), label_tokens" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DA1x2j0JXX-N" - }, - "source": [ - "This function adds operations to a dataset. The steps are:\n", - "\n", - "1. Load the images (and ignore images that fail to load).\n", - "2. Replicate images to match the number of captions.\n", - "3. Shuffle and rebatch the `image, caption` pairs.\n", - "4. Tokenize the text, shift the tokens and add `label_tokens`.\n", - "5. Convert the text from a `RaggedTensor` representation to padded dense `Tensor` representation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "4_Pt9zldjQ0q" - }, - "outputs": [], - "source": [ - "def prepare_dataset(ds, tokenizer, batch_size=32, shuffle_buffer=1000):\n", - " # Load the images and make batches.\n", - " ds = (ds\n", - " .shuffle(10000)\n", - " .map(lambda path, caption: (load_image(path), caption))\n", - " .apply(tf.data.experimental.ignore_errors())\n", - " .batch(batch_size))\n", - "\n", - " def to_tensor(inputs, labels):\n", - " (images, in_tok), out_tok = inputs, labels\n", - " return (images, in_tok.to_tensor()), out_tok.to_tensor()\n", - "\n", - " return (ds\n", - " .map(match_shapes, tf.data.AUTOTUNE)\n", - " .unbatch()\n", - " .shuffle(shuffle_buffer)\n", - " .batch(batch_size)\n", - " .map(prepare_txt, tf.data.AUTOTUNE)\n", - " .map(to_tensor, tf.data.AUTOTUNE)\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "LrQ85t1GNfpQ" - }, - "source": [ - "You could install the feature extractor in your model and train on the datasets like this:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1KlhOG5cjQ0r" - }, - "outputs": [], - "source": [ - "train_ds = prepare_dataset(train_raw, tokenizer)\n", - "train_ds.element_spec" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "d7Zy9F3zX7i2" - }, - "outputs": [], - "source": [ - "test_ds = prepare_dataset(test_raw, tokenizer)\n", - "test_ds.element_spec" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XZyKygJ8S8zW" - }, - "source": [ - "### [Optional] Cache the image features" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "eHKhSKhti6NS" - }, - "source": [ - "Since the image feature extractor is not changing, and this tutorial is not using image augmentation, the image features can be cached. Same for the text tokenization. The time it takes to set up the cache is earned back on each epoch during training and validation. The code below defines two functions `save_dataset` and `load_dataset`: " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9N1MX5ym6xm5" - }, - "outputs": [], - "source": [ - "def save_dataset(ds, save_path, image_model, tokenizer, shards=10, batch_size=32):\n", - " # Load the images and make batches.\n", - " ds = (ds\n", - " .map(lambda path, caption: (load_image(path), caption))\n", - " .apply(tf.data.experimental.ignore_errors())\n", - " .batch(batch_size))\n", - "\n", - " # Run the feature extractor on each batch\n", - " # Don't do this in a .map, because tf.data runs on the CPU. \n", - " def gen():\n", - " for (images, captions) in tqdm.tqdm(ds): \n", - " feature_maps = image_model(images)\n", - "\n", - " feature_maps, captions = match_shapes(feature_maps, captions)\n", - " yield feature_maps, captions\n", - "\n", - " # Wrap the generator in a new tf.data.Dataset.\n", - " new_ds = tf.data.Dataset.from_generator(\n", - " gen,\n", - " output_signature=(\n", - " tf.TensorSpec(shape=image_model.output_shape),\n", - " tf.TensorSpec(shape=(None,), dtype=tf.string)))\n", - "\n", - " # Apply the tokenization \n", - " new_ds = (new_ds\n", - " .map(prepare_txt, tf.data.AUTOTUNE)\n", - " .unbatch()\n", - " .shuffle(1000))\n", - "\n", - " # Save the dataset into shard files.\n", - " def shard_func(i, item):\n", - " return i % shards\n", - " new_ds.enumerate().save(save_path, shard_func=shard_func)\n", - "\n", - "def load_dataset(save_path, batch_size=32, shuffle=1000, cycle_length=2):\n", - " def custom_reader_func(datasets):\n", - " datasets = datasets.shuffle(1000)\n", - " return datasets.interleave(lambda x: x, cycle_length=cycle_length)\n", - " \n", - " ds = tf.data.Dataset.load(save_path, reader_func=custom_reader_func)\n", - "\n", - " def drop_index(i, x):\n", - " return x\n", - "\n", - " ds = (ds\n", - " .map(drop_index, tf.data.AUTOTUNE)\n", - " .shuffle(shuffle)\n", - " .padded_batch(batch_size)\n", - " .prefetch(tf.data.AUTOTUNE))\n", - " return ds" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tNdzrenxB3Yy" - }, - "outputs": [], - "source": [ - "save_dataset(train_raw, 'train_cache', mobilenet, tokenizer)\n", - "save_dataset(test_raw, 'test_cache', mobilenet, tokenizer)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "798DtfH51UI8" - }, - "source": [ - " \u003c/section\u003e\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GI265LiDslr2" - }, - "source": [ - "## Data ready for training\n", - "\n", - "After those preprocessing steps, here are the datasets:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Pwic2YCjHZmV" - }, - "outputs": [], - "source": [ - "train_ds = load_dataset('train_cache')\n", - "test_ds = load_dataset('test_cache')" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3B80JXj7HloX" - }, - "outputs": [], - "source": [ - "train_ds.element_spec" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "5jfb8qknlsKi" - }, - "source": [ - "The dataset now returns `(input, label)` pairs suitable for training with keras. The `inputs` are `(images, input_tokens)` pairs. The `images` have been processed with the feature-extractor model. For each location in the `input_tokens` the model looks at the text so far and tries to predict the next which is lined up at the same location in the `labels`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "YJBEwuXLZQdw" - }, - "outputs": [], - "source": [ - "for (inputs, ex_labels) in train_ds.take(1):\n", - " (ex_img, ex_in_tok) = inputs\n", - "\n", - "print(ex_img.shape)\n", - "print(ex_in_tok.shape)\n", - "print(ex_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "22R58DzZoF17" - }, - "source": [ - "The input tokens and the labels are the same, just shifted by 1 step:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "V7h5UGftn1hT" - }, - "outputs": [], - "source": [ - "print(ex_in_tok[0].numpy())\n", - "print(ex_labels[0].numpy())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "DfICM49WFpIb" - }, - "source": [ - "## A Transformer decoder model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ONyjuWsmZoyO" - }, - "source": [ - "This model assumes that the pretrained image encoder is sufficient, and just focuses on building the text decoder. This tutorial uses a 2-layer Transformer-decoder.\n", - "\n", - "The implementations are almost identical to those in the [Transformers tutorial](https://www.tensorflow.org/text/tutorials/transformer). Refer back to it for more details.\n", - "\n", - "\u003ctable\u003e\n", - "\u003ctr\u003e\n", - " \u003cth\u003eThe Transformer encoder and decoder.\u003c/th\u003e\n", - "\u003c/tr\u003e\n", - "\u003ctr\u003e\n", - " \u003ctd\u003e\n", - " \u003cimg width=400 src=\"https://www.tensorflow.org/images/tutorials/transformer/Transformer-1layer-words.png\"/\u003e\n", - " \u003c/td\u003e\n", - "\u003c/tr\u003e\n", - "\u003c/table\u003e" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qiRXWwIKNybB" - }, - "source": [ - "The model will be implemented in three main parts: \n", - "\n", - "1. Input - The token embedding and positional encoding (`SeqEmbedding`).\n", - "1. Decoder - A stack of transformer decoder layers (`DecoderLayer`) where each contains:\n", - " 1. A causal self attention later (`CausalSelfAttention`), where each output location can attend to the output so far.\n", - " 1. A cross attention layer (`CrossAttention`) where each output location can attend to the input image.\n", - " 1. A feed forward network (`FeedForward`) layer which further processes each output location independently.\n", - "1. Output - A multiclass-classification over the output vocabulary.\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "_ngm3SQMCaYU" - }, - "source": [ - "### Input" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "i9suaARZGPKw" - }, - "source": [ - "The input text has already been split up into tokens and converted to sequences of IDs. \n", - "\n", - "Remember that unlike a CNN or RNN the Transformer's attention layers are invariant to the order of the sequence. Without some positional input, it just sees an unordered set not a sequence. So in addition to a simple vector embedding for each token ID, the embedding layer will also include an embedding for each position in the sequence.\n", - "\n", - "The `SeqEmbedding` layer defined below:\n", - "\n", - "- It looks up the embedding vector for each token.\n", - "- It looks up an embedding vector for each sequence location.\n", - "- It adds the two together.\n", - "- It uses `mask_zero=True` to initialize the keras-masks for the model.\n", - "\n", - "Note: This implementation learns the position embeddings instead of using fixed embeddings like in the [Transformer tutorial](https://www.tensorflow.org/text/tutorials/transformer). Learning the embeddings is slightly less code, but doesn't generalize to longer sequences." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "P91LU2F0a9Ga" - }, - "outputs": [], - "source": [ - "class SeqEmbedding(tf.keras.layers.Layer):\n", - " def __init__(self, vocab_size, max_length, depth):\n", - " super().__init__()\n", - " self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)\n", - "\n", - " self.token_embedding = tf.keras.layers.Embedding(\n", - " input_dim=vocab_size,\n", - " output_dim=depth,\n", - " mask_zero=True)\n", - " \n", - " self.add = tf.keras.layers.Add()\n", - "\n", - " def call(self, seq):\n", - " seq = self.token_embedding(seq) # (batch, seq, depth)\n", - "\n", - " x = tf.range(tf.shape(seq)[1]) # (seq)\n", - " x = x[tf.newaxis, :] # (1, seq)\n", - " x = self.pos_embedding(x) # (1, seq, depth)\n", - "\n", - " return self.add([seq,x])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "II1mD-bBCdMB" - }, - "source": [ - "### Decoder" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GHMLeMtKPTCW" - }, - "source": [ - "The decoder is a standard Transformer-decoder, it contains a stack of `DecoderLayers` where each contains three sublayers: a `CausalSelfAttention`, a `CrossAttention`, and a`FeedForward`. The implementations are almost identical to the [Transformer tutorial](https://www.tensorflow.org/text/tutorials/transformer), refer to it for more details.\n", - "\n", - "The `CausalSelfAttention` layer is below:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6JTLiX3lKooQ" - }, - "outputs": [], - "source": [ - "class CausalSelfAttention(tf.keras.layers.Layer):\n", - " def __init__(self, **kwargs):\n", - " super().__init__()\n", - " self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)\n", - " # Use Add instead of + so the keras mask propagates through.\n", - " self.add = tf.keras.layers.Add() \n", - " self.layernorm = tf.keras.layers.LayerNormalization()\n", - " \n", - " def call(self, x):\n", - " attn = self.mha(query=x, value=x,\n", - " use_causal_mask=True)\n", - " x = self.add([x, attn])\n", - " return self.layernorm(x)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8c66OTRwQfd8" - }, - "source": [ - "The `CrossAttention` layer is below. Note the use of `return_attention_scores`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "rIY6Vu2pLBAO" - }, - "outputs": [], - "source": [ - "class CrossAttention(tf.keras.layers.Layer):\n", - " def __init__(self,**kwargs):\n", - " super().__init__()\n", - " self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)\n", - " self.add = tf.keras.layers.Add() \n", - " self.layernorm = tf.keras.layers.LayerNormalization()\n", - " \n", - " def call(self, x, y, **kwargs):\n", - " attn, attention_scores = self.mha(\n", - " query=x, value=y,\n", - " return_attention_scores=True)\n", - " \n", - " self.last_attention_scores = attention_scores\n", - "\n", - " x = self.add([x, attn])\n", - " return self.layernorm(x)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8Hn5p6f-RE0C" - }, - "source": [ - "The `FeedForward` layer is below. Remember that a `layers.Dense` layer is applied to the last axis of the input. The input will have a shape of `(batch, sequence, channels)`, so it automatically applies pointwise across the `batch` and `sequence` axes. " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cWKrl7teOnH2" - }, - "outputs": [], - "source": [ - "class FeedForward(tf.keras.layers.Layer):\n", - " def __init__(self, units, dropout_rate=0.1):\n", - " super().__init__()\n", - " self.seq = tf.keras.Sequential([\n", - " tf.keras.layers.Dense(units=2*units, activation='relu'),\n", - " tf.keras.layers.Dense(units=units),\n", - " tf.keras.layers.Dropout(rate=dropout_rate),\n", - " ])\n", - "\n", - " self.layernorm = tf.keras.layers.LayerNormalization()\n", - " \n", - " def call(self, x):\n", - " x = x + self.seq(x)\n", - " return self.layernorm(x)\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lbXoiVNPRoJc" - }, - "source": [ - "Next arrange these three layers into a larger `DecoderLayer`. Each decoder layer applies the three smaller layers in sequence. After each sublayer the shape of `out_seq` is `(batch, sequence, channels)`. The decoder layer also returns the `attention_scores` for later visualizations." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ydcW5KZZHou7" - }, - "outputs": [], - "source": [ - "class DecoderLayer(tf.keras.layers.Layer):\n", - " def __init__(self, units, num_heads=1, dropout_rate=0.1):\n", - " super().__init__()\n", - " \n", - " self.self_attention = CausalSelfAttention(num_heads=num_heads,\n", - " key_dim=units,\n", - " dropout=dropout_rate)\n", - " self.cross_attention = CrossAttention(num_heads=num_heads,\n", - " key_dim=units,\n", - " dropout=dropout_rate)\n", - " self.ff = FeedForward(units=units, dropout_rate=dropout_rate)\n", - " \n", - "\n", - " def call(self, inputs, training=False):\n", - " in_seq, out_seq = inputs\n", - "\n", - " # Text input\n", - " out_seq = self.self_attention(out_seq)\n", - "\n", - " out_seq = self.cross_attention(out_seq, in_seq)\n", - " \n", - " self.last_attention_scores = self.cross_attention.last_attention_scores\n", - "\n", - " out_seq = self.ff(out_seq)\n", - "\n", - " return out_seq" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-lgbYrF5Csqu" - }, - "source": [ - "### Output" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "VcnKZkrklAQf" - }, - "source": [ - "At minimum the output layer needs a `layers.Dense` layer to generate logit-predictions for each token at each location." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6WQD87efena5" - }, - "source": [ - "But there are a few other features you can add to make this work a little better:\n", - "\n", - "1. **Handle bad tokens**: The model will be generating text. It should\n", - " never generate a pad, unknown, or start token (`''`, `'[UNK]'`, \n", - " `'[START]'`). So set the bias for these to a large negative value.\n", - "\n", - " \u003e Note: You'll need to ignore these tokens in the loss function as well. \n", - "\n", - "2. **Smart initialization**: The default initialization of a dense layer will\n", - " give a model that initially predicts each token with almost uniform\n", - " likelihood. The actual token distribution is far from uniform. The\n", - " optimal value for the initial bias of the output layer is the log of the\n", - " probability of each token. So include an `adapt` method to count the tokens\n", - " and set the optimal initial bias. This reduces the initial loss from the\n", - " entropy of the uniform distribution (`log(vocabulary_size)`) to the marginal\n", - " entropy of the distribution (`-p*log(p)`).\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "CeWw2SFDHUfo" - }, - "outputs": [], - "source": [ - "#@title\n", - "class TokenOutput(tf.keras.layers.Layer):\n", - " def __init__(self, tokenizer, banned_tokens=('', '[UNK]', '[START]'), **kwargs):\n", - " super().__init__()\n", - " \n", - " self.dense = tf.keras.layers.Dense(\n", - " units=tokenizer.vocabulary_size(), **kwargs)\n", - " self.tokenizer = tokenizer\n", - " self.banned_tokens = banned_tokens\n", - "\n", - " self.bias = None\n", - "\n", - " def adapt(self, ds):\n", - " counts = collections.Counter()\n", - " vocab_dict = {name: id \n", - " for id, name in enumerate(self.tokenizer.get_vocabulary())}\n", - "\n", - " for tokens in tqdm.tqdm(ds):\n", - " counts.update(tokens.numpy().flatten())\n", - "\n", - " counts_arr = np.zeros(shape=(self.tokenizer.vocabulary_size(),))\n", - " counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values())\n", - "\n", - " counts_arr = counts_arr[:]\n", - " for token in self.banned_tokens:\n", - " counts_arr[vocab_dict[token]] = 0\n", - "\n", - " total = counts_arr.sum()\n", - " p = counts_arr/total\n", - " p[counts_arr==0] = 1.0\n", - " log_p = np.log(p) # log(1) == 0\n", - "\n", - " entropy = -(log_p*p).sum()\n", - "\n", - " print()\n", - " print(f\"Uniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}\")\n", - " print(f\"Marginal entropy: {entropy:0.2f}\")\n", - "\n", - " self.bias = log_p\n", - " self.bias[counts_arr==0] = -1e9\n", - "\n", - " def call(self, x):\n", - " x = self.dense(x)\n", - " # TODO(b/250038731): Fix this.\n", - " # An Add layer doesn't work because of the different shapes.\n", - " # This clears the mask, that's okay because it prevents keras from rescaling\n", - " # the losses.\n", - " return x + self.bias\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xzQHqANd1A6Q" - }, - "source": [ - "The smart initialization will significantly reduce the initial loss:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GGnOQyc501B2" - }, - "outputs": [], - "source": [ - "output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))\n", - "# This might run a little faster if the dataset didn't also have to load the image data.\n", - "output_layer.adapt(train_ds.map(lambda inputs, labels: labels))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3gq-ICN7bD-u" - }, - "source": [ - "### Build the model" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "gou4fPH_SWgH" - }, - "source": [ - "To build the model, you need to combine several parts:\n", - "\n", - "1. The image `feature_extractor` and the text `tokenizer` and.\n", - "1. The `seq_embedding` layer, to convert batches of token-IDs to \n", - " vectors `(batch, sequence, channels)`.\n", - "3. The stack of `DecoderLayers` layers that will process the text and image data.\n", - "4. The `output_layer` which returns a pointwise prediction of what the next word should be." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bHCISYehH1f6" - }, - "outputs": [], - "source": [ - "class Captioner(tf.keras.Model):\n", - " @classmethod\n", - " def add_method(cls, fun):\n", - " setattr(cls, fun.__name__, fun)\n", - " return fun\n", - "\n", - " def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,\n", - " units=256, max_length=50, num_heads=1, dropout_rate=0.1):\n", - " super().__init__()\n", - " self.feature_extractor = feature_extractor\n", - " self.tokenizer = tokenizer\n", - " self.word_to_index = tf.keras.layers.StringLookup(\n", - " mask_token=\"\",\n", - " vocabulary=tokenizer.get_vocabulary())\n", - " self.index_to_word = tf.keras.layers.StringLookup(\n", - " mask_token=\"\",\n", - " vocabulary=tokenizer.get_vocabulary(),\n", - " invert=True) \n", - "\n", - " self.seq_embedding = SeqEmbedding(\n", - " vocab_size=tokenizer.vocabulary_size(),\n", - " depth=units,\n", - " max_length=max_length)\n", - "\n", - " self.decoder_layers = [\n", - " DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)\n", - " for n in range(num_layers)]\n", - "\n", - " self.output_layer = output_layer" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "YW390dOz9T-x" - }, - "source": [ - "When you call the model, for training, it receives an `image, txt` pair. To make this function more usable, be flexible about the input:\n", - "\n", - "* If the image has 3 channels run it through the feature_extractor. Otherwise assume that it has been already. Similarly\n", - "* If the text has dtype `tf.string` run it through the tokenizer.\n", - "\n", - "After that running the model is only a few steps:\n", - "\n", - "1. Flatten the extracted image features, so they can be input to the decoder layers.\n", - "2. Look up the token embeddings.\n", - "3. Run the stack of `DecoderLayer`s, on the image features and text embeddings.\n", - "4. Run the output layer to predict the next token at each position.\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "lPdb7I4h9Ulo" - }, - "outputs": [], - "source": [ - " @Captioner.add_method\n", - " def call(self, inputs):\n", - " image, txt = inputs\n", - "\n", - " if image.shape[-1] == 3:\n", - " # Apply the feature-extractor, if you get an RGB image.\n", - " image = self.feature_extractor(image)\n", - " \n", - " # Flatten the feature map\n", - " image = einops.rearrange(image, 'b h w c -\u003e b (h w) c')\n", - "\n", - "\n", - " if txt.dtype == tf.string:\n", - " # Apply the tokenizer if you get string inputs.\n", - " txt = tokenizer(txt)\n", - "\n", - " txt = self.seq_embedding(txt)\n", - "\n", - " # Look at the image\n", - " for dec_layer in self.decoder_layers:\n", - " txt = dec_layer(inputs=(image, txt))\n", - " \n", - " txt = self.output_layer(txt)\n", - "\n", - " return txt" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kmM7aZQsLiyU" - }, - "outputs": [], - "source": [ - "model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,\n", - " units=256, dropout_rate=0.5, num_layers=2, num_heads=2)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "xGvOcLQKghXN" - }, - "source": [ - "### Generate captions\n", - "\n", - "Before getting into training, write a bit of code to generate captions. You'll use this to see how training is progressing.\n", - "\n", - "Start by downloading a test image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "cwFcdMqC-jE2" - }, - "outputs": [], - "source": [ - "image_url = 'https://tensorflow.org/images/surf.jpg'\n", - "image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)\n", - "image = load_image(image_path)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IRBIiTkubmxA" - }, - "source": [ - "To caption an image with this model:\n", - "\n", - "- Extract the `img_features`\n", - "- Initialize the list of output tokens with a `[START]` token.\n", - "- Pass `img_features` and `tokens` into the model.\n", - " - It returns a list of logits.\n", - " - Choose the next token based on those logits. \n", - " - Add it to the list of tokens, and continue the loop.\n", - " - If it generates an `'[END]'` token, break out of the loop.\n", - "\n", - "So add a \"simple\" method to do just that:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Nf1Jie9ef_Cg" - }, - "outputs": [], - "source": [ - "@Captioner.add_method\n", - "def simple_gen(self, image, temperature=1):\n", - " initial = self.word_to_index([['[START]']]) # (batch, sequence)\n", - " img_features = self.feature_extractor(image[tf.newaxis, ...])\n", - "\n", - " tokens = initial # (batch, sequence)\n", - " for n in range(50):\n", - " preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)\n", - " preds = preds[:,-1, :] #(batch, vocab)\n", - " if temperature==0:\n", - " next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)\n", - " else:\n", - " next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)\n", - " tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) \n", - "\n", - " if next[0] == self.word_to_index('[END]'):\n", - " break\n", - " words = index_to_word(tokens[0, 1:-1])\n", - " result = tf.strings.reduce_join(words, axis=-1, separator=' ')\n", - " return result.numpy().decode()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TxN2NPX2zB8y" - }, - "source": [ - "Here are some generated captions for that image, the model's untrained, so they don't make much sense yet:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sPm96CccvHnq" - }, - "outputs": [], - "source": [ - "for t in (0.0, 0.5, 1.0):\n", - " result = model.simple_gen(image, temperature=t)\n", - " print(result)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "JefwCRZ8z-Ah" - }, - "source": [ - "The temperature parameter allows you to interpolate between 3 modes:\n", - "\n", - "1. Greedy decoding (`temperature=0.0`) - Chooses the most likely next token at each step.\n", - "2. Random sampling according to the logits (`temperature=1.0`).\n", - "3. Uniform random sampling (`temperature \u003e\u003e 1.0`). \n", - "\n", - "Since the model is untrained, and it used the frequency-based initialization, the \"greedy\" output (first) usually only contains the most common tokens: `['a', '.', '[END]']`." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "r0FpTvaPkqON" - }, - "source": [ - "## Train" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IKcwZdqObK-U" - }, - "source": [ - "To train the model you'll need several additional components:\n", - "\n", - "- The Loss and metrics\n", - "- The Optimizer\n", - "- Optional Callbacks" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "g5IW2mWa2sAG" - }, - "source": [ - "### Losses and metrics" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "XbpbDQTw1lOW" - }, - "source": [ - "Here's an implementation of a masked loss and accuracy:\n", - "\n", - "When calculating the mask for the loss, note the `loss \u003c 1e8`. This term discards the artificial, impossibly high losses for the `banned_tokens`." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "s24im3FqxAfT" - }, - "outputs": [], - "source": [ - "def masked_loss(labels, preds): \n", - " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)\n", - "\n", - " mask = (labels != 0) \u0026 (loss \u003c 1e8) \n", - " mask = tf.cast(mask, loss.dtype)\n", - "\n", - " loss = loss*mask\n", - " loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)\n", - " return loss\n", - "\n", - "def masked_acc(labels, preds):\n", - " mask = tf.cast(labels!=0, tf.float32)\n", - " preds = tf.argmax(preds, axis=-1)\n", - " labels = tf.cast(labels, tf.int64)\n", - " match = tf.cast(preds == labels, mask.dtype)\n", - " acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)\n", - " return acc" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "zOhjHqgv3F2e" - }, - "source": [ - "### Callbacks" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3dyQN9UfJYEd" - }, - "source": [ - "For feedback during training setup a `keras.callbacks.Callback` to generate some captions for the surfer image at the end of each epoch." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "IKDwbZOCZ-AP" - }, - "outputs": [], - "source": [ - "class GenerateText(tf.keras.callbacks.Callback):\n", - " def __init__(self):\n", - " image_url = 'https://tensorflow.org/images/surf.jpg'\n", - " image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)\n", - " self.image = load_image(image_path)\n", - "\n", - " def on_epoch_end(self, epochs=None, logs=None):\n", - " print()\n", - " print()\n", - " for t in (0.0, 0.5, 1.0):\n", - " result = self.model.simple_gen(self.image, temperature=t)\n", - " print(result)\n", - " print()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1yNA3_RAsdl0" - }, - "source": [ - "It generates three output strings, like the earlier example, like before the first is \"greedy\", choosing the argmax of the logits at each step." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "IGVLpzo13rcA" - }, - "outputs": [], - "source": [ - "g = GenerateText()\n", - "g.model = model\n", - "g.on_epoch_end(0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "MAxp4KZRKDk9" - }, - "source": [ - "Also use `callbacks.EarlyStopping` to terminate training when the model starts to overfit." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "MjzrwGZp23xx" - }, - "outputs": [], - "source": [ - "callbacks = [\n", - " GenerateText(),\n", - " tf.keras.callbacks.EarlyStopping(\n", - " patience=5, restore_best_weights=True)]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZBaJhQpcG8u0" - }, - "source": [ - "### Train" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WBXG0dCDKO55" - }, - "source": [ - "Configure and execute the training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2OR5ZpAII__u" - }, - "outputs": [], - "source": [ - "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n", - " loss=masked_loss,\n", - " metrics=[masked_acc])" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ro955bQ2KR0X" - }, - "source": [ - "For more frequent reporting, use the `Dataset.repeat()` method, and set the `steps_per_epoch` and `validation_steps` arguments to `Model.fit`. \n", - "\n", - "With this setup on `Flickr8k` a full pass over the dataset is 900+ batches, but below the reporting-epochs are 100 steps." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "3aB0baOVMZe9" - }, - "outputs": [], - "source": [ - "history = model.fit(\n", - " train_ds.repeat(),\n", - " steps_per_epoch=100,\n", - " validation_data=test_ds.repeat(),\n", - " validation_steps=20,\n", - " epochs=100,\n", - " callbacks=callbacks)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "P634LfVgw-eV" - }, - "source": [ - "Plot the loss and accuracy over the training run:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6Wn8KSkUw916" - }, - "outputs": [], - "source": [ - "plt.plot(history.history['loss'], label='loss')\n", - "plt.plot(history.history['val_loss'], label='val_loss')\n", - "plt.ylim([0, max(plt.ylim())])\n", - "plt.xlabel('Epoch #')\n", - "plt.ylabel('CE/token')\n", - "plt.legend()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "yZQ78b2Kxw-T" - }, - "outputs": [], - "source": [ - "plt.plot(history.history['masked_acc'], label='accuracy')\n", - "plt.plot(history.history['val_masked_acc'], label='val_accuracy')\n", - "plt.ylim([0, max(plt.ylim())])\n", - "plt.xlabel('Epoch #')\n", - "plt.ylabel('CE/token')\n", - "plt.legend()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "SQN1qT7KNqbL" - }, - "source": [ - "## Attention plots" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "E9XJaC2b2J23" - }, - "source": [ - "Now, using the trained model, run that `simple_gen` method on the image:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1UQPtNTb2eu3" - }, - "outputs": [], - "source": [ - "result = model.simple_gen(image, temperature=0.0)\n", - "result" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "7NXbmeLGN1bJ" - }, - "source": [ - "Split the output back into tokens:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "zHKOpm0w5Xto" - }, - "outputs": [], - "source": [ - "str_tokens = result.split()\n", - "str_tokens.append('[END]')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fE-AjuAV55Qo" - }, - "source": [ - "The `DecoderLayers` each cache the attention scores for their `CrossAttention` layer. The shape of each attention map is `(batch=1, heads, sequence, image)`:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "XZpyuQvq2q-B" - }, - "outputs": [], - "source": [ - "attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]\n", - "[map.shape for map in attn_maps]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "T42ImsWv6oHG" - }, - "source": [ - "So stack the maps along the `batch` axis, then average over the `(batch, heads)` axes, while splitting the `image` axis back into `height, width`:\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ojwtvnkh6mS-" - }, - "outputs": [], - "source": [ - "attention_maps = tf.concat(attn_maps, axis=0)\n", - "attention_maps = einops.reduce(\n", - " attention_maps,\n", - " 'batch heads sequence (height width) -\u003e sequence height width',\n", - " height=7, width=7,\n", - " reduction='mean')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "4TM7rA3zGpJW" - }, - "source": [ - "Now you have a single attention map, for each sequence prediction. The values in each map should sum to `1.`" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "ASWmWerGCZp3" - }, - "outputs": [], - "source": [ - "einops.reduce(attention_maps, 'sequence height width -\u003e sequence', reduction='sum')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fv7XYGFUd-U7" - }, - "source": [ - "So here is where the model was focusing attention while generating each token of the output:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "fD_y7PD6RPGt" - }, - "outputs": [], - "source": [ - "def plot_attention_maps(image, str_tokens, attention_map):\n", - " fig = plt.figure(figsize=(16, 9))\n", - "\n", - " len_result = len(str_tokens)\n", - " \n", - " titles = []\n", - " for i in range(len_result):\n", - " map = attention_map[i]\n", - " grid_size = max(int(np.ceil(len_result/2)), 2)\n", - " ax = fig.add_subplot(3, grid_size, i+1)\n", - " titles.append(ax.set_title(str_tokens[i]))\n", - " img = ax.imshow(image)\n", - " ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),\n", - " clim=[0.0, np.max(map)])\n", - "\n", - " plt.tight_layout()" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "PI4NAAws9rvY" - }, - "outputs": [], - "source": [ - "plot_attention_maps(image/255, str_tokens, attention_maps)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "riTz0abQKMkV" - }, - "source": [ - "Now put that together into a more usable function:" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mktpfW-SKQIJ" - }, - "outputs": [], - "source": [ - "@Captioner.add_method\n", - "def run_and_show_attention(self, image, temperature=0.0):\n", - " result_txt = self.simple_gen(image, temperature)\n", - " str_tokens = result_txt.split()\n", - " str_tokens.append('[END]')\n", - "\n", - " attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]\n", - " attention_maps = tf.concat(attention_maps, axis=0)\n", - " attention_maps = einops.reduce(\n", - " attention_maps,\n", - " 'batch heads sequence (height width) -\u003e sequence height width',\n", - " height=7, width=7,\n", - " reduction='mean')\n", - " \n", - " plot_attention_maps(image/255, str_tokens, attention_maps)\n", - " t = plt.suptitle(result_txt)\n", - " t.set_y(1.05)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "FntRkY11OiMw" - }, - "outputs": [], - "source": [ - "run_and_show_attention(model, image)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Rprk3HEvZuxb" - }, - "source": [ - "## Try it on your own images\n", - "\n", - "For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9Psd1quzaAWg" - }, - "outputs": [], - "source": [ - "image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'\n", - "image_path = tf.keras.utils.get_file(origin=image_url)\n", - "image = load_image(image_path)\n", - "\n", - "run_and_show_attention(model, image)" - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "image_captioning.ipynb", - "toc_visible": true - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "K2s1A9eLRPEj" + }, + "source": [ + "##### Copyright 2018 The TensorFlow Authors.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "VRLVEKiTEn04" + }, + "outputs": [], + "source": [ + "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EFwSaNB8jF7s" + }, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Cffg2i257iMS" + }, + "source": [ + "# Image captioning with visual attention\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + "
\n", + " View on TensorFlow.org\n", + " \n", + " Run in Google Colab\n", + " \n", + " View on GitHub\n", + " \n", + " Download notebook\n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "QASbY_HGo4Lq" + }, + "source": [ + "Given an image like the example below, your goal is to generate a\n", + "caption such as \"a surfer riding on a wave\".\n", + "\n", + "\n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + "
\n", + " \n", + "
A man surfing, from wikimedia
\n", + "\n", + "The model architecture used here is inspired by [Show, Attend and Tell: Neural Image Caption Generation with Visual Attention](https://arxiv.org/abs/1502.03044), but has been updated to use a 2-layer Transformer-decoder. To get the most out of this tutorial you should have some experience with [text generation](https://www.tensorflow.org/text/tutorials/text_generation), [seq2seq models & attention](https://www.tensorflow.org/text/tutorials/nmt_with_attention), or [transformers](https://www.tensorflow.org/text/tutorials/transformer)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6HbD8n0w7d3F" + }, + "source": [ + "The model architecture built in this tutorial is shown below. Features are extracted from the image, and passed to the cross-attention layers of the Transformer-decoder.\n", + "\n", + "\n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + "
The model architecture
\n", + " \n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1IxifZKT6vXQ" + }, + "source": [ + "The transformer decoder is mainly built from attention layers. It uses self-attention to process the sequence being generated, and it uses cross-attention to attend to the image.\n", + "\n", + "By inspecting the attention weights of the cross attention layers you will see what parts of the image the model is looking at as it generates words.\n", + "\n", + "![Prediction](https://tensorflow.org/images/imcap_prediction.png)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "87us2sLVdwME" + }, + "source": [ + "This notebook is an end-to-end example. When you run the notebook, it downloads a dataset, extracts and caches the image features, and trains a decoder model. It then uses the model to generate captions on new images." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5bwwk4uxRz6A" + }, + "source": [ + "## Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "gc06pTaBbl72" + }, + "outputs": [], + "source": [ + "!apt install --allow-change-held-packages libcudnn8=8.6.0.163-1+cuda11.8" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2R1hQGtZEi8Y" + }, + "outputs": [], + "source": [ + "!pip uninstall -y tensorflow estimator keras" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "5Xbt8BkPv8Ou" + }, + "outputs": [], + "source": [ + "!pip install -U tensorflow_text tensorflow tensorflow_datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7TGZmOuqMia9" + }, + "outputs": [], + "source": [ + "!pip install einops" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nQ6q39Vd-y-7" + }, + "source": [ + "This tutorial uses lots of imports, mostly for loading the dataset(s)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "cellView": "form", + "id": "U8l4RJ0XRPEm" + }, + "outputs": [], + "source": [ + "#@title\n", + "import concurrent.futures\n", + "import collections\n", + "import dataclasses\n", + "import hashlib\n", + "import itertools\n", + "import json\n", + "import math\n", + "import os\n", + "import pathlib\n", + "import random\n", + "import re\n", + "import string\n", + "import time\n", + "import urllib.request\n", + "\n", + "import einops\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "from PIL import Image\n", + "import requests\n", + "import tqdm\n", + "\n", + "import tensorflow as tf\n", + "import tensorflow_hub as hub\n", + "import tensorflow_text as text\n", + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Kl9qGnjWrv80" + }, + "source": [ + "## [Optional] Data handling\n", + "\n", + "This section downloads a captions dataset and prepares it for training. It tokenizes the input text, and caches the results of running all the images through a pretrained feature-extractor model. It's not critical to understand everything in this section.\n", + "\n", + "
\n", + " \n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "q5e_SigQFiWf" + }, + "source": [ + "### Choose a dataset\n", + "\n", + "This tutorial is set up to give a choice of datasets. Either [Flickr8k](https://www.ijcai.org/Proceedings/15/Papers/593.pdf) or a small slice of the [Conceptual Captions](https://ai.google.com/research/ConceptualCaptions/) dataset. These two are downloaded and converted from scratch, but it wouldn't be hard to convert the tutorial to use the caption datasets available in [TensorFlow Datasets](https://www.tensorflow.org/datasets): [Coco Captions](https://www.tensorflow.org/datasets/catalog/coco_captions) and the full [Conceptual Captions](https://www.tensorflow.org/datasets/community_catalog/huggingface/conceptual_captions).\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "wqGXX9Dc5c0v" + }, + "source": [ + "#### Flickr8k" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kaNy_l7tGuAZ" + }, + "outputs": [], + "source": [ + "def flickr8k(path='flickr8k'):\n", + " path = pathlib.Path(path)\n", + "\n", + " if len(list(path.rglob('*'))) < 16197:\n", + " tf.keras.utils.get_file(\n", + " origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip',\n", + " cache_dir='.',\n", + " cache_subdir=path,\n", + " extract=True)\n", + " tf.keras.utils.get_file(\n", + " origin='https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip',\n", + " cache_dir='.',\n", + " cache_subdir=path,\n", + " extract=True)\n", + " \n", + " captions = (path/\"Flickr8k.token.txt\").read_text().splitlines()\n", + " captions = (line.split('\\t') for line in captions)\n", + " captions = ((fname.split('#')[0], caption) for (fname, caption) in captions)\n", + "\n", + " cap_dict = collections.defaultdict(list)\n", + " for fname, cap in captions:\n", + " cap_dict[fname].append(cap)\n", + "\n", + " train_files = (path/'Flickr_8k.trainImages.txt').read_text().splitlines()\n", + " train_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in train_files]\n", + "\n", + " test_files = (path/'Flickr_8k.testImages.txt').read_text().splitlines()\n", + " test_captions = [(str(path/'Flicker8k_Dataset'/fname), cap_dict[fname]) for fname in test_files]\n", + "\n", + " train_ds = tf.data.experimental.from_list(train_captions)\n", + " test_ds = tf.data.experimental.from_list(test_captions)\n", + "\n", + " return train_ds, test_ds" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zQICBAF4FmSL" + }, + "source": [ + "#### Conceptual Captions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vQwnxXZXRl12" + }, + "outputs": [], + "source": [ + "def conceptual_captions(*, data_dir=\"conceptual_captions\", num_train, num_val):\n", + " def iter_index(index_path):\n", + " with open(index_path) as f:\n", + " for line in f:\n", + " caption, url = line.strip().split('\\t')\n", + " yield caption, url\n", + "\n", + " def download_image_urls(data_dir, urls):\n", + " ex = concurrent.futures.ThreadPoolExecutor(max_workers=100)\n", + " def save_image(url):\n", + " hash = hashlib.sha1(url.encode())\n", + " # Name the files after the hash of the URL.\n", + " file_path = data_dir/f'{hash.hexdigest()}.jpeg'\n", + " if file_path.exists():\n", + " # Only download each file once.\n", + " return file_path\n", + "\n", + " try:\n", + " result = requests.get(url, timeout=5)\n", + " except Exception:\n", + " file_path = None\n", + " else:\n", + " file_path.write_bytes(result.content)\n", + " return file_path\n", + " \n", + " result = []\n", + " out_paths = ex.map(save_image, urls)\n", + " for file_path in tqdm.tqdm(out_paths, total=len(urls)):\n", + " result.append(file_path)\n", + "\n", + " return result\n", + "\n", + " def ds_from_index_file(index_path, data_dir, count):\n", + " data_dir.mkdir(exist_ok=True)\n", + " index = list(itertools.islice(iter_index(index_path), count))\n", + " captions = [caption for caption, url in index]\n", + " urls = [url for caption, url in index]\n", + "\n", + " paths = download_image_urls(data_dir, urls)\n", + "\n", + " new_captions = []\n", + " new_paths = []\n", + " for cap, path in zip(captions, paths):\n", + " if path is None:\n", + " # Download failed, so skip this pair.\n", + " continue\n", + " new_captions.append(cap)\n", + " new_paths.append(path)\n", + " \n", + " new_paths = [str(p) for p in new_paths]\n", + "\n", + " ds = tf.data.Dataset.from_tensor_slices((new_paths, new_captions))\n", + " ds = ds.map(lambda path,cap: (path, cap[tf.newaxis])) # 1 caption per image\n", + " return ds\n", + "\n", + " data_dir = pathlib.Path(data_dir)\n", + " train_index_path = tf.keras.utils.get_file(\n", + " origin='https://storage.googleapis.com/gcc-data/Train/GCC-training.tsv',\n", + " cache_subdir=data_dir,\n", + " cache_dir='.')\n", + " \n", + " val_index_path = tf.keras.utils.get_file(\n", + " origin='https://storage.googleapis.com/gcc-data/Validation/GCC-1.1.0-Validation.tsv',\n", + " cache_subdir=data_dir,\n", + " cache_dir='.')\n", + " \n", + " train_raw = ds_from_index_file(train_index_path, data_dir=data_dir/'train', count=num_train)\n", + " test_raw = ds_from_index_file(val_index_path, data_dir=data_dir/'val', count=num_val)\n", + "\n", + " return train_raw, test_raw" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rBAagBw5p-TM" + }, + "source": [ + "#### Download the dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WFtTZaobquNr" + }, + "source": [ + "The Flickr8k is a good choice because it contains 5-captions per image, more data for a smaller download." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EJySPbzJ4Wxw" + }, + "outputs": [], + "source": [ + "choose = 'flickr8k'\n", + "\n", + "if choose == 'flickr8k':\n", + " train_raw, test_raw = flickr8k()\n", + "else:\n", + " train_raw, test_raw = conceptual_captions(num_train=10000, num_val=5000)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-UAc275FHxm8" + }, + "source": [ + "The loaders for both datasets above return `tf.data.Dataset`s containing `(image_path, captions)` pairs. The Flickr8k dataset contains 5 captions per image, while Conceptual Captions has 1:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sAQSps5F8RQI" + }, + "outputs": [], + "source": [ + "train_raw.element_spec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xIa0ZaP4tBez" + }, + "outputs": [], + "source": [ + "for ex_path, ex_captions in train_raw.take(1):\n", + " print(ex_path)\n", + " print(ex_captions)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8cSW4u-ORPFQ" + }, + "source": [ + "### Image feature extractor\n", + "\n", + "You will use an image model (pretrained on imagenet) to extract the features from each image. The model was trained as an image classifier, but setting `include_top=False` returns the model without the final classification layer, so you can use the last layer of feature-maps: \n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IlUckK8Zfikv" + }, + "outputs": [], + "source": [ + "IMAGE_SHAPE=(224, 224, 3)\n", + "mobilenet = tf.keras.applications.MobileNetV3Small(\n", + " input_shape=IMAGE_SHAPE,\n", + " include_top=False,\n", + " include_preprocessing=True)\n", + "mobilenet.trainable=False" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Dojkiou9gL3R" + }, + "source": [ + "Here's a function to load an image and resize it for the model:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zXR0217aRPFR" + }, + "outputs": [], + "source": [ + "def load_image(image_path):\n", + " img = tf.io.read_file(image_path)\n", + " img = tf.io.decode_jpeg(img, channels=3)\n", + " img = tf.image.resize(img, IMAGE_SHAPE[:-1])\n", + " return img" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-JyQ7zS6gzZh" + }, + "source": [ + "The model returns a feature map for each image in the input batch:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sY86n2i6wJNm" + }, + "outputs": [], + "source": [ + "test_img_batch = load_image(ex_path)[tf.newaxis, :]\n", + "\n", + "print(test_img_batch.shape)\n", + "print(mobilenet(test_img_batch).shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nyqH3zFwRPFi" + }, + "source": [ + "### Setup the text tokenizer/vectorizer\n", + "\n", + "You will transform the text captions into integer sequences using the [TextVectorization](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization) layer, with the following steps:\n", + "\n", + "* Use [adapt](https://www.tensorflow.org/api_docs/python/tf/keras/layers/TextVectorization#adapt) to iterate over all captions, split the captions into words, and compute a vocabulary of the top words.\n", + "* Tokenize all captions by mapping each word to its index in the vocabulary. All output sequences will be padded to length 50.\n", + "* Create word-to-index and index-to-word mappings to display results." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NroZIzB90hD3" + }, + "outputs": [], + "source": [ + "def standardize(s):\n", + " s = tf.strings.lower(s)\n", + " s = tf.strings.regex_replace(s, f'[{re.escape(string.punctuation)}]', '')\n", + " s = tf.strings.join(['[START]', s, '[END]'], separator=' ')\n", + " return s" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "n9SQOXFsyS36" + }, + "outputs": [], + "source": [ + "# Use the top 5000 words for a vocabulary.\n", + "vocabulary_size = 5000\n", + "tokenizer = tf.keras.layers.TextVectorization(\n", + " max_tokens=vocabulary_size,\n", + " standardize=standardize,\n", + " ragged=True)\n", + "# Learn the vocabulary from the caption data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oJGE34aiRPFo" + }, + "outputs": [], + "source": [ + "tokenizer.adapt(train_raw.map(lambda fp,txt: txt).unbatch().batch(1024))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oRahTDtWhJIf" + }, + "outputs": [], + "source": [ + "tokenizer.get_vocabulary()[:10]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-2mGxD33JCxN" + }, + "outputs": [], + "source": [ + "t = tokenizer([['a cat in a hat'], ['a robot dog']])\n", + "t" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8Q44tNQVRPFt" + }, + "outputs": [], + "source": [ + "# Create mappings for words to indices and indices to words.\n", + "word_to_index = tf.keras.layers.StringLookup(\n", + " mask_token=\"\",\n", + " vocabulary=tokenizer.get_vocabulary())\n", + "index_to_word = tf.keras.layers.StringLookup(\n", + " mask_token=\"\",\n", + " vocabulary=tokenizer.get_vocabulary(),\n", + " invert=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qo-cfCX3LnHs" + }, + "outputs": [], + "source": [ + "w = index_to_word(t)\n", + "w.to_list()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rrUUfGc65vAT" + }, + "outputs": [], + "source": [ + "tf.strings.reduce_join(w, separator=' ', axis=-1).numpy()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "uEWM9xrYcg45" + }, + "source": [ + "### Prepare the datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6aX0Z_98S2tN" + }, + "source": [ + "The `train_raw` and `test_raw` datasets contain 1:many `(image, captions)` pairs. \n", + "\n", + "This function will replicate the image so there are 1:1 images to captions:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3_Lqwl9NiGT0" + }, + "outputs": [], + "source": [ + "def match_shapes(images, captions):\n", + " caption_shape = einops.parse_shape(captions, 'b c')\n", + " captions = einops.rearrange(captions, 'b c -> (b c)')\n", + " images = einops.repeat(\n", + " images, 'b ... -> (b c) ...',\n", + " c = caption_shape['c'])\n", + " return images, captions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CZGUsuGzUfzt" + }, + "outputs": [], + "source": [ + "for ex_paths, ex_captions in train_raw.batch(32).take(1):\n", + " break\n", + "\n", + "print('image paths:', ex_paths.shape)\n", + "print('captions:', ex_captions.shape)\n", + "print()\n", + "\n", + "ex_paths, ex_captions = match_shapes(images=ex_paths, captions=ex_captions)\n", + "\n", + "print('image_paths:', ex_paths.shape)\n", + "print('captions:', ex_captions.shape)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8ENR_-swVhnm" + }, + "source": [ + "To be compatible with keras training the dataset should contain `(inputs, labels)` pairs. For text generation the tokens are both an input and the labels, shifted by one step. This function will convert an `(images, texts)` pair to an `((images, input_tokens), label_tokens)` pair:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2DsgQ_hZT4C2" + }, + "outputs": [], + "source": [ + "def prepare_txt(imgs, txts):\n", + " tokens = tokenizer(txts)\n", + "\n", + " input_tokens = tokens[..., :-1]\n", + " label_tokens = tokens[..., 1:]\n", + " return (imgs, input_tokens), label_tokens" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DA1x2j0JXX-N" + }, + "source": [ + "This function adds operations to a dataset. The steps are:\n", + "\n", + "1. Load the images (and ignore images that fail to load).\n", + "2. Replicate images to match the number of captions.\n", + "3. Shuffle and rebatch the `image, caption` pairs.\n", + "4. Tokenize the text, shift the tokens and add `label_tokens`.\n", + "5. Convert the text from a `RaggedTensor` representation to padded dense `Tensor` representation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4_Pt9zldjQ0q" + }, + "outputs": [], + "source": [ + "def prepare_dataset(ds, tokenizer, batch_size=32, shuffle_buffer=1000):\n", + " # Load the images and make batches.\n", + " ds = (ds\n", + " .shuffle(10000)\n", + " .map(lambda path, caption: (load_image(path), caption))\n", + " .apply(tf.data.experimental.ignore_errors())\n", + " .batch(batch_size))\n", + "\n", + " def to_tensor(inputs, labels):\n", + " (images, in_tok), out_tok = inputs, labels\n", + " return (images, in_tok.to_tensor()), out_tok.to_tensor()\n", + "\n", + " return (ds\n", + " .map(match_shapes, tf.data.AUTOTUNE)\n", + " .unbatch()\n", + " .shuffle(shuffle_buffer)\n", + " .batch(batch_size)\n", + " .map(prepare_txt, tf.data.AUTOTUNE)\n", + " .map(to_tensor, tf.data.AUTOTUNE)\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LrQ85t1GNfpQ" + }, + "source": [ + "You could install the feature extractor in your model and train on the datasets like this:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1KlhOG5cjQ0r" + }, + "outputs": [], + "source": [ + "train_ds = prepare_dataset(train_raw, tokenizer)\n", + "train_ds.element_spec" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "d7Zy9F3zX7i2" + }, + "outputs": [], + "source": [ + "test_ds = prepare_dataset(test_raw, tokenizer)\n", + "test_ds.element_spec" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XZyKygJ8S8zW" + }, + "source": [ + "### [Optional] Cache the image features" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eHKhSKhti6NS" + }, + "source": [ + "Since the image feature extractor is not changing, and this tutorial is not using image augmentation, the image features can be cached. Same for the text tokenization. The time it takes to set up the cache is earned back on each epoch during training and validation. The code below defines two functions `save_dataset` and `load_dataset`: " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9N1MX5ym6xm5" + }, + "outputs": [], + "source": [ + "def save_dataset(ds, save_path, image_model, tokenizer, shards=10, batch_size=32):\n", + " # Load the images and make batches.\n", + " ds = (ds\n", + " .map(lambda path, caption: (load_image(path), caption))\n", + " .apply(tf.data.experimental.ignore_errors())\n", + " .batch(batch_size))\n", + "\n", + " # Run the feature extractor on each batch\n", + " # Don't do this in a .map, because tf.data runs on the CPU. \n", + " def gen():\n", + " for (images, captions) in tqdm.tqdm(ds): \n", + " feature_maps = image_model(images)\n", + "\n", + " feature_maps, captions = match_shapes(feature_maps, captions)\n", + " yield feature_maps, captions\n", + "\n", + " # Wrap the generator in a new tf.data.Dataset.\n", + " new_ds = tf.data.Dataset.from_generator(\n", + " gen,\n", + " output_signature=(\n", + " tf.TensorSpec(shape=image_model.output_shape),\n", + " tf.TensorSpec(shape=(None,), dtype=tf.string)))\n", + "\n", + " # Apply the tokenization \n", + " new_ds = (new_ds\n", + " .map(prepare_txt, tf.data.AUTOTUNE)\n", + " .unbatch()\n", + " .shuffle(1000))\n", + "\n", + " # Save the dataset into shard files.\n", + " def shard_func(i, item):\n", + " return i % shards\n", + " new_ds.enumerate().save(save_path, shard_func=shard_func)\n", + "\n", + "def load_dataset(save_path, batch_size=32, shuffle=1000, cycle_length=2):\n", + " def custom_reader_func(datasets):\n", + " datasets = datasets.shuffle(1000)\n", + " return datasets.interleave(lambda x: x, cycle_length=cycle_length)\n", + " \n", + " ds = tf.data.Dataset.load(save_path, reader_func=custom_reader_func)\n", + "\n", + " def drop_index(i, x):\n", + " return x\n", + "\n", + " ds = (ds\n", + " .map(drop_index, tf.data.AUTOTUNE)\n", + " .shuffle(shuffle)\n", + " .padded_batch(batch_size)\n", + " .prefetch(tf.data.AUTOTUNE))\n", + " return ds" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tNdzrenxB3Yy" + }, + "outputs": [], + "source": [ + "save_dataset(train_raw, 'train_cache', mobilenet, tokenizer)\n", + "save_dataset(test_raw, 'test_cache', mobilenet, tokenizer)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "798DtfH51UI8" + }, + "source": [ + "
\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GI265LiDslr2" + }, + "source": [ + "## Data ready for training\n", + "\n", + "After those preprocessing steps, here are the datasets:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Pwic2YCjHZmV" + }, + "outputs": [], + "source": [ + "train_ds = load_dataset('train_cache')\n", + "test_ds = load_dataset('test_cache')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3B80JXj7HloX" + }, + "outputs": [], + "source": [ + "train_ds.element_spec" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "5jfb8qknlsKi" + }, + "source": [ + "The dataset now returns `(input, label)` pairs suitable for training with keras. The `inputs` are `(images, input_tokens)` pairs. The `images` have been processed with the feature-extractor model. For each location in the `input_tokens` the model looks at the text so far and tries to predict the next which is lined up at the same location in the `labels`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "YJBEwuXLZQdw" + }, + "outputs": [], + "source": [ + "for (inputs, ex_labels) in train_ds.take(1):\n", + " (ex_img, ex_in_tok) = inputs\n", + "\n", + "print(ex_img.shape)\n", + "print(ex_in_tok.shape)\n", + "print(ex_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "22R58DzZoF17" + }, + "source": [ + "The input tokens and the labels are the same, just shifted by 1 step:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V7h5UGftn1hT" + }, + "outputs": [], + "source": [ + "print(ex_in_tok[0].numpy())\n", + "print(ex_labels[0].numpy())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "DfICM49WFpIb" + }, + "source": [ + "## A Transformer decoder model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ONyjuWsmZoyO" + }, + "source": [ + "This model assumes that the pretrained image encoder is sufficient, and just focuses on building the text decoder. This tutorial uses a 2-layer Transformer-decoder.\n", + "\n", + "The implementations are almost identical to those in the [Transformers tutorial](https://www.tensorflow.org/text/tutorials/transformer). Refer back to it for more details.\n", + "\n", + "\n", + "\n", + " \n", + "\n", + "\n", + " \n", + "\n", + "
The Transformer encoder and decoder.
\n", + " \n", + "
" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qiRXWwIKNybB" + }, + "source": [ + "The model will be implemented in three main parts: \n", + "\n", + "1. Input - The token embedding and positional encoding (`SeqEmbedding`).\n", + "1. Decoder - A stack of transformer decoder layers (`DecoderLayer`) where each contains:\n", + " 1. A causal self attention later (`CausalSelfAttention`), where each output location can attend to the output so far.\n", + " 1. A cross attention layer (`CrossAttention`) where each output location can attend to the input image.\n", + " 1. A feed forward network (`FeedForward`) layer which further processes each output location independently.\n", + "1. Output - A multiclass-classification over the output vocabulary.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_ngm3SQMCaYU" + }, + "source": [ + "### Input" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "i9suaARZGPKw" + }, + "source": [ + "The input text has already been split up into tokens and converted to sequences of IDs. \n", + "\n", + "Remember that unlike a CNN or RNN the Transformer's attention layers are invariant to the order of the sequence. Without some positional input, it just sees an unordered set not a sequence. So in addition to a simple vector embedding for each token ID, the embedding layer will also include an embedding for each position in the sequence.\n", + "\n", + "The `SeqEmbedding` layer defined below:\n", + "\n", + "- It looks up the embedding vector for each token.\n", + "- It looks up an embedding vector for each sequence location.\n", + "- It adds the two together.\n", + "- It uses `mask_zero=True` to initialize the keras-masks for the model.\n", + "\n", + "Note: This implementation learns the position embeddings instead of using fixed embeddings like in the [Transformer tutorial](https://www.tensorflow.org/text/tutorials/transformer). Learning the embeddings is slightly less code, but doesn't generalize to longer sequences." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "P91LU2F0a9Ga" + }, + "outputs": [], + "source": [ + "class SeqEmbedding(tf.keras.layers.Layer):\n", + " def __init__(self, vocab_size, max_length, depth):\n", + " super().__init__()\n", + " self.pos_embedding = tf.keras.layers.Embedding(input_dim=max_length, output_dim=depth)\n", + "\n", + " self.token_embedding = tf.keras.layers.Embedding(\n", + " input_dim=vocab_size,\n", + " output_dim=depth,\n", + " mask_zero=True)\n", + " \n", + " self.add = tf.keras.layers.Add()\n", + "\n", + " def call(self, seq):\n", + " seq = self.token_embedding(seq) # (batch, seq, depth)\n", + "\n", + " x = tf.range(tf.shape(seq)[1]) # (seq)\n", + " x = x[tf.newaxis, :] # (1, seq)\n", + " x = self.pos_embedding(x) # (1, seq, depth)\n", + "\n", + " return self.add([seq,x])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "II1mD-bBCdMB" + }, + "source": [ + "### Decoder" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GHMLeMtKPTCW" + }, + "source": [ + "The decoder is a standard Transformer-decoder, it contains a stack of `DecoderLayers` where each contains three sublayers: a `CausalSelfAttention`, a `CrossAttention`, and a`FeedForward`. The implementations are almost identical to the [Transformer tutorial](https://www.tensorflow.org/text/tutorials/transformer), refer to it for more details.\n", + "\n", + "The `CausalSelfAttention` layer is below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6JTLiX3lKooQ" + }, + "outputs": [], + "source": [ + "class CausalSelfAttention(tf.keras.layers.Layer):\n", + " def __init__(self, **kwargs):\n", + " super().__init__()\n", + " self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)\n", + " # Use Add instead of + so the keras mask propagates through.\n", + " self.add = tf.keras.layers.Add() \n", + " self.layernorm = tf.keras.layers.LayerNormalization()\n", + " \n", + " def call(self, x):\n", + " attn = self.mha(query=x, value=x,\n", + " use_causal_mask=True)\n", + " x = self.add([x, attn])\n", + " return self.layernorm(x)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8c66OTRwQfd8" + }, + "source": [ + "The `CrossAttention` layer is below. Note the use of `return_attention_scores`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rIY6Vu2pLBAO" + }, + "outputs": [], + "source": [ + "class CrossAttention(tf.keras.layers.Layer):\n", + " def __init__(self,**kwargs):\n", + " super().__init__()\n", + " self.mha = tf.keras.layers.MultiHeadAttention(**kwargs)\n", + " self.add = tf.keras.layers.Add() \n", + " self.layernorm = tf.keras.layers.LayerNormalization()\n", + " \n", + " def call(self, x, y, **kwargs):\n", + " attn, attention_scores = self.mha(\n", + " query=x, value=y,\n", + " return_attention_scores=True)\n", + " \n", + " self.last_attention_scores = attention_scores\n", + "\n", + " x = self.add([x, attn])\n", + " return self.layernorm(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Hn5p6f-RE0C" + }, + "source": [ + "The `FeedForward` layer is below. Remember that a `layers.Dense` layer is applied to the last axis of the input. The input will have a shape of `(batch, sequence, channels)`, so it automatically applies pointwise across the `batch` and `sequence` axes. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cWKrl7teOnH2" + }, + "outputs": [], + "source": [ + "class FeedForward(tf.keras.layers.Layer):\n", + " def __init__(self, units, dropout_rate=0.1):\n", + " super().__init__()\n", + " self.seq = tf.keras.Sequential([\n", + " tf.keras.layers.Dense(units=2*units, activation='relu'),\n", + " tf.keras.layers.Dense(units=units),\n", + " tf.keras.layers.Dropout(rate=dropout_rate),\n", + " ])\n", + "\n", + " self.layernorm = tf.keras.layers.LayerNormalization()\n", + " \n", + " def call(self, x):\n", + " x = x + self.seq(x)\n", + " return self.layernorm(x)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "lbXoiVNPRoJc" + }, + "source": [ + "Next arrange these three layers into a larger `DecoderLayer`. Each decoder layer applies the three smaller layers in sequence. After each sublayer the shape of `out_seq` is `(batch, sequence, channels)`. The decoder layer also returns the `attention_scores` for later visualizations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ydcW5KZZHou7" + }, + "outputs": [], + "source": [ + "class DecoderLayer(tf.keras.layers.Layer):\n", + " def __init__(self, units, num_heads=1, dropout_rate=0.1):\n", + " super().__init__()\n", + " \n", + " self.self_attention = CausalSelfAttention(num_heads=num_heads,\n", + " key_dim=units,\n", + " dropout=dropout_rate)\n", + " self.cross_attention = CrossAttention(num_heads=num_heads,\n", + " key_dim=units,\n", + " dropout=dropout_rate)\n", + " self.ff = FeedForward(units=units, dropout_rate=dropout_rate)\n", + " \n", + "\n", + " def call(self, inputs, training=False):\n", + " in_seq, out_seq = inputs\n", + "\n", + " # Text input\n", + " out_seq = self.self_attention(out_seq)\n", + "\n", + " out_seq = self.cross_attention(out_seq, in_seq)\n", + " \n", + " self.last_attention_scores = self.cross_attention.last_attention_scores\n", + "\n", + " out_seq = self.ff(out_seq)\n", + "\n", + " return out_seq" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-lgbYrF5Csqu" + }, + "source": [ + "### Output" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "VcnKZkrklAQf" + }, + "source": [ + "At minimum the output layer needs a `layers.Dense` layer to generate logit-predictions for each token at each location." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6WQD87efena5" + }, + "source": [ + "But there are a few other features you can add to make this work a little better:\n", + "\n", + "1. **Handle bad tokens**: The model will be generating text. It should\n", + " never generate a pad, unknown, or start token (`''`, `'[UNK]'`, \n", + " `'[START]'`). So set the bias for these to a large negative value.\n", + "\n", + " > Note: You'll need to ignore these tokens in the loss function as well. \n", + "\n", + "2. **Smart initialization**: The default initialization of a dense layer will\n", + " give a model that initially predicts each token with almost uniform\n", + " likelihood. The actual token distribution is far from uniform. The\n", + " optimal value for the initial bias of the output layer is the log of the\n", + " probability of each token. So include an `adapt` method to count the tokens\n", + " and set the optimal initial bias. This reduces the initial loss from the\n", + " entropy of the uniform distribution (`log(vocabulary_size)`) to the marginal\n", + " entropy of the distribution (`-p*log(p)`).\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "CeWw2SFDHUfo" + }, + "outputs": [], + "source": [ + "#@title\n", + "class TokenOutput(tf.keras.layers.Layer):\n", + " def __init__(self, tokenizer, banned_tokens=('', '[UNK]', '[START]'), **kwargs):\n", + " super().__init__()\n", + " \n", + " self.dense = tf.keras.layers.Dense(\n", + " units=tokenizer.vocabulary_size(), **kwargs)\n", + " self.tokenizer = tokenizer\n", + " self.banned_tokens = banned_tokens\n", + "\n", + " self.bias = None\n", + "\n", + " def adapt(self, ds):\n", + " counts = collections.Counter()\n", + " vocab_dict = {name: id \n", + " for id, name in enumerate(self.tokenizer.get_vocabulary())}\n", + "\n", + " for tokens in tqdm.tqdm(ds):\n", + " counts.update(tokens.numpy().flatten())\n", + "\n", + " counts_arr = np.zeros(shape=(self.tokenizer.vocabulary_size(),))\n", + " counts_arr[np.array(list(counts.keys()), dtype=np.int32)] = list(counts.values())\n", + "\n", + " counts_arr = counts_arr[:]\n", + " for token in self.banned_tokens:\n", + " counts_arr[vocab_dict[token]] = 0\n", + "\n", + " total = counts_arr.sum()\n", + " p = counts_arr/total\n", + " p[counts_arr==0] = 1.0\n", + " log_p = np.log(p) # log(1) == 0\n", + "\n", + " entropy = -(log_p*p).sum()\n", + "\n", + " print()\n", + " print(f\"Uniform entropy: {np.log(self.tokenizer.vocabulary_size()):0.2f}\")\n", + " print(f\"Marginal entropy: {entropy:0.2f}\")\n", + "\n", + " self.bias = log_p\n", + " self.bias[counts_arr==0] = -1e9\n", + "\n", + " def call(self, x):\n", + " x = self.dense(x)\n", + " # TODO(b/250038731): Fix this.\n", + " # An Add layer doesn't work because of the different shapes.\n", + " # This clears the mask, that's okay because it prevents keras from rescaling\n", + " # the losses.\n", + " return x + self.bias\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xzQHqANd1A6Q" + }, + "source": [ + "The smart initialization will significantly reduce the initial loss:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GGnOQyc501B2" + }, + "outputs": [], + "source": [ + "output_layer = TokenOutput(tokenizer, banned_tokens=('', '[UNK]', '[START]'))\n", + "# This might run a little faster if the dataset didn't also have to load the image data.\n", + "output_layer.adapt(train_ds.map(lambda inputs, labels: labels))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3gq-ICN7bD-u" + }, + "source": [ + "### Build the model" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "gou4fPH_SWgH" + }, + "source": [ + "To build the model, you need to combine several parts:\n", + "\n", + "1. The image `feature_extractor` and the text `tokenizer` and.\n", + "1. The `seq_embedding` layer, to convert batches of token-IDs to \n", + " vectors `(batch, sequence, channels)`.\n", + "3. The stack of `DecoderLayers` layers that will process the text and image data.\n", + "4. The `output_layer` which returns a pointwise prediction of what the next word should be." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bHCISYehH1f6" + }, + "outputs": [], + "source": [ + "class Captioner(tf.keras.Model):\n", + " @classmethod\n", + " def add_method(cls, fun):\n", + " setattr(cls, fun.__name__, fun)\n", + " return fun\n", + "\n", + " def __init__(self, tokenizer, feature_extractor, output_layer, num_layers=1,\n", + " units=256, max_length=50, num_heads=1, dropout_rate=0.1):\n", + " super().__init__()\n", + " self.feature_extractor = feature_extractor\n", + " self.tokenizer = tokenizer\n", + " self.word_to_index = tf.keras.layers.StringLookup(\n", + " mask_token=\"\",\n", + " vocabulary=tokenizer.get_vocabulary())\n", + " self.index_to_word = tf.keras.layers.StringLookup(\n", + " mask_token=\"\",\n", + " vocabulary=tokenizer.get_vocabulary(),\n", + " invert=True) \n", + "\n", + " self.seq_embedding = SeqEmbedding(\n", + " vocab_size=tokenizer.vocabulary_size(),\n", + " depth=units,\n", + " max_length=max_length)\n", + "\n", + " self.decoder_layers = [\n", + " DecoderLayer(units, num_heads=num_heads, dropout_rate=dropout_rate)\n", + " for n in range(num_layers)]\n", + "\n", + " self.output_layer = output_layer" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "YW390dOz9T-x" + }, + "source": [ + "When you call the model, for training, it receives an `image, txt` pair. To make this function more usable, be flexible about the input:\n", + "\n", + "* If the image has 3 channels run it through the feature_extractor. Otherwise assume that it has been already. Similarly\n", + "* If the text has dtype `tf.string` run it through the tokenizer.\n", + "\n", + "After that running the model is only a few steps:\n", + "\n", + "1. Flatten the extracted image features, so they can be input to the decoder layers.\n", + "2. Look up the token embeddings.\n", + "3. Run the stack of `DecoderLayer`s, on the image features and text embeddings.\n", + "4. Run the output layer to predict the next token at each position.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "lPdb7I4h9Ulo" + }, + "outputs": [], + "source": [ + " @Captioner.add_method\n", + " def call(self, inputs):\n", + " image, txt = inputs\n", + "\n", + " if image.shape[-1] == 3:\n", + " # Apply the feature-extractor, if you get an RGB image.\n", + " image = self.feature_extractor(image)\n", + " \n", + " # Flatten the feature map\n", + " image = einops.rearrange(image, 'b h w c -> b (h w) c')\n", + "\n", + "\n", + " if txt.dtype == tf.string:\n", + " # Apply the tokenizer if you get string inputs.\n", + " txt = tokenizer(txt)\n", + "\n", + " txt = self.seq_embedding(txt)\n", + "\n", + " # Look at the image\n", + " for dec_layer in self.decoder_layers:\n", + " txt = dec_layer(inputs=(image, txt))\n", + " \n", + " txt = self.output_layer(txt)\n", + "\n", + " return txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kmM7aZQsLiyU" + }, + "outputs": [], + "source": [ + "model = Captioner(tokenizer, feature_extractor=mobilenet, output_layer=output_layer,\n", + " units=576, dropout_rate=0.5, num_layers=2, num_heads=2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xGvOcLQKghXN" + }, + "source": [ + "### Generate captions\n", + "\n", + "Before getting into training, write a bit of code to generate captions. You'll use this to see how training is progressing.\n", + "\n", + "Start by downloading a test image:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "cwFcdMqC-jE2" + }, + "outputs": [], + "source": [ + "image_url = 'https://tensorflow.org/images/surf.jpg'\n", + "image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)\n", + "image = load_image(image_path)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IRBIiTkubmxA" + }, + "source": [ + "To caption an image with this model:\n", + "\n", + "- Extract the `img_features`\n", + "- Initialize the list of output tokens with a `[START]` token.\n", + "- Pass `img_features` and `tokens` into the model.\n", + " - It returns a list of logits.\n", + " - Choose the next token based on those logits. \n", + " - Add it to the list of tokens, and continue the loop.\n", + " - If it generates an `'[END]'` token, break out of the loop.\n", + "\n", + "So add a \"simple\" method to do just that:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Nf1Jie9ef_Cg" + }, + "outputs": [], + "source": [ + "@Captioner.add_method\n", + "def simple_gen(self, image, temperature=1):\n", + " initial = self.word_to_index([['[START]']]) # (batch, sequence)\n", + " img_features = self.feature_extractor(image[tf.newaxis, ...])\n", + "\n", + " tokens = initial # (batch, sequence)\n", + " for n in range(50):\n", + " preds = self((img_features, tokens)).numpy() # (batch, sequence, vocab)\n", + " preds = preds[:,-1, :] #(batch, vocab)\n", + " if temperature==0:\n", + " next = tf.argmax(preds, axis=-1)[:, tf.newaxis] # (batch, 1)\n", + " else:\n", + " next = tf.random.categorical(preds/temperature, num_samples=1) # (batch, 1)\n", + " tokens = tf.concat([tokens, next], axis=1) # (batch, sequence) \n", + "\n", + " if next[0] == self.word_to_index('[END]'):\n", + " break\n", + " words = index_to_word(tokens[0, 1:-1])\n", + " result = tf.strings.reduce_join(words, axis=-1, separator=' ')\n", + " return result.numpy().decode()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TxN2NPX2zB8y" + }, + "source": [ + "Here are some generated captions for that image, the model's untrained, so they don't make much sense yet:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sPm96CccvHnq" + }, + "outputs": [], + "source": [ + "for t in (0.0, 0.5, 1.0):\n", + " result = model.simple_gen(image, temperature=t)\n", + " print(result)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "JefwCRZ8z-Ah" + }, + "source": [ + "The temperature parameter allows you to interpolate between 3 modes:\n", + "\n", + "1. Greedy decoding (`temperature=0.0`) - Chooses the most likely next token at each step.\n", + "2. Random sampling according to the logits (`temperature=1.0`).\n", + "3. Uniform random sampling (`temperature >> 1.0`). \n", + "\n", + "Since the model is untrained, and it used the frequency-based initialization, the \"greedy\" output (first) usually only contains the most common tokens: `['a', '.', '[END]']`." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "r0FpTvaPkqON" + }, + "source": [ + "## Train" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IKcwZdqObK-U" + }, + "source": [ + "To train the model you'll need several additional components:\n", + "\n", + "- The Loss and metrics\n", + "- The Optimizer\n", + "- Optional Callbacks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "g5IW2mWa2sAG" + }, + "source": [ + "### Losses and metrics" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "XbpbDQTw1lOW" + }, + "source": [ + "Here's an implementation of a masked loss and accuracy:\n", + "\n", + "When calculating the mask for the loss, note the `loss < 1e8`. This term discards the artificial, impossibly high losses for the `banned_tokens`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "s24im3FqxAfT" + }, + "outputs": [], + "source": [ + "def masked_loss(labels, preds):\n", + " labels = tf.cast(labels, tf.int64) \n", + " loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels, preds)\n", + "\n", + " mask = (labels != 0) & (loss < 1e8) \n", + " mask = tf.cast(mask, loss.dtype)\n", + "\n", + " loss = loss*mask\n", + " loss = tf.reduce_sum(loss)/tf.reduce_sum(mask)\n", + " return loss\n", + "\n", + "def masked_acc(labels, preds):\n", + " mask = tf.cast(labels!=0, tf.float32)\n", + " preds = tf.argmax(preds, axis=-1)\n", + " labels = tf.cast(labels, tf.int64)\n", + " match = tf.cast(preds == labels, mask.dtype)\n", + " acc = tf.reduce_sum(match*mask)/tf.reduce_sum(mask)\n", + " return acc" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "zOhjHqgv3F2e" + }, + "source": [ + "### Callbacks" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3dyQN9UfJYEd" + }, + "source": [ + "For feedback during training setup a `keras.callbacks.Callback` to generate some captions for the surfer image at the end of each epoch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IKDwbZOCZ-AP" + }, + "outputs": [], + "source": [ + "class GenerateText(tf.keras.callbacks.Callback):\n", + " def __init__(self):\n", + " image_url = 'https://tensorflow.org/images/surf.jpg'\n", + " image_path = tf.keras.utils.get_file('surf.jpg', origin=image_url)\n", + " self.image = load_image(image_path)\n", + "\n", + " def on_epoch_end(self, epochs=None, logs=None):\n", + " print()\n", + " print()\n", + " for t in (0.0, 0.5, 1.0):\n", + " result = self.model.simple_gen(self.image, temperature=t)\n", + " print(result)\n", + " print()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1yNA3_RAsdl0" + }, + "source": [ + "It generates three output strings, like the earlier example, like before the first is \"greedy\", choosing the argmax of the logits at each step." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "IGVLpzo13rcA" + }, + "outputs": [], + "source": [ + "g = GenerateText()\n", + "g.set_model(model)\n", + "g.on_epoch_end(0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "MAxp4KZRKDk9" + }, + "source": [ + "Also use `callbacks.EarlyStopping` to terminate training when the model starts to overfit." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "MjzrwGZp23xx" + }, + "outputs": [], + "source": [ + "callbacks = [\n", + " GenerateText(),\n", + " tf.keras.callbacks.EarlyStopping(\n", + " patience=5, restore_best_weights=True)]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZBaJhQpcG8u0" + }, + "source": [ + "### Train" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WBXG0dCDKO55" + }, + "source": [ + "Configure and execute the training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2OR5ZpAII__u" + }, + "outputs": [], + "source": [ + "model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),\n", + " loss=masked_loss,\n", + " metrics=[masked_acc])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ro955bQ2KR0X" + }, + "source": [ + "For more frequent reporting, use the `Dataset.repeat()` method, and set the `steps_per_epoch` and `validation_steps` arguments to `Model.fit`. \n", + "\n", + "With this setup on `Flickr8k` a full pass over the dataset is 900+ batches, but below the reporting-epochs are 100 steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3aB0baOVMZe9" + }, + "outputs": [], + "source": [ + "history = model.fit(\n", + " train_ds.repeat(),\n", + " steps_per_epoch=100,\n", + " validation_data=test_ds.repeat(),\n", + " validation_steps=20,\n", + " epochs=100,\n", + " callbacks=callbacks)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "P634LfVgw-eV" + }, + "source": [ + "Plot the loss and accuracy over the training run:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6Wn8KSkUw916" + }, + "outputs": [], + "source": [ + "plt.plot(history.history['loss'], label='loss')\n", + "plt.plot(history.history['val_loss'], label='val_loss')\n", + "plt.ylim([0, max(plt.ylim())])\n", + "plt.xlabel('Epoch #')\n", + "plt.ylabel('CE/token')\n", + "plt.legend()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "yZQ78b2Kxw-T" + }, + "outputs": [], + "source": [ + "plt.plot(history.history['masked_acc'], label='accuracy')\n", + "plt.plot(history.history['val_masked_acc'], label='val_accuracy')\n", + "plt.ylim([0, max(plt.ylim())])\n", + "plt.xlabel('Epoch #')\n", + "plt.ylabel('CE/token')\n", + "plt.legend()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "SQN1qT7KNqbL" + }, + "source": [ + "## Attention plots" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "E9XJaC2b2J23" + }, + "source": [ + "Now, using the trained model, run that `simple_gen` method on the image:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1UQPtNTb2eu3" + }, + "outputs": [], + "source": [ + "result = model.simple_gen(image, temperature=0.0)\n", + "result" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7NXbmeLGN1bJ" + }, + "source": [ + "Split the output back into tokens:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "zHKOpm0w5Xto" + }, + "outputs": [], + "source": [ + "str_tokens = result.split()\n", + "str_tokens.append('[END]')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fE-AjuAV55Qo" + }, + "source": [ + "The `DecoderLayers` each cache the attention scores for their `CrossAttention` layer. The shape of each attention map is `(batch=1, heads, sequence, image)`:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "XZpyuQvq2q-B" + }, + "outputs": [], + "source": [ + "attn_maps = [layer.last_attention_scores for layer in model.decoder_layers]\n", + "[map.shape for map in attn_maps]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T42ImsWv6oHG" + }, + "source": [ + "So stack the maps along the `batch` axis, then average over the `(batch, heads)` axes, while splitting the `image` axis back into `height, width`:\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ojwtvnkh6mS-" + }, + "outputs": [], + "source": [ + "attention_maps = tf.concat(attn_maps, axis=0)\n", + "attention_maps = einops.reduce(\n", + " attention_maps,\n", + " 'batch heads sequence (height width) -> sequence height width',\n", + " height=7, width=7,\n", + " reduction='mean')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4TM7rA3zGpJW" + }, + "source": [ + "Now you have a single attention map, for each sequence prediction. The values in each map should sum to `1.`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "ASWmWerGCZp3" + }, + "outputs": [], + "source": [ + "einops.reduce(attention_maps, 'sequence height width -> sequence', reduction='sum')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fv7XYGFUd-U7" + }, + "source": [ + "So here is where the model was focusing attention while generating each token of the output:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "fD_y7PD6RPGt" + }, + "outputs": [], + "source": [ + "def plot_attention_maps(image, str_tokens, attention_map):\n", + " fig = plt.figure(figsize=(16, 9))\n", + "\n", + " len_result = len(str_tokens)\n", + " \n", + " titles = []\n", + " for i in range(len_result):\n", + " map = attention_map[i]\n", + " grid_size = max(int(np.ceil(len_result/2)), 2)\n", + " ax = fig.add_subplot(3, grid_size, i+1)\n", + " titles.append(ax.set_title(str_tokens[i]))\n", + " img = ax.imshow(image)\n", + " ax.imshow(map, cmap='gray', alpha=0.6, extent=img.get_extent(),\n", + " clim=[0.0, np.max(map)])\n", + "\n", + " plt.tight_layout()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PI4NAAws9rvY" + }, + "outputs": [], + "source": [ + "plot_attention_maps(image/255, str_tokens, attention_maps)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "riTz0abQKMkV" + }, + "source": [ + "Now put that together into a more usable function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mktpfW-SKQIJ" + }, + "outputs": [], + "source": [ + "@Captioner.add_method\n", + "def run_and_show_attention(self, image, temperature=0.0):\n", + " result_txt = self.simple_gen(image, temperature)\n", + " str_tokens = result_txt.split()\n", + " str_tokens.append('[END]')\n", + "\n", + " attention_maps = [layer.last_attention_scores for layer in self.decoder_layers]\n", + " attention_maps = tf.concat(attention_maps, axis=0)\n", + " attention_maps = einops.reduce(\n", + " attention_maps,\n", + " 'batch heads sequence (height width) -> sequence height width',\n", + " height=7, width=7,\n", + " reduction='mean')\n", + " \n", + " plot_attention_maps(image/255, str_tokens, attention_maps)\n", + " t = plt.suptitle(result_txt)\n", + " t.set_y(1.05)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "FntRkY11OiMw" + }, + "outputs": [], + "source": [ + "run_and_show_attention(model, image)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Rprk3HEvZuxb" + }, + "source": [ + "## Try it on your own images\n", + "\n", + "For fun, below you're provided a method you can use to caption your own images with the model you've just trained. Keep in mind, it was trained on a relatively small amount of data, and your images may be different from the training data (so be prepared for strange results!)\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "9Psd1quzaAWg" + }, + "outputs": [], + "source": [ + "image_url = 'https://tensorflow.org/images/bedroom_hrnet_tutorial.jpg'\n", + "image_path = tf.keras.utils.get_file(origin=image_url)\n", + "image = load_image(image_path)\n", + "\n", + "run_and_show_attention(model, image)" + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "image_captioning.ipynb", + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 4 }