In [1]:
from pathlib import Path

import torch
import numpy as np
import scipy
import faiss
from tqdm.auto import tqdm

from data import setup_loader
from test_data import setup_loader as setup_test_loader
from utils import Attrs, default_opts, write_results, get_features_labels, group_features_by_label, balance_frames

In [2]:
def knn_single(data, reference_labels, k, faiss_index=None, reference=None):
    nr_classes = len(np.unique(reference_labels))

    # NN frame-wise
    if faiss_index is not None:
        assert reference is None
        euclidean, indxs = faiss_index.search(data, k)
        euclidean = np.sqrt(euclidean)
        nn_idxs_flat = None
        nn_idxs_per_frame = indxs
        first_nn_dist_per_frame = euclidean[:, 0]
    else:
        assert not reference is None
        euclidean = scipy.spatial.distance.cdist(data, reference, metric="euclidean")
        nn_idxs_flat = np.argsort(euclidean, axis=None)
        nn_idxs_per_frame = np.argsort(euclidean, axis=1)
        first_nn_dist_per_frame = euclidean[np.arange(len(data)), nn_idxs_per_frame[:, :1].reshape(-1)]

    # Global K-NN, independent of individual frames
    if not nn_idxs_flat is None:
        nn_idxs_global = nn_idxs_flat % len(reference_labels)
        knn_idxs_global = nn_idxs_global[:k]
        knn_labels_global = reference_labels[knn_idxs_global]
        counts = np.bincount(knn_labels_global, minlength=nr_classes + 1)
        probs_knn_global = counts[1:] / k
        assert counts[0] == 0
    else:
        probs_knn_global = np.ones(nr_classes) / nr_classes

    # Frame-wise 1-NN, then K-NN based on those distances
    nn_idx_per_frame = nn_idxs_per_frame[:, :1]
    nn_label_per_frame = reference_labels[nn_idx_per_frame.reshape(-1)]
    knn_idxs = np.argsort(first_nn_dist_per_frame, axis=None)[:k]
    knn_labels = nn_label_per_frame[knn_idxs]
    counts = np.bincount(knn_labels, minlength=nr_classes + 1)
    probs_knn_frames_1nn = counts[1:] / min([k, len(data)])
    assert counts[0] == 0

    # Frame-wise 1-NN, then the top occurence
    nn_idx_per_frame = nn_idxs_per_frame[:, :1]
    nn_label_per_frame = reference_labels[nn_idx_per_frame.reshape(-1)]
    counts = np.bincount(nn_label_per_frame, minlength=nr_classes + 1)
    probs_top1_frames_1nn = counts[1:] / len(data)
    assert counts[0] == 0

    # def get_top_label(row):
    #     counts = np.bincount(row, minlength=nr_classes + 1)
    #     top_count = counts.max()
    #     top_labels = np.flatnonzero(counts == top_count)
    #     top_label = row[np.flatnonzero(np.isin(row, top_labels))[0]]
    #     return top_label
    # top_labels = np.apply_along_axis(get_top_label, 1, knn_labels_per_frame)

    # Frame-wise K-NN, then rank based on (un)weighted occurences, and unweighted counted occurences.
    score_w = np.zeros(nr_classes + 1)
    score_now = np.zeros(nr_classes + 1)
    knn_idxs_per_frame = nn_idxs_per_frame[:, :k]
    knn_labels_per_frame = reference_labels[knn_idxs_per_frame.reshape(-1)].reshape(knn_idxs_per_frame.shape)
    weights = np.arange(k, 0, -1)
    for labels in knn_labels_per_frame:
        for label, weight in zip(labels, weights):
            score_w[label] += weight
        counts = np.bincount(labels, minlength=nr_classes + 1)
        score_now += counts
    probs_wtopk_frames_knn = score_w[1:] / score_w.sum()
    probs_nowtopk_frames_knn = score_now[1:] / score_now.sum()
    assert score_w[0] == 0
    assert score_now[0] == 0

    probs = {}
    probs["knn_global"] = probs_knn_global
    probs["knn_frames_1nn"] = probs_knn_frames_1nn
    probs["wtopk_frames_knn"] = probs_wtopk_frames_knn
    probs["nowtopk_frames_knn"] = probs_nowtopk_frames_knn
    probs["top1_frames_1nn"] = probs_top1_frames_1nn
    assert np.allclose([ps.sum() for ps in probs.values()], 1)
    return probs

In [3]:
def knn_all(x_train, y_train, x_val, modes=["faiss", "scipy"]):
    reference = np.concatenate(x_train_limited, axis=0).astype(np.float32)
    reference_labels = np.concatenate([np.full((len(frames),), label) for frames, label in zip(x_train_limited, y_train)], axis=0)
    index = faiss.IndexFlatL2(opts.nr_ceps)
    index.add(reference)

    trackers = {}
    if "faiss" in modes:
        trackers["faiss"] = Attrs()
    if "scipy" in modes:
        trackers["scipy"] = Attrs()
    assert len(trackers)

    for obj in trackers.values():
        obj.probs = {}
        obj.preds = {}

    if opts.max_seconds_per_sample > 0:
        max_overlapped_frames = int((opts.max_seconds_per_sample * 16000) / (opts.frames_in_window - opts.frames_overlap))
        print(f"{opts.max_seconds_per_sample=}, {max_overlapped_frames=}")

    for data in tqdm(x_val, unit="batch", unit_scale=True):
        if opts.max_seconds_per_sample > 0:
            if len(data) > max_overlapped_frames:
                if opts.shuffle:
                    data = opts.rng.choice(data, size=max_overlapped_frames, replace=False, p=None, axis=0, shuffle=True)
                else:
                    data = data[:max_overlapped_frames]
        for name, tracker in trackers.items():
            if name == "faiss":
                probs = knn_single(data, reference_labels, k=opts.knn_k, faiss_index=index, reference=None)
            else:
                probs = knn_single(data, reference_labels, k=opts.knn_k, faiss_index=None, reference=reference)
            for variant, ps in probs.items():
                if variant not in tracker.probs:
                    tracker.probs[variant] = [ps]
                else:
                    tracker.probs[variant].append(ps)

    for obj in trackers.values():
        for variant in probs.keys():
            obj.probs[variant] = np.vstack(obj.probs[variant])
            obj.preds[variant] = obj.probs[variant].argmax(axis=1) + 1

    return trackers

def get_hits(x, y):
    return np.sum(x == y)

def eval_knn(x_train, y_train, x_val, y_val, **kwargs):
    trackers = knn_all(x_train, y_train, x_val, **kwargs)
    gt = np.array(y_val)
    for name, obj in trackers.items():
        print(name)
        for variant, preds in obj.preds.items():
            print(hits := get_hits(preds, gt), f"out of {len(gt)} ({int((hits / len(gt)) * 100)}%) |", variant)
        print()
    return trackers

In [4]:
dl_train = setup_loader(data_dir="data/train", batch_size=16, shuffle=False)
ds_train = dl_train.dataset
dl_val = setup_loader(data_dir="data/dev", batch_size=16, shuffle=False)
ds_val = dl_val.dataset

In [5]:
opts = default_opts
opts.vad_source = "mfcc0"
opts.shuffle = True
opts.max_seconds_per_sample = -1
opts.knn_k = 3
opts.nr_ceps = opts.nr_banks

In [6]:
x_train, y_train = get_features_labels(dl_train, opts)
x_val, y_val = get_features_labels(dl_val, opts)

In [7]:
if opts.balance_by_reduction:
    x_train_limited = balance_frames(x_train, opts)
else:
    x_train_limited = x_train

In [8]:
trackers = eval_knn(x_train_limited, y_train, x_val, y_val, modes=["faiss"])

  0%|          | 0.00/62.0 [00:00<?, ?batch/s]

faiss
2 out of 62 (3%) | knn_global
11 out of 62 (17%) | knn_frames_1nn
49 out of 62 (79%) | wtopk_frames_knn
50 out of 62 (80%) | nowtopk_frames_knn
45 out of 62 (72%) | top1_frames_1nn



In [37]:
dl_test = setup_test_loader(data_dir="data/eval", batch_size=16, shuffle=False)
x_test, segments = get_features_labels(dl_test, opts, segments=True)

trackers = knn_all(x_train_limited, y_train, x_test, modes=["faiss"])

write_results(Path("results_knn.txt"), trackers["faiss"].probs["wtopk_frames_knn"], segments, opts, is_score=False)

  0%|          | 0.00/736 [00:00<?, ?batch/s]