Building an SMS Spam Classifier with PyTorch and Hugging Face
Introduction
I recently built a simple SMS spam classifier using PyTorch, Scikit-learn, and the Hugging Face sms_spam dataset. This project focused on training a logistic regression-based neural network to detect whether an SMS message is spam or not spam with high accuracy.
The motivation was to understand the end-to-end pipeline of an NLP binary classification task using my own custom model and code, rather than relying solely on pre-trained transformers.
Model Architecture and Workflow
The neural network (CircleModelV0) consists of:
- Input Layer: 8713 features derived from CountVectorizer applied to SMS texts.
- Hidden Layer: 64 neurons with ReLU activation to introduce non-linearity.
- Output Layer: Single neuron with sigmoid activation for binary output (spam or ham).
import torch
import torch.nn as nn
class CircleModelV0(nn.Module):
def __init__(self, input_size: int):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_size, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)Training Workflow
1. Dataset Loading
The dataset was loaded using Hugging Face's datasets library, containing 5,574 SMS messages labeled as spam or ham.
from datasets import load_dataset
dataset = load_dataset("ucirvine/sms_spam", split="train")
texts = dataset["sms"]
labels = dataset["label"] # 0 = ham, 1 = spam2. Preprocessing with Scikit-learn
The SMS texts were converted into numerical vectors using CountVectorizer, resulting in an 8713-dimensional sparse feature space.
from sklearn.feature_extraction.text import CountVectorizer
vectorizer = CountVectorizer()
X = vectorizer.fit_transform(texts)3. Data Preparation for PyTorch
import torch
from sklearn.model_selection import train_test_split
import scipy.sparse as sp
X_train, X_test, y_train, y_test = train_test_split(X, labels, test_size=0.2)
X_train_t = torch.FloatTensor(X_train.toarray())
X_test_t = torch.FloatTensor(X_test.toarray())
y_train_t = torch.FloatTensor(y_train).unsqueeze(1)
y_test_t = torch.FloatTensor(y_test).unsqueeze(1)4. Model Definition and Training
model = CircleModelV0(input_size=X_train_t.shape[1])
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
for epoch in range(100):
model.train()
optimizer.zero_grad()
preds = model(X_train_t)
loss = criterion(preds, y_train_t)
loss.backward()
optimizer.step()
if epoch % 10 == 0:
acc = ((preds > 0.5).float() == y_train_t).float().mean()
print(f"Epoch {epoch} | Loss: {loss.item():.4f} | Acc: {acc:.4f}")5. Saving the Model and Vectorizer
import pickle
torch.save(model.state_dict(), "full_model.pth")
with open("vectorizer.pkl", "wb") as f:
pickle.dump(vectorizer, f)Running Inference
A separate inference.py script loads the saved model and vectorizer for easy prediction:
import torch
import pickle
from model import CircleModelV0
with open("vectorizer.pkl", "rb") as f:
vectorizer = pickle.load(f)
model = CircleModelV0(input_size=8713)
model.load_state_dict(torch.load("full_model.pth"))
model.eval()
def predict(text: str) -> str:
X = torch.FloatTensor(vectorizer.transform([text]).toarray())
with torch.no_grad():
prob = model(X).item()
return "Spam" if prob > 0.5 else "Ham"
print(predict("you have won a free prize")) # → Spam
print(predict("hey are you free tonight?")) # → HamProject Repository
Complete code, training, and inference scripts:
milliyin/sms-spam-model-train
Benefits of This Project
- Hands-on PyTorch Training — Implemented a neural network from scratch without pre-built classifiers.
- Clear NLP Workflow Understanding — Learned how to process textual data into model-ready tensors.
- Efficient Inference Pipeline — Created a streamlined script to reuse the trained model for quick predictions.
- Utilized Public Datasets — Harnessed the Hugging Face dataset ecosystem for robust and reproducible training.
Conclusion
This SMS spam classification project deepened my understanding of NLP preprocessing, PyTorch model training, and practical deployment pipelines. Such projects bridge the gap between theoretical knowledge and real-world applications, providing a strong foundation for building more advanced models in the future.