From 7e0d6509fde1733fcfe17f92f5c65abc8597fa8a Mon Sep 17 00:00:00 2001 From: Arno Candel Date: Tue, 30 Dec 2014 15:41:08 -0800 Subject: [PATCH] Add extra options to CreateFrame to make (potentially sparse) binary columns. --- src/main/java/hex/CreateFrame.java | 10 ++++++++- src/main/java/water/fvec/FrameCreator.java | 24 ++++++++++++++++------ 2 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/main/java/hex/CreateFrame.java b/src/main/java/hex/CreateFrame.java index 8124f007ab..324250d0bb 100644 --- a/src/main/java/hex/CreateFrame.java +++ b/src/main/java/hex/CreateFrame.java @@ -41,6 +41,12 @@ public class CreateFrame extends Request2 { @API(help = "Fraction of categorical columns (for randomize=true)", filter = Default.class, dmin = 0, dmax = 1, json=true) public double categorical_fraction = 0.2; + @API(help = "Fraction of binary columns (for randomize=true)", filter = Default.class, dmin = 0, dmax = 1, json=true) + public double binary_fraction = 0.1; + + @API(help = "Approximate fraction of 1's in binary columns", filter = Default.class, dmin = 0, dmax = 1, json=true) + public double binary_ones_fraction = 0.02; + @API(help = "Factor levels for categorical variables", filter = Default.class, lmin = 2, json=true) public int factors = 100; @@ -60,9 +66,11 @@ public class CreateFrame extends Request2 { @Override public Response serve() { try { - if (integer_fraction + categorical_fraction > 1) throw new IllegalArgumentException("Integer and categorical fractions must add up to <= 1."); + if (integer_fraction + binary_fraction + categorical_fraction > 1) throw new IllegalArgumentException("Integer, binary and categorical fractions must add up to <= 1."); if (Math.abs(missing_fraction) > 1) throw new IllegalArgumentException("Missing fraction must be between 0 and 1."); if (Math.abs(integer_fraction) > 1) throw new IllegalArgumentException("Integer fraction must be between 0 and 1."); + if (Math.abs(binary_fraction) > 1) throw new IllegalArgumentException("Binary fraction must be between 0 and 1."); + if (Math.abs(binary_ones_fraction) > 1) throw new IllegalArgumentException("Binary ones fraction must be between 0 and 1."); if (Math.abs(categorical_fraction) > 1) throw new IllegalArgumentException("Categorical fraction must be between 0 and 1."); if (categorical_fraction > 0 && factors <= 1) throw new IllegalArgumentException("Factors must be larger than 2 for categorical data."); if (response_factors < 1) throw new IllegalArgumentException("Response factors must be either 1 (real-valued response), or >=2 (factor levels)."); diff --git a/src/main/java/water/fvec/FrameCreator.java b/src/main/java/water/fvec/FrameCreator.java index e31b08c235..1588f371d0 100644 --- a/src/main/java/water/fvec/FrameCreator.java +++ b/src/main/java/water/fvec/FrameCreator.java @@ -29,15 +29,18 @@ public FrameCreator(CreateFrame createFrame, Key job) { int catcols = (int)(_createFrame.categorical_fraction * _createFrame.cols); int intcols = (int)(_createFrame.integer_fraction * _createFrame.cols); - int realcols = _createFrame.cols - catcols - intcols; + int bincols = (int)(_createFrame.binary_fraction * _createFrame.cols); + int realcols = _createFrame.cols - catcols - intcols - bincols; assert(catcols >= 0); assert(intcols >= 0); + assert(bincols >= 0); assert(realcols >= 0); - _cat_cols = Arrays.copyOfRange(shuffled_idx, 0, catcols); - _int_cols = Arrays.copyOfRange(shuffled_idx, catcols, catcols+intcols); - _real_cols = Arrays.copyOfRange(shuffled_idx, catcols+intcols, catcols+intcols+realcols); + _cat_cols = Arrays.copyOfRange(shuffled_idx, 0, catcols); + _int_cols = Arrays.copyOfRange(shuffled_idx, catcols, catcols+intcols); + _real_cols = Arrays.copyOfRange(shuffled_idx, catcols+intcols, catcols+intcols+realcols); + _bin_cols = Arrays.copyOfRange(shuffled_idx, catcols+intcols+realcols, catcols+intcols+realcols+bincols); // create domains for categorical variables if (_createFrame.randomize) { @@ -63,6 +66,7 @@ public FrameCreator(CreateFrame createFrame, Key job) { private int[] _cat_cols; private int[] _int_cols; private int[] _real_cols; + private int[] _bin_cols; private String[][] _domain; private Frame _out; final private Key _job; @@ -79,7 +83,7 @@ public FrameCreator(CreateFrame createFrame, Key job) { _out.delete_and_lock(_job); // fill with random values - new FrameRandomizer(_createFrame, _cat_cols, _int_cols, _real_cols).doAll(_out); + new FrameRandomizer(_createFrame, _cat_cols, _int_cols, _real_cols, _bin_cols).doAll(_out); //overwrite a fraction with N/A new MissingInserter(this, _createFrame.seed, _createFrame.missing_fraction).asyncExec(_out); @@ -95,12 +99,14 @@ private static class FrameRandomizer extends MRTask2 { final private int[] _cat_cols; final private int[] _int_cols; final private int[] _real_cols; + final private int[] _bin_cols; - public FrameRandomizer(CreateFrame createFrame, int[] cat_cols, int[] int_cols, int[] real_cols){ + public FrameRandomizer(CreateFrame createFrame, int[] cat_cols, int[] int_cols, int[] real_cols, int[] bin_cols){ _createFrame = createFrame; _cat_cols = cat_cols; _int_cols = int_cols; _real_cols = real_cols; + _bin_cols = bin_cols; } //row+col-dependent RNG for reproducibility with different number of VMs, chunks, etc. @@ -143,6 +149,12 @@ else if (_createFrame.positive_response) cs[c].set0(r, _createFrame.real_range * (1 - 2 * rng.nextDouble())); } } + for (int c : _bin_cols) { + for (int r = 0; r < cs[c]._len; r++) { + setSeed(rng, c, cs[c]._start + r); + cs[c].set0(r, rng.nextFloat() > _createFrame.binary_ones_fraction ? 0 : 1); + } + } } }