How to train your own object detection models using the TensorFlow Object Detection API (2020 Update)

This started as a summary of this nice tutorial, but has since then become its own thing.

Prerequisites

  1. Choose a TensorFlow installation. TensorFlow 1 and 2 have different neural networks avaliable, so check here and here to make your choice.

    • Tip: if you opt for one of the TF1 models, please note that the Object Detection API is only officialy compatible with TF 1.15.O, which works only with CUDA 10.0 (unless you compile from source). From personal experience, I know that all versions of TF from 1.12 and backwards do not work with the Object Detection API anymore.
  2. Install TensorFlow.
  3. Download the TensorFlow models repository and install the Object Detection API [TF1] [TF2].

Annotating images

  1. Install labelImg. This is a Python package, which means you can install it via pip, but the one from GitHub is better.

  2. Annotate your dataset using labelImg. Each image you annotate will have its annotations saved to an individual XML file with the name of the original image file and the .xml extension.

Serializing the dataset

For these steps, I’ll recommend a collection of scripts I made, which are available in this repository. All the scripts mentioned in this section receive arguments from the command line and have help messages through the -h/--help flags. Also check the README from the repo they come from to get more details, if needed.

  1. Use this script to convert the XML files generated by labelImg into a single CSV file.

    • Optional: Use this script to separate the CSV file into two, one with training examples and one with evaluation examples. Let's call them train.csv and val.csv. Images will be selected randomly and there are options to stratify examples by class, making sure that objects from all classes are present in both datasets. The usual proportions are 75% to 80% of the annotated objects used for training and the rest for the evaluation dataset.
  2. Create a “label map” for your classes. You can check some examples to understand what they look like. You can also generate one from your original CSV file with this script.

  3. Use this script to convert each of your CSV files into two TFRecord files (e.g. train.record and eval.record), a serialized data format that TensorFlow is most familiar with. You’ll need to point to the directory where the image files are stored and to the label map generated in the previous step.

    • Tip: if you notice mistakes during the creation of these files, you can check their contents and compare to the ones in these examples.

Preparing the training pipeline

  1. Download the neural network model of choice from either the Detection Model Zoo [TF1][TF2] or from the models trained for classification available here and here. This is the step in which your choice of TensorFlow version will make a difference. From my experience, many of the classification models work with TF 1.15, but I am not aware if they work with TF 2.

  2. Provide a training pipeline, which is a file with .config extension that describes the training procedure. The models provided in the Detection Zoo come with their own pipelines inside their .tar.gz file, but the classification models do not. In this situation, your options are to:

    • download one that is close enough from here (I have succesfully done that to train classification MobileNets V1, V2 and V3 for detection).
    • create your own, by following this tutorial.

    The pipeline config file has some fields that must be adjusted before training is started. The first thing you’ll definitely want to keep an eye on is the num_classes attribute, which you’ll need to change to the number of classes in your personal dataset.

    Other important fields are the ones with the PATH_TO_BE_CONFIGURED string. In these fields, you’ll need to point to the files they ask for, such as the label map, the training and evaluation TFRecords and the neural network checkpoint, which is a file with an extension like .ckpt or .ckpt.data-####-of-####. This file also comes with the .tar.gz file.

    In case you are using a model from the Detection Zoo, set the fine_tune_checkpoint_type field to "detection", otherwise, set it to "classification".

    There are additional parameters that may affect how much RAM is consumed by the training process, as well as the quality of the training. Things like the batch size or how many batches TensorFlow can prefetch and keep in memory may considerably increase the amount of RAM necessary, but I won’t go over those here as there is too much trial and error in adjusting those.

Training the network

  1. Train the model. To do it locally, follow the steps available here: [TF1][TF2].

    Optional: in order to check training progress, TensorBoard can be started pointing its --logdir to the --model_dir path from the previous step.

  2. Export the network, like this.

    Tip: if your training completes successfully but you get a scary “data loss error” like this one when exporting, make sure you point the export script to the checkpoint file accordingly. For example, if your checkpoint file is named model.ckpt-50000.data-00000-of-00001 or model.ckpt.data-00000-of-00001, you have to pass the file name as model.ckpt-50000 or model.ckpt, respectively.

  3. In the directory where the export script was pointed to, a file called frozen_inference_graph.pb will be created. Use this file alongside the label map to detect objects in your application. An example of how to achieve this can be found in this notebook from the Models repository. Alternatively, you can give my package dodo detector a try, which uses these same files, but abstracts the inner workings of TensorFlow. You can see it in action in this Gist.

Final Tips

In the data augmentation section of the training pipeline, some options can be added or removed to try and make the training better. Some options are listed here.




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • Running a Gemma-powered question-answering chatbot locally with LangChain + Ollama
  • Answering questions from an Obsidian database with LLMs + RAG
  • Using task-spooler to queue experiments on Linux
  • In C++, classes and structs are the same thing