Skip to content

kyutai-labs/moshi-finetune

Repository files navigation

Moshi-Finetune

Open In Colab

Moshi interface

Moshi-Finetune provides an easy way to fine-tune Moshi models using LoRA (Low-Rank Adaptation) for lightweight and efficient training. This guide walks you through installation, model downloading, dataset preparation, training, and inference. By following these steps, you'll be able to: transform stereo audio files into your very own transcribed dataset, fine-tune moshi weights on real conversations, and—best of all—chat with your customized moshi model!

📥 Installation

You can also follow along interactively in our Colab, see link at the top.

To get started, follow these steps:

1️⃣ Clone this repository

git clone [email protected]:kyutai-labs/moshi-finetune.git

2️⃣ Install all required dependencies:

We recommend using uv to manage the environment. It's about 10x faster than pip and has a bunch of other benefits too. Once you've installed uv, no explicit package installation is required: just prefix every command with uv run (e.g. train using uv run torchrun ...). This will automatically install the necessary dependencies based on pyproject.toml.

Installing without uv

If you prefer working with pip, and handling the install manually, you will need at least Python 3.10. We still advise using a virtual environment, which can be created using Conda virtualenv. Then, run:

cd moshi-finetune
pip install -e .

📥 Model configuration

The training setup is specified via a YAML configuration file. The example configuration files are located in the example directory.

We recommend fine-tuning one of the official moshi models. To achieve this, you can use the following section in your configuration file.

moshi_paths:
   hf_repo_id: "kyutai/moshiko-pytorch-bf16"

📚 Prepare dataset

The pipeline expects a dataset of stereo audio files, the left channel is used for the audio generated by moshi, whereas the second channel is used for the user's input.

The files contained in the dataset should be specified in a .jsonl file, where each line has the form

{"path": "relative/path/to/file.wav", "duration": <duration in seconds>}

Each audio file should have an associated .json file that contains the transcript with timestamps. These JSONs can be generated automatically, see below.

For example, the following would be a valid directory structure for a dataset:

data/
├── mycooldataset.jsonl
└── data_stereo
    ├── a.json
    ├── a.wav
    ├── b.json
    ├── b.wav
    ├── c.json
    └── c.wav

where mycooldataset.jsonl contains:

{"path": "data_stereo/a.wav", "duration": 24.521950113378686}
{"path": "data_stereo/b.wav", "duration": 18.317074829931972}
{"path": "data_stereo/c.wav", "duration": 39.38641723356009}

The .jsonl file can be generated with the snippet below. This will include all the .wav files in a given directory.

import sphn
import json
from pathlib import Path

paths = [str(f) for f in Path("wav-dir").glob("*.wav")]
durations = sphn.durations(paths)

with open("data.jsonl", "w") as fobj:
    for p, d in zip(paths, durations):
        if d is None:
            continue
        json.dump({"path": p, "duration": d}, fobj)
        fobj.write("\n")

A sample dataset in this format can be found in the kyutai/DailyTalkContiguous repository. This 14 GB dataset can be downloaded using the following snippet:

from huggingface_hub import snapshot_download

local_dir = snapshot_download(
    "kyutai/DailyTalkContiguous",
    repo_type="dataset",
    local_dir="./daily-talk-contiguous"
)

If you want to annotate your own dataset and generate the .json transcripts for each audio file, you can use the annotate.py script:

python annotate.py {your jsonl file}

This script can also be run in a distributed manner with SLURM using e.g. --shards 64 --partition 'your-partition'.

🏋️ Start training

Once your dataset is ready, start fine-tuning using the following steps.

📌 Recommended settings for quick training:

lora:
  enable: true
  rank: 128
  scaling: 2.

duration_sec: 100
batch_size: 16
max_steps: 2000

📌 Run training on a single GPU:

torchrun --nproc-per-node 1 -m train example/moshi_7B.yaml

Note that you should still use torchrun even if you're only using a single GPU.

📌 Run training on multiple GPUs (8):

torchrun --nproc-per-node 8 --master_port $RANDOM -m train example/moshi_7B.yaml

💡 Expected performance:

Using the above hyperparameters:

Avg Tokens/sec Peak Allocated Memory
1×H100 ≈12k 39.6GB
8×H100 ≈10.7k 23.7GB

If you encounter out-of-memory errors, try reducing the batch_size. If the issue persists, you can lower the duration_sec parameter, but be aware that this may negatively impact the user experience during inference, potentially causing the model to become silent more quickly.

⚙️ Customizing training configuration

The example moshi-finetune/example/moshi_7B.yaml defines reasonable parameters for learning rate, weight decay, etc... but you are advised to customize these settings for your use case.

🔧 Key training parameters

Parameter Description
moshi_paths Defines all the paths: .hf_repo_id if the model is imported from Hugging Face Hub ( .hf_repo_id enables to change default settings), for more information take a look at Moshi loading.
run_dir Directory where training checkpoints and logs are stored.
duration_sec Maximum sequence length (in seconds) for training.
first_codebook_weight_multiplier The first codebook being the semantic token, we put more weight on it.
text_padding_weight Most of the text stream is padding as audio is 12.5Hz with mimi but tokenizing text takes less space. Decrease the loss weight on paddings to avoid the model over-focussing on predicting paddings.
gradient_checkpointing Whether to use gradient checkpointing per transformer layer to mitigate out of memory issues.
batch_size Number of training examples per GPU.
max_steps Total number of training steps. Defines how many iterations the training will run. **Total tokens processed = max_steps × num_gpus × batch_size × duration_seq × 9 (token per step) × 12.5 (step per second) **.
optim.lr Learning rate. Recommended starting value: 2e-6.
optim.weight_decay Weight decay for regularization. Default: 0.1.
optim.pct_start Percentage of total training steps used for learning rate warm-up before decay. Equivalent to pct_start in PyTorch’s OneCycleLR.
lora.rank Size of the LoRA adapters. Recommended ≤128 for efficiency.
lora.ft_embed Whether to full-finetune embedding matrices while fine-tuning with LoRA all the other linear layers.
seed Random seed for initialization, data shuffling, and sampling (ensures reproducibility).
log_freq Defines how often (in steps) training metrics are logged.
data.train_data Path to the dataset used for training.
data.eval_data (Optional) Path to evaluation dataset for cross-validation at eval_freq intervals.
data.shuffle Whether to shuffle training samples (Recommended).
eval_freq Number of steps between evaluations on the validation set.
no_eval If False, enables periodic model evaluation during training.
ckpt_freq Number of steps between saving model checkpoints.
full_finetuning Set to True for full fine-tuning, or False to use LoRA for adaptation.
save_adapters If True, saves only LoRA adapters (works with Moshi Inference). If False, merges LoRA into the base model (requires sufficient CPU/GPU memory).
wandb.key API key for Weights & Biases (wandb) logging (Optional).
wandb.project Name of the wandb project where training logs will be stored.

Training curves

Figure 1: Training curves over steps on dailytalk dataset using a maximal learning rate of 4e-6.

🔮 Inference

1️⃣ Install Moshi for inference

Once your model is trained, you can use it in interactive mode using moshi. The package should already be in your environment if you used the requirements.txt file. If not, you can install it using pip install git+https://[email protected]/kyutai-labs/moshi.git#egg=moshi&subdirectory=moshi.

2️⃣ Run inference using the fine-tuned model

Let's say a checkpoint was saved under CHECKPOINT_DIR=$HOME/dailydialog_ft/checkpoints/checkpoint_000500.

If you trained using LORA, you can run the Moshi web app using:

python -m moshi.server \
  --lora-weight=$CHECKPOINT_DIR/consolidated/lora.safetensors \
  --config-path=$CHECKPOINT_DIR/consolidated/config.json

This will run the fine-tuned model by applying your LORA adapter on top of the base model's weights.

Otherwise (if your ran full fine-tuning, or didn't checkpoint LORA only), you can run:

python -m moshi.server \
--moshi-weight=$CHECKPOINT_DIR/consolidated/consolidated.safetensors \
--config-path=$CHECKPOINT_DIR/consolidated/consolidated/config.json 

Here, consolidated.safetensors contains all of the new Moshi weights and doesn't reference any base model.

📊 Monitoring with Weights & Biases (W&B)

Explicit support for Weights and Biases are added to help you monitor and visualize your training runs. This integration allows you to log various metrics and track experiments easily.

To use Weights and Biases with moshi-finetune, install wandb using pip install wandb and fill the wandb: section of your YAML configuration, see example/moshi_7B.yaml.

Once the training starts, you can monitor the progress in real-time by visiting your wandb project dashboard. All metrics, including training loss, evaluation loss, learning rate, etc., will be logged and visualized.

For more details on how to use wandb, visit the Weights and Biases documentation.

Acknowledgments

This project uses code from mistral-finetune licensed under the Apache License 2.0.