class TfidfEmbeds(BaseEmbeds):
# Input just placeholder
def __init__(self, _=None):
super().__init__()
self.tfidf = TfidfVectorizer().fit([" ".join(i) for i in train_data])
# Need these to map features
self.tfidf_word_list = self.tfidf.get_feature_names()
self.word_to_ix = {word: i for i, word in enumerate(self.tfidf_word_list)}
self.cache = {}
def forward(self, X):
# Cache
key = tuple(X)
if key in self.cache:
return cache[key]
# Get text
words = [word_list[i] for i in X.cpu().numpy()]
idx = [self.word_to_ix[word] if word in self.word_to_ix else 0 for word in words]
mask = [False if word in self.word_to_ix else True for word in words]
outputs = self.tfidf.transform([" ".join(words)])[:, idx].toarray().reshape(-1, 1)
outputs[mask] = 0
self.cache[key] = outputs
return torch.from_numpy(outputs).to(device).float()
def get_output_dim(self):
return 1
У меня есть приведенный выше код python для кодирования функции tfidf для списка документов. Я заключил код в Module
, но, очевидно, torch не совсем подходит, поскольку я не использую компоненты nn
. И невозможно спасти state_dict
такого класса. Я хочу знать, как правильно создавать такие Module
в pytorch, чтобы обеспечить лучшую интеграцию с сохранением .state_dict()
?
Спасибо