From b2122ac15ad6158ca28b6d21b1e44fa63e508f14 Mon Sep 17 00:00:00 2001 From: ruclz Date: Tue, 21 Jan 2025 07:06:10 +0000 Subject: [PATCH 01/14] AMX enable --- CMakeLists.txt | 35 +- examples/cpp/example_mt_search_bf16.cpp | 578 ++++++++++++++++++++++++ hnswlib/hnswalg.h | 64 +++ hnswlib/hnswlib.h | 26 ++ hnswlib/space_ip.h | 396 +++++++++++++++- hnswlib/space_l2.h | 4 + hnswlib/stop_condition.h | 5 +- 7 files changed, 1105 insertions(+), 3 deletions(-) create mode 100644 examples/cpp/example_mt_search_bf16.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index be0d40f0..2cdf0c4f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,6 +6,7 @@ project(hnswlib include(GNUInstallDirs) include(CheckCXXCompilerFlag) +SET(AMXEnable true) add_library(hnswlib INTERFACE) add_library(hnswlib::hnswlib ALIAS hnswlib) @@ -48,11 +49,39 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) + SET( CMAKE_CXX_FLAGS "-O0 -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() +if(AMXEnable) + set(CMAKE_CXX_FLAGS_BACKUP "${CMAKE_CXX_FLAGS}") + # 添加需要测试的编译器选项 + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -mamx-bf16") + + # 测试代码片段 + set(SOURCE_CODE " + #if !defined(__AMX_BF16__) + #error \"AMX not supported\" + #endif + int main() { return 0; } + ") + + check_cxx_source_compiles("${SOURCE_CODE}" HAS_AMX_SUPPORT) + + if(HAS_AMX_SUPPORT) + message(STATUS "Compiler supports AMX") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP} -mamx-bf16") + #add_compile_definitions(USE_AMX=1) + else() + message(STATUS "Compiler does NOT support AMX") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") + endif() +endif() + + + # 恢复CMAKE_REQUIRED_FLAGS的原始值 + # examples add_executable(example_search examples/cpp/example_search.cpp) target_link_libraries(example_search hnswlib) @@ -72,6 +101,10 @@ if(HNSWLIB_EXAMPLES) add_executable(example_mt_search examples/cpp/example_mt_search.cpp) target_link_libraries(example_mt_search hnswlib) + + add_executable(example_mt_search_bf16 examples/cpp/example_mt_search_bf16.cpp) + target_link_libraries(example_mt_search_bf16 hnswlib) + add_executable(example_mt_filter examples/cpp/example_mt_filter.cpp) target_link_libraries(example_mt_filter hnswlib) diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp new file mode 100644 index 00000000..e00306c4 --- /dev/null +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -0,0 +1,578 @@ +#include "../../hnswlib/hnswlib.h" +#include + + +// Multithreaded executor +// The helper function copied from python_bindings/bindings.cpp (and that itself is copied from nmslib) +// An alternative is using #pragme omp parallel for or any other C++ threading + +template +using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); +static int8_t vector_dot_product(const void* a, const void* b, const void *qty_ptr) { + + uint32_t length = * (uint32_t*)qty_ptr; + + int32_t final_result = 0; + size_t i = 0; + int8_t *a_tmp=(int8_t *)a; + int8_t *b_tmp=(int8_t *)b; + if(length>=64){ + __m512i sum = _mm512_setzero_si512(); // 初始化累加和为 0 + for (; i+64 <= length; i += 64) { + // 加载数据 + __m512i va = _mm512_loadu_si512((void*)&a_tmp[i]); + __m512i vb = _mm512_loadu_si512((void*)&b_tmp[i]); + // 执行点积运算 + //std::cout << "we are 32 lines" <(temp_x); + __m512bh v2_f16 = reinterpret_cast<__m512bh&>(temp_y); + + // 计算BF16的点积,并将结果累加到vr_f32 + vr_f32 = _mm512_dpbf16_ps(vr_f32, v1_f16, v2_f16); + } + + // 将vr_f32寄存器的值存入result数组 + _mm512_storeu_ps(result, vr_f32); + + // 累加result数组的所有元素,获得最终的点积结果 + float dot_product = 0.0f; + for (int j = 0; j < 16; j++) { + dot_product += result[j]; + } + + // 处理剩余的元素(小于32的部分) +/* for (; i < dim; i++) { + float x_val = bf162float(x[i]); + float y_val = bf162float(y[i]); + dot_product += x_val * y_val; + } */ + //printf("%d %f ",dim,dot_product); + return 1-dot_product; +} + + +static int8_t fvec_inner_product_int8_avx2int8(const void* a, const void* b, const void *qty_ptr) { + //exit(-1); + const uint8_t* pvec_u8 = (const uint8_t*)a; + const int8_t* pvec_s8 = (const int8_t*)b; + size_t qty32 = *((size_t*)qty_ptr) / 32; + const uint8_t* pend_u8 = pvec_u8 + 32 * qty32; + + // 初始化累加和为 0 + __m256i sum256 = _mm256_setzero_si256(); + __m256i v1, v2, v3; + + // 创建一个包含 1 的 128 位向量 + __m128i one = _mm_set1_epi16(1); + // 广播 1 到 256 位向量 + __m256i agg_base = _mm256_broadcastw_epi16(one); + + while (pvec_u8 < pend_u8) { + v1 = _mm256_loadu_si256((__m256i*)pvec_u8); + v2 = _mm256_loadu_si256((__m256i*)pvec_s8); + v3 = _mm256_maddubs_epi16(v1, v2); + sum256 = _mm256_add_epi32(sum256, _mm256_madd_epi16(v3, agg_base)); + pvec_u8 += 32; + pvec_s8 += 32; + } + + // 处理剩余数据 + for (size_t i = 32 * qty32; i < *((size_t*)qty_ptr); i++) { + sum256 = _mm256_add_epi32(sum256, _mm256_set1_epi32(pvec_u8[i] * pvec_s8[i])); + } + + // 将 SIMD 寄存器中的结果累积到一个标量值 + int32_t result[8]; + _mm256_storeu_si256((__m256i*)result, sum256); + + int8_t dotsum = 0; + for (int i = 0; i < 8; ++i) { + dotsum += result[i]; + } + std::cout<127) res=127; + else if(res<-128) res=-128; + return static_cast(res); +} + +class Int8InnerProductSpace : public hnswlib::SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + Int8InnerProductSpace(size_t dim) { + fstdistfunc_ = vector_dot_product_opt_avx512; + dim_ = dim; + data_size_ = dim * sizeof(int8_t); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + void *get_dist_func_param() { + return &dim_; + } + ~Int8InnerProductSpace() {} +}; +class Bf16InnerProductSpace : public hnswlib::SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + Bf16InnerProductSpace(size_t dim) { + fstdistfunc_ = vector_dot_product_bf16; + dim_ = dim * 2; + data_size_ = dim * 2 * sizeof(uint16_t); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + void *get_dist_func_param() { + return &dim_; + } + ~Bf16InnerProductSpace() {} +}; + +void setThreadAffinity(std::thread::native_handle_type handle, size_t cpuId) { + cpu_set_t cpuset; + CPU_ZERO(&cpuset); + CPU_SET(cpuId, &cpuset); + + int rc = pthread_setaffinity_np(handle, sizeof(cpuset), &cpuset); + if (rc != 0) { + throw std::system_error(rc, std::generic_category(), "pthread_setaffinity_np"); + } +} +template +inline void ParallelFor_Build(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + + int dimSizeperThread = (end-start)/numThreads; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + setThreadAffinity(pthread_self(), threadId); + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + +template +inline void ParallelFor(size_t start, size_t end, size_t numThreads, Function fn) { + if (numThreads <= 0) { + numThreads = std::thread::hardware_concurrency(); + } + + if (numThreads == 1) { + for (size_t id = start; id < end; id++) { + fn(id, 0); + } + } else { + std::vector threads; + std::atomic current(start); + + // keep track of exceptions in threads + // https://stackoverflow.com/a/32428427/1713196 + std::exception_ptr lastException = nullptr; + std::mutex lastExceptMutex; + + + int dimSizeperThread = (end-start)/numThreads; + + for (size_t threadId = 0; threadId < numThreads; ++threadId) { + threads.push_back(std::thread([&, threadId] { + while (true) { + size_t id = current.fetch_add(1); + + if (id >= end) { + break; + } + + try { + fn(id, threadId); + } catch (...) { + std::unique_lock lastExcepLock(lastExceptMutex); + lastException = std::current_exception(); + /* + * This will work even when current is the largest value that + * size_t can fit, because fetch_add returns the previous value + * before the increment (what will result in overflow + * and produce 0 instead of current + 1). + */ + current = end; + break; + } + } + })); + } + for (auto &thread : threads) { + thread.join(); + } + if (lastException) { + std::rethrow_exception(lastException); + } + } +} + +int call_scalar(hnswlib::HierarchicalNSW* alg_hnsw,Int8InnerProductSpace & space,int8_t* data,int dim, size_t max_elements,int top_k,int num_threads){ + std::vector neighbors(max_elements); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); + hnswlib::labeltype label = result.top().second; + neighbors[row] = label; + }); + float correct = 0; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = neighbors[i]; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + return 0; +} + +// int call_AMX(hnswlib::HierarchicalNSW* alg_hnsw,Int8InnerProductSpace & space,int8_t* data,int dim, size_t max_elements,int top_k,int num_threads){ +// //init_onednn(); +// std::vector neighbors(max_elements); +// ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { +// std::priority_queue> result = alg_hnsw->searchKnnAMX(data + dim * row, 1); +// hnswlib::labeltype label = result.top().second; +// neighbors[row] = label; +// }); +// float correct = 0; +// for (int i = 0; i < max_elements; i++) { +// hnswlib::labeltype label = neighbors[i]; +// if (label == i) correct++; +// } +// float recall = correct / max_elements; +// std::cout << "Recall: " << recall << "\n"; +// return 0; +// } + +int call_scalar_fp32(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::InnerProductSpace& space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ + std::vector neighbors(max_elements); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, top_k); + hnswlib::labeltype label = result.top().second; + neighbors[row] = label; + }); + float correct = 0; + for (size_t i = 0; i < max_elements; i++) { + hnswlib::labeltype label = neighbors[i]; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + return 0; +} +// int call_scalar_bf16(hnswlib::HierarchicalNSW* alg_hnsw,Bf16InnerProductSpace& space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ +// std::vector neighbors(max_elements); +// ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { +// std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); +// hnswlib::labeltype label = result.top().second; +// neighbors[row] = label; +// }); +// float correct = 0; +// for (int i = 0; i < max_elements; i++) { +// hnswlib::labeltype label = neighbors[i]; +// if (label == i) correct++; +// } +// float recall = correct / max_elements; +// std::cout << "Recall: " << recall << "\n"; +// return 0; +// } + + +int call_AMX_fp32(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::InnerProductSpace & space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ + //init_onednn(); + std::vector neighbors(max_elements); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); + hnswlib::labeltype label = result.top().second; + neighbors[row] = label; + }); + float correct = 0; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = neighbors[i]; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + return 0; +} + +// int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,Bf16InnerProductSpace & space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ +// //init_onednn(); +// std::vector neighbors(max_elements); +// ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { +// std::priority_queue> result = alg_hnsw->searchKnnAMX_bf16(data + dim * row, 1); +// hnswlib::labeltype label = result.top().second; +// neighbors[row] = label; +// }); +// float correct = 0; +// for (int i = 0; i < max_elements; i++) { +// hnswlib::labeltype label = neighbors[i]; +// if (label == i) correct++; +// } +// float recall = correct / max_elements; +// std::cout << "Recall: " << recall << "\n"; +// return 0; +// } +int main() { + int true_dim=1024; + int dim = true_dim/2; // Dimension of the elements + size_t max_elements = 100*1024; // Maximum number of elements, should be known beforehand + int M = 32; // Tightly connected with internal dimensionality of the data + size_t nq = max_elements; + // strongly affects the memory consumption + int ef_construction = 200; // Controls index search speed/build speed tradeoff + int num_threads = 16; // Number of threads for operations with index + + int top_k=1; + + int iteration=3; + float correct = 0; + + // Generate random data + std::mt19937 rng; + rng.seed(47); + std::uniform_real_distribution<> distrib_real; + float* data_fp32 = (float* )aligned_alloc(64,true_dim * max_elements*sizeof(float)); + float* query_fp32 = (float* )aligned_alloc(64,true_dim * nq*sizeof(float)); + float* data_bf16 = (float* )aligned_alloc(64,dim * max_elements*sizeof(float)); + + const char* amx_bf16_env = std::getenv("BF16_AMX"); + bool amx_enable_bf16 = amx_bf16_env ? std::stoi(amx_bf16_env) : false; + + const char* amx_fp32_env = std::getenv("FP32_AMX"); + bool amx_enable_fp32 = amx_fp32_env ? std::stoi(amx_fp32_env) : false; + + const char* def_fp32_env = std::getenv("FP32_DEF"); + bool def_enable_fp32 = def_fp32_env ? std::stoi(def_fp32_env) : false; + + const char* avx512_bf16_env = std::getenv("BF16_AVX512"); + bool avx512_enable_bf16 = avx512_bf16_env ? std::stoi(avx512_bf16_env) : false; + + + + uint16_t *bf_data = (uint16_t* ) data_bf16; + for (size_t i = 0; i < true_dim * max_elements; i++) { + float tmp = (distrib_real(rng)); + data_fp32[i] = tmp; + query_fp32[i] = tmp; +/* uint32_t *int32_data =(uint32_t *) &tmp; + bf_data[i]=*int32_data >> 16; */ + } + + hnswlib::InnerProductSpace space_fp32(true_dim); + hnswlib::HierarchicalNSW* alg_hnsw_fp32 = new hnswlib::HierarchicalNSW(&space_fp32, max_elements, M, ef_construction); + // Add data to index + ParallelFor_Build(0, max_elements, 43, [&](size_t row, size_t threadId) { + alg_hnsw_fp32->addPoint((void*)(data_fp32 + true_dim * row), row); + }); + + // Query the elements for themselves and measure recall + + // Bf16InnerProductSpace space_bf16(dim); + // hnswlib::HierarchicalNSW* alg_hnsw_bf16 = new hnswlib::HierarchicalNSW(&space_bf16, max_elements, M, ef_construction); + // ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + // alg_hnsw_bf16->addPoint((void*)(data_bf16 + dim * row), row); + // }); + + + + std::chrono::_V2::system_clock::time_point start_scalar_fp32,end_scalar_fp32, + start_scalar_bf16,end_scalar_bf16, + start_AMX_fp32,end_AMX_fp32, + start_AMX_bf16,end_AMX_bf16; + // if(def_enable_fp32){ + // std::cout << "Default FP32 search start." <<"\n"; + // start_scalar_fp32 = std::chrono::high_resolution_clock::now(); + // for(int i=0;i duration_scalar_fp32 = end_scalar_fp32 - start_scalar_fp32; + std::chrono::duration duration_scalar_bf16 = end_scalar_bf16 - start_scalar_bf16; + std::chrono::duration duration_AMX_fp32 = end_AMX_fp32 - start_AMX_fp32; + std::chrono::duration duration_AMX_bf16 = end_AMX_bf16 - start_AMX_bf16; + + + if(def_enable_fp32) std::cout << "Time taken for default fp32:" << duration_scalar_fp32.count()/iteration/nq< class HierarchicalNSW : public AlgorithmInterface { public: @@ -36,6 +39,7 @@ class HierarchicalNSW : public AlgorithmInterface { std::unique_ptr visited_list_pool_{nullptr}; + // Locks operations with element by label value mutable std::vector label_op_locks_; @@ -54,6 +58,10 @@ class HierarchicalNSW : public AlgorithmInterface { size_t data_size_{0}; DISTFUNC fstdistfunc_; + +#if defined(USE_AMX) + AMXDISTFUNC amxdistfunc_; +#endif void *dist_func_param_{nullptr}; mutable std::mutex label_lookup_lock; // lock for label_lookup_ @@ -101,6 +109,9 @@ class HierarchicalNSW : public AlgorithmInterface { num_deleted_ = 0; data_size_ = s->get_data_size(); fstdistfunc_ = s->get_dist_func(); +#if defined(USE_AMX) + amxdistfunc_ = s->get_amx_dist_func(); +#endif dist_func_param_ = s->get_dist_func_param(); if ( M <= 10000 ) { M_ = M; @@ -374,6 +385,22 @@ class HierarchicalNSW : public AlgorithmInterface { _mm_prefetch((char *) (data + 2), _MM_HINT_T0); #endif +#ifdef USE_AMX + int count=0; + size_t dim=(*(size_t *)dist_func_param_); + for (size_t j = 1; j <= size; j++) { + int candidate_id = *(data + j); + if (!(visited_array[candidate_id] == visited_array_tag)) { + char *currObj1 = (getDataByInternalId(candidate_id)); + mydata[count++]=currObj1; + } + } + if(size>0 ){ + memset(res,0,sizeof(float)*count); + amxdistfunc_((const void**)mydata,(const void*)data_point,(const void*)&dim,count,1,(float*)res); + } + count=0; +#endif for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); // if (candidate_id == 0) continue; @@ -386,7 +413,11 @@ class HierarchicalNSW : public AlgorithmInterface { visited_array[candidate_id] = visited_array_tag; char *currObj1 = (getDataByInternalId(candidate_id)); +#ifdef USE_AMX + dist_t dist=((dist_t*)res)[count++]; +#else dist_t dist = fstdistfunc_(data_point, currObj1, dist_func_param_); +#endif bool flag_consider_candidate; if (!bare_bone_search && stop_condition) { @@ -1269,6 +1300,8 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue> searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + + std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1287,6 +1320,36 @@ class HierarchicalNSW : public AlgorithmInterface { metric_distance_computations+=size; tableint *datal = (tableint *) (data + 1); + +#if defined(USE_AMX) + enable_amx(); + + size_t dim=(size_t)(*(size_t *)dist_func_param_); + if(mydata==NULL){ + mydata=(void**)malloc(sizeof(dist_t*)*maxM0_); + res=(float*) malloc(maxM0_*sizeof(float)); + memset(res,0,maxM0_*sizeof(float)); + //printf("We are 1443lines\n"); + } + + for (int i= 0; i max_elements_) @@ -1299,6 +1362,7 @@ class HierarchicalNSW : public AlgorithmInterface { changed = true; } } +#endif } } diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 7ccfbba5..00065b9d 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -20,6 +20,12 @@ #endif #endif +// #ifdef __AMX_BF16__ +// #define USE_AMX +// #endif + +// #define USE_AMX + #if defined(USE_AVX) || defined(USE_SSE) #ifdef _MSC_VER #include @@ -114,6 +120,14 @@ static bool AVX512Capable() { } return HW_AVX512F && avx512Supported; } + +static bool AMXCapable() +{ + unsigned int eax, ebx, ecx, edx; + if (!__get_cpuid_count(7, 0, &eax, &ebx, &ecx, &edx)) + return false; + return (edx & (1 << 24)) && (edx & (1 << 25)); // Check for AMX-TILE and AMX-BF16 +} #endif #include @@ -170,6 +184,12 @@ static void readBinaryPOD(std::istream &in, T &podRef) { template using DISTFUNC = MTYPE(*)(const void *, const void *, const void *); +#if defined(USE_AMX) +template +using AMXDISTFUNC = MTYPE(*)(const void **, const void *, const void *, size_t, size_t,float*); +#endif + + template class SpaceInterface { public: @@ -178,6 +198,12 @@ class SpaceInterface { virtual DISTFUNC get_dist_func() = 0; +#if defined(USE_AMX) + virtual AMXDISTFUNC get_amx_dist_func() { + return 0; // TODO + }; +#endif + virtual void *get_dist_func_param() = 0; virtual ~SpaceInterface() {} diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index 0e6834c1..a6172698 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -1,8 +1,49 @@ #pragma once #include "hnswlib.h" - +#include +#include +#include +#include namespace hnswlib { +#if defined(USE_AMX) + +#define u64 unsigned long long +#define u8 unsigned char +#define u16 unsigned short int + +#define XFEATURE_XTILECFG 17 +#define XFEATURE_XTILEDATA 18 +#define XFEATURE_MASK_XTILECFG (1 << XFEATURE_XTILECFG) +#define XFEATURE_MASK_XTILEDATA (1 << XFEATURE_XTILEDATA) +#define XFEATURE_MASK_XTILE (XFEATURE_MASK_XTILECFG | XFEATURE_MASK_XTILEDATA) +#define ARCH_GET_XCOMP_PERM 0x1022 +#define ARCH_REQ_XCOMP_PERM 0x1023 + +int enable_amx() { + unsigned long bitmask = 0; + long status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status) { + std::cout << "SYS_arch_prctl(READ) error" << std::endl; + return 0; + } + if (bitmask & XFEATURE_MASK_XTILEDATA) { + return 1; + } + status = syscall(SYS_arch_prctl, ARCH_REQ_XCOMP_PERM, XFEATURE_XTILEDATA); + if (0 != status) { + std::cout << "SYS_arch_prctl(WRITE) error" << std::endl; + return 0; + } + status = syscall(SYS_arch_prctl, ARCH_GET_XCOMP_PERM, &bitmask); + if (0 != status || !(bitmask & XFEATURE_MASK_XTILEDATA)) { + std::cout << "SYS_arch_prctl(READ) error" << std::endl; + return 0; + } + return 1; +} +#endif + static float InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { size_t qty = *((size_t *) qty_ptr); @@ -18,6 +59,8 @@ InnerProductDistance(const void *pVect1, const void *pVect2, const void *qty_ptr return 1.0f - InnerProduct(pVect1, pVect2, qty_ptr); } + + #if defined(USE_AVX) // Favor using AVX if available. @@ -250,6 +293,7 @@ InnerProductDistanceSIMD16ExtAVX(const void *pVect1v, const void *pVect2v, const #endif + #if defined(USE_SSE) static float @@ -304,6 +348,7 @@ InnerProductDistanceSIMD16ExtSSE(const void *pVect1v, const void *pVect2v, const #endif + #if defined(USE_SSE) || defined(USE_AVX) || defined(USE_AVX512) static DISTFUNC InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; @@ -339,8 +384,236 @@ InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, } #endif +#if defined(USE_AMX) +float amx_inner_product_matrix_fp32( char **floatLibraryMatrix, char *floatQueryMatrix, uint64_t dims,uint64_t batchSizeA, + uint64_t batchSizeB, float *results){ + int DIM=32; + int blockCount=(dims)/DIM; + int tailCount=dims%DIM; + unsigned char maBf16[1024] __attribute__((aligned(64))); + unsigned char mbBf16[1024] __attribute__((aligned(64))); + + //thread_local unsigned char *mbBf16=NULL; + thread_local char *preQuery=NULL; + thread_local char cfg[64]={0}; + thread_local bool init_mem=false; + + if(!init_mem){ + +/* if(!mbBf16){ + mbBf16 =(unsigned char *)aligned_alloc(64,sizeof(char)*dims*4); + } */ + cfg[0]=1; + cfg[16]=DIM*2; + cfg[48] = 16; // row->M + // matrix B need a layout rearragement + cfg[16+1*2] = batchSizeB*2*2; // col = N*4 + cfg[48+1] = DIM/2; // row = K/4 + + cfg[22]=DIM*2; + cfg[51] = 16; // row->M + // matrix B need a layout rearragement + cfg[24] = batchSizeB*2*2; // col = N*4 + cfg[52] = DIM/2; // row = K/4 + + cfg[26]=DIM*2; + cfg[53] = 16; // row->M + // matrix B need a layout rearragement + cfg[28] = batchSizeB*2*2; // col = N*4 + cfg[54] = DIM/2; // row = K/4 + + cfg[16+2*2] = (batchSizeB*4); // N*sizeof(int32) + cfg[48+2] = 16; + + init_mem = true; + + _tile_loadconfig((void *)cfg); + } + __m512i high_bits; + __m512i low_bits; + __m512i all_bits; + int i=0; + for( i = 0; i < blockCount/3; i+=1) { + int index=3*i; + for(int j = 0; j < batchSizeA; j++) { + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatLibraryMatrix[j] + index * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatLibraryMatrix[j] + index * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(maBf16 + j * DIM * 2 , all_bits); + } + + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatQueryMatrix + index * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatQueryMatrix + index * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(mbBf16 , all_bits); + + _tile_loadd(0,maBf16, 64); + _tile_loadd(1,mbBf16 , 4); + + for(int j = 0; j < batchSizeA; j++) { + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatLibraryMatrix[j] + (index+1) * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatLibraryMatrix[j] + (index+1) * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(maBf16 + j * DIM * 2 , all_bits); + } + + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatQueryMatrix + (index+1) * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatQueryMatrix + (index+1) * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(mbBf16 , all_bits); + + _tile_loadd(3,maBf16, 64); + _tile_loadd(4,mbBf16 , 4); + + for(int j = 0; j < batchSizeA; j++) { + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatLibraryMatrix[j] + (index+2) * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatLibraryMatrix[j] + (index+2) * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(maBf16 + j * DIM * 2 , all_bits); + } + + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatQueryMatrix + (index+2) * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatQueryMatrix + (index+2) * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(mbBf16 , all_bits); + + _tile_loadd(5,maBf16, 64); + _tile_loadd(6,mbBf16 , 4); + + _tile_dpbf16ps(2,0,1); + _tile_dpbf16ps(2,3,4); + _tile_dpbf16ps(2,5,6); + } + switch(blockCount%3){ + case 0: break; + case 1: + for(int j = 0; j < batchSizeA; j++) { + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatLibraryMatrix[j] + 3 * i * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatLibraryMatrix[j] + 3 * i * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(maBf16 + j * DIM * 2 , all_bits); + } + + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatQueryMatrix + 3*i * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatQueryMatrix + 3*i * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(mbBf16 , all_bits); + + _tile_loadd(0,maBf16, 64); + _tile_loadd(1,mbBf16, 4); + _tile_dpbf16ps(2,0,1); + break; + + case 2: + for(int j = 0; j < batchSizeA; j++) { + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatLibraryMatrix[j] + 3*i * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatLibraryMatrix[j] + 3*i * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(maBf16 + j * DIM * 2 , all_bits); + } + + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatQueryMatrix + 3*i * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatQueryMatrix + 3*i * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(mbBf16 , all_bits); + + _tile_loadd(0,maBf16, 64); + _tile_loadd(1,mbBf16, 4); + + for(int j = 0; j < batchSizeA; j++) { + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatLibraryMatrix[j] + (3*i+1) * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatLibraryMatrix[j] + (3*i+1) * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(maBf16 + j * DIM * 2 , all_bits); + } + + high_bits = _mm512_srli_epi32(_mm512_loadu_si512(floatQueryMatrix + (3*i+1) * DIM * 4),16); + low_bits = _mm512_loadu_si512(floatQueryMatrix + (3*i+1) * DIM * 4 + 64); + all_bits= _mm512_mask_blend_epi16(0x55555555,low_bits , high_bits); + _mm512_store_si512(mbBf16 , all_bits); + + _tile_loadd(3,maBf16, 64); + _tile_loadd(4,mbBf16, 4); + _tile_dpbf16ps(2,0,1); + _tile_dpbf16ps(2,3,4); + break; + } + + + _tile_stored(2, results, batchSizeB*2*2); + _tile_zero(2); + + if (tailCount != 0) { + for (int k = 0; k < batchSizeA; k++) { + for (int l = 0; l < batchSizeB; l++) { + __m512 result_vec = _mm512_setzero_ps(); + for (int i = 0; i < tailCount; i += 16) { + __m512 lib_vec = _mm512_loadu_ps((float *)(floatLibraryMatrix[k]) + DIM * blockCount + i); + __m512 query_vec = _mm512_loadu_ps((float *)(floatQueryMatrix + DIM * blockCount + i)); + result_vec = _mm512_fmadd_ps(lib_vec, query_vec, result_vec); + } + results[k * batchSizeB + l] += _mm512_reduce_add_ps(result_vec); + } + } + } + + return 0; +} + +static float InnerProductBatchExtAMX(const void **pVect1v, const void *pVect2v, const void *qty_ptr, size_t nSize, size_t mSize, float * results_amx){ + + unsigned int dims= *(unsigned int*)qty_ptr; + char **floatLibraryMatrix = (char**) pVect1v; + char *floatQueryMatrix = (char*) pVect2v; + + + int batchSizeA = 16, batchSizeB = 16; + int batchCountA = (nSize - 1) / batchSizeA + 1; + int batchCountB = (mSize - 1) / batchSizeB + 1; + + int lastBatchSizeA = (nSize % batchSizeA == 0) ? batchSizeA : nSize % batchSizeA; + int lastBatchSizeB = (mSize % batchSizeB == 0) ? batchSizeB : mSize % batchSizeB; + + int offsetA = batchSizeA * dims * 4; + int offsetB = batchSizeB * dims * 4; + + float *results_ptr = results_amx; + + for (int i = 0; i < batchCountA; i++) { + int currentBatchSizeA = (i == batchCountA - 1) ? lastBatchSizeA : batchSizeA; + char **currentLibraryMatrixPtr = floatLibraryMatrix + i * 16; + + for (int j = 0; j < batchCountB; j++) { + int currentBatchSizeB = (j == batchCountB - 1) ? lastBatchSizeB : batchSizeB; + char *currentQueryMatrixPtr = floatQueryMatrix + j * offsetB; + + amx_inner_product_matrix_fp32(currentLibraryMatrixPtr, currentQueryMatrixPtr, dims, currentBatchSizeA, currentBatchSizeB, results_ptr); + + results_ptr += currentBatchSizeB * currentBatchSizeA; + } + } + + return 0; +} +static float +InnerProductDistanceBatchExtAMX(const void **pVect1v, const void *pVect2v, const void *qty_ptr, size_t nSize, size_t mSize, float * results_amx) { + + InnerProductBatchExtAMX(pVect1v, pVect2v, qty_ptr,nSize,mSize,results_amx); + for(int i=0;i InnerProductBatchExt = InnerProductBatchExtAMX; +static AMXDISTFUNC InnerProductDistanceBatchExt = InnerProductDistanceBatchExtAMX; +#endif + class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; +#ifdef USE_AMX + AMXDISTFUNC amxdistfunc_; +#endif + size_t data_size_; size_t dim_; @@ -369,6 +642,8 @@ class InnerProductSpace : public SpaceInterface { } #endif + + if (dim % 16 == 0) fstdistfunc_ = InnerProductDistanceSIMD16Ext; else if (dim % 4 == 0) @@ -377,6 +652,14 @@ class InnerProductSpace : public SpaceInterface { fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; else if (dim > 4) fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif +#if defined(USE_AMX) + if (AMXCapable()) { + InnerProductBatchExt = InnerProductBatchExtAMX; + InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMX; + } + + amxdistfunc_ = InnerProductDistanceBatchExt; #endif dim_ = dim; data_size_ = dim * sizeof(float); @@ -389,6 +672,11 @@ class InnerProductSpace : public SpaceInterface { DISTFUNC get_dist_func() { return fstdistfunc_; } +#if defined(USE_AMX) + AMXDISTFUNC get_amx_dist_func(){ + return amxdistfunc_; + } +#endif void *get_dist_func_param() { return &dim_; @@ -397,4 +685,110 @@ class InnerProductSpace : public SpaceInterface { ~InnerProductSpace() {} }; +static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { + float result[16] = {0.0f}; // 用于存储中间结果 + + uint16_t *x = (uint16_t *)a; + uint16_t *y = (uint16_t *)b; + __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 + + size_t dim = * (size_t*) qty_ptr ; + + size_t i = 0; + // 每次处理32个元素(16个__bf16元素在__m512bh寄存器中存储为32个uint16_t) + for (; i + 32 <= dim; i += 32) { + // 加载32个uint16_t到__m512i类型的临时寄存器 + __m512i temp_x = _mm512_loadu_si512(x + i); + __m512i temp_y = _mm512_loadu_si512(y + i); + + // 强制转换为__m512bh类型 + __m512bh v1_f16 = reinterpret_cast<__m512bh&>(temp_x); + __m512bh v2_f16 = reinterpret_cast<__m512bh&>(temp_y); + + // 计算BF16的点积,并将结果累加到vr_f32 + vr_f32 = _mm512_dpbf16_ps(vr_f32, v1_f16, v2_f16); + } + + // 将vr_f32寄存器的值存入result数组 + _mm512_storeu_ps(result, vr_f32); + + // 累加result数组的所有元素,获得最终的点积结果 + float dot_product = 0.0f; + for (int j = 0; j < 16; j++) { + dot_product += result[j]; + } + + // 处理剩余的元素(小于32的部分) +/* for (; i < dim; i++) { + float x_val = bf162float(x[i]); + float y_val = bf162float(y[i]); + dot_product += x_val * y_val; + } */ + //printf("%d %f ",dim,dot_product); + return 1-dot_product; +} +class Bf16InnerProductSpace : public hnswlib::SpaceInterface { + DISTFUNC fstdistfunc_; + size_t data_size_; + size_t dim_; + public: + Bf16InnerProductSpace(size_t dim) { + fstdistfunc_ = InnerProductDistanceBf16; +#if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) + #if defined(USE_AVX512) + if (AVX512Capable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + } else if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #elif defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; + InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + } + #endif + #if defined(USE_AVX) + if (AVXCapable()) { + InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; + InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; + } + #endif + + + + if (dim % 16 == 0) + fstdistfunc_ = InnerProductDistanceSIMD16Ext; + else if (dim % 4 == 0) + fstdistfunc_ = InnerProductDistanceSIMD4Ext; + else if (dim > 16) + fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; + else if (dim > 4) + fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; +#endif +#if defined(USE_AMX) + if (AMXCapable()) { + InnerProductBatchExt = InnerProductBatchExtAMX; + InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMX; + } + + amxdistfunc_ = InnerProductDistanceBatchExt; +#endif + dim_ = dim * 2; + data_size_ = dim * 2 * sizeof(uint16_t); + } + + size_t get_data_size() { + return data_size_; + } + + DISTFUNC get_dist_func() { + return fstdistfunc_; + } + void *get_dist_func_param() { + return &dim_; + } + ~Bf16InnerProductSpace() {} +}; } // namespace hnswlib diff --git a/hnswlib/space_l2.h b/hnswlib/space_l2.h index 834d19f7..f95ccab3 100644 --- a/hnswlib/space_l2.h +++ b/hnswlib/space_l2.h @@ -244,6 +244,10 @@ class L2Space : public SpaceInterface { DISTFUNC get_dist_func() { return fstdistfunc_; } +// #if defined(USE_AMX) +// AMXDISTFUNC get_amx_dist_func() { +// } +// #endif void *get_dist_func_param() { return &dim_; diff --git a/hnswlib/stop_condition.h b/hnswlib/stop_condition.h index acc80ebe..f94fb4d0 100644 --- a/hnswlib/stop_condition.h +++ b/hnswlib/stop_condition.h @@ -57,7 +57,10 @@ class MultiVectorL2Space : public BaseMultiVectorSpace { DISTFUNC get_dist_func() override { return fstdistfunc_; } - +// #if defined(USE_AMX) +// AMXDISTFUNC get_amx_dist_func() { +// } +// #endif void *get_dist_func_param() override { return &dim_; } From 7e6fb3095a66f06d4f1df1e097625eaca017c715 Mon Sep 17 00:00:00 2001 From: ruclz Date: Tue, 21 Jan 2025 07:12:58 +0000 Subject: [PATCH 02/14] AMX enable for FP32 innerproduct --- CMakeLists.txt | 2 +- hnswlib/space_ip.h | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cdf0c4f..a1b5d55f 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -72,7 +72,7 @@ if(AMXEnable) if(HAS_AMX_SUPPORT) message(STATUS "Compiler supports AMX") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP} -mamx-bf16") - #add_compile_definitions(USE_AMX=1) + add_compile_definitions(USE_AMX=1) else() message(STATUS "Compiler does NOT support AMX") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS_BACKUP}") diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index a6172698..1606b395 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -729,6 +729,9 @@ static float InnerProductDistanceBf16(const void* a, const void* b, const void * } class Bf16InnerProductSpace : public hnswlib::SpaceInterface { DISTFUNC fstdistfunc_; +#ifdef USE_AMX + AMXDISTFUNC amxdistfunc_; +#endif size_t data_size_; size_t dim_; public: From f2723ad1cf336fe64ef939927f51704352532705 Mon Sep 17 00:00:00 2001 From: ruclz Date: Tue, 21 Jan 2025 07:34:33 +0000 Subject: [PATCH 03/14] AMX enable for FP32 innerproduct --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index a1b5d55f..80c03afd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-O0 -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) + SET( CMAKE_CXX_FLAGS "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() From 59fa6513834d9a05bc94c943e1917d677d9e25f7 Mon Sep 17 00:00:00 2001 From: ruclz Date: Fri, 21 Feb 2025 11:00:35 +0800 Subject: [PATCH 04/14] Add bf16 --- CMakeLists.txt | 4 +- examples/cpp/example_mt_search_bf16.cpp | 128 ++++++-------- hnswlib/hnswlib.h | 2 +- hnswlib/space_ip.h | 217 ++++++++++++++++++++---- 4 files changed, 239 insertions(+), 112 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 80c03afd..f97b55ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -6,7 +6,7 @@ project(hnswlib include(GNUInstallDirs) include(CheckCXXCompilerFlag) -SET(AMXEnable true) +SET(AMXEnable false) add_library(hnswlib INTERFACE) add_library(hnswlib::hnswlib ALIAS hnswlib) @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-Ofast -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0" ) + SET( CMAKE_CXX_FLAGS "-O0 -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp index e00306c4..bf15dfc2 100644 --- a/examples/cpp/example_mt_search_bf16.cpp +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -195,30 +195,6 @@ class Int8InnerProductSpace : public hnswlib::SpaceInterface { } ~Int8InnerProductSpace() {} }; -class Bf16InnerProductSpace : public hnswlib::SpaceInterface { - DISTFUNC fstdistfunc_; - size_t data_size_; - size_t dim_; - public: - Bf16InnerProductSpace(size_t dim) { - fstdistfunc_ = vector_dot_product_bf16; - dim_ = dim * 2; - data_size_ = dim * 2 * sizeof(uint16_t); - } - - size_t get_data_size() { - return data_size_; - } - - DISTFUNC get_dist_func() { - return fstdistfunc_; - } - void *get_dist_func_param() { - return &dim_; - } - ~Bf16InnerProductSpace() {} -}; - void setThreadAffinity(std::thread::native_handle_type handle, size_t cpuId) { cpu_set_t cpuset; CPU_ZERO(&cpuset); @@ -431,27 +407,27 @@ int call_AMX_fp32(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::InnerProduc return 0; } -// int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,Bf16InnerProductSpace & space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ -// //init_onednn(); -// std::vector neighbors(max_elements); -// ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { -// std::priority_queue> result = alg_hnsw->searchKnnAMX_bf16(data + dim * row, 1); -// hnswlib::labeltype label = result.top().second; -// neighbors[row] = label; -// }); -// float correct = 0; -// for (int i = 0; i < max_elements; i++) { -// hnswlib::labeltype label = neighbors[i]; -// if (label == i) correct++; -// } -// float recall = correct / max_elements; -// std::cout << "Recall: " << recall << "\n"; -// return 0; -// } +int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::Bf16InnerProductSpace & space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ + //init_onednn(); + std::vector neighbors(max_elements); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); + hnswlib::labeltype label = result.top().second; + neighbors[row] = label; + }); + float correct = 0; + for (int i = 0; i < max_elements; i++) { + hnswlib::labeltype label = neighbors[i]; + if (label == i) correct++; + } + float recall = correct / max_elements; + std::cout << "Recall: " << recall << "\n"; + return 0; +} int main() { int true_dim=1024; int dim = true_dim/2; // Dimension of the elements - size_t max_elements = 100*1024; // Maximum number of elements, should be known beforehand + size_t max_elements = 10*1024; // Maximum number of elements, should be known beforehand int M = 32; // Tightly connected with internal dimensionality of the data size_t nq = max_elements; // strongly affects the memory consumption @@ -490,24 +466,24 @@ int main() { float tmp = (distrib_real(rng)); data_fp32[i] = tmp; query_fp32[i] = tmp; -/* uint32_t *int32_data =(uint32_t *) &tmp; - bf_data[i]=*int32_data >> 16; */ + uint32_t *int32_data =(uint32_t *) &tmp; + bf_data[i]=*int32_data >> 16; } hnswlib::InnerProductSpace space_fp32(true_dim); hnswlib::HierarchicalNSW* alg_hnsw_fp32 = new hnswlib::HierarchicalNSW(&space_fp32, max_elements, M, ef_construction); // Add data to index - ParallelFor_Build(0, max_elements, 43, [&](size_t row, size_t threadId) { + ParallelFor_Build(0, max_elements, num_threads, [&](size_t row, size_t threadId) { alg_hnsw_fp32->addPoint((void*)(data_fp32 + true_dim * row), row); }); // Query the elements for themselves and measure recall - // Bf16InnerProductSpace space_bf16(dim); - // hnswlib::HierarchicalNSW* alg_hnsw_bf16 = new hnswlib::HierarchicalNSW(&space_bf16, max_elements, M, ef_construction); - // ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { - // alg_hnsw_bf16->addPoint((void*)(data_bf16 + dim * row), row); - // }); + hnswlib::Bf16InnerProductSpace space_bf16(true_dim); + hnswlib::HierarchicalNSW* alg_hnsw_bf16 = new hnswlib::HierarchicalNSW(&space_bf16, max_elements, M, ef_construction); + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + alg_hnsw_bf16->addPoint((void*)(bf_data + true_dim * row), row); + }); @@ -515,15 +491,15 @@ int main() { start_scalar_bf16,end_scalar_bf16, start_AMX_fp32,end_AMX_fp32, start_AMX_bf16,end_AMX_bf16; - // if(def_enable_fp32){ - // std::cout << "Default FP32 search start." <<"\n"; - // start_scalar_fp32 = std::chrono::high_resolution_clock::now(); - // for(int i=0;i duration_scalar_fp32 = end_scalar_fp32 - start_scalar_fp32; std::chrono::duration duration_scalar_bf16 = end_scalar_bf16 - start_scalar_bf16; diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 00065b9d..95ebab1c 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -24,7 +24,7 @@ // #define USE_AMX // #endif -// #define USE_AMX +//#define USE_AMX #if defined(USE_AVX) || defined(USE_SSE) #ifdef _MSC_VER diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index 1606b395..91a1ed6c 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -595,6 +595,107 @@ static float InnerProductBatchExtAMX(const void **pVect1v, const void *pVect2v, return 0; } + +float amx_inner_product_matrix_bf16( char **floatLibraryMatrix, char *floatQueryMatrix, uint64_t dims,uint64_t batchSizeA, + uint64_t batchSizeB, float *results){ + int DIM=32; + int blockDim = 96; + int blockCount=((dims))/blockDim; + size_t tailCount=dims%DIM; + int tailBlock=dims%blockDim; + + thread_local char cfg[64]={0}; + thread_local bool init_mem=false; + + unsigned char ma1Bf16[1024] __attribute__((aligned(64))); + unsigned char ma2Bf16[1024] __attribute__((aligned(64))); + unsigned char ma3Bf16[1024] __attribute__((aligned(64))); + + if(!init_mem){ + cfg[0]=1; + cfg[16]=DIM*2; + cfg[48] = 16; // row->M + // matrix B need a layout rearragement + cfg[16+1*2] = batchSizeB*2*2; // col = N*4 + cfg[48+1] = DIM/2; // row = K/4 + + cfg[22]=DIM*2; + cfg[51] = 16; // row->M + // matrix B need a layout rearragement + cfg[24] = batchSizeB*2*2; // col = N*4 + cfg[52] = DIM/2; // row = K/4 + + cfg[26]= DIM*2; + cfg[53] = 16; // row->M + // matrix B need a layout rearragement + cfg[28] = batchSizeB*2*2; // col = N*4 + cfg[54] = DIM/2; // row = K/4 + + cfg[16+2*2] = (batchSizeB*4); // N*sizeof(int32) + cfg[48+2] = 16; + init_mem = true; + + _tile_loadconfig((void *)cfg); + } + //memset(maBf16,0,16*DIM*2); + + int i=0; + for(int i=0;i= DIM){ + for(int i=0;i InnerProductBatchExt = InnerProductBatchExtAMX; static AMXDISTFUNC InnerProductDistanceBatchExt = InnerProductDistanceBatchExtAMX; #endif @@ -685,7 +827,28 @@ class InnerProductSpace : public SpaceInterface { ~InnerProductSpace() {} }; +float bf162float(uint16_t data) { + int t = (data<<16); + auto a= *reinterpret_cast(&t); + return a; +} static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { + uint16_t *x = (uint16_t *)a; + uint16_t *y = (uint16_t *)b; + // __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 + + size_t dim = * (size_t*) qty_ptr; + + float dot_product = 0.0f; + + for (int i=0; i < dim; i++) { + float x_val = bf162float(x[i]); + float y_val = bf162float(y[i]); + dot_product += x_val * y_val; + } + return 1-dot_product; +} +static float InnerProductDistanceBf16AVX512(const void* a, const void* b, const void *qty_ptr) { float result[16] = {0.0f}; // 用于存储中间结果 uint16_t *x = (uint16_t *)a; @@ -719,13 +882,13 @@ static float InnerProductDistanceBf16(const void* a, const void* b, const void * } // 处理剩余的元素(小于32的部分) -/* for (; i < dim; i++) { + for (; i < dim; i++) { float x_val = bf162float(x[i]); float y_val = bf162float(y[i]); dot_product += x_val * y_val; - } */ + } //printf("%d %f ",dim,dot_product); - return 1-dot_product; + return 1 - dot_product; } class Bf16InnerProductSpace : public hnswlib::SpaceInterface { DISTFUNC fstdistfunc_; @@ -740,46 +903,28 @@ class Bf16InnerProductSpace : public hnswlib::SpaceInterface { #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) #if defined(USE_AVX512) if (AVX512Capable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX512; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX512; + InnerProductSIMD16Ext = InnerProductDistanceBf16AVX512; + InnerProductDistanceSIMD16Ext = InnerProductDistanceBf16AVX512; } else if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; - } - #elif defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductSIMD16ExtAVX; - InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtAVX; + InnerProductSIMD16Ext = InnerProductDistanceBf16; + InnerProductDistanceSIMD16Ext = InnerProductDistanceBf16; } + #else + InnerProductSIMD16Ext = InnerProductDistanceBf16; + InnerProductDistanceSIMD16Ext = InnerProductDistanceBf16; #endif - #if defined(USE_AVX) - if (AVXCapable()) { - InnerProductSIMD4Ext = InnerProductSIMD4ExtAVX; - InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtAVX; - } - #endif - - - - if (dim % 16 == 0) - fstdistfunc_ = InnerProductDistanceSIMD16Ext; - else if (dim % 4 == 0) - fstdistfunc_ = InnerProductDistanceSIMD4Ext; - else if (dim > 16) - fstdistfunc_ = InnerProductDistanceSIMD16ExtResiduals; - else if (dim > 4) - fstdistfunc_ = InnerProductDistanceSIMD4ExtResiduals; + fstdistfunc_=InnerProductDistanceSIMD16Ext; #endif #if defined(USE_AMX) if (AMXCapable()) { - InnerProductBatchExt = InnerProductBatchExtAMX; - InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMX; + InnerProductBatchExt = InnerProductBatchExtAMXBF16; + InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMXBF16; } amxdistfunc_ = InnerProductDistanceBatchExt; #endif - dim_ = dim * 2; - data_size_ = dim * 2 * sizeof(uint16_t); + dim_ = dim ; + data_size_ = dim * sizeof(uint16_t); } size_t get_data_size() { @@ -789,6 +934,12 @@ class Bf16InnerProductSpace : public hnswlib::SpaceInterface { DISTFUNC get_dist_func() { return fstdistfunc_; } + +#if defined(USE_AMX) + AMXDISTFUNC get_amx_dist_func(){ + return amxdistfunc_; + } +#endif void *get_dist_func_param() { return &dim_; } From d176f7a4027d4bd6e81c3a03c69aa6fddd0d5aed Mon Sep 17 00:00:00 2001 From: ruclz Date: Wed, 5 Mar 2025 11:05:25 +0800 Subject: [PATCH 05/14] Support ann-bench --- hnswlib/hnswalg.h | 31 ++++++-- hnswlib/hnswlib.h | 3 +- python_bindings/bindings.cpp | 145 ++++++++++++++++++++++------------- 3 files changed, 119 insertions(+), 60 deletions(-) diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index 8af1546e..af20813f 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -15,8 +15,8 @@ typedef unsigned int tableint; typedef unsigned int linklistsizeint; -thread_local void **mydata=NULL; -thread_local void *res=NULL; +//thread_local void **mydata=NULL; +//thread_local void *res=NULL; template class HierarchicalNSW : public AlgorithmInterface { public: @@ -333,6 +333,9 @@ class HierarchicalNSW : public AlgorithmInterface { std::priority_queue, std::vector>, CompareByFirst> candidate_set; dist_t lowerBound; + + void *mydata[maxM0_]; + float res[maxM0_]; if (bare_bone_search || (!isMarkedDeleted(ep_id) && ((!isIdAllowed) || (*isIdAllowed)(getExternalLabel(ep_id))))) { char* ep_data = getDataByInternalId(ep_id); @@ -388,6 +391,8 @@ class HierarchicalNSW : public AlgorithmInterface { #ifdef USE_AMX int count=0; size_t dim=(*(size_t *)dist_func_param_); + + for (size_t j = 1; j <= size; j++) { int candidate_id = *(data + j); if (!(visited_array[candidate_id] == visited_array_tag)) { @@ -467,6 +472,10 @@ class HierarchicalNSW : public AlgorithmInterface { } visited_list_pool_->releaseVisitedList(vl); +// #ifdef USE_AMX +// free(mydata); +// free(res); +// #endif return top_candidates; } @@ -1302,6 +1311,8 @@ class HierarchicalNSW : public AlgorithmInterface { searchKnn(const void *query_data, size_t k, BaseFilterFunctor* isIdAllowed = nullptr) const { + void *mydata[maxM0_]; + float res[maxM0_]; std::priority_queue> result; if (cur_element_count == 0) return result; @@ -1325,13 +1336,15 @@ class HierarchicalNSW : public AlgorithmInterface { enable_amx(); size_t dim=(size_t)(*(size_t *)dist_func_param_); - if(mydata==NULL){ - mydata=(void**)malloc(sizeof(dist_t*)*maxM0_); - res=(float*) malloc(maxM0_*sizeof(float)); - memset(res,0,maxM0_*sizeof(float)); + + //if(mydata==NULL){ + // mydata=(void**)malloc(sizeof(dist_t*)*maxM0_); + // res=(float*) malloc(maxM0_*sizeof(float)); + // memset(res,0,maxM0_*sizeof(float)); //printf("We are 1443lines\n"); - } + //} + //printf("we are here\n"); for (int i= 0; i { result.push(std::pair(rez.first, getExternalLabel(rez.second))); top_candidates.pop(); } +// #ifdef USE_AMX +// free(mydata); +// free(res); +// #endif return result; } diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 95ebab1c..17e2c274 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -24,7 +24,8 @@ // #define USE_AMX // #endif -//#define USE_AMX +#define USE_AMX +#define BF16_SUPPORT #if defined(USE_AVX) || defined(USE_SSE) #ifdef _MSC_VER diff --git a/python_bindings/bindings.cpp b/python_bindings/bindings.cpp index dd09e80a..782693bd 100644 --- a/python_bindings/bindings.cpp +++ b/python_bindings/bindings.cpp @@ -159,6 +159,7 @@ class Index { hnswlib::labeltype cur_l; hnswlib::HierarchicalNSW* appr_alg; hnswlib::SpaceInterface* l2space; + //hnswlib::SpaceInterface* bf16space; Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { @@ -166,9 +167,21 @@ class Index { if (space_name == "l2") { l2space = new hnswlib::L2Space(dim); } else if (space_name == "ip") { - l2space = new hnswlib::InnerProductSpace(dim); +#if defined(BF16_SUPPORT) + //exit(0); + l2space = new hnswlib::Bf16InnerProductSpace(dim); +#else + l2space = new hnswlib::InnerProductSpace(dim); +#endif + + } else if (space_name == "cosine") { +#if defined(BF16_SUPPORT) + //exit(0); + l2space = new hnswlib::Bf16InnerProductSpace(dim); +#else l2space = new hnswlib::InnerProductSpace(dim); +#endif normalize = true; } else { throw std::runtime_error("Space name must be one of l2, ip, or cosine."); @@ -184,6 +197,7 @@ class Index { ~Index() { delete l2space; + //delete bf16space; if (appr_alg) delete appr_alg; } @@ -238,70 +252,97 @@ class Index { } + void normalize_vector(float* data, float* norm_array) { float norm = 0.0f; for (int i = 0; i < dim; i++) norm += data[i] * data[i]; norm = 1.0f / (sqrtf(norm) + 1e-30f); - for (int i = 0; i < dim; i++) - norm_array[i] = data[i] * norm; - } - +#ifdef BF16_SUPPORT - void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) { - py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); - auto buffer = items.request(); - if (num_threads <= 0) - num_threads = num_threads_default; + uint16_t * bf_data=(uint16_t*)norm_array; + for(int i=0;i> 16; + } +#else + for (int i = 0; i < dim; i++){ + norm_array[i] = data[i] * norm; + } +#endif - size_t rows, features; - get_input_array_shapes(buffer, &rows, &features); + } - if (features != dim) - throw std::runtime_error("Wrong dimensionality of the vectors"); + void addItems(py::object input, py::object ids_ = py::none(), int num_threads = -1, bool replace_deleted = false) { - // avoid using threads when the number of additions is small: - if (rows <= num_threads * 4) { - num_threads = 1; - } + py::array_t < dist_t, py::array::c_style | py::array::forcecast > items(input); + auto buffer = items.request(); + if (num_threads <= 0) + num_threads = num_threads_default; - std::vector ids = get_input_ids_and_check_shapes(ids_, rows); + size_t rows, features; + get_input_array_shapes(buffer, &rows, &features); - { - int start = 0; - if (!ep_added) { - size_t id = ids.size() ? ids.at(0) : (cur_l); - float* vector_data = (float*)items.data(0); - std::vector norm_array(dim); - if (normalize) { - normalize_vector(vector_data, norm_array.data()); - vector_data = norm_array.data(); - } - appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted); - start = 1; - ep_added = true; - } + if (features != dim) + throw std::runtime_error("Wrong dimensionality of the vectors"); - py::gil_scoped_release l; - if (normalize == false) { - ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { - size_t id = ids.size() ? ids.at(row) : (cur_l + row); - appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted); - }); - } else { - std::vector norm_array(num_threads * dim); - ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { - // normalize vector: - size_t start_idx = threadId * dim; - normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); + // avoid using threads when the number of additions is small: + if (rows <= num_threads * 4) { + num_threads = 1; + } - size_t id = ids.size() ? ids.at(row) : (cur_l + row); - appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted); - }); - } - cur_l += rows; - } - } + std::vector ids = get_input_ids_and_check_shapes(ids_, rows); + + { +/*#if defined(BF16_SUPPORT) + for(int i=0;i> 16; + } + } +#endif*/ + int start = 0; + if (!ep_added) { + size_t id = ids.size() ? ids.at(0) : (cur_l); + + float* vector_data = (float*)items.data(0); + std::vector norm_array(dim); + if (normalize) { + normalize_vector(vector_data, norm_array.data()); + vector_data = norm_array.data(); + } + + + appr_alg->addPoint((void*)vector_data, (size_t)id, replace_deleted); + start = 1; + ep_added = true; + } + + py::gil_scoped_release l; + if (normalize == false) { + ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)items.data(row), (size_t)id, replace_deleted); + }); + } else { + std::vector norm_array(num_threads * dim); + ParallelFor(start, rows, num_threads, [&](size_t row, size_t threadId) { + // normalize vector: + size_t start_idx = threadId * dim; + normalize_vector((float*)items.data(row), (norm_array.data() + start_idx)); + + size_t id = ids.size() ? ids.at(row) : (cur_l + row); + appr_alg->addPoint((void*)(norm_array.data() + start_idx), (size_t)id, replace_deleted); + }); + } + cur_l += rows; + } + } py::object getData(py::object ids_ = py::none(), std::string return_type = "numpy") { From ce089f612cb7d4250a155e2ca83bfca0f7b0f0ba Mon Sep 17 00:00:00 2001 From: ruclz Date: Fri, 28 Mar 2025 16:39:28 +0800 Subject: [PATCH 06/14] FP32 benchmark --- examples/cpp/example_mt_search_bf16.cpp | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp index bf15dfc2..231f13b7 100644 --- a/examples/cpp/example_mt_search_bf16.cpp +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -456,8 +456,8 @@ int main() { const char* def_fp32_env = std::getenv("FP32_DEF"); bool def_enable_fp32 = def_fp32_env ? std::stoi(def_fp32_env) : false; - const char* avx512_bf16_env = std::getenv("BF16_AVX512"); - bool avx512_enable_bf16 = avx512_bf16_env ? std::stoi(avx512_bf16_env) : false; + // const char* avx512_bf16_env = std::getenv("BF16_AVX512"); + // bool avx512_enable_bf16 = avx512_bf16_env ? std::stoi(avx512_bf16_env) : false; @@ -513,15 +513,15 @@ int main() { // } - // if(amx_enable_fp32){ - // std::cout << "FP32 with AMX search start." <<"\n"; - // start_AMX_fp32 = std::chrono::high_resolution_clock::now(); - // for(int i=0;i Date: Fri, 28 Mar 2025 16:41:19 +0800 Subject: [PATCH 07/14] FP32 benchmark --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f97b55ff..cf76ac64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-O0 -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) + SET( CMAKE_CXX_FLAGS "-Ofast -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() From a2ef65923aac753a9941c2ba6346f91c33b24162 Mon Sep 17 00:00:00 2001 From: ruclz Date: Tue, 1 Apr 2025 19:14:12 +0800 Subject: [PATCH 08/14] BF16 tail data calculation --- CMakeLists.txt | 2 +- examples/cpp/example_mt_search_bf16.cpp | 12 ++++---- hnswlib/space_ip.h | 41 ++++++++++++++----------- 3 files changed, 30 insertions(+), 25 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf76ac64..9f320f49 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-Ofast -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) + SET( CMAKE_CXX_FLAGS "-O -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp index 231f13b7..2f31364e 100644 --- a/examples/cpp/example_mt_search_bf16.cpp +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -229,7 +229,7 @@ inline void ParallelFor_Build(size_t start, size_t end, size_t numThreads, Funct for (size_t threadId = 0; threadId < numThreads; ++threadId) { threads.push_back(std::thread([&, threadId] { - setThreadAffinity(pthread_self(), threadId); + //setThreadAffinity(pthread_self(), threadId); while (true) { size_t id = current.fetch_add(1); @@ -358,7 +358,7 @@ int call_scalar(hnswlib::HierarchicalNSW* alg_hnsw,Int8InnerProductSpace int call_scalar_fp32(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::InnerProductSpace& space,float* data,int dim, size_t max_elements,int top_k,int num_threads){ std::vector neighbors(max_elements); ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { - std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, top_k); + std::priority_queue> result = alg_hnsw->searchKnn(data + dim * row, 1); hnswlib::labeltype label = result.top().second; neighbors[row] = label; }); @@ -425,14 +425,14 @@ int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::Bf16InnerPr return 0; } int main() { - int true_dim=1024; + int true_dim=30; int dim = true_dim/2; // Dimension of the elements size_t max_elements = 10*1024; // Maximum number of elements, should be known beforehand int M = 32; // Tightly connected with internal dimensionality of the data size_t nq = max_elements; // strongly affects the memory consumption int ef_construction = 200; // Controls index search speed/build speed tradeoff - int num_threads = 16; // Number of threads for operations with index + int num_threads = 1; // Number of threads for operations with index int top_k=1; @@ -473,7 +473,7 @@ int main() { hnswlib::InnerProductSpace space_fp32(true_dim); hnswlib::HierarchicalNSW* alg_hnsw_fp32 = new hnswlib::HierarchicalNSW(&space_fp32, max_elements, M, ef_construction); // Add data to index - ParallelFor_Build(0, max_elements, num_threads, [&](size_t row, size_t threadId) { + ParallelFor(0, max_elements, num_threads, [&](size_t row, size_t threadId) { alg_hnsw_fp32->addPoint((void*)(data_fp32 + true_dim * row), row); }); @@ -495,7 +495,7 @@ int main() { std::cout << "Default FP32 search start." <<"\n"; start_scalar_fp32 = std::chrono::high_resolution_clock::now(); for(int i=0;i(&t); + return a; +} + static float InnerProduct(const void *pVect1, const void *pVect2, const void *qty_ptr) { size_t qty = *((size_t *) qty_ptr); @@ -543,16 +551,14 @@ float amx_inner_product_matrix_fp32( char **floatLibraryMatrix, char *floatQuer _tile_stored(2, results, batchSizeB*2*2); _tile_zero(2); + // printf("tailCount=%d\n",tailCount); if (tailCount != 0) { + int32_t offset= dims/DIM*DIM; for (int k = 0; k < batchSizeA; k++) { for (int l = 0; l < batchSizeB; l++) { - __m512 result_vec = _mm512_setzero_ps(); - for (int i = 0; i < tailCount; i += 16) { - __m512 lib_vec = _mm512_loadu_ps((float *)(floatLibraryMatrix[k]) + DIM * blockCount + i); - __m512 query_vec = _mm512_loadu_ps((float *)(floatQueryMatrix + DIM * blockCount + i)); - result_vec = _mm512_fmadd_ps(lib_vec, query_vec, result_vec); - } - results[k * batchSizeB + l] += _mm512_reduce_add_ps(result_vec); + for (int m = 0; m < tailCount; m += 1) { + results[k * batchSizeB + l] += (*(float *)(floatLibraryMatrix[k] + 4*(offset+m))) * (*(float *)(floatQueryMatrix + 4*(offset+m))); + } } } } @@ -680,15 +686,17 @@ float amx_inner_product_matrix_bf16( char **floatLibraryMatrix, char *floatQuer _tile_zero(2); if (tailCount != 0) { + int32_t offset= dims/DIM*DIM; for (int k = 0; k < batchSizeA; k++) { for (int l = 0; l < batchSizeB; l++) { - __m512 result_vec = _mm512_setzero_ps(); - for (int i = 0; i < tailCount; i += 16) { - __m512 lib_vec = _mm512_loadu_ps((float *)(floatLibraryMatrix[k]) + DIM * blockCount + i); - __m512 query_vec = _mm512_loadu_ps((float *)(floatQueryMatrix + DIM * blockCount + i)); - result_vec = _mm512_fmadd_ps(lib_vec, query_vec, result_vec); + for (int m = 0; m < tailCount; m += 1) { + //blockDim*blockCount+tailBlock/DIM*DIM+i + + results[k * batchSizeB + l] += bf162float(*(uint16_t *)(floatLibraryMatrix[k] + 2*(offset+m))) * bf162float(*(uint16_t *)(floatQueryMatrix + 2*(offset+m))); + // __m512 lib_vec = _mm512_loadu_ps((float *)(floatLibraryMatrix[k] + 2*(DIM * blockCount + i))); + // __m512 query_vec = _mm512_loadu_ps((float *)(floatQueryMatrix + 2*(DIM * blockCount + i))); + // result_vec = _mm512_fmadd_ps(lib_vec, query_vec, result_vec); } - results[k * batchSizeB + l] += _mm512_reduce_add_ps(result_vec); } } } @@ -827,11 +835,7 @@ class InnerProductSpace : public SpaceInterface { ~InnerProductSpace() {} }; -float bf162float(uint16_t data) { - int t = (data<<16); - auto a= *reinterpret_cast(&t); - return a; -} + static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { uint16_t *x = (uint16_t *)a; uint16_t *y = (uint16_t *)b; @@ -946,3 +950,4 @@ class Bf16InnerProductSpace : public hnswlib::SpaceInterface { ~Bf16InnerProductSpace() {} }; } // namespace hnswlib + From f33f629deea8a7d2b11c4672c755f4df27473570 Mon Sep 17 00:00:00 2001 From: ruclz Date: Wed, 2 Apr 2025 11:00:29 +0800 Subject: [PATCH 09/14] Fix bf16 bug for tail data --- CMakeLists.txt | 2 +- examples/cpp/example_mt_search_bf16.cpp | 2 +- hnswlib/hnswalg.h | 41 +++--- hnswlib/space_ip.h | 177 ++++++++++++++---------- 4 files changed, 123 insertions(+), 99 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 9f320f49..cf76ac64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-O -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) + SET( CMAKE_CXX_FLAGS "-Ofast -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp index 2f31364e..75b8c52d 100644 --- a/examples/cpp/example_mt_search_bf16.cpp +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -425,7 +425,7 @@ int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::Bf16InnerPr return 0; } int main() { - int true_dim=30; + int true_dim=1024; int dim = true_dim/2; // Dimension of the elements size_t max_elements = 10*1024; // Maximum number of elements, should be known beforehand int M = 32; // Tightly connected with internal dimensionality of the data diff --git a/hnswlib/hnswalg.h b/hnswlib/hnswalg.h index af20813f..4bd7a86c 100644 --- a/hnswlib/hnswalg.h +++ b/hnswlib/hnswalg.h @@ -1336,30 +1336,23 @@ class HierarchicalNSW : public AlgorithmInterface { enable_amx(); size_t dim=(size_t)(*(size_t *)dist_func_param_); - - //if(mydata==NULL){ - // mydata=(void**)malloc(sizeof(dist_t*)*maxM0_); - // res=(float*) malloc(maxM0_*sizeof(float)); - // memset(res,0,maxM0_*sizeof(float)); - //printf("We are 1443lines\n"); - //} - - //printf("we are here\n"); - for (int i= 0; i 0){ + for (int i= 0; i InnerProductSIMD16Ext = InnerProductSIMD16ExtSSE; static DISTFUNC InnerProductSIMD4Ext = InnerProductSIMD4ExtSSE; static DISTFUNC InnerProductDistanceSIMD16Ext = InnerProductDistanceSIMD16ExtSSE; static DISTFUNC InnerProductDistanceSIMD4Ext = InnerProductDistanceSIMD4ExtSSE; + static float InnerProductDistanceSIMD16ExtResiduals(const void *pVect1v, const void *pVect2v, const void *qty_ptr) { @@ -392,7 +393,73 @@ InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, } #endif + #if defined(USE_AMX) +static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { + uint16_t *x = (uint16_t *)a; + uint16_t *y = (uint16_t *)b; + // __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 + + size_t dim = * (size_t*) qty_ptr; + + float dot_product = 0.0f; + + for (int i=0; i < dim; i++) { + float x_val = bf162float(x[i]); + float y_val = bf162float(y[i]); + dot_product += x_val * y_val; + } + return 1-dot_product; +} + + + +static float InnerProductDistanceBf16AVX512(const void* a, const void* b, const void *qty_ptr) { + float result[16] = {0.0f}; // 用于存储中间结果 + + uint16_t *x = (uint16_t *)a; + uint16_t *y = (uint16_t *)b; + __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 + + size_t dim = * (size_t*) qty_ptr ; + + size_t i = 0; + // 每次处理32个元素(16个__bf16元素在__m512bh寄存器中存储为32个uint16_t) + for (; i + 32 <= dim; i += 32) { + // 加载32个uint16_t到__m512i类型的临时寄存器 + __m512i temp_x = _mm512_loadu_si512(x + i); + __m512i temp_y = _mm512_loadu_si512(y + i); + + // 强制转换为__m512bh类型 + __m512bh v1_f16 = reinterpret_cast<__m512bh&>(temp_x); + __m512bh v2_f16 = reinterpret_cast<__m512bh&>(temp_y); + + // 计算BF16的点积,并将结果累加到vr_f32 + vr_f32 = _mm512_dpbf16_ps(vr_f32, v1_f16, v2_f16); + } + + // 将vr_f32寄存器的值存入result数组 + _mm512_storeu_ps(result, vr_f32); + + // 累加result数组的所有元素,获得最终的点积结果 + float dot_product = 0.0f; + for (int j = 0; j < 16; j++) { + dot_product += result[j]; + } + + // 处理剩余的元素(小于32的部分) + for (; i < dim; i++) { + float x_val = bf162float(x[i]); + float y_val = bf162float(y[i]); + dot_product += x_val * y_val; + } + //printf("%d %f ",dim,dot_product); + return dot_product; +} +static float InnerProductDistanceBf16AVX512Ext(const void* a, const void* b, const void *qty_ptr){ + return 1.0f - InnerProductDistanceBf16AVX512(a, b, qty_ptr); +} + float amx_inner_product_matrix_fp32( char **floatLibraryMatrix, char *floatQueryMatrix, uint64_t dims,uint64_t batchSizeA, uint64_t batchSizeB, float *results){ int DIM=32; @@ -548,8 +615,8 @@ float amx_inner_product_matrix_fp32( char **floatLibraryMatrix, char *floatQuer } - _tile_stored(2, results, batchSizeB*2*2); - _tile_zero(2); + // _tile_stored(2, results, batchSizeB*2*2); + // _tile_zero(2); // printf("tailCount=%d\n",tailCount); if (tailCount != 0) { @@ -754,8 +821,28 @@ static float InnerProductDistanceBatchExtAMXBF16(const void **pVect1v, const voi } return 0; } -static AMXDISTFUNC InnerProductBatchExt = InnerProductBatchExtAMX; + +static float +InnerProductDistanceBatchExtAMXBF16Residuals(const void **pVect1v, const void *pVect2v, const void *qty_ptr, size_t nSize, size_t mSize, float * results_amx) { + size_t qty = *((size_t *) qty_ptr); + size_t qty32 = qty >> 5 << 5; + + InnerProductBatchExtAMXBF16(pVect1v, pVect2v, &qty32,nSize,mSize,results_amx); + + size_t qty_left = qty - qty32; + + float *pVect2 = (float *) pVect2v + qty32; + for(size_t i = 0; i < nSize; i++) { + float *pVect1 = (float *) pVect1v[i] + qty32; + results_amx[i] += InnerProductDistanceBf16AVX512(pVect1, pVect2, &qty_left); + } + for(size_t i = 0; i < nSize; i++) { + results_amx[i] = 1.0f - results_amx[i]; + } + return 0; +} static AMXDISTFUNC InnerProductDistanceBatchExt = InnerProductDistanceBatchExtAMX; +static DISTFUNC InnerProductDistanceBF16Ext = InnerProductDistanceBf16; #endif class InnerProductSpace : public SpaceInterface { @@ -805,11 +892,9 @@ class InnerProductSpace : public SpaceInterface { #endif #if defined(USE_AMX) if (AMXCapable()) { - InnerProductBatchExt = InnerProductBatchExtAMX; - InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMX; + amxdistfunc_=InnerProductDistanceBatchExtAMX; } - amxdistfunc_ = InnerProductDistanceBatchExt; #endif dim_ = dim; data_size_ = dim * sizeof(float); @@ -836,64 +921,6 @@ class InnerProductSpace : public SpaceInterface { }; -static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { - uint16_t *x = (uint16_t *)a; - uint16_t *y = (uint16_t *)b; - // __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 - - size_t dim = * (size_t*) qty_ptr; - - float dot_product = 0.0f; - - for (int i=0; i < dim; i++) { - float x_val = bf162float(x[i]); - float y_val = bf162float(y[i]); - dot_product += x_val * y_val; - } - return 1-dot_product; -} -static float InnerProductDistanceBf16AVX512(const void* a, const void* b, const void *qty_ptr) { - float result[16] = {0.0f}; // 用于存储中间结果 - - uint16_t *x = (uint16_t *)a; - uint16_t *y = (uint16_t *)b; - __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 - - size_t dim = * (size_t*) qty_ptr ; - - size_t i = 0; - // 每次处理32个元素(16个__bf16元素在__m512bh寄存器中存储为32个uint16_t) - for (; i + 32 <= dim; i += 32) { - // 加载32个uint16_t到__m512i类型的临时寄存器 - __m512i temp_x = _mm512_loadu_si512(x + i); - __m512i temp_y = _mm512_loadu_si512(y + i); - - // 强制转换为__m512bh类型 - __m512bh v1_f16 = reinterpret_cast<__m512bh&>(temp_x); - __m512bh v2_f16 = reinterpret_cast<__m512bh&>(temp_y); - - // 计算BF16的点积,并将结果累加到vr_f32 - vr_f32 = _mm512_dpbf16_ps(vr_f32, v1_f16, v2_f16); - } - - // 将vr_f32寄存器的值存入result数组 - _mm512_storeu_ps(result, vr_f32); - - // 累加result数组的所有元素,获得最终的点积结果 - float dot_product = 0.0f; - for (int j = 0; j < 16; j++) { - dot_product += result[j]; - } - - // 处理剩余的元素(小于32的部分) - for (; i < dim; i++) { - float x_val = bf162float(x[i]); - float y_val = bf162float(y[i]); - dot_product += x_val * y_val; - } - //printf("%d %f ",dim,dot_product); - return 1 - dot_product; -} class Bf16InnerProductSpace : public hnswlib::SpaceInterface { DISTFUNC fstdistfunc_; #ifdef USE_AMX @@ -907,22 +934,26 @@ class Bf16InnerProductSpace : public hnswlib::SpaceInterface { #if defined(USE_AVX) || defined(USE_SSE) || defined(USE_AVX512) #if defined(USE_AVX512) if (AVX512Capable()) { - InnerProductSIMD16Ext = InnerProductDistanceBf16AVX512; - InnerProductDistanceSIMD16Ext = InnerProductDistanceBf16AVX512; + //InnerProductSIMD16Ext = InnerProductDistanceBf16AVX512; + InnerProductDistanceBF16Ext = InnerProductDistanceBf16AVX512Ext; } else if (AVXCapable()) { - InnerProductSIMD16Ext = InnerProductDistanceBf16; - InnerProductDistanceSIMD16Ext = InnerProductDistanceBf16; + //InnerProductSIMD16Ext = InnerProductDistanceBf16; + InnerProductDistanceBF16Ext = InnerProductDistanceBf16; } #else - InnerProductSIMD16Ext = InnerProductDistanceBf16; - InnerProductDistanceSIMD16Ext = InnerProductDistanceBf16; + //InnerProductSIMD16Ext = InnerProductDistanceBf16; + InnerProductDistanceBF16Ext = InnerProductDistanceBf16; #endif - fstdistfunc_=InnerProductDistanceSIMD16Ext; + fstdistfunc_=InnerProductDistanceBF16Ext; #endif #if defined(USE_AMX) if (AMXCapable()) { - InnerProductBatchExt = InnerProductBatchExtAMXBF16; - InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMXBF16; + if (dim%32!=0){ + InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMXBF16Residuals; + }else{ + InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMXBF16; + } + } amxdistfunc_ = InnerProductDistanceBatchExt; From 3bca8eb23f9d1d0d1bc510f8f305c39d8478905f Mon Sep 17 00:00:00 2001 From: ruclz Date: Tue, 8 Apr 2025 12:17:13 +0800 Subject: [PATCH 10/14] fix memory leak --- examples/cpp/example_mt_search_bf16.cpp | 2 +- hnswlib/space_ip.h | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp index 75b8c52d..a26874a0 100644 --- a/examples/cpp/example_mt_search_bf16.cpp +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -425,7 +425,7 @@ int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::Bf16InnerPr return 0; } int main() { - int true_dim=1024; + int true_dim=128; int dim = true_dim/2; // Dimension of the elements size_t max_elements = 10*1024; // Maximum number of elements, should be known beforehand int M = 32; // Tightly connected with internal dimensionality of the data diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index 96699902..b72ea0ba 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -670,7 +670,7 @@ static float InnerProductBatchExtAMX(const void **pVect1v, const void *pVect2v, } float amx_inner_product_matrix_bf16( char **floatLibraryMatrix, char *floatQueryMatrix, uint64_t dims,uint64_t batchSizeA, - uint64_t batchSizeB, float *results){ + uint64_t batchSizeB, float *results_ptr){ int DIM=32; int blockDim = 96; int blockCount=((dims))/blockDim; @@ -683,6 +683,8 @@ float amx_inner_product_matrix_bf16( char **floatLibraryMatrix, char *floatQuer unsigned char ma1Bf16[1024] __attribute__((aligned(64))); unsigned char ma2Bf16[1024] __attribute__((aligned(64))); unsigned char ma3Bf16[1024] __attribute__((aligned(64))); + + float results[16*16] __attribute__((aligned(64)))={0}; if(!init_mem){ cfg[0]=1; @@ -721,6 +723,7 @@ float amx_inner_product_matrix_bf16( char **floatLibraryMatrix, char *floatQuer for(int j=0;j Date: Tue, 8 Apr 2025 16:59:23 +0800 Subject: [PATCH 11/14] fix bf16 avx512 bug --- hnswlib/space_ip.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index b72ea0ba..ba435a7b 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -835,9 +835,9 @@ InnerProductDistanceBatchExtAMXBF16Residuals(const void **pVect1v, const void *p size_t qty_left = qty - qty32; - float *pVect2 = (float *) pVect2v + qty32; + uint16_t *pVect2 = (uint16_t *) pVect2v + qty32; for(size_t i = 0; i < nSize; i++) { - float *pVect1 = (float *) pVect1v[i] + qty32; + uint16_t *pVect1 = (uint16_t *) pVect1v[i] + qty32; results_amx[i] += InnerProductDistanceBf16AVX512(pVect1, pVect2, &qty_left); } for(size_t i = 0; i < nSize; i++) { From 64dc857b63be42e98db47ef6f257b8dde0f1db3e Mon Sep 17 00:00:00 2001 From: ruclz Date: Fri, 11 Apr 2025 10:01:18 +0800 Subject: [PATCH 12/14] fix fp32 bug --- CMakeLists.txt | 2 +- examples/cpp/example_mt_search_bf16.cpp | 2 +- hnswlib/space_ip.h | 61 +++++++++++++------------ 3 files changed, 34 insertions(+), 31 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf76ac64..f97b55ff 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-Ofast -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) + SET( CMAKE_CXX_FLAGS "-O0 -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() diff --git a/examples/cpp/example_mt_search_bf16.cpp b/examples/cpp/example_mt_search_bf16.cpp index a26874a0..75b8c52d 100644 --- a/examples/cpp/example_mt_search_bf16.cpp +++ b/examples/cpp/example_mt_search_bf16.cpp @@ -425,7 +425,7 @@ int call_AMX_bf16(hnswlib::HierarchicalNSW* alg_hnsw,hnswlib::Bf16InnerPr return 0; } int main() { - int true_dim=128; + int true_dim=1024; int dim = true_dim/2; // Dimension of the elements size_t max_elements = 10*1024; // Maximum number of elements, should be known beforehand int M = 32; // Tightly connected with internal dimensionality of the data diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index ba435a7b..0c7e3f5b 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -392,28 +392,6 @@ InnerProductDistanceSIMD4ExtResiduals(const void *pVect1v, const void *pVect2v, return 1.0f - (res + res_tail); } #endif - - -#if defined(USE_AMX) -static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { - uint16_t *x = (uint16_t *)a; - uint16_t *y = (uint16_t *)b; - // __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 - - size_t dim = * (size_t*) qty_ptr; - - float dot_product = 0.0f; - - for (int i=0; i < dim; i++) { - float x_val = bf162float(x[i]); - float y_val = bf162float(y[i]); - dot_product += x_val * y_val; - } - return 1-dot_product; -} - - - static float InnerProductDistanceBf16AVX512(const void* a, const void* b, const void *qty_ptr) { float result[16] = {0.0f}; // 用于存储中间结果 @@ -456,9 +434,32 @@ static float InnerProductDistanceBf16AVX512(const void* a, const void* b, const //printf("%d %f ",dim,dot_product); return dot_product; } + static float InnerProductDistanceBf16AVX512Ext(const void* a, const void* b, const void *qty_ptr){ return 1.0f - InnerProductDistanceBf16AVX512(a, b, qty_ptr); } +static float InnerProductDistanceBf16(const void* a, const void* b, const void *qty_ptr) { + uint16_t *x = (uint16_t *)a; + uint16_t *y = (uint16_t *)b; + // __m512 vr_f32 = _mm512_setzero_ps(); // 初始化累积寄存器为0 + + size_t dim = * (size_t*) qty_ptr; + + float dot_product = 0.0f; + + for (int i=0; i < dim; i++) { + float x_val = bf162float(x[i]); + float y_val = bf162float(y[i]); + dot_product += x_val * y_val; + } + return 1-dot_product; +} +#if defined(USE_AMX) + + + + + float amx_inner_product_matrix_fp32( char **floatLibraryMatrix, char *floatQueryMatrix, uint64_t dims,uint64_t batchSizeA, uint64_t batchSizeB, float *results){ @@ -615,8 +616,8 @@ float amx_inner_product_matrix_fp32( char **floatLibraryMatrix, char *floatQuer } - // _tile_stored(2, results, batchSizeB*2*2); - // _tile_zero(2); + _tile_stored(2, results, batchSizeB*2*2); + _tile_zero(2); // printf("tailCount=%d\n",tailCount); if (tailCount != 0) { @@ -781,6 +782,7 @@ InnerProductDistanceBatchExtAMX(const void **pVect1v, const void *pVect2v, const InnerProductBatchExtAMX(pVect1v, pVect2v, qty_ptr,nSize,mSize,results_amx); for(int i=0;i InnerProductDistanceBatchExt = InnerProductDistanceBatchExtAMX; -static DISTFUNC InnerProductDistanceBF16Ext = InnerProductDistanceBf16; -#endif +static AMXDISTFUNC InnerProductDistanceBatchBf16Ext = InnerProductDistanceBatchExtAMXBF16; +#endif +static DISTFUNC InnerProductDistanceBF16Ext = InnerProductDistanceBf16; class InnerProductSpace : public SpaceInterface { DISTFUNC fstdistfunc_; #ifdef USE_AMX @@ -953,14 +956,14 @@ class Bf16InnerProductSpace : public hnswlib::SpaceInterface { #if defined(USE_AMX) if (AMXCapable()) { if (dim%32!=0){ - InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMXBF16Residuals; + InnerProductDistanceBatchBf16Ext=InnerProductDistanceBatchExtAMXBF16Residuals; }else{ - InnerProductDistanceBatchExt=InnerProductDistanceBatchExtAMXBF16; + InnerProductDistanceBatchBf16Ext=InnerProductDistanceBatchExtAMXBF16; } } - amxdistfunc_ = InnerProductDistanceBatchExt; + amxdistfunc_ = InnerProductDistanceBatchBf16Ext; #endif dim_ = dim ; data_size_ = dim * sizeof(uint16_t); From 05aaed95aaa9b353e07f9aba380228a1879eada5 Mon Sep 17 00:00:00 2001 From: ruclz Date: Mon, 14 Apr 2025 12:31:21 +0800 Subject: [PATCH 13/14] ofast --- CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index f97b55ff..cf76ac64 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -49,7 +49,7 @@ if(HNSWLIB_EXAMPLES) endif() endif() elseif (CMAKE_CXX_COMPILER_ID STREQUAL "GNU") - SET( CMAKE_CXX_FLAGS "-O0 -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) + SET( CMAKE_CXX_FLAGS "-Ofast -g -lrt -std=c++11 -DHAVE_CXX0X -march=native -fpic -w -fopenmp -ftree-vectorize -ftree-vectorizer-verbose=0 " ) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") SET( CMAKE_CXX_FLAGS "/O2 -DHAVE_CXX0X /W1 /openmp /EHsc" ) endif() From 3e4fe7d154d7b727a32c3160b6ec5b2e2d2613f9 Mon Sep 17 00:00:00 2001 From: ruclz Date: Tue, 27 May 2025 12:40:40 +0800 Subject: [PATCH 14/14] avx512bf16 opt --- examples/python/example_search.py | 2 +- hnswlib/hnswlib.h | 2 +- hnswlib/space_ip.h | 10 ++-------- 3 files changed, 4 insertions(+), 10 deletions(-) diff --git a/examples/python/example_search.py b/examples/python/example_search.py index 4581843b..875d6939 100644 --- a/examples/python/example_search.py +++ b/examples/python/example_search.py @@ -15,7 +15,7 @@ ids = np.arange(num_elements) # Declaring index -p = hnswlib.Index(space='l2', dim=dim) # possible options are l2, cosine or ip +p = hnswlib.Index(space='ip', dim=dim) # possible options are l2, cosine or ip # Initializing index - the maximum number of elements should be known beforehand p.init_index(max_elements=num_elements, ef_construction=200, M=16) diff --git a/hnswlib/hnswlib.h b/hnswlib/hnswlib.h index 17e2c274..456d6f9b 100644 --- a/hnswlib/hnswlib.h +++ b/hnswlib/hnswlib.h @@ -24,7 +24,7 @@ // #define USE_AMX // #endif -#define USE_AMX +// #define USE_AMX #define BF16_SUPPORT #if defined(USE_AVX) || defined(USE_SSE) diff --git a/hnswlib/space_ip.h b/hnswlib/space_ip.h index 0c7e3f5b..ec2e8be0 100644 --- a/hnswlib/space_ip.h +++ b/hnswlib/space_ip.h @@ -415,15 +415,9 @@ static float InnerProductDistanceBf16AVX512(const void* a, const void* b, const // 计算BF16的点积,并将结果累加到vr_f32 vr_f32 = _mm512_dpbf16_ps(vr_f32, v1_f16, v2_f16); } - - // 将vr_f32寄存器的值存入result数组 - _mm512_storeu_ps(result, vr_f32); - + // 累加result数组的所有元素,获得最终的点积结果 - float dot_product = 0.0f; - for (int j = 0; j < 16; j++) { - dot_product += result[j]; - } + float dot_product = _mm512_reduce_add_ps(vr_f32); // 处理剩余的元素(小于32的部分) for (; i < dim; i++) {