экспортировать обученную модель pytorch в SQL запрос - PullRequest
0 голосов
/ 30 апреля 2020

Я обучил модель MLP с использованием PyTorch. Я мог бы получить все коэффициенты с помощью следующей команды и написать запрос SQL, который вычисляет прогнозируемые значения из входных столбцов. Мне интересно, есть ли уже какие-нибудь пакеты, которые создают оптимизированный запрос на основе обученной модели? Я не мог найти ни одного. Я новичок в python, и это может быть моим плохим.

    model.state_dict()
    OrderedDict([('layer_1.weight',
              tensor([[-0.1187,  0.0713,  0.1447,  ...,  0.0503, -0.1552, -0.0629],
                      [ 0.1376, -0.1211, -0.1090,  ...,  0.1125, -0.1668,  0.0193],
                      [-0.1746, -0.0655,  0.1052,  ..., -0.1250, -0.1124, -0.0515],
                      ...,
                      [ 0.0479, -0.0183, -0.1080,  ..., -0.1622,  0.1207,  0.1146],
                      [ 0.1046,  0.0433,  0.1749,  ...,  0.0446,  0.0913,  0.0843],
                      [-0.1082, -0.0713,  0.0994,  ..., -0.0986,  0.0630,  0.0865]])),
             ('layer_1.bias',
              tensor([ 0.0023,  0.0534, -0.1129,  0.0610,  0.0618,  0.1039,  0.0941,  0.0618,
                       0.1807,  0.2457,  0.2924,  0.1461,  0.3209,  0.0682,  0.0405,  0.4555,
                       0.5450, -0.0808,  0.4360,  0.2562,  0.3564,  0.0427,  0.1197,  0.4128,
                      -0.1160,  0.1907,  0.4645,  0.2666,  0.0480, -0.0931, -0.1598, -0.2190,
                       0.1503, -0.3840,  0.3384, -0.0459, -0.2578, -0.0534, -0.1121, -0.2066,
                       0.3153,  0.2321, -0.1500,  0.1564,  0.1074, -0.1102,  0.1311,  0.3268,
                       0.1681,  0.1639,  0.3823, -0.5152,  0.0198, -0.1510,  0.0537, -0.1364,
                      -0.0089,  0.1701, -0.0288,  0.1713,  0.3223, -0.0856,  0.3055, -0.0700])),
             ('layer_2.weight',
              tensor([[-0.0250, -0.1085,  0.1491,  ...,  0.0290,  0.0445, -0.1272],
                      [ 0.0177, -0.0318,  0.0338,  ...,  0.0928, -0.0348,  0.0345],
                      [-0.0754,  0.0394,  0.0109,  ..., -0.0488, -0.2314, -0.0233],
                      ...,
                      [-0.0171,  0.1082, -0.0031,  ...,  0.0544, -0.1525, -0.1215],
                      [ 0.0234, -0.0068, -0.1358,  ...,  0.0090, -0.0280, -0.0791],
                      [-0.1413, -0.0118, -0.0166,  ..., -0.0565, -0.0039, -0.1394]])),
             ('layer_2.bias',
              tensor([-0.1474,  0.0213,  0.0038, -0.0093,  0.1257,  0.1641, -0.0737, -0.0422,
                       0.0148, -0.1041, -0.1546, -0.1056, -0.0281, -0.0253,  0.1511, -0.1159,
                      -0.1745,  0.0639,  0.1032,  0.1163,  0.0912, -0.0086,  0.0816,  0.1686,
                       0.2618, -0.0659,  0.0636,  0.1432, -0.0430,  0.0131, -0.0498, -0.0811,
                       0.0673,  0.1419,  0.0460, -0.0645,  0.1385, -0.1181, -0.0434, -0.0194,
                       0.1036,  0.2060,  0.1256,  0.0346, -0.0927,  0.0912, -0.0416, -0.0017,
                      -0.0777,  0.0535,  0.0889, -0.0997, -0.1336,  0.0257,  0.1941, -0.1289,
                       0.0919,  0.0788,  0.1582, -0.0674, -0.1670,  0.0683,  0.0939,  0.1309])),
             ('layer_out.weight',
              tensor([[-0.0312, -0.1084,  0.1952, -0.1502, -0.1233, -0.0898, -0.1678, -0.1846,
                        0.2139, -0.2063, -0.0643,  0.2050, -0.1946,  0.1971,  0.1785, -0.1020,
                       -0.0392,  0.1839,  0.2022,  0.2153,  0.1727,  0.1685,  0.1376,  0.1861,
                       -0.0442,  0.1197,  0.1491,  0.2214,  0.2006,  0.1760, -0.0557, -0.1567,
                       -0.1566,  0.2069, -0.1395,  0.1566,  0.1437,  0.1537, -0.0494,  0.1463,
                        0.1348, -0.1276,  0.2090, -0.1367, -0.0331,  0.2099, -0.1707, -0.1543,
                       -0.2036,  0.2164,  0.1727,  0.1859,  0.0843,  0.1691, -0.1454,  0.1745,
                        0.1418,  0.2192, -0.0509,  0.1986, -0.0616,  0.1373,  0.1681, -0.1450]])),
             ('layer_out.bias', tensor([-0.1983])),
             ('batchnorm1.weight',
              tensor([0.9681, 0.8976, 0.9302, 0.9252, 1.1069, 0.9596, 0.9212, 1.0008, 0.9581,
                      0.9985, 1.0053, 0.9333, 0.9949, 0.8464, 0.8386, 1.3050, 0.9928, 0.9313,
                      1.0727, 0.8547, 1.0248, 0.9584, 0.9896, 1.0563, 0.9942, 0.9790, 1.0364,
                      1.0379, 1.0368, 0.9663, 1.0447, 1.0532, 1.0298, 1.0858, 1.0042, 0.9484,
                      0.8282, 1.0391, 0.9396, 0.8617, 1.1749, 0.9258, 1.0064, 1.0406, 1.1507,
                      0.9748, 1.0200, 1.0357, 0.9871, 0.8859, 0.9563, 0.9667, 0.9147, 0.9649,
                      0.8234, 1.0417, 0.9108, 0.8871, 0.9405, 1.0084, 0.8264, 1.0113, 1.0314,
                      0.9494])),
             ('batchnorm1.bias',
              tensor([-0.0311,  0.0937, -0.0572,  0.0309,  0.1525,  0.0575,  0.0393,  0.0569,
                       0.1252,  0.0776,  0.0863,  0.0097,  0.0777,  0.0121,  0.1351,  0.0913,
                      -0.0844,  0.0135,  0.0087,  0.0366,  0.1473,  0.1491,  0.0581,  0.0704,
                       0.0339, -0.0151,  0.1697, -0.0004, -0.0130, -0.0198,  0.1492,  0.0666,
                       0.0219,  0.1564,  0.0691, -0.0296, -0.0014,  0.1042,  0.0606, -0.0546,
                       0.0678, -0.0492,  0.1035, -0.0435,  0.1125,  0.0245,  0.0874,  0.0802,
                       0.0661, -0.0773,  0.0834,  0.0411, -0.0037,  0.1110,  0.0466,  0.0342,
                       0.0387, -0.0335, -0.0834,  0.0792, -0.0423,  0.0762,  0.0753, -0.0740])),
             ('batchnorm1.running_mean',
              tensor([1.3067e+06, 6.0457e+05, 8.2752e+05, 4.4208e+02, 8.2479e+03, 1.4389e+04,
                      1.0922e+06, 1.8224e+06, 1.7183e+05, 2.7979e+03, 1.0712e+06, 1.4432e+06,
                      2.5499e+02, 4.4386e-01, 2.5277e+05, 1.0932e+02, 6.4647e+01, 1.3927e+06,
                      7.2479e+02, 8.9471e+05, 2.1874e+05, 4.4251e+05, 8.6447e+05, 3.1026e+03,
                      2.4198e+05, 5.0970e+05, 2.9184e+02, 1.1396e+04, 3.1564e+05, 3.0748e+05,
                      1.6696e+04, 2.2066e+03, 2.2857e+05, 1.7602e+04, 9.2798e+03, 1.5475e+06,
                      1.0197e+00, 9.4096e+03, 1.1205e+06, 1.7919e+05, 2.9698e+01, 1.4211e+04,
                      4.6715e+01, 4.5771e+05, 2.5411e+05, 1.3596e+06, 7.1259e+05, 1.7534e+03,
                      4.2882e+05, 1.5287e+06, 3.4151e+05, 1.0956e+05, 1.3922e+06, 4.2413e+05,
                      2.5614e+05, 1.3156e+04, 7.8092e+05, 1.5409e+06, 1.0130e+06, 2.1385e+03,
                      1.1631e+02, 8.0586e+03, 1.5317e+02, 1.0621e+06])),
             ('batchnorm1.running_var',
              tensor([4.4985e+14, 1.0021e+14, 1.7813e+14, 2.1193e+09, 7.4105e+08, 2.6610e+09,
                      3.1982e+14, 8.6328e+14, 7.5264e+12, 1.0338e+08, 2.9778e+14, 5.4661e+14,
                      6.5767e+08, 3.6646e+04, 2.0381e+13, 1.1757e+07, 3.4102e+05, 5.0417e+14,
                      7.2898e+06, 2.1524e+14, 1.2777e+13, 5.7494e+13, 2.0215e+14, 1.0316e+08,
                      1.3597e+13, 7.1202e+13, 5.0234e+06, 1.7027e+09, 2.7486e+13, 2.6450e+13,
                      3.4578e+09, 5.0856e+07, 1.3995e+13, 4.0587e+09, 1.1558e+09, 6.3731e+14,
                      2.3907e+05, 1.1756e+09, 3.2539e+14, 1.3262e+13, 5.6695e+05, 2.6482e+09,
                      5.5276e+04, 5.4737e+13, 1.7663e+13, 4.8439e+14, 1.3414e+14, 6.6979e+07,
                      5.3276e+13, 6.2247e+14, 3.3580e+13, 5.2394e+12, 5.0920e+14, 4.9142e+13,
                      1.8105e+13, 2.3170e+09, 1.7177e+14, 6.3223e+14, 2.6651e+14, 1.3704e+08,
                      1.3324e+07, 8.8097e+08, 2.3162e+07, 2.9167e+14])),
             ('batchnorm1.num_batches_tracked', tensor(2528)),
             ('batchnorm2.weight',
              tensor([0.9045, 0.9330, 1.0603, 1.0511, 1.0862, 0.9880, 1.0687, 1.1034, 1.0879,
                      1.0421, 0.9636, 1.0854, 1.0573, 1.0643, 1.0651, 0.9512, 0.9282, 1.0023,
                      1.0907, 1.1078, 1.1396, 1.1438, 1.1421, 1.0270, 0.9495, 0.9524, 1.1532,
                      1.0949, 1.0830, 1.1749, 0.9481, 1.0335, 1.0247, 1.1089, 1.1051, 1.1001,
                      0.9581, 1.0120, 0.9528, 1.0461, 1.1699, 1.0827, 1.1184, 1.0253, 0.9057,
                      1.0708, 1.1039, 1.0631, 1.1334, 1.1222, 1.1016, 1.0941, 1.0189, 1.1003,
                      0.9773, 1.1064, 1.1996, 1.0941, 0.9677, 1.1146, 0.9070, 1.1774, 1.1355,
                      1.0127])),
             ('batchnorm2.bias',
              tensor([ 0.1404,  0.1515, -0.1751,  0.3049,  0.2871,  0.2794,  0.2921,  0.3167,
                      -0.1752,  0.2629,  0.1687, -0.1806,  0.2781, -0.1863, -0.1931,  0.1724,
                       0.1743, -0.1832, -0.2028, -0.1754, -0.2232, -0.2855, -0.2497, -0.1836,
                       0.2784, -0.2207, -0.2724, -0.1660, -0.1817, -0.2773,  0.1449,  0.3477,
                       0.2758, -0.1826,  0.3957, -0.1962, -0.1813, -0.1976,  0.1568, -0.2148,
                      -0.3069,  0.3683, -0.2152,  0.2839,  0.1459, -0.1983,  0.3639,  0.2938,
                       0.3777, -0.1976, -0.2180, -0.1797, -0.3297, -0.1998,  0.2473, -0.1974,
                      -0.3117, -0.1649,  0.2007, -0.2020,  0.1470, -0.3560, -0.2303,  0.2384])),
             ('batchnorm2.running_mean',
              tensor([0.0818, 0.2604, 0.2367, 0.2034, 0.2625, 0.3723, 0.0721, 0.1728, 0.1317,
                      0.1797, 0.2480, 0.2488, 0.0833, 0.2430, 0.3321, 0.1704, 0.2857, 0.3144,
                      0.2553, 0.2646, 0.2861, 0.2463, 0.2994, 0.3364, 0.9763, 0.1421, 0.1485,
                      0.2862, 0.2331, 0.2600, 0.1732, 0.1903, 0.2381, 0.3158, 0.2876, 0.1270,
                      0.2801, 0.3171, 0.2028, 0.2451, 0.2338, 0.5979, 0.2338, 0.2817, 0.1630,
                      0.3428, 0.0794, 0.2882, 0.1878, 0.2208, 0.2140, 0.1691, 0.1626, 0.2396,
                      0.4150, 0.0935, 0.2734, 0.3051, 0.2979, 0.1538, 0.1558, 0.2688, 0.2890,
                      0.3867])),
             ('batchnorm2.running_var',
              tensor([0.2264, 1.2192, 0.6055, 0.0141, 0.1366, 0.0474, 0.0718, 0.0178, 0.2406,
                      0.0279, 0.8640, 0.8436, 0.0045, 0.7779, 0.9837, 0.6758, 1.4899, 0.6993,
                      0.4628, 0.4569, 0.9133, 0.8354, 0.8984, 0.6698, 0.0982, 0.1073, 0.1528,
                      0.6010, 0.7083, 0.8089, 0.2064, 0.0199, 0.0301, 1.0915, 0.0252, 0.2126,
                      0.4634, 1.3170, 0.0137, 0.7959, 0.4566, 0.0918, 0.5987, 0.0223, 1.7209,
                      0.5086, 0.0030, 0.0365, 0.0204, 0.5189, 0.4973, 0.3750, 0.0945, 0.7348,
                      0.0676, 0.1096, 0.8977, 0.9489, 0.0863, 0.3033, 1.1825, 0.7292, 0.8726,
                      0.0234])),
             ('batchnorm2.num_batches_tracked', tensor(2528))])
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...