From ae9e7fbe36badc5e61ed8c137c282934ebd21699 Mon Sep 17 00:00:00 2001 From: Vignesh Kothapalli Date: Fri, 8 Jan 2021 02:08:53 +0530 Subject: [PATCH] Tests to train a keras model using MongoDBIODataset (#1264) --- tensorflow_io/core/kernels/mongodb_kernels.cc | 16 +-- tests/test_mongodb_eager.py | 101 ++++++++++++++++-- 2 files changed, 99 insertions(+), 18 deletions(-) diff --git a/tensorflow_io/core/kernels/mongodb_kernels.cc b/tensorflow_io/core/kernels/mongodb_kernels.cc index cc342ab27..cc4356c21 100644 --- a/tensorflow_io/core/kernels/mongodb_kernels.cc +++ b/tensorflow_io/core/kernels/mongodb_kernels.cc @@ -59,7 +59,6 @@ class MongoDBReadableResource : public ResourceBase { // Register the application name so we can track it in the profile logs // on the server. This can also be done from the URI. - mongoc_client_set_appname(client_obj_, "tfio-mongo-read"); // Get a handle on the database "db_name" and collection "coll_name" @@ -87,12 +86,15 @@ class MongoDBReadableResource : public ResourceBase { const bson_t* doc; int num_records = 0; - while (mongoc_cursor_next(cursor_obj_, &doc) && - num_records < max_num_records) { - char* record = bson_as_canonical_extended_json(doc, NULL); - records.emplace_back(record); - num_records++; - bson_free(record); + while (num_records < max_num_records) { + if (mongoc_cursor_next(cursor_obj_, &doc)) { + char* record = bson_as_canonical_extended_json(doc, NULL); + records.emplace_back(record); + num_records++; + bson_free(record); + } else { + break; + } } TensorShape shape({static_cast(records.size())}); diff --git a/tests/test_mongodb_eager.py b/tests/test_mongodb_eager.py index b8d0688ee..ce1c7c8db 100644 --- a/tests/test_mongodb_eager.py +++ b/tests/test_mongodb_eager.py @@ -15,22 +15,44 @@ """Tests for the mongodb datasets""" -from datetime import datetime -import time -import json -import pytest import socket -import requests +import pytest import tensorflow as tf from tensorflow import feature_column from tensorflow.keras import layers import tensorflow_io as tfio # COMMON VARIABLES -TIMESTAMP_PATTERN = "%Y-%m-%dT%H:%M:%S.%fZ" URI = "mongodb://mongoadmin:default_password@localhost:27017" DATABASE = "tfiodb" COLLECTION = "test" +RECORDS = [ + { + "name": "person1", + "gender": "Male", + "age": 20, + "fare": 80.52, + "vip": False, + "survived": 1, + }, + { + "name": "person2", + "gender": "Female", + "age": 20, + "fare": 40.88, + "vip": True, + "survived": 0, + }, +] * 1000 +SPECS = { + "name": tf.TensorSpec(tf.TensorShape([]), tf.string), + "gender": tf.TensorSpec(tf.TensorShape([]), tf.string), + "age": tf.TensorSpec(tf.TensorShape([]), tf.int32), + "fare": tf.TensorSpec(tf.TensorShape([]), tf.float64), + "vip": tf.TensorSpec(tf.TensorShape([]), tf.bool), + "survived": tf.TensorSpec(tf.TensorShape([]), tf.int64), +} +BATCH_SIZE = 32 def is_container_running(): @@ -53,10 +75,9 @@ def test_writer_write(): writer = tfio.experimental.mongodb.MongoDBWriter( uri=URI, database=DATABASE, collection=COLLECTION ) - timestamp = datetime.utcnow().strftime(TIMESTAMP_PATTERN) - for i in range(1000): - data = {"timestamp": timestamp, "key{}".format(i): "value{}".format(i)} - writer.write(data) + + for record in RECORDS: + writer.write(record) @pytest.mark.skipif(not is_container_running(), reason="The container is not running") @@ -69,7 +90,65 @@ def test_dataset_read(): count = 0 for d in dataset: count += 1 - assert count == 1000 + assert count == len(RECORDS) + + +@pytest.mark.skipif(not is_container_running(), reason="The container is not running") +def test_train_model(): + """Test the dataset by training a tf.keras model""" + + dataset = tfio.experimental.mongodb.MongoDBIODataset( + uri=URI, database=DATABASE, collection=COLLECTION + ) + dataset = dataset.map( + lambda x: tfio.experimental.serialization.decode_json(x, specs=SPECS) + ) + dataset = dataset.map(lambda v: (v, v.pop("survived"))) + dataset = dataset.batch(BATCH_SIZE) + + assert issubclass(type(dataset), tf.data.Dataset) + + feature_columns = [] + + # Numeric column + fare_column = feature_column.numeric_column("fare") + feature_columns.append(fare_column) + + # Bucketized column + age = feature_column.numeric_column("age") + age_buckets = feature_column.bucketized_column(age, boundaries=[10, 30]) + feature_columns.append(age_buckets) + + # Categorical column + gender = feature_column.categorical_column_with_vocabulary_list( + "gender", ["Male", "Female"] + ) + gender_indicator = feature_column.indicator_column(gender) + feature_columns.append(gender_indicator) + + # Convert the feature columns into a tf.keras layer + feature_layer = tf.keras.layers.DenseFeatures(feature_columns) + + # Build the model + model = tf.keras.Sequential( + [ + feature_layer, + layers.Dense(128, activation="relu"), + layers.Dense(128, activation="relu"), + layers.Dropout(0.1), + layers.Dense(1), + ] + ) + + # Compile the model + model.compile( + optimizer="adam", + loss=tf.keras.losses.BinaryCrossentropy(from_logits=True), + metrics=["accuracy"], + ) + + # train the model + model.fit(dataset, epochs=5) @pytest.mark.skipif(not is_container_running(), reason="The container is not running")