Skip to content

Commit

Permalink
Fixing CNN shape issue (#84)
Browse files Browse the repository at this point in the history
* Initial commit that uncovered the issue.

* Small update to the notebook.
  • Loading branch information
drewoldag authored Oct 8, 2024
1 parent 3068b8f commit 6b8d978
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 3 deletions.
66 changes: 66 additions & 0 deletions docs/notebooks/TrainingAModel.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import fibad"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create an instance of a fibad object, instantiated (implicitly) with the default configuration file\n",
"fibad_instance = fibad.Fibad()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Update a few of the configuration parameters\n",
"fibad_instance.config[\"model\"][\"name\"] = \"ExampleCNN\"\n",
"fibad_instance.config[\"data_set\"][\"name\"] = \"CifarDataSet\"\n",
"fibad_instance.config[\"data_loader\"][\"batch_size\"] = 64"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Begin training the Example CNN model using the CIFAR-10 dataset\n",
"fibad_instance.train()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "fibad",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
4 changes: 2 additions & 2 deletions src/fibad/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _validate_runtime_config(runtime_config: ConfigDict, default_config: ConfigD
"""
for key in runtime_config:
if key not in default_config:
msg = f"Runtime config contains key or section {key} which has no default defined."
msg = f"Runtime config contains key or section {key} which has no default defined. "
msg += f"All configuration keys and sections must be defined in {DEFAULT_CONFIG_FILEPATH}"
raise RuntimeError(msg)

Expand Down Expand Up @@ -138,7 +138,7 @@ def resolve_runtime_config(runtime_config_filepath: Union[Path, str, None] = Non
"""Resolve a user-supplied runtime config to where we will actually pull config from.
1) If a runtime config file is specified, we will use that file
2) If not file is specified and there is a file named "fibad_config.toml" in the cwd we will use that file
2) If no file is specified and there is a file named "fibad_config.toml" in the cwd we will use that file
3) If no file is specified and there is no file named "fibad_config.toml" in the current working directory
we will exclusively work off the configuration defaults in the packaged "fibad_default_config.toml"
file.
Expand Down
2 changes: 1 addition & 1 deletion src/fibad/models/example_cnn_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

@fibad_model
class ExampleCNN(nn.Module):
def __init__(self, config, _):
def __init__(self, config, shape):
super().__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
Expand Down

0 comments on commit 6b8d978

Please sign in to comment.