#include "base/kaldi-common.h"
#include "util/common-utils.h"
#include "feat/wave-reader.h"
#include "feat/feature-functions.h"
using namespace kaldi;
void DeInterlaceAndNormalizeSpectrum( Vector<BaseFloat> *spectrum )
{
int32 N = spectrum->Dim();
Vector<BaseFloat> temp(N);
temp.CopyFromVec(*spectrum);
for ( int32 n = 0; n < N / 2; n++ ) {
BaseFloat re = n==0 || n==N/2 ?
temp(n * 2)/(N) : // Special cases on the edge
temp(n * 2)/(N/2);
BaseFloat im = -temp(n * 2 + 1);
(*spectrum)(n) = re;
(*spectrum)(spectrum->Dim() / 2 + n) = im;
}
}
void InterlaceSpectrum( Vector<BaseFloat> *spectrum )
{
Vector<BaseFloat> temp(spectrum->Dim());
temp.CopyFromVec(*spectrum);
for ( int32 i = 0; i < spectrum->Dim() / 2; i++ ) {
BaseFloat re = temp(i);
BaseFloat im = temp(spectrum->Dim() / 2 + i);
(*spectrum)(i * 2) = re;
(*spectrum)(i * 2 + 1) = im;
}
}
BaseFloat Mel( BaseFloat hz )
{
return 1125 * std::log(1 + hz / 700);
}
BaseFloat Hz( BaseFloat mel )
{
return 700 * (std::exp(mel / 1125) - 1);
}
Matrix<BaseFloat> GetForwardMelFbank( FrameExtractionOptions fe_opts, MelBanksOptions mb_opts )
{
int32 num_fft = fe_opts.PaddedWindowSize();
int32 num_filter = mb_opts.num_bins;
Matrix<BaseFloat> mel_fbank;
mel_fbank.Resize(num_filter, num_fft / 2, kSetZero);
BaseFloat freq_resolution = fe_opts.samp_freq / num_fft;
BaseFloat low_freq_mel = Mel(mb_opts.low_freq);
BaseFloat high_freq_mel = (mb_opts.high_freq == 0) ? Mel(fe_opts.samp_freq / 2) : Mel(mb_opts.high_freq);
BaseFloat delta_freq_mel = (high_freq_mel - low_freq_mel) / (num_filter + 1);
BaseFloat left_freq;
BaseFloat center_freq = round(Hz(0 * delta_freq_mel + low_freq_mel) / freq_resolution) * freq_resolution;
BaseFloat right_freq = round(Hz(1 * delta_freq_mel + low_freq_mel) / freq_resolution) * freq_resolution;
for ( int32 filter = 0; filter < num_filter; filter++ ) {
left_freq = center_freq;
center_freq = right_freq;
right_freq = round(Hz(((BaseFloat) filter + 2) * delta_freq_mel + low_freq_mel) / freq_resolution) * freq_resolution;
for ( int32 fft = 0; fft < num_fft / 2; fft++ ) {
BaseFloat freq = fft * freq_resolution;
if ( freq < left_freq || freq > right_freq ) {
mel_fbank(filter, fft) = 0;
} else if ( freq < center_freq ) {
mel_fbank(filter, fft) = (freq - left_freq) / (center_freq - left_freq);
} else {
mel_fbank(filter, fft) = (right_freq - freq) / (right_freq - center_freq);
}
}
}
return mel_fbank;
}
Matrix<BaseFloat> GetBackwardMelFbank( FrameExtractionOptions fe_opts, MelBanksOptions mb_opts, Matrix<BaseFloat> forward_mel_filterbank )
{
int32 num_fft = fe_opts.PaddedWindowSize();
int32 num_filter = mb_opts.num_bins;
Matrix<BaseFloat> mel_fbank;
mel_fbank.Resize(num_fft / 2, num_filter, kSetZero);
BaseFloat freq_resolution = fe_opts.samp_freq / num_fft;
BaseFloat low_freq_mel = Mel(mb_opts.low_freq);
BaseFloat high_freq_mel = (mb_opts.high_freq == 0) ? Mel(fe_opts.samp_freq / 2) : Mel(mb_opts.high_freq);
BaseFloat delta_freq_mel = (high_freq_mel - low_freq_mel) / (num_filter + 1);
for ( int32 filter = 0; filter < num_filter; filter++ ) {
int32 center_fft = round(Hz(((BaseFloat) filter + 1) * delta_freq_mel + low_freq_mel) / freq_resolution);
BaseFloat filter_area = forward_mel_filterbank.Row(filter).Sum();
mel_fbank(center_fft, filter) = 1 / filter_area;
}
return mel_fbank;
}
Matrix<BaseFloat> GetIdentityFbank( FrameExtractionOptions fe_opts )
{
int32 num_fft = fe_opts.PaddedWindowSize();
int32 num_filter = num_fft / 2;
Matrix<BaseFloat> fbank;
fbank.Resize(num_filter, num_fft / 2, kSetZero);
for ( int32 filter = 0; filter < num_filter; filter++ ) {
fbank(filter, filter) = 1;
}
return fbank;
}
int main( int argc, char *argv[] )
{
try {
const char *usage = "...";
ParseOptions po(usage);
FrameExtractionOptions fe_opts;
fe_opts.Register(&po);
MelBanksOptions mb_opts;
mb_opts.Register(&po);
po.Read(argc, argv);
if ( po.NumArgs() != 2 ) {
po.PrintUsage();
exit(1);
}
std::string wav_rspecifier = po.GetArg(1);
std::string wav_wspecifier = po.GetArg(2);
KALDI_LOG<<"Opening files";
SequentialTableReader<WaveHolder> wav_reader(wav_rspecifier);
TableWriter<WaveHolder> wav_writer(wav_wspecifier);
KALDI_LOG<<"Setting up window functions";
// For extracting frames from time-domain signal
FeatureWindowFunction forward_window_function(fe_opts);
// Weights for overlap and add (nominator)
Vector<BaseFloat> backward_window(forward_window_function.window.Dim());
backward_window.CopyFromVec(forward_window_function.window);
// Weights for overlap and add (denominator)
Vector<BaseFloat> backward_window_squared(forward_window_function.window.Dim());
backward_window_squared.CopyFromVec(backward_window);
backward_window_squared.MulElements(backward_window);
KALDI_LOG<<"Setting up DFT and supporting objects";
Vector<BaseFloat> window;
SplitRadixRealFft<BaseFloat> srfft(fe_opts.PaddedWindowSize());
std::vector<BaseFloat> temp_buffer;
KALDI_LOG<<"Setting up filterbanks and supporting objects";
Matrix<BaseFloat> forward_filterbank = GetForwardMelFbank(fe_opts, mb_opts);
Matrix<BaseFloat> backward_filterbank = GetBackwardMelFbank(fe_opts, mb_opts, forward_filterbank);
// Matrix<BaseFloat> forward_filterbank = GetIdentityFbank(fe_opts);
// Matrix<BaseFloat> backward_filterbank = GetIdentityFbank(fe_opts);
Vector<BaseFloat> foo(forward_filterbank.NumCols());
foo.Add(1);
Vector<BaseFloat> bar(forward_filterbank.NumRows());
bar.AddMatVec(1, forward_filterbank, kNoTrans, foo, 1);
foo.SetZero();
foo.AddMatVec(1, backward_filterbank, kNoTrans, bar, 1);
KALDI_LOG<<"Makse sure we get back 1s at the bank center frequencies"<<foo;
for ( ; !wav_reader.Done(); wav_reader.Next() ) {
KALDI_LOG<<"Reading input wav";
std::string utt = wav_reader.Key();
const WaveData &wave_data_in = wav_reader.Value();
SubVector<BaseFloat> wave_in(wave_data_in.Data(), 0);
int32 num_frames = NumFrames(wave_in.Dim(), fe_opts);
KALDI_LOG<<"Setting up output wav objects";
int32 num_samples = num_frames * fe_opts.WindowShift()+fe_opts.WindowSize();
Vector<BaseFloat> wave_out_numerator(num_samples);
Vector<BaseFloat> wave_out_denominator(num_samples);
KALDI_LOG<<"Iterating over frames";
for(int32 r = 0; r < num_frames; r++) {
ExtractWindow(wave_in, r, fe_opts, forward_window_function, &window, NULL);
// Save the original signal for comparison
Vector<BaseFloat> old_window(window.Dim());
old_window.CopyFromVec(window);
srfft.Compute(window.Data(), true, &temp_buffer);
// Separate real and imaginary parts
DeInterlaceAndNormalizeSpectrum(&window);
bool apply_filterbank = true;
if(apply_filterbank) {
// Apply filterbank only to real part
SubVector<BaseFloat> sub_window(window.Data(), fe_opts.PaddedWindowSize()/2);
Vector<BaseFloat> feature_frame(forward_filterbank.NumRows());
feature_frame.AddMatVec(1., forward_filterbank, kNoTrans, sub_window, 1.);
// Reset the window and invert the filterbank
window.SetZero();
sub_window.AddMatVec(1., backward_filterbank, kNoTrans, feature_frame, 1);
}
// Restore window to interlaced form in preparation for inverse DFT
InterlaceSpectrum(&window);
srfft.Compute(window.Data(), false, &temp_buffer);
// Do overlap and add for this piece
SubVector<BaseFloat> local_wave_out_numerator(wave_out_numerator, r*fe_opts.WindowShift(), fe_opts.WindowSize());
window.MulElements(backward_window);
local_wave_out_numerator.AddVec(1, window);
SubVector<BaseFloat> local_wave_out_denominator(wave_out_denominator,r*fe_opts.WindowShift(), fe_opts.WindowSize());
local_wave_out_denominator.AddVec(1, backward_window_squared);
}
KALDI_LOG << "Finalizing overlap and add";
// Prevent divizion by (near) zero
for(int32 r = 0; r < wave_out_denominator.Dim(); r++) {
if (wave_out_denominator(r) < wave_out_denominator.Max()/100) {
wave_out_denominator(r) = wave_out_denominator.Max()/100;
}
}
// Normalize with window weights
wave_out_numerator.DivElements(wave_out_denominator);
KALDI_LOG << "Normalizing signal to prevent clipping and writing output wav";
wave_out_numerator.Scale( (std::numeric_limits<int16>::max() ) / std::max( wave_out_numerator.Max(), std::abs( wave_out_numerator.Min() ) ) );
Matrix<BaseFloat> wave_out_matrix(1, wave_out_numerator.Dim());
wave_out_matrix.CopyRowFromVec(wave_out_numerator, 0);
WaveData wave_data_out(wave_data_in.SampFreq(), wave_out_matrix);
wav_writer.Write(utt, wave_data_out);
}
return 0;
} catch ( const std::exception &e ) {
std::cerr << e.what();
return -1;
}
}
You received this message because you are subscribed to a topic in the Google Groups "kaldi-help" group.
To unsubscribe from this topic, visit https://groups.google.com/d/topic/kaldi-help/aPv8D7rMSVM/unsubscribe.
To unsubscribe from this group and all its topics, send an email to kaldi-help+...@googlegroups.com.