Hard Examples Mining in Keras

Posted on Mon 23 October 2017 • Tagged with Deep Learning, Code Snippets, Weird Code

In deep learning, one often works with a high-level interface of a particular framework. No need for computing gradients manually, just stack layers with your favorite Keras.

The bad thing: as soon as it's a black box, black magic effects happen. Today I've faced one weird effect while implementing some kind of hard example mining.

Let's have a look:

from os import environ

import numpy as np
from keras.models import Model
from keras.layers import Dense, Input

environ['CUDA_VISIBLE_DEVICES'] = ''
# i'm running the tests, while my GPUs are fitting some Kaggle datasets

def gen(batch_size):
   while True:
       x_data = np.random.rand(batch_size, 100)
       y_data = np.random.rand(batch_size, 2)
       yield x_data, y_data


def mine_hard_samples(model, datagen, batch_size):
   while True:
       samples, targets = [], []
       while len(samples) < batch_size:
           x_data, y_data = next(datagen)
           preds = model.predict(x_data)
           errors = np.abs(preds - y_data).max(axis=-1) > .99
           samples += x_data[errors].tolist()
           targets += y_data[errors].tolist()

       regular_samples = batch_size * 2 - len(samples)
       x_data, y_data = next(datagen)
       samples += x_data[:regular_samples].tolist()
       targets += y_data[:regular_samples].tolist()

       samples, targets = map(np.array, (samples, targets))

       idx = np.arange(batch_size * 2)
       np.random.shuffle(idx)
       batch1, batch2 = np.split(idx, 2)
       yield samples[batch1], targets[batch1]
       yield samples[batch2], targets[batch2]


def make_model():
   inp = Input((100,))
   x = Dense(10)(inp)
   out = Dense(2)(x)

   model = Model(inputs=inp, outputs=out)
   model.compile(optimizer='adam', loss='mse')
   return model


def valid_main():
   model = make_model()

   x, y = next(gen(64))
   model.predict(x)
   # magic that helps

   model.fit_generator(mine_hard_samples(model, gen(64), 64),
                       steps_per_epoch=10)


def invalid_main():
   model = make_model()
   model.fit_generator(mine_hard_samples(model, gen(64), 64),
                       steps_per_epoch=10)


if __name__ == '__main__':
   valid_main()
   invalid_main()

I see no reason to describe the whole code - just focus on two last functions. For some mysterious reason, calling predict within a hard sample generator leads to a ValueError. At the same time, it's enough to call a single predict before, and it works.

In the ideal world, I would investigate this in the very deep core of Keras and Tensorflow. In the real world, I'm happy to see the solution working with this weird fix.