FunASR is similar to ESPNet, which applies Task
as the general interface ti achieve the training and inference of models. Each Task
is a class inherited from AbsTask
and its corresponding code can be seen in funasr/tasks/abs_task.py
. The main functions of AbsTask
are shown as follows:
class AbsTask(ABC):
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
pass
@classmethod
def build_preprocess_fn(cls, args, train):
(...)
@classmethod
def build_collate_fn(cls, args: argparse.Namespace):
(...)
@classmethod
def build_model(cls, args):
(...)
@classmethod
def main(cls, args):
(...)
- add_task_arguments:Add parameters required by a specified
Task
- build_preprocess_fn:定义如何处理对样本进行预处理 define how to preprocess samples
- build_collate_fn:define how to combine multiple samples into a
batch
- build_model:define the model
- main:training interface, starting training through
Task.main()
Next, we take the speech recognition as an example to introduce how to define a new Task
. For the corresponding code, please see ASRTask
in funasr/tasks/asr.py
. The procedure of defining a new Task
is actually the procedure of redefining the above functions according to the requirements of the specified Task
.
- add_task_arguments
@classmethod
def add_task_arguments(cls, parser: argparse.ArgumentParser):
group = parser.add_argument_group(description="Task related")
group.add_argument(
"--token_list",
type=str_or_none,
default=None,
help="A text mapping int-id to token",
)
(...)
For speech recognition tasks, specific parameters required include token_list
, etc. According to the specific requirements of different tasks, users can define corresponding parameters in this function.
- build_preprocess_fn
@classmethod
def build_preprocess_fn(cls, args, train):
if args.use_preprocessor:
retval = CommonPreprocessor(
train=train,
token_type=args.token_type,
token_list=args.token_list,
bpemodel=args.bpemodel,
non_linguistic_symbols=args.non_linguistic_symbols,
text_cleaner=args.cleaner,
...
)
else:
retval = None
return retval
This function defines how to preprocess samples. Specifically, the input of speech recognition tasks includes speech and text. For speech, functions such as (optional) adding noise and reverberation to the speech are supported. For text, functions such as (optional) processing text according to bpe and mapping text to tokenid
are supported. Users can choose the preprocessing operation that needs to be performed on the sample. For the detail implementation, please refer to CommonPreprocessor
.
- build_collate_fn
@classmethod
def build_collate_fn(cls, args, train):
return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
This function defines how to combine multiple samples into a batch
. For speech recognition tasks, padding
is employed to obtain equal-length data from different speech and text. Specifically, we set 0.0
as the default padding value for speech and -1
as the default padding value for text. Users can define different batch
operations here. For the detail implementation, please refer to CommonCollateFn
.
- build_model
@classmethod
def build_model(cls, args, train):
with open(args.token_list, encoding="utf-8") as f:
token_list = [line.rstrip() for line in f]
vocab_size = len(token_list)
frontend = frontend_class(**args.frontend_conf)
specaug = specaug_class(**args.specaug_conf)
normalize = normalize_class(**args.normalize_conf)
preencoder = preencoder_class(**args.preencoder_conf)
encoder = encoder_class(input_size=input_size, **args.encoder_conf)
postencoder = postencoder_class(input_size=encoder_output_size, **args.postencoder_conf)
decoder = decoder_class(vocab_size=vocab_size, encoder_output_size=encoder_output_size, **args.decoder_conf)
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **args.ctc_conf)
model = model_class(
vocab_size=vocab_size,
frontend=frontend,
specaug=specaug,
normalize=normalize,
preencoder=preencoder,
encoder=encoder,
postencoder=postencoder,
decoder=decoder,
ctc=ctc,
token_list=token_list,
**args.model_conf,
)
return model
This function defines the detail of the model. For different speech recognition models, the same speech recognition Task
can usually be shared and the remaining thing needed to be done is to define a specific model in this function. For example, a speech recognition model with a standard encoder-decoder structure has been shown above. Specifically, it first defines each module of the model, including encoder, decoder, etc. and then combine these modules together to generate a complete model. In FunASR, the model needs to inherit FunASRModel
and the corresponding code can be seen in funasr/train/abs_espnet_model.py
. The main function needed to be implemented is the forward
function.
Next, we take SANMEncoder
as an example to introduce how to use a custom encoder as a part of the model when defining the specified model and the corresponding code can be seen in funasr/models/encoder/sanm_encoder.py
. For a custom encoder, in addition to inheriting the common encoder class AbsEncoder
, it is also necessary to define the forward
function to achieve the forward computation of the encoder
. After defining the encoder
, it should also be registered in the Task
. The corresponding code example can be seen as below:
encoder_choices = ClassChoices(
"encoder",
classes=dict(
conformer=ConformerEncoder,
transformer=TransformerEncoder,
rnn=RNNEncoder,
sanm=SANMEncoder,
sanm_chunk_opt=SANMEncoderChunkOpt,
data2vec_encoder=Data2VecEncoder,
mfcca_enc=MFCCAEncoder,
),
type_check=AbsEncoder,
default="rnn",
)
In this code, sanm=SANMEncoder
takes the newly defined SANMEncoder
as an optional choice of the encoder
. Once the user specifies the encoder
as sanm
in the configuration file, the SANMEncoder
will be correspondingly employed as the encoder
module of the model.