How to implement PySpark StandardScaler on subset of columns?

4k views Asked by At

I want to use pyspark StandardScaler on 6 out of 10 columns in my dataframe. This will be part of a pipeline.

The inputCol parameter seems to expect a vector, which I can pass in after using VectorAssembler on all my features, but this scales all 10 features. I don’t want to scale the other 4 features because they are binary and I want unstandardized coefficients for them.

Am I supposed to use vector assembler on the 6 features, scale them, then use vector assembler again on this scaled features vector and the remaining 4 features? I would end up with a vector within a vector and I’m not sure this will work.

What’s the right way to do this? An example is appreciated.

1

There are 1 answers

2
Sreeram TP On BEST ANSWER

You can do this by using VectorAssembler. They key is you have to extract the columns from the assembler output. See the code below for a working example,

from pyspark.ml.feature import MinMaxScaler, StandardScaler
from pyspark.ml.feature import VectorAssembler
import pandas as pd
import numpy as np
import random

df = pd.DataFrame()
df['a'] = random.sample(range(100), 10)
df['b'] = random.sample(range(100), 10)
df['c'] = random.sample(range(100), 10)
df['d'] = random.sample(range(100), 10)
df['e'] = random.sample(range(100), 10)

sdf = sc.createDataFrame(df)

sdf.show()

+---+---+---+---+---+
|  a|  b|  c|  d|  e|
+---+---+---+---+---+
| 51| 13|  6|  5| 26|
| 18| 29| 19| 81| 28|
| 34|  1| 36| 57| 87|
| 56| 86| 51| 52| 48|
| 36| 49| 33| 15| 54|
| 87| 53| 47| 89| 85|
|  7| 14| 55| 13| 98|
| 70| 50| 32| 39| 58|
| 80| 20| 25| 54| 37|
| 40| 33| 44| 83| 27|
+---+---+---+---+---+

cols_to_scale = ['c', 'd', 'e']
cols_to_keep_unscaled = ['a', 'b']

scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")
assembler = VectorAssembler().setInputCols(cols_to_scale).setOutputCol("features")
sdf_transformed = assembler.transform(sdf)
scaler_model = scaler.fit(sdf_transformed.select("features"))
sdf_scaled = scaler_model.transform(sdf_transformed)

sdf_scaled.show()

+---+---+---+---+---+----------------+--------------------+
|  a|  b|  c|  d|  e|        features|      scaledFeatures|
+---+---+---+---+---+----------------+--------------------+
| 51| 13|  6|  5| 26|  [6.0,5.0,26.0]|[0.39358015146628...|
| 18| 29| 19| 81| 28|[19.0,81.0,28.0]|[1.24633714630991...|
| 34|  1| 36| 57| 87|[36.0,57.0,87.0]|[2.36148090879773...|
| 56| 86| 51| 52| 48|[51.0,52.0,48.0]|[3.34543128746345...|
| 36| 49| 33| 15| 54|[33.0,15.0,54.0]|[2.16469083306459...|
| 87| 53| 47| 89| 85|[47.0,89.0,85.0]|[3.08304451981926...|
|  7| 14| 55| 13| 98|[55.0,13.0,98.0]|[3.60781805510765...|
| 70| 50| 32| 39| 58|[32.0,39.0,58.0]|[2.09909414115354...|
| 80| 20| 25| 54| 37|[25.0,54.0,37.0]|[1.63991729777620...|
| 40| 33| 44| 83| 27|[44.0,83.0,27.0]|[2.88625444408612...|
+---+---+---+---+---+----------------+--------------------+

# Function just to convert to help build data frame
def extract(row):
  return (row.a, row.b,) + tuple(row.scaledFeatures.toArray().tolist())

sdf_scaled = sdf_scaled.select(*cols_to_keep_unscaled, "scaledFeatures").rdd \
        .map(extract).toDF(cols_to_keep_unscaled + cols_to_scale)
  
  
sdf_scaled.show()


+---+---+------------------+-------------------+------------------+
|  a|  b|                 c|                  d|                 e|
+---+---+------------------+-------------------+------------------+
| 51| 13|0.3935801514662892|0.16399957083190683|0.9667572801316145|
| 18| 29| 1.246337146309916|  2.656793047476891|1.0411232247571234|
| 34|  1|2.3614809087977355| 1.8695951074837378|3.2349185912096337|
| 56| 86|3.3454312874634584| 1.7055955366518312|1.7847826710122114|
| 36| 49| 2.164690833064591|0.49199871249572047| 2.007880504888738|
| 87| 53| 3.083044519819266| 2.9191923608079415|3.1605526465841245|
|  7| 14|3.6078180551076513| 0.4263988841629578| 3.643931286649932|
| 70| 50|2.0990941411535426| 1.2791966524888734|2.1566123941397555|
| 80| 20| 1.639917297776205| 1.7711953649845937| 1.375769975571913|
| 40| 33|2.8862544440861213| 2.7223928758096534| 1.003940252444369|
+---+---+------------------+-------------------+------------------+