Как видно из документов TensorFlow , партия должна быть вдоль первой оси, например, первый пример в документах:
model.add(Dense(32, input_shape=(16,)))
# now the model will take as input arrays of shape (*, 16)
# and output arrays of shape (*, 32)
, поэтому, если, например, размер партии равен10, то здесь ввод будет (10, 16)
, а вывод (10, 32)
. Так почему же тогда мой плотный слой с двумя нейронами выводит (2, *)
вместо (*, 2)
?
Я вставил print()
формы того, что мой активационный слой получает от этого плотного слоя, так что вы можете увидеть shape=(2,)
после компиляции.
library(keras)
library(tidyverse)
set.seed(2019)
weibull_activate <- function(ab) {
print(k_shape(ab))
a = k_exp(ab[, 1])
b = k_softplus(ab[, 2])
a = k_reshape(a, c(-1, 1))
b = k_reshape(b, c(-1, 1))
return(k_concatenate(list(a, b)))
}
weibull_loglik_continuous <- function(y_true, y_pred) {
y_ = y_true[, 1]
u_ = y_true[, 2]
a_ = y_pred[, 1]
b_ = y_pred[, 2]
ya = (y_ + 1e-35) / a_
return(-1 * k_mean(u_ * (k_log(b_) + b_ * k_log(ya)) - k_pow(ya, b_)))
}
test_data <- tibble(
x = runif(1e5, min = -1, max = 1),
true_shape = 2*x + 2.5,
true_scale = 10*x + 10.5,
y = map2_dbl(true_shape, true_scale, rweibull, n = 1),
o = 1
)
test_model <-
keras_model_sequential() %>%
layer_dense(input_shape = 1,
units = 2,
name = "dense_1") %>%
layer_activation(weibull_activate,
name = "weibull_activate") %>%
compile(optimizer = "rmsprop",
loss = weibull_loglik_continuous)
#> Tensor("weibull_activate/Shape:0", shape=(2,), dtype=int32)
Извините за довольно большой MWE, но это также может помочь увидеть общую ситуацию. Эти 2 нейрона должны быть параметрами Вейбулла, но они продолжают отрицательно работать при обучении моим реальным данным, и даже в этом простом тесте параметры не сходятся к реальным, хотя они сходятся по чему-то. Я подозреваю, что вычисление активации и потерь, выполняемое здесь, является математически обоснованным, но бессмысленным, потому что параметры меняются местами или агрегируются по наблюдениям, или что-то в этом роде.
Создано 2019-10-26 гг. Представить пакет (v0.3.0)