Skip to content

Commit eb25b19

Browse files
author
Arash Hosseini
committed
write pose json
1 parent b119759 commit eb25b19

File tree

3 files changed

+30
-7
lines changed

3 files changed

+30
-7
lines changed

run.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
help='if provided, resize images before they are processed. default=0x0, Recommends : 432x368 or 656x368 or 1312x736 ')
2828
parser.add_argument('--resize-out-ratio', type=float, default=4.0,
2929
help='if provided, resize heatmaps before they are post-processed. default=1.0')
30-
30+
parser.add_argument('--output_json', type=str, default='/tmp/', help='writing output json dir')
3131
args = parser.parse_args()
3232

3333
w, h = model_wh(args.resize)
@@ -47,7 +47,7 @@
4747

4848
logger.info('inference image: %s in %.4f seconds.' % (args.image, elapsed))
4949

50-
image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False)
50+
image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False, frame=0, output_json_dir=args.output_json)
5151

5252
import matplotlib.pyplot as plt
5353

run_webcam.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,16 @@
2929
help='if provided, resize heatmaps before they are post-processed. default=1.0')
3030

3131
parser.add_argument('--model', type=str, default='mobilenet_thin', help='cmu / mobilenet_thin')
32+
33+
parser.add_argument('--output_json', type=str, default='/tmp/', help='writing output json dir')
34+
3235
parser.add_argument('--show-process', type=bool, default=False,
3336
help='for debug purpose, if enabled, speed for inference is dropped.')
3437
args = parser.parse_args()
3538

3639
logger.debug('initialization %s : %s' % (args.model, get_graph_path(args.model)))
3740
w, h = model_wh(args.resize)
41+
3842
if w > 0 and h > 0:
3943
e = TfPoseEstimator(get_graph_path(args.model), target_size=(w, h))
4044
else:
@@ -44,14 +48,18 @@
4448
ret_val, image = cam.read()
4549
logger.info('cam image=%dx%d' % (image.shape[1], image.shape[0]))
4650

51+
frame = 0
4752
while True:
4853
ret_val, image = cam.read()
4954

5055
logger.debug('image process+')
5156
humans = e.inference(image, resize_to_default=(w > 0 and h > 0), upsample_size=args.resize_out_ratio)
5257

5358
logger.debug('postprocess+')
54-
image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False)
59+
60+
61+
image = TfPoseEstimator.draw_humans(image, humans, imgcopy=False, frame=frame, output_json_dir=args.output_json)
62+
frame += 1
5563

5664
logger.debug('show+')
5765
cv2.putText(image,

tf_pose/estimator.py

+19-4
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
import numpy as np
88
import tensorflow as tf
99
import time
10-
10+
import json
11+
import os
1112
from tf_pose import common
1213
from tf_pose.common import CocoPart
1314
from tf_pose.tensblur.smoother import Smoother
@@ -303,6 +304,7 @@ class TfPoseEstimator:
303304
def __init__(self, graph_path, target_size=(320, 240), tf_config=None):
304305
self.target_size = target_size
305306

307+
306308
# load graph
307309
logger.info('loading graph from %s(default size=%dx%d)' % (graph_path, target_size[0], target_size[1]))
308310
with tf.gfile.GFile(graph_path, 'rb') as f:
@@ -378,12 +380,14 @@ def _quantize_img(npimg):
378380
return npimg_q
379381

380382
@staticmethod
381-
def draw_humans(npimg, humans, imgcopy=False):
383+
def draw_humans(npimg, humans, imgcopy=False, frame=0, output_json_dir=None):
382384
if imgcopy:
383385
npimg = np.copy(npimg)
384386
image_h, image_w = npimg.shape[:2]
387+
dc = {"people":[]}
385388
centers = {}
386-
for human in humans:
389+
for n, human in enumerate(humans):
390+
flat = [0.0 for i in range(36)]
387391
# draw point
388392
for i in range(common.CocoPart.Background.value):
389393
if i not in human.body_parts.keys():
@@ -392,7 +396,11 @@ def draw_humans(npimg, humans, imgcopy=False):
392396
body_part = human.body_parts[i]
393397
center = (int(body_part.x * image_w + 0.5), int(body_part.y * image_h + 0.5))
394398
centers[i] = center
395-
cv2.circle(npimg, center, 3, common.CocoColors[i], thickness=3, lineType=8, shift=0)
399+
#add x
400+
flat[i*2] = center[0]
401+
#add y
402+
flat[i*2+1] = center[1]
403+
cv2.circle(npimg, center, 8, common.CocoColors[i], thickness=3, lineType=8, shift=0)
396404

397405
# draw line
398406
for pair_order, pair in enumerate(common.CocoPairsRender):
@@ -402,6 +410,13 @@ def draw_humans(npimg, humans, imgcopy=False):
402410
# npimg = cv2.line(npimg, centers[pair[0]], centers[pair[1]], common.CocoColors[pair_order], 3)
403411
cv2.line(npimg, centers[pair[0]], centers[pair[1]], common.CocoColors[pair_order], 3)
404412

413+
dc["people"].append({"pose_keypoints_2d" : flat})
414+
415+
if output_json_dir:
416+
with open(os.path.join(output_json_dir, '{0}_keypoints.json'.format(str(frame).zfill(12))), 'w') as outfile:
417+
json.dump(dc, outfile)
418+
419+
405420
return npimg
406421

407422
def _get_scaled_img(self, npimg, scale):

0 commit comments

Comments
 (0)