Skip to content

Commit

Permalink
Tests to train a keras model using MongoDBIODataset (tensorflow#1264)
Browse files Browse the repository at this point in the history
  • Loading branch information
kvignesh1420 authored Jan 7, 2021
1 parent 252a357 commit ae9e7fb
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 18 deletions.
16 changes: 9 additions & 7 deletions tensorflow_io/core/kernels/mongodb_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ class MongoDBReadableResource : public ResourceBase {

// Register the application name so we can track it in the profile logs
// on the server. This can also be done from the URI.

mongoc_client_set_appname(client_obj_, "tfio-mongo-read");

// Get a handle on the database "db_name" and collection "coll_name"
Expand Down Expand Up @@ -87,12 +86,15 @@ class MongoDBReadableResource : public ResourceBase {
const bson_t* doc;

int num_records = 0;
while (mongoc_cursor_next(cursor_obj_, &doc) &&
num_records < max_num_records) {
char* record = bson_as_canonical_extended_json(doc, NULL);
records.emplace_back(record);
num_records++;
bson_free(record);
while (num_records < max_num_records) {
if (mongoc_cursor_next(cursor_obj_, &doc)) {
char* record = bson_as_canonical_extended_json(doc, NULL);
records.emplace_back(record);
num_records++;
bson_free(record);
} else {
break;
}
}

TensorShape shape({static_cast<int32>(records.size())});
Expand Down
101 changes: 90 additions & 11 deletions tests/test_mongodb_eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,44 @@

"""Tests for the mongodb datasets"""

from datetime import datetime
import time
import json
import pytest
import socket
import requests
import pytest
import tensorflow as tf
from tensorflow import feature_column
from tensorflow.keras import layers
import tensorflow_io as tfio

# COMMON VARIABLES
TIMESTAMP_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ"
URI = "mongodb://mongoadmin:default_password@localhost:27017"
DATABASE = "tfiodb"
COLLECTION = "test"
RECORDS = [
{
"name": "person1",
"gender": "Male",
"age": 20,
"fare": 80.52,
"vip": False,
"survived": 1,
},
{
"name": "person2",
"gender": "Female",
"age": 20,
"fare": 40.88,
"vip": True,
"survived": 0,
},
] * 1000
SPECS = {
"name": tf.TensorSpec(tf.TensorShape([]), tf.string),
"gender": tf.TensorSpec(tf.TensorShape([]), tf.string),
"age": tf.TensorSpec(tf.TensorShape([]), tf.int32),
"fare": tf.TensorSpec(tf.TensorShape([]), tf.float64),
"vip": tf.TensorSpec(tf.TensorShape([]), tf.bool),
"survived": tf.TensorSpec(tf.TensorShape([]), tf.int64),
}
BATCH_SIZE = 32


def is_container_running():
Expand All @@ -53,10 +75,9 @@ def test_writer_write():
writer = tfio.experimental.mongodb.MongoDBWriter(
uri=URI, database=DATABASE, collection=COLLECTION
)
timestamp = datetime.utcnow().strftime(TIMESTAMP_PATTERN)
for i in range(1000):
data = {"timestamp": timestamp, "key{}".format(i): "value{}".format(i)}
writer.write(data)

for record in RECORDS:
writer.write(record)


@pytest.mark.skipif(not is_container_running(), reason="The container is not running")
Expand All @@ -69,7 +90,65 @@ def test_dataset_read():
count = 0
for d in dataset:
count += 1
assert count == 1000
assert count == len(RECORDS)


@pytest.mark.skipif(not is_container_running(), reason="The container is not running")
def test_train_model():
"""Test the dataset by training a tf.keras model"""

dataset = tfio.experimental.mongodb.MongoDBIODataset(
uri=URI, database=DATABASE, collection=COLLECTION
)
dataset = dataset.map(
lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS)
)
dataset = dataset.map(lambda v: (v, v.pop("survived")))
dataset = dataset.batch(BATCH_SIZE)

assert issubclass(type(dataset), tf.data.Dataset)

feature_columns = []

# Numeric column
fare_column = feature_column.numeric_column("fare")
feature_columns.append(fare_column)

# Bucketized column
age = feature_column.numeric_column("age")
age_buckets = feature_column.bucketized_column(age, boundaries=[10, 30])
feature_columns.append(age_buckets)

# Categorical column
gender = feature_column.categorical_column_with_vocabulary_list(
"gender", ["Male", "Female"]
)
gender_indicator = feature_column.indicator_column(gender)
feature_columns.append(gender_indicator)

# Convert the feature columns into a tf.keras layer
feature_layer = tf.keras.layers.DenseFeatures(feature_columns)

# Build the model
model = tf.keras.Sequential(
[
feature_layer,
layers.Dense(128, activation="relu"),
layers.Dense(128, activation="relu"),
layers.Dropout(0.1),
layers.Dense(1),
]
)

# Compile the model
model.compile(
optimizer="adam",
loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
metrics=["accuracy"],
)

# train the model
model.fit(dataset, epochs=5)


@pytest.mark.skipif(not is_container_running(), reason="The container is not running")
Expand Down

0 comments on commit ae9e7fb

Please sign in to comment.