Skip to content

Commit

Permalink
add ipex multiprocessing test (#4724)
Browse files Browse the repository at this point in the history
  • Loading branch information
yangw1234 authored May 30, 2022
1 parent c364d88 commit ad0efc9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 2 deletions.
7 changes: 5 additions & 2 deletions python/nano/test/pytorch/tests/test_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from bigdl.nano.pytorch import Trainer

from test.pytorch.utils._train_torch_lightning import create_data_loader, data_transform
from test.pytorch.utils._train_torch_lightning import create_test_data_loader
from test.pytorch.tests.test_lightning import ResNet18

num_classes = 10
Expand All @@ -41,6 +42,8 @@ class TestPlugin(TestCase):
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
data_loader = create_data_loader(data_dir, batch_size, num_workers,
data_transform, subset=dataset_size)
test_data_loader = create_test_data_loader(data_dir, batch_size, num_workers,
data_transform, subset=dataset_size)

def setUp(self):
test_dir = os.path.dirname(__file__)
Expand All @@ -56,8 +59,8 @@ def test_trainer_subprocess_plugin(self):
)
trainer = Trainer(num_processes=2, distributed_backend="subprocess",
max_epochs=4)
trainer.fit(pl_model, self.data_loader, self.data_loader)
trainer.test(pl_model, self.data_loader)
trainer.fit(pl_model, self.data_loader, self.test_data_loader)
trainer.test(pl_model, self.test_data_loader)


if __name__ == '__main__':
Expand Down
67 changes: 67 additions & 0 deletions python/nano/test/pytorch/tests/test_plugin_ipex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright 2016 The BigDL Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

import os
import pytest
from unittest import TestCase

import torch
import torchmetrics
from torch import nn

from bigdl.nano.pytorch.lightning import LightningModuleFromTorch
from bigdl.nano.pytorch import Trainer

from test.pytorch.utils._train_torch_lightning import create_data_loader, data_transform
from test.pytorch.utils._train_torch_lightning import create_test_data_loader
from test.pytorch.tests.test_lightning import ResNet18

num_classes = 10
batch_size = 32
dataset_size = 256
num_workers = 0
data_dir = os.path.join(os.path.dirname(__file__), "../data")


class TestPlugin(TestCase):
model = ResNet18(pretrained=False, include_top=False, freeze=True)
loss = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
data_loader = create_data_loader(data_dir, batch_size, num_workers,
data_transform, subset=dataset_size)
test_data_loader = create_test_data_loader(data_dir, batch_size, num_workers,
data_transform, subset=dataset_size)

def setUp(self):
test_dir = os.path.dirname(__file__)
project_test_dir = os.path.abspath(
os.path.join(os.path.join(os.path.join(test_dir, ".."), ".."), "..")
)
os.environ['PYTHONPATH'] = project_test_dir

def test_trainer_subprocess_plugin(self):
pl_model = LightningModuleFromTorch(
self.model, self.loss, self.optimizer,
metrics=[torchmetrics.F1(num_classes), torchmetrics.Accuracy(num_classes=10)]
)
trainer = Trainer(num_processes=2, distributed_backend="subprocess",
max_epochs=4, use_ipex=True)
trainer.fit(pl_model, self.data_loader, self.test_data_loader)
trainer.test(pl_model, self.test_data_loader)


if __name__ == '__main__':
pytest.main([__file__])

0 comments on commit ad0efc9

Please sign in to comment.