Skip to content

Commit

Permalink
First commit
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Aug 15, 2023
1 parent 38f18b3 commit 4c8d3ac
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 9 deletions.
1 change: 0 additions & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import os
import sys
from types import ModuleType

sys.path.insert(0, os.path.abspath("."))
sys.path.insert(0, os.path.abspath("../.."))
Expand Down
34 changes: 26 additions & 8 deletions docs/source/examples/plot_fista.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,35 @@
"""
Lasso Regression with FISTA
===========================
Lasso Regression with FISTA Optimization
========================================
In this example, we'll implement Lasso regression using the FISTA algorithm,
leveraging `proxtorch` for the proximal operations.
"""
This script demonstrates how to implement Lasso regression using the
Fast Iterative Shrinkage-Thresholding Algorithm (FISTA). Lasso regression
is a method in linear regression that incorporates L1 regularization,
leading to a sparser solution where many coefficients are set to zero.
import numpy as np
import torch
By using FISTA, an accelerated gradient-based optimization technique, we
can achieve faster convergence in solving the Lasso problem compared to
standard gradient descent methods.
Key Highlights:
- Leverages the `proxtorch` library for efficient proximal operations.
- Creates synthetic data using `sklearn.datasets.make_regression`.
- Defines and runs the FISTA algorithm for Lasso regression.
- Visualizes the non-zero coefficients of the learned Lasso model.
Dependencies:
- `numpy`
- `torch`
- `matplotlib`
- `proxtorch`
- `sklearn.datasets`
"""
import matplotlib.pyplot as plt
from proxtorch.operators import L1Prox
import torch
from sklearn.datasets import make_regression

from proxtorch.operators import L1Prox

# Create synthetic data
X, y = make_regression(n_samples=100, n_features=20, noise=0.1)
X, y = torch.tensor(X, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)
Expand Down
46 changes: 46 additions & 0 deletions docs/source/examples/plot_lasso_regression.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,49 @@
"""
Comparing Custom Lasso Regression with scikit-learn's Implementation
====================================================================
This script demonstrates how to implement and train a Lasso regression model using
PyTorch Lightning and compares it with scikit-learn's built-in Lasso regression.
Lasso regression is a linear regression variant that incorporates L1 regularization,
leading to sparse weight vectors. In other words, many weights become exactly zero,
allowing for simpler and more interpretable models.
In this example:
- A custom `LassoRegression` class is defined using PyTorch Lightning, which
incorporates the L1 regularization via a proximal gradient method.
- Synthetic data is generated where the ground truth weights are partly set to zero
to mimic sparse structures.
- Both the custom Lasso model and scikit-learn's Lasso model are trained on the
synthetic data.
- The models' performances are compared based on their mean squared error (MSE) on
a test set.
- The predicted values of both models are visualized against the true values for
a comparative look.
- Finally, the learned weights from both models are compared with the true weights
through bar plots.
By the end of this script, you should have insights into how Lasso regression can
be implemented in PyTorch Lightning and how its performance matches up against
traditional implementations in packages like scikit-learn.
Dependencies:
- `torch`
- `torch.nn`
- `torch.optim`
- `sklearn.model_selection`
- `sklearn.linear_model`
- `numpy`
- `matplotlib`
- `pytorch_lightning`
- `proxtorch`
"""
import torch
import torch.nn as nn
import torch.optim as optim
Expand Down
28 changes: 28 additions & 0 deletions docs/source/examples/plot_robust_pca.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,31 @@
"""
Robust Principal Component Analysis with PyTorch Lightning
===========================================================
This script demonstrates how to perform Robust Principal Component Analysis (RPCA) using
PyTorch Lightning. RPCA decomposes a matrix into two components:
1. A low-rank matrix that captures the global structure.
2. A sparse matrix that identifies the sparse errors.
The goal of RPCA is to find the best low-rank and sparse matrices that, when combined, closely
approximate the original matrix.
In this example:
- A custom `RobustPCA` class is defined using PyTorch Lightning, which learns the low-rank and sparse matrices.
- A `RandomMatrixDataset` class is designed to generate synthetic matrices composed of a true low-rank matrix and a true sparse matrix.
- The model is trained to approximate these matrices.
- The true and learned matrices are visualized for comparison.
By the end of this script, you will have a clear idea of how to implement and visualize RPCA using PyTorch Lightning.
Dependencies:
- `pytorch_lightning`
- `torch`
- `matplotlib`
- `proxtorch`
"""
import pytorch_lightning as pl
import torch
from torch.utils.data import DataLoader, Dataset
Expand Down

0 comments on commit 4c8d3ac

Please sign in to comment.