IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed when using multi-column y input

40 views Asked by At

The input train_data_dict is a nested dictionary with a multi-column y-variable by design. All the columns in y are independent variables and should not be flattened.

I'm using optuna to find the optimal hyperparameters.

Error: IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

import optuna
# Define the search space for hyperparameters
def objective(trial, train_data_dict):
    rand_indices = []

    # Define hyperparameters to be optimized
    input_dim = trial.suggest_int('input_dim', 32, 256)
    latent_dim = trial.suggest_int('latent_dim', 32, 256)
    intermediate_dim = trial.suggest_int('intermediate_dim', 64, 512)
    dropout_rate = trial.suggest_float('dropout_rate', 0.1, 0.5)
    epsilon_std = trial.suggest_float('epsilon_std', 0.1, 2.0)
    l2_reg_vae = trial.suggest_float('l2_reg_vae', 1e-5, 1e-1)
    weight_factor = trial.suggest_float('weight_factor', 0.001, 0.1)
    epochs = trial.suggest_int('epochs', 10, 100)
    batch_size = trial.suggest_int('batch_size', 8, 128)
    n_clusters = trial.suggest_int('n_clusters', 2, 10)

    # K-fold Cross-validation
    kf = KFold(n_splits=5, shuffle=True, random_state=42)

    # Perform K-fold Cross-validation on the training set
    for omics_type, omics_dict in train_data_dict.items():
        for view_key, (data, y_var) in omics_dict.items():
            # Separate the inputs
            encoder_inputs = data.values

            # Ensure that y_input has the correct shape
            y_input = y_var.values

            # Convert encoder_inputs to DataFrame
            encoder_df = pd.DataFrame(encoder_inputs, columns=data.columns, index=data.index)

            # Store adjusted Rand indices for each fold
            fold_rand_indices = []

            # Perform K-fold Cross-validation
            for train_index, val_index in kf.split(encoder_df):
                x_train, x_val = encoder_df.iloc[train_index], encoder_df.iloc[val_index]
                y_train, y_val = y_input[train_index], y_input[val_index]

                # Create VAE, encoder, and decoder with current hyperparameters
                vae_model, _, _ = create_vae(
                    input_dim=x_train.shape[1],
                    latent_dim=latent_dim,
                    intermediate_dim=intermediate_dim,
                    dropout_rate=dropout_rate,
                    train_data_dict={view_key: (x_train, y_train)},
                    epsilon_std=epsilon_std,
                    reg_strength=l2_reg_vae,
                    weight_factor=weight_factor
                )
            
                # Get latent representations
                encoded_data = vae_model.predict([x_val])

                # Perform spectral clustering on the latent representations
                _, _, consensus_labels = supervised_multi_view_spectral_clustering([encoded_data], pd.DataFrame(y_val), n_clusters=n_clusters)

                # Calculate adjusted Rand index for each column of y_val
                fold_rand_indices.extend([adjusted_rand_score(y_val[:, col, np.newaxis], consensus_labels[:, col]) for col in range(y_val.shape[1])])

            # Take the average of the Rand indices for all columns in all folds
            fold_average_rand_index = np.mean(fold_rand_indices)
            rand_indices.append(fold_average_rand_index)

    # Compute the average Rand index over all folds and views
    average_rand_index = np.mean(rand_indices)

    # Return the average clustering accuracy over all folds as the objective value
    return average_rand_index

# Optimize hyperparameters using Optuna
study = optuna.create_study(direction='maximize')  # maximize the clustering accuracy

# Optimize the objective function by passing train_data_dict
study.optimize(lambda trial: objective(trial, train_data_dict), n_trials=50)

Traceback:

---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
Input In [113], in <cell line: 5>()
      2 study = optuna.create_study(direction='maximize')  # maximize the clustering accuracy
      4 # Optimize the objective function by passing train_data_dict
----> 5 study.optimize(lambda trial: objective(trial, train_data_dict), n_trials=50)

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/study.py:451, in Study.optimize(self, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
    348 def optimize(
    349     self,
    350     func: ObjectiveFuncType,
   (...)
    357     show_progress_bar: bool = False,
    358 ) -> None:
    359     """Optimize an objective function.
    360 
    361     Optimization is done by choosing a suitable set of hyperparameter values from a given
   (...)
    449             If nested invocation of this method occurs.
    450     """
--> 451     _optimize(
    452         study=self,
    453         func=func,
    454         n_trials=n_trials,
    455         timeout=timeout,
    456         n_jobs=n_jobs,
    457         catch=tuple(catch) if isinstance(catch, Iterable) else (catch,),
    458         callbacks=callbacks,
    459         gc_after_trial=gc_after_trial,
    460         show_progress_bar=show_progress_bar,
    461     )

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:66, in _optimize(study, func, n_trials, timeout, n_jobs, catch, callbacks, gc_after_trial, show_progress_bar)
     64 try:
     65     if n_jobs == 1:
---> 66         _optimize_sequential(
     67             study,
     68             func,
     69             n_trials,
     70             timeout,
     71             catch,
     72             callbacks,
     73             gc_after_trial,
     74             reseed_sampler_rng=False,
     75             time_start=None,
     76             progress_bar=progress_bar,
     77         )
     78     else:
     79         if n_jobs == -1:

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:163, in _optimize_sequential(study, func, n_trials, timeout, catch, callbacks, gc_after_trial, reseed_sampler_rng, time_start, progress_bar)
    160         break
    162 try:
--> 163     frozen_trial = _run_trial(study, func, catch)
    164 finally:
    165     # The following line mitigates memory problems that can be occurred in some
    166     # environments (e.g., services that use computing containers such as GitHub Actions).
    167     # Please refer to the following PR for further details:
    168     # https://github.com/optuna/optuna/pull/325.
    169     if gc_after_trial:

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:251, in _run_trial(study, func, catch)
    244         assert False, "Should not reach."
    246 if (
    247     frozen_trial.state == TrialState.FAIL
    248     and func_err is not None
    249     and not isinstance(func_err, catch)
    250 ):
--> 251     raise func_err
    252 return frozen_trial

File /scg/apps/software/jupyter/python_3.9/lib/python3.9/site-packages/optuna/study/_optimize.py:200, in _run_trial(study, func, catch)
    198 with get_heartbeat_thread(trial._trial_id, study._storage):
    199     try:
--> 200         value_or_values = func(trial)
    201     except exceptions.TrialPruned as e:
    202         # TODO(mamu): Handle multi-objective cases.
    203         state = TrialState.PRUNED

Input In [113], in <lambda>(trial)
      2 study = optuna.create_study(direction='maximize')  # maximize the clustering accuracy
      4 # Optimize the objective function by passing train_data_dict
----> 5 study.optimize(lambda trial: objective(trial, train_data_dict), n_trials=50)

Input In [111], in objective(trial, train_data_dict)
     56     _, _, consensus_labels = supervised_multi_view_spectral_clustering([encoded_data], pd.DataFrame(y_val), n_clusters=n_clusters)
     58     # Calculate adjusted Rand index for each column of y_val
---> 59     fold_rand_indices.extend([adjusted_rand_score(y_val[:, col, np.newaxis], consensus_labels[:, col]) for col in range(y_val.shape[1])])
     61 # Take the average of the Rand indices for all columns in all folds
     62 fold_average_rand_index = np.mean(fold_rand_indices)

Input In [111], in <listcomp>(.0)
     56     _, _, consensus_labels = supervised_multi_view_spectral_clustering([encoded_data], pd.DataFrame(y_val), n_clusters=n_clusters)
     58     # Calculate adjusted Rand index for each column of y_val
---> 59     fold_rand_indices.extend([adjusted_rand_score(y_val[:, col, np.newaxis], consensus_labels[:, col]) for col in range(y_val.shape[1])])
     61 # Take the average of the Rand indices for all columns in all folds
     62 fold_average_rand_index = np.mean(fold_rand_indices)

IndexError: too many indices for array: array is 1-dimensional, but 2 were indexed

Input: train_data_dict

{'transcriptomics': {'transcriptomics_df': (             gene_1    gene_2    gene_3    gene_4    gene_5    gene_6  \
   sample_1   0.081271  0.285246  0.238980  0.109389  0.088941  0.070102   
   sample_12  0.690177  0.010544  0.649123  0.077980  0.791253  0.173279   
   sample_24  0.762251  0.957225  0.222334  0.812339  0.198768  0.510893   
   sample_14  0.380729  0.799106  0.678803  0.489014  0.561917  0.064196   
   
                gene_7    gene_8    gene_9   gene_10  ...   gene_41   gene_42  \
   sample_1   0.607299  0.220408  1.000000  0.959517  ...  0.553930  0.510626   
   sample_12  0.444133  0.239353  0.463949  0.241307  ...  0.585237  0.526544   
   sample_24  0.063147  0.364753  0.320597  0.933455  ...  0.101972  0.805806   
   sample_14  0.143937  0.719831  0.257680  0.295210  ...  0.152115  0.124360   
   
               gene_43   gene_44   gene_45   gene_46   gene_47   gene_48  \
   sample_1   0.749899  0.060830  1.000000  0.159956  0.293132  0.472478   
   sample_12  0.221148  0.144944  0.706477  0.828765  0.392234  0.806104   
   sample_24  1.000000  0.609475  0.858978  0.613737  0.876306  0.064037   
   sample_14  0.650961  0.792480  0.870467  0.183074  1.000000  0.092793   
   
               gene_49   gene_50  
   sample_1   0.307162  0.121577  
   sample_12  0.131183  0.025633  
   sample_24  0.627487  0.221973  
   sample_14  0.898939  0.464103  
   
   [4 rows x 50 columns],
              Overall_Survival  Immune_Response
   sample_1                  0                1
   sample_12                 1                0
   sample_24                 0                1
   sample_14                 1                1),
  'mrna_deconv': (           mrna_cell_type_1  mrna_cell_type_2  mrna_cell_type_3  \
   sample_1           0.259693          0.701656          0.357655   
   sample_12          0.890338          0.846845          0.739739   
   sample_24          0.065137          0.109415          0.160966   
   sample_14          0.669036          0.985836          0.000000   
   
              mrna_cell_type_4  mrna_cell_type_5  mrna_cell_type_6  \
   sample_1           0.147154          0.248317          0.515681   
   sample_12          0.850631          0.443040          0.468561   
   sample_24          0.000000          0.262814          0.259445   
   sample_14          0.381551          0.928021          0.619240   
   
              mrna_cell_type_7  mrna_cell_type_8  mrna_cell_type_9  \
   sample_1           0.397258          0.060051          1.000000   
   sample_12          0.809078          0.354320          0.513866   
   sample_24          1.000000          0.259792          0.442548   
   sample_14          0.727470          0.707928          0.572115   
   
              mrna_cell_type_10  mrna_cell_type_11  mrna_cell_type_12  \
   sample_1            0.645953           0.080449           0.634015   
   sample_12           0.954660           0.441691           0.780304   
   sample_24           0.934251           0.628559           0.832609   
   sample_14           0.082558           0.494718           0.855966   
   
              mrna_cell_type_13  mrna_cell_type_14  mrna_cell_type_15  \
   sample_1            0.693858           0.810538           0.462534   
   sample_12           0.233477           0.802866           0.675646   
   sample_24           0.435969           0.422486           0.274812   
   sample_14           0.941884           0.718656           0.912517   
   
              mrna_cell_type_16  mrna_cell_type_17  mrna_cell_type_18  \
   sample_1            0.771869           0.055841           0.747881   
   sample_12           0.901962           0.571642           0.137232   
   sample_24           0.123668           0.787283           0.236869   
   sample_14           0.556080           0.608710           0.485642   
   
              mrna_cell_type_19  mrna_cell_type_20  
   sample_1            0.690401            0.00000  
   sample_12           0.804235            1.00000  
   sample_24           0.835204            0.43159  
   sample_14           1.000000            0.25220  ,
              Overall_Survival  Immune_Response
   sample_1                  0                1
   sample_12                 1                0
   sample_24                 0                1
   sample_14                 1                1)},
 'epigenomics': {'epigenomics_df': (           methyl_1  methyl_2  methyl_3  methyl_4  methyl_5  methyl_6  \
   sample_1   0.518432  0.562772  1.000000  0.574512  0.197689  0.013935   
   sample_12  0.893894  0.778101  0.238617  0.689478  1.000000  0.413729   
   sample_24  0.443299  0.763554  0.675737  0.918801  0.616425  0.625535   
   sample_14  0.849289  0.071325  0.461540  0.204233  0.683559  0.000000   
   
              methyl_7  methyl_8  methyl_9  methyl_10  ...  methyl_21  methyl_22  \
   sample_1   0.579390  0.170915  0.741125   0.345231  ...   0.634699   0.918234   
   sample_12  0.765269  0.414511  0.952738   0.835907  ...   1.000000   0.606318   
   sample_24  0.935773  0.348954  0.815997   0.765416  ...   0.589030   0.400325   
   sample_14  0.732785  0.431157  0.194957   0.018413  ...   0.379741   0.222882   
   
              methyl_23  methyl_24  methyl_25  methyl_26  methyl_27  methyl_28  \
   sample_1    0.942603   0.824346   0.332977   1.000000   0.184222   0.399519   
   sample_12   0.182254   0.293121   0.234942   0.003985   0.847842   0.635192   
   sample_24   0.637001   0.575127   0.223733   0.497607   0.893084   0.105603   
   sample_14   0.666715   0.573543   0.906313   0.116190   0.255035   0.488607   
   
              methyl_29  methyl_30  
   sample_1    0.561876   0.551404  
   sample_12   0.000847   0.431484  
   sample_24   0.708240   0.544635  
   sample_14   0.431328   0.646670  
   
   [4 rows x 30 columns],
              Overall_Survival  Immune_Response
   sample_1                  0                1
   sample_12                 1                0
   sample_24                 0                1
   sample_14                 1                1),
  'meth_deconv': (           meth_cell_type_1  meth_cell_type_2  meth_cell_type_3  \
   sample_1           0.790354          0.905251          0.784297   
   sample_12          0.465173          0.633696          0.640522   
   sample_24          0.607869          0.902261          0.207588   
   sample_14          0.757622          0.031540          0.933623   
   
              meth_cell_type_4  meth_cell_type_5  meth_cell_type_6  \
   sample_1           0.655590          0.843466          0.813060   
   sample_12          0.110660          0.892560          0.098858   
   sample_24          0.379688          0.767904          0.075993   
   sample_14          0.000000          0.576410          0.783901   
   
              meth_cell_type_7  meth_cell_type_8  meth_cell_type_9  \
   sample_1           0.310647          0.294450          0.275912   
   sample_12          1.000000          0.666482          0.000000   
   sample_24          0.483003          0.988328          0.743239   
   sample_14          0.458792          0.486113          0.341735   
   
              meth_cell_type_10  meth_cell_type_11  meth_cell_type_12  \
   sample_1            0.341193           0.542600           0.191132   
   sample_12           0.482228           0.412591           0.703000   
   sample_24           0.102253           0.696404           0.776406   
   sample_14           0.315953           0.146682           0.384930   
   
              meth_cell_type_13  meth_cell_type_14  meth_cell_type_15  \
   sample_1            0.713090           0.361978           0.414776   
   sample_12           0.144266           0.473406           0.000000   
   sample_24           0.261975           0.838333           0.593856   
   sample_14           0.292408           0.990236           0.387011   
   
              meth_cell_type_16  meth_cell_type_17  meth_cell_type_18  \
   sample_1            0.022376           0.629440           0.949370   
   sample_12           0.689384           0.072314           0.021034   
   sample_24           0.512479           0.579598           0.290853   
   sample_14           0.161483           0.229933           0.562100   
   
              meth_cell_type_19  meth_cell_type_20  
   sample_1            0.618516           0.161318  
   sample_12           0.381948           0.201290  
   sample_24           0.659075           0.658301  
   sample_14           0.701604           0.566443  ,
              Overall_Survival  Immune_Response
   sample_1                  0                1
   sample_12                 1                0
   sample_24                 0                1
   sample_14                 1                1)},
 'proteomics': {'proteomics_df': (           protein_1  protein_2  protein_3  protein_4  protein_5  protein_6  \
   sample_1    0.434099   0.568244   0.955054   0.515010   0.275183   1.000000   
   sample_12   0.000000   0.621718   0.685375   0.148019   0.000000   0.853397   
   sample_24   1.000000   0.066931   0.357629   0.229160   1.000000   0.273934   
   sample_14   0.321337   0.231934   0.583539   0.830817   0.846940   0.318258   
   
              protein_7  protein_8  protein_9  protein_10  protein_11  \
   sample_1    1.000000   0.556124   0.062139    0.325602    0.193152   
   sample_12   0.406860   0.000000   0.522786    0.020663    0.845124   
   sample_24   0.396358   0.929895   0.639372    0.208086    0.025799   
   sample_14   0.905107   0.502941   0.507171    1.000000    0.781555   
   
              protein_12  protein_13  protein_14  protein_15  protein_16  \
   sample_1     0.134788    0.514226    0.461197    0.077336    0.434350   
   sample_12    0.120330    0.852330    0.000000    0.949000    0.603237   
   sample_24    0.446697    0.161775    0.517867    0.876150    0.110444   
   sample_14    0.386148    0.486579    1.000000    0.014181    0.171252   
   
              protein_17  protein_18  protein_19  protein_20  
   sample_1     0.109298    0.530479    0.572973    0.292885  
   sample_12    0.207121    1.000000    0.271716    0.854000  
   sample_24    0.221096    0.738751    0.555530    0.332737  
   sample_14    0.964332    0.341878    0.820894    0.095262  ,
              Overall_Survival  Immune_Response
   sample_1                  0                1
   sample_12                 1                0
   sample_24                 0                1
   sample_14                 1                1)},
 'mutation': {'mutation_df': (           gene_1  gene_2  gene_3  gene_4  gene_5  gene_6  gene_7  gene_8  \
   sample_1        1       1       0       1       1       0       0       0   
   sample_12       1       1       0       0       0       0       0       1   
   sample_24       0       1       0       1       1       0       1       0   
   sample_14       0       1       0       1       0       0       1       1   
   
              gene_9  gene_10  gene_11  gene_12  gene_13  gene_14  gene_15  \
   sample_1        1        0        1        0        0        0        0   
   sample_12       1        0        1        1        0        1        0   
   sample_24       1        1        1        0        1        1        0   
   sample_14       0        1        0        0        1        1        0   
   
              gene_16  gene_17  gene_18  gene_19  gene_20  
   sample_1         0        1        1        1        1  
   sample_12        0        0        0        1        1  
   sample_24        0        0        0        1        0  
   sample_14        1        0        0        1        0  ,
              Overall_Survival  Immune_Response
   sample_1                  0                1
   sample_12                 1                0
   sample_24                 0                1
   sample_14                 1                1)}} 
0

There are 0 answers