diff --git a/src/lib/classify.ml b/src/lib/classify.ml new file mode 100644 index 0000000..afdb958 --- /dev/null +++ b/src/lib/classify.ml @@ -0,0 +1,175 @@ + +module List = ListLabels +open Util + +type 'a probabilities = ('a * float) list + +let most_likely = function + | [] -> invalidArg "Classify.most_likely: empty probabilities" + | h::tl -> + List.fold_left ~f:(fun ((_,p1) as c1) ((_,p2) as c2) -> + if p2 > p1 then c2 else c1) ~init:h tl + |> fst + +let multiply_ref = ref true +let prod_arr, prod_arr2 = + if !multiply_ref then + (fun f x -> Array.fold_left (fun p x -> p *. f x) 1.0 x), + (fun f x y -> Array.fold2 (fun p x y -> p *. f x y) 1.0 x y) + else + (fun f x -> Array.fold_left (fun s x -> s +. log (f x)) 0.0 x |> exp), + (fun f x y -> Array.fold2 (fun s x y -> s +. log (f x y)) 0.0 x y |> exp) + +type ('cls, 'ftr) naive_bayes = + (* Store the class prior in last element of the array. *) + { table : ('cls * float array) list + ; to_feature_array : 'ftr -> int array + ; features : int + } + +let eval ?(bernoulli=false) nb b = + let evidence = ref 0.0 in + let to_likelihood class_probs = + let idx = nb.to_feature_array b in + if bernoulli then + let set = Array.to_list idx in + prod_arr (fun i -> + if List.mem i ~set then + class_probs.(i) + else + (1.0 -. class_probs.(i))) + (Array.init nb.features (fun x -> x)) + else + prod_arr (fun i -> class_probs.(i)) idx + in + let byc = + List.map nb.table ~f:(fun (c, class_probs) -> + let prior = class_probs.(nb.features) in + let likelihood = to_likelihood class_probs in + let prob = prior *. likelihood in + evidence := !evidence +. prob; + (c, prob)) + in + List.map byc ~f:(fun (c, prob) -> (c, prob /. !evidence)) + +let within a b x = max a (min x b) + +type smoothing = + { factor : float + ; feature_space_size : int array + } + +let estimate ?smoothing ~feature_size to_ftr_arr data = + if data = [] then + invalidArg "Classify.estimate: Nothing to train on" + else + let aa = feature_size + 1 in + let update arr idx = + Array.iter (fun i -> arr.(i) <- arr.(i) + 1) idx; + (* keep track of the class count at the end of array. *) + arr.(feature_size) <- arr.(feature_size) + 1; + in + let (total, all) = + List.fold_left data + ~f:(fun (total, asc) (label, feature) -> + let n_asc = + try + let fr = List.assoc label asc in + update fr (to_ftr_arr feature); + asc + with Not_found -> + let fr = Array.make aa 0 in + update fr (to_ftr_arr feature); + (label, fr) :: asc + in + total + 1, n_asc) + ~init:(0, []) + in + let totalf = float total in + let cls_sz = float (List.length all) in + let to_prior_prob, to_lkhd_prob = + match smoothing with + | None -> + (fun count bkgrnd _ -> count /. bkgrnd), + (fun count bkgrnd _ -> count /. bkgrnd) + | Some s -> + (* TODO: Issue warning? Fail? *) + let sf = within 0.0 1.0 s.factor in + let fss = Array.map float s.feature_space_size in + (fun count bkgrnd space_size -> + (count +. sf) /. (bkgrnd +. sf *. space_size)), + (fun count bkgrnd idx -> + (count +. sf) /. (bkgrnd +. sf *. fss.(idx))) + in + let table = + List.map all ~f:(fun (cl, attr_count) -> + let prior_count = float attr_count.(feature_size) in + let likelihood = + Array.init aa (fun i -> + to_lkhd_prob (float attr_count.(i)) prior_count i) + in + (* Store the prior at the end. *) + likelihood.(feature_size) <- to_prior_prob prior_count totalf cls_sz; + cl, likelihood) + in + { table + ; to_feature_array = to_ftr_arr + ; features = feature_size + } + +type 'a gauss_bayes = + { table : ('a * float * (float * float) array) list + ; features : int + } + +let gauss_eval gb features = + if Array.length features <> gb.features then + invalidArg "Classify:gauss_eval: Expected a features array of %d features." + gb.features; + let prod = + prod_arr2 (fun (mean,std) y -> Distributions.normal_pdf ~mean ~std y) + in + let evidence = ref 0.0 in + let byc = + List.map gb.table ~f:(fun (c, prior, class_params) -> + let likelihood = prod class_params features in + let prob = prior *. likelihood in + evidence := !evidence +. prob; + (c, prob)) + in + List.map byc ~f:(fun (c, prob) -> (c, prob /. !evidence)) + +let gauss_estimate data = + if data = [] then + invalidArg "Classify.gauss_estimate: Nothing to train on!" + else + let update = Array.map2 Running.update in + let init = Array.map Running.init in + let features = Array.length (snd (List.hd data)) in + let total, by_class = + List.fold_left data + ~f:(fun (t, acc) (cls, attr) -> + try + let (cf, rsar) = List.assoc cls acc in + let acc' = List.remove_assoc cls acc in + let nrs = update rsar attr in + let cf' = cf + 1 in + (t + 1, (cls, (cf', nrs)) :: acc') + with Not_found -> + (t + 1, (cls, (1, (init attr))) :: acc)) + ~init:(0, []) + in + let totalf = float total in + (* A lot of the literature in estimating Naive Bayes focuses on estimating + the parameters using Maximum Likelihood. The Running estimate of variance + computes the unbiased form. Not certain if we should implement the + n/(n-1) conversion below. *) + let table = + let select rs = rs.Running.mean, (sqrt rs.Running.var) in + by_class + |> List.map ~f:(fun (c, (cf, rsarr)) -> + let class_prior = (float cf) /. totalf in + let attr_params = Array.map select rsarr in + (c, class_prior, attr_params)) + in + { table ; features } diff --git a/src/lib/classify.mli b/src/lib/classify.mli new file mode 100644 index 0000000..529a6c6 --- /dev/null +++ b/src/lib/classify.mli @@ -0,0 +1,53 @@ + +(** The classifiers below assign a discrete probability distribution over the + list of class 'a in their training set. *) +type 'a probabilities = ('a * float) list + +(** [most_likely probabilities] returns the most likely class from the + discrete probability distribution. *) +val most_likely : 'a probabilities -> 'a + +(** A discrete Naive Bayes classifier of class ['cls] by observing + features ['ftr]. *) +type ('cls, 'ftr) naive_bayes + +(** When estimating a probability distribution by counting observed instances + in the feature space we may want to smooth the values, particularly if our + training data is sparse. + + [http://en.wikipedia.org/wiki/Additive_smoothing] + *) +type smoothing = + { factor : float (** Multiplicative factor *) + ; feature_space_size : int array (** Size of the space of each feature. + Must be at least [feature_size] long.*) + } + +(** [estimate smoothing feature_size to_feature_array training_data] trains a + discrete Naive Bayes classifier based on the [training_data]. + [to_feature_array] maps a feature to an integer array of length + [feature_size]. Optionally, additive [smoothing] is applied to the final + estimates if provided. +*) +val estimate : ?smoothing:smoothing -> feature_size:int -> + ('ftr -> int array) -> ('cls * 'ftr) list -> + ('cls, 'ftr) naive_bayes + +(** [eval bernoulli classifier feature] classifies [feature] + according to [classifier]. if [bernoulli] is specified we treat the + underlying distribution as Bernoulli (as opposed to Multinomial) and + estimate the likelihood with (1-p_i) for features [i] that are missing + from [feature]. +*) +val eval : ?bernoulli:bool -> ('cls, 'ftr) naive_bayes -> 'ftr -> 'cls probabilities + +(** A continuous Gaussian Naive Bayes classifier of class ['cls]. The + feature space is assumed to be a float array. *) +type 'cls gauss_bayes + +(** [gauss_estimate training_data] trains a Gaussian Naive Bayes classifier from + [training_data], where all of the data are of the same length; feature size. *) +val gauss_estimate : ('cls * float array) list -> 'cls gauss_bayes + +(** [gauss_eval classifier feature] classify the [feature] using the [classifier]. *) +val gauss_eval : 'cls gauss_bayes -> float array -> 'cls probabilities diff --git a/src/lib/classify.mlt b/src/lib/classify.mlt new file mode 100644 index 0000000..5e46f87 --- /dev/null +++ b/src/lib/classify.mlt @@ -0,0 +1,212 @@ + +open Test_utils + +let () = + (* The IRIS data set from http://archive.ics.uci.edu/ml/datasets/Iris*) + Test.add_simple_test ~title:"Classify: Naive Multinomial Bayes." + (fun () -> + (** This example comes from + Bayesian Reasoning and Machine Learning by David Barber + http://web4.cs.ucl.ac.uk/staff/D.Barber/pmwiki/pmwiki.php?n=Brml.HomePage + *) + let feature_map = function + | `shortbread -> 0 + | `lager -> 1 + | `whiskey -> 2 + | `porridge -> 3 + | `football -> 4 + in + let data = + [ + (`English, [`whiskey; `porridge; `football]); + (`English, [`shortbread; `whiskey; `porridge]); + (`English, [`shortbread; `lager; `football]); + (`English, [`shortbread; `lager]); + (`English, [`lager; `football]); + (`English, [`porridge]); + (`Scottish, [`shortbread; `porridge; `football]); + (`Scottish, [`shortbread; `lager; `football]); + (`Scottish, [`shortbread; `lager; `whiskey; `porridge]); + (`Scottish, [`shortbread; `lager; `porridge]); + (`Scottish, [`shortbread; `lager; `porridge; `football]); + (`Scottish, [`shortbread; `whiskey; `porridge]); + (`Scottish, [`shortbread; `whiskey]) + ] + in + let to_feature_arr l = l |> List.map ~f:feature_map |> Array.of_list in + let naiveb = estimate ~feature_size:5 to_feature_arr data in + let sample = [ `shortbread ; `whiskey; `porridge ] in + let result = eval ~bernoulli:true naiveb sample in + let expect = + [(`Scottish, 0.807627593942793); (`English, 0.192372406057206957)] + in + Assert.is_true (expect = result)); + + (* The IRIS data set from http://archive.ics.uci.edu/ml/datasets/Iris*) + Test.add_simple_test ~title:"Classify: Naive Gaussian Bayes." + (fun () -> + let iris = + [ + `setosa,[|5.1;3.5;1.4;0.2|]; + `setosa,[|4.9;3.0;1.4;0.2|]; + `setosa,[|4.7;3.2;1.3;0.2|]; + `setosa,[|4.6;3.1;1.5;0.2|]; + `setosa,[|5.0;3.6;1.4;0.2|]; + `setosa,[|5.4;3.9;1.7;0.4|]; + `setosa,[|4.6;3.4;1.4;0.3|]; + `setosa,[|5.0;3.4;1.5;0.2|]; + `setosa,[|4.4;2.9;1.4;0.2|]; + `setosa,[|4.9;3.1;1.5;0.1|]; + `setosa,[|5.4;3.7;1.5;0.2|]; + `setosa,[|4.8;3.4;1.6;0.2|]; + `setosa,[|4.8;3.0;1.4;0.1|]; + `setosa,[|4.3;3.0;1.1;0.1|]; + `setosa,[|5.8;4.0;1.2;0.2|]; + `setosa,[|5.7;4.4;1.5;0.4|]; + `setosa,[|5.4;3.9;1.3;0.4|]; + `setosa,[|5.1;3.5;1.4;0.3|]; + `setosa,[|5.7;3.8;1.7;0.3|]; + `setosa,[|5.1;3.8;1.5;0.3|]; + `setosa,[|5.4;3.4;1.7;0.2|]; + `setosa,[|5.1;3.7;1.5;0.4|]; + `setosa,[|4.6;3.6;1.0;0.2|]; + `setosa,[|5.1;3.3;1.7;0.5|]; + `setosa,[|4.8;3.4;1.9;0.2|]; + `setosa,[|5.0;3.0;1.6;0.2|]; + `setosa,[|5.0;3.4;1.6;0.4|]; + `setosa,[|5.2;3.5;1.5;0.2|]; + `setosa,[|5.2;3.4;1.4;0.2|]; + `setosa,[|4.7;3.2;1.6;0.2|]; + `setosa,[|4.8;3.1;1.6;0.2|]; + `setosa,[|5.4;3.4;1.5;0.4|]; + `setosa,[|5.2;4.1;1.5;0.1|]; + `setosa,[|5.5;4.2;1.4;0.2|]; + `setosa,[|4.9;3.1;1.5;0.2|]; + `setosa,[|5.0;3.2;1.2;0.2|]; + `setosa,[|5.5;3.5;1.3;0.2|]; + `setosa,[|4.9;3.6;1.4;0.1|]; + `setosa,[|4.4;3.0;1.3;0.2|]; + `setosa,[|5.1;3.4;1.5;0.2|]; + `setosa,[|5.0;3.5;1.3;0.3|]; + `setosa,[|4.5;2.3;1.3;0.3|]; + `setosa,[|4.4;3.2;1.3;0.2|]; + `setosa,[|5.0;3.5;1.6;0.6|]; + `setosa,[|5.1;3.8;1.9;0.4|]; + `setosa,[|4.8;3.0;1.4;0.3|]; + `setosa,[|5.1;3.8;1.6;0.2|]; + `setosa,[|4.6;3.2;1.4;0.2|]; + `setosa,[|5.3;3.7;1.5;0.2|]; + `setosa,[|5.0;3.3;1.4;0.2|]; + `versicolor,[|7.0;3.2;4.7;1.4|]; + `versicolor,[|6.4;3.2;4.5;1.5|]; + `versicolor,[|6.9;3.1;4.9;1.5|]; + `versicolor,[|5.5;2.3;4.0;1.3|]; + `versicolor,[|6.5;2.8;4.6;1.5|]; + `versicolor,[|5.7;2.8;4.5;1.3|]; + `versicolor,[|6.3;3.3;4.7;1.6|]; + `versicolor,[|4.9;2.4;3.3;1.0|]; + `versicolor,[|6.6;2.9;4.6;1.3|]; + `versicolor,[|5.2;2.7;3.9;1.4|]; + `versicolor,[|5.0;2.0;3.5;1.0|]; + `versicolor,[|5.9;3.0;4.2;1.5|]; + `versicolor,[|6.0;2.2;4.0;1.0|]; + `versicolor,[|6.1;2.9;4.7;1.4|]; + `versicolor,[|5.6;2.9;3.6;1.3|]; + `versicolor,[|6.7;3.1;4.4;1.4|]; + `versicolor,[|5.6;3.0;4.5;1.5|]; + `versicolor,[|5.8;2.7;4.1;1.0|]; + `versicolor,[|6.2;2.2;4.5;1.5|]; + `versicolor,[|5.6;2.5;3.9;1.1|]; + `versicolor,[|5.9;3.2;4.8;1.8|]; + `versicolor,[|6.1;2.8;4.0;1.3|]; + `versicolor,[|6.3;2.5;4.9;1.5|]; + `versicolor,[|6.1;2.8;4.7;1.2|]; + `versicolor,[|6.4;2.9;4.3;1.3|]; + `versicolor,[|6.6;3.0;4.4;1.4|]; + `versicolor,[|6.8;2.8;4.8;1.4|]; + `versicolor,[|6.7;3.0;5.0;1.7|]; + `versicolor,[|6.0;2.9;4.5;1.5|]; + `versicolor,[|5.7;2.6;3.5;1.0|]; + `versicolor,[|5.5;2.4;3.8;1.1|]; + `versicolor,[|5.5;2.4;3.7;1.0|]; + `versicolor,[|5.8;2.7;3.9;1.2|]; + `versicolor,[|6.0;2.7;5.1;1.6|]; + `versicolor,[|5.4;3.0;4.5;1.5|]; + `versicolor,[|6.0;3.4;4.5;1.6|]; + `versicolor,[|6.7;3.1;4.7;1.5|]; + `versicolor,[|6.3;2.3;4.4;1.3|]; + `versicolor,[|5.6;3.0;4.1;1.3|]; + `versicolor,[|5.5;2.5;4.0;1.3|]; + `versicolor,[|5.5;2.6;4.4;1.2|]; + `versicolor,[|6.1;3.0;4.6;1.4|]; + `versicolor,[|5.8;2.6;4.0;1.2|]; + `versicolor,[|5.0;2.3;3.3;1.0|]; + `versicolor,[|5.6;2.7;4.2;1.3|]; + `versicolor,[|5.7;3.0;4.2;1.2|]; + `versicolor,[|5.7;2.9;4.2;1.3|]; + `versicolor,[|6.2;2.9;4.3;1.3|]; + `versicolor,[|5.1;2.5;3.0;1.1|]; + `versicolor,[|5.7;2.8;4.1;1.3|]; + `virginica,[|6.3;3.3;6.0;2.5|]; + `virginica,[|5.8;2.7;5.1;1.9|]; + `virginica,[|7.1;3.0;5.9;2.1|]; + `virginica,[|6.3;2.9;5.6;1.8|]; + `virginica,[|6.5;3.0;5.8;2.2|]; + `virginica,[|7.6;3.0;6.6;2.1|]; + `virginica,[|4.9;2.5;4.5;1.7|]; + `virginica,[|7.3;2.9;6.3;1.8|]; + `virginica,[|6.7;2.5;5.8;1.8|]; + `virginica,[|7.2;3.6;6.1;2.5|]; + `virginica,[|6.5;3.2;5.1;2.0|]; + `virginica,[|6.4;2.7;5.3;1.9|]; + `virginica,[|6.8;3.0;5.5;2.1|]; + `virginica,[|5.7;2.5;5.0;2.0|]; + `virginica,[|5.8;2.8;5.1;2.4|]; + `virginica,[|6.4;3.2;5.3;2.3|]; + `virginica,[|6.5;3.0;5.5;1.8|]; + `virginica,[|7.7;3.8;6.7;2.2|]; + `virginica,[|7.7;2.6;6.9;2.3|]; + `virginica,[|6.0;2.2;5.0;1.5|]; + `virginica,[|6.9;3.2;5.7;2.3|]; + `virginica,[|5.6;2.8;4.9;2.0|]; + `virginica,[|7.7;2.8;6.7;2.0|]; + `virginica,[|6.3;2.7;4.9;1.8|]; + `virginica,[|6.7;3.3;5.7;2.1|]; + `virginica,[|7.2;3.2;6.0;1.8|]; + `virginica,[|6.2;2.8;4.8;1.8|]; + `virginica,[|6.1;3.0;4.9;1.8|]; + `virginica,[|6.4;2.8;5.6;2.1|]; + `virginica,[|7.2;3.0;5.8;1.6|]; + `virginica,[|7.4;2.8;6.1;1.9|]; + `virginica,[|7.9;3.8;6.4;2.0|]; + `virginica,[|6.4;2.8;5.6;2.2|]; + `virginica,[|6.3;2.8;5.1;1.5|]; + `virginica,[|6.1;2.6;5.6;1.4|]; + `virginica,[|7.7;3.0;6.1;2.3|]; + `virginica,[|6.3;3.4;5.6;2.4|]; + `virginica,[|6.4;3.1;5.5;1.8|]; + `virginica,[|6.0;3.0;4.8;1.8|]; + `virginica,[|6.9;3.1;5.4;2.1|]; + `virginica,[|6.7;3.1;5.6;2.4|]; + `virginica,[|6.9;3.1;5.1;2.3|]; + `virginica,[|5.8;2.7;5.1;1.9|]; + `virginica,[|6.8;3.2;5.9;2.3|]; + `virginica,[|6.7;3.3;5.7;2.5|]; + `virginica,[|6.7;3.0;5.2;2.3|]; + `virginica,[|6.3;2.5;5.0;1.9|]; + `virginica,[|6.5;3.0;5.2;2.0|]; + `virginica,[|6.2;3.4;5.4;2.3|]; + `virginica,[|5.9;3.0;5.1;1.8|]; + ] + in + let iris_gb = gauss_estimate iris in + let result = + List.map ~f:(fun (v, d) -> gauss_eval iris_gb d |> most_likely, v) iris + in + let different = + List.fold_left result ~f:(fun s (x,y) -> if x = y then s else s + 1) + ~init:0 + in + Assert.is_true (different = 6)); + + () diff --git a/src/lib/oml.ml b/src/lib/oml.ml index 2cf44aa..0a35580 100644 --- a/src/lib/oml.ml +++ b/src/lib/oml.ml @@ -13,3 +13,4 @@ module Sampling = Sampling module Solvers = Solvers module Running = Running module Svd = Svd +module Classify = Classify