1
1
from functools import lru_cache
2
+ from pytorch_kinematics .transforms .rotation_conversions import tensor_axis_and_angle_to_matrix
3
+ from pytorch_kinematics .transforms .rotation_conversions import tensor_axis_and_d_to_pris_matrix
2
4
from typing import Optional , Sequence
3
5
4
6
import numpy as np
5
7
import torch
6
8
7
9
import pytorch_kinematics .transforms as tf
8
- import zpk_cpp
9
10
from pytorch_kinematics import jacobian
10
11
from pytorch_kinematics .frame import Frame , Link , Joint
11
12
@@ -68,7 +69,7 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
68
69
69
70
# As we traverse the kinematic tree, each frame is assigned an index.
70
71
# We use this index to build a flat representation of the tree.
71
- # parent_indices and joint_indices all use this indexing scheme.
72
+ # parents_indices and joint_indices all use this indexing scheme.
72
73
# The root frame will be index 0 and the first frame of the root frame's children will be index 1,
73
74
# then the child of that frame will be index 2, etc. In other words, it's a depth-first ordering.
74
75
self .parents_indices = [] # list of indices from 0 (root) to the given frame
@@ -89,9 +90,9 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
89
90
self .frame_to_idx [name_strip ] = idx
90
91
self .idx_to_frame [idx ] = name_strip
91
92
if parent_idx == - 1 :
92
- self .parents_indices .append ([])
93
+ self .parents_indices .append ([idx ])
93
94
else :
94
- self .parents_indices .append (self .parents_indices [parent_idx ] + [parent_idx ])
95
+ self .parents_indices .append (self .parents_indices [parent_idx ] + [idx ])
95
96
96
97
is_fixed = root .joint .joint_type == 'fixed'
97
98
@@ -134,6 +135,7 @@ def to(self, dtype=None, device=None):
134
135
self .identity = self .identity .to (device = self .device , dtype = self .dtype )
135
136
self .parents_indices = [p .to (dtype = torch .long , device = self .device ) for p in self .parents_indices ]
136
137
self .joint_type_indices = self .joint_type_indices .to (dtype = torch .long , device = self .device )
138
+ self .joint_indices = self .joint_indices .to (dtype = torch .long , device = self .device )
137
139
self .axes = self .axes .to (dtype = self .dtype , device = self .device )
138
140
self .link_offsets = [l if l is None else l .to (dtype = self .dtype , device = self .device ) for l in self .link_offsets ]
139
141
self .joint_offsets = [j if j is None else j .to (dtype = self .dtype , device = self .device ) for j in
@@ -298,25 +300,36 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
298
300
th = self .ensure_tensor (th )
299
301
th = torch .atleast_2d (th )
300
302
301
- b , n = th .shape
303
+ import zpk_cpp
304
+ frame_transforms = zpk_cpp .fk (
305
+ frame_indices ,
306
+ self .axes ,
307
+ th ,
308
+ self .parents_indices ,
309
+ self .joint_type_indices ,
310
+ self .joint_indices ,
311
+ self .joint_offsets ,
312
+ self .link_offsets
313
+ )
314
+
315
+ frame_names_and_transform3ds = {self .idx_to_frame [frame_idx ]: tf .Transform3d (matrix = transform ) for
316
+ frame_idx , transform in frame_transforms .items ()}
317
+
318
+ return frame_names_and_transform3ds
319
+
320
+ def forward_kinematics_py (self , th , frame_indices : Optional = None ):
321
+ if frame_indices is None :
322
+ frame_indices = self .get_all_frame_indices ()
323
+
324
+ th = self .ensure_tensor (th )
325
+ th = torch .atleast_2d (th )
326
+
327
+ b = th .shape [0 ]
302
328
303
329
axes_expanded = self .axes .unsqueeze (0 ).repeat (b , 1 , 1 )
304
330
305
- # TODO: reimplement in CPP
306
- # frame_transforms = zpk_cpp.fk(
307
- # frame_indices,
308
- # axes_expanded,
309
- # th,
310
- # self.parent_indices,
311
- # self.joint_indices,
312
- # self.joint_offsets,
313
- # self.link_offsets
314
- # )
315
-
316
- from pytorch_kinematics .transforms .rotation_conversions import tensor_axis_and_angle_to_matrix
317
- from pytorch_kinematics .transforms .rotation_conversions import tensor_axis_and_d_to_pris_matrix
318
331
frame_transforms = {}
319
- b = th . size ( 0 )
332
+
320
333
# compute all joint transforms at once first
321
334
# in order to handle multiple joint types without branching, we create all possible transforms
322
335
# for all joint types and then select the appropriate one for each joint.
@@ -327,9 +340,8 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
327
340
frame_transform = torch .eye (4 ).to (th ).unsqueeze (0 ).repeat (b , 1 , 1 )
328
341
329
342
# iterate down the list and compose the transform
330
- chain_indices = torch .cat ((self .parents_indices [frame_idx ], frame_idx [None ]))
331
- for chain_idx in chain_indices :
332
- if chain_idx .item () in frame_transforms and False : # DEBUGGING
343
+ for chain_idx in self .parents_indices [frame_idx ]:
344
+ if chain_idx .item () in frame_transforms :
333
345
frame_transform = frame_transforms [chain_idx .item ()]
334
346
else :
335
347
link_offset_i = self .link_offsets [chain_idx ]
@@ -481,13 +493,34 @@ def jacobian(self, th, locations=None):
481
493
482
494
def forward_kinematics (self , th , end_only : bool = True ):
483
495
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
496
+ frame_indices , th = self .convert_serial_inputs_to_chain_inputs (end_only , th )
497
+
498
+ mat = super ().forward_kinematics (th , frame_indices )
499
+
500
+ if end_only :
501
+ return mat [self ._serial_frames [- 1 ].name ]
502
+ else :
503
+ return mat
504
+
505
+ def forward_kinematics_py (self , th , end_only : bool = True ):
506
+ """ Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
507
+ frame_indices , th = self .convert_serial_inputs_to_chain_inputs (end_only , th )
508
+
509
+ mat = super ().forward_kinematics_py (th , frame_indices )
510
+
511
+ if end_only :
512
+ return mat [self ._serial_frames [- 1 ].name ]
513
+ else :
514
+ return mat
515
+
516
+
517
+ def convert_serial_inputs_to_chain_inputs (self , end_only , th ):
484
518
if end_only :
485
519
frame_indices = self .get_frame_indices (self ._serial_frames [- 1 ].name )
486
520
else :
487
521
# pass through default behavior for frame indices being None, which is currently
488
522
# to return all frames.
489
523
frame_indices = None
490
-
491
524
if get_th_size (th ) < self .n_joints :
492
525
# if th is only a partial list of joints, assume it's a list of joints for only the serial chain.
493
526
partial_th = th
@@ -500,13 +533,7 @@ def forward_kinematics(self, th, end_only: bool = True):
500
533
jnt_idx = self .joint_indices [k ]
501
534
if frame .joint .joint_type != 'fixed' :
502
535
th [jnt_idx ] = partial_th_i
503
-
504
- mat = super ().forward_kinematics (th , frame_indices )
505
-
506
- if end_only :
507
- return mat [self ._serial_frames [- 1 ].name ]
508
- else :
509
- return mat
536
+ return frame_indices , th
510
537
511
538
def forward_kinematics_slow (self , th , world = None , end_only = True ):
512
539
if world is None :
0 commit comments