I think I found the problem:
Using Owl_algodiff.D.pack_flt takes a float and converts it into a Owl_algodiff.D.t that resembles a constant. Owl_optimise.D.minimise_fun does not calculate derivatives on its own.
So with the following workaround (still relying on a float -> float function as input), I get the right result:
#require "owl";;
let my_func x = x ** x;;
+let delta = 1e-6;;
+let my_func_diff x =
+ ( my_func (x +. delta) -. my_func (x -. delta) ) /.
+ (delta +. delta);;
+
let pack_flt = Owl_algodiff.D.pack_flt;;
let unpack_flt = Owl_algodiff.D.unpack_flt;;
-let my_func_wrapped x = x |> unpack_flt |> my_func |> pack_flt;;
+let my_func_wrapped x =
+ let open Owl_algodiff.D.Maths in
+ pack_flt (my_func (unpack_flt x)) +
+ pack_flt (my_func_diff (unpack_flt x)) * (x - (pack_flt (unpack_flt x)));;
-let opt_params = Owl_optimise.D.Params.config 10.;;
+let opt_params = Owl_optimise.D.Params.config 10000.;;
let result = Owl_optimise.D.minimise_fun opt_params my_func_wrapped (pack_flt 0.567) |> snd;;
let _ = result |> unpack_flt |> Float.to_string |> print_endline;;
Note that I (ab)use pack_flt (unpack_flt x) to convert x to a constant here.
I also see that Owl provides numerical differentiation. I tried to use that instead of manually specifying a delta, but the code got somewhat messy (though it does work):
#require "owl";;
let my_func x = x ** x;;
let pack_flt = Owl_algodiff.D.pack_flt;;
let unpack_flt = Owl_algodiff.D.unpack_flt;;
let my_func_wrapped x =
let open Owl_algodiff.D.Maths in
let module D = Owl_numdiff_generic.Make (Owl.Arr) in
let wrapped_func x = my_func (Owl.Arr.get x [|0|]) in
pack_flt (my_func (unpack_flt x)) +
pack_flt (
Owl.Arr.get
(D.grad wrapped_func (Owl.Arr.create [|1|] (unpack_flt x)))
[|0|]
) * (x - (pack_flt (unpack_flt x)));;
let opt_params = Owl_optimise.D.Params.config 10000.;;
let result = Owl_optimise.D.minimise_fun opt_params my_func_wrapped (pack_flt 0.567) |> snd;;
let _ = result |> unpack_flt |> Float.to_string |> print_endline;;
Lots and lots of packing and wrapping and unpacking. ![]()
My questions:
- What would be the idiomatic way to solve this problem (“I have a
float -> floatfunction and want to find the argument where the result is minimized”) using Owl? - I got a bit confused by
Owl_algodiffvsOwl.Algodiff. What is the difference? I suspectOwl_algodiffis more the internal module andOwl.Algodiffis the official API? I noticed, for example, that we haveOwl.Arrbut noOwl_arr. - Am I supposed to use
Owl_numdiff_genericin higher-level code at all? I see it belongs toowl-baserather than theowllibrary.