Skip to content

Commit

Permalink
simple fix
Browse files Browse the repository at this point in the history
  • Loading branch information
SangbumChoi committed Oct 19, 2024
1 parent bb5cc96 commit 899cb96
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 58 deletions.
66 changes: 54 additions & 12 deletions docs/source/en/model_doc/vitpose.md
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,49 @@ boxes = [pascal_voc_to_coco(boxes.cpu().numpy())]

image_processor = VitPoseImageProcessor.from_pretrained("nielsr/vitpose-base-simple")
model = VitPoseForPoseEstimation.from_pretrained("nielsr/vitpose-base-simple")
config = VitPoseConfig()

keypoint_edges = [
[15, 13],
[13, 11],
[16, 14],
[14, 12],
[11, 12],
[5, 11],
[6, 12],
[5, 6],
[5, 7],
[6, 8],
[7, 9],
[8, 10],
[1, 2],
[0, 1],
[0, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
],
keypoint_nodes = [
"Nose",
"L_Eye",
"R_Eye",
"L_Ear",
"R_Ear",
"L_Shoulder",
"R_Shoulder",
"L_Elbow",
"R_Elbow",
"L_Wrist",
"R_Wrist",
"L_Hip",
"R_Hip",
"L_Knee",
"R_Knee",
"L_Ankle",
"R_Ankle",
],

config = VitPoseConfig(keypoint_edges=keypoint_edges, keypoint_nodes=keypoint_nodes)

# Stage 2. Run ViTPose
pixel_values = image_processor(image, boxes=boxes, return_tensors="pt").pixel_values
Expand Down Expand Up @@ -130,11 +172,11 @@ def draw_points(image, keypoints, keypoint_colors, keypoint_score_threshold, rad
else:
cv2.circle(image, (x_coord, y_coord), radius, color, -1)

def draw_links(image, keypoints, keypoint_connections, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
def draw_links(image, keypoints, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight, stick_width = 2):
height, width, _ = image.shape
if keypoint_connections is not None and link_colors is not None:
assert len(link_colors) == len(keypoint_connections)
for sk_id, sk in enumerate(keypoint_connections):
if keypoint_edges is not None and link_colors is not None:
assert len(link_colors) == len(keypoint_edges)
for sk_id, sk in enumerate(keypoint_edges):
x1, y1, score1 = (int(keypoints[sk[0], 0]), int(keypoints[sk[0], 1]), keypoints[sk[0], 2])
x2, y2, score2 = (int(keypoints[sk[1], 0]), int(keypoints[sk[1], 1]), keypoints[sk[1], 2])
if (
Expand Down Expand Up @@ -169,7 +211,7 @@ def draw_links(image, keypoints, keypoint_connections, link_colors, keypoint_sco
def visualize_keypoints(
image,
pose_result,
keypoint_connections=None,
keypoint_edges=None,
keypoint_score_threshold=0.3,
keypoint_colors=None,
link_colors=None,
Expand All @@ -185,8 +227,8 @@ def visualize_keypoints(
pose_result (`List[numpy.ndarray]`):
The poses to draw. Each element is a set of K keypoints as a Kx3 numpy.ndarray, where each keypoint
is represented as x, y, score.
keypoint_connections (`List[tuple]`, *optional*):
Mapping index of the keypoint_connections links.
keypoint_edges (`List[tuple]`, *optional*):
Mapping index of the keypoint_edges links.
keypoint_score_threshold (`float`, *optional*, defaults to 0.3):
Minimum score of keypoints to be shown.
keypoint_colors (`numpy.ndarray`, *optional*):
Expand All @@ -210,12 +252,12 @@ def visualize_keypoints(
draw_points(image, keypoints, keypoint_colors, keypoint_score_threshold, radius, show_keypoint_weight)

# draw links
draw_links(image, keypoints, keypoint_connections, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight)
draw_links(image, keypoints, keypoint_edges, link_colors, keypoint_score_threshold, thickness, show_keypoint_weight)

return image

# Note: keypoint_connections and color palette are dataset-specific
keypoint_connections = config.skeleton_edges
# Note: keypoint_edges and color palette are dataset-specific
keypoint_edges = config.keypoint_edges

palette = np.array(
[
Expand Down Expand Up @@ -250,7 +292,7 @@ pose_results = [result["keypoints"] for result in pose_results]
result = visualize_keypoints(
np.array(image),
pose_result,
keypoint_connections=keypoint_connections,
keypoint_edges=keypoint_edges,
keypoint_score_threshold=0.3,
keypoint_colors=keypoint_colors,
link_colors=link_colors,
Expand Down
46 changes: 0 additions & 46 deletions src/transformers/models/vitpose/configuration_vitpose.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ class VitPoseConfig(PretrainedConfig):
Factor to upscale the feature maps coming from the ViT backbone.
use_simple_decoder (`bool`, *optional*, defaults to `True`):
Whether to use a `VitPoseSimpleDecoder` to decode the feature maps from the backbone into heatmaps. Otherwise it uses `VitPoseClassicDecoder`.
skeleton_edges (`list`, *optional*, defaults to `[[15, 13], [13, 11], [16, 14], [14, 12], [11, 12], [5, 11], [6, 12], [5, 6], [5, 7], [6, 8], [7, 9], [8, 10], [1, 2], [0, 1], [0, 2], [1, 3], [2, 4], [3, 5], [4, 6]]`):
List of edges connecting skeleton nodes, each edge represented by two node indices. This edges are based on MSCOCO.
skeleton_nodes (`list`, *optional*, defaults to `['Nose', 'L_Eye', 'R_Eye', 'L_Ear', 'R_Ear', 'L_Shoulder', 'R_Shoulder', 'L_Elbow', 'R_Elbow', 'L_Wrist', 'R_Wrist', 'L_Hip', 'R_Hip', 'L_Knee', 'R_Knee', 'L_Ankle', 'R_Ankle']`):
List of node names representing different body parts in the skeleton. This edges are based on MSCOCO.
Example:
Expand Down Expand Up @@ -87,46 +83,6 @@ def __init__(
initializer_range: float = 0.02,
scale_factor: int = 4,
use_simple_decoder: bool = True,
skeleton_edges: list = [
[15, 13],
[13, 11],
[16, 14],
[14, 12],
[11, 12],
[5, 11],
[6, 12],
[5, 6],
[5, 7],
[6, 8],
[7, 9],
[8, 10],
[1, 2],
[0, 1],
[0, 2],
[1, 3],
[2, 4],
[3, 5],
[4, 6],
],
skeleton_nodes: list = [
"Nose",
"L_Eye",
"R_Eye",
"L_Ear",
"R_Ear",
"L_Shoulder",
"R_Shoulder",
"L_Elbow",
"R_Elbow",
"L_Wrist",
"R_Wrist",
"L_Hip",
"R_Hip",
"L_Knee",
"R_Knee",
"L_Ankle",
"R_Ankle",
],
**kwargs,
):
super().__init__(**kwargs)
Expand Down Expand Up @@ -163,5 +119,3 @@ def __init__(
self.initializer_range = initializer_range
self.scale_factor = scale_factor
self.use_simple_decoder = use_simple_decoder
self.skeleton_edges = skeleton_edges
self.skeleton_nodes = skeleton_nodes

0 comments on commit 899cb96

Please sign in to comment.