Source code for javad.main

import torch
import torch.nn as nn
import numpy as np
from typing import Optional, Union
from torch import Tensor
from javad.utils import load_checkpoint


MODELINFO = {
    "tiny": {
        "args": {
            "n_state": 256,
            "output_dim": 64,
            "seq": {
                "conv1": ["crb", 1, 128, (3, 7), 1, (1, 3)],
                "conv2": None,
                "conv3": ["dc", 128, 256, 3, 1, 1, (32, 32)],
            },
        },
        "threshold": 0.5,
        "input_length": 0.64,
        "output_length": 64,
        # audio params
        "sample_rate": 16000,
        "n_mels": 64,
        "hop_length": 160,
        "n_fft": 400,
    },
    "balanced": {
        "args": {},
        "threshold": 0.5,
        "input_length": 1.92,
        "output_length": 192,
        # audio params
        "sample_rate": 16000,
        "n_mels": 80,
        "hop_length": 160,
        "n_fft": 400,
    },
    "precise": {
        "args": {
            "n_layer": 2,
            "n_head": 4,
            "output_dim": 384,
            "seq": {
                "conv1": ["crb", 1, 128, (3, 7), (1, 2), (1, 3)],
                "conv2": ["dc", 128, 256, (3, 7), (1, 3), (1, 3), (80, 64)],
                "conv3": ["dc", 256, 512, 3, 1, 1, (40, 32)],
            },
        },
        "threshold": 0.5768,
        "input_length": 3.84,
        "output_length": 384,
        # audio params
        "sample_rate": 16000,
        "n_mels": 80,
        "hop_length": 160,
        "n_fft": 400,
    },
}


class DirectionalConv2d(nn.Module):
    def __init__(
        self,
        shape: tuple,
        in_channels: int,
        out_channels: int,
        kernel: Union[int, tuple],
        stride: Union[int, tuple] = 1,
        padding: Union[int, tuple] = 0,
        groups: int = 1,
        bias: bool = True,
    ) -> None:
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=kernel,
            stride=stride,
            padding=padding,
            groups=groups,
            bias=bias,
        )
        self.bn = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.activation = nn.ReLU()
        self.alignment = DirectionalAlignment(out_channels, shape)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv(x)
        x = self.bn(x)
        x = self.activation(x)
        y = self.alignment(x.clone())
        x = x + y
        x = self.bn2(x)
        return x


class DirectionalAlignment(nn.Module):
    def __init__(self, planes: int, shape: tuple) -> None:
        super().__init__()
        self.weights = nn.Parameter(torch.Tensor(planes, shape[0], shape[0]))
        nn.init.kaiming_normal_(self.weights)
        self.bn = nn.BatchNorm2d(planes)
        self.dim = torch.sqrt(torch.tensor(shape[-1]))
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = x @ x.transpose(-1, -2)
        x = x / self.dim
        x = x * self.weights
        x = torch.sum(x, dim=-1, keepdim=True)
        return x

    def extra_repr(self) -> str:
        return f"DirectionalAlignment: planes={self.weights.shape[0]}, shape={self.weights.shape[1:]}"


def sinusoids(length, channels, max_timescale=10000):
    """Returns sinusoids for positional embedding"""
    assert channels % 2 == 0
    log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
    inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
    scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
    return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)


class MultiHeadAttention(nn.Module):
    def __init__(self, n_state: int, n_head: int):
        super().__init__()
        self.n_head = n_head
        self.query = nn.Linear(n_state, n_state)
        self.key = nn.Linear(n_state, n_state, bias=False)
        self.value = nn.Linear(n_state, n_state)
        self.out = nn.Linear(n_state, n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        q = self.query(x)

        if kv_cache is None or xa is None or self.key not in kv_cache:
            # hooks, if installed (i.e. kv_cache is not None), will prepend the cached kv tensors;
            # otherwise, perform key/value projections for self- or cross-attention as usual.
            k = self.key(x if xa is None else xa)
            v = self.value(x if xa is None else xa)
        else:
            # for cross-attention, calculate keys and values once and reuse in subsequent calls.
            k = kv_cache[self.key]
            v = kv_cache[self.value]

        wv, qk = self.qkv_attention(q, k, v, mask)
        return self.out(wv), qk

    def qkv_attention(
        self,
        q: Tensor,
        k: Tensor,
        v: Tensor,
        mask: Optional[Tensor] = None,
    ):
        n_batch, n_ctx, n_state = q.shape
        scale = (n_state // self.n_head) ** -0.25
        q = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3) * scale
        k = k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 3, 1) * scale
        v = v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)

        qk = q @ k
        if mask is not None:
            qk = qk + mask[:n_ctx, :n_ctx]
        qk = qk.float()

        w = torch.nn.functional.softmax(qk, dim=-1).to(q.dtype)
        return (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2), qk.detach()


class ResidualAttentionBlock(nn.Module):
    def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
        super().__init__()

        self.attn = MultiHeadAttention(n_state, n_head)
        self.attn_ln = nn.LayerNorm(n_state)

        self.cross_attn = (
            MultiHeadAttention(n_state, n_head) if cross_attention else None
        )
        self.cross_attn_ln = nn.LayerNorm(n_state) if cross_attention else None

        n_mlp = n_state * 4
        self.mlp = nn.Sequential(
            nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
        )
        self.mlp_ln = nn.LayerNorm(n_state)

    def forward(
        self,
        x: Tensor,
        xa: Optional[Tensor] = None,
        mask: Optional[Tensor] = None,
        kv_cache: Optional[dict] = None,
    ):
        x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
        if self.cross_attn:
            x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
        x = x + self.mlp(self.mlp_ln(x))
        return x


class JaVAD(nn.Module):
    def __init__(
        self,
        seq: dict = {
            "conv1": ["crb", 1, 128, (3, 7), (1, 3), (1, 3)],
            "conv2": None,
            "conv3": ["dc", 128, 256, 3, 1, 1, (40, 32)],
        },
        n_state: int = 320,
        n_head: int = 4,
        n_layer: int = 0,
        output_dim: int = 192,
        dropout: float = 0.4,
    ) -> None:
        super().__init__()
        channels_out = 1
        for k, v in seq.items():
            if v is None:
                setattr(self, k, None)
                continue
            elif v[0] == "crb":
                setattr(
                    self,
                    k,
                    nn.Sequential(
                        nn.Conv2d(
                            in_channels=v[1],
                            out_channels=v[2],
                            kernel_size=v[3],
                            stride=v[4],
                            padding=v[5],
                        ),
                        nn.ReLU(),
                        nn.BatchNorm2d(v[2]),
                    ),
                )
            elif v[0] == "dc":
                setattr(
                    self,
                    k,
                    DirectionalConv2d(
                        in_channels=v[1],
                        out_channels=v[2],
                        kernel=v[3],
                        stride=v[4],
                        padding=v[5],
                        shape=v[6],
                    ),
                )
            channels_out = v[2] if v is not None else channels_out
        self.pool = nn.AvgPool2d(2, 2)
        self.register_buffer("positional_embedding", sinusoids(channels_out, n_state))
        self.blocks = nn.ModuleList(
            [ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)]
        )
        self.ln_post = nn.LayerNorm(n_state)
        self.collapsing_weights = nn.Parameter(
            torch.nn.init.xavier_uniform_(torch.rand((n_state, 1)))
        )
        self.classification_head = nn.Sequential(
            nn.Linear(channels_out, n_state),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(n_state, output_dim),
        )

    def forward(self, x: torch.Tensor):
        # (bs, 1, 80, 192) or (bs, 1, 64, 64)
        if self.conv1:
            x = self.conv1(x)
        if self.conv2:
            x = self.conv2(x)
        x = self.pool(x)
        if self.conv3:
            x = self.conv3(x)
        x = self.pool(x)
        x = x.flatten(2)
        assert (
            x.shape[1:] == self.positional_embedding.shape
        ), f"incorrect audio shape: {x.shape}-vs-{self.positional_embedding.shape}"
        x = (x + self.positional_embedding).to(x.dtype)
        for block in self.blocks:
            x = block(x)
        x = self.ln_post(x)
        x = x @ self.collapsing_weights
        x = x.flatten(1)
        x = self.classification_head(x)
        return x


[docs] def initialize(name: str = "balanced") -> torch.nn.Module: """ Initializes a model with the given name. Args: name (str): The name of the model to initialize. Defaults to "balanced". Available options are "tiny", "balanced", and "precise". Returns: torch.nn.Module: The initialized model. """ if name not in MODELINFO: raise Exception(f"Unknown model name: {name}") mdl = JaVAD(**MODELINFO[name]["args"]) mdl.eval() return mdl
[docs] def from_pretrained( name: str = "balanced", checkpoint: Union[str, None] = None ) -> torch.nn.Module: """ Initializes and loads a pre-trained model. Args: name (str): The name of the model to initialize and load. Defaults to "balanced". Available options are "tiny" and "precise" checkpoint (str, optional): The path to a checkpoint file to load. Defaults to None. If not None, the model will be loaded from the checkpoint file. Returns: torch.nn.Module: The initialized and loaded model. """ if checkpoint is not None: cpt = load_checkpoint(checkpoint, is_asset=False) else: cpt = load_checkpoint(name) model_name = cpt["model_name"] state_dict = cpt["state_dict"] mdl = initialize(name=model_name) mdl.load_state_dict(state_dict=state_dict) return mdl