diff --git a/docs/source/en/model_doc/falcon_mamba.md b/docs/source/en/model_doc/falcon_mamba.md index fb6debfef921..ef346e89892e 100644 --- a/docs/source/en/model_doc/falcon_mamba.md +++ b/docs/source/en/model_doc/falcon_mamba.md @@ -14,95 +14,100 @@ rendered properly in your Markdown viewer. --> -# FalconMamba - -
-PyTorch +
+
+ PyTorch +
-## Overview - -The FalconMamba model was proposed by TII UAE (Technology Innovation Institute) in their release. - -The abstract from the paper is the following: - -*We present FalconMamba, a new base large language model based on the novel Mamba architecture. FalconMamba is trained on 5.8 trillion tokens with carefully selected data mixtures. As a pure Mamba-based model, FalconMamba surpasses leading open-weight models based on Transformers, such as Mistral 7B, Llama3 8B, and Falcon2 11B. It is on par with Gemma 7B and outperforms models with different architecture designs, such as RecurrentGemma 9B. Currently, FalconMamba is the best-performing Mamba model in the literature at this scale, surpassing both existing Mamba and hybrid Mamba-Transformer models. -Due to its architecture, FalconMamba is significantly faster at inference and requires substantially less memory for long sequence generation. Despite recent studies suggesting that hybrid Mamba-Transformer models outperform pure architecture designs, we argue and demonstrate that the pure Mamba design can achieve similar, even superior results compared to the hybrid design. We make the weights of our implementation of FalconMamba publicly available under a permissive license.* - -Tips: +# FalconMamba -- FalconMamba is mostly based on Mamba architecture, the same [tips and best practices](./mamba) would be relevant here. +[FalconMamba](https://huggingface.co/papers/2410.05355) is a 7B large language model, available as pretrained and instruction-tuned variants, based on the [Mamba](./mamba). This model implements a pure Mamba design that focuses on computational efficiency while maintaining strong performance. FalconMamba is significantly faster at inference and requires substantially less memory for long sequence generation. The models are pretrained on a diverse 5.8T token dataset including [RefinedWeb](https://huggingface.co/datasets/tiiuae/falcon-refinedweb), technical content, code, and mathematical data. -The model has been trained on approximtely 6T tokens consisting a mixture of many data sources such as RefineWeb, Cosmopedia and Math data. +You can find the official FalconMamba checkpoints in the [FalconMamba 7B](https://huggingface.co/collections/tiiuae/falconmamba-7b-66b9a580324dd1598b0f6d4a) collection. -For more details about the training procedure and the architecture, have a look at [the technical paper of FalconMamba]() (coming soon). +> [!TIP] +> Click on the FalconMamba models in the right sidebar for more examples of how to apply FalconMamba to different language tasks. -# Usage +The examples below demonstrate how to generate text with [`Pipeline`], [`AutoModel`], and from the command line. -Below we demonstrate how to use the model: + + -```python -from transformers import FalconMambaForCausalLM, AutoTokenizer +```py import torch - -tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") -model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b") - -input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] - -out = model.generate(input_ids, max_new_tokens=10) -print(tokenizer.batch_decode(out)) +from transformers import pipeline + +pipeline = pipeline( + "text-generation", + model="tiiuae/falcon-mamba-7b-instruct", + torch_dtype=torch.bfloat16, + device=0 +) +pipeline( + "Explain the difference between transformers and SSMs", + max_length=100, + do_sample=True, + temperature=0.7 +) ``` -The architecture is also compatible with `torch.compile` for faster generation: + + -```python -from transformers import FalconMambaForCausalLM, AutoTokenizer +```py import torch +from transformers import AutoTokenizer, AutoModelForCausalLM -tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") -model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", torch_dtype=torch.bfloat16).to(0) -model = torch.compile(model) +tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b-instruct") +model = AutoModelForCausalLM.from_pretrained( + "tiiuae/falcon-mamba-7b-instruct", + torch_dtype=torch.bfloat16, + device_map="auto" +) -input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] +input_ids = tokenizer("Explain the difference between transformers and SSMs", return_tensors="pt").to("cuda") -out = model.generate(input_ids, max_new_tokens=10) -print(tokenizer.batch_decode(out)) +output = model.generate(**input_ids, max_new_tokens=100, cache_implementation="static") +print(tokenizer.decode(output[0], skip_special_tokens=True)) ``` -If you have access to a GPU that is compatible with `bitsandbytes`, you can also quantize the model in 4-bit precision: + + -```python -from transformers import FalconMambaForCausalLM, AutoTokenizer, BitsAndBytesConfig -import torch - -tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") -quantization_config = BitsAndBytesConfig(load_in_4bit=True) -model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b", quantization_config=quantization_config) +```bash +transformers-cli chat --model_name_or_path tiiuae/falcon-mamba-7b-instruct --torch_dtype auto --device 0 +``` -input_ids = tokenizer("Hey how are you doing?", return_tensors= "pt")["input_ids"] + + -out = model.generate(input_ids, max_new_tokens=10) -print(tokenizer.batch_decode(out)) -``` +Quantization reduces the memory burden of large models by representing the weights in a lower precision. Refer to the [Quantization](../quantization/overview) overview for more available quantization backends. -You can also play with the instruction fine-tuned model: +The example below uses [bitsandbytes](../quantization/bitsandbytes) to quantize the weights to 4-bits. -```python -from transformers import FalconMambaForCausalLM, AutoTokenizer +```python import torch +from transformers import AutoTokenizer, FalconMambaForCausalLM, BitsAndBytesConfig -tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b-instruct") -model = FalconMambaForCausalLM.from_pretrained("tiiuae/falcon-mamba-7b-instruct") +quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=torch.bfloat16, + bnb_4bit_quant_type="nf4", + bnb_4bit_use_double_quant=True, +) -# We use the tokenizer's chat template to format each message - see https://huggingface.co/docs/transformers/main/en/chat_templating -messages = [ - {"role": "user", "content": "How many helicopters can a human eat in one sitting?"}, -] -input_ids = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True).input_ids - -outputs = model.generate(input_ids) -print(tokenizer.decode(outputs[0])) +tokenizer = AutoTokenizer.from_pretrained("tiiuae/falcon-mamba-7b") +model = FalconMambaForCausalLM.from_pretrained( + "tiiuae/falcon-mamba-7b", + torch_dtype=torch.bfloat16, + device_map="auto", + quantization_config=quantization_config, +) + +inputs = tokenizer("Explain the concept of state space models in simple terms", return_tensors="pt").to("cuda") +outputs = model.generate(**inputs, max_new_tokens=100) +print(tokenizer.decode(outputs[0], skip_special_tokens=True)) ``` ## FalconMambaConfig