|
| 1 | +import sys |
| 2 | + |
1 | 3 | import neural_compressor as inc |
2 | 4 | print("neural_compressor version {}".format(inc.__version__)) |
3 | 5 |
|
4 | | -import alexnet |
5 | | -import math |
6 | | -import yaml |
7 | | -import mnist_dataset |
| 6 | +import tensorflow as tf |
| 7 | +print("tensorflow {}".format(tf.__version__)) |
| 8 | + |
| 9 | +from neural_compressor.config import PostTrainingQuantConfig, AccuracyCriterion, TuningCriterion |
| 10 | +from neural_compressor.data import DataLoader |
8 | 11 | from neural_compressor.quantization import fit |
9 | | -from neural_compressor.config import PostTrainingQuantConfig, TuningCriterion, AccuracyCriterion |
| 12 | +from neural_compressor import Metric |
10 | 13 |
|
| 14 | +import mnist_dataset |
11 | 15 |
|
12 | | -def save_int8_frezon_pb(q_model, path): |
13 | | - from tensorflow.python.platform import gfile |
14 | | - f = gfile.GFile(path, 'wb') |
15 | | - f.write(q_model.graph.as_graph_def().SerializeToString()) |
16 | | - print("Save to {}".format(path)) |
17 | 16 |
|
| 17 | +class Dataset(object): |
| 18 | + def __init__(self): |
| 19 | + _x_train, _y_train, label_train, x_test, y_test, label_test = mnist_dataset.read_data() |
18 | 20 |
|
19 | | -class Dataloader(object): |
20 | | - def __init__(self, batch_size): |
21 | | - self.batch_size = batch_size |
| 21 | + self.test_images = x_test |
| 22 | + self.labels = label_test |
22 | 23 |
|
23 | | - def __iter__(self): |
24 | | - x_train, y_train, label_train, x_test, y_test, label_test = mnist_dataset.read_data() |
25 | | - batch_nums = math.ceil(len(x_test) / self.batch_size) |
| 24 | + def __getitem__(self, index): |
| 25 | + return self.test_images[index], self.labels[index] |
26 | 26 |
|
27 | | - for i in range(batch_nums - 1): |
28 | | - begin = i * self.batch_size |
29 | | - end = (i + 1) * self.batch_size |
30 | | - yield x_test[begin: end], label_test[begin: end] |
| 27 | + def __len__(self): |
| 28 | + return len(self.test_images) |
31 | 29 |
|
32 | | - begin = (batch_nums - 1) * self.batch_size |
33 | | - yield x_test[begin:], label_test[begin:] |
34 | 30 |
|
| 31 | +def auto_tune(input_graph_path, batch_size): |
| 32 | + dataset = Dataset() |
| 33 | + dataloader = DataLoader(framework='tensorflow', dataset=dataset, batch_size=batch_size) |
| 34 | + tuning_criterion = TuningCriterion(max_trials=100) |
| 35 | + config = PostTrainingQuantConfig(approach="static", tuning_criterion=tuning_criterion, |
| 36 | + accuracy_criterion = AccuracyCriterion( |
| 37 | + higher_is_better=True, |
| 38 | + criterion='relative', |
| 39 | + tolerable_loss=0.01 ) |
| 40 | + ) |
| 41 | + top1 = Metric(name="topk", k=1) |
35 | 42 |
|
36 | | -def auto_tune(input_graph_path, config, batch_size): |
37 | | - fp32_graph = alexnet.load_pb(input_graph_path) |
38 | | - dataloader = Dataloader(batch_size) |
39 | | - assert(dataloader) |
40 | | - |
41 | | - tuning_criterion = TuningCriterion(**config["tuning_criterion"]) |
42 | | - accuracy_criterion = AccuracyCriterion(**config["accuracy_criterion"]) |
43 | 43 | q_model = fit( |
44 | | - model=input_graph_path, |
45 | | - conf=PostTrainingQuantConfig(**config["quant_config"], |
46 | | - tuning_criterion=tuning_criterion, |
47 | | - accuracy_criterion=accuracy_criterion, |
48 | | - ), |
49 | | - calib_dataloader=dataloader, |
| 44 | + model=input_graph_path, |
| 45 | + conf=config, |
| 46 | + calib_dataloader=dataloader, |
| 47 | + eval_dataloader=dataloader, |
| 48 | + eval_metric=top1 |
50 | 49 | ) |
| 50 | + |
| 51 | + |
51 | 52 | return q_model |
52 | 53 |
|
53 | 54 |
|
54 | 55 | batch_size = 200 |
55 | | -fp32_frezon_pb_file = "fp32_frezon.pb" |
| 56 | +fp32_frozen_pb_file = "fp32_frozen.pb" |
56 | 57 | int8_pb_file = "alexnet_int8_model.pb" |
57 | 58 |
|
58 | | -with open("quant_config.yaml") as f: |
59 | | - config = yaml.safe_load(f.read()) |
60 | | -config |
61 | | - |
62 | | -q_model = auto_tune(fp32_frezon_pb_file, config, batch_size) |
63 | | -save_int8_frezon_pb(q_model, int8_pb_file) |
| 59 | +q_model = auto_tune(fp32_frozen_pb_file, batch_size) |
| 60 | +q_model.save(int8_pb_file) |
0 commit comments