I am using the tf_Agents library for contextual bandits usecase.
In this usecase predictions (daily range between 20k and 30k predictions, 1 for each user) are made daily (multiple times a day) and training only happens on all the predicted data from 4 days ago (Since the labels for predictions takes 3 days to observe).
The driver seems to replay only the batch_size number of experience (Since max_step length is 1 for contextual bandits). Also the replay buffer has the same constraint only handling batch size number of experiences.
I wanted to use checkpointer and save all the predictions (experience from driver which are saved in replay buffer) from the past 4 days and train only on the first of the 4 days saved on each given day.
I am unsure how to do the following and any help is greatly appreciate.
- How to (run the driver) save replay buffer using checkpoints for the entire day (a day contains, say, 3 predictions runs and each prediction will be made on 30,000 observations [say batch size of 16]). So in this case I need multiple saves for each day
- How to save the replay buffers for past 4 days (12 prediction runs ) and only retrieve the first 3 prediction runs (replay buffer and the driver run) to train for each day.
- Unsure how to handle the driver, replay buffer and checkpointer configurations given the above #1, #2 above
 
                        
On the Replay Buffer I don't think there is any way to get that working without implementing your own RB class (which I wouldn't necessarily recommend). Seems to me like the most straight forward solution for this is to take the memory inefficiency hit and have two RB with a different size of
max_length. One of the two is the one given to the driver to store episodes and thenrb.as_dataset(single_determinsitic_pass=true)is used to get the appropriate items to place in the memory of the second one used for training. The only thing you need to checkpoint of course is the first one.Note: I'm not sure off-the-top-of-my head how exactly
single_determinsitic_passworks, you may want to check that in order to determine which portion of the returned dataset corresponds to the day you want to train from. I also have the suspicion that probably the portion corresponding to the last day shifts, because if I don't remember wrong the RB table that stores the experiences works with a cursor that once reached the maximum length starts overwriting from the beginning.Neither RB needs to know about the logic of how many prediction runs there are, in the end your code should manage that logic and you might want to keep track (maybe in a pickle if you want to save this) how many predictions correspond to each day so that you know which ones to pick.