Skip to content

add marlin moe wint4 #2629

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,121 changes: 1,121 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.cu

Large diffs are not rendered by default.

37 changes: 37 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_gemm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
#pragma once
#ifndef MARLIN_NAMESPACE_NAME
#define MARLIN_NAMESPACE_NAME marlin_moe_wna16
#endif

#include "paddle/phi/api/include/api.h"
#include "paddle/phi/core/enforce.h"

#include "moe/moe_wna16_marlin_utils/kernel.h"
#include "moe/moe_wna16_marlin_utils/types.h"

std::vector<paddle::Tensor> MoeWna16MarlinGemmApi(
const paddle::Tensor& a,
const paddle::optional<paddle::Tensor>& c_or_none,
const paddle::Tensor& b_q_weight,
const paddle::Tensor& b_scales,
const paddle::optional<paddle::Tensor>& global_scale_or_none,
const paddle::optional<paddle::Tensor>& b_zeros_or_none,
const paddle::optional<paddle::Tensor>& g_idx_or_none,
const paddle::optional<paddle::Tensor>& perm_or_none,
const paddle::Tensor& workspace,
const paddle::Tensor& sorted_token_ids,
const paddle::Tensor& expert_ids,
const paddle::Tensor& num_tokens_post_padded,
const paddle::Tensor& topk_weights,
int64_t moe_block_size,
int64_t top_k,
bool mul_topk_weights,
bool is_ep,
const std::string& b_q_type_str,
int64_t size_m,
int64_t size_n,
int64_t size_k,
bool is_k_full,
bool use_atomic_add,
bool use_fp32_reduce,
bool is_zp_float);
63 changes: 63 additions & 0 deletions custom_ops/gpu_ops/moe/moe_wna16_marlin_utils/CUDAStream.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include "glog/logging.h"
#include "paddle/phi/api/include/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/cuda_stream.h"

namespace MARLIN_NAMESPACE_NAME {

using DeviceIndex = int8_t;
using StreamId = int64_t;

class CUDAStream {
public:
CUDAStream() {}
explicit CUDAStream(const cudaStream_t& stream) : raw_stream_(stream) {}
StreamId id() const { return reinterpret_cast<StreamId>(raw_stream_); }

operator cudaStream_t() const { return raw_stream_; }

const cudaStream_t& raw_stream() const { return raw_stream_; }

private:
cudaStream_t raw_stream_;
};

/**
* Get the current CUDA stream, for the passed CUDA device, or for the
* current device if no device index is passed. The current CUDA stream
* will usually be the default CUDA stream for the device, but it may
* be different if someone called 'setCurrentCUDAStream' or used 'StreamGuard'
* or 'CUDAStreamGuard'.
*/
inline CUDAStream getCurrentCUDAStream(DeviceIndex device_index = -1) {
if (device_index == -1) {
device_index = phi::backends::gpu::GetCurrentDeviceId();
}

return CUDAStream(
paddle::GetCurrentCUDAStream(phi::GPUPlace(device_index))->raw_stream());
// LOG(FATAL) << "getCurrentCUDAStream is not implemented";
// return *(CUDAStream*)nullptr;
}

cudaStream_t GetCalcStreamFromGroup(int context_ring_id);

cudaStream_t GetCommStreamFromGroup(int context_ring_id);
} // namespace MARLIN_NAMESPACE_NAME
Loading