Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OBSRAI 2.0 #453

Open
wants to merge 55 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
7b800c9
HF loader addition
Repcak2000 Mar 15, 2024
afcff39
feat: requirements update
Repcak2000 Apr 17, 2024
d77d4cb
fix(pre-commit.ci): auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 17, 2024
9fa95e3
fix: ruff reformat check
Repcak2000 Apr 17, 2024
48dcd43
fix: SRAI consultations
Repcak2000 Apr 19, 2024
c143abd
chore: minor changes ;)
Calychas Apr 19, 2024
60ffe33
chore: pre-commit fixes
Calychas Apr 19, 2024
f4f8ed4
update tooling
Calychas Apr 22, 2024
9446a93
add changelog
Calychas Apr 22, 2024
6e461c1
chore: hf -> huggingface
Calychas Apr 22, 2024
655e966
chore: remove config.yamls
Calychas Apr 23, 2024
1251266
feat: Initialization of base models
Repcak2000 Apr 24, 2024
aea6428
feat: regression base model init
Repcak2000 Apr 24, 2024
86b4972
feat: Vectorizer hex2vec added
Repcak2000 Apr 25, 2024
9c5348c
initial version of evaluator
mskaa3 Apr 25, 2024
d01c4c4
feat: extended evaluator with custom metrics function & loss computing
mskaa3 Apr 25, 2024
d44c551
fix: usage of evaluator
mskaa3 Apr 25, 2024
b040825
feat: dataloader allowed in evaluate function
mskaa3 Apr 25, 2024
d2ba3b6
feat: evaluate returning metrics dictionary
mskaa3 Apr 25, 2024
6b58526
feat: inital version of trainer
mskaa3 Apr 25, 2024
0d6b2b1
added untracked changes
mskaa3 Apr 25, 2024
853b9c4
feat: initial prediction method in trainer
mskaa3 Apr 25, 2024
4d1c158
feat: vectorizer standarization
Repcak2000 Apr 26, 2024
d18f1e5
fix: metrics import
mskaa3 Apr 26, 2024
af04c97
Merge branch 'base_models' of https://github.com/Repcak2000/srai into…
mskaa3 Apr 26, 2024
dce1d50
fix: circular imports in trainer fixed
mskaa3 Apr 26, 2024
f6ebc57
fix: fixed mistake on model arguments
mskaa3 Apr 26, 2024
73bb0fa
fix: reshaped target labels
mskaa3 Apr 26, 2024
31e2b0e
fix: TODO list add
Repcak2000 Apr 26, 2024
e2cb396
chore: Merge branch 'base_models' of https://github.com/Repcak2000/sr…
Repcak2000 Apr 26, 2024
3513962
feat: shifted prediction functionalities to predicor module
mskaa3 Apr 28, 2024
68ffc3f
fix: zeroed early stopping counter
mskaa3 Apr 28, 2024
6ff2938
fix on vectorizer reference"
mskaa3 Apr 28, 2024
b9776ea
fix on vectorizer reference"
mskaa3 Apr 28, 2024
31c6b0d
feat: using predictor in example, fix on h3 indexes encoding
mskaa3 Apr 28, 2024
1d98615
fix: Resolve conflicts
Repcak2000 May 6, 2024
bebd4cf
feat: categorical / numerical columns moved to dataset object - Airbn…
Repcak2000 May 7, 2024
525871d
feat: add core columns information to datasets
Repcak2000 May 10, 2024
6b70af9
feat: GeoVexEmbedder added to Vectorizer
Repcak2000 May 10, 2024
86eb16f
feat: train dev test split add
Repcak2000 May 14, 2024
658872b
fix: remove data leakage from Vectorizer
Repcak2000 May 16, 2024
45bb7c6
added initial colum names
mskaa3 May 16, 2024
4069d53
feat: Train dev test split, based on spatial h3
Repcak2000 May 17, 2024
92676b6
feat: regression generalization, type datasets addition
Repcak2000 May 21, 2024
9e5718d
feat: changed PortoTaxi preprocessing
Jakub-Polczyk-PWr May 26, 2024
0a0c575
fix: change train_test_split methods
Repcak2000 May 30, 2024
619b538
fix: typo in dataset name
Repcak2000 May 31, 2024
dbc2209
adjusted bucket split to hex values
mskaa3 Jun 2, 2024
e88b8c7
fix: avoid warnings
mskaa3 Jun 2, 2024
9653cc5
added constant to metric calculation to avoid zeroes
mskaa3 Jun 2, 2024
b8fa5b4
added normalization to metrics
mskaa3 Jun 2, 2024
b24764d
chore: introducing information about dataset
Repcak2000 Jun 3, 2024
47749e6
feat: Benchmark version added to Airbnb dataset.
Repcak2000 Jun 3, 2024
06824ef
fix: change approach to resolution based in airbnb dataset.
Repcak2000 Jun 3, 2024
12cc305
fix: change approach in train test split to resolution based.
Repcak2000 Jun 3, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ repos:
- id: conventional-pre-commit
stages: [commit-msg]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.3.7'
rev: 'v0.4.1'
hooks:
- id: ruff
types_or: [ python, pyi, jupyter ]
Expand All @@ -28,7 +28,7 @@ repos:
args: ["--config-file", "pyproject.toml"]
additional_dependencies: ['types-requests', 'types-six']
- repo: https://github.com/pdm-project/pdm
rev: 2.14.0
rev: 2.15.0
hooks:
- id: pdm-lock-check
- id: pdm-export
Expand Down
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Changed

### Added

- Initial implementation of datasets [#430](https://github.com/kraina-ai/srai/pull/430) for feature enrichment and benchmarking.

## [0.7.3] - 2024-04-21

### Changed
Expand Down
306 changes: 306 additions & 0 deletions examples/base_models/regression_model.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,306 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"import geopandas as gpd\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.optim as optim\n",
"from shapely.geometry import Polygon\n",
"\n",
"from srai.datasets import AirbnbMulticityDataset\n",
"from srai.h3 import h3_to_geoseries\n",
"from srai.models import Evaluator, Predictor, RegressionBaseModel, Trainer, Vectorizer\n",
"from srai.plotting import plot_numeric_data\n",
"from srai.regionalizers import H3Regionalizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"airbnb = AirbnbMulticityDataset()\n",
"gdf_airbnb = airbnb.load(os.getenv(\"HF_TOKEN\"))\n",
"gdf_airbnb = gdf_airbnb.loc[gdf_airbnb[\"city\"].isin([\"paris\"])]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"resolution = 8"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_gdf, test_gdf = airbnb.train_test_split_bucket_regression(gdf_airbnb)\n",
"train_gdf, dev_gdf = airbnb.train_test_split_bucket_regression(train_gdf) # get dev set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"train_gdf, test_gdf = airbnb.train_test_split_spatial_points(gdf_airbnb)\n",
"train_gdf, dev_gdf = airbnb.train_test_split_spatial_points(train_gdf) # get dev set"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"vectorizer = Vectorizer(\n",
" gdf_train=train_gdf,\n",
" HF_dataset_object=airbnb,\n",
" embedder_type=\"Hex2VecEmbedder\",\n",
" h3_resolution=resolution,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_airbnb_train = vectorizer.get_dataset(train_gdf)\n",
"embedding_size = dataset_airbnb_train[\"X\"].shape[1]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"(train_gdf.shape[0] + test_gdf.shape[0] + dev_gdf.shape[0]) == gdf_airbnb.shape[0]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_airbnb_test = vectorizer.get_dataset(test_gdf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_airbnb_dev = vectorizer.get_dataset(dev_gdf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_airbnb_dev"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_airbnb_test"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset_airbnb_train"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"type(dataset_airbnb_train)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"regression_model = RegressionBaseModel(embedding_size)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"loss_fn = nn.L1Loss()\n",
"optimizer = optim.Adam(regression_model.parameters(), lr=0.001)\n",
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"args = {\n",
" \"batch_size\": 32,\n",
" \"task\": \"regression\",\n",
" \"epochs\": 50,\n",
" \"device\": device,\n",
" \"metric2look4\": \"MAE\",\n",
"}\n",
"trainer = Trainer(\n",
" model=regression_model,\n",
" train_dataset=dataset_airbnb_train,\n",
" eval_dataset=dataset_airbnb_dev,\n",
" optimizer=optimizer,\n",
" loss_fn=loss_fn,\n",
" **args,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model, _, _ = trainer.train()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"evaluator = Evaluator(task=\"regression\", device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"evaluator.evaluate(model, dataset_airbnb_test, return_metrics=False)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"predictor = Predictor(\"regression\", device=device)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"_, hexes, values = predictor.predict(model, dataset_airbnb_test, resolution=resolution)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"original_label = [dataset_airbnb_test[i][\"y\"] for i in range(len(dataset_airbnb_test))]\n",
"original_hexes = [dataset_airbnb_test[i][\"X_h3_idx\"] for i in range(len(dataset_airbnb_test))]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"polygons = h3_to_geoseries(\n",
" hexes,\n",
")\n",
"preds_gdf = gpd.GeoDataFrame(geometry=polygons)\n",
"preds_gdf.crs = {\"init\": \"epsg:4326\"}\n",
"preds_gdf[\"price\"] = [tensor.item() for tensor in values]\n",
"preds_gdf[\"region_id\"] = hexes\n",
"preds_gdf.index = preds_gdf[\"region_id\"]\n",
"\n",
"original_polygons = h3_to_geoseries(original_hexes)\n",
"original_gdf = gpd.GeoDataFrame(geometry=[Polygon(polygon) for polygon in original_polygons])\n",
"original_gdf.crs = {\"init\": \"epsg:4326\"}\n",
"original_gdf[\"price\"] = [tensor.item() for tensor in original_label]\n",
"original_gdf[\"region_id\"] = original_hexes\n",
"original_gdf.index = original_gdf[\"region_id\"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"regionalizer = H3Regionalizer(resolution=resolution)\n",
"regions = regionalizer.transform(original_gdf)\n",
"plot_numeric_data(regions, \"price\", original_gdf)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_numeric_data(regions, \"price\", preds_gdf)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading