diff --git a/configs/_base_/datasets/300w.py b/configs/_base_/datasets/300w.py
index 10c343a2ad..2c3728da1d 100644
--- a/configs/_base_/datasets/300w.py
+++ b/configs/_base_/datasets/300w.py
@@ -11,373 +11,123 @@
homepage='https://ibug.doc.ic.ac.uk/resources/300-W/',
),
keypoint_info={
- 0:
- dict(
- name='kpt-0', id=0, color=[255, 255, 255], type='', swap='kpt-16'),
- 1:
- dict(
- name='kpt-1', id=1, color=[255, 255, 255], type='', swap='kpt-15'),
- 2:
- dict(
- name='kpt-2', id=2, color=[255, 255, 255], type='', swap='kpt-14'),
- 3:
- dict(
- name='kpt-3', id=3, color=[255, 255, 255], type='', swap='kpt-13'),
- 4:
- dict(
- name='kpt-4', id=4, color=[255, 255, 255], type='', swap='kpt-12'),
- 5:
- dict(
- name='kpt-5', id=5, color=[255, 255, 255], type='', swap='kpt-11'),
- 6:
- dict(
- name='kpt-6', id=6, color=[255, 255, 255], type='', swap='kpt-10'),
- 7:
- dict(name='kpt-7', id=7, color=[255, 255, 255], type='', swap='kpt-9'),
- 8:
- dict(name='kpt-8', id=8, color=[255, 255, 255], type='', swap=''),
- 9:
- dict(name='kpt-9', id=9, color=[255, 255, 255], type='', swap='kpt-7'),
+ 0: dict(name='kpt-0', id=0, color=[255, 0, 0], type='', swap='kpt-16'),
+ 1: dict(name='kpt-1', id=1, color=[255, 0, 0], type='', swap='kpt-15'),
+ 2: dict(name='kpt-2', id=2, color=[255, 0, 0], type='', swap='kpt-14'),
+ 3: dict(name='kpt-3', id=3, color=[255, 0, 0], type='', swap='kpt-13'),
+ 4: dict(name='kpt-4', id=4, color=[255, 0, 0], type='', swap='kpt-12'),
+ 5: dict(name='kpt-5', id=5, color=[255, 0, 0], type='', swap='kpt-11'),
+ 6: dict(name='kpt-6', id=6, color=[255, 0, 0], type='', swap='kpt-10'),
+ 7: dict(name='kpt-7', id=7, color=[255, 0, 0], type='', swap='kpt-9'),
+ 8: dict(name='kpt-8', id=8, color=[255, 0, 0], type='', swap=''),
+ 9: dict(name='kpt-9', id=9, color=[255, 0, 0], type='', swap='kpt-7'),
10:
- dict(
- name='kpt-10', id=10, color=[255, 255, 255], type='',
- swap='kpt-6'),
+ dict(name='kpt-10', id=10, color=[255, 0, 0], type='', swap='kpt-6'),
11:
- dict(
- name='kpt-11', id=11, color=[255, 255, 255], type='',
- swap='kpt-5'),
+ dict(name='kpt-11', id=11, color=[255, 0, 0], type='', swap='kpt-5'),
12:
- dict(
- name='kpt-12', id=12, color=[255, 255, 255], type='',
- swap='kpt-4'),
+ dict(name='kpt-12', id=12, color=[255, 0, 0], type='', swap='kpt-4'),
13:
- dict(
- name='kpt-13', id=13, color=[255, 255, 255], type='',
- swap='kpt-3'),
+ dict(name='kpt-13', id=13, color=[255, 0, 0], type='', swap='kpt-3'),
14:
- dict(
- name='kpt-14', id=14, color=[255, 255, 255], type='',
- swap='kpt-2'),
+ dict(name='kpt-14', id=14, color=[255, 0, 0], type='', swap='kpt-2'),
15:
- dict(
- name='kpt-15', id=15, color=[255, 255, 255], type='',
- swap='kpt-1'),
+ dict(name='kpt-15', id=15, color=[255, 0, 0], type='', swap='kpt-1'),
16:
- dict(
- name='kpt-16', id=16, color=[255, 255, 255], type='',
- swap='kpt-0'),
+ dict(name='kpt-16', id=16, color=[255, 0, 0], type='', swap='kpt-0'),
17:
- dict(
- name='kpt-17',
- id=17,
- color=[255, 255, 255],
- type='',
- swap='kpt-26'),
+ dict(name='kpt-17', id=17, color=[255, 0, 0], type='', swap='kpt-26'),
18:
- dict(
- name='kpt-18',
- id=18,
- color=[255, 255, 255],
- type='',
- swap='kpt-25'),
+ dict(name='kpt-18', id=18, color=[255, 0, 0], type='', swap='kpt-25'),
19:
- dict(
- name='kpt-19',
- id=19,
- color=[255, 255, 255],
- type='',
- swap='kpt-24'),
+ dict(name='kpt-19', id=19, color=[255, 0, 0], type='', swap='kpt-24'),
20:
- dict(
- name='kpt-20',
- id=20,
- color=[255, 255, 255],
- type='',
- swap='kpt-23'),
+ dict(name='kpt-20', id=20, color=[255, 0, 0], type='', swap='kpt-23'),
21:
- dict(
- name='kpt-21',
- id=21,
- color=[255, 255, 255],
- type='',
- swap='kpt-22'),
+ dict(name='kpt-21', id=21, color=[255, 0, 0], type='', swap='kpt-22'),
22:
- dict(
- name='kpt-22',
- id=22,
- color=[255, 255, 255],
- type='',
- swap='kpt-21'),
+ dict(name='kpt-22', id=22, color=[255, 0, 0], type='', swap='kpt-21'),
23:
- dict(
- name='kpt-23',
- id=23,
- color=[255, 255, 255],
- type='',
- swap='kpt-20'),
+ dict(name='kpt-23', id=23, color=[255, 0, 0], type='', swap='kpt-20'),
24:
- dict(
- name='kpt-24',
- id=24,
- color=[255, 255, 255],
- type='',
- swap='kpt-19'),
+ dict(name='kpt-24', id=24, color=[255, 0, 0], type='', swap='kpt-19'),
25:
- dict(
- name='kpt-25',
- id=25,
- color=[255, 255, 255],
- type='',
- swap='kpt-18'),
+ dict(name='kpt-25', id=25, color=[255, 0, 0], type='', swap='kpt-18'),
26:
- dict(
- name='kpt-26',
- id=26,
- color=[255, 255, 255],
- type='',
- swap='kpt-17'),
- 27:
- dict(name='kpt-27', id=27, color=[255, 255, 255], type='', swap=''),
- 28:
- dict(name='kpt-28', id=28, color=[255, 255, 255], type='', swap=''),
- 29:
- dict(name='kpt-29', id=29, color=[255, 255, 255], type='', swap=''),
- 30:
- dict(name='kpt-30', id=30, color=[255, 255, 255], type='', swap=''),
+ dict(name='kpt-26', id=26, color=[255, 0, 0], type='', swap='kpt-17'),
+ 27: dict(name='kpt-27', id=27, color=[255, 0, 0], type='', swap=''),
+ 28: dict(name='kpt-28', id=28, color=[255, 0, 0], type='', swap=''),
+ 29: dict(name='kpt-29', id=29, color=[255, 0, 0], type='', swap=''),
+ 30: dict(name='kpt-30', id=30, color=[255, 0, 0], type='', swap=''),
31:
- dict(
- name='kpt-31',
- id=31,
- color=[255, 255, 255],
- type='',
- swap='kpt-35'),
+ dict(name='kpt-31', id=31, color=[255, 0, 0], type='', swap='kpt-35'),
32:
- dict(
- name='kpt-32',
- id=32,
- color=[255, 255, 255],
- type='',
- swap='kpt-34'),
- 33:
- dict(name='kpt-33', id=33, color=[255, 255, 255], type='', swap=''),
+ dict(name='kpt-32', id=32, color=[255, 0, 0], type='', swap='kpt-34'),
+ 33: dict(name='kpt-33', id=33, color=[255, 0, 0], type='', swap=''),
34:
- dict(
- name='kpt-34',
- id=34,
- color=[255, 255, 255],
- type='',
- swap='kpt-32'),
+ dict(name='kpt-34', id=34, color=[255, 0, 0], type='', swap='kpt-32'),
35:
- dict(
- name='kpt-35',
- id=35,
- color=[255, 255, 255],
- type='',
- swap='kpt-31'),
+ dict(name='kpt-35', id=35, color=[255, 0, 0], type='', swap='kpt-31'),
36:
- dict(
- name='kpt-36',
- id=36,
- color=[255, 255, 255],
- type='',
- swap='kpt-45'),
+ dict(name='kpt-36', id=36, color=[255, 0, 0], type='', swap='kpt-45'),
37:
- dict(
- name='kpt-37',
- id=37,
- color=[255, 255, 255],
- type='',
- swap='kpt-44'),
+ dict(name='kpt-37', id=37, color=[255, 0, 0], type='', swap='kpt-44'),
38:
- dict(
- name='kpt-38',
- id=38,
- color=[255, 255, 255],
- type='',
- swap='kpt-43'),
+ dict(name='kpt-38', id=38, color=[255, 0, 0], type='', swap='kpt-43'),
39:
- dict(
- name='kpt-39',
- id=39,
- color=[255, 255, 255],
- type='',
- swap='kpt-42'),
+ dict(name='kpt-39', id=39, color=[255, 0, 0], type='', swap='kpt-42'),
40:
- dict(
- name='kpt-40',
- id=40,
- color=[255, 255, 255],
- type='',
- swap='kpt-47'),
- 41:
- dict(
- name='kpt-41',
- id=41,
- color=[255, 255, 255],
- type='',
- swap='kpt-46'),
- 42:
- dict(
- name='kpt-42',
- id=42,
- color=[255, 255, 255],
- type='',
- swap='kpt-39'),
- 43:
- dict(
- name='kpt-43',
- id=43,
- color=[255, 255, 255],
- type='',
- swap='kpt-38'),
- 44:
- dict(
- name='kpt-44',
- id=44,
- color=[255, 255, 255],
- type='',
- swap='kpt-37'),
- 45:
- dict(
- name='kpt-45',
- id=45,
- color=[255, 255, 255],
- type='',
- swap='kpt-36'),
- 46:
- dict(
- name='kpt-46',
- id=46,
- color=[255, 255, 255],
- type='',
- swap='kpt-41'),
- 47:
- dict(
- name='kpt-47',
- id=47,
- color=[255, 255, 255],
- type='',
- swap='kpt-40'),
- 48:
- dict(
- name='kpt-48',
- id=48,
- color=[255, 255, 255],
- type='',
- swap='kpt-54'),
- 49:
- dict(
- name='kpt-49',
- id=49,
- color=[255, 255, 255],
- type='',
- swap='kpt-53'),
- 50:
- dict(
- name='kpt-50',
- id=50,
- color=[255, 255, 255],
- type='',
- swap='kpt-52'),
- 51:
- dict(name='kpt-51', id=51, color=[255, 255, 255], type='', swap=''),
- 52:
- dict(
- name='kpt-52',
- id=52,
- color=[255, 255, 255],
- type='',
- swap='kpt-50'),
- 53:
- dict(
- name='kpt-53',
- id=53,
- color=[255, 255, 255],
- type='',
- swap='kpt-49'),
- 54:
- dict(
- name='kpt-54',
- id=54,
- color=[255, 255, 255],
- type='',
- swap='kpt-48'),
- 55:
- dict(
- name='kpt-55',
- id=55,
- color=[255, 255, 255],
- type='',
- swap='kpt-59'),
- 56:
- dict(
- name='kpt-56',
- id=56,
- color=[255, 255, 255],
- type='',
- swap='kpt-58'),
- 57:
- dict(name='kpt-57', id=57, color=[255, 255, 255], type='', swap=''),
- 58:
- dict(
- name='kpt-58',
- id=58,
- color=[255, 255, 255],
- type='',
- swap='kpt-56'),
- 59:
- dict(
- name='kpt-59',
- id=59,
- color=[255, 255, 255],
- type='',
- swap='kpt-55'),
- 60:
- dict(
- name='kpt-60',
- id=60,
- color=[255, 255, 255],
- type='',
- swap='kpt-64'),
- 61:
- dict(
- name='kpt-61',
- id=61,
- color=[255, 255, 255],
- type='',
- swap='kpt-63'),
- 62:
- dict(name='kpt-62', id=62, color=[255, 255, 255], type='', swap=''),
- 63:
- dict(
- name='kpt-63',
- id=63,
- color=[255, 255, 255],
- type='',
- swap='kpt-61'),
- 64:
- dict(
- name='kpt-64',
- id=64,
- color=[255, 255, 255],
- type='',
- swap='kpt-60'),
- 65:
- dict(
- name='kpt-65',
- id=65,
- color=[255, 255, 255],
- type='',
- swap='kpt-67'),
- 66:
- dict(name='kpt-66', id=66, color=[255, 255, 255], type='', swap=''),
- 67:
- dict(
- name='kpt-67',
- id=67,
- color=[255, 255, 255],
- type='',
- swap='kpt-65'),
+ dict(name='kpt-40', id=40, color=[255, 0, 0], type='', swap='kpt-47'),
+ 41: dict(
+ name='kpt-41', id=41, color=[255, 0, 0], type='', swap='kpt-46'),
+ 42: dict(
+ name='kpt-42', id=42, color=[255, 0, 0], type='', swap='kpt-39'),
+ 43: dict(
+ name='kpt-43', id=43, color=[255, 0, 0], type='', swap='kpt-38'),
+ 44: dict(
+ name='kpt-44', id=44, color=[255, 0, 0], type='', swap='kpt-37'),
+ 45: dict(
+ name='kpt-45', id=45, color=[255, 0, 0], type='', swap='kpt-36'),
+ 46: dict(
+ name='kpt-46', id=46, color=[255, 0, 0], type='', swap='kpt-41'),
+ 47: dict(
+ name='kpt-47', id=47, color=[255, 0, 0], type='', swap='kpt-40'),
+ 48: dict(
+ name='kpt-48', id=48, color=[255, 0, 0], type='', swap='kpt-54'),
+ 49: dict(
+ name='kpt-49', id=49, color=[255, 0, 0], type='', swap='kpt-53'),
+ 50: dict(
+ name='kpt-50', id=50, color=[255, 0, 0], type='', swap='kpt-52'),
+ 51: dict(name='kpt-51', id=51, color=[255, 0, 0], type='', swap=''),
+ 52: dict(
+ name='kpt-52', id=52, color=[255, 0, 0], type='', swap='kpt-50'),
+ 53: dict(
+ name='kpt-53', id=53, color=[255, 0, 0], type='', swap='kpt-49'),
+ 54: dict(
+ name='kpt-54', id=54, color=[255, 0, 0], type='', swap='kpt-48'),
+ 55: dict(
+ name='kpt-55', id=55, color=[255, 0, 0], type='', swap='kpt-59'),
+ 56: dict(
+ name='kpt-56', id=56, color=[255, 0, 0], type='', swap='kpt-58'),
+ 57: dict(name='kpt-57', id=57, color=[255, 0, 0], type='', swap=''),
+ 58: dict(
+ name='kpt-58', id=58, color=[255, 0, 0], type='', swap='kpt-56'),
+ 59: dict(
+ name='kpt-59', id=59, color=[255, 0, 0], type='', swap='kpt-55'),
+ 60: dict(
+ name='kpt-60', id=60, color=[255, 0, 0], type='', swap='kpt-64'),
+ 61: dict(
+ name='kpt-61', id=61, color=[255, 0, 0], type='', swap='kpt-63'),
+ 62: dict(name='kpt-62', id=62, color=[255, 0, 0], type='', swap=''),
+ 63: dict(
+ name='kpt-63', id=63, color=[255, 0, 0], type='', swap='kpt-61'),
+ 64: dict(
+ name='kpt-64', id=64, color=[255, 0, 0], type='', swap='kpt-60'),
+ 65: dict(
+ name='kpt-65', id=65, color=[255, 0, 0], type='', swap='kpt-67'),
+ 66: dict(name='kpt-66', id=66, color=[255, 0, 0], type='', swap=''),
+ 67: dict(
+ name='kpt-67', id=67, color=[255, 0, 0], type='', swap='kpt-65'),
},
skeleton_info={},
joint_weights=[1.] * 68,
diff --git a/configs/_base_/datasets/aflw.py b/configs/_base_/datasets/aflw.py
index bf534cbb75..cf5e10964d 100644
--- a/configs/_base_/datasets/aflw.py
+++ b/configs/_base_/datasets/aflw.py
@@ -13,70 +13,31 @@
'team-bischof/lrs/downloads/aflw/',
),
keypoint_info={
- 0:
- dict(name='kpt-0', id=0, color=[255, 255, 255], type='', swap='kpt-5'),
- 1:
- dict(name='kpt-1', id=1, color=[255, 255, 255], type='', swap='kpt-4'),
- 2:
- dict(name='kpt-2', id=2, color=[255, 255, 255], type='', swap='kpt-3'),
- 3:
- dict(name='kpt-3', id=3, color=[255, 255, 255], type='', swap='kpt-2'),
- 4:
- dict(name='kpt-4', id=4, color=[255, 255, 255], type='', swap='kpt-1'),
- 5:
- dict(name='kpt-5', id=5, color=[255, 255, 255], type='', swap='kpt-0'),
- 6:
- dict(
- name='kpt-6', id=6, color=[255, 255, 255], type='', swap='kpt-11'),
- 7:
- dict(
- name='kpt-7', id=7, color=[255, 255, 255], type='', swap='kpt-10'),
- 8:
- dict(name='kpt-8', id=8, color=[255, 255, 255], type='', swap='kpt-9'),
- 9:
- dict(name='kpt-9', id=9, color=[255, 255, 255], type='', swap='kpt-8'),
+ 0: dict(name='kpt-0', id=0, color=[255, 0, 0], type='', swap='kpt-5'),
+ 1: dict(name='kpt-1', id=1, color=[255, 0, 0], type='', swap='kpt-4'),
+ 2: dict(name='kpt-2', id=2, color=[255, 0, 0], type='', swap='kpt-3'),
+ 3: dict(name='kpt-3', id=3, color=[255, 0, 0], type='', swap='kpt-2'),
+ 4: dict(name='kpt-4', id=4, color=[255, 0, 0], type='', swap='kpt-1'),
+ 5: dict(name='kpt-5', id=5, color=[255, 0, 0], type='', swap='kpt-0'),
+ 6: dict(name='kpt-6', id=6, color=[255, 0, 0], type='', swap='kpt-11'),
+ 7: dict(name='kpt-7', id=7, color=[255, 0, 0], type='', swap='kpt-10'),
+ 8: dict(name='kpt-8', id=8, color=[255, 0, 0], type='', swap='kpt-9'),
+ 9: dict(name='kpt-9', id=9, color=[255, 0, 0], type='', swap='kpt-8'),
10:
- dict(
- name='kpt-10', id=10, color=[255, 255, 255], type='',
- swap='kpt-7'),
+ dict(name='kpt-10', id=10, color=[255, 0, 0], type='', swap='kpt-7'),
11:
- dict(
- name='kpt-11', id=11, color=[255, 255, 255], type='',
- swap='kpt-6'),
+ dict(name='kpt-11', id=11, color=[255, 0, 0], type='', swap='kpt-6'),
12:
- dict(
- name='kpt-12',
- id=12,
- color=[255, 255, 255],
- type='',
- swap='kpt-14'),
- 13:
- dict(name='kpt-13', id=13, color=[255, 255, 255], type='', swap=''),
+ dict(name='kpt-12', id=12, color=[255, 0, 0], type='', swap='kpt-14'),
+ 13: dict(name='kpt-13', id=13, color=[255, 0, 0], type='', swap=''),
14:
- dict(
- name='kpt-14',
- id=14,
- color=[255, 255, 255],
- type='',
- swap='kpt-12'),
+ dict(name='kpt-14', id=14, color=[255, 0, 0], type='', swap='kpt-12'),
15:
- dict(
- name='kpt-15',
- id=15,
- color=[255, 255, 255],
- type='',
- swap='kpt-17'),
- 16:
- dict(name='kpt-16', id=16, color=[255, 255, 255], type='', swap=''),
+ dict(name='kpt-15', id=15, color=[255, 0, 0], type='', swap='kpt-17'),
+ 16: dict(name='kpt-16', id=16, color=[255, 0, 0], type='', swap=''),
17:
- dict(
- name='kpt-17',
- id=17,
- color=[255, 255, 255],
- type='',
- swap='kpt-15'),
- 18:
- dict(name='kpt-18', id=18, color=[255, 255, 255], type='', swap='')
+ dict(name='kpt-17', id=17, color=[255, 0, 0], type='', swap='kpt-15'),
+ 18: dict(name='kpt-18', id=18, color=[255, 0, 0], type='', swap='')
},
skeleton_info={},
joint_weights=[1.] * 19,
diff --git a/configs/_base_/datasets/coco_wholebody_face.py b/configs/_base_/datasets/coco_wholebody_face.py
index 7c9ee3350e..a3fe1e5b33 100644
--- a/configs/_base_/datasets/coco_wholebody_face.py
+++ b/configs/_base_/datasets/coco_wholebody_face.py
@@ -12,425 +12,131 @@
),
keypoint_info={
0:
- dict(
- name='face-0',
- id=0,
- color=[255, 255, 255],
- type='',
- swap='face-16'),
+ dict(name='face-0', id=0, color=[255, 0, 0], type='', swap='face-16'),
1:
- dict(
- name='face-1',
- id=1,
- color=[255, 255, 255],
- type='',
- swap='face-15'),
+ dict(name='face-1', id=1, color=[255, 0, 0], type='', swap='face-15'),
2:
- dict(
- name='face-2',
- id=2,
- color=[255, 255, 255],
- type='',
- swap='face-14'),
+ dict(name='face-2', id=2, color=[255, 0, 0], type='', swap='face-14'),
3:
- dict(
- name='face-3',
- id=3,
- color=[255, 255, 255],
- type='',
- swap='face-13'),
+ dict(name='face-3', id=3, color=[255, 0, 0], type='', swap='face-13'),
4:
- dict(
- name='face-4',
- id=4,
- color=[255, 255, 255],
- type='',
- swap='face-12'),
+ dict(name='face-4', id=4, color=[255, 0, 0], type='', swap='face-12'),
5:
- dict(
- name='face-5',
- id=5,
- color=[255, 255, 255],
- type='',
- swap='face-11'),
+ dict(name='face-5', id=5, color=[255, 0, 0], type='', swap='face-11'),
6:
- dict(
- name='face-6',
- id=6,
- color=[255, 255, 255],
- type='',
- swap='face-10'),
+ dict(name='face-6', id=6, color=[255, 0, 0], type='', swap='face-10'),
7:
- dict(
- name='face-7', id=7, color=[255, 255, 255], type='',
- swap='face-9'),
- 8:
- dict(name='face-8', id=8, color=[255, 255, 255], type='', swap=''),
+ dict(name='face-7', id=7, color=[255, 0, 0], type='', swap='face-9'),
+ 8: dict(name='face-8', id=8, color=[255, 0, 0], type='', swap=''),
9:
- dict(
- name='face-9', id=9, color=[255, 255, 255], type='',
- swap='face-7'),
+ dict(name='face-9', id=9, color=[255, 0, 0], type='', swap='face-7'),
10:
- dict(
- name='face-10',
- id=10,
- color=[255, 255, 255],
- type='',
- swap='face-6'),
+ dict(name='face-10', id=10, color=[255, 0, 0], type='', swap='face-6'),
11:
- dict(
- name='face-11',
- id=11,
- color=[255, 255, 255],
- type='',
- swap='face-5'),
+ dict(name='face-11', id=11, color=[255, 0, 0], type='', swap='face-5'),
12:
- dict(
- name='face-12',
- id=12,
- color=[255, 255, 255],
- type='',
- swap='face-4'),
+ dict(name='face-12', id=12, color=[255, 0, 0], type='', swap='face-4'),
13:
- dict(
- name='face-13',
- id=13,
- color=[255, 255, 255],
- type='',
- swap='face-3'),
+ dict(name='face-13', id=13, color=[255, 0, 0], type='', swap='face-3'),
14:
- dict(
- name='face-14',
- id=14,
- color=[255, 255, 255],
- type='',
- swap='face-2'),
+ dict(name='face-14', id=14, color=[255, 0, 0], type='', swap='face-2'),
15:
- dict(
- name='face-15',
- id=15,
- color=[255, 255, 255],
- type='',
- swap='face-1'),
+ dict(name='face-15', id=15, color=[255, 0, 0], type='', swap='face-1'),
16:
- dict(
- name='face-16',
- id=16,
- color=[255, 255, 255],
- type='',
- swap='face-0'),
- 17:
- dict(
- name='face-17',
- id=17,
- color=[255, 255, 255],
- type='',
- swap='face-26'),
- 18:
- dict(
- name='face-18',
- id=18,
- color=[255, 255, 255],
- type='',
- swap='face-25'),
- 19:
- dict(
- name='face-19',
- id=19,
- color=[255, 255, 255],
- type='',
- swap='face-24'),
- 20:
- dict(
- name='face-20',
- id=20,
- color=[255, 255, 255],
- type='',
- swap='face-23'),
- 21:
- dict(
- name='face-21',
- id=21,
- color=[255, 255, 255],
- type='',
- swap='face-22'),
- 22:
- dict(
- name='face-22',
- id=22,
- color=[255, 255, 255],
- type='',
- swap='face-21'),
- 23:
- dict(
- name='face-23',
- id=23,
- color=[255, 255, 255],
- type='',
- swap='face-20'),
- 24:
- dict(
- name='face-24',
- id=24,
- color=[255, 255, 255],
- type='',
- swap='face-19'),
- 25:
- dict(
- name='face-25',
- id=25,
- color=[255, 255, 255],
- type='',
- swap='face-18'),
- 26:
- dict(
- name='face-26',
- id=26,
- color=[255, 255, 255],
- type='',
- swap='face-17'),
- 27:
- dict(name='face-27', id=27, color=[255, 255, 255], type='', swap=''),
- 28:
- dict(name='face-28', id=28, color=[255, 255, 255], type='', swap=''),
- 29:
- dict(name='face-29', id=29, color=[255, 255, 255], type='', swap=''),
- 30:
- dict(name='face-30', id=30, color=[255, 255, 255], type='', swap=''),
- 31:
- dict(
- name='face-31',
- id=31,
- color=[255, 255, 255],
- type='',
- swap='face-35'),
- 32:
- dict(
- name='face-32',
- id=32,
- color=[255, 255, 255],
- type='',
- swap='face-34'),
- 33:
- dict(name='face-33', id=33, color=[255, 255, 255], type='', swap=''),
- 34:
- dict(
- name='face-34',
- id=34,
- color=[255, 255, 255],
- type='',
- swap='face-32'),
- 35:
- dict(
- name='face-35',
- id=35,
- color=[255, 255, 255],
- type='',
- swap='face-31'),
- 36:
- dict(
- name='face-36',
- id=36,
- color=[255, 255, 255],
- type='',
- swap='face-45'),
- 37:
- dict(
- name='face-37',
- id=37,
- color=[255, 255, 255],
- type='',
- swap='face-44'),
- 38:
- dict(
- name='face-38',
- id=38,
- color=[255, 255, 255],
- type='',
- swap='face-43'),
- 39:
- dict(
- name='face-39',
- id=39,
- color=[255, 255, 255],
- type='',
- swap='face-42'),
- 40:
- dict(
- name='face-40',
- id=40,
- color=[255, 255, 255],
- type='',
- swap='face-47'),
- 41:
- dict(
- name='face-41',
- id=41,
- color=[255, 255, 255],
- type='',
- swap='face-46'),
- 42:
- dict(
- name='face-42',
- id=42,
- color=[255, 255, 255],
- type='',
- swap='face-39'),
- 43:
- dict(
- name='face-43',
- id=43,
- color=[255, 255, 255],
- type='',
- swap='face-38'),
- 44:
- dict(
- name='face-44',
- id=44,
- color=[255, 255, 255],
- type='',
- swap='face-37'),
- 45:
- dict(
- name='face-45',
- id=45,
- color=[255, 255, 255],
- type='',
- swap='face-36'),
- 46:
- dict(
- name='face-46',
- id=46,
- color=[255, 255, 255],
- type='',
- swap='face-41'),
- 47:
- dict(
- name='face-47',
- id=47,
- color=[255, 255, 255],
- type='',
- swap='face-40'),
- 48:
- dict(
- name='face-48',
- id=48,
- color=[255, 255, 255],
- type='',
- swap='face-54'),
- 49:
- dict(
- name='face-49',
- id=49,
- color=[255, 255, 255],
- type='',
- swap='face-53'),
- 50:
- dict(
- name='face-50',
- id=50,
- color=[255, 255, 255],
- type='',
- swap='face-52'),
- 51:
- dict(name='face-51', id=52, color=[255, 255, 255], type='', swap=''),
- 52:
- dict(
- name='face-52',
- id=52,
- color=[255, 255, 255],
- type='',
- swap='face-50'),
- 53:
- dict(
- name='face-53',
- id=53,
- color=[255, 255, 255],
- type='',
- swap='face-49'),
- 54:
- dict(
- name='face-54',
- id=54,
- color=[255, 255, 255],
- type='',
- swap='face-48'),
- 55:
- dict(
- name='face-55',
- id=55,
- color=[255, 255, 255],
- type='',
- swap='face-59'),
- 56:
- dict(
- name='face-56',
- id=56,
- color=[255, 255, 255],
- type='',
- swap='face-58'),
- 57:
- dict(name='face-57', id=57, color=[255, 255, 255], type='', swap=''),
- 58:
- dict(
- name='face-58',
- id=58,
- color=[255, 255, 255],
- type='',
- swap='face-56'),
- 59:
- dict(
- name='face-59',
- id=59,
- color=[255, 255, 255],
- type='',
- swap='face-55'),
- 60:
- dict(
- name='face-60',
- id=60,
- color=[255, 255, 255],
- type='',
- swap='face-64'),
- 61:
- dict(
- name='face-61',
- id=61,
- color=[255, 255, 255],
- type='',
- swap='face-63'),
- 62:
- dict(name='face-62', id=62, color=[255, 255, 255], type='', swap=''),
- 63:
- dict(
- name='face-63',
- id=63,
- color=[255, 255, 255],
- type='',
- swap='face-61'),
- 64:
- dict(
- name='face-64',
- id=64,
- color=[255, 255, 255],
- type='',
- swap='face-60'),
- 65:
- dict(
- name='face-65',
- id=65,
- color=[255, 255, 255],
- type='',
- swap='face-67'),
- 66:
- dict(name='face-66', id=66, color=[255, 255, 255], type='', swap=''),
- 67:
- dict(
- name='face-67',
- id=67,
- color=[255, 255, 255],
- type='',
- swap='face-65')
+ dict(name='face-16', id=16, color=[255, 0, 0], type='', swap='face-0'),
+ 17: dict(
+ name='face-17', id=17, color=[255, 0, 0], type='', swap='face-26'),
+ 18: dict(
+ name='face-18', id=18, color=[255, 0, 0], type='', swap='face-25'),
+ 19: dict(
+ name='face-19', id=19, color=[255, 0, 0], type='', swap='face-24'),
+ 20: dict(
+ name='face-20', id=20, color=[255, 0, 0], type='', swap='face-23'),
+ 21: dict(
+ name='face-21', id=21, color=[255, 0, 0], type='', swap='face-22'),
+ 22: dict(
+ name='face-22', id=22, color=[255, 0, 0], type='', swap='face-21'),
+ 23: dict(
+ name='face-23', id=23, color=[255, 0, 0], type='', swap='face-20'),
+ 24: dict(
+ name='face-24', id=24, color=[255, 0, 0], type='', swap='face-19'),
+ 25: dict(
+ name='face-25', id=25, color=[255, 0, 0], type='', swap='face-18'),
+ 26: dict(
+ name='face-26', id=26, color=[255, 0, 0], type='', swap='face-17'),
+ 27: dict(name='face-27', id=27, color=[255, 0, 0], type='', swap=''),
+ 28: dict(name='face-28', id=28, color=[255, 0, 0], type='', swap=''),
+ 29: dict(name='face-29', id=29, color=[255, 0, 0], type='', swap=''),
+ 30: dict(name='face-30', id=30, color=[255, 0, 0], type='', swap=''),
+ 31: dict(
+ name='face-31', id=31, color=[255, 0, 0], type='', swap='face-35'),
+ 32: dict(
+ name='face-32', id=32, color=[255, 0, 0], type='', swap='face-34'),
+ 33: dict(name='face-33', id=33, color=[255, 0, 0], type='', swap=''),
+ 34: dict(
+ name='face-34', id=34, color=[255, 0, 0], type='', swap='face-32'),
+ 35: dict(
+ name='face-35', id=35, color=[255, 0, 0], type='', swap='face-31'),
+ 36: dict(
+ name='face-36', id=36, color=[255, 0, 0], type='', swap='face-45'),
+ 37: dict(
+ name='face-37', id=37, color=[255, 0, 0], type='', swap='face-44'),
+ 38: dict(
+ name='face-38', id=38, color=[255, 0, 0], type='', swap='face-43'),
+ 39: dict(
+ name='face-39', id=39, color=[255, 0, 0], type='', swap='face-42'),
+ 40: dict(
+ name='face-40', id=40, color=[255, 0, 0], type='', swap='face-47'),
+ 41: dict(
+ name='face-41', id=41, color=[255, 0, 0], type='', swap='face-46'),
+ 42: dict(
+ name='face-42', id=42, color=[255, 0, 0], type='', swap='face-39'),
+ 43: dict(
+ name='face-43', id=43, color=[255, 0, 0], type='', swap='face-38'),
+ 44: dict(
+ name='face-44', id=44, color=[255, 0, 0], type='', swap='face-37'),
+ 45: dict(
+ name='face-45', id=45, color=[255, 0, 0], type='', swap='face-36'),
+ 46: dict(
+ name='face-46', id=46, color=[255, 0, 0], type='', swap='face-41'),
+ 47: dict(
+ name='face-47', id=47, color=[255, 0, 0], type='', swap='face-40'),
+ 48: dict(
+ name='face-48', id=48, color=[255, 0, 0], type='', swap='face-54'),
+ 49: dict(
+ name='face-49', id=49, color=[255, 0, 0], type='', swap='face-53'),
+ 50: dict(
+ name='face-50', id=50, color=[255, 0, 0], type='', swap='face-52'),
+ 51: dict(name='face-51', id=52, color=[255, 0, 0], type='', swap=''),
+ 52: dict(
+ name='face-52', id=52, color=[255, 0, 0], type='', swap='face-50'),
+ 53: dict(
+ name='face-53', id=53, color=[255, 0, 0], type='', swap='face-49'),
+ 54: dict(
+ name='face-54', id=54, color=[255, 0, 0], type='', swap='face-48'),
+ 55: dict(
+ name='face-55', id=55, color=[255, 0, 0], type='', swap='face-59'),
+ 56: dict(
+ name='face-56', id=56, color=[255, 0, 0], type='', swap='face-58'),
+ 57: dict(name='face-57', id=57, color=[255, 0, 0], type='', swap=''),
+ 58: dict(
+ name='face-58', id=58, color=[255, 0, 0], type='', swap='face-56'),
+ 59: dict(
+ name='face-59', id=59, color=[255, 0, 0], type='', swap='face-55'),
+ 60: dict(
+ name='face-60', id=60, color=[255, 0, 0], type='', swap='face-64'),
+ 61: dict(
+ name='face-61', id=61, color=[255, 0, 0], type='', swap='face-63'),
+ 62: dict(name='face-62', id=62, color=[255, 0, 0], type='', swap=''),
+ 63: dict(
+ name='face-63', id=63, color=[255, 0, 0], type='', swap='face-61'),
+ 64: dict(
+ name='face-64', id=64, color=[255, 0, 0], type='', swap='face-60'),
+ 65: dict(
+ name='face-65', id=65, color=[255, 0, 0], type='', swap='face-67'),
+ 66: dict(name='face-66', id=66, color=[255, 0, 0], type='', swap=''),
+ 67: dict(
+ name='face-67', id=67, color=[255, 0, 0], type='', swap='face-65')
},
skeleton_info={},
joint_weights=[1.] * 68,
diff --git a/configs/_base_/datasets/cofw.py b/configs/_base_/datasets/cofw.py
index 2fb7ad2f8d..d528bf2f2f 100644
--- a/configs/_base_/datasets/cofw.py
+++ b/configs/_base_/datasets/cofw.py
@@ -10,124 +10,47 @@
homepage='http://www.vision.caltech.edu/xpburgos/ICCV13/',
),
keypoint_info={
- 0:
- dict(name='kpt-0', id=0, color=[255, 255, 255], type='', swap='kpt-1'),
- 1:
- dict(name='kpt-1', id=1, color=[255, 255, 255], type='', swap='kpt-0'),
- 2:
- dict(name='kpt-2', id=2, color=[255, 255, 255], type='', swap='kpt-3'),
- 3:
- dict(name='kpt-3', id=3, color=[255, 255, 255], type='', swap='kpt-2'),
- 4:
- dict(name='kpt-4', id=4, color=[255, 255, 255], type='', swap='kpt-6'),
- 5:
- dict(name='kpt-5', id=5, color=[255, 255, 255], type='', swap='kpt-7'),
- 6:
- dict(name='kpt-6', id=6, color=[255, 255, 255], type='', swap='kpt-4'),
- 7:
- dict(name='kpt-7', id=7, color=[255, 255, 255], type='', swap='kpt-5'),
- 8:
- dict(name='kpt-8', id=8, color=[255, 255, 255], type='', swap='kpt-9'),
- 9:
- dict(name='kpt-9', id=9, color=[255, 255, 255], type='', swap='kpt-8'),
+ 0: dict(name='kpt-0', id=0, color=[255, 0, 0], type='', swap='kpt-1'),
+ 1: dict(name='kpt-1', id=1, color=[255, 0, 0], type='', swap='kpt-0'),
+ 2: dict(name='kpt-2', id=2, color=[255, 0, 0], type='', swap='kpt-3'),
+ 3: dict(name='kpt-3', id=3, color=[255, 0, 0], type='', swap='kpt-2'),
+ 4: dict(name='kpt-4', id=4, color=[255, 0, 0], type='', swap='kpt-6'),
+ 5: dict(name='kpt-5', id=5, color=[255, 0, 0], type='', swap='kpt-7'),
+ 6: dict(name='kpt-6', id=6, color=[255, 0, 0], type='', swap='kpt-4'),
+ 7: dict(name='kpt-7', id=7, color=[255, 0, 0], type='', swap='kpt-5'),
+ 8: dict(name='kpt-8', id=8, color=[255, 0, 0], type='', swap='kpt-9'),
+ 9: dict(name='kpt-9', id=9, color=[255, 0, 0], type='', swap='kpt-8'),
10:
- dict(
- name='kpt-10',
- id=10,
- color=[255, 255, 255],
- type='',
- swap='kpt-11'),
+ dict(name='kpt-10', id=10, color=[255, 0, 0], type='', swap='kpt-11'),
11:
- dict(
- name='kpt-11',
- id=11,
- color=[255, 255, 255],
- type='',
- swap='kpt-10'),
+ dict(name='kpt-11', id=11, color=[255, 0, 0], type='', swap='kpt-10'),
12:
- dict(
- name='kpt-12',
- id=12,
- color=[255, 255, 255],
- type='',
- swap='kpt-14'),
+ dict(name='kpt-12', id=12, color=[255, 0, 0], type='', swap='kpt-14'),
13:
- dict(
- name='kpt-13',
- id=13,
- color=[255, 255, 255],
- type='',
- swap='kpt-15'),
+ dict(name='kpt-13', id=13, color=[255, 0, 0], type='', swap='kpt-15'),
14:
- dict(
- name='kpt-14',
- id=14,
- color=[255, 255, 255],
- type='',
- swap='kpt-12'),
+ dict(name='kpt-14', id=14, color=[255, 0, 0], type='', swap='kpt-12'),
15:
- dict(
- name='kpt-15',
- id=15,
- color=[255, 255, 255],
- type='',
- swap='kpt-13'),
+ dict(name='kpt-15', id=15, color=[255, 0, 0], type='', swap='kpt-13'),
16:
- dict(
- name='kpt-16',
- id=16,
- color=[255, 255, 255],
- type='',
- swap='kpt-17'),
+ dict(name='kpt-16', id=16, color=[255, 0, 0], type='', swap='kpt-17'),
17:
- dict(
- name='kpt-17',
- id=17,
- color=[255, 255, 255],
- type='',
- swap='kpt-16'),
+ dict(name='kpt-17', id=17, color=[255, 0, 0], type='', swap='kpt-16'),
18:
- dict(
- name='kpt-18',
- id=18,
- color=[255, 255, 255],
- type='',
- swap='kpt-19'),
+ dict(name='kpt-18', id=18, color=[255, 0, 0], type='', swap='kpt-19'),
19:
- dict(
- name='kpt-19',
- id=19,
- color=[255, 255, 255],
- type='',
- swap='kpt-18'),
- 20:
- dict(name='kpt-20', id=20, color=[255, 255, 255], type='', swap=''),
- 21:
- dict(name='kpt-21', id=21, color=[255, 255, 255], type='', swap=''),
+ dict(name='kpt-19', id=19, color=[255, 0, 0], type='', swap='kpt-18'),
+ 20: dict(name='kpt-20', id=20, color=[255, 0, 0], type='', swap=''),
+ 21: dict(name='kpt-21', id=21, color=[255, 0, 0], type='', swap=''),
22:
- dict(
- name='kpt-22',
- id=22,
- color=[255, 255, 255],
- type='',
- swap='kpt-23'),
+ dict(name='kpt-22', id=22, color=[255, 0, 0], type='', swap='kpt-23'),
23:
- dict(
- name='kpt-23',
- id=23,
- color=[255, 255, 255],
- type='',
- swap='kpt-22'),
- 24:
- dict(name='kpt-24', id=24, color=[255, 255, 255], type='', swap=''),
- 25:
- dict(name='kpt-25', id=25, color=[255, 255, 255], type='', swap=''),
- 26:
- dict(name='kpt-26', id=26, color=[255, 255, 255], type='', swap=''),
- 27:
- dict(name='kpt-27', id=27, color=[255, 255, 255], type='', swap=''),
- 28:
- dict(name='kpt-28', id=28, color=[255, 255, 255], type='', swap='')
+ dict(name='kpt-23', id=23, color=[255, 0, 0], type='', swap='kpt-22'),
+ 24: dict(name='kpt-24', id=24, color=[255, 0, 0], type='', swap=''),
+ 25: dict(name='kpt-25', id=25, color=[255, 0, 0], type='', swap=''),
+ 26: dict(name='kpt-26', id=26, color=[255, 0, 0], type='', swap=''),
+ 27: dict(name='kpt-27', id=27, color=[255, 0, 0], type='', swap=''),
+ 28: dict(name='kpt-28', id=28, color=[255, 0, 0], type='', swap='')
},
skeleton_info={},
joint_weights=[1.] * 29,
diff --git a/configs/_base_/datasets/wflw.py b/configs/_base_/datasets/wflw.py
index bed6f56f30..80c29b696c 100644
--- a/configs/_base_/datasets/wflw.py
+++ b/configs/_base_/datasets/wflw.py
@@ -10,572 +10,182 @@
homepage='https://wywu.github.io/projects/LAB/WFLW.html',
),
keypoint_info={
- 0:
- dict(
- name='kpt-0', id=0, color=[255, 255, 255], type='', swap='kpt-32'),
- 1:
- dict(
- name='kpt-1', id=1, color=[255, 255, 255], type='', swap='kpt-31'),
- 2:
- dict(
- name='kpt-2', id=2, color=[255, 255, 255], type='', swap='kpt-30'),
- 3:
- dict(
- name='kpt-3', id=3, color=[255, 255, 255], type='', swap='kpt-29'),
- 4:
- dict(
- name='kpt-4', id=4, color=[255, 255, 255], type='', swap='kpt-28'),
- 5:
- dict(
- name='kpt-5', id=5, color=[255, 255, 255], type='', swap='kpt-27'),
- 6:
- dict(
- name='kpt-6', id=6, color=[255, 255, 255], type='', swap='kpt-26'),
- 7:
- dict(
- name='kpt-7', id=7, color=[255, 255, 255], type='', swap='kpt-25'),
- 8:
- dict(
- name='kpt-8', id=8, color=[255, 255, 255], type='', swap='kpt-24'),
- 9:
- dict(
- name='kpt-9', id=9, color=[255, 255, 255], type='', swap='kpt-23'),
+ 0: dict(name='kpt-0', id=0, color=[255, 0, 0], type='', swap='kpt-32'),
+ 1: dict(name='kpt-1', id=1, color=[255, 0, 0], type='', swap='kpt-31'),
+ 2: dict(name='kpt-2', id=2, color=[255, 0, 0], type='', swap='kpt-30'),
+ 3: dict(name='kpt-3', id=3, color=[255, 0, 0], type='', swap='kpt-29'),
+ 4: dict(name='kpt-4', id=4, color=[255, 0, 0], type='', swap='kpt-28'),
+ 5: dict(name='kpt-5', id=5, color=[255, 0, 0], type='', swap='kpt-27'),
+ 6: dict(name='kpt-6', id=6, color=[255, 0, 0], type='', swap='kpt-26'),
+ 7: dict(name='kpt-7', id=7, color=[255, 0, 0], type='', swap='kpt-25'),
+ 8: dict(name='kpt-8', id=8, color=[255, 0, 0], type='', swap='kpt-24'),
+ 9: dict(name='kpt-9', id=9, color=[255, 0, 0], type='', swap='kpt-23'),
10:
- dict(
- name='kpt-10',
- id=10,
- color=[255, 255, 255],
- type='',
- swap='kpt-22'),
+ dict(name='kpt-10', id=10, color=[255, 0, 0], type='', swap='kpt-22'),
11:
- dict(
- name='kpt-11',
- id=11,
- color=[255, 255, 255],
- type='',
- swap='kpt-21'),
+ dict(name='kpt-11', id=11, color=[255, 0, 0], type='', swap='kpt-21'),
12:
- dict(
- name='kpt-12',
- id=12,
- color=[255, 255, 255],
- type='',
- swap='kpt-20'),
+ dict(name='kpt-12', id=12, color=[255, 0, 0], type='', swap='kpt-20'),
13:
- dict(
- name='kpt-13',
- id=13,
- color=[255, 255, 255],
- type='',
- swap='kpt-19'),
+ dict(name='kpt-13', id=13, color=[255, 0, 0], type='', swap='kpt-19'),
14:
- dict(
- name='kpt-14',
- id=14,
- color=[255, 255, 255],
- type='',
- swap='kpt-18'),
+ dict(name='kpt-14', id=14, color=[255, 0, 0], type='', swap='kpt-18'),
15:
- dict(
- name='kpt-15',
- id=15,
- color=[255, 255, 255],
- type='',
- swap='kpt-17'),
- 16:
- dict(name='kpt-16', id=16, color=[255, 255, 255], type='', swap=''),
+ dict(name='kpt-15', id=15, color=[255, 0, 0], type='', swap='kpt-17'),
+ 16: dict(name='kpt-16', id=16, color=[255, 0, 0], type='', swap=''),
17:
- dict(
- name='kpt-17',
- id=17,
- color=[255, 255, 255],
- type='',
- swap='kpt-15'),
+ dict(name='kpt-17', id=17, color=[255, 0, 0], type='', swap='kpt-15'),
18:
- dict(
- name='kpt-18',
- id=18,
- color=[255, 255, 255],
- type='',
- swap='kpt-14'),
+ dict(name='kpt-18', id=18, color=[255, 0, 0], type='', swap='kpt-14'),
19:
- dict(
- name='kpt-19',
- id=19,
- color=[255, 255, 255],
- type='',
- swap='kpt-13'),
+ dict(name='kpt-19', id=19, color=[255, 0, 0], type='', swap='kpt-13'),
20:
- dict(
- name='kpt-20',
- id=20,
- color=[255, 255, 255],
- type='',
- swap='kpt-12'),
+ dict(name='kpt-20', id=20, color=[255, 0, 0], type='', swap='kpt-12'),
21:
- dict(
- name='kpt-21',
- id=21,
- color=[255, 255, 255],
- type='',
- swap='kpt-11'),
+ dict(name='kpt-21', id=21, color=[255, 0, 0], type='', swap='kpt-11'),
22:
- dict(
- name='kpt-22',
- id=22,
- color=[255, 255, 255],
- type='',
- swap='kpt-10'),
+ dict(name='kpt-22', id=22, color=[255, 0, 0], type='', swap='kpt-10'),
23:
- dict(
- name='kpt-23', id=23, color=[255, 255, 255], type='',
- swap='kpt-9'),
+ dict(name='kpt-23', id=23, color=[255, 0, 0], type='', swap='kpt-9'),
24:
- dict(
- name='kpt-24', id=24, color=[255, 255, 255], type='',
- swap='kpt-8'),
+ dict(name='kpt-24', id=24, color=[255, 0, 0], type='', swap='kpt-8'),
25:
- dict(
- name='kpt-25', id=25, color=[255, 255, 255], type='',
- swap='kpt-7'),
+ dict(name='kpt-25', id=25, color=[255, 0, 0], type='', swap='kpt-7'),
26:
- dict(
- name='kpt-26', id=26, color=[255, 255, 255], type='',
- swap='kpt-6'),
+ dict(name='kpt-26', id=26, color=[255, 0, 0], type='', swap='kpt-6'),
27:
- dict(
- name='kpt-27', id=27, color=[255, 255, 255], type='',
- swap='kpt-5'),
+ dict(name='kpt-27', id=27, color=[255, 0, 0], type='', swap='kpt-5'),
28:
- dict(
- name='kpt-28', id=28, color=[255, 255, 255], type='',
- swap='kpt-4'),
+ dict(name='kpt-28', id=28, color=[255, 0, 0], type='', swap='kpt-4'),
29:
- dict(
- name='kpt-29', id=29, color=[255, 255, 255], type='',
- swap='kpt-3'),
+ dict(name='kpt-29', id=29, color=[255, 0, 0], type='', swap='kpt-3'),
30:
- dict(
- name='kpt-30', id=30, color=[255, 255, 255], type='',
- swap='kpt-2'),
+ dict(name='kpt-30', id=30, color=[255, 0, 0], type='', swap='kpt-2'),
31:
- dict(
- name='kpt-31', id=31, color=[255, 255, 255], type='',
- swap='kpt-1'),
+ dict(name='kpt-31', id=31, color=[255, 0, 0], type='', swap='kpt-1'),
32:
- dict(
- name='kpt-32', id=32, color=[255, 255, 255], type='',
- swap='kpt-0'),
+ dict(name='kpt-32', id=32, color=[255, 0, 0], type='', swap='kpt-0'),
33:
- dict(
- name='kpt-33',
- id=33,
- color=[255, 255, 255],
- type='',
- swap='kpt-46'),
+ dict(name='kpt-33', id=33, color=[255, 0, 0], type='', swap='kpt-46'),
34:
- dict(
- name='kpt-34',
- id=34,
- color=[255, 255, 255],
- type='',
- swap='kpt-45'),
+ dict(name='kpt-34', id=34, color=[255, 0, 0], type='', swap='kpt-45'),
35:
- dict(
- name='kpt-35',
- id=35,
- color=[255, 255, 255],
- type='',
- swap='kpt-44'),
+ dict(name='kpt-35', id=35, color=[255, 0, 0], type='', swap='kpt-44'),
36:
- dict(
- name='kpt-36',
- id=36,
- color=[255, 255, 255],
- type='',
- swap='kpt-43'),
- 37:
- dict(
- name='kpt-37',
- id=37,
- color=[255, 255, 255],
- type='',
- swap='kpt-42'),
- 38:
- dict(
- name='kpt-38',
- id=38,
- color=[255, 255, 255],
- type='',
- swap='kpt-50'),
- 39:
- dict(
- name='kpt-39',
- id=39,
- color=[255, 255, 255],
- type='',
- swap='kpt-49'),
- 40:
- dict(
- name='kpt-40',
- id=40,
- color=[255, 255, 255],
- type='',
- swap='kpt-48'),
- 41:
- dict(
- name='kpt-41',
- id=41,
- color=[255, 255, 255],
- type='',
- swap='kpt-47'),
- 42:
- dict(
- name='kpt-42',
- id=42,
- color=[255, 255, 255],
- type='',
- swap='kpt-37'),
- 43:
- dict(
- name='kpt-43',
- id=43,
- color=[255, 255, 255],
- type='',
- swap='kpt-36'),
- 44:
- dict(
- name='kpt-44',
- id=44,
- color=[255, 255, 255],
- type='',
- swap='kpt-35'),
- 45:
- dict(
- name='kpt-45',
- id=45,
- color=[255, 255, 255],
- type='',
- swap='kpt-34'),
- 46:
- dict(
- name='kpt-46',
- id=46,
- color=[255, 255, 255],
- type='',
- swap='kpt-33'),
- 47:
- dict(
- name='kpt-47',
- id=47,
- color=[255, 255, 255],
- type='',
- swap='kpt-41'),
- 48:
- dict(
- name='kpt-48',
- id=48,
- color=[255, 255, 255],
- type='',
- swap='kpt-40'),
- 49:
- dict(
- name='kpt-49',
- id=49,
- color=[255, 255, 255],
- type='',
- swap='kpt-39'),
- 50:
- dict(
- name='kpt-50',
- id=50,
- color=[255, 255, 255],
- type='',
- swap='kpt-38'),
- 51:
- dict(name='kpt-51', id=51, color=[255, 255, 255], type='', swap=''),
- 52:
- dict(name='kpt-52', id=52, color=[255, 255, 255], type='', swap=''),
- 53:
- dict(name='kpt-53', id=53, color=[255, 255, 255], type='', swap=''),
- 54:
- dict(name='kpt-54', id=54, color=[255, 255, 255], type='', swap=''),
- 55:
- dict(
- name='kpt-55',
- id=55,
- color=[255, 255, 255],
- type='',
- swap='kpt-59'),
- 56:
- dict(
- name='kpt-56',
- id=56,
- color=[255, 255, 255],
- type='',
- swap='kpt-58'),
- 57:
- dict(name='kpt-57', id=57, color=[255, 255, 255], type='', swap=''),
- 58:
- dict(
- name='kpt-58',
- id=58,
- color=[255, 255, 255],
- type='',
- swap='kpt-56'),
- 59:
- dict(
- name='kpt-59',
- id=59,
- color=[255, 255, 255],
- type='',
- swap='kpt-55'),
- 60:
- dict(
- name='kpt-60',
- id=60,
- color=[255, 255, 255],
- type='',
- swap='kpt-72'),
- 61:
- dict(
- name='kpt-61',
- id=61,
- color=[255, 255, 255],
- type='',
- swap='kpt-71'),
- 62:
- dict(
- name='kpt-62',
- id=62,
- color=[255, 255, 255],
- type='',
- swap='kpt-70'),
- 63:
- dict(
- name='kpt-63',
- id=63,
- color=[255, 255, 255],
- type='',
- swap='kpt-69'),
- 64:
- dict(
- name='kpt-64',
- id=64,
- color=[255, 255, 255],
- type='',
- swap='kpt-68'),
- 65:
- dict(
- name='kpt-65',
- id=65,
- color=[255, 255, 255],
- type='',
- swap='kpt-75'),
- 66:
- dict(
- name='kpt-66',
- id=66,
- color=[255, 255, 255],
- type='',
- swap='kpt-74'),
- 67:
- dict(
- name='kpt-67',
- id=67,
- color=[255, 255, 255],
- type='',
- swap='kpt-73'),
- 68:
- dict(
- name='kpt-68',
- id=68,
- color=[255, 255, 255],
- type='',
- swap='kpt-64'),
- 69:
- dict(
- name='kpt-69',
- id=69,
- color=[255, 255, 255],
- type='',
- swap='kpt-63'),
- 70:
- dict(
- name='kpt-70',
- id=70,
- color=[255, 255, 255],
- type='',
- swap='kpt-62'),
- 71:
- dict(
- name='kpt-71',
- id=71,
- color=[255, 255, 255],
- type='',
- swap='kpt-61'),
- 72:
- dict(
- name='kpt-72',
- id=72,
- color=[255, 255, 255],
- type='',
- swap='kpt-60'),
- 73:
- dict(
- name='kpt-73',
- id=73,
- color=[255, 255, 255],
- type='',
- swap='kpt-67'),
- 74:
- dict(
- name='kpt-74',
- id=74,
- color=[255, 255, 255],
- type='',
- swap='kpt-66'),
- 75:
- dict(
- name='kpt-75',
- id=75,
- color=[255, 255, 255],
- type='',
- swap='kpt-65'),
- 76:
- dict(
- name='kpt-76',
- id=76,
- color=[255, 255, 255],
- type='',
- swap='kpt-82'),
- 77:
- dict(
- name='kpt-77',
- id=77,
- color=[255, 255, 255],
- type='',
- swap='kpt-81'),
- 78:
- dict(
- name='kpt-78',
- id=78,
- color=[255, 255, 255],
- type='',
- swap='kpt-80'),
- 79:
- dict(name='kpt-79', id=79, color=[255, 255, 255], type='', swap=''),
- 80:
- dict(
- name='kpt-80',
- id=80,
- color=[255, 255, 255],
- type='',
- swap='kpt-78'),
- 81:
- dict(
- name='kpt-81',
- id=81,
- color=[255, 255, 255],
- type='',
- swap='kpt-77'),
- 82:
- dict(
- name='kpt-82',
- id=82,
- color=[255, 255, 255],
- type='',
- swap='kpt-76'),
- 83:
- dict(
- name='kpt-83',
- id=83,
- color=[255, 255, 255],
- type='',
- swap='kpt-87'),
- 84:
- dict(
- name='kpt-84',
- id=84,
- color=[255, 255, 255],
- type='',
- swap='kpt-86'),
- 85:
- dict(name='kpt-85', id=85, color=[255, 255, 255], type='', swap=''),
- 86:
- dict(
- name='kpt-86',
- id=86,
- color=[255, 255, 255],
- type='',
- swap='kpt-84'),
- 87:
- dict(
- name='kpt-87',
- id=87,
- color=[255, 255, 255],
- type='',
- swap='kpt-83'),
- 88:
- dict(
- name='kpt-88',
- id=88,
- color=[255, 255, 255],
- type='',
- swap='kpt-92'),
- 89:
- dict(
- name='kpt-89',
- id=89,
- color=[255, 255, 255],
- type='',
- swap='kpt-91'),
- 90:
- dict(name='kpt-90', id=90, color=[255, 255, 255], type='', swap=''),
- 91:
- dict(
- name='kpt-91',
- id=91,
- color=[255, 255, 255],
- type='',
- swap='kpt-89'),
- 92:
- dict(
- name='kpt-92',
- id=92,
- color=[255, 255, 255],
- type='',
- swap='kpt-88'),
- 93:
- dict(
- name='kpt-93',
- id=93,
- color=[255, 255, 255],
- type='',
- swap='kpt-95'),
- 94:
- dict(name='kpt-94', id=94, color=[255, 255, 255], type='', swap=''),
- 95:
- dict(
- name='kpt-95',
- id=95,
- color=[255, 255, 255],
- type='',
- swap='kpt-93'),
- 96:
- dict(
- name='kpt-96',
- id=96,
- color=[255, 255, 255],
- type='',
- swap='kpt-97'),
- 97:
- dict(
- name='kpt-97',
- id=97,
- color=[255, 255, 255],
- type='',
- swap='kpt-96')
+ dict(name='kpt-36', id=36, color=[255, 0, 0], type='', swap='kpt-43'),
+ 37: dict(
+ name='kpt-37', id=37, color=[255, 0, 0], type='', swap='kpt-42'),
+ 38: dict(
+ name='kpt-38', id=38, color=[255, 0, 0], type='', swap='kpt-50'),
+ 39: dict(
+ name='kpt-39', id=39, color=[255, 0, 0], type='', swap='kpt-49'),
+ 40: dict(
+ name='kpt-40', id=40, color=[255, 0, 0], type='', swap='kpt-48'),
+ 41: dict(
+ name='kpt-41', id=41, color=[255, 0, 0], type='', swap='kpt-47'),
+ 42: dict(
+ name='kpt-42', id=42, color=[255, 0, 0], type='', swap='kpt-37'),
+ 43: dict(
+ name='kpt-43', id=43, color=[255, 0, 0], type='', swap='kpt-36'),
+ 44: dict(
+ name='kpt-44', id=44, color=[255, 0, 0], type='', swap='kpt-35'),
+ 45: dict(
+ name='kpt-45', id=45, color=[255, 0, 0], type='', swap='kpt-34'),
+ 46: dict(
+ name='kpt-46', id=46, color=[255, 0, 0], type='', swap='kpt-33'),
+ 47: dict(
+ name='kpt-47', id=47, color=[255, 0, 0], type='', swap='kpt-41'),
+ 48: dict(
+ name='kpt-48', id=48, color=[255, 0, 0], type='', swap='kpt-40'),
+ 49: dict(
+ name='kpt-49', id=49, color=[255, 0, 0], type='', swap='kpt-39'),
+ 50: dict(
+ name='kpt-50', id=50, color=[255, 0, 0], type='', swap='kpt-38'),
+ 51: dict(name='kpt-51', id=51, color=[255, 0, 0], type='', swap=''),
+ 52: dict(name='kpt-52', id=52, color=[255, 0, 0], type='', swap=''),
+ 53: dict(name='kpt-53', id=53, color=[255, 0, 0], type='', swap=''),
+ 54: dict(name='kpt-54', id=54, color=[255, 0, 0], type='', swap=''),
+ 55: dict(
+ name='kpt-55', id=55, color=[255, 0, 0], type='', swap='kpt-59'),
+ 56: dict(
+ name='kpt-56', id=56, color=[255, 0, 0], type='', swap='kpt-58'),
+ 57: dict(name='kpt-57', id=57, color=[255, 0, 0], type='', swap=''),
+ 58: dict(
+ name='kpt-58', id=58, color=[255, 0, 0], type='', swap='kpt-56'),
+ 59: dict(
+ name='kpt-59', id=59, color=[255, 0, 0], type='', swap='kpt-55'),
+ 60: dict(
+ name='kpt-60', id=60, color=[255, 0, 0], type='', swap='kpt-72'),
+ 61: dict(
+ name='kpt-61', id=61, color=[255, 0, 0], type='', swap='kpt-71'),
+ 62: dict(
+ name='kpt-62', id=62, color=[255, 0, 0], type='', swap='kpt-70'),
+ 63: dict(
+ name='kpt-63', id=63, color=[255, 0, 0], type='', swap='kpt-69'),
+ 64: dict(
+ name='kpt-64', id=64, color=[255, 0, 0], type='', swap='kpt-68'),
+ 65: dict(
+ name='kpt-65', id=65, color=[255, 0, 0], type='', swap='kpt-75'),
+ 66: dict(
+ name='kpt-66', id=66, color=[255, 0, 0], type='', swap='kpt-74'),
+ 67: dict(
+ name='kpt-67', id=67, color=[255, 0, 0], type='', swap='kpt-73'),
+ 68: dict(
+ name='kpt-68', id=68, color=[255, 0, 0], type='', swap='kpt-64'),
+ 69: dict(
+ name='kpt-69', id=69, color=[255, 0, 0], type='', swap='kpt-63'),
+ 70: dict(
+ name='kpt-70', id=70, color=[255, 0, 0], type='', swap='kpt-62'),
+ 71: dict(
+ name='kpt-71', id=71, color=[255, 0, 0], type='', swap='kpt-61'),
+ 72: dict(
+ name='kpt-72', id=72, color=[255, 0, 0], type='', swap='kpt-60'),
+ 73: dict(
+ name='kpt-73', id=73, color=[255, 0, 0], type='', swap='kpt-67'),
+ 74: dict(
+ name='kpt-74', id=74, color=[255, 0, 0], type='', swap='kpt-66'),
+ 75: dict(
+ name='kpt-75', id=75, color=[255, 0, 0], type='', swap='kpt-65'),
+ 76: dict(
+ name='kpt-76', id=76, color=[255, 0, 0], type='', swap='kpt-82'),
+ 77: dict(
+ name='kpt-77', id=77, color=[255, 0, 0], type='', swap='kpt-81'),
+ 78: dict(
+ name='kpt-78', id=78, color=[255, 0, 0], type='', swap='kpt-80'),
+ 79: dict(name='kpt-79', id=79, color=[255, 0, 0], type='', swap=''),
+ 80: dict(
+ name='kpt-80', id=80, color=[255, 0, 0], type='', swap='kpt-78'),
+ 81: dict(
+ name='kpt-81', id=81, color=[255, 0, 0], type='', swap='kpt-77'),
+ 82: dict(
+ name='kpt-82', id=82, color=[255, 0, 0], type='', swap='kpt-76'),
+ 83: dict(
+ name='kpt-83', id=83, color=[255, 0, 0], type='', swap='kpt-87'),
+ 84: dict(
+ name='kpt-84', id=84, color=[255, 0, 0], type='', swap='kpt-86'),
+ 85: dict(name='kpt-85', id=85, color=[255, 0, 0], type='', swap=''),
+ 86: dict(
+ name='kpt-86', id=86, color=[255, 0, 0], type='', swap='kpt-84'),
+ 87: dict(
+ name='kpt-87', id=87, color=[255, 0, 0], type='', swap='kpt-83'),
+ 88: dict(
+ name='kpt-88', id=88, color=[255, 0, 0], type='', swap='kpt-92'),
+ 89: dict(
+ name='kpt-89', id=89, color=[255, 0, 0], type='', swap='kpt-91'),
+ 90: dict(name='kpt-90', id=90, color=[255, 0, 0], type='', swap=''),
+ 91: dict(
+ name='kpt-91', id=91, color=[255, 0, 0], type='', swap='kpt-89'),
+ 92: dict(
+ name='kpt-92', id=92, color=[255, 0, 0], type='', swap='kpt-88'),
+ 93: dict(
+ name='kpt-93', id=93, color=[255, 0, 0], type='', swap='kpt-95'),
+ 94: dict(name='kpt-94', id=94, color=[255, 0, 0], type='', swap=''),
+ 95: dict(
+ name='kpt-95', id=95, color=[255, 0, 0], type='', swap='kpt-93'),
+ 96: dict(
+ name='kpt-96', id=96, color=[255, 0, 0], type='', swap='kpt-97'),
+ 97: dict(
+ name='kpt-97', id=97, color=[255, 0, 0], type='', swap='kpt-96')
},
skeleton_info={},
joint_weights=[1.] * 98,
diff --git a/configs/animal_2d_keypoint/topdown_heatmap/ap10k/resnet_ap10k.yml b/configs/animal_2d_keypoint/topdown_heatmap/ap10k/resnet_ap10k.yml
new file mode 100644
index 0000000000..11c6d912ac
--- /dev/null
+++ b/configs/animal_2d_keypoint/topdown_heatmap/ap10k/resnet_ap10k.yml
@@ -0,0 +1,41 @@
+Collections:
+- Name: SimpleBaseline2D
+ Paper:
+ Title: Simple baselines for human pose estimation and tracking
+ URL: http://openaccess.thecvf.com/content_ECCV_2018/html/Bin_Xiao_Simple_Baselines_for_ECCV_2018_paper.html
+ README: https://github.com/open-mmlab/mmpose/blob/master/docs/en/papers/algorithms/simplebaseline2d.md
+Models:
+- Config: configs/animal_2d_keypoint/topdown_heatmap/ap10k/td-hm_res50_8xb64-210e_ap10k-256x256.py
+ In Collection: SimpleBaseline2D
+ Alias: animal
+ Metadata:
+ Architecture: &id001
+ - SimpleBaseline2D
+ Training Data: AP-10K
+ Name: topdown_heatmap_res50_ap10k_256x256
+ Results:
+ - Dataset: AP-10K
+ Metrics:
+ AP: 0.680
+ AP@0.5: 0.926
+ AP@0.75: 0.738
+ APL: 0.687
+ APM: 0.552
+ Task: Animal 2D Keypoint
+ Weights: https://download.openmmlab.com/mmpose/animal/resnet/res50_ap10k_256x256-35760eb8_20211029.pth
+- Config: configs/animal_2d_keypoint/topdown_heatmap/ap10k/td-hm_res101_8xb64-210e_ap10k-256x256.py
+ In Collection: SimpleBaseline2D
+ Metadata:
+ Architecture: *id001
+ Training Data: AP-10K
+ Name: topdown_heatmap_res101_ap10k_256x256
+ Results:
+ - Dataset: AP-10K
+ Metrics:
+ AP: 0.681
+ AP@0.5: 0.921
+ AP@0.75: 0.751
+ APL: 0.690
+ APM: 0.545
+ Task: Animal 2D Keypoint
+ Weights: https://download.openmmlab.com/mmpose/animal/resnet/res101_ap10k_256x256-9edfafb9_20211029.pth
diff --git a/configs/body_2d_keypoint/topdown_heatmap/coco/hrnet_coco.yml b/configs/body_2d_keypoint/topdown_heatmap/coco/hrnet_coco.yml
index 0131493c15..86a305d223 100644
--- a/configs/body_2d_keypoint/topdown_heatmap/coco/hrnet_coco.yml
+++ b/configs/body_2d_keypoint/topdown_heatmap/coco/hrnet_coco.yml
@@ -7,6 +7,7 @@ Collections:
Models:
- Config: configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py
In Collection: HRNet
+ Alias: human
Metadata:
Architecture: &id001
- HRNet
diff --git a/configs/face_2d_keypoint/topdown_heatmap/wflw/hrnetv2_wflw.yml b/configs/face_2d_keypoint/topdown_heatmap/wflw/hrnetv2_wflw.yml
new file mode 100644
index 0000000000..3a0a4a454a
--- /dev/null
+++ b/configs/face_2d_keypoint/topdown_heatmap/wflw/hrnetv2_wflw.yml
@@ -0,0 +1,27 @@
+Collections:
+- Name: HRNetv2
+ Paper:
+ Title: Deep High-Resolution Representation Learning for Visual Recognition
+ URL: https://arxiv.org/abs/1908.07919
+ README: https://github.com/open-mmlab/mmpose/blob/1.x/docs/src/papers/backbones/hrnetv2.md
+Models:
+- Config: configs/face_2d_keypoint/topdown_heatmap/wflw/td-hm_hrnetv2-w18_8xb64-60e_wflw-256x256.py
+ In Collection: HRNetv2
+ Alias: face
+ Metadata:
+ Architecture:
+ - HRNetv2
+ Training Data: WFLW
+ Name: topdown_heatmap_hrnetv2_w18_wflw_256x256
+ Results:
+ - Dataset: WFLW
+ Metrics:
+ NME blur: 4.58
+ NME expression: 4.33
+ NME illumination: 3.99
+ NME makeup: 3.94
+ NME occlusion: 4.83
+ NME pose: 6.97
+ NME test: 4.06
+ Task: Face 2D Keypoint
+ Weights: https://download.openmmlab.com/mmpose/face/hrnetv2/hrnetv2_w18_wflw_256x256-2bf032a6_20210125.pth
diff --git a/configs/hand_2d_keypoint/topdown_heatmap/onehand10k/resnet_onehand10k.yml b/configs/hand_2d_keypoint/topdown_heatmap/onehand10k/resnet_onehand10k.yml
new file mode 100644
index 0000000000..828427899c
--- /dev/null
+++ b/configs/hand_2d_keypoint/topdown_heatmap/onehand10k/resnet_onehand10k.yml
@@ -0,0 +1,24 @@
+Collections:
+- Name: SimpleBaseline2D
+ Paper:
+ Title: Simple baselines for human pose estimation and tracking
+ URL: http://openaccess.thecvf.com/content_ECCV_2018/html/Bin_Xiao_Simple_Baselines_for_ECCV_2018_paper.html
+ README: https://github.com/open-mmlab/mmpose/blob/master/docs/en/papers/algorithms/simplebaseline2d.md
+Models:
+- Config: configs/hand_2d_keypoint/topdown_heatmap/onehand10k/td-hm_res50_8xb32-210e_onehand10k-256x256.py
+ In Collection: SimpleBaseline2D
+ Alias: hand
+ Metadata:
+ Architecture:
+ - SimpleBaseline2D
+ - ResNet
+ Training Data: OneHand10K
+ Name: topdown_heatmap_res50_onehand10k_256x256
+ Results:
+ - Dataset: OneHand10K
+ Metrics:
+ AUC: 0.555
+ EPE: 25.16
+ PCK@0.2: 0.989
+ Task: Hand 2D Keypoint
+ Weights: https://download.openmmlab.com/mmpose/hand/resnet/res50_onehand10k_256x256-739c8639_20210330.pth
diff --git a/demo/docs/2d_face_demo.md b/demo/docs/2d_face_demo.md
index 13ff380b69..09b8ceb330 100644
--- a/demo/docs/2d_face_demo.md
+++ b/demo/docs/2d_face_demo.md
@@ -1,29 +1,28 @@
## 2D Face Keypoint Demo
-We provide a demo script to test a single image or video with face detectors and top-down pose estimators, Please install `face_recognition` before running the demo, by:
+We provide a demo script to test a single image or video with hand detectors and top-down pose estimators. Assume that you have already installed [mmdet](https://github.com/open-mmlab/mmdetection) with version >= 3.0.
-```
-pip install face_recognition
-```
-
-For more details, please refer to [face_recognition](https://github.com/ageitgey/face_recognition).
+**Face Box Model Preparation:** The pre-trained face box estimation model can be found in [mmdet model zoo](/demo/docs/mmdet_modelzoo.md).
### 2D Face Image Demo
```shell
-python demo/topdown_face_demo.py \
+python demo/topdown_demo_with_mmdet.py \
+ ${MMDET_CONFIG_FILE} ${MMDET_CHECKPOINT_FILE} \
${MMPOSE_CONFIG_FILE} ${MMPOSE_CHECKPOINT_FILE} \
--input ${INPUT_PATH} [--output-root ${OUTPUT_DIR}] \
[--show] [--device ${GPU_ID or CPU}] [--save-predictions] \
[--draw-heatmap ${DRAW_HEATMAP}] [--radius ${KPT_RADIUS}] \
- [--kpt-thr ${KPT_SCORE_THR}]
+ [--kpt-thr ${KPT_SCORE_THR}] [--bbox-thr ${BBOX_SCORE_THR}]
```
The pre-trained face keypoint estimation models can be found from [model zoo](https://mmpose.readthedocs.io/en/1.x/model_zoo/face_2d_keypoint.html).
Take [aflw model](https://download.openmmlab.com/mmpose/face/hrnetv2/hrnetv2_w18_aflw_256x256-f2bbc62b_20210125.pth) as an example:
```shell
-python demo/topdown_face_demo.py \
+python demo/topdown_demo_with_mmdet.py \
+ demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py \
+ https://download.openmmlab.com/mmpose/mmdet_pretrained/yolo-x_8xb8-300e_coco-face_13274d7c.pth \
configs/face_2d_keypoint/topdown_heatmap/aflw/td-hm_hrnetv2-w18_8xb64-60e_aflw-256x256.py \
https://download.openmmlab.com/mmpose/face/hrnetv2/hrnetv2_w18_aflw_256x256-f2bbc62b_20210125.pth \
--input tests/data/cofw/001766.jpg \
@@ -32,14 +31,16 @@ python demo/topdown_face_demo.py \
Visualization result:
-
+
If you use a heatmap-based model and set argument `--draw-heatmap`, the predicted heatmap will be visualized together with the keypoints.
To save visualized results on disk:
```shell
-python demo/topdown_face_demo.py \
+python demo/topdown_demo_with_mmdet.py \
+ demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py \
+ https://download.openmmlab.com/mmpose/mmdet_pretrained/yolo-x_8xb8-300e_coco-face_13274d7c.pth \
configs/face_2d_keypoint/topdown_heatmap/aflw/td-hm_hrnetv2-w18_8xb64-60e_aflw-256x256.py \
https://download.openmmlab.com/mmpose/face/hrnetv2/hrnetv2_w18_aflw_256x256-f2bbc62b_20210125.pth \
--input tests/data/cofw/001766.jpg \
@@ -51,7 +52,9 @@ To save the predicted results on disk, please specify `--save-predictions`.
To run demos on CPU:
```shell
-python demo/topdown_face_demo.py \
+python demo/topdown_demo_with_mmdet.py \
+ demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py \
+ https://download.openmmlab.com/mmpose/mmdet_pretrained/yolo-x_8xb8-300e_coco-face_13274d7c.pth \
configs/face_2d_keypoint/topdown_heatmap/aflw/td-hm_hrnetv2-w18_8xb64-60e_aflw-256x256.py \
https://download.openmmlab.com/mmpose/face/hrnetv2/hrnetv2_w18_aflw_256x256-f2bbc62b_20210125.pth \
--input tests/data/cofw/001766.jpg \
@@ -63,14 +66,16 @@ python demo/topdown_face_demo.py \
Videos share the same interface with images. The difference is that the `${INPUT_PATH}` for videos can be the local path or **URL** link to video file.
```shell
-python demo/topdown_face_demo.py \
+python demo/topdown_demo_with_mmdet.py \
+ demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py \
+ https://download.openmmlab.com/mmpose/mmdet_pretrained/yolo-x_8xb8-300e_coco-face_13274d7c.pth \
configs/face_2d_keypoint/topdown_heatmap/aflw/td-hm_hrnetv2-w18_8xb64-60e_aflw-256x256.py \
https://download.openmmlab.com/mmpose/face/hrnetv2/hrnetv2_w18_aflw_256x256-f2bbc62b_20210125.pth \
--input demo/resources/ \
--show --draw-heatmap --output-root vis_results
```
-
+
The original video can be downloaded from [Google Drive](https://drive.google.com/file/d/1kQt80t6w802b_vgVcmiV_QfcSJ3RWzmb/view?usp=sharing).
diff --git a/demo/docs/mmdet_modelzoo.md b/demo/docs/mmdet_modelzoo.md
index d438a5e982..a50be168a5 100644
--- a/demo/docs/mmdet_modelzoo.md
+++ b/demo/docs/mmdet_modelzoo.md
@@ -15,6 +15,16 @@ For hand bounding box detection, we simply train our hand box models on onehand1
| :---------------------------------------------------------------- | :----: | :---------------------------------------------------------------: | :--------------------------------------------------------------: |
| [Cascade_R-CNN X-101-64x4d-FPN-1class](/demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py) | 0.817 | [ckpt](https://download.openmmlab.com/mmpose/mmdet_pretrained/cascade_rcnn_x101_64x4d_fpn_20e_onehand10k-dac19597_20201030.pth) | [log](https://download.openmmlab.com/mmpose/mmdet_pretrained/cascade_rcnn_x101_64x4d_fpn_20e_onehand10k_20201030.log.json) |
+### Face Bounding Box Detection Models
+
+For face bounding box detection, we train a YOLOX detector on COCO-face data using MMDetection.
+
+#### Hand detection results on OneHand10K test set
+
+| Arch | Box AP | ckpt |
+| :-------------------------------------------------------------- | :----: | :----------------------------------------------------------------------------------------------------: |
+| [YOLOX-s](/demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py) | 0.408 | [ckpt](https://download.openmmlab.com/mmpose/mmdet_pretrained/yolo-x_8xb8-300e_coco-face_13274d7c.pth) |
+
### Animal Bounding Box Detection Models
#### COCO animals
diff --git a/demo/inferencer_demo.py b/demo/inferencer_demo.py
new file mode 100644
index 0000000000..df4877e6d9
--- /dev/null
+++ b/demo/inferencer_demo.py
@@ -0,0 +1,104 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from argparse import ArgumentParser
+
+from mmpose.apis.inferencers import MMPoseInferencer
+
+
+def parse_args():
+ parser = ArgumentParser()
+ parser.add_argument(
+ 'inputs', type=str, help='Input image/video path or folder path.')
+ parser.add_argument(
+ '--pose2d',
+ type=str,
+ default=None,
+ help='Pretrained 2D pose estimation algorithm. It\'s the path to the '
+ 'config file or the model name defined in metafile.')
+ parser.add_argument(
+ '--pose2d-weights',
+ type=str,
+ default=None,
+ help='Path to the custom checkpoint file of the selected pose model. '
+ 'If it is not specified and "pose2d" is a model name of metafile, '
+ 'the weights will be loaded from metafile.')
+ parser.add_argument(
+ '--det-model',
+ type=str,
+ default=None,
+ help='Config path or alias of detection model.')
+ parser.add_argument(
+ '--det-weights',
+ type=str,
+ default=None,
+ help='Path to the checkpoints of detection model.')
+ parser.add_argument(
+ '--det-cat-ids',
+ type=int,
+ nargs='+',
+ default=None,
+ help='Category id for detection model.')
+ parser.add_argument(
+ '--device',
+ type=str,
+ default=None,
+ help='Device used for inference. '
+ 'If not specified, the available device will be automatically used.')
+ parser.add_argument(
+ '--show',
+ action='store_true',
+ help='Display the image/video in a popup window.')
+ parser.add_argument(
+ '--bbox-thr',
+ type=float,
+ default=0.3,
+ help='Bounding box score threshold')
+ parser.add_argument(
+ '--nms-thr',
+ type=float,
+ default=0.3,
+ help='IoU threshold for bounding box NMS')
+ parser.add_argument(
+ '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
+ parser.add_argument(
+ '--radius',
+ type=int,
+ default=3,
+ help='Keypoint radius for visualization.')
+ parser.add_argument(
+ '--thickness',
+ type=int,
+ default=1,
+ help='Link thickness for visualization.')
+ parser.add_argument(
+ '--vis-out-dir',
+ type=str,
+ default='',
+ help='Directory for saving visualized results.')
+ parser.add_argument(
+ '--pred-out-dir',
+ type=str,
+ default='',
+ help='Directory for saving inference results.')
+
+ call_args = vars(parser.parse_args())
+
+ init_kws = [
+ 'pose2d', 'pose2d_weights', 'device', 'det_model', 'det_weights',
+ 'det_cat_ids'
+ ]
+ init_args = {}
+ for init_kw in init_kws:
+ init_args[init_kw] = call_args.pop(init_kw)
+
+ return init_args, call_args
+
+
+def main():
+ init_args, call_args = parse_args()
+ inferencer = MMPoseInferencer(**init_args)
+ for _ in inferencer(**call_args):
+ pass
+
+
+if __name__ == '__main__':
+ main()
diff --git a/demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py b/demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py
new file mode 100644
index 0000000000..9180b831e6
--- /dev/null
+++ b/demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py
@@ -0,0 +1,306 @@
+train_cfg = dict(type='EpochBasedTrainLoop', max_epochs=300, val_interval=10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+param_scheduler = [
+ dict(
+ type='mmdet.QuadraticWarmupLR',
+ by_epoch=True,
+ begin=0,
+ end=5,
+ convert_to_iter_based=True),
+ dict(
+ type='CosineAnnealingLR',
+ eta_min=0.0005,
+ begin=5,
+ T_max=285,
+ end=285,
+ by_epoch=True,
+ convert_to_iter_based=True),
+ dict(type='ConstantLR', by_epoch=True, factor=1, begin=285, end=300)
+]
+optim_wrapper = dict(
+ type='OptimWrapper',
+ optimizer=dict(
+ type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005, nesterov=True),
+ paramwise_cfg=dict(norm_decay_mult=0.0, bias_decay_mult=0.0))
+auto_scale_lr = dict(enable=False, base_batch_size=64)
+default_scope = 'mmdet'
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(type='CheckpointHook', interval=10, max_keep_ckpts=3),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='DetVisualizationHook'))
+env_cfg = dict(
+ cudnn_benchmark=False,
+ mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
+ dist_cfg=dict(backend='nccl'))
+vis_backends = [dict(type='LocalVisBackend')]
+visualizer = dict(
+ type='DetLocalVisualizer',
+ vis_backends=[dict(type='LocalVisBackend')],
+ name='visualizer')
+log_processor = dict(type='LogProcessor', window_size=50, by_epoch=True)
+log_level = 'INFO'
+load_from = 'https://download.openmmlab.com/mmdetection/' \
+ 'v2.0/yolox/yolox_s_8x8_300e_coco/' \
+ 'yolox_s_8x8_300e_coco_20211121_095711-4592a793.pth'
+resume = False
+img_scale = (640, 640)
+model = dict(
+ type='YOLOX',
+ data_preprocessor=dict(
+ type='DetDataPreprocessor',
+ pad_size_divisor=32,
+ batch_augments=[
+ dict(
+ type='BatchSyncRandomResize',
+ random_size_range=(480, 800),
+ size_divisor=32,
+ interval=10)
+ ]),
+ backbone=dict(
+ type='CSPDarknet',
+ deepen_factor=0.33,
+ widen_factor=0.5,
+ out_indices=(2, 3, 4),
+ use_depthwise=False,
+ spp_kernal_sizes=(5, 9, 13),
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish')),
+ neck=dict(
+ type='YOLOXPAFPN',
+ in_channels=[128, 256, 512],
+ out_channels=128,
+ num_csp_blocks=1,
+ use_depthwise=False,
+ upsample_cfg=dict(scale_factor=2, mode='nearest'),
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish')),
+ bbox_head=dict(
+ type='YOLOXHead',
+ num_classes=1,
+ in_channels=128,
+ feat_channels=128,
+ stacked_convs=2,
+ strides=(8, 16, 32),
+ use_depthwise=False,
+ norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
+ act_cfg=dict(type='Swish'),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='sum',
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='IoULoss',
+ mode='square',
+ eps=1e-16,
+ reduction='sum',
+ loss_weight=5.0),
+ loss_obj=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ reduction='sum',
+ loss_weight=1.0),
+ loss_l1=dict(type='L1Loss', reduction='sum', loss_weight=1.0)),
+ train_cfg=dict(assigner=dict(type='SimOTAAssigner', center_radius=2.5)),
+ test_cfg=dict(score_thr=0.01, nms=dict(type='nms', iou_threshold=0.65)))
+data_root = 'data/coco/'
+dataset_type = 'CocoDataset'
+file_client_args = dict(backend='disk')
+train_pipeline = [
+ dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0),
+ dict(
+ type='RandomAffine', scaling_ratio_range=(0.1, 2),
+ border=(-320, -320)),
+ dict(
+ type='MixUp',
+ img_scale=(640, 640),
+ ratio_range=(0.8, 1.6),
+ pad_val=114.0),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
+ dict(
+ type='Pad',
+ pad_to_square=True,
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
+ dict(type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
+ dict(type='PackDetInputs')
+]
+train_dataset = dict(
+ type='MultiImageMixDataset',
+ dataset=dict(
+ type='CocoDataset',
+ data_root='data/coco/',
+ ann_file='annotations/instances_train2017.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=[
+ dict(
+ type='LoadImageFromFile',
+ file_client_args=dict(backend='disk')),
+ dict(type='LoadAnnotations', with_bbox=True)
+ ],
+ filter_cfg=dict(filter_empty_gt=False, min_size=32)),
+ pipeline=[
+ dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0),
+ dict(
+ type='RandomAffine',
+ scaling_ratio_range=(0.1, 2),
+ border=(-320, -320)),
+ dict(
+ type='MixUp',
+ img_scale=(640, 640),
+ ratio_range=(0.8, 1.6),
+ pad_val=114.0),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
+ dict(
+ type='Pad',
+ pad_to_square=True,
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
+ dict(
+ type='FilterAnnotations', min_gt_bbox_wh=(1, 1), keep_empty=False),
+ dict(type='PackDetInputs')
+ ])
+test_pipeline = [
+ dict(type='LoadImageFromFile', file_client_args=dict(backend='disk')),
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
+ dict(
+ type='Pad',
+ pad_to_square=True,
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+]
+train_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ persistent_workers=True,
+ sampler=dict(type='DefaultSampler', shuffle=True),
+ dataset=dict(
+ type='MultiImageMixDataset',
+ dataset=dict(
+ type='CocoDataset',
+ data_root='data/coco/',
+ ann_file='annotations/coco_face_train.json',
+ data_prefix=dict(img='train2017/'),
+ pipeline=[
+ dict(
+ type='LoadImageFromFile',
+ file_client_args=dict(backend='disk')),
+ dict(type='LoadAnnotations', with_bbox=True)
+ ],
+ filter_cfg=dict(filter_empty_gt=False, min_size=32),
+ metainfo=dict(CLASSES=('person', ), PALETTE=(220, 20, 60))),
+ pipeline=[
+ dict(type='Mosaic', img_scale=(640, 640), pad_val=114.0),
+ dict(
+ type='RandomAffine',
+ scaling_ratio_range=(0.1, 2),
+ border=(-320, -320)),
+ dict(
+ type='MixUp',
+ img_scale=(640, 640),
+ ratio_range=(0.8, 1.6),
+ pad_val=114.0),
+ dict(type='YOLOXHSVRandomAug'),
+ dict(type='RandomFlip', prob=0.5),
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
+ dict(
+ type='Pad',
+ pad_to_square=True,
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
+ dict(
+ type='FilterAnnotations',
+ min_gt_bbox_wh=(1, 1),
+ keep_empty=False),
+ dict(type='PackDetInputs')
+ ]))
+val_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='CocoDataset',
+ data_root='data/coco/',
+ ann_file='annotations/coco_face_val.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=[
+ dict(
+ type='LoadImageFromFile',
+ file_client_args=dict(backend='disk')),
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
+ dict(
+ type='Pad',
+ pad_to_square=True,
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+ ],
+ metainfo=dict(CLASSES=('person', ), PALETTE=(220, 20, 60))))
+test_dataloader = dict(
+ batch_size=8,
+ num_workers=4,
+ persistent_workers=True,
+ drop_last=False,
+ sampler=dict(type='DefaultSampler', shuffle=False),
+ dataset=dict(
+ type='CocoDataset',
+ data_root='data/coco/',
+ ann_file='annotations/coco_face_val.json',
+ data_prefix=dict(img='val2017/'),
+ test_mode=True,
+ pipeline=[
+ dict(
+ type='LoadImageFromFile',
+ file_client_args=dict(backend='disk')),
+ dict(type='Resize', scale=(640, 640), keep_ratio=True),
+ dict(
+ type='Pad',
+ pad_to_square=True,
+ pad_val=dict(img=(114.0, 114.0, 114.0))),
+ dict(type='LoadAnnotations', with_bbox=True),
+ dict(
+ type='PackDetInputs',
+ meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
+ 'scale_factor'))
+ ],
+ metainfo=dict(CLASSES=('person', ), PALETTE=(220, 20, 60))))
+val_evaluator = dict(
+ type='CocoMetric',
+ ann_file='data/coco/annotations/coco_face_val.json',
+ metric='bbox')
+test_evaluator = dict(
+ type='CocoMetric',
+ ann_file='data/coco/annotations/instances_val2017.json',
+ metric='bbox')
+max_epochs = 300
+num_last_epochs = 15
+interval = 10
+base_lr = 0.01
+custom_hooks = [
+ dict(type='YOLOXModeSwitchHook', num_last_epochs=15, priority=48),
+ dict(type='SyncNormHook', priority=48),
+ dict(
+ type='EMAHook',
+ ema_type='ExpMomentumEMA',
+ momentum=0.0001,
+ strict_load=False,
+ update_buffers=True,
+ priority=49)
+]
+metainfo = dict(CLASSES=('person', ), PALETTE=(220, 20, 60))
+launcher = 'pytorch'
diff --git a/demo/topdown_face_demo.py b/demo/topdown_face_demo.py
deleted file mode 100644
index 3442d098e3..0000000000
--- a/demo/topdown_face_demo.py
+++ /dev/null
@@ -1,218 +0,0 @@
-# Copyright (c) OpenMMLab. All rights reserved.
-import mimetypes
-import os
-import tempfile
-from argparse import ArgumentParser
-
-import json_tricks as json
-import mmcv
-import mmengine
-import numpy as np
-
-from mmpose.apis import inference_topdown
-from mmpose.apis import init_model as init_pose_estimator
-from mmpose.evaluation.functional import nms
-from mmpose.registry import VISUALIZERS
-from mmpose.structures import merge_data_samples, split_instances
-
-try:
- import face_recognition
- has_face_det = True
-except (ImportError, ModuleNotFoundError):
- has_face_det = False
-
-
-def process_face_det_results(face_det_results):
- """Process det results, and return a list of bboxes.
-
- :param face_det_results: (top, right, bottom and left)
- :return: a list of detected bounding boxes (x,y,x,y)-format
- """
-
- person_results = []
- for bbox in face_det_results:
- # left, top, right, bottom
- person_results.append([bbox[3], bbox[0], bbox[1], bbox[2]])
- person_results = np.array(person_results)
-
- return person_results
-
-
-def process_one_image(args, img_path, pose_estimator, visualizer,
- show_interval):
- """Visualize predicted keypoints (and heatmaps) of one image."""
-
- # predict bbox
- image = face_recognition.load_image_file(img_path)
- face_det_results = face_recognition.face_locations(image)
- bboxes = process_face_det_results(face_det_results)
-
- bboxes = np.concatenate((bboxes, np.ones((bboxes.shape[0], 1))), axis=1)
- bboxes = bboxes[nms(bboxes, args.nms_thr), :4]
-
- # predict keypoints
- pose_results = inference_topdown(pose_estimator, img_path, bboxes)
- data_samples = merge_data_samples(pose_results)
-
- # show the results
- img = mmcv.imread(img_path, channel_order='rgb')
-
- out_file = None
- if args.output_root:
- out_file = f'{args.output_root}/{os.path.basename(img_path)}'
-
- visualizer.add_datasample(
- 'result',
- img,
- data_sample=data_samples,
- draw_gt=False,
- draw_heatmap=args.draw_heatmap,
- draw_bbox=args.draw_bbox,
- show=args.show,
- wait_time=show_interval,
- out_file=out_file,
- kpt_score_thr=args.kpt_thr)
-
- return data_samples.pred_instances
-
-
-def main():
- """Visualize the demo images.
-
- Use `face_recognition` to detect the face.
- """
- parser = ArgumentParser()
- parser.add_argument('pose_config', help='Config file for pose')
- parser.add_argument('pose_checkpoint', help='Checkpoint file for pose')
- parser.add_argument(
- '--input', type=str, default='', help='Image/Video file')
- parser.add_argument(
- '--show',
- action='store_true',
- default=False,
- help='whether to show img')
- parser.add_argument(
- '--output-root',
- type=str,
- default='',
- help='root of the output img file. '
- 'Default not saving the visualization images.')
- parser.add_argument(
- '--save-predictions',
- action='store_true',
- default=False,
- help='whether to save predicted results')
- parser.add_argument(
- '--device', default='cuda:0', help='Device used for inference')
- parser.add_argument(
- '--nms-thr',
- type=float,
- default=0.3,
- help='IoU threshold for bounding box NMS')
- parser.add_argument(
- '--kpt-thr', type=float, default=0.3, help='Keypoint score threshold')
- parser.add_argument(
- '--draw-heatmap',
- action='store_true',
- default=False,
- help='Draw heatmap predicted by the model')
- parser.add_argument(
- '--radius',
- type=int,
- default=2,
- help='Keypoint radius for visualization')
- parser.add_argument(
- '--thickness',
- type=int,
- default=1,
- help='Link thickness for visualization')
- parser.add_argument(
- '--draw-bbox', action='store_true', help='Draw bboxes of instances')
-
- assert has_face_det, 'Please install face_recognition to run the demo. ' \
- '"pip install face_recognition", For more details, ' \
- 'see https://github.com/ageitgey/face_recognition'
-
- args = parser.parse_args()
-
- assert args.show or (args.output_root != '')
- assert args.input != ''
- if args.output_root:
- mmengine.mkdir_or_exist(args.output_root)
- if args.save_predictions:
- assert args.output_root != ''
- args.pred_save_path = f'{args.output_root}/results_' \
- f'{os.path.splitext(os.path.basename(args.input))[0]}.json'
-
- # build pose estimator
- pose_estimator = init_pose_estimator(
- args.pose_config,
- args.pose_checkpoint,
- device=args.device,
- cfg_options=dict(
- model=dict(test_cfg=dict(output_heatmaps=args.draw_heatmap))))
-
- # init visualizer
- pose_estimator.cfg.visualizer.radius = args.radius
- pose_estimator.cfg.visualizer.line_width = args.thickness
- visualizer = VISUALIZERS.build(pose_estimator.cfg.visualizer)
- # the dataset_meta is loaded from the checkpoint and
- # then pass to the model in init_pose_estimator
- visualizer.set_dataset_meta(pose_estimator.dataset_meta)
- visualizer.kpt_color = 'red'
-
- input_type = mimetypes.guess_type(args.input)[0].split('/')[0]
- if input_type == 'image':
- pred_instances = process_one_image(
- args, args.input, pose_estimator, visualizer, show_interval=0)
- pred_instances_list = split_instances(pred_instances)
-
- elif input_type == 'video':
- tmp_folder = tempfile.TemporaryDirectory()
- video = mmcv.VideoReader(args.input)
- progressbar = mmengine.ProgressBar(len(video))
- video.cvt2frames(tmp_folder.name, show_progress=False)
- output_root = args.output_root
- args.output_root = tmp_folder.name
- pred_instances_list = []
-
- for frame_id, img_fname in enumerate(os.listdir(tmp_folder.name)):
- pred_instances = process_one_image(
- args,
- f'{tmp_folder.name}/{img_fname}',
- pose_estimator,
- visualizer,
- show_interval=1)
- progressbar.update()
- pred_instances_list.append(
- dict(
- frame_id=frame_id,
- instances=split_instances(pred_instances)))
-
- if output_root:
- mmcv.frames2video(
- tmp_folder.name,
- f'{output_root}/{os.path.basename(args.input)}',
- fps=video.fps,
- fourcc='mp4v',
- show_progress=False)
- tmp_folder.cleanup()
-
- else:
- args.save_predictions = False
- raise ValueError(
- f'file {os.path.basename(args.input)} has invalid format.')
-
- if args.save_predictions:
- with open(args.pred_save_path, 'w') as f:
- json.dump(
- dict(
- meta_info=pose_estimator.dataset_meta,
- instance_info=pred_instances_list),
- f,
- indent='\t')
- print(f'predictions have been saved at {args.pred_save_path}')
-
-
-if __name__ == '__main__':
- main()
diff --git a/docs/en/user_guides/inference.md b/docs/en/user_guides/inference.md
index b247d819fb..e9af2adee1 100644
--- a/docs/en/user_guides/inference.md
+++ b/docs/en/user_guides/inference.md
@@ -9,6 +9,122 @@ In MMPose, a model is defined by a configuration file and existing model paramet
To start with, we recommend HRNet model with [this configuration file](/configs/body_2d_keypoint/topdown_heatmap/coco/td-hm_hrnet-w32_8xb64-210e_coco-256x192.py) and [this checkpoint file](https://download.openmmlab.com/mmpose/top_down/hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth). It is recommended to download the checkpoint file to `checkpoints` directory.
+## Out-of-the-box inferencer
+
+MMPose offers a comprehensive API for inference, known as `MMPoseInferencer`. This API enables users to perform inference on both images and videos using all the models supported by MMPose. Furthermore, the API provides automatic visualization of inference results and allows for the convenient saving of predictions.
+
+Here is an example of inference on a given image using the pre-trained human pose estimator.
+
+```python
+from mmpose.apis import MMPoseInferencer
+
+img_path = 'tests/data/coco/000000000785.jpg' # you can specify your own picture path
+
+# build the inferencer with model alias
+inferencer = MMPoseInferencer('human')
+
+# The MMPoseInferencer API utilizes a lazy inference strategy,
+# whereby it generates a prediction generator when provided with input
+result_generator = inferencer(img_path, show=True)
+result = next(result_generator)
+```
+
+If everything works fine, you will see the following image in a new window.
+![inferencer_result_coco](https://user-images.githubusercontent.com/26127467/220008302-4a57fd44-0978-408e-8351-600e5513316a.jpg)
+
+The variable `result` is a dictionary that contains two keys, `'visualization'` and `'predictions'`. The key `'visualization'` is intended to contain the visualization results. However, as the `return_vis` argument was not specified, this list remains blank. On the other hand, the key `'predictions'` is a list that contains the estimated keypoints for each individual instance.
+
+### CLI tool
+
+A command-line interface (CLI) tool for the inferencer is also available: `demo/inferencer_demo.py`. This tool enables users to perform inference with the same model and inputs using the following command:
+
+```bash
+python demo/inferencer_demo.py 'tests/data/coco/000000000785.jpg' \
+ --pose2d 'human' --show --pred-out-dir 'predictions'
+```
+
+The predictions will be save in `predictions/000000000785.json`.
+
+### Custom pose estimation models
+
+The inferencer provides several methods that can be used to customize the models employed:
+
+```python
+
+# build the inferencer with model alias
+# the available aliases include 'human', 'hand', 'face' and 'animal'
+inferencer = MMPoseInferencer('human')
+
+# build the inferencer with model config name
+inferencer = MMPoseInferencer('td-hm_hrnet-w32_8xb64-210e_coco-256x192')
+
+# build the inferencer with model config path and checkpoint path/URL
+inferencer = MMPoseInferencer(
+ pose2d='configs/body_2d_keypoint/topdown_heatmap/coco/' \
+ 'td-hm_hrnet-w32_8xb64-210e_coco-256x192.py',
+ pose2d_weights='https://download.openmmlab.com/mmpose/top_down/' \
+ 'hrnet/hrnet_w32_coco_256x192-c78dce93_20200708.pth'
+)
+```
+
+In addition, top-down pose estimators also require an object detection model. The inferencer is capable of inferring the instance type for models trained with datasets supported in MMPose, and subsequently constructing the necessary object detection model. Alternatively, users may also manually specify the detection model using the following methods:
+
+```python
+
+# specify detection model by alias
+# the available aliases include 'human', 'hand', 'face', 'animal',
+# as well as any additional aliases defined in mmdet
+inferencer = MMPoseInferencer(
+ # suppose the pose estimator is trained on custom dataset
+ pose2d='custom_human_pose_estimator.py',
+ pose2d_weights='custom_human_pose_estimator.pth',
+ det_model='human'
+)
+
+# specify detection model with model config name
+inferencer = MMPoseInferencer(
+ pose2d='human',
+ det_model='yolox_l_8x8_300e_coco',
+ det_cat_ids=[0], # the category id of 'human' class
+)
+
+# specify detection model with config path and checkpoint path/URL
+inferencer = MMPoseInferencer(
+ pose2d='human',
+ det_model=f'{PATH_TO_MMDET}/configs/yolox/yolox_l_8x8_300e_coco.py',
+ det_weights='https://download.openmmlab.com/mmdetection/v2.0/' \
+ 'yolox/yolox_l_8x8_300e_coco/' \
+ 'yolox_l_8x8_300e_coco_20211126_140236-d3bd2b23.pth',
+ det_cat_ids=[0], # the category id of 'human' class
+)
+```
+
+### Input format
+
+The inferencer is capable of processing a range of input types, which includes the following:
+
+- A path to an image
+- A path to a video
+- A path to a folder (which will cause all images in that folder to be inferred)
+- An image array
+- A list of image arrays
+- A webcam (in which case the `input` parameter should be set to either `'webcam'` or `'webcam:{CAMERA_ID}'`)
+
+### Output settings
+
+The inferencer is capable of both visualizing and saving predictions. The relevant arguments are as follows:
+
+| Argument | Description |
+| ------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `show` | Determines whether the image or video should be displayed in a pop-up window. |
+| `radius` | Sets the keypoint radius for visualization. |
+| `thickness` | Sets the link thickness for visualization. |
+| `return_vis` | Determines whether visualization images should be included in the results. |
+| `vis_out_dir` | Specifies the folder path for saving the visualization images. If not set, the visualization images will not be saved. |
+| `return_datasample` | Determines whether to return the prediction in the format of `PoseDataSample`. |
+| `pred_out_dir` | Specifies the folder path for saving the predictions. If not set, the predictions will not be saved. |
+| `out_dir` | If `vis_out_dir` or `pred_out_dir` is not set, the values will be set to `f'{out_dir}/visualization'` or `f'{out_dir}/predictions'`, respectively. |
+
## High-level APIs for inference
MMPose provides high-level Python APIs for inference on a given image:
diff --git a/mmpose/apis/__init__.py b/mmpose/apis/__init__.py
index 8534c95e61..ff7149e453 100644
--- a/mmpose/apis/__init__.py
+++ b/mmpose/apis/__init__.py
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_bottomup, inference_topdown, init_model
+from .inferencers import MMPoseInferencer, Pose2DInferencer
-__all__ = ['init_model', 'inference_topdown', 'inference_bottomup']
+__all__ = [
+ 'init_model', 'inference_topdown', 'inference_bottomup',
+ 'Pose2DInferencer', 'MMPoseInferencer'
+]
diff --git a/mmpose/apis/inferencers/__init__.py b/mmpose/apis/inferencers/__init__.py
new file mode 100644
index 0000000000..3c21a02e08
--- /dev/null
+++ b/mmpose/apis/inferencers/__init__.py
@@ -0,0 +1,5 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .mmpose_inferencer import MMPoseInferencer
+from .pose2d_inferencer import Pose2DInferencer
+
+__all__ = ['Pose2DInferencer', 'MMPoseInferencer']
diff --git a/mmpose/apis/inferencers/base_mmpose_inferencer.py b/mmpose/apis/inferencers/base_mmpose_inferencer.py
new file mode 100644
index 0000000000..d99dcc1b68
--- /dev/null
+++ b/mmpose/apis/inferencers/base_mmpose_inferencer.py
@@ -0,0 +1,444 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import mimetypes
+import os
+import shutil
+import tempfile
+import warnings
+from collections import defaultdict
+from typing import (Any, Callable, Dict, Generator, List, Optional, Sequence,
+ Union)
+
+import cv2
+import mmcv
+import mmengine
+import numpy as np
+import torch.nn as nn
+from mmengine.config import Config, ConfigDict
+from mmengine.dataset import Compose
+from mmengine.fileio import (get_file_backend, isdir, join_path,
+ list_dir_or_file)
+from mmengine.infer.infer import BaseInferencer
+from mmengine.runner.checkpoint import _load_checkpoint_to_model
+from mmengine.structures import InstanceData
+
+from mmpose.apis.inference import dataset_meta_from_config
+from mmpose.structures import PoseDataSample, split_instances
+
+InstanceList = List[InstanceData]
+InputType = Union[str, np.ndarray]
+InputsType = Union[InputType, Sequence[InputType]]
+PredType = Union[InstanceData, InstanceList]
+ImgType = Union[np.ndarray, Sequence[np.ndarray]]
+ConfigType = Union[Config, ConfigDict]
+ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
+
+
+class BaseMMPoseInferencer(BaseInferencer):
+ """The base class for MMPose inferencers."""
+
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr'}
+ forward_kwargs: set = set()
+ visualize_kwargs: set = {
+ 'return_vis',
+ 'show',
+ 'wait_time',
+ 'radius',
+ 'thickness',
+ 'kpt_thr',
+ 'vis_out_dir',
+ }
+ postprocess_kwargs: set = {'pred_out_dir'}
+
+ def _load_weights_to_model(self, model: nn.Module,
+ checkpoint: Optional[dict],
+ cfg: Optional[ConfigType]) -> None:
+ """Loading model weights and meta information from cfg and checkpoint.
+
+ Subclasses could override this method to load extra meta information
+ from ``checkpoint`` and ``cfg`` to model.
+
+ Args:
+ model (nn.Module): Model to load weights and meta information.
+ checkpoint (dict, optional): The loaded checkpoint.
+ cfg (Config or ConfigDict, optional): The loaded config.
+ """
+ if checkpoint is not None:
+ _load_checkpoint_to_model(model, checkpoint)
+ checkpoint_meta = checkpoint.get('meta', {})
+ # save the dataset_meta in the model for convenience
+ if 'dataset_meta' in checkpoint_meta:
+ # mmpose 1.x
+ model.dataset_meta = checkpoint_meta['dataset_meta']
+ else:
+ warnings.warn(
+ 'dataset_meta are not saved in the checkpoint\'s '
+ 'meta data, load via config.')
+ model.dataset_meta = dataset_meta_from_config(
+ cfg, dataset_mode='train')
+ else:
+ warnings.warn('Checkpoint is not loaded, and the inference '
+ 'result is calculated by the randomly initialized '
+ 'model!')
+ model.dataset_meta = dataset_meta_from_config(
+ cfg, dataset_mode='train')
+
+ def _inputs_to_list(self, inputs: InputsType) -> list:
+ """Preprocess the inputs to a list.
+
+ Preprocess inputs to a list according to its type:
+
+ - list or tuple: return inputs
+ - str:
+ - Directory path: return all files in the directory
+ - other cases: return a list containing the string. The string
+ could be a path to file, a url or other types of string
+ according to the task.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+
+ Returns:
+ list: List of input for the :meth:`preprocess`.
+ """
+ self._video_input = False
+
+ if isinstance(inputs, str):
+ backend = get_file_backend(inputs)
+ if hasattr(backend, 'isdir') and isdir(inputs):
+ # Backends like HttpsBackend do not implement `isdir`, so only
+ # those backends that implement `isdir` could accept the
+ # inputs as a directory
+ filepath_list = [
+ join_path(inputs, fname)
+ for fname in list_dir_or_file(inputs, list_dir=False)
+ ]
+ inputs = []
+ for filepath in filepath_list:
+ input_type = mimetypes.guess_type(filepath)[0].split(
+ '/')[0]
+ if input_type == 'image':
+ inputs.append(filepath)
+ inputs.sort()
+ else:
+ # if inputs is a path to a video file, it will be converted
+ # to a list containing separated frame filenames
+ input_type = mimetypes.guess_type(inputs)[0].split('/')[0]
+ if input_type == 'video':
+ self._video_input = True
+ # split video frames into a temporary folder
+ frame_folder = tempfile.TemporaryDirectory()
+ video = mmcv.VideoReader(inputs)
+ self.video_info = dict(
+ fps=video.fps,
+ name=os.path.basename(inputs),
+ frame_folder=frame_folder)
+ video.cvt2frames(frame_folder.name, show_progress=False)
+ frames = sorted(list_dir_or_file(frame_folder.name))
+ inputs = [join_path(frame_folder.name, f) for f in frames]
+
+ if not isinstance(inputs, (list, tuple)):
+ inputs = [inputs]
+
+ return list(inputs)
+
+ def _get_webcam_inputs(self, inputs: str) -> Generator:
+ """Sets up and returns a generator function that reads frames from a
+ webcam input. The generator function returns a new frame each time it
+ is iterated over.
+
+ Args:
+ inputs (str): A string describing the webcam input, in the format
+ "webcam:id".
+
+ Returns:
+ A generator function that yields frames from the webcam input.
+
+ Raises:
+ ValueError: If the inputs string is not in the expected format.
+ """
+
+ # Ensure the inputs string is in the expected format.
+ inputs = inputs.lower()
+ assert inputs.startswith('webcam'), f'Expected input to start with ' \
+ f'"webcam", but got "{inputs}"'
+
+ # Parse the camera ID from the inputs string.
+ inputs_ = inputs.split(':')
+ if len(inputs_) == 1:
+ camera_id = 0
+ elif len(inputs_) == 2 and str.isdigit(inputs_[1]):
+ camera_id = int(inputs_[1])
+ else:
+ raise ValueError(
+ f'Expected webcam input to have format "webcam:id", '
+ f'but got "{inputs}"')
+
+ # Attempt to open the video capture object.
+ vcap = cv2.VideoCapture(camera_id)
+ if not vcap.isOpened():
+ warnings.warn(f'Cannot open camera (ID={camera_id})')
+ return []
+
+ # Set video input flag and metadata.
+ self._video_input = True
+ self.video_info = dict(fps=10, name='webcam.mp4', frame_folder=None)
+
+ # Set up webcam reader generator function.
+ self._window_closing = False
+
+ def _webcam_reader() -> Generator:
+ while True:
+ if self._window_closing:
+ vcap.release()
+ break
+
+ ret_val, frame = vcap.read()
+ if not ret_val:
+ break
+
+ yield frame
+
+ return _webcam_reader()
+
+ def _visualization_window_on_close(self, event):
+ self._window_closing = True
+
+ def _init_pipeline(self, cfg: ConfigType) -> Callable:
+ """Initialize the test pipeline.
+
+ Args:
+ cfg (ConfigType): model config path or dict
+
+ Returns:
+ A pipeline to handle various input data, such as ``str``,
+ ``np.ndarray``. The returned pipeline will be used to process
+ a single data.
+ """
+ return Compose(cfg.test_dataloader.dataset.pipeline)
+
+ def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
+ """Process the inputs into a model-feedable format.
+
+ Args:
+ inputs (InputsType): Inputs given by user.
+ batch_size (int): batch size. Defaults to 1.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
+ List[str or np.ndarray]: List of original inputs in the batch
+ """
+
+ for i, input in enumerate(inputs):
+ data_infos = self.preprocess_single(input, index=i, **kwargs)
+ # only supports inference with batch size 1
+ yield self.collate_fn(data_infos), [input]
+
+ def visualize(self,
+ inputs: list,
+ preds: List[PoseDataSample],
+ return_vis: bool = False,
+ show: bool = False,
+ wait_time: float = 0,
+ radius: int = 3,
+ thickness: int = 1,
+ kpt_thr: float = 0.3,
+ vis_out_dir: str = '',
+ window_name: str = '',
+ window_close_event_handler: Optional[Callable] = None
+ ) -> List[np.ndarray]:
+ """Visualize predictions.
+
+ Args:
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
+ preds (Any): Predictions of the model.
+ return_vis (bool): Whether to return images with predicted results.
+ show (bool): Whether to display the image in a popup window.
+ Defaults to False.
+ wait_time (float): The interval of show (ms). Defaults to 0
+ radius (int): Keypoint radius for visualization. Defaults to 3
+ thickness (int): Link thickness for visualization. Defaults to 1
+ kpt_thr (float): The threshold to visualize the keypoints.
+ Defaults to 0.3
+ vis_out_dir (str, optional): Directory to save visualization
+ results w/o predictions. If left as empty, no file will
+ be saved. Defaults to ''.
+ window_name (str, optional): Title of display window.
+ window_close_event_handler (callable, optional):
+
+ Returns:
+ List[np.ndarray]: Visualization results.
+ """
+ if (not return_vis) and (not show) and (not vis_out_dir):
+ return
+
+ if getattr(self, 'visualizer', None) is None:
+ raise ValueError('Visualization needs the "visualizer" term'
+ 'defined in the config, but got None.')
+
+ self.visualizer.radius = radius
+ self.visualizer.line_width = thickness
+
+ results = []
+
+ for single_input, pred in zip(inputs, preds):
+ if isinstance(single_input, str):
+ img = mmcv.imread(single_input, channel_order='rgb')
+ elif isinstance(single_input, np.ndarray):
+ img = mmcv.bgr2rgb(single_input.copy())
+ else:
+ raise ValueError('Unsupported input type: '
+ f'{type(single_input)}')
+
+ img_name = os.path.basename(pred.metainfo['img_path'])
+
+ if vis_out_dir:
+ if self._video_input:
+ out_file = join_path(vis_out_dir, 'vis_frames', img_name)
+ else:
+ out_file = join_path(vis_out_dir, img_name)
+ else:
+ out_file = None
+
+ # since visualization and inference utilize the same process,
+ # the wait time is reduced when a video input is utilized,
+ # thereby eliminating the issue of inference getting stuck.
+ wait_time = 1e-5 if self._video_input else wait_time
+
+ window_name = window_name if window_name else img_name
+
+ visualization = self.visualizer.add_datasample(
+ window_name,
+ img,
+ pred,
+ draw_gt=False,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file,
+ kpt_score_thr=kpt_thr)
+ results.append(visualization)
+
+ if show and not hasattr(self, '_window_close_cid'):
+ if window_close_event_handler is None:
+ window_close_event_handler = \
+ self._visualization_window_on_close
+ self._window_close_cid = \
+ self.visualizer.manager.canvas.mpl_connect(
+ 'close_event',
+ window_close_event_handler
+ )
+
+ if return_vis:
+ return results
+ else:
+ return []
+
+ def postprocess(
+ self,
+ preds: List[PoseDataSample],
+ visualization: List[np.ndarray],
+ return_datasample=False,
+ pred_out_dir: str = '',
+ ) -> dict:
+ """Process the predictions and visualization results from ``forward``
+ and ``visualize``.
+
+ This method should be responsible for the following tasks:
+
+ 1. Convert datasamples into a json-serializable dict if needed.
+ 2. Pack the predictions and visualization results and return them.
+ 3. Dump or log the predictions.
+
+ Args:
+ preds (List[Dict]): Predictions of the model.
+ visualization (np.ndarray): Visualized predictions.
+ return_datasample (bool): Whether to return results as
+ datasamples. Defaults to False.
+ pred_out_dir (str): Directory to save the inference results w/o
+ visualization. If left as empty, no file will be saved.
+ Defaults to ''.
+
+ Returns:
+ dict: Inference and visualization results with key ``predictions``
+ and ``visualization``
+
+ - ``visualization (Any)``: Returned by :meth:`visualize`
+ - ``predictions`` (dict or DataSample): Returned by
+ :meth:`forward` and processed in :meth:`postprocess`.
+ If ``return_datasample=False``, it usually should be a
+ json-serializable dict containing only basic data elements such
+ as strings and numbers.
+ """
+
+ result_dict = defaultdict(list)
+
+ result_dict['visualization'] = visualization
+ for pred in preds:
+ if not return_datasample:
+ # convert datasamples to list of instance predictions
+ pred = split_instances(pred.pred_instances)
+ result_dict['predictions'].append(pred)
+
+ if pred_out_dir != '':
+ if self._video_input:
+ pred_out_dir = join_path(pred_out_dir, 'pred_frames')
+
+ for pred, data_sample in zip(result_dict['predictions'], preds):
+ fname = os.path.splitext(
+ os.path.basename(
+ data_sample.metainfo['img_path']))[0] + '.json'
+ mmengine.dump(
+ pred, join_path(pred_out_dir, fname), indent=' ')
+
+ return result_dict
+
+ def _merge_outputs(self, vis_out_dir: str, pred_out_dir: str,
+ **kwargs: Dict[str, Any]) -> None:
+ """Merge the visualized frames and predicted instance outputs and save
+ them.
+
+ Args:
+ vis_out_dir (str): Path to the directory where the visualized
+ frames are saved.
+ pred_out_dir (str): Path to the directory where the predicted
+ instance outputs are saved.
+ **kwargs: Other arguments that are not used in this method.
+ """
+ assert self._video_input
+
+ if vis_out_dir != '':
+ vis_frame_out_dir = join_path(vis_out_dir, 'vis_frames')
+ if not isdir(vis_frame_out_dir) or len(
+ os.listdir(vis_frame_out_dir)) == 0:
+ warnings.warn(
+ f'{vis_frame_out_dir} does not exist or is empty.')
+ else:
+ mmcv.frames2video(
+ vis_frame_out_dir,
+ join_path(vis_out_dir, self.video_info['name']),
+ fps=self.video_info['fps'],
+ fourcc='mp4v',
+ show_progress=False)
+ shutil.rmtree(vis_frame_out_dir)
+
+ if pred_out_dir != '':
+ pred_frame_out_dir = join_path(pred_out_dir, 'pred_frames')
+ if not isdir(pred_frame_out_dir) or len(
+ os.listdir(pred_frame_out_dir)) == 0:
+ warnings.warn(
+ f'{pred_frame_out_dir} does not exist or is empty.')
+ else:
+ predictions = []
+ pred_files = list_dir_or_file(pred_frame_out_dir)
+ for frame_id, pred_file in enumerate(sorted(pred_files)):
+ predictions.append({
+ 'frame_id':
+ frame_id,
+ 'instances':
+ mmengine.load(
+ join_path(pred_frame_out_dir, pred_file))
+ })
+ fname = os.path.splitext(
+ os.path.basename(self.video_info['name']))[0] + '.json'
+ mmengine.dump(
+ predictions, join_path(pred_out_dir, fname), indent=' ')
+ shutil.rmtree(pred_frame_out_dir)
diff --git a/mmpose/apis/inferencers/mmpose_inferencer.py b/mmpose/apis/inferencers/mmpose_inferencer.py
new file mode 100644
index 0000000000..f5b23fb125
--- /dev/null
+++ b/mmpose/apis/inferencers/mmpose_inferencer.py
@@ -0,0 +1,275 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Dict, List, Optional, Sequence, Union
+
+import numpy as np
+from mmengine.config import Config, ConfigDict
+from mmengine.fileio import join_path
+from mmengine.infer.infer import ModelType
+from mmengine.structures import InstanceData
+from rich.progress import track
+
+from mmpose.structures import PoseDataSample
+from .base_mmpose_inferencer import BaseMMPoseInferencer
+from .pose2d_inferencer import Pose2DInferencer
+
+InstanceList = List[InstanceData]
+InputType = Union[str, np.ndarray]
+InputsType = Union[InputType, Sequence[InputType]]
+PredType = Union[InstanceData, InstanceList]
+ImgType = Union[np.ndarray, Sequence[np.ndarray]]
+ConfigType = Union[Config, ConfigDict]
+ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
+
+
+class MMPoseInferencer(BaseMMPoseInferencer):
+ """MMPose Inferencer. It's a unified inferencer interface for pose
+ estimation task, currently including: Pose2D. and it can be used to perform
+ 2D keypoint detection.
+
+ Args:
+ pose2d (str, optional): Pretrained 2D pose estimation algorithm.
+ It's the path to the config file or the model name defined in
+ metafile. For example, it could be:
+
+ - model alias, e.g. ``'body'``,
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
+ - config path
+
+ Defaults to ``None``.
+ pose2d_weights (str, optional): Path to the custom checkpoint file of
+ the selected pose2d model. If it is not specified and "pose2d" is
+ a model name of metafile, the weights will be loaded from
+ metafile. Defaults to None.
+ device (str, optional): Device to run inference. If None, the
+ available device will be automatically used. Defaults to None.
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
+ det_model(str, optional): Config path or alias of detection model.
+ Defaults to None.
+ det_weights(str, optional): Path to the checkpoints of detection
+ model. Defaults to None.
+ det_cat_ids(int or list[int], optional): Category id for
+ detection model. Defaults to None.
+ """
+
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr'}
+ forward_kwargs: set = set()
+ visualize_kwargs: set = {
+ 'return_vis',
+ 'show',
+ 'wait_time',
+ 'radius',
+ 'thickness',
+ 'kpt_thr',
+ 'vis_out_dir',
+ }
+ postprocess_kwargs: set = {'pred_out_dir'}
+
+ def __init__(self,
+ pose2d: Optional[str] = None,
+ pose2d_weights: Optional[str] = None,
+ device: Optional[str] = None,
+ scope: str = 'mmpose',
+ det_model: Optional[Union[ModelType, str]] = None,
+ det_weights: Optional[str] = None,
+ det_cat_ids: Optional[Union[int, List]] = None) -> None:
+
+ if pose2d is None:
+ raise ValueError('2d pose estimation algorithm should provided.')
+
+ self.visualizer = None
+ if pose2d is not None:
+ self.pose2d_inferencer = Pose2DInferencer(pose2d, pose2d_weights,
+ device, scope, det_model,
+ det_weights, det_cat_ids)
+ self.mode = 'pose2d'
+
+ def preprocess(self, inputs: InputsType, batch_size: int = 1, **kwargs):
+ """Process the inputs into a model-feedable format.
+
+ Args:
+ inputs (InputsType): Inputs given by user.
+ batch_size (int): batch size. Defaults to 1.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
+ List[str or np.ndarray]: List of original inputs in the batch
+ """
+
+ for i, input in enumerate(inputs):
+ data_batch = {}
+ if 'pose2d' in self.mode:
+ data_infos = self.pose2d_inferencer.preprocess_single(
+ input, index=i, **kwargs)
+ data_batch['pose2d'] = self.pose2d_inferencer.collate_fn(
+ data_infos)
+ # only supports inference with batch size 1
+ yield data_batch, [input]
+
+ def forward(self, inputs: InputType, **forward_kwargs) -> PredType:
+ """Forward the inputs to the model.
+
+ Args:
+ inputs (InputsType): The inputs to be forwarded.
+
+ Returns:
+ Dict: The prediction results. Possibly with keys "pose2d".
+ """
+ result = {}
+ if self.mode == 'pose2d':
+ data_samples = self.pose2d_inferencer.forward(
+ inputs['pose2d'], **forward_kwargs)
+ result['pose2d'] = data_samples
+
+ return result
+
+ def __call__(
+ self,
+ inputs: InputsType,
+ return_datasample: bool = False,
+ batch_size: int = 1,
+ out_dir: Optional[str] = None,
+ **kwargs,
+ ) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+ return_datasample (bool): Whether to return results as
+ :obj:`BaseDataElement`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ out_dir (str, optional): directory to save visualization
+ results and predictions. Will be overoden if vis_out_dir or
+ pred_out_dir are given. Defaults to None
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
+ Each key in kwargs should be in the corresponding set of
+ ``preprocess_kwargs``, ``forward_kwargs``,
+ ``visualize_kwargs`` and ``postprocess_kwargs``.
+
+ Returns:
+ dict: Inference and visualization results.
+ """
+ if out_dir is not None:
+ if 'vis_out_dir' not in kwargs:
+ kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
+ if 'pred_out_dir' not in kwargs:
+ kwargs['pred_out_dir'] = f'{out_dir}/predictions'
+ (
+ preprocess_kwargs,
+ forward_kwargs,
+ visualize_kwargs,
+ postprocess_kwargs,
+ ) = self._dispatch_kwargs(**kwargs)
+
+ # preprocessing
+ if isinstance(inputs, str) and inputs.startswith('webcam'):
+ inputs = self._get_webcam_inputs(inputs)
+ batch_size = 1
+ if not visualize_kwargs.get('show', False):
+ warnings.warn('The display mode is closed when using webcam '
+ 'input. It will be turned on automatically.')
+ visualize_kwargs['show'] = True
+ else:
+ inputs = self._inputs_to_list(inputs)
+
+ inputs = self.preprocess(
+ inputs, batch_size=batch_size, **preprocess_kwargs)
+
+ preds = []
+ if 'pose2d' not in self.mode or not hasattr(self.pose2d_inferencer,
+ 'detector'):
+ inputs = track(inputs, description='Inference')
+
+ for proc_inputs, ori_inputs in inputs:
+ preds = self.forward(proc_inputs, **forward_kwargs)
+
+ visualization = self.visualize(ori_inputs, preds,
+ **visualize_kwargs)
+ results = self.postprocess(preds, visualization, return_datasample,
+ **postprocess_kwargs)
+ yield results
+
+ # merge visualization and prediction results
+ if self._video_input:
+ self._merge_outputs(**visualize_kwargs, **postprocess_kwargs)
+
+ def visualize(self, inputs: InputsType, preds: PredType,
+ **kwargs) -> List[np.ndarray]:
+ """Visualize predictions.
+
+ Args:
+ inputs (list): Inputs preprocessed by :meth:`_inputs_to_list`.
+ preds (Any): Predictions of the model.
+ return_vis (bool): Whether to return images with predicted results.
+ show (bool): Whether to display the image in a popup window.
+ Defaults to False.
+ show_interval (int): The interval of show (s). Defaults to 0
+ radius (int): Keypoint radius for visualization. Defaults to 3
+ thickness (int): Link thickness for visualization. Defaults to 1
+ kpt_thr (float): The threshold to visualize the keypoints.
+ Defaults to 0.3
+ vis_out_dir (str, optional): directory to save visualization
+ results w/o predictions. If left as empty, no file will
+ be saved. Defaults to ''.
+
+ Returns:
+ List[np.ndarray]: Visualization results.
+ """
+
+ if 'pose2d' in self.mode:
+ window_name = ''
+ if self._video_input:
+ window_name = self.video_info['name']
+ if kwargs.get('vis_out_dir', ''):
+ kwargs['vis_out_dir'] = join_path(kwargs['vis_out_dir'],
+ 'vis_frames')
+ if kwargs.get('show', False):
+ kwargs['wait_time'] = 1e-5
+ return self.pose2d_inferencer.visualize(
+ inputs,
+ preds['pose2d'],
+ window_name=window_name,
+ window_close_event_handler=self._visualization_window_on_close,
+ **kwargs)
+
+ def postprocess(
+ self,
+ preds: List[PoseDataSample],
+ visualization: List[np.ndarray],
+ return_datasample=False,
+ pred_out_dir: str = '',
+ ) -> dict:
+ """Process the predictions and visualization results from ``forward``
+ and ``visualize``.
+
+ This method should be responsible for the following tasks:
+
+ 1. Convert datasamples into a json-serializable dict if needed.
+ 2. Pack the predictions and visualization results and return them.
+ 3. Dump or log the predictions.
+
+ Args:
+ preds (List[Dict]): Predictions of the model.
+ visualization (np.ndarray): Visualized predictions.
+ return_datasample (bool): Whether to return results as
+ datasamples. Defaults to False.
+ pred_out_dir (str): Directory to save the inference results w/o
+ visualization. If left as empty, no file will be saved.
+ Defaults to ''.
+
+ Returns:
+ dict: Inference and visualization results with key ``predictions``
+ and ``visualization``
+
+ - ``visualization (Any)``: Returned by :meth:`visualize`
+ - ``predictions`` (dict or DataSample): Returned by
+ :meth:`forward` and processed in :meth:`postprocess`.
+ If ``return_datasample=False``, it usually should be a
+ json-serializable dict containing only basic data elements such
+ as strings and numbers.
+ """
+
+ if 'pose2d' in self.mode:
+ return super().postprocess(preds['pose2d'], visualization,
+ return_datasample, pred_out_dir)
diff --git a/mmpose/apis/inferencers/pose2d_inferencer.py b/mmpose/apis/inferencers/pose2d_inferencer.py
new file mode 100644
index 0000000000..30ebd7c711
--- /dev/null
+++ b/mmpose/apis/inferencers/pose2d_inferencer.py
@@ -0,0 +1,246 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import warnings
+from typing import Dict, List, Optional, Sequence, Tuple, Union
+
+import mmcv
+import numpy as np
+from mmdet.apis.det_inferencer import DetInferencer
+from mmengine.config import Config, ConfigDict
+from mmengine.infer.infer import ModelType
+from mmengine.registry import init_default_scope
+from mmengine.structures import InstanceData
+from rich.progress import track
+
+from mmpose.evaluation.functional import nms
+from mmpose.registry import DATASETS, INFERENCERS
+from mmpose.structures import merge_data_samples
+from .base_mmpose_inferencer import BaseMMPoseInferencer
+from .utils import default_det_models
+
+InstanceList = List[InstanceData]
+InputType = Union[str, np.ndarray]
+InputsType = Union[InputType, Sequence[InputType]]
+PredType = Union[InstanceData, InstanceList]
+ImgType = Union[np.ndarray, Sequence[np.ndarray]]
+ConfigType = Union[Config, ConfigDict]
+ResType = Union[Dict, List[Dict], InstanceData, List[InstanceData]]
+
+
+@INFERENCERS.register_module(name='pose-estimation')
+@INFERENCERS.register_module()
+class Pose2DInferencer(BaseMMPoseInferencer):
+ """The inferencer for 2D pose estimation.
+
+ Args:
+ model (str, optional): Pretrained 2D pose estimation algorithm.
+ It's the path to the config file or the model name defined in
+ metafile. For example, it could be:
+
+ - model alias, e.g. ``'body'``,
+ - config name, e.g. ``'simcc_res50_8xb64-210e_coco-256x192'``,
+ - config path
+
+ Defaults to ``None``.
+ weights (str, optional): Path to the checkpoint. If it is not
+ specified and "model" is a model name of metafile, the weights
+ will be loaded from metafile. Defaults to None.
+ device (str, optional): Device to run inference. If None, the
+ available device will be automatically used. Defaults to None.
+ scope (str, optional): The scope of the model. Defaults to "mmpose".
+ det_model(str, optional): Config path or alias of detection model.
+ Defaults to None.
+ det_weights(str, optional): Path to the checkpoints of detection
+ model. Defaults to None.
+ det_cat_ids(int or list[int], optional): Category id for
+ detection model. Defaults to None.
+ """
+
+ preprocess_kwargs: set = {'bbox_thr', 'nms_thr'}
+ forward_kwargs: set = set()
+ visualize_kwargs: set = {
+ 'return_vis',
+ 'show',
+ 'wait_time',
+ 'radius',
+ 'thickness',
+ 'kpt_thr',
+ 'vis_out_dir',
+ }
+ postprocess_kwargs: set = {'pred_out_dir'}
+
+ def __init__(self,
+ model: Union[ModelType, str],
+ weights: Optional[str] = None,
+ device: Optional[str] = None,
+ scope: Optional[str] = 'mmpose',
+ det_model: Optional[Union[ModelType, str]] = None,
+ det_weights: Optional[str] = None,
+ det_cat_ids: Optional[Union[int, Tuple]] = None) -> None:
+
+ init_default_scope(scope)
+ super().__init__(
+ model=model, weights=weights, device=device, scope=scope)
+
+ # assign dataset metainfo to self.visualizer
+ self.visualizer.set_dataset_meta(self.model.dataset_meta)
+
+ # initialize detector for top-down models
+ if self.cfg.data_mode == 'topdown':
+ if det_model is None:
+ det_model = DATASETS.get(
+ self.cfg.dataset_type).__module__.split(
+ 'datasets.')[-1].split('.')[0].lower()
+ det_info = default_det_models[det_model]
+ det_model, det_weights, det_cat_ids = det_info[
+ 'model'], det_info['weights'], det_info['cat_ids']
+
+ self.detector = DetInferencer(
+ det_model, det_weights, device=device)
+ if isinstance(det_cat_ids, (tuple, list)):
+ self.det_cat_ids = det_cat_ids
+ else:
+ self.det_cat_ids = (det_cat_ids, )
+
+ self._video_input = False
+
+ def preprocess_single(self,
+ input: InputType,
+ index: int,
+ bbox_thr: float = 0.3,
+ nms_thr: float = 0.3):
+ """Process a single input into a model-feedable format.
+
+ Args:
+ input (InputType): Input given by user.
+ index (int): index of the input
+ bbox_thr (float): threshold for bounding box detection.
+ Defaults to 0.3.
+ nms_thr (float): IoU threshold for bounding box NMS.
+ Defaults to 0.3.
+
+ Yields:
+ Any: Data processed by the ``pipeline`` and ``collate_fn``.
+ """
+
+ if isinstance(input, str):
+ data_info = dict(img_path=input)
+ else:
+ data_info = dict(img=input, img_path=f'{index}.jpg'.rjust(10, '0'))
+ data_info.update(self.model.dataset_meta)
+
+ if self.cfg.data_mode == 'topdown':
+ det_results = self.detector(
+ input, return_datasample=True)['predictions']
+ pred_instance = det_results[0].pred_instances.cpu().numpy()
+ bboxes = np.concatenate(
+ (pred_instance.bboxes, pred_instance.scores[:, None]), axis=1)
+
+ label_mask = np.zeros(len(bboxes), dtype=np.uint8)
+ for cat_id in self.det_cat_ids:
+ label_mask = np.logical_or(label_mask,
+ pred_instance.labels == cat_id)
+
+ bboxes = bboxes[np.logical_and(label_mask,
+ pred_instance.scores > bbox_thr)]
+ bboxes = bboxes[nms(bboxes, nms_thr)]
+
+ data_infos = []
+ if len(bboxes) > 0:
+ for bbox in bboxes:
+ inst = data_info.copy()
+ inst['bbox'] = bbox[None, :4]
+ inst['bbox_score'] = bbox[4:5]
+ data_infos.append(self.pipeline(inst))
+ else:
+ inst = data_info.copy()
+
+ # get bbox from the image size
+ if isinstance(input, str):
+ input = mmcv.imread(input)
+ h, w = input.shape[:2]
+
+ inst['bbox'] = np.array([[0, 0, w, h]], dtype=np.float32)
+ inst['bbox_score'] = np.ones(1, dtype=np.float32)
+ data_infos.append(self.pipeline(inst))
+
+ else: # bottom-up
+ data_infos = [self.pipeline(data_info)]
+
+ return data_infos
+
+ def forward(self, inputs: Union[dict, tuple]):
+ data_samples = super().forward(inputs)
+ if self.cfg.data_mode == 'topdown':
+ data_samples = [merge_data_samples(data_samples)]
+ return data_samples
+
+ def __call__(
+ self,
+ inputs: InputsType,
+ return_datasample: bool = False,
+ batch_size: int = 1,
+ out_dir: Optional[str] = None,
+ **kwargs,
+ ) -> dict:
+ """Call the inferencer.
+
+ Args:
+ inputs (InputsType): Inputs for the inferencer.
+ return_datasample (bool): Whether to return results as
+ :obj:`BaseDataElement`. Defaults to False.
+ batch_size (int): Batch size. Defaults to 1.
+ out_dir (str, optional): directory to save visualization
+ results and predictions. Will be overoden if vis_out_dir or
+ pred_out_dir are given. Defaults to None
+ **kwargs: Key words arguments passed to :meth:`preprocess`,
+ :meth:`forward`, :meth:`visualize` and :meth:`postprocess`.
+ Each key in kwargs should be in the corresponding set of
+ ``preprocess_kwargs``, ``forward_kwargs``,
+ ``visualize_kwargs`` and ``postprocess_kwargs``.
+
+ Returns:
+ dict: Inference and visualization results.
+ """
+ if out_dir is not None:
+ if 'vis_out_dir' not in kwargs:
+ kwargs['vis_out_dir'] = f'{out_dir}/visualizations'
+ if 'pred_out_dir' not in kwargs:
+ kwargs['pred_out_dir'] = f'{out_dir}/predictions'
+
+ (
+ preprocess_kwargs,
+ forward_kwargs,
+ visualize_kwargs,
+ postprocess_kwargs,
+ ) = self._dispatch_kwargs(**kwargs)
+
+ # preprocessing
+ if isinstance(inputs, str) and inputs.startswith('webcam'):
+ inputs = self._get_webcam_inputs(inputs)
+ batch_size = 1
+ if not visualize_kwargs.get('show', False):
+ warnings.warn('The display mode is closed when using webcam '
+ 'input. It will be turned on automatically.')
+ visualize_kwargs['show'] = True
+ else:
+ inputs = self._inputs_to_list(inputs)
+
+ inputs = self.preprocess(
+ inputs, batch_size=batch_size, **preprocess_kwargs)
+
+ preds = []
+ if not hasattr(self, 'detector'):
+ inputs = track(inputs, description='Inference')
+
+ for proc_inputs, ori_inputs in inputs:
+ preds = self.forward(proc_inputs, **forward_kwargs)
+
+ visualization = self.visualize(ori_inputs, preds,
+ **visualize_kwargs)
+ results = self.postprocess(preds, visualization, return_datasample,
+ **postprocess_kwargs)
+ yield results
+
+ # merge visualization and prediction results
+ if self._video_input:
+ self._merge_outputs(**visualize_kwargs, **postprocess_kwargs)
diff --git a/mmpose/apis/inferencers/utils/__init__.py b/mmpose/apis/inferencers/utils/__init__.py
new file mode 100644
index 0000000000..e43e7b6734
--- /dev/null
+++ b/mmpose/apis/inferencers/utils/__init__.py
@@ -0,0 +1,4 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from .default_det_models import default_det_models
+
+__all__ = ['default_det_models']
diff --git a/mmpose/apis/inferencers/utils/default_det_models.py b/mmpose/apis/inferencers/utils/default_det_models.py
new file mode 100644
index 0000000000..96fda2cf14
--- /dev/null
+++ b/mmpose/apis/inferencers/utils/default_det_models.py
@@ -0,0 +1,40 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os.path as osp
+
+from mmengine.config.utils import MODULE2PACKAGE
+from mmengine.utils import get_installed_path
+
+mmpose_path = get_installed_path(MODULE2PACKAGE['mmpose'])
+
+default_det_models = dict(
+ human=dict(
+ model=osp.join(mmpose_path, '.mim',
+ 'demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'),
+ weights='https://download.openmmlab.com/mmdetection/v2.0/'
+ 'faster_rcnn/faster_rcnn_r50_fpn_1x_coco/'
+ 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
+ cat_ids=(0, )),
+ face=dict(
+ model=osp.join(mmpose_path, '.mim',
+ 'demo/mmdetection_cfg/yolox-s_8xb8-300e_coco-face.py'),
+ weights='https://download.openmmlab.com/mmpose/mmdet_pretrained/'
+ 'yolo-x_8xb8-300e_coco-face_13274d7c.pth',
+ cat_ids=(0, )),
+ hand=dict(
+ model=osp.join(
+ mmpose_path, '.mim',
+ 'demo/mmdetection_cfg/cascade_rcnn_x101_64x4d_fpn_1class.py'),
+ weights='https://download.openmmlab.com/mmpose/mmdet_pretrained/'
+ 'cascade_rcnn_x101_64x4d_fpn_20e_onehand10k-dac19597_20201030.pth',
+ cat_ids=(0, )),
+ animal=dict(
+ model=osp.join(mmpose_path, '.mim',
+ 'demo/mmdetection_cfg/faster_rcnn_r50_fpn_coco.py'),
+ weights='https://download.openmmlab.com/mmdetection/v2.0/'
+ 'faster_rcnn/faster_rcnn_r50_fpn_1x_coco/'
+ 'faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth',
+ cat_ids=(15, 16, 17, 18, 19, 20, 21, 22, 23)),
+)
+
+default_det_models['body'] = default_det_models['human']
+default_det_models['wholebody'] = default_det_models['human']
diff --git a/mmpose/registry.py b/mmpose/registry.py
index 78bd75c64b..f1c080565f 100644
--- a/mmpose/registry.py
+++ b/mmpose/registry.py
@@ -11,6 +11,7 @@
from mmengine.registry import DATASETS as MMENGINE_DATASETS
from mmengine.registry import EVALUATOR as MMENGINE_EVALUATOR
from mmengine.registry import HOOKS as MMENGINE_HOOKS
+from mmengine.registry import INFERENCERS as MMENGINE_INFERENCERS
from mmengine.registry import LOG_PROCESSORS as MMENGINE_LOG_PROCESSORS
from mmengine.registry import LOOPS as MMENGINE_LOOPS
from mmengine.registry import METRICS as MMENGINE_METRICS
@@ -125,3 +126,9 @@
# manager keypoint encoder/decoder
KEYPOINT_CODECS = Registry('KEYPOINT_CODECS', locations=['mmpose.codecs'])
+
+# manage inferencer
+INFERENCERS = Registry(
+ 'inferencer',
+ parent=MMENGINE_INFERENCERS,
+ locations=['mmpose.apis.inferencers'])
diff --git a/mmpose/visualization/local_visualizer.py b/mmpose/visualization/local_visualizer.py
index 545ac53be3..d4030be820 100644
--- a/mmpose/visualization/local_visualizer.py
+++ b/mmpose/visualization/local_visualizer.py
@@ -506,3 +506,5 @@ def add_datasample(self,
else:
# save drawn_img to backends
self.add_image(name, drawn_img, step)
+
+ return self.get_image()
diff --git a/model-index.yml b/model-index.yml
index 12690e94ef..d961b808aa 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -6,3 +6,6 @@ Import:
- configs/body_2d_keypoint/simcc/coco/resnet_coco.yml
- configs/body_2d_keypoint/simcc/coco/mobilenetv2_coco.yml
- configs/body_2d_keypoint/simcc/coco/vipnas_coco.yml
+- configs/animal_2d_keypoint/topdown_heatmap/ap10k/resnet_ap10k.yml
+- configs/face_2d_keypoint/topdown_heatmap/wflw/hrnetv2_wflw.yml
+- configs/hand_2d_keypoint/topdown_heatmap/onehand10k/resnet_onehand10k.yml
diff --git a/tests/test_apis/test_inferencers/test_mmpose_inferencer.py b/tests/test_apis/test_inferencers/test_mmpose_inferencer.py
new file mode 100644
index 0000000000..3df85fc46e
--- /dev/null
+++ b/tests/test_apis/test_inferencers/test_mmpose_inferencer.py
@@ -0,0 +1,74 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from collections import defaultdict
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+
+import mmcv
+
+from mmpose.apis.inferencers import MMPoseInferencer
+from mmpose.structures import PoseDataSample
+
+
+class TestMMPoseInferencer(TestCase):
+
+ def test_call(self):
+
+ # top-down model
+ inferencer = MMPoseInferencer('human')
+
+ img_path = 'tests/data/coco/000000197388.jpg'
+ img = mmcv.imread(img_path)
+
+ # `inputs` is path to an image
+ inputs = img_path
+ results1 = next(inferencer(inputs, return_vis=True))
+ self.assertIn('visualization', results1)
+ self.assertSequenceEqual(results1['visualization'][0].shape, img.shape)
+ self.assertIn('predictions', results1)
+ self.assertIn('keypoints', results1['predictions'][0][0])
+ self.assertEqual(len(results1['predictions'][0][0]['keypoints']), 17)
+
+ # `inputs` is an image array
+ inputs = img
+ results2 = next(inferencer(inputs))
+ self.assertEqual(
+ len(results1['predictions'][0]), len(results2['predictions'][0]))
+ self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
+ results2['predictions'][0][0]['keypoints'])
+ results2 = next(inferencer(inputs, return_datasample=True))
+ self.assertIsInstance(results2['predictions'][0], PoseDataSample)
+
+ # `inputs` is path to a directory
+ inputs = osp.dirname(img_path)
+ with TemporaryDirectory() as tmp_dir:
+ # only save visualizations
+ for res in inferencer(inputs, vis_out_dir=tmp_dir):
+ pass
+ self.assertEqual(len(os.listdir(tmp_dir)), 4)
+ # save both visualizations and predictions
+ results3 = defaultdict(list)
+ for res in inferencer(inputs, out_dir=tmp_dir):
+ for key in res:
+ results3[key].extend(res[key])
+ self.assertEqual(len(os.listdir(f'{tmp_dir}/visualizations')), 4)
+ self.assertEqual(len(os.listdir(f'{tmp_dir}/predictions')), 4)
+ self.assertEqual(len(results3['predictions']), 4)
+ self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
+ results3['predictions'][3][0]['keypoints'])
+
+ # `inputs` is path to a video
+ inputs = 'tests/data/posetrack18/videos/000001_mpiinew_test/' \
+ '000001_mpiinew_test.mp4'
+ with TemporaryDirectory() as tmp_dir:
+ results = defaultdict(list)
+ for res in inferencer(inputs, out_dir=tmp_dir):
+ for key in res:
+ results[key].extend(res[key])
+ self.assertIn('000001_mpiinew_test.mp4',
+ os.listdir(f'{tmp_dir}/visualizations'))
+ self.assertIn('000001_mpiinew_test.json',
+ os.listdir(f'{tmp_dir}/predictions'))
+ self.assertTrue(inferencer._video_input)
+ self.assertIn(len(results['predictions']), (4, 5))
diff --git a/tests/test_apis/test_inferencers/test_pose2d_inferencer.py b/tests/test_apis/test_inferencers/test_pose2d_inferencer.py
new file mode 100644
index 0000000000..f402d05d19
--- /dev/null
+++ b/tests/test_apis/test_inferencers/test_pose2d_inferencer.py
@@ -0,0 +1,119 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import os
+import os.path as osp
+from collections import defaultdict
+from tempfile import TemporaryDirectory
+from unittest import TestCase
+
+import mmcv
+import torch
+from mmengine.infer.infer import BaseInferencer
+
+from mmpose.apis.inferencers import Pose2DInferencer
+from mmpose.structures import PoseDataSample
+
+
+class TestPose2DInferencer(TestCase):
+
+ def _test_init(self):
+
+ # 1. init with config path and checkpoint
+ inferencer = Pose2DInferencer(
+ model='configs/body_2d_keypoint/simcc/coco/'
+ 'simcc_res50_8xb64-210e_coco-256x192.py',
+ weights='https://download.openmmlab.com/mmpose/'
+ 'v1/body_2d_keypoint/simcc/coco/'
+ 'simcc_res50_8xb64-210e_coco-256x192-8e0f5b59_20220919.pth',
+ )
+ self.assertIsInstance(inferencer.model, torch.nn.Module)
+ self.assertIsInstance(inferencer.detector, BaseInferencer)
+ self.assertSequenceEqual(inferencer.det_cat_ids, (0, ))
+
+ # 2. init with config name
+ inferencer = Pose2DInferencer(
+ model='td-hm_res50_8xb32-210e_onehand10k-256x256')
+ self.assertIsInstance(inferencer.model, torch.nn.Module)
+ self.assertIsInstance(inferencer.detector, BaseInferencer)
+ self.assertSequenceEqual(inferencer.det_cat_ids, (0, ))
+
+ # 3. init with alias
+ with self.assertWarnsRegex(
+ Warning, 'dataset_meta are not saved in '
+ 'the checkpoint\'s meta data, load via config.'):
+ inferencer = Pose2DInferencer(model='animal')
+ self.assertIsInstance(inferencer.model, torch.nn.Module)
+ self.assertIsInstance(inferencer.detector, BaseInferencer)
+ self.assertSequenceEqual(inferencer.det_cat_ids,
+ (15, 16, 17, 18, 19, 20, 21, 22, 23))
+
+ # 4. init with bottom-up model
+ inferencer = Pose2DInferencer(
+ model='configs/body_2d_keypoint/dekr/coco/'
+ 'dekr_hrnet-w32_8xb10-140e_coco-512x512.py',
+ weights='https://download.openmmlab.com/mmpose/v1/'
+ 'body_2d_keypoint/dekr/coco/'
+ 'dekr_hrnet-w32_8xb10-140e_coco-512x512_ac7c17bf-20221228.pth',
+ )
+ self.assertIsInstance(inferencer.model, torch.nn.Module)
+ self.assertFalse(hasattr(inferencer, 'detector'))
+
+ def test_call(self):
+
+ # top-down model
+ inferencer = Pose2DInferencer('human')
+
+ img_path = 'tests/data/coco/000000197388.jpg'
+ img = mmcv.imread(img_path)
+
+ # `inputs` is path to an image
+ inputs = img_path
+ results1 = next(inferencer(inputs, return_vis=True))
+ self.assertIn('visualization', results1)
+ self.assertSequenceEqual(results1['visualization'][0].shape, img.shape)
+ self.assertIn('predictions', results1)
+ self.assertIn('keypoints', results1['predictions'][0][0])
+ self.assertEqual(len(results1['predictions'][0][0]['keypoints']), 17)
+
+ # `inputs` is an image array
+ inputs = img
+ results2 = next(inferencer(inputs))
+ self.assertEqual(
+ len(results1['predictions'][0]), len(results2['predictions'][0]))
+ self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
+ results2['predictions'][0][0]['keypoints'])
+ results2 = next(inferencer(inputs, return_datasample=True))
+ self.assertIsInstance(results2['predictions'][0], PoseDataSample)
+
+ # `inputs` is path to a directory
+ inputs = osp.dirname(img_path)
+
+ with TemporaryDirectory() as tmp_dir:
+ # only save visualizations
+ for res in inferencer(inputs, vis_out_dir=tmp_dir):
+ pass
+ self.assertEqual(len(os.listdir(tmp_dir)), 4)
+ # save both visualizations and predictions
+ results3 = defaultdict(list)
+ for res in inferencer(inputs, out_dir=tmp_dir):
+ for key in res:
+ results3[key].extend(res[key])
+ self.assertEqual(len(os.listdir(f'{tmp_dir}/visualizations')), 4)
+ self.assertEqual(len(os.listdir(f'{tmp_dir}/predictions')), 4)
+ self.assertEqual(len(results3['predictions']), 4)
+ self.assertSequenceEqual(results1['predictions'][0][0]['keypoints'],
+ results3['predictions'][3][0]['keypoints'])
+
+ # `inputs` is path to a video
+ inputs = 'tests/data/posetrack18/videos/000001_mpiinew_test/' \
+ '000001_mpiinew_test.mp4'
+ with TemporaryDirectory() as tmp_dir:
+ results = defaultdict(list)
+ for res in inferencer(inputs, out_dir=tmp_dir):
+ for key in res:
+ results[key].extend(res[key])
+ self.assertIn('000001_mpiinew_test.mp4',
+ os.listdir(f'{tmp_dir}/visualizations'))
+ self.assertIn('000001_mpiinew_test.json',
+ os.listdir(f'{tmp_dir}/predictions'))
+ self.assertTrue(inferencer._video_input)
+ self.assertIn(len(results['predictions']), (4, 5))
diff --git a/tests/test_engine/test_hooks/test_visualization_hook.py b/tests/test_engine/test_hooks/test_visualization_hook.py
index 8eab9670de..3e4a202198 100644
--- a/tests/test_engine/test_hooks/test_visualization_hook.py
+++ b/tests/test_engine/test_hooks/test_visualization_hook.py
@@ -27,7 +27,7 @@ def _rand_poses(num_boxes, h, w):
class TestVisualizationHook(TestCase):
def setUp(self) -> None:
- PoseLocalVisualizer.get_instance('visualizer')
+ PoseLocalVisualizer.get_instance('test_visualization_hook')
data_sample = PoseDataSample()
data_sample.set_metainfo({