cuda : refactor into multiple files (#6269)

This commit is contained in:
slaren 2024-03-25 13:50:23 +01:00 committed by GitHub
parent ad3a0505e3
commit ae1f211ce2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
59 changed files with 9154 additions and 8987 deletions

View File

@ -12,6 +12,7 @@ Checks: >
-readability-implicit-bool-conversion,
-readability-magic-numbers,
-readability-uppercase-literal-suffix,
-readability-simplify-boolean-expr,
clang-analyzer-*,
-clang-analyzer-security.insecureAPI.DeprecatedOrUnsafeBufferHandling,
performance-*,

View File

@ -369,7 +369,9 @@ if (LLAMA_CUBLAS)
enable_language(CUDA)
set(GGML_HEADERS_CUDA ggml-cuda.h)
set(GGML_SOURCES_CUDA ggml-cuda.cu)
file(GLOB GGML_SOURCES_CUDA "ggml-cuda/*.cu")
list(APPEND GGML_SOURCES_CUDA "ggml-cuda.cu")
add_compile_definitions(GGML_USE_CUBLAS)
if (LLAMA_CUDA_FORCE_DMMV)
@ -519,7 +521,9 @@ if (LLAMA_HIPBLAS)
message(STATUS "HIP and hipBLAS found")
set(GGML_HEADERS_ROCM ggml-cuda.h)
set(GGML_SOURCES_ROCM ggml-cuda.cu)
file(GLOB GGML_SOURCES_ROCM "ggml-cuda/*.cu")
list(APPEND GGML_SOURCES_ROCM "ggml-cuda.cu")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
@ -543,7 +547,7 @@ if (LLAMA_HIPBLAS)
add_compile_definitions(GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
set_source_files_properties(${GGML_SOURCES_ROCM} PROPERTIES LANGUAGE CXX)
if (LLAMA_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")

View File

@ -398,6 +398,7 @@ ifdef LLAMA_CUBLAS
MK_CPPFLAGS += -DGGML_USE_CUBLAS -I$(CUDA_PATH)/include -I$(CUDA_PATH)/targets/$(UNAME_M)-linux/include
MK_LDFLAGS += -lcuda -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L$(CUDA_PATH)/lib64 -L/usr/lib64 -L$(CUDA_PATH)/targets/$(UNAME_M)-linux/lib -L/usr/lib/wsl/lib
OBJS += ggml-cuda.o
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
MK_NVCCFLAGS += -use_fast_math
ifdef LLAMA_FATAL_WARNINGS
MK_NVCCFLAGS += -Werror all-warnings
@ -458,12 +459,23 @@ endif # LLAMA_CUDA_NO_PEER_COPY
ifdef LLAMA_CUDA_CCBIN
MK_NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN)
endif
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-common.h
ifdef JETSON_EOL_MODULE_DETECT
define NVCC_COMPILE
$(NVCC) -I. -Icommon -D_XOPEN_SOURCE=600 -D_GNU_SOURCE -DNDEBUG -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I/usr/local/cuda/targets/aarch64-linux/include -std=c++11 -O3 $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
endef # NVCC_COMPILE
else
define NVCC_COMPILE
$(NVCC) $(NVCCFLAGS) $(CPPFLAGS) -Xcompiler "$(CUDA_CXXFLAGS)" -c $< -o $@
endef # NVCC_COMPILE
endif # JETSON_EOL_MODULE_DETECT
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
$(NVCC_COMPILE)
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
$(NVCC_COMPILE)
endif # LLAMA_CUBLAS
ifdef LLAMA_CLBLAST
@ -510,7 +522,6 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
endif # LLAMA_VULKAN
ifdef LLAMA_HIPBLAS
ifeq ($(wildcard /opt/rocm),)
ROCM_PATH ?= /usr
GPU_TARGETS ?= $(shell $(shell which amdgpu-arch))
@ -539,8 +550,13 @@ ifdef LLAMA_CUDA_NO_PEER_COPY
HIPFLAGS += -DGGML_CUDA_NO_PEER_COPY
endif # LLAMA_CUDA_NO_PEER_COPY
OBJS += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
OBJS += $(patsubst %.cu,%.o,$(wildcard ggml-cuda/*.cu))
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml.h ggml-backend.h ggml-backend-impl.h ggml-common.h $(wildcard ggml-cuda/*.cuh)
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
ggml-cuda/%.o: ggml-cuda/%.cu ggml-cuda/%.cuh ggml.h ggml-common.h ggml-cuda/common.cuh
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
endif # LLAMA_HIPBLAS
ifdef LLAMA_METAL
@ -687,6 +703,7 @@ libllama.a: llama.o ggml.o $(OBJS) $(COMMON_DEPS)
clean:
rm -vrf *.o tests/*.o *.so *.a *.dll benchmark-matmult lookup-create lookup-merge lookup-stats common/build-info.cpp *.dot $(COV_TARGETS) $(BUILD_TARGETS) $(TEST_TARGETS)
rm -vrf ggml-cuda/*.o
find examples pocs -type f -name "*.o" -delete
#

File diff suppressed because it is too large Load Diff

47
ggml-cuda/acc.cu Normal file
View File

@ -0,0 +1,47 @@
#include "acc.cuh"
static __global__ void acc_f32(const float * x, const float * y, float * dst, const int ne,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, int offset) {
const int i = blockDim.x * blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
int src1_idx = i - offset;
int oz = src1_idx / nb2;
int oy = (src1_idx - (oz * nb2)) / nb1;
int ox = src1_idx % nb1;
if (src1_idx >= 0 && ox < ne10 && oy < ne11 && oz < ne12) {
dst[i] = x[i] + y[ox + oy * ne10 + oz * ne10 * ne11];
} else {
dst[i] = x[i];
}
}
static void acc_f32_cuda(const float * x, const float * y, float * dst, const int n_elements,
const int ne10, const int ne11, const int ne12,
const int nb1, const int nb2, const int offset, cudaStream_t stream) {
int num_blocks = (n_elements + CUDA_ACC_BLOCK_SIZE - 1) / CUDA_ACC_BLOCK_SIZE;
acc_f32<<<num_blocks, CUDA_ACC_BLOCK_SIZE, 0, stream>>>(x, y, dst, n_elements, ne10, ne11, ne12, nb1, nb2, offset);
}
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->ne[3] == 1); // just 3D tensors supported
int nb1 = dst->op_params[0] / 4; // 4 bytes of float32
int nb2 = dst->op_params[1] / 4; // 4 bytes of float32
// int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused
int offset = dst->op_params[3] / 4; // offset in bytes
acc_f32_cuda(src0_d, src1_d, dst_d, ggml_nelements(dst), src1->ne[0], src1->ne[1], src1->ne[2], nb1, nb2, offset, stream);
}

5
ggml-cuda/acc.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ACC_BLOCK_SIZE 256
void ggml_cuda_op_acc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

63
ggml-cuda/alibi.cu Normal file
View File

@ -0,0 +1,63 @@
#include "alibi.cuh"
static __global__ void alibi_f32(const float * x, float * dst, const int ncols, const int k_rows,
const int n_heads_log2_floor, const float m0, const float m1) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const int k = row/k_rows;
float m_k;
if (k < n_heads_log2_floor) {
m_k = powf(m0, k + 1);
} else {
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
}
dst[i] = col * m_k + x[i];
}
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
const int k_rows, const int n_heads_log2_floor, const float m0,
const float m1, cudaStream_t stream) {
const dim3 block_dims(CUDA_ALIBI_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + CUDA_ALIBI_BLOCK_SIZE - 1) / (CUDA_ALIBI_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
alibi_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, k_rows, n_heads_log2_floor, m0, m1);
}
void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_head = ((int32_t *) dst->op_params)[1];
float max_bias;
memcpy(&max_bias, (int32_t *) dst->op_params + 2, sizeof(float));
//GGML_ASSERT(ne01 + n_past == ne00);
GGML_ASSERT(n_head == ne02);
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias) / n_heads_log2_floor);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_heads_log2_floor);
alibi_f32_cuda(src0_d, dst_d, ne00, nrows, ne01, n_heads_log2_floor, m0, m1, stream);
}

5
ggml-cuda/alibi.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ALIBI_BLOCK_SIZE 32
void ggml_cuda_op_alibi(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

34
ggml-cuda/arange.cu Normal file
View File

@ -0,0 +1,34 @@
#include "arange.cuh"
static __global__ void arange_f32(float * dst, const int ne0, const float start, const float step) {
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
dst[nidx] = start + step * nidx;
}
static void arange_f32_cuda(float * dst, const int ne0, const float start, const float step, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_ARANGE_BLOCK_SIZE - 1) / CUDA_ARANGE_BLOCK_SIZE;
arange_f32<<<num_blocks, CUDA_ARANGE_BLOCK_SIZE, 0, stream>>>(dst, ne0, start, step);
}
void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(dst->type == GGML_TYPE_F32);
float start;
float stop;
float step;
memcpy(&start, (float *)dst->op_params + 0, sizeof(float));
memcpy(&stop, (float *)dst->op_params + 1, sizeof(float));
memcpy(&step, (float *)dst->op_params + 2, sizeof(float));
int64_t steps = (int64_t)ceil((stop - start) / step);
GGML_ASSERT(ggml_nelements(dst) == steps);
arange_f32_cuda(dst_d, dst->ne[0], start, step, stream);
}

5
ggml-cuda/arange.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ARANGE_BLOCK_SIZE 256
void ggml_cuda_op_arange(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

77
ggml-cuda/argsort.cu Normal file
View File

@ -0,0 +1,77 @@
#include "argsort.cuh"
template<typename T>
static inline __device__ void ggml_cuda_swap(T & a, T & b) {
T tmp = a;
a = b;
b = tmp;
}
template<ggml_sort_order order>
static __global__ void k_argsort_f32_i32(const float * x, int * dst, const int ncols) {
// bitonic sort
int col = threadIdx.x;
int row = blockIdx.y;
if (col >= ncols) return;
const float * x_row = x + row * ncols;
int * dst_row = dst + row * ncols;
// initialize indices
if (col < ncols) {
dst_row[col] = col;
}
__syncthreads();
for (int k = 2; k <= ncols; k *= 2) {
for (int j = k / 2; j > 0; j /= 2) {
int ixj = col ^ j;
if (ixj > col) {
if ((col & k) == 0) {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
} else {
if (order == GGML_SORT_ORDER_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
ggml_cuda_swap(dst_row[col], dst_row[ixj]);
}
}
}
__syncthreads();
}
}
}
static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2
GGML_ASSERT((ncols & (ncols - 1)) == 0);
const dim3 block_dims(ncols, 1, 1);
const dim3 block_nums(1, nrows, 1);
if (order == GGML_SORT_ORDER_ASC) {
k_argsort_f32_i32<GGML_SORT_ORDER_ASC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else if (order == GGML_SORT_ORDER_DESC) {
k_argsort_f32_i32<GGML_SORT_ORDER_DESC><<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
} else {
GGML_ASSERT(false);
}
}
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream);
}

3
ggml-cuda/argsort.cuh Normal file
View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

236
ggml-cuda/binbcast.cu Normal file
View File

@ -0,0 +1,236 @@
#include "binbcast.cuh"
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
return b;
GGML_UNUSED(a);
}
static __device__ __forceinline__ float op_add(const float a, const float b) {
return a + b;
}
static __device__ __forceinline__ float op_mul(const float a, const float b) {
return a * b;
}
static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b;
}
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s10,*/ int s11, int s12, int s13) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) {
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
}
template<float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst,
int ne0, int ne1, int ne2, int ne3,
int ne10, int ne11, int ne12, int ne13,
/*int s0, */ int s1, int s2, int s3,
/*int s10,*/ int s11, int s12, int s13) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0);
const int i2 = (i/(ne1*ne0)) % ne2;
const int i1 = (i/ne0) % ne1;
const int i0 = i % ne0;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
const int i11 = i1 % ne11;
const int i12 = i2 % ne12;
const int i13 = i3 % ne13;
const size_t i_src0 = i3*s3 + i2*s2 + i1*s1;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
const size_t i_dst = i_src0;
const src0_t * src0_row = src0 + i_src0;
const src1_t * src1_row = src1 + i_src1;
dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10;
dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]);
}
template<float (*bin_op)(const float, const float)>
struct bin_bcast_cuda {
template<typename src0_t, typename src1_t, typename dst_t>
void operator()(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst,
const src0_t * src0_dd, const src1_t * src1_dd, dst_t * dst_dd,
cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
int nr0 = ne10/ne0;
int nr1 = ne11/ne1;
int nr2 = ne12/ne2;
int nr3 = ne13/ne3;
int nr[4] = { nr0, nr1, nr2, nr3 };
// collapse dimensions until first broadcast dimension
int64_t cne0[] = {ne0, ne1, ne2, ne3};
int64_t cne1[] = {ne10, ne11, ne12, ne13};
size_t cnb0[] = {nb0, nb1, nb2, nb3};
size_t cnb1[] = {nb10, nb11, nb12, nb13};
auto collapse = [](int64_t cne[]) {
cne[0] *= cne[1];
cne[1] = cne[2];
cne[2] = cne[3];
cne[3] = 1;
};
auto collapse_nb = [](size_t cnb[], const int64_t cne[]) {
cnb[1] *= cne[1];
cnb[2] *= cne[2];
cnb[3] *= cne[3];
};
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
}
if (i > 0) {
collapse_nb(cnb0, cne0);
collapse_nb(cnb1, cne1);
collapse(cne0);
collapse(cne1);
}
}
{
int64_t ne0 = cne0[0];
int64_t ne1 = cne0[1];
int64_t ne2 = cne0[2];
int64_t ne3 = cne0[3];
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb0[0];
size_t nb1 = cnb0[1];
size_t nb2 = cnb0[2];
size_t nb3 = cnb0[3];
size_t nb10 = cnb1[0];
size_t nb11 = cnb1[1];
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
size_t s10 = nb10 / sizeof(src1_t);
size_t s11 = nb11 / sizeof(src1_t);
size_t s12 = nb12 / sizeof(src1_t);
size_t s13 = nb13 / sizeof(src1_t);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0/2LL, 1LL);
dim3 block_dims;
block_dims.x = std::min<unsigned int>(hne0, block_size);
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2*ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums(
(hne0 + block_dims.x - 1) / block_dims.x,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2*ne3 + block_dims.z - 1) / block_dims.z
);
if (block_nums.z > 65535) {
// this is the maximum number of blocks in z direction, fallback to 1D grid kernel
int block_num = (ne0*ne1*ne2*ne3 + block_size - 1) / block_size;
k_bin_bcast_unravel<bin_op><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s10, */ s11, s12, s13);
} else {
k_bin_bcast<bin_op><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne0, ne1, ne2, ne3,
ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s10, */ s11, s12, s13);
}
}
}
};
template<class op>
static void ggml_cuda_op_bin_bcast(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const void * src1_dd, void * dst_dd, cudaStream_t stream) {
GGML_ASSERT(src1->type == GGML_TYPE_F32);
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const float *)src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (half *) dst_dd, stream);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) {
op()(src0, src1, dst, (const half *) src0_dd, (const float *)src1_dd, (float *)dst_dd, stream);
} else {
fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__,
ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
}
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_repeat>>(dst, dst->src[0], dst, nullptr, dst->src[0]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_add>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_mul>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_bin_bcast<bin_bcast_cuda<op_div>>(dst->src[0], dst->src[1], dst, dst->src[0]->data, dst->src[1]->data, dst->data, ctx.stream());
}

6
ggml-cuda/binbcast.cuh Normal file
View File

@ -0,0 +1,6 @@
#include "common.cuh"
void ggml_cuda_op_repeat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_add(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_mul(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_div(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

35
ggml-cuda/clamp.cu Normal file
View File

@ -0,0 +1,35 @@
#include "clamp.cuh"
static __global__ void clamp_f32(const float * x, float * dst, const float min, const float max, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] < min ? min : (x[i] > max ? max : x[i]);
}
static void clamp_f32_cuda(const float * x, float * dst, const float min, const float max, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_CLAMP_BLOCK_SIZE - 1) / CUDA_CLAMP_BLOCK_SIZE;
clamp_f32<<<num_blocks, CUDA_CLAMP_BLOCK_SIZE, 0, stream>>>(x, dst, min, max, k);
}
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float min;
float max;
memcpy(&min, dst->op_params, sizeof(float));
memcpy(&max, (float *) dst->op_params + 1, sizeof(float));
clamp_f32_cuda(src0_d, dst_d, min, max, ggml_nelements(src0), stream);
CUDA_CHECK(cudaGetLastError());
}

5
ggml-cuda/clamp.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CLAMP_BLOCK_SIZE 256
void ggml_cuda_op_clamp(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

550
ggml-cuda/common.cuh Normal file
View File

@ -0,0 +1,550 @@
#pragma once
#include "../ggml.h"
#include "../ggml-cuda.h"
#include <memory>
#if defined(GGML_USE_HIPBLAS)
#define GGML_COMMON_DECL_HIP
#define GGML_COMMON_IMPL_HIP
#else
#define GGML_COMMON_DECL_CUDA
#define GGML_COMMON_IMPL_CUDA
#endif
#include "../ggml-common.h"
#include <cstdio>
#include <array>
#include <cassert>
#include <cfloat>
#include <string>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate
#define cublasDestroy hipblasDestroy
#define cublasGemmEx hipblasGemmEx
#define cublasGemmBatchedEx hipblasGemmBatchedEx
#define cublasGemmStridedBatchedEx hipblasGemmStridedBatchedEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDataType_t hipblasDatatype_t //deprecated, new hipblasDatatype not in 5.6
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled
#define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEventSynchronize hipEventSynchronize
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaHostRegister hipHostRegister
#define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister
#define cudaLaunchHostFunc hipLaunchHostFunc
#ifdef GGML_HIP_UMA
#define cudaMalloc hipMallocManaged
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size)
#else
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#endif
#define cudaMemcpy hipMemcpy
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyPeerAsync hipMemcpyPeerAsync
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaMemsetAsync hipMemsetAsync
#define cudaMemGetInfo hipMemGetInfo
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamDestroy hipStreamDestroy
#define cudaStreamFireAndForget hipStreamFireAndForget
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#define __trap abort
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_STATUS_NOT_INITIALIZED HIPBLAS_STATUS_NOT_INITIALIZED
#define CUBLAS_STATUS_ALLOC_FAILED HIPBLAS_STATUS_ALLOC_FAILED
#define CUBLAS_STATUS_INVALID_VALUE HIPBLAS_STATUS_INVALID_VALUE
#define CUBLAS_STATUS_ARCH_MISMATCH HIPBLAS_STATUS_ARCH_MISMATCH
#define CUBLAS_STATUS_MAPPING_ERROR HIPBLAS_STATUS_MAPPING_ERROR
#define CUBLAS_STATUS_EXECUTION_FAILED HIPBLAS_STATUS_EXECUTION_FAILED
#define CUBLAS_STATUS_INTERNAL_ERROR HIPBLAS_STATUS_INTERNAL_ERROR
#define CUBLAS_STATUS_NOT_SUPPORTED HIPBLAS_STATUS_NOT_SUPPORTED
#else
#include <cuda_runtime.h>
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
#define CUBLAS_TF32_TENSOR_OP_MATH CUBLAS_TENSOR_OP_MATH
#define CUBLAS_COMPUTE_16F CUDA_R_16F
#define CUBLAS_COMPUTE_32F CUDA_R_32F
#define cublasComputeType_t cudaDataType_t
#endif // CUDART_VERSION < 11020
#endif // defined(GGML_USE_HIPBLAS)
#define STRINGIZE_IMPL(...) #__VA_ARGS__
#define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__)
#define WARP_SIZE 32
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
// define this if you want to always fallback to MMQ kernels and not use cuBLAS for matrix multiplication
// on modern hardware, using cuBLAS is recommended as it utilizes F16 tensor cores which are very performant
// for large computational tasks. the drawback is that this requires some extra amount of VRAM:
// - 7B quantum model: +100-200 MB
// - 13B quantum model: +200-400 MB
//
//#define GGML_CUDA_FORCE_MMQ
// TODO: improve this to be correct for more hardware
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
#if !defined(GGML_CUDA_FORCE_MMQ)
#define CUDA_USE_TENSOR_CORES
#endif
#define MMVQ_MAX_BATCH_SIZE 8 // max batch size to use MMVQ kernels
#define MMQ_MAX_BATCH_SIZE 32 // max batch size to use MMQ kernels when tensor cores are available
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define GGML_CUDA_MAX_STREAMS 8
[[noreturn]]
void ggml_cuda_error(const char * stmt, const char * func, const char * file, int line, const char * msg);
#define CUDA_CHECK_GEN(err, success, error_fn) \
do { \
auto err_ = (err); \
if (err_ != (success)) { \
ggml_cuda_error(#err, __func__, __FILE__, __LINE__, error_fn(err_)); \
} \
} while (0)
#define CUDA_CHECK(err) CUDA_CHECK_GEN(err, cudaSuccess, cudaGetErrorString)
#if CUDART_VERSION >= 12000
static const char * cublas_get_error_str(const cublasStatus_t err) {
return cublasGetStatusString(err);
}
#else
static const char * cublas_get_error_str(const cublasStatus_t err) {
switch (err) {
case CUBLAS_STATUS_SUCCESS: return "CUBLAS_STATUS_SUCCESS";
case CUBLAS_STATUS_NOT_INITIALIZED: return "CUBLAS_STATUS_NOT_INITIALIZED";
case CUBLAS_STATUS_ALLOC_FAILED: return "CUBLAS_STATUS_ALLOC_FAILED";
case CUBLAS_STATUS_INVALID_VALUE: return "CUBLAS_STATUS_INVALID_VALUE";
case CUBLAS_STATUS_ARCH_MISMATCH: return "CUBLAS_STATUS_ARCH_MISMATCH";
case CUBLAS_STATUS_MAPPING_ERROR: return "CUBLAS_STATUS_MAPPING_ERROR";
case CUBLAS_STATUS_EXECUTION_FAILED: return "CUBLAS_STATUS_EXECUTION_FAILED";
case CUBLAS_STATUS_INTERNAL_ERROR: return "CUBLAS_STATUS_INTERNAL_ERROR";
case CUBLAS_STATUS_NOT_SUPPORTED: return "CUBLAS_STATUS_NOT_SUPPORTED";
default: return "unknown error";
}
}
#endif // CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) CUDA_CHECK_GEN(err, CUBLAS_STATUS_SUCCESS, cublas_get_error_str)
#if !defined(GGML_USE_HIPBLAS)
static const char * cu_get_error_str(CUresult err) {
const char * err_str;
cuGetErrorString(err, &err_str);
return err_str;
}
#define CU_CHECK(err) CUDA_CHECK_GEN(err, CUDA_SUCCESS, cu_get_error_str)
#endif
#if CUDART_VERSION >= 11100
#define GGML_CUDA_ASSUME(x) __builtin_assume(x)
#else
#define GGML_CUDA_ASSUME(x)
#endif // CUDART_VERSION >= 11100
#ifdef GGML_CUDA_F16
typedef half dfloat; // dequantize float
typedef half2 dfloat2;
#else
typedef float dfloat; // dequantize float
typedef float2 dfloat2;
#endif //GGML_CUDA_F16
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
printf("%s:%d: ERROR: HIP kernel %s has no device code compatible with HIP arch %d.\n",
file_name, line, function_name, arch);
GGML_UNUSED(arch_list);
#else
printf("%s:%d: ERROR: CUDA kernel %s has no device code compatible with CUDA arch %d. ggml-cuda.cu was compiled for: %s\n",
file_name, line, function_name, arch, arch_list);
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
__trap();
GGML_UNUSED(no_device_code); // suppress unused function warning
}
#ifdef __CUDA_ARCH__
#define NO_DEVICE_CODE no_device_code(__FILE__, __LINE__, __FUNCTION__, __CUDA_ARCH__, STRINGIZE(__CUDA_ARCH_LIST__))
#else
#define NO_DEVICE_CODE //GGML_ASSERT(false && "NO_DEVICE_CODE not valid in host code.")
#endif // __CUDA_ARCH__
static __device__ __forceinline__ float warp_reduce_sum(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x += __shfl_xor_sync(0xffffffff, x, mask, 32);
}
return x;
}
static __device__ __forceinline__ float2 warp_reduce_sum(float2 a) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a.x += __shfl_xor_sync(0xffffffff, a.x, mask, 32);
a.y += __shfl_xor_sync(0xffffffff, a.y, mask, 32);
}
return a;
}
#ifdef GGML_CUDA_F16
static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32));
}
return a;
#else
GGML_UNUSED(a);
NO_DEVICE_CODE;
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
}
#endif // GGML_CUDA_F16
static __device__ __forceinline__ float warp_reduce_max(float x) {
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
x = fmaxf(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
}
return x;
}
//static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
//#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
//#pragma unroll
// for (int mask = 16; mask > 0; mask >>= 1) {
// x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
// }
// return x;
//#else
// GGML_UNUSED(x);
// NO_DEVICE_CODE;
//#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
//}
#if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300
#if defined(__gfx1100__) || defined(__gfx1101__) || defined(__gfx1102__) || defined(__gfx1103__) || \
defined(__gfx1150__) || defined(__gfx1151__)
#define RDNA3
#endif
#if defined(__gfx1030__) || defined(__gfx1031__) || defined(__gfx1032__) || defined(__gfx1033__) || \
defined(__gfx1034__) || defined(__gfx1035__) || defined(__gfx1036__) || defined(__gfx1037__)
#define RDNA2
#endif
#ifndef __has_builtin
#define __has_builtin(x) 0
#endif
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
typedef uint8_t uint8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
#if __has_builtin(__builtin_elementwise_sub_sat)
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int &>(c);
#else
int8x4_t c;
int16_t tmp;
#pragma unroll
for (int i = 0; i < 4; i++) {
tmp = va[i] - vb[i];
if(tmp > std::numeric_limits<int8_t>::max()) tmp = std::numeric_limits<int8_t>::max();
if(tmp < std::numeric_limits<int8_t>::min()) tmp = std::numeric_limits<int8_t>::min();
c[i] = tmp;
}
return reinterpret_cast<int &>(c);
#endif // __has_builtin(__builtin_elementwise_sub_sat)
}
static __device__ __forceinline__ int __vsub4(const int a, const int b) {
return __vsubss4(a, b);
}
static __device__ __forceinline__ unsigned int __vcmpeq4(unsigned int a, unsigned int b) {
const uint8x4_t& va = reinterpret_cast<const uint8x4_t&>(a);
const uint8x4_t& vb = reinterpret_cast<const uint8x4_t&>(b);
unsigned int c;
uint8x4_t& vc = reinterpret_cast<uint8x4_t&>(c);
#pragma unroll
for (int i = 0; i < 4; ++i) {
vc[i] = va[i] == vb[i] ? 0xff : 0x00;
}
return c;
}
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(RDNA3)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
int tmp2;
asm("\n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
v_add3_u32 %0, %1, %2, %0 \n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
v_add3_u32 %0, %1, %2, %0 \n \
"
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
: "v"(a), "v"(b)
);
#else
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
#endif
return c;
}
#endif // defined(GGML_USE_HIPBLAS)
// TODO: move to ggml-common.h
static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113};
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
//////////////////////
struct ggml_cuda_device_info {
int device_count;
struct cuda_device_info {
int cc; // compute capability
size_t smpb; // max. shared memory per block
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
size_t total_vram;
};
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
std::array<float, GGML_CUDA_MAX_DEVICES> default_tensor_split = {};
};
const ggml_cuda_device_info & ggml_cuda_info();
void ggml_cuda_set_device(int device);
int ggml_cuda_get_device();
struct ggml_cuda_pool {
virtual ~ggml_cuda_pool() = default;
virtual void * alloc(size_t size, size_t * actual_size) = 0;
virtual void free(void * ptr, size_t size) = 0;
};
template<typename T>
struct ggml_cuda_pool_alloc {
ggml_cuda_pool * pool = nullptr;
T * ptr = nullptr;
size_t actual_size = 0;
ggml_cuda_pool_alloc() = default;
explicit ggml_cuda_pool_alloc(ggml_cuda_pool & pool) : pool(&pool) {
}
ggml_cuda_pool_alloc(ggml_cuda_pool & pool, size_t size) : pool(&pool) {
alloc(size);
}
~ggml_cuda_pool_alloc() {
if (ptr != nullptr) {
pool->free(ptr, actual_size);
}
}
// size is in number of elements
T * alloc(size_t size) {
GGML_ASSERT(pool != nullptr);
GGML_ASSERT(ptr == nullptr);
ptr = (T *) pool->alloc(size * sizeof(T), &this->actual_size);
return ptr;
}
T * alloc(ggml_cuda_pool & pool, size_t size) {
this->pool = &pool;
return alloc(size);
}
T * get() {
return ptr;
}
ggml_cuda_pool_alloc(const ggml_cuda_pool_alloc &) = delete;
ggml_cuda_pool_alloc(ggml_cuda_pool_alloc &&) = delete;
ggml_cuda_pool_alloc& operator=(const ggml_cuda_pool_alloc &) = delete;
ggml_cuda_pool_alloc& operator=(ggml_cuda_pool_alloc &&) = delete;
};
// backend interface
struct ggml_tensor_extra_gpu {
void * data_device[GGML_CUDA_MAX_DEVICES]; // 1 pointer for each device for split tensors
cudaEvent_t events[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS]; // events for synchronizing multiple GPUs
};
struct ggml_backend_cuda_context {
int device;
std::string name;
cudaEvent_t copy_event = nullptr;
cudaStream_t streams[GGML_CUDA_MAX_DEVICES][GGML_CUDA_MAX_STREAMS] = { { nullptr } };
cublasHandle_t cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
explicit ggml_backend_cuda_context(int device) :
device(device),
name(GGML_CUDA_NAME + std::to_string(device)) {
}
~ggml_backend_cuda_context() {
if (copy_event != nullptr) {
CUDA_CHECK(cudaEventDestroy(copy_event));
}
for (int i = 0; i < GGML_CUDA_MAX_DEVICES; ++i) {
for (int j = 0; j < GGML_CUDA_MAX_STREAMS; ++j) {
if (streams[i][j] != nullptr) {
CUDA_CHECK(cudaStreamDestroy(streams[i][j]));
}
}
if (cublas_handles[i] != nullptr) {
CUBLAS_CHECK(cublasDestroy(cublas_handles[i]));
}
}
}
cudaStream_t stream(int device, int stream) {
if (streams[device][stream] == nullptr) {
ggml_cuda_set_device(device);
CUDA_CHECK(cudaStreamCreateWithFlags(&streams[device][stream], cudaStreamNonBlocking));
}
return streams[device][stream];
}
cudaStream_t stream() {
return stream(device, 0);
}
cublasHandle_t cublas_handle(int device) {
if (cublas_handles[device] == nullptr) {
ggml_cuda_set_device(device);
CUBLAS_CHECK(cublasCreate(&cublas_handles[device]));
CUBLAS_CHECK(cublasSetMathMode(cublas_handles[device], CUBLAS_TF32_TENSOR_OP_MATH));
}
return cublas_handles[device];
}
cublasHandle_t cublas_handle() {
return cublas_handle(device);
}
// pool
std::unique_ptr<ggml_cuda_pool> pools[GGML_CUDA_MAX_DEVICES];
static std::unique_ptr<ggml_cuda_pool> new_pool_for_device(int device);
ggml_cuda_pool & pool(int device) {
if (pools[device] == nullptr) {
pools[device] = new_pool_for_device(device);
}
return *pools[device];
}
ggml_cuda_pool & pool() {
return pool(device);
}
};

49
ggml-cuda/concat.cu Normal file
View File

@ -0,0 +1,49 @@
#include "concat.cuh"
static __global__ void concat_f32(const float * x,const float * y, float * dst, const int ne0, const int ne02) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (blockIdx.z < ne02) { // src0
int offset_src =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
dst[offset_dst] = x[offset_src];
} else {
int offset_src =
nidx +
blockIdx.y * ne0 +
(blockIdx.z - ne02) * ne0 * gridDim.y;
dst[offset_dst] = y[offset_src];
}
}
static void concat_f32_cuda(const float * x, const float * y, float * dst, const int ne0, int ne1, int ne2, int ne02, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_CONCAT_BLOCK_SIZE - 1) / CUDA_CONCAT_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2);
concat_f32<<<gridDim, CUDA_CONCAT_BLOCK_SIZE, 0, stream>>>(x, y, dst, ne0, ne02);
}
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
for (int i3 = 0; i3 < dst->ne[3]; i3++) {
concat_f32_cuda(src0_d + i3 * (src0->nb[3] / 4), src1_d + i3 * (src1->nb[3] / 4), dst_d + i3 * (dst->nb[3] / 4), dst->ne[0], dst->ne[1], dst->ne[2], src0->ne[2], stream);
}
}

5
ggml-cuda/concat.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_CONCAT_BLOCK_SIZE 256
void ggml_cuda_op_concat(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

783
ggml-cuda/convert.cu Normal file
View File

@ -0,0 +1,783 @@
#include "convert.cuh"
#include "dequantize.cuh"
#define CUDA_Q8_0_NE_ALIGN 2048
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (i >= k) {
return;
}
const int ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
dequantize_kernel(vx, ib, iqs, v);
y[iybs + iqs + 0] = v.x;
y[iybs + iqs + y_offset] = v.y;
}
template <bool need_check>
static __global__ void dequantize_block_q8_0_f16(const void * __restrict__ vx, half * __restrict__ y, const int k) {
#if __CUDA_ARCH__ >= CC_PASCAL
constexpr int nint = CUDA_Q8_0_NE_ALIGN/sizeof(int) + WARP_SIZE;
const int i0 = CUDA_Q8_0_NE_ALIGN*blockIdx.x;
const int * x0 = ((int *) vx) + blockIdx.x * nint;
half2 * y2 = (half2 *) (y + i0);
__shared__ int vals[nint];
#pragma unroll
for (int ix0 = 0; ix0 < nint; ix0 += WARP_SIZE) {
if (need_check && i0*sizeof(block_q8_0)/QK8_0 + sizeof(int)*(ix0 + threadIdx.x) >= k*sizeof(block_q8_0)/QK8_0) {
break;
}
const int ix = ix0 + threadIdx.x;
vals[ix] = x0[ix];
}
#pragma unroll
for (int iy = 0; iy < CUDA_Q8_0_NE_ALIGN; iy += 2*WARP_SIZE) {
if (need_check && i0 + iy + 2*threadIdx.x >= k) {
return;
}
const half * b0 = ((const half *) vals) + (sizeof(block_q8_0)/sizeof(half)) * ((iy + 2*threadIdx.x)/QK8_0);
const half d = *b0;
const char2 qs = ((const char2 *) (b0 + 1))[threadIdx.x % (QK8_0/2)];
y2[iy/2 + threadIdx.x] = __hmul2(make_half2(qs.x, qs.y), __half2half2(d));
}
#else
GGML_UNUSED(vx);
GGML_UNUSED(y);
GGML_UNUSED(k);
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ >= CC_PASCAL
}
template<typename dst_t>
static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
if (ib >= nb32) {
return;
}
dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_0 * x = (const block_q4_0 *)vx + ib;
const float d = __half2float(x->d);
const float dm = -8*d;
const uint8_t * q = x->qs + 4*il;
for (int l = 0; l < 4; ++l) {
y[l+ 0] = d * (q[l] & 0xF) + dm;
y[l+16] = d * (q[l] >> 4) + dm;
}
}
template<typename dst_t>
static __global__ void dequantize_block_q4_1(const void * __restrict__ vx, dst_t * __restrict__ yy, int nb32) {
const int i = blockIdx.x;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int ib = 8*i + ir;
if (ib >= nb32) {
return;
}
dst_t * y = yy + 256*i + 32*ir + 4*il;
const block_q4_1 * x = (const block_q4_1 *)vx + ib;
const float2 d = __half22float2(x->dm);
const uint8_t * q = x->qs + 4*il;
for (int l = 0; l < 4; ++l) {
y[l+ 0] = d.x * (q[l] & 0xF) + d.y;
y[l+16] = d.x * (q[l] >> 4) + d.y;
}
}
//================================== k-quants
template<typename dst_t>
static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_q2_K * x = (const block_q2_K *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;
const uint8_t q = x[i].qs[32*n + l];
dst_t * y = yy + i*QK_K + 128*n;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
#else
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is);
dst_t * y = yy + i*QK_K + 16*is + il;
float dall = __low2half(x[i].dm);
float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_q3_K * x = (const block_q3_K *) vx;
#if QK_K == 256
const int r = threadIdx.x/4;
const int tid = r/2;
const int is0 = r%2;
const int l0 = 16*is0 + 4*(threadIdx.x%4);
const int n = tid / 4;
const int j = tid - 4*n;
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int shift = 2*j;
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
float d_all = x[i].d;
float dl = d_all * (us - 32);
dst_t * y = yy + i*QK_K + 128*n + 32*j;
const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
#else
const int tid = threadIdx.x;
const int is = tid/16; // 0 or 1
const int il = tid%16; // 0...15
const int im = il/8; // 0...1
const int in = il%8; // 0...7
dst_t * y = yy + i*QK_K + 16*is + il;
const uint8_t q = x[i].qs[il] >> (2*is);
const uint8_t h = x[i].hmask[in] >> (2*is + im);
const float d = (float)x[i].d;
if (is == 0) {
y[ 0] = d * ((x[i].scales[0] & 0xF) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
y[32] = d * ((x[i].scales[1] & 0xF) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
} else {
y[ 0] = d * ((x[i].scales[0] >> 4) - 8) * ((int8_t)((q >> 0) & 3) - ((h >> 0) & 1 ? 0 : 4));
y[32] = d * ((x[i].scales[1] >> 4) - 8) * ((int8_t)((q >> 4) & 3) - ((h >> 4) & 1 ? 0 : 4));
}
#endif
}
#if QK_K == 256
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
} else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
#endif
template<typename dst_t>
static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x;
#if QK_K == 256
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int is = 2*il;
const int n = 4;
dst_t * y = yy + i*QK_K + 64*il + n*ir;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint8_t * q = x[i].qs + 32*il + n*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
#else
const int tid = threadIdx.x;
const uint8_t * q = x[i].qs;
dst_t * y = yy + i*QK_K;
const float d = (float)x[i].dm[0];
const float m = (float)x[i].dm[1];
y[tid+ 0] = d * (x[i].scales[0] & 0xF) * (q[tid] & 0xF) - m * (x[i].scales[0] >> 4);
y[tid+32] = d * (x[i].scales[1] & 0xF) * (q[tid] >> 4) - m * (x[i].scales[1] >> 4);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q5_K * x = (const block_q5_K *) vx;
const int i = blockIdx.x;
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int il = tid/16; // il is in 0...3
const int ir = tid%16; // ir is in 0...15
const int is = 2*il; // is is in 0...6
dst_t * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
uint8_t hm = 1 << (2*il);
y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
#else
const int tid = threadIdx.x;
const uint8_t q = x[i].qs[tid];
const int im = tid/8; // 0...3
const int in = tid%8; // 0...7
const int is = tid/16; // 0 or 1
const uint8_t h = x[i].qh[in] >> im;
const float d = x[i].d;
dst_t * y = yy + i*QK_K + tid;
y[ 0] = d * x[i].scales[is+0] * ((q & 0xF) - ((h >> 0) & 1 ? 0 : 16));
y[32] = d * x[i].scales[is+2] * ((q >> 4) - ((h >> 4) & 1 ? 0 : 16));
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x;
#if QK_K == 256
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16;
dst_t * y = yy + i*QK_K + 128*ip + il;
const float d = x[i].d;
const uint8_t * ql = x[i].ql + 64*ip + il;
const uint8_t qh = x[i].qh[32*ip + il];
const int8_t * sc = x[i].scales + is;
y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
#else
// assume 32 threads
const int tid = threadIdx.x;
const int ip = tid/16; // 0 or 1
const int il = tid - 16*ip; // 0...15
dst_t * y = yy + i*QK_K + 16*ip + il;
const float d = x[i].d;
const uint8_t ql = x[i].ql[16*ip + il];
const uint8_t qh = x[i].qh[il] >> (2*ip);
const int8_t * sc = x[i].scales;
y[ 0] = d * sc[ip+0] * ((int8_t)((ql & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = d * sc[ip+2] * ((int8_t)((ql >> 4) | (((qh >> 4) & 3) << 4)) - 32);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq2_xxs * x = (const block_iq2_xxs *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * aux8 = (const uint8_t *)q2;
const uint8_t * grid = (const uint8_t *)(iq2xxs_grid + aux8[il]);
const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq2_xs * x = (const block_iq2_xs *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint16_t * q2 = x[i].qs + 4*ib;
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq2_s * x = (const block_iq2_s *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
const float d = (float)x[i].d * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
#else
assert(false);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * q3 = x[i].qs + 8*ib;
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq3_s * x = (const block_iq3_s *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * qs = x[i].qs + 8*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3s_grid + (qs[2*il+0] | ((x[i].qh[ib] << (8-2*il)) & 256)));
const uint8_t * grid2 = (const uint8_t *)(iq3s_grid + (qs[2*il+1] | ((x[i].qh[ib] << (7-2*il)) & 256)));
const float d = (float)x[i].d * (1 + 2*((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf));
const uint8_t signs = x[i].signs[4*ib + il];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq1_s * x = (const block_iq1_s *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const float delta = x[i].qh[ib] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA;
const float d = (float)x[i].d * (2*((x[i].qh[ib] >> 12) & 7) + 1);
uint32_t grid32[2]; const int8_t * q = (const int8_t *)grid32;
grid32[0] = iq1s_grid_gpu[x[i].qs[4*ib+il] | (((x[i].qh[ib] >> 3*il) & 7) << 8)];
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) {
y[j] = d * (q[j] + delta);
}
#else
assert(false);
#endif
}
template<typename dst_t>
static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq4_nl * x = (const block_iq4_nl *) vx + i*(QK_K/QK4_NL);
const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[ib].qs + 4*il;
const float d = (float)x[ib].d;
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
#if QK_K != 64
template<typename dst_t>
static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq4_xs * x = (const block_iq4_xs *)vx;
const int tid = threadIdx.x;
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 4*il;
const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
const float d = (float)x[i].d * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
for (int j = 0; j < 4; ++j) {
y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
}
}
#endif
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
static void dequantize_block_q8_0_f16_cuda(const void * __restrict__ vx, half * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_Q8_0_NE_ALIGN - 1) / CUDA_Q8_0_NE_ALIGN;
if (k % CUDA_Q8_0_NE_ALIGN == 0) {
const bool need_check = false;
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
} else {
const bool need_check = true;
dequantize_block_q8_0_f16<need_check><<<num_blocks, WARP_SIZE, 0, stream>>>(vx, y, k);
}
}
template<typename dst_t>
static void dequantize_row_q2_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q2_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q2_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
template<typename dst_t>
static void dequantize_row_q3_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q3_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q3_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
template<typename dst_t>
static void dequantize_row_q4_0_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_0<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
static void dequantize_row_q4_1_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb32 = k / 32;
const int nb = (k + 255) / 256;
dequantize_block_q4_1<<<nb, 32, 0, stream>>>(vx, y, nb32);
}
template<typename dst_t>
static void dequantize_row_q4_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_q4_K<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_q5_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q5_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q5_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
template<typename dst_t>
static void dequantize_row_q6_K_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
#if QK_K == 256
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
#else
dequantize_block_q6_K<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
template<typename dst_t>
static void dequantize_row_iq2_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq2_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq2_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq3_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq1_s_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq1_s<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq4_nl_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = (k + QK_K - 1) / QK_K;
#if QK_K == 64
dequantize_block_iq4_nl<<<nb, 32, 0, stream>>>(vx, y);
#else
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
#endif
}
template <typename src_t, typename dst_t>
static __global__ void convert_unary(const void * __restrict__ vx, dst_t * __restrict__ y, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
const src_t * x = (src_t *) vx;
y[i] = x[i];
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
}
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
int id;
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
CUDA_CHECK(cudaGetDevice(&id));
if (ggml_cuda_info().devices[id].cc >= CC_PASCAL) {
return dequantize_block_q8_0_f16_cuda;
}
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
default:
return nullptr;
}
}
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
return dequantize_row_q4_0_cuda;
case GGML_TYPE_Q4_1:
return dequantize_row_q4_1_cuda;
case GGML_TYPE_Q5_0:
return dequantize_block_cuda<QK5_0, QR5_0, dequantize_q5_0>;
case GGML_TYPE_Q5_1:
return dequantize_block_cuda<QK5_1, QR5_1, dequantize_q5_1>;
case GGML_TYPE_Q8_0:
return dequantize_block_cuda<QK8_0, QR8_0, dequantize_q8_0>;
case GGML_TYPE_Q2_K:
return dequantize_row_q2_K_cuda;
case GGML_TYPE_Q3_K:
return dequantize_row_q3_K_cuda;
case GGML_TYPE_Q4_K:
return dequantize_row_q4_K_cuda;
case GGML_TYPE_Q5_K:
return dequantize_row_q5_K_cuda;
case GGML_TYPE_Q6_K:
return dequantize_row_q6_K_cuda;
case GGML_TYPE_IQ2_XXS:
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ2_S:
return dequantize_row_iq2_s_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_IQ1_S:
return dequantize_row_iq1_s_cuda;
case GGML_TYPE_IQ4_NL:
return dequantize_row_iq4_nl_cuda;
case GGML_TYPE_IQ4_XS:
return dequantize_row_iq4_xs_cuda;
case GGML_TYPE_IQ3_S:
return dequantize_row_iq3_s_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
default:
return nullptr;
}
}

13
ggml-cuda/convert.cuh Normal file
View File

@ -0,0 +1,13 @@
#include "common.cuh"
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
template<typename T>
using to_t_cuda_t = void (*)(const void * __restrict__ x, T * __restrict__ y, int k, cudaStream_t stream);
typedef to_t_cuda_t<float> to_fp32_cuda_t;
typedef to_t_cuda_t<half> to_fp16_cuda_t;
to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type);
to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type);

461
ggml-cuda/cpy.cu Normal file
View File

@ -0,0 +1,461 @@
#include "cpy.cuh"
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
static __device__ void cpy_1_f32_f32(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f32_f16(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
half * dsti = (half *) cdsti;
*dsti = __float2half(*xi);
}
static __device__ void cpy_1_f16_f16(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
half * dsti = (half *) cdsti;
*dsti = *xi;
}
static __device__ void cpy_1_f16_f32(const char * cxi, char * cdsti) {
const half * xi = (const half *) cxi;
float * dsti = (float *) cdsti;
*dsti = *xi;
}
template <cpy_kernel_t cpy_1>
static __global__ void cpy_f32_f16(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
// determine indices i03/i13, i02/i12, i01/i11, i00/i10 as a function of index i of flattened tensor
// then combine those indices with the corresponding byte offsets to get the total offsets
const int64_t i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13 * nb13;
cpy_1(cx + x_offset, cdst + dst_offset);
}
static __device__ void cpy_blck_f32_q8_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q8_0 * dsti = (block_q8_0 *) cdsti;
float amax = 0.0f; // absolute max
for (int j = 0; j < QK8_0; j++) {
const float v = xi[j];
amax = fmaxf(amax, fabsf(v));
}
const float d = amax / ((1 << 7) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK8_0; ++j) {
const float x0 = xi[j]*id;
dsti->qs[j] = roundf(x0);
}
}
static __device__ void cpy_blck_f32_q4_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_0 * dsti = (block_q4_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -8;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
for (int j = 0; j < QK4_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_0/2 + j]*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 8.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 8.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q4_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q4_1 * dsti = (block_q4_1 *) cdsti;
float vmin = FLT_MAX;
float vmax = -FLT_MAX;
for (int j = 0; j < QK4_1; ++j) {
const float v = xi[j];
if (v < vmin) vmin = v;
if (v > vmax) vmax = v;
}
const float d = (vmax - vmin) / ((1 << 4) - 1);
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = vmin;
for (int j = 0; j < QK4_1/2; ++j) {
const float x0 = (xi[0 + j] - vmin)*id;
const float x1 = (xi[QK4_1/2 + j] - vmin)*id;
const uint8_t xi0 = min(15, (int8_t)(x0 + 0.5f));
const uint8_t xi1 = min(15, (int8_t)(x1 + 0.5f));
dsti->qs[j] = xi0;
dsti->qs[j] |= xi1 << 4;
}
}
static __device__ void cpy_blck_f32_q5_0(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_0 * dsti = (block_q5_0 *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK5_0; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
const float d = vmax / -16;
const float id = d ? 1.0f/d : 0.0f;
dsti->d = d;
uint32_t qh = 0;
for (int j = 0; j < QK5_0/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK5_0/2 + j]*id;
const uint8_t xi0 = min(31, (int8_t)(x0 + 16.5f));
const uint8_t xi1 = min(31, (int8_t)(x1 + 16.5f));
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_0/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ void cpy_blck_f32_q5_1(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_q5_1 * dsti = (block_q5_1 *) cdsti;
float min = xi[0];
float max = xi[0];
for (int j = 1; j < QK5_1; ++j) {
const float v = xi[j];
min = v < min ? v : min;
max = v > max ? v : max;
}
const float d = (max - min) / 31;
const float id = d ? 1.0f/d : 0.0f;
dsti->dm.x = d;
dsti->dm.y = min;
uint32_t qh = 0;
for (int j = 0; j < QK5_1/2; ++j) {
const float x0 = (xi[0 + j] - min)*id;
const float x1 = (xi[QK5_1/2 + j] - min)*id;
const uint8_t xi0 = (uint8_t)(x0 + 0.5f);
const uint8_t xi1 = (uint8_t)(x1 + 0.5f);
dsti->qs[j] = (xi0 & 0xf) | ((xi1 & 0xf) << 4);
qh |= ((xi0 & 0x10u) >> 4) << (j + 0);
qh |= ((xi1 & 0x10u) >> 4) << (j + QK5_1/2);
}
memcpy(dsti->qh, &qh, sizeof(qh));
}
static __device__ __forceinline__ int best_index_int8(int n, const int8_t * val, float x) {
if (x <= val[0]) return 0;
if (x >= val[n-1]) return n-1;
int ml = 0, mu = n-1;
while (mu-ml > 1) {
int mav = (ml+mu)/2;
if (x < val[mav]) mu = mav; else ml = mav;
}
return x - val[mu-1] < val[mu] - x ? mu-1 : mu;
}
static __device__ void cpy_blck_f32_iq4_nl(const char * cxi, char * cdsti) {
const float * xi = (const float *) cxi;
block_iq4_nl * dsti = (block_iq4_nl *) cdsti;
float amax = 0.0f;
float vmax = 0.0f;
for (int j = 0; j < QK4_NL; ++j) {
const float v = xi[j];
if (amax < fabsf(v)) {
amax = fabsf(v);
vmax = v;
}
}
float d = vmax / kvalues_iq4nl[0];
const float id = d ? 1.0f/d : 0.0f;
float sumqx = 0, sumq2 = 0;
for (int j = 0; j < QK4_NL/2; ++j) {
const float x0 = xi[0 + j]*id;
const float x1 = xi[QK4_NL/2 + j]*id;
const uint8_t xi0 = best_index_int8(16, kvalues_iq4nl, x0);
const uint8_t xi1 = best_index_int8(16, kvalues_iq4nl, x1);
dsti->qs[j] = xi0 | (xi1 << 4);
const float v0 = kvalues_iq4nl[xi0];
const float v1 = kvalues_iq4nl[xi1];
const float w0 = xi[0 + j]*xi[0 + j];
const float w1 = xi[QK4_NL/2 + j]*xi[QK4_NL/2 + j];
sumqx += w0*v0*xi[j] + w1*v1*xi[QK4_NL/2 + j];
sumq2 += w0*v0*v0 + w1*v1*v1;
}
dsti->d = sumq2 > 0 ? sumqx/sumq2 : d;
}
template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int nb12, const int nb13) {
const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) {
return;
}
const int i03 = i/(ne00 * ne01 * ne02);
const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int i13 = i/(ne10 * ne11 * ne12);
const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset);
}
static void ggml_cpy_f16_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_f32_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f32><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f32_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0);
const int num_blocks = ne / QK8_0;
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0);
const int num_blocks = ne / QK4_0;
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0);
const int num_blocks = ne / QK4_1;
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0);
const int num_blocks = ne / QK5_0;
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0);
const int num_blocks = ne / QK5_1;
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0);
const int num_blocks = ne / QK4_NL;
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
static void ggml_cpy_f16_f16_cuda(
const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
cpy_f32_f16<cpy_1_f16_f16><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
}
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1) {
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
//GGML_ASSERT(src0->ne[3] == 1);
const int64_t nb00 = src0->nb[0];
const int64_t nb01 = src0->nb[1];
const int64_t nb02 = src0->nb[2];
const int64_t nb03 = src0->nb[3];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne12 = src1->ne[2];
//GGML_ASSERT(src1->ne[3] == 1);
const int64_t nb10 = src1->nb[0];
const int64_t nb11 = src1->nb[1];
const int64_t nb12 = src1->nb[2];
const int64_t nb13 = src1->nb[3];
cudaStream_t main_stream = ctx.stream();
char * src0_ddc = (char *) src0->data;
char * src1_ddc = (char *) src1->data;
if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f32_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f32_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q8_0) {
ggml_cpy_f32_q8_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_0) {
ggml_cpy_f32_q4_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q4_1) {
ggml_cpy_f32_q4_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_0) {
ggml_cpy_f32_q5_0_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_IQ4_NL) {
ggml_cpy_f32_iq4_nl_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_Q5_1) {
ggml_cpy_f32_q5_1_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16) {
ggml_cpy_f16_f16_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_f16_f32_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream);
} else {
fprintf(stderr, "%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));
GGML_ASSERT(false);
}
}
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
ggml_cuda_cpy(ctx, src0, dst);
}

7
ggml-cuda/cpy.cuh Normal file
View File

@ -0,0 +1,7 @@
#include "common.cuh"
#define CUDA_CPY_BLOCK_SIZE 32
void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1);
void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

103
ggml-cuda/dequantize.cuh Normal file
View File

@ -0,0 +1,103 @@
#include "common.cuh"
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
const dfloat d = x[ib].d;
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
#ifdef GGML_CUDA_F16
v = __hsub2(v, {8.0f, 8.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 8.0f) * d;
v.y = (v.y - 8.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
const int vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
const dfloat d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
#ifdef GGML_CUDA_F16
v = __hsub2(v, {16.0f, 16.0f});
v = __hmul2(v, {d, d});
#else
v.x = (v.x - 16.0f) * d;
v.y = (v.y - 16.0f) * d;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
const dfloat d = __low2half(x[ib].dm);
const dfloat m = __high2half(x[ib].dm);
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
v = __hadd2(v, {m, m});
#else
v.x = (v.x * d) + m;
v.y = (v.y * d) + m;
#endif // GGML_CUDA_F16
}
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
const dfloat d = x[ib].d;
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
#ifdef GGML_CUDA_F16
v = __hmul2(v, {d, d});
#else
v.x *= d;
v.y *= d;
#endif // GGML_CUDA_F16
}

40
ggml-cuda/diagmask.cu Normal file
View File

@ -0,0 +1,40 @@
#include "diagmask.cuh"
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.y*blockIdx.y + threadIdx.y;
const int row = blockDim.x*blockIdx.x + threadIdx.x;
if (col >= ncols) {
return;
}
const int i = row*ncols + col;
//dst[i] = col > (n_past + row % rows_per_channel) ? -INFINITY : x[i];
//dst[i] = x[i] - (col > n_past + row % rows_per_channel) * INT_MAX; // equivalent within rounding error but slightly faster on GPU
dst[i] = x[i] - (col > n_past + row % rows_per_channel) * FLT_MAX;
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(1, CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
const dim3 block_nums(nrows_x, block_num_x, 1);
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
}
void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int nrows0 = ggml_nrows(src0);
const int n_past = ((int32_t *) dst->op_params)[0];
diag_mask_inf_f32_cuda(src0_d, dst_d, ne00, nrows0, ne01, n_past, stream);
}

5
ggml-cuda/diagmask.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_DIAG_MASK_INF_BLOCK_SIZE 32
void ggml_cuda_op_diag_mask_inf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

820
ggml-cuda/dmmv.cu Normal file
View File

@ -0,0 +1,820 @@
#include "dmmv.cuh"
#include "dequantize.cuh"
// dmmv = dequantize_mul_mat_vec
#ifndef GGML_CUDA_DMMV_X
#define GGML_CUDA_DMMV_X 32
#endif
#ifndef GGML_CUDA_MMV_Y
#define GGML_CUDA_MMV_Y 1
#endif
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q2_K * x = (const block_q2_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
#if QK_K == 256
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
aux[1] = a[1] & 0x0f0f0f0f;
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
}
tmp += dall * sum1 - dmin * sum2;
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
const int offset = tid * K_QUANTS_PER_ITERATION;
uint32_t uaux[2];
const uint8_t * d = (const uint8_t *)uaux;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset;
const uint32_t * s = (const uint32_t *)x[i].scales;
uaux[0] = s[0] & 0x0f0f0f0f;
uaux[1] = (s[0] >> 4) & 0x0f0f0f0f;
const float2 dall = __half22float2(x[i].dm);
float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
const uint8_t ql = q[l];
sum1 += y[l+ 0] * d[0] * ((ql >> 0) & 3)
+ y[l+16] * d[1] * ((ql >> 2) & 3)
+ y[l+32] * d[2] * ((ql >> 4) & 3)
+ y[l+48] * d[3] * ((ql >> 6) & 3);
sum2 += y[l+0] * d[4] + y[l+16] * d[5] + y[l+32] * d[6] + y[l+48] * d[7];
}
tmp += dall.x * sum1 - dall.y * sum2;
}
#endif
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q3_K * x = (const block_q3_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
#if QK_K == 256
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop
const int step = 16/K_QUANTS_PER_ITERATION;
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0....15 or 0...7
const uint8_t m = 1 << (4*im);
const int l0 = n*in; // 0...15 or 0...14 in steps of 2
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;
uint16_t utmp[4];
const int8_t * s = (const int8_t *)utmp;
const uint16_t s_shift = 4*im;
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const uint8_t * h = x[i].hmask + l0;
const uint16_t * a = (const uint16_t *)x[i].scales;
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
const float d = x[i].d;
float sum = 0;
for (int l = 0; l < n; ++l) {
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
}
tmp += d * sum;
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15 or 0...7
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0....1 or 0...3
const int offset = tid * K_QUANTS_PER_ITERATION; // 0...15 or 0...14
const int in = offset/8; // 0 or 1
const int im = offset%8; // 0...7
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + offset;
const uint8_t * q = x[i].qs + offset;
const uint8_t * s = x[i].scales;
const float dall = (float)x[i].d;
float sum = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
const uint8_t hl = x[i].hmask[im+l] >> in;
const uint8_t ql = q[l];
sum += y[l+ 0] * dall * ((s[0] & 0xF) - 8) * ((int8_t)((ql >> 0) & 3) - ((hl >> 0) & 1 ? 0 : 4))
+ y[l+16] * dall * ((s[0] >> 4) - 8) * ((int8_t)((ql >> 2) & 3) - ((hl >> 2) & 1 ? 0 : 4))
+ y[l+32] * dall * ((s[1] & 0xF) - 8) * ((int8_t)((ql >> 4) & 3) - ((hl >> 4) & 1 ? 0 : 4))
+ y[l+48] * dall * ((s[1] >> 4) - 8) * ((int8_t)((ql >> 6) & 3) - ((hl >> 6) & 1 ? 0 : 4));
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q4_K * x = (const block_q4_K *)vx + ib0;
#if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1
const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4
const int il = tid/step; // 0...3
const int ir = tid - step*il; // 0...7 or 0...3
const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
#if K_QUANTS_PER_ITERATION == 2
uint32_t q32[4];
const uint8_t * q4 = (const uint8_t *)q32;
#else
uint16_t q16[4];
const uint8_t * q4 = (const uint8_t *)q16;
#endif
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
#if K_QUANTS_PER_ITERATION == 2
const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset);
const uint32_t * q2 = q1 + 16;
q32[0] = q1[0] & 0x0f0f0f0f;
q32[1] = q1[0] & 0xf0f0f0f0;
q32[2] = q2[0] & 0x0f0f0f0f;
q32[3] = q2[0] & 0xf0f0f0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 4; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4];
s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#else
const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset);
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[0] & 0xf0f0;
q16[2] = q2[0] & 0x0f0f;
q16[3] = q2[0] & 0xf0f0;
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < 2; ++l) {
s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2];
s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6];
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin;
#endif
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
const int step = tid * K_QUANTS_PER_ITERATION;
uint16_t aux16[2];
const uint8_t * s = (const uint8_t *)aux16;
float tmp = 0;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const uint8_t * q = x[i].qs + step;
const float * y = yy + i*QK_K + step;
const uint16_t * a = (const uint16_t *)x[i].scales;
aux16[0] = a[0] & 0x0f0f;
aux16[1] = (a[0] >> 4) & 0x0f0f;
const float d = (float)x[i].dm[0];
const float m = (float)x[i].dm[1];
float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * (d * s[0] * (q[j+ 0] & 0xF) - m * s[2])
+ y[j+16] * (d * s[0] * (q[j+16] & 0xF) - m * s[2])
+ y[j+32] * (d * s[1] * (q[j+ 0] >> 4) - m * s[3])
+ y[j+48] * (d * s[1] * (q[j+16] >> 4) - m * s[3]);
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) {
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q5_K * x = (const block_q5_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
#if QK_K == 256
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 2;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
const uint8_t hm1 = 1 << (2*im);
const uint8_t hm2 = hm1 << 4;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
uint16_t q16[8];
const uint8_t * q4 = (const uint8_t *)q16;
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
const uint8_t * qh = x[i].qh + l0;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const float dall = __low2half(x[i].dm);
const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
const uint16_t * q1 = (const uint16_t *)ql1;
const uint16_t * q2 = q1 + 32;
q16[0] = q1[0] & 0x0f0f;
q16[1] = q1[8] & 0x0f0f;
q16[2] = (q1[0] >> 4) & 0x0f0f;
q16[3] = (q1[8] >> 4) & 0x0f0f;
q16[4] = q2[0] & 0x0f0f;
q16[5] = q2[8] & 0x0f0f;
q16[6] = (q2[0] >> 4) & 0x0f0f;
q16[7] = (q2[8] >> 4) & 0x0f0f;
for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0))
+ y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0))
+ y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0))
+ y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0))
+ y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0));
smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3]
+ (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...15
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION);
const int step = tid * K_QUANTS_PER_ITERATION;
const int im = step/8;
const int in = step%8;
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const uint8_t * q = x[i].qs + step;
const int8_t * s = x[i].scales;
const float * y = yy + i*QK_K + step;
const float d = x[i].d;
float sum = 0.f;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
const uint8_t h = x[i].qh[in+j] >> im;
sum += y[j+ 0] * d * s[0] * ((q[j+ 0] & 0xF) - ((h >> 0) & 1 ? 0 : 16))
+ y[j+16] * d * s[1] * ((q[j+16] & 0xF) - ((h >> 2) & 1 ? 0 : 16))
+ y[j+32] * d * s[2] * ((q[j+ 0] >> 4) - ((h >> 4) & 1 ? 0 : 16))
+ y[j+48] * d * s[3] * ((q[j+16] >> 4) - ((h >> 6) & 1 ? 0 : 16));
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (threadIdx.x == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q6_K * x = (const block_q6_K *)vx + ib0;
#if QK_K == 256
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset;
const uint8_t * qh = x[i].qh + qh_offset;
const int8_t * s = x[i].scales + s_offset;
const float d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp += sum;
#endif
}
#else
const int tid = threadIdx.x/(2*K_QUANTS_PER_ITERATION); // 0...7
const int ix = threadIdx.x%(2*K_QUANTS_PER_ITERATION); // 0...3
const int step = tid * K_QUANTS_PER_ITERATION;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2*K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + step;
const uint8_t * ql = x[i].ql + step;
const uint8_t * qh = x[i].qh + step;
const int8_t * s = x[i].scales;
const float d = x[i+0].d;
float sum = 0;
for (int j = 0; j < K_QUANTS_PER_ITERATION; ++j) {
sum += y[j+ 0] * s[0] * d * ((int8_t)((ql[j+ 0] & 0xF) | ((qh[j] & 0x03) << 4)) - 32)
+ y[j+16] * s[1] * d * ((int8_t)((ql[j+16] & 0xF) | ((qh[j] & 0x0c) << 2)) - 32)
+ y[j+32] * s[2] * d * ((int8_t)((ql[j+ 0] >> 4) | ((qh[j] & 0x30) >> 0)) - 32)
+ y[j+48] * s[3] * d * ((int8_t)((ql[j+16] >> 4) | ((qh[j] & 0xc0) >> 2)) - 32);
}
tmp += sum;
}
#endif
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
dst[row] = tmp;
}
}
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
const half * x = (const half *) vx;
// automatic half -> float type cast if dfloat == float
v.x = x[ib + iqs + 0];
v.y = x[ib + iqs + 1];
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.x*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
}
const int tid = threadIdx.x;
const int iter_stride = 2*GGML_CUDA_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2;
// partial sum for each thread
#ifdef GGML_CUDA_F16
half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
#else
float tmp = 0.0f;
#endif // GGML_CUDA_F16
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index
// processing >2 values per i iter is faster for fast GPUs
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
// process 2 vals per j iter
// dequantize
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
dfloat2 v;
dequantize_kernel(vx, ib, iqs + j/qr, v);
// matrix multiplication
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
#ifdef GGML_CUDA_F16
tmp += __hmul2(v, {
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]
});
#else
tmp += v.x * y[iybs + iqs + j/qr + 0];
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
#endif // GGML_CUDA_F16
}
}
// sum up partial sums and write back result
tmp = warp_reduce_sum(tmp);
if (tid == 0) {
#ifdef GGML_CUDA_F16
dst[row] = tmp.x + tmp.y;
#else
dst[row] = tmp;
#endif // GGML_CUDA_F16
}
}
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
// the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_0, QR4_0, dequantize_q4_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK4_1, QR4_1, dequantize_q4_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_0, QR5_0, dequantize_q5_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK5_1, QR5_1, dequantize_q5_1>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<QK8_0, QR8_0, dequantize_q8_0>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q3_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q4_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const dim3 block_dims(32, 1, 1);
dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
}
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(block_num_y, 1, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
dequantize_mul_mat_vec<1, 1, convert_f16>
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
}
void ggml_cuda_op_dequantize_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
GGML_UNUSED(ctx);
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
GGML_ASSERT(src1->type == GGML_TYPE_F32);
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
#ifdef GGML_CUDA_F16
ggml_cuda_pool_alloc<half> src1_dfloat_a(ctx.pool());
half * src1_dfloat = nullptr; // dfloat == half
bool src1_convert_f16 =
src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
if (src1_convert_f16) {
src1_dfloat = src1_dfloat_a.alloc(ne00);
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type);
GGML_ASSERT(to_fp16_cuda != nullptr);
to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream);
}
#else
const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion
#endif // GGML_CUDA_F16
switch (src0->type) {
case GGML_TYPE_Q4_0:
dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_1:
dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_0:
dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_1:
dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q8_0:
dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q2_K:
dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q3_K:
dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q4_K:
dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q5_K:
dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_Q6_K:
dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream);
break;
case GGML_TYPE_F16:
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
break;
default:
GGML_ASSERT(false);
break;
}
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddq_i);
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}

7
ggml-cuda/dmmv.cuh Normal file
View File

@ -0,0 +1,7 @@
#include "common.cuh"
void ggml_cuda_op_dequantize_mul_mat_vec(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);

178
ggml-cuda/getrows.cu Normal file
View File

@ -0,0 +1,178 @@
#include "getrows.cuh"
#include "dequantize.cuh"
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * src0, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
const int i00 = (blockIdx.x*blockDim.x + threadIdx.x)*2;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
if (i00 >= ne00) {
return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
dfloat2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = v.x;
dst_row[iybs + iqs + y_offset] = v.y;
}
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
const src0_t * src0, const int32_t * src1, dst_t * dst,
int64_t ne00, /*int64_t ne01, int64_t ne02, int64_t ne03,*/
/*int64_t ne10, int64_t ne11,*/ int64_t ne12, /*int64_t ne13,*/
/*size_t s0,*/ size_t s1, size_t s2, size_t s3,
/*size_t nb00,*/ size_t nb01, size_t nb02, size_t nb03,
size_t s10, size_t s11, size_t s12/*, size_t s13*/) {
const int i00 = blockIdx.x*blockDim.x + threadIdx.x;
const int i10 = blockDim.y*blockIdx.y + threadIdx.y;
const int i11 = (blockIdx.z*blockDim.z + threadIdx.z)/ne12;
const int i12 = (blockIdx.z*blockDim.z + threadIdx.z)%ne12;
if (i00 >= ne00) {
return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *)src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = src0_row[i00];
}
template<int qk, int qr, dequantize_kernel_t dq>
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const size_t s1 = nb1 / ggml_element_size(dst);
const size_t s2 = nb2 / ggml_element_size(dst);
const size_t s3 = nb3 / ggml_element_size(dst);
const size_t s10 = nb10 / ggml_element_size(src1);
const size_t s11 = nb11 / ggml_element_size(src1);
const size_t s12 = nb12 / ggml_element_size(src1);
//const size_t s13 = nb13 / ggml_element_size(src1);
GGML_ASSERT(ne00 % 2 == 0);
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
GGML_UNUSED(dst);
}
template<typename src0_t>
static void get_rows_cuda_float(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
const src0_t * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
GGML_TENSOR_BINARY_OP_LOCALS
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_x = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(block_num_x, ne10, ne11*ne12);
// strides in elements
//const size_t s0 = nb0 / ggml_element_size(dst);
const size_t s1 = nb1 / ggml_element_size(dst);
const size_t s2 = nb2 / ggml_element_size(dst);
const size_t s3 = nb3 / ggml_element_size(dst);
const size_t s10 = nb10 / ggml_element_size(src1);
const size_t s11 = nb11 / ggml_element_size(src1);
const size_t s12 = nb12 / ggml_element_size(src1);
//const size_t s13 = nb13 / ggml_element_size(src1);
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
GGML_UNUSED(dst);
}
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(src1->nb[0] == ggml_type_size(src1->type));
GGML_ASSERT(dst->nb[0] == ggml_type_size(dst->type));
const int32_t * src1_i32 = (const int32_t *) src1_d;
switch (src0->type) {
case GGML_TYPE_F16:
get_rows_cuda_float(src0, src1, dst, (const half *)src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_F32:
get_rows_cuda_float(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_0:
get_rows_cuda<QK4_0, QR4_0, dequantize_q4_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q4_1:
get_rows_cuda<QK4_1, QR4_1, dequantize_q4_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_0:
get_rows_cuda<QK5_0, QR5_0, dequantize_q5_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q5_1:
get_rows_cuda<QK5_1, QR5_1, dequantize_q5_1>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
case GGML_TYPE_Q8_0:
get_rows_cuda<QK8_0, QR8_0, dequantize_q8_0>(src0, src1, dst, src0_d, src1_i32, dst_d, stream);
break;
default:
// TODO: k-quants
fprintf(stderr, "%s: unsupported type: %s\n", __func__, ggml_type_name(src0->type));
GGML_ASSERT(false);
break;
}
}

5
ggml-cuda/getrows.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_GET_ROWS_BLOCK_SIZE 256
void ggml_cuda_op_get_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

104
ggml-cuda/im2col.cu Normal file
View File

@ -0,0 +1,104 @@
#include "im2col.cuh"
template <typename T>
static __global__ void im2col_kernel(
const float * x, T * dst, int64_t batch_offset,
int64_t offset_delta, int64_t IC, int64_t IW, int64_t IH, int64_t OH, int64_t OW, int64_t KW, int64_t KH, int64_t pelements, int64_t CHW,
int s0, int s1, int p0, int p1, int d0, int d1) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= pelements) {
return;
}
const int64_t ksize = OW * (KH > 1 ? KW : 1);
const int64_t kx = i / ksize;
const int64_t kd = kx * ksize;
const int64_t ky = (i - kd) / OW;
const int64_t ix = i % OW;
const int64_t oh = blockIdx.y;
const int64_t batch = blockIdx.z / IC;
const int64_t ic = blockIdx.z % IC;
const int64_t iiw = ix * s0 + kx * d0 - p0;
const int64_t iih = oh * s1 + ky * d1 - p1;
const int64_t offset_dst =
((batch * OH + oh) * OW + ix) * CHW +
(ic * (KW * KH) + ky * KW + kx);
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = ic * offset_delta + batch * batch_offset;
dst[offset_dst] = x[offset_src + iih * IW + iiw];
}
}
template <typename T>
static void im2col_cuda(const float * x, T* dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
const int parallel_elements = OW * KW * KH;
const int num_blocks = (parallel_elements + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OH, batch * IC);
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
}
static void im2col_cuda_f16(const float * x, half * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<half>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
}
static void im2col_cuda_f32(const float * x, float * dst,
int64_t IW, int64_t IH, int64_t OW, int64_t OH, int64_t KW, int64_t KH, int64_t IC,
int64_t batch, int64_t batch_offset, int64_t offset_delta,
int s0,int s1,int p0,int p1,int d0,int d1, cudaStream_t stream) {
im2col_cuda<float>(x, dst, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, offset_delta, s0, s1, p0, p1, d0, d1, stream);
}
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
const int32_t s0 = ((const int32_t*)(dst->op_params))[0];
const int32_t s1 = ((const int32_t*)(dst->op_params))[1];
const int32_t p0 = ((const int32_t*)(dst->op_params))[2];
const int32_t p1 = ((const int32_t*)(dst->op_params))[3];
const int32_t d0 = ((const int32_t*)(dst->op_params))[4];
const int32_t d1 = ((const int32_t*)(dst->op_params))[5];
const bool is_2D = ((const int32_t*)(dst->op_params))[6] == 1;
const int64_t IC = src1->ne[is_2D ? 2 : 1];
const int64_t IH = is_2D ? src1->ne[1] : 1;
const int64_t IW = src1->ne[0];
const int64_t KH = is_2D ? src0->ne[1] : 1;
const int64_t KW = src0->ne[0];
const int64_t OH = is_2D ? dst->ne[2] : 1;
const int64_t OW = dst->ne[1];
const size_t delta_offset = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32
const int64_t batch = src1->ne[3];
const size_t batch_offset = src1->nb[3] / 4; // nb is byte offset, src is type float32
if(dst->type == GGML_TYPE_F16) {
im2col_cuda_f16(src1_d, (half *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
} else {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, batch, batch_offset, delta_offset, s0, s1, p0, p1, d0, d1, stream);
}
}

5
ggml-cuda/im2col.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

2265
ggml-cuda/mmq.cu Normal file

File diff suppressed because it is too large Load Diff

9
ggml-cuda/mmq.cuh Normal file
View File

@ -0,0 +1,9 @@
#include "common.cuh"
void ggml_cuda_op_mul_mat_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_supports_mmq(enum ggml_type type);

395
ggml-cuda/mmvq.cu Normal file
View File

@ -0,0 +1,395 @@
#include "mmvq.cuh"
#include "vecdotq.cuh"
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
// partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
// x block quant index when casting the quants to int
const int kqs = vdr * (tid % (qi/vdr));
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda(
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
}
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
}
}
}
__syncthreads();
if (threadIdx.y > 0) {
return;
}
// sum up partial sums and write back result
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
#pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) {
#pragma unroll
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}
}
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot>
static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
GGML_ASSERT(ncols_x % qk == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
int id;
CUDA_CHECK(cudaGetDevice(&id));
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
if (ggml_cuda_info().devices[id].cc < CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
rows_per_cuda_block = 1;
break;
case 2:
case 3:
case 4:
nwarps = 4;
rows_per_cuda_block = 2;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
rows_per_cuda_block = 2;
break;
default:
GGML_ASSERT(false);
break;
}
}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
switch (ncols_y) {
case 1:
mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 2:
mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 3:
mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 4:
mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 5:
mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 6:
mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 7:
mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
case 8:
mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
default:
GGML_ASSERT(false);
break;
}
}
static void mul_mat_vec_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q3_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq1_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1>
(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
const int64_t row_diff = row_high - row_low;
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne10 % QK8_1 == 0);
const int64_t ne0 = dst->ne[0];
int id;
CUDA_CHECK(cudaGetDevice(&id));
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
switch (src0->type) {
case GGML_TYPE_Q4_0:
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q4_1:
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q5_0:
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q5_1:
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_XXS:
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_XS:
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ2_S:
mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ3_XXS:
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ1_S:
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
case GGML_TYPE_IQ3_S:
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break;
default:
GGML_ASSERT(false);
break;
}
GGML_UNUSED(src1);
GGML_UNUSED(dst);
GGML_UNUSED(src1_ddf_i);
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}

7
ggml-cuda/mmvq.cuh Normal file
View File

@ -0,0 +1,7 @@
#include "common.cuh"
void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream);

215
ggml-cuda/norm.cu Normal file
View File

@ -0,0 +1,215 @@
#include "norm.cuh"
template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float2 mean_var = make_float2(0.f, 0.f);
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
mean_var.x += xi;
mean_var.y += xi * xi;
}
// sum up partial sums
mean_var = warp_reduce_sum(mean_var);
if (block_size > WARP_SIZE) {
__shared__ float2 s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = mean_var;
}
__syncthreads();
mean_var = s_sum[lane_id];
mean_var = warp_reduce_sum(mean_var);
}
const float mean = mean_var.x / ncols;
const float var = mean_var.y / ncols - mean * mean;
const float inv_std = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_std;
}
}
template <int block_size>
static __global__ void group_norm_f32(const float * x, float * dst, const int group_size, const int ne_elements, const float eps) {
// blockIdx.x: num_groups idx
// threadIdx.x: block_size idx
int start = blockIdx.x * group_size;
int end = start + group_size;
start += threadIdx.x;
if (end >= ne_elements) {
end = ne_elements;
}
float tmp = 0.0f; // partial sum for thread in warp
for (int j = start; j < end; j += block_size) {
tmp += x[j];
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
float mean = tmp / group_size;
tmp = 0.0f;
for (int j = start; j < end; j += block_size) {
float xi = x[j] - mean;
dst[j] = xi;
tmp += xi * xi;
}
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
float variance = tmp / group_size;
float scale = rsqrtf(variance + eps);
for (int j = start; j < end; j += block_size) {
dst[j] *= scale;
}
}
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__shared__ float s_sum[32];
int warp_id = threadIdx.x / WARP_SIZE;
int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float mean = tmp / ncols;
const float scale = rsqrtf(mean + eps);
for (int col = tid; col < ncols; col += block_size) {
dst[row*ncols + col] = scale * x[row*ncols + col];
}
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
static const float eps = 1e-6f;
if (group_size < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
} else {
const dim3 block_dims(1024, 1, 1);
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
}
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
int num_groups = dst->op_params[0];
int group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups);
group_norm_f32_cuda(src0_d, dst_d, num_groups * src0->ne[3], group_size, ggml_nelements(src0), stream);
}
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
}

7
ggml-cuda/norm.cuh Normal file
View File

@ -0,0 +1,7 @@
#include "common.cuh"
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

49
ggml-cuda/pad.cu Normal file
View File

@ -0,0 +1,49 @@
#include "pad.cuh"
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
} else {
dst[offset_dst] = 0.0f;
}
}
static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}

5
ggml-cuda/pad.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_PAD_BLOCK_SIZE 256
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

94
ggml-cuda/pool2d.cu Normal file
View File

@ -0,0 +1,94 @@
#include "pool2d.cuh"
template <typename Ti, typename To>
static __global__ void pool2d_nchw_kernel(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const Ti* src, To* dst, const enum ggml_op_pool op) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= parallel_elements) {
return;
}
const int I_HW = ih * iw;
const int O_HW = oh * ow;
const int nc = idx / O_HW;
const int cur_oh = idx % O_HW / ow;
const int cur_ow = idx % O_HW % ow;
const Ti* i_ptr = src + nc * I_HW;
To* o_ptr = dst + nc * O_HW;
const int start_h = cur_oh * sh - ph;
const int bh = max(0, start_h);
const int eh = min(ih, start_h + kh);
const int start_w = cur_ow * sw - pw;
const int bw = max(0, start_w);
const int ew = min(iw, start_w + kw);
const To scale = 1. / (kh * kw);
To res = 0;
switch (op) {
case GGML_OP_POOL_AVG: res = 0; break;
case GGML_OP_POOL_MAX: res = -FLT_MAX; break;
default: assert(false);
}
for (int i = bh; i < eh; i += 1) {
for (int j = bw; j < ew; j += 1) {
#if __CUDA_ARCH__ >= 350
Ti cur = __ldg(i_ptr + i * iw + j);
#else
Ti cur = i_ptr[i * iw + j];
#endif
switch (op) {
case GGML_OP_POOL_AVG: res += cur * scale; break;
case GGML_OP_POOL_MAX: res = max(res, (To)cur); break;
default: assert(false);
}
}
}
o_ptr[cur_oh * ow + cur_ow] = res;
}
static void pool2d_nchw_kernel_f32_f32_cuda(
const int ih, const int iw, const int oh, const int ow,
const int kh, const int kw, const int sh, const int sw,
const int ph, const int pw, const int parallel_elements,
const float * src, float * dst, const enum ggml_op_pool op,
cudaStream_t stream) {
const int num_blocks = (parallel_elements + CUDA_POOL2D_BLOCK_SIZE - 1) / CUDA_POOL2D_BLOCK_SIZE;
dim3 block_nums(num_blocks);
pool2d_nchw_kernel<<<block_nums, CUDA_POOL2D_BLOCK_SIZE, 0, stream>>>(ih, iw, oh, ow, kh, kw, sh, sw, ph, pw, parallel_elements, src, dst, op);
}
void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *)dst->op_params;
enum ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]);
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
const int64_t IH = src0->ne[1];
const int64_t IW = src0->ne[0];
const int64_t N = dst->ne[3];
const int64_t OC = dst->ne[2];
const int64_t OH = dst->ne[1];
const int64_t OW = dst->ne[0];
const int parallel_elements = N * OC * OH * OW;
pool2d_nchw_kernel_f32_f32_cuda(IH, IW, OH, OW, k1, k0, s1, s0, p1, p0, parallel_elements, src0_d, dst_d, op, stream);
}

5
ggml-cuda/pool2d.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_POOL2D_BLOCK_SIZE 256
void ggml_cuda_op_pool2d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

45
ggml-cuda/quantize.cu Normal file
View File

@ -0,0 +1,45 @@
#include "quantize.cuh"
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int kx, const int kx_padded) {
const int ix = blockDim.x*blockIdx.x + threadIdx.x;
if (ix >= kx_padded) {
return;
}
const int iy = blockDim.y*blockIdx.y + threadIdx.y;
const int i_padded = iy*kx_padded + ix;
block_q8_1 * y = (block_q8_1 *) vy;
const int ib = i_padded / QK8_1; // block index
const int iqs = i_padded % QK8_1; // quant index
const float xi = ix < kx ? x[iy*kx + ix] : 0.0f;
float amax = fabsf(xi);
float sum = xi;
amax = warp_reduce_max(amax);
sum = warp_reduce_sum(sum);
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
if (iqs > 0) {
return;
}
reinterpret_cast<half&>(y[ib].ds.x) = d;
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream) {
const int block_num_x = (kx_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ky, 1);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}

5
ggml-cuda/quantize.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_QUANTIZE_BLOCK_SIZE 256
void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, const int ky, const int kx_padded, cudaStream_t stream);

308
ggml-cuda/rope.cu Normal file
View File

@ -0,0 +1,308 @@
#include "rope.cuh"
struct rope_corr_dims {
float v[4];
};
static __device__ float rope_yarn_ramp(const float low, const float high, const int i0) {
const float y = (i0 / 2 - low) / max(0.001f, high - low);
return 1.0f - min(1.0f, max(0.0f, y));
}
// YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn
// MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng.
static __device__ void rope_yarn(
float theta_extrap, float freq_scale, rope_corr_dims corr_dims, int64_t i0, float ext_factor, float mscale,
float * cos_theta, float * sin_theta
) {
// Get n-d rotational scaling corrected for extrapolation
float theta_interp = freq_scale * theta_extrap;
float theta = theta_interp;
if (ext_factor != 0.0f) {
float ramp_mix = rope_yarn_ramp(corr_dims.v[0], corr_dims.v[1], i0) * ext_factor;
theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// Get n-d magnitude scaling corrected for interpolation
mscale *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
*cos_theta = cosf(theta) * mscale;
*sin_theta = sinf(theta) * mscale;
}
// rope == RoPE == rotary positional embedding
template<typename T, bool has_pos>
static __global__ void rope(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
float ext_factor, float attn_factor, rope_corr_dims corr_dims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, -float(col)/ncols);
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, col, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + 1];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + 1] = x0*sin_theta + x1*cos_theta;
}
template<typename T, bool has_pos>
static __global__ void rope_neox(
const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (col >= ncols) {
return;
}
const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int ib = col / n_dims;
const int ic = col % n_dims;
if (ib > 0) {
const int i = row*ncols + ib*n_dims + ic;
dst[i + 0] = x[i + 0];
dst[i + 1] = x[i + 1];
return;
}
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows;
float cur_rot = inv_ndims * ic - ib;
const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0];
const float x1 = x[i + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
static __global__ void rope_glm_f32(
const float * x, float * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base,
int n_ctx
) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;
if (col >= half_n_dims) {
return;
}
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const int i2 = row/p_delta_rows;
const float col_theta_scale = powf(freq_base, -2.0f*col/ncols);
// FIXME: this is likely wrong
const int p = pos != nullptr ? pos[i2] : 0;
const float theta = min(p, n_ctx - 2)*freq_scale*col_theta_scale;
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
const float x0 = x[i + 0];
const float x1 = x[i + half_n_dims];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
const float block_theta = ((float)max(p - n_ctx - 2, 0))*col_theta_scale;
const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta);
const float x2 = x[i + half_n_dims * 2];
const float x3 = x[i + half_n_dims * 3];
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
}
template<typename T>
static void rope_cuda(
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
if (pos == nullptr) {
rope<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
} else {
rope<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims
);
}
}
template<typename T>
static void rope_neox_cuda(
const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
} else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
);
}
}
static void rope_glm_f32_cuda(
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, int n_ctx, cudaStream_t stream
) {
GGML_ASSERT(ncols % 4 == 0);
const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, n_ctx);
}
static void rope_cuda_f16(
const half * x, half * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
rope_cuda<half>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
}
static void rope_cuda_f32(
const float * x, float * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
rope_cuda<float>(x, dst, ncols, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
}
static void rope_neox_cuda_f16(
const half * x, half * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream) {
rope_neox_cuda<half>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
}
static void rope_neox_cuda_f32(
const float * x, float * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) {
rope_neox_cuda<float>(x, dst, ncols, n_dims, nrows, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims, stream);
}
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne2 = dst->ne[2];
const int64_t nrows = ggml_nrows(src0);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
const int n_ctx = ((int32_t *) dst->op_params)[3];
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
// RoPE alteration for extended context
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float));
memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float));
memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float));
memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float));
memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float));
const int32_t * pos = nullptr;
if ((mode & 1) == 0) {
GGML_ASSERT(src1->type == GGML_TYPE_I32);
GGML_ASSERT(src1->ne[0] == ne2);
pos = (const int32_t *) src1_d;
}
const bool is_neox = mode & 2;
const bool is_glm = mode & 4;
rope_corr_dims corr_dims;
ggml_rope_yarn_corr_dims(n_dims, n_orig_ctx, freq_base, beta_fast, beta_slow, corr_dims.v);
// compute
if (is_glm) {
GGML_ASSERT(false);
rope_glm_f32_cuda(src0_d, dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, stream);
} else if (is_neox) {
if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
);
} else {
GGML_ASSERT(false);
}
} else {
if (src0->type == GGML_TYPE_F32) {
rope_cuda_f32(
(const float *)src0_d, (float *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
);
} else if (src0->type == GGML_TYPE_F16) {
rope_cuda_f16(
(const half *)src0_d, (half *)dst_d, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, stream
);
} else {
GGML_ASSERT(false);
}
}
}

5
ggml-cuda/rope.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_ROPE_BLOCK_SIZE 256
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

32
ggml-cuda/scale.cu Normal file
View File

@ -0,0 +1,32 @@
#include "scale.cuh"
static __global__ void scale_f32(const float * x, float * dst, const float scale, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = scale * x[i];
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, k);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float scale;
memcpy(&scale, dst->op_params, sizeof(float));
scale_f32_cuda(src0_d, dst_d, scale, ggml_nelements(src0), stream);
CUDA_CHECK(cudaGetLastError());
}

5
ggml-cuda/scale.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_SCALE_BLOCK_SIZE 256
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

201
ggml-cuda/softmax.cu Normal file
View File

@ -0,0 +1,201 @@
#include "softmax.cuh"
template <bool vals_smem, int ncols_template, int block_size_template>
static __global__ void soft_max_f32(const float * x, const float * mask, const float * pos, float * dst, const int ncols_par, const int nrows_y, const float scale, const float max_bias, const float m0, const float m1, uint32_t n_head_log2) {
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
const int tid = threadIdx.x;
const int rowx = blockIdx.x;
const int rowy = rowx % nrows_y; // broadcast the mask in the row dimension
const int block_size = block_size_template == 0 ? blockDim.x : block_size_template;
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
float slope = 0.0f;
// ALiBi
if (max_bias > 0.0f) {
const int h = rowx/nrows_y; // head index
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = powf(base, exp);
}
extern __shared__ float data_soft_max_f32[];
float * buf_iw = data_soft_max_f32; // shared memory buffer for inter-warp communication
// shared memory buffer to cache values between iterations:
float * vals = vals_smem ? buf_iw + WARP_SIZE : dst + rowx*ncols;
float max_val = -INFINITY;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const int ix = rowx*ncols + col;
const int iy = rowy*ncols + col;
const float val = x[ix]*scale + (mask ? mask[iy] : 0.0f) + (pos ? slope*pos[col] : 0.0f);
vals[col] = val;
max_val = max(max_val, val);
}
// find the max value in the block
max_val = warp_reduce_max(max_val);
if (block_size > WARP_SIZE) {
if (warp_id == 0) {
buf_iw[lane_id] = -INFINITY;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = max_val;
}
__syncthreads();
max_val = buf_iw[lane_id];
max_val = warp_reduce_max(max_val);
}
float tmp = 0.0f; // partial sum
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
break;
}
const float val = expf(vals[col] - max_val);
tmp += val;
vals[col] = val;
}
// find the sum of exps in the block
tmp = warp_reduce_sum(tmp);
if (block_size > WARP_SIZE) {
__syncthreads();
if (warp_id == 0) {
buf_iw[lane_id] = 0.0f;
}
__syncthreads();
if (lane_id == 0) {
buf_iw[warp_id] = tmp;
}
__syncthreads();
tmp = buf_iw[lane_id];
tmp = warp_reduce_sum(tmp);
}
const float inv_sum = 1.0f / tmp;
#pragma unroll
for (int col0 = 0; col0 < ncols; col0 += block_size) {
const int col = col0 + tid;
if (ncols_template == 0 && col >= ncols) {
return;
}
const int idst = rowx*ncols + col;
dst[idst] = vals[col] * inv_sum;
}
}
static void soft_max_f32_cuda(const float * x, const float * mask, const float * pos, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, const float max_bias, cudaStream_t stream) {
int nth = WARP_SIZE;
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
const dim3 block_dims(nth, 1, 1);
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
const uint32_t n_head_kv = nrows_x/nrows_y;
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
if (shmem < ggml_cuda_info().devices[ggml_cuda_get_device()].smpb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 64:
soft_max_f32<true, 64, 64><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 128:
soft_max_f32<true, 128, 128><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 256:
soft_max_f32<true, 256, 256><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 512:
soft_max_f32<true, 512, 512><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 1024:
soft_max_f32<true, 1024, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 2048:
soft_max_f32<true, 2048, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
case 4096:
soft_max_f32<true, 4096, 1024><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
default:
soft_max_f32<true, 0, 0><<<block_nums, block_dims, shmem, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
break;
}
} else {
const size_t shmem_low = WARP_SIZE*sizeof(float);
soft_max_f32<false, 0, 0><<<block_nums, block_dims, shmem_low, stream>>>(x, mask, pos, dst, ncols_x, nrows_y, scale, max_bias, m0, m1, n_head_log2);
}
}
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src0_d = (const float *)src0->data;
const float * src1_d = src1 ? (const float *)src1->data : nullptr;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
const int64_t ne00 = src0->ne[0];
const int64_t nrows_x = ggml_nrows(src0);
const int64_t nrows_y = src0->ne[1];
float scale = 1.0f;
float max_bias = 0.0f;
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) dst->op_params + 1, sizeof(float));
// positions tensor
float * src2_dd = nullptr;
ggml_tensor * src2 = dst->src[2];
const bool use_src2 = src2 != nullptr;
if (use_src2) {
src2_dd = (float *)src2->data;
}
soft_max_f32_cuda(src0_d, src1_d, src2_dd, dst_d, ne00, nrows_x, nrows_y, scale, max_bias, stream);
}

5
ggml-cuda/softmax.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_SOFT_MAX_BLOCK_SIZE 1024
void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

40
ggml-cuda/sumrows.cu Normal file
View File

@ -0,0 +1,40 @@
#include "sumrows.cuh"
static __global__ void k_sum_rows_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x;
const int col = threadIdx.x;
float sum = 0.0f;
for (int i = col; i < ncols; i += blockDim.x) {
sum += x[row * ncols + i];
}
sum = warp_reduce_sum(sum);
if (col == 0) {
dst[row] = sum;
}
}
static void sum_rows_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
const dim3 block_dims(WARP_SIZE, 1, 1);
const dim3 block_nums(nrows, 1, 1);
k_sum_rows_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols);
}
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
sum_rows_f32_cuda(src0_d, dst_d, ncols, nrows, stream);
}

3
ggml-cuda/sumrows.cuh Normal file
View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_sum_rows(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

47
ggml-cuda/tsembd.cu Normal file
View File

@ -0,0 +1,47 @@
#include "tsembd.cuh"
static __global__ void timestep_embedding_f32(const float * timesteps, float * dst, const int nb1, const int dim, const int max_period) {
// blockIDx.y: idx of timesteps->ne[0]
// blockIDx.x: idx of ((dim + 1) / 2) / BLOCK_SIZE
int i = blockIdx.y;
int j = threadIdx.x + blockIdx.x * blockDim.x;
float * embed_data = (float *)((char *)dst + i*nb1);
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
embed_data[dim] = 0.f;
}
int half = dim / 2;
if (j >= half) {
return;
}
float timestep = timesteps[i];
float freq = (float)expf(-logf(max_period) * j / half);
float arg = timestep * freq;
embed_data[j] = cosf(arg);
embed_data[j + half] = sinf(arg);
}
static void timestep_embedding_f32_cuda(const float * x, float * dst, const int ne00, const int nb1,
const int dim, const int max_period, cudaStream_t stream) {
int half_ceil = (dim + 1) / 2;
int num_blocks = (half_ceil + CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE - 1) / CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne00, 1);
timestep_embedding_f32<<<gridDim, CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE, 0, stream>>>(x, dst, nb1, dim, max_period);
}
void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int dim = dst->op_params[0];
const int max_period = dst->op_params[1];
timestep_embedding_f32_cuda(src0_d, dst_d, src0->ne[0], dst->nb[1], dim, max_period, stream);
}

5
ggml-cuda/tsembd.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
void ggml_cuda_op_timestep_embedding(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

240
ggml-cuda/unary.cu Normal file
View File

@ -0,0 +1,240 @@
#include "unary.cuh"
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
float xi = x[i];
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}
static __global__ void gelu_quick_f32(const float * x, float * dst, int k) {
const float GELU_QUICK_COEF = -1.702f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * (1.0f / (1.0f + expf(GELU_QUICK_COEF * x[i])));
}
static __global__ void silu_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] / (1.0f + expf(-x[i]));
}
static __global__ void tanh_f32(const float * x, float * dst, int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = tanhf(x[i]);
}
static __global__ void relu_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fmaxf(x[i], 0);
}
static __global__ void hardsigmoid_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void hardswish_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * fminf(1.0f, fmaxf(0.0f, (x[i] + 3.0f) / 6.0f));
}
static __global__ void leaky_relu_f32(const float * x, float * dst, const int k, const float negative_slope) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = fmaxf(x[i], 0) + fminf(x[i], 0.0f) * negative_slope;
}
static __global__ void sqr_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] * x[i];
}
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void gelu_quick_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_quick_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void tanh_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_TANH_BLOCK_SIZE - 1) / CUDA_TANH_BLOCK_SIZE;
tanh_f32<<<num_blocks, CUDA_TANH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void relu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardsigmoid_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSIGMOID_BLOCK_SIZE - 1) / CUDA_HARDSIGMOID_BLOCK_SIZE;
hardsigmoid_f32<<<num_blocks, CUDA_HARDSIGMOID_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void hardswish_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_HARDSWISH_BLOCK_SIZE - 1) / CUDA_HARDSWISH_BLOCK_SIZE;
hardswish_f32<<<num_blocks, CUDA_HARDSWISH_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void leaky_relu_f32_cuda(const float * x, float * dst, const int k, const float negative_slope, cudaStream_t stream) {
const int num_blocks = (k + CUDA_RELU_BLOCK_SIZE - 1) / CUDA_RELU_BLOCK_SIZE;
leaky_relu_f32<<<num_blocks, CUDA_RELU_BLOCK_SIZE, 0, stream>>>(x, dst, k, negative_slope);
}
static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SQR_BLOCK_SIZE - 1) / CUDA_SQR_BLOCK_SIZE;
sqr_f32<<<num_blocks, CUDA_SQR_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
gelu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
silu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
gelu_quick_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
tanh_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
hardsigmoid_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
hardswish_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
float negative_slope;
memcpy(&negative_slope, dst->op_params, sizeof(float));
leaky_relu_f32_cuda(src0_d, dst_d, ggml_nelements(src0), negative_slope, stream);
}
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
sqr_f32_cuda(src0_d, dst_d, ggml_nelements(src0), stream);
}

27
ggml-cuda/unary.cuh Normal file
View File

@ -0,0 +1,27 @@
#include "common.cuh"
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_TANH_BLOCK_SIZE 256
#define CUDA_RELU_BLOCK_SIZE 256
#define CUDA_HARDSIGMOID_BLOCK_SIZE 256
#define CUDA_HARDSWISH_BLOCK_SIZE 256
#define CUDA_SQR_BLOCK_SIZE 256
void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardsigmoid(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_hardswish(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_leaky_relu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_sqr(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

48
ggml-cuda/upscale.cu Normal file
View File

@ -0,0 +1,48 @@
#include "upscale.cuh"
static __global__ void upscale_f32(const float * x, float * dst, const int ne00, const int ne00xne01, const int scale_factor) {
// blockIdx.z: idx of ne02*ne03
// blockIdx.y: idx of ne01*scale_factor aka ne1
// blockIDx.x: idx of ne00*scale_factor / BLOCK_SIZE
// ne00xne01: ne00 * ne01
int ne0 = ne00 * scale_factor;
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
return;
}
// operation
int i00 = nidx / scale_factor;
int i01 = blockIdx.y / scale_factor;
int offset_src =
i00 +
i01 * ne00 +
blockIdx.z * ne00xne01;
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
dst[offset_dst] = x[offset_src];
}
static void upscale_f32_cuda(const float * x, float * dst, const int ne00, const int ne01, const int ne02, const int ne03,
const int scale_factor, cudaStream_t stream) {
int ne0 = (ne00 * scale_factor);
int num_blocks = (ne0 + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
dim3 gridDim(num_blocks, (ne01 * scale_factor), ne02*ne03);
upscale_f32<<<gridDim, CUDA_UPSCALE_BLOCK_SIZE, 0, stream>>>(x, dst, ne00, ne00 * ne01, scale_factor);
}
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
const int scale_factor = dst->op_params[0];
upscale_f32_cuda(src0_d, dst_d, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], scale_factor, stream);
}

5
ggml-cuda/upscale.cuh Normal file
View File

@ -0,0 +1,5 @@
#include "common.cuh"
#define CUDA_UPSCALE_BLOCK_SIZE 256
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

1284
ggml-cuda/vecdotq.cuh Normal file

File diff suppressed because it is too large Load Diff