Skip to content

Commit 57b74fe

Browse files
author
Peter
committed
finish translating to C++ and test for correctness with Val model
1 parent 1318261 commit 57b74fe

File tree

5 files changed

+362
-84
lines changed

5 files changed

+362
-84
lines changed

.clang-format

+168
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
---
2+
Language: Cpp
3+
# BasedOnStyle: Google
4+
AccessModifierOffset: -1
5+
AlignAfterOpenBracket: Align
6+
AlignConsecutiveMacros: false
7+
AlignConsecutiveAssignments: false
8+
AlignConsecutiveDeclarations: false
9+
AlignEscapedNewlines: Left
10+
AlignOperands: true
11+
AlignTrailingComments: true
12+
AllowAllArgumentsOnNextLine: true
13+
AllowAllConstructorInitializersOnNextLine: true
14+
AllowAllParametersOfDeclarationOnNextLine: true
15+
AllowShortBlocksOnASingleLine: Never
16+
AllowShortCaseLabelsOnASingleLine: false
17+
AllowShortFunctionsOnASingleLine: All
18+
AllowShortLambdasOnASingleLine: All
19+
AllowShortIfStatementsOnASingleLine: WithoutElse
20+
AllowShortLoopsOnASingleLine: true
21+
AlwaysBreakAfterDefinitionReturnType: None
22+
AlwaysBreakAfterReturnType: None
23+
AlwaysBreakBeforeMultilineStrings: true
24+
AlwaysBreakTemplateDeclarations: Yes
25+
BinPackArguments: true
26+
BinPackParameters: true
27+
BraceWrapping:
28+
AfterCaseLabel: false
29+
AfterClass: false
30+
AfterControlStatement: false
31+
AfterEnum: false
32+
AfterFunction: false
33+
AfterNamespace: false
34+
AfterObjCDeclaration: false
35+
AfterStruct: false
36+
AfterUnion: false
37+
AfterExternBlock: false
38+
BeforeCatch: false
39+
BeforeElse: false
40+
IndentBraces: false
41+
SplitEmptyFunction: true
42+
SplitEmptyRecord: true
43+
SplitEmptyNamespace: true
44+
BreakBeforeBinaryOperators: None
45+
BreakBeforeBraces: Attach
46+
BreakBeforeInheritanceComma: false
47+
BreakInheritanceList: BeforeColon
48+
BreakBeforeTernaryOperators: true
49+
BreakConstructorInitializersBeforeComma: false
50+
BreakConstructorInitializers: BeforeColon
51+
BreakAfterJavaFieldAnnotations: false
52+
BreakStringLiterals: true
53+
ColumnLimit: 120
54+
CommentPragmas: '^ IWYU pragma:'
55+
CompactNamespaces: false
56+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
57+
ConstructorInitializerIndentWidth: 4
58+
ContinuationIndentWidth: 4
59+
Cpp11BracedListStyle: true
60+
DeriveLineEnding: true
61+
DerivePointerAlignment: true
62+
DisableFormat: false
63+
ExperimentalAutoDetectBinPacking: false
64+
FixNamespaceComments: true
65+
ForEachMacros:
66+
- foreach
67+
- Q_FOREACH
68+
- BOOST_FOREACH
69+
IncludeBlocks: Regroup
70+
IncludeCategories:
71+
- Regex: '^<ext/.*\.h>'
72+
Priority: 2
73+
SortPriority: 0
74+
- Regex: '^<.*\.h>'
75+
Priority: 1
76+
SortPriority: 0
77+
- Regex: '^<.*'
78+
Priority: 2
79+
SortPriority: 0
80+
- Regex: '.*'
81+
Priority: 3
82+
SortPriority: 0
83+
IncludeIsMainRegex: '([-_](test|unittest))?$'
84+
IncludeIsMainSourceRegex: ''
85+
IndentCaseLabels: true
86+
IndentGotoLabels: true
87+
IndentPPDirectives: None
88+
IndentWidth: 2
89+
IndentWrappedFunctionNames: false
90+
JavaScriptQuotes: Leave
91+
JavaScriptWrapImports: true
92+
KeepEmptyLinesAtTheStartOfBlocks: false
93+
MacroBlockBegin: ''
94+
MacroBlockEnd: ''
95+
MaxEmptyLinesToKeep: 1
96+
NamespaceIndentation: None
97+
ObjCBinPackProtocolList: Never
98+
ObjCBlockIndentWidth: 2
99+
ObjCSpaceAfterProperty: false
100+
ObjCSpaceBeforeProtocolList: true
101+
PenaltyBreakAssignment: 2
102+
PenaltyBreakBeforeFirstCallParameter: 1
103+
PenaltyBreakComment: 300
104+
PenaltyBreakFirstLessLess: 120
105+
PenaltyBreakString: 1000
106+
PenaltyBreakTemplateDeclaration: 10
107+
PenaltyExcessCharacter: 1000000
108+
PenaltyReturnTypeOnItsOwnLine: 200
109+
PointerAlignment: Left
110+
RawStringFormats:
111+
- Language: Cpp
112+
Delimiters:
113+
- cc
114+
- CC
115+
- cpp
116+
- Cpp
117+
- CPP
118+
- 'c++'
119+
- 'C++'
120+
CanonicalDelimiter: ''
121+
BasedOnStyle: google
122+
- Language: TextProto
123+
Delimiters:
124+
- pb
125+
- PB
126+
- proto
127+
- PROTO
128+
EnclosingFunctions:
129+
- EqualsProto
130+
- EquivToProto
131+
- PARSE_PARTIAL_TEXT_PROTO
132+
- PARSE_TEST_PROTO
133+
- PARSE_TEXT_PROTO
134+
- ParseTextOrDie
135+
- ParseTextProtoOrDie
136+
CanonicalDelimiter: ''
137+
BasedOnStyle: google
138+
ReflowComments: true
139+
SortIncludes: true
140+
SortUsingDeclarations: true
141+
SpaceAfterCStyleCast: false
142+
SpaceAfterLogicalNot: false
143+
SpaceAfterTemplateKeyword: true
144+
SpaceBeforeAssignmentOperators: true
145+
SpaceBeforeCpp11BracedList: false
146+
SpaceBeforeCtorInitializerColon: true
147+
SpaceBeforeInheritanceColon: true
148+
SpaceBeforeParens: ControlStatements
149+
SpaceBeforeRangeBasedForLoopColon: true
150+
SpaceInEmptyBlock: false
151+
SpaceInEmptyParentheses: false
152+
SpacesBeforeTrailingComments: 2
153+
SpacesInAngles: false
154+
SpacesInConditionalStatement: false
155+
SpacesInContainerLiterals: true
156+
SpacesInCStyleCastParentheses: false
157+
SpacesInParentheses: false
158+
SpacesInSquareBrackets: false
159+
SpaceBeforeSquareBrackets: false
160+
Standard: Auto
161+
StatementMacros:
162+
- Q_UNUSED
163+
- QT_REQUIRE_VERSION
164+
TabWidth: 8
165+
UseCRLF: false
166+
UseTab: Never
167+
...
168+

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33

44
# This is needed in order to build the C++ extension
55
import torch
6-
print(f'>>>>> {torch.__version__} <<<<<')
76
ext = cpp_extension.CppExtension('zpk_cpp', ['src/zpk_cpp/pk.cpp'], extra_compile_args=['-std=c++17'])
87
setup_args = dict(
98
packages=find_packages(where="src"),

src/pytorch_kinematics/chain.py

+57-30
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from functools import lru_cache
2+
from pytorch_kinematics.transforms.rotation_conversions import tensor_axis_and_angle_to_matrix
3+
from pytorch_kinematics.transforms.rotation_conversions import tensor_axis_and_d_to_pris_matrix
24
from typing import Optional, Sequence
35

46
import numpy as np
57
import torch
68

79
import pytorch_kinematics.transforms as tf
8-
import zpk_cpp
910
from pytorch_kinematics import jacobian
1011
from pytorch_kinematics.frame import Frame, Link, Joint
1112

@@ -68,7 +69,7 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
6869

6970
# As we traverse the kinematic tree, each frame is assigned an index.
7071
# We use this index to build a flat representation of the tree.
71-
# parent_indices and joint_indices all use this indexing scheme.
72+
# parents_indices and joint_indices all use this indexing scheme.
7273
# The root frame will be index 0 and the first frame of the root frame's children will be index 1,
7374
# then the child of that frame will be index 2, etc. In other words, it's a depth-first ordering.
7475
self.parents_indices = [] # list of indices from 0 (root) to the given frame
@@ -89,9 +90,9 @@ def __init__(self, root_frame, dtype=torch.float32, device="cpu"):
8990
self.frame_to_idx[name_strip] = idx
9091
self.idx_to_frame[idx] = name_strip
9192
if parent_idx == -1:
92-
self.parents_indices.append([])
93+
self.parents_indices.append([idx])
9394
else:
94-
self.parents_indices.append(self.parents_indices[parent_idx] + [parent_idx])
95+
self.parents_indices.append(self.parents_indices[parent_idx] + [idx])
9596

9697
is_fixed = root.joint.joint_type == 'fixed'
9798

@@ -134,6 +135,7 @@ def to(self, dtype=None, device=None):
134135
self.identity = self.identity.to(device=self.device, dtype=self.dtype)
135136
self.parents_indices = [p.to(dtype=torch.long, device=self.device) for p in self.parents_indices]
136137
self.joint_type_indices = self.joint_type_indices.to(dtype=torch.long, device=self.device)
138+
self.joint_indices = self.joint_indices.to(dtype=torch.long, device=self.device)
137139
self.axes = self.axes.to(dtype=self.dtype, device=self.device)
138140
self.link_offsets = [l if l is None else l.to(dtype=self.dtype, device=self.device) for l in self.link_offsets]
139141
self.joint_offsets = [j if j is None else j.to(dtype=self.dtype, device=self.device) for j in
@@ -298,25 +300,36 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
298300
th = self.ensure_tensor(th)
299301
th = torch.atleast_2d(th)
300302

301-
b, n = th.shape
303+
import zpk_cpp
304+
frame_transforms = zpk_cpp.fk(
305+
frame_indices,
306+
self.axes,
307+
th,
308+
self.parents_indices,
309+
self.joint_type_indices,
310+
self.joint_indices,
311+
self.joint_offsets,
312+
self.link_offsets
313+
)
314+
315+
frame_names_and_transform3ds = {self.idx_to_frame[frame_idx]: tf.Transform3d(matrix=transform) for
316+
frame_idx, transform in frame_transforms.items()}
317+
318+
return frame_names_and_transform3ds
319+
320+
def forward_kinematics_py(self, th, frame_indices: Optional = None):
321+
if frame_indices is None:
322+
frame_indices = self.get_all_frame_indices()
323+
324+
th = self.ensure_tensor(th)
325+
th = torch.atleast_2d(th)
326+
327+
b = th.shape[0]
302328

303329
axes_expanded = self.axes.unsqueeze(0).repeat(b, 1, 1)
304330

305-
# TODO: reimplement in CPP
306-
# frame_transforms = zpk_cpp.fk(
307-
# frame_indices,
308-
# axes_expanded,
309-
# th,
310-
# self.parent_indices,
311-
# self.joint_indices,
312-
# self.joint_offsets,
313-
# self.link_offsets
314-
# )
315-
316-
from pytorch_kinematics.transforms.rotation_conversions import tensor_axis_and_angle_to_matrix
317-
from pytorch_kinematics.transforms.rotation_conversions import tensor_axis_and_d_to_pris_matrix
318331
frame_transforms = {}
319-
b = th.size(0)
332+
320333
# compute all joint transforms at once first
321334
# in order to handle multiple joint types without branching, we create all possible transforms
322335
# for all joint types and then select the appropriate one for each joint.
@@ -327,9 +340,8 @@ def forward_kinematics(self, th, frame_indices: Optional = None):
327340
frame_transform = torch.eye(4).to(th).unsqueeze(0).repeat(b, 1, 1)
328341

329342
# iterate down the list and compose the transform
330-
chain_indices = torch.cat((self.parents_indices[frame_idx], frame_idx[None]))
331-
for chain_idx in chain_indices:
332-
if chain_idx.item() in frame_transforms and False: # DEBUGGING
343+
for chain_idx in self.parents_indices[frame_idx]:
344+
if chain_idx.item() in frame_transforms:
333345
frame_transform = frame_transforms[chain_idx.item()]
334346
else:
335347
link_offset_i = self.link_offsets[chain_idx]
@@ -481,13 +493,34 @@ def jacobian(self, th, locations=None):
481493

482494
def forward_kinematics(self, th, end_only: bool = True):
483495
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
496+
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(end_only, th)
497+
498+
mat = super().forward_kinematics(th, frame_indices)
499+
500+
if end_only:
501+
return mat[self._serial_frames[-1].name]
502+
else:
503+
return mat
504+
505+
def forward_kinematics_py(self, th, end_only: bool = True):
506+
""" Like the base class, except `th` only needs to contain the joints in the SerialChain, not all joints. """
507+
frame_indices, th = self.convert_serial_inputs_to_chain_inputs(end_only, th)
508+
509+
mat = super().forward_kinematics_py(th, frame_indices)
510+
511+
if end_only:
512+
return mat[self._serial_frames[-1].name]
513+
else:
514+
return mat
515+
516+
517+
def convert_serial_inputs_to_chain_inputs(self, end_only, th):
484518
if end_only:
485519
frame_indices = self.get_frame_indices(self._serial_frames[-1].name)
486520
else:
487521
# pass through default behavior for frame indices being None, which is currently
488522
# to return all frames.
489523
frame_indices = None
490-
491524
if get_th_size(th) < self.n_joints:
492525
# if th is only a partial list of joints, assume it's a list of joints for only the serial chain.
493526
partial_th = th
@@ -500,13 +533,7 @@ def forward_kinematics(self, th, end_only: bool = True):
500533
jnt_idx = self.joint_indices[k]
501534
if frame.joint.joint_type != 'fixed':
502535
th[jnt_idx] = partial_th_i
503-
504-
mat = super().forward_kinematics(th, frame_indices)
505-
506-
if end_only:
507-
return mat[self._serial_frames[-1].name]
508-
else:
509-
return mat
536+
return frame_indices, th
510537

511538
def forward_kinematics_slow(self, th, world=None, end_only=True):
512539
if world is None:

0 commit comments

Comments
 (0)