-
Notifications
You must be signed in to change notification settings - Fork 0
/
synthetic_data_generation.py
123 lines (92 loc) · 5.25 KB
/
synthetic_data_generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from pandas import json_normalize
import pandas as pd
import numpy as np
import copy
from sdmetrics.column_pairs import CorrelationSimilarity, ContingencySimilarity
from sdmetrics.single_table import NewRowSynthesis
from sdmetrics.single_column import MissingValueSimilarity, TVComplement
from sdv.metadata import SingleTableMetadata
from sdv.single_table import CopulaGANSynthesizer, TVAESynthesizer, CTGANSynthesizer
EPOCHS = 5000
def generate_synth_C(data: pd.DataFrame, generator=TVAESynthesizer, n_samples: int = 3000) -> tuple[pd.DataFrame, dict]:
# sample data
data = data.sample(frac=1)
# create metedata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
metadata.update_column(column_name='saldo', sdtype='categorical')
# initialize generator
synthesizer = generator(metadata, epochs=EPOCHS,
enforce_min_max_values=True, enforce_rounding=False
# numerical_distributions={
# 'amenities_fee': 'beta',
# 'checkin_date': 'uniform'}
)
# train generator
synthesizer.fit(data)
##############################################################################
# Quality assesment with statistical significance
##############################################################################
synth_metrics = {'corr_sim': [], 'TV': [], 'new_row': []}
for i in range(100):
synth_data = synthesizer.sample(num_rows=len(data))
synth_metrics['corr_sim'].append(CorrelationSimilarity.compute(
real_data=data[data.columns[:].values],
synthetic_data=synth_data[data.columns[:].values],
coefficient='Pearson')
)
synth_metrics['TV'].append(TVComplement.compute(
real_data=data['saldo'],
synthetic_data=synth_data['saldo'])
)
synth_metrics['new_row'].append(NewRowSynthesis.compute(
real_data=data,
synthetic_data=synth_data,
metadata=metadata,
numerical_match_tolerance=0.1,
synthetic_sample_size=20)
)
synth_metrics_mean = {'corr_sim': np.mean(synth_metrics['corr_sim']),
'TV': np.mean(synth_metrics['TV']),
'new_row': np.mean(synth_metrics['new_row'])}
##############################################################################
synth_data = synthesizer.sample(num_rows=n_samples)
return synth_data, synth_metrics_mean
def generate_synth_R(data: pd.DataFrame, generator=TVAESynthesizer, n_samples: int = 3000) -> tuple[pd.DataFrame, dict]:
# sample data
data = data.sample(frac=1)
# create metedata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(data)
# initialize generator
synthesizer = generator(metadata, epochs=EPOCHS,
enforce_min_max_values=True, enforce_rounding=False
# numerical_distributions={
# 'amenities_fee': 'beta',
# 'checkin_date': 'uniform'}
)
# train generator
synthesizer.fit(data)
##############################################################################
# Quality assesment with statistical significance
##############################################################################
synth_metrics = {'corr_sim': [], 'new_row': []}
for i in range(100):
synth_data = synthesizer.sample(num_rows=len(data))
synth_metrics['corr_sim'].append(CorrelationSimilarity.compute(
real_data=data[data.columns[:].values],
synthetic_data=synth_data[data.columns[:].values],
coefficient='Pearson')
)
synth_metrics['new_row'].append(NewRowSynthesis.compute(
real_data=data,
synthetic_data=synth_data,
metadata=metadata,
numerical_match_tolerance=0.1,
synthetic_sample_size=20)
)
synth_metrics_mean = {'corr_sim': np.mean(synth_metrics['corr_sim']),
'new_row': np.mean(synth_metrics['new_row'])}
##############################################################################
synth_data = synthesizer.sample(num_rows=n_samples)
return synth_data, synth_metrics_mean