kerasでfit_generatorを実行していた時にエラーが出て手こずった話
久しぶりの更新。stack overflowエラーが出て大変な思いをしたので書いておく。
ある日、激おそ深層学習をやってたらエラーが出まして。大変だった。
現状把握
エラー内容
以下の通り。
6629/144508 [>.............................] - ETA: 3:30:58 - loss: 3.1317 - acc: 0.1687Traceback (most recent call last): File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/utils/data_utils.py", line 578, in get inputs = self.queue.get(block=True).get() File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/pool.py", line 644, in get raise self._value File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/multiprocessing/pool.py", line 119, in worker result = (True, func(*args, **kwds)) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/utils/data_utils.py", line 401, in get_index return _SHARED_SEQUENCES[uid][i] File "/home/***/PycharmProjects/img2speech/load_data.py", line 91, in __getitem__ img, label = self._load_data() File "/home/***/PycharmProjects/img2speech/load_data.py", line 42, in _load_data text_size = tmp_d.textsize(text, self.font_data) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/PIL/ImageDraw.py", line 430, in textsize return font.getsize(text, direction, features, language, stroke_width) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/PIL/ImageFont.py", line 262, in getsize size, offset = self.font.getsize(text, False, direction, features, language) OSError: stack overflow The above exception was the direct cause of the following exception: Traceback (most recent call last): File "/home/***/PycharmProjects/img2speech/main.py", line 58, in <module> shuffle=True) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper return func(*args, **kwargs) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/models.py", line 1315, in fit_generator initial_epoch=initial_epoch) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/legacy/interfaces.py", line 91, in wrapper return func(*args, **kwargs) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/engine/training.py", line 2194, in fit_generator generator_output = next(output_generator) File "/home/***/.pyenv/versions/3.6.4/lib/python3.6/site-packages/keras/utils/data_utils.py", line 584, in get six.raise_from(StopIteration(e), e) File "<string>", line 3, in raise_from StopIteration: stack overflow Process finished with exit code 1
えぇぇぇ。まさかのstack overflow。
コード
最小限コード。
main.py
# パッケージのインポート import os from keras.layers import BatchNormalization, Dense, Dropout from keras.models import Sequential from keras.optimizers import SGD import matplotlib.pyplot as plt from tqdm import tqdm import load_data train_paths = [] for root, dirs, files in tqdm(os.walk("./font")): train_paths += list(map(lambda n:root+"/"+n,files)) val_count = int(len(train_paths) * 0.2) train_gen = load_data.Generator( train_paths[val_count:], batch_size=64) val_gen = load_data.Generator( train_paths[:val_count], batch_size=64) # モデルの作成 model = Sequential() model.add(Dense(512, activation='sigmoid', input_shape=(32**2,))) # 入力層 model.add(BatchNormalization()) model.add(Dense(256, activation='sigmoid')) # 隠れ層 model.add(Dropout(rate=0.5)) # ドロップアウト model.add(Dense(94, activation='softmax')) # 出力層 # コンパイル model.compile(loss='categorical_crossentropy', optimizer=SGD(lr=0.1), metrics=['acc']) # 学習 history = model.fit_generator( train_gen, steps_per_epoch=train_gen.num_batches_per_epoch, validation_data=val_gen, validation_steps=val_gen.num_batches_per_epoch, epochs=100, shuffle=True) model.save("model.h5") # model = load_model("model.h5") # グラフの表示 plt.plot(history.history['acc'], label='acc') plt.plot(history.history['val_acc'], label='val_acc') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(loc='best') plt.show()
load_data.py
import numpy as np import importlib from PIL import Image, ImageDraw, ImageFont from keras.utils import Sequence class Generator(Sequence): """Custom generator""" def __init__(self, data_paths, batch_size=1, width=32, height=32, font_size=32, num_of_class=94): """construction :param data_paths: List of image file :param batch_size: Batch size :param width: Image width :param height: Image height :param num_of_class: Num of classes """ self.data_paths = data_paths self.length = len(data_paths) * 94 * int(180/5) self.batch_size = batch_size self.width = width self.height = height self.font_size = font_size self.num_of_class = num_of_class self.data_pos = [0, 0, 0] self.font_data = ImageFont.truetype(self.data_paths[self.data_pos[0]], self.font_size) self.num_batches_per_epoch = int((self.length - 1) / batch_size) + 1 def _load_data(self): text = chr(self.data_pos[2] + 33) font_path = self.data_paths[self.data_pos[0]] font_color = "white" rot = self.data_pos[1]*5 # get fontsize tmp = Image.new('RGBA', (1, 1), (0, 0, 0, 0)) # dummy for get text_size tmp_d = ImageDraw.Draw(tmp) text_size = tmp_d.textsize(text, self.font_data) i = self.font_size while text_size[0] > self.font_size - 5 or text_size[1] > self.font_size - 5: i -= 1 font_data = ImageFont.truetype(font_path, i) text_size = tmp_d.textsize(text, font_data) # draw text img = Image.new('RGBA', [self.font_size] * 2, (0, 0, 0, 0)) # background: transparent img_d = ImageDraw.Draw(img) img_d.text((0, 0), text, fill=font_color, font=self.font_data) img = img.rotate(rot) self.data_pos[2] += 1 if self.data_pos[2] > 93: self.data_pos[1] += 1 self.data_pos[2] = 0 if self.data_pos[1] > 180/5: importlib.reload(np) importlib.reload(Image) importlib.reload(ImageDraw) importlib.reload(ImageFont) importlib.reload(importlib) self.data_pos[0] += 1 self.font_data = ImageFont.truetype(self.data_paths[self.data_pos[0]], self.font_size) self.data_pos[1] = 0 img = np.array(img) img = 0.299 * img[:, :, 2] + 0.587 * img[:, :, 1] + 0.114 * img[:, :, 0] return img, self.data_pos[2] def __getitem__(self, idx) -> np.array: """Get batch data :param idx: Index of batch :return imgs: numpy array of images :return labels: numpy array of label """ start_pos = self.batch_size * idx end_pos = start_pos + self.batch_size if end_pos > self.length: end_pos = self.length imgs = np.empty((end_pos-start_pos+1, self.height, self.width), dtype=np.float32) labels = np.zeros((end_pos-start_pos+1, self.num_of_class), dtype=np.int16) for i in range(self.batch_size): img, label = self._load_data() imgs[i, :] = img labels[i][label] = 1 np.save("test.npy", labels) # データセットの画像の前処理 imgs = imgs.reshape((imgs.shape[0], imgs.shape[1] ** 2)) return imgs, labels def __len__(self): """Batch length""" return self.num_batches_per_epoch def on_epoch_end(self): """Task when end of epoch""" pass
さあどうしよう
- keras側の呼び出しが問題なのはわかった。
- importlibでリロードして回避しようと思ったけど無理。
while True
でやる方法あるけど格好悪いしまずまずyield
わからん。
1時間考えた結果.. 思い出した。 エラーの部分、フォントが大きすぎてpillowが扱えないんだった。google font恐るべし。 エラーの部分をこうする
try: text_size = tmp_d.textsize(text, self.font_data) except OSError: pass
これで大丈夫。
最後に
この頃忙しすぎ。
個人的な質問等はこちらまで。
https://forms.gle/V6NRhoTooFw15hJdA
また、自分が参加しているRobocup soccer シミュレーションリーグのチームでは参加者募集中です!活動の見学、活動に参加したい方、ご連絡お待ちしております!
詳しくはこちら