// The contents of this file are in the public domain. See LICENSE_FOR_EXAMPLE_PROGRAMS.txt
/*
    This example shows how to train a semantic segmentation net using the PASCAL VOC2012
    dataset.  For an introduction to what segmentation is, see the accompanying header file
    dnn_semantic_segmentation_ex.h.

    Instructions how to run the example:
    1. Download the PASCAL VOC2012 data, and untar it somewhere.
       http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar
    2. Build the dnn_semantic_segmentation_train_ex example program.
    3. Run:
       ./dnn_semantic_segmentation_train_ex /path/to/VOC2012
    4. Wait while the network is being trained.
    5. Build the dnn_semantic_segmentation_ex example program.
    6. Run:
       ./dnn_semantic_segmentation_ex /path/to/VOC2012-or-other-images

    It would be a good idea to become familiar with dlib's DNN tooling before reading this
    example.  So you should read dnn_introduction_ex.cpp and dnn_introduction2_ex.cpp
    before reading this example program.
*/

#include "dnn_semantic_segmentation_ex.h"

#include <iostream>
#include <dlib/data_io.h>
#include <dlib/image_transforms.h>
#include <dlib/dir_nav.h>
#include <iterator>
#include <thread>

using namespace std;
using namespace dlib;

// A single training sample. A mini-batch comprises many of these.
struct training_sample
{
    matrix<rgb_pixel> input_image;
    matrix<uint16_t> label_image; // The ground-truth label of each pixel.
};

// ----------------------------------------------------------------------------------------

rectangle make_random_cropping_rect(
    const matrix<rgb_pixel>& img,
    dlib::rand& rnd
)
{
    // figure out what rectangle we want to crop from the image
    double mins = 0.466666666, maxs = 0.875;
    auto scale = mins + rnd.get_random_double()*(maxs-mins);
    auto size = scale*std::min(img.nr(), img.nc());
    rectangle rect(size, size);
    // randomly shift the box around
    point offset(rnd.get_random_32bit_number()%(img.nc()-rect.width()),
                 rnd.get_random_32bit_number()%(img.nr()-rect.height()));
    return move_rect(rect, offset);
}

// ----------------------------------------------------------------------------------------

void randomly_crop_image (
    const matrix<rgb_pixel>& input_image,
    const matrix<uint16_t>& label_image,
    training_sample& crop,
    dlib::rand& rnd
)
{
    const auto rect = make_random_cropping_rect(input_image, rnd);

    const chip_details chip_details(rect, chip_dims(227, 227));

    // Crop the input image.
    extract_image_chip(input_image, chip_details, crop.input_image, interpolate_bilinear());

    // Crop the labels correspondingly. However, note that here bilinear
    // interpolation would make absolutely no sense - you wouldn't say that
    // a bicycle is half-way between an aeroplane and a bird, would you?
    extract_image_chip(label_image, chip_details, crop.label_image, interpolate_nearest_neighbor());

    // Also randomly flip the input image and the labels.
    if (rnd.get_random_double() > 0.5)
    {
        crop.input_image = fliplr(crop.input_image);
        crop.label_image = fliplr(crop.label_image);
    }

    // And then randomly adjust the colors.
    apply_random_color_offset(crop.input_image, rnd);
}

// ----------------------------------------------------------------------------------------

// Calculate the per-pixel accuracy on a dataset whose file names are supplied as a parameter.
double calculate_accuracy(anet_type& anet, const std::vector<image_info>& dataset)
{
    int num_right = 0;
    int num_wrong = 0;

    matrix<rgb_pixel> input_image;
    matrix<rgb_pixel> rgb_label_image;
    matrix<uint16_t> index_label_image;
    matrix<uint16_t> net_output;

    for (const auto& image_info : dataset)
    {
        // Load the input image.
        load_image(input_image, image_info.image_filename);

        // Load the ground-truth (RGB) labels.
        load_image(rgb_label_image, image_info.class_label_filename);

        // Create predictions for each pixel. At this point, the type of each prediction
        // is an index (a value between 0 and 20). Note that the net may return an image
        // that is not exactly the same size as the input.
        const matrix<uint16_t> temp = anet(input_image);

        // Convert the RGB values to indexes.
        rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);

        // Crop the net output to be exactly the same size as the input.
        const chip_details chip_details(
            centered_rect(temp.nc() / 2, temp.nr() / 2, input_image.nc(), input_image.nr()),
            chip_dims(input_image.nr(), input_image.nc())
        );
        extract_image_chip(temp, chip_details, net_output, interpolate_nearest_neighbor());

        const long nr = index_label_image.nr();
        const long nc = index_label_image.nc();

        // Compare the predicted values to the ground-truth values.
        for (long r = 0; r < nr; ++r)
        {
            for (long c = 0; c < nc; ++c)
            {
                const uint16_t truth = index_label_image(r, c);
                if (truth != dlib::loss_multiclass_log_per_pixel_::label_to_ignore)
                {
                    const uint16_t prediction = net_output(r, c);
                    if (prediction == truth)
                    {
                        ++num_right;
                    }
                    else
                    {
                        ++num_wrong;
                    }
                }
            }
        }
    }

    // Return the accuracy estimate.
    return num_right / static_cast<double>(num_right + num_wrong);
}

// ----------------------------------------------------------------------------------------

int main(int argc, char** argv) try
{
    if (argc < 2 || argc > 3)
    {
        cout << "To run this program you need a copy of the PASCAL VOC2012 dataset." << endl;
        cout << endl;
        cout << "You call this program like this: " << endl;
        cout << "./dnn_semantic_segmentation_train_ex /path/to/VOC2012 [minibatch-size]" << endl;
        return 1;
    }

    cout << "\nSCANNING PASCAL VOC2012 DATASET\n" << endl;

    const auto listing = get_pascal_voc2012_train_listing(argv[1]);
    cout << "images in dataset: " << listing.size() << endl;
    if (listing.size() == 0)
    {
        cout << "Didn't find the VOC2012 dataset. " << endl;
        return 1;
    }

    // a mini-batch smaller than the default can be used with GPUs having less memory
    const unsigned int minibatch_size = argc == 3 ? std::stoi(argv[2]) : 23;
    cout << "mini-batch size: " << minibatch_size << endl;

    const double initial_learning_rate = 0.1;
    const double weight_decay = 0.0001;
    const double momentum = 0.9;

    bnet_type bnet;
    dnn_trainer<bnet_type> trainer(bnet,sgd(weight_decay, momentum));
    trainer.be_verbose();
    trainer.set_learning_rate(initial_learning_rate);
    trainer.set_synchronization_file("pascal_voc2012_trainer_state_file.dat", std::chrono::minutes(10));
    // This threshold is probably excessively large.
    trainer.set_iterations_without_progress_threshold(5000);
    // Since the progress threshold is so large might as well set the batch normalization
    // stats window to something big too.
    set_all_bn_running_stats_window_sizes(bnet, 1000);

    // Output training parameters.
    cout << endl << trainer << endl;

    std::vector<matrix<rgb_pixel>> samples;
    std::vector<matrix<uint16_t>> labels;

    // Start a bunch of threads that read images from disk and pull out random crops.  It's
    // important to be sure to feed the GPU fast enough to keep it busy.  Using multiple
    // thread for this kind of data preparation helps us do that.  Each thread puts the
    // crops into the data queue.
    dlib::pipe<training_sample> data(200);
    auto f = [&data, &listing](time_t seed)
    {
        dlib::rand rnd(time(0)+seed);
        matrix<rgb_pixel> input_image;
        matrix<rgb_pixel> rgb_label_image;
        matrix<uint16_t> index_label_image;
        training_sample temp;
        while(data.is_enabled())
        {
            // Pick a random input image.
            const image_info& image_info = listing[rnd.get_random_32bit_number()%listing.size()];

            // Load the input image.
            load_image(input_image, image_info.image_filename);

            // Load the ground-truth (RGB) labels.
            load_image(rgb_label_image, image_info.class_label_filename);

            // Convert the RGB values to indexes.
            rgb_label_image_to_index_label_image(rgb_label_image, index_label_image);

            // Randomly pick a part of the image.
            randomly_crop_image(input_image, index_label_image, temp, rnd);

            // Push the result to be used by the trainer.
            data.enqueue(temp);
        }
    };
    std::thread data_loader1([f](){ f(1); });
    std::thread data_loader2([f](){ f(2); });
    std::thread data_loader3([f](){ f(3); });
    std::thread data_loader4([f](){ f(4); });

    // The main training loop.  Keep making mini-batches and giving them to the trainer.
    // We will run until the learning rate has dropped by a factor of 1e-4.
    while(trainer.get_learning_rate() >= 1e-4)
    {
        samples.clear();
        labels.clear();

        // make a mini-batch
        training_sample temp;
        while(samples.size() < minibatch_size)
        {
            data.dequeue(temp);

            samples.push_back(std::move(temp.input_image));
            labels.push_back(std::move(temp.label_image));
        }

        trainer.train_one_step(samples, labels);
    }

    // Training done, tell threads to stop and make sure to wait for them to finish before
    // moving on.
    data.disable();
    data_loader1.join();
    data_loader2.join();
    data_loader3.join();
    data_loader4.join();

    // also wait for threaded processing to stop in the trainer.
    trainer.get_net();

    bnet.clean();
    cout << "saving network" << endl;
    serialize(semantic_segmentation_net_filename) << bnet;


    // Make a copy of the network to use it for inference.
    anet_type anet = bnet;

    cout << "Testing the network..." << endl;

    // Find the accuracy of the newly trained network on both the training and the validation sets.
    cout << "train accuracy  :  " << calculate_accuracy(anet, get_pascal_voc2012_train_listing(argv[1])) << endl;
    cout << "val accuracy    :  " << calculate_accuracy(anet, get_pascal_voc2012_val_listing(argv[1])) << endl;
}
catch(std::exception& e)
{
    cout << e.what() << endl;
}