-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #55 from lincc-frameworks/generate_data
Add a toy dataset generator function
- Loading branch information
Showing
3 changed files
with
80 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .generation import * # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import numpy as np | ||
|
||
from nested_pandas import NestedFrame | ||
|
||
|
||
def generate_data(n_base, n_layer, seed=None) -> NestedFrame: | ||
"""Generates a toy dataset. | ||
Parameters | ||
---------- | ||
n_base : int | ||
The number of rows to generate for the base layer | ||
n_layer : int, or dict | ||
The number of rows per n_base row to generate for a nested layer. | ||
Alternatively, a dictionary of layer label, layer_size pairs may be | ||
specified to created multiple nested columns with custom sizing. | ||
seed : int | ||
A seed to use for random generation of data | ||
Returns | ||
------- | ||
NestedFrame | ||
The constructed NestedFrame. | ||
Examples | ||
-------- | ||
>>> nested_pandas.datasets.generate_data(10,100) | ||
>>> nested_pandas.datasets.generate_data(10, {"nested_a": 100, "nested_b": 200}) | ||
""" | ||
# use provided seed, "None" acts as if no seed is provided | ||
randomstate = np.random.RandomState(seed=seed) | ||
|
||
# Generate base data | ||
base_data = {"a": randomstate.random(n_base), "b": randomstate.random(n_base) * 2} | ||
base_nf = NestedFrame(data=base_data) | ||
|
||
# In case of int, create a single nested layer called "nested" | ||
if isinstance(n_layer, int): | ||
n_layer = {"nested": n_layer} | ||
|
||
# It should be a dictionary | ||
if isinstance(n_layer, dict): | ||
for key in n_layer: | ||
layer_size = n_layer[key] | ||
layer_data = { | ||
"t": randomstate.random(layer_size * n_base) * 20, | ||
"flux": randomstate.random(layer_size * n_base) * 100, | ||
"band": randomstate.choice(["r", "g"], size=layer_size * n_base), | ||
"index": np.arange(layer_size * n_base) % n_base, | ||
} | ||
layer_nf = NestedFrame(data=layer_data).set_index("index") | ||
base_nf = base_nf.add_nested(layer_nf, key) | ||
return base_nf | ||
else: | ||
raise TypeError("Input to n_layer is not an int or dict.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import pytest | ||
from nested_pandas.datasets import generate_data | ||
|
||
|
||
@pytest.mark.parametrize("n_layers", [10, {"nested_a": 10, "nested_b": 20}]) | ||
def test_generate_data(n_layers): | ||
"""test the data generator function""" | ||
nf = generate_data(10, n_layers, seed=1) | ||
|
||
if isinstance(n_layers, int): | ||
assert len(nf.nested.nest.to_flat()) == 100 | ||
|
||
elif isinstance(n_layers, dict): | ||
assert "nested_a" in nf.columns | ||
assert "nested_b" in nf.columns | ||
|
||
assert len(nf.nested_a.nest.to_flat()) == 100 | ||
assert len(nf.nested_b.nest.to_flat()) == 200 | ||
|
||
|
||
def test_generate_data_bad_input(): | ||
"""test a poor n_layer input to generate_data""" | ||
with pytest.raises(TypeError): | ||
generate_data(10, "nested", seed=1) |