У меня есть следующий код, взятый непосредственно из здесь с некоторыми довольно небольшими изменениями:
import pandas as pd
import torch
import json
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config
from torch import cuda
df = pd.read_pickle('df_final.pkl')
model = T5ForConditionalGeneration.from_pretrained('t5-base')
tokenizer = T5Tokenizer.from_pretrained('t5-base')
device = 'cuda' if cuda.is_available() else 'cpu'
text = ''.join(df[(df['col1'] == 'type') & (df['col2'] == 2)].col3.to_list())
preprocess_text = text.strip().replace("\n","")
t5_prepared_Text = "summarize: "+preprocess_text
#print ("original text preprocessed: \n", preprocess_text)
tokenized_text = tokenizer.encode(t5_prepared_Text, return_tensors="pt", max_length = 500000).to(device)
# summmarize
summary_ids = model.generate(tokenized_text,
num_beams=4,
no_repeat_ngram_size=2,
min_length=30,
max_length=100,
early_stopping=True)
output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
print ("\n\nSummarized text: \n",output)
При выполнении части model_generate()
я получаю такую ошибку:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-12-e8e9819a85dc> in <module>
12 min_length=30,
13 max_length=100,
---> 14 early_stopping=True).to(device)
15
16 output = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
~\Anaconda3\lib\site-packages\torch\autograd\grad_mode.py in decorate_no_grad(*args, **kwargs)
47 def decorate_no_grad(*args, **kwargs):
48 with self:
---> 49 return func(*args, **kwargs)
50 return decorate_no_grad
51
~\Anaconda3\lib\site-packages\transformers\generation_utils.py in generate(self, input_ids, max_length, min_length, do_sample, early_stopping, num_beams, temperature, top_k, top_p, repetition_penalty, bad_words_ids, bos_token_id, pad_token_id, eos_token_id, length_penalty, no_repeat_ngram_size, num_return_sequences, attention_mask, decoder_start_token_id, use_cache, **model_specific_kwargs)
383 encoder = self.get_encoder()
384
--> 385 encoder_outputs: tuple = encoder(input_ids, attention_mask=attention_mask)
386
387 # Expand input ids if num_beams > 1 or num_return_sequences > 1
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~\Anaconda3\lib\site-packages\transformers\modeling_t5.py in forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, past_key_value_states, use_cache, output_attentions, output_hidden_states, return_dict)
701 if inputs_embeds is None:
702 assert self.embed_tokens is not None, "You have to intialize the model with valid token embeddings"
--> 703 inputs_embeds = self.embed_tokens(input_ids)
704
705 batch_size, seq_length = input_shape
~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in __call__(self, *input, **kwargs)
539 result = self._slow_forward(*input, **kwargs)
540 else:
--> 541 result = self.forward(*input, **kwargs)
542 for hook in self._forward_hooks.values():
543 hook_result = hook(self, input, result)
~\Anaconda3\lib\site-packages\torch\nn\modules\sparse.py in forward(self, input)
112 return F.embedding(
113 input, self.weight, self.padding_idx, self.max_norm,
--> 114 self.norm_type, self.scale_grad_by_freq, self.sparse)
115
116 def extra_repr(self):
~\Anaconda3\lib\site-packages\torch\nn\functional.py in embedding(input, weight, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse)
1482 # remove once script supports set_grad_enabled
1483 _no_grad_embedding_renorm_(weight, input, max_norm, norm_type)
-> 1484 return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
1485
1486
RuntimeError: Expected object of device type cuda but got device type cpu for argument #1 'self' in call to _th_index_select
Я искал эту ошибку и нашел некоторые другие темы, такие как this one и this one, но они мне не очень помогли, так как их случай кажется полностью разные. В моем случае не было создано никаких пользовательских экземпляров или классов, поэтому я не знаю, как исправить это или откуда возникла ошибка.
Не могли бы вы сказать мне, откуда возникает ошибка и как я могу исправить?
Заранее большое спасибо.