Skip to content

Commit

Permalink
Merge pull request #26 from AndrewRook/non_nfldb_data
Browse files Browse the repository at this point in the history
training on non-nfldb data
  • Loading branch information
AndrewRook authored Aug 12, 2016
2 parents 271c013 + 5cc2fe9 commit 6cdee3d
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 10 deletions.
26 changes: 16 additions & 10 deletions nflwin/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,11 +161,14 @@ def train_model(self,
"""
self._training_seasons = []
self._training_season_types = []
if source_data == "nfldb":
source_data = utilities.get_nfldb_play_data(season_years=training_seasons,
season_types=training_season_types)
self._training_seasons = training_seasons
self._training_season_types = training_season_types
if isinstance(source_data, basestring):
if source_data == "nfldb":
source_data = utilities.get_nfldb_play_data(season_years=training_seasons,
season_types=training_season_types)
self._training_seasons = training_seasons
self._training_season_types = training_season_types
else:
raise ValueError("WPModel: if source_data is a string, it must be 'nfldb'")
target_col = source_data[target_colname]
feature_cols = source_data.drop(target_colname, axis=1)
self.model.fit(feature_cols, target_col)
Expand Down Expand Up @@ -249,11 +252,14 @@ def validate_model(self,

self._validation_seasons = []
self._validation_season_types = []
if source_data == "nfldb":
source_data = utilities.get_nfldb_play_data(season_years=validation_seasons,
season_types=validation_season_types)
self._validation_seasons = validation_seasons
self._validation_season_types = validation_season_types
if isinstance(source_data, basestring):
if source_data == "nfldb":
source_data = utilities.get_nfldb_play_data(season_years=training_seasons,
season_types=training_season_types)
self._training_seasons = training_seasons
self._training_season_types = training_season_types
else:
raise ValueError("WPModel: if source_data is a string, it must be 'nfldb'")

target_col = source_data[target_colname]
feature_cols = source_data.drop(target_colname, axis=1)
Expand Down
91 changes: 91 additions & 0 deletions nflwin/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,104 @@

from nflwin import model


class TestDefaults(object):
"""Tests for defaults."""

def test_column_descriptions_set(self):
wpmodel = model.WPModel()
assert isinstance(wpmodel.column_descriptions, collections.Mapping)

class TestModelTrain(object):
"""Tests for the train_model method."""

def test_bad_string(self):
wpmodel = model.WPModel()
with pytest.raises(ValueError):
wpmodel.train_model(source_data="this is a bad string")

def test_dataframe_input(self):
wpmodel = model.WPModel()
test_data = {'offense_won': {0: True, 1: False, 2: False,
3: False, 4: False, 5: True,
6: True, 7: True, 8: True, 9: False},
'home_team': {0: 'NYG', 1: 'NYG', 2: 'NYG', 3: 'NYG',
4: 'NYG', 5: 'NYG', 6: 'NYG', 7: 'NYG',
8: 'NYG', 9: 'NYG'},
'away_team': {0: 'DAL', 1: 'DAL', 2: 'DAL', 3: 'DAL',
4: 'DAL', 5: 'DAL', 6: 'DAL', 7: 'DAL',
8: 'DAL', 9: 'DAL'},
'gsis_id': {0: '2012090500', 1: '2012090500', 2: '2012090500',
3: '2012090500', 4: '2012090500', 5: '2012090500',
6: '2012090500', 7: '2012090500', 8: '2012090500',
9: '2012090500'},
'play_id': {0: 35, 1: 57, 2: 79, 3: 103, 4: 125, 5: 150,
6: 171, 7: 190, 8: 212, 9: 252},
'seconds_elapsed': {0: 0.0, 1: 4.0, 2: 11.0, 3: 55.0, 4: 62.0,
5: 76.0, 6: 113.0, 7: 153.0, 8: 159.0, 9: 171.0},
'down': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 1, 6: 2, 7: 3, 8: 4, 9: 1},
'curr_home_score': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0},
'offense_team': {0: 'DAL', 1: 'NYG', 2: 'NYG', 3: 'NYG',
4: 'NYG', 5: 'DAL', 6: 'DAL', 7: 'DAL',
8: 'DAL', 9: 'NYG'},
'curr_away_score': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0},
'yardline': {0: -15.0, 1: -34.0, 2: -34.0, 3: -29.0,
4: -29.0, 5: -26.0, 6: -23.0, 7: -31.0, 8: -31.0, 9: -37.0},
'drive_id': {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 2, 6: 2, 7: 2, 8: 2, 9: 3},
'yards_to_go': {0: 0, 1: 10, 2: 10, 3: 5, 4: 5, 5: 10, 6: 7, 7: 15, 8: 15, 9: 10},
'quarter': {0: 'Q1', 1: 'Q1', 2: 'Q1', 3: 'Q1', 4: 'Q1',
5: 'Q1', 6: 'Q1', 7: 'Q1', 8: 'Q1', 9: 'Q1'}
}
test_df = pd.DataFrame(test_data)
wpmodel.train_model(source_data=test_df)

class TestModelValidate(object):
"""Tests for the validate_model method."""

def setup_method(self, method):
self.test_data = {'offense_won': {0: True, 1: False, 2: False,
3: False, 4: False, 5: True,
6: True, 7: True, 8: True, 9: False},
'home_team': {0: 'NYG', 1: 'NYG', 2: 'NYG', 3: 'NYG',
4: 'NYG', 5: 'NYG', 6: 'NYG', 7: 'NYG',
8: 'NYG', 9: 'NYG'},
'away_team': {0: 'DAL', 1: 'DAL', 2: 'DAL', 3: 'DAL',
4: 'DAL', 5: 'DAL', 6: 'DAL', 7: 'DAL',
8: 'DAL', 9: 'DAL'},
'gsis_id': {0: '2012090500', 1: '2012090500', 2: '2012090500',
3: '2012090500', 4: '2012090500', 5: '2012090500',
6: '2012090500', 7: '2012090500', 8: '2012090500',
9: '2012090500'},
'play_id': {0: 35, 1: 57, 2: 79, 3: 103, 4: 125, 5: 150,
6: 171, 7: 190, 8: 212, 9: 252},
'seconds_elapsed': {0: 0.0, 1: 4.0, 2: 11.0, 3: 55.0, 4: 62.0,
5: 76.0, 6: 113.0, 7: 153.0, 8: 159.0, 9: 171.0},
'down': {0: 0, 1: 1, 2: 2, 3: 3, 4: 4, 5: 1, 6: 2, 7: 3, 8: 4, 9: 1},
'curr_home_score': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0},
'offense_team': {0: 'DAL', 1: 'NYG', 2: 'NYG', 3: 'NYG',
4: 'NYG', 5: 'DAL', 6: 'DAL', 7: 'DAL',
8: 'DAL', 9: 'NYG'},
'curr_away_score': {0: 0, 1: 0, 2: 0, 3: 0, 4: 0, 5: 0, 6: 0, 7: 0, 8: 0, 9: 0},
'yardline': {0: -15.0, 1: -34.0, 2: -34.0, 3: -29.0,
4: -29.0, 5: -26.0, 6: -23.0, 7: -31.0, 8: -31.0, 9: -37.0},
'drive_id': {0: 1, 1: 1, 2: 1, 3: 1, 4: 1, 5: 2, 6: 2, 7: 2, 8: 2, 9: 3},
'yards_to_go': {0: 0, 1: 10, 2: 10, 3: 5, 4: 5, 5: 10, 6: 7, 7: 15, 8: 15, 9: 10},
'quarter': {0: 'Q1', 1: 'Q1', 2: 'Q1', 3: 'Q1', 4: 'Q1',
5: 'Q1', 6: 'Q1', 7: 'Q1', 8: 'Q1', 9: 'Q1'}
}
self.test_df = pd.DataFrame(self.test_data)

def test_bad_string(self):
wpmodel = model.WPModel()
wpmodel.train_model(source_data=self.test_df)
with pytest.raises(ValueError):
wpmodel.validate_model(source_data="this is bad data")


def test_dataframe_input(self):
wpmodel = model.WPModel()
wpmodel.train_model(source_data=self.test_df)
wpmodel.validate_model(source_data=self.test_df)

class TestTestDistribution(object):
"""Tests the _test_distribution static method of WPModel."""
Expand Down

0 comments on commit 6cdee3d

Please sign in to comment.