Skip to content

Commit

Permalink
Configure in components in kv format
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhuomkar committed Aug 14, 2023
1 parent bd9a5f3 commit 0e4dd1a
Show file tree
Hide file tree
Showing 25 changed files with 91 additions and 88 deletions.
30 changes: 15 additions & 15 deletions api/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,21 @@ type (

// ML ...
ML struct {
Places bool `envconfig:"SMRITI_ML_PLACES" default:"true"`
Classification bool `envconfig:"SMRITI_ML_CLASSIFICATION" default:"true"`
OCR bool `envconfig:"SMRITI_ML_OCR" default:"true"`
Search bool `envconfig:"SMRITI_ML_SEARCH" default:"true"`
Faces bool `envconfig:"SMRITI_ML_FACES" default:"true"`
PlacesProvider string `envconfig:"SMRITI_ML_PLACES_PROVIDER" default:"openstreetmap"`
ClassificationProvider string `envconfig:"SMRITI_ML_CLASSIFICATION_PROVIDER" default:"pytorch"`
ClassificationParams []string `envconfig:"SMRITI_ML_CLASSIFICATION_PARAMS" default:"classification_v20230731.pt"`
OCRProvider string `envconfig:"SMRITI_ML_OCR_PROVIDER" default:"paddlepaddle"`
OCRParams []string `envconfig:"SMRITI_ML_OCR_PARAMS" default:"det_infer,rec_infer,cls_infer"`
SearchProvider string `envconfig:"SMRITI_ML_SEARCH_PROVIDER" default:"pytorch"`
SearchParams []string `envconfig:"SMRITI_ML_SEARCH_PARAMS" default:"search_tokenizer,search_processor,search_text_v20230731.pt,search_vision_v20230731.pt"`
FacesProvider string `envconfig:"SMRITI_ML_FACES_PROVIDER" default:"pytorch"`
FacesParams []string `envconfig:"SMRITI_ML_FACES_PARAMS" default:"1,0.9,vggface2"`
MetadataParams []string `envconfig:"SMRITI_ML_METADATA_PARAMS" default:"512"`
Places bool `envconfig:"SMRITI_ML_PLACES" default:"true"`
Classification bool `envconfig:"SMRITI_ML_CLASSIFICATION" default:"true"`
OCR bool `envconfig:"SMRITI_ML_OCR" default:"true"`
Search bool `envconfig:"SMRITI_ML_SEARCH" default:"true"`
Faces bool `envconfig:"SMRITI_ML_FACES" default:"true"`
PlacesProvider string `envconfig:"SMRITI_ML_PLACES_PROVIDER" default:"openstreetmap"`
ClassificationProvider string `envconfig:"SMRITI_ML_CLASSIFICATION_PROVIDER" default:"pytorch"`
ClassificationParams string `envconfig:"SMRITI_ML_CLASSIFICATION_PARAMS" default:"{\"file\":\"classification_v20230731.pt\"}"`
OCRProvider string `envconfig:"SMRITI_ML_OCR_PROVIDER" default:"paddlepaddle"`
OCRParams string `envconfig:"SMRITI_ML_OCR_PARAMS" default:"{\"det_model_dir\":\"det_infer\",\"rec_model_dir\":\"rec_infer\",\"cls_model_dir\":\"cls_infer\"}"`
SearchProvider string `envconfig:"SMRITI_ML_SEARCH_PROVIDER" default:"pytorch"`
SearchParams string `envconfig:"SMRITI_ML_SEARCH_PARAMS" default:"{\"tokenizer_dir\":\"search_tokenizer\",\"processor_dir\":\"search_processor\",\"text_file\":\"search_text_v20230731.pt\",\"vision_file\":\"search_vision_v20230731.pt\"}"` //nolint:lll
FacesProvider string `envconfig:"SMRITI_ML_FACES_PROVIDER" default:"pytorch"`
FacesParams string `envconfig:"SMRITI_ML_FACES_PARAMS" default:"{\"minutes\":\"1\",\"face_threshold\":\"0.9\",\"model\":\"vggface2\"}"`
MetadataParams string `envconfig:"SMRITI_ML_METADATA_PARAMS" default:"{\"thumbnail_size\":\"512\"}"`
}

// Feature ...
Expand Down
6 changes: 3 additions & 3 deletions api/internal/service/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ type Service struct {

func (s *Service) GetWorkerConfig(_ context.Context, _ *empty.Empty) (*api.ConfigResponse, error) {
type WorkerTask struct {
Name string `json:"name"`
Source string `json:"source,omitempty"`
Params []string `json:"params,omitempty"`
Name string `json:"name"`
Source string `json:"source,omitempty"`
Params string `json:"params,omitempty"`
}
var workerTasks []WorkerTask
if len(s.Config.ML.MetadataParams) > 0 {
Expand Down
19 changes: 9 additions & 10 deletions api/internal/service/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,16 @@ func TestGetWorkerConfig(t *testing.T) {
"get worker config with success with all config",
&config.Config{ML: config.ML{
Places: true, PlacesProvider: "openstreetmap",
Classification: true, ClassificationProvider: "pytorch", ClassificationParams: []string{"model-file-name.pt"},
OCR: true, OCRProvider: "paddlepaddle", OCRParams: []string{"ocr-v1-model.pt"},
Search: true, SearchProvider: "pytorch", SearchParams: []string{"search-model.pt"},
Faces: true, FacesParams: []string{"http://faces/model/link"},
MetadataParams: []string{"512"},
Classification: true, ClassificationProvider: "pytorch", ClassificationParams: `{"file":"model-file-name.pt"}`,
OCR: true, OCRProvider: "paddlepaddle", OCRParams: `{"det_model_dir":"/det_infer"}`,
Search: true, SearchProvider: "pytorch", SearchParams: `{"tokenizer_dir":"/tokenizer"}`,
Faces: true, FacesParams: `{"face_threshold":"0.9"}`,
MetadataParams: `{"thumbnail_size":"512"}`,
}},
[]byte(`[{"name":"metadata","params":["512"]},` +
`{"name":"places","source":"openstreetmap"},{"name":"classification","source":"pytorch",` +
`"params":["model-file-name.pt"]},{"name":"ocr","source":"paddlepaddle","params":["ocr-v1-model.pt"]}` +
`,{"name":"search","source":"pytorch","params":["search-model.pt"]}` +
`,{"name":"faces","params":["http://faces/model/link"]}]`),
[]byte(`[{"name":"metadata","params":"{\"thumbnail_size\":\"512\"}"},{"name":"places","source":"openstreetmap"},` +
`{"name":"classification","source":"pytorch","params":"{\"file\":\"model-file-name.pt\"}"},{"name":"ocr","source":"paddlepaddle",` +
`"params":"{\"det_model_dir\":\"/det_infer\"}"},{"name":"search","source":"pytorch","params":"{\"tokenizer_dir\":\"/tokenizer\"}"},` +
`{"name":"faces","params":"{\"face_threshold\":\"0.9\"}"}]`),
nil,
},
{
Expand Down
2 changes: 1 addition & 1 deletion worker/src/components/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

class Classification(Component):
"""Classification Component"""
def __init__(self, api_stub: APIStub, source: str, params: list[str]) -> None:
def __init__(self, api_stub: APIStub, source: str, params: dict) -> None:
super().__init__('classification', api_stub)
self.model = init_classification(source, params)

Expand Down
4 changes: 2 additions & 2 deletions worker/src/components/faces.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@

class Faces(Component):
"""Faces Component"""
def __init__(self, api_stub: APIStub, source: str, params: list[str]) -> None:
def __init__(self, api_stub: APIStub, source: str, params: dict) -> None:
super().__init__('faces', api_stub)
self.source = init_faces(source, params)
schedule.every(int(params[0])).minutes.do(self.cluster)
schedule.every(int(params['minutes'])).minutes.do(self.cluster)

async def process(self, mediaitem_user_id: str, mediaitem_id: str, _: str, metadata: dict) -> None:
"""Process faces detection for mediaitem"""
Expand Down
4 changes: 2 additions & 2 deletions worker/src/components/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ class Metadata(Component):
'video/webm', 'video/3gpp', 'video/3gpp2',
]

def __init__(self, api_stub: APIStub, params: list[str]) -> None:
def __init__(self, api_stub: APIStub, params: dict) -> None:
super().__init__('metadata', api_stub)
self.thumbnail_size = int(params[0])
self.thumbnail_size = int(params['thumbnail_size'])

# pylint: disable=too-many-statements,too-many-branches
async def process(self, mediaitem_user_id: str, mediaitem_id: str, mediaitem_file_path: str, _: dict) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion worker/src/components/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class OCR(Component):
"""OCR Component"""
def __init__(self, api_stub: APIStub, source: str, params: list[str]) -> None:
def __init__(self, api_stub: APIStub, source: str, params: dict) -> None:
super().__init__('ocr', api_stub)
self.model = init_ocr(source, params)

Expand Down
2 changes: 2 additions & 0 deletions worker/src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ async def serve() -> None:
components = []
search_model = None
for item in cfg:
if 'params' in item:
item['params'] = json.loads(item['params'])
if item['name'] == 'metadata':
components.append(Metadata(api_stub=api_stub, params=item['params']))
elif item['name'] == 'places':
Expand Down
4 changes: 2 additions & 2 deletions worker/src/providers/classification/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
class PyTorchModule:
"""PyTorchModule Classification"""

def __init__(self, params: list[str]) -> None:
self.module = torch.jit.load(f'/models/classification/{params[0]}')
def __init__(self, params: dict) -> None:
self.module = torch.jit.load(f'/models/classification/{params["file"]}')

def classify(self, mediaitem_user_id: str, mediaitem_id: str, input_file: str) -> dict:
"""Classify categories for mediaitem"""
Expand Down
2 changes: 1 addition & 1 deletion worker/src/providers/classification/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.providers.classification.pytorch import PyTorchModule


def init_classification(name: str, params: list[str]) -> None | PyTorchModule:
def init_classification(name: str, params: dict) -> None | PyTorchModule:
"""Initialize classification model by name"""
if name == 'pytorch':
return PyTorchModule(params)
Expand Down
6 changes: 3 additions & 3 deletions worker/src/providers/faces/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
class PyTorchModule:
"""PyTorchModule Faces"""

def __init__(self, params: list[str]) -> None:
def __init__(self, params: dict) -> None:
os.environ['TORCH_HOME'] = '/'
try:
os.symlink('/models/faces/', '/checkpoints')
except Exception as exp:
logging.error(f'error creating symlink: {str(exp)}')
self.prob_threshold = float(params[1])
self.prob_threshold = float(params['face_threshold'])
self.det_model = MTCNN(keep_all=True)
self.rec_model = InceptionResnetV1(pretrained=params[2], classify=False)
self.rec_model = InceptionResnetV1(pretrained=params['model'], classify=False)
if self.rec_model:
self.rec_model.eval()

Expand Down
2 changes: 1 addition & 1 deletion worker/src/providers/faces/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.providers.faces.pytorch import PyTorchModule


def init_faces(name: str, params: list[str]) -> None | PyTorchModule:
def init_faces(name: str, params: dict) -> None | PyTorchModule:
"""Initialize faces by name"""
if name == 'pytorch':
return PyTorchModule(params)
Expand Down
8 changes: 5 additions & 3 deletions worker/src/providers/ocr/paddlepaddle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
class PaddleModule:
"""PaddleModule OCR"""

def __init__(self, params: list[str]) -> None:
self.model = PaddleOCR(show_log=False, use_angle_cls=True, lang='en', det_model_dir=f'/models/ocr/{params[0]}',
rec_model_dir=f'/models/ocr/{params[1]}', cls_model_dir=f'/models/ocr/{params[2]}')
def __init__(self, params: dict) -> None:
self.model = PaddleOCR(show_log=False, use_angle_cls=True, lang='en',
det_model_dir=f'/models/ocr/{params["det_model_dir"]}',
rec_model_dir=f'/models/ocr/{params["rec_model_dir"]}',
cls_model_dir=f'/models/ocr/{params["cls_model_dir"]}')

def extract(self, mediaitem_user_id: str, mediaitem_id: str, mediaitem_type: str, input_file: str) -> dict:
"""Extract text from mediaitem"""
Expand Down
2 changes: 1 addition & 1 deletion worker/src/providers/ocr/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.providers.ocr.paddlepaddle import PaddleModule


def init_ocr(name: str, params: list[str]) -> None | PaddleModule:
def init_ocr(name: str, params: dict) -> None | PaddleModule:
"""Initialize ocr model by name"""
if name == 'paddlepaddle':
return PaddleModule(params)
Expand Down
10 changes: 5 additions & 5 deletions worker/src/providers/search/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
class PyTorchModule:
"""PyTorchModule Search"""

def __init__(self, params: list[str]) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(f'/models/search/{params[0]}')
self.processor = AutoImageProcessor.from_pretrained(f'/models/search/{params[1]}')
self.text_module = torch.jit.load(f'/models/search/{params[2]}')
self.vision_module = torch.jit.load(f'/models/search/{params[3]}')
def __init__(self, params: dict) -> None:
self.tokenizer = AutoTokenizer.from_pretrained(f'/models/search/{params["tokenizer_dir"]}')
self.processor = AutoImageProcessor.from_pretrained(f'/models/search/{params["processor_dir"]}')
self.text_module = torch.jit.load(f'/models/search/{params["text_file"]}')
self.vision_module = torch.jit.load(f'/models/search/{params["vision_file"]}')

def generate_embedding(self, input_type: str, data: any):
"""Generate text embedding from text"""
Expand Down
2 changes: 1 addition & 1 deletion worker/src/providers/search/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from src.providers.search.pytorch import PyTorchModule


def init_search(name: str, params: list[str]) -> None | PyTorchModule:
def init_search(name: str, params: dict) -> None | PyTorchModule:
"""Initialize search model by name"""
if name == 'pytorch':
return PyTorchModule(params)
Expand Down
10 changes: 5 additions & 5 deletions worker/tests/components/test_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@mock.patch('src.components.Classification._grpc_save_mediaitem_thing', return_value=None)
@pytest.mark.asyncio
async def test_classification_process_success(_, __):
classification = Classification(APIStub(channel=grpc.insecure_channel('')), 'pytorch', ['model_name.pt'])
classification = Classification(APIStub(channel=grpc.insecure_channel('')), 'pytorch', {'file':'model_name.pt'})
classification.model = mock.MagicMock()
classification.model.classify.return_value = dict({'userId':'userId','id':'id','name':'name'})
result = await classification.process('mediaitem_user_id', 'mediaitem_id', None,
Expand All @@ -23,7 +23,7 @@ async def test_classification_process_success(_, __):
@mock.patch('src.components.Classification._grpc_save_mediaitem_thing', return_value=None)
@pytest.mark.asyncio
async def test_classification_process_success_with_keywords(_, __):
classification = Classification(APIStub(channel=grpc.insecure_channel('')), 'pytorch', ['model_name.pt'])
classification = Classification(APIStub(channel=grpc.insecure_channel('')), 'pytorch', {'file':'model_name.pt'})
classification.model = mock.MagicMock()
classification.model.classify.return_value = dict({'userId':'userId','id':'id','name':'name'})
result = await classification.process('mediaitem_user_id', 'mediaitem_id', None,
Expand All @@ -33,14 +33,14 @@ async def test_classification_process_success_with_keywords(_, __):
@mock.patch('torch.jit.load', return_value=None)
@pytest.mark.asyncio
async def test_classification_process_success_no_metadata(_):
result = await Classification(None, 'pytorch', ['model_name.pt']).process('mediaitem_user_id', 'mediaitem_id', None, None)
result = await Classification(None, 'pytorch', {'file':'model_name.pt'}).process('mediaitem_user_id', 'mediaitem_id', None, None)
assert result == None

@mock.patch('torch.jit.load', return_value=None)
@mock.patch('src.components.Classification._grpc_save_mediaitem_thing', return_value=None)
@pytest.mark.asyncio
async def test_classification_process_failed_process_exception(_, __):
classification = Classification(None, 'pytorch', ['model_name.pt'])
classification = Classification(None, 'pytorch', {'file':'model_name.pt'})
classification.model = mock.MagicMock()
classification.model.classify.side_effect = Exception('some exception')
result = await classification.process('mediaitem_user_id', 'mediaitem_id', None,
Expand All @@ -50,7 +50,7 @@ async def test_classification_process_failed_process_exception(_, __):
@mock.patch('torch.jit.load', return_value=None)
@pytest.mark.asyncio
async def test_classification_process_grpc_exception(_):
classification = Classification(APIStub(channel=grpc.insecure_channel('')), 'pytorch', ['model_name.pt'])
classification = Classification(APIStub(channel=grpc.insecure_channel('')), 'pytorch', {'file':'model_name.pt'})
classification.model = mock.MagicMock()
classification.model.classify.return_value = dict({'userId':'userId','id':'id','name':'name'})
grpc_mock = mock.MagicMock()
Expand Down
14 changes: 7 additions & 7 deletions worker/tests/components/test_faces.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@mock.patch('src.providers.faces.PyTorchModule.__init__', return_value=None)
@pytest.mark.asyncio
async def test_faces_process_success(_):
faces = Faces(APIStub(channel=grpc.insecure_channel('')), 'pytorch', ['1', '0.9', 'vggface2'])
faces = Faces(APIStub(channel=grpc.insecure_channel('')), 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'})
faces.source = mock.MagicMock()
faces.source.detect.return_value = dict({'userId':'userId','id':'id','embeddings':[[0.4,0.2]]})
result = await faces.process('mediaitem_user_id', 'mediaitem_id', None,
Expand All @@ -22,13 +22,13 @@ async def test_faces_process_success(_):
@mock.patch('src.providers.faces.PyTorchModule.__init__', return_value=None)
@pytest.mark.asyncio
async def test_faces_process_success_no_result(_):
result = await Faces(None, 'pytorch', ['1', '0.9', 'vggface2']).process('mediaitem_user_id', 'mediaitem_id', None, None)
result = await Faces(None, 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'}).process('mediaitem_user_id', 'mediaitem_id', None, None)
assert result == None

@mock.patch('src.providers.faces.PyTorchModule.__init__', return_value=None)
@pytest.mark.asyncio
async def test_faces_process_failed_process_exception(_):
faces = Faces(None, 'pytorch', ['1', '0.9', 'vggface2'])
faces = Faces(None, 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'})
faces.source = mock.MagicMock()
faces.source.detect.side_effect = Exception('some exception')
result = await faces.process('mediaitem_user_id', 'mediaitem_id', None,
Expand All @@ -51,7 +51,7 @@ async def test_faces_cluster_success(_, __):
MediaItemFaceEmbedding(id='face-id-5', mediaItemId='mediaitem-id-5', peopleId='', embedding=MediaItemEmbedding(embedding=[33.42,41.24])),
]
)
faces = Faces(api_stub_mock, 'pytorch', ['1', '0.9', 'vggface2'])
faces = Faces(api_stub_mock, 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'})
faces.cluster()

@mock.patch('src.providers.faces.PyTorchModule.__init__', return_value=None)
Expand All @@ -61,7 +61,7 @@ async def test_faces_cluster_success_no_embeddings(_, __):
api_stub_mock = mock.MagicMock()
api_stub_mock.GetUsers.return_value = GetUsersResponse(users=['user-id'])
api_stub_mock.GetMediaItemFaceEmbeddings.return_value = MediaItemFaceEmbeddingsResponse(mediaItemFaceEmbeddings=[])
faces = Faces(api_stub_mock, 'pytorch', ['1', '0.9', 'vggface2'])
faces = Faces(api_stub_mock, 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'})
faces.cluster()

@mock.patch('src.providers.faces.PyTorchModule.__init__', return_value=None)
Expand All @@ -70,7 +70,7 @@ async def test_faces_cluster_failed_cluster_exception(_):
api_stub_mock = mock.MagicMock()
api_stub_mock.GetUsers.return_value = GetUsersResponse(users=['user-id'])
api_stub_mock.GetMediaItemFaceEmbeddings.side_effect = grpc.RpcError(Exception('some error'))
faces = Faces(api_stub_mock, 'pytorch', ['1', '0.9', 'vggface2'])
faces = Faces(api_stub_mock, 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'})
faces.cluster()

@mock.patch('src.providers.faces.PyTorchModule.__init__', return_value=None)
Expand All @@ -85,5 +85,5 @@ async def test_faces_cluster_failed_grpc_exception(_):
]
)
api_stub_mock.SaveMediaItemPeople.side_effect = grpc.RpcError(Exception('some error'))
faces = Faces(api_stub_mock, 'pytorch', ['1', '0.9', 'vggface2'])
faces = Faces(api_stub_mock, 'pytorch', {'minutes':'1','face_threshold':'0.9','model':'vggface2'})
faces.cluster()
Loading

0 comments on commit 0e4dd1a

Please sign in to comment.