| |
| |
| |
| |
| |
| |
|
|
| import math |
| import time |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
|
|
| from .resample import downsample2, upsample2 |
| from .utils import capture_init |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| EPS = 1e-8 |
| class Chomp1d(nn.Module): |
| """To ensure the output length is the same as the input. |
| """ |
| def __init__(self, chomp_size): |
| super(Chomp1d, self).__init__() |
| self.chomp_size = chomp_size |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: [M, H, Kpad] |
| Returns: |
| [M, H, K] |
| """ |
| return x[:, :, :-self.chomp_size].contiguous() |
|
|
| def chose_norm(norm_type, channel_size): |
| """The input of normlization will be (M, C, K), where M is batch size, |
| C is channel size and K is sequence length. |
| """ |
| if norm_type == "gLN": |
| return GlobalLayerNorm(channel_size) |
| elif norm_type == "cLN": |
| return ChannelwiseLayerNorm(channel_size) |
| else: |
| |
| |
| return nn.BatchNorm1d(channel_size) |
|
|
| class ChannelwiseLayerNorm(nn.Module): |
| """Channel-wise Layer Normalization (cLN)""" |
| def __init__(self, channel_size): |
| super(ChannelwiseLayerNorm, self).__init__() |
| self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) |
| self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| self.gamma.data.fill_(1) |
| self.beta.data.zero_() |
|
|
| def forward(self, y): |
| """ |
| Args: |
| y: [M, N, K], M is batch size, N is channel size, K is length |
| Returns: |
| cLN_y: [M, N, K] |
| """ |
| mean = torch.mean(y, dim=1, keepdim=True) |
| var = torch.var(y, dim=1, keepdim=True, unbiased=False) |
| cLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta |
| return cLN_y |
|
|
| class DepthwiseSeparableConv(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, |
| stride, padding, dilation, norm_type="gLN", causal=False): |
| super(DepthwiseSeparableConv, self).__init__() |
| |
| |
| depthwise_conv = nn.Conv1d(in_channels, in_channels, kernel_size, |
| stride=stride, padding=padding, |
| dilation=dilation, groups=in_channels, |
| bias=False) |
| if causal: |
| chomp = Chomp1d(padding) |
| prelu = nn.PReLU() |
| norm = chose_norm(norm_type, in_channels) |
| |
| pointwise_conv = nn.Conv1d(in_channels, out_channels, 1, bias=False) |
| |
| if causal: |
| self.net = nn.Sequential(depthwise_conv, chomp, prelu, norm, |
| pointwise_conv) |
| else: |
| self.net = nn.Sequential(depthwise_conv, prelu, norm, |
| pointwise_conv) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: [M, H, K] |
| Returns: |
| result: [M, B, K] |
| """ |
| return self.net(x) |
|
|
| class TemporalBlock(nn.Module): |
| def __init__(self, in_channels, out_channels, kernel_size, |
| stride, padding, dilation, norm_type="gLN", causal=False): |
| super(TemporalBlock, self).__init__() |
| |
| conv1x1 = nn.Conv1d(in_channels, out_channels, 1, bias=False) |
| prelu = nn.PReLU() |
| norm = chose_norm(norm_type, out_channels) |
| |
| dsconv = DepthwiseSeparableConv(out_channels, in_channels, kernel_size, |
| stride, padding, dilation, norm_type, |
| causal) |
| |
| self.net = nn.Sequential(conv1x1, prelu, norm, dsconv) |
|
|
| def forward(self, x): |
| """ |
| Args: |
| x: [M, B, K] |
| Returns: |
| [M, B, K] |
| """ |
| residual = x |
| out = self.net(x) |
| |
| return out + residual |
| |
|
|
| class GlobalLayerNorm(nn.Module): |
| """Global Layer Normalization (gLN)""" |
| def __init__(self, channel_size): |
| super(GlobalLayerNorm, self).__init__() |
| self.gamma = nn.Parameter(torch.Tensor(1, channel_size, 1)) |
| self.beta = nn.Parameter(torch.Tensor(1, channel_size,1 )) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| self.gamma.data.fill_(1) |
| self.beta.data.zero_() |
|
|
| def forward(self, y): |
| """ |
| Args: |
| y: [M, N, K], M is batch size, N is channel size, K is length |
| Returns: |
| gLN_y: [M, N, K] |
| """ |
| |
| mean = y.mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) |
| var = (torch.pow(y-mean, 2)).mean(dim=1, keepdim=True).mean(dim=2, keepdim=True) |
| gLN_y = self.gamma * (y - mean) / torch.pow(var + EPS, 0.5) + self.beta |
| return gLN_y |
|
|
| class TemporalConvNet(nn.Module): |
| def __init__(self, N=768, B=256, H=512, P=3, X=8, R=4, C=1, norm_type="gLN", causal=1, |
| mask_nonlinear='relu'): |
| """ |
| Args: |
| N: Number of filters in autoencoder |
| B: Number of channels in bottleneck 1 × 1-conv block |
| H: Number of channels in convolutional blocks |
| P: Kernel size in convolutional blocks |
| X: Number of convolutional blocks in each repeat |
| R: Number of repeats |
| C: Number of speakers |
| norm_type: BN, gLN, cLN |
| causal: causal or non-causal |
| mask_nonlinear: use which non-linear function to generate mask |
| """ |
| super(TemporalConvNet, self).__init__() |
| |
| self.C = C |
| self.mask_nonlinear = mask_nonlinear |
| |
| |
| layer_norm = ChannelwiseLayerNorm(N) |
| |
| bottleneck_conv1x1 = nn.Conv1d(N, B, 1, bias=False) |
| |
| repeats = [] |
| for r in range(R): |
| blocks = [] |
| for x in range(X): |
| dilation = 2**x |
| padding = (P - 1) * dilation if causal else (P - 1) * dilation // 2 |
| blocks += [TemporalBlock(B, H, P, stride=1, |
| padding=padding, |
| dilation=dilation, |
| norm_type=norm_type, |
| causal=causal)] |
| repeats += [nn.Sequential(*blocks)] |
| temporal_conv_net = nn.Sequential(*repeats) |
| |
| mask_conv1x1 = nn.Conv1d(B, C*N, 1, bias=False) |
| |
| self.network = nn.Sequential(layer_norm, |
| bottleneck_conv1x1, |
| temporal_conv_net, |
| mask_conv1x1) |
|
|
| def forward(self, mixture_w): |
| """ |
| Keep this API same with TasNet |
| Args: |
| mixture_w: [M, N, K], M is batch size |
| returns: |
| est_mask: [M, C, N, K] |
| """ |
| M, N, K = mixture_w.size() |
| score = self.network(mixture_w) |
| score = score.view(M, self.C, N, K) |
| if self.mask_nonlinear == 'softmax': |
| est_mask = F.softmax(score, dim=1) |
| est_mask = est_mask.squeeze(1) |
| elif self.mask_nonlinear == 'relu': |
| est_mask = F.relu(score) |
| est_mask = est_mask.squeeze(1) |
| else: |
| raise ValueError("Unsupported mask non-linear function") |
| return est_mask |
|
|
|
|
|
|
| def rescale_conv(conv, reference): |
| std = conv.weight.std().detach() |
| scale = (std / reference)**0.5 |
| conv.weight.data /= scale |
| if conv.bias is not None: |
| conv.bias.data /= scale |
|
|
|
|
| def rescale_module(module, reference): |
| for sub in module.modules(): |
| if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d)): |
| rescale_conv(sub, reference) |
|
|
|
|
| class Demucs(nn.Module): |
| """ |
| Demucs speech enhancement model. |
| Args: |
| - chin (int): number of input channels. |
| - chout (int): number of output channels. |
| - hidden (int): number of initial hidden channels. |
| - depth (int): number of layers. |
| - kernel_size (int): kernel size for each layer. |
| - stride (int): stride for each layer. |
| - causal (bool): if false, uses BiLSTM instead of LSTM. |
| - resample (int): amount of resampling to apply to the input/output. |
| Can be one of 1, 2 or 4. |
| - growth (float): number of channels is multiplied by this for every layer. |
| - max_hidden (int): maximum number of channels. Can be useful to |
| control the size/speed of the model. |
| - normalize (bool): if true, normalize the input. |
| - glu (bool): if true uses GLU instead of ReLU in 1x1 convolutions. |
| - rescale (float): controls custom weight initialization. |
| See https://arxiv.org/abs/1911.13254. |
| - floor (float): stability flooring when normalizing. |
| |
| """ |
| @capture_init |
| def __init__(self, |
| chin=1, |
| chout=1, |
| hidden=48, |
| depth=5, |
| kernel_size=8, |
| stride=4, |
| causal=True, |
| resample=4, |
| growth=2, |
| max_hidden=10_000, |
| normalize=True, |
| glu=True, |
| rescale=0.1, |
| floor=1e-3): |
|
|
| super().__init__() |
| if resample not in [1, 2, 4]: |
| raise ValueError("Resample should be 1, 2 or 4.") |
|
|
| self.chin = chin |
| self.chout = chout |
| self.hidden = hidden |
| self.depth = depth |
| self.kernel_size = kernel_size |
| self.stride = stride |
| self.causal = causal |
| self.floor = floor |
| self.resample = resample |
| self.normalize = normalize |
|
|
| self.encoder = nn.ModuleList() |
| self.decoder = nn.ModuleList() |
| activation = nn.GLU(1) if glu else nn.ReLU() |
| ch_scale = 2 if glu else 1 |
|
|
| for index in range(depth): |
| encode = [] |
| encode += [ |
| nn.Conv1d(chin, hidden, kernel_size, stride), |
| nn.ReLU(), |
| nn.Conv1d(hidden, hidden * ch_scale, 1), activation, |
| ] |
| self.encoder.append(nn.Sequential(*encode)) |
|
|
| decode = [] |
| decode += [ |
| nn.Conv1d(hidden, ch_scale * hidden, 1), activation, |
| nn.ConvTranspose1d(hidden, chout, kernel_size, stride), |
| ] |
| if index > 0: |
| decode.append(nn.ReLU()) |
| self.decoder.insert(0, nn.Sequential(*decode)) |
| chout = hidden |
| chin = hidden |
| hidden = min(int(growth * hidden), max_hidden) |
| |
| self.separator = TemporalConvNet(N=chout) |
| |
| if rescale: |
| rescale_module(self, reference=rescale) |
|
|
| def valid_length(self, length): |
| """ |
| Return the nearest valid length to use with the model so that |
| there is no time steps left over in a convolutions, e.g. for all |
| layers, size of the input - kernel_size % stride = 0. |
| |
| If the mixture has a valid length, the estimated sources |
| will have exactly the same length. |
| """ |
| length = math.ceil(length * self.resample) |
| for idx in range(self.depth): |
| length = math.ceil((length - self.kernel_size) / self.stride) + 1 |
| length = max(length, 1) |
| for idx in range(self.depth): |
| length = (length - 1) * self.stride + self.kernel_size |
| length = int(math.ceil(length / self.resample)) |
| return int(length) |
|
|
| @property |
| def total_stride(self): |
| return self.stride ** self.depth // self.resample |
|
|
| def forward(self, mix): |
| if mix.dim() == 2: |
| mix = mix.unsqueeze(1) |
|
|
| if self.normalize: |
| mono = mix.mean(dim=1, keepdim=True) |
| std = mono.std(dim=-1, keepdim=True) |
| mix = mix / (self.floor + std) |
| else: |
| std = 1 |
| length = mix.shape[-1] |
| x = mix |
| x = F.pad(x, (0, self.valid_length(length) - length)) |
| if self.resample == 2: |
| x = upsample2(x) |
| elif self.resample == 4: |
| x = upsample2(x) |
| x = upsample2(x) |
| skips = [] |
| for encode in self.encoder: |
| x = encode(x) |
| skips.append(x) |
| x = self.separator(x) |
| |
| |
| |
| |
| for decode in self.decoder: |
| skip = skips.pop(-1) |
| x = x + skip[..., :x.shape[-1]] |
| x = decode(x) |
| if self.resample == 2: |
| x = downsample2(x) |
| elif self.resample == 4: |
| x = downsample2(x) |
| x = downsample2(x) |
|
|
| x = x[..., :length] |
| return std * x |
|
|
|
|
| def fast_conv(conv, x): |
| """ |
| Faster convolution evaluation if either kernel size is 1 |
| or length of sequence is 1. |
| """ |
| batch, chin, length = x.shape |
| chout, chin, kernel = conv.weight.shape |
| assert batch == 1 |
| if kernel == 1: |
| x = x.view(chin, length) |
| out = th.addmm(conv.bias.view(-1, 1), |
| conv.weight.view(chout, chin), x) |
| elif length == kernel: |
| x = x.view(chin * kernel, 1) |
| out = th.addmm(conv.bias.view(-1, 1), |
| conv.weight.view(chout, chin * kernel), x) |
| else: |
| out = conv(x) |
| return out.view(batch, chout, -1) |
|
|
|
|
| class DemucsStreamer: |
| """ |
| Streaming implementation for Demucs. It supports being fed with any amount |
| of audio at a time. You will get back as much audio as possible at that |
| point. |
| |
| Args: |
| - demucs (Demucs): Demucs model. |
| - dry (float): amount of dry (e.g. input) signal to keep. 0 is maximum |
| noise removal, 1 just returns the input signal. Small values > 0 |
| allows to limit distortions. |
| - num_frames (int): number of frames to process at once. Higher values |
| will increase overall latency but improve the real time factor. |
| - resample_lookahead (int): extra lookahead used for the resampling. |
| - resample_buffer (int): size of the buffer of previous inputs/outputs |
| kept for resampling. |
| """ |
| def __init__(self, demucs, |
| dry=0, |
| num_frames=1, |
| resample_lookahead=64, |
| resample_buffer=256): |
| device = next(iter(demucs.parameters())).device |
| self.demucs = demucs |
| self.lstm_state = None |
| self.conv_state = None |
| self.dry = dry |
| self.resample_lookahead = resample_lookahead |
| self.resample_buffer = resample_buffer |
| self.frame_length = demucs.valid_length(1) + demucs.total_stride * (num_frames - 1) |
| self.total_length = self.frame_length + self.resample_lookahead |
| self.stride = demucs.total_stride * num_frames |
| self.resample_in = torch.zeros(demucs.chin, resample_buffer, device=device) |
| self.resample_out = torch.zeros(demucs.chin, resample_buffer, device=device) |
|
|
| self.frames = 0 |
| self.total_time = 0 |
| self.variance = 0 |
| self.pending = torch.zeros(demucs.chin, 0, device=device) |
|
|
| bias = demucs.decoder[0][2].bias |
| weight = demucs.decoder[0][2].weight |
| chin, chout, kernel = weight.shape |
| self._bias = bias.view(-1, 1).repeat(1, kernel).view(-1, 1) |
| self._weight = weight.permute(1, 2, 0).contiguous() |
|
|
| def reset_time_per_frame(self): |
| self.total_time = 0 |
| self.frames = 0 |
|
|
| @property |
| def time_per_frame(self): |
| return self.total_time / self.frames |
|
|
| def flush(self): |
| """ |
| Flush remaining audio by padding it with zero. Call this |
| when you have no more input and want to get back the last chunk of audio. |
| """ |
| pending_length = self.pending.shape[1] |
| padding = torch.zeros(self.demucs.chin, self.total_length, device=self.pending.device) |
| out = self.feed(padding) |
| return out[:, :pending_length] |
|
|
| def feed(self, wav): |
| """ |
| Apply the model to mix using true real time evaluation. |
| Normalization is done online as is the resampling. |
| """ |
| begin = time.time() |
| demucs = self.demucs |
| resample_buffer = self.resample_buffer |
| stride = self.stride |
| resample = demucs.resample |
|
|
| if wav.dim() != 2: |
| raise ValueError("input wav should be two dimensional.") |
| chin, _ = wav.shape |
| if chin != demucs.chin: |
| raise ValueError(f"Expected {demucs.chin} channels, got {chin}") |
|
|
| self.pending = torch.cat([self.pending, wav], dim=1) |
| outs = [] |
| while self.pending.shape[1] >= self.total_length: |
| self.frames += 1 |
| frame = self.pending[:, :self.total_length] |
| dry_signal = frame[:, :stride] |
| if demucs.normalize: |
| mono = frame.mean(0) |
| variance = (mono**2).mean() |
| self.variance = variance / self.frames + (1 - 1 / self.frames) * self.variance |
| frame = frame / (demucs.floor + math.sqrt(self.variance)) |
| frame = torch.cat([self.resample_in, frame], dim=-1) |
| self.resample_in[:] = frame[:, stride - resample_buffer:stride] |
|
|
| if resample == 4: |
| frame = upsample2(upsample2(frame)) |
| elif resample == 2: |
| frame = upsample2(frame) |
| frame = frame[:, resample * resample_buffer:] |
| frame = frame[:, :resample * self.frame_length] |
|
|
| out, extra = self._separate_frame(frame) |
| padded_out = torch.cat([self.resample_out, out, extra], 1) |
| self.resample_out[:] = out[:, -resample_buffer:] |
| if resample == 4: |
| out = downsample2(downsample2(padded_out)) |
| elif resample == 2: |
| out = downsample2(padded_out) |
| else: |
| out = padded_out |
|
|
| out = out[:, resample_buffer // resample:] |
| out = out[:, :stride] |
|
|
| if demucs.normalize: |
| out *= math.sqrt(self.variance) |
| out = self.dry * dry_signal + (1 - self.dry) * out |
| outs.append(out) |
| self.pending = self.pending[:, stride:] |
|
|
| self.total_time += time.time() - begin |
| if outs: |
| out = torch.cat(outs, 1) |
| else: |
| out = torch.zeros(chin, 0, device=wav.device) |
| return out |
|
|
| def _separate_frame(self, frame): |
| demucs = self.demucs |
| skips = [] |
| next_state = [] |
| first = self.conv_state is None |
| stride = self.stride * demucs.resample |
| x = frame[None] |
| for idx, encode in enumerate(demucs.encoder): |
| stride //= demucs.stride |
| length = x.shape[2] |
| if idx == demucs.depth - 1: |
| |
| x = fast_conv(encode[0], x) |
| x = encode[1](x) |
| x = fast_conv(encode[2], x) |
| x = encode[3](x) |
| else: |
| if not first: |
| prev = self.conv_state.pop(0) |
| prev = prev[..., stride:] |
| tgt = (length - demucs.kernel_size) // demucs.stride + 1 |
| missing = tgt - prev.shape[-1] |
| offset = length - demucs.kernel_size - demucs.stride * (missing - 1) |
| x = x[..., offset:] |
| x = encode[1](encode[0](x)) |
| x = fast_conv(encode[2], x) |
| x = encode[3](x) |
| if not first: |
| x = torch.cat([prev, x], -1) |
| next_state.append(x) |
| skips.append(x) |
|
|
| x = x.permute(2, 0, 1) |
| x, self.lstm_state = demucs.lstm(x, self.lstm_state) |
| x = x.permute(1, 2, 0) |
| |
| |
| |
| |
| extra = None |
| for idx, decode in enumerate(demucs.decoder): |
| skip = skips.pop(-1) |
| x += skip[..., :x.shape[-1]] |
| x = fast_conv(decode[0], x) |
| x = decode[1](x) |
|
|
| if extra is not None: |
| skip = skip[..., x.shape[-1]:] |
| extra += skip[..., :extra.shape[-1]] |
| extra = decode[2](decode[1](decode[0](extra))) |
| x = decode[2](x) |
| next_state.append(x[..., -demucs.stride:] - decode[2].bias.view(-1, 1)) |
| if extra is None: |
| extra = x[..., -demucs.stride:] |
| else: |
| extra[..., :demucs.stride] += next_state[-1] |
| x = x[..., :-demucs.stride] |
|
|
| if not first: |
| prev = self.conv_state.pop(0) |
| x[..., :demucs.stride] += prev |
| if idx != demucs.depth - 1: |
| x = decode[3](x) |
| extra = decode[3](extra) |
| self.conv_state = next_state |
| return x[0], extra[0] |
|
|
|
|
| def test(): |
| import argparse |
| parser = argparse.ArgumentParser( |
| "denoiser.demucs", |
| description="Benchmark the streaming Demucs implementation, " |
| "as well as checking the delta with the offline implementation.") |
| parser.add_argument("--resample", default=4, type=int) |
| parser.add_argument("--hidden", default=48, type=int) |
| parser.add_argument("--device", default="cpu") |
| parser.add_argument("-t", "--num_threads", type=int) |
| parser.add_argument("-f", "--num_frames", type=int, default=1) |
| args = parser.parse_args() |
| if args.num_threads: |
| torch.set_num_threads(args.num_threads) |
| sr = 16_000 |
| sr_ms = sr / 1000 |
| demucs = Demucs(hidden=args.hidden, resample=args.resample).to(args.device) |
| x = torch.randn(1, sr * 4).to(args.device) |
| out = demucs(x[None])[0] |
| streamer = DemucsStreamer(demucs, num_frames=args.num_frames) |
| out_rt = [] |
| frame_size = streamer.total_length |
| with torch.no_grad(): |
| while x.shape[1] > 0: |
| out_rt.append(streamer.feed(x[:, :frame_size])) |
| x = x[:, frame_size:] |
| frame_size = streamer.demucs.total_stride |
| out_rt.append(streamer.flush()) |
| out_rt = torch.cat(out_rt, 1) |
| print(f"total lag: {streamer.total_length / sr_ms:.1f}ms, ", end='') |
| print(f"stride: {streamer.stride / sr_ms:.1f}ms, ", end='') |
| print(f"time per frame: {1000 * streamer.time_per_frame:.1f}ms, ", end='') |
| print(f"delta: {torch.norm(out - out_rt) / torch.norm(out):.2%}, ", end='') |
| print(f"RTF: {((1000 * streamer.time_per_frame) / (streamer.stride / sr_ms)):.1f}") |
|
|
|
|
| if __name__ == "__main__": |
| test() |
|
|