Как я могу использовать метод .predict для модели tenorflow, которую я скачал с FTP-сервера? - PullRequest
0 голосов
/ 25 марта 2020

Я создал модель тензорного потока, загрузил ее на FTP-сервер и затем загрузил. Как я могу использовать метод .predict для этой загруженной модели?

Это функция для загрузки на FTP

def upload_model(self):
    domain_name, username, password = self._get_ftp_credentials()

    ftp = FTP(domain_name)
    ftp.login(user=username,
              passwd=password)

    ftp.cwd('chatbot_models')

    with open('models/models/model.h5', 'rb') as fp:
        ftp.storlines("STOR " + 'model.h5', fp)

    ftp.quit()

Это функция для загрузки с FTP

def _get_model(self):
    domain_name, username, password = self._get_ftp_credentials()

    ftp = FTP(domain_name)
    ftp.login(user=username,
              passwd=password)

    ftp.cwd('chatbot_models')

    filename = 'model.h5'

    model = open(filename, 'wb')
    ftp.retrbinary('RETR ' + filename, model.write, 1024)

    ftp.quit()

    return model

Когда я пытаюсь использовать метод .predict с загруженной моделью, он говорит:

AttributeError: '_io.BufferedWriter' объект не имеет атрибута'gnast '

Полный код:

class QuestionAnswerer(DataProcessor):
    def __init__(self, json_data):
        super().__init__(json_data)
        self.model = self._get_model()
        self.tokenizer = self.create_tokenizer_obj()

    def _get_ftp_credentials(self):
        env_path = Path('../env') / '.env'
        load_dotenv(verbose=True, dotenv_path=env_path)

        domain_name = os.getenv("FTP_HOST")
        ftp_username = os.getenv("FTP_USERNAME")
        ftp_password = os.getenv("FTP_PASSWORD")

        return domain_name, ftp_username, ftp_password

    def _get_model(self):
        domain_name, username, password = self._get_ftp_credentials()

        with FTP(domain_name, username, password) as ftp:
            ftp.cwd('chatbot_models')

            with open('model.h5', 'wb') as file:
                ftp.retrbinary('RETR ' + 'model.h5', file.write)

        return file

    def get_answer(self, question):
        tags = self.get_tags()

        words = [question]
        X_test_tokens = self.tokenizer.texts_to_sequences(words)

        X_test_pad = pad_sequences(X_test_tokens, maxlen=len(tags), padding='post')
        results = self.model.predict(X_test_pad)

Вот как я тестирую свой класс и мою модель:

from socialworks_neural_networks import QuestionAnswerer
from tensorflow.keras.models import load_model
import json

with open('questionss_responses.json') as file:
    data = json.load(file)

model = QuestionAnswerer(data)
model._get_model()

model_1 = load_model('model.h5')
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...