This is a PyTorch implementation of Toward Universal Text-to-Music Retrieval for multi-modal music representation learning. Check our demo
Toward Universal Text-to-Music Retrieval
SeungHeon Doh, Minz Won, Keunwoo Choi, Juhan Nam
To appear ICASSP 2023
TL;DR
- We introduced effective design choices for universal text-to-music retrieval. Recent text-music representation learning frameworks are assessed by using a carefully designed dataset and downstream tasks
- Our proposed stochastic text representation achieved robust performance in tag-level, caption-level, and zero-shot query retrieval cases
- Contrastive models achieve better performance than triplet models in both retrieval and downstream tasks.
- Reproducible code pre-trained models, MSD-ECALS music-caption dataset and the downstream benchmark are available online for future research.
The following results are based on MSD-ECAL dataset pre-training. Pre-trained models and configs can be found at Zenodo-Pretrained.
Tag based Retrieval | Language based Retrieval | ||||||||
---|---|---|---|---|---|---|---|---|---|
Model Type | Text Enc. | Text Rep. | 50 Tag | 1054 Tag | 1000 Music-Caption Pair | ||||
ROC/PR | ROC/PR | R@1 | R@5 | R@10 | mAP | MedR | |||
Classification | Binary | Tag | 90.2/39.5 | 86.4/8.8 | 4.0 | 13.8 | 20.1 | 8.3 | 86 |
Triplet | GloVe | Tag | 89.2/36.0 | 82.6/6.1 | 2.8 | 11.2 | 18.6 | 6.6 | 51.5 |
Triplet | GloVe | Caption | 88.6/37.1 | 76.8/5.3 | 5.4 | 22.1 | 35.0 | 13.0 | 17.0 |
Triplet | GloVe | Stochastic | 89.2/37.6 | 81.6/6.2 | 6.4 | 21.8 | 32.7 | 12.8 | 19.5 |
Triplet | BERT | Tag | 86.9/30.2 | 81.7/5.1 | 1.6 | 6.2 | 12.0 | 3.9 | 68.0 |
Triplet | BERT | Caption | 87.7/35.0 | 78.9/5.4 | 6.7 | 23.6 | 36.6 | 14.1 | 16.0 |
Triplet | BERT | Stochastic | 88.4/35.0 | 83.6/6.3 | 6.6 | 25.1 | 39.4 | 14.6 | 16.0 |
Contrastive | BERT | Tag | 90.6/40.2 | 86.4/8.8 | 2.5 | 13.7 | 22.5 | 7.4 | 47.0 |
Contrastive | BERT | Caption | 87.0/32.5 | 77.6/5.1 | 6.8 | 25.4 | 38.4 | 15.3 | 17.0 |
Contrastive | BERT | Stochastic | 89.8/38.0 | 84.8/7.7 | 10.2 | 29.8 | 42.8 | 18.7 | 13.0 |
Note:
- See our paper for more results on different benchmarks, including MTAT, MTG-Jamendo, FMA, GTZAN, Emotify, KVT.
-
Install python and PyTorch:
- python==3.8
- torch==1.12.1 (Please install it according to your CUDA version.)
-
Other requirements:
- pip install -e .
conda create -n YOUR_ENV_NAME python=3.8
conda activate YOUR_ENV_NAME
pip install -e .
wget https://zenodo.org/record/7322135/files/mtr.tar.gz
tar -zxvf mtr.tar.gz
Please refer to notebook/demo.ipynb for MSD-testset tag, sentence, unseen query retrieval. Below is the audio and text embedding extraction code.
from mtr.utils.demo_utils import get_model
from mtr.utils.audio_utils import load_audio, STR_CH_FIRST
framework='contrastive'
text_type='bert'
text_rep="stochastic"
# load model
model, tokenizer, config = get_model(framework=framework, text_type=text_type, text_rep=text_rep)
def text_infer(query, model, tokenizer):
text_input = tokenizer(query, return_tensors="pt")['input_ids']
with torch.no_grad():
text_embs = model.encode_bert_text(text_input, None)
return text_embs
def audio_infer(audio_path, model, sr=16000, duration=9.91):
audio, _ = load_audio(
path=audio_path,
ch_format= STR_CH_FIRST,
sample_rate= sr,
downmix_to_mono= True
)
input_size = int(duration * sr)
hop = int(len(audio) // input_size)
audio = np.stack([np.array(audio[i * input_size : (i + 1) * input_size]) for i in range(hop)]).astype('float32')
audio_tensor = torch.from_numpy(audio)
with torch.no_grad():
z_audio = model.encode_audio(audio_tensor)
audio_embs = z_audio.mean(0).detach().cpu()
return audio_embs
query = "fusion jazz with synth, bass, drums, saxophone"
audio_path = "your_audio"
text_embs = text_infer(query, model, tokenizer)
audio_embs = audio_infer(audio_path, model)
From our empirical study, we find that there is a strong association between text representation (train stage) and text query types (test stage). We propose a stochastic text representation. During the training stage, we select K words from L length text caption. At this time, K is uniformly randomly sampled among integer numbers from 1 (tag length) to L (caption length). Unlike the dropout method, which determines the length by probability value, stochastic sampling has a dynamic input length.
def text_load(self, tag_list):
"""
input: tag_list = list of tag
output: text = string of text
"""
if self.text_rep == "caption":
if self.split == "TRAIN":
random.shuffle(tag_list)
text = ", ".join(tag_list)
elif self.text_rep == "tag":
text = [random.choice(tag_list)]
elif self.text_rep == "stochastic":
k = random.choice(range(1, len(tag_list)+1))
sampled_tag_list = random.sample(tag_list, k)
text = ", ".join(sampled_tag_list)
return text
Download ECALS(Extended Cleaned tag and Artist-Level Stratified split) dataset & MSD audio Link
cd mtr/{triplet or contrastive}
# train pretrained model
python train.py --text_type {bert,glove} --text_rep {tag,caption,stochastic} --data_dir {msd-subsets} --multiprocessing-distributed
# evaluation on ECALS dataset (single, multi query)
python test.py --text_type {bert,glove} --text_rep {tag,caption,stochastic} --data_dir {msd-subsets}
Following MoCo V3 Repo, This repo only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported. This code is improved to better suit the multi-node setting.
other pretrining settings are:
parser.add_argument("--duration", default=9.91, type=int)
parser.add_argument("--sr", default=16000, type=int)
parser.add_argument("--mel_dim", default=128, type=int)
parser.add_argument("--n_fft", default=1024, type=int)
parser.add_argument("--win_length", default=1024, type=int)
parser.add_argument("--frontend", default="cnn", type=str)
parser.add_argument("--mix_type", default="cf", type=str)
parser.add_argument("--audio_rep", default="mel", type=str)
parser.add_argument("--cos", default=True, type=bool)
parser.add_argument("--attention_nlayers", default=4, type=int)
parser.add_argument("--attention_ndim", default=256, type=int)
parser.add_argument("--temperature", default=0.2, type=float)
parser.add_argument("--mlp_dim", default=128, type=int) -> joint embedding dim
Download downstream dataset and preprocessing code github, and we release datasplit and metadata annotation in zenodo.
Downstream dataset consists MTAT, FMA, MTG-JAMENDO, GTZAN, KVT, Emotify
cd mtr/transfer
# extract embedding
python extractor.py --framework {classification, triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET
# eval zero-shot transfer
python eval_zs.py --framework {triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET
# train shallow classifier
python train_probing.py --probe_type {linear, mlp} --framework {classification, triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET
# eval shallow classifier
python eval_probing.py --probe_type {linear, mlp} --framework {classification, triplet, contrastive} --text_type {binary, glove, bert} --text_rep {tag,caption,stocahstic} --eval_dataset $DATASET
This project is under the CC-BY-NC 4.0 license. See LICENSE for details.
We would like to thank the MoCoV3 for its training code and jukemir-CodifiedLM for its evaluation protocal.
Please consider citing our paper in your publications if the project helps your research. BibTeX reference is as follow.
@inproceedings{doh2023toward,
title={Toward Universal Text-to-Music Retrieval},
author={Doh, SeungHeon and Won, Minz and Choi, Keunwoo and Nam, Juhan},
booktitle={ICASSP 2023-2023 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
year={2023}
}