-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathprune_model.py
More file actions
261 lines (221 loc) · 11.3 KB
/
prune_model.py
File metadata and controls
261 lines (221 loc) · 11.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
''' AdapNet++: Self-Supervised Model Adaptation for Multimodal Semantic Segmentation
Copyright (C) 2018 Abhinav Valada, Rohit Mohan and Wolfram Burgard
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.'''
import importlib
import json
import numpy as np
import os
import pickle
import re
import tensorflow as tf
from matplotlib import pyplot as plt
from models.mapping import mapping
decide_threshold = False # to visualize the histogram of l2 norm
phase = 'decoder_conv1' #conv1, block1/unit_{1,2,3}, block2/unit_{1,2,3,4}, block3/unit_{1,2,3,4,5,6}, block4unit_{1,2,3}, easpp, upsample1, upsample2, decoder_conv1, decoder_conv2
# upsample1 is the first deconvolution layer w.r.t input side of the network followed by upsample2 and upsample3
# decoder_conv1 is the 3x3 conv layer pair between upsample1 and upsample2 layers
# decoder_conv2 is the 3x3 conv layer pair between upsample2 and upsample3 layers
all_convs = False #set true, to run for all convolutions layers under a given phase. Pruning layer by layer is recommended rather than all at once.
conv_name = 'conv1' #in case of block phase, conv1, conv2, conv3, if the block consists of multiphase than conv2_convolution (block 3 and 4), conv2_convolution_1 (block 4)
#in case of decoder_convs, conv1 and conv2
#in case of easpp:
# 'conv256' 1x1
# 'conv70' 1x1 or 'conv7' 3x3 first atrous or 'conv247' 3x3 second atrous or 'conv71' 1x1 rate=3 unit
# 'conv80' 1x1 or 'conv8' 3x3 first atrous or 'conv248' 3x3 second atrous or 'conv81' 1x1 rate=6 unit
# 'conv90' 1x1 or 'conv9' 3x3 first atrous or 'conv249' 3x3 second atrous or 'conv91' 1x1 rate=12 unit
# 'conv10' final 1x1
save_nvidia_name = 'n1.p' # path to the nvidia computed rank values
list_of_convs = ['conv256/', 'conv70/', 'conv80/', 'conv90/', 'conv10/', 'conv71/', 'conv81/', 'conv91/', 'conv7/', 'conv8/', 'conv9/', 'conv247/', 'conv248/', 'conv249/']
threshold = 0.2 # parameters less than the threshold are removed
checkpoint_name = '/home/mohan/AdapNet_Training/checkpoint_final_cityuni/adapnet_sc-121999' # path address to original checkpoint
gpu_id = '1' # gpu id
try_zeros = False # when set true, elimination of the prameters doesn't take place instead their weights are set 0.
new_checkpoint_save = 'check/adapnet_sc-122999' #save path for prunned or zeroed checkpoint
mask_load = None # incase mask already exist, to update it
mask_save = 'temp.npy' # saving mask, used with conv3 of each block to handle shortcut connection
Num_classes = 12 # Number of Classes
height = 384 # height of the model
width = 768 # width of the model
model_def='models/default.json' #set it to the current model definition, if try_zeros = True, use this model_def with the new checkpoint to evaluate.
new_model_def='models/1.json' #path to the new prunned model definition, when try_zeros = False
def get_l2_norm(x):
x[x<0] = 0.0
l2 = np.sqrt(x**2)
return l2/np.max(l2)
def get_mask_id(x):
p=np.where(x==0)
p=p[0].reshape(-1,1)
p=np.int32(p)
return p
f = open(save_nvidia_name)
rank_values = pickle.load(f)
compute = False
trimmed = []
for op_name in rank_values:
if phase == 'conv1':
if phase+'/' in op_name and 'block' not in op_name:
compute = True
elif 'block' in phase:
if phase+'/' in op_name:
if all_convs:
compute = True
else:
if conv_name+'/' in op_name:
compute = True
elif len(conv_name.split('_')) == 2:
parts = conv_name.split('_')
if parts[0]+'/' in op_name and parts[1] in op_name:
compute = True
else:
parts = conv_name.split('_')
if parts[0]+'/' in op_name and '_'.join(parts[:1]) in op_name:
compute = True
elif 'easpp' in phase:
if all_convs:
for some_conv in list_of_convs:
if some_conv in op_name:
compute = True
elif conv_name+'/' in op_name and conv_name+'/' in list_of_convs:
compute = True
elif 'decoder_conv1' in phase:
if ('conv1' in conv_name or all_convs) and 'conv89/' in op_name:
compute = True
elif ('conv2' in conv_name or all_convs) and 'conv96/' in op_name:
compute = True
elif 'decoder_conv2' in phase:
if ('conv1' in conv_name or all_convs) and 'conv88/' in op_name:
compute = True
elif ('conv2' in conv_name or all_convs) and 'conv95/' in op_name:
compute = True
elif 'upsample1' in phase and 'conv41/' in op_name:
compute = True
elif 'upsample2' in phase and 'conv16/' in op_name:
compute = True
if compute:
print op_name
l2 = get_l2_norm(rank_values[op_name])
sorted_ = np.argsort(l2)
norm_val_sorted = l2[sorted_]
if decide_threshold:
plt.hist(norm_val_sorted)
plt.show()
else:
mask = l2<threshold
trimmed.append([op_name, mask])
compute = False
if decide_threshold == False:
with open(model_def) as f:
model_definition = json.load(f)
mask_id = {}
if mask_load is not None:
mask_id = np.load(mask_load)[()]
reader=tf.train.NewCheckpointReader(checkpoint_name)
weights_str = reader.debug_string()
exclude_variables = {}
tensor_list = []
mask_exist = False
for trim in trimmed:
if try_zeros:
name = ('/').join(trim[0].split('/')[:-1])+'/weights'
if name not in exclude_variables:
tensor = reader.get_tensor(name)
else:
tensor = exclude_variables[name]
mask = trim[1]
if 'split' in mapping[trim[0]].keys() and mapping[trim[0]]['split'] == 2:
mask = np.concatenate((np.zeros(model_definition['split'][mapping[trim[0]]['id']][0],dtype = mask.dtype), mask), -1)
elif 'split' in mapping[trim[0]].keys() and mapping[trim[0]]['split'] == 1:
mask = np.concatenate((mask, np.zeros(model_definition['split'][mapping[trim[0]]['id']][1],dtype = mask.dtype)), -1)
tensor[:,:,:,mask] = 0.0
exclude_variables[name] = tensor
else:
mask = trim[1]
if 'split' in mapping[trim[0]].keys() and mapping[trim[0]]['split'] == 2:
temp = mask.copy()
mask = np.concatenate((np.zeros(model_definition['split'][mapping[trim[0]]['id']][0],dtype = mask.dtype), mask), -1)
model_definition['split'][mapping[trim[0]]['id']][1] = np.sum(temp==0)
elif 'split' in mapping[trim[0]].keys() and mapping[trim[0]]['split'] == 1:
temp = mask.copy()
mask = np.concatenate((mask, np.zeros(model_definition['split'][mapping[trim[0]]['id']][1],dtype = mask.dtype)), -1)
model_definition['split'][mapping[trim[0]]['id']][0] = np.sum(temp==0)
if mapping[trim[0]]['BatchNorm']:
stuffs = ['/weights','/BatchNorm/beta','/BatchNorm/gamma','/BatchNorm/moving_mean','/BatchNorm/moving_variance']
else:
stuffs = ['/weights']
for stuff in stuffs:
name = ('/').join(trim[0].split('/')[:-1])+stuff
if name not in exclude_variables:
tensor = reader.get_tensor(name)
else:
tensor = exclude_variables[name]
if stuff == '/weights' and 'transpose' not in trim[0]:
tensor=np.delete(tensor,np.argwhere(mask==1),3)
elif stuff == '/weights' and 'transpose' in trim[0]:
tensor=np.delete(tensor,np.argwhere(mask==1),2)
else:
tensor=np.delete(tensor,np.argwhere(mask==1),0)
exclude_variables[name] = tensor
if 'place' not in mapping[trim[0]].keys():
model_definition['params'][mapping[trim[0]]['id']] = tensor.shape[-1]
else:
model_definition['params'][mapping[trim[0]]['id']][mapping[trim[0]]['place']] = tensor.shape[-1]
if 'mask' in mapping[trim[0]].keys():
mask_id[mapping[trim[0]]['id']] = get_mask_id(mask)
mask_exist = True
for key in mapping[trim[0]]:
if 'next' in key:
name = ('/').join(mapping[trim[0]][key].split('/')[:-1])+'/weights'
if name not in exclude_variables:
tensor = reader.get_tensor(name)
else:
tensor = exclude_variables[name]
if 'transpose' in mapping[trim[0]][key]:
tensor=np.delete(tensor,np.argwhere(mask==1),3)
else:
tensor=np.delete(tensor,np.argwhere(mask==1),2)
exclude_variables[name] = tensor
if try_zeros == False:
with open(new_model_def, 'w') as f:
json.dump(model_definition, f)
model_def = new_model_def
if mask_exist:
with open(mask_save, 'w') as f:
np.save(f, mask_id)
else:
mask_save = None
os.environ['CUDA_VISIBLE_DEVICES'] = gpu_id
module = importlib.import_module('models.' + 'AdapNet_pp')
model_func = getattr(module, 'AdapNet_pp')
resnet_name = 'resnet_v2_50'
with tf.variable_scope(resnet_name):
model = model_func(num_classes=Num_classes, training=False, model_def=model_def, mask=mask_save)
images_pl = tf.placeholder(tf.float32, [None, height, width, 3])
model.build_graph(images_pl)
config1 = tf.ConfigProto()
config1.gpu_options.allow_growth = True
sess = tf.Session(config=config1)
sess.run(tf.global_variables_initializer())
all_variables = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
load_var = {}
index_record = {}
for i,variable in enumerate(all_variables):
name_ = variable.name.split(':')[0]
if name_ not in exclude_variables:
load_var[name_] = variable
else:
index_record[name_] = i
saver = tf.train.Saver(load_var)
saver.restore(sess, checkpoint_name)
for name_ in exclude_variables:
print name_+':0'
process = all_variables[index_record[name_]].assign(exclude_variables[name_])
_=sess.run(process)
saver = tf.train.Saver()
saver.save(sess, new_checkpoint_save)