Pandas rolling window: developing rule based on window values

291 views Asked by At

I'm working on a neonatal project with the long story short of it being that neonates are assigned a certain score based on symptoms they have at a given time point, and based on how their scores change over time, we decide whether to increase medicine dosages, keep them the same, or wean them off. We denote these 3 states numerically as either +1 (increase), 0 (maintain), or -1 (weaning). The rules to decide what to do are as follows:

  • Increase dosage if sum of 3 consecutive scores >= 24 OR a single score is >= 12.
  • Maintain dose if you don't meet the rules for either increase or decrease dosages
  • Lower dose if there's at least 48 hours without needing to increase dose, the sum of the 3 most recent scores is <18, AND no single score is >8.

With help from people here, we have code that accounts for increasing dosages and maintaining dosages. However, I'm struggling to write the rule to determine how to lower dosages. Here's a sample of code we have:

import pandas as pd

df = pd.DataFrame({
   'baby': ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B','B', 'B', 'B', 'B', 'B','B','B'],
   'dateandtime':  ['8/2/2009  5:00:00 PM', '7/19/2009  5:00:00 PM', '7/19/2009  5:00:00 PM', '7/17/2009  6:00:00 AM','7/17/2009  12:01:00 AM', '7/14/2009  12:01:00 AM', '7/19/2009  5:00:00 AM', '7/16/2009  9:00:00 PM','7/19/2009  9:00:00 AM', '7/14/2009  6:00:00 PM', '7/15/2009  3:04:00 PM', '7/20/2009  5:00:00 PM','7/16/2009  12:01:00 AM', '7/18/2009  1:00:00 PM', '7/16/2009  6:00:00 AM', '7/13/2009  9:00:00 PM','7/19/2009  1:00:00 AM','7/15/2009  12:04:00 AM'],
   'score':  [6, 3, 3, 5, 10, 14, 5, 4, 11, 4, 4, 6, 7, 4, 6, 12, 6, 6]
    })

df.dateandtime = pd.to_datetime(df['dateandtime']) # change column type for ease of indexing
df = df.set_index('dateandtime')
df.sort_index(inplace = True)
df = df[~df.index.duplicated()] #Remove any duplicated rows

#Calculate conditions
df['sum_3_scores'] = df.groupby('baby')['score'].rolling(3).sum().reset_index(0,drop=True)
df['max_1_score'] = df.groupby('baby')['score'].rolling(1).max().reset_index(0,drop=True)

#you don't nead to calculate the 24hr mean because the 48hr max is 8 the 24hr mean will also be < 8 
#df['mean_24hr_score'] = df.groupby('baby')['score'].rolling('24h').mean().reset_index(0,drop=True)

#scoring logic
def score(data):
    if data['sum_3_scores'] >= 24 or data['max_1_score'] >= 12:
        return 1
    return 0

df['rule'] = df.apply(score, axis = 1)

df.reset_index().set_index(['baby','dateandtime']).sort_index()
print(df)

This produces a nice dataframe that has what I want (with the exception of the rule for decreasing dosages):

                    baby  score  sum_3_scores  max_1_score  rule
dateandtime                                                     
2009-07-13 21:00:00    B     12           NaN         12.0     1
2009-07-14 00:01:00    A     14           NaN         14.0     1
2009-07-14 18:00:00    B      4           NaN          4.0     0
2009-07-15 00:04:00    B      6          22.0          6.0     0
2009-07-15 15:04:00    B      4          14.0          4.0     0
2009-07-16 00:01:00    B      7          17.0          7.0     0
2009-07-16 06:00:00    B      6          17.0          6.0     0
2009-07-16 21:00:00    A      4           NaN          4.0     0
2009-07-17 00:01:00    A     10          28.0         10.0     1
2009-07-17 06:00:00    A      5          19.0          5.0     0
2009-07-18 13:00:00    B      4          17.0          4.0     0
2009-07-19 01:00:00    B      6          16.0          6.0     0
2009-07-19 05:00:00    A      5          20.0          5.0     0
2009-07-19 09:00:00    A     11          21.0         11.0     0
2009-07-19 17:00:00    A      3          19.0          3.0     0
2009-07-20 17:00:00    B      6          16.0          6.0     0
2009-08-02 17:00:00    A      6          20.0          6.0     0

What's an easy way to program the rule for lowering the dosages? I understand I can do the 48h window with the code df.groupby('baby')['score'].rolling('48h'), but it's not clear to me how to check the sum of only the 3 most recent dosages of that window

1

There are 1 answers

2
Dames On

Your Setup:

import pandas as pd

df = pd.DataFrame({
   'baby': ['A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'A', 'B', 'B','B', 'B', 'B', 'B', 'B','B','B'],
   'dateandtime':  ['8/2/2009  5:00:00 PM', '7/19/2009  5:00:00 PM', '7/19/2009  5:00:00 PM', '7/17/2009  6:00:00 AM','7/17/2009  12:01:00 AM', '7/14/2009  12:01:00 AM', '7/19/2009  5:00:00 AM', '7/16/2009  9:00:00 PM','7/19/2009  9:00:00 AM', '7/14/2009  6:00:00 PM', '7/15/2009  3:04:00 PM', '7/20/2009  5:00:00 PM','7/16/2009, 12:01:00 AM', '7/18/2009  1:00:00 PM', '7/16/2009  6:00:00 AM', '7/13/2009  9:00:00 PM','7/19/2009  1:00:00 AM','7/15/2009  12:04:00 AM'],
   'score':  [6, 3, 3, 5, 10, 14, 5, 4, 11, 4, 4, 6, 7, 4, 6, 12, 6, 6]
    })

df.dateandtime = pd.to_datetime(df['dateandtime']) # change column type for ease of indexing
df = df.set_index('dateandtime')
df = df[~df.index.duplicated()] #Remove any duplicated rows

I'm going to use a .diff() on a .groupby() thrice. When inspecting max_last3 and sum_last3 and last48h_any_critical manually I recommend to sort by baby and dateandtime:

# this helps
df = df.sort_values(by=['baby', 'dateandtime'])
# this is okay too
df.sort_index(inplace=True)

To get the sum of the last 3 values first group by baby then get rolling windows of 3 and then get the sum of each window. Important: If the first two values are e.g. 12, 13 the sum of these 2 are >= 24 but no window of size 3 can be built! So the value would be NaN and (Nan >= 24) == False. To allow building incomplete windows use min_periods=1.

sum_last3 = df.groupby('baby')['score'].rolling(3, min_periods=1).sum()
df['sum_last3'] = sum_last3.reset_index(level=0, drop=True)

df['sum_last3_critical'] = df['sum_last3'] >= 24
df['sum_last3_good'] = df['sum_last3'] < 18

I'm still not sure if you want to look at all scores, the last 3 scores or only the very last score. This implementations detects an value >= 12 in the last 3 scores. Alternate solutions at the end.

max_last3 = df.groupby('baby')['score'].rolling(3, min_periods=1).max()
df['max_last3'] = max_last3.reset_index(level=0, drop=True)

df['max_last3_ciritical'] = df['max_last3'] >= 12
df['max_last3_good'] = df['max_last3'] < 8

Now you can build a critical column, which indicates wether the dosis must be increased.

df['critical'] = df['sum_last3_critical'] | df['max_last3_ciritical']

Now you can get a time windows of 48 hours and get the maximum value of the critical columns (1.0 if True, 0.0 if False). You would ideally be using .any() but this does not exist for a GroupBy Object. As .max() returns a numeric value convert back to boolean afterwards.

last48h_any_critical = df.groupby('baby').rolling('48h')['critical'].max().astype('bool')
df['last48h_good'] = ~last48h_any_critical.reset_index(level=0, drop=True)

Now you can get wether the baby is in good condition and the dosis should be decreased.

df['good'] = df['last48h_good'] & df['sum_last3_good'] & df['max_last3_good']

To get a action value just subtract the good column from the critical column.

df['action'] = df['critical'].astype(int) - df['good'].astype(int)

The resulting DataFrame looks like this:

                    baby  score  sum_last3  sum_last3_critical  sum_last3_good  max_last3  max_last3_ciritical  max_last3_good  critical  last48h_good   good  action
dateandtime
2009-07-14 00:01:00    A     14       14.0               False            True       14.0                 True           False      True         False  False       1
2009-07-16 21:00:00    A      4       18.0               False           False       14.0                 True           False      True         False  False       1
2009-07-17 00:01:00    A     10       28.0                True           False       14.0                 True           False      True         False  False       1
2009-07-17 06:00:00    A      5       19.0               False           False       10.0                False           False     False         False  False       0
2009-07-19 05:00:00    A      5       20.0               False           False       10.0                False           False     False          True  False       0
2009-07-19 09:00:00    A     11       21.0               False           False       11.0                False           False     False          True  False       0
2009-07-19 17:00:00    A      3       19.0               False           False       11.0                False           False     False          True  False       0
2009-08-02 17:00:00    A      6       20.0               False           False       11.0                False           False     False          True  False       0
2009-07-13 21:00:00    B     12       12.0               False            True       12.0                 True           False      True         False  False       1
2009-07-14 18:00:00    B      4       16.0               False            True       12.0                 True           False      True         False  False       1
2009-07-15 00:04:00    B      6       22.0               False           False       12.0                 True           False      True         False  False       1
2009-07-15 15:04:00    B      4       14.0               False            True        6.0                False            True     False         False  False       0
2009-07-16 00:01:00    B      7       17.0               False            True        7.0                False            True     False         False  False       0
2009-07-16 06:00:00    B      6       17.0               False            True        7.0                False            True     False         False  False       0
2009-07-18 13:00:00    B      4       17.0               False            True        7.0                False            True     False          True   True      -1
2009-07-19 01:00:00    B      6       16.0               False            True        6.0                False            True     False          True   True      -1
2009-07-20 17:00:00    B      6       16.0               False            True        6.0                False            True     False          True   True      -1

Alternative Options

If instead of looking at the last three values you want to look at all previous values. Use expanding instead.

# ideally change name of max_last3 to something like max_alltime
max_last3 = df.groupby('baby')['score'].expanding().max()
df['max_last3'] = max_last3.reset_index(level=0, drop=True)

df['max_last3_ciritical'] = df['max_last3'] >= 12
df['max_last3_good'] = df['max_last3'] < 8

And if you instead want to look at only the last value you can directly compare to score:

# ideally change name of max_last3_ciritical to something like last_ciritical
df['max_last3_ciritical'] = df['score'] >= 12
df['max_last3_good'] = df['score'] < 8