4
\$\begingroup\$

I was wondering why the following implementation of a Suffix Tree is 2000 times slower than using a similar data structure in C++ (I tested it using python3, pypy3 and pypy2). I know the bottleneck is the traverse method, do you spot chance of optimisation in the following code? Thanks!

from sys import stdin, stdout, stderr, setrecursionlimit
setrecursionlimit(100000)

def read():
    return stdin.readline().rstrip()

def readint():
    return int(read())

def make_node(_pos, _len):
    global s, n, sz, to, link, fpos, slen, pos, node
    fpos[sz] = _pos
    slen[sz] = _len
    sz += 1
    return sz-1

def go_edge():
    global s, n, sz, to, link, fpos, slen, pos, node
    while (pos > slen[to[node].get(s[n - pos], 0)]):
        node = to[node].get(s[n - pos], 0)
        pos -= slen[node]

def add_letter(c):
    global s, n, sz, to, link, fpos, slen, pos, node
    s[n] = c
    n += 1
    pos += 1
    last = 0
    while(pos > 0):
        go_edge()
        edge = s[n - pos]
        v = to[node].get(edge, 0)
        t = s[fpos[v] + pos - 1]
        if (v == 0):
            to[node][edge] = make_node(n - pos, inf)
            link[last] = node
            last = 0
        elif (t == c):
            link[last] = node
            return
        else:
            u = make_node(fpos[v], pos - 1)
            to[u][c] = make_node(n - 1, inf)
            to[u][t] = v
            fpos[v] += pos - 1
            slen[v] -= pos - 1
            to[node][edge] = u
            link[last] = u
            last = u
        if(node == 0):
            pos -= 1
        else:
            node = link[node]

def init_tree(st):
    global slen, ans, inf, maxn, s, to, fpos, slen, link, node, pos, sz, n

    inf = int(1e9)
    maxn = len(st)*2+1 #int(1e6+1)
    s = [0]*maxn
    to = [{} for i in range(maxn)]
    fpos, slen, link = [0]*maxn, [0]*maxn, [0]*maxn

    node, pos = 0, 0
    sz = 1
    n = 0

    slen[0] = inf
    ans = 0
    for c in st:
        add_letter(ord(c))


text = "TTTTTTTTTT" + 1777*"A" + "$"
query = 1750*"A"

def traverse_edge(st, idx, start, end):
    global len_text, len_st
    k = start
    while k <= end and k < len_text and idx < len_st:
        if text[k] != st[idx]:
            return -1
        k += 1
        idx += 1
    if idx == len_st:
        return idx
    return 0

def edgelen(v, init, e):
    if(v == 0):
        return 0
    return e-init+1

def dfs(node, leafs, off):
    if not node:
        return
    k, tree = node
    if not isinstance(tree, int): # it is a node pointing to other nodes
        for kk, value in tree.items():
            if slen[value] > 190000000:
                leafs.append(fpos[value]-off)
            else:
                dfs((kk, to[value]), leafs, off+slen[value])

def traverse(v, st, idx, depth=0):
    global len_st
    result = cache.get((v, st), -2)
    if result != -2:
        return result
    r = -1
    init = fpos[v]
    end = fpos[v]+slen[v]
    e = end-1
    if v != 0:
        r = traverse_edge(st, idx, init, e)
        if r != 0:
            if r == -1:
                cache[(v, st)] = []
                return []
            depth += r
            # Here is when we found a match
            leafs = []
            dfs((v, to[v]), leafs, depth)
            #return reversed(leafs)
            return leafs
    idx = idx + edgelen(v, init, e)
    if idx > len_st:
        cache[(v, st)] = []
        return []
    k = ord(st[idx])
    children = to[v]
    if k in children:
        vv = children.get(k, 0)
        matches = traverse(vv, st, idx, depth)
        cache[(v, st)] = matches
        return matches
    return []


def main():
    global text, len_st, len_text, cache
    init_tree(text)
    len_text = len(text)
    len_st = len(query)
    r = []
    #cache = {}
    for runs in range(1000):
        cache = {}
        matches = traverse(0, query, 0)
        r.append("\n".join(str(m) for m in matches))
    stdout.write("\n".join(r)+"\n")

main()
\$\endgroup\$
0

2 Answers 2

5
\$\begingroup\$

So, about all those globals. They make your code really hard to read, I have almost no idea what is going on (after having read it twice). The general recommendation is to only use global objects if you really have to. Otherwise pass the objects either as arguments to the function, or write a class which has the relevant variables as attributes and keep the state like this.

Another thing that makes your code a lot easier to understand (both for other people and for yourself in two months time), are docstrings, which document what each function does, the parameters it takes and what it returns or modifies.

\$\endgroup\$
1
  • \$\begingroup\$ Thanks for the suggestions. I was just porting this code ideone.com/sT8Vd1, explained in codeforces.com/blog/entry/16780. In the original code they were using global variables (which is very common in competitive programming implementations) so I tried to keep it the closer I could so it could be easier to compare in my tests. I appreciate your advise about docstrings to make the code clear also \$\endgroup\$ Commented Apr 28, 2020 at 15:20
-2
\$\begingroup\$

I finally gave up trying to optimize Pypy and wrote the code I needed just in C++. Here is it in case somebody is curious:

#include <bits/stdc++.h>

using namespace std;

#define fpos adla
const int inf = 1e9;
const int maxn = 1e6+1;
//#define maxn 100005

char s[maxn];
map<int,int> to[maxn];
int slen[maxn], fpos[maxn], link[maxn];
int node, pos;
int sz = 1, n = 0;

int text_len;
char text[maxn];
int query_len;
char query[maxn];

int make_node(int _pos, int _len) {
    fpos[sz] = _pos;
    slen[sz] = _len;
    return sz++;
}

void go_edge() {
    while(pos > slen[to[node][s[n - pos]]]) {
        node = to[node][s[n - pos]];
        pos -= slen[node];
    }
}

void add_letter(int c) {
    s[n++] = c;
    pos++;
    int last = 0;
    while(pos > 0) {
        go_edge();
        int edge = s[n - pos];
        //int &v = to[node][edge];
        int v = to[node][edge];
        int t = s[fpos[v] + pos - 1];
        if(v == 0) {
            //v = make_node(n - pos, inf);
            to[node][edge] = make_node(n - pos, inf);
            link[last] = node;
            last = 0;
        }
        else if(t == c) {
            link[last] = node;
            return;
        }
        else {
            int u = make_node(fpos[v], pos - 1);
            to[u][c] = make_node(n - 1, inf);
            to[u][t] = v;
            fpos[v] += pos - 1;
            slen [v] -= pos - 1;
            // v = u;
            to[node][edge] = u;
            link[last] = u;
            last = u;
        }
        if(node == 0)
            pos--;
        else
            node = link[node];
    }
}

int traverse_edge(char st[], int st_len, int idx, int start, int end) {
    int k = start;
    while (k <= end && k < text_len && idx < st_len) {
        if (text[k] != st[idx]) {
            return -1;
        }
        k += 1;
        idx += 1;
    }
    if (idx == st_len) {
        return idx;
    }
    return 0;
}

int edgelen(int v, int init, int e) {
    if (v == 0) {
        return 0;
    }
    return e-init+1;
}

void dfs(map<int, int> tree, vector<vector<int>> leafs, int off) {
    if (tree.empty()) {
        return;
    }
        /* for ( auto [k, value] : tree) { */ // C++17
        for ( auto it : tree) {
            int value = it.second;
            if (slen[value] > maxn / 10) {
                vector<int> leaf;
                leaf.push_back(fpos[value]);
                leaf.push_back(slen[value]);
                leaf.push_back(fpos[value]-off);
                leafs.push_back(leaf);
            } else {
                dfs(to[value], leafs, off+slen[value]);
            }
        }
}

pair<bool, vector<int>> traverse(int v, map<int,int> tree, char st[], int idx, int depth) {
    int r = -1;
    int init = fpos[v];
    int end = fpos[v]+slen[v];
    int e = end - 1;
    if (v != 0) {
        r = traverse_edge(st, query_len, idx, init, e);
        if (r != 0) {
            if (r == -1) {
                return {false, vector<int>{}}; 
            }
            vector<int> matches;
            return {true, matches};
        }
    }
    idx = idx + edgelen(v, init, e);
    if (idx > query_len) {
        return {false, vector<int>{}}; 
    }
    int k = int(st[idx]);
    map<int,int> children = to[v];
    /* if (children.contains(k)) { */ // C++20 
    if (children.find(k) != children.end()) {
        int vv = tree[k];
        return traverse(vv, to[vv], st, idx, depth);
    }
    return pair<bool, vector<int>>(false, vector<int>{});
}


int main() {
    int t;
    scanf("%d", &t);
        while (t--) {
            scanf("%s", text);
            text_len = strlen(text);

            for (int i=0; i<maxn; i++) {
                to[i] = map<int,int>{};
            }
            sz = 1;
            n = 0;
            pos = 0;
            node = 0;
            slen[0] = inf;

            for (int i=0; i<text_len; i++) {
                add_letter((int)text[i]);
            }
            add_letter((int)'$');

            int q;
            scanf("%d", &q);
                while (q--) {
                    scanf("%s", query);
                    query_len = strlen(query);
                    pair<bool, vector<int>> results = traverse(0, to[0], query, 0, 0);
                    printf("%s\n", results.first ? "y" : "n");
                }
        }
}
```
\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.