LLMs are machine learning classifiers
Learn how to use LLMs like GPT for text classification. Explore prompting, fine-tuning, and when to choose LLMs over traditional machine learning classifiers.
Created on January 27|Last edited on March 10
Comment
When you hear "Large Language Models" or "LLMs", you probably think of ChatGPT writing essays or Claude assisting with code. However, LLMs are not just tools for free-form text generation - they can also excel at structured tasks like classification." In fact, LLMs can serve as powerful classifiers, categorizing text into predefined labels just as efficiently as they generate responses.
Traditionally, classification tasks have relied on structured machine learning approaches like decision trees, support vector machines, and neural networks. These methods work well with structured data, but real-world text data is often noisy, ambiguous, and context-dependent. This is where LLMs offer a compelling alternative. With their ability to understand language holistically, they can classify text with minimal manual feature engineering.
In this post, we'll explore the different ways to harness LLMs for classification. We’ll cover techniques ranging from simple prompting to few-shot learning and fine-tuning, along with strategies for ensuring reliable, production-ready results. Finally, we’ll examine how LLM-based classification compares to traditional approaches and when it makes sense to use LLMs over specialized encoder models like BERT.
By the end, you'll know exactly when (and when not) to use an LLM for text classification.
Table of contents
What are machine learning classifiers?Traditional machine learning classifiersCommon approaches to machine learning classificationEncoders vs. decoders in text classificationEncoder transformers: The industry standard for text classificationHow encoder-based transformers work (click for details)Using encoders for text classificationLLMs as machine learning classifiersHow LLMs are trainedWhy LLMs can be effective for text classificationThe decoder architecture: Using LLMs for classificationThe challenge of using LLMs for classificationClassification approaches with decoder LLMsStructured outputs with LLMsFine-tuning LLMs for classificationFine-tuning your own classification LLM modelWhen should you use LLMs as machine learning classifiers ?Conclusion
What are machine learning classifiers?
Machine learning classifiers are at the core of many AI-driven systems, helping categorize data into predefined classes. From filtering spam emails to detecting fraudulent transactions, these models power a wide range of real-world applications.
Traditional classifiers excel at handling structured data, where clear numerical or categorical patterns exist. However, real-world text data is often noisy and unstructured, requiring more sophisticated methods to achieve reliable classification.

Traditional machine learning classifiers
Before we explore using LLMs for classification, let’s briefly examine traditional classifiers to understand their strengths and weaknesses. These models have been widely used for text classification and remain effective for many applications. However, they also come with significant limitations that make LLMs a compelling alternative in some cases.
Common approaches to machine learning classification
- Decision Trees: Create a hierarchical structure of if-then rules based on feature values. They are interpretable and computationally efficient but struggle with high-dimensional text data.
- Naive Bayes: A probabilistic approach based on Bayes' theorem. It is fast and effective for small datasets but assumes feature independence, which is problematic for text classification.
- Support Vector Machines (SVMs): Finds optimal decision boundaries in high-dimensional space. SVMs perform well on text data but require significant feature engineering.
- K-Nearest Neighbors (KNN): Classifies text based on the most similar existing examples. While simple, KNN is slow for large datasets and struggles with high-dimensional text.
Text preprocessing in traditional classification typically involves:
Each classifier has its own way of processing and understanding text, making them suitable for different tasks.
- Tokenization, removing stopwords, stemming/lemmatization.
- Converting the words into numerical vectors using techniques like:
- Bag of Words (BoW): Simple word frequency counts, ignoring order.
- TF-IDF: Term frequency-inverse document frequency, which balances local and global word importance.
- N-grams: Capturing sequences of N consecutive words.
- Often some manual feature engineering is required to get a good performance, creating features like getting the counts of the words, frequency-based features for important domain terms, etc.
Challenges with traditional classification methods
In many cases, manual feature engineering is required to boost performance. This might include creating domain-specific term frequency statistics or handcrafting linguistic features. However, these methods have notable drawbacks:
- They lose important contextual information, such as word order and meaning.
- They struggle with out-of-distribution data - if new words or phrases appear that weren’t seen in training, performance can drop significantly.
- Domain shifts can reduce accuracy. For instance, a sentiment classifier trained on movie reviews might struggle when applied to restaurant reviews because the vocabulary and context change.
Traditional classifiers rely heavily on manual feature engineering, such as tokenization, TF-IDF weighting, and n-gram modeling. Unfortunately, these methods often lose contextual meaning, making them less effective for nuanced classification tasks.
Encoders vs. decoders in text classification
To process text, machine learning models typically fall into two main categories: encoders and decoders.
- Encoders (e.g., BERT, DeBERTa): Convert text into fixed-length vector representations, capturing contextual information. These models are widely used for classification.
- Decoders (e.g., GPT, LLaMA): Generate text token by token. While primarily used for text generation, they can also be adapted for classification tasks.
Encoders have traditionally been preferred for classification because they extract representations, while decoders generate text, making their outputs less structured. However, modern decoders can be constrained to structured outputs, enabling them to serve as classifiers effectively.
Encoder transformers: The industry standard for text classification
Before LLMs, encoder-based transformers like BERT revolutionized text classification by learning deep contextual representations. Unlike traditional classifiers, BERT processes words in relation to their entire sentence, capturing context and meaning in a way that older methods cannot.
How encoder-based transformers work (click for details)
Using encoders for text classification
To use encoders for text classification we append a special [CLS] token to our input text (which is tokenized). The encoder architecture processes text through multiple self-attention and feedforward layers. Each token in the input sequence can attend to all other tokens, using self-attention, enabling the model to capture long-range dependencies and contextual relationships. Through these layers, the [CLS] token aggregates information from the entire sequence. The final representation (a dense vector) of this token is then used for classification through a simple feedforward linear layer.

The self-attention mechanism's ability to process all tokens in parallel (unlike RNNs) also makes these models efficiently trainable on modern hardware. Almost no manual feature engineering is required, they work directly with raw text (which is tokenized through the model's respective tokenizer).
LLMs as machine learning classifiers
When we talk about Large Language Models today, we're typically referring to decoder-based architectures like LLaMA, Qwen, Mistral, and Phi. These models are trained on massive text datasets containing trillions of tokens, giving them an impressive grasp of language patterns, context, and even domain-specific knowledge.
Despite their reputation for text generation, LLMs can also be adapted for structured tasks like classification. Their ability to understand nuanced language, interpret complex context, and generalize across different domains makes them powerful classifiers when used correctly. However, to fully leverage LLMs for classification, it’s important to understand how they are developed and trained.
How LLMs are trained
The development of LLMs generally follows a three-stage process
Pretraining: Learning the fundamentals
This is the foundation stage, where the model learns to predict the next token in a sequence. Given a piece of text, the model predicts what comes next, token by token. Through this process, it develops an understanding of grammar, facts, and basic reasoning.
- This stage involves training on vast datasets, often scraped from the internet, books, and research papers.
- You can think of pretraining as compressing the internet into a single model—though the process is lossy, meaning the model doesn’t store raw data but instead learns generalizable patterns.
- However, after this stage, the model is essentially just a text completion engine—if you give it a question, it might simply generate more questions rather than providing useful answers.
Fine-tuning: From text generator to task-specific model
At this stage, the model is refined to perform useful tasks rather than just continuing text sequences. The core training objective remains next-token prediction, but the dataset is replaced with carefully curated human-labeled examples.
- Instead of learning from generic internet data, the model is exposed to instruction-following datasets, improving its ability to generate relevant and structured responses.
- This transforms the model from a text predictor into an assistant that follows commands, making it more useful for real-world applications.
- Human annotations are often used in this phase, though many organizations now rely on synthetically generated fine-tuning data to scale the process.
Alignment: Ensuring the model is safe and useful
In the final stage, the model is aligned with human values and preferences. Various techniques, such as Reinforcement Learning from Human Feedback (RLHF) and newer methods like DPO, GRPO, and ORPO, are used to:
- Ensure that responses align with user expectations rather than simply maximizing token probability.
Why LLMs can be effective for text classification
Although LLMs are primarily trained for text generation, their deep language understanding makes them highly capable classifiers - if we can effectively constrain their outputs. Techniques like structured prompting, few-shot learning, and fine-tuning allow us to harness their reasoning capabilities while ensuring predictable outputs.
In the following sections, we’ll explore how to adapt LLMs for classification, turning them from general-purpose text generators into reliable, production-ready classifiers.
But first, let’s briefly examine decoder architectures, as understanding their structure will help explain how we can refine their behavior for classification.
The decoder architecture: Using LLMs for classification
The decoder architecture shares common components with encoders - attention mechanisms, feedforward networks, and layer normalization. However, it's specifically designed for generating text one token at a time. The final layer contains a language modeling head that outputs probabilities across the model's entire vocabulary, from which we sample the most probable next token.
A crucial distinction from encoders is the causal attention mask. This mask prevents the model from looking at future tokens during training, ensuring the model learns to generate text based only on previous context. This architectural choice is fundamental to how these models operate.

A standard decoder architecture
The image above shows the stacked layers culminating in the language modeling head, which provides probability distributions over the model's vocabulary. Text generation happens sequentially, with each generated token being appended to the input for the next prediction.
The challenge of using LLMs for classification
A straightforward approach to classification with LLMs might involve simple prompting, such as:
Classify the given customer review as either "negative" or "positive":{Review}
This seems easy enough, but it comes with a major flaw: decoder-based LLMs are designed to generate freeform text, making their outputs unpredictable. A model might respond with a paragraph-long explanation instead of the expected "positive" or "negative," making integration into structured applications challenging.
Thankfully, there are effective ways to constrain and optimize LLM outputs for classification tasks. In this post, we’ll explore various techniques, including:
- Prompt engineering for structured outputs
- Few-shot learning to improve reliability
- Fine-tuning to align LLM behavior with classification objectives
We’ll also compare LLM-based classification to traditional encoder-based models and discuss when it makes sense to use one approach over the other. By the end, you’ll have a clear understanding of how to harness LLMs for robust, production-ready classification.
But before we explore how to refine LLMs for classification, it's important to understand how traditional machine learning classifiers work. These models have been widely used for text classification and are still effective for many applications. However, they also come with significant limitations that make LLMs a compelling alternative in some cases.
Classification approaches with decoder LLMs
While these models excel at text generation, they can be effectively adapted for classification tasks. Let's explore them:
Basic prompting
For scenarios with limited or no training data, we can directly prompt the model to classify text into predefined categories. For example:
Classify the following text as either "positive" or "negative":Customer review: "The product exceeded my expectations!"
Few-shot learning
Few-shot prompting enhances classification accuracy by providing examples in the prompt itself. The model learns from these examples to make better predictions. For instance:
Classify the following reviews as positive or negative:Review: "The screen broke in a week"Classification: negativeReview: "Fast delivery and great quality"Classification: positiveReview: "The product exceeded my expectations!"Classification:
While more few-shot examples generally improve performance, they also increase inference time.
While basic prompting and few-shot learning offer powerful ways to leverage LLMs for classification, their outputs can be inconsistent. This is where structured outputs come in. By enforcing a predefined format, we can make the outputs of prompted or few-shot LLMs more reliable and suitable for production systems.
Which brings us to ...
Structured outputs with LLMs
While prompting and few-shot learning are powerful, they're not always reliable for production systems. Since these models are inherently chatty we can't trust them to always give the exact class name as output. Sometimes it may yap a paragraph about why it thinks the text should be classified into a certain category, sometimes it might return the results but not in the required format which makes it tricky to integrate them into production applications where we require a certain exact output.
To alleviate this problem most of the LLMs today supports structured outputs. This ensures that LLM outputs conform to specific, predefined formats and schemas rather than producing free-form texts. Enforcing structured output is crucial for classification tasks because it ensures:
- Consistency: The model always returns the classification label in the expected format.
- Reliability: We can trust the output to be a valid class label, making it easy to integrate into downstream systems.
- Efficiency: No need for complex parsing or post-processing of the model's output.
Modern LLM APIs support structured outputs that conform to specific schemas.
Structured outputs using OpenAI's API:
from pydantic import BaseModelfrom openai import OpenAIfrom enum import Enumclient = OpenAI()class Sentiment(str, Enum):POSITIVE = "positive"NEGATIVE = "negative"NEUTRAL = "neutral"class SentimentModel(BaseModel):sentiment: strreasoning: str # Optional field for explanationcompletion = client.beta.chat.completions.parse(model="gpt-4o",messages=[{"role": "system", "content": "Classify the below customer review."},{"role": "user", "content": "This furniture broke in one month of usage"},],response_format=SentimentModel,)output = completion.choices[0].message.parsedsentiment = output["sentiment"]reasoning = output["reasoning"]
Frameworks like Outlines or xgrammar achieve structured outputs by guiding the LLM's generation process. They essentially restrict the model's vocabulary at each step, ensuring that the output conforms to the desired format. This is often done through techniques like grammar-based sampling or by modifying the model's logits (probability scores) to only allow valid tokens.
Structured output using Outlines with Llama-3.1-8B:
import outlinesfrom outlines.samplers import greedymodel = outlines.models.transformers("meta-llama/Llama-3.1-8B", device="cuda")generator = outlines.generate.choice(model,["positive", "negative", "neutral"],sampler=greedy())prompt = "This furniture broke in one month of usage"label = generator(prompt) # Returns "negative"
These frameworks work by masking unwanted tokens in the language modeling head, ensuring the model only generates tokens within our specified schema. This makes LLM-based classification reliable and production-ready.
The combination of powerful language understanding from pretraining and structured output constraints makes modern LLMs excellent candidates for robust classification systems.
If you want similar reasoning as in the OpenAI example we can use json decoding as well.
from pydantic import BaseModelimport outlinesfrom outlines.samplers import greedyfrom enum import Enummodel = outlines.models.transformers("meta-llama/Llama-3.1-8B", device="cuda")class Sentiment(str, Enum):POSITIVE = "positive"NEGATIVE = "negative"NEUTRAL = "neutral"class SentimentModel(BaseModel):sentiment: strreasoning: str # optional if you want the reasoningprompt = "This furniture broke in one month of usage"generator = outlines.generate.json(model, SentimentModel, sampler=greedy())labels = generator(prompt)
While prompting with structured outputs can be effective, fine-tuning offers another powerful approach to adapt LLMs for classification, especially when you have labeled data available."
Fine-tuning LLMs for classification
While prompting works well, fine-tuning LLMs with labeled data can yield much stronger results. To do this, we slightly modify the final layer of the language modeling head. Instead of predicting over the entire vocabulary, we extract the token embedding with the most sequence information and pass it through a simple feedforward layer, similar to how encoders use the [CLS] token.
However, in decoders, we can't use a [CLS] token because they are causal - they only attend to past tokens, not future ones. This means the first token has the least context, while the last token's embedding holds the most information, making it the best choice for classification.
Padding tokens also require careful handling. In text generation, padding is ignored due to causal masking, but in classification, using the last token's embedding directly can lead to issues if it's a padding token. Special handling ensures the correct token is used.
Let's look at Hugging Face's implementation of LLamaForSequenceClassification which demonstrates how this is handled across decoder models for classification (from huggingface's code):
class LlamaForSequenceClassification(LlamaPreTrainedModel):def __init__(self, config):super().__init__(config)self.num_labels = config.num_labelsself.model = LlamaModel(config)self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)# Initialize weights and apply final processingself.post_init()def get_input_embeddings(self):return self.model.embed_tokensdef set_input_embeddings(self, value):self.model.embed_tokens = value@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)def forward(self,input_ids: Optional[torch.LongTensor] = None,attention_mask: Optional[torch.Tensor] = None,position_ids: Optional[torch.LongTensor] = None,past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,inputs_embeds: Optional[torch.FloatTensor] = None,labels: Optional[torch.LongTensor] = None,use_cache: Optional[bool] = None,output_attentions: Optional[bool] = None,output_hidden_states: Optional[bool] = None,return_dict: Optional[bool] = None,) -> Union[Tuple, SequenceClassifierOutputWithPast]:r"""labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If`config.num_labels > 1` a classification loss is computed (Cross-Entropy)."""return_dict = return_dict if return_dict is not None else self.config.use_return_dicttransformer_outputs = self.model(input_ids,attention_mask=attention_mask,position_ids=position_ids,past_key_values=past_key_values,inputs_embeds=inputs_embeds,use_cache=use_cache,output_attentions=output_attentions,output_hidden_states=output_hidden_states,return_dict=return_dict,)hidden_states = transformer_outputs[0]logits = self.score(hidden_states)if input_ids is not None:batch_size = input_ids.shape[0]else:batch_size = inputs_embeds.shape[0]if self.config.pad_token_id is None and batch_size != 1:raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")if self.config.pad_token_id is None:sequence_lengths = -1else:if input_ids is not None:# if no pad token found, use modulo instead of reverse indexing for ONNX compatibilitysequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1sequence_lengths = sequence_lengths % input_ids.shape[-1]sequence_lengths = sequence_lengths.to(logits.device)else:sequence_lengths = -1pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]loss = Noneif labels is not None:loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)if not return_dict:output = (pooled_logits,) + transformer_outputs[1:]return ((loss,) + output) if loss is not None else outputreturn SequenceClassifierOutputWithPast(loss=loss,logits=pooled_logits,past_key_values=transformer_outputs.past_key_values,hidden_states=transformer_outputs.hidden_states,attentions=transformer_outputs.attentions,)
Here’s how it works:
- We start with the default decoder base model LLamaModel
- Add a simple linear layer on top: self.score = nn.Linear(config.hidden_size, self.num_labels)
- In the forward function, we get the hidden states from the base model - these are contextualized embeddings of each input token.
- We pass these hidden states through self.score to get logits of shape (batch_size, num_input_tokens, num_labels). Among these we only need the last token’s logits since it contains the most information.
- We need the last non-padding token's logits as it contains the most information
- The crucial part is handling padding tokens. Let's break down this code:
if input_ids is not None:# if no pad token found, use modulo instead of reverse indexing for ONNX compatibilitysequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1sequence_lengths = sequence_lengths % input_ids.shape[-1]sequence_lengths = sequence_lengths.to(logits.device)
- torch.eq(input_ids, self.config.pad_token_id) compares each element in input_ids with the padding token ID, returning a boolean tensor.
- .int() converts the boolean values to integers (0s and 1s).
- .argmax(-1) finds the first occurrence of the padding token in each sequence.
- The -1 subtraction adjusts the index to point to the last non-padding token (this is important).
- The modulo operation (%) ensures the index stays within the valid range of the sequence length.
- After finding these indices, we get the logits for our entire batch:
pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
- This gives us a tensor of shape (batch_size, num_labels) containing the logits of each sequence’s last non-padding token. Finally, we calculate the cross-entropy loss using our labels to train the model.
To summarize, we take the base decoder model, pass our tokenized input into the decoder and get contextualized embeddings (hidden_states) from the last layer of decoder. We then pass those hidden states through a linear layer to get the logits for each of the input tokens. We find the last non-padding token in the input sequence and take the logits for it and calculate the final cross-entropy loss.
The complete working of the above process is also illustrated in the diagram below:

Fine-tuning your own classification LLM model
Let's walk through a practical, end-to-end tutorial on how to finetune an LLM for classification tasks. We'll use a sentiment analysis task as our example (to keep things simple), but you can adapt this code for any classification problem.
We will need a decent GPU to train these models. If you don't have one locally, you can rent GPUs for very cheap prices on many cloud providers like Jarvislabs.ai. Here we used an A5000 instance from Jarvislabs.ai
💡
If you want the complete code for this tutorial to try it out and take it for a spin, it's on my GitHub Repo.
💡
Loading the data
First, we need data to work with. For this tutorial, we'll use the IMDB movie reviews dataset from Hugging Face, which is perfect for binary sentiment classification. Of course, you can substitute this with any dataset that matches your specific use case.
from datasets import load_datasetds = load_dataset("imdb")# get the labels in the datasetlabel_list = ds["train"].features["label"].namesnum_labels = len(label_list)# split into training and validation datasettrain_ds, eval_ds = ds["train"], ds["test"]
The IMDB dataset is well-balanced, containing:
- 25,000 training samples
- 25,000 testing samples
- Binary labels: positive (1) and negative (0)
Processing the data
Now we need to transform our text data into a format our LLM can understand. Also, we will use a very small LLM (:P) for faster iterations.
It is important to note that for this tutorial, we use a smaller model (Qwen/Qwen_0.5B) for faster iterations and lower resource requirements. However, in a real-world scenario, you'd likely achieve better results with a larger model (like Qwen 2.5 7B or larger) at the cost of increased training time and computational resources.
💡
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
Now let's tokenizer our dataset, we will truncate larger sequences to 512 tokens. Below tokenize_func tells how a single example should be tokenized and we apply the function to the complete dataset using datasets map function which efficiently applies the function in batched manner.
def tokenize_func(example):return tokenizer(example["text"],add_special_tokens=True,truncation=True,max_length=512,)train_ds = train_ds.map(tokenize_func, batched=True, desc="Tokenizing train data")eval_ds = eval_ds.map(tokenize_func, batched=True, desc="Tokenizing eval data")
The tokenization process adds two new columns to our dataset:
- input_ids: The integer tokens representing our text
- attention_mask: A binary mask indicating which tokens are real (1) vs padding (0)
This processed format is what our LLM model expects as input. The input_ids will be used to get token embeddings, while the attention_mask ensures the model only pays attention to actual content and not padding.
{'text': 'I first saw "Breaking Glass" in 1980, and thought that it would be one of the "Movie Classics". This film is a great look into the music industry with a great cast of performers. This is one film that should be in the collection of everyone and any one that wants to get into the music industry. I can\'t wait for it to be available on DVD.','label': 1,'input_ids': [40, 1156, 5485, 330, 60179, 20734, 1, 304, 220, 16, 24, 23, 15, 11, 323, 3381, 429, 432, 1035, 387, 825, 315, 279, 330, 19668, 72315, 3263, 1096, 4531, 374, 264, 2244, 1401, 1119, 279, 4627, 4958, 448, 264, 2244, 6311, 315, 44418, 13, 1096, 374, 825, 4531, 429, 1265, 387, 304, 279, 4426, 315, 5019, 323, 894, 825, 429, 6801, 311, 633, 1119, 279, 4627, 4958, 13, 358, 646, 944, 3783, 369, 432, 311, 387, 2500, 389, 18092, 13],'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
Loading the model
For our classification task, we'll use Hugging Face’s AutoModelForSequenceClassification, which already incorporates the necessary architectural modifications. However, fine-tuning large models can be memory-intensive, requiring high-end GPUs.
To make fine-tuning more efficient and accessible, we use QLoRA (Quantized Low-Rank Adaptation) - a combination of two powerful techniques:
- 4-bit quantization, which reduces memory usage.
Together, these optimizations dramatically reduce VRAM requirements, making it possible to fine-tune even large LLMs (e.g., LLaMA-3, Qwen-2.5, Mistral) on consumer hardware.

At the heart of LoRA is a simple but powerful mathematical concept: instead of updating all model parameters, LoRA adds small, trainable "rank decomposition matrices" to existing model weights. These matrices allow the base model to remain frozen, while fine-tuning occurs only on the added layers
Example: LoRA parameter reduction
Imagine a model layer with a 512 × 512 weight matrix (which has 262,144 parameters). Instead of fine-tuning the entire matrix, LoRA factorizes it into two much smaller matrices:
- 512 × 8 matrix
- An 8 × 512 matrix
- Total trainable parameters: 8,192 (instead of 262,144 🔽 32x reduction!)
This low-rank adaptation preserves most of the original model’s knowledge while making fine-tuning dramatically cheaper and faster.
The role of QLoRA: Adding 4-Bit quantization
LoRA alone reduces trainable parameters, but QLoRA takes it a step further by reducing memory usage.
How? It stores model weights in 4-bit precision instead of the usual 16-bit or 32-bit format.
- Reduces GPU memory footprint (critical for large models).
- Enables fine-tuning of 7B+ parameter models on a single GPU.
- Maintains near full-precision performance despite lower-bit storage.
While quantization may not seem significant for a 0.5B parameter model, it becomes essential when dealing with 7B, 14B, or 32B parameter models, making large-scale fine-tuning accessible without expensive hardware.
If this sounds complex, don’t worry - the Hugging Face peft and bitsandbytes libraries have made LoRA and QLoRA simple to integrate into training pipelines. As we'll see next, these tools allow you to fine-tune massive LLMs with just a few lines of code, leveraging both LoRA efficiency and 4-bit quantization.
import torchfrom transformers import (BitsAndBytesConfig,AutoModelForSequenceClassification,)from peft import (LoraConfig,TaskType,get_peft_model,prepare_model_for_kbit_training,)# loading the modelmodel_kwargs = dict(trust_remote_code=True,torch_dtype=torch.bfloat16,attn_implementation="flash_attention_2",)model_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_use_double_quant=True,bnb_4bit_compute_dtype="bfloat16",)model = AutoModelForSequenceClassification.from_pretrained("Qwen/Qwen2.5-0.5B", use_cache=False, num_labels=num_labels, **model_kwargs)model.config.pad_token_id = tokenizer.pad_token_idlora_config = LoraConfig(r=64,lora_alpha=64,target_modules="all-linear",lora_dropout=0.05,bias="none",task_type=TaskType.SEQ_CLS,modules_to_save=["score"],)model = prepare_model_for_kbit_training(model)model = get_peft_model(model, lora_config)model.print_trainable_parameters()
Let's go through the above code step-by-step:
First we define some key-word arguments which we would like to pass to the model.
- torch_dtype: we can use bfloat16 on modern GPU architecture (Ampere, Hopper) as they support it, for older GPU architecture (Pascal, Tesla) we need to use float16
- attn_implementation: there are a bunch of stacked attention layers in our transformer model. Self-attention operation is very compute intensive which has quadratic time and memory complexity. Flash attention just speeds that up using the kernel fusion trick. To learn more about in depth have a look at their paper.
- Next we define the quantization config which tells the model to be quantized in 4-bit.
- We also add the padding token to model since many of the model doesn't have it added and we need it to be able to run the model in batches as we learned earlier.
Then we define the LoRA configuration
- r: tells the rank we would want to have our internal matrices to be. Generally higher values perform better at the cost of more parameters.
- alpha: alpha works as a scaling factor applied to the low-rank weight updates, the effective scaling is calculated as alpha/rank during the forward pass.
- target_modules: this tells on what layers do we want to apply LoRA. Generally these are applied on linear layers of models. We can choose subset of those linear layers (like only attention layers) but applying LoRA on all linear layers of models generally have the best results.
- lora_dropout: randomly drops (sets to zero) certain parameters in LoRA's low-rank matrices during training to have a regularizing effect.
- bias: the bias parameter in LoRA controls how bias terms are handled during fine-tuning. none means no bias parameters are trained.
- task_type: as we are performing sequence classification here we set the task type to TaskType.SEQCLS here.
- modules_to_save: this specifies additional modules to be trained and saved alongside the LoRA layers. This allows us to train and save specific model modules beyond LoRA layers. This is typically used for custom model heads that are randomly initialized for fine-tuning tasks. As we are here using a custom model head score which is our final layer which does the classification, this parameter tells that we would also like to fintune this new layer and save alongside the model along with LoRA layers.
If you want a very detailed report on what each parameter does and much more I highly recommend to check out Sebastian Raschka's blog post.
Then we prepare the model for LoRA finetuning. At this point the model architecture looks like this:
PeftModelForSequenceClassification((base_model): LoraModel((model): Qwen2ForSequenceClassification((model): Qwen2Model((embed_tokens): Embedding(151936, 896)(layers): ModuleList((0-23): 24 x Qwen2DecoderLayer((self_attn): Qwen2Attention((q_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=896, out_features=896, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=896, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=896, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(k_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=896, out_features=128, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=896, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=128, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(v_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=896, out_features=128, bias=True)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=896, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=128, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(o_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=896, out_features=896, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=896, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=896, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict()))(mlp): Qwen2MLP((gate_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=896, out_features=4864, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=896, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=4864, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(up_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=896, out_features=4864, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=896, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=4864, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(down_proj): lora.Linear4bit((base_layer): Linear4bit(in_features=4864, out_features=896, bias=False)(lora_dropout): ModuleDict((default): Dropout(p=0.05, inplace=False))(lora_A): ModuleDict((default): Linear(in_features=4864, out_features=64, bias=False))(lora_B): ModuleDict((default): Linear(in_features=64, out_features=896, bias=False))(lora_embedding_A): ParameterDict()(lora_embedding_B): ParameterDict()(lora_magnitude_vector): ModuleDict())(act_fn): SiLU())(input_layernorm): Qwen2RMSNorm((896,), eps=1e-06)(post_attention_layernorm): Qwen2RMSNorm((896,), eps=1e-06)))(norm): Qwen2RMSNorm((896,), eps=1e-06)(rotary_emb): Qwen2RotaryEmbedding())(score): ModulesToSaveWrapper((original_module): Linear(in_features=896, out_features=2, bias=False)(modules_to_save): ModuleDict((default): Linear(in_features=896, out_features=2, bias=False))))))
Note that how the peft library has added the lora.Linear4bit layers which combine the LoRA matrices and 4-bit quantization.
Furthermore, we can see that how LoRA has reduced the number of trainable parameters dramatically using model.print_trainable_parameters(). The following is the output:
trainable params: 35,194,624 || all params: 529,229,184 || trainable%: 6.6502
Now we are almost ready to train our model. Before that we should define a function which will calculate some useful metrics for our evaluation data. We will use huggingface's evaluate library to get metric implementations ready to use.
import evaluate# Load all metricsaccuracy = evaluate.load("accuracy")f1_metric = evaluate.load("f1")precision_metric = evaluate.load("precision")recall_metric = evaluate.load("recall")def compute_metrics(eval_pred):predictions, labels = eval_predpredictions = np.argmax(predictions, axis=1)# Compute all metricsaccuracy_score = accuracy.compute(predictions=predictions, references=labels)["accuracy"]f1 = f1_metric.compute(predictions=predictions, references=labels, average="weighted")["f1"]precision = precision_metric.compute(predictions=predictions, references=labels, average="weighted")["precision"]recall = recall_metric.compute(predictions=predictions, references=labels, average="weighted")["recall"]return {"accuracy": accuracy_score,"f1": f1,"precision": precision,"recall": recall,}
Let's define the training arguments:
training_args = TrainingArguments(output_dir="artifacts",eval_strategy="epoch",save_strategy="no",logging_strategy="steps",logging_steps=1,learning_rate=1e-4,per_device_train_batch_size=32,per_device_eval_batch_size=32,gradient_accumulation_steps=1,lr_scheduler_type="cosine",warmup_ratio=0.05,bf16=torch.cuda.is_bf16_supported(),fp16=not torch.cuda.is_bf16_supported(),bf16_full_eval=torch.cuda.is_bf16_supported(),fp16_full_eval=not torch.cuda.is_bf16_supported(),report_to="wandb",gradient_checkpointing=True,group_by_length=True,torch_compile=True,max_grad_norm=1.0,weight_decay=0.01,)trainer = Trainer(model,args=training_args,train_dataset=train_ds,eval_dataset=eval_ds,tokenizer=tokenizer,compute_metrics=compute_metrics,data_collator=DataCollatorWithPadding(tokenizer),)
We can train our model now
trainer.train()trainer.save_model("artifacts")
Let's look at some of the metrics:
We see that the loss is going down smoothly which is a sign that our model is training. Let's look at some evaluation metrics:
We trained the model for 3 epochs and we can see that the model already achieves 95.2% accuracy. This can be even improved further by using an actual LLM (Qwen 2.5 7B or a higher parameter model), here we just use 0.5B for demonstration purposes.
Note that trainer.save_model will only save the adapters and not the complete model.
If you want to merge the adapters in the base model to be able to use in production you can do so as follows:
# Merge LoRA and base model and saveadapter_path = "artifacts" # directory where adapters are saved in our case artifactspeft_model = PeftModel.from_pretrained(model, adapter_path)merged_model = peft_model.merge_and_unload()merged_model.save_pretrained(merged_path,safe_serialization=True, max_shard_size="2GB")
This is how easy it is to finetune an LLM for classification tasks!
When should you use LLMs as machine learning classifiers ?
We've walked through the process of transforming LLMs into effective classifiers. Now let's address an important question: when does it make sense to use them?
LLMs are undoubtedly powerful, but they come with significant computational overhead. They require substantial hardware resources, have higher latency, and can be expensive to deploy and maintain. These factors make it important to carefully consider when they're the right tool for the job.
For most standard classification tasks, lighter encoder models like DeBERTa or ModernBERT are often the better choice. These models are efficient, fast to run, and much easier to deploy. They've proven themselves highly capable for a wide range of classification tasks while maintaining a smaller computational footprint.
However, LLMs as classifiers shine in scenarios that require both deep language understanding and a bit of reasoning. Consider a task where you need to classify whether a given solution to a math word problem is correct. This requires not just understanding the problem statement, but also following the solution steps, catching logical errors, and verifying if the approach makes sense. Traditional classifiers might struggle here, while LLMs can handle this reasoning chain naturally.
To give you a real-world example, the recently held LMSYS competition on Kaggle by Chatbot Arena, participants had to predict which of two model responses was better based on the given prompt. This required evaluating response coherence, factual accuracy, and reasoning quality - aspects where traditional classifiers often fall short. The winning solutions predominantly used LLMs as classifiers, significantly outperforming encoder-based approaches like DeBERTa.
Furthermore, in RLHF (Reinforcement Learning from Human Feedback) reward modeling, LLMs help evaluate responses by reasoning about their alignment with human preferences and logical coherence. This combination of understanding and reasoning makes them particularly effective for such sophisticated classification tasks.
The key is to match the tool to the task. While LLMs offer sophisticated reasoning capabilities, they should be deployed as classifiers primarily when the task demands these capabilities and the benefits outweigh the deployment costs. For many classification tasks, simpler and more efficient models will serve just as well, if not better.
Conclusion
We explored two comprehensive approaches on how to effectively use Large Language Models (LLMs) for text classification tasks, diving deep into both fine-tuning and prompting techniques. When it comes to fine-tuning, we learned how to modify LLMs' architecture by replacing their language modeling head with a classification head. For the prompting approach, we explored both basic prompting and prompting with structured outputs, seeing how these methods can help create reliable classification systems without the need for model training. Hope this will help you make informed decisions about when and how to use LLMs for your classification needs.
Add a comment
Iterate on AI agents and models faster. Try Weights & Biases today.