Skip to content

Commit

Permalink
implement fitting TOM inputs
Browse files Browse the repository at this point in the history
  • Loading branch information
AmandaWasserman committed Mar 13, 2024
1 parent f3fea9c commit 4dc41ce
Show file tree
Hide file tree
Showing 4 changed files with 153 additions and 9 deletions.
64 changes: 62 additions & 2 deletions resspect/fit_lightcurves.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Copyright 2020 resspect software
# Author: Rupesh Durgesh and Emille Ishida
# Author: Rupesh Durgesh, Emille Ishida, and Amanda Wasserman
#
# created on 14 April 2022
#
Expand Down Expand Up @@ -32,12 +32,14 @@
from resspect.lightcurves_utils import get_resspect_header_data
from resspect.lightcurves_utils import read_plasticc_full_photometry_data
from resspect.lightcurves_utils import SNPCC_FEATURES_HEADER
from resspect.lightcurves_utils import TOM_FEATURES_HEADER
from resspect.lightcurves_utils import TOM_MALANCHEV_FEATURES_HEADER
from resspect.lightcurves_utils import SNPCC_MALANCHEV_FEATURES_HEADER
from resspect.lightcurves_utils import find_available_key_name_in_header
from resspect.lightcurves_utils import PLASTICC_TARGET_TYPES
from resspect.lightcurves_utils import PLASTICC_RESSPECT_FEATURES_HEADER

__all__ = ["fit_snpcc", "fit_plasticc"]
__all__ = ["fit_snpcc", "fit_plasticc", "fit_TOM"]


FEATURE_EXTRACTOR_MAPPING = {
Expand Down Expand Up @@ -238,6 +240,64 @@ def fit_plasticc(path_photo_file: str, path_header_file: str,
light_curve_data, plasticc_features_file)
logging.info("Features have been saved to: %s", output_file)

def _TOM_sample_fit(
id: str, dic: dict, feature_extractor: str):
"""
Reads SNPCC file and performs fit.
Parameters
----------
id
SNID
feature_extractor
Function used for feature extraction.
Options are 'bazin', 'bump', or 'malanchev'.
"""
light_curve_data = FEATURE_EXTRACTOR_MAPPING[feature_extractor]()
light_curve_data.photometry = pd.DataFrame(dic[id]['photometry'])
light_curve_data.dataset_name = 'TOM'
light_curve_data.filters = ['u', 'g', 'r', 'i', 'z', 'Y']

light_curve_data.fit_all()

return light_curve_data

def fit_TOM(data_dic: dict, features_file: str,
number_of_processors: int = 1,
feature_extractor: str = 'bazin'):
"""
Perform fit to all objects from the TOM data.
Parameters
----------
data_dic: str
Dictionary containing the photometry for all light curves.
features_file: str
Path to output file where results should be stored.
number_of_processors: int, default 1
Number of cpu processes to use.
feature_extractor: str, default bazin
Function used for feature extraction.
"""
if feature_extractor == 'bazin':
header = TOM_FEATURES_HEADER
elif feature_extractor == 'malanchev':
header = TOM_MALANCHEV_FEATURES_HEADER

multi_process = multiprocessing.Pool(number_of_processors)
logging.info("Starting TOM " + feature_extractor + " fit...")
with open(features_file, 'w') as snpcc_features_file:
snpcc_features_file.write(','.join(header) + '\n')

for light_curve_data in multi_process.starmap(
_TOM_sample_fit, zip(
data_dic, repeat(data_dic), repeat(feature_extractor))):
if 'None' not in light_curve_data.features:
write_features_to_output_file(
light_curve_data, snpcc_features_file)
logging.info("Features have been saved to: %s", features_file)



def main():
return None
Expand Down
31 changes: 31 additions & 0 deletions resspect/lightcurves_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,37 @@
'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2'
]

TOM_FEATURES_HEADER = [
'id', 'redshift', 'type', 'code', 'orig_sample',
'uA', 'uB', 'ut0', 'utfall', 'utrise',
'gA', 'gB', 'gt0', 'gtfall', 'gtrise', 'rA', 'rB',
'rt0', 'rtfall', 'rtrise', 'iA', 'iB', 'it0', 'itfall',
'itrise', 'zA', 'zB', 'zt0', 'ztfall', 'ztrise',
'YA', 'YB', 'Yt0', 'Ytfall', 'Ytrise'
]

TOM_MALANCHEV_FEATURES_HEADER = [
'id', 'redshift', 'type', 'code', 'orig_sample',
'uanderson_darling_normal','uinter_percentile_range_5',
'uchi2','ustetson_K','uweighted_mean','uduration', 'uotsu_mean_diff','uotsu_std_lower', 'uotsu_std_upper',
'uotsu_lower_to_all_ratio', 'ulinear_fit_slope', 'ulinear_fit_slope_sigma','ulinear_fit_reduced_chi2',
'ganderson_darling_normal','ginter_percentile_range_5',
'gchi2','gstetson_K','gweighted_mean','gduration', 'gotsu_mean_diff','gotsu_std_lower', 'gotsu_std_upper',
'gotsu_lower_to_all_ratio', 'glinear_fit_slope', 'glinear_fit_slope_sigma','glinear_fit_reduced_chi2',
'randerson_darling_normal', 'rinter_percentile_range_5',
'rchi2', 'rstetson_K', 'rweighted_mean','rduration', 'rotsu_mean_diff','rotsu_std_lower', 'rotsu_std_upper',
'rotsu_lower_to_all_ratio', 'rlinear_fit_slope', 'rlinear_fit_slope_sigma','rlinear_fit_reduced_chi2',
'ianderson_darling_normal','iinter_percentile_range_5',
'ichi2', 'istetson_K', 'iweighted_mean','iduration', 'iotsu_mean_diff','iotsu_std_lower', 'iotsu_std_upper',
'iotsu_lower_to_all_ratio', 'ilinear_fit_slope', 'ilinear_fit_slope_sigma','ilinear_fit_reduced_chi2',
'zanderson_darling_normal','zinter_percentile_range_5',
'zchi2', 'zstetson_K', 'zweighted_mean','zduration', 'zotsu_mean_diff','zotsu_std_lower', 'zotsu_std_upper',
'zotsu_lower_to_all_ratio', 'zlinear_fit_slope', 'zlinear_fit_slope_sigma','zlinear_fit_reduced_chi2',
'Yanderson_darling_normal','Yinter_percentile_range_5',
'Ychi2','Ystetson_K','Yweighted_mean','Yduration', 'Yotsu_mean_diff','Yotsu_std_lower', 'Yotsu_std_upper',
'Yotsu_lower_to_all_ratio', 'Ylinear_fit_slope', 'Ylinear_fit_slope_sigma','Ylinear_fit_reduced_chi2'
]

PLASTICC_RESSPECT_FEATURES_HEADER = [
'id', 'redshift', 'type', 'code', 'orig_sample', 'uA', 'uB', 'ut0',
'utfall', 'utrise', 'gA', 'gB', 'gt0', 'gtfall','gtrise', 'rA', 'rB',
Expand Down
8 changes: 1 addition & 7 deletions resspect/read_dash.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

#import logging
#import numpy as np
#import pandas as pd
import csv

__all__ = ["get_id_type"]




def get_id_type(file='/Users/arw/Desktop/spec_sims_binned_spaced/DASH_matches.txt'):
def get_id_type(file: str):
#read text file
data = open(file, "r")

Expand Down
59 changes: 59 additions & 0 deletions resspect/request_webservice_elasticc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright 2020 resspect software
# Author: Amanda Wasserman
#
# created on 12 March 2024
#
# Licensed GNU General Public License v3.0;
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.gnu.org/licenses/gpl-3.0.en.html
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import requests

class TomClient:
def __init__(self, url='https://desc-tom.lbl.gov', username=None,
password=None, passwordfile = None, connect=True):
self._url = url
self._username = username
self._password = password
self._rqs=None
if self._password is None:
if passwordfile is None:
raise RuntimeError('No password or passwordfile provided')
with open(passwordfile) as ifp:
self._password = ifp.readline().strip()

if connect:
self.connect()

def connect(self):
self._rqs = requests.session() #should this be capitalized S?
res = self._rqs.get(f'{self._url}/acounts/login/')
if res.status_code != 200:
raise RuntimeError(f'Failed to connect to {self._url}')
res = self._rqs.post(f'{self._url}/accounts/login/',
data={'username':self._username,
'password':self._password,
'csrfmiddlewaretoken': self._rqs.cookies['csrftoken']})
if res.status_code != 200:
raise RuntimeError(f'Failed to login.')
if 'Please enter a correct' in res.text:
raise RuntimeError("failed to log in.")
self._rqs.headers.update({'X-CSRFToken': self._rqs.cookies['csrftoken']})

def request(self, method="GET", page=None, **kwargs):

return self._rqs.request(method=method, url=f"{self._url}/{page}", **kwargs)

#def request(website: str):
# r = requests.get(website)
# status = r.status_code
# text = r.text
# return text

0 comments on commit 4dc41ce

Please sign in to comment.