Skip to content

Commit 7eaf478

Browse files
Revert "[LT] Store OpKind for each IR subclass in a static field"
This reverts commit ac37ddc. Reverted pytorch#76711 on behalf of https://github.com/malfet
1 parent da565c0 commit 7eaf478

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+63
-212
lines changed

Diff for: BUILD.bazel

-2
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ libtorch_cpp_generated_sources = [
123123
"torch/csrc/autograd/generated/Functions.cpp",
124124
"torch/csrc/autograd/generated/variable_factories.h",
125125
"torch/csrc/lazy/generated/LazyIr.h",
126-
"torch/csrc/lazy/generated/LazyIr.cpp",
127126
"torch/csrc/lazy/generated/LazyNativeFunctions.h",
128127
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
129128
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
@@ -1915,7 +1914,6 @@ test_suite(
19151914
for path in [
19161915
"aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp",
19171916
"aten/src/ATen/templates/DispatchKeyNativeFunctions.h",
1918-
"aten/src/ATen/templates/LazyIr.cpp",
19191917
"aten/src/ATen/templates/LazyIr.h",
19201918
"aten/src/ATen/templates/RegisterDispatchKey.cpp",
19211919
"aten/src/ATen/native/native_functions.yaml",

Diff for: aten/src/ATen/templates/LazyIr.cpp

-8
This file was deleted.

Diff for: build.bzl

-2
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ def define_targets(rules):
2727
srcs = [
2828
":DispatchKeyNativeFunctions.cpp",
2929
":DispatchKeyNativeFunctions.h",
30-
":LazyIr.cpp",
3130
":LazyIr.h",
3231
":RegisterDispatchKey.cpp",
3332
":native_functions.yaml",
@@ -112,7 +111,6 @@ _GENERATED_CPP = [
112111
"torch/csrc/autograd/generated/python_torch_functions_1.cpp",
113112
"torch/csrc/autograd/generated/python_torch_functions_2.cpp",
114113
"torch/csrc/autograd/generated/python_variable_methods.cpp",
115-
"torch/csrc/lazy/generated/LazyIr.cpp",
116114
"torch/csrc/lazy/generated/LazyNativeFunctions.cpp",
117115
"torch/csrc/lazy/generated/RegisterAutogradLazy.cpp",
118116
"torch/csrc/lazy/generated/RegisterLazy.cpp",

Diff for: caffe2/CMakeLists.txt

-2
Original file line numberDiff line numberDiff line change
@@ -352,7 +352,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
352352
)
353353
if(BUILD_LAZY_TS_BACKEND)
354354
list(APPEND GENERATED_CXX_TORCH
355-
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyIr.cpp"
356355
"${TORCH_SRC_DIR}/csrc/lazy/generated/LazyNativeFunctions.cpp"
357356
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterAutogradLazy.cpp"
358357
"${TORCH_SRC_DIR}/csrc/lazy/generated/RegisterLazy.cpp"
@@ -433,7 +432,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
433432
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.h"
434433
"${TORCH_ROOT}/aten/src/ATen/templates/DispatchKeyNativeFunctions.cpp"
435434
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.h"
436-
"${TORCH_ROOT}/aten/src/ATen/templates/LazyIr.cpp"
437435
"${TORCH_ROOT}/aten/src/ATen/templates/RegisterDispatchKey.cpp"
438436
"${TOOLS_PATH}/autograd/templates/VariableType.h"
439437
"${TOOLS_PATH}/autograd/templates/VariableType.cpp"

Diff for: test/cpp/lazy/test_ir.cpp

+2-6
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ namespace lazy {
1717

1818
class TestLeafNode : public Node {
1919
public:
20-
static const OpKind class_op_kind;
21-
2220
explicit TestLeafNode(size_t param)
2321
: Node(OpKind(), /* num_outputs */ 1),
2422
hash_(Hash(param)),
@@ -40,16 +38,14 @@ class TestLeafNode : public Node {
4038
size_t param_;
4139
};
4240

43-
const OpKind TestLeafNode::class_op_kind = OpKind();
44-
4541
TEST(IrTest, BasicTest) {
4642
NodePtr node1 = MakeNode<TestLeafNode>(1);
4743
NodePtr node2 = MakeNode<TestLeafNode>(2);
4844
EXPECT_NE(node1->hash(), node2->hash());
4945

5046
EXPECT_EQ(node1->num_outputs(), 1);
5147

52-
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get());
48+
const TestLeafNode* leafptr = NodeCast<TestLeafNode>(node1.get(), OpKind());
5349
EXPECT_TRUE(leafptr != nullptr);
5450
}
5551

@@ -106,7 +102,7 @@ TEST(IrTest, TsNodeTest) {
106102

107103
EXPECT_EQ(node1->num_outputs(), 1);
108104

109-
const TsNode* leafptr = dynamic_cast<const TsNode*>(node1.get());
105+
const TsNode* leafptr = NodeCast<TsNode>(node1.get(), OpKind(at::aten::view));
110106
EXPECT_TRUE(leafptr != nullptr);
111107
}
112108

Diff for: test/cpp/lazy/test_trie_cache.cpp

+13-17
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,8 @@ namespace lazy {
1313

1414
class TrieCacheNode : public Node {
1515
public:
16-
static const OpKind class_op_kind;
17-
1816
explicit TrieCacheNode(size_t id)
19-
: Node(class_op_kind, /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
17+
: Node(OpKind(), /* num_outputs */ 1), id_(id), hash_(Hash(id_)) {}
2018
~TrieCacheNode() override = default;
2119

2220
bool Equal(size_t id) const {
@@ -38,8 +36,6 @@ class TrieCacheNode : public Node {
3836
hash_t hash_;
3937
};
4038

41-
const OpKind TrieCacheNode::class_op_kind = OpKind();
42-
4339
TEST(TrieCacheTest, TestSinglePath) {
4440
FLAGS_torch_lazy_reuse_ir = true;
4541
TrieCache::Get()->Clear();
@@ -49,9 +45,9 @@ TEST(TrieCacheTest, TestSinglePath) {
4945
NodePtr c = MakeNode<TrieCacheNode>(2);
5046
TrieCache::Get()->ResetCurrent(); // MarkStep
5147

52-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
53-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
54-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
48+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
49+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
50+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
5551
TrieCache::Get()->ResetCurrent(); // MarkStep
5652
}
5753

@@ -71,20 +67,20 @@ TEST(TrieCacheTest, TestTwoPaths) {
7167
NodePtr c = MakeNode<TrieCacheNode>(2);
7268
TrieCache::Get()->ResetCurrent(); // MarkStep
7369

74-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
75-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
76-
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(3);
70+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
71+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
72+
NodePtr d = ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3);
7773
EXPECT_NE(d.get(), c.get());
7874
TrieCache::Get()->ResetCurrent(); // MarkStep
7975

80-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
81-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
82-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(3).get(), d.get());
76+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
77+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
78+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 3).get(), d.get());
8379
TrieCache::Get()->ResetCurrent(); // MarkStep
8480

85-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(0).get(), a.get());
86-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(1).get(), b.get());
87-
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(2).get(), c.get());
81+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 0).get(), a.get());
82+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 1).get(), b.get());
83+
EXPECT_EQ(ReuseOrMakeNode<TrieCacheNode>(OpKind(), 2).get(), c.get());
8884
TrieCache::Get()->ResetCurrent(); // MarkStep
8985
}
9086

Diff for: tools/build_variables.bzl

-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@
1111

1212
# This is duplicated in caffe2/CMakeLists.txt for now and not yet used in buck
1313
GENERATED_LAZY_TS_CPP = [
14-
"lazy/generated/LazyIr.cpp",
1514
"lazy/generated/LazyNativeFunctions.cpp",
1615
"lazy/generated/RegisterAutogradLazy.cpp",
1716
"lazy/generated/RegisterLazy.cpp",
@@ -426,7 +425,6 @@ lazy_tensor_ts_sources = [
426425
"torch/csrc/lazy/ts_backend/ops/expand.cpp",
427426
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
428427
"torch/csrc/lazy/ts_backend/ops/scalar.cpp",
429-
"torch/csrc/lazy/ts_backend/ops/to_copy.cpp",
430428
"torch/csrc/lazy/ts_backend/view_ops/as_strided.cpp",
431429
"torch/csrc/lazy/ts_backend/view_ops/as_strided_view_update.cpp",
432430
"torch/csrc/lazy/ts_backend/view_ops/diagonal.cpp",

Diff for: torch/csrc/lazy/core/ir.h

-14
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,6 @@ inline std::ostream& operator<<(std::ostream& stream, const Node& node) {
175175
return stream;
176176
}
177177

178-
// Note: Keep this version of NodeCast for smooth PyTorch/XLA migration, and
179-
// clean up once the migration is done.
180178
template <typename T>
181179
const T* NodeCast(const Node* node, OpKind op) {
182180
if (op != node->op()) {
@@ -189,18 +187,6 @@ const T* NodeCast(const Node* node, OpKind op) {
189187
#endif
190188
}
191189

192-
template <typename T>
193-
const T* NodeCast(const Node* node) {
194-
if (T::class_op_kind != node->op()) {
195-
return nullptr;
196-
}
197-
#ifdef NDEBUG
198-
return static_cast<const T*>(node);
199-
#else
200-
return &dynamic_cast<const T&>(*node);
201-
#endif
202-
}
203-
204190

205191
// Represents a specific output produced by a node. Since the output of a node
206192
// can be composed by multiple outputs, the node+index coordinates fully qualify

Diff for: torch/csrc/lazy/core/ir_builder.h

+5-5
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ namespace torch {
1515
namespace lazy {
1616

1717
template <typename T, typename... Args>
18-
NodePtr ReuseNode(Args&&... args) {
18+
NodePtr ReuseNode(OpKind op, Args&&... args) {
1919
if (FLAGS_torch_lazy_reuse_ir) {
20-
return LookupNodeFromTrieCache<T>(std::forward<Args>(args)...);
20+
return LookupNodeFromTrieCache<T>(op, std::forward<Args>(args)...);
2121
}
2222
return nullptr;
2323
}
@@ -27,16 +27,16 @@ template <typename T, typename... Args>
2727
NodePtr MakeNode(Args&&... args) {
2828
NodePtr node = std::make_shared<T>(std::forward<Args>(args)...);
2929
if (FLAGS_torch_lazy_reuse_ir) {
30-
// If ir caching is enabled, we need to record all new nodes
30+
// If ir caching is enabled, we need to record all new nodes
3131
TrieCache::Get()->Insert(node);
3232
}
3333
return node;
3434
}
3535

3636
// op is passed in for a more efficient node casting, see the implementation of NodeCast
3737
template <typename T, typename... Args>
38-
NodePtr ReuseOrMakeNode(Args&&... args) {
39-
NodePtr node = ReuseNode<T>(std::forward<Args>(args)...);
38+
NodePtr ReuseOrMakeNode(OpKind op, Args&&... args) {
39+
NodePtr node = ReuseNode<T>(op, std::forward<Args>(args)...);
4040
if (!node) {
4141
node = MakeNode<T>(std::forward<Args>(args)...);
4242
}

Diff for: torch/csrc/lazy/core/trie.h

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,11 +52,11 @@ class TORCH_API TrieCache {
5252
};
5353

5454
template <typename T, typename... Args>
55-
NodePtr LookupNodeFromTrieCache(Args&&... args) {
55+
NodePtr LookupNodeFromTrieCache(OpKind op, Args&&... args) {
5656
auto& successors = TrieCache::Get()->Current()->successors;
5757
for (auto it = successors.begin(); it != successors.end(); it++) {
5858
NodePtr ir_node = (*it)->ir_node;
59-
const T* concrete_node = NodeCast<T>(ir_node.get());
59+
const T* concrete_node = NodeCast<T>(ir_node.get(), op);
6060
if (concrete_node && concrete_node->Equal(std::forward<Args>(args)...)) {
6161
TORCH_LAZY_COUNTER("IrNodeReused::" + std::string(typeid(T).name()), 1);
6262
TrieCache::Get()->SetCurrent(it);

Diff for: torch/csrc/lazy/ts_backend/ops/batch_norm_ops.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,6 @@
44
namespace torch {
55
namespace lazy {
66

7-
const OpKind TSNativeBatchNormBackward::class_op_kind(at::aten::native_batch_norm_backward);
8-
const OpKind TSNativeBatchNormForward::class_op_kind(at::aten::native_batch_norm);
9-
107
TSNativeBatchNormBackward::TSNativeBatchNormBackward(
118
const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
129
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,

Diff for: torch/csrc/lazy/ts_backend/ops/batch_norm_ops.h

-4
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ namespace lazy {
88
// Node for the backward batch norm operator.
99
class TSNativeBatchNormBackward : public torch::lazy::TsNode {
1010
public:
11-
static const OpKind class_op_kind;
12-
1311
TSNativeBatchNormBackward(const torch::lazy::Value& grad_out, const torch::lazy::Value& input,
1412
const torch::lazy::Value& weight, const torch::lazy::Value& running_mean,
1513
const torch::lazy::Value& running_var, const torch::lazy::Value& save_mean,
@@ -37,8 +35,6 @@ class TSNativeBatchNormBackward : public torch::lazy::TsNode {
3735

3836
class TSNativeBatchNormForward : public torch::lazy::TsNode {
3937
public:
40-
static const OpKind class_op_kind;
41-
4238
TSNativeBatchNormForward(const torch::lazy::Value& input, const torch::lazy::Value& weight,
4339
const torch::lazy::Value& bias, const torch::lazy::Value& running_mean,
4440
const torch::lazy::Value& running_var, bool training,

Diff for: torch/csrc/lazy/ts_backend/ops/cast.cpp

-3
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,6 @@ Shape NodeOutputShape(const Value& input, c10::ScalarType type) {
1515
}
1616

1717
} // namespace
18-
19-
const OpKind Cast::class_op_kind(ltc_cast);
20-
2118
Cast::Cast(
2219
const Value& input,
2320
at::ScalarType dtype,

Diff for: torch/csrc/lazy/ts_backend/ops/cast.h

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ namespace lazy {
99

1010
class TORCH_API Cast : public TsNode {
1111
public:
12-
static const OpKind class_op_kind;
13-
1412
Cast(
1513
const Value& input,
1614
at::ScalarType dtype,

Diff for: torch/csrc/lazy/ts_backend/ops/device_data.cpp

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
namespace torch {
88
namespace lazy {
99

10-
const OpKind DeviceData::class_op_kind(ltc_device_data);
11-
1210
DeviceData::DeviceData(std::shared_ptr<BackendData> data)
1311
: TsNode(
1412
ltc_device_data,
@@ -24,7 +22,7 @@ std::string DeviceData::ToString() const {
2422
}
2523

2624
const DeviceData* DeviceData::Cast(const Node* node) {
27-
return NodeCast<DeviceData>(node);
25+
return NodeCast<DeviceData>(node, ltc_device_data);
2826
}
2927

3028
} // namespace lazy

Diff for: torch/csrc/lazy/ts_backend/ops/device_data.h

-2
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ namespace lazy {
88

99
class TORCH_API DeviceData : public TsNode {
1010
public:
11-
static const OpKind class_op_kind;
12-
1311
explicit DeviceData(std::shared_ptr<BackendData> data);
1412

1513
std::string ToString() const override;

Diff for: torch/csrc/lazy/ts_backend/ops/expand.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33
namespace torch {
44
namespace lazy {
55

6-
const OpKind Expand::class_op_kind(at::aten::expand);
7-
86
Expand::Expand(
97
const Value& input,
108
std::vector<int64_t> size,

Diff for: torch/csrc/lazy/ts_backend/ops/expand.h

-2
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@ namespace lazy {
99

1010
class TORCH_API Expand : public TsNode {
1111
public:
12-
static const OpKind class_op_kind;
13-
1412
Expand(const Value& input, std::vector<int64_t> size, bool is_scalar_expand);
1513

1614
std::string ToString() const override;

Diff for: torch/csrc/lazy/ts_backend/ops/random_ops.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
namespace torch {
55
namespace lazy {
66

7-
const OpKind Normal::class_op_kind(c10::Symbol::fromQualString("aten::normal_"));
8-
97
Normal::Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes)
108
: torch::lazy::TsNode(torch::lazy::OpKind(c10::Symbol::fromQualString("aten::normal_")),
119
{self}, std::move(shapes),

Diff for: torch/csrc/lazy/ts_backend/ops/random_ops.h

-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@ namespace lazy {
77

88
class Normal : public torch::lazy::TsNode {
99
public:
10-
static const OpKind class_op_kind;
11-
1210
Normal(const torch::lazy::Value& self, const double& mean, const double& std, std::vector<torch::lazy::Shape>&& shapes);
1311

1412
std::string ToString() const override;

Diff for: torch/csrc/lazy/ts_backend/ops/scalar.cpp

-2
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@ namespace lazy {
1010

1111
using at::operator<<;
1212

13-
const OpKind Scalar::class_op_kind(at::prim::Constant);
14-
1513
Scalar::Scalar(const at::Scalar& value, Shape shape)
1614
: TsNode(
1715
OpKind(at::prim::Constant),

Diff for: torch/csrc/lazy/ts_backend/ops/scalar.h

-2
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,6 @@ namespace lazy {
1212
// computation graph.
1313
class TORCH_API Scalar : public TsNode {
1414
public:
15-
static const OpKind class_op_kind;
16-
1715
Scalar(const at::Scalar& value, Shape shape);
1816
Scalar(const at::Scalar& value, c10::ScalarType type);
1917

Diff for: torch/csrc/lazy/ts_backend/ops/to_copy.cpp

-9
This file was deleted.

0 commit comments

Comments
 (0)