Skip to content

Commit d9386b6

Browse files
committed
add internal utilities for debugging XLA objects from Idris
1 parent 016c4cd commit d9386b6

File tree

16 files changed

+150
-7
lines changed

16 files changed

+150
-7
lines changed

backend/src/ffi.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,13 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
16+
#include <iostream>
17+
1618
extern "C" {
19+
void print_address(void* ptr) {
20+
std::cout << ptr << std::endl;
21+
}
22+
1723
int sizeof_int() {
1824
return sizeof(int);
1925
}

backend/src/ffi.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
1616
extern "C" {
17+
void print_address(void* ptr); // doesn't belong here?
18+
1719
int sizeof_int();
1820

1921
void set_array_int(int* arr, int idx, int value);

backend/src/tensorflow/compiler/xla/client/xla_builder.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ extern "C" {
7676
return reinterpret_cast<XlaComputation*>(non_stack);
7777
}
7878

79+
Shape* XlaBuilder_GetShape(XlaBuilder& s, XlaOp& op) {
80+
auto& s_ = reinterpret_cast<xla::XlaBuilder&>(s);
81+
auto& op_ = reinterpret_cast<xla::XlaOp&>(op);
82+
xla::Shape shape = s_.GetShape(op_).ConsumeValueOrDie();
83+
return reinterpret_cast<Shape*>(new xla::Shape(shape));
84+
}
85+
7986
const char* XlaBuilder_OpToString(XlaBuilder& s, XlaOp& op) {
8087
auto& s_ = reinterpret_cast<xla::XlaBuilder&>(s);
8188
auto& op_ = reinterpret_cast<xla::XlaOp&>(op);

backend/src/tensorflow/compiler/xla/client/xla_builder.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@ extern "C" {
4848
const char* XlaBuilder_name(XlaBuilder& s);
4949
XlaBuilder* CreateSubBuilder(XlaBuilder& s, const char* computation_name);
5050
XlaComputation* XlaBuilder_Build(XlaBuilder& s, XlaOp& root);
51+
Shape* XlaBuilder_GetShape(XlaBuilder& s, XlaOp& op);
5152
const char* XlaBuilder_OpToString(XlaBuilder& s, XlaOp& op);
5253

5354
/*

backend/src/tensorflow/compiler/xla/shape.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,30 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
*/
16+
#include <algorithm>
17+
#include <cstring>
18+
#include <string>
19+
1620
#include "tensorflow/compiler/xla/shape.h"
1721

1822
#include "shape.h"
1923

2024
extern "C" {
25+
const char* c_string_copy(std::string str) {
26+
char *res = NULL;
27+
auto len = str.length();
28+
res = (char *) malloc(len + 1);
29+
strncpy(res, str.c_str(), len);
30+
res[len] = '\0';
31+
return res;
32+
}
33+
2134
void Shape_delete(Shape* s) {
2235
delete reinterpret_cast<xla::Shape*>(s);
2336
}
37+
38+
const char* Shape_DebugString(Shape& s) {
39+
auto& s_ = reinterpret_cast<xla::Shape&>(s);
40+
return c_string_copy(s_.DebugString());
41+
}
2442
}

spidr.ipkg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ modules =
99
BayesianOptimization.Acquisition,
1010
BayesianOptimization.Morphisms,
1111

12+
Compiler.Debug,
1213
Compiler.Eval,
1314
Compiler.Expr,
1415
Compiler.LiteralRW,

src/Compiler/Debug.idr

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
{--
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
--}
16+
module Compiler.Debug
17+
18+
import Compiler.Xla.TensorFlow.Compiler.Xla.Shape
19+
import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder
20+
21+
export
22+
shapeString : XlaBuilder -> XlaOp -> IO String
23+
shapeString builder op = map debugString (getShape builder op)

src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ export
3636
%foreign (libxla "XlaBuilder_Build")
3737
prim__build : GCAnyPtr -> GCAnyPtr -> AnyPtr
3838

39+
export
40+
%foreign (libxla "XlaBuilder_GetShape")
41+
prim__getShape : GCAnyPtr -> GCAnyPtr -> PrimIO AnyPtr
42+
3943
export
4044
%foreign (libxla "XlaBuilder_OpToString")
4145
prim__opToString : GCAnyPtr -> GCAnyPtr -> String

src/Compiler/Xla/Prim/TensorFlow/Compiler/Xla/Shape.idr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,7 @@ import Compiler.Xla.Prim.Util
2222
export
2323
%foreign (libxla "Shape_delete")
2424
prim__delete : AnyPtr -> PrimIO ()
25+
26+
export
27+
%foreign (libxla "Shape_DebugString")
28+
prim__debugString : GCAnyPtr -> String

src/Compiler/Xla/Prim/Util.idr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,10 @@ public export
2121
libxla : String -> String
2222
libxla fname = "C:" ++ fname ++ ",libc_xla_extension"
2323

24+
export
25+
%foreign (libxla "print_address")
26+
prim__printAddress : GCAnyPtr -> PrimIO ()
27+
2428
export
2529
%foreign (libxla "sizeof_int")
2630
sizeofInt : Int

src/Compiler/Xla/TensorFlow/Compiler/Xla/Client/XlaBuilder.idr

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,14 @@ build (MkXlaBuilder ptr) (MkXlaOp root)= do
6565
computationPtr <- onCollectAny computationPtr XlaComputation.delete
6666
pure (MkXlaComputation computationPtr)
6767

68+
-- is this in IO?
69+
export
70+
getShape : HasIO io => XlaBuilder -> XlaOp -> io Xla.Shape
71+
getShape (MkXlaBuilder builder) (MkXlaOp op) = do
72+
shape <- primIO $ prim__getShape builder op
73+
shape <- onCollectAny shape Shape.delete
74+
pure (MkShape shape)
75+
6876
export
6977
opToString : XlaBuilder -> XlaOp -> String
7078
opToString (MkXlaBuilder builderPtr) (MkXlaOp opPtr) = prim__opToString builderPtr opPtr

src/Compiler/Xla/TensorFlow/Compiler/Xla/Shape.idr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,7 @@ namespace Xla
2525
export
2626
delete : AnyPtr -> IO ()
2727
delete = primIO . prim__delete
28+
29+
export
30+
debugString : Shape -> String
31+
debugString (MkShape shape) = prim__debugString shape

src/Compiler/Xla/Util.idr

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,10 @@ boolToCInt : Bool -> Int
3333
boolToCInt True = 1
3434
boolToCInt False = 0
3535

36+
export
37+
printAddress : GCAnyPtr -> IO ()
38+
printAddress = primIO . prim__printAddress
39+
3640
public export
3741
data IntArray : Type where
3842
MkIntArray : GCPtr Int -> IntArray

test.ipkg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ executable = test
1616
main = Main
1717

1818
modules =
19+
Unit.Compiler.TestDebug,
1920
Unit.Model.TestKernel,
2021
Unit.Util.TestHashable,
2122
Unit.TestDistribution,

test/Main.idr

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import Hedgehog
2121
import TestUtils
2222
import Utils.TestComparison
2323

24+
import Unit.Compiler.TestDebug
2425
import Unit.Model.TestKernel
2526
import Unit.Util.TestHashable
2627
import Unit.TestDistribution
@@ -32,11 +33,12 @@ covering
3233
main : IO ()
3334
main = test [
3435
Utils.TestComparison.group
35-
, TestUtils.group
36-
, Unit.Util.TestHashable.group
37-
, Unit.TestUtil.group
38-
, Unit.TestLiteral.group
39-
, Unit.TestTensor.group
40-
, Unit.TestDistribution.group
41-
, Unit.Model.TestKernel.group
36+
-- , TestUtils.group
37+
-- , Unit.Util.TestHashable.group
38+
-- , Unit.TestUtil.group
39+
-- , Unit.TestLiteral.group
40+
, Unit.Compiler.TestDebug.group
41+
-- , Unit.TestTensor.group
42+
-- , Unit.TestDistribution.group
43+
-- , Unit.Model.TestKernel.group
4244
]

test/Unit/Compiler/TestDebug.idr

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
{--
2+
Copyright 2022 Joel Berkeley
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
--}
16+
module Unit.Compiler.TestDebug
17+
18+
import Compiler.Computation
19+
import Compiler.Debug
20+
import Compiler.LiteralRW
21+
import Compiler.Xla.TensorFlow.Compiler.Xla.Literal
22+
import Compiler.Xla.TensorFlow.Compiler.Xla.XlaData
23+
import Compiler.Xla.TensorFlow.Compiler.Xla.Client.XlaBuilder
24+
import Literal
25+
26+
import Utils.Comparison
27+
import Utils.Cases
28+
29+
export
30+
shapeString : Property
31+
shapeString = fixedProperty $ do
32+
let str : String = unsafePerformIO $ do
33+
lit <- write {dtype=S32} [[0, 1, 2], [3, 4, 5]]
34+
builder <- mkXlaBuilder ""
35+
op <- constantLiteral builder lit
36+
shapeString builder op
37+
38+
str ===
39+
"element_type: S32\n" ++
40+
"dimensions: 2\n" ++
41+
"dimensions: 3\n" ++
42+
"layout {\n" ++
43+
" minor_to_major: 1\n" ++
44+
" minor_to_major: 0\n" ++
45+
" format: DENSE\n" ++
46+
"}\n" ++
47+
"is_dynamic_dimension: false\n" ++
48+
"is_dynamic_dimension: false\n"
49+
50+
export covering
51+
group : Group
52+
group = MkGroup "Debug" $ [
53+
("shape string", shapeString)
54+
]

0 commit comments

Comments
 (0)