Skip to content

Commit

Permalink
Add extra options to CreateFrame to make (potentially sparse) binary …
Browse files Browse the repository at this point in the history
…columns.
  • Loading branch information
arnocandel committed Dec 30, 2014
1 parent 859da97 commit 7e0d650
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
10 changes: 9 additions & 1 deletion src/main/java/hex/CreateFrame.java
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,12 @@ public class CreateFrame extends Request2 {
@API(help = "Fraction of categorical columns (for randomize=true)", filter = Default.class, dmin = 0, dmax = 1, json=true)
public double categorical_fraction = 0.2;

@API(help = "Fraction of binary columns (for randomize=true)", filter = Default.class, dmin = 0, dmax = 1, json=true)
public double binary_fraction = 0.1;

@API(help = "Approximate fraction of 1's in binary columns", filter = Default.class, dmin = 0, dmax = 1, json=true)
public double binary_ones_fraction = 0.02;

@API(help = "Factor levels for categorical variables", filter = Default.class, lmin = 2, json=true)
public int factors = 100;

Expand All @@ -60,9 +66,11 @@ public class CreateFrame extends Request2 {

@Override public Response serve() {
try {
if (integer_fraction + categorical_fraction > 1) throw new IllegalArgumentException("Integer and categorical fractions must add up to <= 1.");
if (integer_fraction + binary_fraction + categorical_fraction > 1) throw new IllegalArgumentException("Integer, binary and categorical fractions must add up to <= 1.");
if (Math.abs(missing_fraction) > 1) throw new IllegalArgumentException("Missing fraction must be between 0 and 1.");
if (Math.abs(integer_fraction) > 1) throw new IllegalArgumentException("Integer fraction must be between 0 and 1.");
if (Math.abs(binary_fraction) > 1) throw new IllegalArgumentException("Binary fraction must be between 0 and 1.");
if (Math.abs(binary_ones_fraction) > 1) throw new IllegalArgumentException("Binary ones fraction must be between 0 and 1.");
if (Math.abs(categorical_fraction) > 1) throw new IllegalArgumentException("Categorical fraction must be between 0 and 1.");
if (categorical_fraction > 0 && factors <= 1) throw new IllegalArgumentException("Factors must be larger than 2 for categorical data.");
if (response_factors < 1) throw new IllegalArgumentException("Response factors must be either 1 (real-valued response), or >=2 (factor levels).");
Expand Down
24 changes: 18 additions & 6 deletions src/main/java/water/fvec/FrameCreator.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,18 @@ public FrameCreator(CreateFrame createFrame, Key job) {

int catcols = (int)(_createFrame.categorical_fraction * _createFrame.cols);
int intcols = (int)(_createFrame.integer_fraction * _createFrame.cols);
int realcols = _createFrame.cols - catcols - intcols;
int bincols = (int)(_createFrame.binary_fraction * _createFrame.cols);
int realcols = _createFrame.cols - catcols - intcols - bincols;

assert(catcols >= 0);
assert(intcols >= 0);
assert(bincols >= 0);
assert(realcols >= 0);

_cat_cols = Arrays.copyOfRange(shuffled_idx, 0, catcols);
_int_cols = Arrays.copyOfRange(shuffled_idx, catcols, catcols+intcols);
_real_cols = Arrays.copyOfRange(shuffled_idx, catcols+intcols, catcols+intcols+realcols);
_cat_cols = Arrays.copyOfRange(shuffled_idx, 0, catcols);
_int_cols = Arrays.copyOfRange(shuffled_idx, catcols, catcols+intcols);
_real_cols = Arrays.copyOfRange(shuffled_idx, catcols+intcols, catcols+intcols+realcols);
_bin_cols = Arrays.copyOfRange(shuffled_idx, catcols+intcols+realcols, catcols+intcols+realcols+bincols);

// create domains for categorical variables
if (_createFrame.randomize) {
Expand All @@ -63,6 +66,7 @@ public FrameCreator(CreateFrame createFrame, Key job) {
private int[] _cat_cols;
private int[] _int_cols;
private int[] _real_cols;
private int[] _bin_cols;
private String[][] _domain;
private Frame _out;
final private Key _job;
Expand All @@ -79,7 +83,7 @@ public FrameCreator(CreateFrame createFrame, Key job) {
_out.delete_and_lock(_job);

// fill with random values
new FrameRandomizer(_createFrame, _cat_cols, _int_cols, _real_cols).doAll(_out);
new FrameRandomizer(_createFrame, _cat_cols, _int_cols, _real_cols, _bin_cols).doAll(_out);

//overwrite a fraction with N/A
new MissingInserter(this, _createFrame.seed, _createFrame.missing_fraction).asyncExec(_out);
Expand All @@ -95,12 +99,14 @@ private static class FrameRandomizer extends MRTask2<FrameRandomizer> {
final private int[] _cat_cols;
final private int[] _int_cols;
final private int[] _real_cols;
final private int[] _bin_cols;

public FrameRandomizer(CreateFrame createFrame, int[] cat_cols, int[] int_cols, int[] real_cols){
public FrameRandomizer(CreateFrame createFrame, int[] cat_cols, int[] int_cols, int[] real_cols, int[] bin_cols){
_createFrame = createFrame;
_cat_cols = cat_cols;
_int_cols = int_cols;
_real_cols = real_cols;
_bin_cols = bin_cols;
}

//row+col-dependent RNG for reproducibility with different number of VMs, chunks, etc.
Expand Down Expand Up @@ -143,6 +149,12 @@ else if (_createFrame.positive_response)
cs[c].set0(r, _createFrame.real_range * (1 - 2 * rng.nextDouble()));
}
}
for (int c : _bin_cols) {
for (int r = 0; r < cs[c]._len; r++) {
setSeed(rng, c, cs[c]._start + r);
cs[c].set0(r, rng.nextFloat() > _createFrame.binary_ones_fraction ? 0 : 1);
}
}
}
}

Expand Down

0 comments on commit 7e0d650

Please sign in to comment.