входные данные исчерпаны, генератор может генерировать как минимум пакеты «steps_per_epoch * epochs». Возможно, вам понадобится использовать функцию repeat ()? - PullRequest
0 голосов
/ 01 августа 2020

Я пытаюсь запустить простой код, но он обучен только для одной эпохи и остановился.

Можете ли вы дать мне решение?

Мой полный код ниже как простой код, код basi c.

Предупреждение находится самое большее ниже.

большая часть кода работает хорошо, но функция соответствия также работает, но ее недостаточно.

import numpy as np
import os
import PIL
import PIL.Image
import tensorflow as tf
print(tf.__version__)

import pathlib
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url, 
                                   fname='flower_photos', 
                                   untar=True)
data_dir = pathlib.Path(data_dir)

image_count = len(list(data_dir.glob('*/*.jpg')))
print(image_count)

roses = list(data_dir.glob('roses/*'))
PIL.Image.open(str(roses[0]))

batch_size = 32
img_height = 180
img_width = 180

train_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="training",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size
    )

num_train = len(np.concatenate([i for x, i in train_ds], axis=0))

val_ds = tf.keras.preprocessing.image_dataset_from_directory(
    data_dir,
    validation_split=0.2,
    subset="validation",
    seed=123,
    image_size=(img_height, img_width),
    batch_size=batch_size)

num_test = len(np.concatenate([i for x, i in val_ds], axis=0))


class_names = train_ds.class_names
print(class_names)

import matplotlib.pyplot as plt

plt.figure(figsize=(10, 10))
for images, labels in train_ds.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

for image_batch, labels_batch in train_ds:
    print(image_batch.shape)
    print(labels_batch.shape)
    break

from tensorflow.keras import layers

normalization_layer = tf.keras.layers.experimental.preprocessing.Rescaling(1./255)

normalized_ds = train_ds.map(lambda x, y: (normalization_layer(x), y))
image_batch, labels_batch = next(iter(normalized_ds))
first_image = image_batch[0]
# Notice the pixels values are now in `[0,1]`.
print(np.min(first_image), np.max(first_image)) 

nor_val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y))


# Convolutional Neural Network

import tensorflow as tf
from tensorflow.keras import datasets, layers, models

model  = models.Sequential()
model.add(layers.Conv2D(32,(3,3),activation='relu',input_shape = (180, 180,3)))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(64,(3,3),activation='relu'))
model.add(layers.MaxPooling2D((2,2)))
model.add(layers.Conv2D(64,(3,3),activation='relu'))
model.add(layers.Flatten())
model.add(layers.Dense(64,activation = 'relu'))
model.add(layers.Dense(5,activation = 'softmax'))

model.summary()

model.compile(optimizer = 'adam', loss = 'sparse_categorical_crossentropy', metrics=['accuracy'])

histoy = model.fit(normalized_ds,
        steps_per_epoch= (num_train//batch_size),
        epochs=20,
        shuffle=True,
        validation_data=nor_val_ds,  
        validation_steps = (num_test//batch_size) 
       )

В процессе обучения сеть выводит:

Epoch 1/20
91/91 [==============================] - 63s 696ms/step - loss: 0.8712 - accuracy: 0.6672 - val_loss: 0.9402 - val_accuracy: 0.6293
Epoch 2/20
 1/91 [..............................] - ETA: 0s - loss: 0.8736 - accuracy: 0.6667WARNING:tensorflow:Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches (in this case, 1820 batches). You may need to use the repeat() function when building your dataset.
 1/91 [..............................] - 4s 4s/step - loss: 0.8736 - accuracy: 0.6667 - val_loss: 0.9902 - val_accuracy: 0.6179
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...