Битовая операция туды cuda в pytorch - PullRequest
1 голос
/ 30 марта 2020

Я хотел бы выполнить некоторую битовую операцию с torch.tensor в cuda, например <<, >>, или извлечь каждый бит, который представляет число с плавающей запятой, например 0 01101 01010101012 (3555H) для 0,333 в float16.

Что я делаю сейчас, как показано ниже:

def _decompose(self, value, exp_bias=None):
    '''
    decompose a single into sign, exp and mant
    '''
    if exp_bias is None:
        exp_bias = self.exp_bias
    # smallest non-zero float point
    descriminator = torch.tensor((2 ** (-exp_bias)) / 2).type_as(value)
    sign = (value > descriminator).type_as(value)
    sign -= (value < -descriminator).type_as(value)
    value = value.abs()
    exp = torch.log2(value).floor()
    mant = value / (2 ** exp)
    return sign, exp, mant

Есть ли способ достичь такой функции? Или что-то не так с моим кодом? Спасибо.

...