-
Notifications
You must be signed in to change notification settings - Fork 2.8k
/
Copy pathmultispeaker_data_analysis.py
288 lines (243 loc) · 11.5 KB
/
multispeaker_data_analysis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import multiprocessing
import shutil
from collections import OrderedDict
from pathlib import Path
from pprint import pprint
from typing import Dict
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import sox
from scipy.stats import expon
from tqdm import tqdm
from nemo.collections.asr.parts.utils.vad_utils import (
get_nonspeech_segments,
load_speech_overlap_segments_from_rttm,
plot_sample_from_rttm,
)
"""
This script analyzes multi-speaker speech dataset and generates statistics.
The input directory </path/to/rttm_and_wav_directory> is required to contain the following files:
- rttm files (*.rttm)
- wav files (*.wav)
Usage:
python <NEMO_ROOT>/scripts/speaker_tasks/multispeaker_data_analysis.py \
</path/to/rttm_and_wav_directory> \
--session_dur 20 \
--silence_mean 0.2 \
--silence_var 100 \
--overlap_mean 0.15 \
--overlap_var 50 \
--num_workers 8 \
--num_samples 10 \
--output_dir <path/to/output_directory>
"""
def process_sample(sess_dict: Dict) -> Dict:
"""
Process each synthetic sample
Args:
sess_dict (dict): dictionary containing the following keys
rttm_file (str): path to the rttm file
session_dur (float): duration of the session (specified by argument)
precise (bool): whether to measure the precise duration of the session using sox
Returns:
results (dict): dictionary containing the following keys
session_dur (float): duration of the session
silence_len_list (list): list of silence durations of each silence occurrence
silence_dur (float): total silence duration in a session
silence_ratio (float): ratio of silence duration to session duration
overlap_len_list (list): list of overlap durations of each overlap occurrence
overlap_dur (float): total overlap duration
overlap_ratio (float): ratio of overlap duration to speech (non-silence) duration
"""
rttm_file = sess_dict["rttm_file"]
session_dur = sess_dict["session_dur"]
precise = sess_dict["precise"]
if precise or session_dur is None:
wav_file = rttm_file.parent / Path(rttm_file.stem + ".wav")
session_dur = sox.file_info.duration(str(wav_file))
speech_seg, overlap_seg = load_speech_overlap_segments_from_rttm(rttm_file)
speech_dur = sum([sess_dict[1] - sess_dict[0] for sess_dict in speech_seg])
silence_seg = get_nonspeech_segments(speech_seg, session_dur)
silence_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in silence_seg]
silence_dur = max(0, session_dur - speech_dur)
silence_ratio = silence_dur / session_dur
overlap_len_list = [sess_dict[1] - sess_dict[0] for sess_dict in overlap_seg]
overlap_dur = sum(overlap_len_list) if len(overlap_len_list) else 0
overlap_ratio = overlap_dur / speech_dur
results = {
"session_dur": session_dur,
"silence_len_list": silence_len_list,
"silence_dur": silence_dur,
"silence_ratio": silence_ratio,
"overlap_len_list": overlap_len_list,
"overlap_dur": overlap_dur,
"overlap_ratio": overlap_ratio,
}
return results
def run_multispeaker_data_analysis(
input_dir,
session_dur=None,
silence_mean=None,
silence_var=None,
overlap_mean=None,
overlap_var=None,
precise=False,
save_path=None,
num_workers=1,
) -> Dict:
rttm_list = list(Path(input_dir).glob("*.rttm"))
"""
Analyze the multispeaker data and plot the distribution of silence and overlap durations.
Args:
input_dir (str): path to the directory containing the rttm files
session_dur (float): duration of the session (specified by argument)
silence_mean (float): mean of the silence duration distribution
silence_var (float): variance of the silence duration distribution
overlap_mean (float): mean of the overlap duration distribution
overlap_var (float): variance of the overlap duration distribution
precise (bool): whether to measure the precise duration of the session using sox
save_path (str): path to save the plots
Returns:
stats (dict): dictionary containing the statistics of the analyzed data
"""
print(f"Found {len(rttm_list)} files to be processed")
if len(rttm_list) == 0:
raise ValueError(f"No rttm files found in {input_dir}")
silence_duration = 0.0
total_duration = 0.0
overlap_duration = 0.0
silence_ratio_all = []
overlap_ratio_all = []
silence_length_all = []
overlap_length_all = []
queue = []
for rttm_file in tqdm(rttm_list):
queue.append(
{"rttm_file": rttm_file, "session_dur": session_dur, "precise": precise,}
)
if num_workers <= 1:
results = [process_sample(sess_dict) for sess_dict in tqdm(queue)]
else:
with multiprocessing.Pool(processes=num_workers) as p:
results = list(tqdm(p.imap(process_sample, queue), total=len(queue), desc='Processing', leave=True,))
for item in results:
total_duration += item["session_dur"]
silence_duration += item["silence_dur"]
overlap_duration += item["overlap_dur"]
silence_length_all += item["silence_len_list"]
overlap_length_all += item["overlap_len_list"]
silence_ratio_all.append(item["silence_ratio"])
overlap_ratio_all.append(item["overlap_ratio"])
actual_silence_mean = silence_duration / total_duration
actual_silence_var = np.var(silence_ratio_all)
actual_overlap_mean = overlap_duration / (total_duration - silence_duration)
actual_overlap_var = np.var(overlap_ratio_all)
stats = OrderedDict()
stats["total duration (hours)"] = f"{total_duration / 3600:.2f}"
stats["number of sessions"] = len(rttm_list)
stats["average session duration (seconds)"] = f"{total_duration / len(rttm_list):.2f}"
stats["actual silence ratio mean/var"] = f"{actual_silence_mean:.4f}/{actual_silence_var:.4f}"
stats["actual overlap ratio mean/var"] = f"{actual_overlap_mean:.4f}/{actual_overlap_var:.4f}"
stats["expected silence ratio mean/var"] = f"{silence_mean}/{silence_var}"
stats["expected overlap ratio mean/var"] = f"{overlap_mean}/{overlap_var}"
stats["save_path"] = save_path
print("-----------------------------------------------")
print(" Results ")
print("-----------------------------------------------")
for k, v in stats.items():
print(k, ": ", v)
print("-----------------------------------------------")
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(14, 14))
fig.suptitle(
f"Average session={total_duration/len(rttm_list):.2f} seconds, num sessions={len(rttm_list)}, total={total_duration/3600:.2f} hours"
)
sns.histplot(silence_ratio_all, ax=ax1)
ax1.set_xlabel("Silence ratio in a session")
ax1.set_title(
f"Target silence mean={silence_mean}, var={silence_var}. \nActual silence ratio={actual_silence_mean:.4f}, var={actual_silence_var:.4f}"
)
_, scale = expon.fit(silence_length_all, floc=0)
sns.histplot(silence_length_all, ax=ax2)
ax2.set_xlabel("Per-silence length in seconds")
ax2.set_title(f"Per-silence length histogram, \nfitted exponential distribution with mean={scale:.4f}")
sns.histplot(overlap_ratio_all, ax=ax3)
ax3.set_title(
f"Target overlap mean={overlap_mean}, var={overlap_var}. \nActual ratio={actual_overlap_mean:.4f}, var={actual_overlap_var:.4f}"
)
ax3.set_xlabel("Overlap ratio in a session")
_, scale2 = expon.fit(overlap_length_all, floc=0)
sns.histplot(overlap_length_all, ax=ax4)
ax4.set_title(f"Per overlap length histogram, \nfitted exponential distribution with mean={scale2:.4f}")
ax4.set_xlabel("Duration in seconds")
if save_path:
fig.savefig(save_path)
print(f"Figure saved at: {save_path}")
return stats
def visualize_multispeaker_data(input_dir: str, output_dir: str, num_samples: int = 10) -> None:
"""
Visualize a set of randomly sampled data in the input directory
Args:
input_dir (str): Path to the input directory
output_dir (str): Path to the output directory
num_samples (int): Number of samples to visualize
"""
rttm_list = list(Path(input_dir).glob("*.rttm"))
idx_list = np.random.permutation(len(rttm_list))[:num_samples]
print(f"Visualizing {num_samples} random samples")
for idx in idx_list:
rttm_file = rttm_list[idx]
audio_file = rttm_file.parent / Path(rttm_file.stem + ".wav")
output_file = Path(output_dir) / Path(rttm_file.stem + ".png")
plot_sample_from_rttm(audio_file=audio_file, rttm_file=rttm_file, save_path=str(output_file), show=False)
print(f"Sample plots saved at: {output_dir}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("input_dir", default="", help="Input directory")
parser.add_argument("-sd", "--session_dur", default=None, type=float, help="Duration per session in seconds")
parser.add_argument("-sm", "--silence_mean", default=None, type=float, help="Expected silence ratio mean")
parser.add_argument("-sv", "--silence_var", default=None, type=float, help="Expected silence ratio variance")
parser.add_argument("-om", "--overlap_mean", default=None, type=float, help="Expected overlap ratio mean")
parser.add_argument("-ov", "--overlap_var", default=None, type=float, help="Expected overlap ratio variance")
parser.add_argument("-w", "--num_workers", default=1, type=int, help="Number of CPU workers to use")
parser.add_argument("-s", "--num_samples", default=10, type=int, help="Number of random samples to plot")
parser.add_argument("-o", "--output_dir", default="analysis/", type=str, help="Directory for saving output figure")
parser.add_argument(
"--precise", action="store_true", help="Set to get precise duration, with significant time cost"
)
args = parser.parse_args()
print("Running with params:")
pprint(vars(args))
output_dir = Path(args.output_dir)
if output_dir.exists():
print(f"Removing existing output directory: {args.output_dir}")
shutil.rmtree(str(output_dir))
output_dir.mkdir(parents=True)
run_multispeaker_data_analysis(
input_dir=args.input_dir,
session_dur=args.session_dur,
silence_mean=args.silence_mean,
silence_var=args.silence_var,
overlap_mean=args.overlap_mean,
overlap_var=args.overlap_var,
precise=args.precise,
save_path=str(Path(args.output_dir, "statistics.png")),
num_workers=args.num_workers,
)
visualize_multispeaker_data(input_dir=args.input_dir, output_dir=args.output_dir, num_samples=args.num_samples)
print("The multispeaker data analysis has been completed.")
print(f"Please check the output directory: \n{args.output_dir}")