ml

Cartoon or Photo? – Image detection with Python

This all started with me just wanting a fast way to sort through image folders and remove cartoon images. That lead me down a spiraling rabbit hole of possibility. From using OpenCV to do different types of detection, or even training a Machine Learning model from scratch with Keras. This article will go through the multiple options available and learn how accurate they end up being.

What makes an image a cartoon?

Model is Jessica Marie Frye

If we are going to detect the differences between photos and cartoons, first we need to figure out how they are different. Importantly, how they are different in a quantifiable way. These are the measurable differences I could think up:

  • Cartoons have smoother gradients
  • Real images use a larger color palette
  • Cartoons usually have drawn edge outlines

Below we will try each of these options and see how well they fare.

Results First

You’re probably more interested in what will work the best for you and less about how I spent days toiling away at this, so to cut to the chase, here is how everything performed.

The best OpenCV contender turned out to be counting colors, with a combined 75% accuracy.

Overall, Machine Learning Image Classification using the Xception model wiped the floor with a combined 96% accuracy. This probably isn’t even as good as it could get with further fine tuning, but was more than good enough for my own needs.

These results were also with a hard threshold set to force it into either bucket of real or cartoon. I personally modify them for my own use to use a smaller range for absolutes, and then put the “unsure” ones in another folder. Further reducing any errant picks.

Machine Learning with Keras is the obvious pick if you have a good set of data to train with and a computer beefy enough to process it. However it’s not as portable and much longer than simply trying out one of the OpenCV methods.

OpenCV Gradient Differences

The first method we will use is pretty straight forward. We will blur the image a little, and compare it to it’s unaltered form and quantify the difference. First things first is that we will need opencv installed for python. Go into your venv for this and run:

pip install opencv-python

Now open up your IDE and create a new .py file to get started. First thing we are going to do is read the image into opencv which is the cv2 module. We will use a JPEG image as it will be the expected BRG color format.

import cv2

img = cv2.imread("/path/to/my/image.jpg")

Next we will blur the image a little using a bilateral filter, to even out the colors. We will also resize the image to a standard size so the blur across every image is the same.

img = cv2.resize(img, (1024, 1024))
color_blurred = cv2.bilateralFilter(img, 6, 250, 250)

You can check out the result to see how strong the effect is by previewing the image. Press any key to close the window.

# Optional Preview
cv2.imshow("blurred", color_blurred)
cv2.waitKey(0)
cv2.destroyAllWindows()

Then we need to compare this new color_blurred image to the original image. I accomplished this by comparing the histograms. We will have do that for each color individually.

diffs = []
for k, color in enumerate(('b', 'r', 'g')):
    print(f"Comparing histogram for color {color}")
    real_histogram = cv2.calcHist(img, [k], None, [256], [0, 256])
    color_histogram = cv2.calcHist(color_blurred, [k], None, [256], [0, 256])
    diffs.append(cv2.compareHist(real_histogram, color_histogram, cv2.HISTCMP_CORREL))

result = sum(diffs) / 3

compareHist will give us a result between 0 and 1 (one being the most similar.) We will need to set a threshold for how similar we cartoons will be. I have mine set at 0.98 (aka 98% similar.)

if result > 0.98:
    print("It's a cartoon!")
else: 
    print("It's a photo!")

And that’s it! Now you can test it out and see how it works for your images. I found this iteration to work very fast and have about a ~70% proper detection rate.

Let’s put it all together into a usable script!

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
from typing import Union
import argparse

import cv2


def is_cartoon(
    image: Union[str, Path],
    threshold: float = 0.98,
    preview: bool = False,
) -> bool:
    # read and resize image
    img = cv2.imread(str(image))
    img = cv2.resize(img, (1024, 1024))

    # blur the image to "even out" the colors
    color_blurred = cv2.bilateralFilter(img, 6, 250, 250)

    if preview:
        cv2.imshow("blurred", color_blurred)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    # compare the colors from the original image to blurred one.
    diffs = []
    for k, color in enumerate(("b", "r", "g")):
        # print(f"Comparing histogram for color {color}")
        real_histogram = cv2.calcHist(img, [k], None, [256], [0, 256])
        color_histogram = cv2.calcHist(color_blurred, [k], None, [256], [0, 256])
        diffs.append(
            cv2.compareHist(real_histogram, color_histogram, cv2.HISTCMP_CORREL)
        )

    return sum(diffs) / 3 > threshold


def command_line_options():
    args = argparse.ArgumentParser(
        "blur_compare",
        description="Determine if a image is likely a cartoon or photo.",
    )
    args.add_argument(
        "-p",
        "--preview",
        action="store_true",
        help="Show the blurred image",
    )
    args.add_argument(
        "-t",
        "--threshold",
        type=float,
        help="Cutoff threshold",
        default=0.98,
    )
    args.add_argument(
        "image",
        type=Path,
        help="Path to image file",
    )
    return vars(args.parse_args())


if __name__ == "__main__":
    options = command_line_options()
    if not options["image"].exists():
        raise FileNotFoundError(f"No image exists at {options['image'].absolute()}")
    if is_cartoon(**options):
        print(f"{options['image'].name} is a cartoon!")
    else:
        print(f"{options['image'].name} is a photo!")

OpenCV Color Counting

Using a subset of 512 colors in a 1024×1024 image, determine how much of the image can be reproduced

Next possible way to approach the problem was just figuring out how many colors were used for the majority of the image. We will use all the same code to load and resize the image from above, just this time we will add a loop to count all the colors (slow way, faster option below):

    # Find count of each color
    a = {}
    for item in img.flatten():
        value = tuple(item)
        if value not in a:
            a[value] = 1
        else:
            a[value] += 1

Next we will sort the dictionary by the most used color, and add the count of the top 512 images together.

The actual calculation is a lot of functionality in a small block of code. First let’s get a visual preview of what is happening by re-creating the image with only the selected colors with:

mask = numpy.zeros(img.shape[:2], dtype=bool)

for color, _ in sorted(a.items(), key=lambda pair: pair[1], reverse=True)[:512]:
    mask |= (img == color).all(-1)

img[~mask] = (255, 255, 255)

cv2.imshow("img", img)
cv2.waitKey(0)
cv2.destroyAllWindows()

Here’s the code that basically calculates what you are seeing. We divide the sum of all the top colors by the size of the image to get what percent could be recreated with just those colors.

    # Identify the percent of the image that uses the top 512 colors
    most_common_colors = sum([x[1] for x in sorted(a.items(), key=lambda pair: pair[1], reverse=True)[:512]])
    return (most_common_colors / (1024 * 1024)) > 0.3 # new threshold

This script is pretty similar to the last one. And has a slightly higher success rate on my data set at 76%! It is 20% better at detecting what is a cartoon, but 10% worse at photos. It is also much much slower, some 20~50 times slower. (Will take a second per image instead of 0.05 of a second). However, we can speed that up by instead using some numpy trickery.

    # Replace everything after "Find count of each color" with this faster version, but it doesn't work with the preview.
    flattened = numpy.reshape(img, ((1024 * 1024), 3))
    multiplied = numpy.multiply(flattened, [100_000, 100, 1])
    sums = multiplied.sum(axis=1)
    unique, counts = numpy.unique(sums, return_counts=True)

    # Identify the percent of the image that uses the top 512 colors
    most_common_colors = sum(sorted(counts, reverse=True)[:512])
    return (most_common_colors / (1024 * 1024)) > threshold

Here is the entire script with the slower version that works with preview image. However after getting a good threshold I would suggest replacing the code with the faster version above.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
from typing import Union
import argparse

import cv2
import numpy


def is_cartoon(
    image: Union[str, Path],
    threshold: float = 0.3,
    preview: bool = False,
) -> bool:
    # read and resize image
    img = cv2.imread(str(image))
    img = cv2.resize(img, (1024, 1024))

    # Find count of each color
    a = {}
    for row in img:
        for item in row:
            value = tuple(item)
            if value not in a:
                a[value] = 1
            else:
                a[value] += 1

    if preview:
        mask = numpy.zeros(img.shape[:2], dtype=bool)

        for color, _ in sorted(a.items(), key=lambda pair: pair[1], reverse=True)[:512]:
            mask |= (img == color).all(-1)

        img[~mask] = (255, 255, 255)

        cv2.imshow("img", img)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    # Identify the percent of the image that uses the top 512 colors
    most_common_colors = sum(
        [x[1] for x in sorted(a.items(), key=lambda pair: pair[1], reverse=True)[:512]]
    )
    return (most_common_colors / (1024 * 1024)) > threshold


def command_line_options():
    args = argparse.ArgumentParser(
        "blur_compare",
        description="Determine if a image is likely a cartoon or photo.",
    )
    args.add_argument(
        "-p",
        "--preview",
        action="store_true",
        help="Show the blurred image",
    )
    args.add_argument(
        "-t",
        "--threshold",
        type=float,
        help="Cutoff threshold",
        default=0.3,
    )
    args.add_argument(
        "image",
        type=Path,
        help="Path to image file",
    )
    return vars(args.parse_args())


if __name__ == "__main__":
    options = command_line_options()
    if not options["image"].exists():
        raise FileNotFoundError(f"No image exists at {options['image'].absolute()}")
    if is_cartoon(**options):
        print(f"{options['image'].name} is a cartoon!")
    else:
        print(f"{options['image'].name} is a photo!")

OpenCV Edge Detection

This I haven’t had much luck with, only about 55% successful at determine what type of image it is. A coin toss is about as accurate. I don’t recommend using it as is, but if you have any ideas or improvements, let me hear them!

#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from pathlib import Path
from typing import Union
import argparse

import cv2
import numpy


def is_cartoon(
    image: Union[str, Path],
    threshold: float = 4500,
    preview: bool = False,
) -> bool:
    # read and resize image
    img = cv2.imread(str(image))
    img = cv2.resize(img, (1024, 1024))

    gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    blurred_gray = cv2.medianBlur(gray, 3)

    edges = cv2.adaptiveThreshold(
        gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 5, 10
    )
    blurred_edges = cv2.adaptiveThreshold(
        blurred_gray, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 5, 10
    )

    if preview:
        cv2.imshow("edges", edges)
        cv2.imshow("blurred edges", blurred_edges)
        cv2.waitKey(0)
        cv2.destroyAllWindows()

    count_1 = numpy.count_nonzero(edges)
    count_2 = numpy.count_nonzero(blurred_edges)

    return abs(count_2 - count_1) < threshold


def command_line_options():
    args = argparse.ArgumentParser(
        "blur_compare",
        description="Determine if a image is likely a cartoon or photo.",
    )
    args.add_argument(
        "-p",
        "--preview",
        action="store_true",
        help="Show the blurred image",
    )
    args.add_argument(
        "-t",
        "--threshold",
        type=float,
        help="Cutoff threshold",
        default=4500,
    )
    args.add_argument(
        "image",
        type=Path,
        help="Path to image file",
    )
    return vars(args.parse_args())


if __name__ == "__main__":
    options = command_line_options()
    if not options["image"].exists():
        raise FileNotFoundError(f"No image exists at {options['image'].absolute()}")
    if is_cartoon(**options):
        print(f"{options['image'].name} is a cartoon!")
    else:
        print(f"{options['image'].name} is a photo!")

Keras Machine Learning Model

Time to take the kid gloves off, let’s use some ML image detection.

First we have to train a minified Xception model with a lot of hand picked good data. I used 556 cartoon images and 2295 real images in the training datasets. Then used that model for detection between another 2000+ unsorted images.

To get set by step details, please check out the same tutorial I used for this.

import shutil

import numpy as np
import os
import time
from pathlib import Path

import tensorflow as tf

root_dir = Path("/training/")
image_size = (180, 180)
batch_size = 32
epochs = 20
model_name = "my_model"


def make_model(input_shape, num_classes):
    inputs = tf.keras.Input(shape=input_shape)
    # Image augmentation block
    data_augmentation = tf.keras.Sequential(
        [
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.1),
        ]
    )
    x = data_augmentation(inputs)

    # Entry block
    x = tf.keras.layers.Rescaling(1.0 / 255)(x)
    x = tf.keras.layers.Conv2D(32, 3, strides=2, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    x = tf.keras.layers.Conv2D(64, 3, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    previous_block_activation = x  # Set aside residual

    for size in [128, 256, 512, 728]:
        x = tf.keras.layers.Activation("relu")(x)
        x = tf.keras.layers.SeparableConv2D(size, 3, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)

        x = tf.keras.layers.Activation("relu")(x)
        x = tf.keras.layers.SeparableConv2D(size, 3, padding="same")(x)
        x = tf.keras.layers.BatchNormalization()(x)

        x = tf.keras.layers.MaxPooling2D(3, strides=2, padding="same")(x)

        # Project residual
        residual = tf.keras.layers.Conv2D(size, 1, strides=2, padding="same")(
            previous_block_activation
        )
        x = tf.keras.layers.add([x, residual])  # Add back residual
        previous_block_activation = x  # Set aside next residual

    x = tf.keras.layers.SeparableConv2D(1024, 3, padding="same")(x)
    x = tf.keras.layers.BatchNormalization()(x)
    x = tf.keras.layers.Activation("relu")(x)

    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    if num_classes == 2:
        activation = "sigmoid"
        units = 1
    else:
        activation = "softmax"
        units = num_classes

    x = tf.keras.layers.Dropout(0.5)(x)
    outputs = tf.keras.layers.Dense(units, activation=activation)(x)
    return tf.keras.Model(inputs, outputs)


def train_data():
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        str(root_dir),
        validation_split=0.2,
        subset="training",
        seed=1337,
        image_size=image_size,
        batch_size=batch_size,
    )
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        str(root_dir),
        validation_split=0.2,
        subset="validation",
        seed=1337,
        image_size=image_size,
        batch_size=batch_size,
    )

    train_ds = train_ds.prefetch(buffer_size=32)
    val_ds = val_ds.prefetch(buffer_size=32)

    model = make_model(input_shape=image_size + (3,), num_classes=2)
    tf.keras.utils.plot_model(model, show_shapes=True)

    callbacks = [
        tf.keras.callbacks.ModelCheckpoint(f"{model_name}_{{epoch}}.h5"),
    ]
    model.compile(
        optimizer=tf.keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=["accuracy"],
    )
    model.fit(
        train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds,
    )


def clean_images():
    move_to = root_dir.parent / "bad_data"  # needs to be not in the training directory
    move_to.mkdir(exist_ok=True)
    moved = 0
    for directory in root_dir.glob("*"):
        if directory.name == move_to.name:
            continue
        if directory.is_dir():
            for i, file in enumerate(directory.glob("*")):
                if not file.name.lower().endswith(("jpg", "jpeg")) or not tf.compat.as_bytes("JFIF") in file.open("rb").read(10):
                    shutil.move(file, move_to / file.name)
                    moved += 1
        print("moved unclean data", moved, "from", directory)


def move_images():
    model = tf.keras.models.load_model(f"{model_name}_{epochs}.h5")
    cartoon_dir = Path("/cartoon/")
    real_dir = Path("/real/")

    real, cartoon, unknown = 0, 0, 0

    for file in Path("/unsorted/").glob("*.[jpg][jpeg][png]"):

        img = tf.keras.preprocessing.image.load_img(
            str(file), target_size=image_size
        )
        img_array = tf.keras.preprocessing.image.img_to_array(img)
        img_array = tf.expand_dims(img_array, 0)  # Create batch axis

        predictions = model.predict(img_array)
        score = predictions[0][0]
        if score > 0.98:
            real += 1
            shutil.move(file, real_dir / file.name)
        elif score < 0.02:
            cartoon += 1
            shutil.move(file, cartoon_dir / file.name)
        else:
            unknown += 1
            print(f"Could not figure out {file} as it was {score * 100}")
    print(f"Moved {real} to real and {cartoon} to cartoon, {unknown} were unmoved")


if __name__ == '__main__':
    clean_images()
    train_data()
    move_images()

I personally think 95% overall accuracy is amazing for no additional tuning. As well as not using this model for not exactly it’s intended use case. Generally we think about classifying items in the image (cat vs dog) not type of image (anime cat vs real cat).