2

I wrote an edit distance algorithm both in clojure and scala.

The scala version runs 70x faster than the clojure one.

clojure:

(defn edit-distance                                                                                                                                                                                                                                                             
  "['seq of char' 'seq of char']"                                                                                                                                                                                                                                               
  [s0 s1]                                                                                                                                                                                                                                                                       
  (let [n0 (count s0)                                                                                                                                                                                                                                                           
        n1 (count s1)                                                                                                                                                                                                                                                           
        distances (make-array Long/TYPE (inc n0) (inc n1))]                                                                                                                                                                                                                     
    ;;initialize distances                                                                                                                                                                                                                                                      
    (doseq [i (range 1 (inc n0))] (aset-long distances i 0 i))                                                                                                                                                                                                                  
    (doseq [j (range 1 (inc n1))] (aset-long distances 0 j j))                                                                                                                                                                                                                  

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]                                                                                                                                                                                                                         
      (let [ins (aget distances i (dec j))                                                                                                                                                                                                                                      
            del (aget distances (dec i) j)                                                                                                                                                                                                                                      
            match (aget distances (dec i) (dec j))                                                                                                                                                                                                                              
            min-dist (min ins del match)]                                                                                                                                                                                                                                       
        (cond                                                                                                                                                                                                                                                                   
          (not= match min-dist) (aset-long distances i j (inc min-dist))                                                                                                                                                                                                        
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset-long distances i j (inc min-dist))                                                                                                                                                                                     
          :else (aset-long distances i j min-dist))))                                                                                                                                                                                                                           
    (aget distances n0 n1)))     

scala:

 def editDistance(s0: Array[Char], s1: Array[Char]):Int = {                                                                                                                                                                                                                   
      val n0 = s0.length                                                                                                                                                                                                                                                        
      val n1 = s1.length                                                                                                                                                                                                                                                        
      val distances = Array.fill(n0+1)(ArrayBuffer.fill(n1+1)(0))                                                                                                                                                                                                               
      for(j <- 0 to n1){distances(0)(j) = j}                                                                                                                                                                                                                                    
      for(i <- 0 to n0){distances(i)(0) = i}                                                                                                                                                                                                                                    
      for(i <- 1 to n0; j <- 1 to n1){                                                                                                                                                                                                                                          
         val ins = distances(i)(j-1)                                                                                                                                                                                                                                            
         val del = distances(i-1)(j)                                                                                                                                                                                                                                            
         val matches = distances(i-1)(j-1)                                                                                                                                                                                                                                      
         val minDist = (ins::del::matches::Nil).reduceLeft(_ min _)                                                                                                                                                                                                             
         if (matches != minDist)                                                                                                                                                                                                                                                
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
         else if (s0(i-1) == s1(j-1))                                                                                                                                                                                                                                           
            distances(i)(j) = minDist                                                                                                                                                                                                                                           
         else                                                                                                                                                                                                                                                                   
            distances(i)(j) = minDist + 1                                                                                                                                                                                                                                       
      }                                                                                                                                                                                                                                                                         
      distances(n0)(n1)                                                                                                                                                                                                                                                         
   }                                 

I am using java's array in clojure to get the best performance. I have considered hinting whenever agetis called but my code performs even worse (which might be expected as make-array already defines a typed array). I have also overridden clojure :jvm-opts in projects.clj. Yet the lower performance gap I get is 70x.

What's wrong with my use of java array in clojure?

Thanks for insight.

5
  • 1
    Have you run this through a profiler? In particular, pay attention to memory consumption. Commented Jul 30, 2016 at 16:51
  • @Anony-Mousse Indeed do reflections through the java.lang.reflect.method consume >90% of memory. How could this happen considering the distances 2d array is typed? Commented Jul 30, 2016 at 18:23
  • Maybe some lambda expressions. Does clojure generate Java 8 bytecode with method references? Commented Jul 30, 2016 at 18:26
  • I don't know how to check this. Were it the problem, how can I force clojure to generate method references? thanks Commented Jul 30, 2016 at 18:41
  • 1
    Be sure to also check out HipHip for low-level tasks: github.com/plumatic/hiphip Commented Jul 30, 2016 at 20:35

1 Answer 1

4

I think I figured out where the problem lies.

As you mentioned in the comment, the reflection calls consume most of the time. Here's why.

Before analyzing the code I've set *warn-on-reflection* to true:

(set! *warn-on-reflection* true)

Then, if you look at the source of aset or macro that generates aset-long function, you'll see that for 4+ arities it uses apply to invoke the functions. Same thing for aget for 3+ arities. I'm not 100% sure, but I believe that information about types of arguments is lost during applying a function. Also if you look closely here and here you may notice that aget and aset functions can be inlined during compilation. We definitely want that:

(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;; I've unwinded all aget/aset calls, so they can be inlined by compiler.
    ;; Also I'm type hinting first argument of toplevel aget/aset calls.
    ;; The reason is explained next.
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    ;; we can leave this, since it is not placed within loop
    (aget distances n0 n1)))

Let's compile our new function. Remember that global variable that we've set at the beginning? If set to true, compiler will produce a bunch of warnings during compilation:

Reflection warning, core.clj:75:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:76:23 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
Reflection warning, core.clj:77:25 - call to static method aget on clojure.lang.RT can't be resolved (argument types: unknown, int).
...

The problem is that Clojure cannot figure out the type of (make-array Long/TYPE (inc n0) (inc n1)), marking it as unknown. We need to type hint it:

(let [...
      ;; type hint for 2d array of primitive longs
      ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))
      ...]
   ...)

At this point, it seems that we're all set. The final version is below:

(defn edit-distance
  "['seq of char' 'seq of char']"
  [s0 s1]
  (let [n0 (count s0)
        n1 (count s1)
        ^"[[J" distances (make-array Long/TYPE (inc n0) (inc n1))]
    ;;initialize distances
    (doseq [^long i (range 1 (inc n0))] (aset ^longs (aget distances i) 0 i))
    (doseq [^long j (range 1 (inc n1))] (aset ^longs (aget distances 0) j j))

    (doseq [i (range 1 (inc n0)), j (range 1 (inc n1))]
      (let [ins (aget ^longs (aget distances i) (dec j))
            del (aget ^longs (aget distances (dec i))  j)
            match (aget ^longs (aget distances (dec i)) (dec j))
            min-dist (min ins del match)]
        (cond
          (not= match min-dist) (aset ^longs (aget distances i) j (inc min-dist))
          (not= (nth s0 (dec i)) (nth s1 (dec j))) (aset ^longs (aget distances i) j (inc min-dist))
          :else (aset ^longs (aget distances i) j min-dist))))
    (aget distances n0 n1)))

Here are benchmarks:

before:

> (time (edit-distance i1 i2))
"Elapsed time: 4601.025555 msecs"
291

after:

> (time (edit-distance i1 i2))
"Elapsed time: 27.782828 msecs"
291
Sign up to request clarification or add additional context in comments.

1 Comment

Thanks, now I have also learn how helpful checking source code can be.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.