Вот то, что работает для общего числа входов и соответствующего выражения einsum, а также для конкретного случая скалярного сокращения -
def einsum_outshape(einsum_expr, inputs):
shps = np.concatenate([in_.shape for in_ in inputs])
p = einsum_expr.split(',')
s = p[:-1] + p[-1].split('->')
if s[-1]=='':
return ()
else:
inop = list(map(list,s))
return tuple(shps[(np.concatenate(inop[:-1])[:,None]==inop[-1]).argmax(0)])
Пример выполнения -
In [42]: a = np.random.rand(1,2,5)
...: b = np.random.rand(4,5)
...: c = np.random.rand(5,7,8)
...: d = np.random.rand(7,9)
In [43]: einsum_outshape('ijk,mk,kpq,pr->ikpqr', inputs=(a,b,c,d))
Out[43]: (1, 5, 7, 8, 9)
# Reduction to a scalar
In [44]: einsum_outshape('ijk,mk,kpq,pr->', inputs=(a,b,c,d))
Out[44]: ()