Skip to content

Commit a4976e3

Browse files
sushraja-msftgithub-actions[bot]
andauthoredApr 4, 2025··
Add support for uint8_t as data type for GatherBlockQuantized (#24239)
### Description This change adds support for GatherBlockQuantized to use uin8_t as data's type with the same semantics as MatMulNBits. Zero_Points and Gather Axis other than 0 are not yet supported, in order to keep the change scoped. ### Motivation and Context With the newer llama models like Phi4 trained with shared embeddings, the weights of the lm_head matrix and the embeddings table are exactly the same. These embeddings are huge, unquantized embeddings are 1.2GB in Phi4 mini instruct, at int4 quantization the weights are still 300MB. We can go a step further and have these two ops the lm_head matmulnbits and GatherBlockQuantized share the same weights, that would save 300MB on the model size. The two things that hinder that are the shape expectations for GatherBlockQuantized and the data type supported for data in GatherBlockQuantized. The shape can be solved via a simple reshape op, but the data type needs code changes and that is what this change does. Here is Phi4 modified with shared weights between lm_head and matmulnbits, this model is just 2.1GB on disk. <img width="164" alt="image" src="https://github.com/user-attachments/assets/8bdddbb9-5b44-4839-ab48-605bee53d66b" /> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 11fda2a commit a4976e3

File tree

6 files changed

+142
-16
lines changed

6 files changed

+142
-16
lines changed
 

‎docs/ContribOperators.md

+3-2
Original file line numberDiff line numberDiff line change
@@ -2039,10 +2039,11 @@ This version of the operator has been available since version 1 of the 'com.micr
20392039
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
20402040
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
20412041
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
2042-
If `zero_points` is not provided, 0 is the zero point.
2042+
If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
20432043
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
20442044
to dequantize the output.
20452045
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
2046+
5. For uint8 data, the `gather_axis` must be 0.
20462047

20472048
#### Version
20482049

@@ -2082,7 +2083,7 @@ This version of the operator has been available since version 1 of the 'com.micr
20822083
#### Type Constraints
20832084

20842085
<dl>
2085-
<dt><tt>T1</tt> : tensor(int4), tensor(uint4)</dt>
2086+
<dt><tt>T1</tt> : tensor(int4), tensor(uint4), tensor(uint8)</dt>
20862087
<dd>Constrain quantized types.</dd>
20872088
<dt><tt>T2</tt> : tensor(float), tensor(float16), tensor(bfloat16)</dt>
20882089
<dd>Constrain dequantized types.</dd>

‎docs/OperatorKernels.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@ Do not modify directly.*
515515
|FusedConv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *in* Z:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
516516
|FusedGemm|*in* A:**T**<br> *in* B:**T**<br> *in* C:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
517517
|FusedMatMul|*in* A:**T**<br> *in* B:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
518-
|GatherBlockQuantized|*in* data:**T1**<br> *in* indices:**Tind**<br> *in* scales:**T2**<br> *in* zero_points:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4)<br/> **T2** = tensor(float), tensor(float16)<br/> **Tind** = tensor(int32), tensor(int64)|
518+
|GatherBlockQuantized|*in* data:**T1**<br> *in* indices:**Tind**<br> *in* scales:**T2**<br> *in* zero_points:**T1**<br> *out* output:**T2**|1+|**T1** = tensor(int4), tensor(uint4), tensor(uint8)<br/> **T2** = tensor(float), tensor(float16)<br/> **Tind** = tensor(int32), tensor(int64)|
519519
|GatherND|*in* data:**T**<br> *in* indices:**Tind**<br> *out* output:**T**|1+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(string), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)<br/> **Tind** = tensor(int32), tensor(int64)|
520520
|Gelu|*in* X:**T**<br> *out* Y:**T**|1+|**T** = tensor(float)|
521521
|GreedySearch|*in* input_ids:**I**<br> *in* max_length:**I**<br> *in* min_length:**I**<br> *in* repetition_penalty:**T**<br> *in* vocab_mask:**I**<br> *in* prefix_vocab_mask:**I**<br> *in* attention_mask:**I**<br> *out* sequences:**I**|1+|**T** = tensor(float)|

‎onnxruntime/contrib_ops/cpu/cpu_contrib_kernels.cc

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Fused
3838
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits);
3939
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits);
4040
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4);
41+
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized);
42+
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized);
4143
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized);
4244
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized);
4345
class ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized);
@@ -318,6 +320,8 @@ Status RegisterCpuContribKernels(KernelRegistry& kernel_registry) {
318320
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, float, MatMulNBits)>,
319321
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MLFloat16, MatMulNBits)>,
320322
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, MatMulBnb4)>,
323+
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int32_t, GatherBlockQuantized)>,
324+
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, uint8_t, int64_t, GatherBlockQuantized)>,
321325
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int32_t, GatherBlockQuantized)>,
322326
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, UInt4x2, int64_t, GatherBlockQuantized)>,
323327
BuildKernelCreateInfo<ONNX_OPERATOR_TWO_TYPED_KERNEL_CLASS_NAME(kCpuExecutionProvider, kMSDomain, 1, Int4x2, int32_t, GatherBlockQuantized)>,

‎onnxruntime/contrib_ops/cpu/quantization/gather_block_quantized.cc

+37-8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,21 @@
1616
namespace onnxruntime {
1717
namespace contrib {
1818

19+
namespace {
20+
template <typename T1>
21+
int32_t GetDataElement(const T1* data_ptr, int64_t data_idx) {
22+
return static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
23+
}
24+
25+
template <>
26+
int32_t GetDataElement<uint8_t>(const uint8_t* data_ptr, int64_t data_idx) {
27+
const uint8_t data_val_u8 = data_ptr[data_idx >> 1];
28+
// Weights are stored as (nibble2)(nibble1) in uint8_t.
29+
auto data_val = static_cast<int32_t>((data_idx & 1) ? ((data_val_u8 >> 4) & 0x0F) : (data_val_u8 & 0x0F));
30+
return data_val;
31+
}
32+
} // namespace
33+
1934
template <typename T1, typename Tind>
2035
class GatherBlockQuantized : public OpKernel {
2136
public:
@@ -98,6 +113,12 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
98113
for (int64_t i = p.gather_axis + 1; i < static_cast<int64_t>(data_rank); ++i)
99114
shape.push_back(data_shape[narrow<size_t>(i)]);
100115

116+
// When data is stored as uint8_t, each element has two int4 values.
117+
// The shape in the onnx model reflects that by having the last dimension be half the number of values.
118+
// Ex: For a true data size of 2000x3072, the onnx model would have data of shape 2000x1536.
119+
// However the outputs still need to be of size 2000x3072. Therefore we x2 the last dimension here.
120+
uint32_t components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
121+
shape[shape.size() - 1] = shape.back() * components;
101122
p.output_tensor = context->Output(0, TensorShape(std::move(shape)));
102123

103124
// validate quantization parameters
@@ -106,7 +127,7 @@ Status GatherBlockQuantized<T1, Tind>::PrepareForCompute(OpKernelContext* contex
106127
"data and scales must have the same rank.");
107128
for (size_t i = 0; i < data_shape.NumDimensions(); ++i) {
108129
ORT_RETURN_IF_NOT(i == static_cast<size_t>(p.quantize_axis)
109-
? (data_shape[i] + block_size_ - 1) / block_size_ == scales_shape[i]
130+
? (data_shape[i] * components + block_size_ - 1) / block_size_ == scales_shape[i]
110131
: data_shape[i] == scales_shape[i],
111132
"data and scales do not match shapes.");
112133
}
@@ -165,16 +186,22 @@ Status GatherBlockQuantized<T1, Tind>::CopyDataAndDequantize(const T1* data_ptr,
165186
int64_t output_idx = output_idx_base;
166187
int64_t data_idx = data_idx_base;
167188
for (int64_t i = 0; i < gather_block; ++i, ++output_idx, ++data_idx) {
168-
auto data_val = static_cast<int32_t>(data_ptr[data_idx >> 1].GetElem(narrow<size_t>(data_idx & 1)));
189+
auto data_val = GetDataElement(data_ptr, data_idx);
169190

170191
int64_t x = data_idx / quantize_full_block;
171192
int64_t y = data_idx % quantize_full_block / quantize_N;
172193
int64_t z = data_idx % quantize_N;
173194
int64_t scale_idx = x * scale_full_block + y / block_size_ * quantize_N + z;
174195
auto scale_val = static_cast<float>(scales_ptr[scale_idx]);
175-
auto zp_val = static_cast<int32_t>(zero_points_ptr
176-
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
177-
: 0);
196+
int32_t zp_val;
197+
if constexpr (std::is_same_v<T1, uint8_t>) {
198+
// The default zero point for uint8 weights as stored by MatMulNBits op is 8.
199+
zp_val = 8;
200+
} else {
201+
zp_val = static_cast<int32_t>(zero_points_ptr
202+
? zero_points_ptr[scale_idx >> 1].GetElem(narrow<size_t>(scale_idx & 1))
203+
: 0);
204+
}
178205

179206
output_ptr[output_idx] = static_cast<T2>(static_cast<float>(data_val - zp_val) * scale_val);
180207
}
@@ -205,7 +232,7 @@ template <typename T1, typename Tind>
205232
Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
206233
Prepare p;
207234
ORT_RETURN_IF_ERROR(PrepareForCompute(context, p));
208-
235+
auto components = (std::is_same_v<T1, uint8_t>) ? 2 : 1;
209236
const auto& data_shape = p.data_tensor->Shape();
210237
// re-shape the data tensor to [gather_M, gather_axis_dim, gather_block]
211238
// re-shape the indices tensor to [gather_N]
@@ -215,7 +242,7 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
215242
// 2> block is picked from data based on value from indices: axis_i = indices[blk_i % gather_N],
216243
// 3> get the corresponding block in data tensor: data_blk = data[blk_i / gather_N, axis_i, :],
217244
// 4> pick the element from the block: value_i = data_blk[blk_ele_i]
218-
const int64_t gather_block = data_shape.SizeFromDimension(SafeInt<size_t>(p.gather_axis) + 1);
245+
const int64_t gather_block = data_shape.SizeFromDimension(SafeInt<size_t>(p.gather_axis) + 1) * components;
219246
const int64_t gather_axis_dim = data_shape[narrow<size_t>(p.gather_axis)];
220247
const int64_t gather_M = data_shape.SizeToDimension(narrow<size_t>(p.gather_axis));
221248
const int64_t gather_N = p.indices_tensor->Shape().Size();
@@ -229,7 +256,7 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
229256
// data_i % (quantize_axis_dim * quantize_N) / quantize_N,
230257
// data_i % quantize_N)
231258
// 4> get scale index: (x, y / block_size_, z)
232-
const int64_t quantize_axis_dim = data_shape[narrow<size_t>(p.quantize_axis)];
259+
const int64_t quantize_axis_dim = data_shape[narrow<size_t>(p.quantize_axis)] * components;
233260
const int64_t quantize_N = data_shape.SizeFromDimension(SafeInt<size_t>(p.quantize_axis) + 1);
234261

235262
concurrency::ThreadPool* tp = context->GetOperatorThreadPool();
@@ -273,6 +300,8 @@ Status GatherBlockQuantized<T1, Tind>::Compute(OpKernelContext* context) const {
273300
.TypeConstraint("Tind", DataTypeImpl::GetTensorType<Tind>()), \
274301
GatherBlockQuantized<T1, Tind>);
275302

303+
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int32_t);
304+
REGISTER_GATHERBLOCKQUANTIZED(uint8_t, int64_t);
276305
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int32_t);
277306
REGISTER_GATHERBLOCKQUANTIZED(UInt4x2, int64_t);
278307
REGISTER_GATHERBLOCKQUANTIZED(Int4x2, int32_t);

‎onnxruntime/core/graph/contrib_ops/contrib_defs.cc

+17-4
Original file line numberDiff line numberDiff line change
@@ -3571,10 +3571,11 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
35713571
1. Input `data` is a constant. It is quantized block-wise along attribute `quantize_axis` with block size specified by attribute `block_size`.
35723572
`block_size must` be a power of 2 and not smaller than 16, like 16, 32, 64, 128, ..
35733573
2. Input `data`'s scale and zero point are specified by input `scales` and `zero_points`. `scales` and `zero_points` are also constants.
3574-
If `zero_points` is not provided, 0 is the zero point.
3574+
If `zero_points` is not provided, 0 is the zero point except when data is uint8 type then the default zero point is 8.
35753575
3. During the op execution, `data` and `indices` are first used to generate the quantized output. Then, `scales` and `zero_points` are used
35763576
to dequantize the output.
35773577
4. The `output` and `scales` have the same type. The `data` and `zero_points` have the same type.
3578+
5. For uint8 data, the `gather_axis` must be 0.
35783579
)DOC";
35793580

35803581
ONNX_CONTRIB_OPERATOR_SCHEMA(GatherBlockQuantized)
@@ -3602,7 +3603,7 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36023603
.Input(2, "scales", "quantization scale", "T2")
36033604
.Input(3, "zero_points", "quantization zero points", "T1", OpSchema::Optional)
36043605
.Output(0, "output", "Dequantized output tensor of rank q + (r - 1).", "T2")
3605-
.TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)"}, "Constrain quantized types.")
3606+
.TypeConstraint("T1", {"tensor(int4)", "tensor(uint4)", "tensor(uint8)"}, "Constrain quantized types.")
36063607
.TypeConstraint("T2", {"tensor(float)", "tensor(float16)", "tensor(bfloat16)"}, "Constrain dequantized types.")
36073608
.TypeConstraint("Tind", {"tensor(int32)", "tensor(int64)"}, "Constrain indices to integer types.")
36083609
.TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
@@ -3637,21 +3638,30 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36373638
gather_axis = (gather_axis + r) % r;
36383639
quantize_axis = (quantize_axis + r) % r;
36393640

3641+
if ((ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) && gather_axis != 0) {
3642+
fail_shape_inference("gather_axis must be 0, for uint8 data");
3643+
}
3644+
36403645
if (scales_shape.dim_size() != r) {
36413646
fail_shape_inference("scales must have the same rank as data");
36423647
}
36433648

3649+
uint32_t components = ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8 ? 2 : 1;
36443650
for (int i = 0; i < r; ++i) {
36453651
if (!data_shape.dim(i).has_dim_value() ||
36463652
!scales_shape.dim(i).has_dim_value() ||
3647-
(i == quantize_axis && (data_shape.dim(i).dim_value() + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
3653+
(i == quantize_axis && (data_shape.dim(i).dim_value() * components + block_size - 1) / block_size != scales_shape.dim(i).dim_value()) ||
36483654
(i != quantize_axis && data_shape.dim(i).dim_value() != scales_shape.dim(i).dim_value())) {
36493655
fail_shape_inference("data shape and scales shape do not match");
36503656
}
36513657
}
36523658

36533659
// validate zero point shape
36543660
if (ctx.hasInput(3)) {
3661+
if (ctx.getInputType(0)->tensor_type().elem_type() == onnx::TensorProto_DataType_UINT8) {
3662+
fail_type_inference("zero_points are not supported for uint8_t data type");
3663+
}
3664+
36553665
if (!hasInputShape(ctx, 3)) {
36563666
fail_shape_inference("zero_points shape must be known");
36573667
}
@@ -3675,12 +3685,15 @@ GatherBlockQuantized is a Gather with data quantized. It is similar to Gather (h
36753685
ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape();
36763686
}
36773687
for (int i = 0; i < out_rank; ++i) {
3688+
// For uint8_t data type the last dimension needs to be expanded back to actual dimension,
3689+
// because the data 2 int4s are stored packed in a single uint8_t.
3690+
auto last_dimension_components = (i == out_rank - 1) ? components : 1;
36783691
*ctx.getOutputType(0)->mutable_tensor_type()->mutable_shape()->add_dim() =
36793692
(i < gather_axis)
36803693
? data_shape.dim(i)
36813694
: (i >= gather_axis && i < gather_axis + q)
36823695
? indices_shape.dim(i - gather_axis)
3683-
: data_shape.dim(i - q + 1);
3696+
: data_shape.dim(i - q + 1) * last_dimension_components;
36843697
}
36853698
});
36863699

0 commit comments

Comments
 (0)
Please sign in to comment.