A* search algorithm用のモジュールを書いてみた

仕事ではOCamlもA* search algorithmも全く使っていないのですが、自宅で気分転換/リハビリがてら何か書こうとすると、すぐOCamlでA* search algorithmを書いてしまいます。たまにダイクストラ法で何かgraphを探索するやつも書きます。毎回同じようなものを書くのは面倒なのでコア部分だけモジュール化してみました。A*はDijkstra法の拡張みたいなもんなのでそのままつかえそうですし。

astar.ml

module type RouteType = sig
  type pos
  type cost
  val add_cost: cost -> cost -> cost
  val cost_to_move : pos -> pos -> cost
  val compare_cost : cost -> cost -> int
  val heuristic : pos -> cost
  val next_routes : pos -> pos list
end

module Make(Route:RouteType) : sig
  val run : Route.pos -> Route.pos -> Route.cost ->
            (Route.pos * Route.pos option * Route.cost) list
end = struct
  type node = {pos:Route.pos; cost:Route.cost; score:Route.cost; prev:Route.pos option}

  let score prev now_pos =
    (Route.add_cost
      (Route.add_cost prev.cost (Route.cost_to_move prev.pos now_pos))
      (Route.heuristic now_pos)
    )

  let remove_minimum_score_node nodeset =
    let (minimum_opt, resultset) =
      List.fold_left
        (fun (candidate_opt, resultset) node ->
          match candidate_opt with
          | None -> (Some node, resultset)
          | Some candidate when Route.compare_cost candidate.score node.score > 0 ->
                 (Some node, candidate::resultset)
          | _ -> (candidate_opt, node::resultset)
        ) (None, []) nodeset
    in
    match minimum_opt with
    | None -> failwith "empty nodeset?"
    | Some minimum -> (minimum, resultset)

  let find_same_pos nodeset pos =
    List.fold_left
      (fun (found_node_opt, resultset) node ->
        if node.pos = pos then (Some node, resultset)
        else (found_node_opt, node::resultset)
      ) (None, []) nodeset

  let create_node pos score prev_node =
    { pos=pos;
      cost=Route.add_cost prev_node.cost (Route.cost_to_move prev_node.pos pos);
      score=score;
      prev=Some prev_node.pos }

  let sort closeset goal =
    let rec _sort prev sorted =
      let node = List.find (fun node -> node.pos = prev) closeset in
      let sorted = node::sorted in
      match node.prev with
      | Some prev -> _sort prev sorted 
      | None -> sorted
    in
    _sort goal []

  let result_of_node xs =
    List.map (fun x -> (x.pos, x.prev, x.cost)) xs

  let run start goal init_cost =
    let rec _run openset closeset =
      let (node, openset) = remove_minimum_score_node openset in
      let closeset = node::closeset in 
      if node.pos = goal then result_of_node (sort closeset goal)
      else (
        let openset =
          List.fold_left
            (fun openset next_pos ->   
              let score_of_next_pos = score node next_pos in
              match find_same_pos openset next_pos with
              (* check the same pos in openset *)
              | (None, openset) -> (   
                (* if none, check the same pos in closeset *)
                match find_same_pos closeset next_pos with
                | (None, closeset) ->  
                    (create_node next_pos score_of_next_pos node)::openset
                | (Some same_pos_node, closeset) when same_pos_node.score > score_of_next_pos ->
                      (create_node next_pos score_of_next_pos node)::openset
                | (_, closeset) -> openset
              )
              | (Some same_pos_node, openset) when same_pos_node.score > score_of_next_pos ->
                    (create_node next_pos score_of_next_pos node)::openset
              | (_, openset) -> openset
            ) openset (Route.next_routes node.pos)
        in
        _run openset closeset
      )
    in
    let start_node = {pos=start; cost=init_cost; score=Route.heuristic start; prev=None} in
    _run [start_node] []
  end

posは経路の位置、costは移動コストや評価値を表現できる型です。heuristicは当該位置から終点までのコスト、next_routesは当該位置から移動可能な位置を返す関数です。

このFunctorを用いてDijkstra法っぽいサンプルを解く場合は以下のような感じになります。
dijkstraSample.ml

type t = S | A | B | C | D | E | G

let routes =
  [
    (S, A, 3);
    (S, B, 10);
    (S, C, 12);
    (A, D, 10);
    (A, B, 2);
    (B, C, 3);
    (B, E, 7);
    (C, E, 3);
    (D, E, 1);
    (D, G, 3);
    (E, D, 1);
    (E, G, 5);
  ]

let string_of_point = function
    | S -> "S"
    | A -> "A"
    | B -> "B"
    | C -> "C"
    | D -> "D"
    | E -> "E"
    | G -> "G"

module GraphAstar =
  Astar.Make(
    struct
      type pos = t
      type cost = int

      let heuristic pos = 0

      let add_cost a b = a + b
      let cost_to_move prev_pos now_pos =
        let cost_opt =
          List.fold_left
            (fun candidate (src, dst, cost) ->
              if src = prev_pos && dst = now_pos
              then Some cost
              else candidate
            ) None routes in
        match cost_opt with None -> failwith "not found" | Some cost -> cost

      let compare_cost a b = compare a b

        let next_routes current =
          List.fold_left
            (fun next_points (src, dst, cost) ->
              if src = current
              then dst::next_points
              else next_points
            ) [] routes
    end
  )

let _ =
  let resultset = GraphAstar.run S G 0 in
  List.iter
    (fun (src, _, cost) ->
      Printf.printf "pos=%s, cost=%s\n"
        (string_of_point src)
        (string_of_int cost)
    ) resultset

実行させるとこう。

komamitsu@carrot:~/git/ocaml-libastar$ ocamlc -o dijkstraSample astar.ml dijkstraSample.ml
komamitsu@carrot:~/git/ocaml-libastar$ ./dijkstraSample 
pos=S, cost=0
pos=A, cost=3
pos=B, cost=5
pos=C, cost=8
pos=E, cost=11
pos=D, cost=12
pos=G, cost=15

A*向けっぽいサンプルの場合はこう。

let maze = [
  "S                #   ";
  "   ##  ####      # # ";
  " #  #  #  # ###### # ";
  " ####     #        # ";
  "     #     #    #### ";
  " ##### ########    # ";
  "              ##   # ";
  " ###########   ##### ";
  " #           ##      ";
  "## ###########  #####";
  "                    G";
  ]
let elm_in_maze x y = String.get (List.nth maze y) x

type position = {x:int; y:int}
let len_x = String.length (List.nth maze 0)
let len_y = List.length maze
let start = {x=0; y=0}
let goal = {x=len_x - 1; y=len_y - 1}

let string_of_pos pos = Printf.sprintf "(%d, %d)" pos.x pos.y
let string_of_cost = string_of_float

module MazeAstar= Astar.Make(
  struct
    type pos = position
    type cost = float

    let heuristic pos =
      sqrt (
        (float_of_int (goal.x - pos.x)) ** 2.0 +.
        (float_of_int (goal.y - pos.y)) ** 2.0
      )

    let add_cost a b = a +. b
    let cost_to_move prev_pos now_pos = 1.0
    let compare_cost a b = compare a b 

    let add_cost a b = a +. b
    let cost_to_move prev_pos now_pos = 1.0
    let compare_cost a b = compare a b 

    let next_routes current =
      List.fold_left
        (fun next_points (dx, dy) ->   
          let (x, y) = (current.x - dx, current.y - dy) in
          if x >= 0 && y >= 0 && x < len_x && y < len_y &&
            elm_in_maze x y != '#'
          then {x=x; y=y}::next_points 
          else next_points
        )
        [] [(0, -1); (0, 1); (-1, 0); (1, 0)]
  end
)

let _ =
  let resultset = MazeAstar.run start goal 0.0 in
  List.iter
    (fun ({x=x; y=y}, prev, cost) ->   
      String.set (List.nth maze y) x '.';
    ) resultset;
  List.iter (fun row -> print_endline row) maze

実行させるとこう。

komamitsu@carrot:~/git/ocaml-libastar$ ocamlc -o astarSample astar.ml astarSample.ml 
komamitsu@carrot:~/git/ocaml-libastar$ ./astarSample 
............     #...
   ##  ####.     #.#.
 #  #  #  #.######.#.
 ####     #........#.
     #     #    ####.
 ##### ########    #.
              ##   #.
 ###########   #####.
 #           ##......
## ########### .#####
               ......

まぁ、書いてみたものの自分で使うことは無さそうだなぁ...

https://github.com/komamitsu/ocaml-libastar