diff --git a/python/nano/test/pytorch/tests/test_plugin.py b/python/nano/test/pytorch/tests/test_plugin.py index d9dca0dc405..f17b4161137 100644 --- a/python/nano/test/pytorch/tests/test_plugin.py +++ b/python/nano/test/pytorch/tests/test_plugin.py @@ -26,6 +26,7 @@ from bigdl.nano.pytorch import Trainer from test.pytorch.utils._train_torch_lightning import create_data_loader, data_transform +from test.pytorch.utils._train_torch_lightning import create_test_data_loader from test.pytorch.tests.test_lightning import ResNet18 num_classes = 10 @@ -41,6 +42,8 @@ class TestPlugin(TestCase): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) data_loader = create_data_loader(data_dir, batch_size, num_workers, data_transform, subset=dataset_size) + test_data_loader = create_test_data_loader(data_dir, batch_size, num_workers, + data_transform, subset=dataset_size) def setUp(self): test_dir = os.path.dirname(__file__) @@ -56,8 +59,8 @@ def test_trainer_subprocess_plugin(self): ) trainer = Trainer(num_processes=2, distributed_backend="subprocess", max_epochs=4) - trainer.fit(pl_model, self.data_loader, self.data_loader) - trainer.test(pl_model, self.data_loader) + trainer.fit(pl_model, self.data_loader, self.test_data_loader) + trainer.test(pl_model, self.test_data_loader) if __name__ == '__main__': diff --git a/python/nano/test/pytorch/tests/test_plugin_ipex.py b/python/nano/test/pytorch/tests/test_plugin_ipex.py new file mode 100644 index 00000000000..d29151dd8d7 --- /dev/null +++ b/python/nano/test/pytorch/tests/test_plugin_ipex.py @@ -0,0 +1,67 @@ +# +# Copyright 2016 The BigDL Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import os +import pytest +from unittest import TestCase + +import torch +import torchmetrics +from torch import nn + +from bigdl.nano.pytorch.lightning import LightningModuleFromTorch +from bigdl.nano.pytorch import Trainer + +from test.pytorch.utils._train_torch_lightning import create_data_loader, data_transform +from test.pytorch.utils._train_torch_lightning import create_test_data_loader +from test.pytorch.tests.test_lightning import ResNet18 + +num_classes = 10 +batch_size = 32 +dataset_size = 256 +num_workers = 0 +data_dir = os.path.join(os.path.dirname(__file__), "../data") + + +class TestPlugin(TestCase): + model = ResNet18(pretrained=False, include_top=False, freeze=True) + loss = nn.CrossEntropyLoss() + optimizer = torch.optim.Adam(model.parameters(), lr=0.01) + data_loader = create_data_loader(data_dir, batch_size, num_workers, + data_transform, subset=dataset_size) + test_data_loader = create_test_data_loader(data_dir, batch_size, num_workers, + data_transform, subset=dataset_size) + + def setUp(self): + test_dir = os.path.dirname(__file__) + project_test_dir = os.path.abspath( + os.path.join(os.path.join(os.path.join(test_dir, ".."), ".."), "..") + ) + os.environ['PYTHONPATH'] = project_test_dir + + def test_trainer_subprocess_plugin(self): + pl_model = LightningModuleFromTorch( + self.model, self.loss, self.optimizer, + metrics=[torchmetrics.F1(num_classes), torchmetrics.Accuracy(num_classes=10)] + ) + trainer = Trainer(num_processes=2, distributed_backend="subprocess", + max_epochs=4, use_ipex=True) + trainer.fit(pl_model, self.data_loader, self.test_data_loader) + trainer.test(pl_model, self.test_data_loader) + + +if __name__ == '__main__': + pytest.main([__file__])