Я переписал реализацию Dynamic Time Warping из обычного питона в Tensorflow. Но это действительно медленно - намного медленнее, чем предварительные вычисления расстояний и загрузка их в Tensorflow в качестве данных. Я не могу понять, почему это медленно или как его улучшить.
Я также попытался преобразовать другие реализации DTW с автографом, но безуспешно. Есть предложения?
def tfDTW(s1, s2):
r = tf.cast(tf.shape(s1)[0], tf.int32)
c = tf.cast(tf.shape(s2)[0], tf.int32)
window = tf.math.reduce_max([r,c])
max_step = max_dist = 1e7
penalty = psi = tf.constant(0, dtype=tf.float64)
length = tf.math.reduce_min([c + 1, tf.math.abs(r - c) + 2 * (window - 1) + 1 + 1 + 1])
indices = [0,-1]
dtw = tf.one_hot(indices, depth = length,
on_value=0.0, off_value=1e7,
axis=-1) # output: [2,length]
dtw=tf.cast(dtw, tf.float64)
last_under_max_dist = tf.constant(0)
skip = tf.constant(0)
i0 = tf.constant(1)
i1 = tf.constant(0)
psi_shortest = 1e7
#
#
def condition1(i, r, dtw, i0, i1, skip, last_under_max_dist):
return tf.less(i, r)
def body1(i, r, dtw, i0, i1, skip, last_under_max_dist):
#
#
prev_last_under_max_dist = tf.cond(tf.equal(last_under_max_dist, -1), lambda: tf.cast(tf.constant(1e7), tf.int32), lambda: last_under_max_dist)
last_under_max_dist = tf.constant(-1)
skipp = skip
skip = tf.reduce_max([0, i - tf.reduce_max([0, r - c]) - window + 1])
i0 = 1 - i0
i1 = 1 - i1
dtw = tf.cond(tf.equal(i1, 0), lambda: tf.concat([tf.fill([1, length], tf.constant(1e7, dtype=tf.float64)), [dtw[1]]], 0), lambda: tf.concat([[dtw[0]], tf.fill([1, length], tf.constant(1e7, dtype=tf.float64))], 0) ) #dtw[i1, :] = np.inf
j_start = tf.reduce_max([0, i - tf.reduce_max([0, r - c]) - window + 1])
j_end = tf.reduce_min([c, i + tf.reduce_max([0, c - r]) + window])
skip = tf.constant(0) #tf.cond(tf.equal(dtw.get_shape()[1], c+1), lambda: 0, lambda: skip )
#if psi != 0 and j_start == 0 and i < psi: dtw[i1, 0] = 0 #psi always ==0
def condition2(j, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp):
return tf.math.logical_and(tf.greater(j, j_start-1), tf.less(j,j_end))
def body2(j, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp):
d = (tf.gather(s1, i) - tf.gather(s2, j))*(tf.gather(s1, i) - tf.gather(s2, j))
d = tf.cast(d, tf.float64)
minval = tf.cast(tf.math.reduce_min([dtw[i0, j - skipp],
dtw[i0, j + 1 - skipp] + penalty,
dtw[i1, j - skip] + penalty]), tf.float64)
indices = tf.cond(tf.equal(i1, 0), lambda: tf.stack([j + 1 - skip, -1] ), lambda: tf.stack([-1, j + 1 - skip]) )
minusdtw = tf.one_hot(indices, depth = length,
on_value=-1*dtw[i1, j + 1 - skip], off_value=tf.constant(0.0, dtype=tf.float64),
axis=-1) # output: [2,length]
replacement = tf.one_hot(indices, depth = length,
on_value=tf.reduce_min([d + minval, 1e7]), off_value=tf.constant(0.0, dtype=tf.float64),
axis=-1) # output: [2,length]
dtw = dtw + minusdtw + replacement
last_under_max_dist = j
return tf.add(j, 1), dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp
#
b = tf.while_loop(condition2, body2, [j_start, dtw, j_start, j_end, last_under_max_dist, prev_last_under_max_dist, skip, skipp ],
[j_start.get_shape(), tf.TensorShape((2,None)), j_start.get_shape(), j_end.get_shape(), last_under_max_dist.get_shape(), prev_last_under_max_dist.get_shape(), skip.get_shape(), skipp.get_shape() ])
return tf.add(i, 1), r, b[1], i0, i1, skip, b[4]
#
a = tf.while_loop(condition1, body1, [tf.constant(0), r, dtw, i0, i1, skip, tf.constant(0) ],
[tf.constant(0).get_shape(), r.get_shape(), tf.TensorShape((None,None)), i0.get_shape(), i1.get_shape(), skip.get_shape(), tf.constant(0).get_shape() ])
maindtw = a[2]
d = tf.math.sqrt(maindtw [a[4]][ tf.reduce_min([c, c + window - 1]) - skip])
return d
import tensorflow as tf
import numpy as np
graph = tf.Graph()
sess = tf.InteractiveSession()
s1 = tf.constant([10, 0, 1, 2, 1, 0, 1, 0, 0,14,22])
s2 = tf.constant([10, 1, 2, 0, 0, 0, 0])
tfDTW(s1, s2).eval() #26.13426869074396