-
Notifications
You must be signed in to change notification settings - Fork 5.7k
/
Copy pathreduce_lib_size_util.py
128 lines (107 loc) · 3.55 KB
/
reduce_lib_size_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This script simply removes grad kernels. You should use this script
when cmake ON_INFER=ON, which can greatly reduce the volume of the inference library.
"""
import glob
import os
def is_balanced(content):
"""
Check whether sequence contains valid parenthesis.
Args:
content (str): content of string.
Returns:
boolean: True if sequence contains valid parenthesis.
"""
if content.find('{') == -1:
return False
stack = []
push_chars, pop_chars = '({', ')}'
for c in content:
if c in push_chars:
stack.append(c)
elif c in pop_chars:
if not len(stack):
return False
else:
stack_top = stack.pop()
balancing_bracket = push_chars[pop_chars.index(c)]
if stack_top != balancing_bracket:
return False
return not stack
def grad_kernel_definition(content, kernel_pattern, grad_pattern):
"""
Args:
content(str): file content
kernel_pattern(str): kernel pattern
grad_pattern(str): grad pattern
Returns:
(list, int): grad kernel definitions in file and count.
"""
results = []
count = 0
start = 0
lens = len(content)
while True:
index = content.find(kernel_pattern, start)
if index == -1:
return results, count
i = index + 1
while i <= lens:
check_str = content[index:i]
if is_balanced(check_str):
if check_str.find(grad_pattern) != -1:
results.append(check_str)
count += 1
start = i
break
i += 1
else:
return results, count
def remove_grad_kernels(dry_run=False):
"""
Args:
dry_run(bool): whether just print
Returns:
int: number of kernel(grad) removed
"""
pd_kernel_pattern = 'PD_REGISTER_STRUCT_KERNEL'
register_op_pd_kernel_count = 0
matches = []
tool_dir = os.path.dirname(os.path.abspath(__file__))
all_op = glob.glob(
os.path.join(tool_dir, '../paddle/fluid/operators/**/*.cc'),
recursive=True,
)
all_op += glob.glob(
os.path.join(tool_dir, '../paddle/fluid/operators/**/*.cu'),
recursive=True,
)
for op_file in all_op:
with open(op_file, 'r', encoding='utf-8') as f:
content = ''.join(f.readlines())
pd_kernel, pd_kernel_count = grad_kernel_definition(
content, pd_kernel_pattern, '_grad,'
)
register_op_pd_kernel_count += pd_kernel_count
matches.extend(pd_kernel)
for to_remove in matches:
content = content.replace(to_remove, '')
if dry_run:
print(op_file, to_remove)
if not dry_run:
with open(op_file, 'w', encoding='utf-8') as f:
f.write(content)
return register_op_pd_kernel_count