Moshi-Finetune provides an easy way to fine-tune Moshi models using LoRA (Low-Rank Adaptation) for lightweight and efficient training. This guide walks you through installation, model downloading, dataset preparation, training, and inference. By following these steps, you'll be able to: transform stereo audio files into your very own transcribed dataset, fine-tune moshi weights on real conversations, and—best of all—chat with your customized moshi model!
You can also follow along interactively in our Colab, see link at the top.
To get started, follow these steps:
git clone [email protected]:kyutai-labs/moshi-finetune.git
We recommend using uv
to manage the environment.
It's about 10x faster than pip
and has a bunch of other benefits too.
Once you've installed uv
, no explicit package installation is required:
just prefix every command with uv run
(e.g. train using uv run torchrun ...
).
This will automatically install the necessary dependencies based on pyproject.toml
.
If you prefer working with pip
, and handling the install manually, you will need at least Python 3.10.
We still advise using a virtual environment,
which can be created using Conda
virtualenv.
Then, run:
cd moshi-finetune
pip install -e .
The training setup is specified via a YAML configuration file. The example
configuration files are located in the example
directory.
We recommend fine-tuning one of the official moshi models. To achieve this, you can use the following section in your configuration file.
moshi_paths:
hf_repo_id: "kyutai/moshiko-pytorch-bf16"
The pipeline expects a dataset of stereo audio files, the left channel is used for the audio generated by moshi, whereas the second channel is used for the user's input.
The files contained in the dataset should be specified in a .jsonl
file,
where each line has the form
{"path": "relative/path/to/file.wav", "duration": <duration in seconds>}
Each audio file should have an associated .json
file that contains
the transcript with timestamps. These JSONs can be generated automatically, see below.
For example, the following would be a valid directory structure for a dataset:
data/
├── mycooldataset.jsonl
└── data_stereo
├── a.json
├── a.wav
├── b.json
├── b.wav
├── c.json
└── c.wav
where mycooldataset.jsonl
contains:
{"path": "data_stereo/a.wav", "duration": 24.521950113378686}
{"path": "data_stereo/b.wav", "duration": 18.317074829931972}
{"path": "data_stereo/c.wav", "duration": 39.38641723356009}
The .jsonl
file can be generated with the snippet below. This will include
all the .wav
files in a given directory.
import sphn
import json
from pathlib import Path
paths = [str(f) for f in Path("wav-dir").glob("*.wav")]
durations = sphn.durations(paths)
with open("data.jsonl", "w") as fobj:
for p, d in zip(paths, durations):
if d is None:
continue
json.dump({"path": p, "duration": d}, fobj)
fobj.write("\n")
A sample dataset in this format can be found in the kyutai/DailyTalkContiguous repository. This 14 GB dataset can be downloaded using the following snippet:
from huggingface_hub import snapshot_download
local_dir = snapshot_download(
"kyutai/DailyTalkContiguous",
repo_type="dataset",
local_dir="./daily-talk-contiguous"
)
If you want to annotate your own dataset and generate the .json
transcripts for each
audio file, you can use the annotate.py
script:
python annotate.py {your jsonl file}
This script can also be run in a distributed manner with SLURM using e.g.
--shards 64 --partition 'your-partition'
.
Once your dataset is ready, start fine-tuning using the following steps.
lora:
enable: true
rank: 128
scaling: 2.
duration_sec: 100
batch_size: 16
max_steps: 2000
torchrun --nproc-per-node 1 -m train example/moshi_7B.yaml
Note that you should still use torchrun
even if you're only using a single GPU.
torchrun --nproc-per-node 8 --master_port $RANDOM -m train example/moshi_7B.yaml
Using the above hyperparameters:
Avg Tokens/sec | Peak Allocated Memory | |
---|---|---|
1×H100 | ≈12k | 39.6GB |
8×H100 | ≈10.7k | 23.7GB |
If you encounter out-of-memory errors, try reducing the batch_size
. If the issue persists, you can lower the duration_sec
parameter, but be aware that this may negatively impact the user experience during inference, potentially causing the model to become silent more quickly.
The example moshi-finetune/example/moshi_7B.yaml
defines reasonable parameters for learning rate, weight decay, etc... but you are advised to
customize these settings for your use case.
Parameter | Description |
---|---|
moshi_paths |
Defines all the paths: .hf_repo_id if the model is imported from Hugging Face Hub ( .hf_repo_id enables to change default settings), for more information take a look at Moshi loading. |
run_dir |
Directory where training checkpoints and logs are stored. |
duration_sec |
Maximum sequence length (in seconds) for training. |
first_codebook_weight_multiplier |
The first codebook being the semantic token, we put more weight on it. |
text_padding_weight |
Most of the text stream is padding as audio is 12.5Hz with mimi but tokenizing text takes less space. Decrease the loss weight on paddings to avoid the model over-focussing on predicting paddings. |
gradient_checkpointing |
Whether to use gradient checkpointing per transformer layer to mitigate out of memory issues. |
batch_size |
Number of training examples per GPU. |
max_steps |
Total number of training steps. Defines how many iterations the training will run. **Total tokens processed = max_steps × num_gpus × batch_size × duration_seq × 9 (token per step) × 12.5 (step per second) **. |
optim.lr |
Learning rate. Recommended starting value: 2e-6. |
optim.weight_decay |
Weight decay for regularization. Default: 0.1. |
optim.pct_start |
Percentage of total training steps used for learning rate warm-up before decay. Equivalent to pct_start in PyTorch’s OneCycleLR . |
lora.rank |
Size of the LoRA adapters. Recommended ≤128 for efficiency. |
lora.ft_embed |
Whether to full-finetune embedding matrices while fine-tuning with LoRA all the other linear layers. |
seed |
Random seed for initialization, data shuffling, and sampling (ensures reproducibility). |
log_freq |
Defines how often (in steps) training metrics are logged. |
data.train_data |
Path to the dataset used for training. |
data.eval_data |
(Optional) Path to evaluation dataset for cross-validation at eval_freq intervals. |
data.shuffle |
Whether to shuffle training samples (Recommended). |
eval_freq |
Number of steps between evaluations on the validation set. |
no_eval |
If False , enables periodic model evaluation during training. |
ckpt_freq |
Number of steps between saving model checkpoints. |
full_finetuning |
Set to True for full fine-tuning, or False to use LoRA for adaptation. |
save_adapters |
If True , saves only LoRA adapters (works with Moshi Inference). If False , merges LoRA into the base model (requires sufficient CPU/GPU memory). |
wandb.key |
API key for Weights & Biases (wandb) logging (Optional). |
wandb.project |
Name of the wandb project where training logs will be stored. |
Once your model is trained, you can use it in interactive mode using moshi.
The package should already be in your environment if you used the
requirements.txt
file. If not, you can install it using pip install git+https://[email protected]/kyutai-labs/moshi.git#egg=moshi&subdirectory=moshi
.
Let's say a checkpoint was saved under CHECKPOINT_DIR=$HOME/dailydialog_ft/checkpoints/checkpoint_000500
.
If you trained using LORA, you can run the Moshi web app using:
python -m moshi.server \
--lora-weight=$CHECKPOINT_DIR/consolidated/lora.safetensors \
--config-path=$CHECKPOINT_DIR/consolidated/config.json
This will run the fine-tuned model by applying your LORA adapter on top of the base model's weights.
Otherwise (if your ran full fine-tuning, or didn't checkpoint LORA only), you can run:
python -m moshi.server \
--moshi-weight=$CHECKPOINT_DIR/consolidated/consolidated.safetensors \
--config-path=$CHECKPOINT_DIR/consolidated/consolidated/config.json
Here, consolidated.safetensors
contains all of the new Moshi weights and doesn't reference any base model.
Explicit support for Weights and Biases are added to help you monitor and visualize your training runs. This integration allows you to log various metrics and track experiments easily.
To use Weights and Biases with moshi-finetune
, install wandb
using pip install wandb
and fill the wandb:
section
of your YAML configuration, see example/moshi_7B.yaml
.
Once the training starts, you can monitor the progress in real-time by visiting your wandb project dashboard. All metrics, including training loss, evaluation loss, learning rate, etc., will be logged and visualized.
For more details on how to use wandb, visit the Weights and Biases documentation.
This project uses code from mistral-finetune licensed under the Apache License 2.0.