Skip to content

Commit

Permalink
Baizhige/develop (#2)
Browse files Browse the repository at this point in the history
* update debug directory

* Updates:
  • Loading branch information
Baizhige authored Sep 21, 2024
1 parent 653c6c9 commit 228e952
Show file tree
Hide file tree
Showing 4 changed files with 83 additions and 42 deletions.
13 changes: 6 additions & 7 deletions debug/debug_sheet_epoch_by_events.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from eegunity.unifieddataset import UnifiedDataset

fail_list = ['ieeedata_nju_aad','tuh_eeg_epilepsy', 'bcic_iv_4', 'physionet_sleepdephemocog', 'bcic_iv_2a', 'physionet_mssvepdb', 'physionet_sleepedfx', 'other_ahajournals', 'physionet_eegmat', 'other_artifact_rejection', 'tuh_eeg_abnormal', 'physionet_auditoryeeg', 'tuh_eeg_seizure', 'bcic_iv_2b', 'physionet_motionartifact', 'physionet_capslpdb', 'figshare_shudb', 'tuh_eeg_slowing', 'openneuro_ds003516', 'mendeley_sin', 'kaggle_inria', 'other_migrainedb', 'zenodo_3618205', 'zenodo_sin', 'iscslp2024_chineseaad', 'zenodo_saa', 'openneuro_ds004015', 'physionet_hmcsleepstaging', 'zenodo_uhd', 'zenodo_4518754', 'other_openbmi', 'physionet_eegmmidb', 'figshare_largemi', 'figshare_stroke', 'tuh_eeg', 'tuh_eeg_events', 'zenodo_7778289', 'physionet_chbmit', 'bcic_iv_1', 'other_highgammadataset', 'zenodo_kul', 'other_seed', 'other_eegdenoisenet', 'zenodo_dtu', 'github_inabiyouni', 'bcic_iv_3', 'osf_8jpc5', 'bcic_iii_2', 'figshare_meng2019', 'tuh_eeg_artifact', 'physionet_ucddb', 'kaggle_graspandlift', 'bcic_iii_1']
remain_list = ['ieee_icassp_competition_2024']

remain_list = ['ieee_icassp_competition_2024','bcic_iii_1','physionet_eegmmidb']

done_list = []

for folder_name in remain_list:
try:
unified_dataset = UnifiedDataset(domain_tag=folder_name, locator_path=f"./locator/events/{folder_name}_events.csv", is_unzip=False)
unified_dataset.eeg_batch.epoch_by_event(output_path=f"../EEGUnity_ouput/{folder_name}/epoch_by_events", seg_sec=3, resample=128)
'''
epoch_by_event(self, output_path: str, seg_sec: float, resample: int = None,
exclude_bad=True, baseline=(0, 0.2), miss_bad_data=False):
'''

unified_dataset = UnifiedDataset(domain_tag=folder_name, locator_path=f"../locator/{folder_name}_events.csv", is_unzip=False)
unified_dataset.eeg_batch.epoch_by_event(output_path=f"../EEGUnity_ouput/{folder_name}/epoch_by_events", resample=128, tmin=-0.2, tmax=3, baseline=(None, 0))
except Exception as e:
print("fail===========")
print(folder_name)
Expand Down
10 changes: 6 additions & 4 deletions debug/debug_sheet_extract_event.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from eegunity.unifieddataset import UnifiedDataset

fail_list = ['ieeedata_nju_aad']
remain_list = ['tuh_eeg_epilepsy', 'bcic_iv_4', 'physionet_sleepdephemocog', 'bcic_iv_2a', 'physionet_mssvepdb', 'physionet_sleepedfx', 'other_ahajournals', 'physionet_eegmat', 'other_artifact_rejection', 'tuh_eeg_abnormal', 'physionet_auditoryeeg', 'tuh_eeg_seizure', 'bcic_iv_2b', 'physionet_motionartifact', 'physionet_capslpdb', 'figshare_shudb', 'tuh_eeg_slowing', 'openneuro_ds003516', 'mendeley_sin', 'kaggle_inria', 'other_migrainedb', 'zenodo_3618205', 'zenodo_sin', 'iscslp2024_chineseaad', 'zenodo_saa', 'openneuro_ds004015', 'physionet_hmcsleepstaging', 'zenodo_uhd', 'zenodo_4518754', 'other_openbmi', 'physionet_eegmmidb', 'figshare_largemi', 'figshare_stroke', 'tuh_eeg', 'tuh_eeg_events', 'zenodo_7778289', 'physionet_chbmit', 'bcic_iv_1', 'other_highgammadataset', 'zenodo_kul', 'other_seed', 'other_eegdenoisenet', 'zenodo_dtu', 'github_inabiyouni', 'bcic_iv_3', 'osf_8jpc5', 'bcic_iii_2', 'figshare_meng2019', 'tuh_eeg_artifact', 'physionet_ucddb', 'kaggle_graspandlift', 'bcic_iii_1']

fail_list = ['ieeedata_nju_aad','tuh_eeg_epilepsy', 'bcic_iv_4', 'physionet_sleepdephemocog', 'bcic_iv_2a', 'physionet_mssvepdb', 'physionet_sleepedfx', 'other_ahajournals', 'physionet_eegmat', 'other_artifact_rejection', 'tuh_eeg_abnormal', 'physionet_auditoryeeg', 'tuh_eeg_seizure', 'bcic_iv_2b', 'physionet_motionartifact', 'physionet_capslpdb', 'figshare_shudb', 'tuh_eeg_slowing', 'openneuro_ds003516', 'mendeley_sin', 'kaggle_inria', 'other_migrainedb', 'zenodo_3618205', 'zenodo_sin', 'iscslp2024_chineseaad', 'zenodo_saa', 'openneuro_ds004015', 'physionet_hmcsleepstaging', 'zenodo_uhd', 'zenodo_4518754', 'other_openbmi', 'physionet_eegmmidb', 'figshare_largemi', 'figshare_stroke', 'tuh_eeg', 'tuh_eeg_events', 'zenodo_7778289', 'physionet_chbmit', 'bcic_iv_1', 'other_highgammadataset', 'zenodo_kul', 'other_seed', 'other_eegdenoisenet', 'zenodo_dtu', 'github_inabiyouni', 'bcic_iv_3', 'osf_8jpc5', 'bcic_iii_2', 'figshare_meng2019', 'tuh_eeg_artifact', 'physionet_ucddb', 'kaggle_graspandlift', 'bcic_iii_1']
remain_list = ['ieee_icassp_competition_2024']
done_list = ['ieee_icassp_competition_2024']

for folder_name in remain_list:
try:
unified_dataset = UnifiedDataset(domain_tag=folder_name, locator_path=f"./locator/{folder_name}.csv", is_unzip=False)

unified_dataset = UnifiedDataset(domain_tag=folder_name, locator_path=f"../locator/{folder_name}.csv", is_unzip=False)
unified_dataset.eeg_batch.get_events()
locator_save_path = f"./locator/{folder_name}_events.csv"
locator_save_path = f"../locator/{folder_name}_events.csv"
unified_dataset.save_locator(locator_save_path)
except Exception as e:
print("fail===========")
Expand Down
98 changes: 68 additions & 30 deletions eegunity/module_eeg_batch/eeg_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,10 @@ def batch_process(self, con_func, app_func, is_patch, result_type=None):
result = None

results.append(result)

if result_type == "series":
# Combine results into a DataFrame if app_func returns Series
combined_results = pd.concat([res for res in results if res is not None], axis=1).T
combined_results.reset_index(drop=True, inplace=True) # 重置索引并删除原来的索引列
combined_results.reset_index(drop=True, inplace=True)
return combined_results
elif result_type == "value":
# Collect results into a list if app_func returns values
Expand Down Expand Up @@ -123,17 +122,16 @@ def sample_filter(
"""
Filters the 'locator' dataframe based on the given criteria.
Parameters:
channel_number (tuple/list/array-like, optional): A tuple or list with (min, max) values to filter the
:param (tuple/list/array-like, optional) channel_number: A tuple or list with (min, max) values to filter the
"Number of Channels" column. If None, this criterion is ignored. Default is None.
sampling_rate (tuple/list/array-like, optional): A tuple or list with (min, max) values to filter the
:param (tuple/list/array-like, optional) sampling_rate : A tuple or list with (min, max) values to filter the
"Sampling Rate" column. If None, this criterion is ignored. Default is None.
duration (tuple/list/array-like, optional): A tuple or list with (min, max) values to filter the
:param (tuple/list/array-like, optional) duration: A tuple or list with (min, max) values to filter the
"Duration" column. If None, this criterion is ignored. Default is None.
completeness_check (str, optional): A string that can be 'Completed', 'Unavailable', or 'Acceptable' to filter the
:param (str, optional) completeness_check: A string that can be 'Completed', 'Unavailable', or 'Acceptable' to filter the
"Completeness Check" column. The check is case-insensitive. If None, this criterion is ignored. Default is None.
domain_tag (str, optional): A string to filter the "Domain Tag" column. If None, this criterion is ignored. Default is None.
file_type (str, optional): A string to filter the "File Type" column. If None, this criterion is ignored. Default is None.
:param (str, optional) domain_tag: A string to filter the "Domain Tag" column. If None, this criterion is ignored. Default is None.
:param (str, optional) file_type: A string to filter the "File Type" column. If None, this criterion is ignored. Default is None.
Returns:
None. The function updates the 'locator' dataframe in the shared attributes.
Expand Down Expand Up @@ -196,7 +194,7 @@ def save_as_other(self, output_path, domain_tag=None, format='fif'):

# Check for valid format
if format not in ['fif', 'csv']:
raise ValueError(f"Unsupported format: {format}. Only 'fif' and 'csv' are supported.")
raise ValueError(f"Unsupported format: {format}. Currently, only 'fif' and 'csv' are supported.")

def con_func(row):
return domain_tag is None or row['Domain Tag'] == domain_tag
Expand All @@ -207,11 +205,14 @@ def app_func(row):

if format == 'fif':
# Saving as FIF format
new_file_path = os.path.join(output_path, f"{file_name}_raw.fif")
new_file_path = os.path.join(output_path, f"{file_name}.fif")
raw.save(new_file_path, overwrite=True)
row['File Path'] = new_file_path
row['File Type'] = "standard_data"

elif format == 'csv':
# Saving as CSV format
new_file_path = os.path.join(output_path, f"{file_name}_raw.csv")
new_file_path = os.path.join(output_path, f"{file_name}.csv")

# Extract data and channel names from Raw object
data, times = raw.get_data(return_times=True)
Expand All @@ -221,18 +222,29 @@ def app_func(row):
df = pd.DataFrame(data.T, columns=channel_names)
df.insert(0, 'date', times) # Add 'date' as the first column

# Extract events from raw data
events, event_id = extract_events(raw)

# Create an empty 'marker' column initialized with NaNs
df['marker'] = np.nan

# Map event onsets to timepoints and set the corresponding marker
for event in events:
onset_sample = event[0]
event_code = event[2]
# Find the closest timestamp for the onset sample
closest_time_idx = np.argmin(np.abs(times - raw.times[onset_sample]))
df.at[closest_time_idx, 'marker'] = event_code # Mark event code in the 'marker' column
# Save DataFrame to CSV
df.to_csv(new_file_path, index=False)
row['File Path'] = new_file_path
row['File Type'] = "csv_data"

row['File Path'] = new_file_path
row['File Type'] = "standard_data"
return row

copied_instance = copy.deepcopy(self)
new_locator = self.batch_process(con_func, app_func, is_patch=False, result_type='series')
copied_instance.set_shared_attr({'locator': new_locator})
return copied_instance

def process_mean_std(self, domain_mean=True):
def get_mean_std(data: mne.io.Raw):
"""
Expand Down Expand Up @@ -688,11 +700,8 @@ def con_func(row):
def app_func(row):
try:
mne_raw = get_data_row(row)
# print(row)
events, event_id = extract_events(mne_raw)

row["event_id"] = str(event_id)
print(event_id)
event_id_num = {key: sum(events[:, 2] == val) for key, val in event_id.items()}
row["event_id_num"] = str(event_id_num)
return row
Expand All @@ -709,8 +718,8 @@ def app_func(row):
result_type='series')
self.get_shared_attr()['locator'] = new_locator

def epoch_by_event(self, output_path: str, seg_sec: float, resample: int = None,
exclude_bad=True, baseline=(0, 0.2), miss_bad_data=False):
def epoch_by_event(self, output_path: str, resample: int = None,
exclude_bad=True, miss_bad_data=False, **epoch_params):
"""
Batch process EEG data to create epochs based on events specified in event_id column.
Expand All @@ -719,10 +728,9 @@ def epoch_by_event(self, output_path: str, seg_sec: float, resample: int = None,
output_path (str): Directory to save the processed epochs.
seg_sec (float): Length of each epoch in seconds.
resample (int): Resample rate for the raw data. If None, no resampling is performed.
overlap (float): Fraction of overlap between consecutive epochs. Range is 0 to 1.
exclude_bad (bool): Whether to exclude bad epochs. Uses simple heuristics to determine bad epochs.
baseline (tuple): Time interval to use for baseline correction. If (None, 0), uses the pre-stimulus interval.
miss_bad_data (bool): Whether to skip files with processing errors.
**epoch_params: Additional parameters for mne.Epochs, excluding raw_data, events, event_id.
Returns:
None
Expand All @@ -731,8 +739,8 @@ def epoch_by_event(self, output_path: str, seg_sec: float, resample: int = None,
def con_func(row):
return True

def app_func(row, output_path: str, seg_sec: float, resample: int = None,
exclude_bad=True, baseline=(None, 0), event_repeated="merge"):
def app_func(row, output_path: str, resample: int = None,
exclude_bad=True, **epoch_params):
try:
# Load raw data
raw_data = get_data_row(row)
Expand All @@ -749,9 +757,10 @@ def app_func(row, output_path: str, seg_sec: float, resample: int = None,
if not event_id_str or len(event_id) == 0:
print(f"No event_id found for file {row['File Path']}")
return None
# Create epochs
epochs = mne.Epochs(raw_data, events, event_id, tmin=0, tmax=seg_sec,
baseline=baseline, preload=True, event_repeated=event_repeated)

# Create epochs with the passed epoch_params
epochs = mne.Epochs(raw_data, events, event_id, **epoch_params)

# Exclude bad epochs
if exclude_bad:
epochs.drop_bad()
Expand All @@ -774,8 +783,8 @@ def app_func(row, output_path: str, seg_sec: float, resample: int = None,

# Use batch_process to process data
self.batch_process(con_func,
app_func=lambda row: app_func(row, output_path, seg_sec=seg_sec, resample=resample,
exclude_bad=exclude_bad, baseline=baseline),
app_func=lambda row: app_func(row, output_path, resample=resample,
exclude_bad=exclude_bad, **epoch_params),
is_patch=False,
result_type=None)

Expand Down Expand Up @@ -840,3 +849,32 @@ def app_func(row):
results = self.batch_process(con_func, app_func, is_patch=False, result_type="value")

self.set_column("Score", results)


def replace_paths(self, old_prefix, new_prefix):
"""
Replace the prefix of file paths in the dataset according to the provided mapping.
Parameters:
- path_mapping (dict): A dictionary where the keys are the old path prefixes and the values are the new prefixes.
Returns:
- A new instance with updated file paths.
"""

def replace_func(row):
original_path = row['File Path']
new_path = original_path
# Replace the path prefix based on the provided mapping
if original_path.startswith(old_prefix):
new_path = original_path.replace(old_prefix, new_prefix, 1) # Only replace the first occurrence
row['File Path'] = new_path
return row

copied_instance = copy.deepcopy(self)

# Process the dataset, applying the path replacement function to each row
updated_locator = self.batch_process(lambda row: True, replace_func, is_patch=False, result_type='series')

copied_instance.set_shared_attr({'locator': updated_locator})
return copied_instance
4 changes: 3 additions & 1 deletion eegunity/module_eeg_parser/eeg_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,10 @@ def __init__(self, main_instance):
else:
raise ValueError("The provided 'datasets' path is not a valid directory.")

def _process_directory(self, datasets_path):
def _process_directory(self, datasets_path, use_relative_path=False):
files_info = []
datasets_path = os.path.abspath(datasets_path) if not use_relative_path else os.path.relpath(datasets_path)

for filepath in glob.glob(datasets_path + '/**/*', recursive=True):
if os.path.isfile(filepath):
files_info.append([filepath, self.get_shared_attr()['domain_tag'], '', '', '', '', '', '', ''])
Expand Down

0 comments on commit 228e952

Please sign in to comment.