Как ускорить вычисления Tensorflow для решения системы ODE методом runge-kutta? - PullRequest
0 голосов
/ 16 апреля 2019

Я реализовал график Tensorflow, который решает систему OED, используя метод Рунге-Кутты, чтобы оптимизировать решение по некоторым параметрам уравнения. Когда я запускаю график для необходимого количества шагов runge-kutta, он работает очень медленно. Есть ли способ заставить его работать быстрее? Позже мне понадобится использовать график для вычисления градиентов и обратного распространения.

ODE описаны в этом документе: http://web.mit.edu/~gari/www/papers/ieeetbe50p289.pdf

И метод Рунге-Кутта описан здесь: https://en.wikipedia.org/wiki/Runge%E2%80%93Kutta_methods

Для сравнения я реализовал код в numpy, и вычисления выполняются намного быстрее. Я также реализовал только один шаг runge-kutta, а затем открыл сеанс и запускал каждый шаг отдельно. Это было также намного быстрее, чем проходить все ступени Рудж-кутта вместе. Оба подхода мне не подходят, потому что я хочу вычислить градиенты после выполнения ряда шагов в решении runge-kutta.

Вот код, который выполняет одношаговое вычисление в runge-kutta:

import math

import matplotlib.pyplot as plt
import tensorflow as tf
from runge_kutta.tf_mod_op import tf_mod
from runge_kutta import utils
import time
def runge_kutta_single_step_net(x_curr, y_curr, z_curr, t_curr, input_params):

    h = 1 / 512
    A = 0.005  # mV
    f1 = 0.1  # mean 1
    f2 = 0.25  # mean 2
    c1 = 0.01  # std 1
    c2 = 0.01  # std 2
    N, rrpc = utils.generate_omega_function(f1, f2, c1, c2, h)
    rrpc = tf.constant(rrpc)

    a_p = input_params[0]
    a_q = input_params[3]
    a_r = input_params[6]
    a_s = input_params[9]
    a_t = input_params[12]

    b_p = input_params[1]
    b_q = input_params[4]
    b_r = input_params[7]
    b_s = input_params[10]
    b_t = input_params[13]

    theta_p = input_params[2]
    theta_q = input_params[5]
    theta_r = input_params[8]
    theta_s = input_params[11]
    theta_t = input_params[14]

    alpha = 1 - tf.sqrt(x_curr * x_curr + y_curr * y_curr)
    cast = tf.cast((t_curr / h), tf.int32)
    tensor_temp = 1 + cast
    tensor_temp = tf.reshape(tensor_temp, [])
    omega = tf.cond(tf.equal(rrpc[tensor_temp], 0), lambda: math.inf,
                    lambda: 2.0 * math.pi / rrpc[tensor_temp])
    d_x_d_t_next = alpha * x_curr - omega * y_curr

    d_y_d_t_next = alpha * y_curr + omega * x_curr

    theta = tf.atan2(y_curr, x_curr)
    delta_theta_p = tf_mod(theta - theta_p, 2 * math.pi)
    delta_theta_q = tf_mod(theta - theta_q, 2 * math.pi)
    delta_theta_r = tf_mod(theta - theta_r, 2 * math.pi)
    delta_theta_s = tf_mod(theta - theta_s, 2 * math.pi)
    delta_theta_t = tf_mod(theta - theta_t, 2 * math.pi)

    z_p = a_p * delta_theta_p * \
          tf.exp((- delta_theta_p * delta_theta_p / (2 * b_p * b_p)))

    z_q = a_q * delta_theta_q * \
          tf.exp((- delta_theta_q * delta_theta_q / (2 * b_q * b_q)))

    z_r = a_r * delta_theta_r * \
          tf.exp((- delta_theta_r * delta_theta_r / (2 * b_r * b_r)))

    z_s = a_s * delta_theta_s * \
          tf.exp((- delta_theta_s * delta_theta_s / (2 * b_s * b_s)))

    z_t = a_t * delta_theta_t * \
          tf.exp((- delta_theta_t * delta_theta_t / (2 * b_t * b_t)))

    z_0_t = A * tf.sin(2 * math.pi * f2 * t_curr)

    d_z_d_t_next = -1 * (z_p + z_q + z_r + z_s + z_t) - (z_curr - z_0_t)

    k1_x = h * d_x_d_t_next

    k1_y = h * d_y_d_t_next

    k1_z = h * d_z_d_t_next

    # K2 - Stage:
    cast = tf.cast(((t_curr + h / 2) / h), tf.int32)
    tensor_temp = 1 + cast
    tensor_temp = tf.reshape(tensor_temp, [])
    k2_omega = tf.cond(tf.equal(rrpc[tensor_temp], 0), lambda: math.inf,
                       lambda: 2.0 * math.pi / rrpc[tensor_temp])
    # k2_omega = 2.0 * math.pi / rrpc[tensor_temp]

    k2_alpha = 1 - tf.sqrt((x_curr + k1_x / 2) * (x_curr + k1_x / 2) +
                           (y_curr + k1_y / 2) * (y_curr + k1_y / 2))

    k2_x = h * (
            k2_alpha * (x_curr + k1_x / 2) - k2_omega * (y_curr + k1_y / 2))

    k2_y = h * (
            k2_alpha * (y_curr + k1_y / 2) + k2_omega * (x_curr + k1_x / 2))

    k2_theta = tf.atan2(y_curr + k1_y / 2, x_curr + k1_x / 2)
    k2_delta_theta_p = tf_mod(k2_theta - theta_p, 2 * math.pi)
    k2_delta_theta_q = tf_mod(k2_theta - theta_q, 2 * math.pi)
    k2_delta_theta_r = tf_mod(k2_theta - theta_r, 2 * math.pi)
    k2_delta_theta_s = tf_mod(k2_theta - theta_s, 2 * math.pi)
    k2_delta_theta_t = tf_mod(k2_theta - theta_t, 2 * math.pi)
    k2_z_p = a_p * k2_delta_theta_p * \
             tf.exp((- k2_delta_theta_p * k2_delta_theta_p / (2 * b_p * b_p)))
    k2_z_q = a_q * k2_delta_theta_q * \
             tf.exp((- k2_delta_theta_q * k2_delta_theta_q / (2 * b_q * b_q)))
    k2_z_r = a_r * k2_delta_theta_r * \
             tf.exp((- k2_delta_theta_r * k2_delta_theta_r / (2 * b_r * b_r)))
    k2_z_s = a_s * k2_delta_theta_s * \
             tf.exp((- k2_delta_theta_s * k2_delta_theta_s / (2 * b_s * b_s)))
    k2_z_t = a_t * k2_delta_theta_t * \
             tf.exp((- k2_delta_theta_t * k2_delta_theta_t / (2 * b_t * b_t)))

    z_2_t = A * tf.sin(2 * math.pi * f2 * (t_curr + h / 2))

    k2_z = h * (
            -1 * (k2_z_p + k2_z_q + k2_z_r + k2_z_s + k2_z_t) - (z_curr + k1_z / 2 - z_2_t))

    # K3 STAGE:
    cast = tf.cast(((t_curr + h / 2) / h), tf.int32)
    tensor_temp = 1 + cast
    tensor_temp = tf.reshape(tensor_temp, [])
    k3_omega = tf.cond(tf.equal(rrpc[tensor_temp], 0), lambda: math.inf,
                       lambda: 2.0 * math.pi / rrpc[tensor_temp])
    # k3_omega = 2.0 * math.pi / rrpc[tensor_temp]

    k3_alpha = 1 - tf.sqrt((x_curr + k2_x / 2) * (x_curr + k2_x / 2) +
                           (y_curr + k2_y / 2) * (y_curr + k2_y / 2))

    k3_x = h * (
            k3_alpha * (x_curr + k2_x / 2) - k3_omega * (y_curr + k2_y / 2))

    k3_y = h * (
            k3_alpha * (y_curr + k2_y / 2) + k3_omega * (x_curr + k2_x / 2))

    k3_theta = tf.atan2(y_curr + k2_y / 2, x_curr + k2_x / 2)
    k3_delta_theta_p = tf_mod(k3_theta - theta_p, 2 * math.pi)
    k3_delta_theta_q = tf_mod(k3_theta - theta_q, 2 * math.pi)
    k3_delta_theta_r = tf_mod(k3_theta - theta_r, 2 * math.pi)
    k3_delta_theta_s = tf_mod(k3_theta - theta_s, 2 * math.pi)
    k3_delta_theta_t = tf_mod(k3_theta - theta_t, 2 * math.pi)
    k3_z_p = a_p * k3_delta_theta_p * \
             tf.exp((- k3_delta_theta_p * k3_delta_theta_p / (2 * b_p * b_p)))
    k3_z_q = a_q * k3_delta_theta_q * \
             tf.exp((- k3_delta_theta_q * k3_delta_theta_q / (2 * b_q * b_q)))
    k3_z_r = a_r * k3_delta_theta_r * \
             tf.exp((- k3_delta_theta_r * k3_delta_theta_r / (2 * b_r * b_r)))
    k3_z_s = a_s * k3_delta_theta_s * \
             tf.exp((- k3_delta_theta_s * k3_delta_theta_s / (2 * b_s * b_s)))
    k3_z_t = a_t * k3_delta_theta_t * \
             tf.exp((- k3_delta_theta_t * k3_delta_theta_t / (2 * b_t * b_t)))

    z_3_t = A * tf.sin(2 * math.pi * f2 * (t_curr + h / 2))

    k3_z = h * (
            -1 * (k3_z_p + k3_z_q + k3_z_r + k3_z_s + k3_z_t) - (z_curr + k2_z / 2 - z_3_t))

    # K4 STAGE:
    cast = tf.cast(((t_curr + h) / h), tf.int32)
    tensor_temp = 1 + cast
    tensor_temp = tf.reshape(tensor_temp, [])
    k4_omega = tf.cond(tf.equal(rrpc[tensor_temp], 0), lambda: math.inf,
                       lambda: 2.0 * math.pi / rrpc[tensor_temp])
    # k4_omega = 2.0 * math.pi / rrpc[tensor_temp]

    k4_alpha = 1 - tf.sqrt((x_curr + k3_x) * (x_curr + k3_x) +
                           (y_curr + k3_y) * (y_curr + k2_y))
    k4_x = h * (k4_alpha * (x_curr + k3_x) - k4_omega * (y_curr + k3_y))

    k4_y = h * (k4_alpha * (y_curr + k3_y) + k4_omega * (x_curr + k3_x))

    k4_theta = tf.atan2(y_curr + k3_y, x_curr + k3_x)
    k4_delta_theta_p = tf_mod(k4_theta - theta_p, 2 * math.pi)
    k4_delta_theta_q = tf_mod(k4_theta - theta_q, 2 * math.pi)
    k4_delta_theta_r = tf_mod(k4_theta - theta_r, 2 * math.pi)
    k4_delta_theta_s = tf_mod(k4_theta - theta_s, 2 * math.pi)
    k4_delta_theta_t = tf_mod(k4_theta - theta_t, 2 * math.pi)
    k4_z_p = a_p * k4_delta_theta_p * \
             tf.exp((- k4_delta_theta_p * k4_delta_theta_p / (2 * b_p * b_p)))
    k4_z_q = a_q * k4_delta_theta_q * \
             tf.exp((- k4_delta_theta_q * k4_delta_theta_q / (2 * b_q * b_q)))
    k4_z_r = a_r * k4_delta_theta_r * \
             tf.exp((- k4_delta_theta_r * k4_delta_theta_r / (2 * b_r * b_r)))
    k4_z_s = a_s * k4_delta_theta_s * \
             tf.exp((- k4_delta_theta_s * k4_delta_theta_s / (2 * b_s * b_s)))
    k4_z_t = a_t * k4_delta_theta_t * \
             tf.exp((- k4_delta_theta_t * k4_delta_theta_t / (2 * b_t * b_t)))

    z_4_t = A * tf.sin(2 * math.pi * f2 * (t_curr + h))

    k4_z = h * (-1 * (k4_z_p + k4_z_q + k4_z_r + k4_z_s + k4_z_t) - (z_curr + k3_z - z_4_t))

    # Calculate next stage:
    x_next = x_curr + (1 / 6) * (k1_x + 2 * k2_x + 2 * k3_x + k4_x)
    y_next = y_curr + (1 / 6) * (k1_y + 2 * k2_y + 2 * k3_y + k4_y)
    z_next = z_curr + (1 / 6) * (k1_z + 2 * k2_z + 2 * k3_z + k4_z)

    return x_next, y_next, z_next

А вот цикл, который строит 514 шагов:

def build_full_graph():
    params = [1.2, 0.25, -60.0 * math.pi / 180.0, -5.0, 0.1, -15.0 * math.pi / 180.0,
          30.0, 0.1, 0.0 * math.pi / 180.0, -7.5, 0.1, 15.0 * math.pi / 180.0, 0.75, 0.4, 90.0 * math.pi / 180.0]

    input_params = params
    x_curr = tf.placeholder(dtype=tf.float32, shape=[])
    y_curr = tf.placeholder(dtype=tf.float32, shape=[])
    z_curr = tf.placeholder(dtype=tf.float32, shape=[])

    t = 0
    x_next, y_next, z_next = runge_kutta_single_step_net(x_curr, y_curr, z_curr, t, input_params)
    x_t = [x_next]
    y_t = [y_next]
    z_t = [z_next]

    for i in range(514):
        t += 1 / 512
        start = time.time()
        x_next, y_next, z_next = runge_kutta_single_step_net(x_next, y_next, z_next, t, input_params)
        end = time.time()
        tf.logging.info("Iter took: %f", end - start)
        x_t.append(x_next)
        y_t.append(y_next)
        z_t.append(z_next)

    with tf.Session() as sess:
        x = -0.417750770388669
        y = -0.9085616622823985
        z = -0.004551233843726818
        t = 0
        tf.logging.info("Forward pass starts...")
        start = time.time()

        ecg_res = sess.run(z_t, feed_dict={x_curr: x, y_curr: y, z_curr: z})

        end = time.time()
        tf.logging.info("Forward pass took: %f", end - start)
    tf.logging.info("Ploting result")
    plt.plot(ecg_res)
    plt.title("Runge-Kutta heart-beat")
    plt.show()

Когда я запускаю код, он дает мне правильные результаты, но расчет занимает более 2 минут. Есть ли способ сделать граф тензорного потока более эффективным?

...