8
8
9
9
#pragma once
10
10
11
+ #include " ../ssd_split_embeddings_cache/kv_tensor_wrapper.h"
11
12
#include " dram_kv_embedding_cache.h"
12
13
14
+ namespace {
15
+ using DramKVEmbeddingCacheVariant = std::variant<
16
+ std::shared_ptr<kv_mem::DramKVEmbeddingCache<float >>,
17
+ std::shared_ptr<kv_mem::DramKVEmbeddingCache<at::Half>>>;
18
+ }
19
+
13
20
namespace kv_mem {
14
21
15
22
class DramKVEmbeddingCacheWrapper : public torch ::jit::CustomClassHolder {
@@ -21,61 +28,128 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
21
28
int64_t num_shards = 8 ,
22
29
int64_t num_threads = 32 ,
23
30
int64_t row_storage_bitwidth = 32 ,
24
- int64_t weight_ttl_in_hours = 2 )
25
- : impl_(std::make_shared<kv_mem::DramKVEmbeddingCache<float >>(
26
- max_D,
27
- uniform_init_lower,
28
- uniform_init_upper,
29
- num_shards,
30
- num_threads,
31
- row_storage_bitwidth,
32
- weight_ttl_in_hours)) {}
31
+ int64_t weight_ttl_in_hours = 2 ) {
32
+ if (row_storage_bitwidth == 16 ) {
33
+ impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
34
+ max_D,
35
+ uniform_init_lower,
36
+ uniform_init_upper,
37
+ num_shards,
38
+ num_threads,
39
+ row_storage_bitwidth,
40
+ weight_ttl_in_hours);
41
+ } else if (row_storage_bitwidth == 32 ) {
42
+ impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<float >>(
43
+ max_D,
44
+ uniform_init_lower,
45
+ uniform_init_upper,
46
+ num_shards,
47
+ num_threads,
48
+ row_storage_bitwidth,
49
+ weight_ttl_in_hours);
50
+ } else {
51
+ throw std::runtime_error (" Failed to create recording device" );
52
+ }
53
+ }
33
54
34
55
void set_cuda (
35
56
at::Tensor indices,
36
57
at::Tensor weights,
37
58
at::Tensor count,
38
59
int64_t timestep,
39
60
bool is_bwd) {
40
- return impl_->set_cuda (indices, weights, count, timestep, is_bwd);
61
+ return std::visit (
62
+ [&indices, &weights, &count, ×tep](auto & ptr) {
63
+ if (ptr) {
64
+ ptr->set_cuda (indices, weights, count, timestep);
65
+ }
66
+ },
67
+ impl_);
41
68
}
42
69
43
70
void get_cuda (at::Tensor indices, at::Tensor weights, at::Tensor count) {
44
- return impl_->get_cuda (indices, weights, count);
71
+ return std::visit (
72
+ [&indices, &weights, &count](auto & ptr) {
73
+ if (ptr) {
74
+ ptr->get_cuda (indices, weights, count);
75
+ }
76
+ },
77
+ impl_);
45
78
}
46
79
47
80
void set (at::Tensor indices, at::Tensor weights, at::Tensor count) {
48
- return impl_->set (indices, weights, count);
81
+ return std::visit (
82
+ [&indices, &weights, &count](auto & ptr) {
83
+ if (ptr) {
84
+ ptr->set (indices, weights, count);
85
+ }
86
+ },
87
+ impl_);
49
88
}
50
89
51
90
void flush () {
52
- return impl_->flush ();
91
+ return std::visit (
92
+ [](auto & ptr) {
93
+ if (ptr) {
94
+ ptr->flush ();
95
+ }
96
+ },
97
+ impl_);
53
98
}
54
99
55
100
void set_range_to_storage (
56
101
const at::Tensor& weights,
57
102
const int64_t start,
58
103
const int64_t length) {
59
- return impl_->set_range_to_storage (weights, start, length);
104
+ return std::visit (
105
+ [&weights, &start, &length](auto & ptr) {
106
+ if (ptr) {
107
+ ptr->set_range_to_storage (weights, start, length);
108
+ }
109
+ },
110
+ impl_);
60
111
}
61
112
62
113
void get (
63
114
at::Tensor indices,
64
115
at::Tensor weights,
65
116
at::Tensor count,
66
117
int64_t sleep_ms) {
67
- return impl_->get (indices, weights, count, sleep_ms);
118
+ return std::visit (
119
+ [&indices, &weights, &count, sleep_ms](auto & ptr) {
120
+ if (ptr) {
121
+ ptr->get (indices, weights, count, sleep_ms);
122
+ }
123
+ },
124
+ impl_);
68
125
}
69
126
70
127
void wait_util_filling_work_done () {
71
- return impl_->wait_util_filling_work_done ();
128
+ return std::visit (
129
+ [](auto & ptr) {
130
+ if (ptr) {
131
+ ptr->wait_util_filling_work_done ();
132
+ }
133
+ },
134
+ impl_);
72
135
}
73
136
74
137
at::Tensor get_keys_in_range (int64_t start, int64_t end) {
75
- return impl_->get_keys_in_range (start, end);
138
+ return std::visit (
139
+ [&start, &end](auto & ptr) {
140
+ if (ptr) {
141
+ return ptr->get_keys_in_range (start, end);
142
+ }
143
+ return at::empty ({0 });
144
+ },
145
+ impl_);
76
146
}
77
147
78
- std::shared_ptr<kv_mem::DramKVEmbeddingCache<float >> impl_;
148
+ private:
149
+ // friend class EmbeddingRocksDBWrapper;
150
+ friend class ssd ::KVTensorWrapper;
151
+
152
+ DramKVEmbeddingCacheVariant impl_;
79
153
};
80
154
81
155
} // namespace kv_mem
0 commit comments