Произведение всех узлов на пути дерева - PullRequest
1 голос
/ 11 апреля 2020

Я изучал алгоритм МО. В этом я нашел вопрос. В которой мы должны создать программу, которая будет принимать входные данные n для n узлов дерева, а затем n-1 пар u и v, обозначающих связь между узлом u и узлом v. После этого получим значения n узлов.

Тогда мы зададим q запросов. Для каждого запроса мы берем ввод k и l, которые обозначают два узла этого дерева. Теперь нам нужно найти произведение всех узлов на пути k и l (включая k и l).

Я хочу использовать алгоритм МО. https://codeforces.com/blog/entry/43230

Но я не могу сделать код. Кто-нибудь может мне помочь в этом.

Базовый c код для этого будет:

int n, q;
int nxt[ N ], to[ N ], hd[ N ];

struct Que{
    int u, v, id;
} que[ N ];

void init() {
    // read how many nodes and how many queries
    cin >> n >> q;
    // read the edge of tree
    for ( int i = 1 ; i < n ; ++ i ) {
        int u, v; cin >> u >> v;
        // save the tree using adjacency list
        nxt[ i << 1 | 0 ] = hd[ u ];
        to[ i << 1 | 0 ] = v;
        hd[ u ] = i << 1 | 0;

        nxt[ i << 1 | 1 ] = hd[ v ];
        to[ i << 1 | 1 ] = u;
        hd[ v ] = i << 1 | 1;
    }

    for ( int i = 0 ; i < q ; ++ i ) {
        // read queries
        cin >> que[ i ].u >> que[ i ].v;
        que[ i ].id = i;
    }
}

int dfn[ N ], dfn_, block_id[ N ], block_;

int stk[ N ], stk_;

void dfs( int u, int f ) {
    dfn[ u ] = dfn_++;

    int saved_rbp = stk_;

    for ( int v_ = hd[ u ] ; v_ ; v_ = nxt[ v_ ] ) {
        if ( to[ v_ ] == f ) continue;
        dfs( to[ v_ ], u );
        if ( stk_ - saved_rbp < SQRT_N ) continue;
        for ( ++ block_ ; stk_ != saved_rbp ; )
             block_id[ stk[ -- stk_ ] ] = block_;
    }

    stk[ stk_ ++ ] = u;
}

bool inPath[ N ];

void SymmetricDifference( int u ) {
    if ( inPath[ u ] ) {
        // remove this edge
    } else {
        // add this edge
    }
    inPath[ u ] ^= 1;
}
void traverse( int& origin_u, int u ) {
    for ( int g = lca( origin_u, u ) ; origin_u != g ; origin_u = parent_of[ origin_u ] )
        SymmetricDifference( origin_u );
    for ( int v = u ; v != origin_u ; v = parent_of[ v ] )
        SymmetricDifference( v );
    origin_u = u;
}

void solve() {
    // construct blocks using dfs
    dfs( 1, 1 );
    while ( stk_ ) block_id[ stk[ -- stk_ ] ] = block_;
    // re-order our queries
    sort( que, que + q, [] ( const Que& x, const Que& y ) {
        return tie( block_id[ x.u ], dfn[ x.v ] ) < tie( block_id[ y.u ], dfn[ y.v ] );
    } );
    // apply mo's algorithm on tree
    int U = 1, V = 1;
    for ( int i = 0 ; i < q ; ++ i ) {
        pass( U, que[ i ].u );
        pass( V, que[ i ].v );
        // we could our answer of que[ i ].id
    }
}

Ответы [ 2 ]

1 голос
/ 13 апреля 2020

Эта проблема является небольшой модификацией блога, которым вы поделились.

Теги задачи: - Алгоритм MO, деревья, LCA, бинарный подъем, сито, предварительное вычисление, простые множители

Предварительные вычисления: - Просто нам нужно сделать некоторые предварительные вычисления с seiveOfErothenesis , чтобы сохранить наивысший простой множитель каждого элемента, возможный во входных ограничениях. Затем, используя это, мы будем хранить все основные факторы и их мощности для каждого элемента во входном массиве в другой матрице.

Наблюдение: - с ограничениями, которые вы видите, их может быть очень мало такие простые числа возможны для каждого элемента. Для элемента (10 ^ 6) может быть максимально 7 простых факторов.

Изменить MO Al go В блоге: - Теперь в нашем методе вычислений нам просто нужно сохранить карту, в которой будет храниться текущий счетчик основного множителя. При добавлении или вычитании каждого элемента при решении запросов мы будем перебирать основные факторы этого элемента и делить наш результат (сохраняя общее число факторов) на старое число этого простого числа, а затем обновлять счет этого простого и кратного нашего результата с новым счетом. (Это будет O (7) max для каждого сложения / вычитания).

Сложность: - O (T * (( N + Q) * sqrt (N) * F)) где F - 7 в нашем случае. F - сложность вашего метода проверки ().

  • T - количество тестовых примеров во входном файле.
  • N - размер входного массива.
  • Q - количество запросов.

Ниже приведена реализация вышеуказанного подхода в JAVA. computePrimePowers () и check () - это те методы, которые вам интересны.

import java.util.*;
import java.io.*;

public class Main {

    static int BLOCK_SIZE;
    static int ar[];
    static ArrayList<Integer> graph[];
    static StringBuffer sb = new StringBuffer();

    static boolean notPrime[] = new boolean[1000001];
    static int hpf[] = new int[1000001];
    static void seive(){
        notPrime[0] = true;
        notPrime[1] = true;
        for(int i = 2; i < 1000001; i++){
            if(!notPrime[i]){
                hpf[i] = i;
            for(int j = 2 * i; j < 1000001; j += i){
                notPrime[j] = true;
                hpf[j] = i;
            }
        }
    }
    }

    static long modI[] = new long[1000001];
    static void computeModI() {
        for(int i = 1; i < 1000001; i++) {
            modI[i] = pow(i, 1000000005);
        }
    }
    static long pow(long x, long y) { 
        if (y == 0) 
            return 1; 

        long p = pow(x, y / 2);
        p = (p >= 1000000007) ? p % 1000000007 : p;
        p = p * p;
        p = (p >= 1000000007) ? p % 1000000007 : p;

        if ((y & 1) == 0) 
            return p; 
        else {
            long tt = x * p;
            return (tt >= 1000000007) ? tt % 1000000007 : tt; 
        }
    }

    public static void main(String[] args) throws Exception {
        Reader s = new Reader();
        int test = s.nextInt();
        seive();
        computeModI();
        for(int ii = 0; ii < test; ii++){
            int n = s.nextInt();
            lcaTable = new int[19][n + 1];
            graph = new ArrayList[n + 1];
            arrPrimes = new int[n + 1][7][2];
            primeCnt = new int[1000001];
            visited = new int[n + 1];
            ar = new int[n + 1];
            for(int i = 0; i < graph.length; i++) graph[i] = new ArrayList<>();
            for(int i = 1; i < n; i++){
                int u = s.nextInt(), v = s.nextInt();
                graph[u].add(v);
                graph[v].add(u);
            }
            int ip = 1; while(ip <= n) ar[ip++] = s.nextInt();

            computePrimePowers();

            int q = s.nextInt();
            LVL = new int[n + 1];

            dfsTime = 0;
            dfs(1, -1);

            BLOCK_SIZE = (int) Math.sqrt(dfsTime);
            int Q[][] = new int[q][4];
            int i = 0;
            while(q-- > 0) {
                int u = s.nextInt(), v = s.nextInt();
                Q[i][0] = lca(u, v);
                if (l[u] > l[v]) {
                    int temp = u; u = v; v = temp;
                }
                if (Q[i][0] == u) {
                    Q[i][1] = l[u];
                    Q[i][2] = l[v];
                }
                else {
                    Q[i][1] = r[u]; // left at col1 in query
                    Q[i][2] = l[v]; // right at col2
                }
                Q[i][3] = i;
                i++;
            }
            Arrays.sort(Q, new Comparator<int[]>() {
                @Override
                public int compare(int[] x, int[] y) {
                    int block_x = (x[1] - 1) / (BLOCK_SIZE + 1);
                    int block_y = (y[1] - 1) / (BLOCK_SIZE + 1);
                    if(block_x != block_y)
                        return block_x - block_y;
                    return x[2] - y[2];
                }
            });
            solveQueries(Q);
        }
        System.out.println(sb);
    }

    static long res;
    private static void solveQueries(int [][] Q) {
        int M = Q.length;
        long results[] = new long[M];
        res = 1;
        int curL = Q[0][1], curR = Q[0][1] - 1;
        int i = 0;
        while(i < M){
            while (curL < Q[i][1]) check(ID[curL++]);
            while (curL > Q[i][1]) check(ID[--curL]);
            while (curR < Q[i][2]) check(ID[++curR]);
            while (curR > Q[i][2]) check(ID[curR--]);

            int u = ID[curL], v = ID[curR];

            if (Q[i][0] != u && Q[i][0] != v) check(Q[i][0]);

            results[Q[i][3]] = res;

            if (Q[i][0] != u && Q[i][0] != v) check(Q[i][0]);

            i++;
        }

        i = 0;
        while(i < M) sb.append(results[i++] + "\n");
    }

    static int visited[];
    static int primeCnt[];
    private static void check(int x) {
        if(visited[x] == 1){
            for(int i = 0; i < 7; i++) {
                int c = arrPrimes[x][i][1];
                int pp = arrPrimes[x][i][0];
                if(pp == 0) break;
                long tem = res * modI[primeCnt[pp] + 1];
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
                primeCnt[pp] -= c;
                tem = res * (primeCnt[pp] + 1);
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
            }
        }
        else if(visited[x] == 0){
            for(int i = 0; i < 7; i++) {
                int c = arrPrimes[x][i][1];
                int pp = arrPrimes[x][i][0];
                if(pp == 0) break;
                long tem = res * modI[primeCnt[pp] + 1];
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
                primeCnt[pp] += c;
                tem = res * (primeCnt[pp] + 1);
                res = (tem >= 1000000007) ? tem % 1000000007 : tem;
            }
        }
        visited[x] ^= 1;
    }

    static int arrPrimes[][][];
    static void computePrimePowers() {
        int n = arrPrimes.length;
        int i = 0;
        while(i < n) {
            int ele = ar[i];
            int k = 0;
            while(ele > 1) {
                int c = 0;
                int pp = hpf[ele];
                while(hpf[ele] == pp) {
                    c++; ele /= pp;
                }
                arrPrimes[i][k][0] = pp;
                arrPrimes[i][k][1] = c;
                k++;
            }
            i++;
        }
    }

    static int dfsTime;
    static int l[] = new int[1000001], r[] = new int[1000001], ID[] = new int[1000001], LVL[], lcaTable[][];
    static void dfs(int u, int p){
        l[u] = ++dfsTime;
        ID[dfsTime] = u;
        int i = 1;
        while(i < 19) {
            lcaTable[i][u] = lcaTable[i - 1][lcaTable[i - 1][u]];
            i++;
        }
        i = 0;
        while(i < graph[u].size()){
            int v = graph[u].get(i);
            i++;
            if (v == p) continue;
            LVL[v] = LVL[u] + 1;
            lcaTable[0][v] = u;
            dfs(v, u);
        }
        r[u] = ++dfsTime;
        ID[dfsTime] = u;
    }

    static int lca(int u, int v){
        if (LVL[u] > LVL[v]) {
            int temp = u;
            u = v; v = temp;
        }
        int i = 18;
        while(i >= 0) {
            if (LVL[v] - (1 << i) >= LVL[u]) v = lcaTable[i][v];
            i--;
        }

        if (u == v) return u;

        i = 18;
        while(i >= 0){
            if (lcaTable[i][u] != lcaTable[i][v]){
                u = lcaTable[i][u];
                v = lcaTable[i][v];
            }
            i--;
        }
        return lcaTable[0][u];
    }
}
0 голосов
/ 11 апреля 2020
// SIMILAR SOLUTION FOR FINDING NUMBER OF DISTINCT ELEMENTS FROM U TO V
// USING MO's ALGORITHM
#include <bits/stdc++.h>
using namespace std;

const int MAXN = 40005;
const int MAXM = 100005;
const int LN = 19;

int N, M, K, cur, A[MAXN], LVL[MAXN], DP[LN][MAXN];
int BL[MAXN << 1], ID[MAXN << 1], VAL[MAXN], ANS[MAXM];
int d[MAXN], l[MAXN], r[MAXN];
bool VIS[MAXN];
vector < int > adjList[MAXN];

struct query{
    int id, l, r, lc;
    bool operator < (const query& rhs){
        return (BL[l] == BL[rhs.l]) ? (r < rhs.r) : (BL[l] < BL[rhs.l]);
    }
}Q[MAXM];

// Set up Stuff
void dfs(int u, int par){
    l[u] = ++cur; 
    ID[cur] = u;
    for (int i = 1; i < LN; i++) DP[i][u] = DP[i - 1][DP[i - 1][u]];
    for (int i = 0; i < adjList[u].size(); i++){
        int v = adjList[u][i];
        if (v == par) continue;
        LVL[v] = LVL[u] + 1;
        DP[0][v] = u;
        dfs(v, u);
    }
    r[u] = ++cur; ID[cur] = u;
}

// Function returns lca of (u) and (v)
inline int lca(int u, int v){
    if (LVL[u] > LVL[v]) swap(u, v);
    for (int i = LN - 1; i >= 0; i--)
        if (LVL[v] - (1 << i) >= LVL[u]) v = DP[i][v];
    if (u == v) return u;
    for (int i = LN - 1; i >= 0; i--){
        if (DP[i][u] != DP[i][v]){
            u = DP[i][u];
            v = DP[i][v];
        }
    }
    return DP[0][u];
}

inline void check(int x, int& res){
    // If (x) occurs twice, then don't consider it's value 
    if ( (VIS[x]) and (--VAL[A[x]] == 0) ) res--; 
    else if ( (!VIS[x]) and (VAL[A[x]]++ == 0) ) res++;
    VIS[x] ^= 1;
}

void compute(){

    // Perform standard Mo's Algorithm
    int curL = Q[0].l, curR = Q[0].l - 1, res = 0;

    for (int i = 0; i < M; i++){

        while (curL < Q[i].l) check(ID[curL++], res);
        while (curL > Q[i].l) check(ID[--curL], res);
        while (curR < Q[i].r) check(ID[++curR], res);
        while (curR > Q[i].r) check(ID[curR--], res);

        int u = ID[curL], v = ID[curR];

        // Case 2
        if (Q[i].lc != u and Q[i].lc != v) check(Q[i].lc, res);

        ANS[Q[i].id] = res;

        if (Q[i].lc != u and Q[i].lc != v) check(Q[i].lc, res);
    }

    for (int i = 0; i < M; i++) printf("%d\n", ANS[i]);
}

int main(){

    int u, v, x;

    while (scanf("%d %d", &N, &M) != EOF){

        // Cleanup
        cur = 0;
        memset(VIS, 0, sizeof(VIS));
        memset(VAL, 0, sizeof(VAL));
        for (int i = 1; i <= N; i++) adjList[i].clear();

        // Inputting Values
        for (int i = 1; i <= N; i++) scanf("%d", &A[i]);
        memcpy(d + 1, A + 1, sizeof(int) * N);

        // Compressing Coordinates
        sort(d + 1, d + N + 1);
        K = unique(d + 1, d + N + 1) - d - 1;
        for (int i = 1; i <= N; i++) A[i] = lower_bound(d + 1, d + K + 1, A[i]) - d;

        // Inputting Tree
        for (int i = 1; i < N; i++){
            scanf("%d %d", &u, &v);
            adjList[u].push_back(v);
            adjList[v].push_back(u);
        }

        // Preprocess
        DP[0][1] = 1;
        dfs(1, -1);
        int size = sqrt(cur);

        for (int i = 1; i <= cur; i++) BL[i] = (i - 1) / size + 1;

        for (int i = 0; i < M; i++){
            scanf("%d %d", &u, &v);
            Q[i].lc = lca(u, v);
            if (l[u] > l[v]) swap(u, v);
            if (Q[i].lc == u) Q[i].l = l[u], Q[i].r = l[v];
            else Q[i].l = r[u], Q[i].r = l[v];
            Q[i].id = i;
        }

        sort(Q, Q + M);
        compute();
    }
}

Демо

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...