-
Notifications
You must be signed in to change notification settings - Fork 281
/
Copy pathconfig_util.py
260 lines (236 loc) · 11.5 KB
/
config_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration types."""
from typing import Optional
from absl import logging
from tensorflow_model_analysis import constants
from tensorflow_model_analysis.proto import config_pb2
def verify_eval_config(eval_config: config_pb2.EvalConfig,
baseline_required: Optional[bool] = None):
"""Verifies eval config."""
if not eval_config.model_specs:
raise ValueError(
'At least one model_spec is required: eval_config=\n{}'.format(
eval_config))
model_specs_by_name = {}
baseline = None
for spec in eval_config.model_specs:
if spec.label_key and spec.label_keys:
raise ValueError('only one of label_key or label_keys should be used at '
'a time: model_spec=\n{}'.format(spec))
if spec.prediction_key and spec.prediction_keys:
raise ValueError(
'only one of prediction_key or prediction_keys should be used at '
'a time: model_spec=\n{}'.format(spec))
if spec.example_weight_key and spec.example_weight_keys:
raise ValueError(
'only one of example_weight_key or example_weight_keys should be '
'used at a time: model_spec=\n{}'.format(spec))
if spec.name in eval_config.model_specs:
raise ValueError(
'more than one model_spec found for model "{}": {}'.format(
spec.name, [spec, model_specs_by_name[spec.name]]))
model_specs_by_name[spec.name] = spec
if spec.is_baseline:
if baseline is not None:
raise ValueError('only one model_spec may be a baseline, found: '
'{} and {}'.format(spec, baseline))
baseline = spec
if len(model_specs_by_name) > 1 and '' in model_specs_by_name:
raise ValueError('A name is required for all ModelSpecs when multiple '
'models are used: eval_config=\n{}'.format(eval_config))
if baseline_required and not baseline:
raise ValueError(
'A baseline ModelSpec is required: eval_config=\n{}'.format(
eval_config))
# Raise exception if per_slice_thresholds has no slicing_specs.
for metric_spec in eval_config.metrics_specs:
for name, per_slice_thresholds in metric_spec.per_slice_thresholds.items():
for per_slice_threshold in per_slice_thresholds.thresholds:
if not per_slice_threshold.slicing_specs:
raise ValueError(
'slicing_specs must be set on per_slice_thresholds but found '
f'per_slice_threshold=\n{per_slice_threshold}\n'
f'for metric name {name} in metric_spec:\n{metric_spec}'
)
for metric_config in metric_spec.metrics:
for per_slice_threshold in metric_config.per_slice_thresholds:
if not per_slice_threshold.slicing_specs:
raise ValueError(
'slicing_specs must be set on per_slice_thresholds but found '
f'per_slice_threshold=\n{per_slice_threshold}\n'
f'for metric config:\t{metric_config}'
)
def update_eval_config_with_defaults(
eval_config: config_pb2.EvalConfig,
maybe_add_baseline: Optional[bool] = None,
maybe_remove_baseline: Optional[bool] = None,
has_baseline: Optional[bool] = False,
rubber_stamp: Optional[bool] = False) -> config_pb2.EvalConfig:
"""Returns a new config with default settings applied.
a) Add or remove a model_spec according to "has_baseline".
b) Fix the model names (model_spec.name) to tfma.CANDIDATE_KEY and
tfma.BASELINE_KEY.
c) Update the metrics_specs with the fixed model name.
Args:
eval_config: Original eval config.
maybe_add_baseline: DEPRECATED. True to add a baseline ModelSpec to the
config as a copy of the candidate ModelSpec that should already be
present. This is only applied if a single ModelSpec already exists in the
config and that spec doesn't have a name associated with it. When applied
the model specs will use the names tfma.CANDIDATE_KEY and
tfma.BASELINE_KEY. Only one of maybe_add_baseline or maybe_remove_baseline
should be used.
maybe_remove_baseline: DEPRECATED. True to remove a baseline ModelSpec from
the config if it already exists. Removal of the baseline also removes any
change thresholds. Only one of maybe_add_baseline or maybe_remove_baseline
should be used.
has_baseline: True to add a baseline ModelSpec to the config as a copy of
the candidate ModelSpec that should already be present. This is only
applied if a single ModelSpec already exists in the config and that spec
doesn't have a name associated with it. When applied the model specs will
use the names tfma.CANDIDATE_KEY and tfma.BASELINE_KEY. False to remove a
baseline ModelSpec from the config if it already exists. Removal of the
baseline also removes any change thresholds. Only one of has_baseline or
maybe_remove_baseline should be used.
rubber_stamp: True if this model is being rubber stamped. When a model is
rubber stamped diff thresholds will be ignored if an associated baseline
model is not passed.
Raises:
RuntimeError: on missing baseline model for non-rubberstamp cases.
"""
if (not has_baseline and has_change_threshold(eval_config) and
not rubber_stamp):
# TODO(b/173657964): Raise an error instead of logging an error.
raise RuntimeError(
'There are change thresholds, but the baseline is missing. '
'This is allowed only when rubber stamping (first run).')
updated_config = config_pb2.EvalConfig()
updated_config.CopyFrom(eval_config)
# if user requests CIs but doesn't set method, use JACKKNIFE
if (eval_config.options.compute_confidence_intervals.value and
eval_config.options.confidence_intervals.method ==
config_pb2.ConfidenceIntervalOptions.UNKNOWN_CONFIDENCE_INTERVAL_METHOD):
updated_config.options.confidence_intervals.method = (
config_pb2.ConfidenceIntervalOptions.JACKKNIFE)
if maybe_add_baseline and maybe_remove_baseline:
raise ValueError('only one of maybe_add_baseline and maybe_remove_baseline '
'should be used')
if maybe_add_baseline or maybe_remove_baseline:
logging.warning(
""""maybe_add_baseline" and "maybe_remove_baseline" are deprecated,
please use "has_baseline" instead.""")
if has_baseline:
raise ValueError(
""""maybe_add_baseline" and "maybe_remove_baseline" are ignored if
"has_baseline" is set.""")
if has_baseline is not None:
if has_baseline:
maybe_add_baseline = True
else:
maybe_remove_baseline = True
# Has a baseline model.
if (maybe_add_baseline and len(updated_config.model_specs) == 1 and
not updated_config.model_specs[0].name):
baseline = updated_config.model_specs.add()
baseline.CopyFrom(updated_config.model_specs[0])
baseline.name = constants.BASELINE_KEY
baseline.is_baseline = True
updated_config.model_specs[0].name = constants.CANDIDATE_KEY
logging.info(
'Adding default baseline ModelSpec based on the candidate ModelSpec '
'provided. The candidate model will be called "%s" and the baseline '
'will be called "%s": updated_config=\n%s', constants.CANDIDATE_KEY,
constants.BASELINE_KEY, updated_config)
# Does not have a baseline.
if maybe_remove_baseline:
tmp_model_specs = []
for model_spec in updated_config.model_specs:
if not model_spec.is_baseline:
tmp_model_specs.append(model_spec)
del updated_config.model_specs[:]
updated_config.model_specs.extend(tmp_model_specs)
for metrics_spec in updated_config.metrics_specs:
for metric in metrics_spec.metrics:
if metric.threshold.ByteSize():
metric.threshold.ClearField('change_threshold')
for per_slice_threshold in metric.per_slice_thresholds:
if per_slice_threshold.threshold.ByteSize():
per_slice_threshold.threshold.ClearField('change_threshold')
for cross_slice_threshold in metric.cross_slice_thresholds:
if cross_slice_threshold.threshold.ByteSize():
cross_slice_threshold.threshold.ClearField('change_threshold')
for threshold in metrics_spec.thresholds.values():
if threshold.ByteSize():
threshold.ClearField('change_threshold')
for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
for per_slice_threshold in per_slice_thresholds.thresholds:
if per_slice_threshold.threshold.ByteSize():
per_slice_threshold.threshold.ClearField('change_threshold')
for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values(
):
for cross_slice_threshold in cross_slice_thresholds.thresholds:
if cross_slice_threshold.threshold.ByteSize():
cross_slice_threshold.threshold.ClearField('change_threshold')
logging.info(
'Request was made to ignore the baseline ModelSpec and any change '
'thresholds. This is likely because a baseline model was not provided: '
'updated_config=\n%s', updated_config)
if not updated_config.model_specs:
updated_config.model_specs.add()
model_names = []
for spec in updated_config.model_specs:
model_names.append(spec.name)
if len(model_names) == 1 and model_names[0]:
logging.info(
'ModelSpec name "%s" is being ignored and replaced by "" because a '
'single ModelSpec is being used', model_names[0])
updated_config.model_specs[0].name = ''
model_names = ['']
for spec in updated_config.metrics_specs:
if not spec.model_names:
spec.model_names.extend(model_names)
elif len(model_names) == 1:
del spec.model_names[:]
spec.model_names.append('')
return updated_config
def has_change_threshold(eval_config: config_pb2.EvalConfig) -> bool:
"""Checks whether the eval_config has any change thresholds.
Args:
eval_config: the TFMA eval_config.
Returns:
True when there are change thresholds otherwise False.
"""
for metrics_spec in eval_config.metrics_specs:
for metric in metrics_spec.metrics:
if metric.threshold.change_threshold.ByteSize():
return True
for per_slice_threshold in metric.per_slice_thresholds:
if per_slice_threshold.threshold.change_threshold.ByteSize():
return True
for cross_slice_threshold in metric.cross_slice_thresholds:
if cross_slice_threshold.threshold.change_threshold.ByteSize():
return True
for threshold in metrics_spec.thresholds.values():
if threshold.change_threshold.ByteSize():
return True
for per_slice_thresholds in metrics_spec.per_slice_thresholds.values():
for per_slice_threshold in per_slice_thresholds.thresholds:
if per_slice_threshold.threshold.change_threshold.ByteSize():
return True
for cross_slice_thresholds in metrics_spec.cross_slice_thresholds.values():
for cross_slice_threshold in cross_slice_thresholds.thresholds:
if cross_slice_threshold.threshold.change_threshold.ByteSize():
return True
return False