Skip to content

Wint4 #2628

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open

Wint4 #2628

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,088 changes: 1,088 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif

#include "paddle/phi/api/include/api.h"
#include "paddle/phi/core/enforce.h"

#include "moe/moe_wna16_marlin_utils/kernel.h"
#include "moe/moe_wna16_marlin_utils/types.h"

paddle::Tensor moe_wna16_marlin_gemm(
const paddle::Tensor& a,
const std::optional<paddle::Tensor>& c_or_none,
const paddle::Tensor& b_q_weight,
const paddle::Tensor& b_scales,
const std::optional<paddle::Tensor>& global_scale_or_none,
const std::optional<paddle::Tensor>& b_zeros_or_none,
const std::optional<paddle::Tensor>& g_idx_or_none,
const std::optional<paddle::Tensor>& perm_or_none,
const paddle::Tensor& workspace,
const paddle::Tensor& sorted_token_ids,
const paddle::Tensor& expert_ids,
const paddle::Tensor& num_tokens_post_padded,
const paddle::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
const std::string& b_q_type_str,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
75 changes: 75 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDADataType.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <unordered_map>
#include <variant>
#include "paddle/common/overloaded.h"
#include "paddle/phi/backends/gpu/cuda/cuda_helper.h"

#include "moe/moe_wna16_marlin_utils/ScalarType.h"

namespace MARLIN_NAMESPACE_NAME {

template <phi::DataType phi_data_type>
struct PhiDataTypeImpl {
constexpr static phi::DataType value = phi_data_type;
};

using PhiDataType = std::variant<
#define MAKE_PHI_DATA_TYPE_CASE(_, phi_data_type) \
PhiDataTypeImpl<phi::phi_data_type>,
PD_FOR_EACH_DATA_TYPE(MAKE_PHI_DATA_TYPE_CASE)
PhiDataTypeImpl<phi::DataType::UNDEFINED>
#undef MAKE_PHI_DATA_TYPE_CASE
>;

inline PhiDataType ScalarTypeToPhiDataType(
const MARLIN_NAMESPACE_NAME::ScalarType& scalar_type) {
static std::unordered_map<MARLIN_NAMESPACE_NAME::ScalarType, PhiDataType> map = {
#define MAKE_PHI_DATA_TYPE_CONVERT_CASE(_, phi_data_type) \
{phi::phi_data_type, PhiDataTypeImpl<phi::phi_data_type>{}},
PD_FOR_EACH_DATA_TYPE(MAKE_PHI_DATA_TYPE_CONVERT_CASE)
#undef MAKE_PHI_DATA_TYPE_CONVERT_CASE
{phi::DataType::UNDEFINED,
PhiDataTypeImpl<phi::DataType::UNDEFINED>{}},
};
const auto iter = map.find(scalar_type);
if (iter == map.end()) {
LOG(FATAL) << "unsupported scalar type: " << static_cast<int>(scalar_type);
}
return iter->second;
}

inline cudaDataType_t ScalarTypeToCudaDataType(
const MARLIN_NAMESPACE_NAME::ScalarType& scalar_type) {
auto phi_data_type = detail::ScalarTypeToPhiDataType(scalar_type);
auto Converter = ::common::Overloaded{
[](detail::PhiDataTypeImpl<phi::DataType::PSTRING>) -> cudaDataType_t {
LOG(FATAL) << "unsupported scalar type: pstring";
return *(cudaDataType_t*)nullptr; // NOLINT
},
[](detail::PhiDataTypeImpl<phi::DataType::UNDEFINED>) -> cudaDataType_t {
LOG(FATAL) << "unsupported scalar type: undefined";
return *(cudaDataType_t*)nullptr; // NOLINT
},
[](auto phi_data_type_impl) -> cudaDataType_t {
using T = std::decay_t<decltype(phi_data_type_impl)>;
using CppT = typename phi::DataTypeToCppType<T::value>::type;
return phi::backends::gpu::ToCudaDataType<CppT>();
}};
return std::visit(Converter, phi_data_type);
}

} // namespace MARLIN_NAMESPACE_NAME
63 changes: 63 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "glog/logging.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/cuda_stream.h"

namespace MARLIN_NAMESPACE_NAME {

using DeviceIndex = int8_t;
using StreamId = int64_t;

class CUDAStream {
public:
CUDAStream() { LOG(FATAL) << "CUDAStream::CUDAStream() is not implemented"; }
explicit CUDAStream(const cudaStream_t& stream) : raw_stream_(stream) {}
StreamId id() const { return reinterpret_cast<StreamId>(raw_stream_); }

operator cudaStream_t() const { return raw_stream_; }

const cudaStream_t& raw_stream() const { return raw_stream_; }

private:
cudaStream_t raw_stream_;
};

/**
* Get the current CUDA stream, for the passed CUDA device, or for the
* current device if no device index is passed. The current CUDA stream
* will usually be the default CUDA stream for the device, but it may
* be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
* or 'CUDAStreamGuard'.
*/
inline CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1) {
if (device_index == -1) {
device_index = phi::backends::gpu::GetCurrentDeviceId();
}

return CUDAStream(
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream());
// LOG(FATAL) << "getCurrentCUDAStream is not implemented";
// return *(CUDAStream*)nullptr;
}

cudaStream_t GetCalcStreamFromGroup(int context_ring_id);

cudaStream_t GetCommStreamFromGroup(int context_ring_id);
} // namespace MARLIN_NAMESPACE_NAME
Loading