The input train_data_dict is a nested dictionary with a multi-column y-variable by design.
All the columns in y are independent variables and should not be flattened.
I'm using optuna to find the optimal hyperparameters.
Error: IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed
import optuna
# Define the search space for hyperparameters
def objective(trial, train_data_dict):
rand_indices = []
# Define hyperparameters to be optimized
input_dim = trial.suggest_int('input_dim', 32, 256)
latent_dim = trial.suggest_int('latent_dim', 32, 256)
intermediate_dim = trial.suggest_int('intermediate_dim', 64, 512)
dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
epsilon_std = trial.suggest_float('epsilon_std', 0.1, 2.0)
l2_reg_vae = trial.suggest_float('l2_reg_vae', 1e-5, 1e-1)
weight_factor = trial.suggest_float('weight_factor', 0.001, 0.1)
epochs = trial.suggest_int('epochs', 10, 100)
batch_size = trial.suggest_int('batch_size', 8, 128)
n_clusters = trial.suggest_int('n_clusters', 2, 10)
# K-fold Cross-validation
kf = KFold(n_splits=5, shuffle=True, random_state=42)
# Perform K-fold Cross-validation on the training set
for omics_type, omics_dict in train_data_dict.items():
for view_key, (data, y_var) in omics_dict.items():
# Separate the inputs
encoder_inputs = data.values
# Ensure that y_input has the correct shape
y_input = y_var.values
# Convert encoder_inputs to DataFrame
encoder_df = pd.DataFrame(encoder_inputs, columns=data.columns, index=data.index)
# Store adjusted Rand indices for each fold
fold_rand_indices = []
# Perform K-fold Cross-validation
for train_index, val_index in kf.split(encoder_df):
x_train, x_val = encoder_df.iloc[train_index], encoder_df.iloc[val_index]
y_train, y_val = y_input[train_index], y_input[val_index]
# Create VAE, encoder, and decoder with current hyperparameters
vae_model, _, _ = create_vae(
input_dim=x_train.shape[1],
latent_dim=latent_dim,
intermediate_dim=intermediate_dim,
dropout_rate=dropout_rate,
train_data_dict={view_key: (x_train, y_train)},
epsilon_std=epsilon_std,
reg_strength=l2_reg_vae,
weight_factor=weight_factor
)
# Get latent representations
encoded_data = vae_model.predict([x_val])
# Perform spectral clustering on the latent representations
_, _, consensus_labels = supervised_multi_view_spectral_clustering([encoded_data], pd.DataFrame(y_val), n_clusters=n_clusters)
# Calculate adjusted Rand index for each column of y_val
fold_rand_indices.extend([adjusted_rand_score(y_val[:, col, np.newaxis], consensus_labels[:, col]) for col in range(y_val.shape[1])])
# Take the average of the Rand indices for all columns in all folds
fold_average_rand_index = np.mean(fold_rand_indices)
rand_indices.append(fold_average_rand_index)
# Compute the average Rand index over all folds and views
average_rand_index = np.mean(rand_indices)
# Return the average clustering accuracy over all folds as the objective value
return average_rand_index
# Optimize hyperparameters using Optuna
study = optuna.create_study(direction='maximize') # maximize the clustering accuracy
# Optimize the objective function by passing train_data_dict
study.optimize(lambda trial: objective(trial, train_data_dict), n_trials=50)
Traceback:
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
Input In [113], in <cell line: 5>()
2 study = optuna.create_study(direction='maximize') # maximize the clustering accuracy
4 # Optimize the objective function by passing train_data_dict
----> 5 study.optimize(lambda trial: objective(trial, train_data_dict), n_trials=50)
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/study.py:451, in Study.optimize(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
348 def optimize(
349 self,
350 func: ObjectiveFuncType,
(...)
357 show_progress_bar: bool = False,
358 ) -> None:
359 """Optimize an objective function.
360
361 Optimization is done by choosing a suitable set of hyperparameter values from a given
(...)
449 If nested invocation of this method occurs.
450 """
--> 451 _optimize(
452 study=self,
453 func=func,
454 n_trials=n_trials,
455 timeout=timeout,
456 n_jobs=n_jobs,
457 catch=tuple(catch) if isinstance(catch, Iterable) else (catch,),
458 callbacks=callbacks,
459 gc_after_trial=gc_after_trial,
460 show_progress_bar=show_progress_bar,
461 )
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:66, in _optimize(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
64 try:
65 if n_jobs == 1:
---> 66 _optimize_sequential(
67 study,
68 func,
69 n_trials,
70 timeout,
71 catch,
72 callbacks,
73 gc_after_trial,
74 reseed_sampler_rng=False,
75 time_start=None,
76 progress_bar=progress_bar,
77 )
78 else:
79 if n_jobs == -1:
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:163, in _optimize_sequential(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)
160 break
162 try:
--> 163 frozen_trial = _run_trial(study, func, catch)
164 finally:
165 # The following line mitigates memory problems that can be occurred in some
166 # environments (e.g., services that use computing containers such as GitHub Actions).
167 # Please refer to the following PR for further details:
168 # https://github.com/optuna/optuna/pull/325.
169 if gc_after_trial:
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:251, in _run_trial(study, func, catch)
244 assert False, "Should not reach."
246 if (
247 frozen_trial.state == TrialState.FAIL
248 and func_err is not None
249 and not isinstance(func_err, catch)
250 ):
--> 251 raise func_err
252 return frozen_trial
File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:200, in _run_trial(study, func, catch)
198 with get_heartbeat_thread(trial._trial_id, study._storage):
199 try:
--> 200 value_or_values = func(trial)
201 except exceptions.TrialPruned as e:
202 # TODO(mamu): Handle multi-objective cases.
203 state = TrialState.PRUNED
Input In [113], in <lambda>(trial)
2 study = optuna.create_study(direction='maximize') # maximize the clustering accuracy
4 # Optimize the objective function by passing train_data_dict
----> 5 study.optimize(lambda trial: objective(trial, train_data_dict), n_trials=50)
Input In [111], in objective(trial, train_data_dict)
56 _, _, consensus_labels = supervised_multi_view_spectral_clustering([encoded_data], pd.DataFrame(y_val), n_clusters=n_clusters)
58 # Calculate adjusted Rand index for each column of y_val
---> 59 fold_rand_indices.extend([adjusted_rand_score(y_val[:, col, np.newaxis], consensus_labels[:, col]) for col in range(y_val.shape[1])])
61 # Take the average of the Rand indices for all columns in all folds
62 fold_average_rand_index = np.mean(fold_rand_indices)
Input In [111], in <listcomp>(.0)
56 _, _, consensus_labels = supervised_multi_view_spectral_clustering([encoded_data], pd.DataFrame(y_val), n_clusters=n_clusters)
58 # Calculate adjusted Rand index for each column of y_val
---> 59 fold_rand_indices.extend([adjusted_rand_score(y_val[:, col, np.newaxis], consensus_labels[:, col]) for col in range(y_val.shape[1])])
61 # Take the average of the Rand indices for all columns in all folds
62 fold_average_rand_index = np.mean(fold_rand_indices)
IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed
Input:
train_data_dict
{'transcriptomics': {'transcriptomics_df': ( gene_1 gene_2 gene_3 gene_4 gene_5 gene_6 \
sample_1 0.081271 0.285246 0.238980 0.109389 0.088941 0.070102
sample_12 0.690177 0.010544 0.649123 0.077980 0.791253 0.173279
sample_24 0.762251 0.957225 0.222334 0.812339 0.198768 0.510893
sample_14 0.380729 0.799106 0.678803 0.489014 0.561917 0.064196
gene_7 gene_8 gene_9 gene_10 ... gene_41 gene_42 \
sample_1 0.607299 0.220408 1.000000 0.959517 ... 0.553930 0.510626
sample_12 0.444133 0.239353 0.463949 0.241307 ... 0.585237 0.526544
sample_24 0.063147 0.364753 0.320597 0.933455 ... 0.101972 0.805806
sample_14 0.143937 0.719831 0.257680 0.295210 ... 0.152115 0.124360
gene_43 gene_44 gene_45 gene_46 gene_47 gene_48 \
sample_1 0.749899 0.060830 1.000000 0.159956 0.293132 0.472478
sample_12 0.221148 0.144944 0.706477 0.828765 0.392234 0.806104
sample_24 1.000000 0.609475 0.858978 0.613737 0.876306 0.064037
sample_14 0.650961 0.792480 0.870467 0.183074 1.000000 0.092793
gene_49 gene_50
sample_1 0.307162 0.121577
sample_12 0.131183 0.025633
sample_24 0.627487 0.221973
sample_14 0.898939 0.464103
[4 rows x 50 columns],
Overall_Survival Immune_Response
sample_1 0 1
sample_12 1 0
sample_24 0 1
sample_14 1 1),
'mrna_deconv': ( mrna_cell_type_1 mrna_cell_type_2 mrna_cell_type_3 \
sample_1 0.259693 0.701656 0.357655
sample_12 0.890338 0.846845 0.739739
sample_24 0.065137 0.109415 0.160966
sample_14 0.669036 0.985836 0.000000
mrna_cell_type_4 mrna_cell_type_5 mrna_cell_type_6 \
sample_1 0.147154 0.248317 0.515681
sample_12 0.850631 0.443040 0.468561
sample_24 0.000000 0.262814 0.259445
sample_14 0.381551 0.928021 0.619240
mrna_cell_type_7 mrna_cell_type_8 mrna_cell_type_9 \
sample_1 0.397258 0.060051 1.000000
sample_12 0.809078 0.354320 0.513866
sample_24 1.000000 0.259792 0.442548
sample_14 0.727470 0.707928 0.572115
mrna_cell_type_10 mrna_cell_type_11 mrna_cell_type_12 \
sample_1 0.645953 0.080449 0.634015
sample_12 0.954660 0.441691 0.780304
sample_24 0.934251 0.628559 0.832609
sample_14 0.082558 0.494718 0.855966
mrna_cell_type_13 mrna_cell_type_14 mrna_cell_type_15 \
sample_1 0.693858 0.810538 0.462534
sample_12 0.233477 0.802866 0.675646
sample_24 0.435969 0.422486 0.274812
sample_14 0.941884 0.718656 0.912517
mrna_cell_type_16 mrna_cell_type_17 mrna_cell_type_18 \
sample_1 0.771869 0.055841 0.747881
sample_12 0.901962 0.571642 0.137232
sample_24 0.123668 0.787283 0.236869
sample_14 0.556080 0.608710 0.485642
mrna_cell_type_19 mrna_cell_type_20
sample_1 0.690401 0.00000
sample_12 0.804235 1.00000
sample_24 0.835204 0.43159
sample_14 1.000000 0.25220 ,
Overall_Survival Immune_Response
sample_1 0 1
sample_12 1 0
sample_24 0 1
sample_14 1 1)},
'epigenomics': {'epigenomics_df': ( methyl_1 methyl_2 methyl_3 methyl_4 methyl_5 methyl_6 \
sample_1 0.518432 0.562772 1.000000 0.574512 0.197689 0.013935
sample_12 0.893894 0.778101 0.238617 0.689478 1.000000 0.413729
sample_24 0.443299 0.763554 0.675737 0.918801 0.616425 0.625535
sample_14 0.849289 0.071325 0.461540 0.204233 0.683559 0.000000
methyl_7 methyl_8 methyl_9 methyl_10 ... methyl_21 methyl_22 \
sample_1 0.579390 0.170915 0.741125 0.345231 ... 0.634699 0.918234
sample_12 0.765269 0.414511 0.952738 0.835907 ... 1.000000 0.606318
sample_24 0.935773 0.348954 0.815997 0.765416 ... 0.589030 0.400325
sample_14 0.732785 0.431157 0.194957 0.018413 ... 0.379741 0.222882
methyl_23 methyl_24 methyl_25 methyl_26 methyl_27 methyl_28 \
sample_1 0.942603 0.824346 0.332977 1.000000 0.184222 0.399519
sample_12 0.182254 0.293121 0.234942 0.003985 0.847842 0.635192
sample_24 0.637001 0.575127 0.223733 0.497607 0.893084 0.105603
sample_14 0.666715 0.573543 0.906313 0.116190 0.255035 0.488607
methyl_29 methyl_30
sample_1 0.561876 0.551404
sample_12 0.000847 0.431484
sample_24 0.708240 0.544635
sample_14 0.431328 0.646670
[4 rows x 30 columns],
Overall_Survival Immune_Response
sample_1 0 1
sample_12 1 0
sample_24 0 1
sample_14 1 1),
'meth_deconv': ( meth_cell_type_1 meth_cell_type_2 meth_cell_type_3 \
sample_1 0.790354 0.905251 0.784297
sample_12 0.465173 0.633696 0.640522
sample_24 0.607869 0.902261 0.207588
sample_14 0.757622 0.031540 0.933623
meth_cell_type_4 meth_cell_type_5 meth_cell_type_6 \
sample_1 0.655590 0.843466 0.813060
sample_12 0.110660 0.892560 0.098858
sample_24 0.379688 0.767904 0.075993
sample_14 0.000000 0.576410 0.783901
meth_cell_type_7 meth_cell_type_8 meth_cell_type_9 \
sample_1 0.310647 0.294450 0.275912
sample_12 1.000000 0.666482 0.000000
sample_24 0.483003 0.988328 0.743239
sample_14 0.458792 0.486113 0.341735
meth_cell_type_10 meth_cell_type_11 meth_cell_type_12 \
sample_1 0.341193 0.542600 0.191132
sample_12 0.482228 0.412591 0.703000
sample_24 0.102253 0.696404 0.776406
sample_14 0.315953 0.146682 0.384930
meth_cell_type_13 meth_cell_type_14 meth_cell_type_15 \
sample_1 0.713090 0.361978 0.414776
sample_12 0.144266 0.473406 0.000000
sample_24 0.261975 0.838333 0.593856
sample_14 0.292408 0.990236 0.387011
meth_cell_type_16 meth_cell_type_17 meth_cell_type_18 \
sample_1 0.022376 0.629440 0.949370
sample_12 0.689384 0.072314 0.021034
sample_24 0.512479 0.579598 0.290853
sample_14 0.161483 0.229933 0.562100
meth_cell_type_19 meth_cell_type_20
sample_1 0.618516 0.161318
sample_12 0.381948 0.201290
sample_24 0.659075 0.658301
sample_14 0.701604 0.566443 ,
Overall_Survival Immune_Response
sample_1 0 1
sample_12 1 0
sample_24 0 1
sample_14 1 1)},
'proteomics': {'proteomics_df': ( protein_1 protein_2 protein_3 protein_4 protein_5 protein_6 \
sample_1 0.434099 0.568244 0.955054 0.515010 0.275183 1.000000
sample_12 0.000000 0.621718 0.685375 0.148019 0.000000 0.853397
sample_24 1.000000 0.066931 0.357629 0.229160 1.000000 0.273934
sample_14 0.321337 0.231934 0.583539 0.830817 0.846940 0.318258
protein_7 protein_8 protein_9 protein_10 protein_11 \
sample_1 1.000000 0.556124 0.062139 0.325602 0.193152
sample_12 0.406860 0.000000 0.522786 0.020663 0.845124
sample_24 0.396358 0.929895 0.639372 0.208086 0.025799
sample_14 0.905107 0.502941 0.507171 1.000000 0.781555
protein_12 protein_13 protein_14 protein_15 protein_16 \
sample_1 0.134788 0.514226 0.461197 0.077336 0.434350
sample_12 0.120330 0.852330 0.000000 0.949000 0.603237
sample_24 0.446697 0.161775 0.517867 0.876150 0.110444
sample_14 0.386148 0.486579 1.000000 0.014181 0.171252
protein_17 protein_18 protein_19 protein_20
sample_1 0.109298 0.530479 0.572973 0.292885
sample_12 0.207121 1.000000 0.271716 0.854000
sample_24 0.221096 0.738751 0.555530 0.332737
sample_14 0.964332 0.341878 0.820894 0.095262 ,
Overall_Survival Immune_Response
sample_1 0 1
sample_12 1 0
sample_24 0 1
sample_14 1 1)},
'mutation': {'mutation_df': ( gene_1 gene_2 gene_3 gene_4 gene_5 gene_6 gene_7 gene_8 \
sample_1 1 1 0 1 1 0 0 0
sample_12 1 1 0 0 0 0 0 1
sample_24 0 1 0 1 1 0 1 0
sample_14 0 1 0 1 0 0 1 1
gene_9 gene_10 gene_11 gene_12 gene_13 gene_14 gene_15 \
sample_1 1 0 1 0 0 0 0
sample_12 1 0 1 1 0 1 0
sample_24 1 1 1 0 1 1 0
sample_14 0 1 0 0 1 1 0
gene_16 gene_17 gene_18 gene_19 gene_20
sample_1 0 1 1 1 1
sample_12 0 0 0 1 1
sample_24 0 0 0 1 0
sample_14 1 0 0 1 0 ,
Overall_Survival Immune_Response
sample_1 0 1
sample_12 1 0
sample_24 0 1
sample_14 1 1)}}