Idiomatically functional way of finding two numbers in BST which add to a target

I am writing code to find two numbers in a binary search tree which add to a given number. The code I currently have is below, with the implementation in the function two_sum at bottom. The code is somewhat object-oriented, using an iterator to keep state. I am new to functional programming and am wondering if there is a more idiomatic way to write this.

open Core

type 'a node = {
  key: 'a;
  left: 'a t;
  right: 'a t;
and 'a t =
  | Empty
  | Node of 'a node

module Dfs = struct

  type 'a iterator = {
    bst: 'a t;
    current: 'a option;
    stack: 'a node list;

  let create_iterator (bst: 'a t) : 'a iterator =
      bst = bst;
      current =
          match bst with
          | Empty -> None
          | Node node -> Some node.key
      stack = [];

  let rec next_inorder (it: 'a iterator) : 'a iterator =
    match it.bst with
    | Empty ->
        match it.stack with
        | h::t -> { bst = h.right; current = Some h.key; stack = t }
        | [] -> { bst = Empty; current = None; stack = [] }
    | Node node ->
      next_inorder { bst = node.left; current = it.current; stack = node::it.stack }

  let rec prev_inorder (it: 'a iterator) : 'a iterator =
    match it.bst with
    | Empty ->
        match it.stack with
        | h::t -> { bst = h.left; current = Some h.key; stack = t }
        | [] -> { bst = Empty; current = None; stack = [] }
    | Node node ->
      prev_inorder { bst = node.right; current = it.current; stack = node::it.stack }


(* Returns first solution found, or None *)
let two_sum (nums: int t) (target: int) : (int * int) option =
  let rec loop (left: int Dfs.iterator) (right: int Dfs.iterator) : (int * int) option =
    let left_value = Option.value_exn left.current in
    let right_value = Option.value_exn right.current in
    if left_value >= right_value then
        let sum = left_value + right_value in
        match compare target sum with
        | -1 -> loop left (Dfs.prev_inorder right)
        | 1 ->  loop (Dfs.next_inorder left) right
        | 0 ->  Some (left_value, right_value)
        | _ ->  assert false
  let left = Dfs.create_iterator nums in
  let right = Dfs.create_iterator nums in
  loop (Dfs.next_inorder left) (Dfs.prev_inorder right)

(Code for keeping the invariant of the BST is not shown, as the main point of the question is the iteration.)


your algorithm does not use the fact that the keys are in binary search tree; indeed your use of iterators shows that your are only interested in traversing the elements in increasing or decreasing order. This suggests you should first reduce your tree to an increasing list, then make your algorithm works on lists instead of iterators (lists or streams are definitively more idiomatic in functional programming than iterators, so this should answer your question).

First, how to reduce the tree to an list in increasing order? This is not too hard to do directly with a recursion. But to be more idiomatic, I write a fold that encodes the particular recursion scheme for trees that we need. If you use a tree from a library, fold should already be defined.

let rec fold ~f tree init =
  match tree with
  | Empty -> init
  | Node node ->
    |> fold ~f node.left
    |> f node.key 
    |> fold ~f node.right

let decreasing_elements tree =
  fold ~f:(fun key list -> key::list) tree []

Then I encode the algorithm that finds a pair with given sum (actually my function builds the list of all pairs with given sum, but that can be changed easily). It takes the same list, in increasing and in decreasing order. Pattern matching with guard conditions is well suited here.

let rec find_pairs target increasing decreasing =
  match increasing, decreasing with
  | x::bigger, y::smaller when x >= y -> [] (* don't get each pair twice *)
  | x::bigger, y::smaller when x + y = target ->
    (x,y) :: find_pairs target bigger smaller
  | x::bigger, y::smaller when x + y < target ->
    find_pairs target bigger decreasing
  | x::bigger, y::smaller (* when x + y > target *) ->
    find_pairs target increasing smaller
  | _ -> []

Then combine both parts:

let find_pairs_in_tree target tree =
  let decreasing = decreasing_elements tree in
  let increasing = List.rev decreasing in
  find_pairs target increasing decreasing 

Compared to your solution, it always builds the list of all elements, losing some efficiency, an extreme case is when the min and max elements sum to the target. If it matters to you, a solution would be to use lazy lists instead of lists. You would need two functions to get increasing and decreasing lazy lists from trees, then find_pairs would be mostly the same (with the additional lazy syntax).

1 Like

Is there a way that we can that does not involve creating a new data structure proportional to the length of the original BST?

Even for lazy lists, I am getting 5-10x slower using the approach of folding into two lazy lists then pattern matching over the approach of having an iterator provide the values with which to compare.

For example, some sort of β€œbifold” which does not need to store the elements (list) or thunks (lazy list) entirely in memory but just the elements which need to be accessed?

I did some benchmarking, and my results do not quite agree with yours. I implemented five versions, without optimizing much (I am not competent in optimizing code).

  1. The usual lists as I proposed in my first answer,
  2. Same with lazy lists, using the type 'a Lazy.t
  3. Same as lazy lists, but because we do not use the memoisation, I used the type () -> 'a instead of 'a Lazy.t
    (so forcing twice the same value would take double the time)
  4. An implementation of generators, this is similar to your version
  5. An implementation using zippers. Zippers are a functional way to β€œwalk inside” a data structure.

The program is on a git here.

I generate an instance with a single valid pair that are at 5% at 95% in the distribution (this can be changed easily). This is what I get for about 80k keys.

β”‚ Name              β”‚ Time/Run β”‚  mWd/Run β”‚  mjWd/Run β”‚  Prom/Run β”‚ Percentage β”‚
β”‚ stdlib lists      β”‚ 749.07us β”‚ 124.11kw β”‚ 9_815.53w β”‚ 9_815.53w β”‚    100.00% β”‚
β”‚ lazy lists        β”‚ 168.40us β”‚  32.89kw β”‚   225.38w β”‚   225.38w β”‚     22.48% β”‚
β”‚ simili lazy lists β”‚ 100.06us β”‚  53.08kw β”‚    35.76w β”‚    35.76w β”‚     13.36% β”‚
β”‚ generators        β”‚  91.26us β”‚  14.93kw β”‚     6.96w β”‚     6.96w β”‚     12.18% β”‚
β”‚ zippers           β”‚ 201.36us β”‚  70.47kw β”‚   630.82w β”‚   630.82w β”‚     26.88% β”‚

And with the pair in position 20%-80% I get:

β”‚ Name              β”‚ Time/Run β”‚  mWd/Run β”‚   mjWd/Run β”‚   Prom/Run β”‚ Percentage β”‚
β”‚ stdlib lists      β”‚ 968.39us β”‚ 146.83kw β”‚ 13_728.07w β”‚ 13_728.07w β”‚    100.00% β”‚
β”‚ lazy lists        β”‚ 690.88us β”‚ 120.81kw β”‚  2_686.45w β”‚  2_686.45w β”‚     71.34% β”‚
β”‚ simili lazy lists β”‚ 413.36us β”‚ 196.99kw β”‚    133.25w β”‚    133.25w β”‚     42.69% β”‚
β”‚ generators        β”‚ 370.01us β”‚  54.86kw β”‚     27.06w β”‚     27.06w β”‚     38.21% β”‚
β”‚ zippers           β”‚ 948.56us β”‚ 261.98kw β”‚  8_710.53w β”‚  8_710.53w β”‚     97.95% β”‚