forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcpp_c10d_extension.hpp
122 lines (95 loc) · 3.88 KB
/
cpp_c10d_extension.hpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#pragma once
#include <torch/extension.h>
#include <deque>
#include <exception>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
#include <chrono>
#include <pybind11/chrono.h>
#include <c10d/ProcessGroup.hpp>
#include <c10d/Store.hpp>
#include <c10d/Types.hpp>
#include <c10d/Utils.hpp>
namespace c10d {
//
// ProcessGroupTest implements dummy bindings for c10d.
//
class ProcessGroupTest : public ProcessGroup {
public:
class WorkTest : public ProcessGroup::Work {
public:
WorkTest() {}
virtual ~WorkTest();
bool isCompleted() override;
bool isSuccess() const override;
bool wait(std::chrono::milliseconds timeout) override;
protected:
friend class ProcessGroupTest;
};
explicit ProcessGroupTest(int rank = -1, int size = -1);
virtual ~ProcessGroupTest();
c10::intrusive_ptr<ProcessGroup::Work> broadcast(
std::vector<at::Tensor>& data,
const BroadcastOptions& opts = BroadcastOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce(
std::vector<at::Tensor>& tensors,
const AllreduceOptions& opts = AllreduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allreduce_coalesced(
std::vector<at::Tensor>& tensors,
const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce(
std::vector<at::Tensor>& tensors,
const ReduceOptions& opts = ReduceOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> allgather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> _allgather_base(
at::Tensor& outputBuffer,
at::Tensor& inputBuffer,
const AllgatherOptions& opts = AllgatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> barrier(
const BarrierOptions& opts = BarrierOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> gather(
std::vector<std::vector<at::Tensor>>& outputTensors,
std::vector<at::Tensor>& inputTensors,
const GatherOptions& opts = GatherOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ScatterOptions& opts = ScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> reduce_scatter(
std::vector<at::Tensor>& outputTensors,
std::vector<std::vector<at::Tensor>>& inputTensors,
const ReduceScatterOptions& opts = ReduceScatterOptions()) override;
c10::intrusive_ptr<ProcessGroup::Work> send(
std::vector<at::Tensor>& tensors,
int dstRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recv(
std::vector<at::Tensor>& tensors,
int srcRank,
int tag) override;
c10::intrusive_ptr<ProcessGroup::Work> recvAnysource(
std::vector<at::Tensor>& tensor,
int tag) override;
// Create a new ProcessGroupTest instance
static c10::intrusive_ptr<ProcessGroup> createProcessGroupTest(
const c10::intrusive_ptr<::c10d::Store>& store,
int rank,
int size,
const std::chrono::duration<float>& timeout);
static void ProcessGroupTestConstructor() __attribute__((constructor)) {
py::object module = py::module::import("torch.distributed");
py::object register_backend = module.attr("Backend").attr("register_backend");
// The first parameter is the backend name used by user in invoking
// torch.distributed.init_process_group().
// Note it could be different with module name. For example, the module
// name is "torch_test" but the backend name is "test".
// The second parameter is the instantiation function.
register_backend("test", py::cpp_function(createProcessGroupTest));
}
};
} // namespace c10d