Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 459307e

Browse files
committedApr 19, 2025·
rename 4bits to nbits
1 parent cd9c02f commit 459307e

File tree

5 files changed

+51
-29
lines changed

5 files changed

+51
-29
lines changed
 

‎onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py renamed to ‎onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def __init__(
5050
quant_axes (dict[str, int], optional):
5151
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
5252
customized_weight_config:
53-
customized weight config for nodes if needed.
54-
If both customized_weight_config and nodes_to_exclude are set, nodes_to_exclude overwrites customized_weight_config.
53+
customized weight config for nodes if needed. It is dictionary with node name as key,
54+
and the value is a dict of customized config.
5555
"""
5656
self.algorithm = algorithm
5757
self.quant_format = quant_format
@@ -81,6 +81,9 @@ def __init__(
8181
Defaults to QuantFormat.QOperator.
8282
op_types_to_quantize (optional):
8383
set of operator types to quantize.
84+
customized_weight_config:
85+
customized weight config for nodes if needed. It is dictionary with node name as key,
86+
and the value is a dict of customized config.
8487
"""
8588
assert quant_format == QuantFormat.QOperator, "RTN only supports QOperator format"
8689

@@ -220,6 +223,8 @@ def __init__(
220223
set of operator types to quantize.
221224
quant_axes (dict[str, int], optional):
222225
op:axis, which axis to quantize for an op. Default {MatMul: 0, Gather: 1}
226+
bits (int, optional):
227+
number of bits per element after quantization. Default 4.
223228
"""
224229
super().__init__(
225230
algorithm="DEFAULT",
@@ -661,8 +666,10 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
661666
scales_torch = scales_torch.contiguous()
662667
zero_points_torch = zero_points_torch.contiguous()
663668

669+
packed_size = (8 // self.config.bits) # number of elements packed into one byte
670+
664671
packed_torch = torch.zeros(
665-
(quant_weight_torch.shape[0], quant_weight_torch.shape[1] // 2),
672+
(quant_weight_torch.shape[0], quant_weight_torch.shape[1] // packed_size),
666673
dtype=torch.uint8,
667674
device=quant_weight_torch.device,
668675
)
@@ -674,12 +681,12 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
674681
zero_points = zero_points.reshape(-1)
675682
rows, cols = b_array_torch.shape
676683
block_size = self.config.block_size
677-
blob_size = block_size // 2
684+
blob_size = block_size // packed_size
678685
k_blocks = (rows + block_size - 1) // block_size
679686
packed_torch = packed_torch.reshape(cols, k_blocks, blob_size)
680687

681688
b_quant = onnx.numpy_helper.from_array(packed_torch.cpu().numpy())
682-
b_quant.name = b_pb.name + "_Q4"
689+
b_quant.name = b_pb.name + "_Q" + str(self.config.bits)
683690
for input in bs_graph.input:
684691
if input.name == input_b:
685692
bs_graph.input.remove(input)
@@ -702,18 +709,18 @@ def quantize(self, node: NodeProto, graph_stack: list[GraphProto]) -> list[NodeP
702709
kwargs["bits"] = self.config.bits
703710
kwargs["block_size"] = self.config.block_size
704711

705-
matmul_q4_node = onnx.helper.make_node(
712+
matmul_q_node = onnx.helper.make_node(
706713
"MatMulNBits",
707714
inputs=input_names,
708715
outputs=[node.output[0]],
709-
name=node.name + "_Q4" if node.name else "",
716+
name=node.name + "_Q" + str(self.config.bits) if node.name else "",
710717
domain="com.microsoft",
711718
**kwargs,
712719
)
713720

714721
logger.info(f"complete quantization of {node.name} ...")
715722

716-
return [matmul_q4_node]
723+
return [matmul_q_node]
717724

718725

719726
def get_initializer(name, graph_path: list[GraphProto]) -> tuple[TensorProto, GraphProto]:
@@ -761,7 +768,7 @@ def qbits_block_quant(self, fp32weight: npt.ArrayLike) -> tuple[np.ndarray, np.n
761768
packed, fp32weight, scales, zero_point, block_size, cols, rows, self.config.is_symmetric
762769
)
763770
else:
764-
# QDQ format only support 4 bits quantization
771+
assert qbits == 4, "QDQ format only support 4 bits quantization"
765772
packed = np.zeros((rows * cols + 1) // 2, dtype="uint8")
766773
zero_point = np.zeros((cols * k_blocks + 1) // 2, dtype="uint8")
767774
scales = np.zeros((k_blocks, cols), dtype=fp32weight.dtype)
@@ -798,6 +805,7 @@ def quantize_matmul(self, node: NodeProto, graph_stack: list[GraphProto]) -> lis
798805
b_quant = onnx.numpy_helper.from_array(packed, b_tensor.name + f"_Q{bits}")
799806
scales_tensor = onnx.numpy_helper.from_array(scales, b_tensor.name + "_scales")
800807
else:
808+
print(f"quantize {b_tensor.name} to {bits} bits: {b_ndarray.shape=} {packed.tobytes().size=}")
801809
b_quant = onnx.helper.make_tensor(
802810
b_tensor.name + f"_DQ_Q{bits}", qtype, b_ndarray.shape, packed.tobytes(), True
803811
)
@@ -1095,14 +1103,13 @@ def quantize_awq(self, model: ModelProto | str) -> ModelProto:
10951103
return quantized_model
10961104

10971105

1098-
# TODO(fajin): change class name
1099-
class MatMul4BitsQuantizer:
1106+
class MatMulNBitsQuantizer:
11001107
"""
11011108
Target node: QOperator node: QDQ nodes:
11021109
MatMul MatMulNBits DeQuantizeLinear -> MatMul
11031110
Gather GatherBlockQuantized Gather, Gather, Gather (optional) -> DequantizeLinear
11041111
1105-
Perform 4b quantization of constant weights for target nodes.
1112+
Perform 4/8 bits quantization of constant weights for target nodes.
11061113
If algo_config.quant_format is QOperator:
11071114
- nodes are replaced by the corresponding QOperator nodes.
11081115
- quantized weights are stored in the contrib ops.
@@ -1114,6 +1121,7 @@ class MatMul4BitsQuantizer:
11141121
Note:
11151122
- for quantized gather, the memory usage of "DequantizeLinear + Gather" is the same as the original Gather
11161123
during runtime. Therefor it is not recommended.
1124+
- when a node is in nodes_to_exclude, and the node configuration in algo_config.customized_weight_config will be ignored.
11171125
"""
11181126

11191127
def __init__(
@@ -1148,8 +1156,13 @@ def __init__(
11481156
quant_format=quant_format,
11491157
op_types_to_quantize=op_types_to_quantize,
11501158
quant_axes=quant_axes,
1159+
bits=4, # default to 4 bits
11511160
)
1161+
11521162
self.algo_config = algo_config
1163+
if hasattr(self.algo_config, "bits"):
1164+
assert self.algo_config.bits in [4, 8], "Only support 4 or 8 bits quantization"
1165+
11531166
if algo_config.algorithm == "HQQ":
11541167
self.node_quantizer = HQQWeightOnlyQuantizer(self.algo_config)
11551168
elif algo_config.algorithm == "DEFAULT":
@@ -1511,7 +1524,7 @@ def parse_args():
15111524
else:
15121525
raise ValueError(f"Unsupported quantization method: {args.quant_method}")
15131526

1514-
quant = MatMul4BitsQuantizer(
1527+
quant = MatMulNBitsQuantizer(
15151528
model=model,
15161529
accuracy_level=args.accuracy_level,
15171530
nodes_to_exclude=args.nodes_to_exclude,

‎onnxruntime/python/tools/quantization/quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -929,11 +929,11 @@ def quantize(
929929
)
930930
else:
931931
# training package doesn't has quantize_matmul_4bits, avoid global import
932-
from .matmul_4bits_quantizer import MatMul4BitsQuantizer, WeightOnlyQuantConfig
932+
from .matmul_nbits_quantizer import MatMulNBitsQuantizer, WeightOnlyQuantConfig
933933

934934
if isinstance(quant_config, WeightOnlyQuantConfig):
935935
model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load(model_input)
936-
quant = MatMul4BitsQuantizer(model, algo_config=quant_config)
936+
quant = MatMulNBitsQuantizer(model, algo_config=quant_config)
937937
quant.process()
938938
quant.model.save_model_to_file(model_output, True)
939939
else:

‎onnxruntime/python/tools/transformers/models/llama/convert_to_onnx.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
from transformers import AutoConfig, AutoModelForCausalLM
3333

3434
from onnxruntime import quantization as ort_quantization
35-
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
35+
36+
from onnxruntime import __version__ as ort_version
37+
if version.parse(ort_version) < version.parse("1.22.0"):
38+
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
39+
else:
40+
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
3641

3742
torch_export_onnx_opset_version = 14
3843
logger = logging.getLogger("")
@@ -714,7 +719,7 @@ def get_args():
714719
required=False,
715720
default=32,
716721
type=int,
717-
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
722+
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
718723
)
719724

720725
blockwise_group.add_argument(
@@ -1025,7 +1030,7 @@ def main():
10251030
for fp_path, int4_path in zip(old_paths, new_paths, strict=False):
10261031
if os.path.exists(fp_path):
10271032
model = onnx.load_model(fp_path, load_external_data=True)
1028-
quant = MatMul4BitsQuantizer(
1033+
quant = MatMulNBitsQuantizer(
10291034
model=model,
10301035
block_size=args.block_size,
10311036
is_symmetric=True,

‎onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from onnx_model import OnnxModel
1717
from transformers import AutoConfig, AutoModelForCausalLM
1818

19-
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer
19+
from onnxruntime import __version__ as ort_version
20+
if version.parse(ort_version) < version.parse("1.22.0"):
21+
from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer as MatMulNBitsQuantizer
22+
else:
23+
from onnxruntime.quantization.matmul_nbits_quantizer import MatMulNBitsQuantizer
2024

2125

2226
class ConvertPhi2ToONNX:
@@ -160,7 +164,7 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str):
160164
return
161165
else:
162166
assert self.precision == Precision.INT4
163-
quant = MatMul4BitsQuantizer(
167+
quant = MatMulNBitsQuantizer(
164168
model=optimizer.model,
165169
block_size=self.block_size,
166170
is_symmetric=True,
@@ -351,7 +355,7 @@ def parse_arguments():
351355
required=False,
352356
default=16,
353357
type=int,
354-
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_4bits_quantizer.py for details.",
358+
help="Block size to quantize with. See https://github.com/microsoft/onnxruntime/blob/main/onnxruntime/python/tools/quantization/matmul_nbits_quantizer.py for details.",
355359
)
356360

357361
parser.add_argument(

‎onnxruntime/test/python/quantization/test_op_matmul_4bits.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -195,17 +195,17 @@ def quant_test(
195195
)
196196

197197
# Quantize fp32 model to int4 model
198-
from onnxruntime.quantization import matmul_4bits_quantizer
198+
from onnxruntime.quantization import matmul_nbits_quantizer
199199

200200
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
201-
quant_config = matmul_4bits_quantizer.DefaultWeightOnlyQuantConfig(
201+
quant_config = matmul_nbits_quantizer.DefaultWeightOnlyQuantConfig(
202202
block_size=block_size,
203203
is_symmetric=is_symmetric,
204204
quant_format=quant_format,
205205
op_types_to_quantize=op_types_to_quantize,
206206
quant_axes=quant_axes,
207207
)
208-
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, algo_config=quant_config)
208+
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(model, algo_config=quant_config)
209209
quant.process()
210210
quant.model.save_model_to_file(model_int4_path, False)
211211

@@ -260,21 +260,21 @@ def quant_test_with_algo(
260260
)
261261

262262
# Quantize fp32 model to int4 model
263-
from onnxruntime.quantization import matmul_4bits_quantizer
263+
from onnxruntime.quantization import matmul_nbits_quantizer
264264

265265
algo_config = None
266266
if algorithm == "RTN":
267267
# test RTN algorithm
268-
algo_config = matmul_4bits_quantizer.RTNWeightOnlyQuantConfig()
268+
algo_config = matmul_nbits_quantizer.RTNWeightOnlyQuantConfig()
269269
elif algorithm == "GPTQ":
270270
# test GPTQ algorithm
271-
algo_config = matmul_4bits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
271+
algo_config = matmul_nbits_quantizer.GPTQWeightOnlyQuantConfig(calibration_data_reader=data_reader)
272272
elif algorithm == "HQQ":
273273
# test HQQ algorithm
274-
algo_config = matmul_4bits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size)
274+
algo_config = matmul_nbits_quantizer.HQQWeightOnlyQuantConfig(block_size=block_size)
275275

276276
model = quant_utils.load_model_with_shape_infer(Path(model_fp32_path))
277-
quant = matmul_4bits_quantizer.MatMul4BitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
277+
quant = matmul_nbits_quantizer.MatMulNBitsQuantizer(model, block_size, is_symmetric, algo_config=algo_config)
278278
quant.process()
279279
quant.model.save_model_to_file(model_int4_path, False)
280280

0 commit comments

Comments
 (0)
Please sign in to comment.