import numpy as np
import torch
from javad.main import from_pretrained, MODELINFO, load_checkpoint
from javad.utils import exact_div
from javad.utils import load_mel_filters, log_mel_spectrogram
from types import SimpleNamespace
from typing import List, Tuple, Dict, Union
import warnings
[docs]
class Pipeline:
def __init__(
self,
model_name: str = "balanced",
checkpoint: Union[str, None] = None,
mode: str = "gradual",
threshold: Union[float, None] = None,
device: Union[torch.device, str] = torch.device("cpu"),
) -> None:
"""
Initialize the stream pipeline for voice activity detection.
This class processes audio streams for voice activity detection using various models.
Args:
model_name (str, optional): Name of the model to use. Defaults to "balanced" (there are also "tiny" and "precise" options).
checkpoint (Union[str, None], optional): Path to a custom model checkpoint. If None, uses the default model.
mode (str, optional): Processing mode - "instant" or "gradual". Defaults to "gradual".
'instant' mode immediately returns latest predictions, although it may not be as accurate
as 'gradual' mode which maintains and updates predictions while chunks are moving across buffer.
threshold (Union[float, None], optional): Detection threshold. If None, uses model's default.
Defaults to None.
device (Union[torch.device, str], optional): Device to run computations on.
Defaults to torch.device("cpu").
Attributes:
mode (str): Processing mode.
config (SimpleNamespace): Configuration parameters including:
- model_name: Name of the model
- sample_rate: Audio sample rate
- fps: Frames per second
- window_size: Size of processing window in samples
- window_size_frames: Size of processing window in frames
- model_output_length: Model output length
- n_mels: Number of mel frequency bands
- hop: Hop length
- threshold: Detection threshold
- padding_size: Size of padding added to buffer to prevent inaccuracy in
spectrograms at the start of the buffer
- buffer_size: Size of audio buffer
flags (SimpleNamespace): Processing flags
audio_buffer (torch.Tensor): Buffer for audio processing
model: Neural network model for VAD
mel_filters: Mel-frequency filterbank
mean (float): Running mean for statistics
variance (float): Running variance for statistics
chunk_count (int): Counter for processed chunks
predictions_storage (dict): Storage for predictions
frames_tracker (list): Tracker for processed frames
predicted_intervals (dict): Storage for predicted intervals
detection_carry (int): Carryover detection counter
"""
self.__device = (
device if isinstance(device, torch.device) else torch.device(device)
)
# Initialize model
if checkpoint is not None:
cpt = load_checkpoint(checkpoint, is_asset=False)
model_name = cpt["model_name"]
self.__model = from_pretrained(checkpoint=checkpoint).to(self.__device)
else:
self.__model = from_pretrained(name=model_name).to(self.__device)
self.__model.eval()
self.mode = mode
modelinfo = MODELINFO[model_name]
fps = int(exact_div(modelinfo["sample_rate"], modelinfo["hop_length"]))
# with sample rate = 16000, hop length = 160 -> fps = 100
# means 1 second of audio is 100 spectrogram frames
self.config = SimpleNamespace(
model_name=model_name,
sample_rate=modelinfo["sample_rate"],
fps=fps,
window_size=int(modelinfo["input_length"] * modelinfo["sample_rate"]),
window_size_frames=int(modelinfo["input_length"] * fps),
model_output_length=modelinfo["output_length"],
n_mels=modelinfo["n_mels"],
hop=modelinfo["hop_length"],
threshold=(threshold or modelinfo["threshold"]),
)
self.flags = SimpleNamespace(input_padded=False)
self.config.padding_size = modelinfo["n_fft"] // modelinfo["hop_length"]
self.config.buffer_size = int(
modelinfo["input_length"] * modelinfo["sample_rate"]
+ self.config.padding_size * modelinfo["hop_length"]
)
self.audio_buffer = torch.zeros(
self.config.buffer_size, dtype=torch.float32
).to(device)
# Preload mel filters
self.preload_mel_filters(n_mels=self.config.n_mels)
self.mean = 0.0
self.variance = 0.0
self.chunk_count = -1
self.predictions_storage = {}
self.frames_tracker = []
self.predicted_intervals = {}
self.detection_carry = 0
[docs]
def reset(self):
"""Reset the pipeline to initial state."""
self.audio_buffer.zero_()
self.mean = 0.0
self.variance = 0.0
self.chunk_count = -1
self.predictions_storage = {}
self.frames_tracker = []
self.predicted_intervals = {}
self.detection_carry = 0
self.flags.input_padded = False
return self
@property
def device(self) -> torch.device:
return self.__device
@device.setter
def device(self, d: Union[torch.device, str]):
if isinstance(d, str):
d = torch.device(d)
self.__device = d
self.__model.to(self.device)
def to(self, device: Union[torch.device, str]) -> "Pipeline":
self.device = device
return self
[docs]
def preload_mel_filters(self, n_mels: int) -> torch.Tensor:
"""Load mel filter bank matrices for a given number of mel bins."""
if self.__device == torch.device("mps"):
self.mel_filters = (
load_mel_filters(n_mels=n_mels).to(torch.float32).to(self.__device)
)
else:
self.mel_filters = load_mel_filters(n_mels=n_mels).to(self.__device)
[docs]
def update_stats(self, spectrogram: torch.Tensor):
"""
Update running statistics (mean and standard deviation) of the spectrogram data.
This method uses Welford's online algorithm to compute running statistics
of streaming spectrogram data.
Args:
spectrogram : torch.Tensor
Input spectrogram tensor of shape (frequency_bins, time_frames)
Returns:
tuple
A tuple containing:
- mean (float): Updated running mean of the spectrogram
- std (torch.Tensor): Updated running standard deviation of the spectrogram
normalized by total number of frames and frequency bins
Notes:
The method tracks the total number of frames processed using self.frames_tracker
and updates statistics incrementally using Welford's method for numerical stability.
"""
frames_chunk = self.frames_tracker[-1]
frames_total = sum(self.frames_tracker)
spg = spectrogram[:, -frames_chunk:]
delta = spg.mean() - self.mean
self.mean += delta * frames_chunk / frames_total
# Update the variance using Welford's method
self.variance += torch.sum((spg - self.mean) ** 2)
return self.mean, torch.sqrt(self.variance / (frames_total * spg.shape[0]))
[docs]
def normalize_spectrogram(self, spectrogram: torch.Tensor) -> torch.Tensor:
"""Normalizes the spectrogram using running mean and standard deviation.
Args:
spectrogram (torch.Tensor): Input spectrogram tensor to be normalized.
Returns:
torch.Tensor: Normalized spectrogram tensor with zero mean and unit variance.
If standard deviation is 0, returns original spectrogram unchanged.
"""
mean, std = self.update_stats(spectrogram)
if std == 0.0:
return spectrogram
return (spectrogram - mean) / std
[docs]
def push(
self, chunk: Union[List, np.ndarray, torch.Tensor]
) -> Union[torch.Tensor, Dict]:
"""
Pushes a chunk of audio data through the model for prediction.
This method processes audio chunks for prediction by:
1. Converting input to torch tensor if needed
2. Padding the chunk if it's not divisible by hop length
3. Managing a rolling audio buffer
4. Computing log mel spectrogram
5. Normalizing the spectrogram
6. Running prediction
7. Tracking and aggregating predictions across chunks
Args:
chunk : Union[List, np.ndarray, torch.Tensor]
Audio chunk to process. Can be a list, numpy array or torch tensor.
Length must not exceed model's window_size.
Returns:
Union[torch.Tensor, Dict[int, torch.Tensor]]
If mode is "instant": Returns tensor of predictions for current chunk
If mode is "gradual": Returns dict mapping chunk numbers to mean predictions
across all passes that included that chunk
Raises:
ValueError
If chunk length is larger than model window size
If non-final chunk length is not divisible by hop length
"""
# convert chunk to torch.tensor
if isinstance(chunk, (list, np.ndarray)):
chunk = torch.tensor(chunk, dtype=torch.float32).to(self.device)
# if chunk is not divisible by hop, pad with zeroes
if len(chunk) % self.config.hop != 0:
if self.flags.input_padded:
raise ValueError(
f"All chunks except last one should have size divisible by hop length {self.config.hop}. Current size {len(chunk)}"
)
if not self.flags.input_padded:
self.flags.input_padded = True
chunk = torch.nn.functional.pad(
chunk, (0, self.config.hop - len(chunk) % self.config.hop)
)
if len(chunk) > self.config.window_size:
raise ValueError(
f"Chunk size {len(chunk)} cannot be larger than model input ({self.config.window_size})"
)
self.chunk_count += 1
self.predictions_storage[self.chunk_count] = []
self.frames_tracker.append(len(chunk) // self.config.hop)
# Shift buffer to the left by the chunk size and write new chunk
self.audio_buffer = torch.roll(self.audio_buffer, -len(chunk), dims=0)
self.audio_buffer[-len(chunk) :] = chunk
spectrogram = log_mel_spectrogram(
audio=self.audio_buffer,
n_mels=self.config.n_mels,
mel_filters=self.mel_filters,
device=self.device,
)
# trim padding if needed
if self.config.padding_size > 0:
spectrogram = spectrogram[:, self.config.padding_size :]
# normalize spectrogram with running mean and std
spectrogram = self.normalize_spectrogram(spectrogram)
# if we are through initial steps and not every element of the buffer is filled
# zero spectrogram elements that were produced with zeroes in buffer
frames_filled = sum(self.frames_tracker)
if frames_filled < spectrogram.shape[1]:
spectrogram[:, : spectrogram.shape[1] - frames_filled] = 0.0
# predict
with torch.no_grad():
predictions = self.__model(spectrogram.unsqueeze(0).unsqueeze(0)).flatten()
# update chunk tracker with predictions per chunk
accounted_frames = 0
dispose = []
for chunk_num in reversed(self.predictions_storage):
chunk_frames = self.frames_tracker[chunk_num]
start_idx = max(
self.config.padding_size,
len(predictions) - accounted_frames - chunk_frames,
)
end_idx = len(predictions) - accounted_frames
if end_idx < start_idx:
dispose.append(chunk_num)
continue
chunk_predictions = predictions[start_idx:end_idx]
# if chunk went over left side of the buffer so that
# predictions are less than chunk_frames, pad with NaNs
# then when computing mean over all prediction series,
# ignore NaNs
if len(chunk_predictions) < chunk_frames:
chunk_predictions = torch.nn.functional.pad(
chunk_predictions,
(0, chunk_frames - len(chunk_predictions)),
mode="constant",
value=float("nan"),
)
self.predictions_storage[chunk_num].append(chunk_predictions)
if self.mode == "instant":
del self.predictions_storage[chunk_num]
return chunk_predictions
accounted_frames += chunk_frames
# prepare output
mean_predictions = {}
for chunk_num in reversed(self.predictions_storage):
predictions = torch.stack(self.predictions_storage[chunk_num])
mean = torch.nanmean(predictions, dim=0)
mean_predictions[chunk_num] = mean
# delete chunks from tracker that are no longer needed
for chunk_num in dispose:
del self.predictions_storage[chunk_num]
return mean_predictions
update = push
logits = push
[docs]
def predict(
self, chunk: Union[List, np.ndarray, torch.Tensor]
) -> Union[torch.Tensor, Dict]:
"""
Predicts whether audio chunks contain speech based on model predictions.
Args:
chunk (Union[List, np.ndarray, torch.Tensor]): Input audio chunk to process.
Can be a list, numpy array or PyTorch tensor.
Returns:
Union[torch.Tensor, Dict]:
If mode is "instant": Returns boolean tensor where True indicates speech was detected
(predictions above threshold)
If mode is "gradual": Returns dictionary mapping chunk numbers to boolean tensors
indicating speech detection for each chunk
Raises:
ValueError: If input chunk has invalid format or dimensions
"""
predictions = self.push(chunk)
if self.mode == "instant":
return predictions > self.config.threshold
output = {}
for chunk_num in predictions:
output[chunk_num] = predictions[chunk_num] > self.config.threshold
return output
[docs]
def detect(
self, chunk: Union[List, np.ndarray, torch.Tensor], min_duration: float = 0.0
) -> bool:
"""
Detect speech presence in the provided audio chunk.
This method analyzes an audio chunk to detect speech segments and determines if any speech segment
exceeds the minimum duration threshold.
Args:
chunk (Union[List, np.ndarray, torch.Tensor]): Audio data chunk to analyze.
min_duration (float, optional): Minimum duration in seconds for a speech segment to be considered valid.
Defaults to 0.0. If 0.0, uses the pipeline's default minimum duration.
Returns:
bool: True if speech segments longer than min_duration are detected, False otherwise.
Notes:
- The method maintains a detection_carry state variable to handle speech segments that span multiple chunks
- Speech segments are identified by analyzing state changes in model predictions
- Duration is calculated based on the configured frames per second (fps)
"""
if self.mode == "gradual":
warnings.warn(
'"gradual" mode detected. Switching to "instant" mode for detection.'
)
self.mode = "instant"
predictions = self.predict(chunk)
# Find state changes
changes = torch.diff(
predictions.int(), prepend=torch.tensor([0], device=predictions.device)
)
# Get start positions of True sequences
starts = torch.nonzero(changes == 1).flatten()
# Get end positions of True sequences
ends = torch.nonzero(changes == -1).flatten()
if len(starts) == 0:
self.detection_carry = 0
return False
# Handle case where sequence ends with True
if len(ends) < len(starts):
ends = torch.cat(
[ends, torch.tensor([len(predictions)], device=predictions.device)]
)
# Calculate durations in frames
lengths = ends - starts
# Convert to time
durations = lengths.float() / self.config.fps
# if first duration starts at zero and self.detection_carry > 0
# add detection_carry
if starts[0] == 0 and self.detection_carry > 0:
durations[0] += self.detection_carry
# if last ends at the end of the chunk, store the carry
if ends[-1] == len(predictions):
self.detection_carry = durations[-1]
else:
self.detection_carry = 0
# check if any duration is > min_duration
if torch.any(durations > min_duration):
return True
return False
[docs]
def intervals(
self, chunk: Union[List, np.ndarray, torch.Tensor]
) -> Union[List[Tuple], Dict]:
"""
Process the chunk of data and return intervals based on predictions.
This method processes input data chunks and returns time intervals based on the prediction mode.
For 'instant' mode, it directly converts predictions from latest chunk to intervals.
For 'gradual' mode, it maintains and updates intervals across all chunks.
Args:
chunk (Union[List, numpy.ndarray, torch.Tensor]): The chunk of data to process.
Returns:
Union[List[Tuple], Dict]: The intervals based on predictions.
"""
predictions = self.predict(chunk)
if self.mode == "instant":
return self.predictions_to_intervals(self.chunk_count, predictions)
# else if self.mode == 'gradual'
for chunk_num in predictions:
ivs = self.predictions_to_intervals(chunk_num, predictions[chunk_num])
self.predicted_intervals[chunk_num] = ivs
output = {"last": None, "revised": None, "final": None}
last_chunk_num = max(self.predicted_intervals)
output["last"] = self.predicted_intervals[last_chunk_num]
revised_intervals = [
self.predicted_intervals[k] for k in predictions if k != last_chunk_num
]
revised_intervals_merged = self.merge_intervals(
[t for ival in revised_intervals for t in ival]
)
output["revised"] = revised_intervals_merged
final_intervals = [
self.predicted_intervals[k]
for k in self.predicted_intervals
if k not in predictions
]
final_intervals_merged = self.merge_intervals(
[t for ival in final_intervals for t in ival]
)
output["final"] = final_intervals_merged
return output
[docs]
@staticmethod
def merge_intervals(intervals: List[Tuple]) -> List:
"""
Merges adjacent intervals in a list of tuples.
This function takes a list of intervals (start, end) and merges any overlapping
intervals into a single interval. Two intervals are considered overlapping if
the start of one interval is within 0.01 of the end of another interval.
Args:
intervals (List[Tuple]): List of tuples where each tuple contains
start and end points of an interval.
Returns:
List: A new list containing merged intervals with no overlaps.
"""
if len(intervals) == 0:
return []
intervals.sort(key=lambda x: x[0])
merged = [intervals[0]]
for current in intervals:
previous = merged[-1]
if current[0] <= previous[1]:
previous = (previous[0], max(previous[1], current[1]))
merged[-1] = previous
else:
merged.append(current)
return merged
[docs]
def predictions_to_intervals(
self, chunk_num: int, predictions: torch.Tensor
) -> List[Tuple[float, float]]:
"""
Convert binary predictions tensor into a list of time intervals.
Args:
chunk_num (int): Index of the current chunk being processed
predictions (torch.Tensor): Binary tensor containing predictions (0s and 1s)
indicating presence/absence of target signal
Returns:
List[Tuple[float, float]]: List of time intervals (start_time, end_time) where
target signal is present. Times are in seconds relative to start of recording.
"""
start_frame = sum(self.frames_tracker[:chunk_num])
offset_time = start_frame / self.config.fps
scale = 1.0 / self.config.fps
if predictions.all():
return [(offset_time, offset_time + (len(predictions)) * scale)]
elif not predictions.any():
return []
changes = torch.diff(predictions.int())
change_points = torch.nonzero(changes).flatten()
if predictions[0]:
change_points = torch.cat(
[torch.tensor([0], device=predictions.device), change_points]
)
if predictions[-1]:
change_points = torch.cat(
[
change_points,
torch.tensor([len(predictions)], device=predictions.device),
]
)
pairs = change_points.reshape(-1, 2)
return [
(offset_time + pair[0].item() * scale, offset_time + pair[1].item() * scale)
for pair in pairs
]
def __repr__(self) -> str:
window_size_sec = self.config.window_size_frames / self.config.fps
return (
f"JaVAD [stream] Pipeline(\n"
f" model name: {self.config.model_name!r},\n"
f" mode : {self.mode!r},\n"
f" threshold : {self.config.threshold!r}\n"
f" device : {self.device!r},\n"
f" ══[INFO]═══════════\n"
f" output, timespan: {(window_size_sec/self.config.model_output_length)!r}s,\n"
f" max_chunk_size : {self.config.window_size_frames!r},\n"
f" window size : {window_size_sec!r}s,\n"
f")"
)