| | --- |
| | language: |
| | - en |
| | - hi |
| | license: llama2 |
| | tags: |
| | - multilingual |
| | - instruction-tuning |
| | - llama2 |
| | datasets: |
| | - ai4bharat/indic-instruct-data-v0.1 |
| | model-index: |
| | - name: Airavata |
| | results: |
| | - task: |
| | type: text-generation |
| | name: Text Generation |
| | dataset: |
| | name: AI2 Reasoning Challenge (25-Shot) |
| | type: ai2_arc |
| | config: ARC-Challenge |
| | split: test |
| | args: |
| | num_few_shot: 25 |
| | metrics: |
| | - type: acc_norm |
| | value: 46.5 |
| | name: normalized accuracy |
| | source: |
| | url: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=ai4bharat/Airavata |
| | name: Open LLM Leaderboard |
| | - task: |
| | type: text-generation |
| | name: Text Generation |
| | dataset: |
| | name: HellaSwag (10-Shot) |
| | type: hellaswag |
| | split: validation |
| | args: |
| | num_few_shot: 10 |
| | metrics: |
| | - type: acc_norm |
| | value: 69.26 |
| | name: normalized accuracy |
| | source: |
| | url: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=ai4bharat/Airavata |
| | name: Open LLM Leaderboard |
| | - task: |
| | type: text-generation |
| | name: Text Generation |
| | dataset: |
| | name: MMLU (5-Shot) |
| | type: cais/mmlu |
| | config: all |
| | split: test |
| | args: |
| | num_few_shot: 5 |
| | metrics: |
| | - type: acc |
| | value: 43.9 |
| | name: accuracy |
| | source: |
| | url: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=ai4bharat/Airavata |
| | name: Open LLM Leaderboard |
| | - task: |
| | type: text-generation |
| | name: Text Generation |
| | dataset: |
| | name: TruthfulQA (0-shot) |
| | type: truthful_qa |
| | config: multiple_choice |
| | split: validation |
| | args: |
| | num_few_shot: 0 |
| | metrics: |
| | - type: mc2 |
| | value: 40.62 |
| | source: |
| | url: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=ai4bharat/Airavata |
| | name: Open LLM Leaderboard |
| | - task: |
| | type: text-generation |
| | name: Text Generation |
| | dataset: |
| | name: Winogrande (5-shot) |
| | type: winogrande |
| | config: winogrande_xl |
| | split: validation |
| | args: |
| | num_few_shot: 5 |
| | metrics: |
| | - type: acc |
| | value: 68.82 |
| | name: accuracy |
| | source: |
| | url: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=ai4bharat/Airavata |
| | name: Open LLM Leaderboard |
| | - task: |
| | type: text-generation |
| | name: Text Generation |
| | dataset: |
| | name: GSM8k (5-shot) |
| | type: gsm8k |
| | config: main |
| | split: test |
| | args: |
| | num_few_shot: 5 |
| | metrics: |
| | - type: acc |
| | value: 4.02 |
| | name: accuracy |
| | source: |
| | url: https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard?query=ai4bharat/Airavata |
| | name: Open LLM Leaderboard |
| | --- |
| | |
| | # Airavata |
| |
|
| | This model is a 7B [OpenHathi](https://huggingface.co/sarvamai/OpenHathi-7B-Hi-v0.1-Base) model finetuned on [IndicInstruct dataset](https://huggingface.co/datasets/ai4bharat/indic-instruct-data-v0.1) |
| | which is a collection of instruction datasets (Anudesh, wikiHow, Flan v2, Dolly, Anthropic-HHH, OpenAssistant v1, and LymSys-Chat). |
| | Please check the corresponding huggingface dataset card for more details. |
| |
|
| | This was trained as part of the technical report [Airavata: Introducing Hindi Instruction-tuned LLM](https://arxiv.org/abs/2401.15006). |
| | The codebase used to train and evaluate this model can be found at [https://github.com/AI4Bharat/IndicInstruct](https://github.com/AI4Bharat/IndicInstruct). |
| |
|
| | ## Usage |
| |
|
| | Clone [https://github.com/AI4Bharat/IndicInstruct](https://github.com/AI4Bharat/IndicInstruct) and install the required dependencies. Then download or clone this model to the same machine. |
| |
|
| | ## Input Format |
| |
|
| | The model is trained to use the chat format similar to [open-instruct code repository](https://github.com/allenai/open-instruct) (note the newlines): |
| | ``` |
| | <|user|> |
| | Your message here! |
| | <|assistant|> |
| | ``` |
| |
|
| | For best results, format all inputs in this manner. **Make sure to include a newline after `<|assistant|>`, this can affect generation quality quite a bit.** |
| |
|
| | ## Hyperparameters |
| |
|
| | We fine-tune OpenHathi base model on the aforementioned IndicInstruct dataset with LoRA. The hyperparameters for the LoRA fine-tuning are listed below: |
| | - LoRA Rank: 16 |
| | - LoRA alpha: 32 |
| | - LoRA Dropout: 0.05 |
| | - LoRA Target Modules: ["q_proj", "v_proj", "k_proj", "down_proj", "gate_proj", "up_proj"] |
| | - Epochs: 4 |
| | - Learning rate: 5e-4 |
| | - Batch Size: 128 |
| | - Floating Point Precision: bfloat16 |
| |
|
| | We recommend the readers to check out [our official blog post](https://ai4bharat.github.io/airavata) for more details on the model training, ablations and evaluation results. |
| |
|
| | ## Example |
| |
|
| | ```python3 |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | |
| | device = "cuda" if torch.cuda.is_available() else "cpu" |
| | |
| | |
| | def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True): |
| | formatted_text = "" |
| | for message in messages: |
| | if message["role"] == "system": |
| | formatted_text += "<|system|>\n" + message["content"] + "\n" |
| | elif message["role"] == "user": |
| | formatted_text += "<|user|>\n" + message["content"] + "\n" |
| | elif message["role"] == "assistant": |
| | formatted_text += "<|assistant|>\n" + message["content"].strip() + eos + "\n" |
| | else: |
| | raise ValueError( |
| | "Tulu chat template only supports 'system', 'user' and 'assistant' roles. Invalid role: {}.".format( |
| | message["role"] |
| | ) |
| | ) |
| | formatted_text += "<|assistant|>\n" |
| | formatted_text = bos + formatted_text if add_bos else formatted_text |
| | return formatted_text |
| | |
| | |
| | def inference(input_prompts, model, tokenizer): |
| | input_prompts = [ |
| | create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False) |
| | for input_prompt in input_prompts |
| | ] |
| | |
| | encodings = tokenizer(input_prompts, padding=True, return_tensors="pt") |
| | encodings = encodings.to(device) |
| | |
| | with torch.inference_mode(): |
| | outputs = model.generate(encodings.input_ids, do_sample=False, max_new_tokens=250) |
| | |
| | output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True) |
| | |
| | input_prompts = [ |
| | tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts |
| | ] |
| | output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)] |
| | return output_texts |
| | |
| | |
| | model_name = "ai4bharat/Airavata" |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left") |
| | tokenizer.pad_token = tokenizer.eos_token |
| | model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) |
| | |
| | input_prompts = [ |
| | "मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं।", |
| | "मैं अपने समय प्रबंधन कौशल को कैसे सुधार सकता हूँ? मुझे पांच बिंदु बताएं और उनका वर्णन करें।", |
| | ] |
| | outputs = inference(input_prompts, model, tokenizer) |
| | print(outputs) |
| | ``` |
| |
|
| | ## Citation |
| |
|
| | ```bibtex |
| | @article{gala2024airavata, |
| | title = {Airavata: Introducing Hindi Instruction-tuned LLM}, |
| | author = {Jay Gala and Thanmay Jayakumar and Jaavid Aktar Husain and Aswanth Kumar M and Mohammed Safi Ur Rahman Khan and Diptesh Kanojia and Ratish Puduppully and Mitesh M. Khapra and Raj Dabre and Rudra Murthy and Anoop Kunchukuttan}, |
| | year = {2024}, |
| | journal = {arXiv preprint arXiv: 2401.15006} |
| | } |
| | ``` |
| |
|
| | # [Open LLM Leaderboard Evaluation Results](https://huggingface.co/spaces/HuggingFaceH4/open_llm_leaderboard) |
| | Detailed results can be found [here](https://huggingface.co/datasets/open-llm-leaderboard/details_ai4bharat__Airavata) |
| |
|
| | | Metric |Value| |
| | |---------------------------------|----:| |
| | |Avg. |45.52| |
| | |AI2 Reasoning Challenge (25-Shot)|46.50| |
| | |HellaSwag (10-Shot) |69.26| |
| | |MMLU (5-Shot) |43.90| |
| | |TruthfulQA (0-shot) |40.62| |
| | |Winogrande (5-shot) |68.82| |
| | |GSM8k (5-shot) | 4.02| |
| |
|
| |
|