Как лучше всего писать собственный модуль в pytorch? - PullRequest
0 голосов
/ 28 мая 2020
class TfidfEmbeds(BaseEmbeds):
    # Input just placeholder
    def __init__(self, _=None):
        super().__init__()
        self.tfidf = TfidfVectorizer().fit([" ".join(i) for i in train_data])
        # Need these to map features
        self.tfidf_word_list = self.tfidf.get_feature_names()
        self.word_to_ix = {word: i for i, word in enumerate(self.tfidf_word_list)}
        self.cache = {}

    def forward(self, X):
        # Cache
        key = tuple(X)
        if key in self.cache:
            return cache[key]
        # Get text
        words = [word_list[i] for i in X.cpu().numpy()]
        idx = [self.word_to_ix[word] if word in self.word_to_ix else 0 for word in words]
        mask = [False if word in self.word_to_ix else True for word in words]
        outputs = self.tfidf.transform([" ".join(words)])[:, idx].toarray().reshape(-1, 1)
        outputs[mask] = 0
        self.cache[key] = outputs
        return torch.from_numpy(outputs).to(device).float()

    def get_output_dim(self):
        return 1

У меня есть приведенный выше код python для кодирования функции tfidf для списка документов. Я заключил код в Module, но, очевидно, torch не совсем подходит, поскольку я не использую компоненты nn. И невозможно спасти state_dict такого класса. Я хочу знать, как правильно создавать такие Module в pytorch, чтобы обеспечить лучшую интеграцию с сохранением .state_dict()?

Спасибо

...