I am writing code to find two numbers in a binary search tree which add to a given number. The code I currently have is below, with the implementation in the function two_sum
at bottom. The code is somewhat object-oriented, using an iterator to keep state. I am new to functional programming and am wondering if there is a more idiomatic way to write this.
open Core
type 'a node = {
key: 'a;
left: 'a t;
right: 'a t;
}
and 'a t =
| Empty
| Node of 'a node
module Dfs = struct
type 'a iterator = {
bst: 'a t;
current: 'a option;
stack: 'a node list;
}
let create_iterator (bst: 'a t) : 'a iterator =
{
bst = bst;
current =
begin
match bst with
| Empty -> None
| Node node -> Some node.key
end;
stack = [];
}
let rec next_inorder (it: 'a iterator) : 'a iterator =
match it.bst with
| Empty ->
begin
match it.stack with
| h::t -> { bst = h.right; current = Some h.key; stack = t }
| [] -> { bst = Empty; current = None; stack = [] }
end
| Node node ->
next_inorder { bst = node.left; current = it.current; stack = node::it.stack }
let rec prev_inorder (it: 'a iterator) : 'a iterator =
match it.bst with
| Empty ->
begin
match it.stack with
| h::t -> { bst = h.left; current = Some h.key; stack = t }
| [] -> { bst = Empty; current = None; stack = [] }
end
| Node node ->
prev_inorder { bst = node.right; current = it.current; stack = node::it.stack }
end
(* Returns first solution found, or None *)
let two_sum (nums: int t) (target: int) : (int * int) option =
let rec loop (left: int Dfs.iterator) (right: int Dfs.iterator) : (int * int) option =
let left_value = Option.value_exn left.current in
let right_value = Option.value_exn right.current in
if left_value >= right_value then
None
else
begin
let sum = left_value + right_value in
match compare target sum with
| -1 -> loop left (Dfs.prev_inorder right)
| 1 -> loop (Dfs.next_inorder left) right
| 0 -> Some (left_value, right_value)
| _ -> assert false
end
in
let left = Dfs.create_iterator nums in
let right = Dfs.create_iterator nums in
loop (Dfs.next_inorder left) (Dfs.prev_inorder right)
(Code for keeping the invariant of the BST is not shown, as the main point of the question is the iteration.)