Это следует из Википедии.
import numpy.random as rnd
import numpy as np
A_as_numbers = np.argmax(np.log(P) + rnd.gumbel(size=P.shape), axis=1)
A_one_hot = np.eye(P.shape[1])[A_as_numbers].reshape(P.shape)
Проверено на:
P = np.matrix([[1/4, 1/4, 1/4, 1/4], [1/3,1/3,1/6,1/6]])
Получил:
array([[ 1., 0., 0., 0.],
[ 0., 1., 0., 0.]])