@@ -50,8 +50,8 @@ def __init__(
50
50
quant_axes (dict[str, int], optional):
51
51
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
52
52
customized_weight_config:
53
- customized weight config for nodes if needed.
54
- If both customized_weight_config and nodes_to_exclude are set, nodes_to_exclude overwrites customized_weight_config .
53
+ customized weight config for nodes if needed. It is dictionary with node name as key,
54
+ and the value is a dict of customized config .
55
55
"""
56
56
self .algorithm = algorithm
57
57
self .quant_format = quant_format
@@ -81,6 +81,9 @@ def __init__(
81
81
Defaults to QuantFormat.QOperator.
82
82
op_types_to_quantize (optional):
83
83
set of operator types to quantize.
84
+ customized_weight_config:
85
+ customized weight config for nodes if needed. It is dictionary with node name as key,
86
+ and the value is a dict of customized config.
84
87
"""
85
88
assert quant_format == QuantFormat .QOperator , "RTN only supports QOperator format"
86
89
@@ -220,6 +223,8 @@ def __init__(
220
223
set of operator types to quantize.
221
224
quant_axes (dict[str, int], optional):
222
225
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
226
+ bits (int, optional):
227
+ number of bits per element after quantization. Default 4.
223
228
"""
224
229
super ().__init__ (
225
230
algorithm = "DEFAULT" ,
@@ -661,8 +666,10 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
661
666
scales_torch = scales_torch .contiguous ()
662
667
zero_points_torch = zero_points_torch .contiguous ()
663
668
669
+ packed_size = (8 // self .config .bits ) # number of elements packed into one byte
670
+
664
671
packed_torch = torch .zeros (
665
- (quant_weight_torch .shape [0 ], quant_weight_torch .shape [1 ] // 2 ),
672
+ (quant_weight_torch .shape [0 ], quant_weight_torch .shape [1 ] // packed_size ),
666
673
dtype = torch .uint8 ,
667
674
device = quant_weight_torch .device ,
668
675
)
@@ -674,12 +681,12 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
674
681
zero_points = zero_points .reshape (- 1 )
675
682
rows , cols = b_array_torch .shape
676
683
block_size = self .config .block_size
677
- blob_size = block_size // 2
684
+ blob_size = block_size // packed_size
678
685
k_blocks = (rows + block_size - 1 ) // block_size
679
686
packed_torch = packed_torch .reshape (cols , k_blocks , blob_size )
680
687
681
688
b_quant = onnx .numpy_helper .from_array (packed_torch .cpu ().numpy ())
682
- b_quant .name = b_pb .name + "_Q4"
689
+ b_quant .name = b_pb .name + "_Q" + str ( self . config . bits )
683
690
for input in bs_graph .input :
684
691
if input .name == input_b :
685
692
bs_graph .input .remove (input )
@@ -702,18 +709,18 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
702
709
kwargs ["bits" ] = self .config .bits
703
710
kwargs ["block_size" ] = self .config .block_size
704
711
705
- matmul_q4_node = onnx .helper .make_node (
712
+ matmul_q_node = onnx .helper .make_node (
706
713
"MatMulNBits" ,
707
714
inputs = input_names ,
708
715
outputs = [node .output [0 ]],
709
- name = node .name + "_Q4" if node .name else "" ,
716
+ name = node .name + "_Q" + str ( self . config . bits ) if node .name else "" ,
710
717
domain = "com.microsoft" ,
711
718
** kwargs ,
712
719
)
713
720
714
721
logger .info (f"complete quantization of { node .name } ..." )
715
722
716
- return [matmul_q4_node ]
723
+ return [matmul_q_node ]
717
724
718
725
719
726
def get_initializer (name , graph_path : list [GraphProto ]) -> tuple [TensorProto , GraphProto ]:
@@ -761,7 +768,7 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
761
768
packed , fp32weight , scales , zero_point , block_size , cols , rows , self .config .is_symmetric
762
769
)
763
770
else :
764
- # QDQ format only support 4 bits quantization
771
+ assert qbits == 4 , " QDQ format only support 4 bits quantization"
765
772
packed = np .zeros ((rows * cols + 1 ) // 2 , dtype = "uint8" )
766
773
zero_point = np .zeros ((cols * k_blocks + 1 ) // 2 , dtype = "uint8" )
767
774
scales = np .zeros ((k_blocks , cols ), dtype = fp32weight .dtype )
@@ -798,6 +805,7 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
798
805
b_quant = onnx .numpy_helper .from_array (packed , b_tensor .name + f"_Q{ bits } " )
799
806
scales_tensor = onnx .numpy_helper .from_array (scales , b_tensor .name + "_scales" )
800
807
else :
808
+ print (f"quantize { b_tensor .name } to { bits } bits: { b_ndarray .shape = } { packed .tobytes ().size = } " )
801
809
b_quant = onnx .helper .make_tensor (
802
810
b_tensor .name + f"_DQ_Q{ bits } " , qtype , b_ndarray .shape , packed .tobytes (), True
803
811
)
@@ -1095,14 +1103,13 @@ def quantize_awq(self, model: ModelProto | str) -> ModelProto:
1095
1103
return quantized_model
1096
1104
1097
1105
1098
- # TODO(fajin): change class name
1099
- class MatMul4BitsQuantizer :
1106
+ class MatMulNBitsQuantizer :
1100
1107
"""
1101
1108
Target node: QOperator node: QDQ nodes:
1102
1109
MatMul MatMulNBits DeQuantizeLinear -> MatMul
1103
1110
Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
1104
1111
1105
- Perform 4b quantization of constant weights for target nodes.
1112
+ Perform 4/8 bits quantization of constant weights for target nodes.
1106
1113
If algo_config.quant_format is QOperator:
1107
1114
- nodes are replaced by the corresponding QOperator nodes.
1108
1115
- quantized weights are stored in the contrib ops.
@@ -1114,6 +1121,7 @@ class MatMul4BitsQuantizer:
1114
1121
Note:
1115
1122
- for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
1116
1123
during runtime. Therefor it is not recommended.
1124
+ - when a node is in nodes_to_exclude, and the node configuration in algo_config.customized_weight_config will be ignored.
1117
1125
"""
1118
1126
1119
1127
def __init__ (
@@ -1148,8 +1156,13 @@ def __init__(
1148
1156
quant_format = quant_format ,
1149
1157
op_types_to_quantize = op_types_to_quantize ,
1150
1158
quant_axes = quant_axes ,
1159
+ bits = 4 , # default to 4 bits
1151
1160
)
1161
+
1152
1162
self .algo_config = algo_config
1163
+ if hasattr (self .algo_config , "bits" ):
1164
+ assert self .algo_config .bits in [4 , 8 ], "Only support 4 or 8 bits quantization"
1165
+
1153
1166
if algo_config .algorithm == "HQQ" :
1154
1167
self .node_quantizer = HQQWeightOnlyQuantizer (self .algo_config )
1155
1168
elif algo_config .algorithm == "DEFAULT" :
@@ -1511,7 +1524,7 @@ def parse_args():
1511
1524
else :
1512
1525
raise ValueError (f"Unsupported quantization method: { args .quant_method } " )
1513
1526
1514
- quant = MatMul4BitsQuantizer (
1527
+ quant = MatMulNBitsQuantizer (
1515
1528
model = model ,
1516
1529
accuracy_level = args .accuracy_level ,
1517
1530
nodes_to_exclude = args .nodes_to_exclude ,
0 commit comments