Рассчитать рекурсивный EMA с периодом горения в Esper - PullRequest
0 голосов
/ 05 марта 2020

В качестве упражнения я пытаюсь вычислить рекурсивную EMA с периодом записи в Esper, EPL. У него умеренно сложная логика запуска c, и я подумал, что это будет хорошим тестом для оценки того, чего может достичь Эспер.

Предполагая поток значений x1, x2, x3 через равные промежутки времени мы хотим вычислить:

let p = 0.1
a = average(x1, x2, x3, x4, x5)    // Assume 5, in reality use a parameter
y1 = p * x1 + (p - 1) * a          // Recursive calculation initialized with look-ahead average
y2 = p * x2 + (p - 1) * y1
y3 = p * x3 + (p - 1) * y2
   ....

The final stream should only publish y5, y6, y7, ...

Я играл с контекстом, который генерирует событие, содержащее среднее значение a, и это событие вызывает второй контекст, который начинает рекурсивные вычисления. Но к тому времени, когда я пытаюсь заставить первый контекст запускаться только один раз и второй контекст, чтобы обрабатывать начальный случай, используя a и рекурсивные последующие события, я получаю грязный клубок логики c.

Есть ли прямой способ решения этой проблемы?

(я игнорирую использование пользовательского агрегатора, так как это учебное упражнение)

1 Ответ

0 голосов
/ 05 марта 2020

Это не отвечает на вопрос, но может быть полезно - реализация как пользовательская функция агрегирования, протестированная с esper 7.1.0

public class EmaFactory implements AggregationFunctionFactory {

    int burn = 0;

    @Override
    public void setFunctionName(String s) {
        // Don't know why/when this is called
    }

    @Override
    public void validate(AggregationValidationContext ctx) {
        @SuppressWarnings("rawtypes")
        Class[] p = ctx.getParameterTypes();
        if ((p.length != 3)) {
            throw new IllegalArgumentException(String.format(
                "Ema aggregation required three parameters, received %d",
                p.length));
        }

        if (
            !(
                (p[0] == Double.class || p[0] == double.class) ||
                    (p[1] == Double.class || p[1] == double.class) ||
                    (p[2] == Integer.class || p[2] == int.class))) {
            throw new IllegalArgumentException(
                String.format(
                    "Arguments to Ema aggregation must of types (Double, Double, Integer), got (%s, %s, %s)\n",
                    p[0].getName(), p[1].getName(), p[2].getName()) +
                    "This should be made nicer, see AggregationMethodFactorySum.java in the Esper source code for " +
                    "examples of correctly dealing with multiple types"
            );
        }

        if (!ctx.getIsConstantValue()[2]) {
            throw new IllegalArgumentException(
                "Third argument 'burn' to Ema aggregation must be constant"
            );
        }
        ;


        burn = (int) ctx.getConstantValues()[2];
    }

    @Override
    public AggregationMethod newAggregator() {
        return new EmaAggregationFunction(burn);
    }

    @SuppressWarnings("rawtypes")
    @Override
    public Class getValueType() {
        return Double.class;
    }
}


public class EmaAggregationFunction implements AggregationMethod {

    final private int burnLength;
    private double[] burnValues;
    private int count = 0;
    private double value = 0.;

    EmaAggregationFunction(int burn) {
        this.burnLength = burn;
        this.burnValues = new double[burn];
    }

    private void update(double x, double alpha) {
        if (count < burnLength) {
            value += x;
            burnValues[count++] = x;

            if (count == burnLength) {
                value /= count;
                for (double v : burnValues) {
                    value = alpha * v + (1 - alpha) * value;
                }
                // in case burn is long, free memory
                burnValues = null;
            }
        } else {
            value = alpha * x + (1 - alpha) * value;
        }
    }

    @Override
    public void enter(Object tmp) {
        Object[] o = (Object[]) tmp;
        assert o[0] != null;
        assert o[1] != null;
        assert o[2] != null;
        assert (int) o[2] == burnLength;
        update((double) o[0], (double) o[1]);
    }

    @Override
    public void leave(Object o) {

    }

    @Override
    public Object getValue() {
        if (count < burnLength) {
            return null;
        } else {
            return value;
        }
    }

    @Override
    public void clear() {
        // I don't know when / why this is called - this part untested
        count = 0;
        value = 0.;
        burnValues = new double[burnLength];
    }
}


public class TestEmaAggregation {
    private EPRuntime epRuntime;
    private SupportUpdateListener listener = new SupportUpdateListener();

    void send(int id, double value) {
        epRuntime.sendEvent(
            new HashMap<String, Object>() {{
                put("id", id);
                put("value", value);
            }},
        "CalculationEvent");
    }

    @BeforeEach
    public void beforeEach() {
        EPServiceProvider provider = EPServiceProviderManager.getDefaultProvider();
        EPAdministrator epAdministrator = provider.getEPAdministrator();
        epRuntime = provider.getEPRuntime();

        ConfigurationOperations config = epAdministrator.getConfiguration();
        config.addPlugInAggregationFunctionFactory("ema", EmaFactory.class.getName());
        config.addEventType(
        "CalculationEvent",
            new HashMap<String, Object>() {{ put("id", Integer.class); put("value", Double.class); }}
        );
        EPStatement stmt = epAdministrator.createEPL("select ema(value, 0.1, 5) as ema from CalculationEvent where value is not null");
        stmt.addListener(listener);
    }

    Double getEma() {
        return (Double)listener.assertOneGetNewAndReset().get("ema");
    }

    @Test
    public void someTest() {
        send(1, 1);
        assertEquals(null, getEma());
        send(1, 2);
        assertEquals(null, getEma());
        send(1, 3);
        assertEquals(null, getEma());
        send(1, 4);
        assertEquals(null, getEma());

        // Last of the burn period
        // We expect:
        // a = (1+2+3+4+5) / 5 = 3
        // y1 = 0.1 * 1 + 0.9 * 3 = 2.8
        // y2 = 0.1 * 2 + 0.9 * 2.8
        //    ... leading to
        // y5 = 3.08588
        send(1, 5);
        assertEquals(3.08588, getEma(), 1e-10);

        // Outside burn period
        send(1, 6);
        assertEquals(3.377292, getEma(), 1e-10);
        send(1, 7);
        assertEquals(3.7395628, getEma(), 1e-10);
        send(1, 8);
        assertEquals(4.16560652, getEma(), 1e-10);
    }
}
...