Skip to content

Commit

Permalink
Add VIX table to database.
Browse files Browse the repository at this point in the history
  • Loading branch information
Francisco Silva authored and Francisco Silva committed Nov 12, 2024
1 parent f323fa5 commit 89cffbe
Show file tree
Hide file tree
Showing 10 changed files with 185 additions and 69 deletions.
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ jobs:
- name: Lint with ruff
run: |
ruff check .
ruff format --check .
- name: Run tests
run: |
Expand Down
44 changes: 44 additions & 0 deletions notebooks/eda.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -3058,6 +3058,50 @@
"fig = px.bar(aapl_df, x=\"rdq\", y=\"roa_qoq\")\n",
"fig.show()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div><style>\n",
".dataframe > thead > tr,\n",
".dataframe > tbody > tr {\n",
" text-align: right;\n",
" white-space: pre-wrap;\n",
"}\n",
"</style>\n",
"<small>shape: (5, 5)</small><table border=\"1\" class=\"dataframe\"><thead><tr><th>date</th><th>close</th><th>adj_close</th><th>volume</th><th>tic</th></tr><tr><td>date</td><td>f64</td><td>f64</td><td>i64</td><td>str</td></tr></thead><tbody><tr><td>2005-01-03</td><td>1202.079956</td><td>1202.079956</td><td>1510800000</td><td>&quot;^GSPC&quot;</td></tr><tr><td>2005-01-04</td><td>1188.050049</td><td>1188.050049</td><td>1721000000</td><td>&quot;^GSPC&quot;</td></tr><tr><td>2005-01-05</td><td>1183.73999</td><td>1183.73999</td><td>1738900000</td><td>&quot;^GSPC&quot;</td></tr><tr><td>2005-01-06</td><td>1187.890015</td><td>1187.890015</td><td>1569100000</td><td>&quot;^GSPC&quot;</td></tr><tr><td>2005-01-07</td><td>1186.189941</td><td>1186.189941</td><td>1477900000</td><td>&quot;^GSPC&quot;</td></tr></tbody></table></div>"
],
"text/plain": [
"shape: (5, 5)\n",
"┌────────────┬─────────────┬─────────────┬────────────┬───────┐\n",
"│ date ┆ close ┆ adj_close ┆ volume ┆ tic │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ date ┆ f64 ┆ f64 ┆ i64 ┆ str │\n",
"╞════════════╪═════════════╪═════════════╪════════════╪═══════╡\n",
"│ 2005-01-03 ┆ 1202.079956 ┆ 1202.079956 ┆ 1510800000 ┆ ^GSPC │\n",
"│ 2005-01-04 ┆ 1188.050049 ┆ 1188.050049 ┆ 1721000000 ┆ ^GSPC │\n",
"│ 2005-01-05 ┆ 1183.73999 ┆ 1183.73999 ┆ 1738900000 ┆ ^GSPC │\n",
"│ 2005-01-06 ┆ 1187.890015 ┆ 1187.890015 ┆ 1569100000 ┆ ^GSPC │\n",
"│ 2005-01-07 ┆ 1186.189941 ┆ 1186.189941 ┆ 1477900000 ┆ ^GSPC │\n",
"└────────────┴─────────────┴─────────────┴────────────┴───────┘"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from pipeline import Scraper\n",
"\n",
"vix = Scraper(\"^GSPC\", \"yfinance\").get_market_data(\"2005-01-01\")\n",
"vix.head()\n"
]
}
],
"metadata": {
Expand Down
6 changes: 4 additions & 2 deletions stocksense/config/db_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,7 @@
'sp500':
- 'date'
- 'close'
- 'adj_close'
- 'volume'
- 'volume'
'vix':
- 'date'
- 'close'
2 changes: 2 additions & 0 deletions stocksense/database_handler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
insert_record,
update_data,
delete_data,
delete_table,
count_data,
fetch_record,
fetch_data,
Expand All @@ -19,6 +20,7 @@
"insert_record",
"update_data",
"delete_data",
"delete_table",
"count_data",
"fetch_record",
"fetch_data",
Expand Down
18 changes: 18 additions & 0 deletions stocksense/database_handler/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
insert_record,
update_data,
delete_data,
delete_table,
count_data,
fetch_data,
)
Expand Down Expand Up @@ -55,6 +56,11 @@ def insert_index_data(self, data: pl.DataFrame) -> None:
conn = self.db.get_connection()
insert_data(conn, "sp500", data)

def insert_vix_data(self, data: pl.DataFrame) -> None:
data = convert_date_columns_to_str(data, ["date"])
conn = self.db.get_connection()
insert_data(conn, "vix", data)

def delete_stock(self, tic: str) -> None:
conn = self.db.get_connection()
delete_data(conn, "stock", {"tic": tic})
Expand Down Expand Up @@ -120,6 +126,18 @@ def fetch_index_data(self) -> pl.DataFrame:
df = convert_str_columns_to_date(df, ["date"])
return df

def fetch_vix_data(self) -> pl.DataFrame:
conn = self.db.get_connection()
df = fetch_data(conn, "vix")
if df is None:
return pl.DataFrame()
df = convert_str_columns_to_date(df, ["date"])
return df

def delete_table(self, table_name: str) -> bool:
conn = self.db.get_connection()
return delete_table(conn, table_name)

def close(self):
self.db.close()

Expand Down
27 changes: 27 additions & 0 deletions stocksense/database_handler/queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,33 @@ def delete_data(connection: Connection, table_name: str, condition: dict) -> Non
logger.error(f"Error deleting data from table {table_name}: {e}")


def delete_table(connection: Connection, table_name: str) -> bool:
"""
Delete a table from the database.
Parameters
----------
connection : Connection
Database connection object.
table_name : str
Name of table to delete.
Returns
-------
bool
True if successful, False otherwise.
"""
try:
cursor = connection.cursor()
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
connection.commit()
logger.info(f"Successfully deleted table {table_name}")
return True
except Error as e:
logger.error(f"Error deleting table {table_name}: {e}")
return False


def count_data(connection: Connection, table_name: str, column: str) -> Optional[int]:
try:
sql = f"SELECT COUNT(DISTINCT {column}) FROM {table_name}"
Expand Down
7 changes: 6 additions & 1 deletion stocksense/database_handler/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,15 @@ def create_tables(connection: Connection | None) -> None:
CREATE TABLE IF NOT EXISTS sp500 (
date TEXT PRIMARY KEY,
close REAL,
adj_close REAL,
volume INTEGER
)
""",
"vix": """
CREATE TABLE IF NOT EXISTS vix (
date TEXT PRIMARY KEY,
close REAL
)
""",
}

try:
Expand Down
31 changes: 19 additions & 12 deletions stocksense/model/model_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,9 @@ def train(self, data: pl.DataFrame):
)
scale = self.get_dataset_imbalance_scale(train_df)
aux_cols = ["datadate", "rdq", "sector"]
train_df = train_df.select([c for c in train_df.columns if c not in aux_cols])
train_df = train_df.select(
[c for c in train_df.columns if c not in aux_cols]
)

ga = GeneticAlgorithm(
num_generations=50,
Expand Down Expand Up @@ -105,7 +107,7 @@ def train(self, data: pl.DataFrame):
"scale_pos_weight": scale,
"eval_metric": "logloss",
"nthread": -1,
"seed": self.seed
"seed": self.seed,
}

X_train = train_df.select(
Expand Down Expand Up @@ -134,17 +136,21 @@ def get_dataset_imbalance_scale(self, train_df):
Class imbalance scale.
"""
return int(
len(train_df.filter(
(pl.col(self.target_col) == 0) &
(pl.col("tdq").dt.year() >= self.train_start) &
(pl.col("tdq").dt.year() < self.train_start + self.train_window)
len(
train_df.filter(
(pl.col(self.target_col) == 0)
& (pl.col("tdq").dt.year() >= self.train_start)
& (pl.col("tdq").dt.year() < self.train_start + self.train_window)
)
)
) / len(train_df.filter(
(pl.col(self.target_col) == 1) &
(pl.col("tdq").dt.year() >= self.train_start) &
(pl.col("tdq").dt.year() < self.train_start + self.train_window)
/ len(
train_df.filter(
(pl.col(self.target_col) == 1)
& (pl.col("tdq").dt.year() >= self.train_start)
& (pl.col("tdq").dt.year() < self.train_start + self.train_window)
)
)
))
)

def score(self, data):
"""
Expand All @@ -167,7 +173,8 @@ def score(self, data):
).to_pandas()

model_path = model_path = (
Path("models/") / f"xgb_{self.last_trade_date}.pkl"
Path("models/") /
f"xgb_{self.last_trade_date}.pkl"
)
model = XGBoostModel().load_model(model_path)
model.predict_proba(test_df)
Expand Down
40 changes: 27 additions & 13 deletions stocksense/pipeline/etl.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class ETL:

def __init__(self, stocks: Optional[list[str]] = []):
self.db = DatabaseHandler()
self.db_fields = get_config("db")["schema"]
self.db_schema = get_config("db")["schema"]
self.base_date = get_config("scraping")["base_date"]
self.fin_source = "yfinance"
self.historical_data_path = DATA_PATH / "interim"
Expand Down Expand Up @@ -58,7 +58,7 @@ def set_index_listings(self) -> None:
pl.lit(1).alias("active"),
]
)
self.db.insert_stock(stock_df[self.db_fields["stock"]])
self.db.insert_stock(stock_df[self.db_schema["stock"]])

def update_index_listings(self) -> None:
"""
Expand Down Expand Up @@ -119,7 +119,7 @@ def _add_new_symbols(self, stock_df: pl.DataFrame, active_df: pl.DataFrame) -> N
pl.lit(1).alias("active"),
]
)
self.db.insert_stock(stock[self.db_fields["stock"]])
self.db.insert_stock(stock[self.db_schema["stock"]])

logger.info(f"added {tic} to S&P500 index")

Expand All @@ -137,6 +137,7 @@ def extract(self) -> None:
if self.is_empty():
raise ValueError("No stocks assigned for ETL process.")
self.extract_sp_500()
self.extract_vix()
self._extract_all_stocks()

def extract_sp_500(self) -> None:
Expand All @@ -147,13 +148,26 @@ def extract_sp_500(self) -> None:
try:
scraper = Scraper("^GSPC", self.fin_source)
data = scraper.get_market_data(self.base_date)
data = data.drop("tic")
self.db.insert_index_data(data)
self.db.insert_index_data(data[self.db_schema["sp500"]])
logger.info("inserted S&P500 market data")
except Exception:
logger.error("S&P500 data extraction FAILED")
return

def extract_vix(self) -> None:
"""
Retrieve updated daily VIX data.
"""
logger.info("extracting VIX data")
try:
scraper = Scraper("^VIX", self.fin_source)
data = scraper.get_market_data(self.base_date)
self.db.insert_vix_data(data[self.db_schema["vix"]])
logger.info("inserted VIX market data")
except Exception:
logger.error("VIX data extraction FAILED")
return

def _extract_all_stocks(self) -> None:
"""
Extract data for all assigned stocks.
Expand Down Expand Up @@ -257,7 +271,7 @@ def extract_fundamental_data(
start_date = fin_data["datadate"].max()
except Exception:
# no past data available for stock
fin_data = pl.DataFrame(schema=self.db_fields["financial"])
fin_data = pl.DataFrame(schema=self.db_schema["financial"])
start_date = dt.datetime.strptime(self.base_date, "%Y-%m-%d").date()
logger.warning(f"{tic}: no past financial data found ({last_update})")

Expand All @@ -268,7 +282,7 @@ def extract_fundamental_data(
return False

data = scraper.get_financial_data(start_date, end_date)
self.db.insert_financial_data(data[self.db_fields["financial"]])
self.db.insert_financial_data(data[self.db_schema["financial"]])
self.db.update_stock(tic, {"last_update": end_date})
logger.success(f"{tic}: updated financial data ({start_date}:{end_date})")
return True
Expand Down Expand Up @@ -305,7 +319,7 @@ def extract_market_data(self, tic: str, scraper: Scraper) -> bool:
return False

data = scraper.get_market_data(self.base_date)
self.db.insert_market_data(data[self.db_fields["market"]])
self.db.insert_market_data(data[self.db_schema["market"]])
self.db.update_stock(tic, {"last_update": end_date})
logger.success(f"{tic}: updated market data ({end_date})")
return True
Expand Down Expand Up @@ -341,7 +355,7 @@ def extract_insider_data(self, tic: str, scraper: Scraper) -> bool:
return False

data = scraper.get_stock_insider_data()
self.db.insert_insider_data(data[self.db_fields["insider"]])
self.db.insert_insider_data(data[self.db_schema["insider"]])
logger.success(f"{tic}: updated insider trading data ({end_date})")
return True
except Exception as e:
Expand Down Expand Up @@ -375,7 +389,7 @@ def _ingest_stock_list(self) -> None:
pl.col("spx_status").cast(pl.Int16),
pl.col("spx_status").cast(pl.Int16).alias("active"),
pl.lit(parsed_date).alias("last_update"),
)[self.db_fields["stock"]]
)[self.db_schema["stock"]]

self.db.insert_stock(index_df)

Expand Down Expand Up @@ -430,7 +444,7 @@ def _ingest_market_data(self, market_file: Path, tic: str) -> None:
"Volume": "volume",
}
)
market_df = market_df[self.db_fields["market"]]
market_df = market_df[self.db_schema["market"]]
self.db.insert_market_data(market_df)
except Exception:
logger.warning(f"market data file for {tic} is empty.")
Expand All @@ -455,7 +469,7 @@ def _ingest_insider_data(self, insider_file: Path, tic: str) -> None:
"Value": "value",
}
)
insider_df = insider_df[self.db_fields["insider"]]
insider_df = insider_df[self.db_schema["insider"]]
self.db.insert_insider_data(insider_df)
except Exception:
logger.warning(f"insider data file for {tic} is empty.")
Expand All @@ -471,7 +485,7 @@ def _ingest_financials_data(self, financials_file: Path, tic: str) -> None:
pl.col("rdq").str.to_date("%Y-%m-%d"),
pl.lit(tic).alias("tic"),
)
financials_df = financials_df[self.db_fields["financial"]]
financials_df = financials_df[self.db_schema["financial"]]
self.db.insert_financial_data(financials_df)
except Exception:
logger.warning(f"financials data file for {tic} is empty.")
Loading

0 comments on commit 89cffbe

Please sign in to comment.