kumitatepazuru's blog

中学生のメモブログ。みんなの役に立ちたい。

kerasでfit_generatorを実行していた時にエラーが出て手こずった話

f:id:kumitatepazuru:20200715181903p:plain

久しぶりの更新。stack overflowエラーが出て大変な思いをしたので書いておく。

ある日、激おそ深層学習をやってたらエラーが出まして。大変だった。

現状把握

エラー内容

以下の通り。

  6629/144508 [>.............................] - ETA: 3:30:58 - loss: 3.1317 - acc: 0.1687Traceback (most recent call last):
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/utils/data_utils.py", line 578, in get
    inputs = self.queue.get(block=True).get()
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/pool.py", line 644, in get
    raise self._value
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/pool.py", line 119, in worker
    result = (True, func(*args, **kwds))
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/utils/data_utils.py", line 401, in get_index
    return _SHARED_SEQUENCES[uid][i]
  File "/home/***/PycharmProjects/img2speech/load_data.py", line 91, in __getitem__
    img, label = self._load_data()
  File "/home/***/PycharmProjects/img2speech/load_data.py", line 42, in _load_data
    text_size = tmp_d.textsize(text, self.font_data)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/PIL/ImageDraw.py", line 430, in textsize
    return font.getsize(text, direction, features, language, stroke_width)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/PIL/ImageFont.py", line 262, in getsize
    size, offset = self.font.getsize(text, False, direction, features, language)
OSError: stack overflow

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/***/PycharmProjects/img2speech/main.py", line 58, in <module>
    shuffle=True)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/models.py", line 1315, in fit_generator
    initial_epoch=initial_epoch)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper
    return func(*args, **kwargs)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/engine/training.py", line 2194, in fit_generator
    generator_output = next(output_generator)
  File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/utils/data_utils.py", line 584, in get
    six.raise_from(StopIteration(e), e)
  File "<string>", line 3, in raise_from
StopIteration: stack overflow

Process finished with exit code 1

えぇぇぇ。まさかのstack overflow。

コード

最小限コード。

main.py

# パッケージのインポート
import os

from keras.layers import BatchNormalization, Dense, Dropout
from keras.models import Sequential
from keras.optimizers import SGD
import matplotlib.pyplot as plt
from tqdm import tqdm

import load_data


train_paths = []
for root, dirs, files in tqdm(os.walk("./font")):
    train_paths += list(map(lambda n:root+"/"+n,files))

val_count = int(len(train_paths) * 0.2)

train_gen = load_data.Generator(
                             train_paths[val_count:],
                             batch_size=64)
val_gen = load_data.Generator(
                             train_paths[:val_count],
                             batch_size=64)

# モデルの作成
model = Sequential()
model.add(Dense(512, activation='sigmoid', input_shape=(32**2,)))  # 入力層
model.add(BatchNormalization())
model.add(Dense(256, activation='sigmoid'))  # 隠れ層
model.add(Dropout(rate=0.5))  # ドロップアウト
model.add(Dense(94, activation='softmax'))  # 出力層
# コンパイル
model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.1), metrics=['acc'])
# 学習
history = model.fit_generator(
           train_gen,
           steps_per_epoch=train_gen.num_batches_per_epoch,
           validation_data=val_gen,
           validation_steps=val_gen.num_batches_per_epoch,
           epochs=100,
           shuffle=True)
model.save("model.h5")
# model = load_model("model.h5")

# グラフの表示
plt.plot(history.history['acc'], label='acc')
plt.plot(history.history['val_acc'], label='val_acc')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(loc='best')
plt.show()

load_data.py

import numpy as np
import importlib
from PIL import Image, ImageDraw, ImageFont
from keras.utils import Sequence


class Generator(Sequence):
    """Custom generator"""

    def __init__(self, data_paths, batch_size=1, width=32, height=32, font_size=32, num_of_class=94):
        """construction

        :param data_paths: List of image file
        :param batch_size: Batch size
        :param width: Image width
        :param height: Image height
        :param num_of_class: Num of classes
        """

        self.data_paths = data_paths
        self.length = len(data_paths) * 94 * int(180/5)
        self.batch_size = batch_size
        self.width = width
        self.height = height
        self.font_size = font_size
        self.num_of_class = num_of_class
        self.data_pos = [0, 0, 0]
        self.font_data = ImageFont.truetype(self.data_paths[self.data_pos[0]], self.font_size)
        self.num_batches_per_epoch = int((self.length - 1) / batch_size) + 1

    def _load_data(self):
        text = chr(self.data_pos[2] + 33)
        font_path = self.data_paths[self.data_pos[0]]
        font_color = "white"
        rot = self.data_pos[1]*5

        # get fontsize

        tmp = Image.new('RGBA', (1, 1), (0, 0, 0, 0))  # dummy for get text_size

        tmp_d = ImageDraw.Draw(tmp)
        text_size = tmp_d.textsize(text, self.font_data)
        i = self.font_size
        while text_size[0] > self.font_size - 5 or text_size[1] > self.font_size - 5:
            i -= 1
            font_data = ImageFont.truetype(font_path, i)
            text_size = tmp_d.textsize(text, font_data)
        # draw text

        img = Image.new('RGBA', [self.font_size] * 2, (0, 0, 0, 0))  # background: transparent

        img_d = ImageDraw.Draw(img)
        img_d.text((0, 0), text, fill=font_color, font=self.font_data)
        img = img.rotate(rot)

        self.data_pos[2] += 1
        if self.data_pos[2] > 93:
            self.data_pos[1] += 1
            self.data_pos[2] = 0
            if self.data_pos[1] > 180/5:
                importlib.reload(np)
                importlib.reload(Image)
                importlib.reload(ImageDraw)
                importlib.reload(ImageFont)
                importlib.reload(importlib)
                self.data_pos[0] += 1
                self.font_data = ImageFont.truetype(self.data_paths[self.data_pos[0]], self.font_size)
                self.data_pos[1] = 0

        img = np.array(img)
        img = 0.299 * img[:, :, 2] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 0]
        return img, self.data_pos[2]

    def __getitem__(self, idx) -> np.array:
        """Get batch data

        :param idx: Index of batch

        :return imgs: numpy array of images
        :return labels: numpy array of label
        """

        start_pos = self.batch_size * idx
        end_pos = start_pos + self.batch_size
        if end_pos > self.length:
            end_pos = self.length
        imgs = np.empty((end_pos-start_pos+1, self.height, self.width), dtype=np.float32)
        labels = np.zeros((end_pos-start_pos+1, self.num_of_class), dtype=np.int16)

        for i in range(self.batch_size):
            img, label = self._load_data()
            imgs[i, :] = img
            labels[i][label] = 1
        np.save("test.npy", labels)
        # データセットの画像の前処理
        imgs = imgs.reshape((imgs.shape[0], imgs.shape[1] ** 2))

        return imgs, labels

    def __len__(self):
        """Batch length"""

        return self.num_batches_per_epoch

    def on_epoch_end(self):
        """Task when end of epoch"""
        pass

さあどうしよう

  • keras側の呼び出しが問題なのはわかった。
  • importlibでリロードして回避しようと思ったけど無理。
  • while Trueでやる方法あるけど格好悪いしまずまずyieldわからん。

1時間考えた結果.. 思い出した。 エラーの部分、フォントが大きすぎてpillowが扱えないんだった。google font恐るべし。 エラーの部分をこうする

try:
        text_size = tmp_d.textsize(text, self.font_data)
except OSError:
        pass

これで大丈夫。

最後に

この頃忙しすぎ。


個人的な質問等はこちらまで。

https://forms.gle/V6NRhoTooFw15hJdA

また、自分が参加しているRobocup soccer シミュレーションリーグのチームでは参加者募集中です!活動の見学、活動に参加したい方、ご連絡お待ちしております!

詳しくはこちら

kumitatepazuru.github.io