Я обучил модель MLP с использованием PyTorch. Я мог бы получить все коэффициенты с помощью следующей команды и написать запрос SQL, который вычисляет прогнозируемые значения из входных столбцов. Мне интересно, есть ли уже какие-нибудь пакеты, которые создают оптимизированный запрос на основе обученной модели? Я не мог найти ни одного. Я новичок в python, и это может быть моим плохим.
model.state_dict()
OrderedDict([('layer_1.weight',
tensor([[-0.1187, 0.0713, 0.1447, ..., 0.0503, -0.1552, -0.0629],
[ 0.1376, -0.1211, -0.1090, ..., 0.1125, -0.1668, 0.0193],
[-0.1746, -0.0655, 0.1052, ..., -0.1250, -0.1124, -0.0515],
...,
[ 0.0479, -0.0183, -0.1080, ..., -0.1622, 0.1207, 0.1146],
[ 0.1046, 0.0433, 0.1749, ..., 0.0446, 0.0913, 0.0843],
[-0.1082, -0.0713, 0.0994, ..., -0.0986, 0.0630, 0.0865]])),
('layer_1.bias',
tensor([ 0.0023, 0.0534, -0.1129, 0.0610, 0.0618, 0.1039, 0.0941, 0.0618,
0.1807, 0.2457, 0.2924, 0.1461, 0.3209, 0.0682, 0.0405, 0.4555,
0.5450, -0.0808, 0.4360, 0.2562, 0.3564, 0.0427, 0.1197, 0.4128,
-0.1160, 0.1907, 0.4645, 0.2666, 0.0480, -0.0931, -0.1598, -0.2190,
0.1503, -0.3840, 0.3384, -0.0459, -0.2578, -0.0534, -0.1121, -0.2066,
0.3153, 0.2321, -0.1500, 0.1564, 0.1074, -0.1102, 0.1311, 0.3268,
0.1681, 0.1639, 0.3823, -0.5152, 0.0198, -0.1510, 0.0537, -0.1364,
-0.0089, 0.1701, -0.0288, 0.1713, 0.3223, -0.0856, 0.3055, -0.0700])),
('layer_2.weight',
tensor([[-0.0250, -0.1085, 0.1491, ..., 0.0290, 0.0445, -0.1272],
[ 0.0177, -0.0318, 0.0338, ..., 0.0928, -0.0348, 0.0345],
[-0.0754, 0.0394, 0.0109, ..., -0.0488, -0.2314, -0.0233],
...,
[-0.0171, 0.1082, -0.0031, ..., 0.0544, -0.1525, -0.1215],
[ 0.0234, -0.0068, -0.1358, ..., 0.0090, -0.0280, -0.0791],
[-0.1413, -0.0118, -0.0166, ..., -0.0565, -0.0039, -0.1394]])),
('layer_2.bias',
tensor([-0.1474, 0.0213, 0.0038, -0.0093, 0.1257, 0.1641, -0.0737, -0.0422,
0.0148, -0.1041, -0.1546, -0.1056, -0.0281, -0.0253, 0.1511, -0.1159,
-0.1745, 0.0639, 0.1032, 0.1163, 0.0912, -0.0086, 0.0816, 0.1686,
0.2618, -0.0659, 0.0636, 0.1432, -0.0430, 0.0131, -0.0498, -0.0811,
0.0673, 0.1419, 0.0460, -0.0645, 0.1385, -0.1181, -0.0434, -0.0194,
0.1036, 0.2060, 0.1256, 0.0346, -0.0927, 0.0912, -0.0416, -0.0017,
-0.0777, 0.0535, 0.0889, -0.0997, -0.1336, 0.0257, 0.1941, -0.1289,
0.0919, 0.0788, 0.1582, -0.0674, -0.1670, 0.0683, 0.0939, 0.1309])),
('layer_out.weight',
tensor([[-0.0312, -0.1084, 0.1952, -0.1502, -0.1233, -0.0898, -0.1678, -0.1846,
0.2139, -0.2063, -0.0643, 0.2050, -0.1946, 0.1971, 0.1785, -0.1020,
-0.0392, 0.1839, 0.2022, 0.2153, 0.1727, 0.1685, 0.1376, 0.1861,
-0.0442, 0.1197, 0.1491, 0.2214, 0.2006, 0.1760, -0.0557, -0.1567,
-0.1566, 0.2069, -0.1395, 0.1566, 0.1437, 0.1537, -0.0494, 0.1463,
0.1348, -0.1276, 0.2090, -0.1367, -0.0331, 0.2099, -0.1707, -0.1543,
-0.2036, 0.2164, 0.1727, 0.1859, 0.0843, 0.1691, -0.1454, 0.1745,
0.1418, 0.2192, -0.0509, 0.1986, -0.0616, 0.1373, 0.1681, -0.1450]])),
('layer_out.bias', tensor([-0.1983])),
('batchnorm1.weight',
tensor([0.9681, 0.8976, 0.9302, 0.9252, 1.1069, 0.9596, 0.9212, 1.0008, 0.9581,
0.9985, 1.0053, 0.9333, 0.9949, 0.8464, 0.8386, 1.3050, 0.9928, 0.9313,
1.0727, 0.8547, 1.0248, 0.9584, 0.9896, 1.0563, 0.9942, 0.9790, 1.0364,
1.0379, 1.0368, 0.9663, 1.0447, 1.0532, 1.0298, 1.0858, 1.0042, 0.9484,
0.8282, 1.0391, 0.9396, 0.8617, 1.1749, 0.9258, 1.0064, 1.0406, 1.1507,
0.9748, 1.0200, 1.0357, 0.9871, 0.8859, 0.9563, 0.9667, 0.9147, 0.9649,
0.8234, 1.0417, 0.9108, 0.8871, 0.9405, 1.0084, 0.8264, 1.0113, 1.0314,
0.9494])),
('batchnorm1.bias',
tensor([-0.0311, 0.0937, -0.0572, 0.0309, 0.1525, 0.0575, 0.0393, 0.0569,
0.1252, 0.0776, 0.0863, 0.0097, 0.0777, 0.0121, 0.1351, 0.0913,
-0.0844, 0.0135, 0.0087, 0.0366, 0.1473, 0.1491, 0.0581, 0.0704,
0.0339, -0.0151, 0.1697, -0.0004, -0.0130, -0.0198, 0.1492, 0.0666,
0.0219, 0.1564, 0.0691, -0.0296, -0.0014, 0.1042, 0.0606, -0.0546,
0.0678, -0.0492, 0.1035, -0.0435, 0.1125, 0.0245, 0.0874, 0.0802,
0.0661, -0.0773, 0.0834, 0.0411, -0.0037, 0.1110, 0.0466, 0.0342,
0.0387, -0.0335, -0.0834, 0.0792, -0.0423, 0.0762, 0.0753, -0.0740])),
('batchnorm1.running_mean',
tensor([1.3067e+06, 6.0457e+05, 8.2752e+05, 4.4208e+02, 8.2479e+03, 1.4389e+04,
1.0922e+06, 1.8224e+06, 1.7183e+05, 2.7979e+03, 1.0712e+06, 1.4432e+06,
2.5499e+02, 4.4386e-01, 2.5277e+05, 1.0932e+02, 6.4647e+01, 1.3927e+06,
7.2479e+02, 8.9471e+05, 2.1874e+05, 4.4251e+05, 8.6447e+05, 3.1026e+03,
2.4198e+05, 5.0970e+05, 2.9184e+02, 1.1396e+04, 3.1564e+05, 3.0748e+05,
1.6696e+04, 2.2066e+03, 2.2857e+05, 1.7602e+04, 9.2798e+03, 1.5475e+06,
1.0197e+00, 9.4096e+03, 1.1205e+06, 1.7919e+05, 2.9698e+01, 1.4211e+04,
4.6715e+01, 4.5771e+05, 2.5411e+05, 1.3596e+06, 7.1259e+05, 1.7534e+03,
4.2882e+05, 1.5287e+06, 3.4151e+05, 1.0956e+05, 1.3922e+06, 4.2413e+05,
2.5614e+05, 1.3156e+04, 7.8092e+05, 1.5409e+06, 1.0130e+06, 2.1385e+03,
1.1631e+02, 8.0586e+03, 1.5317e+02, 1.0621e+06])),
('batchnorm1.running_var',
tensor([4.4985e+14, 1.0021e+14, 1.7813e+14, 2.1193e+09, 7.4105e+08, 2.6610e+09,
3.1982e+14, 8.6328e+14, 7.5264e+12, 1.0338e+08, 2.9778e+14, 5.4661e+14,
6.5767e+08, 3.6646e+04, 2.0381e+13, 1.1757e+07, 3.4102e+05, 5.0417e+14,
7.2898e+06, 2.1524e+14, 1.2777e+13, 5.7494e+13, 2.0215e+14, 1.0316e+08,
1.3597e+13, 7.1202e+13, 5.0234e+06, 1.7027e+09, 2.7486e+13, 2.6450e+13,
3.4578e+09, 5.0856e+07, 1.3995e+13, 4.0587e+09, 1.1558e+09, 6.3731e+14,
2.3907e+05, 1.1756e+09, 3.2539e+14, 1.3262e+13, 5.6695e+05, 2.6482e+09,
5.5276e+04, 5.4737e+13, 1.7663e+13, 4.8439e+14, 1.3414e+14, 6.6979e+07,
5.3276e+13, 6.2247e+14, 3.3580e+13, 5.2394e+12, 5.0920e+14, 4.9142e+13,
1.8105e+13, 2.3170e+09, 1.7177e+14, 6.3223e+14, 2.6651e+14, 1.3704e+08,
1.3324e+07, 8.8097e+08, 2.3162e+07, 2.9167e+14])),
('batchnorm1.num_batches_tracked', tensor(2528)),
('batchnorm2.weight',
tensor([0.9045, 0.9330, 1.0603, 1.0511, 1.0862, 0.9880, 1.0687, 1.1034, 1.0879,
1.0421, 0.9636, 1.0854, 1.0573, 1.0643, 1.0651, 0.9512, 0.9282, 1.0023,
1.0907, 1.1078, 1.1396, 1.1438, 1.1421, 1.0270, 0.9495, 0.9524, 1.1532,
1.0949, 1.0830, 1.1749, 0.9481, 1.0335, 1.0247, 1.1089, 1.1051, 1.1001,
0.9581, 1.0120, 0.9528, 1.0461, 1.1699, 1.0827, 1.1184, 1.0253, 0.9057,
1.0708, 1.1039, 1.0631, 1.1334, 1.1222, 1.1016, 1.0941, 1.0189, 1.1003,
0.9773, 1.1064, 1.1996, 1.0941, 0.9677, 1.1146, 0.9070, 1.1774, 1.1355,
1.0127])),
('batchnorm2.bias',
tensor([ 0.1404, 0.1515, -0.1751, 0.3049, 0.2871, 0.2794, 0.2921, 0.3167,
-0.1752, 0.2629, 0.1687, -0.1806, 0.2781, -0.1863, -0.1931, 0.1724,
0.1743, -0.1832, -0.2028, -0.1754, -0.2232, -0.2855, -0.2497, -0.1836,
0.2784, -0.2207, -0.2724, -0.1660, -0.1817, -0.2773, 0.1449, 0.3477,
0.2758, -0.1826, 0.3957, -0.1962, -0.1813, -0.1976, 0.1568, -0.2148,
-0.3069, 0.3683, -0.2152, 0.2839, 0.1459, -0.1983, 0.3639, 0.2938,
0.3777, -0.1976, -0.2180, -0.1797, -0.3297, -0.1998, 0.2473, -0.1974,
-0.3117, -0.1649, 0.2007, -0.2020, 0.1470, -0.3560, -0.2303, 0.2384])),
('batchnorm2.running_mean',
tensor([0.0818, 0.2604, 0.2367, 0.2034, 0.2625, 0.3723, 0.0721, 0.1728, 0.1317,
0.1797, 0.2480, 0.2488, 0.0833, 0.2430, 0.3321, 0.1704, 0.2857, 0.3144,
0.2553, 0.2646, 0.2861, 0.2463, 0.2994, 0.3364, 0.9763, 0.1421, 0.1485,
0.2862, 0.2331, 0.2600, 0.1732, 0.1903, 0.2381, 0.3158, 0.2876, 0.1270,
0.2801, 0.3171, 0.2028, 0.2451, 0.2338, 0.5979, 0.2338, 0.2817, 0.1630,
0.3428, 0.0794, 0.2882, 0.1878, 0.2208, 0.2140, 0.1691, 0.1626, 0.2396,
0.4150, 0.0935, 0.2734, 0.3051, 0.2979, 0.1538, 0.1558, 0.2688, 0.2890,
0.3867])),
('batchnorm2.running_var',
tensor([0.2264, 1.2192, 0.6055, 0.0141, 0.1366, 0.0474, 0.0718, 0.0178, 0.2406,
0.0279, 0.8640, 0.8436, 0.0045, 0.7779, 0.9837, 0.6758, 1.4899, 0.6993,
0.4628, 0.4569, 0.9133, 0.8354, 0.8984, 0.6698, 0.0982, 0.1073, 0.1528,
0.6010, 0.7083, 0.8089, 0.2064, 0.0199, 0.0301, 1.0915, 0.0252, 0.2126,
0.4634, 1.3170, 0.0137, 0.7959, 0.4566, 0.0918, 0.5987, 0.0223, 1.7209,
0.5086, 0.0030, 0.0365, 0.0204, 0.5189, 0.4973, 0.3750, 0.0945, 0.7348,
0.0676, 0.1096, 0.8977, 0.9489, 0.0863, 0.3033, 1.1825, 0.7292, 0.8726,
0.0234])),
('batchnorm2.num_batches_tracked', tensor(2528))])