Форма вывода numpy.einsum - PullRequest
       41

Форма вывода numpy.einsum

1 голос
/ 14 октября 2019

Есть ли элегантный способ предварительно вычислить форму результата из np.einsum заданных входных аргументов einsum (без выполнения вычисления)?

# Given a, b and signature with 
# a.shape == (1, 2, 5)
# b.shape == (4, 5)
einsum_shape('ijk,mk->ik', a, b) # returns (1, 5)

Ответы [ 2 ]

1 голос
/ 14 октября 2019

Вот то, что работает для общего числа входов и соответствующего выражения einsum, а также для конкретного случая скалярного сокращения -

def einsum_outshape(einsum_expr, inputs):
    shps = np.concatenate([in_.shape for in_ in inputs])
    p = einsum_expr.split(',')
    s = p[:-1] + p[-1].split('->')
    if s[-1]=='':
        return ()
    else:
        inop = list(map(list,s))
        return tuple(shps[(np.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)])

Пример выполнения -

In [42]: a = np.random.rand(1,2,5)
    ...: b = np.random.rand(4,5)
    ...: c = np.random.rand(5,7,8)
    ...: d = np.random.rand(7,9)

In [43]: einsum_outshape('ijk,mk,kpq,pr->ikpqr', inputs=(a,b,c,d))
Out[43]: (1, 5, 7, 8, 9)

# Reduction to a scalar
In [44]: einsum_outshape('ijk,mk,kpq,pr->', inputs=(a,b,c,d))
Out[44]: ()
0 голосов
/ 15 октября 2019

Основываясь на ответе @ Divakar, я пришел к следующему, которое немного более читабельно и вызывает ошибки, если пропущены неподдерживаемые строки индекса.

def einsum_outshape(subscripts, *operants):
    """Compute the shape of output from `numpy.einsum`.

    Does not support ellipses.

    """
    if "." in subscripts:
        raise ValueError(f'Ellipses are not supported: {subscripts}')

    insubs, outsubs = subscripts.replace(",", "").split("->")
    if outsubs == "":
        return ()
    insubs = np.array(list(insubs))
    innumber = np.concatenate([op.shape for op in operants])
    outshape = []
    for o in outsubs:
        indices, = np.where(insubs == o)
        try:
            outshape.append(innumber[indices].max())
        except ValueError:
            raise ValueError(f'Invalid subscripts: {subscripts}')
    return tuple(outshape)
...