Я пытаюсь реализовать пользовательский GRUCell, чтобы сохранить значения вентилей r и z. Код идентичен оригинальному keras.layer.recurrent.py
. Единственное отличие состоит в том, что я добавил две переменные класса для хранения тензоров (строки 48-49) и метод получения. Тензоры сохраняются в строке 189-190 или 235-236
class GRUCellTweak(Layer):
def __init__(self, units,
activation='tanh',
recurrent_activation='sigmoid',
use_bias=True,
kernel_initializer='glorot_uniform',
recurrent_initializer='orthogonal',
bias_initializer='zeros',
kernel_regularizer=None,
recurrent_regularizer=None,
bias_regularizer=None,
kernel_constraint=None,
recurrent_constraint=None,
bias_constraint=None,
dropout=0.,
recurrent_dropout=0.,
implementation=2,
reset_after=False,
**kwargs):
super(GRUCellTweak, self).__init__(**kwargs)
self.units = units
self.activation = activations.get(activation)
self.recurrent_activation = activations.get(recurrent_activation)
self.use_bias = use_bias
self.kernel_initializer = initializers.get(kernel_initializer)
self.recurrent_initializer = initializers.get(recurrent_initializer)
self.bias_initializer = initializers.get(bias_initializer)
self.kernel_regularizer = regularizers.get(kernel_regularizer)
self.recurrent_regularizer = regularizers.get(recurrent_regularizer)
self.bias_regularizer = regularizers.get(bias_regularizer)
self.kernel_constraint = constraints.get(kernel_constraint)
self.recurrent_constraint = constraints.get(recurrent_constraint)
self.bias_constraint = constraints.get(bias_constraint)
self.dropout = min(1., max(0., dropout))
self.recurrent_dropout = min(1., max(0., recurrent_dropout))
self.implementation = implementation
self.reset_after = reset_after
self.state_size = self.units
self.output_size = self.units
self._dropout_mask = None
self._recurrent_dropout_mask = None
self.reset_gate_history = []
self.update_gate_history = []
def build(self, input_shape):
input_dim = input_shape[-1]
if isinstance(self.recurrent_initializer, initializers.Identity):
def recurrent_identity(shape, gain=1., dtype=None):
del dtype
return gain * np.concatenate(
[np.identity(shape[0])] * (shape[1] // shape[0]), axis=1)
self.recurrent_initializer = recurrent_identity
self.kernel = self.add_weight(shape=(input_dim, self.units * 3),
name='kernel',
initializer=self.kernel_initializer,
regularizer=self.kernel_regularizer,
constraint=self.kernel_constraint)
self.recurrent_kernel = self.add_weight(
shape=(self.units, self.units * 3),
name='recurrent_kernel',
initializer=self.recurrent_initializer,
regularizer=self.recurrent_regularizer,
constraint=self.recurrent_constraint)
if self.use_bias:
if not self.reset_after:
bias_shape = (3 * self.units,)
else:
# separate biases for input and recurrent kernels
# Note: the shape is intentionally different from CuDNNGRU biases
# `(2 * 3 * self.units,)`, so that we can distinguish the classes
# when loading and converting saved weights.
bias_shape = (2, 3 * self.units)
self.bias = self.add_weight(shape=bias_shape,
name='bias',
initializer=self.bias_initializer,
regularizer=self.bias_regularizer,
constraint=self.bias_constraint)
if not self.reset_after:
self.input_bias, self.recurrent_bias = self.bias, None
else:
# NOTE: need to flatten, since slicing in CNTK gives 2D array
self.input_bias = K.flatten(self.bias[0])
self.recurrent_bias = K.flatten(self.bias[1])
else:
self.bias = None
# update gate
self.kernel_z = self.kernel[:, :self.units]
self.recurrent_kernel_z = self.recurrent_kernel[:, :self.units]
# reset gate
self.kernel_r = self.kernel[:, self.units: self.units * 2]
self.recurrent_kernel_r = self.recurrent_kernel[:,
self.units:
self.units * 2]
# new gate
self.kernel_h = self.kernel[:, self.units * 2:]
self.recurrent_kernel_h = self.recurrent_kernel[:, self.units * 2:]
if self.use_bias:
# bias for inputs
self.input_bias_z = self.input_bias[:self.units]
self.input_bias_r = self.input_bias[self.units: self.units * 2]
self.input_bias_h = self.input_bias[self.units * 2:]
# bias for hidden state - just for compatibility with CuDNN
if self.reset_after:
self.recurrent_bias_z = self.recurrent_bias[:self.units]
self.recurrent_bias_r = (
self.recurrent_bias[self.units: self.units * 2])
self.recurrent_bias_h = self.recurrent_bias[self.units * 2:]
else:
self.input_bias_z = None
self.input_bias_r = None
self.input_bias_h = None
if self.reset_after:
self.recurrent_bias_z = None
self.recurrent_bias_r = None
self.recurrent_bias_h = None
self.built = True
def call(self, inputs, states, training=None):
h_tm1 = states[0] # previous memory
if 0 < self.dropout < 1 and self._dropout_mask is None:
self._dropout_mask = _generate_dropout_mask(
K.ones_like(inputs),
self.dropout,
training=training,
count=3)
if (0 < self.recurrent_dropout < 1 and
self._recurrent_dropout_mask is None):
self._recurrent_dropout_mask = _generate_dropout_mask(
K.ones_like(h_tm1),
self.recurrent_dropout,
training=training,
count=3)
# dropout matrices for input units
dp_mask = self._dropout_mask
# dropout matrices for recurrent units
rec_dp_mask = self._recurrent_dropout_mask
if self.implementation == 1:
if 0. < self.dropout < 1.:
inputs_z = inputs * dp_mask[0]
inputs_r = inputs * dp_mask[1]
inputs_h = inputs * dp_mask[2]
else:
inputs_z = inputs
inputs_r = inputs
inputs_h = inputs
x_z = K.dot(inputs_z, self.kernel_z)
x_r = K.dot(inputs_r, self.kernel_r)
x_h = K.dot(inputs_h, self.kernel_h)
if self.use_bias:
x_z = K.bias_add(x_z, self.input_bias_z)
x_r = K.bias_add(x_r, self.input_bias_r)
x_h = K.bias_add(x_h, self.input_bias_h)
if 0. < self.recurrent_dropout < 1.:
h_tm1_z = h_tm1 * rec_dp_mask[0]
h_tm1_r = h_tm1 * rec_dp_mask[1]
h_tm1_h = h_tm1 * rec_dp_mask[2]
else:
h_tm1_z = h_tm1
h_tm1_r = h_tm1
h_tm1_h = h_tm1
recurrent_z = K.dot(h_tm1_z, self.recurrent_kernel_z)
recurrent_r = K.dot(h_tm1_r, self.recurrent_kernel_r)
if self.reset_after and self.use_bias:
recurrent_z = K.bias_add(recurrent_z, self.recurrent_bias_z)
recurrent_r = K.bias_add(recurrent_r, self.recurrent_bias_r)
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
# tweak to store gate activations
self.reset_gate_history.append(r)
self.update_gate_history.append(z)
# reset gate applied after/before matrix multiplication
if self.reset_after:
recurrent_h = K.dot(h_tm1_h, self.recurrent_kernel_h)
if self.use_bias:
recurrent_h = K.bias_add(recurrent_h, self.recurrent_bias_h)
recurrent_h = r * recurrent_h
else:
recurrent_h = K.dot(r * h_tm1_h, self.recurrent_kernel_h)
hh = self.activation(x_h + recurrent_h)
else:
if 0. < self.dropout < 1.:
inputs *= dp_mask[0]
# inputs projected by all gate matrices at once
matrix_x = K.dot(inputs, self.kernel)
if self.use_bias:
# biases: bias_z_i, bias_r_i, bias_h_i
matrix_x = K.bias_add(matrix_x, self.input_bias)
x_z = matrix_x[:, :self.units]
x_r = matrix_x[:, self.units: 2 * self.units]
x_h = matrix_x[:, 2 * self.units:]
if 0. < self.recurrent_dropout < 1.:
h_tm1 *= rec_dp_mask[0]
if self.reset_after:
# hidden state projected by all gate matrices at once
matrix_inner = K.dot(h_tm1, self.recurrent_kernel)
if self.use_bias:
matrix_inner = K.bias_add(matrix_inner, self.recurrent_bias)
else:
# hidden state projected separately for update/reset and new
matrix_inner = K.dot(h_tm1,
self.recurrent_kernel[:, :2 * self.units])
recurrent_z = matrix_inner[:, :self.units]
recurrent_r = matrix_inner[:, self.units: 2 * self.units]
z = self.recurrent_activation(x_z + recurrent_z)
r = self.recurrent_activation(x_r + recurrent_r)
# tweak to store gate activations
self.reset_gate_history.append(r)
self.update_gate_history.append(z)
if self.reset_after:
recurrent_h = r * matrix_inner[:, 2 * self.units:]
else:
recurrent_h = K.dot(r * h_tm1,
self.recurrent_kernel[:, 2 * self.units:])
hh = self.activation(x_h + recurrent_h)
# previous and candidate state mixed by update gate
h = z * h_tm1 + (1 - z) * hh
if 0 < self.dropout + self.recurrent_dropout:
if training is None:
h._uses_learning_phase = True
return h, [h]
def get_config(self):
config = {'units': self.units,
'activation': activations.serialize(self.activation),
'recurrent_activation':
activations.serialize(self.recurrent_activation),
'use_bias': self.use_bias,
'kernel_initializer':
initializers.serialize(self.kernel_initializer),
'recurrent_initializer':
initializers.serialize(self.recurrent_initializer),
'bias_initializer': initializers.serialize(self.bias_initializer),
'kernel_regularizer':
regularizers.serialize(self.kernel_regularizer),
'recurrent_regularizer':
regularizers.serialize(self.recurrent_regularizer),
'bias_regularizer': regularizers.serialize(self.bias_regularizer),
'kernel_constraint': constraints.serialize(self.kernel_constraint),
'recurrent_constraint':
constraints.serialize(self.recurrent_constraint),
'bias_constraint': constraints.serialize(self.bias_constraint),
'dropout': self.dropout,
'recurrent_dropout': self.recurrent_dropout,
'implementation': self.implementation,
'reset_after': self.reset_after}
base_config = super(GRUCellTweak, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def get_gates_history():
return self.reset_gate_history, self.update_gate_history
Модель определяется как:
input_tensor = Input(shape=(timesteps, features), name="input")
conv = Conv1D(filters=128, kernel_size=6,use_bias=True,
strides=1, dilation_rate=1, padding='same')(input_tensor)
b = BatchNormalization()(conv)
cell = GRUCellTweak(512)
s_gru, states = keras.layers.recurrent.RNN(cell, return_sequences=True, return_state=True)(b)
biases = tf.keras.initializers.Constant(value=20.15)
out = Dense(1, activation='linear', name="output",
kernel_initializer='truncated_normal', bias_initializer=biases)(s_gru)
model = Model(inputs=input_tensor, outputs=[out])
Я могу обучить модель, однако, когда я пытаюсь получить тензоры вот так (пример массива Numpy):
sample = test_x[7:8, :, :]
pred = model.predict(sample)
rnn = model.layers[-2].cell #select the rnn layer
z, r = rnn.get_gates_history()
print(K.eval(r[0]))
Я получаю следующую трассировку стека:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _do_call(self, fn, *args)
1364 try:
-> 1365 return fn(*args)
1366 except errors.OpError as e:
9 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run_fn(feed_dict, fetch_list, target_list, options, run_metadata)
1349 return self._call_tf_sessionrun(options, feed_dict, fetch_list,
-> 1350 target_list, run_metadata)
1351
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _call_tf_sessionrun(self, options, feed_dict, fetch_list, target_list, run_metadata)
1442 fetch_list, target_list,
-> 1443 run_metadata)
1444
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: You must feed a value for placeholder tensor 'input_11' with dtype float and shape [?,1250,2]
[[{{node input_11}}]]
(1) Invalid argument: You must feed a value for placeholder tensor 'input_11' with dtype float and shape [?,1250,2]
[[{{node input_11}}]]
[[rnn_5/Sigmoid/_959]]
0 successful operations.
0 derived errors ignored.
During handling of the above exception, another exception occurred:
InvalidArgumentError Traceback (most recent call last)
<ipython-input-53-2c38b84327d9> in <module>()
5 z, r = rnn.get_gates_history()
6 print(r[0].shape, z[0].shape)
----> 7 print(K.eval(r[0]))
8 return
9 sample = test_x[7:8, :, :]
/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py in eval(x)
701 {{np_implementation}}
702 """
--> 703 return to_dense(x).eval(session=get_session())
704
705
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py in eval(self, feed_dict, session)
796
797 """
--> 798 return _eval_using_default_session(self, feed_dict, self.graph, session)
799
800 def experimental_ref(self):
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py in _eval_using_default_session(tensors, feed_dict, graph, session)
5405 "the tensor's graph is different from the session's "
5406 "graph.")
-> 5407 return session.run(tensors, feed_dict)
5408
5409
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
954 try:
955 result = self._run(None, fetches, feed_dict, options_ptr,
--> 956 run_metadata_ptr)
957 if run_metadata:
958 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
1178 if final_fetches or final_targets or (handle and feed_dict_tensor):
1179 results = self._do_run(handle, final_targets, final_fetches,
-> 1180 feed_dict_tensor, options, run_metadata)
1181 else:
1182 results = []
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _do_run(self, handle, target_list, fetch_list, feed_dict, options, run_metadata)
1357 if handle is None:
1358 return self._do_call(_run_fn, feeds, fetches, targets, options,
-> 1359 run_metadata)
1360 else:
1361 return self._do_call(_prun_fn, handle, feeds, fetches)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/client/session.py in _do_call(self, fn, *args)
1382 '\nsession_config.graph_options.rewrite_options.'
1383 'disable_meta_optimizer = True')
-> 1384 raise type(e)(node_def, op, message)
1385
1386 def _extend_graph(self):
InvalidArgumentError: 2 root error(s) found.
(0) Invalid argument: You must feed a value for placeholder tensor 'input_11' with dtype float and shape [?,1250,2]
[[node input_11 (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1748) ]]
(1) Invalid argument: You must feed a value for placeholder tensor 'input_11' with dtype float and shape [?,1250,2]
[[node input_11 (defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1748) ]]
[[rnn_5/Sigmoid/_959]]
0 successful operations.
0 derived errors ignored.
Original stack trace for 'input_11':
File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
"__main__", mod_spec)
File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
exec(code, run_globals)
File "/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py", line 16, in <module>
app.launch_new_instance()
File "/usr/local/lib/python3.6/dist-packages/traitlets/config/application.py", line 664, in launch_instance
app.start()
File "/usr/local/lib/python3.6/dist-packages/ipykernel/kernelapp.py", line 477, in start
ioloop.IOLoop.instance().start()
File "/usr/local/lib/python3.6/dist-packages/tornado/ioloop.py", line 888, in start
handler_func(fd_obj, events)
File "/usr/local/lib/python3.6/dist-packages/tornado/stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/zmq/eventloop/zmqstream.py", line 450, in _handle_events
self._handle_recv()
File "/usr/local/lib/python3.6/dist-packages/zmq/eventloop/zmqstream.py", line 480, in _handle_recv
self._run_callback(callback, msg)
File "/usr/local/lib/python3.6/dist-packages/zmq/eventloop/zmqstream.py", line 432, in _run_callback
callback(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tornado/stack_context.py", line 277, in null_wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/ipykernel/kernelbase.py", line 283, in dispatcher
return self.dispatch_shell(stream, msg)
File "/usr/local/lib/python3.6/dist-packages/ipykernel/kernelbase.py", line 235, in dispatch_shell
handler(stream, idents, msg)
File "/usr/local/lib/python3.6/dist-packages/ipykernel/kernelbase.py", line 399, in execute_request
user_expressions, allow_stdin)
File "/usr/local/lib/python3.6/dist-packages/ipykernel/ipkernel.py", line 196, in do_execute
res = shell.run_cell(code, store_history=store_history, silent=silent)
File "/usr/local/lib/python3.6/dist-packages/ipykernel/zmqshell.py", line 533, in run_cell
return super(ZMQInteractiveShell, self).run_cell(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py", line 2718, in run_cell
interactivity=interactivity, compiler=compiler, result=result)
File "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py", line 2822, in run_ast_nodes
if self.run_code(code, result):
File "/usr/local/lib/python3.6/dist-packages/IPython/core/interactiveshell.py", line 2882, in run_code
exec(code_obj, self.user_global_ns, self.user_ns)
File "<ipython-input-51-5af6297ceb19>", line 3, in <module>
model = create_model(train_x.shape[1], train_x.shape[2])
File "<ipython-input-30-e8b0b573d8d5>", line 9, in create_model
input_tensor = Input(shape=(timesteps, features), name="input")
File "/usr/local/lib/python3.6/dist-packages/keras/engine/input_layer.py", line 178, in Input
input_tensor=tensor)
File "/usr/local/lib/python3.6/dist-packages/keras/legacy/interfaces.py", line 91, in wrapper
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/keras/engine/input_layer.py", line 87, in __init__
name=self.name)
File "/usr/local/lib/python3.6/dist-packages/keras/backend/tensorflow_backend.py", line 541, in placeholder
x = tf.placeholder(dtype, shape=shape, name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/array_ops.py", line 2619, in placeholder
return gen_array_ops.placeholder(dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_array_ops.py", line 6669, in placeholder
"Placeholder", dtype=dtype, shape=shape, name=name)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/op_def_library.py", line 794, in _apply_op_helper
op_def=op_def)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/util/deprecation.py", line 507, in new_func
return func(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 3357, in create_op
attrs, op_def, compute_device)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 3426, in _create_op_internal
op_def=op_def)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py", line 1748, in __init__
self._traceback = tf_stack.extract_stack()
Как я могу вернуть z и r тензоров стробов?