Ошибка: AttributeError: модуль 'transformers' не имеет атрибута 'TFBertModel' - PullRequest
1 голос
/ 09 марта 2020

Я применяю трансферное обучение с помощью python framework (PyTorch). Я получаю приведенную ниже ошибку при загрузке предварительно обученной модели PyTorch в Google Colab. После изменения кода 1 на код 2 я получил ту же ошибку.

CODE 1:  BertModel.from_pretrained
CODE 2: TFBertModel.from_pretrained
Error: AttributeError: module 'transformers' has no attribute 'TFBertModel'

Я попытался выполнить поиск по inte rnet, но не нашел никакого полезного контента.

1 Ответ

1 голос
/ 09 марта 2020

Вероятно, вам следует перечислить доступный пакет с его версией в вашем python и вашей ссылке на Colab, поскольку TFBertModel доступен только при наличии тензорного потока.

Чтобы воспроизвести вашу ошибку. Я играю в Colab следующим образом:

  1. Нет tensorflow вызывает ошибку при импорте TFBertModel
!pip install transformers
from transformers import BertModel, TFBertModel # no attribute 'TFBertModel'
!pip install tensorflow-gpu
from transformers import BertModel, TFBertModel # good to go
Непосредственно используйте BertModel
!pip install transformers
from transformers import BertModel
BertModel.from_pretrained # good to go

В результате моего тестирования вам, вероятно, следует проверить, импортируете ли вы TFBertModel, пока тензор потока удален.

Трансформаторы в основной ветке импортируют только TFBertModel if is_tf_available(), установленное в True. Вот код для if_is_tf_available():

# transformers/src/transformers/file_utils.py 
# >>> 107 lines
def is_tf_available():
    return _tf_available

# >>> 48 lines
try:
    USE_TF = os.environ.get("USE_TF", "AUTO").upper()
    USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper()

    if USE_TF in ("1", "ON", "YES", "AUTO") and USE_TORCH not in ("1", "ON", "YES"):
        import tensorflow as tf

        assert hasattr(tf, "__version__") and int(tf.__version__[0]) >= 2
        _tf_available = True  # pylint: disable=invalid-name
        logger.info("TensorFlow version {} available.".format(tf.__version__))
    else:
        logger.info("Disabling Tensorflow because USE_TORCH is set")
        _tf_available = False
except (ImportError, AssertionError):
    _tf_available = False  # pylint: disable=invalid-name
...