Can this Clojure code be optimized?

850 views Asked by At

I wrote the code below for game I am working on. But it seems a little slow. If you have not checked the code yet, it's the A* search/pathfinding algorithm. It takes about 100-600 ms for a 100x100 grid, depending on the heuristic used (and consequently the number of tiles visited).

There are no reflection warnings. However, I suspect boxing might be an issue. But I don't know how to get rid of boxing in this case, because the computation is split among several functions. Also, I save tiles/coordinates as vectors of two numbers, like this: [x y]. But then the numbers will be boxed, right? A typical piece of code, if you don't want to read through it all, is: (def add-pos (partial mapv + pos)) where pos is the aforementioned kind of two-number vector. There are sereval of places where the numbers are manipulated in a way similar to add-pos above, and put back in a vector afterwards. Is there any way to optimize code like this? Any other tips is welcome too, performance-related or other.

EDIT: Thinking some more about it, I came up with a few follow-up questions: Can a Clojure function ever return primitives? Can a Clojure function ever take primitives (without any boxing)? Can I put primitives in a type/record without boxing?

(ns game.server.pathfinding
  (:use game.utils)
  (:require [clojure.math.numeric-tower :as math]
            [game.math :as gmath]
            [clojure.data.priority-map :as pm]))

(defn walkable? [x]
  (and x (= 1 x)))

(defn point->tile
  ([p] (apply point->tile p))
  ([x y] [(int x) (int y)]))

(defn get-tile [m v]
  "Gets the type of the tile at the point v in
   the grid m. v is a point in R^2, not grid indices."
  (get-in m (point->tile v)))

(defn integer-points
  "Given an equation: x = start + t * step, returns a list of the
   values for t that make x an integer between start and stop,
   or nil if there is no such value for t."
  [start stop step]
  (if-not (zero? step)
    (let [first-t (-> start ((if (neg? step) math/floor math/ceil))
                      (- start) (/ step))
          t-step (/ 1 (math/abs step))]
      (take-while #((if (neg? step) > <) (+ start (* step %)) stop)
                  (iterate (partial + t-step) first-t)))))

(defn crossed-tiles [[x y :as p] p2 m]
  (let [[dx dy :as diff-vec] (map - p2 p)
        ipf (fn [getter]
              (integer-points (getter p) (getter p2) (getter diff-vec)))
        x-int-ps (ipf first)
        y-int-ps (ipf second)
        get-tiles (fn [[x-indent y-indent] t]
                    (->> [(+ x-indent x (* t dx)) (+ y-indent y (* t dy))]
                         (get-tile m)))]
    (concat (map (partial get-tiles [0.5 0]) x-int-ps)
            (map (partial get-tiles [0 0.5]) y-int-ps))))

(defn clear-line?
  "Returns true if the line between p and p2 passes over only
   walkable? tiles in m, otherwise false."
  [p p2 m]
  (every? walkable? (crossed-tiles p p2 m)))

(defn clear-path?
  "Returns true if a circular object with radius r can move
   between p and p2, passing over only walkable? tiles in m,
   otherwise false.
   Note: Does not currently work for objects with a radius >= 0.5."
  [p p2 r m]
  (let [diff-vec (map (partial * r) (gmath/normalize (map - p2 p)))
        ortho1 ((fn [[x y]] (list (- y) x)) diff-vec)
        ortho2 ((fn [[x y]] (list y (- x))) diff-vec)]
    (and (clear-line? (map + ortho1 p) (map + ortho1 p2) m)
         (clear-line? (map + ortho2 p) (map + ortho2 p2) m))))

(defn straighten-path
  "Given a path in the map m, remove unnecessary nodes of
   the path. A node is removed if one can pass freely
   between the previous and the next node."
  ([m path]
   (if (> (count path) 2) (straighten-path m path nil) path))
  ([m [from mid to & tail] acc]
   (if to
     (if (clear-path? from to 0.49 m)
       (recur m (list* from to tail) acc)
       (recur m (list* mid to tail) (conj acc from)))
     (reverse (conj acc from mid)))))

(defn to-mid-points [path]
  (map (partial map (partial + 0.5)) path))

(defn to-tiles [path]
  (map (partial map int) path))

(defn a*
  "A* search for a grid of squares, mat. Tries to find a
   path from start to goal using only walkable? tiles.
   start and goal are vectors of indices into the grid,
   not points in R^2."
  [mat start goal factor]
  (let [width (count mat)
        height (count (first mat))]
    (letfn [(h [{pos :pos}] (* factor (gmath/distance pos goal)))
            (g [{:keys [pos parent]}]
              (if parent
                (+ (:g parent) (gmath/distance pos (parent :pos)))
                0))
            (make-node [parent pos]
              (let [node {:pos pos :parent parent}
                    g (g node) h (h node)
                    f (+ g h)]
                (assoc node :f f :g g :h h)))
            (get-path
              ([node] (get-path node ()))
              ([{:keys [pos parent]} path]
               (if parent
                 (recur parent (conj path pos))
                 (conj path pos))))
            (free-tile? [tile]
              (let [type (get-in mat (vec tile))]
                (and type (walkable? type))))
            (expand [closed pos]
              (let [adj [[1 0] [0 1] [-1 0] [0 -1]]
                    add-pos (partial mapv + pos)]
                (->> (take 4 (partition 2 1 (cycle adj)))
                     (map (fn [[t t2]]
                            (list* (map + t t2) (map add-pos [t t2]))))
                     (map (fn [[d t t2]]
                            (if (every? free-tile? [t t2]) d nil)))
                     (remove nil?)
                     (concat adj)
                     (map add-pos)
                     (remove (fn [[x y :as tile]]
                               (or (closed tile) (neg? x) (neg? y)
                                   (>= x width) (>= y height)
                                   (not (walkable? (get-in mat tile)))))))))
            (add-to-open [open tile->node [{:keys [pos f] :as node} & more]]
              (if node
                (if (or (not (contains? open pos))
                        (< f (open pos)))
                  (recur (assoc open pos f)
                         (assoc tile->node pos node)
                         more)
                  (recur open tile->node more))
                {:open open :tile->node tile->node}))]
      (let [start-node (make-node nil start)]
        (loop [closed #{}
               open (pm/priority-map start (:f start-node))
               tile->node {start start-node}]
          (let [[curr _] (peek open) curr-node (tile->node curr)]
            (when curr
              (if (= curr goal)
                (get-path curr-node)
                (let [exp-tiles (expand closed curr)
                      exp-nodes (map (partial make-node curr-node) exp-tiles)
                      {:keys [open tile->node]}
                      (add-to-open (pop open) tile->node exp-nodes)]
                  (recur (conj closed curr) open tile->node))))))))))

(defn find-path [mat start goal]
  (let [start-tile (point->tile start)
        goal-tile (point->tile goal)
        path (a* mat start-tile goal-tile)
        point-path (to-mid-points path)
        full-path (concat [start] point-path [goal])
        final-path (rest (straighten-path mat full-path))]
    final-path))
1

There are 1 answers

3
noisesmith On

I recommend the Clojure High Performance Programming book for addressing questions like yours.

There are functions to unbox primitives (byte, short, int, long, float, double).

Warn-on-reflection does not apply to numeric type reflection / failure to optimize numeric code. There is a lib to force warnings for numeric reflection - primitive-math.

You can declare the types of function arguments and function return values (defn ^Integer foo [^Integer x ^Integer y] (+ x y)).

Avoid apply if you want performance.

Avoid varargs (a common reason to need apply) if you want performance. Varargs functions create garbage on every invocation (in order to construct the args map, which usually is not used outside the function body). partial always constructs a varargs function. Consider replacing the varargs (partial * x) with #(* x %), the latter can be optimized much more aggressively.

There is a tradeoff with using primitive jvm single-type arrays (they are mutible and fixed in length, which can lead to more complex and brittle code), but they will perform better than the standard clojure sequential types, and are available if all else fails to get the performance you need.

Also, use criterium to compare various implementations of your code, it has a bunch of tricks to help rule out the random things that affect execution time so you can see what really performs best in a tight loop.

Also, regarding your representation of a point as [x y] - you can reduce the space and lookup overhead of the collection holding them with (defrecord point [x y]) (as long as you know they will remain two elements only, and you don't mind changing your code to ask for (:x point) or (:y point)). You could further optimize by making or using a simple two-number java class (with the tradeoff of losing immutibility).