Skip to content

Commit 0040646

Browse files
Siqiao Chenfacebook-github-bot
Siqiao Chen
authored andcommitted
fix the type hack in dramKV wrapper (#4012)
Summary: Pull Request resolved: #4012 X-link: facebookresearch/FBGEMM#1099 in previous diff, we hard code dram wrapper to be float. In this diff, we allow call site to customize type of weight stored in dram. Currently it support FP32 and FP16. Reviewed By: emlin Differential Revision: D73477947 fbshipit-source-id: 25dda60b61a2a257e3548a7710df75ef4d6b9f88
1 parent 566f289 commit 0040646

File tree

1 file changed

+92
-18
lines changed

1 file changed

+92
-18
lines changed

fbgemm_gpu/src/dram_kv_embedding_cache/dram_kv_embedding_cache_wrapper.h

Lines changed: 92 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,15 @@
88

99
#pragma once
1010

11+
#include "../ssd_split_embeddings_cache/kv_tensor_wrapper.h"
1112
#include "dram_kv_embedding_cache.h"
1213

14+
namespace {
15+
using DramKVEmbeddingCacheVariant = std::variant<
16+
std::shared_ptr<kv_mem::DramKVEmbeddingCache<float>>,
17+
std::shared_ptr<kv_mem::DramKVEmbeddingCache<at::Half>>>;
18+
}
19+
1320
namespace kv_mem {
1421

1522
class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
@@ -21,61 +28,128 @@ class DramKVEmbeddingCacheWrapper : public torch::jit::CustomClassHolder {
2128
int64_t num_shards = 8,
2229
int64_t num_threads = 32,
2330
int64_t row_storage_bitwidth = 32,
24-
int64_t weight_ttl_in_hours = 2)
25-
: impl_(std::make_shared<kv_mem::DramKVEmbeddingCache<float>>(
26-
max_D,
27-
uniform_init_lower,
28-
uniform_init_upper,
29-
num_shards,
30-
num_threads,
31-
row_storage_bitwidth,
32-
weight_ttl_in_hours)) {}
31+
int64_t weight_ttl_in_hours = 2) {
32+
if (row_storage_bitwidth == 16) {
33+
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<at::Half>>(
34+
max_D,
35+
uniform_init_lower,
36+
uniform_init_upper,
37+
num_shards,
38+
num_threads,
39+
row_storage_bitwidth,
40+
weight_ttl_in_hours);
41+
} else if (row_storage_bitwidth == 32) {
42+
impl_ = std::make_shared<kv_mem::DramKVEmbeddingCache<float>>(
43+
max_D,
44+
uniform_init_lower,
45+
uniform_init_upper,
46+
num_shards,
47+
num_threads,
48+
row_storage_bitwidth,
49+
weight_ttl_in_hours);
50+
} else {
51+
throw std::runtime_error("Failed to create recording device");
52+
}
53+
}
3354

3455
void set_cuda(
3556
at::Tensor indices,
3657
at::Tensor weights,
3758
at::Tensor count,
3859
int64_t timestep,
3960
bool is_bwd) {
40-
return impl_->set_cuda(indices, weights, count, timestep, is_bwd);
61+
return std::visit(
62+
[&indices, &weights, &count, &timestep](auto& ptr) {
63+
if (ptr) {
64+
ptr->set_cuda(indices, weights, count, timestep);
65+
}
66+
},
67+
impl_);
4168
}
4269

4370
void get_cuda(at::Tensor indices, at::Tensor weights, at::Tensor count) {
44-
return impl_->get_cuda(indices, weights, count);
71+
return std::visit(
72+
[&indices, &weights, &count](auto& ptr) {
73+
if (ptr) {
74+
ptr->get_cuda(indices, weights, count);
75+
}
76+
},
77+
impl_);
4578
}
4679

4780
void set(at::Tensor indices, at::Tensor weights, at::Tensor count) {
48-
return impl_->set(indices, weights, count);
81+
return std::visit(
82+
[&indices, &weights, &count](auto& ptr) {
83+
if (ptr) {
84+
ptr->set(indices, weights, count);
85+
}
86+
},
87+
impl_);
4988
}
5089

5190
void flush() {
52-
return impl_->flush();
91+
return std::visit(
92+
[](auto& ptr) {
93+
if (ptr) {
94+
ptr->flush();
95+
}
96+
},
97+
impl_);
5398
}
5499

55100
void set_range_to_storage(
56101
const at::Tensor& weights,
57102
const int64_t start,
58103
const int64_t length) {
59-
return impl_->set_range_to_storage(weights, start, length);
104+
return std::visit(
105+
[&weights, &start, &length](auto& ptr) {
106+
if (ptr) {
107+
ptr->set_range_to_storage(weights, start, length);
108+
}
109+
},
110+
impl_);
60111
}
61112

62113
void get(
63114
at::Tensor indices,
64115
at::Tensor weights,
65116
at::Tensor count,
66117
int64_t sleep_ms) {
67-
return impl_->get(indices, weights, count, sleep_ms);
118+
return std::visit(
119+
[&indices, &weights, &count, sleep_ms](auto& ptr) {
120+
if (ptr) {
121+
ptr->get(indices, weights, count, sleep_ms);
122+
}
123+
},
124+
impl_);
68125
}
69126

70127
void wait_util_filling_work_done() {
71-
return impl_->wait_util_filling_work_done();
128+
return std::visit(
129+
[](auto& ptr) {
130+
if (ptr) {
131+
ptr->wait_util_filling_work_done();
132+
}
133+
},
134+
impl_);
72135
}
73136

74137
at::Tensor get_keys_in_range(int64_t start, int64_t end) {
75-
return impl_->get_keys_in_range(start, end);
138+
return std::visit(
139+
[&start, &end](auto& ptr) {
140+
if (ptr) {
141+
return ptr->get_keys_in_range(start, end);
142+
}
143+
return at::empty({0});
144+
},
145+
impl_);
76146
}
77147

78-
std::shared_ptr<kv_mem::DramKVEmbeddingCache<float>> impl_;
148+
private:
149+
// friend class EmbeddingRocksDBWrapper;
150+
friend class ssd::KVTensorWrapper;
151+
152+
DramKVEmbeddingCacheVariant impl_;
79153
};
80154

81155
} // namespace kv_mem

0 commit comments

Comments
 (0)