F beta score for Keras

Posted on Mon 24 April 2017 • Tagged with Deep Learning, Code Snippets, Kaggle, Keras

I'm a newbie in Deep Learning so it makes sense to practice a bit at simple problems. So I've decided to join the contest Planet: Understanding the Amazon from Space at Kaggle. The contest looks promising and not too complicated comparing to other deep learning competitions: not too much data, no image segmentation or image localization required, data is preprocessed properly. The goal is to assign tags (e.g. forest, agriculture, roads etc.) to satellite image patches. Nice problem for the beginner, isn't it?

satellite image samples

The competition uses F2 score as a metric - custom case of F beta score. F1 score is more popular choice, so I wonder why they chose beta = 2. However this metric is available in scikit-learn, which is not suitable for deep learning. So it makes sense to create one for Keras which is easiest DL framework for prototyping. The code was tested with tensorflow backend for Keras.

import numpy as np
from sklearn.metrics import fbeta_score
from keras import backend as K

def fbeta(y_true, y_pred, threshold_shift=0):
    beta = 2

    # just in case of hipster activation at the final layer
    y_pred = K.clip(y_pred, 0, 1)

    # shifting the prediction threshold from .5 if needed
    y_pred_bin = K.round(y_pred + threshold_shift)

    tp = K.sum(K.round(y_true * y_pred_bin)) + K.epsilon()
    fp = K.sum(K.round(K.clip(y_pred_bin - y_true, 0, 1)))
    fn = K.sum(K.round(K.clip(y_true - y_pred, 0, 1)))

    precision = tp / (tp + fp)
    recall = tp / (tp + fn)

    beta_squared = beta ** 2
    return (beta_squared + 1) * (precision * recall) / (beta_squared * precision + recall)

y_true, y_pred = np.round(np.random.rand(100)), np.random.rand(100)

fbeta_keras = fbeta(K.variable(y_true), K.variable(y_pred)).eval(session=K.get_session())
fbeta_sklearn = fbeta_score(y_true, np.round(y_pred), beta=2)

print('Scores are {:.3f} (sklearn) and {:.3f} (keras)'.format(fbeta_sklearn, fbeta_keras))

Same code at Kaggle